├── .gitignore ├── LICENSE ├── README.md ├── geomfmaps ├── config.yaml ├── eval.py ├── model.py ├── shape_matching_dataset.py ├── train.py └── utils.py └── requirement.txt /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Souhaib Attaiki 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![report](https://img.shields.io/badge/arxiv-report-green)](https://arxiv.org/pdf/2003.14286.pdf) 2 | 3 | :warning: :rotating_light: this code base is no longer maintained :confused: 4 | 5 | # GeomFmaps-pytorch 6 | A minimalist pytorch implementation of: "Deep Geometric Functional Maps: Robust Feature Learning for Shape Correspondence" [[1]](#bookmark-references), appeared in [CVPR 2020](http://cvpr2020.thecvf.com/). 7 | 8 | ## Installation 9 | This implementation runs on python >= 3.7, use pip to install dependencies: 10 | ``` 11 | pip3 install -r requirements.txt 12 | ``` 13 | 14 | ## Download data & preprocessing 15 | The preprocessing code will be added later. 16 | For the moment, we refer the reader to the [original implementation](https://github.com/LIX-shape-analysis/GeomFmaps) of GeomFmaps to download the data and the preprocessing code. 17 | 18 | It should be noted that for each dataset (faust, scape, etc), this module expect that the dataset folder contains 3 folders: 19 | 20 | * `off` folder: this folder contains the meshes 21 | * `spectral` folder: this folder contains the laplace beltrami related data. It's composed from files having the same name as the `off` folder. Each fileis a `.mat` contaning a `dict` containing three keys: `evals`, `evecs` and `evecs_trans`. This files are created by the preprocessing code. 22 | * `corres` folder: this folder contains the ".vts" files necessary for the calculation of the ground truth maps. 23 | 24 | ## Usage 25 | Use the `config.yaml` file to specify the hyperparameters as well as the dataset to be used. 26 | 27 | Use the `train.py` script to train the GeomFmaps model. 28 | ``` 29 | python3 train.py 30 | ``` 31 | 32 | References 33 | --------------------- 34 | [1] [Deep Geometric Functional Maps: Robust Feature Learning for Shape Correspondence](https://arxiv.org/pdf/2003.14286.pdf) 35 | -------------------------------------------------------------------------------- /geomfmaps/config.yaml: -------------------------------------------------------------------------------- 1 | dataroot: /path/to/dataset/folder/ 2 | 3 | neig: 30 4 | n_train: 80 # number of shapes used in training 5 | max_train: 100 # max number of shapes used by the model (train - test) 6 | 7 | 8 | pre_transforms: 9 | - transform: GridSampling3D 10 | params: 11 | size: 0.02 12 | train_transforms: 13 | - transform: Random3AxisRotation 14 | params: 15 | apply_rotation: True 16 | rot_x: 0 17 | rot_y: 360 18 | rot_z: 0 19 | - transform: RandomNoise 20 | params: 21 | sigma: 0.01 22 | clip: 0.05 23 | - transform: RandomScaleAnisotropic 24 | params: 25 | scales: [0.9,1.1] 26 | - transform: AddOnes 27 | - transform: AddFeatsByKeys 28 | params: 29 | list_add_to_x: [True] 30 | feat_names: ["ones"] 31 | delete_feats: [True] 32 | test_transforms: 33 | - transform: AddOnes 34 | - transform: AddFeatsByKeys 35 | params: 36 | list_add_to_x: [True] 37 | feat_names: ["ones"] 38 | delete_feats: [True] 39 | 40 | # model params 41 | lambda_: 1e-3 42 | in_grid_size: 0.02 43 | n_feat: 128 44 | 45 | # general 46 | no_cuda: False 47 | batch_size: 8 48 | n_cpu: 8 49 | n_epochs: 20 50 | lr: 1e-3 51 | checkpoint_interval: 5 52 | log_interval: 20 53 | savedir: path/to/savedir/ 54 | evaldir: path/to/evaldir/ 55 | -------------------------------------------------------------------------------- /geomfmaps/eval.py: -------------------------------------------------------------------------------- 1 | # stdlib 2 | import os 3 | # 3p 4 | import torch 5 | from omegaconf import OmegaConf 6 | # project 7 | from model import GeomFmapNet 8 | from shape_matching_dataset import ShapeMatchingDatasetWrapper 9 | 10 | 11 | def eval_model(model_path, params): 12 | if torch.cuda.is_available() and not params.no_cuda: 13 | device = torch.device("cuda:0") 14 | else: 15 | device = torch.device("cpu") 16 | 17 | if not os.path.exists(params.evaldir): 18 | os.makedirs(params.evaldir) 19 | 20 | # create model 21 | model = GeomFmapNet(params.n_feat, params.in_grid_size, params.lambda_).to(device) 22 | model.load_state_dict(torch.load(model_path)) 23 | model.eval() 24 | 25 | # create dataset 26 | testset = ShapeMatchingDatasetWrapper(params, train=False) 27 | testloader = testset.get_dataloader(model.feature_extractor, batch_size=1, 28 | shuffle=False, num_workers=params.n_cpu, precompute_multi_scale=True) 29 | 30 | to_save = [] 31 | used_names, combinations = testloader._dataset.used_names, testloader._dataset.combinations 32 | for i, batch in enumerate(testloader): 33 | batch = batch.to(device) 34 | with torch.set_grad_enabled(False): 35 | C_est = model(batch).t() 36 | 37 | # save 38 | target, source = used_names[combinations[i][0]], used_names[combinations[i][1]] 39 | to_save.append({'C_est': C_est, "source": source, "target": target}) 40 | 41 | torch.save(to_save, os.path.join(params.evaldir, "fmaps.pt")) 42 | 43 | 44 | if __name__ == "__main__": 45 | params = OmegaConf.load("config.yaml") 46 | PATH = "path/to/model.pt" 47 | 48 | eval_model(PATH, params) -------------------------------------------------------------------------------- /geomfmaps/model.py: -------------------------------------------------------------------------------- 1 | # stdlib 2 | import warnings 3 | # 3p 4 | import torch 5 | import torch.nn as nn 6 | from torch_points3d.applications.kpconv import KPConv 7 | from torch_points3d.core.common_modules import MLP 8 | 9 | warnings.filterwarnings("ignore") 10 | 11 | 12 | class FMRegNet(nn.Module): 13 | """Implement Functional map regularizer layer of GeomFNet. 14 | Take as input computed descriptors and returns functional map matrix.""" 15 | def __init__(self, lambda_=1e-3): 16 | """Init layer. 17 | Keyword Arguments: 18 | lambda_ {float} -- regularization parameter (default: {1e-3}) 19 | """ 20 | super().__init__() 21 | self.lambda_ = lambda_ 22 | 23 | def forward(self, data, features): 24 | """One pass in Regularizer layer. 25 | 26 | Returns: 27 | torch.Tensor -- Functional map matrix from source to target. Size: batch_size x neig x neig 28 | """ 29 | tot = 0 30 | neig = data.evecs_x.shape[1] 31 | A_l, B_l, evals_x_l, evals_y_l = [], [], [], [] 32 | for i in range(0, len(data.nv), 2): 33 | # get x features and spectral 34 | evals_x_l.append(data.evals_x[neig * i: neig * (i + 1)].unsqueeze(0)) 35 | x = features[tot: tot + data.nv[i]] 36 | evecs_trans_x = data.evecs_trans_x[tot: tot + data.nv[i]] 37 | A_l.append((evecs_trans_x.T @ x).unsqueeze(0)) 38 | tot += data.nv[i] 39 | # get y features and spectral 40 | evals_y_l.append(data.evals_x[neig * (i + 1): neig * (i + 2)].unsqueeze(0)) 41 | y = features[tot: tot + data.nv[i + 1]] 42 | evecs_trans_y = data.evecs_trans_x[tot: tot + data.nv[i + 1]] 43 | B_l.append((evecs_trans_y.T @ y).unsqueeze(0)) 44 | tot += data.nv[i + 1] 45 | 46 | A = torch.cat(A_l, dim=0) 47 | B = torch.cat(B_l, dim=0) 48 | evals_x = torch.cat(evals_x_l, dim=0) 49 | evals_y = torch.cat(evals_y_l, dim=0) 50 | 51 | A_t = A.transpose(1, 2) 52 | A_A_t = torch.bmm(A, A_t) 53 | B_A_t = torch.bmm(B, A_t) 54 | 55 | D = torch.repeat_interleave(evals_x.unsqueeze(1), repeats=evals_x.size(1), dim=1) 56 | D = (D - torch.repeat_interleave(evals_y.unsqueeze(2), repeats=evals_x.size(1), dim=2)) ** 2 57 | 58 | C_i = [] 59 | for i in range(evals_x.size(1)): 60 | D_i = torch.cat([torch.diag(D[bs, i, :].flatten()).unsqueeze(0) for bs in range(evals_x.size(0))], dim=0) 61 | C = torch.bmm(torch.inverse(A_A_t + self.lambda_ * D_i), B_A_t[:, i, :].unsqueeze(1).transpose(1, 2)) 62 | C_i.append(C.transpose(1, 2)) 63 | C = torch.cat(C_i, dim=1) 64 | 65 | return C 66 | 67 | 68 | class FeatureRegressor(torch.nn.Module): 69 | """ Allows segregated segmentation in case the category of an object is known. 70 | This is the case in ShapeNet for example. 71 | 72 | Parameters 73 | ---------- 74 | in_features - 75 | size of the input channel 76 | n_feat: number of output features 77 | """ 78 | 79 | def __init__(self, in_features, n_feat, dropout_proba=0.5, bn_momentum=0.1): 80 | super().__init__() 81 | 82 | up_factor = 3 83 | 84 | self.channel_rasing = MLP( 85 | [in_features, n_feat * up_factor], bn_momentum=bn_momentum, bias=False 86 | ) 87 | if dropout_proba: 88 | self.channel_rasing.add_module("Dropout", torch.nn.Dropout(p=dropout_proba)) 89 | 90 | self.final_mlp = MLP([n_feat * up_factor, n_feat], bias=True) 91 | 92 | def forward(self, features, **kwargs): 93 | assert features.dim() == 2 94 | features = self.channel_rasing(features) 95 | features = self.final_mlp(features) 96 | 97 | return features 98 | 99 | 100 | class KPConvFeatureExtractor(torch.nn.Module): 101 | def __init__(self, n_feat, in_grid_size): 102 | super().__init__() 103 | 104 | self.unet = KPConv( 105 | architecture="unet", 106 | input_nc=0, 107 | num_layers=4, 108 | in_grid_size=in_grid_size 109 | ) 110 | self.feature_regressor = FeatureRegressor(self.unet.output_nc, n_feat) 111 | 112 | @property 113 | def conv_type(self): 114 | """ This is needed by the dataset to infer which batch collate should be used""" 115 | return self.unet.conv_type 116 | 117 | def forward(self, data): 118 | # Forward through unet and feature_regressor 119 | data_features = self.unet(data) 120 | self.output = self.feature_regressor(data_features.x) 121 | 122 | return self.output 123 | 124 | def get_spatial_ops(self): 125 | return self.unet.get_spatial_ops() 126 | 127 | 128 | class GeomFmapNet(nn.Module): 129 | def __init__(self, n_feat, in_grid_size, lambda_): 130 | super().__init__() 131 | self.feature_extractor = KPConvFeatureExtractor(n_feat=n_feat, in_grid_size=in_grid_size) 132 | 133 | self.fmreg_net = FMRegNet(lambda_=lambda_) 134 | 135 | def forward(self, batch): 136 | features = self.feature_extractor(batch) 137 | C = self.fmreg_net(batch, features) 138 | return C 139 | -------------------------------------------------------------------------------- /geomfmaps/shape_matching_dataset.py: -------------------------------------------------------------------------------- 1 | # stdlib 2 | from pathlib import Path 3 | from itertools import permutations 4 | from functools import partial 5 | import hashlib 6 | # 3p 7 | from omegaconf import OmegaConf 8 | from tqdm import tqdm 9 | import numpy as np 10 | import scipy.io as sio 11 | import torch 12 | import torch_geometric 13 | from torch_geometric.data import Data, InMemoryDataset 14 | from torch_geometric.data.dataset import __repr__ 15 | from torch_points3d.core.data_transform import SaveOriginalPosId 16 | from torch_points3d.datasets.base_dataset import BaseDataset 17 | from torch_points3d.datasets.batch import SimpleBatch 18 | from torch_points3d.datasets.multiscale_data import MultiScaleBatch 19 | from torch_points3d.utils.enums import ConvolutionFormat 20 | from torch_points3d.utils.config import ConvolutionFormatFactory 21 | # project 22 | from utils import read_mesh 23 | 24 | 25 | class ShapeMatchingDataset(InMemoryDataset): 26 | """Abstract class for shape matching dataset""" 27 | def __init__(self, 28 | dataroot, transform=None, pre_transform=None, 29 | neig=30, n_train=80, max_train=100, train=True): 30 | 31 | """Init dataset. 32 | Arguments: 33 | dataroot {string} -- Path to dataset 34 | Keyword Arguments: 35 | neig {int} -- number of eigenvectors used for representation (default: {30}) 36 | transform {object} -- set of transforms to apply to dataset in training and testing (default: {None}) 37 | pre_transform {object} -- set of transforms to apply to dataset before training starts (default: {None}) 38 | n_train {int} -- Number of shapes used in training (default: {80}) 39 | max_train {int} -- Total Number of shapes used in training & testing (default: {100}) 40 | train {bool} -- set dataset to training mode (default: {True}) 41 | """ 42 | 43 | assert max_train >= n_train, f"max_train={max_train} is smaller than n_train={n_train}" 44 | # dataset path 45 | self.dataset_root = Path(dataroot) 46 | self.samples_path = (self.dataset_root / "off").resolve() 47 | self.spectral_path = (self.dataset_root / "spectral").resolve() 48 | self.processed_path = (self.dataset_root / "processed").resolve() 49 | self.raw_path = (self.dataset_root / "raw").resolve() 50 | self.raw_path.mkdir(parents=True, exist_ok=True) 51 | 52 | # params 53 | self.neig = neig 54 | self.train = train 55 | 56 | # train/test dataset 57 | self.sample_names = sorted([x for x in self.samples_path.iterdir() if x.is_file()]) 58 | self.spectral_names = sorted([x for x in self.spectral_path.iterdir() if x.is_file()]) 59 | self.corres_path = (self.dataset_root / "corres").resolve() 60 | 61 | # draw samples 62 | self.all_ind = list(range(len(self.sample_names))) 63 | if train: 64 | self.chosen_indices = self.all_ind[:n_train] 65 | else: 66 | self.chosen_indices = self.all_ind[n_train:max_train] 67 | 68 | # load data to ram (small dataset) 69 | self.used_names = sorted([self.sample_names[i].stem.split("_")[-1] for i in self.chosen_indices]) 70 | self.split_name = '_'.join(self.used_names) + __repr__(pre_transform) 71 | self.split_name = hashlib.sha1(self.split_name.encode()).hexdigest() 72 | print(f"Using: {self.used_names}") 73 | if not train: 74 | self.samples = [self.load_sample(self.sample_names[i]) for i in tqdm(self.chosen_indices, desc="Loading samples")] 75 | 76 | self.evecs = [self.load_spectral(self.spectral_names[i])[1] for i in tqdm(self.chosen_indices, desc="Loading evecs")] 77 | self.combinations = list(permutations(range(len(self.chosen_indices)), 2)) 78 | 79 | # load vts 80 | self.vts_names = sorted([x for x in self.corres_path.iterdir() 81 | if x.is_file() and not any(y in str(x) for y in ["sampleID", "sym"])]) 82 | 83 | self.chosen_vts = [self.vts_names[i] for i in self.chosen_indices] 84 | self.vts = [np.loadtxt(v_path, dtype=np.int32) - 1 for v_path in tqdm(self.chosen_vts, desc="Loading vts")] 85 | 86 | super().__init__(dataroot, transform, pre_transform) 87 | 88 | self.data, self.slices = self.load_data(self.processed_path / f"{self.split_name}.pt") 89 | 90 | def load_data(self, path): 91 | '''This function is used twice to load data for both raw and pre_transformed 92 | ''' 93 | data, slices = torch.load(path) 94 | 95 | return data, slices 96 | 97 | def load_sample(self, path): 98 | """Load and normalize a mesh." 99 | Arguments: 100 | path {string} -- path to mesh file. 101 | Returns: 102 | torch.Tensor -- Tensor containing vertices of shape. Size: `n_points x 3` 103 | """ 104 | verts, _ = read_mesh(path) 105 | return torch.Tensor(verts) 106 | 107 | def load_spectral(self, path): 108 | """Load spectral data at `path`. 109 | The data is stored in a dict. This dict has the following keys: 110 | evals: eigen values. shape: neig x 1. 111 | evecs: eigen vectors. shape: `num_vertices` x neig. 112 | evecs_trans: transposed eigen vectors. shape: neig x `num_vertices`. 113 | Arguments: 114 | path {string} -- path to load spectral data from. 115 | Returns: 116 | tuple(torch.Tensor, torch.Tensor, torch.Tensor) -- spectral data. 117 | """ 118 | mat = sio.loadmat(path) 119 | return (torch.Tensor(mat['evals']).flatten()[:self.neig].float(), 120 | torch.Tensor(mat['evecs'])[:, :self.neig].float(), 121 | torch.Tensor(mat['evecs_trans'])[:self.neig, :].T.float()) 122 | 123 | def load_c(self, i, j): 124 | """Compute functional map matrix from shape `i` to shape `j`. 125 | Arguments: 126 | i {int} -- index of source shape. 127 | j {int} -- index of target shape. 128 | Returns: 129 | torch.Tensor -- Tensor representing the functional map. Size: `n_eig x n_eig`. 130 | """ 131 | # load eigen vectors & vts 132 | evec_i, evec_j = self.evecs[i], self.evecs[j] 133 | vts_i, vts_j = self.vts[i], self.vts[j] 134 | 135 | # compute C 136 | evec_i_a, evec_j_a = evec_i[vts_i], evec_j[vts_j] 137 | C_i_j = np.linalg.lstsq(evec_i_a, evec_j_a, rcond=None)[0] 138 | return torch.Tensor(C_i_j.T) 139 | 140 | def __len__(self): 141 | return len(self.combinations) 142 | 143 | def __getitem__(self, index): 144 | idx1, idx2 = self.combinations[index] 145 | 146 | # load pointcloud 147 | sample_x, sample_y = self.get(idx1), self.get(idx2) 148 | sample_x = sample_x if self.transform is None else self.transform(sample_x) 149 | sample_y = sample_y if self.transform is None else self.transform(sample_y) 150 | # load ground truth functional map 151 | C_gt = self.load_c(idx1, idx2) 152 | 153 | # continue with data class 154 | sample_x.C_gt = C_gt.unsqueeze(0) 155 | 156 | return sample_x, sample_y 157 | 158 | def _process(self): 159 | if (self.processed_path / f"{self.split_name}.pt").is_file(): # pragma: no cover 160 | return 161 | 162 | print('Processing...') 163 | 164 | self.processed_path.mkdir(parents=True, exist_ok=True) 165 | self.process() 166 | 167 | path = self.processed_path / 'pre_transform.pt' 168 | torch.save(__repr__(self.pre_transform), path) 169 | 170 | print('Done!') 171 | 172 | def process(self): 173 | data_raw_list, data_list = self._process_filenames() 174 | 175 | self._save_data_list(data_list, self.processed_path / f"{self.split_name}.pt") 176 | self._save_data_list(data_raw_list, self.raw_path / f"{self.split_name}.pt", save_bool=len(data_raw_list) > 0) 177 | 178 | def _process_filenames(self): 179 | data_raw_list = [] 180 | data_list = [] 181 | 182 | has_pre_transform = self.pre_transform is not None 183 | 184 | id_scan = -1 185 | for idx in tqdm(self.chosen_indices): 186 | id_scan += 1 187 | pos = self.load_sample(self.sample_names[idx]) 188 | evals_x, evecs_x, evecs_trans_x = self.load_spectral(self.spectral_names[idx]) 189 | x = None 190 | id_scan_tensor = torch.from_numpy(np.asarray([id_scan])).clone() 191 | data = Data(pos=pos, x=x, evals_x=evals_x, evecs_x=evecs_x, evecs_trans_x=evecs_trans_x, id_scan=id_scan_tensor) 192 | data = SaveOriginalPosId()(data) 193 | data_raw_list.append(data.clone() if has_pre_transform else data) 194 | if has_pre_transform: 195 | data = self.pre_transform(data) 196 | data.nv = data.pos.shape[0] # number of vertices 197 | data_list.append(data) 198 | if not has_pre_transform: 199 | return [], data_raw_list 200 | return data_raw_list, data_list 201 | 202 | def _save_data_list(self, datas, path_to_datas, save_bool=True): 203 | if save_bool: 204 | torch.save(self.collate(datas), path_to_datas) 205 | 206 | 207 | class ShapeMatchingDatasetWrapper(BaseDataset): 208 | """ Wrapper around ShapeNet that creates shape matching datasets. 209 | Parameters 210 | ---------- 211 | dataset_opt: omegaconf.DictConfig 212 | Config dictionary that should contain 213 | - dataroot 214 | - pre_transforms 215 | - train_transforms 216 | - test_transforms 217 | """ 218 | 219 | def __init__(self, hparams, train): 220 | """Init shape matching wrapper. 221 | Args: 222 | hparams (omegaconf): hydra dataset config 223 | train (bool): indicates if this is a training dataset 224 | """ 225 | 226 | self.train = train 227 | self.use_data_class = True 228 | self.pre_transform = None 229 | self.train_transform, self.test_transform = None, None 230 | 231 | hparams.pre_transforms = hparams.pre_transforms 232 | hparams.train_transforms = hparams.train_transforms 233 | hparams.test_transforms = hparams.test_transforms 234 | 235 | params = OmegaConf.create(hparams) 236 | super().__init__(params) 237 | 238 | transform = self.train_transform if train else self.test_transform 239 | 240 | self._dataset = ShapeMatchingDataset( 241 | hparams.dataroot, pre_transform=self.pre_transform, transform=transform, 242 | neig=hparams.neig, n_train=hparams.n_train, max_train=hparams.max_train, 243 | train=train 244 | ) 245 | 246 | self._test_loaders = [] 247 | 248 | def get_dataloader(self, model, batch_size, shuffle, num_workers, precompute_multi_scale=True): 249 | if not self.use_data_class: 250 | return torch.utils.data.DataLoader(self._dataset, batch_size=batch_size, 251 | shuffle=shuffle and not self.train_sampler, num_workers=num_workers) 252 | conv_type = model.conv_type 253 | self._batch_size = batch_size 254 | 255 | batch_collate_function = self.__class__._get_collate_function(conv_type, precompute_multi_scale) 256 | dataloader = partial( 257 | torch.utils.data.DataLoader, collate_fn=batch_collate_function, worker_init_fn=lambda _: np.random.seed() 258 | ) 259 | 260 | self._dataloader = dataloader( 261 | self._dataset, 262 | batch_size=batch_size, 263 | shuffle=shuffle and not self.train_sampler, 264 | num_workers=num_workers, 265 | sampler=self.train_sampler, 266 | ) 267 | 268 | if precompute_multi_scale: # check if this excuted 269 | self.set_strategies(model) 270 | 271 | return self._dataloader 272 | 273 | @staticmethod 274 | def _get_collate_function(conv_type, is_multiscale): 275 | if is_multiscale: 276 | if conv_type.lower() == ConvolutionFormat.PARTIAL_DENSE.value.lower(): 277 | return lambda datalist: MultiScaleBatch.from_data_list([y for x in datalist for y in x]) 278 | else: 279 | raise NotImplementedError( 280 | "MultiscaleTransform is activated and supported only for partial_dense format" 281 | ) 282 | 283 | is_dense = ConvolutionFormatFactory.check_is_dense_format(conv_type) 284 | if is_dense: 285 | return lambda datalist: SimpleBatch.from_data_list(datalist) 286 | else: 287 | return lambda datalist: torch_geometric.data.batch.Batch.from_data_list(datalist) 288 | -------------------------------------------------------------------------------- /geomfmaps/train.py: -------------------------------------------------------------------------------- 1 | # stdlib 2 | import os 3 | # 3p 4 | import torch 5 | from omegaconf import OmegaConf 6 | # project 7 | from model import GeomFmapNet 8 | from shape_matching_dataset import ShapeMatchingDatasetWrapper 9 | from utils import frobenius_loss 10 | 11 | 12 | def train(params): 13 | if torch.cuda.is_available() and not params.no_cuda: 14 | device = torch.device("cuda:0") 15 | else: 16 | device = torch.device("cpu") 17 | 18 | if not os.path.exists(params.savedir): 19 | os.makedirs(params.savedir) 20 | 21 | # create model 22 | model = GeomFmapNet(params.n_feat, params.in_grid_size, params.lambda_).to(device) 23 | optimizer = torch.optim.Adam(model.parameters(), lr=params.lr) 24 | 25 | # create dataset 26 | trainset = ShapeMatchingDatasetWrapper(params, train=True) 27 | trainloader = trainset.get_dataloader(model.feature_extractor, batch_size=params.batch_size, 28 | shuffle=True, num_workers=params.n_cpu, precompute_multi_scale=True) 29 | 30 | # Training loop 31 | iterations = 0 32 | for epoch in range(1, params.n_epochs + 1): 33 | model.train() 34 | for i, batch in enumerate(trainloader): 35 | batch = batch.to(device) 36 | 37 | # do iteration 38 | optimizer.zero_grad() 39 | with torch.set_grad_enabled(True): 40 | C_est = model(batch) 41 | loss = frobenius_loss(C_est, batch.C_gt) 42 | loss.backward() 43 | optimizer.step() 44 | 45 | # log and save model 46 | iterations += 1 47 | if iterations % params.log_interval == 0: 48 | print(f"#epoch:{epoch}, #batch:{i + 1}, #iteration:{iterations}, fmap loss:{loss}") 49 | 50 | if (epoch + 1) % params.checkpoint_interval == 0: 51 | torch.save(model.state_dict(), os.path.join(params.savedir, 'epoch{}.pth'.format(epoch))) 52 | 53 | 54 | if __name__ == "__main__": 55 | params = OmegaConf.load("config.yaml") 56 | train(params) 57 | -------------------------------------------------------------------------------- /geomfmaps/utils.py: -------------------------------------------------------------------------------- 1 | # stdlib 2 | from pathlib import Path 3 | # 3p 4 | import numpy as np 5 | import torch 6 | 7 | 8 | def read_off(file): 9 | file = open(file, "r") 10 | if file.readline().strip() != "OFF": 11 | raise "Not a valid OFF header" 12 | 13 | n_verts, n_faces, n_dontknow = tuple([int(s) for s in file.readline().strip().split(" ")]) 14 | verts = [[float(s) for s in file.readline().strip().split(" ")] for i_vert in range(n_verts)] 15 | faces = [[int(s) for s in file.readline().strip().split(" ")][1:] for i_face in range(n_faces)] 16 | 17 | return np.array(verts), np.array(faces) 18 | 19 | 20 | def write_off(file, verts, faces): 21 | file = open(file, "w") 22 | file.write("OFF\n") 23 | file.write(f"{verts.shape[0]} {faces.shape[0]} {0}\n") 24 | for x in verts: 25 | file.write(f"{' '.join(map(str, x))}\n") 26 | for x in faces: 27 | file.write(f"{len(x)} {' '.join(map(str, x))}\n") 28 | 29 | 30 | def read_mesh(file): 31 | file = Path(file) 32 | if file.suffix == ".off": 33 | return read_off(file) 34 | else: 35 | raise "File extention not implemented yet!" 36 | 37 | 38 | def write_mesh(file, verts, faces): 39 | file = Path(file) 40 | if file.suffix == ".off": 41 | write_off(file, verts, faces) 42 | else: 43 | raise "File extention not implemented yet!" 44 | 45 | 46 | def frobenius_loss(a, b): 47 | """Compute the Frobenius loss between a and b.""" 48 | 49 | loss = torch.sum((a - b) ** 2, axis=(1, 2)) 50 | return torch.mean(loss) 51 | -------------------------------------------------------------------------------- /requirement.txt: -------------------------------------------------------------------------------- 1 | torch 2 | omegaconf 3 | torch-points3d 4 | tqdm 5 | numpy 6 | scipy --------------------------------------------------------------------------------