├── .gitignore ├── .travis.yml ├── LICENSE ├── README.md ├── data ├── __init__.py ├── base_dataset.py ├── classification_data.py └── segmentation_data.py ├── docs ├── imgs │ ├── T18.png │ ├── T252.png │ ├── T76.png │ ├── alien.gif │ ├── coseg_alien.png │ ├── coseg_chair.png │ ├── coseg_vase.png │ ├── cubes.png │ ├── cubes2.png │ ├── input_edge_features.png │ ├── mesh_conv.png │ ├── mesh_pool_unpool.png │ ├── meshcnn_overview.png │ ├── shrec16_train.png │ ├── shrec__10_0.png │ ├── shrec__14_0.png │ └── shrec__2_0.png ├── index.html └── mainpage.css ├── environment.yml ├── models ├── __init__.py ├── layers │ ├── __init__.py │ ├── mesh.py │ ├── mesh_conv.py │ ├── mesh_pool.py │ ├── mesh_prepare.py │ ├── mesh_union.py │ └── mesh_unpool.py ├── mesh_classifier.py └── networks.py ├── options ├── __init__.py ├── base_options.py ├── test_options.py └── train_options.py ├── scripts ├── coseg_seg │ ├── get_data.sh │ ├── get_pretrained.sh │ ├── test.sh │ ├── train.sh │ └── view.sh ├── cubes │ ├── get_data.sh │ ├── get_pretrained.sh │ ├── test.sh │ ├── train.sh │ └── view.sh ├── dataprep │ └── blender_process.py ├── human_seg │ ├── get_data.sh │ ├── get_pretrained.sh │ ├── test.sh │ ├── train.sh │ └── view.sh ├── shrec │ ├── get_data.sh │ ├── get_pretrained.sh │ ├── test.sh │ ├── train.sh │ └── view.sh └── test_general.py ├── test.py ├── train.py └── util ├── __init__.py ├── mesh_viewer.py ├── util.py └── writer.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/* 2 | *.pyc 3 | *.m~ 4 | 5 | # data files 6 | *.obj 7 | checkpoints 8 | datasets 9 | runs 10 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | notifications: 2 | email: 3 | on_success: never 4 | on_failure: always 5 | language: python 6 | python: 7 | - "3.6" 8 | cache: pip 9 | install: 10 | - sudo apt-get update 11 | - wget https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh -O miniconda.sh; 12 | - bash miniconda.sh -b -p $HOME/anaconda3 13 | - source "$HOME/anaconda3/etc/profile.d/conda.sh" 14 | - hash -r 15 | - conda config --set always_yes yes --set changeps1 no 16 | - conda update -q conda 17 | # Useful for debugging any issues with conda 18 | - conda info -a 19 | 20 | # create meshcnn env 21 | - conda env create -f environment.yml 22 | - conda activate meshcnn 23 | script: 24 | - python -m pytest scripts/test_general.py -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Rana Hanocka 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 |


3 | 4 | # MeshCNN in PyTorch 5 | 6 | 7 | ### SIGGRAPH 2019 [[Paper]](https://bit.ly/meshcnn) [[Project Page]](https://ranahanocka.github.io/MeshCNN/)
8 | 9 | MeshCNN is a general-purpose deep neural network for 3D triangular meshes, which can be used for tasks such as 3D shape classification or segmentation. This framework includes convolution, pooling and unpooling layers which are applied directly on the mesh edges. 10 | 11 |
12 | 13 | The code was written by [Rana Hanocka](https://www.cs.tau.ac.il/~hanocka/) and [Amir Hertz](http://pxcm.org/) with support from [Noa Fish](http://www.cs.tau.ac.il/~noafish/). 14 | 15 | # Getting Started 16 | 17 | ### Installation 18 | - Clone this repo: 19 | ```bash 20 | git clone https://github.com/ranahanocka/MeshCNN.git 21 | cd MeshCNN 22 | ``` 23 | - Install dependencies: [PyTorch](https://pytorch.org/) version 1.2. Optional : [tensorboardX](https://github.com/lanpa/tensorboardX) for training plots. 24 | - Via new conda environment `conda env create -f environment.yml` (creates an environment called meshcnn) 25 | 26 | ### 3D Shape Classification on SHREC 27 | Download the dataset 28 | ```bash 29 | bash ./scripts/shrec/get_data.sh 30 | ``` 31 | 32 | Run training (if using conda env first activate env e.g. ```source activate meshcnn```) 33 | ```bash 34 | bash ./scripts/shrec/train.sh 35 | ``` 36 | 37 | To view the training loss plots, in another terminal run ```tensorboard --logdir runs``` and click [http://localhost:6006](http://localhost:6006). 38 | 39 | Run test and export the intermediate pooled meshes: 40 | ```bash 41 | bash ./scripts/shrec/test.sh 42 | ``` 43 | 44 | Visualize the network-learned edge collapses: 45 | ```bash 46 | bash ./scripts/shrec/view.sh 47 | ``` 48 | 49 | An example of collapses for a mesh: 50 | 51 | 52 | 53 | Note, you can also get pre-trained weights using bash ```./scripts/shrec/get_pretrained.sh```. 54 | 55 | In order to use the pre-trained weights, run ```train.sh``` which will compute and save the mean / standard deviation of the training data. 56 | 57 | 58 | ### 3D Shape Segmentation on Humans 59 | The same as above, to download the dataset / run train / get pretrained / run test / view 60 | ```bash 61 | bash ./scripts/human_seg/get_data.sh 62 | bash ./scripts/human_seg/train.sh 63 | bash ./scripts/human_seg/get_pretrained.sh 64 | bash ./scripts/human_seg/test.sh 65 | bash ./scripts/human_seg/view.sh 66 | ``` 67 | 68 | Some segmentation result examples: 69 | 70 | 71 | 72 | ### Additional Datasets 73 | The same scripts also exist for COSEG segmentation in ```scripts/coseg_seg``` and cubes classification in ```scripts/cubes```. 74 | 75 | # More Info 76 | Check out the [MeshCNN wiki](https://github.com/ranahanocka/MeshCNN/wiki) for more details. Specifically, see info on [segmentation](https://github.com/ranahanocka/MeshCNN/wiki/Segmentation) and [data processing](https://github.com/ranahanocka/MeshCNN/wiki/Data-Processing). 77 | 78 | # Other implementations 79 | - [Point2Mesh tensorflow reimplementation](https://github.com/dcharatan/point2mesh-reimplementation), which also contains MeshCNN 80 | - [MedMeshCNN](https://github.com/Divya9Sasidharan/MedMeshCNN), handles meshes with 170k edges 81 | 82 | # Citation 83 | If you find this code useful, please consider citing our paper 84 | ``` 85 | @article{hanocka2019meshcnn, 86 | title={MeshCNN: A Network with an Edge}, 87 | author={Hanocka, Rana and Hertz, Amir and Fish, Noa and Giryes, Raja and Fleishman, Shachar and Cohen-Or, Daniel}, 88 | journal={ACM Transactions on Graphics (TOG)}, 89 | volume={38}, 90 | number={4}, 91 | pages = {90:1--90:12}, 92 | year={2019}, 93 | publisher={ACM} 94 | } 95 | ``` 96 | 97 | 98 | # Questions / Issues 99 | If you have questions or issues running this code, please open an issue so we can know to fix it. 100 | 101 | # Acknowledgments 102 | This code design was adopted from [pytorch-CycleGAN-and-pix2pix](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix). 103 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data 2 | from data.base_dataset import collate_fn 3 | 4 | def CreateDataset(opt): 5 | """loads dataset class""" 6 | 7 | if opt.dataset_mode == 'segmentation': 8 | from data.segmentation_data import SegmentationData 9 | dataset = SegmentationData(opt) 10 | elif opt.dataset_mode == 'classification': 11 | from data.classification_data import ClassificationData 12 | dataset = ClassificationData(opt) 13 | return dataset 14 | 15 | 16 | class DataLoader: 17 | """multi-threaded data loading""" 18 | 19 | def __init__(self, opt): 20 | self.opt = opt 21 | self.dataset = CreateDataset(opt) 22 | self.dataloader = torch.utils.data.DataLoader( 23 | self.dataset, 24 | batch_size=opt.batch_size, 25 | shuffle=not opt.serial_batches, 26 | num_workers=int(opt.num_threads), 27 | collate_fn=collate_fn) 28 | 29 | def __len__(self): 30 | return min(len(self.dataset), self.opt.max_dataset_size) 31 | 32 | def __iter__(self): 33 | for i, data in enumerate(self.dataloader): 34 | if i * self.opt.batch_size >= self.opt.max_dataset_size: 35 | break 36 | yield data 37 | -------------------------------------------------------------------------------- /data/base_dataset.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | import numpy as np 3 | import pickle 4 | import os 5 | 6 | class BaseDataset(data.Dataset): 7 | 8 | def __init__(self, opt): 9 | self.opt = opt 10 | self.mean = 0 11 | self.std = 1 12 | self.ninput_channels = None 13 | super(BaseDataset, self).__init__() 14 | 15 | def get_mean_std(self): 16 | """ Computes Mean and Standard Deviation from Training Data 17 | If mean/std file doesn't exist, will compute one 18 | :returns 19 | mean: N-dimensional mean 20 | std: N-dimensional standard deviation 21 | ninput_channels: N 22 | (here N=5) 23 | """ 24 | 25 | mean_std_cache = os.path.join(self.root, 'mean_std_cache.p') 26 | if not os.path.isfile(mean_std_cache): 27 | print('computing mean std from train data...') 28 | # doesn't run augmentation during m/std computation 29 | num_aug = self.opt.num_aug 30 | self.opt.num_aug = 1 31 | mean, std = np.array(0), np.array(0) 32 | for i, data in enumerate(self): 33 | if i % 500 == 0: 34 | print('{} of {}'.format(i, self.size)) 35 | features = data['edge_features'] 36 | mean = mean + features.mean(axis=1) 37 | std = std + features.std(axis=1) 38 | mean = mean / (i + 1) 39 | std = std / (i + 1) 40 | transform_dict = {'mean': mean[:, np.newaxis], 'std': std[:, np.newaxis], 41 | 'ninput_channels': len(mean)} 42 | with open(mean_std_cache, 'wb') as f: 43 | pickle.dump(transform_dict, f) 44 | print('saved: ', mean_std_cache) 45 | self.opt.num_aug = num_aug 46 | # open mean / std from file 47 | with open(mean_std_cache, 'rb') as f: 48 | transform_dict = pickle.load(f) 49 | print('loaded mean / std from cache') 50 | self.mean = transform_dict['mean'] 51 | self.std = transform_dict['std'] 52 | self.ninput_channels = transform_dict['ninput_channels'] 53 | 54 | 55 | def collate_fn(batch): 56 | """Creates mini-batch tensors 57 | We should build custom collate_fn rather than using default collate_fn 58 | """ 59 | meta = {} 60 | keys = batch[0].keys() 61 | for key in keys: 62 | meta.update({key: np.array([d[key] for d in batch])}) 63 | return meta -------------------------------------------------------------------------------- /data/classification_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from data.base_dataset import BaseDataset 4 | from util.util import is_mesh_file, pad 5 | from models.layers.mesh import Mesh 6 | 7 | class ClassificationData(BaseDataset): 8 | 9 | def __init__(self, opt): 10 | BaseDataset.__init__(self, opt) 11 | self.opt = opt 12 | self.device = torch.device('cuda:{}'.format(opt.gpu_ids[0])) if opt.gpu_ids else torch.device('cpu') 13 | self.root = opt.dataroot 14 | self.dir = os.path.join(opt.dataroot) 15 | self.classes, self.class_to_idx = self.find_classes(self.dir) 16 | self.paths = self.make_dataset_by_class(self.dir, self.class_to_idx, opt.phase) 17 | self.nclasses = len(self.classes) 18 | self.size = len(self.paths) 19 | self.get_mean_std() 20 | # modify for network later. 21 | opt.nclasses = self.nclasses 22 | opt.input_nc = self.ninput_channels 23 | 24 | def __getitem__(self, index): 25 | path = self.paths[index][0] 26 | label = self.paths[index][1] 27 | mesh = Mesh(file=path, opt=self.opt, hold_history=False, export_folder=self.opt.export_folder) 28 | meta = {'mesh': mesh, 'label': label} 29 | # get edge features 30 | edge_features = mesh.extract_features() 31 | edge_features = pad(edge_features, self.opt.ninput_edges) 32 | meta['edge_features'] = (edge_features - self.mean) / self.std 33 | return meta 34 | 35 | def __len__(self): 36 | return self.size 37 | 38 | # this is when the folders are organized by class... 39 | @staticmethod 40 | def find_classes(dir): 41 | classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))] 42 | classes.sort() 43 | class_to_idx = {classes[i]: i for i in range(len(classes))} 44 | return classes, class_to_idx 45 | 46 | @staticmethod 47 | def make_dataset_by_class(dir, class_to_idx, phase): 48 | meshes = [] 49 | dir = os.path.expanduser(dir) 50 | for target in sorted(os.listdir(dir)): 51 | d = os.path.join(dir, target) 52 | if not os.path.isdir(d): 53 | continue 54 | for root, _, fnames in sorted(os.walk(d)): 55 | for fname in sorted(fnames): 56 | if is_mesh_file(fname) and (root.count(phase)==1): 57 | path = os.path.join(root, fname) 58 | item = (path, class_to_idx[target]) 59 | meshes.append(item) 60 | return meshes 61 | -------------------------------------------------------------------------------- /data/segmentation_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from data.base_dataset import BaseDataset 4 | from util.util import is_mesh_file, pad 5 | import numpy as np 6 | from models.layers.mesh import Mesh 7 | 8 | class SegmentationData(BaseDataset): 9 | 10 | def __init__(self, opt): 11 | BaseDataset.__init__(self, opt) 12 | self.opt = opt 13 | self.device = torch.device('cuda:{}'.format(opt.gpu_ids[0])) if opt.gpu_ids else torch.device('cpu') 14 | self.root = opt.dataroot 15 | self.dir = os.path.join(opt.dataroot, opt.phase) 16 | self.paths = self.make_dataset(self.dir) 17 | self.seg_paths = self.get_seg_files(self.paths, os.path.join(self.root, 'seg'), seg_ext='.eseg') 18 | self.sseg_paths = self.get_seg_files(self.paths, os.path.join(self.root, 'sseg'), seg_ext='.seseg') 19 | self.classes, self.offset = self.get_n_segs(os.path.join(self.root, 'classes.txt'), self.seg_paths) 20 | self.nclasses = len(self.classes) 21 | self.size = len(self.paths) 22 | self.get_mean_std() 23 | # # modify for network later. 24 | opt.nclasses = self.nclasses 25 | opt.input_nc = self.ninput_channels 26 | 27 | def __getitem__(self, index): 28 | path = self.paths[index] 29 | mesh = Mesh(file=path, opt=self.opt, hold_history=True, export_folder=self.opt.export_folder) 30 | meta = {} 31 | meta['mesh'] = mesh 32 | label = read_seg(self.seg_paths[index]) - self.offset 33 | label = pad(label, self.opt.ninput_edges, val=-1, dim=0) 34 | meta['label'] = label 35 | soft_label = read_sseg(self.sseg_paths[index]) 36 | meta['soft_label'] = pad(soft_label, self.opt.ninput_edges, val=-1, dim=0) 37 | # get edge features 38 | edge_features = mesh.extract_features() 39 | edge_features = pad(edge_features, self.opt.ninput_edges) 40 | meta['edge_features'] = (edge_features - self.mean) / self.std 41 | return meta 42 | 43 | def __len__(self): 44 | return self.size 45 | 46 | @staticmethod 47 | def get_seg_files(paths, seg_dir, seg_ext='.seg'): 48 | segs = [] 49 | for path in paths: 50 | segfile = os.path.join(seg_dir, os.path.splitext(os.path.basename(path))[0] + seg_ext) 51 | assert(os.path.isfile(segfile)) 52 | segs.append(segfile) 53 | return segs 54 | 55 | @staticmethod 56 | def get_n_segs(classes_file, seg_files): 57 | if not os.path.isfile(classes_file): 58 | all_segs = np.array([], dtype='float64') 59 | for seg in seg_files: 60 | all_segs = np.concatenate((all_segs, read_seg(seg))) 61 | segnames = np.unique(all_segs) 62 | np.savetxt(classes_file, segnames, fmt='%d') 63 | classes = np.loadtxt(classes_file) 64 | offset = classes[0] 65 | classes = classes - offset 66 | return classes, offset 67 | 68 | @staticmethod 69 | def make_dataset(path): 70 | meshes = [] 71 | assert os.path.isdir(path), '%s is not a valid directory' % path 72 | 73 | for root, _, fnames in sorted(os.walk(path)): 74 | for fname in fnames: 75 | if is_mesh_file(fname): 76 | path = os.path.join(root, fname) 77 | meshes.append(path) 78 | 79 | return meshes 80 | 81 | 82 | def read_seg(seg): 83 | seg_labels = np.loadtxt(open(seg, 'r'), dtype='float64') 84 | return seg_labels 85 | 86 | 87 | def read_sseg(sseg_file): 88 | sseg_labels = read_seg(sseg_file) 89 | sseg_labels = np.array(sseg_labels > 0, dtype=np.int32) 90 | return sseg_labels -------------------------------------------------------------------------------- /docs/imgs/T18.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ranahanocka/MeshCNN/5bf0b899d48eb204b9b73bc1af381be20f4d7df1/docs/imgs/T18.png -------------------------------------------------------------------------------- /docs/imgs/T252.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ranahanocka/MeshCNN/5bf0b899d48eb204b9b73bc1af381be20f4d7df1/docs/imgs/T252.png -------------------------------------------------------------------------------- /docs/imgs/T76.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ranahanocka/MeshCNN/5bf0b899d48eb204b9b73bc1af381be20f4d7df1/docs/imgs/T76.png -------------------------------------------------------------------------------- /docs/imgs/alien.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ranahanocka/MeshCNN/5bf0b899d48eb204b9b73bc1af381be20f4d7df1/docs/imgs/alien.gif -------------------------------------------------------------------------------- /docs/imgs/coseg_alien.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ranahanocka/MeshCNN/5bf0b899d48eb204b9b73bc1af381be20f4d7df1/docs/imgs/coseg_alien.png -------------------------------------------------------------------------------- /docs/imgs/coseg_chair.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ranahanocka/MeshCNN/5bf0b899d48eb204b9b73bc1af381be20f4d7df1/docs/imgs/coseg_chair.png -------------------------------------------------------------------------------- /docs/imgs/coseg_vase.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ranahanocka/MeshCNN/5bf0b899d48eb204b9b73bc1af381be20f4d7df1/docs/imgs/coseg_vase.png -------------------------------------------------------------------------------- /docs/imgs/cubes.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ranahanocka/MeshCNN/5bf0b899d48eb204b9b73bc1af381be20f4d7df1/docs/imgs/cubes.png -------------------------------------------------------------------------------- /docs/imgs/cubes2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ranahanocka/MeshCNN/5bf0b899d48eb204b9b73bc1af381be20f4d7df1/docs/imgs/cubes2.png -------------------------------------------------------------------------------- /docs/imgs/input_edge_features.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ranahanocka/MeshCNN/5bf0b899d48eb204b9b73bc1af381be20f4d7df1/docs/imgs/input_edge_features.png -------------------------------------------------------------------------------- /docs/imgs/mesh_conv.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ranahanocka/MeshCNN/5bf0b899d48eb204b9b73bc1af381be20f4d7df1/docs/imgs/mesh_conv.png -------------------------------------------------------------------------------- /docs/imgs/mesh_pool_unpool.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ranahanocka/MeshCNN/5bf0b899d48eb204b9b73bc1af381be20f4d7df1/docs/imgs/mesh_pool_unpool.png -------------------------------------------------------------------------------- /docs/imgs/meshcnn_overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ranahanocka/MeshCNN/5bf0b899d48eb204b9b73bc1af381be20f4d7df1/docs/imgs/meshcnn_overview.png -------------------------------------------------------------------------------- /docs/imgs/shrec16_train.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ranahanocka/MeshCNN/5bf0b899d48eb204b9b73bc1af381be20f4d7df1/docs/imgs/shrec16_train.png -------------------------------------------------------------------------------- /docs/imgs/shrec__10_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ranahanocka/MeshCNN/5bf0b899d48eb204b9b73bc1af381be20f4d7df1/docs/imgs/shrec__10_0.png -------------------------------------------------------------------------------- /docs/imgs/shrec__14_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ranahanocka/MeshCNN/5bf0b899d48eb204b9b73bc1af381be20f4d7df1/docs/imgs/shrec__14_0.png -------------------------------------------------------------------------------- /docs/imgs/shrec__2_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ranahanocka/MeshCNN/5bf0b899d48eb204b9b73bc1af381be20f4d7df1/docs/imgs/shrec__2_0.png -------------------------------------------------------------------------------- /docs/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | MeshCNN 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 |
14 |
15 |

Siggraph 2019

16 |

MeshCNN: A Network with an Edge

17 |
18 | Rana Hanocka1 19 |   20 | Amir Hertz1 21 |   22 | Noa Fish1 23 |   24 | Raja Giryes1 25 |   26 | Shachar Fleishman2 27 |   28 | Daniel Cohen-Or1 29 |
30 | 1Tel Aviv University     2Amazon

31 |
32 |
33 | 34 |
35 |
36 | 37 |
38 |
39 | 40 |
41 | 42 |
43 |
44 | 45 |
46 |
47 | 48 | 49 | 50 |

Paper

51 |
52 |
53 | 54 |
55 |
56 | 57 | 58 | 59 |

Code

60 |
61 |
62 | 63 |
64 |
65 | 66 | 67 | 68 |

Slides

69 |
70 |
71 | 72 |
73 |
74 | 75 |

Bibtex

76 |
77 |
78 |
79 |
80 | 81 |
82 | 83 |
84 |

Abstract

85 | Polygonal meshes provide an efficient representation for 3D shapes. They explicitly capture both shape surface and topology, 86 | and leverage non-uniformity to represent large flat regions as well as sharp, intricate features. This non-uniformity 87 | and irregularity, however, inhibits mesh analysis efforts using neural networks that combine convolution and pooling 88 | operations. In this paper, we utilize the unique properties of the mesh for a direct analysis of 3D shapes using MeshCNN, 89 | a convolutional neural network designed specifically for triangular meshes. Analogous to classic CNNs, MeshCNN combines 90 | specialized convolution and pooling layers that operate on the mesh edges, by leveraging their intrinsic geodesic 91 | connections. Convolutions are applied on edges and the four edges of their incident triangles, and pooling is applied 92 | via an edge collapse operation that retains surface topology, thereby, generating new mesh connectivity for the 93 | subsequent convolutions. MeshCNN learns which edges to collapse, thus forming a task-driven process where the network 94 | exposes and expands the important features while discarding the redundant ones. We demonstrate the effectiveness 95 | of our task-driven pooling on various learning tasks applied to 3D meshes. 96 |
97 | 98 |
99 |

Video

100 |
101 |
102 | 103 |
104 |
105 |
106 | 107 |
108 |

The Layers of MeshCNN

109 | In MeshCNN the edges of a mesh are analogous to pixels in an image, since they are the basic building blocks 110 | for all CNN operations. Just as images start with a basic input feature: an RGB value per pixel; 111 | MeshCNN starts with a few basic geometric features per edge. The input edge feature is a 5-dimensional vector 112 | every edge: the dihedral angle, two inner angles and two edge-length ratios for each face. 113 |
114 | 115 |

Input Edge Features

116 |
117 | 118 | MeshCNN learns features on the edges of the mesh, since every edge is incident to exactly two faces (triangles), 119 | which defines a natural fixed-sized convolutional neighborhood of four edges. 120 |
121 | 122 |

Mesh Convolution

123 |
124 | Learned convolutional filters are applied on each edge feature vector and the 4 one-ring neighbors. 125 | The consistent face normal order is used to apply a symmetric convolution operation, which learns edge 126 | features that are invariant to rotations, translations and uniform scale. 127 | Mesh pooling downsamples the number of features in the network, by performing a edge-collapse on the learned 128 | edge features. The new edge neighbors are computed dynamically inside the network, and used in the next convolutions. 129 |
130 | 131 |

Mesh Pooling & Unpooling

132 |
133 | For fully-convolutional tasks (such as segmentation), a mesh unpooling operation is used to restore the 134 | original mesh resolution. 135 |
136 | 137 |
138 |

Results

139 |
140 | 141 | 142 |

Learned Simplifications on Cube Dataset

143 |
144 | 145 |
146 |
147 | 148 |

Learned Simplifications on Shrec Dataset

149 |
150 | 151 |
152 | 153 | 154 | 155 |

Human Segmentation Results

156 |
157 | 158 |
159 |
160 |
161 | 162 |

Coseg Segmentation Results

163 |
164 | 165 |
166 | 167 |
168 |

Download Datasets

169 |
170 | COSEG segmentation dataset
171 | Human Segmentation dataset
172 | Cubes classification dataset
173 | Shrec classification dataset
174 | See our github page for how to run our code on these datasets. 175 |
176 |
177 | 178 |
179 |

Contact

180 |
181 | Rana at Hanocka dot com 182 |
183 |
184 | 185 | 187 | 188 | 189 | 190 | -------------------------------------------------------------------------------- /docs/mainpage.css: -------------------------------------------------------------------------------- 1 | body { 2 | font-family: 'Lato', sans-serif; 3 | font-weight: 300; 4 | color: #333; 5 | font-size: 16px; 6 | } 7 | h1 { 8 | font-size: 40px; 9 | color: #555; 10 | font-weight: 400; 11 | text-align: center; 12 | margin: 0; 13 | padding: 0; 14 | margin-top: 30px; 15 | margin-bottom: 10px; 16 | } 17 | .authors { 18 | color: #222; 19 | font-size: 24px; 20 | font-weight: 300; 21 | text-align: center; 22 | margin: 0; 23 | padding: 0; 24 | margin-bottom: 0px; 25 | } 26 | .logoimg { 27 | text-align: center; 28 | margin-bottom: 30px; 29 | } 30 | .container-fluid { 31 | margin-top: 5px; 32 | margin-bottom: 5px; 33 | } 34 | .container { 35 | margin-top: 10px; 36 | } 37 | #footer { 38 | margin-bottom: 100px; 39 | } 40 | .thumbs { 41 | -webkit-box-shadow: 1px 1px 3px #999; 42 | -moz-box-shadow: 1px 1px 3px #999; 43 | box-shadow: 1px 1px 3px #999; 44 | margin-bottom: 20px; 45 | } 46 | h2 { 47 | font-size: 24px; 48 | font-weight: 900; 49 | border-bottom: 1px solid #999; 50 | margin-bottom: 20px; 51 | } 52 | 53 | 54 | .text-primary { 55 | color: #5da2d5 !important; 56 | } 57 | .text-primary:hover { 58 | color: #f3d250 !important; 59 | opacity: 1.0; 60 | } -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: meshcnn 2 | channels: 3 | - pytorch 4 | - defaults 5 | dependencies: 6 | - python=3.6.8 7 | - cython=0.27.3 8 | - pytorch=1.2.0 9 | - numpy=1.15.0 10 | - matplotlib=3.0.3 11 | - pip 12 | - pip: 13 | - git+https://github.com/lanpa/tensorboardX.git 14 | - pytest==5.1.1 15 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | def create_model(opt): 2 | from .mesh_classifier import ClassifierModel # todo - get rid of this ? 3 | model = ClassifierModel(opt) 4 | return model 5 | -------------------------------------------------------------------------------- /models/layers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ranahanocka/MeshCNN/5bf0b899d48eb204b9b73bc1af381be20f4d7df1/models/layers/__init__.py -------------------------------------------------------------------------------- /models/layers/mesh.py: -------------------------------------------------------------------------------- 1 | from tempfile import mkstemp 2 | from shutil import move 3 | import torch 4 | import numpy as np 5 | import os 6 | from models.layers.mesh_union import MeshUnion 7 | from models.layers.mesh_prepare import fill_mesh 8 | 9 | 10 | class Mesh: 11 | 12 | def __init__(self, file=None, opt=None, hold_history=False, export_folder=''): 13 | self.vs = self.v_mask = self.filename = self.features = self.edge_areas = None 14 | self.edges = self.gemm_edges = self.sides = None 15 | self.pool_count = 0 16 | fill_mesh(self, file, opt) 17 | self.export_folder = export_folder 18 | self.history_data = None 19 | if hold_history: 20 | self.init_history() 21 | self.export() 22 | 23 | def extract_features(self): 24 | return self.features 25 | 26 | def merge_vertices(self, edge_id): 27 | self.remove_edge(edge_id) 28 | edge = self.edges[edge_id] 29 | v_a = self.vs[edge[0]] 30 | v_b = self.vs[edge[1]] 31 | # update pA 32 | v_a.__iadd__(v_b) 33 | v_a.__itruediv__(2) 34 | self.v_mask[edge[1]] = False 35 | mask = self.edges == edge[1] 36 | self.ve[edge[0]].extend(self.ve[edge[1]]) 37 | self.edges[mask] = edge[0] 38 | 39 | def remove_vertex(self, v): 40 | self.v_mask[v] = False 41 | 42 | def remove_edge(self, edge_id): 43 | vs = self.edges[edge_id] 44 | for v in vs: 45 | if edge_id not in self.ve[v]: 46 | print(self.ve[v]) 47 | print(self.filename) 48 | self.ve[v].remove(edge_id) 49 | 50 | def clean(self, edges_mask, groups): 51 | edges_mask = edges_mask.astype(bool) 52 | torch_mask = torch.from_numpy(edges_mask.copy()) 53 | self.gemm_edges = self.gemm_edges[edges_mask] 54 | self.edges = self.edges[edges_mask] 55 | self.sides = self.sides[edges_mask] 56 | new_ve = [] 57 | edges_mask = np.concatenate([edges_mask, [False]]) 58 | new_indices = np.zeros(edges_mask.shape[0], dtype=np.int32) 59 | new_indices[-1] = -1 60 | new_indices[edges_mask] = np.arange(0, np.ma.where(edges_mask)[0].shape[0]) 61 | self.gemm_edges[:, :] = new_indices[self.gemm_edges[:, :]] 62 | for v_index, ve in enumerate(self.ve): 63 | update_ve = [] 64 | # if self.v_mask[v_index]: 65 | for e in ve: 66 | update_ve.append(new_indices[e]) 67 | new_ve.append(update_ve) 68 | self.ve = new_ve 69 | self.__clean_history(groups, torch_mask) 70 | self.pool_count += 1 71 | self.export() 72 | 73 | 74 | def export(self, file=None, vcolor=None): 75 | if file is None: 76 | if self.export_folder: 77 | filename, file_extension = os.path.splitext(self.filename) 78 | file = '%s/%s_%d%s' % (self.export_folder, filename, self.pool_count, file_extension) 79 | else: 80 | return 81 | faces = [] 82 | vs = self.vs[self.v_mask] 83 | gemm = np.array(self.gemm_edges) 84 | new_indices = np.zeros(self.v_mask.shape[0], dtype=np.int32) 85 | new_indices[self.v_mask] = np.arange(0, np.ma.where(self.v_mask)[0].shape[0]) 86 | for edge_index in range(len(gemm)): 87 | cycles = self.__get_cycle(gemm, edge_index) 88 | for cycle in cycles: 89 | faces.append(self.__cycle_to_face(cycle, new_indices)) 90 | with open(file, 'w+') as f: 91 | for vi, v in enumerate(vs): 92 | vcol = ' %f %f %f' % (vcolor[vi, 0], vcolor[vi, 1], vcolor[vi, 2]) if vcolor is not None else '' 93 | f.write("v %f %f %f%s\n" % (v[0], v[1], v[2], vcol)) 94 | for face_id in range(len(faces) - 1): 95 | f.write("f %d %d %d\n" % (faces[face_id][0] + 1, faces[face_id][1] + 1, faces[face_id][2] + 1)) 96 | f.write("f %d %d %d" % (faces[-1][0] + 1, faces[-1][1] + 1, faces[-1][2] + 1)) 97 | for edge in self.edges: 98 | f.write("\ne %d %d" % (new_indices[edge[0]] + 1, new_indices[edge[1]] + 1)) 99 | 100 | def export_segments(self, segments): 101 | if not self.export_folder: 102 | return 103 | cur_segments = segments 104 | for i in range(self.pool_count + 1): 105 | filename, file_extension = os.path.splitext(self.filename) 106 | file = '%s/%s_%d%s' % (self.export_folder, filename, i, file_extension) 107 | fh, abs_path = mkstemp() 108 | edge_key = 0 109 | with os.fdopen(fh, 'w') as new_file: 110 | with open(file) as old_file: 111 | for line in old_file: 112 | if line[0] == 'e': 113 | new_file.write('%s %d' % (line.strip(), cur_segments[edge_key])) 114 | if edge_key < len(cur_segments): 115 | edge_key += 1 116 | new_file.write('\n') 117 | else: 118 | new_file.write(line) 119 | os.remove(file) 120 | move(abs_path, file) 121 | if i < len(self.history_data['edges_mask']): 122 | cur_segments = segments[:len(self.history_data['edges_mask'][i])] 123 | cur_segments = cur_segments[self.history_data['edges_mask'][i]] 124 | 125 | def __get_cycle(self, gemm, edge_id): 126 | cycles = [] 127 | for j in range(2): 128 | next_side = start_point = j * 2 129 | next_key = edge_id 130 | if gemm[edge_id, start_point] == -1: 131 | continue 132 | cycles.append([]) 133 | for i in range(3): 134 | tmp_next_key = gemm[next_key, next_side] 135 | tmp_next_side = self.sides[next_key, next_side] 136 | tmp_next_side = tmp_next_side + 1 - 2 * (tmp_next_side % 2) 137 | gemm[next_key, next_side] = -1 138 | gemm[next_key, next_side + 1 - 2 * (next_side % 2)] = -1 139 | next_key = tmp_next_key 140 | next_side = tmp_next_side 141 | cycles[-1].append(next_key) 142 | return cycles 143 | 144 | def __cycle_to_face(self, cycle, v_indices): 145 | face = [] 146 | for i in range(3): 147 | v = list(set(self.edges[cycle[i]]) & set(self.edges[cycle[(i + 1) % 3]]))[0] 148 | face.append(v_indices[v]) 149 | return face 150 | 151 | def init_history(self): 152 | self.history_data = { 153 | 'groups': [], 154 | 'gemm_edges': [self.gemm_edges.copy()], 155 | 'occurrences': [], 156 | 'old2current': np.arange(self.edges_count, dtype=np.int32), 157 | 'current2old': np.arange(self.edges_count, dtype=np.int32), 158 | 'edges_mask': [torch.ones(self.edges_count,dtype=torch.bool)], 159 | 'edges_count': [self.edges_count], 160 | } 161 | if self.export_folder: 162 | self.history_data['collapses'] = MeshUnion(self.edges_count) 163 | 164 | def union_groups(self, source, target): 165 | if self.export_folder and self.history_data: 166 | self.history_data['collapses'].union(self.history_data['current2old'][source], self.history_data['current2old'][target]) 167 | return 168 | 169 | def remove_group(self, index): 170 | if self.history_data is not None: 171 | self.history_data['edges_mask'][-1][self.history_data['current2old'][index]] = 0 172 | self.history_data['old2current'][self.history_data['current2old'][index]] = -1 173 | if self.export_folder: 174 | self.history_data['collapses'].remove_group(self.history_data['current2old'][index]) 175 | 176 | def get_groups(self): 177 | return self.history_data['groups'].pop() 178 | 179 | def get_occurrences(self): 180 | return self.history_data['occurrences'].pop() 181 | 182 | def __clean_history(self, groups, pool_mask): 183 | if self.history_data is not None: 184 | mask = self.history_data['old2current'] != -1 185 | self.history_data['old2current'][mask] = np.arange(self.edges_count, dtype=np.int32) 186 | self.history_data['current2old'][0: self.edges_count] = np.ma.where(mask)[0] 187 | if self.export_folder != '': 188 | self.history_data['edges_mask'].append(self.history_data['edges_mask'][-1].clone()) 189 | self.history_data['occurrences'].append(groups.get_occurrences()) 190 | self.history_data['groups'].append(groups.get_groups(pool_mask)) 191 | self.history_data['gemm_edges'].append(self.gemm_edges.copy()) 192 | self.history_data['edges_count'].append(self.edges_count) 193 | 194 | def unroll_gemm(self): 195 | self.history_data['gemm_edges'].pop() 196 | self.gemm_edges = self.history_data['gemm_edges'][-1] 197 | self.history_data['edges_count'].pop() 198 | self.edges_count = self.history_data['edges_count'][-1] 199 | 200 | def get_edge_areas(self): 201 | return self.edge_areas 202 | -------------------------------------------------------------------------------- /models/layers/mesh_conv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class MeshConv(nn.Module): 6 | """ Computes convolution between edges and 4 incident (1-ring) edge neighbors 7 | in the forward pass takes: 8 | x: edge features (Batch x Features x Edges) 9 | mesh: list of mesh data-structure (len(mesh) == Batch) 10 | and applies convolution 11 | """ 12 | def __init__(self, in_channels, out_channels, k=5, bias=True): 13 | super(MeshConv, self).__init__() 14 | self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=(1, k), bias=bias) 15 | self.k = k 16 | 17 | def __call__(self, edge_f, mesh): 18 | return self.forward(edge_f, mesh) 19 | 20 | def forward(self, x, mesh): 21 | x = x.squeeze(-1) 22 | G = torch.cat([self.pad_gemm(i, x.shape[2], x.device) for i in mesh], 0) 23 | # build 'neighborhood image' and apply convolution 24 | G = self.create_GeMM(x, G) 25 | x = self.conv(G) 26 | return x 27 | 28 | def flatten_gemm_inds(self, Gi): 29 | (b, ne, nn) = Gi.shape 30 | ne += 1 31 | batch_n = torch.floor(torch.arange(b * ne, device=Gi.device).float() / ne).view(b, ne) 32 | add_fac = batch_n * ne 33 | add_fac = add_fac.view(b, ne, 1) 34 | add_fac = add_fac.repeat(1, 1, nn) 35 | # flatten Gi 36 | Gi = Gi.float() + add_fac[:, 1:, :] 37 | return Gi 38 | 39 | def create_GeMM(self, x, Gi): 40 | """ gathers the edge features (x) with from the 1-ring indices (Gi) 41 | applys symmetric functions to handle order invariance 42 | returns a 'fake image' which can use 2d convolution on 43 | output dimensions: Batch x Channels x Edges x 5 44 | """ 45 | Gishape = Gi.shape 46 | # pad the first row of every sample in batch with zeros 47 | padding = torch.zeros((x.shape[0], x.shape[1], 1), requires_grad=True, device=x.device) 48 | # padding = padding.to(x.device) 49 | x = torch.cat((padding, x), dim=2) 50 | Gi = Gi + 1 #shift 51 | 52 | # first flatten indices 53 | Gi_flat = self.flatten_gemm_inds(Gi) 54 | Gi_flat = Gi_flat.view(-1).long() 55 | # 56 | odim = x.shape 57 | x = x.permute(0, 2, 1).contiguous() 58 | x = x.view(odim[0] * odim[2], odim[1]) 59 | 60 | f = torch.index_select(x, dim=0, index=Gi_flat) 61 | f = f.view(Gishape[0], Gishape[1], Gishape[2], -1) 62 | f = f.permute(0, 3, 1, 2) 63 | 64 | # apply the symmetric functions for an equivariant conv 65 | x_1 = f[:, :, :, 1] + f[:, :, :, 3] 66 | x_2 = f[:, :, :, 2] + f[:, :, :, 4] 67 | x_3 = torch.abs(f[:, :, :, 1] - f[:, :, :, 3]) 68 | x_4 = torch.abs(f[:, :, :, 2] - f[:, :, :, 4]) 69 | f = torch.stack([f[:, :, :, 0], x_1, x_2, x_3, x_4], dim=3) 70 | return f 71 | 72 | def pad_gemm(self, m, xsz, device): 73 | """ extracts one-ring neighbors (4x) -> m.gemm_edges 74 | which is of size #edges x 4 75 | add the edge_id itself to make #edges x 5 76 | then pad to desired size e.g., xsz x 5 77 | """ 78 | padded_gemm = torch.tensor(m.gemm_edges, device=device).float() 79 | padded_gemm = padded_gemm.requires_grad_() 80 | padded_gemm = torch.cat((torch.arange(m.edges_count, device=device).float().unsqueeze(1), padded_gemm), dim=1) 81 | # pad using F 82 | padded_gemm = F.pad(padded_gemm, (0, 0, 0, xsz - m.edges_count), "constant", 0) 83 | padded_gemm = padded_gemm.unsqueeze(0) 84 | return padded_gemm 85 | -------------------------------------------------------------------------------- /models/layers/mesh_pool.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from threading import Thread 4 | from models.layers.mesh_union import MeshUnion 5 | import numpy as np 6 | from heapq import heappop, heapify 7 | 8 | 9 | class MeshPool(nn.Module): 10 | 11 | def __init__(self, target, multi_thread=False): 12 | super(MeshPool, self).__init__() 13 | self.__out_target = target 14 | self.__multi_thread = multi_thread 15 | self.__fe = None 16 | self.__updated_fe = None 17 | self.__meshes = None 18 | self.__merge_edges = [-1, -1] 19 | 20 | def __call__(self, fe, meshes): 21 | return self.forward(fe, meshes) 22 | 23 | def forward(self, fe, meshes): 24 | self.__updated_fe = [[] for _ in range(len(meshes))] 25 | pool_threads = [] 26 | self.__fe = fe 27 | self.__meshes = meshes 28 | # iterate over batch 29 | for mesh_index in range(len(meshes)): 30 | if self.__multi_thread: 31 | pool_threads.append(Thread(target=self.__pool_main, args=(mesh_index,))) 32 | pool_threads[-1].start() 33 | else: 34 | self.__pool_main(mesh_index) 35 | if self.__multi_thread: 36 | for mesh_index in range(len(meshes)): 37 | pool_threads[mesh_index].join() 38 | out_features = torch.cat(self.__updated_fe).view(len(meshes), -1, self.__out_target) 39 | return out_features 40 | 41 | def __pool_main(self, mesh_index): 42 | mesh = self.__meshes[mesh_index] 43 | queue = self.__build_queue(self.__fe[mesh_index, :, :mesh.edges_count], mesh.edges_count) 44 | # recycle = [] 45 | # last_queue_len = len(queue) 46 | last_count = mesh.edges_count + 1 47 | mask = np.ones(mesh.edges_count, dtype=np.bool) 48 | edge_groups = MeshUnion(mesh.edges_count, self.__fe.device) 49 | while mesh.edges_count > self.__out_target: 50 | value, edge_id = heappop(queue) 51 | edge_id = int(edge_id) 52 | if mask[edge_id]: 53 | self.__pool_edge(mesh, edge_id, mask, edge_groups) 54 | mesh.clean(mask, edge_groups) 55 | fe = edge_groups.rebuild_features(self.__fe[mesh_index], mask, self.__out_target) 56 | self.__updated_fe[mesh_index] = fe 57 | 58 | def __pool_edge(self, mesh, edge_id, mask, edge_groups): 59 | if self.has_boundaries(mesh, edge_id): 60 | return False 61 | elif self.__clean_side(mesh, edge_id, mask, edge_groups, 0)\ 62 | and self.__clean_side(mesh, edge_id, mask, edge_groups, 2) \ 63 | and self.__is_one_ring_valid(mesh, edge_id): 64 | self.__merge_edges[0] = self.__pool_side(mesh, edge_id, mask, edge_groups, 0) 65 | self.__merge_edges[1] = self.__pool_side(mesh, edge_id, mask, edge_groups, 2) 66 | mesh.merge_vertices(edge_id) 67 | mask[edge_id] = False 68 | MeshPool.__remove_group(mesh, edge_groups, edge_id) 69 | mesh.edges_count -= 1 70 | return True 71 | else: 72 | return False 73 | 74 | def __clean_side(self, mesh, edge_id, mask, edge_groups, side): 75 | if mesh.edges_count <= self.__out_target: 76 | return False 77 | invalid_edges = MeshPool.__get_invalids(mesh, edge_id, edge_groups, side) 78 | while len(invalid_edges) != 0 and mesh.edges_count > self.__out_target: 79 | self.__remove_triplete(mesh, mask, edge_groups, invalid_edges) 80 | if mesh.edges_count <= self.__out_target: 81 | return False 82 | if self.has_boundaries(mesh, edge_id): 83 | return False 84 | invalid_edges = self.__get_invalids(mesh, edge_id, edge_groups, side) 85 | return True 86 | 87 | @staticmethod 88 | def has_boundaries(mesh, edge_id): 89 | for edge in mesh.gemm_edges[edge_id]: 90 | if edge == -1 or -1 in mesh.gemm_edges[edge]: 91 | return True 92 | return False 93 | 94 | 95 | @staticmethod 96 | def __is_one_ring_valid(mesh, edge_id): 97 | v_a = set(mesh.edges[mesh.ve[mesh.edges[edge_id, 0]]].reshape(-1)) 98 | v_b = set(mesh.edges[mesh.ve[mesh.edges[edge_id, 1]]].reshape(-1)) 99 | shared = v_a & v_b - set(mesh.edges[edge_id]) 100 | return len(shared) == 2 101 | 102 | def __pool_side(self, mesh, edge_id, mask, edge_groups, side): 103 | info = MeshPool.__get_face_info(mesh, edge_id, side) 104 | key_a, key_b, side_a, side_b, _, other_side_b, _, other_keys_b = info 105 | self.__redirect_edges(mesh, key_a, side_a - side_a % 2, other_keys_b[0], mesh.sides[key_b, other_side_b]) 106 | self.__redirect_edges(mesh, key_a, side_a - side_a % 2 + 1, other_keys_b[1], mesh.sides[key_b, other_side_b + 1]) 107 | MeshPool.__union_groups(mesh, edge_groups, key_b, key_a) 108 | MeshPool.__union_groups(mesh, edge_groups, edge_id, key_a) 109 | mask[key_b] = False 110 | MeshPool.__remove_group(mesh, edge_groups, key_b) 111 | mesh.remove_edge(key_b) 112 | mesh.edges_count -= 1 113 | return key_a 114 | 115 | @staticmethod 116 | def __get_invalids(mesh, edge_id, edge_groups, side): 117 | info = MeshPool.__get_face_info(mesh, edge_id, side) 118 | key_a, key_b, side_a, side_b, other_side_a, other_side_b, other_keys_a, other_keys_b = info 119 | shared_items = MeshPool.__get_shared_items(other_keys_a, other_keys_b) 120 | if len(shared_items) == 0: 121 | return [] 122 | else: 123 | assert (len(shared_items) == 2) 124 | middle_edge = other_keys_a[shared_items[0]] 125 | update_key_a = other_keys_a[1 - shared_items[0]] 126 | update_key_b = other_keys_b[1 - shared_items[1]] 127 | update_side_a = mesh.sides[key_a, other_side_a + 1 - shared_items[0]] 128 | update_side_b = mesh.sides[key_b, other_side_b + 1 - shared_items[1]] 129 | MeshPool.__redirect_edges(mesh, edge_id, side, update_key_a, update_side_a) 130 | MeshPool.__redirect_edges(mesh, edge_id, side + 1, update_key_b, update_side_b) 131 | MeshPool.__redirect_edges(mesh, update_key_a, MeshPool.__get_other_side(update_side_a), update_key_b, MeshPool.__get_other_side(update_side_b)) 132 | MeshPool.__union_groups(mesh, edge_groups, key_a, edge_id) 133 | MeshPool.__union_groups(mesh, edge_groups, key_b, edge_id) 134 | MeshPool.__union_groups(mesh, edge_groups, key_a, update_key_a) 135 | MeshPool.__union_groups(mesh, edge_groups, middle_edge, update_key_a) 136 | MeshPool.__union_groups(mesh, edge_groups, key_b, update_key_b) 137 | MeshPool.__union_groups(mesh, edge_groups, middle_edge, update_key_b) 138 | return [key_a, key_b, middle_edge] 139 | 140 | @staticmethod 141 | def __redirect_edges(mesh, edge_a_key, side_a, edge_b_key, side_b): 142 | mesh.gemm_edges[edge_a_key, side_a] = edge_b_key 143 | mesh.gemm_edges[edge_b_key, side_b] = edge_a_key 144 | mesh.sides[edge_a_key, side_a] = side_b 145 | mesh.sides[edge_b_key, side_b] = side_a 146 | 147 | @staticmethod 148 | def __get_shared_items(list_a, list_b): 149 | shared_items = [] 150 | for i in range(len(list_a)): 151 | for j in range(len(list_b)): 152 | if list_a[i] == list_b[j]: 153 | shared_items.extend([i, j]) 154 | return shared_items 155 | 156 | @staticmethod 157 | def __get_other_side(side): 158 | return side + 1 - 2 * (side % 2) 159 | 160 | @staticmethod 161 | def __get_face_info(mesh, edge_id, side): 162 | key_a = mesh.gemm_edges[edge_id, side] 163 | key_b = mesh.gemm_edges[edge_id, side + 1] 164 | side_a = mesh.sides[edge_id, side] 165 | side_b = mesh.sides[edge_id, side + 1] 166 | other_side_a = (side_a - (side_a % 2) + 2) % 4 167 | other_side_b = (side_b - (side_b % 2) + 2) % 4 168 | other_keys_a = [mesh.gemm_edges[key_a, other_side_a], mesh.gemm_edges[key_a, other_side_a + 1]] 169 | other_keys_b = [mesh.gemm_edges[key_b, other_side_b], mesh.gemm_edges[key_b, other_side_b + 1]] 170 | return key_a, key_b, side_a, side_b, other_side_a, other_side_b, other_keys_a, other_keys_b 171 | 172 | @staticmethod 173 | def __remove_triplete(mesh, mask, edge_groups, invalid_edges): 174 | vertex = set(mesh.edges[invalid_edges[0]]) 175 | for edge_key in invalid_edges: 176 | vertex &= set(mesh.edges[edge_key]) 177 | mask[edge_key] = False 178 | MeshPool.__remove_group(mesh, edge_groups, edge_key) 179 | mesh.edges_count -= 3 180 | vertex = list(vertex) 181 | assert(len(vertex) == 1) 182 | mesh.remove_vertex(vertex[0]) 183 | 184 | def __build_queue(self, features, edges_count): 185 | # delete edges with smallest norm 186 | squared_magnitude = torch.sum(features * features, 0) 187 | if squared_magnitude.shape[-1] != 1: 188 | squared_magnitude = squared_magnitude.unsqueeze(-1) 189 | edge_ids = torch.arange(edges_count, device=squared_magnitude.device, dtype=torch.float32).unsqueeze(-1) 190 | heap = torch.cat((squared_magnitude, edge_ids), dim=-1).tolist() 191 | heapify(heap) 192 | return heap 193 | 194 | @staticmethod 195 | def __union_groups(mesh, edge_groups, source, target): 196 | edge_groups.union(source, target) 197 | mesh.union_groups(source, target) 198 | 199 | @staticmethod 200 | def __remove_group(mesh, edge_groups, index): 201 | edge_groups.remove_group(index) 202 | mesh.remove_group(index) 203 | 204 | -------------------------------------------------------------------------------- /models/layers/mesh_prepare.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import ntpath 4 | 5 | 6 | def fill_mesh(mesh2fill, file: str, opt): 7 | load_path = get_mesh_path(file, opt.num_aug) 8 | if os.path.exists(load_path): 9 | mesh_data = np.load(load_path, encoding='latin1', allow_pickle=True) 10 | else: 11 | mesh_data = from_scratch(file, opt) 12 | np.savez_compressed(load_path, gemm_edges=mesh_data.gemm_edges, vs=mesh_data.vs, edges=mesh_data.edges, 13 | edges_count=mesh_data.edges_count, ve=mesh_data.ve, v_mask=mesh_data.v_mask, 14 | filename=mesh_data.filename, sides=mesh_data.sides, 15 | edge_lengths=mesh_data.edge_lengths, edge_areas=mesh_data.edge_areas, 16 | features=mesh_data.features) 17 | mesh2fill.vs = mesh_data['vs'] 18 | mesh2fill.edges = mesh_data['edges'] 19 | mesh2fill.gemm_edges = mesh_data['gemm_edges'] 20 | mesh2fill.edges_count = int(mesh_data['edges_count']) 21 | mesh2fill.ve = mesh_data['ve'] 22 | mesh2fill.v_mask = mesh_data['v_mask'] 23 | mesh2fill.filename = str(mesh_data['filename']) 24 | mesh2fill.edge_lengths = mesh_data['edge_lengths'] 25 | mesh2fill.edge_areas = mesh_data['edge_areas'] 26 | mesh2fill.features = mesh_data['features'] 27 | mesh2fill.sides = mesh_data['sides'] 28 | 29 | def get_mesh_path(file: str, num_aug: int): 30 | filename, _ = os.path.splitext(file) 31 | dir_name = os.path.dirname(filename) 32 | prefix = os.path.basename(filename) 33 | load_dir = os.path.join(dir_name, 'cache') 34 | load_file = os.path.join(load_dir, '%s_%03d.npz' % (prefix, np.random.randint(0, num_aug))) 35 | if not os.path.isdir(load_dir): 36 | os.makedirs(load_dir, exist_ok=True) 37 | return load_file 38 | 39 | def from_scratch(file, opt): 40 | 41 | class MeshPrep: 42 | def __getitem__(self, item): 43 | return eval('self.' + item) 44 | 45 | mesh_data = MeshPrep() 46 | mesh_data.vs = mesh_data.edges = None 47 | mesh_data.gemm_edges = mesh_data.sides = None 48 | mesh_data.edges_count = None 49 | mesh_data.ve = None 50 | mesh_data.v_mask = None 51 | mesh_data.filename = 'unknown' 52 | mesh_data.edge_lengths = None 53 | mesh_data.edge_areas = [] 54 | mesh_data.vs, faces = fill_from_file(mesh_data, file) 55 | mesh_data.v_mask = np.ones(len(mesh_data.vs), dtype=bool) 56 | faces, face_areas = remove_non_manifolds(mesh_data, faces) 57 | if opt.num_aug > 1: 58 | faces = augmentation(mesh_data, opt, faces) 59 | build_gemm(mesh_data, faces, face_areas) 60 | if opt.num_aug > 1: 61 | post_augmentation(mesh_data, opt) 62 | mesh_data.features = extract_features(mesh_data) 63 | return mesh_data 64 | 65 | def fill_from_file(mesh, file): 66 | mesh.filename = ntpath.split(file)[1] 67 | mesh.fullfilename = file 68 | vs, faces = [], [] 69 | f = open(file) 70 | for line in f: 71 | line = line.strip() 72 | splitted_line = line.split() 73 | if not splitted_line: 74 | continue 75 | elif splitted_line[0] == 'v': 76 | vs.append([float(v) for v in splitted_line[1:4]]) 77 | elif splitted_line[0] == 'f': 78 | face_vertex_ids = [int(c.split('/')[0]) for c in splitted_line[1:]] 79 | assert len(face_vertex_ids) == 3 80 | face_vertex_ids = [(ind - 1) if (ind >= 0) else (len(vs) + ind) 81 | for ind in face_vertex_ids] 82 | faces.append(face_vertex_ids) 83 | f.close() 84 | vs = np.asarray(vs) 85 | faces = np.asarray(faces, dtype=int) 86 | assert np.logical_and(faces >= 0, faces < len(vs)).all() 87 | return vs, faces 88 | 89 | 90 | def remove_non_manifolds(mesh, faces): 91 | mesh.ve = [[] for _ in mesh.vs] 92 | edges_set = set() 93 | mask = np.ones(len(faces), dtype=bool) 94 | _, face_areas = compute_face_normals_and_areas(mesh, faces) 95 | for face_id, face in enumerate(faces): 96 | if face_areas[face_id] == 0: 97 | mask[face_id] = False 98 | continue 99 | faces_edges = [] 100 | is_manifold = False 101 | for i in range(3): 102 | cur_edge = (face[i], face[(i + 1) % 3]) 103 | if cur_edge in edges_set: 104 | is_manifold = True 105 | break 106 | else: 107 | faces_edges.append(cur_edge) 108 | if is_manifold: 109 | mask[face_id] = False 110 | else: 111 | for idx, edge in enumerate(faces_edges): 112 | edges_set.add(edge) 113 | return faces[mask], face_areas[mask] 114 | 115 | 116 | def build_gemm(mesh, faces, face_areas): 117 | """ 118 | gemm_edges: array (#E x 4) of the 4 one-ring neighbors for each edge 119 | sides: array (#E x 4) indices (values of: 0,1,2,3) indicating where an edge is in the gemm_edge entry of the 4 neighboring edges 120 | for example edge i -> gemm_edges[gemm_edges[i], sides[i]] == [i, i, i, i] 121 | """ 122 | mesh.ve = [[] for _ in mesh.vs] 123 | edge_nb = [] 124 | sides = [] 125 | edge2key = dict() 126 | edges = [] 127 | edges_count = 0 128 | nb_count = [] 129 | for face_id, face in enumerate(faces): 130 | faces_edges = [] 131 | for i in range(3): 132 | cur_edge = (face[i], face[(i + 1) % 3]) 133 | faces_edges.append(cur_edge) 134 | for idx, edge in enumerate(faces_edges): 135 | edge = tuple(sorted(list(edge))) 136 | faces_edges[idx] = edge 137 | if edge not in edge2key: 138 | edge2key[edge] = edges_count 139 | edges.append(list(edge)) 140 | edge_nb.append([-1, -1, -1, -1]) 141 | sides.append([-1, -1, -1, -1]) 142 | mesh.ve[edge[0]].append(edges_count) 143 | mesh.ve[edge[1]].append(edges_count) 144 | mesh.edge_areas.append(0) 145 | nb_count.append(0) 146 | edges_count += 1 147 | mesh.edge_areas[edge2key[edge]] += face_areas[face_id] / 3 148 | for idx, edge in enumerate(faces_edges): 149 | edge_key = edge2key[edge] 150 | edge_nb[edge_key][nb_count[edge_key]] = edge2key[faces_edges[(idx + 1) % 3]] 151 | edge_nb[edge_key][nb_count[edge_key] + 1] = edge2key[faces_edges[(idx + 2) % 3]] 152 | nb_count[edge_key] += 2 153 | for idx, edge in enumerate(faces_edges): 154 | edge_key = edge2key[edge] 155 | sides[edge_key][nb_count[edge_key] - 2] = nb_count[edge2key[faces_edges[(idx + 1) % 3]]] - 1 156 | sides[edge_key][nb_count[edge_key] - 1] = nb_count[edge2key[faces_edges[(idx + 2) % 3]]] - 2 157 | mesh.edges = np.array(edges, dtype=np.int32) 158 | mesh.gemm_edges = np.array(edge_nb, dtype=np.int64) 159 | mesh.sides = np.array(sides, dtype=np.int64) 160 | mesh.edges_count = edges_count 161 | mesh.edge_areas = np.array(mesh.edge_areas, dtype=np.float32) / np.sum(face_areas) #todo whats the difference between edge_areas and edge_lenghts? 162 | 163 | 164 | def compute_face_normals_and_areas(mesh, faces): 165 | face_normals = np.cross(mesh.vs[faces[:, 1]] - mesh.vs[faces[:, 0]], 166 | mesh.vs[faces[:, 2]] - mesh.vs[faces[:, 1]]) 167 | face_areas = np.sqrt((face_normals ** 2).sum(axis=1)) 168 | face_normals /= face_areas[:, np.newaxis] 169 | assert (not np.any(face_areas[:, np.newaxis] == 0)), 'has zero area face: %s' % mesh.filename 170 | face_areas *= 0.5 171 | return face_normals, face_areas 172 | 173 | 174 | # Data augmentation methods 175 | def augmentation(mesh, opt, faces=None): 176 | if hasattr(opt, 'scale_verts') and opt.scale_verts: 177 | scale_verts(mesh) 178 | if hasattr(opt, 'flip_edges') and opt.flip_edges: 179 | faces = flip_edges(mesh, opt.flip_edges, faces) 180 | return faces 181 | 182 | 183 | def post_augmentation(mesh, opt): 184 | if hasattr(opt, 'slide_verts') and opt.slide_verts: 185 | slide_verts(mesh, opt.slide_verts) 186 | 187 | 188 | def slide_verts(mesh, prct): 189 | edge_points = get_edge_points(mesh) 190 | dihedral = dihedral_angle(mesh, edge_points).squeeze() #todo make fixed_division epsilon=0 191 | thr = np.mean(dihedral) + np.std(dihedral) 192 | vids = np.random.permutation(len(mesh.ve)) 193 | target = int(prct * len(vids)) 194 | shifted = 0 195 | for vi in vids: 196 | if shifted < target: 197 | edges = mesh.ve[vi] 198 | if min(dihedral[edges]) > 2.65: 199 | edge = mesh.edges[np.random.choice(edges)] 200 | vi_t = edge[1] if vi == edge[0] else edge[0] 201 | nv = mesh.vs[vi] + np.random.uniform(0.2, 0.5) * (mesh.vs[vi_t] - mesh.vs[vi]) 202 | mesh.vs[vi] = nv 203 | shifted += 1 204 | else: 205 | break 206 | mesh.shifted = shifted / len(mesh.ve) 207 | 208 | 209 | def scale_verts(mesh, mean=1, var=0.1): 210 | for i in range(mesh.vs.shape[1]): 211 | mesh.vs[:, i] = mesh.vs[:, i] * np.random.normal(mean, var) 212 | 213 | 214 | def angles_from_faces(mesh, edge_faces, faces): 215 | normals = [None, None] 216 | for i in range(2): 217 | edge_a = mesh.vs[faces[edge_faces[:, i], 2]] - mesh.vs[faces[edge_faces[:, i], 1]] 218 | edge_b = mesh.vs[faces[edge_faces[:, i], 1]] - mesh.vs[faces[edge_faces[:, i], 0]] 219 | normals[i] = np.cross(edge_a, edge_b) 220 | div = fixed_division(np.linalg.norm(normals[i], ord=2, axis=1), epsilon=0) 221 | normals[i] /= div[:, np.newaxis] 222 | dot = np.sum(normals[0] * normals[1], axis=1).clip(-1, 1) 223 | angles = np.pi - np.arccos(dot) 224 | return angles 225 | 226 | 227 | def flip_edges(mesh, prct, faces): 228 | edge_count, edge_faces, edges_dict = get_edge_faces(faces) 229 | dihedral = angles_from_faces(mesh, edge_faces[:, 2:], faces) 230 | edges2flip = np.random.permutation(edge_count) 231 | # print(dihedral.min()) 232 | # print(dihedral.max()) 233 | target = int(prct * edge_count) 234 | flipped = 0 235 | for edge_key in edges2flip: 236 | if flipped == target: 237 | break 238 | if dihedral[edge_key] > 2.7: 239 | edge_info = edge_faces[edge_key] 240 | if edge_info[3] == -1: 241 | continue 242 | new_edge = tuple(sorted(list(set(faces[edge_info[2]]) ^ set(faces[edge_info[3]])))) 243 | if new_edge in edges_dict: 244 | continue 245 | new_faces = np.array( 246 | [[edge_info[1], new_edge[0], new_edge[1]], [edge_info[0], new_edge[0], new_edge[1]]]) 247 | if check_area(mesh, new_faces): 248 | del edges_dict[(edge_info[0], edge_info[1])] 249 | edge_info[:2] = [new_edge[0], new_edge[1]] 250 | edges_dict[new_edge] = edge_key 251 | rebuild_face(faces[edge_info[2]], new_faces[0]) 252 | rebuild_face(faces[edge_info[3]], new_faces[1]) 253 | for i, face_id in enumerate([edge_info[2], edge_info[3]]): 254 | cur_face = faces[face_id] 255 | for j in range(3): 256 | cur_edge = tuple(sorted((cur_face[j], cur_face[(j + 1) % 3]))) 257 | if cur_edge != new_edge: 258 | cur_edge_key = edges_dict[cur_edge] 259 | for idx, face_nb in enumerate( 260 | [edge_faces[cur_edge_key, 2], edge_faces[cur_edge_key, 3]]): 261 | if face_nb == edge_info[2 + (i + 1) % 2]: 262 | edge_faces[cur_edge_key, 2 + idx] = face_id 263 | flipped += 1 264 | # print(flipped) 265 | return faces 266 | 267 | 268 | def rebuild_face(face, new_face): 269 | new_point = list(set(new_face) - set(face))[0] 270 | for i in range(3): 271 | if face[i] not in new_face: 272 | face[i] = new_point 273 | break 274 | return face 275 | 276 | def check_area(mesh, faces): 277 | face_normals = np.cross(mesh.vs[faces[:, 1]] - mesh.vs[faces[:, 0]], 278 | mesh.vs[faces[:, 2]] - mesh.vs[faces[:, 1]]) 279 | face_areas = np.sqrt((face_normals ** 2).sum(axis=1)) 280 | face_areas *= 0.5 281 | return face_areas[0] > 0 and face_areas[1] > 0 282 | 283 | 284 | def get_edge_faces(faces): 285 | edge_count = 0 286 | edge_faces = [] 287 | edge2keys = dict() 288 | for face_id, face in enumerate(faces): 289 | for i in range(3): 290 | cur_edge = tuple(sorted((face[i], face[(i + 1) % 3]))) 291 | if cur_edge not in edge2keys: 292 | edge2keys[cur_edge] = edge_count 293 | edge_count += 1 294 | edge_faces.append(np.array([cur_edge[0], cur_edge[1], -1, -1])) 295 | edge_key = edge2keys[cur_edge] 296 | if edge_faces[edge_key][2] == -1: 297 | edge_faces[edge_key][2] = face_id 298 | else: 299 | edge_faces[edge_key][3] = face_id 300 | return edge_count, np.array(edge_faces), edge2keys 301 | 302 | 303 | def set_edge_lengths(mesh, edge_points=None): 304 | if edge_points is not None: 305 | edge_points = get_edge_points(mesh) 306 | edge_lengths = np.linalg.norm(mesh.vs[edge_points[:, 0]] - mesh.vs[edge_points[:, 1]], ord=2, axis=1) 307 | mesh.edge_lengths = edge_lengths 308 | 309 | 310 | def extract_features(mesh): 311 | features = [] 312 | edge_points = get_edge_points(mesh) 313 | set_edge_lengths(mesh, edge_points) 314 | with np.errstate(divide='raise'): 315 | try: 316 | for extractor in [dihedral_angle, symmetric_opposite_angles, symmetric_ratios]: 317 | feature = extractor(mesh, edge_points) 318 | features.append(feature) 319 | return np.concatenate(features, axis=0) 320 | except Exception as e: 321 | print(e) 322 | raise ValueError(mesh.filename, 'bad features') 323 | 324 | 325 | def dihedral_angle(mesh, edge_points): 326 | normals_a = get_normals(mesh, edge_points, 0) 327 | normals_b = get_normals(mesh, edge_points, 3) 328 | dot = np.sum(normals_a * normals_b, axis=1).clip(-1, 1) 329 | angles = np.expand_dims(np.pi - np.arccos(dot), axis=0) 330 | return angles 331 | 332 | 333 | def symmetric_opposite_angles(mesh, edge_points): 334 | """ computes two angles: one for each face shared between the edge 335 | the angle is in each face opposite the edge 336 | sort handles order ambiguity 337 | """ 338 | angles_a = get_opposite_angles(mesh, edge_points, 0) 339 | angles_b = get_opposite_angles(mesh, edge_points, 3) 340 | angles = np.concatenate((np.expand_dims(angles_a, 0), np.expand_dims(angles_b, 0)), axis=0) 341 | angles = np.sort(angles, axis=0) 342 | return angles 343 | 344 | 345 | def symmetric_ratios(mesh, edge_points): 346 | """ computes two ratios: one for each face shared between the edge 347 | the ratio is between the height / base (edge) of each triangle 348 | sort handles order ambiguity 349 | """ 350 | ratios_a = get_ratios(mesh, edge_points, 0) 351 | ratios_b = get_ratios(mesh, edge_points, 3) 352 | ratios = np.concatenate((np.expand_dims(ratios_a, 0), np.expand_dims(ratios_b, 0)), axis=0) 353 | return np.sort(ratios, axis=0) 354 | 355 | 356 | def get_edge_points(mesh): 357 | """ returns: edge_points (#E x 4) tensor, with four vertex ids per edge 358 | for example: edge_points[edge_id, 0] and edge_points[edge_id, 1] are the two vertices which define edge_id 359 | each adjacent face to edge_id has another vertex, which is edge_points[edge_id, 2] or edge_points[edge_id, 3] 360 | """ 361 | edge_points = np.zeros([mesh.edges_count, 4], dtype=np.int32) 362 | for edge_id, edge in enumerate(mesh.edges): 363 | edge_points[edge_id] = get_side_points(mesh, edge_id) 364 | # edge_points[edge_id, 3:] = mesh.get_side_points(edge_id, 2) 365 | return edge_points 366 | 367 | 368 | def get_side_points(mesh, edge_id): 369 | # if mesh.gemm_edges[edge_id, side] == -1: 370 | # return mesh.get_side_points(edge_id, ((side + 2) % 4)) 371 | # else: 372 | edge_a = mesh.edges[edge_id] 373 | 374 | if mesh.gemm_edges[edge_id, 0] == -1: 375 | edge_b = mesh.edges[mesh.gemm_edges[edge_id, 2]] 376 | edge_c = mesh.edges[mesh.gemm_edges[edge_id, 3]] 377 | else: 378 | edge_b = mesh.edges[mesh.gemm_edges[edge_id, 0]] 379 | edge_c = mesh.edges[mesh.gemm_edges[edge_id, 1]] 380 | if mesh.gemm_edges[edge_id, 2] == -1: 381 | edge_d = mesh.edges[mesh.gemm_edges[edge_id, 0]] 382 | edge_e = mesh.edges[mesh.gemm_edges[edge_id, 1]] 383 | else: 384 | edge_d = mesh.edges[mesh.gemm_edges[edge_id, 2]] 385 | edge_e = mesh.edges[mesh.gemm_edges[edge_id, 3]] 386 | first_vertex = 0 387 | second_vertex = 0 388 | third_vertex = 0 389 | if edge_a[1] in edge_b: 390 | first_vertex = 1 391 | if edge_b[1] in edge_c: 392 | second_vertex = 1 393 | if edge_d[1] in edge_e: 394 | third_vertex = 1 395 | return [edge_a[first_vertex], edge_a[1 - first_vertex], edge_b[second_vertex], edge_d[third_vertex]] 396 | 397 | 398 | def get_normals(mesh, edge_points, side): 399 | edge_a = mesh.vs[edge_points[:, side // 2 + 2]] - mesh.vs[edge_points[:, side // 2]] 400 | edge_b = mesh.vs[edge_points[:, 1 - side // 2]] - mesh.vs[edge_points[:, side // 2]] 401 | normals = np.cross(edge_a, edge_b) 402 | div = fixed_division(np.linalg.norm(normals, ord=2, axis=1), epsilon=0.1) 403 | normals /= div[:, np.newaxis] 404 | return normals 405 | 406 | def get_opposite_angles(mesh, edge_points, side): 407 | edges_a = mesh.vs[edge_points[:, side // 2]] - mesh.vs[edge_points[:, side // 2 + 2]] 408 | edges_b = mesh.vs[edge_points[:, 1 - side // 2]] - mesh.vs[edge_points[:, side // 2 + 2]] 409 | 410 | edges_a /= fixed_division(np.linalg.norm(edges_a, ord=2, axis=1), epsilon=0.1)[:, np.newaxis] 411 | edges_b /= fixed_division(np.linalg.norm(edges_b, ord=2, axis=1), epsilon=0.1)[:, np.newaxis] 412 | dot = np.sum(edges_a * edges_b, axis=1).clip(-1, 1) 413 | return np.arccos(dot) 414 | 415 | 416 | def get_ratios(mesh, edge_points, side): 417 | edges_lengths = np.linalg.norm(mesh.vs[edge_points[:, side // 2]] - mesh.vs[edge_points[:, 1 - side // 2]], 418 | ord=2, axis=1) 419 | point_o = mesh.vs[edge_points[:, side // 2 + 2]] 420 | point_a = mesh.vs[edge_points[:, side // 2]] 421 | point_b = mesh.vs[edge_points[:, 1 - side // 2]] 422 | line_ab = point_b - point_a 423 | projection_length = np.sum(line_ab * (point_o - point_a), axis=1) / fixed_division( 424 | np.linalg.norm(line_ab, ord=2, axis=1), epsilon=0.1) 425 | closest_point = point_a + (projection_length / edges_lengths)[:, np.newaxis] * line_ab 426 | d = np.linalg.norm(point_o - closest_point, ord=2, axis=1) 427 | return d / edges_lengths 428 | 429 | def fixed_division(to_div, epsilon): 430 | if epsilon == 0: 431 | to_div[to_div == 0] = 0.1 432 | else: 433 | to_div += epsilon 434 | return to_div 435 | -------------------------------------------------------------------------------- /models/layers/mesh_union.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import ConstantPad2d 3 | 4 | 5 | class MeshUnion: 6 | def __init__(self, n, device=torch.device('cpu')): 7 | self.__size = n 8 | self.rebuild_features = self.rebuild_features_average 9 | self.groups = torch.eye(n, device=device) 10 | 11 | def union(self, source, target): 12 | self.groups[target, :] += self.groups[source, :] 13 | 14 | def remove_group(self, index): 15 | return 16 | 17 | def get_group(self, edge_key): 18 | return self.groups[edge_key, :] 19 | 20 | def get_occurrences(self): 21 | return torch.sum(self.groups, 0) 22 | 23 | def get_groups(self, tensor_mask): 24 | self.groups = torch.clamp(self.groups, 0, 1) 25 | return self.groups[tensor_mask, :] 26 | 27 | def rebuild_features_average(self, features, mask, target_edges): 28 | self.prepare_groups(features, mask) 29 | fe = torch.matmul(features.squeeze(-1), self.groups) 30 | occurrences = torch.sum(self.groups, 0).expand(fe.shape) 31 | fe = fe / occurrences 32 | padding_b = target_edges - fe.shape[1] 33 | if padding_b > 0: 34 | padding_b = ConstantPad2d((0, padding_b, 0, 0), 0) 35 | fe = padding_b(fe) 36 | return fe 37 | 38 | def prepare_groups(self, features, mask): 39 | tensor_mask = torch.from_numpy(mask) 40 | self.groups = torch.clamp(self.groups[tensor_mask, :], 0, 1).transpose_(1, 0) 41 | padding_a = features.shape[1] - self.groups.shape[0] 42 | if padding_a > 0: 43 | padding_a = ConstantPad2d((0, 0, 0, padding_a), 0) 44 | self.groups = padding_a(self.groups) 45 | -------------------------------------------------------------------------------- /models/layers/mesh_unpool.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | 6 | class MeshUnpool(nn.Module): 7 | def __init__(self, unroll_target): 8 | super(MeshUnpool, self).__init__() 9 | self.unroll_target = unroll_target 10 | 11 | def __call__(self, features, meshes): 12 | return self.forward(features, meshes) 13 | 14 | def pad_groups(self, group, unroll_start): 15 | start, end = group.shape 16 | padding_rows = unroll_start - start 17 | padding_cols = self.unroll_target - end 18 | if padding_rows != 0 or padding_cols !=0: 19 | padding = nn.ConstantPad2d((0, padding_cols, 0, padding_rows), 0) 20 | group = padding(group) 21 | return group 22 | 23 | def pad_occurrences(self, occurrences): 24 | padding = self.unroll_target - occurrences.shape[0] 25 | if padding != 0: 26 | padding = nn.ConstantPad1d((0, padding), 1) 27 | occurrences = padding(occurrences) 28 | return occurrences 29 | 30 | def forward(self, features, meshes): 31 | batch_size, nf, edges = features.shape 32 | groups = [self.pad_groups(mesh.get_groups(), edges) for mesh in meshes] 33 | unroll_mat = torch.cat(groups, dim=0).view(batch_size, edges, -1) 34 | occurrences = [self.pad_occurrences(mesh.get_occurrences()) for mesh in meshes] 35 | occurrences = torch.cat(occurrences, dim=0).view(batch_size, 1, -1) 36 | occurrences = occurrences.expand(unroll_mat.shape) 37 | unroll_mat = unroll_mat / occurrences 38 | unroll_mat = unroll_mat.to(features.device) 39 | for mesh in meshes: 40 | mesh.unroll_gemm() 41 | return torch.matmul(features, unroll_mat) 42 | -------------------------------------------------------------------------------- /models/mesh_classifier.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from . import networks 3 | from os.path import join 4 | from util.util import seg_accuracy, print_network 5 | 6 | 7 | class ClassifierModel: 8 | """ Class for training Model weights 9 | 10 | :args opt: structure containing configuration params 11 | e.g., 12 | --dataset_mode -> classification / segmentation) 13 | --arch -> network type 14 | """ 15 | def __init__(self, opt): 16 | self.opt = opt 17 | self.gpu_ids = opt.gpu_ids 18 | self.is_train = opt.is_train 19 | self.device = torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu') 20 | self.save_dir = join(opt.checkpoints_dir, opt.name) 21 | self.optimizer = None 22 | self.edge_features = None 23 | self.labels = None 24 | self.mesh = None 25 | self.soft_label = None 26 | self.loss = None 27 | 28 | # 29 | self.nclasses = opt.nclasses 30 | 31 | # load/define networks 32 | self.net = networks.define_classifier(opt.input_nc, opt.ncf, opt.ninput_edges, opt.nclasses, opt, 33 | self.gpu_ids, opt.arch, opt.init_type, opt.init_gain) 34 | self.net.train(self.is_train) 35 | self.criterion = networks.define_loss(opt).to(self.device) 36 | 37 | if self.is_train: 38 | self.optimizer = torch.optim.Adam(self.net.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) 39 | self.scheduler = networks.get_scheduler(self.optimizer, opt) 40 | print_network(self.net) 41 | 42 | if not self.is_train or opt.continue_train: 43 | self.load_network(opt.which_epoch) 44 | 45 | def set_input(self, data): 46 | input_edge_features = torch.from_numpy(data['edge_features']).float() 47 | labels = torch.from_numpy(data['label']).long() 48 | # set inputs 49 | self.edge_features = input_edge_features.to(self.device).requires_grad_(self.is_train) 50 | self.labels = labels.to(self.device) 51 | self.mesh = data['mesh'] 52 | if self.opt.dataset_mode == 'segmentation' and not self.is_train: 53 | self.soft_label = torch.from_numpy(data['soft_label']) 54 | 55 | 56 | def forward(self): 57 | out = self.net(self.edge_features, self.mesh) 58 | return out 59 | 60 | def backward(self, out): 61 | self.loss = self.criterion(out, self.labels) 62 | self.loss.backward() 63 | 64 | def optimize_parameters(self): 65 | self.optimizer.zero_grad() 66 | out = self.forward() 67 | self.backward(out) 68 | self.optimizer.step() 69 | 70 | 71 | ################## 72 | 73 | def load_network(self, which_epoch): 74 | """load model from disk""" 75 | save_filename = '%s_net.pth' % which_epoch 76 | load_path = join(self.save_dir, save_filename) 77 | net = self.net 78 | if isinstance(net, torch.nn.DataParallel): 79 | net = net.module 80 | print('loading the model from %s' % load_path) 81 | # PyTorch newer than 0.4 (e.g., built from 82 | # GitHub source), you can remove str() on self.device 83 | state_dict = torch.load(load_path, map_location=str(self.device)) 84 | if hasattr(state_dict, '_metadata'): 85 | del state_dict._metadata 86 | net.load_state_dict(state_dict) 87 | 88 | 89 | def save_network(self, which_epoch): 90 | """save model to disk""" 91 | save_filename = '%s_net.pth' % (which_epoch) 92 | save_path = join(self.save_dir, save_filename) 93 | if len(self.gpu_ids) > 0 and torch.cuda.is_available(): 94 | torch.save(self.net.module.cpu().state_dict(), save_path) 95 | self.net.cuda(self.gpu_ids[0]) 96 | else: 97 | torch.save(self.net.cpu().state_dict(), save_path) 98 | 99 | def update_learning_rate(self): 100 | """update learning rate (called once every epoch)""" 101 | self.scheduler.step() 102 | lr = self.optimizer.param_groups[0]['lr'] 103 | print('learning rate = %.7f' % lr) 104 | 105 | def test(self): 106 | """tests model 107 | returns: number correct and total number 108 | """ 109 | with torch.no_grad(): 110 | out = self.forward() 111 | # compute number of correct 112 | pred_class = out.data.max(1)[1] 113 | label_class = self.labels 114 | self.export_segmentation(pred_class.cpu()) 115 | correct = self.get_accuracy(pred_class, label_class) 116 | return correct, len(label_class) 117 | 118 | def get_accuracy(self, pred, labels): 119 | """computes accuracy for classification / segmentation """ 120 | if self.opt.dataset_mode == 'classification': 121 | correct = pred.eq(labels).sum() 122 | elif self.opt.dataset_mode == 'segmentation': 123 | correct = seg_accuracy(pred, self.soft_label, self.mesh) 124 | return correct 125 | 126 | def export_segmentation(self, pred_seg): 127 | if self.opt.dataset_mode == 'segmentation': 128 | for meshi, mesh in enumerate(self.mesh): 129 | mesh.export_segments(pred_seg[meshi, :]) 130 | -------------------------------------------------------------------------------- /models/networks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import init 4 | import functools 5 | from torch.optim import lr_scheduler 6 | from models.layers.mesh_conv import MeshConv 7 | import torch.nn.functional as F 8 | from models.layers.mesh_pool import MeshPool 9 | from models.layers.mesh_unpool import MeshUnpool 10 | 11 | 12 | ############################################################################### 13 | # Helper Functions 14 | ############################################################################### 15 | 16 | 17 | def get_norm_layer(norm_type='instance', num_groups=1): 18 | if norm_type == 'batch': 19 | norm_layer = functools.partial(nn.BatchNorm2d, affine=True) 20 | elif norm_type == 'instance': 21 | norm_layer = functools.partial(nn.InstanceNorm2d, affine=False) 22 | elif norm_type == 'group': 23 | norm_layer = functools.partial(nn.GroupNorm, affine=True, num_groups=num_groups) 24 | elif norm_type == 'none': 25 | norm_layer = NoNorm 26 | else: 27 | raise NotImplementedError('normalization layer [%s] is not found' % norm_type) 28 | return norm_layer 29 | 30 | def get_norm_args(norm_layer, nfeats_list): 31 | if hasattr(norm_layer, '__name__') and norm_layer.__name__ == 'NoNorm': 32 | norm_args = [{'fake': True} for f in nfeats_list] 33 | elif norm_layer.func.__name__ == 'GroupNorm': 34 | norm_args = [{'num_channels': f} for f in nfeats_list] 35 | elif norm_layer.func.__name__ == 'BatchNorm': 36 | norm_args = [{'num_features': f} for f in nfeats_list] 37 | else: 38 | raise NotImplementedError('normalization layer [%s] is not found' % norm_layer.func.__name__) 39 | return norm_args 40 | 41 | class NoNorm(nn.Module): #todo with abstractclass and pass 42 | def __init__(self, fake=True): 43 | self.fake = fake 44 | super(NoNorm, self).__init__() 45 | def forward(self, x): 46 | return x 47 | def __call__(self, x): 48 | return self.forward(x) 49 | 50 | def get_scheduler(optimizer, opt): 51 | if opt.lr_policy == 'lambda': 52 | def lambda_rule(epoch): 53 | lr_l = 1.0 - max(0, epoch + 1 + opt.epoch_count - opt.niter) / float(opt.niter_decay + 1) 54 | return lr_l 55 | scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule) 56 | elif opt.lr_policy == 'step': 57 | scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1) 58 | elif opt.lr_policy == 'plateau': 59 | scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5) 60 | else: 61 | return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy) 62 | return scheduler 63 | 64 | 65 | def init_weights(net, init_type, init_gain): 66 | def init_func(m): 67 | classname = m.__class__.__name__ 68 | if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): 69 | if init_type == 'normal': 70 | init.normal_(m.weight.data, 0.0, init_gain) 71 | elif init_type == 'xavier': 72 | init.xavier_normal_(m.weight.data, gain=init_gain) 73 | elif init_type == 'kaiming': 74 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 75 | elif init_type == 'orthogonal': 76 | init.orthogonal_(m.weight.data, gain=init_gain) 77 | else: 78 | raise NotImplementedError('initialization method [%s] is not implemented' % init_type) 79 | elif classname.find('BatchNorm2d') != -1: 80 | init.normal_(m.weight.data, 1.0, init_gain) 81 | init.constant_(m.bias.data, 0.0) 82 | net.apply(init_func) 83 | 84 | 85 | def init_net(net, init_type, init_gain, gpu_ids): 86 | if len(gpu_ids) > 0: 87 | assert(torch.cuda.is_available()) 88 | net.cuda(gpu_ids[0]) 89 | net = net.cuda() 90 | net = torch.nn.DataParallel(net, gpu_ids) 91 | if init_type != 'none': 92 | init_weights(net, init_type, init_gain) 93 | return net 94 | 95 | 96 | def define_classifier(input_nc, ncf, ninput_edges, nclasses, opt, gpu_ids, arch, init_type, init_gain): 97 | net = None 98 | norm_layer = get_norm_layer(norm_type=opt.norm, num_groups=opt.num_groups) 99 | 100 | if arch == 'mconvnet': 101 | net = MeshConvNet(norm_layer, input_nc, ncf, nclasses, ninput_edges, opt.pool_res, opt.fc_n, 102 | opt.resblocks) 103 | elif arch == 'meshunet': 104 | down_convs = [input_nc] + ncf 105 | up_convs = ncf[::-1] + [nclasses] 106 | pool_res = [ninput_edges] + opt.pool_res 107 | net = MeshEncoderDecoder(pool_res, down_convs, up_convs, blocks=opt.resblocks, 108 | transfer_data=True) 109 | else: 110 | raise NotImplementedError('Encoder model name [%s] is not recognized' % arch) 111 | return init_net(net, init_type, init_gain, gpu_ids) 112 | 113 | def define_loss(opt): 114 | if opt.dataset_mode == 'classification': 115 | loss = torch.nn.CrossEntropyLoss() 116 | elif opt.dataset_mode == 'segmentation': 117 | loss = torch.nn.CrossEntropyLoss(ignore_index=-1) 118 | return loss 119 | 120 | ############################################################################## 121 | # Classes For Classification / Segmentation Networks 122 | ############################################################################## 123 | 124 | class MeshConvNet(nn.Module): 125 | """Network for learning a global shape descriptor (classification) 126 | """ 127 | def __init__(self, norm_layer, nf0, conv_res, nclasses, input_res, pool_res, fc_n, 128 | nresblocks=3): 129 | super(MeshConvNet, self).__init__() 130 | self.k = [nf0] + conv_res 131 | self.res = [input_res] + pool_res 132 | norm_args = get_norm_args(norm_layer, self.k[1:]) 133 | 134 | for i, ki in enumerate(self.k[:-1]): 135 | setattr(self, 'conv{}'.format(i), MResConv(ki, self.k[i + 1], nresblocks)) 136 | setattr(self, 'norm{}'.format(i), norm_layer(**norm_args[i])) 137 | setattr(self, 'pool{}'.format(i), MeshPool(self.res[i + 1])) 138 | 139 | 140 | self.gp = torch.nn.AvgPool1d(self.res[-1]) 141 | # self.gp = torch.nn.MaxPool1d(self.res[-1]) 142 | self.fc1 = nn.Linear(self.k[-1], fc_n) 143 | self.fc2 = nn.Linear(fc_n, nclasses) 144 | 145 | def forward(self, x, mesh): 146 | 147 | for i in range(len(self.k) - 1): 148 | x = getattr(self, 'conv{}'.format(i))(x, mesh) 149 | x = F.relu(getattr(self, 'norm{}'.format(i))(x)) 150 | x = getattr(self, 'pool{}'.format(i))(x, mesh) 151 | 152 | x = self.gp(x) 153 | x = x.view(-1, self.k[-1]) 154 | 155 | x = F.relu(self.fc1(x)) 156 | x = self.fc2(x) 157 | return x 158 | 159 | class MResConv(nn.Module): 160 | def __init__(self, in_channels, out_channels, skips=1): 161 | super(MResConv, self).__init__() 162 | self.in_channels = in_channels 163 | self.out_channels = out_channels 164 | self.skips = skips 165 | self.conv0 = MeshConv(self.in_channels, self.out_channels, bias=False) 166 | for i in range(self.skips): 167 | setattr(self, 'bn{}'.format(i + 1), nn.BatchNorm2d(self.out_channels)) 168 | setattr(self, 'conv{}'.format(i + 1), 169 | MeshConv(self.out_channels, self.out_channels, bias=False)) 170 | 171 | def forward(self, x, mesh): 172 | x = self.conv0(x, mesh) 173 | x1 = x 174 | for i in range(self.skips): 175 | x = getattr(self, 'bn{}'.format(i + 1))(F.relu(x)) 176 | x = getattr(self, 'conv{}'.format(i + 1))(x, mesh) 177 | x += x1 178 | x = F.relu(x) 179 | return x 180 | 181 | 182 | class MeshEncoderDecoder(nn.Module): 183 | """Network for fully-convolutional tasks (segmentation) 184 | """ 185 | def __init__(self, pools, down_convs, up_convs, blocks=0, transfer_data=True): 186 | super(MeshEncoderDecoder, self).__init__() 187 | self.transfer_data = transfer_data 188 | self.encoder = MeshEncoder(pools, down_convs, blocks=blocks) 189 | unrolls = pools[:-1].copy() 190 | unrolls.reverse() 191 | self.decoder = MeshDecoder(unrolls, up_convs, blocks=blocks, transfer_data=transfer_data) 192 | 193 | def forward(self, x, meshes): 194 | fe, before_pool = self.encoder((x, meshes)) 195 | fe = self.decoder((fe, meshes), before_pool) 196 | return fe 197 | 198 | def __call__(self, x, meshes): 199 | return self.forward(x, meshes) 200 | 201 | class DownConv(nn.Module): 202 | def __init__(self, in_channels, out_channels, blocks=0, pool=0): 203 | super(DownConv, self).__init__() 204 | self.bn = [] 205 | self.pool = None 206 | self.conv1 = MeshConv(in_channels, out_channels) 207 | self.conv2 = [] 208 | for _ in range(blocks): 209 | self.conv2.append(MeshConv(out_channels, out_channels)) 210 | self.conv2 = nn.ModuleList(self.conv2) 211 | for _ in range(blocks + 1): 212 | self.bn.append(nn.InstanceNorm2d(out_channels)) 213 | self.bn = nn.ModuleList(self.bn) 214 | if pool: 215 | self.pool = MeshPool(pool) 216 | 217 | def __call__(self, x): 218 | return self.forward(x) 219 | 220 | def forward(self, x): 221 | fe, meshes = x 222 | x1 = self.conv1(fe, meshes) 223 | if self.bn: 224 | x1 = self.bn[0](x1) 225 | x1 = F.relu(x1) 226 | x2 = x1 227 | for idx, conv in enumerate(self.conv2): 228 | x2 = conv(x1, meshes) 229 | if self.bn: 230 | x2 = self.bn[idx + 1](x2) 231 | x2 = x2 + x1 232 | x2 = F.relu(x2) 233 | x1 = x2 234 | x2 = x2.squeeze(3) 235 | before_pool = None 236 | if self.pool: 237 | before_pool = x2 238 | x2 = self.pool(x2, meshes) 239 | return x2, before_pool 240 | 241 | 242 | class UpConv(nn.Module): 243 | def __init__(self, in_channels, out_channels, blocks=0, unroll=0, residual=True, 244 | batch_norm=True, transfer_data=True): 245 | super(UpConv, self).__init__() 246 | self.residual = residual 247 | self.bn = [] 248 | self.unroll = None 249 | self.transfer_data = transfer_data 250 | self.up_conv = MeshConv(in_channels, out_channels) 251 | if transfer_data: 252 | self.conv1 = MeshConv(2 * out_channels, out_channels) 253 | else: 254 | self.conv1 = MeshConv(out_channels, out_channels) 255 | self.conv2 = [] 256 | for _ in range(blocks): 257 | self.conv2.append(MeshConv(out_channels, out_channels)) 258 | self.conv2 = nn.ModuleList(self.conv2) 259 | if batch_norm: 260 | for _ in range(blocks + 1): 261 | self.bn.append(nn.InstanceNorm2d(out_channels)) 262 | self.bn = nn.ModuleList(self.bn) 263 | if unroll: 264 | self.unroll = MeshUnpool(unroll) 265 | 266 | def __call__(self, x, from_down=None): 267 | return self.forward(x, from_down) 268 | 269 | def forward(self, x, from_down): 270 | from_up, meshes = x 271 | x1 = self.up_conv(from_up, meshes).squeeze(3) 272 | if self.unroll: 273 | x1 = self.unroll(x1, meshes) 274 | if self.transfer_data: 275 | x1 = torch.cat((x1, from_down), 1) 276 | x1 = self.conv1(x1, meshes) 277 | if self.bn: 278 | x1 = self.bn[0](x1) 279 | x1 = F.relu(x1) 280 | x2 = x1 281 | for idx, conv in enumerate(self.conv2): 282 | x2 = conv(x1, meshes) 283 | if self.bn: 284 | x2 = self.bn[idx + 1](x2) 285 | if self.residual: 286 | x2 = x2 + x1 287 | x2 = F.relu(x2) 288 | x1 = x2 289 | x2 = x2.squeeze(3) 290 | return x2 291 | 292 | 293 | class MeshEncoder(nn.Module): 294 | def __init__(self, pools, convs, fcs=None, blocks=0, global_pool=None): 295 | super(MeshEncoder, self).__init__() 296 | self.fcs = None 297 | self.convs = [] 298 | for i in range(len(convs) - 1): 299 | if i + 1 < len(pools): 300 | pool = pools[i + 1] 301 | else: 302 | pool = 0 303 | self.convs.append(DownConv(convs[i], convs[i + 1], blocks=blocks, pool=pool)) 304 | self.global_pool = None 305 | if fcs is not None: 306 | self.fcs = [] 307 | self.fcs_bn = [] 308 | last_length = convs[-1] 309 | if global_pool is not None: 310 | if global_pool == 'max': 311 | self.global_pool = nn.MaxPool1d(pools[-1]) 312 | elif global_pool == 'avg': 313 | self.global_pool = nn.AvgPool1d(pools[-1]) 314 | else: 315 | assert False, 'global_pool %s is not defined' % global_pool 316 | else: 317 | last_length *= pools[-1] 318 | if fcs[0] == last_length: 319 | fcs = fcs[1:] 320 | for length in fcs: 321 | self.fcs.append(nn.Linear(last_length, length)) 322 | self.fcs_bn.append(nn.InstanceNorm1d(length)) 323 | last_length = length 324 | self.fcs = nn.ModuleList(self.fcs) 325 | self.fcs_bn = nn.ModuleList(self.fcs_bn) 326 | self.convs = nn.ModuleList(self.convs) 327 | reset_params(self) 328 | 329 | def forward(self, x): 330 | fe, meshes = x 331 | encoder_outs = [] 332 | for conv in self.convs: 333 | fe, before_pool = conv((fe, meshes)) 334 | encoder_outs.append(before_pool) 335 | if self.fcs is not None: 336 | if self.global_pool is not None: 337 | fe = self.global_pool(fe) 338 | fe = fe.contiguous().view(fe.size()[0], -1) 339 | for i in range(len(self.fcs)): 340 | fe = self.fcs[i](fe) 341 | if self.fcs_bn: 342 | x = fe.unsqueeze(1) 343 | fe = self.fcs_bn[i](x).squeeze(1) 344 | if i < len(self.fcs) - 1: 345 | fe = F.relu(fe) 346 | return fe, encoder_outs 347 | 348 | def __call__(self, x): 349 | return self.forward(x) 350 | 351 | 352 | class MeshDecoder(nn.Module): 353 | def __init__(self, unrolls, convs, blocks=0, batch_norm=True, transfer_data=True): 354 | super(MeshDecoder, self).__init__() 355 | self.up_convs = [] 356 | for i in range(len(convs) - 2): 357 | if i < len(unrolls): 358 | unroll = unrolls[i] 359 | else: 360 | unroll = 0 361 | self.up_convs.append(UpConv(convs[i], convs[i + 1], blocks=blocks, unroll=unroll, 362 | batch_norm=batch_norm, transfer_data=transfer_data)) 363 | self.final_conv = UpConv(convs[-2], convs[-1], blocks=blocks, unroll=False, 364 | batch_norm=batch_norm, transfer_data=False) 365 | self.up_convs = nn.ModuleList(self.up_convs) 366 | reset_params(self) 367 | 368 | def forward(self, x, encoder_outs=None): 369 | fe, meshes = x 370 | for i, up_conv in enumerate(self.up_convs): 371 | before_pool = None 372 | if encoder_outs is not None: 373 | before_pool = encoder_outs[-(i+2)] 374 | fe = up_conv((fe, meshes), before_pool) 375 | fe = self.final_conv((fe, meshes)) 376 | return fe 377 | 378 | def __call__(self, x, encoder_outs=None): 379 | return self.forward(x, encoder_outs) 380 | 381 | def reset_params(model): # todo replace with my init 382 | for i, m in enumerate(model.modules()): 383 | weight_init(m) 384 | 385 | def weight_init(m): 386 | if isinstance(m, nn.Conv2d): 387 | nn.init.xavier_normal_(m.weight) 388 | nn.init.constant_(m.bias, 0) 389 | -------------------------------------------------------------------------------- /options/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ranahanocka/MeshCNN/5bf0b899d48eb204b9b73bc1af381be20f4d7df1/options/__init__.py -------------------------------------------------------------------------------- /options/base_options.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from util import util 4 | import torch 5 | 6 | 7 | class BaseOptions: 8 | def __init__(self): 9 | self.parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 10 | self.initialized = False 11 | 12 | def initialize(self): 13 | # data params 14 | self.parser.add_argument('--dataroot', required=True, help='path to meshes (should have subfolders train, test)') 15 | self.parser.add_argument('--dataset_mode', choices={"classification", "segmentation"}, default='classification') 16 | self.parser.add_argument('--ninput_edges', type=int, default=750, help='# of input edges (will include dummy edges)') 17 | self.parser.add_argument('--max_dataset_size', type=int, default=float("inf"), help='Maximum number of samples per epoch') 18 | # network params 19 | self.parser.add_argument('--batch_size', type=int, default=16, help='input batch size') 20 | self.parser.add_argument('--arch', type=str, default='mconvnet', help='selects network to use') #todo add choices 21 | self.parser.add_argument('--resblocks', type=int, default=0, help='# of res blocks') 22 | self.parser.add_argument('--fc_n', type=int, default=100, help='# between fc and nclasses') #todo make generic 23 | self.parser.add_argument('--ncf', nargs='+', default=[16, 32, 32], type=int, help='conv filters') 24 | self.parser.add_argument('--pool_res', nargs='+', default=[1140, 780, 580], type=int, help='pooling res') 25 | self.parser.add_argument('--norm', type=str, default='batch',help='instance normalization or batch normalization or group normalization') 26 | self.parser.add_argument('--num_groups', type=int, default=16, help='# of groups for groupnorm') 27 | self.parser.add_argument('--init_type', type=str, default='normal', help='network initialization [normal|xavier|kaiming|orthogonal]') 28 | self.parser.add_argument('--init_gain', type=float, default=0.02, help='scaling factor for normal, xavier and orthogonal.') 29 | # general params 30 | self.parser.add_argument('--num_threads', default=3, type=int, help='# threads for loading data') 31 | self.parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU') 32 | self.parser.add_argument('--name', type=str, default='debug', help='name of the experiment. It decides where to store samples and models') 33 | self.parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here') 34 | self.parser.add_argument('--serial_batches', action='store_true', help='if true, takes meshes in order, otherwise takes them randomly') 35 | self.parser.add_argument('--seed', type=int, help='if specified, uses seed') 36 | # visualization params 37 | self.parser.add_argument('--export_folder', type=str, default='', help='exports intermediate collapses to this folder') 38 | # 39 | self.initialized = True 40 | 41 | def parse(self): 42 | if not self.initialized: 43 | self.initialize() 44 | self.opt, unknown = self.parser.parse_known_args() 45 | self.opt.is_train = self.is_train # train or test 46 | 47 | str_ids = self.opt.gpu_ids.split(',') 48 | self.opt.gpu_ids = [] 49 | for str_id in str_ids: 50 | id = int(str_id) 51 | if id >= 0: 52 | self.opt.gpu_ids.append(id) 53 | # set gpu ids 54 | if len(self.opt.gpu_ids) > 0: 55 | torch.cuda.set_device(self.opt.gpu_ids[0]) 56 | 57 | args = vars(self.opt) 58 | 59 | if self.opt.seed is not None: 60 | import numpy as np 61 | import random 62 | torch.manual_seed(self.opt.seed) 63 | np.random.seed(self.opt.seed) 64 | random.seed(self.opt.seed) 65 | 66 | if self.opt.export_folder: 67 | self.opt.export_folder = os.path.join(self.opt.checkpoints_dir, self.opt.name, self.opt.export_folder) 68 | util.mkdir(self.opt.export_folder) 69 | 70 | if self.is_train: 71 | print('------------ Options -------------') 72 | for k, v in sorted(args.items()): 73 | print('%s: %s' % (str(k), str(v))) 74 | print('-------------- End ----------------') 75 | 76 | # save to the disk 77 | expr_dir = os.path.join(self.opt.checkpoints_dir, self.opt.name) 78 | util.mkdir(expr_dir) 79 | 80 | file_name = os.path.join(expr_dir, 'opt.txt') 81 | with open(file_name, 'wt') as opt_file: 82 | opt_file.write('------------ Options -------------\n') 83 | for k, v in sorted(args.items()): 84 | opt_file.write('%s: %s\n' % (str(k), str(v))) 85 | opt_file.write('-------------- End ----------------\n') 86 | return self.opt 87 | -------------------------------------------------------------------------------- /options/test_options.py: -------------------------------------------------------------------------------- 1 | from .base_options import BaseOptions 2 | 3 | 4 | class TestOptions(BaseOptions): 5 | def initialize(self): 6 | BaseOptions.initialize(self) 7 | self.parser.add_argument('--results_dir', type=str, default='./results/', help='saves results here.') 8 | self.parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc') #todo delete. 9 | self.parser.add_argument('--which_epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model') 10 | self.parser.add_argument('--num_aug', type=int, default=1, help='# of augmentation files') 11 | self.is_train = False -------------------------------------------------------------------------------- /options/train_options.py: -------------------------------------------------------------------------------- 1 | from .base_options import BaseOptions 2 | 3 | class TrainOptions(BaseOptions): 4 | def initialize(self): 5 | BaseOptions.initialize(self) 6 | self.parser.add_argument('--print_freq', type=int, default=10, help='frequency of showing training results on console') 7 | self.parser.add_argument('--save_latest_freq', type=int, default=250, help='frequency of saving the latest results') 8 | self.parser.add_argument('--save_epoch_freq', type=int, default=1, help='frequency of saving checkpoints at the end of epochs') 9 | self.parser.add_argument('--run_test_freq', type=int, default=1, help='frequency of running test in training script') 10 | self.parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model') 11 | self.parser.add_argument('--epoch_count', type=int, default=1, help='the starting epoch count, we save the model by , +, ...') 12 | self.parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc') 13 | self.parser.add_argument('--which_epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model') 14 | self.parser.add_argument('--niter', type=int, default=100, help='# of iter at starting learning rate') 15 | self.parser.add_argument('--niter_decay', type=int, default=500, help='# of iter to linearly decay learning rate to zero') 16 | self.parser.add_argument('--beta1', type=float, default=0.9, help='momentum term of adam') 17 | self.parser.add_argument('--lr', type=float, default=0.0002, help='initial learning rate for adam') 18 | self.parser.add_argument('--lr_policy', type=str, default='lambda', help='learning rate policy: lambda|step|plateau') 19 | self.parser.add_argument('--lr_decay_iters', type=int, default=50, help='multiply by a gamma every lr_decay_iters iterations') 20 | # data augmentation stuff 21 | self.parser.add_argument('--num_aug', type=int, default=10, help='# of augmentation files') 22 | self.parser.add_argument('--scale_verts', action='store_true', help='non-uniformly scale the mesh e.g., in x, y or z') 23 | self.parser.add_argument('--slide_verts', type=float, default=0, help='percent vertices which will be shifted along the mesh surface') 24 | self.parser.add_argument('--flip_edges', type=float, default=0, help='percent of edges to randomly flip') 25 | # tensorboard visualization 26 | self.parser.add_argument('--no_vis', action='store_true', help='will not use tensorboard') 27 | self.parser.add_argument('--verbose_plot', action='store_true', help='plots network weights, etc.') 28 | self.is_train = True 29 | -------------------------------------------------------------------------------- /scripts/coseg_seg/get_data.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | DATADIR='datasets' #location where data gets downloaded to 4 | 5 | echo "downloading the data and putting it in: " $DATADIR 6 | mkdir -p $DATADIR && cd $DATADIR 7 | wget https://www.dropbox.com/s/34vy4o5fthhz77d/coseg.tar.gz 8 | tar -xzvf coseg.tar.gz && rm coseg.tar.gz -------------------------------------------------------------------------------- /scripts/coseg_seg/get_pretrained.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | CHECKPOINT=checkpoints/coseg_aliens 4 | mkdir -p $CHECKPOINT 5 | 6 | #gets the pretrained weights 7 | wget https://www.dropbox.com/s/er7my13k9dwg9ii/coseg_aliens_wts.tar.gz 8 | tar -xzvf coseg_aliens_wts.tar.gz && rm coseg_aliens_wts.tar.gz 9 | mv latest_net.pth $CHECKPOINT 10 | echo "downloaded pretrained weights to" $CHECKPOINT -------------------------------------------------------------------------------- /scripts/coseg_seg/test.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | ## run the test and export collapses 4 | python test.py \ 5 | --dataroot datasets/coseg_aliens \ 6 | --name coseg_aliens \ 7 | --arch meshunet \ 8 | --dataset_mode segmentation \ 9 | --ncf 32 64 128 256 \ 10 | --ninput_edges 2280 \ 11 | --pool_res 1800 1350 600 \ 12 | --resblocks 3 \ 13 | --batch_size 12 \ 14 | --export_folder meshes \ -------------------------------------------------------------------------------- /scripts/coseg_seg/train.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | ## run the training 4 | python train.py \ 5 | --dataroot datasets/coseg_aliens \ 6 | --name coseg_aliens \ 7 | --arch meshunet \ 8 | --dataset_mode segmentation \ 9 | --ncf 32 64 128 256 \ 10 | --ninput_edges 2280 \ 11 | --pool_res 1800 1350 600 \ 12 | --resblocks 3 \ 13 | --lr 0.001 \ 14 | --batch_size 12 \ 15 | --num_aug 20 \ 16 | --slide_verts 0.2 \ 17 | 18 | 19 | # 20 | # python train.py --dataroot datasets/coseg_vases --name coseg_vases --arch meshunet --dataset_mode 21 | segmentation --ncf 32 64 128 256 --ninput_edges 1500 --pool_res 1050 600 300 --resblocks 3 --lr 0.001 --batch_size 12 --num_aug 20 -------------------------------------------------------------------------------- /scripts/coseg_seg/view.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | python util/mesh_viewer.py \ 4 | --files \ 5 | checkpoints/coseg_aliens/meshes/142_0.obj \ 6 | checkpoints/coseg_aliens/meshes/142_2.obj \ 7 | checkpoints/coseg_aliens/meshes/142_3.obj \ -------------------------------------------------------------------------------- /scripts/cubes/get_data.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | DATADIR='datasets' #location where data gets downloaded to 4 | 5 | # get data 6 | mkdir -p $DATADIR && cd $DATADIR 7 | wget https://www.dropbox.com/s/2bxs5f9g60wa0wr/cubes.tar.gz 8 | tar -xzvf cubes.tar.gz && rm cubes.tar.gz 9 | echo "downloaded the data and put it in: " $DATADIR -------------------------------------------------------------------------------- /scripts/cubes/get_pretrained.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | CHECKPOINT='checkpoints/cubes' 4 | 5 | # get pretrained model 6 | mkdir -p $CHECKPOINT 7 | wget https://www.dropbox.com/s/fg7wum39bmlxr7w/cubes_wts.tar.gz 8 | tar -xzvf cubes_wts.tar.gz && rm cubes_wts.tar.gz 9 | mv latest_net.pth $CHECKPOINT 10 | echo "downloaded pretrained weights to" $CHECKPOINT -------------------------------------------------------------------------------- /scripts/cubes/test.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | ## run the test and export collapses 4 | python test.py \ 5 | --dataroot datasets/cubes \ 6 | --name cubes \ 7 | --ncf 64 128 256 256 \ 8 | --pool_res 600 450 300 210 \ 9 | --norm group \ 10 | --resblocks 1 \ 11 | --export_folder meshes \ -------------------------------------------------------------------------------- /scripts/cubes/train.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | ## run the training 4 | python train.py \ 5 | --dataroot datasets/cubes \ 6 | --name cubes \ 7 | --ncf 64 128 256 256 \ 8 | --pool_res 600 450 300 210 \ 9 | --norm group \ 10 | --resblocks 1 \ 11 | --flip_edges 0.2 \ 12 | --slide_verts 0.2 \ 13 | --num_aug 20 \ -------------------------------------------------------------------------------- /scripts/cubes/view.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | python util/mesh_viewer.py \ 4 | --files checkpoints/cubes/meshes/horseshoe_4_0.obj \ 5 | checkpoints/cubes/meshes/horseshoe_4_2.obj \ 6 | checkpoints/cubes/meshes/horseshoe_4_3.obj \ 7 | checkpoints/cubes/meshes/horseshoe_4_4.obj -------------------------------------------------------------------------------- /scripts/dataprep/blender_process.py: -------------------------------------------------------------------------------- 1 | import bpy 2 | import os 3 | import sys 4 | 5 | 6 | ''' 7 | Simplifies mesh to target number of faces 8 | Requires Blender 2.8 9 | Author: Rana Hanocka 10 | 11 | @input: 12 | 13 | number of target faces 14 | name of simplified .obj file 15 | 16 | @output: 17 | simplified mesh .obj 18 | to run it from cmd line: 19 | /opt/blender/blender --background --python blender_process.py /home/rana/koala.obj 1000 /home/rana/koala_1000.obj 20 | ''' 21 | 22 | class Process: 23 | def __init__(self, obj_file, target_faces, export_name): 24 | mesh = self.load_obj(obj_file) 25 | self.simplify(mesh, target_faces) 26 | self.export_obj(mesh, export_name) 27 | 28 | def load_obj(self, obj_file): 29 | bpy.ops.import_scene.obj(filepath=obj_file, axis_forward='-Z', axis_up='Y', filter_glob="*.obj;*.mtl", use_edges=True, 30 | use_smooth_groups=True, use_split_objects=False, use_split_groups=False, 31 | use_groups_as_vgroups=False, use_image_search=True, split_mode='ON') 32 | ob = bpy.context.selected_objects[0] 33 | return ob 34 | 35 | def subsurf(self, mesh): 36 | # subdivide mesh 37 | bpy.context.view_layer.objects.active = mesh 38 | mod = mesh.modifiers.new(name='Subsurf', type='SUBSURF') 39 | mod.subdivision_type = 'SIMPLE' 40 | bpy.ops.object.modifier_apply(modifier=mod.name) 41 | # now triangulate 42 | mod = mesh.modifiers.new(name='Triangluate', type='TRIANGULATE') 43 | bpy.ops.object.modifier_apply(modifier=mod.name) 44 | 45 | def simplify(self, mesh, target_faces): 46 | bpy.context.view_layer.objects.active = mesh 47 | mod = mesh.modifiers.new(name='Decimate', type='DECIMATE') 48 | bpy.context.object.modifiers['Decimate'].use_collapse_triangulate = True 49 | # 50 | nfaces = len(mesh.data.polygons) 51 | if nfaces < target_faces: 52 | self.subsurf(mesh) 53 | nfaces = len(mesh.data.polygons) 54 | ratio = target_faces / float(nfaces) 55 | mod.ratio = float('%s' % ('%.6g' % (ratio))) 56 | print('faces: ', mod.face_count, mod.ratio) 57 | bpy.ops.object.modifier_apply(modifier=mod.name) 58 | 59 | 60 | def export_obj(self, mesh, export_name): 61 | outpath = os.path.dirname(export_name) 62 | if not os.path.isdir(outpath): os.makedirs(outpath) 63 | print('EXPORTING', export_name) 64 | bpy.ops.object.select_all(action='DESELECT') 65 | mesh.select_set(state=True) 66 | bpy.ops.export_scene.obj(filepath=export_name, check_existing=False, filter_glob="*.obj;*.mtl", 67 | use_selection=True, use_animation=False, use_mesh_modifiers=True, use_edges=True, 68 | use_smooth_groups=False, use_smooth_groups_bitflags=False, use_normals=True, 69 | use_uvs=False, use_materials=False, use_triangles=True, use_nurbs=False, 70 | use_vertex_groups=False, use_blen_objects=True, group_by_object=False, 71 | group_by_material=False, keep_vertex_order=True, global_scale=1, path_mode='AUTO', 72 | axis_forward='-Z', axis_up='Y') 73 | 74 | obj_file = sys.argv[-3] 75 | target_faces = int(sys.argv[-2]) 76 | export_name = sys.argv[-1] 77 | 78 | 79 | print('args: ', obj_file, target_faces, export_name) 80 | blender = Process(obj_file, target_faces, export_name) 81 | -------------------------------------------------------------------------------- /scripts/human_seg/get_data.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | DATADIR='datasets' #location where data gets downloaded to 4 | 5 | # get data 6 | echo "downloading the data and putting it in: " $DATADIR 7 | mkdir -p $DATADIR && cd $DATADIR 8 | wget https://www.dropbox.com/s/s3n05sw0zg27fz3/human_seg.tar.gz 9 | tar -xzvf human_seg.tar.gz && rm human_seg.tar.gz -------------------------------------------------------------------------------- /scripts/human_seg/get_pretrained.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | CHECKPOINT='checkpoints/human_seg' 4 | mkdir -p $CHECKPOINT 5 | 6 | wget https://www.dropbox.com/s/8i26y7cpi6st2ra/human_seg_wts.tar.gz 7 | tar -xzvf human_seg_wts.tar.gz && rm human_seg_wts.tar.gz 8 | mv latest_net.pth $CHECKPOINT 9 | echo "downloaded pretrained weights to" $CHECKPOINT 10 | -------------------------------------------------------------------------------- /scripts/human_seg/test.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | ## run the test and export collapses 4 | python test.py \ 5 | --dataroot datasets/human_seg \ 6 | --name human_seg \ 7 | --arch meshunet \ 8 | --dataset_mode segmentation \ 9 | --ncf 32 64 128 256 \ 10 | --ninput_edges 2280 \ 11 | --pool_res 1800 1350 600 \ 12 | --resblocks 3 \ 13 | --batch_size 12 \ 14 | --export_folder meshes \ -------------------------------------------------------------------------------- /scripts/human_seg/train.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | ## run the training 4 | python train.py \ 5 | --dataroot datasets/human_seg \ 6 | --name human_seg \ 7 | --arch meshunet \ 8 | --dataset_mode segmentation \ 9 | --ncf 32 64 128 256 \ 10 | --ninput_edges 2280 \ 11 | --pool_res 1800 1350 600 \ 12 | --resblocks 3 \ 13 | --batch_size 12 \ 14 | --lr 0.001 \ 15 | --num_aug 20 \ 16 | --slide_verts 0.2 \ -------------------------------------------------------------------------------- /scripts/human_seg/view.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | python util/mesh_viewer.py \ 4 | --files \ 5 | checkpoints/human_seg/meshes/shrec__14_0.obj -------------------------------------------------------------------------------- /scripts/shrec/get_data.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | DATADIR='datasets' #location where data gets downloaded to 4 | 5 | # get data 6 | mkdir -p $DATADIR && cd $DATADIR 7 | wget https://www.dropbox.com/s/w16st84r6wc57u7/shrec_16.tar.gz 8 | tar -xzvf shrec_16.tar.gz && rm shrec_16.tar.gz 9 | echo "downloaded the data and putting it in: " $DATADIR 10 | -------------------------------------------------------------------------------- /scripts/shrec/get_pretrained.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | CHECKPOINT='checkpoints/shrec16' 4 | 5 | mkdir -p $CHECKPOINT 6 | wget https://www.dropbox.com/s/wqq1qxj4fjbpfas/shrec16_wts.tar.gz 7 | tar -xzvf shrec16_wts.tar.gz && rm shrec16_wts.tar.gz 8 | mv latest_net.pth $CHECKPOINT 9 | echo "downloaded pretrained weights to" $CHECKPOINT -------------------------------------------------------------------------------- /scripts/shrec/test.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | ## run the test and export collapses 4 | python test.py \ 5 | --dataroot datasets/shrec_16 \ 6 | --name shrec16 \ 7 | --ncf 64 128 256 256 \ 8 | --pool_res 600 450 300 180 \ 9 | --norm group \ 10 | --resblocks 1 \ 11 | --export_folder meshes \ -------------------------------------------------------------------------------- /scripts/shrec/train.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | ## run the training 4 | python train.py \ 5 | --dataroot datasets/shrec_16 \ 6 | --name shrec16 \ 7 | --ncf 64 128 256 256 \ 8 | --pool_res 600 450 300 180 \ 9 | --norm group \ 10 | --resblocks 1 \ 11 | --flip_edges 0.2 \ 12 | --slide_verts 0.2 \ 13 | --num_aug 20 \ 14 | --niter_decay 100 \ -------------------------------------------------------------------------------- /scripts/shrec/view.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | python util/mesh_viewer.py \ 4 | --files \ 5 | checkpoints/shrec16/meshes/T74_0.obj \ 6 | checkpoints/shrec16/meshes/T74_3.obj \ 7 | checkpoints/shrec16/meshes/T74_4.obj -------------------------------------------------------------------------------- /scripts/test_general.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import os 3 | import shutil 4 | import glob 5 | import subprocess 6 | ''' 7 | scripts for unit testing 8 | ''' 9 | 10 | 11 | def get_data(dset): 12 | dpaths = glob.glob('./datasets/{}*'.format(dset)) 13 | [shutil.rmtree(d) for d in dpaths] 14 | cmd = './scripts/{}/get_data.sh > /dev/null 2>&1'.format(dset) 15 | os.system(cmd) 16 | 17 | def add_args(file, temp_file, new_args): 18 | with open(file) as f: 19 | tokens = f.readlines() 20 | # now make the config so it only trains for one iteration 21 | tokens[-1] = tokens[-1] + '\n' 22 | for arg in new_args: 23 | tokens.append(arg) 24 | with open(temp_file, 'w') as f: 25 | f.writelines(tokens) 26 | 27 | def run_train(dset): 28 | train_file = './scripts/{}/train.sh'.format(dset) 29 | temp_train_file = './scripts/{}/train_temp.sh'.format(dset) 30 | p = subprocess.run(['cp', '-p', '--preserve', train_file, temp_train_file]) 31 | add_args(train_file, temp_train_file, ['--niter_decay 0 \\\n', '--niter 1 \\\n', '--max_dataset_size 2 \\\n', '--gpu_ids -1 \\']) 32 | cmd = "bash -c 'source ~/anaconda3/bin/activate ~/anaconda3/envs/meshcnn && {} > /dev/null 2>&1'".format(temp_train_file) 33 | os.system(cmd) 34 | os.remove(temp_train_file) 35 | 36 | def get_pretrained(dset): 37 | cmd = './scripts/{}/get_pretrained.sh > /dev/null 2>&1'.format(dset) 38 | os.system(cmd) 39 | 40 | def run_test(dset): 41 | test_file = './scripts/{}/test.sh'.format(dset) 42 | temp_test_file = './scripts/{}/test_temp.sh'.format(dset) 43 | p = subprocess.run(['cp', '-p', '--preserve', test_file, temp_test_file]) 44 | add_args(test_file, temp_test_file, ['--gpu_ids -1 \\']) 45 | # now run inference 46 | cmd = "bash -c 'source ~/anaconda3/bin/activate ~/anaconda3/envs/meshcnn && {}'".format(temp_test_file) 47 | proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, shell=True) 48 | (_out, err) = proc.communicate() 49 | out = str(_out) 50 | idf0 = 'TEST ACC: [' 51 | token = out[out.find(idf0) + len(idf0):] 52 | idf1 = '%]' 53 | accs = token[:token.find(idf1)] 54 | acc = float(accs) 55 | if dset == 'shrec': 56 | assert acc == 99.167, "shrec accuracy was {} and not 99.167".format(acc) 57 | if dset == 'human_seg': 58 | assert acc == 92.554, "human_seg accuracy was {} and not 92.554".format(acc) 59 | os.remove(temp_test_file) 60 | 61 | def run_dataset(dset): 62 | get_data(dset) 63 | run_train(dset) 64 | get_pretrained(dset) 65 | run_test(dset) 66 | 67 | def test_shrec(): 68 | run_dataset('shrec') 69 | 70 | def test_human_seg(): 71 | run_dataset('human_seg') -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | from options.test_options import TestOptions 2 | from data import DataLoader 3 | from models import create_model 4 | from util.writer import Writer 5 | 6 | 7 | def run_test(epoch=-1): 8 | print('Running Test') 9 | opt = TestOptions().parse() 10 | opt.serial_batches = True # no shuffle 11 | dataset = DataLoader(opt) 12 | model = create_model(opt) 13 | writer = Writer(opt) 14 | # test 15 | writer.reset_counter() 16 | for i, data in enumerate(dataset): 17 | model.set_input(data) 18 | ncorrect, nexamples = model.test() 19 | writer.update_counter(ncorrect, nexamples) 20 | writer.print_acc(epoch, writer.acc) 21 | return writer.acc 22 | 23 | 24 | if __name__ == '__main__': 25 | run_test() 26 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import time 2 | from options.train_options import TrainOptions 3 | from data import DataLoader 4 | from models import create_model 5 | from util.writer import Writer 6 | from test import run_test 7 | 8 | if __name__ == '__main__': 9 | opt = TrainOptions().parse() 10 | dataset = DataLoader(opt) 11 | dataset_size = len(dataset) 12 | print('#training meshes = %d' % dataset_size) 13 | 14 | model = create_model(opt) 15 | writer = Writer(opt) 16 | total_steps = 0 17 | 18 | for epoch in range(opt.epoch_count, opt.niter + opt.niter_decay + 1): 19 | epoch_start_time = time.time() 20 | iter_data_time = time.time() 21 | epoch_iter = 0 22 | 23 | for i, data in enumerate(dataset): 24 | iter_start_time = time.time() 25 | if total_steps % opt.print_freq == 0: 26 | t_data = iter_start_time - iter_data_time 27 | total_steps += opt.batch_size 28 | epoch_iter += opt.batch_size 29 | model.set_input(data) 30 | model.optimize_parameters() 31 | 32 | if total_steps % opt.print_freq == 0: 33 | loss = model.loss 34 | t = (time.time() - iter_start_time) / opt.batch_size 35 | writer.print_current_losses(epoch, epoch_iter, loss, t, t_data) 36 | writer.plot_loss(loss, epoch, epoch_iter, dataset_size) 37 | 38 | if i % opt.save_latest_freq == 0: 39 | print('saving the latest model (epoch %d, total_steps %d)' % 40 | (epoch, total_steps)) 41 | model.save_network('latest') 42 | 43 | iter_data_time = time.time() 44 | if epoch % opt.save_epoch_freq == 0: 45 | print('saving the model at the end of epoch %d, iters %d' % 46 | (epoch, total_steps)) 47 | model.save_network('latest') 48 | model.save_network(epoch) 49 | 50 | print('End of epoch %d / %d \t Time Taken: %d sec' % 51 | (epoch, opt.niter + opt.niter_decay, time.time() - epoch_start_time)) 52 | model.update_learning_rate() 53 | if opt.verbose_plot: 54 | writer.plot_model_wts(model, epoch) 55 | 56 | if epoch % opt.run_test_freq == 0: 57 | acc = run_test(epoch) 58 | writer.plot_acc(acc, epoch) 59 | 60 | writer.close() 61 | -------------------------------------------------------------------------------- /util/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ranahanocka/MeshCNN/5bf0b899d48eb204b9b73bc1af381be20f4d7df1/util/__init__.py -------------------------------------------------------------------------------- /util/mesh_viewer.py: -------------------------------------------------------------------------------- 1 | import mpl_toolkits.mplot3d as a3 2 | import matplotlib.colors as colors 3 | import pylab as pl 4 | import numpy as np 5 | 6 | V = np.array 7 | r2h = lambda x: colors.rgb2hex(tuple(map(lambda y: y / 255., x))) 8 | surface_color = r2h((255, 230, 205)) 9 | edge_color = r2h((90, 90, 90)) 10 | edge_colors = (r2h((15, 167, 175)), r2h((230, 81, 81)), r2h((142, 105, 252)), r2h((248, 235, 57)), 11 | r2h((51, 159, 255)), r2h((225, 117, 231)), r2h((97, 243, 185)), r2h((161, 183, 196))) 12 | 13 | 14 | 15 | 16 | def init_plot(): 17 | ax = pl.figure().add_subplot(111, projection='3d') 18 | # hide axis, thank to 19 | # https://stackoverflow.com/questions/29041326/3d-plot-with-matplotlib-hide-axes-but-keep-axis-labels/ 20 | ax.w_xaxis.set_pane_color((1.0, 1.0, 1.0, 0.0)) 21 | ax.w_yaxis.set_pane_color((1.0, 1.0, 1.0, 0.0)) 22 | ax.w_zaxis.set_pane_color((1.0, 1.0, 1.0, 0.0)) 23 | # Get rid of the spines 24 | ax.w_xaxis.line.set_color((1.0, 1.0, 1.0, 0.0)) 25 | ax.w_yaxis.line.set_color((1.0, 1.0, 1.0, 0.0)) 26 | ax.w_zaxis.line.set_color((1.0, 1.0, 1.0, 0.0)) 27 | # Get rid of the ticks 28 | ax.set_xticks([]) 29 | ax.set_yticks([]) 30 | ax.set_zticks([]) 31 | return (ax, [np.inf, -np.inf, np.inf, -np.inf, np.inf, -np.inf]) 32 | 33 | 34 | def update_lim(mesh, plot): 35 | vs = mesh[0] 36 | for i in range(3): 37 | plot[1][2 * i] = min(plot[1][2 * i], vs[:, i].min()) 38 | plot[1][2 * i + 1] = max(plot[1][2 * i], vs[:, i].max()) 39 | return plot 40 | 41 | 42 | def update_plot(mesh, plot): 43 | if plot is None: 44 | plot = init_plot() 45 | return update_lim(mesh, plot) 46 | 47 | 48 | def surfaces(mesh, plot): 49 | vs, faces, edges = mesh 50 | vtx = vs[faces] 51 | edgecolor = edge_color if not len(edges) else 'none' 52 | tri = a3.art3d.Poly3DCollection(vtx, facecolors=surface_color +'55', edgecolors=edgecolor, 53 | linewidths=.5, linestyles='dashdot') 54 | plot[0].add_collection3d(tri) 55 | return plot 56 | 57 | 58 | def segments(mesh, plot): 59 | vs, _, edges = mesh 60 | for edge_c, edge_group in enumerate(edges): 61 | for edge_idx in edge_group: 62 | edge = vs[edge_idx] 63 | line = a3.art3d.Line3DCollection([edge], linewidths=.5, linestyles='dashdot') 64 | line.set_color(edge_colors[edge_c % len(edge_colors)]) 65 | plot[0].add_collection3d(line) 66 | return plot 67 | 68 | 69 | def plot_mesh(mesh, *whats, show=True, plot=None): 70 | for what in [update_plot] + list(whats): 71 | plot = what(mesh, plot) 72 | if show: 73 | li = max(plot[1][1], plot[1][3], plot[1][5]) 74 | plot[0].auto_scale_xyz([0, li], [0, li], [0, li]) 75 | pl.tight_layout() 76 | pl.show() 77 | return plot 78 | 79 | 80 | def parse_obje(obj_file, scale_by): 81 | vs = [] 82 | faces = [] 83 | edges = [] 84 | 85 | def add_to_edges(): 86 | if edge_c >= len(edges): 87 | for _ in range(len(edges), edge_c + 1): 88 | edges.append([]) 89 | edges[edge_c].append(edge_v) 90 | 91 | def fix_vertices(): 92 | nonlocal vs, scale_by 93 | vs = V(vs) 94 | z = vs[:, 2].copy() 95 | vs[:, 2] = vs[:, 1] 96 | vs[:, 1] = z 97 | max_range = 0 98 | for i in range(3): 99 | min_value = np.min(vs[:, i]) 100 | max_value = np.max(vs[:, i]) 101 | max_range = max(max_range, max_value - min_value) 102 | vs[:, i] -= min_value 103 | if not scale_by: 104 | scale_by = max_range 105 | vs /= scale_by 106 | 107 | with open(obj_file) as f: 108 | for line in f: 109 | line = line.strip() 110 | splitted_line = line.split() 111 | if not splitted_line: 112 | continue 113 | elif splitted_line[0] == 'v': 114 | vs.append([float(v) for v in splitted_line[1:]]) 115 | elif splitted_line[0] == 'f': 116 | faces.append([int(c) - 1 for c in splitted_line[1:]]) 117 | elif splitted_line[0] == 'e': 118 | if len(splitted_line) >= 4: 119 | edge_v = [int(c) - 1 for c in splitted_line[1:-1]] 120 | edge_c = int(splitted_line[-1]) 121 | add_to_edges() 122 | 123 | vs = V(vs) 124 | fix_vertices() 125 | faces = V(faces, dtype=int) 126 | edges = [V(c, dtype=int) for c in edges] 127 | return (vs, faces, edges), scale_by 128 | 129 | 130 | def view_meshes(*files, offset=.2): 131 | plot = None 132 | max_x = 0 133 | scale = 0 134 | for file in files: 135 | mesh, scale = parse_obje(file, scale) 136 | max_x_current = mesh[0][:, 0].max() 137 | mesh[0][:, 0] += max_x + offset 138 | plot = plot_mesh(mesh, surfaces, segments, plot=plot, show=file == files[-1]) 139 | max_x += max_x_current + offset 140 | 141 | 142 | if __name__=='__main__': 143 | import argparse 144 | parser = argparse.ArgumentParser("view meshes") 145 | parser.add_argument('--files', nargs='+', default=['checkpoints/human_seg/meshes/shrec__14_0.obj', 146 | 'checkpoints/human_seg/meshes/shrec__14_3.obj'], type=str, 147 | help="list of 1 or more .obj files") 148 | args = parser.parse_args() 149 | 150 | # view meshes 151 | view_meshes(*args.files) 152 | 153 | -------------------------------------------------------------------------------- /util/util.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import torch 3 | import numpy as np 4 | import os 5 | 6 | 7 | def mkdir(path): 8 | if not os.path.exists(path): 9 | os.makedirs(path) 10 | 11 | MESH_EXTENSIONS = [ 12 | '.obj', 13 | ] 14 | 15 | def is_mesh_file(filename): 16 | return any(filename.endswith(extension) for extension in MESH_EXTENSIONS) 17 | 18 | def pad(input_arr, target_length, val=0, dim=1): 19 | shp = input_arr.shape 20 | npad = [(0, 0) for _ in range(len(shp))] 21 | npad[dim] = (0, target_length - shp[dim]) 22 | return np.pad(input_arr, pad_width=npad, mode='constant', constant_values=val) 23 | 24 | def seg_accuracy(predicted, ssegs, meshes): 25 | correct = 0 26 | ssegs = ssegs.squeeze(-1) 27 | correct_mat = ssegs.gather(2, predicted.cpu().unsqueeze(dim=2)) 28 | for mesh_id, mesh in enumerate(meshes): 29 | correct_vec = correct_mat[mesh_id, :mesh.edges_count, 0] 30 | edge_areas = torch.from_numpy(mesh.get_edge_areas()) 31 | correct += (correct_vec.float() * edge_areas).sum() 32 | return correct 33 | 34 | def print_network(net): 35 | """Print the total number of parameters in the network 36 | Parameters: 37 | network 38 | """ 39 | print('---------- Network initialized -------------') 40 | num_params = 0 41 | for param in net.parameters(): 42 | num_params += param.numel() 43 | print('[Network] Total number of parameters : %.3f M' % (num_params / 1e6)) 44 | print('-----------------------------------------------') 45 | 46 | def get_heatmap_color(value, minimum=0, maximum=1): 47 | minimum, maximum = float(minimum), float(maximum) 48 | ratio = 2 * (value-minimum) / (maximum - minimum) 49 | b = int(max(0, 255*(1 - ratio))) 50 | r = int(max(0, 255*(ratio - 1))) 51 | g = 255 - b - r 52 | return r, g, b 53 | 54 | 55 | def normalize_np_array(np_array): 56 | min_value = np.min(np_array) 57 | max_value = np.max(np_array) 58 | return (np_array - min_value) / (max_value - min_value) 59 | 60 | 61 | def calculate_entropy(np_array): 62 | entropy = 0 63 | np_array /= np.sum(np_array) 64 | for a in np_array: 65 | if a != 0: 66 | entropy -= a * np.log(a) 67 | entropy /= np.log(np_array.shape[0]) 68 | return entropy 69 | -------------------------------------------------------------------------------- /util/writer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | 4 | try: 5 | from tensorboardX import SummaryWriter 6 | except ImportError as error: 7 | print('tensorboard X not installed, visualizing wont be available') 8 | SummaryWriter = None 9 | 10 | class Writer: 11 | def __init__(self, opt): 12 | self.name = opt.name 13 | self.opt = opt 14 | self.save_dir = os.path.join(opt.checkpoints_dir, opt.name) 15 | self.log_name = os.path.join(self.save_dir, 'loss_log.txt') 16 | self.testacc_log = os.path.join(self.save_dir, 'testacc_log.txt') 17 | self.start_logs() 18 | self.nexamples = 0 19 | self.ncorrect = 0 20 | # 21 | if opt.is_train and not opt.no_vis and SummaryWriter is not None: 22 | self.display = SummaryWriter(comment=opt.name) 23 | else: 24 | self.display = None 25 | 26 | def start_logs(self): 27 | """ creates test / train log files """ 28 | if self.opt.is_train: 29 | with open(self.log_name, "a") as log_file: 30 | now = time.strftime("%c") 31 | log_file.write('================ Training Loss (%s) ================\n' % now) 32 | else: 33 | with open(self.testacc_log, "a") as log_file: 34 | now = time.strftime("%c") 35 | log_file.write('================ Testing Acc (%s) ================\n' % now) 36 | 37 | def print_current_losses(self, epoch, i, losses, t, t_data): 38 | """ prints train loss to terminal / file """ 39 | message = '(epoch: %d, iters: %d, time: %.3f, data: %.3f) loss: %.3f ' \ 40 | % (epoch, i, t, t_data, losses.item()) 41 | print(message) 42 | with open(self.log_name, "a") as log_file: 43 | log_file.write('%s\n' % message) 44 | 45 | def plot_loss(self, loss, epoch, i, n): 46 | iters = i + (epoch - 1) * n 47 | if self.display: 48 | self.display.add_scalar('data/train_loss', loss, iters) 49 | 50 | def plot_model_wts(self, model, epoch): 51 | if self.opt.is_train and self.display: 52 | for name, param in model.net.named_parameters(): 53 | self.display.add_histogram(name, param.clone().cpu().data.numpy(), epoch) 54 | 55 | def print_acc(self, epoch, acc): 56 | """ prints test accuracy to terminal / file """ 57 | message = 'epoch: {}, TEST ACC: [{:.5} %]\n' \ 58 | .format(epoch, acc * 100) 59 | print(message) 60 | with open(self.testacc_log, "a") as log_file: 61 | log_file.write('%s\n' % message) 62 | 63 | def plot_acc(self, acc, epoch): 64 | if self.display: 65 | self.display.add_scalar('data/test_acc', acc, epoch) 66 | 67 | def reset_counter(self): 68 | """ 69 | counts # of correct examples 70 | """ 71 | self.ncorrect = 0 72 | self.nexamples = 0 73 | 74 | def update_counter(self, ncorrect, nexamples): 75 | self.ncorrect += ncorrect 76 | self.nexamples += nexamples 77 | 78 | @property 79 | def acc(self): 80 | return float(self.ncorrect) / self.nexamples 81 | 82 | def close(self): 83 | if self.display is not None: 84 | self.display.close() 85 | --------------------------------------------------------------------------------