├── .idea ├── .gitignore ├── vcs.xml ├── modules.xml ├── other.xml ├── misc.xml ├── inspectionProfiles │ └── Project_Default.xml └── 3DMM.iml ├── avgModel.mat ├── requirements.txt ├── resources ├── icon.png └── gitScreen1.png ├── data └── avgModel_bh_1779_NE.mat ├── registration_parameters.conf ├── pointRegistration ├── registration_parameters.conf ├── registration_param.py ├── batchRegistration.py ├── registration.py ├── model.py ├── file3D.py └── displacementMap.py ├── main.py ├── graphicInterface ├── upper_toolbar_controls.py ├── file_dialogs.py ├── rotatable_figure.py ├── plot_button_collection.py ├── window.py ├── console.py ├── show_displacement.py ├── plot_interactive_figure.py ├── conf_param_window.py ├── upper_toolbar.py ├── plot_figure.py └── main_widget.py ├── morpmodel ├── RP.py ├── Matrix_operations.py ├── _3DMM.py └── util_for_graphic.py └── README.md /.idea/.gitignore: -------------------------------------------------------------------------------- 1 | 2 | # Default ignored files 3 | /workspace.xml -------------------------------------------------------------------------------- /avgModel.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rickie95/3DMMRegistration/HEAD/avgModel.mat -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy>=1.15.0 2 | scipy>=1.3 3 | matplotlib 4 | PyQt5 5 | h5py 6 | pycpd 7 | 8 | -------------------------------------------------------------------------------- /resources/icon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rickie95/3DMMRegistration/HEAD/resources/icon.png -------------------------------------------------------------------------------- /resources/gitScreen1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rickie95/3DMMRegistration/HEAD/resources/gitScreen1.png -------------------------------------------------------------------------------- /data/avgModel_bh_1779_NE.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rickie95/3DMMRegistration/HEAD/data/avgModel_bh_1779_NE.mat -------------------------------------------------------------------------------- /registration_parameters.conf: -------------------------------------------------------------------------------- 1 | [PARAMETERS] 2 | tolerance = 0.001 3 | max_iterations = 100 4 | sigma2 = None 5 | w = 0.0 6 | 7 | -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /pointRegistration/registration_parameters.conf: -------------------------------------------------------------------------------- 1 | # Configuration file for CPD rigid registration, view readme for insights 2 | # This is a comment. 3 | [PARAMETERS] 4 | tolerance=0.001 5 | max_iterations=100 6 | sigma2=None 7 | w=0 -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from PyQt5.QtWidgets import QApplication 2 | from graphicInterface.window import * 3 | import sys 4 | 5 | if __name__ == "__main__": 6 | 7 | app = QApplication(sys.argv) 8 | 9 | Logger() 10 | RegistrationParameters() 11 | 12 | w = Window() 13 | 14 | sys.exit(app.exec_()) 15 | -------------------------------------------------------------------------------- /.idea/other.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 8 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 7 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/Project_Default.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 19 | -------------------------------------------------------------------------------- /.idea/3DMM.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 15 | 16 | 18 | -------------------------------------------------------------------------------- /graphicInterface/upper_toolbar_controls.py: -------------------------------------------------------------------------------- 1 | from PyQt5.QtWidgets import QPushButton, QComboBox 2 | 3 | 4 | class ControlPushButton(QPushButton): 5 | 6 | def __init__(self, text="Button", callback=None, enabled=True): 7 | super().__init__(text) 8 | self.clicked.connect(callback) 9 | self.setEnabled(enabled) 10 | 11 | 12 | class PercentageComboBox(QComboBox): 13 | 14 | def __init__(self, ): 15 | super().__init__() 16 | for x in range(100, 20, -10): 17 | self.addItem(str(x) + "%", x) 18 | 19 | 20 | class RegistrationMethodsCombobox(QComboBox): 21 | 22 | def __init__(self): 23 | super(RegistrationMethodsCombobox, self).__init__() 24 | # self.addItem("ICP", 0) 25 | self.addItem("CPD - Rigid", 1) 26 | self.addItem("CPD - Affine", 2) 27 | self.addItem("CPD - Deformable", 3) 28 | -------------------------------------------------------------------------------- /graphicInterface/file_dialogs.py: -------------------------------------------------------------------------------- 1 | from PyQt5.Qt import QFileDialog 2 | 3 | 4 | def load_file_dialog(parent, filters, multiple_files=False): 5 | dlg = QFileDialog() 6 | options = dlg.Options() 7 | options |= QFileDialog.DontUseNativeDialog 8 | if multiple_files: 9 | file_name, _ = dlg.getOpenFileNames(parent, "Load a model", "", filters, "File WRML (*.wrl)", options=options) 10 | else: 11 | file_name, _ = dlg.getOpenFileName(parent, "Load a model", "", filters, "File WRML (*.wrl)", options=options) 12 | if file_name == "": 13 | return None 14 | return file_name 15 | 16 | 17 | def save_file_dialog(parent, filters): 18 | dlg = QFileDialog() 19 | options = dlg.Options() 20 | options |= dlg.DontUseNativeDialog 21 | filename, _ = dlg.getSaveFileName(parent, None, "Save model", filter=filters, options=options) 22 | if filename == "": 23 | return None 24 | return filename 25 | -------------------------------------------------------------------------------- /graphicInterface/rotatable_figure.py: -------------------------------------------------------------------------------- 1 | from graphicInterface.plot_button_collection import PlotButtonCollection 2 | from graphicInterface.plot_figure import PlotFigure 3 | from pointRegistration.model import Model 4 | 5 | 6 | class RotatableFigure(PlotFigure): 7 | 8 | def __init__(self, parent, model=None, landmarks=True, title=None, secondary_model=None): 9 | self.rotate_buttons = None 10 | super().__init__(parent, model, landmarks, title) 11 | self.registered = False 12 | self.secondary_model = secondary_model 13 | 14 | def rotate(self, axis, theta, registered=False): 15 | if self.model is not None: 16 | self.clear() 17 | self.model.rotate(axis, theta) 18 | if self.secondary_model is not None: 19 | self.secondary_model.rotate(axis, theta) 20 | 21 | self.draw_data() 22 | 23 | def load_model(self, model): 24 | super().load_model(model) 25 | if self.rotate_buttons is None: 26 | self.rotate_buttons = PlotButtonCollection(self.rotate, self) 27 | 28 | def set_secondary_model(self, model): 29 | self.secondary_model = Model.from_model(model) 30 | 31 | def draw_data(self, clear=False): 32 | if self.secondary_model is not None: 33 | self.load_data(self.secondary_model.points, "r") 34 | if self.secondary_model.landmarks is not None: 35 | self.load_data(self.secondary_model.landmarks, "black") 36 | super().draw_data(clear=clear) 37 | -------------------------------------------------------------------------------- /morpmodel/RP.py: -------------------------------------------------------------------------------- 1 | class RP: 2 | def __init__(self): 3 | self.width = 640 4 | self.height = 486 5 | self.gamma = 0 6 | self.theta = 0 7 | self.phi = 0 8 | self.alpha = 0 9 | self.t2d = [0,0] 10 | self.camera_pos = [0,0,3400] 11 | self.scale_scene = 0.0 12 | #rselfp.object_size = 0.615 * 512 13 | self.shift_object = [0,0,-46125] 14 | #rp.shift_object = [0;0;0]; 15 | self.shift_world = [0,0,0] 16 | self.scale = 0.001 17 | self.ac_g = [1,1,1] 18 | self.ac_c = 1 19 | self.ac_o = [0,0,0] 20 | #self.ambient_col = 0.6*[1,1,1] 21 | #rp.rotm = eye(3) ??? 22 | self.use_rotm = 0 23 | self.do_remap = 0 24 | self.dir_light = [] 25 | self.do_specular = 0.1 26 | self.phong_exp = 8 27 | self.specular = 0.1*255 28 | self.do_cast_shadows = 1 29 | self.sbufsize = 200 30 | # projection method 31 | self.proj = 'perspective' 32 | # if scale_scene == 0, then f is used: 33 | self.f = 6000 34 | # is 1 for grey level images and 3 for color images 35 | self.n_chan = 3 36 | self.backface_culling = 2; # 2 = default for current projection 37 | # can be 'phong', 'global_illum' or 'no_illum' 38 | #self.illum_method = 'phong' 39 | #self.global_illum.brdf = 'lambert' 40 | #self.global_illum.envmap = struct([]) 41 | #self.global_illum.light_probe = [] 42 | self.ablend = [] # no blending performed 43 | -------------------------------------------------------------------------------- /morpmodel/Matrix_operations.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from numpy.linalg import norm 3 | 4 | class Matrix_op: 5 | X_aligned_model = [] 6 | X_after_training = [] 7 | X_res = [] 8 | def __init__(self, Components, aligned_models_data): 9 | if aligned_models_data is None: 10 | self.X_after_training = Components 11 | self.reshape(Components) 12 | X_aligned_model = None 13 | if Components is None: 14 | self.X_aligned_model = np.array(aligned_models_data) # 2-d Dimension and saved in memory with the fortran memory order 15 | self.X_after_training = [] 16 | self.X_res = [] 17 | 18 | def mean(self): 19 | self.X_aligned_model = self.X_aligned_model - self.X_aligned_model.mean(axis=1, keepdims=True) # axis=1 for get row's mean (axis=0 for columns) 20 | 21 | def normalization(self): 22 | self.X_aligned_model = norm(self.X_aligned_model, axis=1,ord=1) #L1-norm 23 | 24 | def transpose(self): 25 | self.X_aligned_model = np.transpose(self.X_aligned_model) 26 | 27 | def reshape(self,D): 28 | #self.X_res = np.empty((6704,3,300))# define numpy array 29 | self.X_res = np.empty((int(D.shape[0]/3),3,D.shape[1])) 30 | for c in range(D.shape[1]): 31 | comp = np.transpose(np.array(D[:,c])) 32 | _app = np.reshape(np.transpose(comp),(3,int(D.shape[0]/3)), order='F') 33 | comp = np.transpose(_app) 34 | self.X_res[:,:,c] = comp 35 | 36 | class Vector_op: 37 | V = [] 38 | def __init__(self,v_init): 39 | self.V = v_init 40 | 41 | def scale(self,mx,mn): 42 | min_w = np.amin(self.V) 43 | max_w = np.amax(self.V) 44 | self.V = (((self.V-min_w)*(mx-mn))/(max_w-min_w)) + mn 45 | -------------------------------------------------------------------------------- /graphicInterface/plot_button_collection.py: -------------------------------------------------------------------------------- 1 | from matplotlib.widgets import Button 2 | from matplotlib import pyplot 3 | 4 | 5 | class PlotButtonCollection(object): 6 | 7 | def __init__(self, callback, parent): 8 | self.callback = callback 9 | 10 | ax_yminus = pyplot.axes([0, 0.07, 0.09, 0.05]) 11 | ax_yplus = pyplot.axes( [0.1, 0.07, 0.09, 0.05]) 12 | ax_xminus = pyplot.axes([0, 0.13, 0.09, 0.05]) 13 | ax_xplus = pyplot.axes( [0.1, 0.13, 0.09, 0.05]) 14 | ax_zminus = pyplot.axes([0, 0.01, 0.09, 0.05]) 15 | ax_zplus = pyplot.axes( [0.1, 0.01, 0.09, 0.05]) 16 | 17 | self.b_yplus = Button(ax_yplus, '+15°') 18 | self.b_yplus.on_clicked(self.rotate_y_plus) 19 | 20 | self.b_yminus = Button(ax_yminus, 'Y -15°') 21 | self.b_yminus.on_clicked(self.rotate_y_minus) 22 | 23 | self.b_xplus = Button(ax_xplus, '+15°') 24 | self.b_xplus.on_clicked(self.rotate_x_plus) 25 | 26 | self.b_xminus = Button(ax_xminus, 'X -15°') 27 | self.b_xminus.on_clicked(self.rotate_x_minus) 28 | 29 | self.b_zminus = Button(ax_zminus, 'Z -15°') 30 | self.b_zminus.on_clicked(self.rotate_z_minus) 31 | 32 | self.b_zplus = Button(ax_zplus, '+15°') 33 | self.b_zplus.on_clicked(self.rotate_z_plus) 34 | 35 | def rotate_x_plus(self, event): 36 | self.callback('x', 15) 37 | 38 | def rotate_x_minus(self, event): 39 | self.callback('x', -15) 40 | 41 | def rotate_y_plus(self, event): 42 | self.callback('y', 15) 43 | 44 | def rotate_y_minus(self, event): 45 | self.callback('y', -15) 46 | 47 | def rotate_z_plus(self, event): 48 | self.callback('z', +15) 49 | 50 | def rotate_z_minus(self, event): 51 | self.callback('z', -15) 52 | -------------------------------------------------------------------------------- /graphicInterface/window.py: -------------------------------------------------------------------------------- 1 | from graphicInterface.main_widget import * 2 | from graphicInterface.console import Logger 3 | from pointRegistration.registration_param import RegistrationParameters 4 | from PyQt5.QtGui import QIcon 5 | from graphicInterface.conf_param_window import EditConfWindow 6 | 7 | 8 | class Window(QMainWindow): 9 | 10 | def __init__(self): 11 | super().__init__() 12 | self.setWindowIcon(QIcon('resources/icon.png')) 13 | Logger() 14 | self.statusLabel = QLabel("") 15 | self.initUI() 16 | 17 | def initUI(self): 18 | self.mainWidget = MainWidget(self) 19 | self.setCentralWidget(self.mainWidget) 20 | self.resize(1000, 600) 21 | self.center() 22 | statusBar = QStatusBar() 23 | self.setStatusBar(statusBar) 24 | statusBar.addWidget(self.statusLabel) 25 | statusBar.addPermanentWidget(ConfigLabel()) 26 | self.setWindowTitle('Shape Registrator') 27 | self.setStatus("Ready.") 28 | self.show() 29 | 30 | def center(self): 31 | qr = self.frameGeometry() 32 | cp = QDesktopWidget().availableGeometry().center() 33 | qr.moveCenter(cp) 34 | self.move(qr.topLeft()) 35 | 36 | def setStatus(self, message): 37 | self.statusLabel.setText(message) 38 | 39 | def setStatusReady(self): 40 | self.setStatus("Ready") 41 | 42 | 43 | class ConfigLabel(QLabel): 44 | 45 | def __init__(self): 46 | super(ConfigLabel, self).__init__() 47 | self.update_parameter_label() 48 | 49 | def mouseDoubleClickEvent(self, *args, **kwargs): 50 | window = EditConfWindow(self) 51 | window.show() 52 | 53 | def update_parameter_label(self): 54 | ss = "(Double click to edit) |" + RegistrationParameters.to_string() 55 | self.setText(ss) 56 | -------------------------------------------------------------------------------- /pointRegistration/registration_param.py: -------------------------------------------------------------------------------- 1 | from configparser import ConfigParser 2 | from graphicInterface.console import Logger 3 | 4 | 5 | class RegistrationParameters: 6 | 7 | instance = None 8 | 9 | class __PrivateParams: # Singleton 10 | 11 | def __init__(self): 12 | c = ConfigParser() 13 | c.read('registration_parameters.conf') 14 | ps = c['PARAMETERS'] 15 | # Now read every key 16 | self.tolerance = float(ps['tolerance']) 17 | self.max_iterations = int(ps['max_iterations']) 18 | self.sigma2 = None if ps['sigma2'] == 'None' else float(ps['sigma2']) 19 | self.w = float(ps['w']) 20 | self.params = {'tolerance': self.tolerance, 21 | 'max_iterations': self.max_iterations, 22 | 'sigma2': self.sigma2, 23 | 'w': self.w} 24 | 25 | def get_param(self, key): 26 | try: 27 | return self.params[key] 28 | except KeyError: 29 | return None 30 | 31 | def write_on_file(self): 32 | with open('registration_parameters.conf', 'w+') as conf_file: 33 | conf = ConfigParser() 34 | conf['PARAMETERS'] = {} 35 | for key, value in self.params.items(): 36 | conf['PARAMETERS'][str(key)] = str(value) 37 | conf.write(conf_file) 38 | Logger.addRow("Configuration file updated") 39 | 40 | def get_params(self): 41 | return self.params 42 | 43 | def set_params(self, key, value): 44 | self.params[key] = value 45 | 46 | def __init__(self): 47 | if RegistrationParameters.instance is None: 48 | RegistrationParameters.instance = self.__PrivateParams() 49 | 50 | @staticmethod 51 | def get_params(): 52 | """ Returns a dictionary with all the registration parameters. """ 53 | return RegistrationParameters.instance.get_params() 54 | 55 | @staticmethod 56 | def get_param(key): 57 | return RegistrationParameters.instance.get_param(key) 58 | 59 | @staticmethod 60 | def set_param(key, value): 61 | RegistrationParameters.instance.set_params(key, value) 62 | 63 | @staticmethod 64 | def write_on_file(): 65 | RegistrationParameters.instance.write_on_file() 66 | 67 | @staticmethod 68 | def to_string(): 69 | param = RegistrationParameters.instance.get_params() 70 | ss = "" 71 | for k in param: 72 | ss += k + " : " + str(param[k]) + " | " 73 | 74 | return ss[0:-2] 75 | 76 | -------------------------------------------------------------------------------- /graphicInterface/console.py: -------------------------------------------------------------------------------- 1 | from PyQt5.QtWidgets import QPlainTextEdit 2 | from PyQt5.QtGui import QTextCursor 3 | import datetime 4 | import time 5 | import os 6 | 7 | 8 | class Logger: 9 | 10 | instance = None 11 | lock = False 12 | 13 | class _PrivateLog(QPlainTextEdit): 14 | 15 | def __init__(self): 16 | super().__init__() 17 | self.setReadOnly(True) 18 | self.setLineWrapMode(QPlainTextEdit.NoWrap) 19 | self.cursor = QTextCursor(self.document()) 20 | self.setTextCursor(self.cursor) 21 | 22 | def save_onfile(self): 23 | now = datetime.datetime.now() 24 | filename = ("log-%d-%d_%d-%d-%d.txt" % (now.hour, now.minute, now.year, now.month, now.day)) 25 | file = open(filename, "w") 26 | file.write(self.toPlainText()) 27 | file.close() 28 | 29 | def __init__(self): 30 | if Logger.instance is None: 31 | Logger.instance = self._PrivateLog() 32 | 33 | @staticmethod 34 | def addRow(row): 35 | if Logger.instance is not None: 36 | Logger.instance.insertPlainText(row+"\n") 37 | time.sleep(0.1) 38 | Logger.instance.moveCursor(QTextCursor.End) 39 | time.sleep(0.1) 40 | print(row) 41 | 42 | @staticmethod 43 | def setParent(parent): 44 | if Logger.instance is not None: 45 | Logger.instance.setParent(parent) 46 | 47 | @staticmethod 48 | def save_log(): 49 | if Logger.instance is not None: 50 | Logger.instance.save_onfile() 51 | 52 | 53 | class suppress_stdout_stderr(object): 54 | """ 55 | A context manager for doing a "deep suppression" of stdout and stderr in 56 | Python, i.e. will suppress all print, even if the print originates in a 57 | compiled C/Fortran sub-function. 58 | This will not suppress raised exceptions, since exceptions are printed 59 | to stderr just before a script exits, and after the context manager has 60 | exited. 61 | """ 62 | 63 | def __init__(self): 64 | # Open a pair of null files 65 | self.null_fds = [os.open(os.devnull, os.O_RDWR) for x in range(2)] 66 | # Save the actual stdout (1) and stderr (2) file descriptors. 67 | self.save_fds = [os.dup(1), os.dup(2)] 68 | 69 | def __enter__(self): 70 | # Assign the null pointers to stdout and stderr. 71 | os.dup2(self.null_fds[0], 1) 72 | os.dup2(self.null_fds[1], 2) 73 | 74 | def __exit__(self, *_): 75 | # Re-assign the real stdout/stderr back to (1) and (2) 76 | os.dup2(self.save_fds[0], 1) 77 | os.dup2(self.save_fds[1], 2) 78 | # Close the null files 79 | for fd in self.null_fds + self.save_fds: 80 | os.close(fd) -------------------------------------------------------------------------------- /graphicInterface/show_displacement.py: -------------------------------------------------------------------------------- 1 | from PyQt5.QtGui import QIcon 2 | from PyQt5.QtWidgets import QMainWindow, QLabel, QGridLayout, QWidget, QPushButton 3 | from graphicInterface.rotatable_figure import RotatableFigure 4 | from graphicInterface.upper_toolbar_controls import ControlPushButton 5 | from graphicInterface.file_dialogs import * 6 | 7 | 8 | class DisplacementMapWindow(QMainWindow): 9 | 10 | def __init__(self, parent, model): 11 | super().__init__(parent=parent) 12 | self.setWindowIcon(QIcon('resources/icon.png')) 13 | self.statusLabel = QLabel("Ready") 14 | self.setWindowTitle("Displacement Map") 15 | self.model = model 16 | self.rot_figure = None 17 | self.initUI() 18 | 19 | def initUI(self): 20 | central_widget = QWidget(self) 21 | grid_central = QGridLayout(self) 22 | central_widget.setLayout(grid_central) 23 | self.setCentralWidget(central_widget) 24 | self.setLayout(grid_central) 25 | self.rot_figure = RotatableFigure(parent=self, model=self.model, title="Displacement Map") 26 | self.toolbar = LowerToolbar(self) 27 | grid_central.addWidget(self.rot_figure, 0, 0, 19, 1) 28 | grid_central.addWidget(self.toolbar, 20, 0, 1, 1) 29 | if self.model is not None: 30 | self.toolbar.save_button.setEnabled(True) 31 | self.resize(600, 600) 32 | self.rot_figure.draw_data(clear=True) 33 | self.show() 34 | 35 | def save_displacement_map(self): 36 | filters = "Serialized Python Obj (*.pickle);;H5Py Compatible file (*.h5py)" 37 | filename = save_file_dialog(self, filters) 38 | if filename is not None: 39 | self.model.save_model(filename) 40 | 41 | def load_displacement_map(self): 42 | filters = "H5Py Compatible file (*.h5py)" 43 | filename = load_file_dialog(self, filters) 44 | if filename is not None: 45 | from pointRegistration.displacementMap import DisplacementMap 46 | self.model = DisplacementMap.load_model(filename) 47 | self.rot_figure.load_model(self.model) 48 | self.rot_figure.draw_data() 49 | self.toolbar.save_button.setEnabled(True) 50 | 51 | 52 | class LowerToolbar(QWidget): 53 | 54 | def __init__(self, parent): 55 | super().__init__(parent) 56 | self.parent = parent 57 | self.initUI() 58 | 59 | def initUI(self): 60 | self.layout = QGridLayout() 61 | self.setLayout(self.layout) 62 | self.save_button = ControlPushButton("Save on file", self.parent.save_displacement_map, False) 63 | load_button = ControlPushButton("Load from file..", self.parent.load_displacement_map) 64 | self.layout.addWidget(self.save_button, 0, 1, 1, 1) 65 | self.layout.addWidget(load_button, 0, 3, 1, 1) -------------------------------------------------------------------------------- /graphicInterface/plot_interactive_figure.py: -------------------------------------------------------------------------------- 1 | import matplotlib.patches as patches 2 | import numpy as np 3 | from matplotlib.widgets import RectangleSelector 4 | from scipy import spatial 5 | 6 | from graphicInterface.console import suppress_stdout_stderr 7 | from graphicInterface.plot_figure import PlotFigure 8 | 9 | 10 | class PlotInteractiveFigure(PlotFigure): 11 | 12 | def __init__(self, parent, model=None, landmarks=True, title=None): 13 | super().__init__(parent, model, landmarks, title) 14 | self.myTree = None 15 | self.RS = RectangleSelector(self.ax, self.square_select_callback, drawtype='box', useblit=True, button=[1, 3], 16 | minspanx=5, minspany=5, spancoords='pixels', interactive=True) 17 | 18 | def load_model(self, model): 19 | self.myTree = None 20 | super().load_model(model) 21 | 22 | def select_nearest_pixel(self, x_coord, y_coord): 23 | if self.myTree is None: 24 | print("Calculating 2DTree...") 25 | self.myTree = spatial.cKDTree(self.model.landmarks_3D[:, 0:2]) # costruisce il KDTree con i punti del Model 26 | 27 | dist, index = self.myTree.query([[x_coord, y_coord]], k=1) 28 | if dist < 5: 29 | self.landmarks_colors[index[0]] = "y" if self.landmarks_colors[index[0]] == "r" else "r" 30 | self.draw() 31 | if self.parent() is not None: 32 | self.parent().landmark_selected(self.landmarks_colors) 33 | 34 | def square_select_callback(self, eclick, erelease): 35 | # eclick and erelease are the press and release events 36 | x1, y1 = eclick.xdata, eclick.ydata 37 | x2, y2 = erelease.xdata, erelease.ydata 38 | rect = patches.Rectangle((min(x1, x2), min(y1, y2)), np.abs(x1 - x2), np.abs(y1 - y2), 39 | linewidth=1, edgecolor='r', facecolor='none', fill=True) 40 | self.get_ax().add_patch(rect) 41 | self.select_area(min(x1, x2), min(y1, y2), np.abs(x1 - x2), np.abs(y1 - y2)) 42 | self.draw() 43 | 44 | def there_are_points_highlighted(self): 45 | return self.model.has_registration_points() 46 | 47 | def select_area(self, x_coord, y_coord, width, height): 48 | with suppress_stdout_stderr(): 49 | x_data = self.model.points[:, 0] 50 | y_data = self.model.points[:, 1] 51 | 52 | x_ind = np.where((x_coord <= x_data) & (x_data <= x_coord + width)) 53 | y_ind = np.where((y_coord <= y_data) & (y_data <= y_coord + height)) 54 | 55 | ind = np.intersect1d(np.array(x_ind), np.array(y_ind), assume_unique=True) 56 | self.highlight_data(ind) 57 | 58 | def highlight_data(self, indices): 59 | if indices[0] != -1: 60 | self.model.add_registration_points(indices) 61 | self.draw_data() 62 | else: 63 | self.model.init_registration_points() 64 | self.draw_data() 65 | -------------------------------------------------------------------------------- /pointRegistration/batchRegistration.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from threading import Thread 3 | from graphicInterface.console import Logger 4 | from pycpd.rigid_registration import rigid_registration 5 | from pointRegistration.model import Model 6 | from functools import partial 7 | import time 8 | import datetime 9 | import os 10 | 11 | 12 | class BatchRegistrationThread(Thread): 13 | 14 | def __init__(self, source, target_list, percentage, final_callback): 15 | Thread.__init__(self) 16 | self.source_model = source 17 | self.target_list = target_list 18 | self.perc = percentage 19 | self.finalCallback = final_callback 20 | self.should_stop = False 21 | Logger.addRow("Starting Batch Thread..") 22 | 23 | def run(self): 24 | source = self.source_model.get_registration_points() 25 | 26 | try: 27 | for targ in self.target_list: 28 | Logger.addRow("Batch %d of %d:" % (self.target_list.index(targ) + 1, len(self.target_list))) 29 | path_wrl = targ[0:len(targ) - 3] + "bnd" 30 | t = Model(targ, path_wrl) 31 | target = Model.decimate(t.points, self.perc) 32 | Logger.addRow("Points decimated.") 33 | if t.landmarks is not None: 34 | target = np.concatenate((target, t.landmarks), axis=0) 35 | reg = rigid_registration(**{'X': source, 'Y': target}) 36 | meth = "CPD Rigid" 37 | 38 | Logger.addRow("Starting registration with " + meth + ", using " + str(self.perc) + "% of points.") 39 | model = Model() 40 | 41 | reg_time = time.time() 42 | 43 | # Se si vuole visualizzare i progressi usare questa versione 44 | # data, reg_param = reg.register(partial(self.drawCallback, ax=None)) 45 | data, reg_param = reg.register(partial(self.log, ax=None)) 46 | 47 | model.set_points(reg.transform_point_cloud(self.source_model.model_data)) 48 | 49 | model.registration_params = reg_param 50 | if t.landmarks is not None: 51 | model.set_landmarks(data[target.shape[0] - t.landmarks.shape[0]: data.shape[0]]) 52 | model.filename = t.filename 53 | # model.centerData() 54 | model.compute_displacement_map(target, 3) 55 | now = datetime.datetime.now() 56 | save_filename = "RIGID_REG_{0}_{1}_{2}_{3}_{4}.mat" 57 | save_path = os.path.join("results", save_filename.format(now.day, now.month, now.year, now.hour, 58 | now.minute)) 59 | model.save_model(save_path) 60 | model.shoot_displacement_map(save_path) 61 | Logger.addRow("Took " + str(round(time.time() - reg_time, 3)) + "s.") 62 | 63 | except Exception as ex: 64 | Logger.addRow(str(ex)) 65 | print(ex) 66 | finally: 67 | self.finalCallback() 68 | 69 | def log(self, iteration, error, X, Y, ax): 70 | sss = "Iteration #" + str(iteration) + " error: " + str(error) 71 | Logger.addRow(sss) 72 | if self.should_stop: 73 | raise Exception("Registration has been stopped") 74 | 75 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Shape Registrator 2 | 3 | *A tool for 3D face models aligment. With a GUI.* 4 | 5 | ![Application screenshot](https://github.com/rickie95/3DMMRegistration/blob/master/resources/gitScreen1.png) 6 | 7 | Shape Registrator works with 3D model and supports **.mat** (Matlab), **.wrl** (WRML) and **.off** (Open File Format) file formats. 8 | 9 | ## 2019 Update: 10 | - Displacement map can now be saved as a Pickle file. 11 | - Target model can be rotated, in order to provide a good starting point for CPD algorithm. 12 | - Source model is now shown in the target section while registrated. 13 | 14 | ### Installation 15 | 16 | 1. Clone repo. 17 | 2. `pip install -r requirements.txt` 18 | 3. Run main.py 19 | 20 | ### Usage 21 | 22 | 1. Load two models, you can choose between .mat, .wrl + .bnd and .off formats. 23 | 2. Select the amount of points to be used and the CPD version: rigid, affine or deformable. 24 | 3. "Registrate", and wait for the result. 25 | 26 | In the log box you can see the registration error and some messages. 27 | 28 | ### About file formats: 29 | 30 | **.mat** files have to contain a '3ddata' field storing 3D points coordinates (a Nx3 matrix) and a 'landmarks3d' witch stores the landmarks coordinates (Nx3 matrix). 31 | 32 | **.wrl** files should be paired with a .bnd file storing the landmarks coordinates. If a .png image with the same name is aviable, it will be used for the background of the plot tool. 33 | 34 | **.off** files must contain 3D points coordinates, faces are not mandatory and will be ignored. 35 | 36 | ## Insights 37 | 38 | In the very heart of the application there's the *Coherent Point Drift* algorithm[1], witch allows to registrate two point sets regardless from the transformation's nature. 39 | In fact, CPD provides three kind of transformation: **rigid**, **affine** and **deformable**; with crescent grades of freedom. 40 | 41 | It's possible to control the quantity of the points involved in the registration process: you can choose to keep all of them (100%) or to reduce until 30% of original number. The sampling policy used is *uniform and casual*. 42 | 43 | Reduce the amount of points makes the registration process faster, but - only for affine and deformable case - the transformation matrices can't be used to process the entire point set. However, we can take advantage of this option in the rigid case. 44 | 45 | ## Packages required 46 | *A requirements.txt is provided, in order to be used with pip* 47 | 48 | Package | Version 49 | --------|-------- 50 | numpy | >=1.15.0 51 | scipy | >=1.3 52 | matplotlib| ~2.2.3 53 | PyQt5| ~5.11.2 54 | h5py| ~2.8 55 | pycpd| ~1.0.3 56 | 57 | The application was tested so far with Python 3.7 x64 operating on Windows 10 and Ubuntu 16.04 LTS. Should work on MacOs as well, since PyQt5 is cross-platform. 58 | 59 | ## References 60 | [1] Andriy Myronenko and Xubo Song, "*Point Set Registration: Coherent Point Drift*", IEEE Trans. on Pattern Analysis and Machine Intelligence, vol. 32, issue 12, pp. 2262-2275, 15 May 2009 {[link](https://arxiv.org/pdf/0905.2635.pdf)} 61 | 62 | [2] Alessandro Soci and Gabriele Barlacchi, [AlessandroSoci/3DMM-Facial-Expression-from-Webcam](https://github.com/AlessandroSoci/3DMM-Facial-Expression-from-Webcam) 63 | 64 | [3] Alessandro Sestini and Francesco Lombardi, [fralomba/Facial-Expression-Prediction](https://github.com/fralomba/Facial-Expression-Prediction) 65 | -------------------------------------------------------------------------------- /graphicInterface/conf_param_window.py: -------------------------------------------------------------------------------- 1 | from PyQt5.QtWidgets import QLabel, QMainWindow, QWidget, QGridLayout, QGroupBox, QLineEdit 2 | from PyQt5.QtGui import QIntValidator, QDoubleValidator, QValidator 3 | from graphicInterface.upper_toolbar_controls import ControlPushButton 4 | from pointRegistration.registration_param import RegistrationParameters 5 | 6 | 7 | class EditConfWindow(QMainWindow): 8 | 9 | def __init__(self, parent): 10 | super().__init__(parent=parent) 11 | self.setWindowTitle("Edit configuration") 12 | self.setCentralWidget(ConfEditCentralWidget(self)) 13 | self.resize(400, 200) 14 | self.show() 15 | 16 | 17 | class ConfEditCentralWidget(QWidget): 18 | 19 | def __init__(self, parent): 20 | super(ConfEditCentralWidget, self).__init__(parent=parent) 21 | general_layout = QGridLayout() 22 | self.setLayout(general_layout) 23 | group = QGroupBox("Configuration parameters") 24 | general_layout.addWidget(group, 0, 0, 1, 6) 25 | 26 | general_layout.addWidget(ControlPushButton("Apply", self.apply_changes), 1, 4, 1, 1) 27 | general_layout.addWidget(ControlPushButton("Close", self.parent().close), 1, 5, 1, 1) 28 | 29 | layout_conf = QGridLayout() 30 | group.setLayout(layout_conf) 31 | 32 | layout_conf.addWidget(QLabel("Tolerance"), 0, 0) 33 | layout_conf.addWidget(QLabel("Max iterations for Local Registration"), 1, 0) 34 | layout_conf.addWidget(QLabel("Sigma2"), 2, 0) 35 | weight_label = QLabel("Weight ?") 36 | weight_label.setToolTip("Uniform distribution component in GMM for Expectation-Maximizazion step. " 37 | "Must be in interval (0,1)") 38 | layout_conf.addWidget(weight_label, 3, 0) 39 | 40 | self.tolerance_input = ValidatedLineEdit('tolerance', DoubleNoneValidator(0.0000001, 1000, 7)) 41 | layout_conf.addWidget(self.tolerance_input, 0, 1) 42 | 43 | self.max_iterations_input = ValidatedLineEdit('max_iterations', IntValidator(1, 999)) 44 | layout_conf.addWidget(self.max_iterations_input, 1, 1) 45 | 46 | self.sigma_input = ValidatedLineEdit('sigma2', DoubleNoneValidator(0, 1, 6, nullable=True)) 47 | layout_conf.addWidget(self.sigma_input, 2, 1) 48 | 49 | self.weight_input = ValidatedLineEdit('w', DoubleNoneValidator(-0.0000001, 1, 6)) 50 | layout_conf.addWidget(self.weight_input, 3, 1) 51 | 52 | self.inputs = [self.tolerance_input, self.max_iterations_input, self.sigma_input, self.weight_input] 53 | 54 | def apply_changes(self): 55 | for value_input in self.inputs: 56 | if value_input.is_valid(): 57 | RegistrationParameters.set_param(value_input.key, value_input.get_value()) 58 | else: 59 | print(f"Parameter for {value_input.key} is invalid, using previous value.") 60 | RegistrationParameters.write_on_file() 61 | self.parent().parent().update_parameter_label() 62 | 63 | 64 | class ValidatedLineEdit(QLineEdit): 65 | 66 | def __init__(self, key, validator=None): 67 | super(ValidatedLineEdit, self).__init__() 68 | self.key = key 69 | self.setText(str(RegistrationParameters.get_param(key))) 70 | self.validator = validator 71 | 72 | def get_value(self): 73 | text = self.text() 74 | 75 | if text.lower().find("none") > -1: 76 | return None 77 | 78 | return self.validator.type(text) 79 | 80 | def is_valid(self): 81 | state = self.validator.validate(self.text()) 82 | if state == QValidator.Invalid: 83 | return False 84 | return True 85 | 86 | 87 | class IntValidator(QValidator): 88 | 89 | def __init__(self, bottom, top): 90 | super(IntValidator, self).__init__() 91 | self.bottom = bottom 92 | self.top = top 93 | self.type = int 94 | 95 | def validate(self, p_str, p_int=None): 96 | try: 97 | if self.top > int(p_str) > self.bottom: 98 | return QValidator.Acceptable 99 | except ValueError: 100 | return QValidator.Invalid 101 | 102 | return QValidator.Invalid 103 | 104 | 105 | class DoubleNoneValidator(QValidator): 106 | 107 | def __init__(self, bottom, top, digits, nullable=False): 108 | super(DoubleNoneValidator, self).__init__() 109 | self.bottom = bottom 110 | self.top = top 111 | self.digits = digits 112 | self.type = float 113 | self.nullable = nullable 114 | 115 | def validate(self, p_str, p_int=None): 116 | if self.nullable and p_str.lower().find('none') > -1: 117 | return QValidator.Acceptable 118 | try: 119 | if self.top > float(p_str) > self.bottom and len(p_str) <= (self.digits + 2): 120 | return QValidator.Acceptable 121 | except ValueError: 122 | return QValidator.Invalid 123 | 124 | return QValidator.Invalid 125 | -------------------------------------------------------------------------------- /pointRegistration/registration.py: -------------------------------------------------------------------------------- 1 | from pointRegistration.registration_param import RegistrationParameters 2 | from pointRegistration.model import Model 3 | from graphicInterface.console import Logger 4 | from threading import Thread 5 | from functools import partial 6 | from pycpd import * 7 | import numpy as np 8 | import time 9 | 10 | 11 | class Registration(Thread): 12 | 13 | def __init__(self, method, source_model, target_model, perc, callback, iteration_callback=None): 14 | Thread.__init__(self) 15 | self.method = method 16 | self.source_model = source_model 17 | self.target_model = target_model 18 | self.percentage = perc 19 | self.callback = callback 20 | self.should_stop = False 21 | self.iteration_callback = iteration_callback 22 | self.registration_method = None 23 | 24 | def run(self): 25 | source = self.source_model.get_registration_points() 26 | target = Model.decimate(self.target_model.points, self.percentage) 27 | Logger.addRow("Points decimated.") 28 | if self.target_model.landmarks is not None: 29 | target = np.concatenate((target, self.target_model.landmarks), axis=0) 30 | Logger.addRow("Landmarks added.") 31 | 32 | ps = RegistrationParameters().get_params() 33 | 34 | if self.method == 1: # CPD - RIGID 35 | self.registration_method = rigid_registration(**{'X': target, 'Y': source, 'sigma2': ps['sigma2'], 36 | 'max_iterations': ps['max_iterations'], 37 | 'tolerance': ps['tolerance'], 'w': ps['w']}) 38 | method = "CPD Rigid" 39 | if self.method == 2: # CPD - AFFINE 40 | self.registration_method = affine_registration(**{'X': target, 'Y': source, 'sigma2': ps['sigma2'], 41 | 'max_iterations': ps['max_iterations'], 42 | 'tolerance': ps['tolerance'], 'w': ps['w']}) 43 | method = "CPD Affine" 44 | if self.method == 3: # CPD - DEFORMABLE 45 | self.registration_method = deformable_registration(**{'X': target, 'Y': source, 'sigma2': ps['sigma2'], 46 | 'max_iterations': ps['max_iterations'], 47 | 'tolerance': ps['tolerance'], 'w': ps['w']}) 48 | method = "CPD Deformable" 49 | 50 | Logger.addRow("Starting registration with " + method + ", using " + str(self.percentage) + "% of points.") 51 | model = Model() 52 | reg_time = time.time() 53 | 54 | try: 55 | self.registration_method.register(partial(self.interruptable_wrapper, ax=None)) 56 | model = self.aligned_model(model) 57 | except InterruptedException as ex: 58 | Logger.addRow(str(ex)) 59 | model = self.aligned_model(model) 60 | except Exception as ex: 61 | Logger.addRow("Err: " + str(ex)) 62 | model = self.target_model # Fail: back with the original target model 63 | finally: 64 | Logger.addRow("Took "+str(round(time.time()-reg_time, 3))+"s.") 65 | self.callback(model) 66 | 67 | def aligned_model(self, model): 68 | """ 69 | Transforms all data points and landmarks of source model, applying the transformation obtained during the 70 | registration phase. 71 | :param model: A initialized and empty Model object 72 | :return: The input model, filled with transformed data points, landmarks, filename, registration parameters 73 | and displacement map oomputed between transformed source and target. 74 | """ 75 | model.registration_params = self.registration_method.get_registration_parameters() 76 | points = self.registration_method.transform_point_cloud(self.source_model.points) 77 | if self.source_model.landmarks is not None: 78 | landmarks = self.registration_method.transform_point_cloud(self.source_model.landmarks) 79 | model.set_landmarks(landmarks) 80 | 81 | model.set_points(points) 82 | model.filename = self.target_model.filename 83 | 84 | # model.compute_displacement_map(self.target_model, 3) 85 | return model 86 | 87 | def stop(self): 88 | self.should_stop = True 89 | 90 | def log(self, iteration, error): 91 | row = "Iteration #" + str(iteration) + " error: " + str(error) 92 | Logger.addRow(row) 93 | 94 | def interruptable_wrapper(self, **kwargs): 95 | if self.should_stop: 96 | raise InterruptedException("Registration has been stopped") 97 | 98 | if self.iteration_callback is None: 99 | self.log(**kwargs) 100 | else: 101 | self.iteration_callback(**kwargs) 102 | 103 | 104 | class InterruptedException(Exception): 105 | pass 106 | -------------------------------------------------------------------------------- /graphicInterface/upper_toolbar.py: -------------------------------------------------------------------------------- 1 | from graphicInterface.file_dialogs import * 2 | from PyQt5.Qt import * 3 | from graphicInterface.console import Logger 4 | from graphicInterface.upper_toolbar_controls import * 5 | 6 | 7 | class UpperToolbar(QWidget): 8 | 9 | def __init__(self, parent): 10 | super().__init__(parent) 11 | self.parent = parent 12 | self.layout = QGridLayout() 13 | self.setLayout(self.layout) 14 | 15 | # "Registration" group 16 | registration_group = QGroupBox("Registration") 17 | registration_group_layout = QGridLayout() 18 | registration_group.setLayout(registration_group_layout) 19 | label = QLabel("Registration method:") 20 | label.setAlignment(Qt.AlignCenter) 21 | registration_group_layout.addWidget(label, 0, 0) 22 | 23 | self.registration_method_combobox = RegistrationMethodsCombobox() 24 | registration_group_layout.addWidget(self.registration_method_combobox, 0, 1) 25 | 26 | self.start_registration_button = ControlPushButton("Start Registration", self.registrate, False) 27 | registration_group_layout.addWidget(self.start_registration_button, 0, 2) 28 | 29 | self.stop_registration_button = ControlPushButton("Stop", self.stop_registration, False) 30 | registration_group_layout.addWidget(self.stop_registration_button, 1, 2) 31 | 32 | point_percentage_label = QLabel("Target's points used:") 33 | point_percentage_label.setAlignment(Qt.AlignCenter) 34 | registration_group_layout.addWidget(point_percentage_label, 1, 0) 35 | 36 | self.point_percentage_combobox = PercentageComboBox() 37 | registration_group_layout.addWidget(self.point_percentage_combobox, 1, 1) 38 | 39 | # "Model" group 40 | model_group = QGroupBox("Models") 41 | model_group_layout = QGridLayout() 42 | model_group.setLayout(model_group_layout) 43 | 44 | model_group_layout.addWidget(ControlPushButton("Load Source", self.load_source, True), 0, 0) 45 | model_group_layout.addWidget(ControlPushButton("Restore", self.restore, True), 1, 0) 46 | model_group_layout.addWidget(ControlPushButton("Load Target", self.load_target, True), 0, 1) 47 | 48 | self.save_target_btn = ControlPushButton("Save Target", self.save_target, False) 49 | model_group_layout.addWidget(self.save_target_btn, 1, 1) 50 | 51 | model_group_layout.addWidget(ControlPushButton("Batch registration", self.batch_reg, False), 0, 2) 52 | 53 | self.show_displacement_btn = ControlPushButton("Show Displacement Map", self.show_displacement, True) 54 | model_group_layout.addWidget(self.show_displacement_btn, 1, 2) 55 | 56 | # LOGGER GROUP 57 | logger_group = QGroupBox("Log") 58 | logger_group_layout = QGridLayout() 59 | logger_group.setLayout(logger_group_layout) 60 | 61 | save_log_btn = QPushButton("Save log on file") 62 | save_log_btn.clicked.connect(self.savelog_onfile) 63 | logger_group_layout.addWidget(Logger.instance) 64 | logger_group_layout.addWidget(save_log_btn) 65 | 66 | self.layout.addWidget(registration_group, 0, 0) 67 | self.layout.addWidget(model_group, 0, 1) 68 | self.layout.addWidget(logger_group, 0, 2) 69 | 70 | self.layout.setColumnStretch(0, 2) 71 | self.layout.setColumnStretch(1, 1) 72 | self.layout.setColumnStretch(2, 4) 73 | 74 | def registrate(self): 75 | method = self.registration_method_combobox.currentData() 76 | percent = self.point_percentage_combobox.currentData() 77 | try: 78 | self.parent.registrate(method, percent) 79 | self.start_registration_button.setEnabled(False) 80 | self.stop_registration_button.setEnabled(True) 81 | except Exception as ex: 82 | print(ex) 83 | 84 | def show_displacement(self): 85 | self.parent.show_displacement_map() 86 | 87 | def savelog_onfile(self): 88 | self.parent.savelog_onfile() 89 | 90 | def stop_registration(self): 91 | self.parent.stop_registration_thread() 92 | self.stop_registration_button.setEnabled(False) 93 | 94 | def restore(self): 95 | self.parent.restore() 96 | 97 | @pyqtSlot() 98 | def batch_reg(self): 99 | file_names = self.load_file(multiple_files=True) 100 | if file_names: 101 | method = self.registration_method_combobox.currentData() 102 | percent = self.point_percentage_combobox.currentData() 103 | try: 104 | self.parent.registrate_batch(method, percent, file_names) 105 | self.start_registration_button.setEnabled(False) 106 | self.stop_registration_button.setEnabled(True) 107 | except Exception as ex: 108 | print(ex) 109 | 110 | @pyqtSlot() 111 | def load_target(self): 112 | self.parent.load_target() 113 | 114 | @pyqtSlot() 115 | def load_source(self): 116 | self.parent.load_source() 117 | 118 | @pyqtSlot() 119 | def save_target(self): 120 | self.parent.save_target() 121 | -------------------------------------------------------------------------------- /graphicInterface/plot_figure.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PyQt5.QtWidgets import QSizePolicy 3 | from matplotlib import pyplot 4 | from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas 5 | 6 | from graphicInterface.console import suppress_stdout_stderr 7 | 8 | 9 | class PlotFigure(FigureCanvas): 10 | 11 | def __init__(self, parent, model=None, landmarks=True, title=None): 12 | self.fig, self.ax = pyplot.subplots() 13 | self.title = title 14 | FigureCanvas.__init__(self, self.fig) 15 | FigureCanvas.setSizePolicy(self, QSizePolicy.Expanding, QSizePolicy.Expanding) 16 | FigureCanvas.updateGeometry(self) 17 | self.setParent(parent) 18 | self.draw_landmarks = landmarks 19 | self.bgImage = None 20 | self.model = model 21 | self.drawDisplacement = False 22 | self.landmarks_colors = None 23 | self.data_colors = None 24 | self.scale_width = None 25 | self.scale_height = None 26 | self.legend_handlers = [] 27 | self.legend_has_been_drawn = True 28 | 29 | if model is not None: 30 | self.load_model(model) 31 | 32 | def load_model(self, model): 33 | self.model = model 34 | self.drawDisplacement = False 35 | 36 | def load_data(self, data, color, marker='o', points_size=0.5, label=""): 37 | if data is None or data.shape[0] == 0: 38 | return 39 | append = False 40 | if label != "": 41 | label += f" ({data.shape[0]})" 42 | append = True 43 | max_v = np.max(data[:, 2]) 44 | min_v = np.min(data[:, 2]) 45 | sizes = np.copy(data[:, 2]) 46 | self.sizes = ((sizes + np.abs(min_v)) / np.abs(max_v)) + 0.5 47 | legend_handler, = self.ax.plot(data[:, 0], data[:, 1], c=color, marker=marker, linestyle='None', 48 | markersize=points_size, label=label) 49 | if append: 50 | self.legend_handlers.append(legend_handler) 51 | 52 | def load_displacement(self): 53 | self.ax.plot(self.model.displacement_map[:, 0], self.model.displacement_map[:, 1]) 54 | 55 | def load_image(self): 56 | if self.bgImage is not None: 57 | img = pyplot.imread(self.bgImage) 58 | self.ax.imshow(img, extent=[-self.scale_height * 1.05, self.scale_height * 1.03, -self.scale_height * 1.03, 59 | self.scale_height * 1.05]) # SX DX BOTTOM UP 60 | 61 | def show_displacement(self): 62 | self.drawDisplacement = True if self.model.displacement_map is not None else False 63 | self.draw_landmarks = False 64 | 65 | def clear(self): 66 | self.ax.cla() 67 | 68 | def draw_data(self, clear=False): 69 | if clear: 70 | self.ax.cla() 71 | if self.title is not None: 72 | self.ax.set_title(self.title) 73 | if self.model is not None: 74 | self.ax.autoscale() 75 | self.ax.set_aspect('equal') 76 | self.ax.set_xlabel('X axis') 77 | self.ax.set_ylabel('Y axis') 78 | 79 | self.load_data(self.model.points, self.model.points_color, label="Points") 80 | if self.model.landmarks is not None: 81 | self.load_data(self.model.landmarks, self.model.landmarks_color, points_size=5, label="Landmarks") 82 | try: 83 | self.load_data(self.model.points[self.model.registration_points], 'y', label="Registration Points") 84 | except Exception as ex: 85 | pass 86 | 87 | try: 88 | self.load_data(self.model.missed_points, self.model.missed_points_color, marker='v', 89 | label="Missed Points") 90 | self.load_data(self.model.missed_landmarks, self.model.missed_landmarks_color, size=5, marker='v', 91 | label="Missed Landmarks") 92 | except Exception as ex: 93 | pass 94 | """if len(self.legend_handlers) > 0 and self.legend_has_been_drawn: 95 | self.ax.legend(handles=self.legend_handlers) 96 | self.legend_has_been_drawn = False 97 | """ 98 | self.load_image() 99 | self.draw() 100 | 101 | def draw(self, clear=False): 102 | if clear is True: 103 | self.clear() 104 | super().draw() 105 | self.flush_events() 106 | 107 | def restore_model(self): 108 | self.ax.cla() 109 | self.draw_data() 110 | 111 | def landmarks(self, l): 112 | self.draw_landmarks = l 113 | 114 | def set_landmarks_colors(self, colors): 115 | self.model.landmarks_color = colors 116 | self.draw_data() 117 | 118 | def set_data_colors(self, colors): 119 | self.model.points_color = colors 120 | self.draw_data() 121 | 122 | def get_ax(self): 123 | return self.ax 124 | 125 | def update_plot_callback(self, iteration, error, X, Y, ax): 126 | self.ax.cla() 127 | print(iteration, error) 128 | self.ax.scatter(Y[:, 0], Y[:, 1], 0.1, c='r') 129 | self.ax.scatter(X[:, 0], X[:, 1], 0.1, c='b') 130 | if self.bgImage is not None: 131 | img = pyplot.imread(self.bgImage) 132 | self.ax.imshow(img, extent=[-self.model.rangeY/2 * 1.05, self.model.rangeY/2 * 1.03, 133 | -self.model.rangeY/2 * 1.03, self.model.rangeY/2 * 1.05]) 134 | try: 135 | self.ax.text(0.87, 0.92, 'Iteration: {:d}\nError: {:06.4f}'.format(iteration, error), 136 | horizontalalignment='center', verticalalignment='center', transform=self.ax.transAxes, 137 | fontsize='x-large') 138 | except Exception as ex: 139 | print(ex) 140 | with suppress_stdout_stderr(): 141 | super().draw() 142 | self.flush_events() 143 | 144 | def has_model(self): 145 | if self.model is not None: 146 | return True 147 | 148 | return False 149 | -------------------------------------------------------------------------------- /pointRegistration/model.py: -------------------------------------------------------------------------------- 1 | import ntpath 2 | import os 3 | 4 | import h5py 5 | import numpy as np 6 | from scipy.spatial.transform import Rotation 7 | 8 | from graphicInterface.console import Logger 9 | from pointRegistration import file3D 10 | 11 | 12 | class Model: 13 | 14 | def __init__(self, path_data=None): 15 | """ 16 | Create a model from a .wrl file (3D points) and a .bnd file (landmarks) or load 17 | both elements from a .mat file. 18 | :param path_data: path to the .wrl file 19 | :param path_landmarks: path to .bnd file (only for .wrl/.bnd case) 20 | :param image: path to the associated image (opt) 21 | """ 22 | self.landmarks = None 23 | self.landmarks_color = None 24 | self.points = None 25 | self.points_color = None 26 | 27 | self.highlighted_color = "y" 28 | 29 | self.init_attributes() 30 | if path_data is not None: 31 | self.load_model(path_data) 32 | self.center_model() 33 | 34 | @classmethod 35 | def from_model(cls, model): 36 | copy = cls() 37 | copy.set_points(np.copy(model.points)) 38 | if model.points_color is not None: 39 | copy.points_color = model.points_color 40 | if model.landmarks is not None: 41 | copy.set_landmarks(np.copy(model.landmarks)) 42 | if model.landmarks_color is not None: 43 | copy.landmarks_color = model.landmarks_color 44 | copy.filename = model.filename 45 | return copy 46 | 47 | def init_attributes(self): 48 | self.bgImage = None 49 | self.registration_points = None 50 | self.registration_params = None 51 | self.displacement_map = None 52 | self.rangeX = None 53 | self.rangeY = None 54 | self.filename = None 55 | 56 | def center_model(self): 57 | # Not always scans are perfectly centered. 58 | # Median is not sensible to big extreme value clusters 59 | points_median = np.median(self.points, axis=0) 60 | self.points -= points_median 61 | if self.landmarks is not None: 62 | self.landmarks -= points_median 63 | 64 | def set_points(self, data): 65 | self.points = data 66 | self.rangeX = np.ptp(self.points[:, 0]) 67 | self.rangeY = np.ptp(self.points[:, 1]) 68 | self.points_color = "b" 69 | 70 | def set_landmarks(self, land): 71 | self.landmarks = land 72 | self.landmarks_color = "r" 73 | 74 | def add_registration_points(self, reg_points): 75 | if reg_points[0] == -1: 76 | self.registration_points = np.empty((0, 3), dtype=int) 77 | 78 | self.registration_points = np.unique(np.append(self.registration_points, reg_points)) 79 | 80 | def init_registration_points(self): 81 | self.registration_points = np.empty((0, 3), dtype=int) 82 | 83 | def get_registration_points(self): 84 | self.registration_points = np.unique(self.registration_points) 85 | return np.array(self.points[self.registration_points]) 86 | 87 | def has_registration_points(self): 88 | if self.registration_points.shape[0] > 0: 89 | return True 90 | return False 91 | 92 | def save_model(self, filename): 93 | model = {"model_data": self.points} 94 | 95 | if self.landmarks is not None: 96 | model["landmarks3D"] = self.landmarks 97 | if self.displacement_map is not None: 98 | model["displacement_map"] = self.displacement_map 99 | if self.registration_params is not None: 100 | for i in range(len(self.registration_params)): 101 | model[str("reg_param"+str(i))] = self.registration_params[i] 102 | 103 | file3D.save_file(filename, model) 104 | Logger.addRow(str("File saved: " + filename)) 105 | 106 | def compute_displacement_map(self, target_model, distance): 107 | from pointRegistration.displacementMap import DisplacementMap 108 | if target_model is None: 109 | return None 110 | return DisplacementMap.compute_map(source_model=self, target_model=target_model, max_dist=distance) 111 | 112 | def rotate(self, axis, theta): 113 | self.points = Model.rotate_model(axis, theta, self.points) 114 | if self.landmarks is not None: 115 | self.landmarks = Model.rotate_model(axis, theta, self.landmarks) 116 | 117 | @staticmethod 118 | def rotate_model(axis, theta, data): 119 | theta = np.radians(theta) 120 | cos_t = np.cos(theta) 121 | sin_t = np.sin(theta) 122 | if axis == 'x': 123 | rotation_matrix = Rotation.from_quat([0, sin_t, 0, cos_t]) 124 | if axis == 'y': 125 | rotation_matrix = Rotation.from_quat([sin_t, 0, 0, cos_t]) 126 | if axis == 'z': 127 | rotation_matrix = Rotation.from_quat([0, 0, sin_t, cos_t]) 128 | return rotation_matrix.apply(data) 129 | 130 | def load_model(self, path_data): 131 | self.filename, self.file_extension = os.path.splitext(path_data) 132 | 133 | if os.path.exists(self.filename + ".png"): 134 | self.bgImage = self.filename + ".png" 135 | 136 | if self.file_extension == ".mat": 137 | file = h5py.File(path_data, 'r') 138 | try: 139 | self.set_points(np.transpose(np.array(file["avgModel"]))) 140 | self.set_landmarks(np.transpose(np.array(file["landmarks3D"]))) 141 | except Exception as ex: 142 | print("File is not compatible:" + ex) 143 | 144 | if self.file_extension == ".wrl": 145 | self.set_points(file3D.load_wrml(path_data)) 146 | self.set_landmarks(file3D.load_bnd(self.filename + ".bnd")) 147 | if self.bgImage is not None and os.path.exists(self.bgImage): 148 | self.bgImage = self.filename[:-3] + "F2D.png" 149 | 150 | if self.file_extension == ".off": 151 | self.set_points(file3D.load_off(path_data)) 152 | 153 | row = "Model loaded: " + str(self.points.shape[0]) + " points" 154 | 155 | if self.landmarks is not None: 156 | row += " and " + str(self.landmarks.shape[0]) + " landmarks." 157 | 158 | self.init_registration_points() 159 | Logger.addRow(row) 160 | 161 | @staticmethod 162 | def __path_leaf__(path): 163 | head, tail = ntpath.split(path) 164 | return tail or ntpath.basename(head) 165 | 166 | @staticmethod 167 | def decimate(old_array, percentage): 168 | if percentage >= 100: 169 | return old_array 170 | 171 | le, _ = old_array.shape 172 | useful_range = np.arange(le) 173 | np.random.shuffle(useful_range) 174 | limit = int(le / 100 * percentage) 175 | new_arr = np.empty((limit, 3)) 176 | rr = np.arange(limit) 177 | for count in rr: 178 | new_arr[count] = old_array[useful_range[count]] 179 | 180 | return new_arr 181 | -------------------------------------------------------------------------------- /pointRegistration/file3D.py: -------------------------------------------------------------------------------- 1 | from graphicInterface.console import Logger 2 | import numpy as np 3 | import os 4 | import h5py 5 | 6 | 7 | def load_wrml(path): 8 | """ 9 | Restituisce un array Nx3 con le coordinate 3D dei punti contenuti in un file .wrl. 10 | 11 | Attenzione, il file deve avere questa precisa struttura: 12 | **** roba***** 13 | ************* 14 | point [ 15 | x.xxx y.yyyyyy z.zzzzz, 16 | x.xxx y.yyyyyy z.zzzzz, 17 | **** 18 | **** 19 | **** 20 | x.xxx y.yyyyyy z.zzzzz, 21 | x.xxx y.yyyyyy z.zzzzz 22 | ] 23 | ****** roba ****** 24 | ****************** 25 | 26 | Notare che manca l'ultima virgola 27 | 28 | Ai fini del caricamento il resto del contenuto non è importante, ma essendo questo 29 | un parser a riga robusto come un bicchiere di cristallo è consigliabile assicurarsi 30 | che il ile rispetti la struttura descritta prima di dare colpa all'implementazione 31 | da veri n00b. 32 | 33 | :param path: pathname del file (" pippo.wrl") 34 | :return: un array numpy Nx3, dove N è il numero dei punti contenuti nel file 35 | """ 36 | 37 | file = load_file(path) 38 | if file is None: 39 | return None 40 | 41 | while file.readline().strip() != "point [": 42 | pass 43 | 44 | # arrivo alla prima riga dei punti 45 | points_num = 0 46 | points = np.empty((1000, 3)) 47 | line = file.readline().strip() 48 | while line.strip() != "]": 49 | x, y, z = line[0:len(line)-1].split(" ") 50 | points[points_num] = [float(x), float(y), float(z)] 51 | points_num +=1 52 | if points_num == points.size/3: 53 | points = np.concatenate([points, np.empty((1000, 3))]) 54 | line = file.readline().strip() 55 | return points[0:points_num, :] 56 | 57 | 58 | def load_bnd(path): 59 | """ 60 | Restituisce un array Nx3 con le coordinate 3D dei landmarks contenuti in un file .bnd. 61 | 62 | Attenzione, il file deve avere questa precisa struttura: 63 | nnnn x.xxx y.yyyyyy z.zzzzz 64 | nnnn x.xxx y.yyyyyy z.zzzzz 65 | nnnn x.xxx y.yyyyyy z.zzzzz 66 | ******* 67 | ******* 68 | ******* 69 | nnnn x.xxx y.yyyyyy z.zzzzz 70 | nnnn x.xxx y.yyyyyy z.zzzzz 71 | 72 | Ai fini del caricamento il resto del contenuto non è importante, ma essendo questo 73 | un parser a riga robusto come un bicchiere di cristallo è consigliabile assicurarsi 74 | che il ile rispetti la struttura descritta prima di dare colpa all'implementazione 75 | da veri n00b. 76 | 77 | :param path: pathname del file (" pippo.wrl") 78 | :return: un array numpy Nx3, dove N è il numero dei punti contenuti nel file 79 | """ 80 | 81 | file = load_file(path) 82 | if file is None: 83 | return None 84 | 85 | # arrivo alla prima riga dei punti 86 | points_num = 0 87 | points = np.empty((10, 3)) 88 | line = file.readline().strip() 89 | line = line.replace("\t\t", " ") 90 | while line.strip() != "": 91 | nn, x, y, z = line.split(" ") 92 | points[points_num] = [float(x), float(y), float(z)] 93 | points_num += 1 94 | if points_num == points.size/3: 95 | points = np.concatenate([points, np.empty((10, 3))]) 96 | line = file.readline().replace("\t\t", " ").strip() 97 | return points[0:points_num, :] 98 | 99 | 100 | def load_off(file, faces_required=False): 101 | """ 102 | Reads vertices and faces from an off file. 103 | 104 | :param file: path to file to read 105 | :type file: str 106 | :param faces_required: True if the function should return faces also 107 | :type: bool 108 | :return: vertices and faces as lists of tuples 109 | :rtype: [(float)], [(int)] 110 | """ 111 | 112 | assert os.path.exists(file) 113 | 114 | with open(file, 'r') as fp: 115 | lines = fp.readlines() 116 | lines = [line.strip() for line in lines] 117 | 118 | assert (lines[0] == 'OFF'), "Invalid preambole" 119 | 120 | parts = lines[1].split(' ') 121 | assert (len(parts) == 3), "Need exactly 3 parameters on 2nd line (n_vertices, n_faces, n_edges)." 122 | 123 | num_vertices = int(parts[0]) 124 | assert num_vertices > 0 125 | 126 | num_faces = int(parts[1]) 127 | if faces_required: 128 | assert num_faces > 0 129 | 130 | vertices = [] 131 | for i in range(num_vertices): 132 | vertex = lines[2 + i].split(' ') 133 | vertex = [float(point) for point in vertex] 134 | assert (len(vertex) == 3), str("Invalid vertex row on line " + str(i)) 135 | 136 | vertices.append(vertex) 137 | 138 | if num_vertices > len(vertices): 139 | row = "WARNING: some vertices were not loaded correctly: {0} declared vs {1} loaded." 140 | Logger.addRow(row.format(num_vertices, len(vertices))) 141 | 142 | vertices = np.asarray(vertices) 143 | 144 | if faces_required: 145 | faces = [] 146 | for i in range(num_faces): 147 | face = lines[2 + num_vertices + i].split(' ') 148 | face = [int(index) for index in face] 149 | 150 | assert face[0] == len(face) - 1 151 | for index in face: 152 | assert 0 <= index < num_vertices 153 | 154 | assert len(face) > 1 155 | 156 | faces.append(face) 157 | return vertices, faces 158 | 159 | return vertices 160 | 161 | 162 | def load_file(path): 163 | try: 164 | file = open(path, "r") 165 | except FileNotFoundError as ex: 166 | print("Can't find the file") 167 | return None 168 | except FileExistsError as ex: 169 | print("File exists but got troubles") 170 | return None 171 | except Exception as ex: 172 | print("Something gone wrong") 173 | return None 174 | 175 | return file 176 | 177 | 178 | def save_file(filepath, model): 179 | filename, file_extension = os.path.splitext(filepath) 180 | 181 | if file_extension == '.off': 182 | save_off(filepath, model) 183 | if file_extension == '.wrl': 184 | save_wrl(filepath, model) 185 | if file_extension == '.mat': 186 | save_mat(filepath, model) 187 | 188 | 189 | def save_off(filepath, model): 190 | with open(filepath, "w") as file: 191 | file.write("OFF\n") 192 | n_points = int(model["model_data"].size / 3) 193 | file.write(str(n_points) + " 0 0\n") 194 | row = "{0} {1} {2}\n" 195 | for index in range(n_points): 196 | x, y, z = model["model_data"][index] 197 | file.write(row.format(x, y, z)) 198 | file.close() 199 | 200 | 201 | def save_mat(filepath, model): 202 | f = h5py.File(filepath, "w") 203 | for key, value in model.items(): 204 | f.create_dataset(key, data=value) 205 | f.close() 206 | 207 | 208 | def save_wrl(filepath, model): 209 | pass -------------------------------------------------------------------------------- /graphicInterface/main_widget.py: -------------------------------------------------------------------------------- 1 | import os 2 | from graphicInterface.plot_interactive_figure import PlotInteractiveFigure 3 | from graphicInterface.rotatable_figure import RotatableFigure 4 | from graphicInterface.show_displacement import DisplacementMapWindow 5 | from graphicInterface.upper_toolbar import * 6 | from pointRegistration.batchRegistration import BatchRegistrationThread 7 | from pointRegistration.model import Model 8 | from pointRegistration.registration import Registration 9 | 10 | 11 | class MainWidget(QWidget): 12 | def __init__(self, parent): 13 | super().__init__(parent) 14 | Logger.addRow(str("Starting up..")) 15 | self.source_model = Model(os.path.join(".", "data", "avgModel_bh_1779_NE.mat")) 16 | self.target_model = None 17 | self.sx_widget = None 18 | self.dx_widget = None 19 | self.registration_thread = None 20 | self.toolbar = None 21 | self.registrated = None 22 | Logger.addRow(str("Ready.")) 23 | self.initUI() 24 | 25 | def initUI(self): 26 | grid_central = QGridLayout(self) 27 | self.setLayout(grid_central) 28 | self.sx_widget = PlotInteractiveFigure(self, self.source_model, title="Source") 29 | self.dx_widget = RotatableFigure(self, None, title="Target") 30 | grid_central.addWidget(self.sx_widget, 1, 0, 1, 2) 31 | self.sx_widget.draw_data() 32 | grid_central.addWidget(self.dx_widget, 1, 2, 1, 2) 33 | self.dx_widget.draw_data() 34 | self.toolbar = UpperToolbar(self) 35 | grid_central.addWidget(self.toolbar, 0, 0, 1, 4) 36 | grid_central.setRowStretch(0, 1) 37 | grid_central.setRowStretch(1, 30) 38 | 39 | def load_target(self): 40 | filters = "OFF Files (*.off);;WRML Files (*.wrml);;MAT Files (*.mat)" 41 | file_name = load_file_dialog(self, filters) 42 | if file_name is None: 43 | return 44 | 45 | self.toolbar.start_registration_button.setEnabled(True) 46 | self.target_model = Model(file_name) 47 | self.dx_widget.load_model(self.target_model) 48 | self.dx_widget.draw_data(clear=True) 49 | Logger.addRow(str("File loaded correctly: " + file_name)) 50 | self.toolbar.save_target_btn.setEnabled(True) 51 | 52 | def load_source(self): 53 | filters = "OFF Files (*.off);;WRML Files (*.wrml);;MAT Files (*.mat)" 54 | file_name = load_file_dialog(self, filters) 55 | if file_name is None: 56 | return 57 | 58 | self.toolbar.start_registration_button.setEnabled(True) 59 | self.source_model = Model(file_name) 60 | self.sx_widget.load_model(self.source_model) 61 | self.sx_widget.draw_data() 62 | Logger.addRow(str("File loaded correctly: " + file_name)) 63 | 64 | def restore(self): 65 | self.restore_highlight() 66 | if self.dx_widget.has_model(): 67 | self.restore_target() 68 | self.toolbar.start_registration_button.setEnabled(True) 69 | 70 | def restore_highlight(self): 71 | self.sx_widget.highlight_data([-1]) 72 | 73 | def restore_target(self): 74 | self.dx_widget.restore_model() 75 | 76 | def landmark_selected(self, colors): 77 | self.dx_widget.set_landmarks_colors(colors) 78 | 79 | def data_selected(self, x_coord, y_coord, width, height): # apply color to target 80 | self.dx_widget.select_area(x_coord, y_coord, width, height) 81 | 82 | def registrate(self, method, percentage): 83 | if not self.sx_widget.there_are_points_highlighted(): # names are everything 84 | QMessageBox.critical(self, 'Error', "No rigid points have been selected.") 85 | raise Exception("No rigid points selected") 86 | 87 | if self.dx_widget.model is None: 88 | QMessageBox.critical(self, 'Error', "Please, load a target model.") 89 | raise Exception("Target model is not present.") 90 | 91 | if self.registration_thread is None: 92 | self.toolbar.show_displacement_btn.setEnabled(False) 93 | self.parent().setStatus("Busy...") 94 | self.registration_thread = Registration(method, self.source_model, self.target_model, percentage, 95 | self.registration_completed_callback, 96 | self.dx_widget.update_plot_callback) 97 | self.registration_thread.start() 98 | 99 | def stop_registration_thread(self): 100 | if self.registration_thread is not None: 101 | Logger.addRow(str("Trying to stop registration thread...")) 102 | self.registration_thread.stop() 103 | 104 | def show_displacement_map(self): 105 | DisplacementMapWindow(self.parent(), self.source_model.compute_displacement_map(self.target_model, 3)) # FIXME 106 | self.parent().setStatusReady() 107 | 108 | def save_target(self): 109 | if self.target_model is None: 110 | QMessageBox.critical(self, 'Error', "The source model was not registered yet.") 111 | return 112 | 113 | filters = "MAT File (*.mat);;OFF File (*.off);;" 114 | filename = save_file_dialog(self, filters) 115 | if filename is None: 116 | return 117 | self.target_model.save_model(filename) #fixme controllare che venga realmente salvato 118 | 119 | def registration_completed_callback(self, model): 120 | Logger.addRow(str("Registration completed.")) 121 | self.target_model.bgImage = self.dx_widget.bgImage 122 | self.dx_widget.clear() 123 | 124 | self.dx_widget.set_secondary_model(model) 125 | self.dx_widget.load_model(self.target_model) 126 | self.target_model = model 127 | 128 | self.dx_widget.draw_data() 129 | self.parent().setStatusReady() 130 | self.registration_thread = None 131 | self.toolbar.stop_registration_button.setEnabled(False) 132 | self.toolbar.show_displacement_btn.setEnabled(True) 133 | self.registrated = True 134 | self.parent().setStatus("Registration completed, displacement map available. Click Show Displacement Map.") 135 | # Target and Source are now plotted in target widget 136 | 137 | def registrate_batch_callback(self): 138 | try: 139 | Logger.addRow(str("Registration completed.")) 140 | self.registration_thread = None 141 | self.parent().setStatus("Ready.") 142 | except Exception as ex: 143 | print(ex) 144 | 145 | def registrate_batch(self, method, percentage, filenames): 146 | 147 | if not self.sx_widget.there_are_points_highlighted(): # names are everything 148 | QMessageBox.critical(self, 'Error', "No rigid points have been selected.") 149 | raise Exception("No rigid points selected") 150 | 151 | if self.registration_thread is None and self.sx_widget.there_are_points_highlighted(): 152 | self.parent().setStatus("Busy...") 153 | 154 | self.registration_thread = BatchRegistrationThread(self.sx_widget.model, filenames, percentage, 155 | self.registrate_batch_callback) 156 | self.registration_thread.start() 157 | 158 | @staticmethod 159 | def savelog_onfile(): 160 | Logger.save_log() 161 | -------------------------------------------------------------------------------- /pointRegistration/displacementMap.py: -------------------------------------------------------------------------------- 1 | import multiprocessing 2 | import os 3 | import pickle # Look Morty, I'm a pickle! 4 | 5 | import h5py 6 | import matplotlib.pyplot as plt 7 | import numpy as np 8 | from scipy.spatial import cKDTree 9 | 10 | from graphicInterface.console import Logger 11 | from pointRegistration.model import Model 12 | 13 | 14 | class DisplacementMap(Model): 15 | 16 | def __init__(self, points, landmarks, points_color, landmarks_color, missed_points, missed_points_color, 17 | missed_landmarks, missed_landmarks_color): 18 | self.init_attributes() 19 | self.set_points(points) 20 | self.set_landmarks(landmarks) 21 | self.points_color = points_color 22 | self.landmarks_color = landmarks_color 23 | self.missed_points = missed_points 24 | self.missed_points_color = missed_points_color 25 | self.missed_landmarks = missed_landmarks 26 | self.missed_landmarks_color = missed_landmarks_color 27 | 28 | self.rangeX = np.ptp(np.append(self.points[:, 0], self.missed_points[:, 0])) 29 | self.rangeY = np.ptp(np.append(self.points[:, 1], self.missed_points[:, 1])) 30 | 31 | @classmethod 32 | def compute_map(cls, source_model, target_model, max_dist=2): 33 | """:param source_model 34 | :type source_model Model 35 | 36 | :param target_model 37 | :type target_model Model 38 | 39 | :param max_dist 40 | :type max_dist number 41 | """ 42 | 43 | source_points = source_model.points 44 | target_points = target_model.points 45 | 46 | hit_landmarks = None 47 | missed_landmarks = None 48 | landmarks_color = None 49 | missed_landmarks_color = None 50 | 51 | # Points KDTree 52 | tree = cKDTree(source_points) 53 | _, indices = tree.query(target_points, distance_upper_bound=max_dist, n_jobs=multiprocessing.cpu_count()) 54 | # Indices contains source's neighbors indices to target's points used for query 55 | indices = set(indices) 56 | indices.discard(source_points.shape[0]) 57 | indices = list(indices) 58 | hit_points = source_points[indices] 59 | missed_points = np.delete(source_points, indices, axis=0) # Points without a feasible neighbor 60 | missed_color = "y" 61 | points_color = "b" 62 | # LandmarksKDTree 63 | if source_model.landmarks is not None and target_model.landmarks is not None: 64 | tree = cKDTree(source_model.landmarks) 65 | _, indices = tree.query(target_model.landmarks, distance_upper_bound=max_dist, n_jobs=multiprocessing.cpu_count()) 66 | indices = set(indices) 67 | indices.discard(source_model.landmarks.shape[0]) 68 | indices = list(indices) 69 | hit_landmarks = source_model.landmarks[indices] 70 | missed_landmarks = np.delete(source_model.landmarks, indices, axis=0) 71 | landmarks_color = "r" 72 | missed_landmarks_color = "p" 73 | 74 | Logger.addRow("Displacement map created.") 75 | return cls(points=hit_points, landmarks=hit_landmarks, points_color=points_color, 76 | landmarks_color=landmarks_color, missed_points=missed_points, missed_landmarks=missed_landmarks, 77 | missed_points_color=missed_color, missed_landmarks_color=missed_landmarks_color) 78 | 79 | def save_model(self, filename): 80 | name, file_extension = os.path.splitext(filename) 81 | 82 | if file_extension == ".h5py": 83 | self.save_h5py(filename) 84 | if file_extension == ".pickle": 85 | self.save_pickle(filename) 86 | 87 | def save_pickle(self, filename): 88 | name, file_extension = os.path.splitext(filename) 89 | missed_map = self.missed_points 90 | hit_map = self.points 91 | if self.landmarks is not None: 92 | missed_map = np.concatenate(missed_map, self.missed_landmarks) 93 | hit_map = np.concatenate(hit_map, self.landmarks) 94 | pickle.dump(missed_map, str(name+"_missed_points" + file_extension), "wb") 95 | pickle.dump(hit_map, str(name+"_hit_points" + file_extension), "wb") 96 | 97 | def save_h5py(self, filename): 98 | def encode_ss_utf_8(ss): 99 | return np.void(ss.encode('utf-8')) 100 | # self = self.source_model.compute_displacement_map(self.target_model, 3) 101 | file = None 102 | try: 103 | file = h5py.File(filename, "w") 104 | file.create_dataset("points", data=self.points) 105 | file.create_dataset("missed_points", data=self.missed_points) 106 | file.attrs["points_color"] = encode_ss_utf_8(self.points_color) 107 | file.attrs["missed_points_color"] = encode_ss_utf_8(self.missed_points_color) 108 | if self.landmarks is not None: 109 | file.create_dataset("landmarks", data=self.landmarks) 110 | file.create_dataset("missed_landmarks", data=self.missed_landmarks) 111 | file.attrs["landmarks_color"] = encode_ss_utf_8(self.landmarks_color) 112 | file.attrs["missed_landmarks_color"] = encode_ss_utf_8(self.missed_landmarks_color) 113 | except Exception as ex: 114 | Logger.addRow("ERROR: Displacement map was not saved. =>" + str(ex)) 115 | finally: 116 | if file is not None: 117 | file.close() 118 | Logger.addRow(f'File {filename} saved correctly.') 119 | 120 | @classmethod 121 | def load_model(cls, filename): 122 | try: 123 | file = h5py.File(filename, 'r') 124 | points = file.get("points").value 125 | missed_points = file.get("missed_points").value 126 | points_color = (file.attrs["points_color"]).tostring().decode('utf-8') 127 | missed_points_color = (file.attrs["missed_points_color"]).tostring().decode('utf-8') 128 | 129 | # If landmarks are not present a KeyError is raised 130 | landmarks = file.get("landmarks").value 131 | missed_landmarks = file.get("missed_landmarks").value 132 | landmarks_color = (file.attrs["landmarks_color"]).tostring().decode('utf-8') 133 | missed_landmarks_color = (file.attrs["missed_landmarks_color"]).tostring().decode('utf-8') 134 | 135 | except AttributeError: 136 | Logger.addRow("INFO: Displacement model has no landmarks.") 137 | landmarks = None 138 | landmarks_color = None 139 | missed_landmarks = None 140 | missed_landmarks_color = None 141 | except FileNotFoundError: 142 | Logger.addRow(f"ERROR: {filename} not found, check path.") 143 | 144 | except Exception as ex: 145 | Logger.addRow(f"ERROR: Problem during opening of {filename} =>" + str(ex)) 146 | 147 | finally: 148 | displacement_model = cls(points=points, landmarks=landmarks, 149 | missed_points=missed_points, 150 | missed_landmarks=missed_landmarks, 151 | points_color=points_color, 152 | landmarks_color=landmarks_color, 153 | missed_points_color=missed_points_color, 154 | missed_landmarks_color=missed_landmarks_color) 155 | 156 | Logger.addRow(f'File {filename} load correctly.') 157 | return displacement_model 158 | 159 | def shoot_displacement_map(self, filepath): 160 | plt.scatter(self.displacement_map[:, 0], self.displacement_map[:, 1], s=0.5) 161 | plt.savefig(str(filepath[0:-3]+"png")) 162 | plt.close() 163 | 164 | def rotate(self, axis, theta): 165 | super().rotate(axis, theta) 166 | try: 167 | self.missed_points = Model.rotate_model(axis, theta, self.missed_points) 168 | if self.missed_landmarks is not None: 169 | self.missed_landmarks = Model.rotate_model(axis, theta, self.missed_landmarks) 170 | except Exception as ex: 171 | print(ex) 172 | -------------------------------------------------------------------------------- /morpmodel/_3DMM.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from numpy.linalg import pinv 3 | from numpy.linalg import inv 4 | from scipy.spatial import ConvexHull 5 | 6 | class _3DMM: 7 | 8 | result = {}# dictionary to return 9 | 10 | def opt_3DMM_fast(self, Weights, Components, Components_res, landmarks3D, id_landmarks_3D, landImage, avgFace, _lambda, rounds, r, C_dist, ex_to_me): 11 | _Weights = Weights 12 | index = np.array(id_landmarks_3D, dtype=np.intp) 13 | app_var = np.reshape(avgFace[index,:],(landmarks3D.shape[0],3),order='F') 14 | # _var --> variabili di appoggio temporanee 15 | [_Aa, _Sa, _Ra, _Ta] = self.estimate_pose(app_var, landImage) 16 | _proj = np.transpose(self.getProjectedVertex(app_var, _Sa, _Ra, _Ta)) 17 | # deform shape 18 | _alpha = self.alphaEstimation(landImage, _proj, Components_res, id_landmarks_3D, _Sa, _Ra, _Ta, Weights, _lambda) 19 | _defShape = self.deform_3D_shape_fast(np.transpose(avgFace), Components, _alpha, ex_to_me) 20 | _defShape = np.transpose(_defShape) 21 | _defLand = np.reshape(_defShape[index,:],(landmarks3D.shape[0],3),order='F') 22 | _proj = np.transpose(self.getProjectedVertex(_defLand, _Sa, _Ra, _Ta)) 23 | for i in range(rounds): 24 | [_Aa, _Sa, _Ra, _Ta] = self.estimate_pose(_defLand, landImage) 25 | _proj = np.transpose(self.getProjectedVertex(_defLand, _Sa, _Ra, _Ta)) 26 | _alpha = self.alphaEstimation(landImage, _proj, Components_res, id_landmarks_3D, _Sa, _Ra, _Ta, Weights,_lambda) 27 | _defShape = self.deform_3D_shape_fast(np.transpose(_defShape),Components,_alpha, ex_to_me) 28 | _defShape = np.transpose(_defShape) 29 | _defLand = np.reshape(_defShape[index, :], (landmarks3D.shape[0], 3), order='F') 30 | 31 | [_Aa, _Sa, _Ra, _Ta] = self.estimate_pose(_defLand, landImage) 32 | _proj = np.transpose(self.getProjectedVertex(_defLand, _Sa, _Ra, _Ta)) 33 | _visIdx = self.estimateVis_vertex(_defShape, _Ra, C_dist, r) 34 | _visIdx = np.delete(_visIdx, (0), axis=0) 35 | 36 | # assing values to dictionary 37 | self.result["A"] = _Aa 38 | self.result["S"] = _Sa 39 | self.result["R"] = _Ra 40 | self.result["T"] = _Ta 41 | self.result["defShape"] = _defShape 42 | self.result["alpha"] = _alpha 43 | self.result["visIdx"] = _visIdx 44 | 45 | return self.result 46 | 47 | def estimate_pose(self, landModel, landImage): 48 | baricM = np.mean(landModel,axis=0) 49 | P = landModel - np.tile(baricM,(landModel.shape[0],1)) 50 | 51 | baricI = np.mean(landImage,axis=0) 52 | p = landImage - np.tile(baricI,(landImage.shape[0],1)) 53 | 54 | P = np.transpose(P) 55 | p = np.transpose(p) 56 | qbar = np.transpose(baricI) 57 | Qbar = np.transpose(baricM) 58 | A = p.dot(pinv(P)) 59 | [S,R] = np.linalg.qr(np.transpose(A), mode='complete') # A = S*R 60 | rr = S 61 | S = np.transpose(R) 62 | R = np.transpose(rr) 63 | t = qbar - A.dot(Qbar) 64 | # save data to the dictionary 65 | return [A,S,R,t] 66 | 67 | def getProjectedVertex(self,vertex,S,R,t): 68 | _vertex = self.transVertex(vertex) 69 | rotPc = np.transpose(R.dot(_vertex)) 70 | #t = np.reshape(t,(2,1),order='F') 71 | t = np.reshape(t, (3,1), order='F') 72 | app = np.tile(t, (1,rotPc.shape[0])) 73 | return S.dot(np.transpose(rotPc)) + app 74 | 75 | def transVertex(self,vertex): 76 | v = vertex 77 | if vertex.shape[0] != 3: 78 | return np.transpose(v) 79 | else: 80 | return v 81 | 82 | 83 | def alphaEstimation(self, landImage, projLandModel, Components_res, id_landmarks_3D, S, R, t, Weights, _lambda): 84 | Weights = np.reshape(Weights, (Weights.shape[0], 1)) 85 | X = landImage - projLandModel 86 | X = X.flatten(order='F') 87 | Y = np.zeros((X.shape[0],Components_res.shape[2])) 88 | index = np.array(id_landmarks_3D,dtype=np.intp) 89 | 90 | for c in range(Components_res.shape[2]): 91 | vect = Components_res[index,:,c].reshape(landImage.shape[0],3,order='F') 92 | vertexOnImComp = np.transpose(self.getProjectedVertex(vect,S,R,t)) 93 | Y[:,c] = vertexOnImComp.flatten(order='F') 94 | if _lambda == 0: 95 | Alpha = Y.divide(X) 96 | else: 97 | with np.errstate(divide='ignore'): 98 | invW = np.diag( _lambda/(np.diagflat(Weights))) # as diag in malab 99 | #if invW.shape[1] != 1: 100 | # invW = np.diagflat(invW) 101 | var = (np.transpose(Y)).dot(Y) 102 | res = var + np.diagflat(invW) 103 | YY = inv(res) 104 | app = np.dot(YY, np.transpose(Y)) 105 | Alpha = np.dot(app,X) 106 | return Alpha 107 | 108 | def alpha_estimation_test(self, avg_colors, target_colors, Components, Weights, _lambda): 109 | avg_colors = np.reshape(avg_colors, (avg_colors.shape[0], 1)) 110 | target_colors = np.reshape(target_colors, (target_colors.shape[0], 1)) 111 | 112 | X = target_colors - avg_colors 113 | Y = Components 114 | with np.errstate(divide='ignore'): 115 | invW = np.diag(_lambda / (np.diagflat(Weights))) # as diag in malab 116 | var = (np.transpose(Y)).dot(Y) 117 | res = var + np.diagflat(invW) 118 | YY = inv(res) 119 | app = np.dot(YY, np.transpose(Y)) 120 | Alpha = np.dot(app, X) 121 | return Alpha 122 | 123 | 124 | def deform_3D_shape_fast(self, mean_face, eigenvecs, alpha, ex_to_me): 125 | dim = (eigenvecs.shape[0])//3 126 | alpha_full = np.tile(np.transpose(alpha), (eigenvecs.shape[0],1)) 127 | tmp_eigen = alpha_full*eigenvecs 128 | sumVec = np.sum(tmp_eigen, axis=1) # somma attraverso le righe 129 | sumMat = np.reshape(np.transpose(sumVec), (3, dim), order='F') 130 | return mean_face + sumMat 131 | 132 | 133 | def estimateVis_vertex(self, vertex, R, C_dist, r): 134 | viewPoint_front = np.array([0,0,C_dist]).reshape(1,3, order='F') 135 | viewPoint = np.transpose(self.rotatePointCloud(viewPoint_front, R, [])) 136 | visIdx = self.HPR(vertex, viewPoint, r) # controllare la funzione HPR 137 | 138 | return visIdx 139 | 140 | def rotatePointCloud(self, P, R, t): 141 | #if not t: 142 | # _tile = np.tile(np.tranpose(t),(t.shape[1],1)) 143 | # P = P + np.transpose(_tile) 144 | return np.dot(R, np.transpose(P)) 145 | 146 | def HPR(self, p, C, param): 147 | dim = p.shape[1] 148 | numPts = p.shape[0] 149 | p = p - np.tile(C,(numPts,1)) 150 | #normP = np.sqrt(p.dot(p)) 151 | normP = np.linalg.norm(p, axis=1) 152 | normP = normP.reshape(normP.shape[0], 1, order='F') 153 | app = np.amax(normP)*(np.power(10,param)) 154 | R = np.tile(app,(numPts,1)) 155 | 156 | P = p + 2*np.tile((R-normP),(1,dim))*p/np.tile(normP,(1,dim)) 157 | _zeros = np.zeros((1,dim)) 158 | vect_conv_hull = np.vstack([P,_zeros]) 159 | #into_unique = ConvexHull(vect_conv_hull) 160 | hull = ConvexHull(vect_conv_hull) 161 | visiblePtInds = np.unique(hull.vertices) 162 | #visiblePtInds(visiblePtInds==numPts+1)=[] # ?????????? 163 | for i in range(visiblePtInds.shape[0]): 164 | visiblePtInds[i] = visiblePtInds[i] - 1 165 | if visiblePtInds[i] == (numPts + 1): 166 | visiblePtInds.remove(i) 167 | 168 | return visiblePtInds.reshape(visiblePtInds.shape[0], 1, order='F') 169 | 170 | def getVisLand(self, vertex, landImage, visIdx, id_landmarks_3D): 171 | # deve ritornare visLand, idxVland, Nvis 172 | visLand = np.intersect1d(visIdx, id_landmarks_3D) 173 | # id_Vland 174 | mb = np.in1d(id_landmarks_3D, visIdx) 175 | id_Vland = np.nonzero(mb)[0] 176 | 177 | app = np.empty([1,landImage.shape[0]]) 178 | for i in range(1,landImage.shape[0]): 179 | app[0,i] = i 180 | 181 | Nvis = np.setxor1d(app,id_Vland) 182 | return [visLand, id_Vland, Nvis] 183 | 184 | -------------------------------------------------------------------------------- /morpmodel/util_for_graphic.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.interpolate import LinearNDInterpolator 3 | from numpy import dot, mean, std, argsort 4 | from numpy.linalg import eigh 5 | from scipy.spatial import ConvexHull 6 | from matplotlib import path 7 | 8 | 9 | class graphic_tools: 10 | _3DMM_obj = () 11 | 12 | def __init__(self, _3dmm_obj): 13 | if _3dmm_obj is None: 14 | self._3DMM_obj = None 15 | else: 16 | self._3DMM_obj = _3dmm_obj 17 | 18 | def render3DMM(self, xIm, yIm, rgb, w, h): 19 | # need to perfomr the proj shape from the 3D model of the shape. The first column of projShape is xIm, the secondo is yIm. 20 | # The texture of the 3D model is the var rgb 21 | x_grid, y_grid = np.meshgrid([int(x) for x in range(w)], [int(y) for y in range(h)]) 22 | _first_cord = np.transpose(xIm) 23 | _second_cord = np.transpose(yIm) 24 | # Interpolate z values on the grid 25 | coords = self._2linear_vects(_first_cord.reshape(_first_cord.shape[0],1), _second_cord.reshape(_second_cord.shape[0],1)) 26 | Fr = LinearNDInterpolator(coords, np.transpose(rgb[:,0])) 27 | Fb = LinearNDInterpolator(coords, np.transpose(rgb[:,1])) 28 | Fg = LinearNDInterpolator(coords, np.transpose(rgb[:,2])) 29 | # Get values for each location 30 | imR = Fr(x_grid,y_grid) 31 | imG = Fb(x_grid,y_grid) 32 | imB = Fg(x_grid,y_grid) 33 | 34 | pts = np.column_stack([xIm, yIm]) 35 | pts = np.round(pts) 36 | pts = np.unique(pts, axis=0) 37 | allPoints=np.column_stack((pts[:,0],pts[:,1])) 38 | hullPoints = ConvexHull(allPoints) 39 | bb = hullPoints.vertices 40 | bb = np.array(bb, dtype=np.intp) 41 | mask = self.inpolygon(x_grid, y_grid, pts[bb,0], pts[bb,1]) 42 | mask = np.logical_not(mask) 43 | 44 | imR[mask] = 0 45 | imG[mask] = 0 46 | imB[mask] = 0 47 | # building image 48 | img = np.empty((imR.shape[0], imR.shape[1], 3)) 49 | img[:,:,0] = imR 50 | img[:,:,1] = imG 51 | img[:,:,2] = imB 52 | #img = img*255 53 | return np.array(img,dtype=np.uint8) 54 | 55 | def inpolygon(self, xq, yq, xv, yv): 56 | shape = xq.shape 57 | xq = xq.reshape(-1) 58 | yq = yq.reshape(-1) 59 | xv = xv.reshape(-1) 60 | yv = yv.reshape(-1) 61 | q = [(xq[i], yq[i]) for i in range(xq.shape[0])] 62 | p = path.Path([(xv[i], yv[i]) for i in range(xv.shape[0])]) 63 | return p.contains_points(q).reshape(shape) 64 | 65 | def denseResampling(self, defShape, proj_shape, img, S, R, t, visIdx): 66 | grid_step = 1 67 | max_sh = (np.amax(proj_shape, axis=0)) 68 | max_sh = np.reshape(max_sh,(1,2),order='F') 69 | min_sh = np.amin(proj_shape, axis=0) 70 | min_sh = np.reshape(min_sh, (1,2), order='F') 71 | Xsampling = np.arange(min_sh[0,0], max_sh[0,0], grid_step, dtype='float') 72 | Ysampling = np.arange(max_sh[0,1], min_sh[0,1], -grid_step, dtype='float') 73 | # round the float numbers 74 | Xsampling = np.around(Xsampling) 75 | Ysampling = np.around(Ysampling) 76 | 77 | print ("SAMPLING THE GRID") 78 | x_grid, y_grid = np.meshgrid(Xsampling, Ysampling) 79 | print ("3D LOCATION INTERPOLATION") 80 | index = np.array(visIdx, dtype=np.intp) 81 | X = proj_shape[index, 0] 82 | Y = proj_shape[index, 1] 83 | coords = self._2linear_vects(X,Y) 84 | #coords = list(zip(X,Y)) 85 | Fx = LinearNDInterpolator(coords, defShape[index, 0]) 86 | Fy = LinearNDInterpolator(coords, defShape[index, 1]) 87 | Fz = LinearNDInterpolator(coords, defShape[index, 2]) 88 | print ("PIXEL SAMPLING") 89 | x = Fx(x_grid.flatten(order='F'), y_grid.flatten(order='F')) 90 | y = Fy(x_grid.flatten(order='F'), y_grid.flatten(order='F')) 91 | z = Fz(x_grid.flatten(order='F'), y_grid.flatten(order='F')) 92 | print ("IMAGE BUILDING") 93 | x = x[~np.isnan(x).any(axis=1)] 94 | y = y[~np.isnan(y).any(axis=1)] 95 | z = z[~np.isnan(z).any(axis=1)] 96 | mod3d = np.dstack([x, y, z]) 97 | #mod3d = np.delete(mod3d, [0,1,2,3], axis=0) # c'erano 3 valori in piu dall'interpolazione...?? 98 | mod3d = np.reshape(mod3d,(mod3d.shape[0],3)) 99 | print ("FINAL RENDERING") 100 | projMod3d = np.transpose(self._3DMM_obj.getProjectedVertex(mod3d, S, R, t)) 101 | colors = self.getRGBtexture(np.round(projMod3d), img) 102 | colors = self._3DMM_obj.transVertex(colors) 103 | print ("DONE DENSE RESAMPLING") 104 | return [mod3d,colors] 105 | 106 | def resampleRGB(self,P,colors,step): 107 | min_x = np.amin(P[0,:], axis=0) 108 | max_x = np.amax(P[0,:], axis=0) 109 | min_y = np.amin(P[1,:], axis=0) 110 | max_y = np.amax(P[1,:], axis=0) 111 | Xsampling = np.arange(min_x, max_x, step, dtype='float') 112 | Ysampling = np.arange(max_y, min_y, -step, dtype='float') 113 | x, y = np.meshgrid(Xsampling, Ysampling) 114 | _first_cord = np.transpose(P[0,:]) 115 | _second_cord = np.transpose(P[1,:]) 116 | coords = self._2linear_vects(_first_cord.reshape(_first_cord.shape[0],1), _second_cord.reshape(_second_cord.shape[0],1)) 117 | #Fr = griddata(coords, np.transpose(colors[0, :]), (x, y), method='linear') 118 | #Fg = griddata(coords, np.transpose(colors[1, :]), (x, y)) 119 | #Fb = griddata(coords, np.transpose(colors[2, :]), (x, y)) 120 | 121 | Fr = LinearNDInterpolator(coords, np.transpose(colors[0, :])) 122 | Fb = LinearNDInterpolator(coords, np.transpose(colors[1, :])) 123 | Fg = LinearNDInterpolator(coords, np.transpose(colors[2, :])) 124 | 125 | texR = Fr(x,y) 126 | texG = Fb(x,y) 127 | texB = Fg(x,y) 128 | #remove Nan elements 129 | mask = np.isnan(texR) 130 | index = np.array(mask, dtype=np.intp) 131 | texR[mask] = 0 132 | texG[mask] = 0 133 | texB[mask] = 0 134 | img = np.empty((texR.shape[0], texR.shape[1], 3)) 135 | img[:,:,0] = texR 136 | img[:,:,1] = texG 137 | img[:,:,2] = texB 138 | return img, Xsampling, Ysampling 139 | 140 | 141 | 142 | '''texR = np.nan_to_num(Fr(x, y)) 143 | texG = np.nan_to_num(Fg(x, y)) 144 | texB = np.nan_to_num(Fb(x, y)) 145 | img = np.empty((texR.shape[0], texR.shape[1], 3)) 146 | img[:, :, 0] = texR 147 | img[:, :, 1] = texG 148 | img[:, :, 2] = texB 149 | #img = img * 255 150 | return img, Xsampling, Ysampling''' 151 | 152 | def _2linear_vects(self, X,Y): # Creo le coordinate (X,Y) da passare alla funzione interpolatrice 153 | coords = np.empty((1, 2)) 154 | for i in range(X.shape[0]): 155 | new_cord = np.array([X[i, 0], Y[i, 0]]).reshape(1, 2, order='F') 156 | coords = np.row_stack((coords, new_cord)) 157 | coords = np.delete(coords, (0), axis=0) 158 | return coords 159 | 160 | def getRGBtexture(self, coordTex, tex): 161 | R = (tex[:,:,0]).flatten(order='F') 162 | G = (tex[:,:,1]).flatten(order='F') 163 | B = (tex[:,:,2]).flatten(order='F') 164 | #print(coordTex.shape) 165 | Xtex = np.round(coordTex[:,0]) 166 | Ytex = np.round(coordTex[:,1]) 167 | Xtex = np.maximum(Xtex, np.tile(1, Xtex.shape[0])) 168 | Ytex = np.maximum(Ytex, np.tile(1, Ytex.shape[0])) 169 | Xtex = np.minimum(Xtex, np.tile(tex.shape[1], Xtex.shape[0])) 170 | Ytex = np.minimum(Ytex, np.tile(tex.shape[0], Ytex.shape[0])) 171 | #I = (self.sub2ind(tex.shape, Ytex, Xtex)).astype(int) 172 | #I = np.sub2ind(tex.shape, Ytex, Xtex) 173 | I = Ytex + (Xtex - 1) * tex.shape[0] 174 | I = I.astype(int) 175 | #print(I.shape) 176 | Ri = R[I] 177 | Gi = G[I] 178 | Bi = B[I] 179 | colors = np.empty((Ri.shape[0],1)) 180 | colors = np.hstack((colors, Ri.reshape(Ri.shape[0],1))) 181 | colors = np.hstack((colors, Gi.reshape(Gi.shape[0],1))) 182 | colors = np.hstack((colors, Bi.reshape(Bi.shape[0],1))) 183 | colors = np.delete(colors, (0), axis=1) 184 | colors = (colors.astype(float))/255.0 185 | 186 | return colors 187 | 188 | def renderFaceLossLess(self, defShape, projShape, img, S, R, T, renderingStep, visIdx): 189 | [mod3d, colors] = self.denseResampling(defShape, projShape, img, S, R, T, visIdx) 190 | frontal_view = self.build_image(np.transpose(mod3d), np.transpose(defShape), colors, renderingStep) 191 | return frontal_view, colors, mod3d 192 | 193 | def build_image(self, P_f, P, colors, step): 194 | print("START BUILD FRONTAL VIEW") 195 | min_x = np.amin(P[0,:]) 196 | max_x = np.amax(P[0,:]) 197 | min_y = np.amin(P[1,:]) 198 | max_y = np.amax(P[1,:]) 199 | Xsampling = np.arange(min_x, max_x, step, dtype='float') 200 | Ysampling = np.arange(max_y, min_y, -step, dtype='float') 201 | 202 | [x, y] = np.meshgrid(Xsampling, Ysampling) 203 | _first_cord = np.transpose(P_f[0,:]) 204 | _second_cord = np.transpose(P_f[1,:]) 205 | 206 | coords = self._2linear_vects(_first_cord.reshape(_first_cord.shape[0],1), _second_cord.reshape(_second_cord.shape[0],1)) 207 | Fr = LinearNDInterpolator(coords, np.transpose(colors[0,:])) 208 | Fb = LinearNDInterpolator(coords, np.transpose(colors[1,:])) 209 | Fg = LinearNDInterpolator(coords, np.transpose(colors[2,:])) 210 | texR = np.nan_to_num(Fr(x, y)) 211 | texG = np.nan_to_num(Fg(x, y)) 212 | texB = np.nan_to_num(Fb(x, y)) 213 | #mask = np.zeros((texR.shape[0], texR.shape[1]), dtype='bool') 214 | #for i in range(mask.shape[0]): 215 | # for j in range(mask.shape[1]): 216 | # if np.isnan(texR[i, j]): 217 | # mask[i, j] = True 218 | #texR = texR[~mask] 219 | img = np.empty((texR.shape[0],texR.shape[1],3)) 220 | img[:,:,0] = texR 221 | img[:,:,1] = texG 222 | img[:,:,2] = texB 223 | img = img*255 224 | #img = img[~mask] 225 | print("DONE") 226 | return np.array(img,dtype=np.uint8) 227 | 228 | def resize_imgs(self, img, size_row, size_col): 229 | rows = img.shape[0] 230 | cols = img.shape[1] 231 | # resize the rows 232 | if rows > size_row: 233 | diff = rows - size_row 234 | for i in range(diff): 235 | img = np.delete(img, (i), axis=0) 236 | else: 237 | diff = size_row - rows 238 | new_row = np.zeros((1, cols, 3)) 239 | for i in range(diff): 240 | img = np.vstack([img, new_row]) 241 | if cols > size_col: 242 | diff = cols - size_col 243 | for i in range(diff): 244 | img = np.delete(img, (i), axis=1) 245 | else: 246 | diff = size_col - cols 247 | new_col = np.zeros((img.shape[0],1, 3)) 248 | #print img.shape 249 | #print new_col.shape 250 | for i in range(diff): 251 | img = np.hstack([img,new_col]) 252 | return img 253 | 254 | #def sub2ind(self, array_shape, rows, cols): 255 | # return rows*array_shape[1] + cols 256 | 257 | def avgModel(self, object): 258 | # creo il modello medio 259 | rows = object[0].frontalView.shape[0] 260 | cols = object[0].frontalView.shape[1] 261 | R_cumulative_matrix = object[0].frontalView[:, :, 0] 262 | G_cumulative_matrix = object[0].frontalView[:, :, 1] 263 | B_cumulative_matrix = object[0].frontalView[:, :, 2] 264 | for i in range(len(object)): 265 | current_R = object[i].frontalView[:, :, 0] 266 | current_G = object[i].frontalView[:, :, 1] 267 | current_B = object[i].frontalView[:, :, 2] 268 | R_cumulative_matrix = R_cumulative_matrix + current_R 269 | G_cumulative_matrix = G_cumulative_matrix + current_G 270 | B_cumulative_matrix = B_cumulative_matrix + current_B 271 | mean_R = R_cumulative_matrix / len(object) 272 | mean_G = G_cumulative_matrix / len(object) 273 | mean_B = B_cumulative_matrix / len(object) 274 | avgModel = np.empty((rows, cols, 3)) 275 | avgModel[:,:,0] = mean_R 276 | avgModel[:,:,1] = mean_G 277 | avgModel[:,:,2] = mean_B 278 | 279 | return np.array(avgModel,dtype=np.uint8) 280 | 281 | def colors_AVG(self, object): 282 | cumulative_matrix_col = object[0].colors 283 | for i in range(len(object)): 284 | cumulative_matrix_col = cumulative_matrix_col + object[i].colors 285 | return cumulative_matrix_col/len(object) 286 | 287 | def deform_texture_fast(self, mean, eigenves, alpha, ex_to_ne): 288 | dim = eigenves.shape[0]/3 289 | alpha_full = np.tile(np.transpose(alpha), (eigenves.shape[0],1)) 290 | tmp_eigen = alpha_full*eigenves 291 | sumVec = tmp_eigen.sum(axis=1) 292 | sumVec = sumVec.reshape((sumVec.shape[0],1), order='F') 293 | sumMat = np.reshape(np.transpose(sumVec), (3,dim), order='F') 294 | return mean + sumMat 295 | 296 | def cov(self,X): 297 | """ 298 | Covariance matrix 299 | note: specifically for mean-centered data 300 | note: numpy's `cov` uses N-1 as normalization 301 | """ 302 | return dot(X.T, X) / X.shape[0] 303 | # N = data.shape[1] 304 | # C = empty((N, N)) 305 | # for j in range(N): 306 | # C[j, j] = mean(data[:, j] * data[:, j]) 307 | # for k in range(j + 1, N): 308 | # C[j, k] = C[k, j] = mean(data[:, j] * data[:, k]) 309 | # return C 310 | 311 | def pca(self,data, pc_count=None): 312 | """ 313 | Principal component analysis using eigenvalues 314 | note: this mean-centers and auto-scales the data (in-place) 315 | """ 316 | data -= mean(data, 0) 317 | data /= std(data, 0) 318 | C = self.cov(data) 319 | E, V = eigh(C) 320 | key = argsort(E)[::-1][:pc_count] 321 | E, V = E[key], V[:, key] 322 | U = dot(data, V) # used to be dot(V.T, data.T).T 323 | return U, E, V 324 | 325 | def PIL2array(self,img): 326 | return np.array(img.getdata(), 327 | np.uint8).reshape(img.size[1], img.size[0], 3) 328 | 329 | def mse(self, imageA, imageB): 330 | # the 'Mean Squared Error' between the two images is the 331 | # sum of the squared difference between the two images; 332 | # NOTE: the two images must have the same dimension 333 | err = np.sum((imageA.astype("float") - imageB.astype("float")) ** 2) 334 | err /= float(imageA.shape[0] * imageA.shape[1]) 335 | 336 | # return the MSE, the lower the error, the more "similar" 337 | # the two images are 338 | return err 339 | 340 | def compare_imgs(self,i1,i2): 341 | pairs = zip(i1.getdata(), i2.getdata()) 342 | if len(i1.getbands()) == 1: 343 | # for gray-scale jpegs 344 | dif = sum(abs(p1 - p2) for p1, p2 in pairs) 345 | else: 346 | dif = sum(abs(c1 - c2) for p1, p2 in pairs for c1, c2 in zip(p1, p2)) 347 | 348 | ncomponents = i1.size[0] * i1.size[1] * 3 349 | #print "Difference (percentage):", (dif / 255.0 * 100) / ncomponents 350 | return (dif / 255.0 * 100) / ncomponents 351 | 352 | def mean(self, obj): 353 | sum = 0 354 | for i in range(len(obj)): 355 | sum += obj[i].ssim_D 356 | return sum/len(obj) 357 | 358 | def plot_mesh(self, vertex, face, texture, texture_coords): 359 | # function to map texture (on a 2d image) to a 3d surface 360 | # It's not general, vertex was assigned both to FF and TF 361 | FF = face 362 | VV = vertex 363 | TF = face 364 | VT = texture_coords 365 | I = texture 366 | iscolor = True 367 | VT2 = VT 368 | sizep = 64 369 | [lambda1, lambda2, lambda3, jind] = self.calculateBarycentricInterpolationValues(sizep) 370 | Ir = I[:,:,0] 371 | Ig = I[:,:,1] 372 | Ib = I[:,:,2] 373 | Jr = np.zeros((sizep+1, sizep+1)) 374 | Jg = np.zeros((sizep+1, sizep+1)) 375 | Jb = np.zeros((sizep+1, sizep+1)) 376 | 377 | #for i in range(FF.shape[0]): 378 | for i in range(1): 379 | # current triangles of the mesh 380 | V = VV[FF[i,:],:] 381 | Vt = VT2[TF[i,:],:] 382 | 383 | # Define the triangle as a surface 384 | x = np.matrix([[V[0,0], V[1,0]], [V[2,0], V[2,0]]]) 385 | y = np.matrix([[V[0,1], V[1,1]], [V[2,1], V[2,1]]]) 386 | z = np.matrix([[V[0,2], V[1,2]], [V[2,2], V[2,2]]]) 387 | 388 | # Define the texture coordinates of the surface 389 | tx = np.matrix([Vt[0,0], Vt[1,0], Vt[2,0], Vt[2,0]]) 390 | ty = np.matrix([Vt[0,1], Vt[1,1], Vt[2,1], Vt[2,1]]) 391 | xy = np.matrix([ [[tx[0,0]], ty[0,0]], [tx[0,1], ty[0,1]], [tx[0,2], ty[0,2]], [tx[0,2], ty[0,2]] ]) 392 | pos = np.zeros([lambda1.shape[0], 2]) 393 | pos[:,0] = (xy[0,0]*lambda1+xy[1,0]*lambda2+xy[2,0]*lambda3).reshape((pos.shape[0])) 394 | pos[:,1] = (xy[0,1]*lambda1+xy[1,1]*lambda2+xy[2,1]*lambda3).reshape((pos.shape[0])) 395 | pos = np.round(pos) 396 | pos[:,0] = np.minimum(pos[:,0], I.shape[0]) 397 | pos[:,1] = np.minimum(pos[:,1], I.shape[1]) 398 | posind=(pos[:,0]-1)+(pos[:,1]-1)*I.shape[0]+1 399 | # indices 400 | jind = np.array(jind, dtype=np.intp) 401 | posind = np.array(posind, dtype=np.intp) 402 | Jr = Jr.flatten(order='F') 403 | Jg = Jr.flatten(order='F') 404 | Jb = Jr.flatten(order='F') 405 | Ir = Ir.flatten(order='F') 406 | Ig = Ir.flatten(order='F') 407 | Ib = Ir.flatten(order='F') 408 | 409 | J = np.zeros((sizep+1, sizep+1, 3)) 410 | 411 | Jr[jind-1] = Ir[posind-1] 412 | J[:,:,0] = Jr.reshape((J.shape[0], J.shape[0])) 413 | 414 | Jg[jind-1] = Ig[posind-1] 415 | J[:,:,1] = Jg.reshape((J.shape[0], J.shape[0])) 416 | 417 | Jb[jind-1] = Ib[posind-1] 418 | J[:,:,2] = Jb.reshape((J.shape[0], J.shape[0])) 419 | 420 | 421 | 422 | 423 | return J 424 | 425 | 426 | def calculateBarycentricInterpolationValues(self, sizep): 427 | # Define a triangle in the upperpart of an square, because only that 428 | # part is used by the surface function 429 | x1 = sizep 430 | y1 = sizep 431 | x2 = sizep 432 | y2 = 0 433 | x3 = 0 434 | y3 = 0 435 | detT = (x1-x3)*(y2-y3) - (x2-x3)*(y1-y3) 436 | [x,y] = np.meshgrid(np.arange(0,sizep+1), np.arange(0,sizep+1)) 437 | x = x.flatten() 438 | y = y.flatten() 439 | x = x.reshape((x.shape[0], 1)) 440 | y = y.reshape((y.shape[0], 1)) 441 | lambda1 = ((y2-y3)*(x-x3)+(x3-x2)*(y-y3))/detT 442 | lambda2 = ((y3-y1)*(x-x3)+(x1-x3)*(y-y3))/detT 443 | lambda3 = 1-lambda1-lambda2 444 | 445 | [jx,jy] = np.meshgrid(sizep-np.arange(0,sizep+1)+1, sizep-np.arange(0,sizep+1)+1) 446 | jind = (jx.flatten()-1)+(jy.flatten()-1)*(sizep+1)+1 447 | return np.reshape(lambda1, (lambda1.shape[0], 1)), np.reshape(lambda2, (lambda2.shape[0], 1)), np.reshape(lambda3, (lambda3.shape[0], 1)), jind 448 | --------------------------------------------------------------------------------