├── .gitignore ├── LICENSE ├── README.md ├── config ├── __init__.py ├── config.py ├── test_config.yaml └── train_config.yaml ├── data ├── ModelNet40.py ├── __init__.py └── preprocess.py ├── doc └── pipeline.PNG ├── models ├── MeshNet.py ├── __init__.py └── layers.py ├── requirements.txt ├── test.py ├── train.py └── utils ├── __init__.py └── retrival.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/ 2 | ckpt_root/ 3 | dataset/ 4 | *.pkl 5 | */__pycache__/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Yue's Group of THU 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## MeshNet: Mesh Neural Network for 3D Shape Representation 2 | Created by Yutong Feng, Yifan Feng, Haoxuan You, Xibin Zhao, Yue Gao from Tsinghua University. 3 | 4 | ![pipeline](doc/pipeline.PNG) 5 | ### Introduction 6 | 7 | This work was published in AAAI 2019. We proposed a novel framework (MeshNet) for 3D shape representation, which could learn on mesh data directly and achieve satisfying performance compared with traditional methods based on mesh and representative methods based on other types of data. You can also check out [paper](https://ojs.aaai.org/index.php/AAAI/article/view/4840/4713) for a deeper introduction. 8 | 9 | Mesh is an important and powerful type of data for 3D shapes. Due to the complexity and irregularity of mesh data, there is little effort on using mesh data for 3D shape representation in recent years. We propose a mesh neural network, named MeshNet, to learn 3D shape representation directly from mesh data. Face-unit and feature splitting are introduced to solve the complexity and irregularity problem. We have applied MeshNet in the applications of 3D shape classification and retrieval. Experimental results and comparisons with the state-of-the-art methods demonstrate that MeshNet can achieve satisfying 3D shape classification and retrieval performance, which indicates the effectiveness of the proposed method on 3D shape representation. 10 | 11 | In this repository, we release the code and data for train a Mesh Neural Network for classification and retrieval tasks on ModelNet40 dataset. 12 | 13 | ### Update 14 | **[2021/12]** We have released an updated version that the proposed MeshNet achieves 92.75% classification accuracy on ModelNet40. The results are based on a better simplified version of ModelNet40, named "Manifold40", with watertight mesh and 500 faces per model. We also provide a more stable training script to achieve the performance. See the Usage section for details. 15 | 16 | ### Usage 17 | 18 | #### Installation 19 | You could install the required package as follows. This code has been tested with Python 3.8 and CUDA 11.1. 20 | ``` 21 | pip install -r requirements.txt 22 | ``` 23 | 24 | ##### Data Preparation 25 | MeshNet requires the pre-processed ModelNet40 with simplified and re-organized mesh data. To quickly start training, we recommend to use our [pre-processed ModelNet40 dataset](https://cloud.tsinghua.edu.cn/f/77436a9afd294a52b492/?dl=1), and configure the "data_root" in `config/train_config.yaml` and `config/test_config.yaml` with your path to the downloaded dataset. By default, run 26 | ``` 27 | wget --content-disposition https://cloud.tsinghua.edu.cn/f/77436a9afd294a52b492/?dl=1 28 | mkdir dataset 29 | unzip -d dataset/ ModelNet40_processed.zip 30 | rm ModelNet40_processed.zip 31 | ``` 32 | 33 | The details of our pre-processing are as follows: The original dataset are from [ModelNet](http://modelnet.cs.princeton.edu/). Firstly, we simplify the mesh models with no more than `max_faces` faces. We now recommend to use the [Manifold40](https://cloud.tsinghua.edu.cn/f/2a292c598af94265a0b8/?dl=1) version with watertight mesh and `max_faces=500`. Then we reorganize the dataset to the format required by MeshNet and store it into `XXX.npz`. The reorganized file includes two parts of data: 34 | - The "faces" part contains the center position, vertices' positions and normal vector of each face. 35 | - The "neighbors" part contains the indices of neighbors of each face. 36 | 37 | If you wish to create and use your own dataset, simplify your models into `.obj` format and use the code in `data/preprocess.py` to transform them into the required `.npz` format. Notice that the parameter `max_faces` in config files should be maximum number of faces among all of your simplified mesh models. 38 | 39 | ##### Evaluation 40 | The pretrained MeshNet model weights are stored in [pretrained model](https://cloud.tsinghua.edu.cn/f/33bfdc6f103340daa86a/?dl=1). You can download it and configure the "load_model" in `config/test_config.yaml` with your path to the weight file. Then run the test script. 41 | ``` 42 | wget --content-disposition https://cloud.tsinghua.edu.cn/f/33bfdc6f103340daa86a/?dl=1 43 | python test.py 44 | ``` 45 | 46 | ##### Training 47 | 48 | To train and evaluate MeshNet for classification and retrieval: 49 | 50 | ```bash 51 | python train.py 52 | ``` 53 | 54 | You can modify the configuration in the `config/train_config.yaml` for your own training, including the CUDA devices to use, the flag of data augmentation and the hyper-parameters of MeshNet. 55 | 56 | 57 | ### Citation 58 | 59 | if you find our work useful in your research, please consider citing: 60 | 61 | ``` 62 | @inproceedings{feng2019meshnet, 63 | title={Meshnet: Mesh neural network for 3d shape representation}, 64 | author={Feng, Yutong and Feng, Yifan and You, Haoxuan and Zhao, Xibin and Gao, Yue}, 65 | booktitle={Proceedings of the AAAI Conference on Artificial Intelligence}, 66 | volume={33}, 67 | number={01}, 68 | pages={8279--8286}, 69 | year={2019} 70 | } 71 | ``` 72 | 73 | ### Licence 74 | 75 | Our code is released under MIT License (see LICENSE file for details). 76 | -------------------------------------------------------------------------------- /config/__init__.py: -------------------------------------------------------------------------------- 1 | from .config import get_train_config, get_test_config 2 | -------------------------------------------------------------------------------- /config/config.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import yaml 4 | 5 | 6 | def _check_dir(dir, make_dir=True): 7 | if not osp.exists(dir): 8 | if make_dir: 9 | print('Create directory {}'.format(dir)) 10 | os.mkdir(dir) 11 | else: 12 | raise Exception('Directory not exist: {}'.format(dir)) 13 | 14 | 15 | def get_train_config(config_file='config/train_config.yaml'): 16 | with open(config_file, 'r') as f: 17 | cfg = yaml.load(f, Loader=yaml.loader.SafeLoader) 18 | 19 | _check_dir(cfg['dataset']['data_root'], make_dir=False) 20 | _check_dir(cfg['ckpt_root']) 21 | 22 | return cfg 23 | 24 | 25 | def get_test_config(config_file='config/test_config.yaml'): 26 | with open(config_file, 'r') as f: 27 | cfg = yaml.load(f, Loader=yaml.loader.SafeLoader) 28 | 29 | _check_dir(cfg['dataset']['data_root'], make_dir=False) 30 | 31 | return cfg 32 | -------------------------------------------------------------------------------- /config/test_config.yaml: -------------------------------------------------------------------------------- 1 | # CUDA 2 | cuda_devices: '0' 3 | 4 | # dataset 5 | dataset: 6 | data_root: 'dataset/ModelNet40_processed' 7 | augment_data: false 8 | max_faces: 500 9 | 10 | # model 11 | load_model: 'MeshNet_ModelNet40_250e_bs128_lr6e-4.pkl' 12 | 13 | # MeshNet 14 | MeshNet: 15 | structural_descriptor: 16 | num_kernel: 64 17 | sigma: 0.2 18 | mesh_convolution: 19 | aggregation_method: 'Concat' # Concat/Max/Average 20 | mask_ratio: 0.95 21 | dropout: 0.5 22 | num_classes: 40 23 | 24 | # test config 25 | batch_size: 128 26 | retrieval_on: true 27 | -------------------------------------------------------------------------------- /config/train_config.yaml: -------------------------------------------------------------------------------- 1 | # CUDA 2 | cuda_devices: '0,1' # multi-gpu training is available 3 | 4 | # dataset 5 | dataset: 6 | data_root: 'dataset/ModelNet40_processed' 7 | max_faces: 500 8 | augment_data: true 9 | jitter_sigma: 0.01 10 | jitter_clip: 0.05 11 | 12 | # result 13 | ckpt_root: 'ckpt_root' 14 | 15 | # MeshNet 16 | MeshNet: 17 | structural_descriptor: 18 | num_kernel: 64 19 | sigma: 0.2 20 | mesh_convolution: 21 | aggregation_method: 'Concat' # Concat/Max/Average 22 | mask_ratio: 0.95 23 | dropout: 0.5 24 | num_classes: 40 25 | 26 | # train 27 | seed: 0 28 | lr: 0.0006 29 | momentum: 0.9 30 | weight_decay: 0.0005 31 | batch_size: 128 32 | max_epoch: 250 33 | optimizer: 'adamw' # sgd/adamw 34 | scheduler: 'cos' # step/cos 35 | milestones: [30, 60, 90] 36 | gamma: 0.1 37 | retrieval_on: true # enable evaluating retrieval performance during training 38 | save_steps: 10 39 | -------------------------------------------------------------------------------- /data/ModelNet40.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import torch 4 | import torch.utils.data as data 5 | import pymeshlab 6 | from data.preprocess import find_neighbor 7 | 8 | type_to_index_map = { 9 | 'night_stand': 0, 'range_hood': 1, 'plant': 2, 'chair': 3, 'tent': 4, 10 | 'curtain': 5, 'piano': 6, 'dresser': 7, 'desk': 8, 'bed': 9, 11 | 'sink': 10, 'laptop':11, 'flower_pot': 12, 'car': 13, 'stool': 14, 12 | 'vase': 15, 'monitor': 16, 'airplane': 17, 'stairs': 18, 'glass_box': 19, 13 | 'bottle': 20, 'guitar': 21, 'cone': 22, 'toilet': 23, 'bathtub': 24, 14 | 'wardrobe': 25, 'radio': 26, 'person': 27, 'xbox': 28, 'bowl': 29, 15 | 'cup': 30, 'door': 31, 'tv_stand': 32, 'mantel': 33, 'sofa': 34, 16 | 'keyboard': 35, 'bookshelf': 36, 'bench': 37, 'table': 38, 'lamp': 39 17 | } 18 | 19 | 20 | class ModelNet40(data.Dataset): 21 | 22 | def __init__(self, cfg, part='train'): 23 | self.root = cfg['data_root'] 24 | self.max_faces = cfg['max_faces'] 25 | self.part = part 26 | self.augment_data = cfg['augment_data'] 27 | if self.augment_data: 28 | self.jitter_sigma = cfg['jitter_sigma'] 29 | self.jitter_clip = cfg['jitter_clip'] 30 | 31 | self.data = [] 32 | for type in os.listdir(self.root): 33 | if type not in type_to_index_map.keys(): 34 | continue 35 | type_index = type_to_index_map[type] 36 | type_root = os.path.join(os.path.join(self.root, type), part) 37 | for filename in os.listdir(type_root): 38 | if filename.endswith('.npz') or filename.endswith('.obj'): 39 | self.data.append((os.path.join(type_root, filename), type_index)) 40 | 41 | def __getitem__(self, i): 42 | path, type = self.data[i] 43 | if path.endswith('.npz'): 44 | data = np.load(path) 45 | face = data['faces'] 46 | neighbor_index = data['neighbors'] 47 | else: 48 | face, neighbor_index = process_mesh(path, self.max_faces) 49 | if face is None: 50 | return self.__getitem__(0) 51 | 52 | # data augmentation 53 | if self.augment_data and self.part == 'train': 54 | # jitter 55 | jittered_data = np.clip(self.jitter_sigma * np.random.randn(*face[:, :3].shape), -1 * self.jitter_clip, self.jitter_clip) 56 | face = np.concatenate((face[:, :3] + jittered_data, face[:, 3:]), 1) 57 | 58 | # fill for n < max_faces with randomly picked faces 59 | num_point = len(face) 60 | if num_point < self.max_faces: 61 | fill_face = [] 62 | fill_neighbor_index = [] 63 | for i in range(self.max_faces - num_point): 64 | index = np.random.randint(0, num_point) 65 | fill_face.append(face[index]) 66 | fill_neighbor_index.append(neighbor_index[index]) 67 | face = np.concatenate((face, np.array(fill_face))) 68 | neighbor_index = np.concatenate((neighbor_index, np.array(fill_neighbor_index))) 69 | 70 | # to tensor 71 | face = torch.from_numpy(face).float() 72 | neighbor_index = torch.from_numpy(neighbor_index).long() 73 | target = torch.tensor(type, dtype=torch.long) 74 | 75 | # reorganize 76 | face = face.permute(1, 0).contiguous() 77 | centers, corners, normals = face[:3], face[3:12], face[12:] 78 | corners = corners - torch.cat([centers, centers, centers], 0) 79 | 80 | return centers, corners, normals, neighbor_index, target 81 | 82 | def __len__(self): 83 | return len(self.data) 84 | 85 | 86 | def process_mesh(path, max_faces): 87 | ms = pymeshlab.MeshSet() 88 | ms.clear() 89 | 90 | # load mesh 91 | ms.load_new_mesh(path) 92 | mesh = ms.current_mesh() 93 | 94 | # # clean up 95 | # mesh, _ = pymesh.remove_isolated_vertices(mesh) 96 | # mesh, _ = pymesh.remove_duplicated_vertices(mesh) 97 | 98 | # get elements 99 | vertices = mesh.vertex_matrix() 100 | faces = mesh.face_matrix() 101 | 102 | if faces.shape[0] != max_faces: # only occur once in train set of Manifold40 103 | print("Model with more than {} faces ({}): {}".format(max_faces, faces.shape[0], path)) 104 | return None, None 105 | 106 | # move to center 107 | center = (np.max(vertices, 0) + np.min(vertices, 0)) / 2 108 | vertices -= center 109 | 110 | # normalize 111 | max_len = np.max(vertices[:, 0]**2 + vertices[:, 1]**2 + vertices[:, 2]**2) 112 | vertices /= np.sqrt(max_len) 113 | 114 | # get normal vector 115 | ms.clear() 116 | mesh = pymeshlab.Mesh(vertices, faces) 117 | ms.add_mesh(mesh) 118 | face_normal = ms.current_mesh().face_normal_matrix() 119 | 120 | # get neighbors 121 | faces_contain_this_vertex = [] 122 | for i in range(len(vertices)): 123 | faces_contain_this_vertex.append(set([])) 124 | centers = [] 125 | corners = [] 126 | for i in range(len(faces)): 127 | [v1, v2, v3] = faces[i] 128 | x1, y1, z1 = vertices[v1] 129 | x2, y2, z2 = vertices[v2] 130 | x3, y3, z3 = vertices[v3] 131 | centers.append([(x1 + x2 + x3) / 3, (y1 + y2 + y3) / 3, (z1 + z2 + z3) / 3]) 132 | corners.append([x1, y1, z1, x2, y2, z2, x3, y3, z3]) 133 | faces_contain_this_vertex[v1].add(i) 134 | faces_contain_this_vertex[v2].add(i) 135 | faces_contain_this_vertex[v3].add(i) 136 | 137 | neighbors = [] 138 | for i in range(len(faces)): 139 | [v1, v2, v3] = faces[i] 140 | n1 = find_neighbor(faces, faces_contain_this_vertex, v1, v2, i) 141 | n2 = find_neighbor(faces, faces_contain_this_vertex, v2, v3, i) 142 | n3 = find_neighbor(faces, faces_contain_this_vertex, v3, v1, i) 143 | neighbors.append([n1, n2, n3]) 144 | 145 | centers = np.array(centers) 146 | corners = np.array(corners) 147 | faces = np.concatenate([centers, corners, face_normal], axis=1) 148 | neighbors = np.array(neighbors) 149 | 150 | return faces, neighbors 151 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | from .ModelNet40 import ModelNet40 2 | -------------------------------------------------------------------------------- /data/preprocess.py: -------------------------------------------------------------------------------- 1 | import pymeshlab 2 | import numpy as np 3 | from pathlib import Path 4 | from rich.progress import track 5 | 6 | 7 | def find_neighbor(faces, faces_contain_this_vertex, vf1, vf2, except_face): 8 | for i in (faces_contain_this_vertex[vf1] & faces_contain_this_vertex[vf2]): 9 | if i != except_face: 10 | face = faces[i].tolist() 11 | face.remove(vf1) 12 | face.remove(vf2) 13 | return i 14 | 15 | return except_face 16 | 17 | if __name__ == '__main__': 18 | root = Path('dataset/Manifold40') 19 | new_root = Path('dataset/ModelNet40_processed') 20 | max_faces = 500 21 | shape_list = sorted(list(root.glob('*/*/*.obj'))) 22 | ms = pymeshlab.MeshSet() 23 | 24 | for shape_dir in track(shape_list): 25 | out_dir = new_root / shape_dir.relative_to(root).with_suffix('.npz') 26 | # if out_dir.exists(): 27 | # continue 28 | out_dir.parent.mkdir(parents=True, exist_ok=True) 29 | 30 | ms.clear() 31 | # load mesh 32 | ms.load_new_mesh(str(shape_dir)) 33 | mesh = ms.current_mesh() 34 | 35 | # # clean up 36 | # mesh, _ = pymesh.remove_isolated_vertices(mesh) 37 | # mesh, _ = pymesh.remove_duplicated_vertices(mesh) 38 | 39 | # get elements 40 | vertices = mesh.vertex_matrix() 41 | faces = mesh.face_matrix() 42 | 43 | if faces.shape[0] != max_faces: 44 | print("Model with more than {} faces ({}): {}".format(max_faces, faces.shape[0], out_dir)) 45 | continue 46 | 47 | # move to center 48 | center = (np.max(vertices, 0) + np.min(vertices, 0)) / 2 49 | vertices -= center 50 | 51 | # normalize 52 | max_len = np.max(vertices[:, 0]**2 + vertices[:, 1]**2 + vertices[:, 2]**2) 53 | vertices /= np.sqrt(max_len) 54 | 55 | # get normal vector 56 | ms.clear() 57 | mesh = pymeshlab.Mesh(vertices, faces) 58 | ms.add_mesh(mesh) 59 | face_normal = ms.current_mesh().face_normal_matrix() 60 | 61 | # get neighbors 62 | faces_contain_this_vertex = [] 63 | for i in range(len(vertices)): 64 | faces_contain_this_vertex.append(set([])) 65 | centers = [] 66 | corners = [] 67 | for i in range(len(faces)): 68 | [v1, v2, v3] = faces[i] 69 | x1, y1, z1 = vertices[v1] 70 | x2, y2, z2 = vertices[v2] 71 | x3, y3, z3 = vertices[v3] 72 | centers.append([(x1 + x2 + x3) / 3, (y1 + y2 + y3) / 3, (z1 + z2 + z3) / 3]) 73 | corners.append([x1, y1, z1, x2, y2, z2, x3, y3, z3]) 74 | faces_contain_this_vertex[v1].add(i) 75 | faces_contain_this_vertex[v2].add(i) 76 | faces_contain_this_vertex[v3].add(i) 77 | 78 | neighbors = [] 79 | for i in range(len(faces)): 80 | [v1, v2, v3] = faces[i] 81 | n1 = find_neighbor(faces, faces_contain_this_vertex, v1, v2, i) 82 | n2 = find_neighbor(faces, faces_contain_this_vertex, v2, v3, i) 83 | n3 = find_neighbor(faces, faces_contain_this_vertex, v3, v1, i) 84 | neighbors.append([n1, n2, n3]) 85 | 86 | centers = np.array(centers) 87 | corners = np.array(corners) 88 | faces = np.concatenate([centers, corners, face_normal], axis=1) 89 | neighbors = np.array(neighbors) 90 | 91 | np.savez(str(out_dir), faces=faces, neighbors=neighbors) 92 | -------------------------------------------------------------------------------- /doc/pipeline.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iMoonLab/MeshNet/70f9115a121cef71f62d774088771337c3beaf4b/doc/pipeline.PNG -------------------------------------------------------------------------------- /models/MeshNet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from models import SpatialDescriptor, StructuralDescriptor, MeshConvolution 4 | 5 | 6 | class MeshNet(nn.Module): 7 | 8 | def __init__(self, cfg, require_fea=False): 9 | super(MeshNet, self).__init__() 10 | self.require_fea = require_fea 11 | 12 | self.spatial_descriptor = SpatialDescriptor() 13 | self.structural_descriptor = StructuralDescriptor(cfg['structural_descriptor']) 14 | self.mesh_conv1 = MeshConvolution(cfg['mesh_convolution'], 64, 131, 256, 256) 15 | self.mesh_conv2 = MeshConvolution(cfg['mesh_convolution'], 256, 256, 512, 512) 16 | self.fusion_mlp = nn.Sequential( 17 | nn.Conv1d(1024, 1024, 1), 18 | nn.BatchNorm1d(1024), 19 | nn.ReLU(), 20 | ) 21 | self.concat_mlp = nn.Sequential( 22 | nn.Conv1d(1792, 1024, 1), 23 | nn.BatchNorm1d(1024), 24 | nn.ReLU(), 25 | ) 26 | self.mask_ratio = cfg['mask_ratio'] 27 | self.classifier = nn.Sequential( 28 | nn.Linear(1024, 512), 29 | nn.ReLU(), 30 | nn.Dropout(p=cfg['dropout']), 31 | nn.Linear(512, 256), 32 | nn.ReLU(), 33 | nn.Dropout(p=cfg['dropout']), 34 | nn.Linear(256, cfg['num_classes']) 35 | ) 36 | 37 | def forward(self, centers, corners, normals, neighbor_index): 38 | spatial_fea0 = self.spatial_descriptor(centers) 39 | structural_fea0 = self.structural_descriptor(corners, normals, neighbor_index) 40 | 41 | spatial_fea1, structural_fea1 = self.mesh_conv1(spatial_fea0, structural_fea0, neighbor_index) 42 | spatial_fea2, structural_fea2 = self.mesh_conv2(spatial_fea1, structural_fea1, neighbor_index) 43 | spatial_fea3 = self.fusion_mlp(torch.cat([spatial_fea2, structural_fea2], 1)) 44 | 45 | fea = self.concat_mlp(torch.cat([spatial_fea1, spatial_fea2, spatial_fea3], 1)) # b, c, n 46 | if self.training: 47 | fea = fea[:, :, torch.randperm(fea.size(2))[:int(fea.size(2) * (1 - self.mask_ratio))]] 48 | fea = torch.max(fea, dim=2)[0] 49 | fea = fea.reshape(fea.size(0), -1) 50 | fea = self.classifier[:-1](fea) 51 | cls = self.classifier[-1:](fea) 52 | 53 | if self.require_fea: 54 | return cls, fea / torch.norm(fea) 55 | else: 56 | return cls 57 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .layers import SpatialDescriptor, StructuralDescriptor, MeshConvolution 2 | from .MeshNet import MeshNet 3 | -------------------------------------------------------------------------------- /models/layers.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | from torch.nn.parameter import Parameter 5 | 6 | 7 | class FaceRotateConvolution(nn.Module): 8 | 9 | def __init__(self): 10 | super(FaceRotateConvolution, self).__init__() 11 | self.rotate_mlp = nn.Sequential( 12 | nn.Conv1d(6, 32, 1), 13 | nn.BatchNorm1d(32), 14 | nn.ReLU(), 15 | nn.Conv1d(32, 32, 1), 16 | nn.BatchNorm1d(32), 17 | nn.ReLU() 18 | ) 19 | self.fusion_mlp = nn.Sequential( 20 | nn.Conv1d(32, 64, 1), 21 | nn.BatchNorm1d(64), 22 | nn.ReLU(), 23 | nn.Conv1d(64, 64, 1), 24 | nn.BatchNorm1d(64), 25 | nn.ReLU() 26 | ) 27 | 28 | def forward(self, corners): 29 | 30 | fea = (self.rotate_mlp(corners[:, :6]) + 31 | self.rotate_mlp(corners[:, 3:9]) + 32 | self.rotate_mlp(torch.cat([corners[:, 6:], corners[:, :3]], 1))) / 3 33 | 34 | return self.fusion_mlp(fea) 35 | 36 | 37 | class FaceKernelCorrelation(nn.Module): 38 | 39 | def __init__(self, num_kernel=64, sigma=0.2): 40 | super(FaceKernelCorrelation, self).__init__() 41 | self.num_kernel = num_kernel 42 | self.sigma = sigma 43 | self.weight_alpha = Parameter(torch.rand(1, num_kernel, 4) * np.pi) 44 | self.weight_beta = Parameter(torch.rand(1, num_kernel, 4) * 2 * np.pi) 45 | self.bn = nn.BatchNorm1d(num_kernel) 46 | self.relu = nn.ReLU() 47 | 48 | def forward(self, normals, neighbor_index): 49 | 50 | b, _, n = normals.size() 51 | 52 | center = normals.unsqueeze(2).expand(-1, -1, self.num_kernel, -1).unsqueeze(4) 53 | neighbor = torch.gather(normals.unsqueeze(3).expand(-1, -1, -1, 3), 2, 54 | neighbor_index.unsqueeze(1).expand(-1, 3, -1, -1)) 55 | neighbor = neighbor.unsqueeze(2).expand(-1, -1, self.num_kernel, -1, -1) 56 | 57 | fea = torch.cat([center, neighbor], 4) 58 | fea = fea.unsqueeze(5).expand(-1, -1, -1, -1, -1, 4) 59 | weight = torch.cat([torch.sin(self.weight_alpha) * torch.cos(self.weight_beta), 60 | torch.sin(self.weight_alpha) * torch.sin(self.weight_beta), 61 | torch.cos(self.weight_alpha)], 0) 62 | weight = weight.unsqueeze(0).expand(b, -1, -1, -1) 63 | weight = weight.unsqueeze(3).expand(-1, -1, -1, n, -1) 64 | weight = weight.unsqueeze(4).expand(-1, -1, -1, -1, 4, -1) 65 | 66 | dist = torch.sum((fea - weight)**2, 1) 67 | fea = torch.sum(torch.sum(np.e**(dist / (-2 * self.sigma**2)), 4), 3) / 16 68 | 69 | return self.relu(self.bn(fea)) 70 | 71 | 72 | class SpatialDescriptor(nn.Module): 73 | 74 | def __init__(self): 75 | super(SpatialDescriptor, self).__init__() 76 | 77 | self.spatial_mlp = nn.Sequential( 78 | nn.Conv1d(3, 64, 1), 79 | nn.BatchNorm1d(64), 80 | nn.ReLU(), 81 | nn.Conv1d(64, 64, 1), 82 | nn.BatchNorm1d(64), 83 | nn.ReLU(), 84 | ) 85 | 86 | def forward(self, centers): 87 | return self.spatial_mlp(centers) 88 | 89 | 90 | class StructuralDescriptor(nn.Module): 91 | 92 | def __init__(self, cfg): 93 | super(StructuralDescriptor, self).__init__() 94 | 95 | self.FRC = FaceRotateConvolution() 96 | self.FKC = FaceKernelCorrelation(cfg['num_kernel'], cfg['sigma']) 97 | self.structural_mlp = nn.Sequential( 98 | nn.Conv1d(64 + 3 + cfg['num_kernel'], 131, 1), 99 | nn.BatchNorm1d(131), 100 | nn.ReLU(), 101 | nn.Conv1d(131, 131, 1), 102 | nn.BatchNorm1d(131), 103 | nn.ReLU(), 104 | ) 105 | 106 | def forward(self, corners, normals, neighbor_index): 107 | structural_fea1 = self.FRC(corners) 108 | structural_fea2 = self.FKC(normals, neighbor_index) 109 | 110 | return self.structural_mlp(torch.cat([structural_fea1, structural_fea2, normals], 1)) 111 | 112 | 113 | class MeshConvolution(nn.Module): 114 | 115 | def __init__(self, cfg, spatial_in_channel, structural_in_channel, spatial_out_channel, structural_out_channel): 116 | super(MeshConvolution, self).__init__() 117 | 118 | self.spatial_in_channel = spatial_in_channel 119 | self.structural_in_channel = structural_in_channel 120 | self.spatial_out_channel = spatial_out_channel 121 | self.structural_out_channel = structural_out_channel 122 | 123 | assert cfg['aggregation_method'] in ['Concat', 'Max', 'Average'] 124 | self.aggregation_method = cfg['aggregation_method'] 125 | 126 | self.combination_mlp = nn.Sequential( 127 | nn.Conv1d(self.spatial_in_channel + self.structural_in_channel, self.spatial_out_channel, 1), 128 | nn.BatchNorm1d(self.spatial_out_channel), 129 | nn.ReLU(), 130 | ) 131 | 132 | if self.aggregation_method == 'Concat': 133 | self.concat_mlp = nn.Sequential( 134 | nn.Conv2d(self.structural_in_channel * 2, self.structural_in_channel, 1), 135 | nn.BatchNorm2d(self.structural_in_channel), 136 | nn.ReLU(), 137 | ) 138 | 139 | self.aggregation_mlp = nn.Sequential( 140 | nn.Conv1d(self.structural_in_channel, self.structural_out_channel, 1), 141 | nn.BatchNorm1d(self.structural_out_channel), 142 | nn.ReLU(), 143 | ) 144 | 145 | def forward(self, spatial_fea, structural_fea, neighbor_index): 146 | b, _, n = spatial_fea.size() 147 | 148 | # Combination 149 | spatial_fea = self.combination_mlp(torch.cat([spatial_fea, structural_fea], 1)) 150 | 151 | # Aggregation 152 | if self.aggregation_method == 'Concat': 153 | structural_fea = torch.cat([structural_fea.unsqueeze(3).expand(-1, -1, -1, 3), 154 | torch.gather(structural_fea.unsqueeze(3).expand(-1, -1, -1, 3), 2, 155 | neighbor_index.unsqueeze(1).expand(-1, self.structural_in_channel, 156 | -1, -1))], 1) 157 | structural_fea = self.concat_mlp(structural_fea) 158 | structural_fea = torch.max(structural_fea, 3)[0] 159 | 160 | elif self.aggregation_method == 'Max': 161 | structural_fea = torch.cat([structural_fea.unsqueeze(3), 162 | torch.gather(structural_fea.unsqueeze(3).expand(-1, -1, -1, 3), 2, 163 | neighbor_index.unsqueeze(1).expand(-1, self.structural_in_channel, 164 | -1, -1))], 3) 165 | structural_fea = torch.max(structural_fea, 3)[0] 166 | 167 | elif self.aggregation_method == 'Average': 168 | structural_fea = torch.cat([structural_fea.unsqueeze(3), 169 | torch.gather(structural_fea.unsqueeze(3).expand(-1, -1, -1, 3), 2, 170 | neighbor_index.unsqueeze(1).expand(-1, self.structural_in_channel, 171 | -1, -1))], 3) 172 | structural_fea = torch.sum(structural_fea, dim=3) / 4 173 | 174 | structural_fea = self.aggregation_mlp(structural_fea) 175 | 176 | return spatial_fea, structural_fea 177 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.8.0 2 | PyYAML==6.0 3 | pymeshlab==2021.10 4 | rich==10.16.1 5 | scipy 6 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import torch 4 | from torch.autograd import Variable 5 | import torch.nn as nn 6 | import torch.utils.data as data 7 | from config import get_test_config 8 | from data import ModelNet40 9 | from models import MeshNet 10 | from utils.retrival import append_feature, calculate_map 11 | 12 | 13 | cfg = get_test_config() 14 | os.environ['CUDA_VISIBLE_DEVICES'] = cfg['cuda_devices'] 15 | 16 | 17 | data_set = ModelNet40(cfg=cfg['dataset'], part='test') 18 | data_loader = data.DataLoader(data_set, batch_size=cfg['batch_size'], num_workers=4, shuffle=True, pin_memory=False) 19 | 20 | 21 | def test_model(model): 22 | 23 | correct_num = 0 24 | ft_all, lbl_all = None, None 25 | 26 | with torch.no_grad(): 27 | for i, (centers, corners, normals, neighbor_index, targets) in enumerate(data_loader): 28 | centers = centers.cuda() 29 | corners = corners.cuda() 30 | normals = normals.cuda() 31 | neighbor_index = neighbor_index.cuda() 32 | targets = targets.cuda() 33 | 34 | outputs, feas = model(centers, corners, normals, neighbor_index) 35 | _, preds = torch.max(outputs, 1) 36 | 37 | correct_num += (preds == targets).float().sum() 38 | 39 | if cfg['retrieval_on']: 40 | ft_all = append_feature(ft_all, feas.detach().cpu()) 41 | lbl_all = append_feature(lbl_all, targets.detach().cpu(), flaten=True) 42 | 43 | print('Accuracy: {:.4f}'.format(float(correct_num) / len(data_set))) 44 | if cfg['retrieval_on']: 45 | print('mAP: {:.4f}'.format(calculate_map(ft_all, lbl_all))) 46 | 47 | 48 | if __name__ == '__main__': 49 | 50 | model = MeshNet(cfg=cfg['MeshNet'], require_fea=True) 51 | model.cuda() 52 | model = nn.DataParallel(model) 53 | model.load_state_dict(torch.load(cfg['load_model'])) 54 | model.eval() 55 | 56 | test_model(model) 57 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import os 3 | import random 4 | import torch 5 | from torch.autograd import Variable 6 | import torch.nn as nn 7 | import torch.optim as optim 8 | import torch.utils.data as data 9 | import torch.backends.cudnn as cudnn 10 | import math 11 | import numpy as np 12 | from config import get_train_config 13 | from data import ModelNet40 14 | from models import MeshNet 15 | from utils.retrival import append_feature, calculate_map 16 | 17 | 18 | cfg = get_train_config() 19 | os.environ['CUDA_VISIBLE_DEVICES'] = cfg['cuda_devices'] 20 | 21 | # seed 22 | seed = cfg['seed'] 23 | random.seed(seed) 24 | os.environ['PYTHONHASHSEED'] = str(seed) 25 | np.random.seed(seed) 26 | torch.manual_seed(seed) 27 | torch.cuda.manual_seed(seed) 28 | torch.cuda.manual_seed_all(seed) 29 | 30 | # dataset 31 | data_set = { 32 | x: ModelNet40(cfg=cfg['dataset'], part=x) for x in ['train', 'test'] 33 | } 34 | data_loader = { 35 | x: data.DataLoader(data_set[x], batch_size=cfg['batch_size'], num_workers=4, shuffle=True, pin_memory=False) 36 | for x in ['train', 'test'] 37 | } 38 | 39 | 40 | def train_model(model, criterion, optimizer, scheduler, cfg): 41 | 42 | best_acc = 0.0 43 | best_map = 0.0 44 | best_model_wts = copy.deepcopy(model.state_dict()) 45 | 46 | for epoch in range(1, cfg['max_epoch']): 47 | 48 | print('-' * 60) 49 | print('Epoch: {} / {}'.format(epoch, cfg['max_epoch'])) 50 | print('-' * 60) 51 | 52 | # adjust_learning_rate(cfg, epoch, optimizer) 53 | for phrase in ['train', 'test']: 54 | 55 | if phrase == 'train': 56 | model.train() 57 | else: 58 | model.eval() 59 | 60 | running_loss = 0.0 61 | running_corrects = 0 62 | ft_all, lbl_all = None, None 63 | 64 | for i, (centers, corners, normals, neighbor_index, targets) in enumerate(data_loader[phrase]): 65 | centers = centers.cuda() 66 | corners = corners.cuda() 67 | normals = normals.cuda() 68 | neighbor_index = neighbor_index.cuda() 69 | targets = targets.cuda() 70 | 71 | with torch.set_grad_enabled(phrase == 'train'): 72 | outputs, feas = model(centers, corners, normals, neighbor_index) 73 | _, preds = torch.max(outputs, 1) 74 | loss = criterion(outputs, targets) 75 | 76 | if phrase == 'train': 77 | optimizer.zero_grad() 78 | loss.backward() 79 | optimizer.step() 80 | 81 | if phrase == 'test' and cfg['retrieval_on']: 82 | ft_all = append_feature(ft_all, feas.detach().cpu()) 83 | lbl_all = append_feature(lbl_all, targets.detach().cpu(), flaten=True) 84 | 85 | running_loss += loss.item() * centers.size(0) 86 | running_corrects += torch.sum(preds == targets.data) 87 | 88 | epoch_loss = running_loss / len(data_set[phrase]) 89 | epoch_acc = running_corrects.double() / len(data_set[phrase]) 90 | 91 | if phrase == 'train': 92 | print('{} Loss: {:.4f} Acc: {:.4f}'.format(phrase, epoch_loss, epoch_acc)) 93 | scheduler.step() 94 | 95 | if phrase == 'test': 96 | if epoch_acc > best_acc: 97 | best_acc = epoch_acc 98 | best_model_wts = copy.deepcopy(model.state_dict()) 99 | print_info = '{} Loss: {:.4f} Acc: {:.4f} (best {:.4f})'.format(phrase, epoch_loss, epoch_acc, best_acc) 100 | 101 | if cfg['retrieval_on']: 102 | epoch_map = calculate_map(ft_all, lbl_all) 103 | if epoch_map > best_map: 104 | best_map = epoch_map 105 | print_info += ' mAP: {:.4f}'.format(epoch_map) 106 | 107 | if epoch % cfg['save_steps'] == 0: 108 | torch.save(copy.deepcopy(model.state_dict()), os.path.join(cfg['ckpt_root'], '{}.pkl'.format(epoch))) 109 | 110 | print(print_info) 111 | 112 | print('Best val acc: {:.4f}'.format(best_acc)) 113 | print('Config: {}'.format(cfg)) 114 | 115 | return best_model_wts 116 | 117 | 118 | if __name__ == '__main__': 119 | 120 | # prepare model 121 | model = MeshNet(cfg=cfg['MeshNet'], require_fea=True) 122 | model.cuda() 123 | model = nn.DataParallel(model) 124 | 125 | # criterion 126 | criterion = nn.CrossEntropyLoss() 127 | 128 | # optimizer 129 | if cfg['optimizer'] == 'sgd': 130 | optimizer = optim.SGD(model.parameters(), lr=cfg['lr'], momentum=cfg['momentum'], weight_decay=cfg['weight_decay']) 131 | else: 132 | optimizer = optim.AdamW(model.parameters(), lr=cfg['lr'], weight_decay=cfg['weight_decay']) 133 | 134 | # scheduler 135 | if cfg['scheduler'] == 'step': 136 | scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=cfg['milestones']) 137 | else: 138 | scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=cfg['max_epoch']) 139 | 140 | # start training 141 | if not os.path.exists(cfg['ckpt_root']): 142 | os.mkdir(cfg['ckpt_root']) 143 | best_model_wts = train_model(model, criterion, optimizer, scheduler, cfg) 144 | torch.save(best_model_wts, os.path.join(cfg['ckpt_root'], 'MeshNet_best.pkl')) 145 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .retrival import append_feature, calculate_map 2 | -------------------------------------------------------------------------------- /utils/retrival.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import scipy 4 | import scipy.spatial 5 | 6 | 7 | def append_feature(raw, data, flaten=False): 8 | data = np.array(data) 9 | if flaten: 10 | data = data.reshape(-1, 1) 11 | if raw is None: 12 | raw = np.array(data) 13 | else: 14 | raw = np.vstack((raw, data)) 15 | return raw 16 | 17 | 18 | def calculate_map(fts, lbls, dis_mat=None): 19 | return map_score(fts, fts, lbls, lbls) 20 | 21 | 22 | def acc_score(y_true, y_pred, average="micro"): 23 | if isinstance(y_true, list): 24 | y_true = np.array(y_true) 25 | if isinstance(y_pred, list): 26 | y_pred = np.array(y_pred) 27 | if average == "micro": 28 | # overall 29 | return np.mean(y_true == y_pred) 30 | elif average == "macro": 31 | # average of each class 32 | cls_acc = [] 33 | for cls_idx in np.unique(y_true): 34 | cls_acc.append(np.mean(y_pred[y_true==cls_idx]==cls_idx)) 35 | return np.mean(np.array(cls_acc)) 36 | else: 37 | raise NotImplementedError 38 | 39 | def cdist(fts_a, fts_b, metric): 40 | if metric == 'inner': 41 | return np.matmul(fts_a, fts_b.T) 42 | else: 43 | return scipy.spatial.distance.cdist(fts_a, fts_b, metric) 44 | 45 | def map_score(fts_a, fts_b, lbl_a, lbl_b, metric='cosine'): 46 | dist = cdist(fts_a, fts_b, metric) 47 | res = map_from_dist(dist, lbl_a, lbl_b) 48 | return res 49 | 50 | 51 | def map_from_dist(dist, lbl_a, lbl_b): 52 | n_a, n_b = dist.shape 53 | s_idx = dist.argsort() 54 | 55 | res = [] 56 | for i in range(n_a): 57 | order = s_idx[i] 58 | p = 0.0 59 | r = 0.0 60 | for j in range(n_b): 61 | if lbl_a[i] == lbl_b[order[j]]: 62 | r += 1 63 | p += (r / (j + 1)) 64 | if r > 0: 65 | res.append(p/r) 66 | else: 67 | res.append(0) 68 | return np.mean(res) 69 | --------------------------------------------------------------------------------