├── .gitignore ├── LICENSE ├── README.md ├── data ├── HumanSeg.py ├── ModelNet.py ├── load_npz.py └── load_obj.py ├── docker └── Dockerfile ├── images ├── 1.png ├── 2.png ├── 3.png ├── 4.png ├── 5.png └── 6.png ├── models ├── __init__.py ├── layers │ ├── cpp_extension │ │ ├── meshgraph.cpp │ │ └── setup.py │ ├── mesh.py │ ├── mesh_cpp_extension.py │ ├── mesh_graph_conv.py │ ├── mesh_net.py │ ├── mesh_net_with_out_neigbour.py │ └── struct_conv.py ├── mesh_graph.py ├── networks.py └── optimizer │ └── adabound.py ├── options ├── __init__.py ├── base_options.py ├── test_options.py └── train_options.py ├── preprocess ├── __init__.py └── preprocess.py ├── script ├── human_seg │ ├── get_data.sh │ ├── get_pretrained.sh │ ├── test.sh │ ├── train.sh │ └── view.sh ├── modelnet10 │ └── train.sh └── modelnet40_graph │ ├── test.sh │ └── train.sh ├── test.py ├── train.py └── util ├── __init__.py ├── mesh_viewer.py ├── util.py └── writer.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | feature_img/ 105 | .mypy_cache/ 106 | datasets/ 107 | cal_data.py 108 | runs/ 109 | ckpt/ 110 | ckpt_root/ 111 | demo.py 112 | *.jpg 113 | *.png 114 | *.txt 115 | test_* 116 | demo.py 117 | draw_features.py 118 | !images/* -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Cery D 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MeshGraph in PyTorch 2 | 3 | Transfroming Mesh data to Mesh Graph Tology through the idea of Finite Element, Paper is published on IEEE [url](https://ieeexplore.ieee.org/document/9253518). 4 | ## Transform to Topology Graph 5 | ![transfrom](images/5.png) 6 | ## Network Structure 7 | ![network](images/4.png) 8 | # Getting Started 9 | 10 | ### Installation 11 | - Clone this repo: 12 | ``` bash 13 | git clone https://github.com/JsBlueCat/MeshGraph.git 14 | cd MeshVertexNet 15 | ``` 16 | - Install dependencies: [nvidia-docker](https://github.com/NVIDIA/nvidia-docker) and [docker](https://docs.docker.com/get-started/) 17 | 18 | - First intall docker image 19 | 20 | ```bash 21 | cd docker 22 | docker build -t your/docker:meshgraph. 23 | ``` 24 | 25 | - then run docker image 26 | ```bash 27 | docker run --rm -it --runtime=nvidia --shm-size 16G -e DISPLAY=unix$DISPLAY -v /tmp/.X11-unix:/tmp/.X11-unix -v /your/path/to/MeshGraph/:/meshgraph your/docker:meshgraph bash 28 | ``` 29 | 30 | 31 | ### 3D Shape Classification on ModelNet40 32 | - get the dataset if you fail download it from python [[Model40]](https://drive.google.com/uc?export=download&confirm=HB4c&id=1o9pyskkKMxuomI5BWuLjCG2nSv5iePZz) 33 | - put the zip file in datasets/modelnet40_graph 34 | 35 | ```bash 36 | cd /meshvertex 37 | sh script/modelnet40_graph/train.sh 38 | ``` 39 | 40 | 41 | ### Classification Acc. On ModelNet40 42 | | Method | Modality | Acc | 43 | | ----------------------- |:--------:|:--------:| 44 | | 3DShapeNets | volume | 77.3% | 45 | | Voxnet | volume | 83% | 46 | | O-cnn | volume | 89.9% | 47 | | Mvcnn | view | 90.1% | 48 | | Mlvcnn | view | 94.16% | 49 | | Pointnet | Point cloud | 89.2% | 50 | | Meshnet | Mesh | 91% | 51 | | Ours with SAGE | Mesh | 94.3% +0.5% | 52 | - run test 53 | - get the dataset if you fail download it from python [[Model40]](https://drive.google.com/uc?export=download&confirm=HB4c&id=1o9pyskkKMxuomI5BWuLjCG2nSv5iePZz) 54 | - put the zip file in datasets/modelnet40_graph 55 | - download the weight file from [[weights]](https://drive.google.com/file/d/11JOiaTOBCykCYgZKw24qcD6r1Tzz7dvu/view?usp=sharing) and put it in your ckpt_root/40_graph/ and run 56 | ``` bash 57 | sh script/modelnet40_graph/test.sh 58 | ``` 59 | - the result will be like 60 | ``` bash 61 | root@2730e382330f:/meshvertex# sh script/modelnet40_graph/test.sh 62 | Running Test 63 | loading the model from ./ckpt_root/40_graph/final_net.pth 64 | epoch: -1, TEST ACC: [94.49 %] 65 | ``` 66 | ![result1](images/2.png) 67 | ![result2](images/3.png) 68 | # Train on your Dataset 69 | ### Coming soon 70 | 71 | # Some 3D Reconstruct conducted on 3d face 72 | ![face](images/6.png) 73 | 74 | # Credit 75 | 76 | ### MeshGraphNet: An Effective 3D Polygon Mesh Recognition With Topology Reconstruction 77 | An Ping Song ; Xin Yi Di; Xiao Kang Xu; Zi Heng Song
78 | 79 | **Abstract**
80 | Three-dimensional polygon mesh recognition has a significant impact on current computer graphics. However, its application to some real-life fields, such as unmanned driving and medical image processing, has been restricted due to the lack of inner-interactivity, shift-invariance, and numerical uncertainty of mesh surfaces. In this paper, an interconnected topological dual graph that extracts adjacent information from each triangular face of a polygon mesh is constructed, in order to address the above issues. On the basis of the algebraic topological graph, we propose a mesh graph neural network, called MeshGraphNet, to effectively extract features from mesh data. In this concept, the graph node-unit and correlation between every two dual graph vertexes are defined, the concept of aggregating features extracted from geodesically adjacent nodes is introduced, and a graph neural network with available and effective blocks is proposed. With these methods, MeshGraphNet performs well in 3D shape representation by avoiding the lack of inner-interactivity, shift-invariance, and the numerical uncertainty problems of mesh data. We conduct extensive 3D shape classification experiments and provide visualizations of the features extracted from the fully connected layers. The results demonstrate that our method performs better than state-of-the-art methods and improves the recognition accuracy by 4–4.5%. 81 | 82 | If you find this code useful, please consider citing our paper 83 | 84 | ``` bash 85 | @article{Song2020MeshGraphNetAE, 86 | title={MeshGraphNet: An Effective 3D Polygon Mesh Recognition With Topology Reconstruction}, 87 | author={An Ping Song and Xin Yi Di and X. Xu and Zi Heng Song}, 88 | journal={IEEE Access}, 89 | year={2020}, 90 | volume={8}, 91 | pages={205181-205189} 92 | } 93 | ``` -------------------------------------------------------------------------------- /data/HumanSeg.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import glob 4 | 5 | 6 | import torch 7 | from torch_geometric.data import InMemoryDataset, download_url, extract_zip 8 | from data.load_obj import read_obj 9 | 10 | 11 | class HumanSeg(InMemoryDataset): 12 | def __init__(self, 13 | root, 14 | name='seg', 15 | train=True, 16 | transform=None, 17 | pre_transform=None, 18 | pre_filter=None): 19 | assert name in ['seg'] 20 | self.name = name 21 | self.train = train 22 | super(HumanSeg, self).__init__(root, transform, pre_transform, 23 | pre_filter) 24 | path = self.processed_paths[0] if train else self.processed_paths[1] 25 | self.data, self.slices = torch.load(path) 26 | 27 | @property 28 | def raw_file_names(self): 29 | return [ 30 | 'seg', 'sseg', 'test', 'train' 31 | ] 32 | 33 | @property 34 | def processed_file_names(self): 35 | return ['training.pt', 'test.pt'] 36 | 37 | def download(self): 38 | path = download_url( 39 | 'https://www.dropbox.com/s/s3n05sw0zg27fz3/human_seg.tar.gz', self.root) 40 | extract_zip(path, self.root) 41 | os.unlink(path) 42 | folder = osp.join(self.root, 'human_seg') 43 | os.rename(folder, self.raw_dir) 44 | 45 | def process(self): 46 | torch.save(self.process_set('train'), self.processed_paths[0]) 47 | torch.save(self.process_set('test'), self.processed_paths[1]) 48 | 49 | def process_set(self, dataset): 50 | categories = glob.glob(osp.join(self.raw_dir, '*', '')) 51 | categories = sorted([x.split(os.sep)[-2] for x in categories]) 52 | data_list = [] 53 | seg_folder = osp.join(self.raw_dir, 'seg') 54 | sseg_folder = osp.join(self.raw_dir, 'sseg') 55 | data_folder = osp.join(self.raw_dir, dataset) 56 | paths = glob.glob('{}/*.obj'.format(data_folder)) 57 | for path in paths: 58 | data = read_obj(path) 59 | print(data) 60 | 61 | if self.pre_filter is not None: 62 | data_list = [d for d in data_list if self.pre_filter(d)] 63 | 64 | if self.pre_transform is not None: 65 | data_list = [self.pre_transform(d) for d in data_list] 66 | 67 | return self.collate(data_list) 68 | 69 | def __repr__(self): 70 | return '{}{}({})'.format(self.__class__.__name__, self.name, len(self)) 71 | -------------------------------------------------------------------------------- /data/ModelNet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import glob 4 | 5 | 6 | import torch 7 | from torch_geometric.data import InMemoryDataset, download_url, extract_zip 8 | from data.load_npz import read_npz 9 | 10 | 11 | class ModelNet(InMemoryDataset): 12 | r"""The ModelNet10/40 datasets from the `"3D ShapeNets: A Deep 13 | Representation for Volumetric Shapes" 14 | `_ paper, 15 | containing CAD models of 10 and 40 categories, respectively. 16 | 17 | .. note:: 18 | 19 | Data objects hold mesh faces instead of edge indices. 20 | To convert the mesh to a graph, use the 21 | :obj:`torch_geometric.transforms.FaceToEdge` as :obj:`pre_transform`. 22 | To convert the mesh to a point cloud, use the 23 | :obj:`torch_geometric.transforms.SamplePoints` as :obj:`transform` to 24 | sample a fixed number of points on the mesh faces according to their 25 | face area. 26 | 27 | Args: 28 | root (string): Root directory where the dataset should be saved. 29 | name (string, optional): The name of the dataset (:obj:`"10"` for 30 | ModelNet10, :obj:`"40"` for ModelNet40). (default: :obj:`"10"`) 31 | train (bool, optional): If :obj:`True`, loads the training dataset, 32 | otherwise the test dataset. (default: :obj:`True`) 33 | transform (callable, optional): A function/transform that takes in an 34 | :obj:`torch_geometric.data.Data` object and returns a transformed 35 | version. The data object will be transformed before every access. 36 | (default: :obj:`None`) 37 | pre_transform (callable, optional): A function/transform that takes in 38 | an :obj:`torch_geometric.data.Data` object and returns a 39 | transformed version. The data object will be transformed before 40 | being saved to disk. (default: :obj:`None`) 41 | pre_filter (callable, optional): A function that takes in an 42 | :obj:`torch_geometric.data.Data` object and returns a boolean 43 | value, indicating whether the data object should be included in the 44 | final dataset. (default: :obj:`None`) 45 | """ 46 | 47 | urls = { 48 | '10': 49 | 'http://vision.princeton.edu/projects/2014/3DShapeNets/ModelNet10.zip', 50 | '40': 'http://modelnet.cs.princeton.edu/ModelNet40.zip', 51 | '40_graph': 'https://drive.google.com/uc?export=download&confirm=HB4c&id=1o9pyskkKMxuomI5BWuLjCG2nSv5iePZz' 52 | } 53 | 54 | def __init__(self, 55 | root, 56 | name='10', 57 | train=True, 58 | transform=None, 59 | pre_transform=None, 60 | pre_filter=None): 61 | assert name in ['10', '40', '40_graph'] 62 | self.name = name 63 | self.train = train 64 | super(ModelNet, self).__init__(root, transform, pre_transform, 65 | pre_filter) 66 | path = self.processed_paths[0] if train else self.processed_paths[1] 67 | self.data, self.slices = torch.load(path) 68 | 69 | @property 70 | def raw_file_names(self): 71 | return [ 72 | 'bathtub', 'bed', 'chair', 'desk', 'dresser', 'monitor', 73 | 'night_stand', 'sofa', 'table', 'toilet' 74 | ] 75 | 76 | @property 77 | def processed_file_names(self): 78 | return ['training.pt', 'test.pt'] 79 | 80 | def download(self): 81 | path = download_url(self.urls[self.name], self.root) 82 | extract_zip(path, self.root) 83 | os.unlink(path) 84 | folder = osp.join(self.root, 'ModelNet{}'.format(self.name)) 85 | os.rename(folder, self.raw_dir) 86 | 87 | def process(self): 88 | torch.save(self.process_set('train'), self.processed_paths[0]) 89 | torch.save(self.process_set('test'), self.processed_paths[1]) 90 | 91 | def process_set(self, dataset): 92 | categories = glob.glob(osp.join(self.raw_dir, '*', '')) 93 | categories = sorted([x.split(os.sep)[-2] for x in categories]) 94 | 95 | data_list = [] 96 | for target, category in enumerate(categories): 97 | folder = osp.join(self.raw_dir, category, dataset) 98 | paths = glob.glob('{}/{}_*.npz'.format(folder, category)) 99 | for path in paths: 100 | data = read_npz(path, self.train) 101 | data.y = torch.tensor([target]) 102 | data_list.append(data) 103 | 104 | if self.pre_filter is not None: 105 | data_list = [d for d in data_list if self.pre_filter(d)] 106 | 107 | if self.pre_transform is not None: 108 | data_list = [self.pre_transform(d) for d in data_list] 109 | 110 | return self.collate(data_list) 111 | 112 | def __repr__(self): 113 | return '{}{}({})'.format(self.__class__.__name__, self.name, len(self)) 114 | -------------------------------------------------------------------------------- /data/load_npz.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch_geometric.data import Data 4 | 5 | 6 | def read_npz(path, train): 7 | with np.load(path) as f: 8 | return parse_npz(f, train) 9 | 10 | 11 | def parse_npz(f, train): 12 | face = f['face'] 13 | neighbor_index = f['neighbor_index'] 14 | 15 | # data augmentation 16 | if train: 17 | sigma, clip = 0.01, 0.05 18 | jittered_data = np.clip( 19 | sigma * np.random.randn(*face[:, :12].shape), -1 * clip, clip) 20 | face = np.concatenate((face[:, :12] + jittered_data, face[:, 12:]), 1) 21 | 22 | # fill for n < max_faces with randomly picked faces 23 | num_point = len(face) 24 | if num_point < 1024: 25 | fill_face = [] 26 | fill_neighbor_index = [] 27 | for i in range(1024 - num_point): 28 | index = np.random.randint(0, num_point) 29 | fill_face.append(face[index]) 30 | fill_neighbor_index.append(neighbor_index[index]) 31 | face = np.concatenate((face, np.array(fill_face))) 32 | neighbor_index = np.concatenate( 33 | (neighbor_index, np.array(fill_neighbor_index))) 34 | 35 | # to tensor 36 | face = torch.from_numpy(face).float() 37 | neighbor_index = torch.from_numpy(neighbor_index).long() 38 | index = torch.arange(face.size(0)).unsqueeze(dim=1).repeat(1, 3) 39 | gather_index = torch.tensor([0, 3, 1, 4, 2, 5]).repeat(face.size(0), 1) 40 | edge_index = torch.cat([neighbor_index, index], dim=1).gather( 41 | 1, gather_index).view(-1, 2).permute(1, 0) 42 | 43 | 44 | # reorganize 45 | face = face.permute(1, 0) 46 | centers, corners, normals = face[:3], face[3:12], face[12:] 47 | 48 | 49 | # # get the sod of each faces 50 | # ''' 51 | # w(e) = cos(ni/||ni|| * nj/||nj||)^-1 52 | # ''' 53 | # start_point, end_point = edge_index[0, :], edge_index[1, :] 54 | 55 | # print(start_point.size()) 56 | # print(normals.size()) 57 | 58 | 59 | 60 | 61 | corners = corners - torch.cat([centers, centers, centers], 0) 62 | 63 | features = torch.cat([centers, corners, normals], dim=0).permute(1, 0) 64 | 65 | data = Data(x=features, edge_index=edge_index, pos=neighbor_index) 66 | return data 67 | -------------------------------------------------------------------------------- /data/load_obj.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_geometric.data import Data 3 | 4 | 5 | def yield_file(in_file): 6 | f = open(in_file) 7 | buf = f.read() 8 | f.close() 9 | for b in buf.split('\n'): 10 | if b.startswith('v '): 11 | yield ['v', [float(x) for x in b.split(" ")[1:]]] 12 | elif b.startswith('f '): 13 | triangles = b.split(' ')[1:] 14 | # -1 as .obj is base 1 but the Data class expects base 0 indices 15 | yield ['f', [int(t.split("/")[0]) - 1 for t in triangles]] 16 | else: 17 | yield ['', ""] 18 | 19 | 20 | def read_obj(in_file): 21 | vertices = [] 22 | faces = [] 23 | 24 | for k, v in yield_file(in_file): 25 | if k == 'v': 26 | vertices.append(v) 27 | elif k == 'f': 28 | faces.append(v) 29 | 30 | if not len(faces) or not len(vertices): 31 | return None 32 | 33 | pos = torch.tensor(vertices, dtype=torch.float) 34 | face = torch.tensor(faces, dtype=torch.long).t().contiguous() 35 | 36 | data = Data(pos=pos, face=face) 37 | 38 | return data -------------------------------------------------------------------------------- /docker/Dockerfile: -------------------------------------------------------------------------------- 1 | ARG CUDA="10.2" 2 | ARG CUDNN="7" 3 | 4 | FROM nvidia/cuda:${CUDA}-cudnn${CUDNN}-devel-ubuntu16.04 5 | # FROM pytorch/pytorch:latest 6 | 7 | RUN echo 'debconf debconf/frontend select Noninteractive' | debconf-set-selections 8 | 9 | # install basics 10 | RUN apt-get update -y \ 11 | && apt-get install -y apt-utils git curl ca-certificates bzip2 cmake tree htop bmon iotop g++ \ 12 | && apt-get install -y libglib2.0-0 libsm6 libxext6 libxrender-dev wget 13 | 14 | # Install Miniconda 15 | RUN cd / && wget https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh \ 16 | && chmod +x /Miniconda3-latest-Linux-x86_64.sh \ 17 | && /Miniconda3-latest-Linux-x86_64.sh -b -p /miniconda \ 18 | && rm /Miniconda3-latest-Linux-x86_64.sh 19 | 20 | ENV PATH=/miniconda/bin:$PATH 21 | 22 | # Create a Python 3.6 environment 23 | RUN /miniconda/bin/conda install -y conda-build \ 24 | && /miniconda/bin/conda create -y --name py37 python=3.7 \ 25 | && /miniconda/bin/conda clean -ya 26 | 27 | ENV CONDA_DEFAULT_ENV=py37 28 | ENV CONDA_PREFIX=/miniconda/envs/$CONDA_DEFAULT_ENV 29 | ENV PATH=$CONDA_PREFIX/bin:$PATH 30 | ENV CONDA_AUTO_UPDATE_CONDA=false 31 | 32 | RUN conda install -y ipython 33 | RUN pip install --default-timeout=1000 numpy ninja yacs cython matplotlib opencv-python tqdm pyyaml tensorboardX hiddenlayer -i http://pypi.douban.com/simple --trusted-host pypi.douban.com 34 | 35 | # install pytorch 36 | # RUN conda config --add channels https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main/ 37 | # RUN conda config --add channels https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/free/ 38 | # RUN conda config --add channels https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/pytorch/ 39 | # RUN conda config --set show_channel_urls yes && conda install pytorch torchvision cudatoolkit=${CUDA} -c pytorch \ 40 | # && conda clean -ya 41 | # RUN conda install pytorch torchvision cudatoolkit=${CUDA} -c pytorch \ 42 | # && conda clean -ya 43 | 44 | # set cuda path 45 | ENV PATH=/usr/local/cuda/bin:$PATH 46 | ENV CPATH=/usr/local/cuda/include:$CPATH 47 | ENV LD_LIBRARY_PATH=/usr/local/cuda/lib64:$LD_LIBRARY_PATH 48 | ENV DYLD_LIBRARY_PATH=/usr/local/cuda/lib:$DYLD_LIBRARY_PATH 49 | 50 | RUN pip install torch-scatter==latest+cu102 -f https://pytorch-geometric.com/whl/torch-1.5.0.html -i http://pypi.douban.com/simple --trusted-host pypi.douban.com 51 | RUN pip install torch-sparse==latest+cu102 -f https://pytorch-geometric.com/whl/torch-1.5.0.html -i http://pypi.douban.com/simple --trusted-host pypi.douban.com 52 | RUN pip install torch-cluster==latest+cu102 -f https://pytorch-geometric.com/whl/torch-1.5.0.html -i http://pypi.douban.com/simple --trusted-host pypi.douban.com 53 | RUN pip install torch-spline-conv==latest+cu102 -f https://pytorch-geometric.com/whl/torch-1.5.0.html -i http://pypi.douban.com/simple --trusted-host pypi.douban.com 54 | RUN pip install torch-geometric -i http://pypi.douban.com/simple --trusted-host pypi.douban.com 55 | RUN pip install torchvision==0.6.0 -i http://pypi.douban.com/simple --trusted-host pypi.douban.com 56 | 57 | WORKDIR /meshgraph 58 | -------------------------------------------------------------------------------- /images/1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JsBlueCat/MeshGraph/c7b331f64b70a442d1042351efe4dc4231e8b289/images/1.png -------------------------------------------------------------------------------- /images/2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JsBlueCat/MeshGraph/c7b331f64b70a442d1042351efe4dc4231e8b289/images/2.png -------------------------------------------------------------------------------- /images/3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JsBlueCat/MeshGraph/c7b331f64b70a442d1042351efe4dc4231e8b289/images/3.png -------------------------------------------------------------------------------- /images/4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JsBlueCat/MeshGraph/c7b331f64b70a442d1042351efe4dc4231e8b289/images/4.png -------------------------------------------------------------------------------- /images/5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JsBlueCat/MeshGraph/c7b331f64b70a442d1042351efe4dc4231e8b289/images/5.png -------------------------------------------------------------------------------- /images/6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JsBlueCat/MeshGraph/c7b331f64b70a442d1042351efe4dc4231e8b289/images/6.png -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | def create_model(opt): 2 | from .mesh_graph import mesh_graph 3 | model = mesh_graph(opt) 4 | return model 5 | -------------------------------------------------------------------------------- /models/layers/cpp_extension/meshgraph.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | at::Tensor get_connect_matrix(const at::Tensor &faces, const int64_t &num_nodes) 4 | { 5 | /* 6 | get connect matrix 7 | */ 8 | std::vector> node_list(num_nodes, std::vector(0)); 9 | #pragma omp parallel for 10 | for (int64_t i = 0; i < faces.size(0); i++) 11 | { 12 | int64_t *col = faces[i].data(); 13 | #pragma omp parallel for 14 | for (int j = 0; j < 3; j++) 15 | { 16 | node_list[col[j]].push_back(i); 17 | } 18 | } 19 | at::Tensor result = at::zeros(2, faces.options()); 20 | #pragma omp parallel for 21 | for (int64_t it = 0; it < num_nodes; it++) 22 | { 23 | #pragma omp parallel for 24 | for (int64_t i = 0; i < node_list[it].size() - 1; i++) 25 | { 26 | #pragma omp parallel for 27 | for (int64_t j = i + 1; j < node_list[it].size(); j++) 28 | { 29 | std::cout << i << " " << j << std::endl; 30 | auto temp = at::empty(2, faces.options()); 31 | temp[0] = i; 32 | temp[1] = j; 33 | result = at::cat({result, temp}, 0); 34 | } 35 | } 36 | } 37 | 38 | return result.view({-1, 2}); 39 | } 40 | 41 | std::string test() 42 | { 43 | std::cout << "hello world" << std::endl; 44 | return "hello_world"; 45 | } 46 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) 47 | { 48 | m.def("get_connect_matrix", &get_connect_matrix, "get_connect_matrix"); 49 | m.def("test", &test, "test function"); 50 | } 51 | -------------------------------------------------------------------------------- /models/layers/cpp_extension/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, Extension 2 | from torch.utils import cpp_extension 3 | 4 | 5 | setup(name='meshgraph_cpp', 6 | ext_modules=[cpp_extension.CppExtension( 7 | 'meshgraph_cpp', ['meshgraph.cpp'])], 8 | cmdclass={'build_ext': cpp_extension.BuildExtension}) 9 | -------------------------------------------------------------------------------- /models/layers/mesh.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import math 4 | 5 | 6 | class Mesh: 7 | ''' 8 | mesh graph obj contains pos edge_index and x 9 | x reprensent feature in face obj 10 | edge_index is the sparse matrix of adj matrix 11 | pos represent position matrix 12 | ''' 13 | 14 | def __init__(self, vertexes, faces): 15 | self.vertexes = vertexes 16 | self.faces = faces.t().long() 17 | self.edge_index = self.nodes = None 18 | self.num_nodes = self.vertexes.size(0) 19 | # find point and face indices 20 | self.sorted_point_to_face = None 21 | self.sorted_point_to_face_index_dict = None 22 | # normalize vertexes 23 | self.normlize_vertices() 24 | 25 | # create graph 26 | self.create_graph() 27 | 28 | def create_graph(self): 29 | # conat center ox oy oz norm 30 | pos = self.vertexes[self.faces] 31 | point_x, point_y, point_z = pos[:, 0, :], pos[:, 1, :], pos[:, 2, :] 32 | centers = get_inner_center_vec( 33 | point_x, point_y, point_z 34 | ) 35 | temp = centers.view(-1) 36 | if torch.sum(torch.isnan(temp), dim=0) > 0: 37 | raise('center ------------------------------- nan') 38 | ox, oy, oz = get_three_vec(centers, point_x, point_y, point_z) 39 | norm = get_unit_norm_vec(point_x, point_y, point_z) 40 | temp = norm.view(-1) 41 | if torch.sum(torch.isnan(temp), dim=0) > 0: 42 | raise('norm ------------------------------- nan') 43 | # cat the vecter 44 | self.nodes = torch.cat((centers, ox, oy, oz, norm), dim=1) 45 | 46 | is_nan = self.nodes.view(-1) 47 | if torch.sum(torch.isnan(is_nan), dim=0) > 0: 48 | raise('contain nan ') 49 | self.get_connect_matrix() 50 | 51 | def normlize_vertices(self): 52 | ''' move vertices to center 53 | ''' 54 | center = (torch.max(self.vertexes, dim=0)[0] + 55 | torch.min(self.vertexes, dim=0)[0])/2 56 | self.vertexes -= center 57 | max_len = torch.max(self.vertexes[:, 0]**2 + 58 | self.vertexes[:, 1]**2 + self.vertexes[:, 2]**2).item() 59 | self.vertexes /= math.sqrt(max_len) 60 | 61 | def get_connect_matrix(self): 62 | node_list = [[] for _ in range(self.num_nodes)] 63 | result = [] 64 | for i, v in enumerate(self.faces): 65 | v0, v1, v2 = v 66 | node_list[v0].append(i) 67 | node_list[v1].append(i) 68 | node_list[v2].append(i) 69 | for i in node_list: 70 | for p in range(0, len(i)-1): 71 | for q in range(p+1, len(i)): 72 | result.append([p, q]) 73 | self_loop = torch.arange(self.nodes.size(0)).repeat(2, 1).t().long() 74 | result = torch.tensor(result).long() 75 | self.edge_index = torch.cat([self_loop, result], dim=0) 76 | 77 | 78 | def get_inner_center_vec(v1, v2, v3): 79 | ''' 80 | v1 v2 v3 represent 3 vertexes of triangle 81 | v1 (n,3) 82 | ''' 83 | return (v1+v2+v3)/3 84 | 85 | 86 | def get_distance_vec(v1, v2): 87 | ''' 88 | get distance between 89 | vecter_1 and vecter_2 90 | v1 : (x,y,z) 91 | ''' 92 | return torch.sqrt(torch.sum((v1-v2)**2, dim=1)) 93 | 94 | 95 | def get_unit_norm_vec(v1, v2, v3): 96 | ''' 97 | xy X xz 98 | (y1z2-y2z1,x2z1-x1z2,x1y2-x2y1) 99 | ''' 100 | xy = v2-v1 101 | xz = v3-v1 102 | x1, y1, z1 = xy[:, 0], xy[:, 1], xy[:, 2] 103 | x2, y2, z2 = xz[:, 0], xz[:, 1], xz[:, 2] 104 | norm = torch.stack((y1*z2-y2*z1, x2*z1-x1*z2, x1*y2-x2*y1), dim=1) 105 | return norm 106 | 107 | 108 | def get_three_vec(center, v1, v2, v3): 109 | ''' 110 | return ox oy oz vector 111 | ''' 112 | return v1-center, v2-center, v3-center 113 | -------------------------------------------------------------------------------- /models/layers/mesh_cpp_extension.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | class Mesh: 6 | ''' 7 | mesh graph obj contains pos edge_index and x 8 | x reprensent feature in face obj 9 | edge_index is the sparse matrix of adj matrix 10 | pos represent position matrix 11 | ''' 12 | 13 | def __init__(self, vertexes, faces, meshgraph_cpp): 14 | self.vertexes = vertexes 15 | self.faces = faces.t() 16 | self.edge_index = self.nodes = None 17 | self.num_nodes = self.vertexes.size(0) 18 | self.meshgraph_cpp = meshgraph_cpp 19 | # find point and face indices 20 | self.sorted_point_to_face = None 21 | self.sorted_point_to_face_index_dict = None 22 | # create graph 23 | self.create_graph() 24 | 25 | def create_graph(self): 26 | # conat center ox oy oz norm 27 | pos = self.vertexes[self.faces] 28 | point_x, point_y, point_z = pos[:, 0, :], pos[:, 1, :], pos[:, 2, :] 29 | centers = get_inner_center_vec( 30 | point_x, point_y, point_z 31 | ) 32 | ox, oy, oz = get_three_vec(centers, point_x, point_y, point_z) 33 | norm = get_unit_norm_vec(point_x, point_y, point_z) 34 | # cat the vecter 35 | self.nodes = torch.cat((centers, ox, oy, oz, norm), dim=1) 36 | self.edge_index = self.meshgraph_cpp.get_connect_matrix( 37 | self.faces, self.num_nodes 38 | ) 39 | 40 | 41 | def get_inner_center_vec(v1, v2, v3): 42 | ''' 43 | v1 v2 v3 represent 3 vertexes of triangle 44 | v1 (n,3) 45 | ''' 46 | a = get_distance_vec(v2, v3) 47 | b = get_distance_vec(v3, v1) 48 | c = get_distance_vec(v1, v2) 49 | x = torch.stack((v1[:, 0], v2[:, 0], v3[:, 0]), dim=1) 50 | y = torch.stack((v1[:, 1], v2[:, 1], v3[:, 1]), dim=1) 51 | z = torch.stack((v1[:, 2], v2[:, 2], v3[:, 2]), dim=1) 52 | dis = torch.stack((a, b, c), dim=1) 53 | return torch.stack(( 54 | torch.sum((x * dis) / (a+b+c).repeat(3, 1).t(), dim=1), 55 | torch.sum((y * dis) / (a+b+c).repeat(3, 1).t(), dim=1), 56 | torch.sum((z * dis) / (a+b+c).repeat(3, 1).t(), dim=1), 57 | ), dim=1) 58 | 59 | 60 | def get_distance_vec(v1, v2): 61 | ''' 62 | get distance between 63 | vecter_1 and vecter_2 64 | v1 : (x,y,z) 65 | ''' 66 | return torch.sqrt(torch.sum((v1-v2)**2, dim=1)) 67 | 68 | 69 | def get_unit_norm_vec(v1, v2, v3): 70 | ''' 71 | xy X xz 72 | (y1z2-y2z1,x2z1-x1z2,x1y2-x2y1) 73 | ''' 74 | xy = v2-v1 75 | xz = v3-v1 76 | x1, y1, z1 = xy[:, 0], xy[:, 1], xy[:, 2] 77 | x2, y2, z2 = xz[:, 0], xz[:, 1], xz[:, 2] 78 | norm = torch.stack((y1*z2-y2*z1, x2*z1-x1*z2, x1*y2-x2*y1), dim=1) 79 | vec_len = torch.sqrt(torch.sum(norm, dim=1)) 80 | return norm / vec_len.repeat(3, 1).t() 81 | 82 | 83 | def get_three_vec(center, v1, v2, v3): 84 | ''' 85 | return ox oy oz vector 86 | ''' 87 | return v1-center, v2-center, v3-center 88 | -------------------------------------------------------------------------------- /models/layers/mesh_graph_conv.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import math 5 | from torch.nn.parameter import Parameter 6 | from torch_scatter import scatter_add 7 | from torch_geometric.nn import GINConv, SAGEConv 8 | 9 | 10 | class FaceRotateConvolution(nn.Module): 11 | 12 | def __init__(self): 13 | super(FaceRotateConvolution, self).__init__() 14 | self.rotate_mlp = nn.Sequential( 15 | nn.Conv1d(6, 32, 1), 16 | nn.BatchNorm1d(32), 17 | nn.ReLU(), 18 | nn.Conv1d(32, 32, 1), 19 | nn.BatchNorm1d(32), 20 | nn.ReLU() 21 | ) 22 | self.fusion_mlp = nn.Sequential( 23 | nn.Conv1d(32, 64, 1), 24 | nn.BatchNorm1d(64), 25 | nn.ReLU(), 26 | nn.Conv1d(64, 64, 1), 27 | nn.BatchNorm1d(64), 28 | nn.ReLU() 29 | ) 30 | 31 | def forward(self, corners): 32 | fea = (self.rotate_mlp(corners[:, :6]) + 33 | self.rotate_mlp(corners[:, 3:9]) + 34 | self.rotate_mlp(torch.cat([corners[:, 6:], corners[:, :3]], 1))) / 3 35 | return self.fusion_mlp(fea) 36 | 37 | 38 | class FaceKernelCorrelation(nn.Module): 39 | 40 | def __init__(self, num_kernel=64, sigma=0.2): 41 | super(FaceKernelCorrelation, self).__init__() 42 | self.num_kernel = num_kernel 43 | self.sigma = sigma 44 | self.weight_alpha = Parameter(torch.rand(1, num_kernel, 4) * np.pi) 45 | self.weight_beta = Parameter(torch.rand(1, num_kernel, 4) * 2 * np.pi) 46 | self.bn = nn.BatchNorm1d(num_kernel) 47 | self.relu = nn.ReLU() 48 | 49 | def forward(self, normals, neighbor_index): 50 | 51 | b, _, n = normals.size() 52 | 53 | center = normals.unsqueeze( 54 | 2).expand(-1, -1, self.num_kernel, -1).unsqueeze(4) 55 | neighbor = torch.gather(normals.unsqueeze(3).expand(-1, -1, -1, 3), 2, 56 | neighbor_index.unsqueeze(1).expand(-1, 3, -1, -1)) 57 | neighbor = neighbor.unsqueeze( 58 | 2).expand(-1, -1, self.num_kernel, -1, -1) 59 | 60 | fea = torch.cat([center, neighbor], 4) 61 | fea = fea.unsqueeze(5).expand(-1, -1, -1, -1, -1, 4) 62 | weight = torch.cat([torch.sin(self.weight_alpha) * torch.cos(self.weight_beta), 63 | torch.sin(self.weight_alpha) * 64 | torch.sin(self.weight_beta), 65 | torch.cos(self.weight_alpha)], 0) 66 | weight = weight.unsqueeze(0).expand(b, -1, -1, -1) 67 | weight = weight.unsqueeze(3).expand(-1, -1, -1, n, -1) 68 | weight = weight.unsqueeze(4).expand(-1, -1, -1, -1, 4, -1) 69 | 70 | dist = torch.sum((fea - weight)**2, 1) 71 | fea = torch.sum( 72 | torch.sum(np.e**(dist / (-2 * self.sigma**2)), 4), 3) / 16 73 | 74 | return self.relu(self.bn(fea)) 75 | 76 | 77 | class SpatialDescriptor(nn.Module): 78 | 79 | def __init__(self): 80 | super(SpatialDescriptor, self).__init__() 81 | 82 | self.spatial_mlp = nn.Sequential( 83 | nn.Conv1d(3, 64, 1), 84 | nn.BatchNorm1d(64), 85 | nn.ReLU(), 86 | nn.Conv1d(64, 64, 1), 87 | nn.BatchNorm1d(64), 88 | nn.ReLU(), 89 | ) 90 | 91 | def forward(self, centers): 92 | return self.spatial_mlp(centers) 93 | 94 | 95 | class StructuralDescriptor(nn.Module): 96 | 97 | def __init__(self): 98 | super(StructuralDescriptor, self).__init__() 99 | 100 | self.FRC = FaceRotateConvolution() 101 | self.FKC = FaceKernelCorrelation(64, 0.2) 102 | self.structural_mlp = nn.Sequential( 103 | nn.Conv1d(64 + 3, 131, 1), 104 | nn.BatchNorm1d(131), 105 | nn.ReLU(), 106 | nn.Conv1d(131, 131, 1), 107 | nn.BatchNorm1d(131), 108 | nn.ReLU(), 109 | ) 110 | 111 | def forward(self, corners, normals): 112 | structural_fea1 = self.FRC(corners) 113 | # structural_fea2 = self.FKC(normals, neighbor_index) 114 | 115 | return self.structural_mlp(torch.cat([structural_fea1, normals], 1)) 116 | 117 | 118 | class MeshConvolution(nn.Module): 119 | 120 | def __init__(self, spatial_in_channel, structural_in_channel, spatial_out_channel, structural_out_channel): 121 | super(MeshConvolution, self).__init__() 122 | 123 | self.spatial_in_channel = spatial_in_channel 124 | self.structural_in_channel = structural_in_channel 125 | self.spatial_out_channel = spatial_out_channel 126 | self.structural_out_channel = structural_out_channel 127 | 128 | self.aggregation_method = 'Concat' 129 | 130 | self.combination_mlp = nn.Sequential( 131 | nn.Conv1d(self.spatial_in_channel + 132 | self.structural_in_channel, self.spatial_out_channel, 1), 133 | nn.BatchNorm1d(self.spatial_out_channel), 134 | nn.ReLU(), 135 | ) 136 | 137 | # if self.aggregation_method == 'Concat': 138 | # self.concat_mlp = nn.Sequential( 139 | # nn.Conv2d(self.structural_in_channel * 2, 140 | # self.structural_in_channel, 1), 141 | # nn.BatchNorm2d(self.structural_in_channel), 142 | # nn.ReLU(), 143 | # ) 144 | 145 | # self.concat_mlp = nn.Sequential( 146 | # nn.Conv1d(self.structural_in_channel, 147 | # self.structural_out_channel, 1), 148 | # nn.BatchNorm1d(self.spatial_out_channel), 149 | # nn.ReLU(), 150 | # ) 151 | 152 | self.aggregation_mlp = nn.Sequential( 153 | nn.Conv1d(self.structural_in_channel, 154 | self.structural_out_channel, 1), 155 | nn.BatchNorm1d(self.structural_out_channel), 156 | nn.ReLU(), 157 | ) 158 | 159 | def forward(self, spatial_fea, structural_fea,_): 160 | b, _, n = spatial_fea.size() 161 | 162 | # Combination 163 | spatial_fea = self.combination_mlp( 164 | torch.cat([spatial_fea, structural_fea], 1)) 165 | 166 | # # Aggregation 167 | # if self.aggregation_method == 'Concat': 168 | # structural_fea = torch.cat([structural_fea.unsqueeze(3).expand(-1, -1, -1, 3), 169 | # torch.gather(structural_fea.unsqueeze(3).expand(-1, -1, -1, 3), 2, 170 | # neighbor_index.unsqueeze(1).expand(-1, self.structural_in_channel, 171 | # -1, -1))], 1) 172 | # structural_fea = self.concat_mlp(structural_fea) 173 | # structural_fea = torch.max(structural_fea, 3)[0] 174 | 175 | # elif self.aggregation_method == 'Max': 176 | # structural_fea = torch.cat([structural_fea.unsqueeze(3), 177 | # torch.gather(structural_fea.unsqueeze(3).expand(-1, -1, -1, 3), 2, 178 | # neighbor_index.unsqueeze(1).expand(-1, self.structural_in_channel, 179 | # -1, -1))], 3) 180 | # structural_fea = torch.max(structural_fea, 3)[0] 181 | 182 | # elif self.aggregation_method == 'Average': 183 | # structural_fea = torch.cat([structural_fea.unsqueeze(3), 184 | # torch.gather(structural_fea.unsqueeze(3).expand(-1, -1, -1, 3), 2, 185 | # neighbor_index.unsqueeze(1).expand(-1, self.structural_in_channel, 186 | # -1, -1))], 3) 187 | # structural_fea = torch.sum(structural_fea, dim=3) / 4 188 | 189 | structural_fea = self.aggregation_mlp(structural_fea) 190 | 191 | return spatial_fea, structural_fea 192 | 193 | 194 | class NormGraphConv(nn.Module): 195 | def __init__(self): 196 | super(NormGraphConv, self).__init__() 197 | self.graph_conv = SAGEConv(3, 64) 198 | self.norm_mlp = nn.Sequential( 199 | nn.Conv1d(64, 64, 1), 200 | nn.BatchNorm1d(64), 201 | nn.ReLU() 202 | ) 203 | 204 | def forward(self, x, edge_index, opt): 205 | result = self.graph_conv(x, edge_index) 206 | # x = x.view(opt.batch_size, -1, 3) # n 1024 3 207 | # result = None 208 | # for i in range(opt.batch_size): 209 | # temp = self.graph_conv(x[i], edge_index) # 1024 3 -> 1024 64 210 | # if i == 0: 211 | # result = temp 212 | # else: 213 | # result = torch.cat([result, temp], dim=0) # 1024 * i 64 214 | result = result.view(opt.batch_size, -1, 215 | 64).transpose(1, 2) # 64 64 1024 216 | return self.norm_mlp(result) 217 | -------------------------------------------------------------------------------- /models/layers/mesh_net.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn.parameter import Parameter 4 | 5 | 6 | class FaceRotateConvolution(nn.Module): 7 | 8 | def __init__(self): 9 | super(FaceRotateConvolution, self).__init__() 10 | self.rotate_mlp = nn.Sequential( 11 | nn.Conv1d(6, 32, 1), 12 | nn.BatchNorm1d(32), 13 | nn.ReLU(), 14 | nn.Conv1d(32, 32, 1), 15 | nn.BatchNorm1d(32), 16 | nn.ReLU() 17 | ) 18 | self.fusion_mlp = nn.Sequential( 19 | nn.Conv1d(32, 64, 1), 20 | nn.BatchNorm1d(64), 21 | nn.ReLU(), 22 | nn.Conv1d(64, 64, 1), 23 | nn.BatchNorm1d(64), 24 | nn.ReLU() 25 | ) 26 | 27 | def forward(self, corners): 28 | fea = (self.rotate_mlp(corners[:, :6]) + 29 | self.rotate_mlp(corners[:, 3:9]) + 30 | self.rotate_mlp(torch.cat([corners[:, 6:], corners[:, :3]], 1))) / 3 31 | return self.fusion_mlp(fea) 32 | 33 | 34 | class SpatialDescriptor(nn.Module): 35 | 36 | def __init__(self): 37 | super(SpatialDescriptor, self).__init__() 38 | 39 | self.spatial_mlp = nn.Sequential( 40 | nn.Conv1d(3, 64, 1), 41 | nn.BatchNorm1d(64), 42 | nn.ReLU(), 43 | nn.Conv1d(64, 64, 1), 44 | nn.BatchNorm1d(64), 45 | nn.ReLU(), 46 | ) 47 | 48 | def forward(self, centers): 49 | return self.spatial_mlp(centers) 50 | 51 | 52 | class StructuralDescriptor(nn.Module): 53 | 54 | def __init__(self): 55 | super(StructuralDescriptor, self).__init__() 56 | 57 | self.FRC = FaceRotateConvolution() 58 | self.structural_mlp = nn.Sequential( 59 | nn.Conv1d(64+3+64, 131, 1), 60 | nn.BatchNorm1d(131), 61 | nn.ReLU(), 62 | nn.Conv1d(131, 131, 1), 63 | nn.BatchNorm1d(131), 64 | nn.ReLU(), 65 | ) 66 | 67 | def forward(self, corners, normals, extra_norm): 68 | structural_fea1 = self.FRC(corners) 69 | 70 | return self.structural_mlp(torch.cat([structural_fea1, normals, extra_norm], 1)) 71 | 72 | 73 | class MeshConvolution(nn.Module): 74 | 75 | def __init__(self, spatial_in_channel, structural_in_channel, spatial_out_channel, structural_out_channel): 76 | super(MeshConvolution, self).__init__() 77 | 78 | self.spatial_in_channel = spatial_in_channel 79 | self.structural_in_channel = structural_in_channel 80 | self.spatial_out_channel = spatial_out_channel 81 | self.structural_out_channel = structural_out_channel 82 | 83 | self.combination_mlp = nn.Sequential( 84 | nn.Conv1d(self.spatial_in_channel + 85 | self.structural_in_channel, self.spatial_out_channel, 1), 86 | nn.GroupNorm(32, self.spatial_out_channel), 87 | nn.ReLU(), 88 | ) 89 | 90 | self.aggregation_mlp = nn.Sequential( 91 | nn.Conv1d(self.structural_in_channel, 92 | self.structural_out_channel, 1), 93 | nn.GroupNorm(32, self.structural_out_channel), 94 | nn.ReLU(), 95 | ) 96 | 97 | def forward(self, spatial_fea, structural_fea, _): 98 | b, _, n = spatial_fea.size() 99 | 100 | # Combination 101 | spatial_fea = self.combination_mlp( 102 | torch.cat([spatial_fea, structural_fea], 1)) 103 | 104 | # Aggregation 105 | structural_fea = self.aggregation_mlp(structural_fea) 106 | 107 | return spatial_fea, structural_fea 108 | -------------------------------------------------------------------------------- /models/layers/mesh_net_with_out_neigbour.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | from torch.nn.parameter import Parameter 5 | 6 | 7 | class FaceRotateConvolution(nn.Module): 8 | 9 | def __init__(self): 10 | super(FaceRotateConvolution, self).__init__() 11 | self.rotate_mlp = nn.Sequential( 12 | nn.Conv1d(6, 32, 1), 13 | nn.BatchNorm1d(32), 14 | nn.ReLU(), 15 | nn.Conv1d(32, 32, 1), 16 | nn.BatchNorm1d(32), 17 | nn.ReLU() 18 | ) 19 | self.fusion_mlp = nn.Sequential( 20 | nn.Conv1d(32, 64, 1), 21 | nn.BatchNorm1d(64), 22 | nn.ReLU(), 23 | nn.Conv1d(64, 64, 1), 24 | nn.BatchNorm1d(64), 25 | nn.ReLU() 26 | ) 27 | 28 | def forward(self, corners): 29 | fea = (self.rotate_mlp(corners[:, :6]) + 30 | self.rotate_mlp(corners[:, 3:9]) + 31 | self.rotate_mlp(torch.cat([corners[:, 6:], corners[:, :3]], 1))) / 3 32 | return self.fusion_mlp(fea) 33 | 34 | 35 | class FaceKernelCorrelation(nn.Module): 36 | 37 | def __init__(self, num_kernel=64, sigma=0.2): 38 | super(FaceKernelCorrelation, self).__init__() 39 | self.num_kernel = num_kernel 40 | self.sigma = sigma 41 | self.weight_alpha = Parameter(torch.rand(1, num_kernel, 4) * np.pi) 42 | self.weight_beta = Parameter(torch.rand(1, num_kernel, 4) * 2 * np.pi) 43 | self.bn = nn.BatchNorm1d(num_kernel) 44 | self.relu = nn.ReLU() 45 | 46 | def forward(self, normals, neighbor_index): 47 | 48 | b, _, n = normals.size() 49 | 50 | center = normals.unsqueeze( 51 | 2).expand(-1, -1, self.num_kernel, -1).unsqueeze(4) 52 | neighbor = torch.gather(normals.unsqueeze(3).expand(-1, -1, -1, 3), 2, 53 | neighbor_index.unsqueeze(1).expand(-1, 3, -1, -1)) 54 | neighbor = neighbor.unsqueeze( 55 | 2).expand(-1, -1, self.num_kernel, -1, -1) 56 | 57 | fea = torch.cat([center, neighbor], 4) 58 | fea = fea.unsqueeze(5).expand(-1, -1, -1, -1, -1, 4) 59 | weight = torch.cat([torch.sin(self.weight_alpha) * torch.cos(self.weight_beta), 60 | torch.sin(self.weight_alpha) * 61 | torch.sin(self.weight_beta), 62 | torch.cos(self.weight_alpha)], 0) 63 | weight = weight.unsqueeze(0).expand(b, -1, -1, -1) 64 | weight = weight.unsqueeze(3).expand(-1, -1, -1, n, -1) 65 | weight = weight.unsqueeze(4).expand(-1, -1, -1, -1, 4, -1) 66 | 67 | dist = torch.sum((fea - weight)**2, 1) 68 | fea = torch.sum( 69 | torch.sum(np.e**(dist / (-2 * self.sigma**2)), 4), 3) / 16 70 | 71 | return self.relu(self.bn(fea)) 72 | 73 | 74 | class SpatialDescriptor(nn.Module): 75 | 76 | def __init__(self): 77 | super(SpatialDescriptor, self).__init__() 78 | 79 | self.spatial_mlp = nn.Sequential( 80 | nn.Conv1d(3, 64, 1), 81 | nn.BatchNorm1d(64), 82 | nn.ReLU(), 83 | nn.Conv1d(64, 64, 1), 84 | nn.BatchNorm1d(64), 85 | nn.ReLU(), 86 | ) 87 | 88 | def forward(self, centers): 89 | return self.spatial_mlp(centers) 90 | 91 | 92 | class StructuralDescriptor(nn.Module): 93 | 94 | def __init__(self): 95 | super(StructuralDescriptor, self).__init__() 96 | 97 | self.FRC = FaceRotateConvolution() 98 | self.FKC = FaceKernelCorrelation(64, 0.2) 99 | self.structural_mlp = nn.Sequential( 100 | nn.Conv1d(64 + 3, 131, 1), 101 | nn.BatchNorm1d(131), 102 | nn.ReLU(), 103 | nn.Conv1d(131, 131, 1), 104 | nn.BatchNorm1d(131), 105 | nn.ReLU(), 106 | ) 107 | 108 | def forward(self, corners, normals): 109 | structural_fea1 = self.FRC(corners) 110 | # structural_fea2 = self.FKC(normals, neighbor_index) 111 | 112 | return self.structural_mlp(torch.cat([structural_fea1, normals], 1)) 113 | 114 | 115 | class MeshConvolution(nn.Module): 116 | 117 | def __init__(self, spatial_in_channel, structural_in_channel, spatial_out_channel, structural_out_channel): 118 | super(MeshConvolution, self).__init__() 119 | 120 | self.spatial_in_channel = spatial_in_channel 121 | self.structural_in_channel = structural_in_channel 122 | self.spatial_out_channel = spatial_out_channel 123 | self.structural_out_channel = structural_out_channel 124 | 125 | self.aggregation_method = 'Concat' 126 | 127 | self.combination_mlp = nn.Sequential( 128 | nn.Conv1d(self.spatial_in_channel + 129 | self.structural_in_channel, self.spatial_out_channel, 1), 130 | nn.BatchNorm1d(self.spatial_out_channel), 131 | nn.ReLU(), 132 | ) 133 | 134 | # if self.aggregation_method == 'Concat': 135 | # self.concat_mlp = nn.Sequential( 136 | # nn.Conv2d(self.structural_in_channel * 2, 137 | # self.structural_in_channel, 1), 138 | # nn.BatchNorm2d(self.structural_in_channel), 139 | # nn.ReLU(), 140 | # ) 141 | 142 | # self.concat_mlp = nn.Sequential( 143 | # nn.Conv1d(self.structural_in_channel, 144 | # self.structural_out_channel, 1), 145 | # nn.BatchNorm1d(self.spatial_out_channel), 146 | # nn.ReLU(), 147 | # ) 148 | 149 | self.aggregation_mlp = nn.Sequential( 150 | nn.Conv1d(self.structural_in_channel + 64, 151 | self.structural_out_channel, 1), 152 | nn.BatchNorm1d(self.structural_out_channel), 153 | nn.ReLU(), 154 | ) 155 | 156 | def forward(self, spatial_fea, structural_fea, shape_norm): 157 | b, _, n = spatial_fea.size() 158 | 159 | # Combination 160 | spatial_fea = self.combination_mlp( 161 | torch.cat([spatial_fea, structural_fea], 1)) 162 | 163 | # # Aggregation 164 | # if self.aggregation_method == 'Concat': 165 | # structural_fea = torch.cat([structural_fea.unsqueeze(3).expand(-1, -1, -1, 3), 166 | # torch.gather(structural_fea.unsqueeze(3).expand(-1, -1, -1, 3), 2, 167 | # neighbor_index.unsqueeze(1).expand(-1, self.structural_in_channel, 168 | # -1, -1))], 1) 169 | # structural_fea = self.concat_mlp(structural_fea) 170 | # structural_fea = torch.max(structural_fea, 3)[0] 171 | 172 | # elif self.aggregation_method == 'Max': 173 | # structural_fea = torch.cat([structural_fea.unsqueeze(3), 174 | # torch.gather(structural_fea.unsqueeze(3).expand(-1, -1, -1, 3), 2, 175 | # neighbor_index.unsqueeze(1).expand(-1, self.structural_in_channel, 176 | # -1, -1))], 3) 177 | # structural_fea = torch.max(structural_fea, 3)[0] 178 | 179 | # elif self.aggregation_method == 'Average': 180 | # structural_fea = torch.cat([structural_fea.unsqueeze(3), 181 | # torch.gather(structural_fea.unsqueeze(3).expand(-1, -1, -1, 3), 2, 182 | # neighbor_index.unsqueeze(1).expand(-1, self.structural_in_channel, 183 | # -1, -1))], 3) 184 | # structural_fea = torch.sum(structural_fea, dim=3) / 4 185 | 186 | structural_fea = self.aggregation_mlp( 187 | torch.cat([structural_fea, shape_norm], dim=1)) # n 64+3+64 188 | 189 | return spatial_fea, structural_fea 190 | -------------------------------------------------------------------------------- /models/layers/struct_conv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class FaceVectorConv(nn.Module): 6 | def __init__(self, output_channel=64): 7 | super(FaceVectorConv, self).__init__() 8 | self.rotate_mlp = nn.Sequential( 9 | nn.Conv1d(6, 32, 1), 10 | nn.BatchNorm1d(32), 11 | nn.ReLU(), 12 | nn.Conv1d(32, 32, 1), 13 | nn.BatchNorm1d(32), 14 | nn.ReLU() 15 | ) 16 | self.fusion_mlp = nn.Sequential( 17 | nn.Conv1d(32, 64, 1), 18 | nn.BatchNorm1d(64), 19 | nn.ReLU(), 20 | nn.Conv1d(64, output_channel, 1), 21 | nn.BatchNorm1d(output_channel), 22 | nn.ReLU() 23 | ) 24 | 25 | def forward(self, x, opt): 26 | ''' 27 | x : batch_size *15 , 1024 28 | center ox oy oz norm 29 | 3 3 3 3 3 30 | ''' 31 | data = x.view(opt.batch_size, -1, 15).transpose(1, 2) 32 | xy = data[:, 3:9] 33 | yz = data[:, 6:12] 34 | xz = torch.cat([data[:, 3:6], data[:, 9:12]], dim=1) 35 | face_line = ( 36 | self.rotate_mlp(xy) + 37 | self.rotate_mlp(yz) + 38 | self.rotate_mlp(xz) 39 | ) / 3 # 64 , 64 , 1024 40 | return self.fusion_mlp(face_line) 41 | 42 | 43 | class PointConv(nn.Module): 44 | def __init__(self): 45 | super(PointConv, self).__init__() 46 | 47 | self.spatial_mlp = nn.Sequential( 48 | nn.Conv1d(3, 64, 1), 49 | nn.BatchNorm1d(64), 50 | nn.ReLU(), 51 | nn.Conv1d(64, 64, 1), 52 | nn.BatchNorm1d(64), 53 | nn.ReLU(), 54 | ) 55 | 56 | def forward(self, x, opt): 57 | ''' 58 | center ox oy oz norm 59 | ''' 60 | data = x.view(opt.batch_size, -1, 15).transpose(1, 2) 61 | x = data[:, :3] 62 | return self.spatial_mlp(x) 63 | 64 | 65 | class MeshMlp(nn.Module): 66 | def __init__(self, opt, output_channel=256): 67 | super(MeshMlp, self).__init__() 68 | self.opt = opt 69 | self.pc = PointConv() 70 | self.fvc = FaceVectorConv() 71 | self.mlp = nn.Sequential( 72 | nn.Conv1d(131, output_channel, 1), 73 | nn.BatchNorm1d(output_channel), 74 | nn.ReLU(), 75 | nn.Conv1d(output_channel, output_channel, 1), 76 | nn.BatchNorm1d(output_channel), 77 | nn.ReLU() 78 | ) 79 | 80 | def forward(self, x): 81 | # x: batch_size*15,1024 82 | data = x.view(self.opt.batch_size, -1, 15).transpose(1, 2) 83 | point_feature = self.pc(x, self.opt) # n 64 1024 84 | face_feature = self.fvc(x, self.opt) # n 64 1024 85 | norm = data[:, 12:] 86 | fusion_feature = torch.cat( 87 | [norm, point_feature, face_feature] # n 64+64+3 = 131 1024 88 | , dim=1 89 | ) 90 | return self.mlp(fusion_feature) # n 256 1024 91 | 92 | 93 | class NormMlp(nn.Module): 94 | def __init__(self, opt): 95 | super(NormMlp, self).__init__() 96 | self.opt = opt 97 | self.extra_mlp = nn.Sequential( 98 | nn.Conv1d(3, 64, 1), 99 | nn.BatchNorm1d(64), 100 | nn.ReLU(), 101 | nn.Conv1d(64, 64, 1), 102 | nn.BatchNorm1d(64), 103 | nn.ReLU() 104 | ) 105 | 106 | def forward(self, x): 107 | data = x.view(self.opt.batch_size, -1, 15).transpose(1, 2) 108 | norm = data[:, 12:] 109 | return self.extra_mlp(norm) # n 64 1024 110 | -------------------------------------------------------------------------------- /models/mesh_graph.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import os 4 | import os.path as osp 5 | import numpy as np 6 | from . import networks 7 | import torch.optim as optim 8 | from models.optimizer import adabound 9 | from torch_geometric.utils import remove_self_loops, contains_self_loops, contains_isolated_nodes 10 | import hiddenlayer as hl 11 | import cv2 12 | 13 | ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 14 | OUTPUT_DIR = os.path.join(ROOT_DIR, "feature_img") 15 | print(OUTPUT_DIR) 16 | class mesh_graph: 17 | def __init__(self, opt): 18 | self.opt = opt 19 | self.cuda = opt.cuda 20 | self.is_train = opt.is_train 21 | self.device = torch.device('cuda:{}'.format( 22 | self.cuda[0]) if self.cuda else 'cpu') 23 | self.save_dir = osp.join(opt.ckpt_root, opt.name) 24 | self.optimizer = None 25 | self.loss = None 26 | 27 | # init mesh data 28 | self.nclasses = opt.nclasses 29 | 30 | # init network 31 | self.net = networks.get_net(opt) 32 | self.net.train(self.is_train) 33 | 34 | # criterion 35 | self.loss = networks.get_loss(self.opt).to(self.device) 36 | 37 | if self.is_train: 38 | # self.optimizer = adabound.AdaBound( 39 | # params=self.net.parameters(), lr=self.opt.lr, final_lr=self.opt.final_lr) 40 | self.optimizer = optim.SGD(self.net.parameters( 41 | ), lr=opt.lr, momentum=opt.momentum, weight_decay=opt.weight_decay) 42 | self.scheduler = networks.get_scheduler(self.optimizer, self.opt) 43 | if not self.is_train or opt.continue_train: 44 | self.load_state(opt.last_epoch) 45 | 46 | # A History object to store metrics 47 | self.history = hl.History() 48 | # A Canvas object to draw the metrics 49 | self.canvas = hl.Canvas() 50 | 51 | def test(self): 52 | """tests model 53 | returns: number correct and total number 54 | """ 55 | with torch.no_grad(): 56 | out, fea = self.forward() 57 | # compute number of correct 58 | pred_class = torch.max(out, dim=1)[1] 59 | # print(pred_class) 60 | label_class = self.labels 61 | correct = self.get_accuracy(pred_class, label_class) 62 | return correct, len(label_class) 63 | 64 | def get_accuracy(self, pred, labels): 65 | """computes accuracy for classification / segmentation """ 66 | if self.opt.task == 'cls': 67 | correct = pred.eq(labels).sum() 68 | return correct 69 | 70 | def load_state(self, last_epoch): 71 | ''' load epoch ''' 72 | load_file = '%s_net.pth' % last_epoch 73 | load_path = osp.join(self.save_dir, load_file) 74 | net = self.net 75 | if isinstance(net, torch.nn.DataParallel): 76 | net = net.module 77 | print('loading the model from %s' % load_path) 78 | state_dict = torch.load(load_path, map_location=str(self.device)) 79 | if hasattr(state_dict, '_metadata'): 80 | del state_dict._metadata 81 | net.load_state_dict(state_dict) 82 | 83 | def set_input_data(self, data): 84 | '''set input data''' 85 | 86 | gt_label = data.y 87 | edge_index = data.edge_index 88 | if self.opt.batch_size == 1: 89 | nodes_features = data.x.unsqueeze(0) 90 | # neigbour_index = data.pos.unsqueeze(0) 91 | else: 92 | nodes_features = data.x.view(self.opt.batch_size, 1024, -1) 93 | neigbour_index = data.pos.view(self.opt.batch_size, -1, 3) 94 | 95 | self.labels = gt_label.to(self.device).long() 96 | self.edge_index = edge_index.to(self.device).long() 97 | 98 | self.neigbour_index = neigbour_index.to(self.device).long() 99 | self.centers = nodes_features[:, :, :3].transpose(1, 2) 100 | self.corners = nodes_features[:, :, 3:12].transpose(1, 2) 101 | self.normals = nodes_features[:, :, 12:].transpose(1, 2) 102 | self.x = data.x[:, -3:].to(self.device).float() 103 | 104 | def forward(self): 105 | out = self.net(self.x, self.edge_index, self.centers, 106 | self.corners, self.normals, self.neigbour_index) 107 | return out 108 | 109 | def backward(self, out, fea): 110 | self.loss_val = self.loss(out, self.labels) 111 | self.loss_val.backward() 112 | 113 | def optimize(self): 114 | ''' 115 | optimize paramater 116 | ''' 117 | self.optimizer.zero_grad() 118 | out, fea = self.forward() 119 | # print(self.net.module.concat_mlp[0].weight) 120 | self.backward(out, fea) 121 | nn.utils.clip_grad_norm_(self.net.parameters(), self.opt.grad_clip) 122 | self.optimizer.step() 123 | 124 | def save_network(self, which_epoch): 125 | ''' 126 | save network to disk 127 | ''' 128 | save_filename = '%s_net.pth' % (which_epoch) 129 | save_path = osp.join(self.save_dir, save_filename) 130 | if len(self.cuda) > 0 and torch.cuda.is_available(): 131 | torch.save(self.net.module.cpu().state_dict(), save_path) 132 | self.net.cuda(self.cuda[0]) 133 | else: 134 | torch.save(self.net.cpu().state_dict(), save_path) 135 | 136 | def log_history_and_plot(self, writer, epoch, batch): 137 | writer.history_log(epoch, batch, self.net.module.concat_mlp[0].weight) 138 | writer.draw_hist() 139 | 140 | def log_features_and_plot(self, epoch, batch): 141 | out, fea = self.forward() 142 | labels = self.labels 143 | fea_norm_numpy = fea.cpu().detach().numpy() 144 | cv_img = deprocess_image(fea_norm_numpy) 145 | im_color = cv2.applyColorMap(cv_img, cv2.COLORMAP_JET) 146 | cv2.imwrite(os.path.join(OUTPUT_DIR, "%d_epoch1.png" % epoch),im_color) 147 | # add canvas 148 | 149 | # self.history.log(epoch, image=im_color) 150 | # # self.canvas.draw_image(self.history["image"]) 151 | # self.canvas.save(os.path.join(OUTPUT_DIR, "%d_epoch.png" % epoch)) 152 | 153 | def deprocess_image(img): 154 | """ see https://github.com/jacobgil/keras-grad-cam/blob/master/grad-cam.py#L65 """ 155 | img = img - np.mean(img) 156 | img = img / (np.std(img) + 1e-5) 157 | img = img * 0.1 158 | img = img + 0.5 159 | img = np.clip(img, 0, 1) 160 | return np.uint8(img*255) -------------------------------------------------------------------------------- /models/networks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import init 4 | from torch.optim import lr_scheduler 5 | import torch.nn.functional as F 6 | from torch_geometric.nn import GraphConv 7 | from models.layers.struct_conv import MeshMlp, NormMlp 8 | from models.layers.mesh_net import SpatialDescriptor, StructuralDescriptor, MeshConvolution 9 | # ,SpatialDescriptor, StructuralDescriptor, MeshConvolution 10 | from models.layers.mesh_graph_conv import NormGraphConv, GINConv 11 | 12 | 13 | def init_weights(net, init_type, init_gain): 14 | def init_func(m): 15 | classname = m.__class__.__name__ 16 | if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): 17 | if init_type == 'normal': 18 | init.normal_(m.weight.data, 0.0, init_gain) 19 | elif init_type == 'xavier': 20 | init.xavier_normal_(m.weight.data, gain=init_gain) 21 | elif init_type == 'kaiming': 22 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 23 | elif init_type == 'orthogonal': 24 | init.orthogonal_(m.weight.data, gain=init_gain) 25 | else: 26 | raise NotImplementedError( 27 | 'initialization method [%s] is not implemented' % init_type) 28 | elif classname.find('BatchNorm2d') != -1: 29 | init.normal_(m.weight.data, 1.0, init_gain) 30 | init.constant_(m.bias.data, 0.0) 31 | net.apply(init_func) 32 | 33 | 34 | def init_net(net, init_type, init_gain, cuda_ids): 35 | if len(cuda_ids) > 0: 36 | assert(torch.cuda.is_available()) 37 | net.cuda(cuda_ids[0]) 38 | net = net.cuda() 39 | net = torch.nn.DataParallel(net, cuda_ids) 40 | if init_type != 'none': 41 | init_weights(net, init_type, init_gain) 42 | return net 43 | 44 | 45 | def get_net(opt): 46 | net = None 47 | if opt.arch == 'meshconv': 48 | net = MeshGraph(opt) 49 | else: 50 | raise NotImplementedError( 51 | 'model name [%s] is not implemented' % opt.arch) 52 | return init_net(net, opt.init_type, opt.init_gain, opt.cuda) 53 | 54 | 55 | def get_loss(opt): 56 | if opt.task == 'cls': 57 | loss = nn.CrossEntropyLoss() 58 | elif opt.task == 'seg': 59 | loss = nn.CrossEntropyLoss(ignore_index=-1) 60 | return loss 61 | 62 | 63 | def get_scheduler(optimizer, opt): 64 | if opt.lr_policy == 'step': 65 | scheduler = lr_scheduler.MultiStepLR( 66 | optimizer, gamma=opt.gamma, milestones=opt.milestones) 67 | else: 68 | return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy) 69 | return scheduler 70 | 71 | 72 | # class MeshGraph(nn.Module): 73 | # """Some Information about MeshGraph""" 74 | 75 | # def __init__(self, opt): 76 | # super(MeshGraph, self).__init__() 77 | # self.mesh_mlp_256 = MeshMlp(opt, 256) 78 | # self.gin_conv_256 = GINConv(self.mesh_mlp_256) 79 | 80 | # # self.graph_conv_64 = GraphConv(1024, 256) 81 | # # self.graph_conv_64 = GraphConv(256, 64) 82 | 83 | # self.classifier = nn.Sequential( 84 | # nn.Linear(256, 1024), 85 | # nn.ReLU(), 86 | # nn.Dropout(p=0.5), 87 | # nn.Linear(1024, 256), 88 | # nn.ReLU(), 89 | # nn.Dropout(p=0.5), 90 | # nn.Linear(256, 40), 91 | # ) 92 | 93 | # if opt.use_fpm: 94 | # self.mesh_mlp_64 = MeshMlp(64) 95 | # self.mesh_mlp_128 = MeshMlp(128) 96 | # self.gin_conv_64 = GINConv(self.mesh_mlp_64) 97 | # self.gin_conv_128 = GINConv(self.mesh_mlp_128) 98 | 99 | # def forward(self, nodes_features, edge_index): 100 | # x = nodes_features 101 | # edge_index = edge_index 102 | # x1 = self.gin_conv_256(x, edge_index) # 64 256 1024 103 | # # x1 = x1.view(x1.size(0), -1) 104 | # # x1 = torch.max(x1, dim=2)[0] 105 | # return self.classifier(x1) 106 | 107 | 108 | # class MeshGraph(nn.Module): 109 | # """Some Information about MeshGraph""" 110 | 111 | # def __init__(self, opt): 112 | # super(MeshGraph, self).__init__() 113 | # self.spatial_descriptor = SpatialDescriptor() 114 | # self.structural_descriptor = StructuralDescriptor() 115 | # self.mesh_conv1 = MeshConvolution(64, 131, 256, 256) 116 | # self.mesh_conv2 = MeshConvolution(256, 256, 512, 512) 117 | # self.fusion_mlp = nn.Sequential( 118 | # nn.Conv1d(1024, 1024, 1), 119 | # nn.BatchNorm1d(1024), 120 | # nn.ReLU(), 121 | # ) 122 | # self.concat_mlp = nn.Sequential( 123 | # nn.Conv1d(1792, 1024, 1), 124 | # nn.BatchNorm1d(1024), 125 | # nn.ReLU(), 126 | # ) 127 | # self.classifier = nn.Sequential( 128 | # nn.Linear(1024, 512), 129 | # nn.ReLU(), 130 | # nn.Dropout(p=0.5), 131 | # nn.Linear(512, 256), 132 | # nn.ReLU(), 133 | # nn.Dropout(p=0.5), 134 | # nn.Linear(256, 40) 135 | # ) 136 | 137 | # def forward(self, nodes_features, edge_index, centers, corners, normals, neigbour_index): 138 | # spatial_fea0 = self.spatial_descriptor(centers) 139 | # structural_fea0 = self.structural_descriptor( 140 | # corners, normals) 141 | 142 | # spatial_fea1, structural_fea1 = self.mesh_conv1( 143 | # spatial_fea0, structural_fea0) 144 | # spatial_fea2, structural_fea2 = self.mesh_conv2( 145 | # spatial_fea1, structural_fea1) 146 | 147 | # spatial_fea3 = self.fusion_mlp( 148 | # torch.cat([spatial_fea2, structural_fea2], 1)) 149 | 150 | # fea = self.concat_mlp( 151 | # torch.cat([spatial_fea1, spatial_fea2, spatial_fea3], 1)) 152 | 153 | # fea = torch.max(fea, dim=2)[0] 154 | # fea = fea.reshape(fea.size(0), -1) 155 | # fea = self.classifier[:-1](fea) 156 | # cls = self.classifier[-1:](fea) 157 | # return cls 158 | 159 | 160 | class MeshGraph(nn.Module): 161 | """Some Information about MeshGraph""" 162 | 163 | def __init__(self, opt): 164 | super(MeshGraph, self).__init__() 165 | self.opt = opt 166 | self.spatial_descriptor = SpatialDescriptor() 167 | self.structural_descriptor = StructuralDescriptor() 168 | self.shape_descriptor = NormGraphConv() 169 | 170 | self.mesh_conv1 = MeshConvolution(64, 131, 256, 256) 171 | self.mesh_conv2 = MeshConvolution(256, 256, 512, 512) 172 | self.fusion_mlp = nn.Sequential( 173 | nn.Conv1d(1024, 1024, 1), 174 | nn.BatchNorm1d(1024), 175 | nn.ReLU(), 176 | ) 177 | self.concat_mlp = nn.Sequential( 178 | nn.Conv1d(1792, 1024, 1), 179 | nn.BatchNorm1d(1024), 180 | nn.ReLU(), 181 | ) 182 | self.classifier = nn.Sequential( 183 | nn.Linear(1024, 512), 184 | nn.ReLU(), 185 | nn.Dropout(p=0.5), 186 | nn.Linear(512, 256), 187 | nn.ReLU(), 188 | nn.Dropout(p=0.5), 189 | nn.Linear(256, 40) 190 | ) 191 | 192 | def forward(self, nodes_features, edge_index, centers, corners, normals, neigbour_index): 193 | shape_norm = self.shape_descriptor( 194 | nodes_features, edge_index, self.opt) # n 64 1024 195 | 196 | spatial_fea0 = self.spatial_descriptor(centers) 197 | structural_fea0 = self.structural_descriptor( 198 | corners, normals) 199 | 200 | spatial_fea1, structural_fea1 = self.mesh_conv1( 201 | spatial_fea0, structural_fea0, shape_norm) 202 | spatial_fea2, structural_fea2 = self.mesh_conv2( 203 | spatial_fea1, structural_fea1, shape_norm) 204 | 205 | spatial_fea3 = self.fusion_mlp( 206 | torch.cat([spatial_fea2, structural_fea2], 1)) 207 | 208 | fea = self.concat_mlp( 209 | torch.cat([spatial_fea1, spatial_fea2, spatial_fea3], 1)) 210 | 211 | fea = torch.max(fea, dim=2)[0] 212 | fea = fea.reshape(fea.size(0), -1) 213 | fea = self.classifier[:-1](fea) 214 | cls = self.classifier[-1:](fea) 215 | return cls, fea 216 | -------------------------------------------------------------------------------- /models/optimizer/adabound.py: -------------------------------------------------------------------------------- 1 | ''' 2 | @inproceedings{Luo2019AdaBound, 3 | author = {Luo, Liangchen and Xiong, Yuanhao and Liu, Yan and Sun, Xu}, 4 | title = {Adaptive Gradient Methods with Dynamic Bound of Learning Rate}, 5 | booktitle = {Proceedings of the 7th International Conference on Learning Representations}, 6 | month = {May}, 7 | year = {2019}, 8 | address = {New Orleans, Louisiana} 9 | } 10 | ''' 11 | 12 | import math 13 | import torch 14 | from torch.optim import Optimizer 15 | 16 | 17 | class AdaBound(Optimizer): 18 | """Implements AdaBound algorithm. 19 | It has been proposed in `Adaptive Gradient Methods with Dynamic Bound of Learning Rate`_. 20 | Arguments: 21 | params (iterable): iterable of parameters to optimize or dicts defining 22 | parameter groups 23 | lr (float, optional): Adam learning rate (default: 1e-3) 24 | betas (Tuple[float, float], optional): coefficients used for computing 25 | running averages of gradient and its square (default: (0.9, 0.999)) 26 | final_lr (float, optional): final (SGD) learning rate (default: 0.1) 27 | gamma (float, optional): convergence speed of the bound functions (default: 1e-3) 28 | eps (float, optional): term added to the denominator to improve 29 | numerical stability (default: 1e-8) 30 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0) 31 | amsbound (boolean, optional): whether to use the AMSBound variant of this algorithm 32 | .. Adaptive Gradient Methods with Dynamic Bound of Learning Rate: 33 | https://openreview.net/forum?id=Bkg3g2R9FX 34 | """ 35 | 36 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), final_lr=0.1, gamma=1e-3, 37 | eps=1e-8, weight_decay=0, amsbound=False): 38 | if not 0.0 <= lr: 39 | raise ValueError("Invalid learning rate: {}".format(lr)) 40 | if not 0.0 <= eps: 41 | raise ValueError("Invalid epsilon value: {}".format(eps)) 42 | if not 0.0 <= betas[0] < 1.0: 43 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 44 | if not 0.0 <= betas[1] < 1.0: 45 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 46 | if not 0.0 <= final_lr: 47 | raise ValueError("Invalid final learning rate: {}".format(final_lr)) 48 | if not 0.0 <= gamma < 1.0: 49 | raise ValueError("Invalid gamma parameter: {}".format(gamma)) 50 | defaults = dict(lr=lr, betas=betas, final_lr=final_lr, gamma=gamma, eps=eps, 51 | weight_decay=weight_decay, amsbound=amsbound) 52 | super(AdaBound, self).__init__(params, defaults) 53 | 54 | self.base_lrs = list(map(lambda group: group['lr'], self.param_groups)) 55 | 56 | def __setstate__(self, state): 57 | super(AdaBound, self).__setstate__(state) 58 | for group in self.param_groups: 59 | group.setdefault('amsbound', False) 60 | 61 | def step(self, closure=None): 62 | """Performs a single optimization step. 63 | Arguments: 64 | closure (callable, optional): A closure that reevaluates the model 65 | and returns the loss. 66 | """ 67 | loss = None 68 | if closure is not None: 69 | loss = closure() 70 | 71 | for group, base_lr in zip(self.param_groups, self.base_lrs): 72 | for p in group['params']: 73 | if p.grad is None: 74 | continue 75 | grad = p.grad.data 76 | if grad.is_sparse: 77 | raise RuntimeError( 78 | 'Adam does not support sparse gradients, please consider SparseAdam instead') 79 | amsbound = group['amsbound'] 80 | 81 | state = self.state[p] 82 | 83 | # State initialization 84 | if len(state) == 0: 85 | state['step'] = 0 86 | # Exponential moving average of gradient values 87 | state['exp_avg'] = torch.zeros_like(p.data) 88 | # Exponential moving average of squared gradient values 89 | state['exp_avg_sq'] = torch.zeros_like(p.data) 90 | if amsbound: 91 | # Maintains max of all exp. moving avg. of sq. grad. values 92 | state['max_exp_avg_sq'] = torch.zeros_like(p.data) 93 | 94 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 95 | if amsbound: 96 | max_exp_avg_sq = state['max_exp_avg_sq'] 97 | beta1, beta2 = group['betas'] 98 | 99 | state['step'] += 1 100 | 101 | if group['weight_decay'] != 0: 102 | grad = grad.add(group['weight_decay'], p.data) 103 | 104 | # Decay the first and second moment running average coefficient 105 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 106 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 107 | if amsbound: 108 | # Maintains the maximum of all 2nd moment running avg. till now 109 | torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq) 110 | # Use the max. for normalizing running avg. of gradient 111 | denom = max_exp_avg_sq.sqrt().add_(group['eps']) 112 | else: 113 | denom = exp_avg_sq.sqrt().add_(group['eps']) 114 | 115 | bias_correction1 = 1 - beta1 ** state['step'] 116 | bias_correction2 = 1 - beta2 ** state['step'] 117 | step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1 118 | 119 | # Applies bounds on actual learning rate 120 | # lr_scheduler cannot affect final_lr, this is a workaround to apply lr decay 121 | final_lr = group['final_lr'] * group['lr'] / base_lr 122 | lower_bound = final_lr * (1 - 1 / (group['gamma'] * state['step'] + 1)) 123 | upper_bound = final_lr * (1 + 1 / (group['gamma'] * state['step'])) 124 | step_size = torch.full_like(denom, step_size) 125 | step_size.div_(denom).clamp_(lower_bound, upper_bound).mul_(exp_avg) 126 | 127 | p.data.add_(-step_size) 128 | 129 | return loss -------------------------------------------------------------------------------- /options/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JsBlueCat/MeshGraph/c7b331f64b70a442d1042351efe4dc4231e8b289/options/__init__.py -------------------------------------------------------------------------------- /options/base_options.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import torch 4 | from util import util 5 | 6 | 7 | class base_options: 8 | def __init__(self): 9 | self.parser = argparse.ArgumentParser( 10 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 11 | self.is_init = False 12 | 13 | def initialize(self): 14 | # dataset 15 | self.parser.add_argument( 16 | '--datasets', required=True, help='dataset to be train/test' 17 | ) 18 | self.parser.add_argument( 19 | '--task', choices={'cls', 'seg'}, default='cls', help='task for network to loader model' 20 | ) 21 | self.parser.add_argument( 22 | '--nclasses', type=int, default=40, help='num classes for classify' 23 | ) 24 | # general arg 25 | self.parser.add_argument( 26 | '--cuda', type=str, default='0', help='cuda device number e.g. 0 0,1,2, 0,2. use -1 for CPU' 27 | ) 28 | self.parser.add_argument( 29 | '--ckpt_root', type=str, default='./ckpt_root', help='model saved path' 30 | ) 31 | self.parser.add_argument( 32 | '--ckpt', type=str, default='./ckpt', help='final model saved path' 33 | ) 34 | self.parser.add_argument( 35 | '--name', type=str, default='debug', help='model saved path' 36 | ) 37 | # network parameter 38 | self.parser.add_argument( 39 | '--batch_size', type=int, default=64, help='mini batch size' 40 | ) 41 | self.parser.add_argument( 42 | '--arch', type=str, default='meshconv', help='model arch' 43 | ) 44 | self.parser.add_argument( 45 | '--init_type', type=str, default='xavier', help='network initialization [normal|xavier|kaiming|orthogonal]' 46 | ) 47 | self.parser.add_argument( 48 | '--init_gain', type=float, default=0.02, help='scaling factor for normal, xavier and orthogonal.' 49 | ) 50 | self.parser.add_argument( 51 | '--use_fpm', action='store_true', help='use fpm model to catch feature' 52 | ) 53 | self.parser.add_argument( 54 | '--milestones', default='30,60', help='milestones for MultiStepLR' 55 | ) 56 | 57 | def parse(self): 58 | if not self.is_init: 59 | self.initialize() 60 | self.opt, _ = self.parser.parse_known_args() 61 | self.opt.is_train = self.is_train # train or test 62 | 63 | cuda_ids = self.opt.cuda.split(',') 64 | self.opt.cuda = [] 65 | self.opt.milestones = [int(x) for x in self.opt.milestones.split(',')] 66 | 67 | for cuda_id in cuda_ids: 68 | id = int(cuda_id) 69 | if id >= 0: 70 | self.opt.cuda.append(id) 71 | # set gpu id 72 | if len(self.opt.cuda) > 0: 73 | torch.cuda.set_device(self.opt.cuda[0]) 74 | 75 | args = vars(self.opt) 76 | 77 | if self.is_train: 78 | print('------------ Options -------------') 79 | for k, v in sorted(args.items()): 80 | print('%s: %s' % (str(k), str(v))) 81 | print('-------------- End ----------------') 82 | 83 | # check dir 84 | util.check_dir(os.path.join(self.opt.ckpt_root, self.opt.name)) 85 | util.check_dir(os.path.join(self.opt.ckpt, self.opt.name)) 86 | 87 | # save train options 88 | expr_dir = os.path.join(self.opt.ckpt_root, self.opt.name) 89 | file_name = os.path.join(expr_dir, 'opt.txt') 90 | 91 | with open(file_name, 'wt') as opt_file: 92 | opt_file.write('------------ Options -------------\n') 93 | for k, v in sorted(args.items()): 94 | opt_file.write('%s: %s\n' % (str(k), str(v))) 95 | opt_file.write('-------------- End ----------------\n') 96 | return self.opt 97 | -------------------------------------------------------------------------------- /options/test_options.py: -------------------------------------------------------------------------------- 1 | from .base_options import base_options 2 | 3 | 4 | class test_options(base_options): 5 | def initialize(self): 6 | base_options.initialize(self) 7 | self.parser.add_argument( 8 | '--last_epoch', default='latest', help='which epoch to load?' 9 | ) 10 | self.is_train = False 11 | -------------------------------------------------------------------------------- /options/train_options.py: -------------------------------------------------------------------------------- 1 | from .base_options import base_options 2 | 3 | 4 | class train_options(base_options): 5 | def initialize(self): 6 | base_options.initialize(self) 7 | self.parser.add_argument( 8 | '--continue_train', action='store_true', help='continue training: load the latest model' 9 | ) 10 | self.parser.add_argument( 11 | '--last_epoch', default='latest', help='which epoch to load?' 12 | ) 13 | # optimizer param 14 | self.parser.add_argument( 15 | '--lr', default=0.01, type=float, help='learning rate' 16 | ) 17 | self.parser.add_argument( 18 | '--final_lr', default=0.1, type=float, help='final learning rate' 19 | ) 20 | self.parser.add_argument( 21 | '--epoch', type=int, default=200, help='training epoch' 22 | ) 23 | self.parser.add_argument( 24 | '--frequency', type=int, default=10, help='training epoch' 25 | ) 26 | self.parser.add_argument( 27 | '--epoch_frequency', type=int, default=1, help='epoch to print' 28 | ) 29 | self.parser.add_argument( 30 | '--loop_frequency', type=int, default=100, help='iters epoch to print' 31 | ) 32 | self.parser.add_argument( 33 | '--test_frequency', type=int, default=1, help='test epoch' 34 | ) 35 | self.parser.add_argument( 36 | '--lr_policy', default='step', type=str, help='learning rate policy: step|' 37 | ) 38 | self.parser.add_argument( 39 | '--gamma', default=0.1, type=float, help='gamma for MultiStepLR' 40 | ) 41 | self.parser.add_argument( 42 | '--momentum', default=0.9, type=float 43 | ) 44 | self.parser.add_argument( 45 | '--weight_decay', default=0.0005, type=float 46 | ) 47 | self.parser.add_argument('--grad_clip', type=float, 48 | default=5, help='gradient clipping') 49 | 50 | # model 51 | self.is_train = True 52 | -------------------------------------------------------------------------------- /preprocess/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JsBlueCat/MeshGraph/c7b331f64b70a442d1042351efe4dc4231e8b289/preprocess/__init__.py -------------------------------------------------------------------------------- /preprocess/preprocess.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.cpp_extension import load 3 | from torch_geometric.utils import to_undirected, remove_self_loops 4 | # from models.layers.mesh_cpp_extension import Mesh 5 | from models.layers.mesh import Mesh 6 | import time 7 | 8 | 9 | def maybe_num_nodes(index, num_nodes=None): 10 | return index.max().item() + 1 if num_nodes is None else num_nodes 11 | 12 | 13 | def segregate_self_loops(edge_index, edge_attr=None): 14 | r"""Segregates self-loops from the graph. 15 | 16 | Args: 17 | edge_index (LongTensor): The edge indices. 18 | edge_attr (Tensor, optional): Edge weights or multi-dimensional 19 | edge features. (default: :obj:`None`) 20 | 21 | :rtype: (:class:`LongTensor`, :class:`Tensor`, :class:`LongTensor`, 22 | :class:`Tensor`) 23 | """ 24 | 25 | mask = edge_index[0] != edge_index[1] 26 | inv_mask = ~mask 27 | 28 | loop_edge_index = edge_index[:, inv_mask] 29 | loop_edge_attr = None if edge_attr is None else edge_attr[inv_mask] 30 | edge_index = edge_index[:, mask] 31 | edge_attr = None if edge_attr is None else edge_attr[mask] 32 | 33 | return edge_index, edge_attr, loop_edge_index, loop_edge_attr 34 | 35 | 36 | def remove_isolated_nodes(edge_index, edge_attr=None, num_nodes=None): 37 | r"""Removes the isolated nodes from the graph given by :attr:`edge_index` 38 | with optional edge attributes :attr:`edge_attr`. 39 | In addition, returns a mask of shape :obj:`[num_nodes]` to manually filter 40 | out isolated node features later on. 41 | Self-loops are preserved for non-isolated nodes. 42 | 43 | Args: 44 | edge_index (LongTensor): The edge indices. 45 | edge_attr (Tensor, optional): Edge weights or multi-dimensional 46 | edge features. (default: :obj:`None`) 47 | num_nodes (int, optional): The number of nodes, *i.e.* 48 | :obj:`max_val + 1` of :attr:`edge_index`. (default: :obj:`None`) 49 | 50 | :rtype: (LongTensor, Tensor, ByteTensor) 51 | """ 52 | num_nodes = maybe_num_nodes(edge_index, num_nodes) 53 | 54 | out = segregate_self_loops(edge_index, edge_attr) 55 | edge_index, edge_attr, loop_edge_index, loop_edge_attr = out 56 | 57 | mask = torch.zeros(num_nodes, dtype=torch.uint8, device=edge_index.device) 58 | mask[edge_index.view(-1)] = 1 59 | 60 | assoc = torch.full((num_nodes, ), -1, dtype=torch.long, device=mask.device) 61 | assoc[mask] = torch.arange(mask.sum(), device=assoc.device) 62 | edge_index = assoc[edge_index] 63 | 64 | loop_mask = torch.zeros_like(mask) 65 | loop_mask[loop_edge_index[0]] = 1 66 | loop_mask = loop_mask & mask 67 | loop_assoc = torch.full_like(assoc, -1) 68 | loop_assoc[loop_edge_index[0]] = torch.arange( 69 | loop_edge_index.size(1), device=loop_assoc.device) 70 | loop_idx = loop_assoc[loop_mask] 71 | loop_edge_index = assoc[loop_edge_index[:, loop_idx]] 72 | 73 | edge_index = torch.cat([edge_index, loop_edge_index], dim=1) 74 | 75 | if edge_attr is not None: 76 | loop_edge_attr = loop_edge_attr[loop_idx] 77 | edge_attr = torch.cat([edge_attr, loop_edge_attr], dim=0) 78 | 79 | return edge_index, edge_attr, mask 80 | 81 | class FaceToGraph(object): 82 | r"""Converts mesh faces :obj:`[3, num_faces]` to graph. 83 | 84 | Args: 85 | remove_faces (bool, optional): If set to :obj:`False`, the face tensor 86 | will not be removed. 87 | """ 88 | 89 | def __init__(self, remove_faces=True): 90 | # self.mesh_graph_cpp = load(name='meshgraph_cpp', sources=[ 91 | # 'models/layers/meshgraph.cpp']) 92 | self.remove_faces = remove_faces 93 | self.count = 0 94 | 95 | def __call__(self, data): 96 | start_time = time.time() 97 | data.num_nodes = data.x.size(0) 98 | edge_index = to_undirected(data.edge_index, data.num_nodes) 99 | # edge_index, _ = remove_self_loops(edge_index) 100 | 101 | print('%d-th mesh,size: %d' % (self.count, data.x.size(1))) 102 | 103 | data.edge_index = edge_index 104 | end_time = time.time() 105 | print('take {} s time for translate'.format(end_time-start_time)) 106 | if self.remove_faces: 107 | data.face = None 108 | self.count += 1 109 | return data 110 | 111 | def __repr__(self): 112 | return '{}()'.format(self.__class__.__name__) 113 | 114 | 115 | 116 | # class FaceToGraph(object): 117 | # r"""Converts mesh faces :obj:`[3, num_faces]` to graph. 118 | 119 | # Args: 120 | # remove_faces (bool, optional): If set to :obj:`False`, the face tensor 121 | # will not be removed. 122 | # """ 123 | 124 | # def __init__(self, remove_faces=True): 125 | # # self.mesh_graph_cpp = load(name='meshgraph_cpp', sources=[ 126 | # # 'models/layers/meshgraph.cpp']) 127 | # self.remove_faces = remove_faces 128 | # self.count = 0 129 | 130 | # def __call__(self, data): 131 | # start_time = time.time() 132 | # print('start transform') 133 | # mesh_grap = Mesh(data.pos, data.face) 134 | # # set the center ox oy oz unit_norm 135 | # data.x = mesh_grap.nodes 136 | # print(data.x) 137 | # data.num_nodes = data.x.size(0) 138 | # edge_index = to_undirected(mesh_grap.edge_index.t(), data.num_nodes) 139 | # # edge_index, _ = remove_self_loops(edge_index) 140 | 141 | # print('%d-th mesh,size: %d' % (self.count, data.x.size(0))) 142 | # # set edge_index to data 143 | # data.edge_index = edge_index 144 | # end_time = time.time() 145 | # print('take {} s time for translate'.format(end_time-start_time)) 146 | # mesh_grap = None 147 | # if self.remove_faces: 148 | # data.face = None 149 | # self.count += 1 150 | # return data 151 | 152 | # def __repr__(self): 153 | # return '{}()'.format(self.__class__.__name__) 154 | 155 | 156 | class FaceToEdge(object): 157 | r"""Converts mesh faces :obj:`[3, num_faces]` to edge indices 158 | :obj:`[2, num_edges]`. 159 | 160 | Args: 161 | remove_faces (bool, optional): If set to :obj:`False`, the face tensor 162 | will not be removed. 163 | """ 164 | 165 | def __init__(self, remove_faces=True): 166 | self.remove_faces = remove_faces 167 | self.count = 0 168 | 169 | def __call__(self, data): 170 | print(self.count) 171 | face = data.face 172 | 173 | edge_index = torch.cat([face[:2], face[1:], face[::2]], dim=1) 174 | edge_index = to_undirected(edge_index, num_nodes=data.num_nodes) 175 | 176 | data.edge_index = edge_index 177 | if self.remove_faces: 178 | data.face = None 179 | self.count += 1 180 | return data 181 | 182 | def __repr__(self): 183 | return '{}()'.format(self.__class__.__name__) 184 | -------------------------------------------------------------------------------- /script/human_seg/get_data.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | DATADIR='datasets' #location where data gets downloaded to 4 | 5 | # get data 6 | echo "downloading the data and putting it in: " $DATADIR 7 | mkdir -p $DATADIR && cd $DATADIR 8 | wget https://www.dropbox.com/s/s3n05sw0zg27fz3/human_seg.tar.gz 9 | tar -xzvf human_seg.tar.gz && rm human_seg.tar.gz -------------------------------------------------------------------------------- /script/human_seg/get_pretrained.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | -------------------------------------------------------------------------------- /script/human_seg/test.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | ## run the test and export collapses 4 | -------------------------------------------------------------------------------- /script/human_seg/train.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | ## run the training 4 | -------------------------------------------------------------------------------- /script/human_seg/view.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | python util/mesh_viewer.py \ 4 | --files \ 5 | ckpt/human_seg/meshes/shrec__14_0.obj -------------------------------------------------------------------------------- /script/modelnet10/train.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | ## run the training 4 | python train.py \ 5 | --datasets datasets/modelnet10 \ 6 | --name modelnet10 \ 7 | --batch_size 1 \ 8 | --nclasses 10 -------------------------------------------------------------------------------- /script/modelnet40_graph/test.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | ## run the test 4 | python test.py \ 5 | --datasets datasets/modelnet40_graph \ 6 | --name 40_graph \ 7 | --batch_size 64 \ 8 | --nclasses 40 \ 9 | --last_epoch final -------------------------------------------------------------------------------- /script/modelnet40_graph/train.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | ## run the training 4 | python train.py \ 5 | --datasets datasets/modelnet40_graph \ 6 | --name 40_graph \ 7 | --batch_size 64 \ 8 | --nclasses 40 \ 9 | --epoch 300 \ 10 | --init_type orthogonal -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | from options.test_options import test_options 2 | from data.ModelNet import ModelNet 3 | from preprocess.preprocess import FaceToGraph 4 | from torch_geometric.data import DataLoader 5 | from models import create_model 6 | from util.writer import Writer 7 | 8 | 9 | def run_test(epoch=-1): 10 | print('Running Test') 11 | opt = test_options().parse() 12 | dataset = ModelNet(root=opt.datasets, name='40_graph', train=False, 13 | pre_transform=FaceToGraph(remove_faces=True)) 14 | loader = DataLoader(dataset, batch_size=opt.batch_size, shuffle=False) 15 | model = create_model(opt) 16 | writer = Writer(opt) 17 | writer.reset_counter() 18 | for i, data in enumerate(loader): 19 | if data.y.size(0) % 64 != 0: 20 | continue 21 | model.set_input_data(data) 22 | ncorrect, nexamples = model.test() 23 | writer.update_counter(ncorrect, nexamples) 24 | writer.print_acc(epoch, writer.acc) 25 | return writer.acc 26 | 27 | 28 | if __name__ == '__main__': 29 | run_test() 30 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import time 2 | from options.train_options import train_options 3 | from data.ModelNet import ModelNet 4 | from preprocess.preprocess import FaceToGraph, FaceToEdge 5 | from torch_geometric.data import DataLoader 6 | from models import create_model 7 | from util.writer import Writer 8 | from test import run_test 9 | 10 | if __name__ == '__main__': 11 | opt = train_options().parse() 12 | 13 | # load dataset 14 | dataset = ModelNet(root=opt.datasets, name=str(opt.name), 15 | pre_transform=FaceToGraph(remove_faces=True)) 16 | print('# training meshes = %d' % len(dataset)) 17 | loader = DataLoader(dataset, batch_size=opt.batch_size, shuffle=True) 18 | model = create_model(opt) 19 | writer = Writer(opt) 20 | total_steps = 0 21 | 22 | for epoch in range(1, opt.epoch): 23 | start_time = time.time() 24 | count = 0 25 | running_loss = 0.0 26 | for i, data in enumerate(loader): 27 | # break 28 | if data.y.size(0) % 64 != 0: 29 | continue 30 | total_steps += opt.batch_size 31 | count += opt.batch_size 32 | model.set_input_data(data) 33 | model.optimize() 34 | running_loss += model.loss_val 35 | if total_steps % opt.frequency == 0: 36 | loss_val = running_loss/opt.frequency 37 | writer.print_loss(epoch, count, loss_val) 38 | writer.plot_loss(epoch, count, loss_val, len(dataset)) 39 | running_loss = 0 40 | 41 | if i % opt.loop_frequency == 0: 42 | print('saving the latest model (epoch %d, total_steps %d)' % 43 | (epoch, total_steps)) 44 | model.save_network('latest') 45 | # break 46 | 47 | if epoch % opt.epoch_frequency == 0: 48 | print('saving the model at the end of epoch %d, iters %d' % 49 | (epoch, total_steps)) 50 | if (epoch-1) % 20 == 0: 51 | model.log_history_and_plot(writer, epoch, count) 52 | model.log_features_and_plot(epoch, count) 53 | model.save_network('latest') 54 | model.save_network(epoch) 55 | 56 | if epoch % opt.test_frequency == 0: 57 | acc = run_test(epoch) 58 | writer.plot_acc(acc, epoch) 59 | # break 60 | wait = input("input") 61 | writer.close() 62 | -------------------------------------------------------------------------------- /util/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JsBlueCat/MeshGraph/c7b331f64b70a442d1042351efe4dc4231e8b289/util/__init__.py -------------------------------------------------------------------------------- /util/mesh_viewer.py: -------------------------------------------------------------------------------- 1 | import mpl_toolkits.mplot3d as a3 2 | import matplotlib.colors as colors 3 | import pylab as pl 4 | import numpy as np 5 | 6 | V = np.array 7 | 8 | 9 | def r2h(x): return colors.rgb2hex(tuple(map(lambda y: y / 255., x))) 10 | 11 | 12 | surface_color = r2h((255, 230, 205)) 13 | edge_color = r2h((90, 90, 90)) 14 | edge_colors = (r2h((15, 167, 175)), r2h((230, 81, 81)), r2h((142, 105, 252)), r2h((248, 235, 57)), 15 | r2h((51, 159, 255)), r2h((225, 117, 231)), r2h((97, 243, 185)), r2h((161, 183, 196))) 16 | 17 | 18 | def init_plot(): 19 | ax = pl.figure().add_subplot(111, projection='3d') 20 | # hide axis, thank to 21 | # https://stackoverflow.com/questions/29041326/3d-plot-with-matplotlib-hide-axes-but-keep-axis-labels/ 22 | ax.w_xaxis.set_pane_color((1.0, 1.0, 1.0, 0.0)) 23 | ax.w_yaxis.set_pane_color((1.0, 1.0, 1.0, 0.0)) 24 | ax.w_zaxis.set_pane_color((1.0, 1.0, 1.0, 0.0)) 25 | # Get rid of the spines 26 | ax.w_xaxis.line.set_color((1.0, 1.0, 1.0, 0.0)) 27 | ax.w_yaxis.line.set_color((1.0, 1.0, 1.0, 0.0)) 28 | ax.w_zaxis.line.set_color((1.0, 1.0, 1.0, 0.0)) 29 | # Get rid of the ticks 30 | ax.set_xticks([]) 31 | ax.set_yticks([]) 32 | ax.set_zticks([]) 33 | return (ax, [np.inf, -np.inf, np.inf, -np.inf, np.inf, -np.inf]) 34 | 35 | 36 | def update_lim(mesh, plot): 37 | vs = mesh[0] 38 | for i in range(3): 39 | plot[1][2 * i] = min(plot[1][2 * i], vs[:, i].min()) 40 | plot[1][2 * i + 1] = max(plot[1][2 * i], vs[:, i].max()) 41 | return plot 42 | 43 | 44 | def update_plot(mesh, plot): 45 | if plot is None: 46 | plot = init_plot() 47 | return update_lim(mesh, plot) 48 | 49 | 50 | def surfaces(mesh, plot): 51 | vs, faces, edges = mesh 52 | vtx = vs[faces] 53 | edgecolor = edge_color if not len(edges) else 'none' 54 | tri = a3.art3d.Poly3DCollection(vtx, facecolors=surface_color + '55', edgecolors=edgecolor, 55 | linewidths=.5, linestyles='dashdot') 56 | plot[0].add_collection3d(tri) 57 | return plot 58 | 59 | 60 | def segments(mesh, plot): 61 | vs, _, edges = mesh 62 | for edge_c, edge_group in enumerate(edges): 63 | for edge_idx in edge_group: 64 | edge = vs[edge_idx] 65 | line = a3.art3d.Line3DCollection( 66 | [edge], linewidths=.5, linestyles='dashdot') 67 | line.set_color(edge_colors[edge_c % len(edge_colors)]) 68 | plot[0].add_collection3d(line) 69 | return plot 70 | 71 | 72 | def plot_mesh(mesh, *whats, show=True, plot=None): 73 | for what in [update_plot] + list(whats): 74 | plot = what(mesh, plot) 75 | if show: 76 | li = max(plot[1][1], plot[1][3], plot[1][5]) 77 | plot[0].auto_scale_xyz([0, li], [0, li], [0, li]) 78 | pl.tight_layout() 79 | pl.show() 80 | return plot 81 | 82 | 83 | def parse_obje(obj_file, scale_by): 84 | vs = [] 85 | faces = [] 86 | edges = [] 87 | 88 | def add_to_edges(): 89 | if edge_c >= len(edges): 90 | for _ in range(len(edges), edge_c + 1): 91 | edges.append([]) 92 | edges[edge_c].append(edge_v) 93 | 94 | def fix_vertices(): 95 | nonlocal vs, scale_by 96 | vs = V(vs) 97 | z = vs[:, 2].copy() 98 | vs[:, 2] = vs[:, 1] 99 | vs[:, 1] = z 100 | max_range = 0 101 | for i in range(3): 102 | min_value = np.min(vs[:, i]) 103 | max_value = np.max(vs[:, i]) 104 | max_range = max(max_range, max_value - min_value) 105 | vs[:, i] -= min_value 106 | if not scale_by: 107 | scale_by = max_range 108 | vs /= scale_by 109 | 110 | with open(obj_file) as f: 111 | for line in f: 112 | line = line.strip() 113 | splitted_line = line.split() 114 | if not splitted_line: 115 | continue 116 | elif splitted_line[0] == 'v': 117 | vs.append([float(v) for v in splitted_line[1:]]) 118 | elif splitted_line[0] == 'f': 119 | faces.append([int(c) - 1 for c in splitted_line[1:]]) 120 | elif splitted_line[0] == 'e': 121 | if len(splitted_line) >= 4: 122 | edge_v = [int(c) - 1 for c in splitted_line[1:-1]] 123 | edge_c = int(splitted_line[-1]) 124 | add_to_edges() 125 | 126 | vs = V(vs) 127 | fix_vertices() 128 | faces = V(faces, dtype=int) 129 | edges = [V(c, dtype=int) for c in edges] 130 | return (vs, faces, edges), scale_by 131 | 132 | 133 | def view_meshes(*files, offset=.2): 134 | plot = None 135 | max_x = 0 136 | scale = 0 137 | for file in files: 138 | mesh, scale = parse_obje(file, scale) 139 | max_x_current = mesh[0][:, 0].max() 140 | mesh[0][:, 0] += max_x + offset 141 | plot = plot_mesh(mesh, surfaces, segments, 142 | plot=plot, show=file == files[-1]) 143 | max_x += max_x_current + offset 144 | 145 | 146 | if __name__ == '__main__': 147 | import argparse 148 | parser = argparse.ArgumentParser("view meshes") 149 | parser.add_argument('--files', nargs='+', default=['/meshgraph/ckpt_root/airplane_0124.obj', 150 | #'/meshgraph/ckpt_root/bed_0010.obj', 151 | #'/meshgraph/ckpt_root/cup_0007.obj', 152 | #'/meshgraph/ckpt_root/person_0004.obj', 153 | #'/meshgraph/ckpt_root/toilet_0006.obj', 154 | #'/meshgraph/ckpt_root/tv_stand_0004.obj', 155 | #'/meshgraph/ckpt_root/xbox_0005.obj' 156 | ], type=str, 157 | help="list of 1 or more .obj files") 158 | args = parser.parse_args() 159 | 160 | # view meshes 161 | view_meshes(*args.files) 162 | -------------------------------------------------------------------------------- /util/util.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import numpy as np 4 | 5 | def check_dir(dir): 6 | if not osp.exists(dir): 7 | os.makedirs(dir) 8 | 9 | 10 | -------------------------------------------------------------------------------- /util/writer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import random 4 | import numpy as np 5 | import torch 6 | import torchvision.models 7 | import torch.nn as nn 8 | from torchvision import datasets, transforms 9 | import hiddenlayer as hl 10 | 11 | try: 12 | from tensorboardX import SummaryWriter 13 | except ImportError as error: 14 | raise('tensorboardX is not available, please install it.') 15 | SummaryWriter = None 16 | 17 | 18 | class Writer: 19 | def __init__(self, opt): 20 | self.opt = opt 21 | self.name = opt.name 22 | self.save_path = os.path.join(opt.ckpt_root, opt.name) 23 | self.train_loss = os.path.join(self.save_path, 'train_loss.txt') 24 | self.test_loss = os.path.join(self.save_path, 'test_loss.txt') 25 | 26 | # set display 27 | if opt.is_train and SummaryWriter is not None: 28 | self.display = SummaryWriter() # comment=opt.name 29 | else: 30 | self.display = None 31 | 32 | self.start_logs() 33 | self.nexamples = 0 34 | self.ncorrect = 0 35 | 36 | # A History object to store metrics 37 | self.history = hl.History() 38 | 39 | # A Canvas object to draw the metrics 40 | self.canvas = hl.Canvas() 41 | 42 | def start_logs(self): 43 | ''' create log file''' 44 | if self.opt.is_train: 45 | with open(self.train_loss, 'a') as train_loss: 46 | now = time.strftime('%c') 47 | train_loss.write( 48 | '================ Training Loss (%s) ================\n' % now) 49 | else: 50 | with open(self.test_loss, 'a') as test_loss: 51 | now = time.strftime('%c') 52 | test_loss.write( 53 | '================ Test Loss (%s) ================\n' % now) 54 | 55 | def reset_counter(self): 56 | """ 57 | counts # of correct examples 58 | """ 59 | self.ncorrect = 0 60 | self.nexamples = 0 61 | 62 | def update_counter(self, ncorrect, nexamples): 63 | self.ncorrect += ncorrect 64 | self.nexamples += nexamples 65 | 66 | def print_loss(self, epoch, iters, loss): 67 | print('epoch : %d, iter : %d , loss : %.3f' % 68 | (epoch, iters, loss.item())) 69 | with open(self.train_loss, 'a') as train_loss: 70 | train_loss.write('epoch : %d, iter : %d , loss : %.3f\n' % 71 | (epoch, iters, loss.item())) 72 | 73 | def plot_loss(self, epoch, i, loss, n): 74 | train_data_iter = i + (epoch-1) * n 75 | if self.display: 76 | self.display.add_scalar( 77 | 'data/train_loss', loss.item(), train_data_iter) 78 | 79 | def plot_acc(self, acc, epoch): 80 | if self.display: 81 | self.display.add_scalar('data/test_acc', acc, epoch-1) 82 | 83 | def print_acc(self, epoch, acc): 84 | """ prints test accuracy to terminal / file """ 85 | message = 'epoch: {}, TEST ACC: [{:.5} %]\n' \ 86 | .format(epoch, acc * 100) 87 | print(message) 88 | with open(self.test_loss, "a") as log_file: 89 | log_file.write('%s\n' % message) 90 | 91 | def history_log(self, epoch, batch, weight): 92 | self.history.log((epoch, batch), global_feature_wight=weight) 93 | 94 | def draw_hist(self): 95 | self.canvas.draw_hist(self.history["global_feature_wight"]) 96 | 97 | @property 98 | def acc(self): 99 | return float(self.ncorrect) / self.nexamples 100 | 101 | def close(self): 102 | if self.display is not None: 103 | self.display.close() 104 | --------------------------------------------------------------------------------