├── .gitignore
├── .travis.yml
├── LICENSE
├── README.md
├── data
├── __init__.py
├── base_dataset.py
├── classification_data.py
└── segmentation_data.py
├── docs
├── imgs
│ ├── T18.png
│ ├── T252.png
│ ├── T76.png
│ ├── alien.gif
│ ├── coseg_alien.png
│ ├── coseg_chair.png
│ ├── coseg_vase.png
│ ├── cubes.png
│ ├── cubes2.png
│ ├── input_edge_features.png
│ ├── mesh_conv.png
│ ├── mesh_pool_unpool.png
│ ├── meshcnn_overview.png
│ ├── shrec16_train.png
│ ├── shrec__10_0.png
│ ├── shrec__14_0.png
│ └── shrec__2_0.png
├── index.html
└── mainpage.css
├── environment.yml
├── models
├── __init__.py
├── layers
│ ├── __init__.py
│ ├── mesh.py
│ ├── mesh_conv.py
│ ├── mesh_pool.py
│ ├── mesh_prepare.py
│ ├── mesh_union.py
│ └── mesh_unpool.py
├── mesh_classifier.py
└── networks.py
├── options
├── __init__.py
├── base_options.py
├── test_options.py
└── train_options.py
├── scripts
├── coseg_seg
│ ├── get_data.sh
│ ├── get_pretrained.sh
│ ├── test.sh
│ ├── train.sh
│ └── view.sh
├── cubes
│ ├── get_data.sh
│ ├── get_pretrained.sh
│ ├── test.sh
│ ├── train.sh
│ └── view.sh
├── dataprep
│ └── blender_process.py
├── human_seg
│ ├── get_data.sh
│ ├── get_pretrained.sh
│ ├── test.sh
│ ├── train.sh
│ └── view.sh
├── shrec
│ ├── get_data.sh
│ ├── get_pretrained.sh
│ ├── test.sh
│ ├── train.sh
│ └── view.sh
└── test_general.py
├── test.py
├── train.py
└── util
├── __init__.py
├── mesh_viewer.py
├── util.py
└── writer.py
/.gitignore:
--------------------------------------------------------------------------------
1 | .idea/*
2 | *.pyc
3 | *.m~
4 |
5 | # data files
6 | *.obj
7 | checkpoints
8 | datasets
9 | runs
10 |
--------------------------------------------------------------------------------
/.travis.yml:
--------------------------------------------------------------------------------
1 | notifications:
2 | email:
3 | on_success: never
4 | on_failure: always
5 | language: python
6 | python:
7 | - "3.6"
8 | cache: pip
9 | install:
10 | - sudo apt-get update
11 | - wget https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh -O miniconda.sh;
12 | - bash miniconda.sh -b -p $HOME/anaconda3
13 | - source "$HOME/anaconda3/etc/profile.d/conda.sh"
14 | - hash -r
15 | - conda config --set always_yes yes --set changeps1 no
16 | - conda update -q conda
17 | # Useful for debugging any issues with conda
18 | - conda info -a
19 |
20 | # create meshcnn env
21 | - conda env create -f environment.yml
22 | - conda activate meshcnn
23 | script:
24 | - python -m pytest scripts/test_general.py
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2019 Rana Hanocka
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 | # MeshCNN in PyTorch
5 |
6 |
7 | ### SIGGRAPH 2019 [[Paper]](https://bit.ly/meshcnn) [[Project Page]](https://ranahanocka.github.io/MeshCNN/)
8 |
9 | MeshCNN is a general-purpose deep neural network for 3D triangular meshes, which can be used for tasks such as 3D shape classification or segmentation. This framework includes convolution, pooling and unpooling layers which are applied directly on the mesh edges.
10 |
11 |
12 |
13 | The code was written by [Rana Hanocka](https://www.cs.tau.ac.il/~hanocka/) and [Amir Hertz](http://pxcm.org/) with support from [Noa Fish](http://www.cs.tau.ac.il/~noafish/).
14 |
15 | # Getting Started
16 |
17 | ### Installation
18 | - Clone this repo:
19 | ```bash
20 | git clone https://github.com/ranahanocka/MeshCNN.git
21 | cd MeshCNN
22 | ```
23 | - Install dependencies: [PyTorch](https://pytorch.org/) version 1.2. Optional : [tensorboardX](https://github.com/lanpa/tensorboardX) for training plots.
24 | - Via new conda environment `conda env create -f environment.yml` (creates an environment called meshcnn)
25 |
26 | ### 3D Shape Classification on SHREC
27 | Download the dataset
28 | ```bash
29 | bash ./scripts/shrec/get_data.sh
30 | ```
31 |
32 | Run training (if using conda env first activate env e.g. ```source activate meshcnn```)
33 | ```bash
34 | bash ./scripts/shrec/train.sh
35 | ```
36 |
37 | To view the training loss plots, in another terminal run ```tensorboard --logdir runs``` and click [http://localhost:6006](http://localhost:6006).
38 |
39 | Run test and export the intermediate pooled meshes:
40 | ```bash
41 | bash ./scripts/shrec/test.sh
42 | ```
43 |
44 | Visualize the network-learned edge collapses:
45 | ```bash
46 | bash ./scripts/shrec/view.sh
47 | ```
48 |
49 | An example of collapses for a mesh:
50 |
51 |
52 |
53 | Note, you can also get pre-trained weights using bash ```./scripts/shrec/get_pretrained.sh```.
54 |
55 | In order to use the pre-trained weights, run ```train.sh``` which will compute and save the mean / standard deviation of the training data.
56 |
57 |
58 | ### 3D Shape Segmentation on Humans
59 | The same as above, to download the dataset / run train / get pretrained / run test / view
60 | ```bash
61 | bash ./scripts/human_seg/get_data.sh
62 | bash ./scripts/human_seg/train.sh
63 | bash ./scripts/human_seg/get_pretrained.sh
64 | bash ./scripts/human_seg/test.sh
65 | bash ./scripts/human_seg/view.sh
66 | ```
67 |
68 | Some segmentation result examples:
69 |
70 |
71 |
72 | ### Additional Datasets
73 | The same scripts also exist for COSEG segmentation in ```scripts/coseg_seg``` and cubes classification in ```scripts/cubes```.
74 |
75 | # More Info
76 | Check out the [MeshCNN wiki](https://github.com/ranahanocka/MeshCNN/wiki) for more details. Specifically, see info on [segmentation](https://github.com/ranahanocka/MeshCNN/wiki/Segmentation) and [data processing](https://github.com/ranahanocka/MeshCNN/wiki/Data-Processing).
77 |
78 | # Other implementations
79 | - [Point2Mesh tensorflow reimplementation](https://github.com/dcharatan/point2mesh-reimplementation), which also contains MeshCNN
80 | - [MedMeshCNN](https://github.com/Divya9Sasidharan/MedMeshCNN), handles meshes with 170k edges
81 |
82 | # Citation
83 | If you find this code useful, please consider citing our paper
84 | ```
85 | @article{hanocka2019meshcnn,
86 | title={MeshCNN: A Network with an Edge},
87 | author={Hanocka, Rana and Hertz, Amir and Fish, Noa and Giryes, Raja and Fleishman, Shachar and Cohen-Or, Daniel},
88 | journal={ACM Transactions on Graphics (TOG)},
89 | volume={38},
90 | number={4},
91 | pages = {90:1--90:12},
92 | year={2019},
93 | publisher={ACM}
94 | }
95 | ```
96 |
97 |
98 | # Questions / Issues
99 | If you have questions or issues running this code, please open an issue so we can know to fix it.
100 |
101 | # Acknowledgments
102 | This code design was adopted from [pytorch-CycleGAN-and-pix2pix](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix).
103 |
--------------------------------------------------------------------------------
/data/__init__.py:
--------------------------------------------------------------------------------
1 | import torch.utils.data
2 | from data.base_dataset import collate_fn
3 |
4 | def CreateDataset(opt):
5 | """loads dataset class"""
6 |
7 | if opt.dataset_mode == 'segmentation':
8 | from data.segmentation_data import SegmentationData
9 | dataset = SegmentationData(opt)
10 | elif opt.dataset_mode == 'classification':
11 | from data.classification_data import ClassificationData
12 | dataset = ClassificationData(opt)
13 | return dataset
14 |
15 |
16 | class DataLoader:
17 | """multi-threaded data loading"""
18 |
19 | def __init__(self, opt):
20 | self.opt = opt
21 | self.dataset = CreateDataset(opt)
22 | self.dataloader = torch.utils.data.DataLoader(
23 | self.dataset,
24 | batch_size=opt.batch_size,
25 | shuffle=not opt.serial_batches,
26 | num_workers=int(opt.num_threads),
27 | collate_fn=collate_fn)
28 |
29 | def __len__(self):
30 | return min(len(self.dataset), self.opt.max_dataset_size)
31 |
32 | def __iter__(self):
33 | for i, data in enumerate(self.dataloader):
34 | if i * self.opt.batch_size >= self.opt.max_dataset_size:
35 | break
36 | yield data
37 |
--------------------------------------------------------------------------------
/data/base_dataset.py:
--------------------------------------------------------------------------------
1 | import torch.utils.data as data
2 | import numpy as np
3 | import pickle
4 | import os
5 |
6 | class BaseDataset(data.Dataset):
7 |
8 | def __init__(self, opt):
9 | self.opt = opt
10 | self.mean = 0
11 | self.std = 1
12 | self.ninput_channels = None
13 | super(BaseDataset, self).__init__()
14 |
15 | def get_mean_std(self):
16 | """ Computes Mean and Standard Deviation from Training Data
17 | If mean/std file doesn't exist, will compute one
18 | :returns
19 | mean: N-dimensional mean
20 | std: N-dimensional standard deviation
21 | ninput_channels: N
22 | (here N=5)
23 | """
24 |
25 | mean_std_cache = os.path.join(self.root, 'mean_std_cache.p')
26 | if not os.path.isfile(mean_std_cache):
27 | print('computing mean std from train data...')
28 | # doesn't run augmentation during m/std computation
29 | num_aug = self.opt.num_aug
30 | self.opt.num_aug = 1
31 | mean, std = np.array(0), np.array(0)
32 | for i, data in enumerate(self):
33 | if i % 500 == 0:
34 | print('{} of {}'.format(i, self.size))
35 | features = data['edge_features']
36 | mean = mean + features.mean(axis=1)
37 | std = std + features.std(axis=1)
38 | mean = mean / (i + 1)
39 | std = std / (i + 1)
40 | transform_dict = {'mean': mean[:, np.newaxis], 'std': std[:, np.newaxis],
41 | 'ninput_channels': len(mean)}
42 | with open(mean_std_cache, 'wb') as f:
43 | pickle.dump(transform_dict, f)
44 | print('saved: ', mean_std_cache)
45 | self.opt.num_aug = num_aug
46 | # open mean / std from file
47 | with open(mean_std_cache, 'rb') as f:
48 | transform_dict = pickle.load(f)
49 | print('loaded mean / std from cache')
50 | self.mean = transform_dict['mean']
51 | self.std = transform_dict['std']
52 | self.ninput_channels = transform_dict['ninput_channels']
53 |
54 |
55 | def collate_fn(batch):
56 | """Creates mini-batch tensors
57 | We should build custom collate_fn rather than using default collate_fn
58 | """
59 | meta = {}
60 | keys = batch[0].keys()
61 | for key in keys:
62 | meta.update({key: np.array([d[key] for d in batch])})
63 | return meta
--------------------------------------------------------------------------------
/data/classification_data.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | from data.base_dataset import BaseDataset
4 | from util.util import is_mesh_file, pad
5 | from models.layers.mesh import Mesh
6 |
7 | class ClassificationData(BaseDataset):
8 |
9 | def __init__(self, opt):
10 | BaseDataset.__init__(self, opt)
11 | self.opt = opt
12 | self.device = torch.device('cuda:{}'.format(opt.gpu_ids[0])) if opt.gpu_ids else torch.device('cpu')
13 | self.root = opt.dataroot
14 | self.dir = os.path.join(opt.dataroot)
15 | self.classes, self.class_to_idx = self.find_classes(self.dir)
16 | self.paths = self.make_dataset_by_class(self.dir, self.class_to_idx, opt.phase)
17 | self.nclasses = len(self.classes)
18 | self.size = len(self.paths)
19 | self.get_mean_std()
20 | # modify for network later.
21 | opt.nclasses = self.nclasses
22 | opt.input_nc = self.ninput_channels
23 |
24 | def __getitem__(self, index):
25 | path = self.paths[index][0]
26 | label = self.paths[index][1]
27 | mesh = Mesh(file=path, opt=self.opt, hold_history=False, export_folder=self.opt.export_folder)
28 | meta = {'mesh': mesh, 'label': label}
29 | # get edge features
30 | edge_features = mesh.extract_features()
31 | edge_features = pad(edge_features, self.opt.ninput_edges)
32 | meta['edge_features'] = (edge_features - self.mean) / self.std
33 | return meta
34 |
35 | def __len__(self):
36 | return self.size
37 |
38 | # this is when the folders are organized by class...
39 | @staticmethod
40 | def find_classes(dir):
41 | classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))]
42 | classes.sort()
43 | class_to_idx = {classes[i]: i for i in range(len(classes))}
44 | return classes, class_to_idx
45 |
46 | @staticmethod
47 | def make_dataset_by_class(dir, class_to_idx, phase):
48 | meshes = []
49 | dir = os.path.expanduser(dir)
50 | for target in sorted(os.listdir(dir)):
51 | d = os.path.join(dir, target)
52 | if not os.path.isdir(d):
53 | continue
54 | for root, _, fnames in sorted(os.walk(d)):
55 | for fname in sorted(fnames):
56 | if is_mesh_file(fname) and (root.count(phase)==1):
57 | path = os.path.join(root, fname)
58 | item = (path, class_to_idx[target])
59 | meshes.append(item)
60 | return meshes
61 |
--------------------------------------------------------------------------------
/data/segmentation_data.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | from data.base_dataset import BaseDataset
4 | from util.util import is_mesh_file, pad
5 | import numpy as np
6 | from models.layers.mesh import Mesh
7 |
8 | class SegmentationData(BaseDataset):
9 |
10 | def __init__(self, opt):
11 | BaseDataset.__init__(self, opt)
12 | self.opt = opt
13 | self.device = torch.device('cuda:{}'.format(opt.gpu_ids[0])) if opt.gpu_ids else torch.device('cpu')
14 | self.root = opt.dataroot
15 | self.dir = os.path.join(opt.dataroot, opt.phase)
16 | self.paths = self.make_dataset(self.dir)
17 | self.seg_paths = self.get_seg_files(self.paths, os.path.join(self.root, 'seg'), seg_ext='.eseg')
18 | self.sseg_paths = self.get_seg_files(self.paths, os.path.join(self.root, 'sseg'), seg_ext='.seseg')
19 | self.classes, self.offset = self.get_n_segs(os.path.join(self.root, 'classes.txt'), self.seg_paths)
20 | self.nclasses = len(self.classes)
21 | self.size = len(self.paths)
22 | self.get_mean_std()
23 | # # modify for network later.
24 | opt.nclasses = self.nclasses
25 | opt.input_nc = self.ninput_channels
26 |
27 | def __getitem__(self, index):
28 | path = self.paths[index]
29 | mesh = Mesh(file=path, opt=self.opt, hold_history=True, export_folder=self.opt.export_folder)
30 | meta = {}
31 | meta['mesh'] = mesh
32 | label = read_seg(self.seg_paths[index]) - self.offset
33 | label = pad(label, self.opt.ninput_edges, val=-1, dim=0)
34 | meta['label'] = label
35 | soft_label = read_sseg(self.sseg_paths[index])
36 | meta['soft_label'] = pad(soft_label, self.opt.ninput_edges, val=-1, dim=0)
37 | # get edge features
38 | edge_features = mesh.extract_features()
39 | edge_features = pad(edge_features, self.opt.ninput_edges)
40 | meta['edge_features'] = (edge_features - self.mean) / self.std
41 | return meta
42 |
43 | def __len__(self):
44 | return self.size
45 |
46 | @staticmethod
47 | def get_seg_files(paths, seg_dir, seg_ext='.seg'):
48 | segs = []
49 | for path in paths:
50 | segfile = os.path.join(seg_dir, os.path.splitext(os.path.basename(path))[0] + seg_ext)
51 | assert(os.path.isfile(segfile))
52 | segs.append(segfile)
53 | return segs
54 |
55 | @staticmethod
56 | def get_n_segs(classes_file, seg_files):
57 | if not os.path.isfile(classes_file):
58 | all_segs = np.array([], dtype='float64')
59 | for seg in seg_files:
60 | all_segs = np.concatenate((all_segs, read_seg(seg)))
61 | segnames = np.unique(all_segs)
62 | np.savetxt(classes_file, segnames, fmt='%d')
63 | classes = np.loadtxt(classes_file)
64 | offset = classes[0]
65 | classes = classes - offset
66 | return classes, offset
67 |
68 | @staticmethod
69 | def make_dataset(path):
70 | meshes = []
71 | assert os.path.isdir(path), '%s is not a valid directory' % path
72 |
73 | for root, _, fnames in sorted(os.walk(path)):
74 | for fname in fnames:
75 | if is_mesh_file(fname):
76 | path = os.path.join(root, fname)
77 | meshes.append(path)
78 |
79 | return meshes
80 |
81 |
82 | def read_seg(seg):
83 | seg_labels = np.loadtxt(open(seg, 'r'), dtype='float64')
84 | return seg_labels
85 |
86 |
87 | def read_sseg(sseg_file):
88 | sseg_labels = read_seg(sseg_file)
89 | sseg_labels = np.array(sseg_labels > 0, dtype=np.int32)
90 | return sseg_labels
--------------------------------------------------------------------------------
/docs/imgs/T18.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ranahanocka/MeshCNN/5bf0b899d48eb204b9b73bc1af381be20f4d7df1/docs/imgs/T18.png
--------------------------------------------------------------------------------
/docs/imgs/T252.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ranahanocka/MeshCNN/5bf0b899d48eb204b9b73bc1af381be20f4d7df1/docs/imgs/T252.png
--------------------------------------------------------------------------------
/docs/imgs/T76.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ranahanocka/MeshCNN/5bf0b899d48eb204b9b73bc1af381be20f4d7df1/docs/imgs/T76.png
--------------------------------------------------------------------------------
/docs/imgs/alien.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ranahanocka/MeshCNN/5bf0b899d48eb204b9b73bc1af381be20f4d7df1/docs/imgs/alien.gif
--------------------------------------------------------------------------------
/docs/imgs/coseg_alien.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ranahanocka/MeshCNN/5bf0b899d48eb204b9b73bc1af381be20f4d7df1/docs/imgs/coseg_alien.png
--------------------------------------------------------------------------------
/docs/imgs/coseg_chair.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ranahanocka/MeshCNN/5bf0b899d48eb204b9b73bc1af381be20f4d7df1/docs/imgs/coseg_chair.png
--------------------------------------------------------------------------------
/docs/imgs/coseg_vase.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ranahanocka/MeshCNN/5bf0b899d48eb204b9b73bc1af381be20f4d7df1/docs/imgs/coseg_vase.png
--------------------------------------------------------------------------------
/docs/imgs/cubes.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ranahanocka/MeshCNN/5bf0b899d48eb204b9b73bc1af381be20f4d7df1/docs/imgs/cubes.png
--------------------------------------------------------------------------------
/docs/imgs/cubes2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ranahanocka/MeshCNN/5bf0b899d48eb204b9b73bc1af381be20f4d7df1/docs/imgs/cubes2.png
--------------------------------------------------------------------------------
/docs/imgs/input_edge_features.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ranahanocka/MeshCNN/5bf0b899d48eb204b9b73bc1af381be20f4d7df1/docs/imgs/input_edge_features.png
--------------------------------------------------------------------------------
/docs/imgs/mesh_conv.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ranahanocka/MeshCNN/5bf0b899d48eb204b9b73bc1af381be20f4d7df1/docs/imgs/mesh_conv.png
--------------------------------------------------------------------------------
/docs/imgs/mesh_pool_unpool.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ranahanocka/MeshCNN/5bf0b899d48eb204b9b73bc1af381be20f4d7df1/docs/imgs/mesh_pool_unpool.png
--------------------------------------------------------------------------------
/docs/imgs/meshcnn_overview.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ranahanocka/MeshCNN/5bf0b899d48eb204b9b73bc1af381be20f4d7df1/docs/imgs/meshcnn_overview.png
--------------------------------------------------------------------------------
/docs/imgs/shrec16_train.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ranahanocka/MeshCNN/5bf0b899d48eb204b9b73bc1af381be20f4d7df1/docs/imgs/shrec16_train.png
--------------------------------------------------------------------------------
/docs/imgs/shrec__10_0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ranahanocka/MeshCNN/5bf0b899d48eb204b9b73bc1af381be20f4d7df1/docs/imgs/shrec__10_0.png
--------------------------------------------------------------------------------
/docs/imgs/shrec__14_0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ranahanocka/MeshCNN/5bf0b899d48eb204b9b73bc1af381be20f4d7df1/docs/imgs/shrec__14_0.png
--------------------------------------------------------------------------------
/docs/imgs/shrec__2_0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ranahanocka/MeshCNN/5bf0b899d48eb204b9b73bc1af381be20f4d7df1/docs/imgs/shrec__2_0.png
--------------------------------------------------------------------------------
/docs/index.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 | MeshCNN
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
Siggraph 2019
16 |
MeshCNN: A Network with an Edge
17 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
53 |
54 |
62 |
63 |
71 |
72 |
78 |
79 |
80 |
81 |
82 |
83 |
84 |
Abstract
85 | Polygonal meshes provide an efficient representation for 3D shapes. They explicitly capture both shape surface and topology,
86 | and leverage non-uniformity to represent large flat regions as well as sharp, intricate features. This non-uniformity
87 | and irregularity, however, inhibits mesh analysis efforts using neural networks that combine convolution and pooling
88 | operations. In this paper, we utilize the unique properties of the mesh for a direct analysis of 3D shapes using MeshCNN,
89 | a convolutional neural network designed specifically for triangular meshes. Analogous to classic CNNs, MeshCNN combines
90 | specialized convolution and pooling layers that operate on the mesh edges, by leveraging their intrinsic geodesic
91 | connections. Convolutions are applied on edges and the four edges of their incident triangles, and pooling is applied
92 | via an edge collapse operation that retains surface topology, thereby, generating new mesh connectivity for the
93 | subsequent convolutions. MeshCNN learns which edges to collapse, thus forming a task-driven process where the network
94 | exposes and expands the important features while discarding the redundant ones. We demonstrate the effectiveness
95 | of our task-driven pooling on various learning tasks applied to 3D meshes.
96 |
97 |
98 |
99 |
Video
100 |
101 |
102 | VIDEO
103 |
104 |
105 |
106 |
107 |
108 |
The Layers of MeshCNN
109 | In MeshCNN the edges of a mesh are analogous to pixels in an image, since they are the basic building blocks
110 | for all CNN operations. Just as images start with a basic input feature: an RGB value per pixel;
111 | MeshCNN starts with a few basic geometric features per edge. The input edge feature is a 5-dimensional vector
112 | every edge: the dihedral angle, two inner angles and two edge-length ratios for each face.
113 |
114 |
115 |
Input Edge Features
116 |
117 |
118 | MeshCNN learns features on the edges of the mesh, since every edge is incident to exactly two faces (triangles),
119 | which defines a natural fixed-sized convolutional neighborhood of four edges.
120 |
121 |
122 |
Mesh Convolution
123 |
124 | Learned convolutional filters are applied on each edge feature vector and the 4 one-ring neighbors.
125 | The consistent face normal order is used to apply a symmetric convolution operation, which learns edge
126 | features that are invariant to rotations, translations and uniform scale.
127 | Mesh pooling downsamples the number of features in the network, by performing a edge-collapse on the learned
128 | edge features. The new edge neighbors are computed dynamically inside the network, and used in the next convolutions.
129 |
130 |
131 |
Mesh Pooling & Unpooling
132 |
133 | For fully-convolutional tasks (such as segmentation), a mesh unpooling operation is used to restore the
134 | original mesh resolution.
135 |
136 |
137 |
138 |
Results
139 |
140 |
141 |
142 |
Learned Simplifications on Cube Dataset
143 |
144 |
145 |
146 |
147 |
148 |
Learned Simplifications on Shrec Dataset
149 |
150 |
151 |
152 |
153 |
154 |
155 |
Human Segmentation Results
156 |
157 |
158 |
159 |
160 |
161 |
162 |
Coseg Segmentation Results
163 |
164 |
165 |
166 |
167 |
168 |
Download Datasets
169 |
176 |
177 |
178 |
179 |
Contact
180 |
181 | Rana at Hanocka dot com
182 |
183 |
184 |
185 |
187 |
188 |
189 |
190 |
--------------------------------------------------------------------------------
/docs/mainpage.css:
--------------------------------------------------------------------------------
1 | body {
2 | font-family: 'Lato', sans-serif;
3 | font-weight: 300;
4 | color: #333;
5 | font-size: 16px;
6 | }
7 | h1 {
8 | font-size: 40px;
9 | color: #555;
10 | font-weight: 400;
11 | text-align: center;
12 | margin: 0;
13 | padding: 0;
14 | margin-top: 30px;
15 | margin-bottom: 10px;
16 | }
17 | .authors {
18 | color: #222;
19 | font-size: 24px;
20 | font-weight: 300;
21 | text-align: center;
22 | margin: 0;
23 | padding: 0;
24 | margin-bottom: 0px;
25 | }
26 | .logoimg {
27 | text-align: center;
28 | margin-bottom: 30px;
29 | }
30 | .container-fluid {
31 | margin-top: 5px;
32 | margin-bottom: 5px;
33 | }
34 | .container {
35 | margin-top: 10px;
36 | }
37 | #footer {
38 | margin-bottom: 100px;
39 | }
40 | .thumbs {
41 | -webkit-box-shadow: 1px 1px 3px #999;
42 | -moz-box-shadow: 1px 1px 3px #999;
43 | box-shadow: 1px 1px 3px #999;
44 | margin-bottom: 20px;
45 | }
46 | h2 {
47 | font-size: 24px;
48 | font-weight: 900;
49 | border-bottom: 1px solid #999;
50 | margin-bottom: 20px;
51 | }
52 |
53 |
54 | .text-primary {
55 | color: #5da2d5 !important;
56 | }
57 | .text-primary:hover {
58 | color: #f3d250 !important;
59 | opacity: 1.0;
60 | }
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: meshcnn
2 | channels:
3 | - pytorch
4 | - defaults
5 | dependencies:
6 | - python=3.6.8
7 | - cython=0.27.3
8 | - pytorch=1.2.0
9 | - numpy=1.15.0
10 | - matplotlib=3.0.3
11 | - pip
12 | - pip:
13 | - git+https://github.com/lanpa/tensorboardX.git
14 | - pytest==5.1.1
15 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | def create_model(opt):
2 | from .mesh_classifier import ClassifierModel # todo - get rid of this ?
3 | model = ClassifierModel(opt)
4 | return model
5 |
--------------------------------------------------------------------------------
/models/layers/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ranahanocka/MeshCNN/5bf0b899d48eb204b9b73bc1af381be20f4d7df1/models/layers/__init__.py
--------------------------------------------------------------------------------
/models/layers/mesh.py:
--------------------------------------------------------------------------------
1 | from tempfile import mkstemp
2 | from shutil import move
3 | import torch
4 | import numpy as np
5 | import os
6 | from models.layers.mesh_union import MeshUnion
7 | from models.layers.mesh_prepare import fill_mesh
8 |
9 |
10 | class Mesh:
11 |
12 | def __init__(self, file=None, opt=None, hold_history=False, export_folder=''):
13 | self.vs = self.v_mask = self.filename = self.features = self.edge_areas = None
14 | self.edges = self.gemm_edges = self.sides = None
15 | self.pool_count = 0
16 | fill_mesh(self, file, opt)
17 | self.export_folder = export_folder
18 | self.history_data = None
19 | if hold_history:
20 | self.init_history()
21 | self.export()
22 |
23 | def extract_features(self):
24 | return self.features
25 |
26 | def merge_vertices(self, edge_id):
27 | self.remove_edge(edge_id)
28 | edge = self.edges[edge_id]
29 | v_a = self.vs[edge[0]]
30 | v_b = self.vs[edge[1]]
31 | # update pA
32 | v_a.__iadd__(v_b)
33 | v_a.__itruediv__(2)
34 | self.v_mask[edge[1]] = False
35 | mask = self.edges == edge[1]
36 | self.ve[edge[0]].extend(self.ve[edge[1]])
37 | self.edges[mask] = edge[0]
38 |
39 | def remove_vertex(self, v):
40 | self.v_mask[v] = False
41 |
42 | def remove_edge(self, edge_id):
43 | vs = self.edges[edge_id]
44 | for v in vs:
45 | if edge_id not in self.ve[v]:
46 | print(self.ve[v])
47 | print(self.filename)
48 | self.ve[v].remove(edge_id)
49 |
50 | def clean(self, edges_mask, groups):
51 | edges_mask = edges_mask.astype(bool)
52 | torch_mask = torch.from_numpy(edges_mask.copy())
53 | self.gemm_edges = self.gemm_edges[edges_mask]
54 | self.edges = self.edges[edges_mask]
55 | self.sides = self.sides[edges_mask]
56 | new_ve = []
57 | edges_mask = np.concatenate([edges_mask, [False]])
58 | new_indices = np.zeros(edges_mask.shape[0], dtype=np.int32)
59 | new_indices[-1] = -1
60 | new_indices[edges_mask] = np.arange(0, np.ma.where(edges_mask)[0].shape[0])
61 | self.gemm_edges[:, :] = new_indices[self.gemm_edges[:, :]]
62 | for v_index, ve in enumerate(self.ve):
63 | update_ve = []
64 | # if self.v_mask[v_index]:
65 | for e in ve:
66 | update_ve.append(new_indices[e])
67 | new_ve.append(update_ve)
68 | self.ve = new_ve
69 | self.__clean_history(groups, torch_mask)
70 | self.pool_count += 1
71 | self.export()
72 |
73 |
74 | def export(self, file=None, vcolor=None):
75 | if file is None:
76 | if self.export_folder:
77 | filename, file_extension = os.path.splitext(self.filename)
78 | file = '%s/%s_%d%s' % (self.export_folder, filename, self.pool_count, file_extension)
79 | else:
80 | return
81 | faces = []
82 | vs = self.vs[self.v_mask]
83 | gemm = np.array(self.gemm_edges)
84 | new_indices = np.zeros(self.v_mask.shape[0], dtype=np.int32)
85 | new_indices[self.v_mask] = np.arange(0, np.ma.where(self.v_mask)[0].shape[0])
86 | for edge_index in range(len(gemm)):
87 | cycles = self.__get_cycle(gemm, edge_index)
88 | for cycle in cycles:
89 | faces.append(self.__cycle_to_face(cycle, new_indices))
90 | with open(file, 'w+') as f:
91 | for vi, v in enumerate(vs):
92 | vcol = ' %f %f %f' % (vcolor[vi, 0], vcolor[vi, 1], vcolor[vi, 2]) if vcolor is not None else ''
93 | f.write("v %f %f %f%s\n" % (v[0], v[1], v[2], vcol))
94 | for face_id in range(len(faces) - 1):
95 | f.write("f %d %d %d\n" % (faces[face_id][0] + 1, faces[face_id][1] + 1, faces[face_id][2] + 1))
96 | f.write("f %d %d %d" % (faces[-1][0] + 1, faces[-1][1] + 1, faces[-1][2] + 1))
97 | for edge in self.edges:
98 | f.write("\ne %d %d" % (new_indices[edge[0]] + 1, new_indices[edge[1]] + 1))
99 |
100 | def export_segments(self, segments):
101 | if not self.export_folder:
102 | return
103 | cur_segments = segments
104 | for i in range(self.pool_count + 1):
105 | filename, file_extension = os.path.splitext(self.filename)
106 | file = '%s/%s_%d%s' % (self.export_folder, filename, i, file_extension)
107 | fh, abs_path = mkstemp()
108 | edge_key = 0
109 | with os.fdopen(fh, 'w') as new_file:
110 | with open(file) as old_file:
111 | for line in old_file:
112 | if line[0] == 'e':
113 | new_file.write('%s %d' % (line.strip(), cur_segments[edge_key]))
114 | if edge_key < len(cur_segments):
115 | edge_key += 1
116 | new_file.write('\n')
117 | else:
118 | new_file.write(line)
119 | os.remove(file)
120 | move(abs_path, file)
121 | if i < len(self.history_data['edges_mask']):
122 | cur_segments = segments[:len(self.history_data['edges_mask'][i])]
123 | cur_segments = cur_segments[self.history_data['edges_mask'][i]]
124 |
125 | def __get_cycle(self, gemm, edge_id):
126 | cycles = []
127 | for j in range(2):
128 | next_side = start_point = j * 2
129 | next_key = edge_id
130 | if gemm[edge_id, start_point] == -1:
131 | continue
132 | cycles.append([])
133 | for i in range(3):
134 | tmp_next_key = gemm[next_key, next_side]
135 | tmp_next_side = self.sides[next_key, next_side]
136 | tmp_next_side = tmp_next_side + 1 - 2 * (tmp_next_side % 2)
137 | gemm[next_key, next_side] = -1
138 | gemm[next_key, next_side + 1 - 2 * (next_side % 2)] = -1
139 | next_key = tmp_next_key
140 | next_side = tmp_next_side
141 | cycles[-1].append(next_key)
142 | return cycles
143 |
144 | def __cycle_to_face(self, cycle, v_indices):
145 | face = []
146 | for i in range(3):
147 | v = list(set(self.edges[cycle[i]]) & set(self.edges[cycle[(i + 1) % 3]]))[0]
148 | face.append(v_indices[v])
149 | return face
150 |
151 | def init_history(self):
152 | self.history_data = {
153 | 'groups': [],
154 | 'gemm_edges': [self.gemm_edges.copy()],
155 | 'occurrences': [],
156 | 'old2current': np.arange(self.edges_count, dtype=np.int32),
157 | 'current2old': np.arange(self.edges_count, dtype=np.int32),
158 | 'edges_mask': [torch.ones(self.edges_count,dtype=torch.bool)],
159 | 'edges_count': [self.edges_count],
160 | }
161 | if self.export_folder:
162 | self.history_data['collapses'] = MeshUnion(self.edges_count)
163 |
164 | def union_groups(self, source, target):
165 | if self.export_folder and self.history_data:
166 | self.history_data['collapses'].union(self.history_data['current2old'][source], self.history_data['current2old'][target])
167 | return
168 |
169 | def remove_group(self, index):
170 | if self.history_data is not None:
171 | self.history_data['edges_mask'][-1][self.history_data['current2old'][index]] = 0
172 | self.history_data['old2current'][self.history_data['current2old'][index]] = -1
173 | if self.export_folder:
174 | self.history_data['collapses'].remove_group(self.history_data['current2old'][index])
175 |
176 | def get_groups(self):
177 | return self.history_data['groups'].pop()
178 |
179 | def get_occurrences(self):
180 | return self.history_data['occurrences'].pop()
181 |
182 | def __clean_history(self, groups, pool_mask):
183 | if self.history_data is not None:
184 | mask = self.history_data['old2current'] != -1
185 | self.history_data['old2current'][mask] = np.arange(self.edges_count, dtype=np.int32)
186 | self.history_data['current2old'][0: self.edges_count] = np.ma.where(mask)[0]
187 | if self.export_folder != '':
188 | self.history_data['edges_mask'].append(self.history_data['edges_mask'][-1].clone())
189 | self.history_data['occurrences'].append(groups.get_occurrences())
190 | self.history_data['groups'].append(groups.get_groups(pool_mask))
191 | self.history_data['gemm_edges'].append(self.gemm_edges.copy())
192 | self.history_data['edges_count'].append(self.edges_count)
193 |
194 | def unroll_gemm(self):
195 | self.history_data['gemm_edges'].pop()
196 | self.gemm_edges = self.history_data['gemm_edges'][-1]
197 | self.history_data['edges_count'].pop()
198 | self.edges_count = self.history_data['edges_count'][-1]
199 |
200 | def get_edge_areas(self):
201 | return self.edge_areas
202 |
--------------------------------------------------------------------------------
/models/layers/mesh_conv.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | class MeshConv(nn.Module):
6 | """ Computes convolution between edges and 4 incident (1-ring) edge neighbors
7 | in the forward pass takes:
8 | x: edge features (Batch x Features x Edges)
9 | mesh: list of mesh data-structure (len(mesh) == Batch)
10 | and applies convolution
11 | """
12 | def __init__(self, in_channels, out_channels, k=5, bias=True):
13 | super(MeshConv, self).__init__()
14 | self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=(1, k), bias=bias)
15 | self.k = k
16 |
17 | def __call__(self, edge_f, mesh):
18 | return self.forward(edge_f, mesh)
19 |
20 | def forward(self, x, mesh):
21 | x = x.squeeze(-1)
22 | G = torch.cat([self.pad_gemm(i, x.shape[2], x.device) for i in mesh], 0)
23 | # build 'neighborhood image' and apply convolution
24 | G = self.create_GeMM(x, G)
25 | x = self.conv(G)
26 | return x
27 |
28 | def flatten_gemm_inds(self, Gi):
29 | (b, ne, nn) = Gi.shape
30 | ne += 1
31 | batch_n = torch.floor(torch.arange(b * ne, device=Gi.device).float() / ne).view(b, ne)
32 | add_fac = batch_n * ne
33 | add_fac = add_fac.view(b, ne, 1)
34 | add_fac = add_fac.repeat(1, 1, nn)
35 | # flatten Gi
36 | Gi = Gi.float() + add_fac[:, 1:, :]
37 | return Gi
38 |
39 | def create_GeMM(self, x, Gi):
40 | """ gathers the edge features (x) with from the 1-ring indices (Gi)
41 | applys symmetric functions to handle order invariance
42 | returns a 'fake image' which can use 2d convolution on
43 | output dimensions: Batch x Channels x Edges x 5
44 | """
45 | Gishape = Gi.shape
46 | # pad the first row of every sample in batch with zeros
47 | padding = torch.zeros((x.shape[0], x.shape[1], 1), requires_grad=True, device=x.device)
48 | # padding = padding.to(x.device)
49 | x = torch.cat((padding, x), dim=2)
50 | Gi = Gi + 1 #shift
51 |
52 | # first flatten indices
53 | Gi_flat = self.flatten_gemm_inds(Gi)
54 | Gi_flat = Gi_flat.view(-1).long()
55 | #
56 | odim = x.shape
57 | x = x.permute(0, 2, 1).contiguous()
58 | x = x.view(odim[0] * odim[2], odim[1])
59 |
60 | f = torch.index_select(x, dim=0, index=Gi_flat)
61 | f = f.view(Gishape[0], Gishape[1], Gishape[2], -1)
62 | f = f.permute(0, 3, 1, 2)
63 |
64 | # apply the symmetric functions for an equivariant conv
65 | x_1 = f[:, :, :, 1] + f[:, :, :, 3]
66 | x_2 = f[:, :, :, 2] + f[:, :, :, 4]
67 | x_3 = torch.abs(f[:, :, :, 1] - f[:, :, :, 3])
68 | x_4 = torch.abs(f[:, :, :, 2] - f[:, :, :, 4])
69 | f = torch.stack([f[:, :, :, 0], x_1, x_2, x_3, x_4], dim=3)
70 | return f
71 |
72 | def pad_gemm(self, m, xsz, device):
73 | """ extracts one-ring neighbors (4x) -> m.gemm_edges
74 | which is of size #edges x 4
75 | add the edge_id itself to make #edges x 5
76 | then pad to desired size e.g., xsz x 5
77 | """
78 | padded_gemm = torch.tensor(m.gemm_edges, device=device).float()
79 | padded_gemm = padded_gemm.requires_grad_()
80 | padded_gemm = torch.cat((torch.arange(m.edges_count, device=device).float().unsqueeze(1), padded_gemm), dim=1)
81 | # pad using F
82 | padded_gemm = F.pad(padded_gemm, (0, 0, 0, xsz - m.edges_count), "constant", 0)
83 | padded_gemm = padded_gemm.unsqueeze(0)
84 | return padded_gemm
85 |
--------------------------------------------------------------------------------
/models/layers/mesh_pool.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from threading import Thread
4 | from models.layers.mesh_union import MeshUnion
5 | import numpy as np
6 | from heapq import heappop, heapify
7 |
8 |
9 | class MeshPool(nn.Module):
10 |
11 | def __init__(self, target, multi_thread=False):
12 | super(MeshPool, self).__init__()
13 | self.__out_target = target
14 | self.__multi_thread = multi_thread
15 | self.__fe = None
16 | self.__updated_fe = None
17 | self.__meshes = None
18 | self.__merge_edges = [-1, -1]
19 |
20 | def __call__(self, fe, meshes):
21 | return self.forward(fe, meshes)
22 |
23 | def forward(self, fe, meshes):
24 | self.__updated_fe = [[] for _ in range(len(meshes))]
25 | pool_threads = []
26 | self.__fe = fe
27 | self.__meshes = meshes
28 | # iterate over batch
29 | for mesh_index in range(len(meshes)):
30 | if self.__multi_thread:
31 | pool_threads.append(Thread(target=self.__pool_main, args=(mesh_index,)))
32 | pool_threads[-1].start()
33 | else:
34 | self.__pool_main(mesh_index)
35 | if self.__multi_thread:
36 | for mesh_index in range(len(meshes)):
37 | pool_threads[mesh_index].join()
38 | out_features = torch.cat(self.__updated_fe).view(len(meshes), -1, self.__out_target)
39 | return out_features
40 |
41 | def __pool_main(self, mesh_index):
42 | mesh = self.__meshes[mesh_index]
43 | queue = self.__build_queue(self.__fe[mesh_index, :, :mesh.edges_count], mesh.edges_count)
44 | # recycle = []
45 | # last_queue_len = len(queue)
46 | last_count = mesh.edges_count + 1
47 | mask = np.ones(mesh.edges_count, dtype=np.bool)
48 | edge_groups = MeshUnion(mesh.edges_count, self.__fe.device)
49 | while mesh.edges_count > self.__out_target:
50 | value, edge_id = heappop(queue)
51 | edge_id = int(edge_id)
52 | if mask[edge_id]:
53 | self.__pool_edge(mesh, edge_id, mask, edge_groups)
54 | mesh.clean(mask, edge_groups)
55 | fe = edge_groups.rebuild_features(self.__fe[mesh_index], mask, self.__out_target)
56 | self.__updated_fe[mesh_index] = fe
57 |
58 | def __pool_edge(self, mesh, edge_id, mask, edge_groups):
59 | if self.has_boundaries(mesh, edge_id):
60 | return False
61 | elif self.__clean_side(mesh, edge_id, mask, edge_groups, 0)\
62 | and self.__clean_side(mesh, edge_id, mask, edge_groups, 2) \
63 | and self.__is_one_ring_valid(mesh, edge_id):
64 | self.__merge_edges[0] = self.__pool_side(mesh, edge_id, mask, edge_groups, 0)
65 | self.__merge_edges[1] = self.__pool_side(mesh, edge_id, mask, edge_groups, 2)
66 | mesh.merge_vertices(edge_id)
67 | mask[edge_id] = False
68 | MeshPool.__remove_group(mesh, edge_groups, edge_id)
69 | mesh.edges_count -= 1
70 | return True
71 | else:
72 | return False
73 |
74 | def __clean_side(self, mesh, edge_id, mask, edge_groups, side):
75 | if mesh.edges_count <= self.__out_target:
76 | return False
77 | invalid_edges = MeshPool.__get_invalids(mesh, edge_id, edge_groups, side)
78 | while len(invalid_edges) != 0 and mesh.edges_count > self.__out_target:
79 | self.__remove_triplete(mesh, mask, edge_groups, invalid_edges)
80 | if mesh.edges_count <= self.__out_target:
81 | return False
82 | if self.has_boundaries(mesh, edge_id):
83 | return False
84 | invalid_edges = self.__get_invalids(mesh, edge_id, edge_groups, side)
85 | return True
86 |
87 | @staticmethod
88 | def has_boundaries(mesh, edge_id):
89 | for edge in mesh.gemm_edges[edge_id]:
90 | if edge == -1 or -1 in mesh.gemm_edges[edge]:
91 | return True
92 | return False
93 |
94 |
95 | @staticmethod
96 | def __is_one_ring_valid(mesh, edge_id):
97 | v_a = set(mesh.edges[mesh.ve[mesh.edges[edge_id, 0]]].reshape(-1))
98 | v_b = set(mesh.edges[mesh.ve[mesh.edges[edge_id, 1]]].reshape(-1))
99 | shared = v_a & v_b - set(mesh.edges[edge_id])
100 | return len(shared) == 2
101 |
102 | def __pool_side(self, mesh, edge_id, mask, edge_groups, side):
103 | info = MeshPool.__get_face_info(mesh, edge_id, side)
104 | key_a, key_b, side_a, side_b, _, other_side_b, _, other_keys_b = info
105 | self.__redirect_edges(mesh, key_a, side_a - side_a % 2, other_keys_b[0], mesh.sides[key_b, other_side_b])
106 | self.__redirect_edges(mesh, key_a, side_a - side_a % 2 + 1, other_keys_b[1], mesh.sides[key_b, other_side_b + 1])
107 | MeshPool.__union_groups(mesh, edge_groups, key_b, key_a)
108 | MeshPool.__union_groups(mesh, edge_groups, edge_id, key_a)
109 | mask[key_b] = False
110 | MeshPool.__remove_group(mesh, edge_groups, key_b)
111 | mesh.remove_edge(key_b)
112 | mesh.edges_count -= 1
113 | return key_a
114 |
115 | @staticmethod
116 | def __get_invalids(mesh, edge_id, edge_groups, side):
117 | info = MeshPool.__get_face_info(mesh, edge_id, side)
118 | key_a, key_b, side_a, side_b, other_side_a, other_side_b, other_keys_a, other_keys_b = info
119 | shared_items = MeshPool.__get_shared_items(other_keys_a, other_keys_b)
120 | if len(shared_items) == 0:
121 | return []
122 | else:
123 | assert (len(shared_items) == 2)
124 | middle_edge = other_keys_a[shared_items[0]]
125 | update_key_a = other_keys_a[1 - shared_items[0]]
126 | update_key_b = other_keys_b[1 - shared_items[1]]
127 | update_side_a = mesh.sides[key_a, other_side_a + 1 - shared_items[0]]
128 | update_side_b = mesh.sides[key_b, other_side_b + 1 - shared_items[1]]
129 | MeshPool.__redirect_edges(mesh, edge_id, side, update_key_a, update_side_a)
130 | MeshPool.__redirect_edges(mesh, edge_id, side + 1, update_key_b, update_side_b)
131 | MeshPool.__redirect_edges(mesh, update_key_a, MeshPool.__get_other_side(update_side_a), update_key_b, MeshPool.__get_other_side(update_side_b))
132 | MeshPool.__union_groups(mesh, edge_groups, key_a, edge_id)
133 | MeshPool.__union_groups(mesh, edge_groups, key_b, edge_id)
134 | MeshPool.__union_groups(mesh, edge_groups, key_a, update_key_a)
135 | MeshPool.__union_groups(mesh, edge_groups, middle_edge, update_key_a)
136 | MeshPool.__union_groups(mesh, edge_groups, key_b, update_key_b)
137 | MeshPool.__union_groups(mesh, edge_groups, middle_edge, update_key_b)
138 | return [key_a, key_b, middle_edge]
139 |
140 | @staticmethod
141 | def __redirect_edges(mesh, edge_a_key, side_a, edge_b_key, side_b):
142 | mesh.gemm_edges[edge_a_key, side_a] = edge_b_key
143 | mesh.gemm_edges[edge_b_key, side_b] = edge_a_key
144 | mesh.sides[edge_a_key, side_a] = side_b
145 | mesh.sides[edge_b_key, side_b] = side_a
146 |
147 | @staticmethod
148 | def __get_shared_items(list_a, list_b):
149 | shared_items = []
150 | for i in range(len(list_a)):
151 | for j in range(len(list_b)):
152 | if list_a[i] == list_b[j]:
153 | shared_items.extend([i, j])
154 | return shared_items
155 |
156 | @staticmethod
157 | def __get_other_side(side):
158 | return side + 1 - 2 * (side % 2)
159 |
160 | @staticmethod
161 | def __get_face_info(mesh, edge_id, side):
162 | key_a = mesh.gemm_edges[edge_id, side]
163 | key_b = mesh.gemm_edges[edge_id, side + 1]
164 | side_a = mesh.sides[edge_id, side]
165 | side_b = mesh.sides[edge_id, side + 1]
166 | other_side_a = (side_a - (side_a % 2) + 2) % 4
167 | other_side_b = (side_b - (side_b % 2) + 2) % 4
168 | other_keys_a = [mesh.gemm_edges[key_a, other_side_a], mesh.gemm_edges[key_a, other_side_a + 1]]
169 | other_keys_b = [mesh.gemm_edges[key_b, other_side_b], mesh.gemm_edges[key_b, other_side_b + 1]]
170 | return key_a, key_b, side_a, side_b, other_side_a, other_side_b, other_keys_a, other_keys_b
171 |
172 | @staticmethod
173 | def __remove_triplete(mesh, mask, edge_groups, invalid_edges):
174 | vertex = set(mesh.edges[invalid_edges[0]])
175 | for edge_key in invalid_edges:
176 | vertex &= set(mesh.edges[edge_key])
177 | mask[edge_key] = False
178 | MeshPool.__remove_group(mesh, edge_groups, edge_key)
179 | mesh.edges_count -= 3
180 | vertex = list(vertex)
181 | assert(len(vertex) == 1)
182 | mesh.remove_vertex(vertex[0])
183 |
184 | def __build_queue(self, features, edges_count):
185 | # delete edges with smallest norm
186 | squared_magnitude = torch.sum(features * features, 0)
187 | if squared_magnitude.shape[-1] != 1:
188 | squared_magnitude = squared_magnitude.unsqueeze(-1)
189 | edge_ids = torch.arange(edges_count, device=squared_magnitude.device, dtype=torch.float32).unsqueeze(-1)
190 | heap = torch.cat((squared_magnitude, edge_ids), dim=-1).tolist()
191 | heapify(heap)
192 | return heap
193 |
194 | @staticmethod
195 | def __union_groups(mesh, edge_groups, source, target):
196 | edge_groups.union(source, target)
197 | mesh.union_groups(source, target)
198 |
199 | @staticmethod
200 | def __remove_group(mesh, edge_groups, index):
201 | edge_groups.remove_group(index)
202 | mesh.remove_group(index)
203 |
204 |
--------------------------------------------------------------------------------
/models/layers/mesh_prepare.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 | import ntpath
4 |
5 |
6 | def fill_mesh(mesh2fill, file: str, opt):
7 | load_path = get_mesh_path(file, opt.num_aug)
8 | if os.path.exists(load_path):
9 | mesh_data = np.load(load_path, encoding='latin1', allow_pickle=True)
10 | else:
11 | mesh_data = from_scratch(file, opt)
12 | np.savez_compressed(load_path, gemm_edges=mesh_data.gemm_edges, vs=mesh_data.vs, edges=mesh_data.edges,
13 | edges_count=mesh_data.edges_count, ve=mesh_data.ve, v_mask=mesh_data.v_mask,
14 | filename=mesh_data.filename, sides=mesh_data.sides,
15 | edge_lengths=mesh_data.edge_lengths, edge_areas=mesh_data.edge_areas,
16 | features=mesh_data.features)
17 | mesh2fill.vs = mesh_data['vs']
18 | mesh2fill.edges = mesh_data['edges']
19 | mesh2fill.gemm_edges = mesh_data['gemm_edges']
20 | mesh2fill.edges_count = int(mesh_data['edges_count'])
21 | mesh2fill.ve = mesh_data['ve']
22 | mesh2fill.v_mask = mesh_data['v_mask']
23 | mesh2fill.filename = str(mesh_data['filename'])
24 | mesh2fill.edge_lengths = mesh_data['edge_lengths']
25 | mesh2fill.edge_areas = mesh_data['edge_areas']
26 | mesh2fill.features = mesh_data['features']
27 | mesh2fill.sides = mesh_data['sides']
28 |
29 | def get_mesh_path(file: str, num_aug: int):
30 | filename, _ = os.path.splitext(file)
31 | dir_name = os.path.dirname(filename)
32 | prefix = os.path.basename(filename)
33 | load_dir = os.path.join(dir_name, 'cache')
34 | load_file = os.path.join(load_dir, '%s_%03d.npz' % (prefix, np.random.randint(0, num_aug)))
35 | if not os.path.isdir(load_dir):
36 | os.makedirs(load_dir, exist_ok=True)
37 | return load_file
38 |
39 | def from_scratch(file, opt):
40 |
41 | class MeshPrep:
42 | def __getitem__(self, item):
43 | return eval('self.' + item)
44 |
45 | mesh_data = MeshPrep()
46 | mesh_data.vs = mesh_data.edges = None
47 | mesh_data.gemm_edges = mesh_data.sides = None
48 | mesh_data.edges_count = None
49 | mesh_data.ve = None
50 | mesh_data.v_mask = None
51 | mesh_data.filename = 'unknown'
52 | mesh_data.edge_lengths = None
53 | mesh_data.edge_areas = []
54 | mesh_data.vs, faces = fill_from_file(mesh_data, file)
55 | mesh_data.v_mask = np.ones(len(mesh_data.vs), dtype=bool)
56 | faces, face_areas = remove_non_manifolds(mesh_data, faces)
57 | if opt.num_aug > 1:
58 | faces = augmentation(mesh_data, opt, faces)
59 | build_gemm(mesh_data, faces, face_areas)
60 | if opt.num_aug > 1:
61 | post_augmentation(mesh_data, opt)
62 | mesh_data.features = extract_features(mesh_data)
63 | return mesh_data
64 |
65 | def fill_from_file(mesh, file):
66 | mesh.filename = ntpath.split(file)[1]
67 | mesh.fullfilename = file
68 | vs, faces = [], []
69 | f = open(file)
70 | for line in f:
71 | line = line.strip()
72 | splitted_line = line.split()
73 | if not splitted_line:
74 | continue
75 | elif splitted_line[0] == 'v':
76 | vs.append([float(v) for v in splitted_line[1:4]])
77 | elif splitted_line[0] == 'f':
78 | face_vertex_ids = [int(c.split('/')[0]) for c in splitted_line[1:]]
79 | assert len(face_vertex_ids) == 3
80 | face_vertex_ids = [(ind - 1) if (ind >= 0) else (len(vs) + ind)
81 | for ind in face_vertex_ids]
82 | faces.append(face_vertex_ids)
83 | f.close()
84 | vs = np.asarray(vs)
85 | faces = np.asarray(faces, dtype=int)
86 | assert np.logical_and(faces >= 0, faces < len(vs)).all()
87 | return vs, faces
88 |
89 |
90 | def remove_non_manifolds(mesh, faces):
91 | mesh.ve = [[] for _ in mesh.vs]
92 | edges_set = set()
93 | mask = np.ones(len(faces), dtype=bool)
94 | _, face_areas = compute_face_normals_and_areas(mesh, faces)
95 | for face_id, face in enumerate(faces):
96 | if face_areas[face_id] == 0:
97 | mask[face_id] = False
98 | continue
99 | faces_edges = []
100 | is_manifold = False
101 | for i in range(3):
102 | cur_edge = (face[i], face[(i + 1) % 3])
103 | if cur_edge in edges_set:
104 | is_manifold = True
105 | break
106 | else:
107 | faces_edges.append(cur_edge)
108 | if is_manifold:
109 | mask[face_id] = False
110 | else:
111 | for idx, edge in enumerate(faces_edges):
112 | edges_set.add(edge)
113 | return faces[mask], face_areas[mask]
114 |
115 |
116 | def build_gemm(mesh, faces, face_areas):
117 | """
118 | gemm_edges: array (#E x 4) of the 4 one-ring neighbors for each edge
119 | sides: array (#E x 4) indices (values of: 0,1,2,3) indicating where an edge is in the gemm_edge entry of the 4 neighboring edges
120 | for example edge i -> gemm_edges[gemm_edges[i], sides[i]] == [i, i, i, i]
121 | """
122 | mesh.ve = [[] for _ in mesh.vs]
123 | edge_nb = []
124 | sides = []
125 | edge2key = dict()
126 | edges = []
127 | edges_count = 0
128 | nb_count = []
129 | for face_id, face in enumerate(faces):
130 | faces_edges = []
131 | for i in range(3):
132 | cur_edge = (face[i], face[(i + 1) % 3])
133 | faces_edges.append(cur_edge)
134 | for idx, edge in enumerate(faces_edges):
135 | edge = tuple(sorted(list(edge)))
136 | faces_edges[idx] = edge
137 | if edge not in edge2key:
138 | edge2key[edge] = edges_count
139 | edges.append(list(edge))
140 | edge_nb.append([-1, -1, -1, -1])
141 | sides.append([-1, -1, -1, -1])
142 | mesh.ve[edge[0]].append(edges_count)
143 | mesh.ve[edge[1]].append(edges_count)
144 | mesh.edge_areas.append(0)
145 | nb_count.append(0)
146 | edges_count += 1
147 | mesh.edge_areas[edge2key[edge]] += face_areas[face_id] / 3
148 | for idx, edge in enumerate(faces_edges):
149 | edge_key = edge2key[edge]
150 | edge_nb[edge_key][nb_count[edge_key]] = edge2key[faces_edges[(idx + 1) % 3]]
151 | edge_nb[edge_key][nb_count[edge_key] + 1] = edge2key[faces_edges[(idx + 2) % 3]]
152 | nb_count[edge_key] += 2
153 | for idx, edge in enumerate(faces_edges):
154 | edge_key = edge2key[edge]
155 | sides[edge_key][nb_count[edge_key] - 2] = nb_count[edge2key[faces_edges[(idx + 1) % 3]]] - 1
156 | sides[edge_key][nb_count[edge_key] - 1] = nb_count[edge2key[faces_edges[(idx + 2) % 3]]] - 2
157 | mesh.edges = np.array(edges, dtype=np.int32)
158 | mesh.gemm_edges = np.array(edge_nb, dtype=np.int64)
159 | mesh.sides = np.array(sides, dtype=np.int64)
160 | mesh.edges_count = edges_count
161 | mesh.edge_areas = np.array(mesh.edge_areas, dtype=np.float32) / np.sum(face_areas) #todo whats the difference between edge_areas and edge_lenghts?
162 |
163 |
164 | def compute_face_normals_and_areas(mesh, faces):
165 | face_normals = np.cross(mesh.vs[faces[:, 1]] - mesh.vs[faces[:, 0]],
166 | mesh.vs[faces[:, 2]] - mesh.vs[faces[:, 1]])
167 | face_areas = np.sqrt((face_normals ** 2).sum(axis=1))
168 | face_normals /= face_areas[:, np.newaxis]
169 | assert (not np.any(face_areas[:, np.newaxis] == 0)), 'has zero area face: %s' % mesh.filename
170 | face_areas *= 0.5
171 | return face_normals, face_areas
172 |
173 |
174 | # Data augmentation methods
175 | def augmentation(mesh, opt, faces=None):
176 | if hasattr(opt, 'scale_verts') and opt.scale_verts:
177 | scale_verts(mesh)
178 | if hasattr(opt, 'flip_edges') and opt.flip_edges:
179 | faces = flip_edges(mesh, opt.flip_edges, faces)
180 | return faces
181 |
182 |
183 | def post_augmentation(mesh, opt):
184 | if hasattr(opt, 'slide_verts') and opt.slide_verts:
185 | slide_verts(mesh, opt.slide_verts)
186 |
187 |
188 | def slide_verts(mesh, prct):
189 | edge_points = get_edge_points(mesh)
190 | dihedral = dihedral_angle(mesh, edge_points).squeeze() #todo make fixed_division epsilon=0
191 | thr = np.mean(dihedral) + np.std(dihedral)
192 | vids = np.random.permutation(len(mesh.ve))
193 | target = int(prct * len(vids))
194 | shifted = 0
195 | for vi in vids:
196 | if shifted < target:
197 | edges = mesh.ve[vi]
198 | if min(dihedral[edges]) > 2.65:
199 | edge = mesh.edges[np.random.choice(edges)]
200 | vi_t = edge[1] if vi == edge[0] else edge[0]
201 | nv = mesh.vs[vi] + np.random.uniform(0.2, 0.5) * (mesh.vs[vi_t] - mesh.vs[vi])
202 | mesh.vs[vi] = nv
203 | shifted += 1
204 | else:
205 | break
206 | mesh.shifted = shifted / len(mesh.ve)
207 |
208 |
209 | def scale_verts(mesh, mean=1, var=0.1):
210 | for i in range(mesh.vs.shape[1]):
211 | mesh.vs[:, i] = mesh.vs[:, i] * np.random.normal(mean, var)
212 |
213 |
214 | def angles_from_faces(mesh, edge_faces, faces):
215 | normals = [None, None]
216 | for i in range(2):
217 | edge_a = mesh.vs[faces[edge_faces[:, i], 2]] - mesh.vs[faces[edge_faces[:, i], 1]]
218 | edge_b = mesh.vs[faces[edge_faces[:, i], 1]] - mesh.vs[faces[edge_faces[:, i], 0]]
219 | normals[i] = np.cross(edge_a, edge_b)
220 | div = fixed_division(np.linalg.norm(normals[i], ord=2, axis=1), epsilon=0)
221 | normals[i] /= div[:, np.newaxis]
222 | dot = np.sum(normals[0] * normals[1], axis=1).clip(-1, 1)
223 | angles = np.pi - np.arccos(dot)
224 | return angles
225 |
226 |
227 | def flip_edges(mesh, prct, faces):
228 | edge_count, edge_faces, edges_dict = get_edge_faces(faces)
229 | dihedral = angles_from_faces(mesh, edge_faces[:, 2:], faces)
230 | edges2flip = np.random.permutation(edge_count)
231 | # print(dihedral.min())
232 | # print(dihedral.max())
233 | target = int(prct * edge_count)
234 | flipped = 0
235 | for edge_key in edges2flip:
236 | if flipped == target:
237 | break
238 | if dihedral[edge_key] > 2.7:
239 | edge_info = edge_faces[edge_key]
240 | if edge_info[3] == -1:
241 | continue
242 | new_edge = tuple(sorted(list(set(faces[edge_info[2]]) ^ set(faces[edge_info[3]]))))
243 | if new_edge in edges_dict:
244 | continue
245 | new_faces = np.array(
246 | [[edge_info[1], new_edge[0], new_edge[1]], [edge_info[0], new_edge[0], new_edge[1]]])
247 | if check_area(mesh, new_faces):
248 | del edges_dict[(edge_info[0], edge_info[1])]
249 | edge_info[:2] = [new_edge[0], new_edge[1]]
250 | edges_dict[new_edge] = edge_key
251 | rebuild_face(faces[edge_info[2]], new_faces[0])
252 | rebuild_face(faces[edge_info[3]], new_faces[1])
253 | for i, face_id in enumerate([edge_info[2], edge_info[3]]):
254 | cur_face = faces[face_id]
255 | for j in range(3):
256 | cur_edge = tuple(sorted((cur_face[j], cur_face[(j + 1) % 3])))
257 | if cur_edge != new_edge:
258 | cur_edge_key = edges_dict[cur_edge]
259 | for idx, face_nb in enumerate(
260 | [edge_faces[cur_edge_key, 2], edge_faces[cur_edge_key, 3]]):
261 | if face_nb == edge_info[2 + (i + 1) % 2]:
262 | edge_faces[cur_edge_key, 2 + idx] = face_id
263 | flipped += 1
264 | # print(flipped)
265 | return faces
266 |
267 |
268 | def rebuild_face(face, new_face):
269 | new_point = list(set(new_face) - set(face))[0]
270 | for i in range(3):
271 | if face[i] not in new_face:
272 | face[i] = new_point
273 | break
274 | return face
275 |
276 | def check_area(mesh, faces):
277 | face_normals = np.cross(mesh.vs[faces[:, 1]] - mesh.vs[faces[:, 0]],
278 | mesh.vs[faces[:, 2]] - mesh.vs[faces[:, 1]])
279 | face_areas = np.sqrt((face_normals ** 2).sum(axis=1))
280 | face_areas *= 0.5
281 | return face_areas[0] > 0 and face_areas[1] > 0
282 |
283 |
284 | def get_edge_faces(faces):
285 | edge_count = 0
286 | edge_faces = []
287 | edge2keys = dict()
288 | for face_id, face in enumerate(faces):
289 | for i in range(3):
290 | cur_edge = tuple(sorted((face[i], face[(i + 1) % 3])))
291 | if cur_edge not in edge2keys:
292 | edge2keys[cur_edge] = edge_count
293 | edge_count += 1
294 | edge_faces.append(np.array([cur_edge[0], cur_edge[1], -1, -1]))
295 | edge_key = edge2keys[cur_edge]
296 | if edge_faces[edge_key][2] == -1:
297 | edge_faces[edge_key][2] = face_id
298 | else:
299 | edge_faces[edge_key][3] = face_id
300 | return edge_count, np.array(edge_faces), edge2keys
301 |
302 |
303 | def set_edge_lengths(mesh, edge_points=None):
304 | if edge_points is not None:
305 | edge_points = get_edge_points(mesh)
306 | edge_lengths = np.linalg.norm(mesh.vs[edge_points[:, 0]] - mesh.vs[edge_points[:, 1]], ord=2, axis=1)
307 | mesh.edge_lengths = edge_lengths
308 |
309 |
310 | def extract_features(mesh):
311 | features = []
312 | edge_points = get_edge_points(mesh)
313 | set_edge_lengths(mesh, edge_points)
314 | with np.errstate(divide='raise'):
315 | try:
316 | for extractor in [dihedral_angle, symmetric_opposite_angles, symmetric_ratios]:
317 | feature = extractor(mesh, edge_points)
318 | features.append(feature)
319 | return np.concatenate(features, axis=0)
320 | except Exception as e:
321 | print(e)
322 | raise ValueError(mesh.filename, 'bad features')
323 |
324 |
325 | def dihedral_angle(mesh, edge_points):
326 | normals_a = get_normals(mesh, edge_points, 0)
327 | normals_b = get_normals(mesh, edge_points, 3)
328 | dot = np.sum(normals_a * normals_b, axis=1).clip(-1, 1)
329 | angles = np.expand_dims(np.pi - np.arccos(dot), axis=0)
330 | return angles
331 |
332 |
333 | def symmetric_opposite_angles(mesh, edge_points):
334 | """ computes two angles: one for each face shared between the edge
335 | the angle is in each face opposite the edge
336 | sort handles order ambiguity
337 | """
338 | angles_a = get_opposite_angles(mesh, edge_points, 0)
339 | angles_b = get_opposite_angles(mesh, edge_points, 3)
340 | angles = np.concatenate((np.expand_dims(angles_a, 0), np.expand_dims(angles_b, 0)), axis=0)
341 | angles = np.sort(angles, axis=0)
342 | return angles
343 |
344 |
345 | def symmetric_ratios(mesh, edge_points):
346 | """ computes two ratios: one for each face shared between the edge
347 | the ratio is between the height / base (edge) of each triangle
348 | sort handles order ambiguity
349 | """
350 | ratios_a = get_ratios(mesh, edge_points, 0)
351 | ratios_b = get_ratios(mesh, edge_points, 3)
352 | ratios = np.concatenate((np.expand_dims(ratios_a, 0), np.expand_dims(ratios_b, 0)), axis=0)
353 | return np.sort(ratios, axis=0)
354 |
355 |
356 | def get_edge_points(mesh):
357 | """ returns: edge_points (#E x 4) tensor, with four vertex ids per edge
358 | for example: edge_points[edge_id, 0] and edge_points[edge_id, 1] are the two vertices which define edge_id
359 | each adjacent face to edge_id has another vertex, which is edge_points[edge_id, 2] or edge_points[edge_id, 3]
360 | """
361 | edge_points = np.zeros([mesh.edges_count, 4], dtype=np.int32)
362 | for edge_id, edge in enumerate(mesh.edges):
363 | edge_points[edge_id] = get_side_points(mesh, edge_id)
364 | # edge_points[edge_id, 3:] = mesh.get_side_points(edge_id, 2)
365 | return edge_points
366 |
367 |
368 | def get_side_points(mesh, edge_id):
369 | # if mesh.gemm_edges[edge_id, side] == -1:
370 | # return mesh.get_side_points(edge_id, ((side + 2) % 4))
371 | # else:
372 | edge_a = mesh.edges[edge_id]
373 |
374 | if mesh.gemm_edges[edge_id, 0] == -1:
375 | edge_b = mesh.edges[mesh.gemm_edges[edge_id, 2]]
376 | edge_c = mesh.edges[mesh.gemm_edges[edge_id, 3]]
377 | else:
378 | edge_b = mesh.edges[mesh.gemm_edges[edge_id, 0]]
379 | edge_c = mesh.edges[mesh.gemm_edges[edge_id, 1]]
380 | if mesh.gemm_edges[edge_id, 2] == -1:
381 | edge_d = mesh.edges[mesh.gemm_edges[edge_id, 0]]
382 | edge_e = mesh.edges[mesh.gemm_edges[edge_id, 1]]
383 | else:
384 | edge_d = mesh.edges[mesh.gemm_edges[edge_id, 2]]
385 | edge_e = mesh.edges[mesh.gemm_edges[edge_id, 3]]
386 | first_vertex = 0
387 | second_vertex = 0
388 | third_vertex = 0
389 | if edge_a[1] in edge_b:
390 | first_vertex = 1
391 | if edge_b[1] in edge_c:
392 | second_vertex = 1
393 | if edge_d[1] in edge_e:
394 | third_vertex = 1
395 | return [edge_a[first_vertex], edge_a[1 - first_vertex], edge_b[second_vertex], edge_d[third_vertex]]
396 |
397 |
398 | def get_normals(mesh, edge_points, side):
399 | edge_a = mesh.vs[edge_points[:, side // 2 + 2]] - mesh.vs[edge_points[:, side // 2]]
400 | edge_b = mesh.vs[edge_points[:, 1 - side // 2]] - mesh.vs[edge_points[:, side // 2]]
401 | normals = np.cross(edge_a, edge_b)
402 | div = fixed_division(np.linalg.norm(normals, ord=2, axis=1), epsilon=0.1)
403 | normals /= div[:, np.newaxis]
404 | return normals
405 |
406 | def get_opposite_angles(mesh, edge_points, side):
407 | edges_a = mesh.vs[edge_points[:, side // 2]] - mesh.vs[edge_points[:, side // 2 + 2]]
408 | edges_b = mesh.vs[edge_points[:, 1 - side // 2]] - mesh.vs[edge_points[:, side // 2 + 2]]
409 |
410 | edges_a /= fixed_division(np.linalg.norm(edges_a, ord=2, axis=1), epsilon=0.1)[:, np.newaxis]
411 | edges_b /= fixed_division(np.linalg.norm(edges_b, ord=2, axis=1), epsilon=0.1)[:, np.newaxis]
412 | dot = np.sum(edges_a * edges_b, axis=1).clip(-1, 1)
413 | return np.arccos(dot)
414 |
415 |
416 | def get_ratios(mesh, edge_points, side):
417 | edges_lengths = np.linalg.norm(mesh.vs[edge_points[:, side // 2]] - mesh.vs[edge_points[:, 1 - side // 2]],
418 | ord=2, axis=1)
419 | point_o = mesh.vs[edge_points[:, side // 2 + 2]]
420 | point_a = mesh.vs[edge_points[:, side // 2]]
421 | point_b = mesh.vs[edge_points[:, 1 - side // 2]]
422 | line_ab = point_b - point_a
423 | projection_length = np.sum(line_ab * (point_o - point_a), axis=1) / fixed_division(
424 | np.linalg.norm(line_ab, ord=2, axis=1), epsilon=0.1)
425 | closest_point = point_a + (projection_length / edges_lengths)[:, np.newaxis] * line_ab
426 | d = np.linalg.norm(point_o - closest_point, ord=2, axis=1)
427 | return d / edges_lengths
428 |
429 | def fixed_division(to_div, epsilon):
430 | if epsilon == 0:
431 | to_div[to_div == 0] = 0.1
432 | else:
433 | to_div += epsilon
434 | return to_div
435 |
--------------------------------------------------------------------------------
/models/layers/mesh_union.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.nn import ConstantPad2d
3 |
4 |
5 | class MeshUnion:
6 | def __init__(self, n, device=torch.device('cpu')):
7 | self.__size = n
8 | self.rebuild_features = self.rebuild_features_average
9 | self.groups = torch.eye(n, device=device)
10 |
11 | def union(self, source, target):
12 | self.groups[target, :] += self.groups[source, :]
13 |
14 | def remove_group(self, index):
15 | return
16 |
17 | def get_group(self, edge_key):
18 | return self.groups[edge_key, :]
19 |
20 | def get_occurrences(self):
21 | return torch.sum(self.groups, 0)
22 |
23 | def get_groups(self, tensor_mask):
24 | self.groups = torch.clamp(self.groups, 0, 1)
25 | return self.groups[tensor_mask, :]
26 |
27 | def rebuild_features_average(self, features, mask, target_edges):
28 | self.prepare_groups(features, mask)
29 | fe = torch.matmul(features.squeeze(-1), self.groups)
30 | occurrences = torch.sum(self.groups, 0).expand(fe.shape)
31 | fe = fe / occurrences
32 | padding_b = target_edges - fe.shape[1]
33 | if padding_b > 0:
34 | padding_b = ConstantPad2d((0, padding_b, 0, 0), 0)
35 | fe = padding_b(fe)
36 | return fe
37 |
38 | def prepare_groups(self, features, mask):
39 | tensor_mask = torch.from_numpy(mask)
40 | self.groups = torch.clamp(self.groups[tensor_mask, :], 0, 1).transpose_(1, 0)
41 | padding_a = features.shape[1] - self.groups.shape[0]
42 | if padding_a > 0:
43 | padding_a = ConstantPad2d((0, 0, 0, padding_a), 0)
44 | self.groups = padding_a(self.groups)
45 |
--------------------------------------------------------------------------------
/models/layers/mesh_unpool.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 |
6 | class MeshUnpool(nn.Module):
7 | def __init__(self, unroll_target):
8 | super(MeshUnpool, self).__init__()
9 | self.unroll_target = unroll_target
10 |
11 | def __call__(self, features, meshes):
12 | return self.forward(features, meshes)
13 |
14 | def pad_groups(self, group, unroll_start):
15 | start, end = group.shape
16 | padding_rows = unroll_start - start
17 | padding_cols = self.unroll_target - end
18 | if padding_rows != 0 or padding_cols !=0:
19 | padding = nn.ConstantPad2d((0, padding_cols, 0, padding_rows), 0)
20 | group = padding(group)
21 | return group
22 |
23 | def pad_occurrences(self, occurrences):
24 | padding = self.unroll_target - occurrences.shape[0]
25 | if padding != 0:
26 | padding = nn.ConstantPad1d((0, padding), 1)
27 | occurrences = padding(occurrences)
28 | return occurrences
29 |
30 | def forward(self, features, meshes):
31 | batch_size, nf, edges = features.shape
32 | groups = [self.pad_groups(mesh.get_groups(), edges) for mesh in meshes]
33 | unroll_mat = torch.cat(groups, dim=0).view(batch_size, edges, -1)
34 | occurrences = [self.pad_occurrences(mesh.get_occurrences()) for mesh in meshes]
35 | occurrences = torch.cat(occurrences, dim=0).view(batch_size, 1, -1)
36 | occurrences = occurrences.expand(unroll_mat.shape)
37 | unroll_mat = unroll_mat / occurrences
38 | unroll_mat = unroll_mat.to(features.device)
39 | for mesh in meshes:
40 | mesh.unroll_gemm()
41 | return torch.matmul(features, unroll_mat)
42 |
--------------------------------------------------------------------------------
/models/mesh_classifier.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from . import networks
3 | from os.path import join
4 | from util.util import seg_accuracy, print_network
5 |
6 |
7 | class ClassifierModel:
8 | """ Class for training Model weights
9 |
10 | :args opt: structure containing configuration params
11 | e.g.,
12 | --dataset_mode -> classification / segmentation)
13 | --arch -> network type
14 | """
15 | def __init__(self, opt):
16 | self.opt = opt
17 | self.gpu_ids = opt.gpu_ids
18 | self.is_train = opt.is_train
19 | self.device = torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu')
20 | self.save_dir = join(opt.checkpoints_dir, opt.name)
21 | self.optimizer = None
22 | self.edge_features = None
23 | self.labels = None
24 | self.mesh = None
25 | self.soft_label = None
26 | self.loss = None
27 |
28 | #
29 | self.nclasses = opt.nclasses
30 |
31 | # load/define networks
32 | self.net = networks.define_classifier(opt.input_nc, opt.ncf, opt.ninput_edges, opt.nclasses, opt,
33 | self.gpu_ids, opt.arch, opt.init_type, opt.init_gain)
34 | self.net.train(self.is_train)
35 | self.criterion = networks.define_loss(opt).to(self.device)
36 |
37 | if self.is_train:
38 | self.optimizer = torch.optim.Adam(self.net.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
39 | self.scheduler = networks.get_scheduler(self.optimizer, opt)
40 | print_network(self.net)
41 |
42 | if not self.is_train or opt.continue_train:
43 | self.load_network(opt.which_epoch)
44 |
45 | def set_input(self, data):
46 | input_edge_features = torch.from_numpy(data['edge_features']).float()
47 | labels = torch.from_numpy(data['label']).long()
48 | # set inputs
49 | self.edge_features = input_edge_features.to(self.device).requires_grad_(self.is_train)
50 | self.labels = labels.to(self.device)
51 | self.mesh = data['mesh']
52 | if self.opt.dataset_mode == 'segmentation' and not self.is_train:
53 | self.soft_label = torch.from_numpy(data['soft_label'])
54 |
55 |
56 | def forward(self):
57 | out = self.net(self.edge_features, self.mesh)
58 | return out
59 |
60 | def backward(self, out):
61 | self.loss = self.criterion(out, self.labels)
62 | self.loss.backward()
63 |
64 | def optimize_parameters(self):
65 | self.optimizer.zero_grad()
66 | out = self.forward()
67 | self.backward(out)
68 | self.optimizer.step()
69 |
70 |
71 | ##################
72 |
73 | def load_network(self, which_epoch):
74 | """load model from disk"""
75 | save_filename = '%s_net.pth' % which_epoch
76 | load_path = join(self.save_dir, save_filename)
77 | net = self.net
78 | if isinstance(net, torch.nn.DataParallel):
79 | net = net.module
80 | print('loading the model from %s' % load_path)
81 | # PyTorch newer than 0.4 (e.g., built from
82 | # GitHub source), you can remove str() on self.device
83 | state_dict = torch.load(load_path, map_location=str(self.device))
84 | if hasattr(state_dict, '_metadata'):
85 | del state_dict._metadata
86 | net.load_state_dict(state_dict)
87 |
88 |
89 | def save_network(self, which_epoch):
90 | """save model to disk"""
91 | save_filename = '%s_net.pth' % (which_epoch)
92 | save_path = join(self.save_dir, save_filename)
93 | if len(self.gpu_ids) > 0 and torch.cuda.is_available():
94 | torch.save(self.net.module.cpu().state_dict(), save_path)
95 | self.net.cuda(self.gpu_ids[0])
96 | else:
97 | torch.save(self.net.cpu().state_dict(), save_path)
98 |
99 | def update_learning_rate(self):
100 | """update learning rate (called once every epoch)"""
101 | self.scheduler.step()
102 | lr = self.optimizer.param_groups[0]['lr']
103 | print('learning rate = %.7f' % lr)
104 |
105 | def test(self):
106 | """tests model
107 | returns: number correct and total number
108 | """
109 | with torch.no_grad():
110 | out = self.forward()
111 | # compute number of correct
112 | pred_class = out.data.max(1)[1]
113 | label_class = self.labels
114 | self.export_segmentation(pred_class.cpu())
115 | correct = self.get_accuracy(pred_class, label_class)
116 | return correct, len(label_class)
117 |
118 | def get_accuracy(self, pred, labels):
119 | """computes accuracy for classification / segmentation """
120 | if self.opt.dataset_mode == 'classification':
121 | correct = pred.eq(labels).sum()
122 | elif self.opt.dataset_mode == 'segmentation':
123 | correct = seg_accuracy(pred, self.soft_label, self.mesh)
124 | return correct
125 |
126 | def export_segmentation(self, pred_seg):
127 | if self.opt.dataset_mode == 'segmentation':
128 | for meshi, mesh in enumerate(self.mesh):
129 | mesh.export_segments(pred_seg[meshi, :])
130 |
--------------------------------------------------------------------------------
/models/networks.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.nn import init
4 | import functools
5 | from torch.optim import lr_scheduler
6 | from models.layers.mesh_conv import MeshConv
7 | import torch.nn.functional as F
8 | from models.layers.mesh_pool import MeshPool
9 | from models.layers.mesh_unpool import MeshUnpool
10 |
11 |
12 | ###############################################################################
13 | # Helper Functions
14 | ###############################################################################
15 |
16 |
17 | def get_norm_layer(norm_type='instance', num_groups=1):
18 | if norm_type == 'batch':
19 | norm_layer = functools.partial(nn.BatchNorm2d, affine=True)
20 | elif norm_type == 'instance':
21 | norm_layer = functools.partial(nn.InstanceNorm2d, affine=False)
22 | elif norm_type == 'group':
23 | norm_layer = functools.partial(nn.GroupNorm, affine=True, num_groups=num_groups)
24 | elif norm_type == 'none':
25 | norm_layer = NoNorm
26 | else:
27 | raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
28 | return norm_layer
29 |
30 | def get_norm_args(norm_layer, nfeats_list):
31 | if hasattr(norm_layer, '__name__') and norm_layer.__name__ == 'NoNorm':
32 | norm_args = [{'fake': True} for f in nfeats_list]
33 | elif norm_layer.func.__name__ == 'GroupNorm':
34 | norm_args = [{'num_channels': f} for f in nfeats_list]
35 | elif norm_layer.func.__name__ == 'BatchNorm':
36 | norm_args = [{'num_features': f} for f in nfeats_list]
37 | else:
38 | raise NotImplementedError('normalization layer [%s] is not found' % norm_layer.func.__name__)
39 | return norm_args
40 |
41 | class NoNorm(nn.Module): #todo with abstractclass and pass
42 | def __init__(self, fake=True):
43 | self.fake = fake
44 | super(NoNorm, self).__init__()
45 | def forward(self, x):
46 | return x
47 | def __call__(self, x):
48 | return self.forward(x)
49 |
50 | def get_scheduler(optimizer, opt):
51 | if opt.lr_policy == 'lambda':
52 | def lambda_rule(epoch):
53 | lr_l = 1.0 - max(0, epoch + 1 + opt.epoch_count - opt.niter) / float(opt.niter_decay + 1)
54 | return lr_l
55 | scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
56 | elif opt.lr_policy == 'step':
57 | scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1)
58 | elif opt.lr_policy == 'plateau':
59 | scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5)
60 | else:
61 | return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy)
62 | return scheduler
63 |
64 |
65 | def init_weights(net, init_type, init_gain):
66 | def init_func(m):
67 | classname = m.__class__.__name__
68 | if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
69 | if init_type == 'normal':
70 | init.normal_(m.weight.data, 0.0, init_gain)
71 | elif init_type == 'xavier':
72 | init.xavier_normal_(m.weight.data, gain=init_gain)
73 | elif init_type == 'kaiming':
74 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
75 | elif init_type == 'orthogonal':
76 | init.orthogonal_(m.weight.data, gain=init_gain)
77 | else:
78 | raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
79 | elif classname.find('BatchNorm2d') != -1:
80 | init.normal_(m.weight.data, 1.0, init_gain)
81 | init.constant_(m.bias.data, 0.0)
82 | net.apply(init_func)
83 |
84 |
85 | def init_net(net, init_type, init_gain, gpu_ids):
86 | if len(gpu_ids) > 0:
87 | assert(torch.cuda.is_available())
88 | net.cuda(gpu_ids[0])
89 | net = net.cuda()
90 | net = torch.nn.DataParallel(net, gpu_ids)
91 | if init_type != 'none':
92 | init_weights(net, init_type, init_gain)
93 | return net
94 |
95 |
96 | def define_classifier(input_nc, ncf, ninput_edges, nclasses, opt, gpu_ids, arch, init_type, init_gain):
97 | net = None
98 | norm_layer = get_norm_layer(norm_type=opt.norm, num_groups=opt.num_groups)
99 |
100 | if arch == 'mconvnet':
101 | net = MeshConvNet(norm_layer, input_nc, ncf, nclasses, ninput_edges, opt.pool_res, opt.fc_n,
102 | opt.resblocks)
103 | elif arch == 'meshunet':
104 | down_convs = [input_nc] + ncf
105 | up_convs = ncf[::-1] + [nclasses]
106 | pool_res = [ninput_edges] + opt.pool_res
107 | net = MeshEncoderDecoder(pool_res, down_convs, up_convs, blocks=opt.resblocks,
108 | transfer_data=True)
109 | else:
110 | raise NotImplementedError('Encoder model name [%s] is not recognized' % arch)
111 | return init_net(net, init_type, init_gain, gpu_ids)
112 |
113 | def define_loss(opt):
114 | if opt.dataset_mode == 'classification':
115 | loss = torch.nn.CrossEntropyLoss()
116 | elif opt.dataset_mode == 'segmentation':
117 | loss = torch.nn.CrossEntropyLoss(ignore_index=-1)
118 | return loss
119 |
120 | ##############################################################################
121 | # Classes For Classification / Segmentation Networks
122 | ##############################################################################
123 |
124 | class MeshConvNet(nn.Module):
125 | """Network for learning a global shape descriptor (classification)
126 | """
127 | def __init__(self, norm_layer, nf0, conv_res, nclasses, input_res, pool_res, fc_n,
128 | nresblocks=3):
129 | super(MeshConvNet, self).__init__()
130 | self.k = [nf0] + conv_res
131 | self.res = [input_res] + pool_res
132 | norm_args = get_norm_args(norm_layer, self.k[1:])
133 |
134 | for i, ki in enumerate(self.k[:-1]):
135 | setattr(self, 'conv{}'.format(i), MResConv(ki, self.k[i + 1], nresblocks))
136 | setattr(self, 'norm{}'.format(i), norm_layer(**norm_args[i]))
137 | setattr(self, 'pool{}'.format(i), MeshPool(self.res[i + 1]))
138 |
139 |
140 | self.gp = torch.nn.AvgPool1d(self.res[-1])
141 | # self.gp = torch.nn.MaxPool1d(self.res[-1])
142 | self.fc1 = nn.Linear(self.k[-1], fc_n)
143 | self.fc2 = nn.Linear(fc_n, nclasses)
144 |
145 | def forward(self, x, mesh):
146 |
147 | for i in range(len(self.k) - 1):
148 | x = getattr(self, 'conv{}'.format(i))(x, mesh)
149 | x = F.relu(getattr(self, 'norm{}'.format(i))(x))
150 | x = getattr(self, 'pool{}'.format(i))(x, mesh)
151 |
152 | x = self.gp(x)
153 | x = x.view(-1, self.k[-1])
154 |
155 | x = F.relu(self.fc1(x))
156 | x = self.fc2(x)
157 | return x
158 |
159 | class MResConv(nn.Module):
160 | def __init__(self, in_channels, out_channels, skips=1):
161 | super(MResConv, self).__init__()
162 | self.in_channels = in_channels
163 | self.out_channels = out_channels
164 | self.skips = skips
165 | self.conv0 = MeshConv(self.in_channels, self.out_channels, bias=False)
166 | for i in range(self.skips):
167 | setattr(self, 'bn{}'.format(i + 1), nn.BatchNorm2d(self.out_channels))
168 | setattr(self, 'conv{}'.format(i + 1),
169 | MeshConv(self.out_channels, self.out_channels, bias=False))
170 |
171 | def forward(self, x, mesh):
172 | x = self.conv0(x, mesh)
173 | x1 = x
174 | for i in range(self.skips):
175 | x = getattr(self, 'bn{}'.format(i + 1))(F.relu(x))
176 | x = getattr(self, 'conv{}'.format(i + 1))(x, mesh)
177 | x += x1
178 | x = F.relu(x)
179 | return x
180 |
181 |
182 | class MeshEncoderDecoder(nn.Module):
183 | """Network for fully-convolutional tasks (segmentation)
184 | """
185 | def __init__(self, pools, down_convs, up_convs, blocks=0, transfer_data=True):
186 | super(MeshEncoderDecoder, self).__init__()
187 | self.transfer_data = transfer_data
188 | self.encoder = MeshEncoder(pools, down_convs, blocks=blocks)
189 | unrolls = pools[:-1].copy()
190 | unrolls.reverse()
191 | self.decoder = MeshDecoder(unrolls, up_convs, blocks=blocks, transfer_data=transfer_data)
192 |
193 | def forward(self, x, meshes):
194 | fe, before_pool = self.encoder((x, meshes))
195 | fe = self.decoder((fe, meshes), before_pool)
196 | return fe
197 |
198 | def __call__(self, x, meshes):
199 | return self.forward(x, meshes)
200 |
201 | class DownConv(nn.Module):
202 | def __init__(self, in_channels, out_channels, blocks=0, pool=0):
203 | super(DownConv, self).__init__()
204 | self.bn = []
205 | self.pool = None
206 | self.conv1 = MeshConv(in_channels, out_channels)
207 | self.conv2 = []
208 | for _ in range(blocks):
209 | self.conv2.append(MeshConv(out_channels, out_channels))
210 | self.conv2 = nn.ModuleList(self.conv2)
211 | for _ in range(blocks + 1):
212 | self.bn.append(nn.InstanceNorm2d(out_channels))
213 | self.bn = nn.ModuleList(self.bn)
214 | if pool:
215 | self.pool = MeshPool(pool)
216 |
217 | def __call__(self, x):
218 | return self.forward(x)
219 |
220 | def forward(self, x):
221 | fe, meshes = x
222 | x1 = self.conv1(fe, meshes)
223 | if self.bn:
224 | x1 = self.bn[0](x1)
225 | x1 = F.relu(x1)
226 | x2 = x1
227 | for idx, conv in enumerate(self.conv2):
228 | x2 = conv(x1, meshes)
229 | if self.bn:
230 | x2 = self.bn[idx + 1](x2)
231 | x2 = x2 + x1
232 | x2 = F.relu(x2)
233 | x1 = x2
234 | x2 = x2.squeeze(3)
235 | before_pool = None
236 | if self.pool:
237 | before_pool = x2
238 | x2 = self.pool(x2, meshes)
239 | return x2, before_pool
240 |
241 |
242 | class UpConv(nn.Module):
243 | def __init__(self, in_channels, out_channels, blocks=0, unroll=0, residual=True,
244 | batch_norm=True, transfer_data=True):
245 | super(UpConv, self).__init__()
246 | self.residual = residual
247 | self.bn = []
248 | self.unroll = None
249 | self.transfer_data = transfer_data
250 | self.up_conv = MeshConv(in_channels, out_channels)
251 | if transfer_data:
252 | self.conv1 = MeshConv(2 * out_channels, out_channels)
253 | else:
254 | self.conv1 = MeshConv(out_channels, out_channels)
255 | self.conv2 = []
256 | for _ in range(blocks):
257 | self.conv2.append(MeshConv(out_channels, out_channels))
258 | self.conv2 = nn.ModuleList(self.conv2)
259 | if batch_norm:
260 | for _ in range(blocks + 1):
261 | self.bn.append(nn.InstanceNorm2d(out_channels))
262 | self.bn = nn.ModuleList(self.bn)
263 | if unroll:
264 | self.unroll = MeshUnpool(unroll)
265 |
266 | def __call__(self, x, from_down=None):
267 | return self.forward(x, from_down)
268 |
269 | def forward(self, x, from_down):
270 | from_up, meshes = x
271 | x1 = self.up_conv(from_up, meshes).squeeze(3)
272 | if self.unroll:
273 | x1 = self.unroll(x1, meshes)
274 | if self.transfer_data:
275 | x1 = torch.cat((x1, from_down), 1)
276 | x1 = self.conv1(x1, meshes)
277 | if self.bn:
278 | x1 = self.bn[0](x1)
279 | x1 = F.relu(x1)
280 | x2 = x1
281 | for idx, conv in enumerate(self.conv2):
282 | x2 = conv(x1, meshes)
283 | if self.bn:
284 | x2 = self.bn[idx + 1](x2)
285 | if self.residual:
286 | x2 = x2 + x1
287 | x2 = F.relu(x2)
288 | x1 = x2
289 | x2 = x2.squeeze(3)
290 | return x2
291 |
292 |
293 | class MeshEncoder(nn.Module):
294 | def __init__(self, pools, convs, fcs=None, blocks=0, global_pool=None):
295 | super(MeshEncoder, self).__init__()
296 | self.fcs = None
297 | self.convs = []
298 | for i in range(len(convs) - 1):
299 | if i + 1 < len(pools):
300 | pool = pools[i + 1]
301 | else:
302 | pool = 0
303 | self.convs.append(DownConv(convs[i], convs[i + 1], blocks=blocks, pool=pool))
304 | self.global_pool = None
305 | if fcs is not None:
306 | self.fcs = []
307 | self.fcs_bn = []
308 | last_length = convs[-1]
309 | if global_pool is not None:
310 | if global_pool == 'max':
311 | self.global_pool = nn.MaxPool1d(pools[-1])
312 | elif global_pool == 'avg':
313 | self.global_pool = nn.AvgPool1d(pools[-1])
314 | else:
315 | assert False, 'global_pool %s is not defined' % global_pool
316 | else:
317 | last_length *= pools[-1]
318 | if fcs[0] == last_length:
319 | fcs = fcs[1:]
320 | for length in fcs:
321 | self.fcs.append(nn.Linear(last_length, length))
322 | self.fcs_bn.append(nn.InstanceNorm1d(length))
323 | last_length = length
324 | self.fcs = nn.ModuleList(self.fcs)
325 | self.fcs_bn = nn.ModuleList(self.fcs_bn)
326 | self.convs = nn.ModuleList(self.convs)
327 | reset_params(self)
328 |
329 | def forward(self, x):
330 | fe, meshes = x
331 | encoder_outs = []
332 | for conv in self.convs:
333 | fe, before_pool = conv((fe, meshes))
334 | encoder_outs.append(before_pool)
335 | if self.fcs is not None:
336 | if self.global_pool is not None:
337 | fe = self.global_pool(fe)
338 | fe = fe.contiguous().view(fe.size()[0], -1)
339 | for i in range(len(self.fcs)):
340 | fe = self.fcs[i](fe)
341 | if self.fcs_bn:
342 | x = fe.unsqueeze(1)
343 | fe = self.fcs_bn[i](x).squeeze(1)
344 | if i < len(self.fcs) - 1:
345 | fe = F.relu(fe)
346 | return fe, encoder_outs
347 |
348 | def __call__(self, x):
349 | return self.forward(x)
350 |
351 |
352 | class MeshDecoder(nn.Module):
353 | def __init__(self, unrolls, convs, blocks=0, batch_norm=True, transfer_data=True):
354 | super(MeshDecoder, self).__init__()
355 | self.up_convs = []
356 | for i in range(len(convs) - 2):
357 | if i < len(unrolls):
358 | unroll = unrolls[i]
359 | else:
360 | unroll = 0
361 | self.up_convs.append(UpConv(convs[i], convs[i + 1], blocks=blocks, unroll=unroll,
362 | batch_norm=batch_norm, transfer_data=transfer_data))
363 | self.final_conv = UpConv(convs[-2], convs[-1], blocks=blocks, unroll=False,
364 | batch_norm=batch_norm, transfer_data=False)
365 | self.up_convs = nn.ModuleList(self.up_convs)
366 | reset_params(self)
367 |
368 | def forward(self, x, encoder_outs=None):
369 | fe, meshes = x
370 | for i, up_conv in enumerate(self.up_convs):
371 | before_pool = None
372 | if encoder_outs is not None:
373 | before_pool = encoder_outs[-(i+2)]
374 | fe = up_conv((fe, meshes), before_pool)
375 | fe = self.final_conv((fe, meshes))
376 | return fe
377 |
378 | def __call__(self, x, encoder_outs=None):
379 | return self.forward(x, encoder_outs)
380 |
381 | def reset_params(model): # todo replace with my init
382 | for i, m in enumerate(model.modules()):
383 | weight_init(m)
384 |
385 | def weight_init(m):
386 | if isinstance(m, nn.Conv2d):
387 | nn.init.xavier_normal_(m.weight)
388 | nn.init.constant_(m.bias, 0)
389 |
--------------------------------------------------------------------------------
/options/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ranahanocka/MeshCNN/5bf0b899d48eb204b9b73bc1af381be20f4d7df1/options/__init__.py
--------------------------------------------------------------------------------
/options/base_options.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | from util import util
4 | import torch
5 |
6 |
7 | class BaseOptions:
8 | def __init__(self):
9 | self.parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
10 | self.initialized = False
11 |
12 | def initialize(self):
13 | # data params
14 | self.parser.add_argument('--dataroot', required=True, help='path to meshes (should have subfolders train, test)')
15 | self.parser.add_argument('--dataset_mode', choices={"classification", "segmentation"}, default='classification')
16 | self.parser.add_argument('--ninput_edges', type=int, default=750, help='# of input edges (will include dummy edges)')
17 | self.parser.add_argument('--max_dataset_size', type=int, default=float("inf"), help='Maximum number of samples per epoch')
18 | # network params
19 | self.parser.add_argument('--batch_size', type=int, default=16, help='input batch size')
20 | self.parser.add_argument('--arch', type=str, default='mconvnet', help='selects network to use') #todo add choices
21 | self.parser.add_argument('--resblocks', type=int, default=0, help='# of res blocks')
22 | self.parser.add_argument('--fc_n', type=int, default=100, help='# between fc and nclasses') #todo make generic
23 | self.parser.add_argument('--ncf', nargs='+', default=[16, 32, 32], type=int, help='conv filters')
24 | self.parser.add_argument('--pool_res', nargs='+', default=[1140, 780, 580], type=int, help='pooling res')
25 | self.parser.add_argument('--norm', type=str, default='batch',help='instance normalization or batch normalization or group normalization')
26 | self.parser.add_argument('--num_groups', type=int, default=16, help='# of groups for groupnorm')
27 | self.parser.add_argument('--init_type', type=str, default='normal', help='network initialization [normal|xavier|kaiming|orthogonal]')
28 | self.parser.add_argument('--init_gain', type=float, default=0.02, help='scaling factor for normal, xavier and orthogonal.')
29 | # general params
30 | self.parser.add_argument('--num_threads', default=3, type=int, help='# threads for loading data')
31 | self.parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU')
32 | self.parser.add_argument('--name', type=str, default='debug', help='name of the experiment. It decides where to store samples and models')
33 | self.parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here')
34 | self.parser.add_argument('--serial_batches', action='store_true', help='if true, takes meshes in order, otherwise takes them randomly')
35 | self.parser.add_argument('--seed', type=int, help='if specified, uses seed')
36 | # visualization params
37 | self.parser.add_argument('--export_folder', type=str, default='', help='exports intermediate collapses to this folder')
38 | #
39 | self.initialized = True
40 |
41 | def parse(self):
42 | if not self.initialized:
43 | self.initialize()
44 | self.opt, unknown = self.parser.parse_known_args()
45 | self.opt.is_train = self.is_train # train or test
46 |
47 | str_ids = self.opt.gpu_ids.split(',')
48 | self.opt.gpu_ids = []
49 | for str_id in str_ids:
50 | id = int(str_id)
51 | if id >= 0:
52 | self.opt.gpu_ids.append(id)
53 | # set gpu ids
54 | if len(self.opt.gpu_ids) > 0:
55 | torch.cuda.set_device(self.opt.gpu_ids[0])
56 |
57 | args = vars(self.opt)
58 |
59 | if self.opt.seed is not None:
60 | import numpy as np
61 | import random
62 | torch.manual_seed(self.opt.seed)
63 | np.random.seed(self.opt.seed)
64 | random.seed(self.opt.seed)
65 |
66 | if self.opt.export_folder:
67 | self.opt.export_folder = os.path.join(self.opt.checkpoints_dir, self.opt.name, self.opt.export_folder)
68 | util.mkdir(self.opt.export_folder)
69 |
70 | if self.is_train:
71 | print('------------ Options -------------')
72 | for k, v in sorted(args.items()):
73 | print('%s: %s' % (str(k), str(v)))
74 | print('-------------- End ----------------')
75 |
76 | # save to the disk
77 | expr_dir = os.path.join(self.opt.checkpoints_dir, self.opt.name)
78 | util.mkdir(expr_dir)
79 |
80 | file_name = os.path.join(expr_dir, 'opt.txt')
81 | with open(file_name, 'wt') as opt_file:
82 | opt_file.write('------------ Options -------------\n')
83 | for k, v in sorted(args.items()):
84 | opt_file.write('%s: %s\n' % (str(k), str(v)))
85 | opt_file.write('-------------- End ----------------\n')
86 | return self.opt
87 |
--------------------------------------------------------------------------------
/options/test_options.py:
--------------------------------------------------------------------------------
1 | from .base_options import BaseOptions
2 |
3 |
4 | class TestOptions(BaseOptions):
5 | def initialize(self):
6 | BaseOptions.initialize(self)
7 | self.parser.add_argument('--results_dir', type=str, default='./results/', help='saves results here.')
8 | self.parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc') #todo delete.
9 | self.parser.add_argument('--which_epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model')
10 | self.parser.add_argument('--num_aug', type=int, default=1, help='# of augmentation files')
11 | self.is_train = False
--------------------------------------------------------------------------------
/options/train_options.py:
--------------------------------------------------------------------------------
1 | from .base_options import BaseOptions
2 |
3 | class TrainOptions(BaseOptions):
4 | def initialize(self):
5 | BaseOptions.initialize(self)
6 | self.parser.add_argument('--print_freq', type=int, default=10, help='frequency of showing training results on console')
7 | self.parser.add_argument('--save_latest_freq', type=int, default=250, help='frequency of saving the latest results')
8 | self.parser.add_argument('--save_epoch_freq', type=int, default=1, help='frequency of saving checkpoints at the end of epochs')
9 | self.parser.add_argument('--run_test_freq', type=int, default=1, help='frequency of running test in training script')
10 | self.parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model')
11 | self.parser.add_argument('--epoch_count', type=int, default=1, help='the starting epoch count, we save the model by , +, ...')
12 | self.parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc')
13 | self.parser.add_argument('--which_epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model')
14 | self.parser.add_argument('--niter', type=int, default=100, help='# of iter at starting learning rate')
15 | self.parser.add_argument('--niter_decay', type=int, default=500, help='# of iter to linearly decay learning rate to zero')
16 | self.parser.add_argument('--beta1', type=float, default=0.9, help='momentum term of adam')
17 | self.parser.add_argument('--lr', type=float, default=0.0002, help='initial learning rate for adam')
18 | self.parser.add_argument('--lr_policy', type=str, default='lambda', help='learning rate policy: lambda|step|plateau')
19 | self.parser.add_argument('--lr_decay_iters', type=int, default=50, help='multiply by a gamma every lr_decay_iters iterations')
20 | # data augmentation stuff
21 | self.parser.add_argument('--num_aug', type=int, default=10, help='# of augmentation files')
22 | self.parser.add_argument('--scale_verts', action='store_true', help='non-uniformly scale the mesh e.g., in x, y or z')
23 | self.parser.add_argument('--slide_verts', type=float, default=0, help='percent vertices which will be shifted along the mesh surface')
24 | self.parser.add_argument('--flip_edges', type=float, default=0, help='percent of edges to randomly flip')
25 | # tensorboard visualization
26 | self.parser.add_argument('--no_vis', action='store_true', help='will not use tensorboard')
27 | self.parser.add_argument('--verbose_plot', action='store_true', help='plots network weights, etc.')
28 | self.is_train = True
29 |
--------------------------------------------------------------------------------
/scripts/coseg_seg/get_data.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | DATADIR='datasets' #location where data gets downloaded to
4 |
5 | echo "downloading the data and putting it in: " $DATADIR
6 | mkdir -p $DATADIR && cd $DATADIR
7 | wget https://www.dropbox.com/s/34vy4o5fthhz77d/coseg.tar.gz
8 | tar -xzvf coseg.tar.gz && rm coseg.tar.gz
--------------------------------------------------------------------------------
/scripts/coseg_seg/get_pretrained.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | CHECKPOINT=checkpoints/coseg_aliens
4 | mkdir -p $CHECKPOINT
5 |
6 | #gets the pretrained weights
7 | wget https://www.dropbox.com/s/er7my13k9dwg9ii/coseg_aliens_wts.tar.gz
8 | tar -xzvf coseg_aliens_wts.tar.gz && rm coseg_aliens_wts.tar.gz
9 | mv latest_net.pth $CHECKPOINT
10 | echo "downloaded pretrained weights to" $CHECKPOINT
--------------------------------------------------------------------------------
/scripts/coseg_seg/test.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | ## run the test and export collapses
4 | python test.py \
5 | --dataroot datasets/coseg_aliens \
6 | --name coseg_aliens \
7 | --arch meshunet \
8 | --dataset_mode segmentation \
9 | --ncf 32 64 128 256 \
10 | --ninput_edges 2280 \
11 | --pool_res 1800 1350 600 \
12 | --resblocks 3 \
13 | --batch_size 12 \
14 | --export_folder meshes \
--------------------------------------------------------------------------------
/scripts/coseg_seg/train.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | ## run the training
4 | python train.py \
5 | --dataroot datasets/coseg_aliens \
6 | --name coseg_aliens \
7 | --arch meshunet \
8 | --dataset_mode segmentation \
9 | --ncf 32 64 128 256 \
10 | --ninput_edges 2280 \
11 | --pool_res 1800 1350 600 \
12 | --resblocks 3 \
13 | --lr 0.001 \
14 | --batch_size 12 \
15 | --num_aug 20 \
16 | --slide_verts 0.2 \
17 |
18 |
19 | #
20 | # python train.py --dataroot datasets/coseg_vases --name coseg_vases --arch meshunet --dataset_mode
21 | segmentation --ncf 32 64 128 256 --ninput_edges 1500 --pool_res 1050 600 300 --resblocks 3 --lr 0.001 --batch_size 12 --num_aug 20
--------------------------------------------------------------------------------
/scripts/coseg_seg/view.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | python util/mesh_viewer.py \
4 | --files \
5 | checkpoints/coseg_aliens/meshes/142_0.obj \
6 | checkpoints/coseg_aliens/meshes/142_2.obj \
7 | checkpoints/coseg_aliens/meshes/142_3.obj \
--------------------------------------------------------------------------------
/scripts/cubes/get_data.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | DATADIR='datasets' #location where data gets downloaded to
4 |
5 | # get data
6 | mkdir -p $DATADIR && cd $DATADIR
7 | wget https://www.dropbox.com/s/2bxs5f9g60wa0wr/cubes.tar.gz
8 | tar -xzvf cubes.tar.gz && rm cubes.tar.gz
9 | echo "downloaded the data and put it in: " $DATADIR
--------------------------------------------------------------------------------
/scripts/cubes/get_pretrained.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | CHECKPOINT='checkpoints/cubes'
4 |
5 | # get pretrained model
6 | mkdir -p $CHECKPOINT
7 | wget https://www.dropbox.com/s/fg7wum39bmlxr7w/cubes_wts.tar.gz
8 | tar -xzvf cubes_wts.tar.gz && rm cubes_wts.tar.gz
9 | mv latest_net.pth $CHECKPOINT
10 | echo "downloaded pretrained weights to" $CHECKPOINT
--------------------------------------------------------------------------------
/scripts/cubes/test.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | ## run the test and export collapses
4 | python test.py \
5 | --dataroot datasets/cubes \
6 | --name cubes \
7 | --ncf 64 128 256 256 \
8 | --pool_res 600 450 300 210 \
9 | --norm group \
10 | --resblocks 1 \
11 | --export_folder meshes \
--------------------------------------------------------------------------------
/scripts/cubes/train.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | ## run the training
4 | python train.py \
5 | --dataroot datasets/cubes \
6 | --name cubes \
7 | --ncf 64 128 256 256 \
8 | --pool_res 600 450 300 210 \
9 | --norm group \
10 | --resblocks 1 \
11 | --flip_edges 0.2 \
12 | --slide_verts 0.2 \
13 | --num_aug 20 \
--------------------------------------------------------------------------------
/scripts/cubes/view.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | python util/mesh_viewer.py \
4 | --files checkpoints/cubes/meshes/horseshoe_4_0.obj \
5 | checkpoints/cubes/meshes/horseshoe_4_2.obj \
6 | checkpoints/cubes/meshes/horseshoe_4_3.obj \
7 | checkpoints/cubes/meshes/horseshoe_4_4.obj
--------------------------------------------------------------------------------
/scripts/dataprep/blender_process.py:
--------------------------------------------------------------------------------
1 | import bpy
2 | import os
3 | import sys
4 |
5 |
6 | '''
7 | Simplifies mesh to target number of faces
8 | Requires Blender 2.8
9 | Author: Rana Hanocka
10 |
11 | @input:
12 |
13 | number of target faces
14 | name of simplified .obj file
15 |
16 | @output:
17 | simplified mesh .obj
18 | to run it from cmd line:
19 | /opt/blender/blender --background --python blender_process.py /home/rana/koala.obj 1000 /home/rana/koala_1000.obj
20 | '''
21 |
22 | class Process:
23 | def __init__(self, obj_file, target_faces, export_name):
24 | mesh = self.load_obj(obj_file)
25 | self.simplify(mesh, target_faces)
26 | self.export_obj(mesh, export_name)
27 |
28 | def load_obj(self, obj_file):
29 | bpy.ops.import_scene.obj(filepath=obj_file, axis_forward='-Z', axis_up='Y', filter_glob="*.obj;*.mtl", use_edges=True,
30 | use_smooth_groups=True, use_split_objects=False, use_split_groups=False,
31 | use_groups_as_vgroups=False, use_image_search=True, split_mode='ON')
32 | ob = bpy.context.selected_objects[0]
33 | return ob
34 |
35 | def subsurf(self, mesh):
36 | # subdivide mesh
37 | bpy.context.view_layer.objects.active = mesh
38 | mod = mesh.modifiers.new(name='Subsurf', type='SUBSURF')
39 | mod.subdivision_type = 'SIMPLE'
40 | bpy.ops.object.modifier_apply(modifier=mod.name)
41 | # now triangulate
42 | mod = mesh.modifiers.new(name='Triangluate', type='TRIANGULATE')
43 | bpy.ops.object.modifier_apply(modifier=mod.name)
44 |
45 | def simplify(self, mesh, target_faces):
46 | bpy.context.view_layer.objects.active = mesh
47 | mod = mesh.modifiers.new(name='Decimate', type='DECIMATE')
48 | bpy.context.object.modifiers['Decimate'].use_collapse_triangulate = True
49 | #
50 | nfaces = len(mesh.data.polygons)
51 | if nfaces < target_faces:
52 | self.subsurf(mesh)
53 | nfaces = len(mesh.data.polygons)
54 | ratio = target_faces / float(nfaces)
55 | mod.ratio = float('%s' % ('%.6g' % (ratio)))
56 | print('faces: ', mod.face_count, mod.ratio)
57 | bpy.ops.object.modifier_apply(modifier=mod.name)
58 |
59 |
60 | def export_obj(self, mesh, export_name):
61 | outpath = os.path.dirname(export_name)
62 | if not os.path.isdir(outpath): os.makedirs(outpath)
63 | print('EXPORTING', export_name)
64 | bpy.ops.object.select_all(action='DESELECT')
65 | mesh.select_set(state=True)
66 | bpy.ops.export_scene.obj(filepath=export_name, check_existing=False, filter_glob="*.obj;*.mtl",
67 | use_selection=True, use_animation=False, use_mesh_modifiers=True, use_edges=True,
68 | use_smooth_groups=False, use_smooth_groups_bitflags=False, use_normals=True,
69 | use_uvs=False, use_materials=False, use_triangles=True, use_nurbs=False,
70 | use_vertex_groups=False, use_blen_objects=True, group_by_object=False,
71 | group_by_material=False, keep_vertex_order=True, global_scale=1, path_mode='AUTO',
72 | axis_forward='-Z', axis_up='Y')
73 |
74 | obj_file = sys.argv[-3]
75 | target_faces = int(sys.argv[-2])
76 | export_name = sys.argv[-1]
77 |
78 |
79 | print('args: ', obj_file, target_faces, export_name)
80 | blender = Process(obj_file, target_faces, export_name)
81 |
--------------------------------------------------------------------------------
/scripts/human_seg/get_data.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | DATADIR='datasets' #location where data gets downloaded to
4 |
5 | # get data
6 | echo "downloading the data and putting it in: " $DATADIR
7 | mkdir -p $DATADIR && cd $DATADIR
8 | wget https://www.dropbox.com/s/s3n05sw0zg27fz3/human_seg.tar.gz
9 | tar -xzvf human_seg.tar.gz && rm human_seg.tar.gz
--------------------------------------------------------------------------------
/scripts/human_seg/get_pretrained.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | CHECKPOINT='checkpoints/human_seg'
4 | mkdir -p $CHECKPOINT
5 |
6 | wget https://www.dropbox.com/s/8i26y7cpi6st2ra/human_seg_wts.tar.gz
7 | tar -xzvf human_seg_wts.tar.gz && rm human_seg_wts.tar.gz
8 | mv latest_net.pth $CHECKPOINT
9 | echo "downloaded pretrained weights to" $CHECKPOINT
10 |
--------------------------------------------------------------------------------
/scripts/human_seg/test.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | ## run the test and export collapses
4 | python test.py \
5 | --dataroot datasets/human_seg \
6 | --name human_seg \
7 | --arch meshunet \
8 | --dataset_mode segmentation \
9 | --ncf 32 64 128 256 \
10 | --ninput_edges 2280 \
11 | --pool_res 1800 1350 600 \
12 | --resblocks 3 \
13 | --batch_size 12 \
14 | --export_folder meshes \
--------------------------------------------------------------------------------
/scripts/human_seg/train.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | ## run the training
4 | python train.py \
5 | --dataroot datasets/human_seg \
6 | --name human_seg \
7 | --arch meshunet \
8 | --dataset_mode segmentation \
9 | --ncf 32 64 128 256 \
10 | --ninput_edges 2280 \
11 | --pool_res 1800 1350 600 \
12 | --resblocks 3 \
13 | --batch_size 12 \
14 | --lr 0.001 \
15 | --num_aug 20 \
16 | --slide_verts 0.2 \
--------------------------------------------------------------------------------
/scripts/human_seg/view.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | python util/mesh_viewer.py \
4 | --files \
5 | checkpoints/human_seg/meshes/shrec__14_0.obj
--------------------------------------------------------------------------------
/scripts/shrec/get_data.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | DATADIR='datasets' #location where data gets downloaded to
4 |
5 | # get data
6 | mkdir -p $DATADIR && cd $DATADIR
7 | wget https://www.dropbox.com/s/w16st84r6wc57u7/shrec_16.tar.gz
8 | tar -xzvf shrec_16.tar.gz && rm shrec_16.tar.gz
9 | echo "downloaded the data and putting it in: " $DATADIR
10 |
--------------------------------------------------------------------------------
/scripts/shrec/get_pretrained.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | CHECKPOINT='checkpoints/shrec16'
4 |
5 | mkdir -p $CHECKPOINT
6 | wget https://www.dropbox.com/s/wqq1qxj4fjbpfas/shrec16_wts.tar.gz
7 | tar -xzvf shrec16_wts.tar.gz && rm shrec16_wts.tar.gz
8 | mv latest_net.pth $CHECKPOINT
9 | echo "downloaded pretrained weights to" $CHECKPOINT
--------------------------------------------------------------------------------
/scripts/shrec/test.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | ## run the test and export collapses
4 | python test.py \
5 | --dataroot datasets/shrec_16 \
6 | --name shrec16 \
7 | --ncf 64 128 256 256 \
8 | --pool_res 600 450 300 180 \
9 | --norm group \
10 | --resblocks 1 \
11 | --export_folder meshes \
--------------------------------------------------------------------------------
/scripts/shrec/train.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | ## run the training
4 | python train.py \
5 | --dataroot datasets/shrec_16 \
6 | --name shrec16 \
7 | --ncf 64 128 256 256 \
8 | --pool_res 600 450 300 180 \
9 | --norm group \
10 | --resblocks 1 \
11 | --flip_edges 0.2 \
12 | --slide_verts 0.2 \
13 | --num_aug 20 \
14 | --niter_decay 100 \
--------------------------------------------------------------------------------
/scripts/shrec/view.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | python util/mesh_viewer.py \
4 | --files \
5 | checkpoints/shrec16/meshes/T74_0.obj \
6 | checkpoints/shrec16/meshes/T74_3.obj \
7 | checkpoints/shrec16/meshes/T74_4.obj
--------------------------------------------------------------------------------
/scripts/test_general.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import os
3 | import shutil
4 | import glob
5 | import subprocess
6 | '''
7 | scripts for unit testing
8 | '''
9 |
10 |
11 | def get_data(dset):
12 | dpaths = glob.glob('./datasets/{}*'.format(dset))
13 | [shutil.rmtree(d) for d in dpaths]
14 | cmd = './scripts/{}/get_data.sh > /dev/null 2>&1'.format(dset)
15 | os.system(cmd)
16 |
17 | def add_args(file, temp_file, new_args):
18 | with open(file) as f:
19 | tokens = f.readlines()
20 | # now make the config so it only trains for one iteration
21 | tokens[-1] = tokens[-1] + '\n'
22 | for arg in new_args:
23 | tokens.append(arg)
24 | with open(temp_file, 'w') as f:
25 | f.writelines(tokens)
26 |
27 | def run_train(dset):
28 | train_file = './scripts/{}/train.sh'.format(dset)
29 | temp_train_file = './scripts/{}/train_temp.sh'.format(dset)
30 | p = subprocess.run(['cp', '-p', '--preserve', train_file, temp_train_file])
31 | add_args(train_file, temp_train_file, ['--niter_decay 0 \\\n', '--niter 1 \\\n', '--max_dataset_size 2 \\\n', '--gpu_ids -1 \\'])
32 | cmd = "bash -c 'source ~/anaconda3/bin/activate ~/anaconda3/envs/meshcnn && {} > /dev/null 2>&1'".format(temp_train_file)
33 | os.system(cmd)
34 | os.remove(temp_train_file)
35 |
36 | def get_pretrained(dset):
37 | cmd = './scripts/{}/get_pretrained.sh > /dev/null 2>&1'.format(dset)
38 | os.system(cmd)
39 |
40 | def run_test(dset):
41 | test_file = './scripts/{}/test.sh'.format(dset)
42 | temp_test_file = './scripts/{}/test_temp.sh'.format(dset)
43 | p = subprocess.run(['cp', '-p', '--preserve', test_file, temp_test_file])
44 | add_args(test_file, temp_test_file, ['--gpu_ids -1 \\'])
45 | # now run inference
46 | cmd = "bash -c 'source ~/anaconda3/bin/activate ~/anaconda3/envs/meshcnn && {}'".format(temp_test_file)
47 | proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, shell=True)
48 | (_out, err) = proc.communicate()
49 | out = str(_out)
50 | idf0 = 'TEST ACC: ['
51 | token = out[out.find(idf0) + len(idf0):]
52 | idf1 = '%]'
53 | accs = token[:token.find(idf1)]
54 | acc = float(accs)
55 | if dset == 'shrec':
56 | assert acc == 99.167, "shrec accuracy was {} and not 99.167".format(acc)
57 | if dset == 'human_seg':
58 | assert acc == 92.554, "human_seg accuracy was {} and not 92.554".format(acc)
59 | os.remove(temp_test_file)
60 |
61 | def run_dataset(dset):
62 | get_data(dset)
63 | run_train(dset)
64 | get_pretrained(dset)
65 | run_test(dset)
66 |
67 | def test_shrec():
68 | run_dataset('shrec')
69 |
70 | def test_human_seg():
71 | run_dataset('human_seg')
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | from options.test_options import TestOptions
2 | from data import DataLoader
3 | from models import create_model
4 | from util.writer import Writer
5 |
6 |
7 | def run_test(epoch=-1):
8 | print('Running Test')
9 | opt = TestOptions().parse()
10 | opt.serial_batches = True # no shuffle
11 | dataset = DataLoader(opt)
12 | model = create_model(opt)
13 | writer = Writer(opt)
14 | # test
15 | writer.reset_counter()
16 | for i, data in enumerate(dataset):
17 | model.set_input(data)
18 | ncorrect, nexamples = model.test()
19 | writer.update_counter(ncorrect, nexamples)
20 | writer.print_acc(epoch, writer.acc)
21 | return writer.acc
22 |
23 |
24 | if __name__ == '__main__':
25 | run_test()
26 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import time
2 | from options.train_options import TrainOptions
3 | from data import DataLoader
4 | from models import create_model
5 | from util.writer import Writer
6 | from test import run_test
7 |
8 | if __name__ == '__main__':
9 | opt = TrainOptions().parse()
10 | dataset = DataLoader(opt)
11 | dataset_size = len(dataset)
12 | print('#training meshes = %d' % dataset_size)
13 |
14 | model = create_model(opt)
15 | writer = Writer(opt)
16 | total_steps = 0
17 |
18 | for epoch in range(opt.epoch_count, opt.niter + opt.niter_decay + 1):
19 | epoch_start_time = time.time()
20 | iter_data_time = time.time()
21 | epoch_iter = 0
22 |
23 | for i, data in enumerate(dataset):
24 | iter_start_time = time.time()
25 | if total_steps % opt.print_freq == 0:
26 | t_data = iter_start_time - iter_data_time
27 | total_steps += opt.batch_size
28 | epoch_iter += opt.batch_size
29 | model.set_input(data)
30 | model.optimize_parameters()
31 |
32 | if total_steps % opt.print_freq == 0:
33 | loss = model.loss
34 | t = (time.time() - iter_start_time) / opt.batch_size
35 | writer.print_current_losses(epoch, epoch_iter, loss, t, t_data)
36 | writer.plot_loss(loss, epoch, epoch_iter, dataset_size)
37 |
38 | if i % opt.save_latest_freq == 0:
39 | print('saving the latest model (epoch %d, total_steps %d)' %
40 | (epoch, total_steps))
41 | model.save_network('latest')
42 |
43 | iter_data_time = time.time()
44 | if epoch % opt.save_epoch_freq == 0:
45 | print('saving the model at the end of epoch %d, iters %d' %
46 | (epoch, total_steps))
47 | model.save_network('latest')
48 | model.save_network(epoch)
49 |
50 | print('End of epoch %d / %d \t Time Taken: %d sec' %
51 | (epoch, opt.niter + opt.niter_decay, time.time() - epoch_start_time))
52 | model.update_learning_rate()
53 | if opt.verbose_plot:
54 | writer.plot_model_wts(model, epoch)
55 |
56 | if epoch % opt.run_test_freq == 0:
57 | acc = run_test(epoch)
58 | writer.plot_acc(acc, epoch)
59 |
60 | writer.close()
61 |
--------------------------------------------------------------------------------
/util/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ranahanocka/MeshCNN/5bf0b899d48eb204b9b73bc1af381be20f4d7df1/util/__init__.py
--------------------------------------------------------------------------------
/util/mesh_viewer.py:
--------------------------------------------------------------------------------
1 | import mpl_toolkits.mplot3d as a3
2 | import matplotlib.colors as colors
3 | import pylab as pl
4 | import numpy as np
5 |
6 | V = np.array
7 | r2h = lambda x: colors.rgb2hex(tuple(map(lambda y: y / 255., x)))
8 | surface_color = r2h((255, 230, 205))
9 | edge_color = r2h((90, 90, 90))
10 | edge_colors = (r2h((15, 167, 175)), r2h((230, 81, 81)), r2h((142, 105, 252)), r2h((248, 235, 57)),
11 | r2h((51, 159, 255)), r2h((225, 117, 231)), r2h((97, 243, 185)), r2h((161, 183, 196)))
12 |
13 |
14 |
15 |
16 | def init_plot():
17 | ax = pl.figure().add_subplot(111, projection='3d')
18 | # hide axis, thank to
19 | # https://stackoverflow.com/questions/29041326/3d-plot-with-matplotlib-hide-axes-but-keep-axis-labels/
20 | ax.w_xaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))
21 | ax.w_yaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))
22 | ax.w_zaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))
23 | # Get rid of the spines
24 | ax.w_xaxis.line.set_color((1.0, 1.0, 1.0, 0.0))
25 | ax.w_yaxis.line.set_color((1.0, 1.0, 1.0, 0.0))
26 | ax.w_zaxis.line.set_color((1.0, 1.0, 1.0, 0.0))
27 | # Get rid of the ticks
28 | ax.set_xticks([])
29 | ax.set_yticks([])
30 | ax.set_zticks([])
31 | return (ax, [np.inf, -np.inf, np.inf, -np.inf, np.inf, -np.inf])
32 |
33 |
34 | def update_lim(mesh, plot):
35 | vs = mesh[0]
36 | for i in range(3):
37 | plot[1][2 * i] = min(plot[1][2 * i], vs[:, i].min())
38 | plot[1][2 * i + 1] = max(plot[1][2 * i], vs[:, i].max())
39 | return plot
40 |
41 |
42 | def update_plot(mesh, plot):
43 | if plot is None:
44 | plot = init_plot()
45 | return update_lim(mesh, plot)
46 |
47 |
48 | def surfaces(mesh, plot):
49 | vs, faces, edges = mesh
50 | vtx = vs[faces]
51 | edgecolor = edge_color if not len(edges) else 'none'
52 | tri = a3.art3d.Poly3DCollection(vtx, facecolors=surface_color +'55', edgecolors=edgecolor,
53 | linewidths=.5, linestyles='dashdot')
54 | plot[0].add_collection3d(tri)
55 | return plot
56 |
57 |
58 | def segments(mesh, plot):
59 | vs, _, edges = mesh
60 | for edge_c, edge_group in enumerate(edges):
61 | for edge_idx in edge_group:
62 | edge = vs[edge_idx]
63 | line = a3.art3d.Line3DCollection([edge], linewidths=.5, linestyles='dashdot')
64 | line.set_color(edge_colors[edge_c % len(edge_colors)])
65 | plot[0].add_collection3d(line)
66 | return plot
67 |
68 |
69 | def plot_mesh(mesh, *whats, show=True, plot=None):
70 | for what in [update_plot] + list(whats):
71 | plot = what(mesh, plot)
72 | if show:
73 | li = max(plot[1][1], plot[1][3], plot[1][5])
74 | plot[0].auto_scale_xyz([0, li], [0, li], [0, li])
75 | pl.tight_layout()
76 | pl.show()
77 | return plot
78 |
79 |
80 | def parse_obje(obj_file, scale_by):
81 | vs = []
82 | faces = []
83 | edges = []
84 |
85 | def add_to_edges():
86 | if edge_c >= len(edges):
87 | for _ in range(len(edges), edge_c + 1):
88 | edges.append([])
89 | edges[edge_c].append(edge_v)
90 |
91 | def fix_vertices():
92 | nonlocal vs, scale_by
93 | vs = V(vs)
94 | z = vs[:, 2].copy()
95 | vs[:, 2] = vs[:, 1]
96 | vs[:, 1] = z
97 | max_range = 0
98 | for i in range(3):
99 | min_value = np.min(vs[:, i])
100 | max_value = np.max(vs[:, i])
101 | max_range = max(max_range, max_value - min_value)
102 | vs[:, i] -= min_value
103 | if not scale_by:
104 | scale_by = max_range
105 | vs /= scale_by
106 |
107 | with open(obj_file) as f:
108 | for line in f:
109 | line = line.strip()
110 | splitted_line = line.split()
111 | if not splitted_line:
112 | continue
113 | elif splitted_line[0] == 'v':
114 | vs.append([float(v) for v in splitted_line[1:]])
115 | elif splitted_line[0] == 'f':
116 | faces.append([int(c) - 1 for c in splitted_line[1:]])
117 | elif splitted_line[0] == 'e':
118 | if len(splitted_line) >= 4:
119 | edge_v = [int(c) - 1 for c in splitted_line[1:-1]]
120 | edge_c = int(splitted_line[-1])
121 | add_to_edges()
122 |
123 | vs = V(vs)
124 | fix_vertices()
125 | faces = V(faces, dtype=int)
126 | edges = [V(c, dtype=int) for c in edges]
127 | return (vs, faces, edges), scale_by
128 |
129 |
130 | def view_meshes(*files, offset=.2):
131 | plot = None
132 | max_x = 0
133 | scale = 0
134 | for file in files:
135 | mesh, scale = parse_obje(file, scale)
136 | max_x_current = mesh[0][:, 0].max()
137 | mesh[0][:, 0] += max_x + offset
138 | plot = plot_mesh(mesh, surfaces, segments, plot=plot, show=file == files[-1])
139 | max_x += max_x_current + offset
140 |
141 |
142 | if __name__=='__main__':
143 | import argparse
144 | parser = argparse.ArgumentParser("view meshes")
145 | parser.add_argument('--files', nargs='+', default=['checkpoints/human_seg/meshes/shrec__14_0.obj',
146 | 'checkpoints/human_seg/meshes/shrec__14_3.obj'], type=str,
147 | help="list of 1 or more .obj files")
148 | args = parser.parse_args()
149 |
150 | # view meshes
151 | view_meshes(*args.files)
152 |
153 |
--------------------------------------------------------------------------------
/util/util.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | import torch
3 | import numpy as np
4 | import os
5 |
6 |
7 | def mkdir(path):
8 | if not os.path.exists(path):
9 | os.makedirs(path)
10 |
11 | MESH_EXTENSIONS = [
12 | '.obj',
13 | ]
14 |
15 | def is_mesh_file(filename):
16 | return any(filename.endswith(extension) for extension in MESH_EXTENSIONS)
17 |
18 | def pad(input_arr, target_length, val=0, dim=1):
19 | shp = input_arr.shape
20 | npad = [(0, 0) for _ in range(len(shp))]
21 | npad[dim] = (0, target_length - shp[dim])
22 | return np.pad(input_arr, pad_width=npad, mode='constant', constant_values=val)
23 |
24 | def seg_accuracy(predicted, ssegs, meshes):
25 | correct = 0
26 | ssegs = ssegs.squeeze(-1)
27 | correct_mat = ssegs.gather(2, predicted.cpu().unsqueeze(dim=2))
28 | for mesh_id, mesh in enumerate(meshes):
29 | correct_vec = correct_mat[mesh_id, :mesh.edges_count, 0]
30 | edge_areas = torch.from_numpy(mesh.get_edge_areas())
31 | correct += (correct_vec.float() * edge_areas).sum()
32 | return correct
33 |
34 | def print_network(net):
35 | """Print the total number of parameters in the network
36 | Parameters:
37 | network
38 | """
39 | print('---------- Network initialized -------------')
40 | num_params = 0
41 | for param in net.parameters():
42 | num_params += param.numel()
43 | print('[Network] Total number of parameters : %.3f M' % (num_params / 1e6))
44 | print('-----------------------------------------------')
45 |
46 | def get_heatmap_color(value, minimum=0, maximum=1):
47 | minimum, maximum = float(minimum), float(maximum)
48 | ratio = 2 * (value-minimum) / (maximum - minimum)
49 | b = int(max(0, 255*(1 - ratio)))
50 | r = int(max(0, 255*(ratio - 1)))
51 | g = 255 - b - r
52 | return r, g, b
53 |
54 |
55 | def normalize_np_array(np_array):
56 | min_value = np.min(np_array)
57 | max_value = np.max(np_array)
58 | return (np_array - min_value) / (max_value - min_value)
59 |
60 |
61 | def calculate_entropy(np_array):
62 | entropy = 0
63 | np_array /= np.sum(np_array)
64 | for a in np_array:
65 | if a != 0:
66 | entropy -= a * np.log(a)
67 | entropy /= np.log(np_array.shape[0])
68 | return entropy
69 |
--------------------------------------------------------------------------------
/util/writer.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 |
4 | try:
5 | from tensorboardX import SummaryWriter
6 | except ImportError as error:
7 | print('tensorboard X not installed, visualizing wont be available')
8 | SummaryWriter = None
9 |
10 | class Writer:
11 | def __init__(self, opt):
12 | self.name = opt.name
13 | self.opt = opt
14 | self.save_dir = os.path.join(opt.checkpoints_dir, opt.name)
15 | self.log_name = os.path.join(self.save_dir, 'loss_log.txt')
16 | self.testacc_log = os.path.join(self.save_dir, 'testacc_log.txt')
17 | self.start_logs()
18 | self.nexamples = 0
19 | self.ncorrect = 0
20 | #
21 | if opt.is_train and not opt.no_vis and SummaryWriter is not None:
22 | self.display = SummaryWriter(comment=opt.name)
23 | else:
24 | self.display = None
25 |
26 | def start_logs(self):
27 | """ creates test / train log files """
28 | if self.opt.is_train:
29 | with open(self.log_name, "a") as log_file:
30 | now = time.strftime("%c")
31 | log_file.write('================ Training Loss (%s) ================\n' % now)
32 | else:
33 | with open(self.testacc_log, "a") as log_file:
34 | now = time.strftime("%c")
35 | log_file.write('================ Testing Acc (%s) ================\n' % now)
36 |
37 | def print_current_losses(self, epoch, i, losses, t, t_data):
38 | """ prints train loss to terminal / file """
39 | message = '(epoch: %d, iters: %d, time: %.3f, data: %.3f) loss: %.3f ' \
40 | % (epoch, i, t, t_data, losses.item())
41 | print(message)
42 | with open(self.log_name, "a") as log_file:
43 | log_file.write('%s\n' % message)
44 |
45 | def plot_loss(self, loss, epoch, i, n):
46 | iters = i + (epoch - 1) * n
47 | if self.display:
48 | self.display.add_scalar('data/train_loss', loss, iters)
49 |
50 | def plot_model_wts(self, model, epoch):
51 | if self.opt.is_train and self.display:
52 | for name, param in model.net.named_parameters():
53 | self.display.add_histogram(name, param.clone().cpu().data.numpy(), epoch)
54 |
55 | def print_acc(self, epoch, acc):
56 | """ prints test accuracy to terminal / file """
57 | message = 'epoch: {}, TEST ACC: [{:.5} %]\n' \
58 | .format(epoch, acc * 100)
59 | print(message)
60 | with open(self.testacc_log, "a") as log_file:
61 | log_file.write('%s\n' % message)
62 |
63 | def plot_acc(self, acc, epoch):
64 | if self.display:
65 | self.display.add_scalar('data/test_acc', acc, epoch)
66 |
67 | def reset_counter(self):
68 | """
69 | counts # of correct examples
70 | """
71 | self.ncorrect = 0
72 | self.nexamples = 0
73 |
74 | def update_counter(self, ncorrect, nexamples):
75 | self.ncorrect += ncorrect
76 | self.nexamples += nexamples
77 |
78 | @property
79 | def acc(self):
80 | return float(self.ncorrect) / self.nexamples
81 |
82 | def close(self):
83 | if self.display is not None:
84 | self.display.close()
85 |
--------------------------------------------------------------------------------