├── .gitignore
├── LICENSE
├── README.md
├── data
├── HumanSeg.py
├── ModelNet.py
├── load_npz.py
└── load_obj.py
├── docker
└── Dockerfile
├── images
├── 1.png
├── 2.png
├── 3.png
├── 4.png
├── 5.png
└── 6.png
├── models
├── __init__.py
├── layers
│ ├── cpp_extension
│ │ ├── meshgraph.cpp
│ │ └── setup.py
│ ├── mesh.py
│ ├── mesh_cpp_extension.py
│ ├── mesh_graph_conv.py
│ ├── mesh_net.py
│ ├── mesh_net_with_out_neigbour.py
│ └── struct_conv.py
├── mesh_graph.py
├── networks.py
└── optimizer
│ └── adabound.py
├── options
├── __init__.py
├── base_options.py
├── test_options.py
└── train_options.py
├── preprocess
├── __init__.py
└── preprocess.py
├── script
├── human_seg
│ ├── get_data.sh
│ ├── get_pretrained.sh
│ ├── test.sh
│ ├── train.sh
│ └── view.sh
├── modelnet10
│ └── train.sh
└── modelnet40_graph
│ ├── test.sh
│ └── train.sh
├── test.py
├── train.py
└── util
├── __init__.py
├── mesh_viewer.py
├── util.py
└── writer.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | *.egg-info/
24 | .installed.cfg
25 | *.egg
26 | MANIFEST
27 |
28 | # PyInstaller
29 | # Usually these files are written by a python script from a template
30 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
31 | *.manifest
32 | *.spec
33 |
34 | # Installer logs
35 | pip-log.txt
36 | pip-delete-this-directory.txt
37 |
38 | # Unit test / coverage reports
39 | htmlcov/
40 | .tox/
41 | .coverage
42 | .coverage.*
43 | .cache
44 | nosetests.xml
45 | coverage.xml
46 | *.cover
47 | .hypothesis/
48 | .pytest_cache/
49 |
50 | # Translations
51 | *.mo
52 | *.pot
53 |
54 | # Django stuff:
55 | *.log
56 | local_settings.py
57 | db.sqlite3
58 |
59 | # Flask stuff:
60 | instance/
61 | .webassets-cache
62 |
63 | # Scrapy stuff:
64 | .scrapy
65 |
66 | # Sphinx documentation
67 | docs/_build/
68 |
69 | # PyBuilder
70 | target/
71 |
72 | # Jupyter Notebook
73 | .ipynb_checkpoints
74 |
75 | # pyenv
76 | .python-version
77 |
78 | # celery beat schedule file
79 | celerybeat-schedule
80 |
81 | # SageMath parsed files
82 | *.sage.py
83 |
84 | # Environments
85 | .env
86 | .venv
87 | env/
88 | venv/
89 | ENV/
90 | env.bak/
91 | venv.bak/
92 |
93 | # Spyder project settings
94 | .spyderproject
95 | .spyproject
96 |
97 | # Rope project settings
98 | .ropeproject
99 |
100 | # mkdocs documentation
101 | /site
102 |
103 | # mypy
104 | feature_img/
105 | .mypy_cache/
106 | datasets/
107 | cal_data.py
108 | runs/
109 | ckpt/
110 | ckpt_root/
111 | demo.py
112 | *.jpg
113 | *.png
114 | *.txt
115 | test_*
116 | demo.py
117 | draw_features.py
118 | !images/*
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2019 Cery D
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # MeshGraph in PyTorch
2 |
3 | Transfroming Mesh data to Mesh Graph Tology through the idea of Finite Element, Paper is published on IEEE [url](https://ieeexplore.ieee.org/document/9253518).
4 | ## Transform to Topology Graph
5 | 
6 | ## Network Structure
7 | 
8 | # Getting Started
9 |
10 | ### Installation
11 | - Clone this repo:
12 | ``` bash
13 | git clone https://github.com/JsBlueCat/MeshGraph.git
14 | cd MeshVertexNet
15 | ```
16 | - Install dependencies: [nvidia-docker](https://github.com/NVIDIA/nvidia-docker) and [docker](https://docs.docker.com/get-started/)
17 |
18 | - First intall docker image
19 |
20 | ```bash
21 | cd docker
22 | docker build -t your/docker:meshgraph.
23 | ```
24 |
25 | - then run docker image
26 | ```bash
27 | docker run --rm -it --runtime=nvidia --shm-size 16G -e DISPLAY=unix$DISPLAY -v /tmp/.X11-unix:/tmp/.X11-unix -v /your/path/to/MeshGraph/:/meshgraph your/docker:meshgraph bash
28 | ```
29 |
30 |
31 | ### 3D Shape Classification on ModelNet40
32 | - get the dataset if you fail download it from python [[Model40]](https://drive.google.com/uc?export=download&confirm=HB4c&id=1o9pyskkKMxuomI5BWuLjCG2nSv5iePZz)
33 | - put the zip file in datasets/modelnet40_graph
34 |
35 | ```bash
36 | cd /meshvertex
37 | sh script/modelnet40_graph/train.sh
38 | ```
39 |
40 |
41 | ### Classification Acc. On ModelNet40
42 | | Method | Modality | Acc |
43 | | ----------------------- |:--------:|:--------:|
44 | | 3DShapeNets | volume | 77.3% |
45 | | Voxnet | volume | 83% |
46 | | O-cnn | volume | 89.9% |
47 | | Mvcnn | view | 90.1% |
48 | | Mlvcnn | view | 94.16% |
49 | | Pointnet | Point cloud | 89.2% |
50 | | Meshnet | Mesh | 91% |
51 | | Ours with SAGE | Mesh | 94.3% +0.5% |
52 | - run test
53 | - get the dataset if you fail download it from python [[Model40]](https://drive.google.com/uc?export=download&confirm=HB4c&id=1o9pyskkKMxuomI5BWuLjCG2nSv5iePZz)
54 | - put the zip file in datasets/modelnet40_graph
55 | - download the weight file from [[weights]](https://drive.google.com/file/d/11JOiaTOBCykCYgZKw24qcD6r1Tzz7dvu/view?usp=sharing) and put it in your ckpt_root/40_graph/ and run
56 | ``` bash
57 | sh script/modelnet40_graph/test.sh
58 | ```
59 | - the result will be like
60 | ``` bash
61 | root@2730e382330f:/meshvertex# sh script/modelnet40_graph/test.sh
62 | Running Test
63 | loading the model from ./ckpt_root/40_graph/final_net.pth
64 | epoch: -1, TEST ACC: [94.49 %]
65 | ```
66 | 
67 | 
68 | # Train on your Dataset
69 | ### Coming soon
70 |
71 | # Some 3D Reconstruct conducted on 3d face
72 | 
73 |
74 | # Credit
75 |
76 | ### MeshGraphNet: An Effective 3D Polygon Mesh Recognition With Topology Reconstruction
77 | An Ping Song ; Xin Yi Di; Xiao Kang Xu; Zi Heng Song
78 |
79 | **Abstract**
80 | Three-dimensional polygon mesh recognition has a significant impact on current computer graphics. However, its application to some real-life fields, such as unmanned driving and medical image processing, has been restricted due to the lack of inner-interactivity, shift-invariance, and numerical uncertainty of mesh surfaces. In this paper, an interconnected topological dual graph that extracts adjacent information from each triangular face of a polygon mesh is constructed, in order to address the above issues. On the basis of the algebraic topological graph, we propose a mesh graph neural network, called MeshGraphNet, to effectively extract features from mesh data. In this concept, the graph node-unit and correlation between every two dual graph vertexes are defined, the concept of aggregating features extracted from geodesically adjacent nodes is introduced, and a graph neural network with available and effective blocks is proposed. With these methods, MeshGraphNet performs well in 3D shape representation by avoiding the lack of inner-interactivity, shift-invariance, and the numerical uncertainty problems of mesh data. We conduct extensive 3D shape classification experiments and provide visualizations of the features extracted from the fully connected layers. The results demonstrate that our method performs better than state-of-the-art methods and improves the recognition accuracy by 4–4.5%.
81 |
82 | If you find this code useful, please consider citing our paper
83 |
84 | ``` bash
85 | @article{Song2020MeshGraphNetAE,
86 | title={MeshGraphNet: An Effective 3D Polygon Mesh Recognition With Topology Reconstruction},
87 | author={An Ping Song and Xin Yi Di and X. Xu and Zi Heng Song},
88 | journal={IEEE Access},
89 | year={2020},
90 | volume={8},
91 | pages={205181-205189}
92 | }
93 | ```
--------------------------------------------------------------------------------
/data/HumanSeg.py:
--------------------------------------------------------------------------------
1 | import os
2 | import os.path as osp
3 | import glob
4 |
5 |
6 | import torch
7 | from torch_geometric.data import InMemoryDataset, download_url, extract_zip
8 | from data.load_obj import read_obj
9 |
10 |
11 | class HumanSeg(InMemoryDataset):
12 | def __init__(self,
13 | root,
14 | name='seg',
15 | train=True,
16 | transform=None,
17 | pre_transform=None,
18 | pre_filter=None):
19 | assert name in ['seg']
20 | self.name = name
21 | self.train = train
22 | super(HumanSeg, self).__init__(root, transform, pre_transform,
23 | pre_filter)
24 | path = self.processed_paths[0] if train else self.processed_paths[1]
25 | self.data, self.slices = torch.load(path)
26 |
27 | @property
28 | def raw_file_names(self):
29 | return [
30 | 'seg', 'sseg', 'test', 'train'
31 | ]
32 |
33 | @property
34 | def processed_file_names(self):
35 | return ['training.pt', 'test.pt']
36 |
37 | def download(self):
38 | path = download_url(
39 | 'https://www.dropbox.com/s/s3n05sw0zg27fz3/human_seg.tar.gz', self.root)
40 | extract_zip(path, self.root)
41 | os.unlink(path)
42 | folder = osp.join(self.root, 'human_seg')
43 | os.rename(folder, self.raw_dir)
44 |
45 | def process(self):
46 | torch.save(self.process_set('train'), self.processed_paths[0])
47 | torch.save(self.process_set('test'), self.processed_paths[1])
48 |
49 | def process_set(self, dataset):
50 | categories = glob.glob(osp.join(self.raw_dir, '*', ''))
51 | categories = sorted([x.split(os.sep)[-2] for x in categories])
52 | data_list = []
53 | seg_folder = osp.join(self.raw_dir, 'seg')
54 | sseg_folder = osp.join(self.raw_dir, 'sseg')
55 | data_folder = osp.join(self.raw_dir, dataset)
56 | paths = glob.glob('{}/*.obj'.format(data_folder))
57 | for path in paths:
58 | data = read_obj(path)
59 | print(data)
60 |
61 | if self.pre_filter is not None:
62 | data_list = [d for d in data_list if self.pre_filter(d)]
63 |
64 | if self.pre_transform is not None:
65 | data_list = [self.pre_transform(d) for d in data_list]
66 |
67 | return self.collate(data_list)
68 |
69 | def __repr__(self):
70 | return '{}{}({})'.format(self.__class__.__name__, self.name, len(self))
71 |
--------------------------------------------------------------------------------
/data/ModelNet.py:
--------------------------------------------------------------------------------
1 | import os
2 | import os.path as osp
3 | import glob
4 |
5 |
6 | import torch
7 | from torch_geometric.data import InMemoryDataset, download_url, extract_zip
8 | from data.load_npz import read_npz
9 |
10 |
11 | class ModelNet(InMemoryDataset):
12 | r"""The ModelNet10/40 datasets from the `"3D ShapeNets: A Deep
13 | Representation for Volumetric Shapes"
14 | `_ paper,
15 | containing CAD models of 10 and 40 categories, respectively.
16 |
17 | .. note::
18 |
19 | Data objects hold mesh faces instead of edge indices.
20 | To convert the mesh to a graph, use the
21 | :obj:`torch_geometric.transforms.FaceToEdge` as :obj:`pre_transform`.
22 | To convert the mesh to a point cloud, use the
23 | :obj:`torch_geometric.transforms.SamplePoints` as :obj:`transform` to
24 | sample a fixed number of points on the mesh faces according to their
25 | face area.
26 |
27 | Args:
28 | root (string): Root directory where the dataset should be saved.
29 | name (string, optional): The name of the dataset (:obj:`"10"` for
30 | ModelNet10, :obj:`"40"` for ModelNet40). (default: :obj:`"10"`)
31 | train (bool, optional): If :obj:`True`, loads the training dataset,
32 | otherwise the test dataset. (default: :obj:`True`)
33 | transform (callable, optional): A function/transform that takes in an
34 | :obj:`torch_geometric.data.Data` object and returns a transformed
35 | version. The data object will be transformed before every access.
36 | (default: :obj:`None`)
37 | pre_transform (callable, optional): A function/transform that takes in
38 | an :obj:`torch_geometric.data.Data` object and returns a
39 | transformed version. The data object will be transformed before
40 | being saved to disk. (default: :obj:`None`)
41 | pre_filter (callable, optional): A function that takes in an
42 | :obj:`torch_geometric.data.Data` object and returns a boolean
43 | value, indicating whether the data object should be included in the
44 | final dataset. (default: :obj:`None`)
45 | """
46 |
47 | urls = {
48 | '10':
49 | 'http://vision.princeton.edu/projects/2014/3DShapeNets/ModelNet10.zip',
50 | '40': 'http://modelnet.cs.princeton.edu/ModelNet40.zip',
51 | '40_graph': 'https://drive.google.com/uc?export=download&confirm=HB4c&id=1o9pyskkKMxuomI5BWuLjCG2nSv5iePZz'
52 | }
53 |
54 | def __init__(self,
55 | root,
56 | name='10',
57 | train=True,
58 | transform=None,
59 | pre_transform=None,
60 | pre_filter=None):
61 | assert name in ['10', '40', '40_graph']
62 | self.name = name
63 | self.train = train
64 | super(ModelNet, self).__init__(root, transform, pre_transform,
65 | pre_filter)
66 | path = self.processed_paths[0] if train else self.processed_paths[1]
67 | self.data, self.slices = torch.load(path)
68 |
69 | @property
70 | def raw_file_names(self):
71 | return [
72 | 'bathtub', 'bed', 'chair', 'desk', 'dresser', 'monitor',
73 | 'night_stand', 'sofa', 'table', 'toilet'
74 | ]
75 |
76 | @property
77 | def processed_file_names(self):
78 | return ['training.pt', 'test.pt']
79 |
80 | def download(self):
81 | path = download_url(self.urls[self.name], self.root)
82 | extract_zip(path, self.root)
83 | os.unlink(path)
84 | folder = osp.join(self.root, 'ModelNet{}'.format(self.name))
85 | os.rename(folder, self.raw_dir)
86 |
87 | def process(self):
88 | torch.save(self.process_set('train'), self.processed_paths[0])
89 | torch.save(self.process_set('test'), self.processed_paths[1])
90 |
91 | def process_set(self, dataset):
92 | categories = glob.glob(osp.join(self.raw_dir, '*', ''))
93 | categories = sorted([x.split(os.sep)[-2] for x in categories])
94 |
95 | data_list = []
96 | for target, category in enumerate(categories):
97 | folder = osp.join(self.raw_dir, category, dataset)
98 | paths = glob.glob('{}/{}_*.npz'.format(folder, category))
99 | for path in paths:
100 | data = read_npz(path, self.train)
101 | data.y = torch.tensor([target])
102 | data_list.append(data)
103 |
104 | if self.pre_filter is not None:
105 | data_list = [d for d in data_list if self.pre_filter(d)]
106 |
107 | if self.pre_transform is not None:
108 | data_list = [self.pre_transform(d) for d in data_list]
109 |
110 | return self.collate(data_list)
111 |
112 | def __repr__(self):
113 | return '{}{}({})'.format(self.__class__.__name__, self.name, len(self))
114 |
--------------------------------------------------------------------------------
/data/load_npz.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from torch_geometric.data import Data
4 |
5 |
6 | def read_npz(path, train):
7 | with np.load(path) as f:
8 | return parse_npz(f, train)
9 |
10 |
11 | def parse_npz(f, train):
12 | face = f['face']
13 | neighbor_index = f['neighbor_index']
14 |
15 | # data augmentation
16 | if train:
17 | sigma, clip = 0.01, 0.05
18 | jittered_data = np.clip(
19 | sigma * np.random.randn(*face[:, :12].shape), -1 * clip, clip)
20 | face = np.concatenate((face[:, :12] + jittered_data, face[:, 12:]), 1)
21 |
22 | # fill for n < max_faces with randomly picked faces
23 | num_point = len(face)
24 | if num_point < 1024:
25 | fill_face = []
26 | fill_neighbor_index = []
27 | for i in range(1024 - num_point):
28 | index = np.random.randint(0, num_point)
29 | fill_face.append(face[index])
30 | fill_neighbor_index.append(neighbor_index[index])
31 | face = np.concatenate((face, np.array(fill_face)))
32 | neighbor_index = np.concatenate(
33 | (neighbor_index, np.array(fill_neighbor_index)))
34 |
35 | # to tensor
36 | face = torch.from_numpy(face).float()
37 | neighbor_index = torch.from_numpy(neighbor_index).long()
38 | index = torch.arange(face.size(0)).unsqueeze(dim=1).repeat(1, 3)
39 | gather_index = torch.tensor([0, 3, 1, 4, 2, 5]).repeat(face.size(0), 1)
40 | edge_index = torch.cat([neighbor_index, index], dim=1).gather(
41 | 1, gather_index).view(-1, 2).permute(1, 0)
42 |
43 |
44 | # reorganize
45 | face = face.permute(1, 0)
46 | centers, corners, normals = face[:3], face[3:12], face[12:]
47 |
48 |
49 | # # get the sod of each faces
50 | # '''
51 | # w(e) = cos(ni/||ni|| * nj/||nj||)^-1
52 | # '''
53 | # start_point, end_point = edge_index[0, :], edge_index[1, :]
54 |
55 | # print(start_point.size())
56 | # print(normals.size())
57 |
58 |
59 |
60 |
61 | corners = corners - torch.cat([centers, centers, centers], 0)
62 |
63 | features = torch.cat([centers, corners, normals], dim=0).permute(1, 0)
64 |
65 | data = Data(x=features, edge_index=edge_index, pos=neighbor_index)
66 | return data
67 |
--------------------------------------------------------------------------------
/data/load_obj.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch_geometric.data import Data
3 |
4 |
5 | def yield_file(in_file):
6 | f = open(in_file)
7 | buf = f.read()
8 | f.close()
9 | for b in buf.split('\n'):
10 | if b.startswith('v '):
11 | yield ['v', [float(x) for x in b.split(" ")[1:]]]
12 | elif b.startswith('f '):
13 | triangles = b.split(' ')[1:]
14 | # -1 as .obj is base 1 but the Data class expects base 0 indices
15 | yield ['f', [int(t.split("/")[0]) - 1 for t in triangles]]
16 | else:
17 | yield ['', ""]
18 |
19 |
20 | def read_obj(in_file):
21 | vertices = []
22 | faces = []
23 |
24 | for k, v in yield_file(in_file):
25 | if k == 'v':
26 | vertices.append(v)
27 | elif k == 'f':
28 | faces.append(v)
29 |
30 | if not len(faces) or not len(vertices):
31 | return None
32 |
33 | pos = torch.tensor(vertices, dtype=torch.float)
34 | face = torch.tensor(faces, dtype=torch.long).t().contiguous()
35 |
36 | data = Data(pos=pos, face=face)
37 |
38 | return data
--------------------------------------------------------------------------------
/docker/Dockerfile:
--------------------------------------------------------------------------------
1 | ARG CUDA="10.2"
2 | ARG CUDNN="7"
3 |
4 | FROM nvidia/cuda:${CUDA}-cudnn${CUDNN}-devel-ubuntu16.04
5 | # FROM pytorch/pytorch:latest
6 |
7 | RUN echo 'debconf debconf/frontend select Noninteractive' | debconf-set-selections
8 |
9 | # install basics
10 | RUN apt-get update -y \
11 | && apt-get install -y apt-utils git curl ca-certificates bzip2 cmake tree htop bmon iotop g++ \
12 | && apt-get install -y libglib2.0-0 libsm6 libxext6 libxrender-dev wget
13 |
14 | # Install Miniconda
15 | RUN cd / && wget https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh \
16 | && chmod +x /Miniconda3-latest-Linux-x86_64.sh \
17 | && /Miniconda3-latest-Linux-x86_64.sh -b -p /miniconda \
18 | && rm /Miniconda3-latest-Linux-x86_64.sh
19 |
20 | ENV PATH=/miniconda/bin:$PATH
21 |
22 | # Create a Python 3.6 environment
23 | RUN /miniconda/bin/conda install -y conda-build \
24 | && /miniconda/bin/conda create -y --name py37 python=3.7 \
25 | && /miniconda/bin/conda clean -ya
26 |
27 | ENV CONDA_DEFAULT_ENV=py37
28 | ENV CONDA_PREFIX=/miniconda/envs/$CONDA_DEFAULT_ENV
29 | ENV PATH=$CONDA_PREFIX/bin:$PATH
30 | ENV CONDA_AUTO_UPDATE_CONDA=false
31 |
32 | RUN conda install -y ipython
33 | RUN pip install --default-timeout=1000 numpy ninja yacs cython matplotlib opencv-python tqdm pyyaml tensorboardX hiddenlayer -i http://pypi.douban.com/simple --trusted-host pypi.douban.com
34 |
35 | # install pytorch
36 | # RUN conda config --add channels https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main/
37 | # RUN conda config --add channels https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/free/
38 | # RUN conda config --add channels https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/pytorch/
39 | # RUN conda config --set show_channel_urls yes && conda install pytorch torchvision cudatoolkit=${CUDA} -c pytorch \
40 | # && conda clean -ya
41 | # RUN conda install pytorch torchvision cudatoolkit=${CUDA} -c pytorch \
42 | # && conda clean -ya
43 |
44 | # set cuda path
45 | ENV PATH=/usr/local/cuda/bin:$PATH
46 | ENV CPATH=/usr/local/cuda/include:$CPATH
47 | ENV LD_LIBRARY_PATH=/usr/local/cuda/lib64:$LD_LIBRARY_PATH
48 | ENV DYLD_LIBRARY_PATH=/usr/local/cuda/lib:$DYLD_LIBRARY_PATH
49 |
50 | RUN pip install torch-scatter==latest+cu102 -f https://pytorch-geometric.com/whl/torch-1.5.0.html -i http://pypi.douban.com/simple --trusted-host pypi.douban.com
51 | RUN pip install torch-sparse==latest+cu102 -f https://pytorch-geometric.com/whl/torch-1.5.0.html -i http://pypi.douban.com/simple --trusted-host pypi.douban.com
52 | RUN pip install torch-cluster==latest+cu102 -f https://pytorch-geometric.com/whl/torch-1.5.0.html -i http://pypi.douban.com/simple --trusted-host pypi.douban.com
53 | RUN pip install torch-spline-conv==latest+cu102 -f https://pytorch-geometric.com/whl/torch-1.5.0.html -i http://pypi.douban.com/simple --trusted-host pypi.douban.com
54 | RUN pip install torch-geometric -i http://pypi.douban.com/simple --trusted-host pypi.douban.com
55 | RUN pip install torchvision==0.6.0 -i http://pypi.douban.com/simple --trusted-host pypi.douban.com
56 |
57 | WORKDIR /meshgraph
58 |
--------------------------------------------------------------------------------
/images/1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JsBlueCat/MeshGraph/c7b331f64b70a442d1042351efe4dc4231e8b289/images/1.png
--------------------------------------------------------------------------------
/images/2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JsBlueCat/MeshGraph/c7b331f64b70a442d1042351efe4dc4231e8b289/images/2.png
--------------------------------------------------------------------------------
/images/3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JsBlueCat/MeshGraph/c7b331f64b70a442d1042351efe4dc4231e8b289/images/3.png
--------------------------------------------------------------------------------
/images/4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JsBlueCat/MeshGraph/c7b331f64b70a442d1042351efe4dc4231e8b289/images/4.png
--------------------------------------------------------------------------------
/images/5.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JsBlueCat/MeshGraph/c7b331f64b70a442d1042351efe4dc4231e8b289/images/5.png
--------------------------------------------------------------------------------
/images/6.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JsBlueCat/MeshGraph/c7b331f64b70a442d1042351efe4dc4231e8b289/images/6.png
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | def create_model(opt):
2 | from .mesh_graph import mesh_graph
3 | model = mesh_graph(opt)
4 | return model
5 |
--------------------------------------------------------------------------------
/models/layers/cpp_extension/meshgraph.cpp:
--------------------------------------------------------------------------------
1 | #include
2 |
3 | at::Tensor get_connect_matrix(const at::Tensor &faces, const int64_t &num_nodes)
4 | {
5 | /*
6 | get connect matrix
7 | */
8 | std::vector> node_list(num_nodes, std::vector(0));
9 | #pragma omp parallel for
10 | for (int64_t i = 0; i < faces.size(0); i++)
11 | {
12 | int64_t *col = faces[i].data();
13 | #pragma omp parallel for
14 | for (int j = 0; j < 3; j++)
15 | {
16 | node_list[col[j]].push_back(i);
17 | }
18 | }
19 | at::Tensor result = at::zeros(2, faces.options());
20 | #pragma omp parallel for
21 | for (int64_t it = 0; it < num_nodes; it++)
22 | {
23 | #pragma omp parallel for
24 | for (int64_t i = 0; i < node_list[it].size() - 1; i++)
25 | {
26 | #pragma omp parallel for
27 | for (int64_t j = i + 1; j < node_list[it].size(); j++)
28 | {
29 | std::cout << i << " " << j << std::endl;
30 | auto temp = at::empty(2, faces.options());
31 | temp[0] = i;
32 | temp[1] = j;
33 | result = at::cat({result, temp}, 0);
34 | }
35 | }
36 | }
37 |
38 | return result.view({-1, 2});
39 | }
40 |
41 | std::string test()
42 | {
43 | std::cout << "hello world" << std::endl;
44 | return "hello_world";
45 | }
46 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
47 | {
48 | m.def("get_connect_matrix", &get_connect_matrix, "get_connect_matrix");
49 | m.def("test", &test, "test function");
50 | }
51 |
--------------------------------------------------------------------------------
/models/layers/cpp_extension/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, Extension
2 | from torch.utils import cpp_extension
3 |
4 |
5 | setup(name='meshgraph_cpp',
6 | ext_modules=[cpp_extension.CppExtension(
7 | 'meshgraph_cpp', ['meshgraph.cpp'])],
8 | cmdclass={'build_ext': cpp_extension.BuildExtension})
9 |
--------------------------------------------------------------------------------
/models/layers/mesh.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import math
4 |
5 |
6 | class Mesh:
7 | '''
8 | mesh graph obj contains pos edge_index and x
9 | x reprensent feature in face obj
10 | edge_index is the sparse matrix of adj matrix
11 | pos represent position matrix
12 | '''
13 |
14 | def __init__(self, vertexes, faces):
15 | self.vertexes = vertexes
16 | self.faces = faces.t().long()
17 | self.edge_index = self.nodes = None
18 | self.num_nodes = self.vertexes.size(0)
19 | # find point and face indices
20 | self.sorted_point_to_face = None
21 | self.sorted_point_to_face_index_dict = None
22 | # normalize vertexes
23 | self.normlize_vertices()
24 |
25 | # create graph
26 | self.create_graph()
27 |
28 | def create_graph(self):
29 | # conat center ox oy oz norm
30 | pos = self.vertexes[self.faces]
31 | point_x, point_y, point_z = pos[:, 0, :], pos[:, 1, :], pos[:, 2, :]
32 | centers = get_inner_center_vec(
33 | point_x, point_y, point_z
34 | )
35 | temp = centers.view(-1)
36 | if torch.sum(torch.isnan(temp), dim=0) > 0:
37 | raise('center ------------------------------- nan')
38 | ox, oy, oz = get_three_vec(centers, point_x, point_y, point_z)
39 | norm = get_unit_norm_vec(point_x, point_y, point_z)
40 | temp = norm.view(-1)
41 | if torch.sum(torch.isnan(temp), dim=0) > 0:
42 | raise('norm ------------------------------- nan')
43 | # cat the vecter
44 | self.nodes = torch.cat((centers, ox, oy, oz, norm), dim=1)
45 |
46 | is_nan = self.nodes.view(-1)
47 | if torch.sum(torch.isnan(is_nan), dim=0) > 0:
48 | raise('contain nan ')
49 | self.get_connect_matrix()
50 |
51 | def normlize_vertices(self):
52 | ''' move vertices to center
53 | '''
54 | center = (torch.max(self.vertexes, dim=0)[0] +
55 | torch.min(self.vertexes, dim=0)[0])/2
56 | self.vertexes -= center
57 | max_len = torch.max(self.vertexes[:, 0]**2 +
58 | self.vertexes[:, 1]**2 + self.vertexes[:, 2]**2).item()
59 | self.vertexes /= math.sqrt(max_len)
60 |
61 | def get_connect_matrix(self):
62 | node_list = [[] for _ in range(self.num_nodes)]
63 | result = []
64 | for i, v in enumerate(self.faces):
65 | v0, v1, v2 = v
66 | node_list[v0].append(i)
67 | node_list[v1].append(i)
68 | node_list[v2].append(i)
69 | for i in node_list:
70 | for p in range(0, len(i)-1):
71 | for q in range(p+1, len(i)):
72 | result.append([p, q])
73 | self_loop = torch.arange(self.nodes.size(0)).repeat(2, 1).t().long()
74 | result = torch.tensor(result).long()
75 | self.edge_index = torch.cat([self_loop, result], dim=0)
76 |
77 |
78 | def get_inner_center_vec(v1, v2, v3):
79 | '''
80 | v1 v2 v3 represent 3 vertexes of triangle
81 | v1 (n,3)
82 | '''
83 | return (v1+v2+v3)/3
84 |
85 |
86 | def get_distance_vec(v1, v2):
87 | '''
88 | get distance between
89 | vecter_1 and vecter_2
90 | v1 : (x,y,z)
91 | '''
92 | return torch.sqrt(torch.sum((v1-v2)**2, dim=1))
93 |
94 |
95 | def get_unit_norm_vec(v1, v2, v3):
96 | '''
97 | xy X xz
98 | (y1z2-y2z1,x2z1-x1z2,x1y2-x2y1)
99 | '''
100 | xy = v2-v1
101 | xz = v3-v1
102 | x1, y1, z1 = xy[:, 0], xy[:, 1], xy[:, 2]
103 | x2, y2, z2 = xz[:, 0], xz[:, 1], xz[:, 2]
104 | norm = torch.stack((y1*z2-y2*z1, x2*z1-x1*z2, x1*y2-x2*y1), dim=1)
105 | return norm
106 |
107 |
108 | def get_three_vec(center, v1, v2, v3):
109 | '''
110 | return ox oy oz vector
111 | '''
112 | return v1-center, v2-center, v3-center
113 |
--------------------------------------------------------------------------------
/models/layers/mesh_cpp_extension.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 |
5 | class Mesh:
6 | '''
7 | mesh graph obj contains pos edge_index and x
8 | x reprensent feature in face obj
9 | edge_index is the sparse matrix of adj matrix
10 | pos represent position matrix
11 | '''
12 |
13 | def __init__(self, vertexes, faces, meshgraph_cpp):
14 | self.vertexes = vertexes
15 | self.faces = faces.t()
16 | self.edge_index = self.nodes = None
17 | self.num_nodes = self.vertexes.size(0)
18 | self.meshgraph_cpp = meshgraph_cpp
19 | # find point and face indices
20 | self.sorted_point_to_face = None
21 | self.sorted_point_to_face_index_dict = None
22 | # create graph
23 | self.create_graph()
24 |
25 | def create_graph(self):
26 | # conat center ox oy oz norm
27 | pos = self.vertexes[self.faces]
28 | point_x, point_y, point_z = pos[:, 0, :], pos[:, 1, :], pos[:, 2, :]
29 | centers = get_inner_center_vec(
30 | point_x, point_y, point_z
31 | )
32 | ox, oy, oz = get_three_vec(centers, point_x, point_y, point_z)
33 | norm = get_unit_norm_vec(point_x, point_y, point_z)
34 | # cat the vecter
35 | self.nodes = torch.cat((centers, ox, oy, oz, norm), dim=1)
36 | self.edge_index = self.meshgraph_cpp.get_connect_matrix(
37 | self.faces, self.num_nodes
38 | )
39 |
40 |
41 | def get_inner_center_vec(v1, v2, v3):
42 | '''
43 | v1 v2 v3 represent 3 vertexes of triangle
44 | v1 (n,3)
45 | '''
46 | a = get_distance_vec(v2, v3)
47 | b = get_distance_vec(v3, v1)
48 | c = get_distance_vec(v1, v2)
49 | x = torch.stack((v1[:, 0], v2[:, 0], v3[:, 0]), dim=1)
50 | y = torch.stack((v1[:, 1], v2[:, 1], v3[:, 1]), dim=1)
51 | z = torch.stack((v1[:, 2], v2[:, 2], v3[:, 2]), dim=1)
52 | dis = torch.stack((a, b, c), dim=1)
53 | return torch.stack((
54 | torch.sum((x * dis) / (a+b+c).repeat(3, 1).t(), dim=1),
55 | torch.sum((y * dis) / (a+b+c).repeat(3, 1).t(), dim=1),
56 | torch.sum((z * dis) / (a+b+c).repeat(3, 1).t(), dim=1),
57 | ), dim=1)
58 |
59 |
60 | def get_distance_vec(v1, v2):
61 | '''
62 | get distance between
63 | vecter_1 and vecter_2
64 | v1 : (x,y,z)
65 | '''
66 | return torch.sqrt(torch.sum((v1-v2)**2, dim=1))
67 |
68 |
69 | def get_unit_norm_vec(v1, v2, v3):
70 | '''
71 | xy X xz
72 | (y1z2-y2z1,x2z1-x1z2,x1y2-x2y1)
73 | '''
74 | xy = v2-v1
75 | xz = v3-v1
76 | x1, y1, z1 = xy[:, 0], xy[:, 1], xy[:, 2]
77 | x2, y2, z2 = xz[:, 0], xz[:, 1], xz[:, 2]
78 | norm = torch.stack((y1*z2-y2*z1, x2*z1-x1*z2, x1*y2-x2*y1), dim=1)
79 | vec_len = torch.sqrt(torch.sum(norm, dim=1))
80 | return norm / vec_len.repeat(3, 1).t()
81 |
82 |
83 | def get_three_vec(center, v1, v2, v3):
84 | '''
85 | return ox oy oz vector
86 | '''
87 | return v1-center, v2-center, v3-center
88 |
--------------------------------------------------------------------------------
/models/layers/mesh_graph_conv.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | import math
5 | from torch.nn.parameter import Parameter
6 | from torch_scatter import scatter_add
7 | from torch_geometric.nn import GINConv, SAGEConv
8 |
9 |
10 | class FaceRotateConvolution(nn.Module):
11 |
12 | def __init__(self):
13 | super(FaceRotateConvolution, self).__init__()
14 | self.rotate_mlp = nn.Sequential(
15 | nn.Conv1d(6, 32, 1),
16 | nn.BatchNorm1d(32),
17 | nn.ReLU(),
18 | nn.Conv1d(32, 32, 1),
19 | nn.BatchNorm1d(32),
20 | nn.ReLU()
21 | )
22 | self.fusion_mlp = nn.Sequential(
23 | nn.Conv1d(32, 64, 1),
24 | nn.BatchNorm1d(64),
25 | nn.ReLU(),
26 | nn.Conv1d(64, 64, 1),
27 | nn.BatchNorm1d(64),
28 | nn.ReLU()
29 | )
30 |
31 | def forward(self, corners):
32 | fea = (self.rotate_mlp(corners[:, :6]) +
33 | self.rotate_mlp(corners[:, 3:9]) +
34 | self.rotate_mlp(torch.cat([corners[:, 6:], corners[:, :3]], 1))) / 3
35 | return self.fusion_mlp(fea)
36 |
37 |
38 | class FaceKernelCorrelation(nn.Module):
39 |
40 | def __init__(self, num_kernel=64, sigma=0.2):
41 | super(FaceKernelCorrelation, self).__init__()
42 | self.num_kernel = num_kernel
43 | self.sigma = sigma
44 | self.weight_alpha = Parameter(torch.rand(1, num_kernel, 4) * np.pi)
45 | self.weight_beta = Parameter(torch.rand(1, num_kernel, 4) * 2 * np.pi)
46 | self.bn = nn.BatchNorm1d(num_kernel)
47 | self.relu = nn.ReLU()
48 |
49 | def forward(self, normals, neighbor_index):
50 |
51 | b, _, n = normals.size()
52 |
53 | center = normals.unsqueeze(
54 | 2).expand(-1, -1, self.num_kernel, -1).unsqueeze(4)
55 | neighbor = torch.gather(normals.unsqueeze(3).expand(-1, -1, -1, 3), 2,
56 | neighbor_index.unsqueeze(1).expand(-1, 3, -1, -1))
57 | neighbor = neighbor.unsqueeze(
58 | 2).expand(-1, -1, self.num_kernel, -1, -1)
59 |
60 | fea = torch.cat([center, neighbor], 4)
61 | fea = fea.unsqueeze(5).expand(-1, -1, -1, -1, -1, 4)
62 | weight = torch.cat([torch.sin(self.weight_alpha) * torch.cos(self.weight_beta),
63 | torch.sin(self.weight_alpha) *
64 | torch.sin(self.weight_beta),
65 | torch.cos(self.weight_alpha)], 0)
66 | weight = weight.unsqueeze(0).expand(b, -1, -1, -1)
67 | weight = weight.unsqueeze(3).expand(-1, -1, -1, n, -1)
68 | weight = weight.unsqueeze(4).expand(-1, -1, -1, -1, 4, -1)
69 |
70 | dist = torch.sum((fea - weight)**2, 1)
71 | fea = torch.sum(
72 | torch.sum(np.e**(dist / (-2 * self.sigma**2)), 4), 3) / 16
73 |
74 | return self.relu(self.bn(fea))
75 |
76 |
77 | class SpatialDescriptor(nn.Module):
78 |
79 | def __init__(self):
80 | super(SpatialDescriptor, self).__init__()
81 |
82 | self.spatial_mlp = nn.Sequential(
83 | nn.Conv1d(3, 64, 1),
84 | nn.BatchNorm1d(64),
85 | nn.ReLU(),
86 | nn.Conv1d(64, 64, 1),
87 | nn.BatchNorm1d(64),
88 | nn.ReLU(),
89 | )
90 |
91 | def forward(self, centers):
92 | return self.spatial_mlp(centers)
93 |
94 |
95 | class StructuralDescriptor(nn.Module):
96 |
97 | def __init__(self):
98 | super(StructuralDescriptor, self).__init__()
99 |
100 | self.FRC = FaceRotateConvolution()
101 | self.FKC = FaceKernelCorrelation(64, 0.2)
102 | self.structural_mlp = nn.Sequential(
103 | nn.Conv1d(64 + 3, 131, 1),
104 | nn.BatchNorm1d(131),
105 | nn.ReLU(),
106 | nn.Conv1d(131, 131, 1),
107 | nn.BatchNorm1d(131),
108 | nn.ReLU(),
109 | )
110 |
111 | def forward(self, corners, normals):
112 | structural_fea1 = self.FRC(corners)
113 | # structural_fea2 = self.FKC(normals, neighbor_index)
114 |
115 | return self.structural_mlp(torch.cat([structural_fea1, normals], 1))
116 |
117 |
118 | class MeshConvolution(nn.Module):
119 |
120 | def __init__(self, spatial_in_channel, structural_in_channel, spatial_out_channel, structural_out_channel):
121 | super(MeshConvolution, self).__init__()
122 |
123 | self.spatial_in_channel = spatial_in_channel
124 | self.structural_in_channel = structural_in_channel
125 | self.spatial_out_channel = spatial_out_channel
126 | self.structural_out_channel = structural_out_channel
127 |
128 | self.aggregation_method = 'Concat'
129 |
130 | self.combination_mlp = nn.Sequential(
131 | nn.Conv1d(self.spatial_in_channel +
132 | self.structural_in_channel, self.spatial_out_channel, 1),
133 | nn.BatchNorm1d(self.spatial_out_channel),
134 | nn.ReLU(),
135 | )
136 |
137 | # if self.aggregation_method == 'Concat':
138 | # self.concat_mlp = nn.Sequential(
139 | # nn.Conv2d(self.structural_in_channel * 2,
140 | # self.structural_in_channel, 1),
141 | # nn.BatchNorm2d(self.structural_in_channel),
142 | # nn.ReLU(),
143 | # )
144 |
145 | # self.concat_mlp = nn.Sequential(
146 | # nn.Conv1d(self.structural_in_channel,
147 | # self.structural_out_channel, 1),
148 | # nn.BatchNorm1d(self.spatial_out_channel),
149 | # nn.ReLU(),
150 | # )
151 |
152 | self.aggregation_mlp = nn.Sequential(
153 | nn.Conv1d(self.structural_in_channel,
154 | self.structural_out_channel, 1),
155 | nn.BatchNorm1d(self.structural_out_channel),
156 | nn.ReLU(),
157 | )
158 |
159 | def forward(self, spatial_fea, structural_fea,_):
160 | b, _, n = spatial_fea.size()
161 |
162 | # Combination
163 | spatial_fea = self.combination_mlp(
164 | torch.cat([spatial_fea, structural_fea], 1))
165 |
166 | # # Aggregation
167 | # if self.aggregation_method == 'Concat':
168 | # structural_fea = torch.cat([structural_fea.unsqueeze(3).expand(-1, -1, -1, 3),
169 | # torch.gather(structural_fea.unsqueeze(3).expand(-1, -1, -1, 3), 2,
170 | # neighbor_index.unsqueeze(1).expand(-1, self.structural_in_channel,
171 | # -1, -1))], 1)
172 | # structural_fea = self.concat_mlp(structural_fea)
173 | # structural_fea = torch.max(structural_fea, 3)[0]
174 |
175 | # elif self.aggregation_method == 'Max':
176 | # structural_fea = torch.cat([structural_fea.unsqueeze(3),
177 | # torch.gather(structural_fea.unsqueeze(3).expand(-1, -1, -1, 3), 2,
178 | # neighbor_index.unsqueeze(1).expand(-1, self.structural_in_channel,
179 | # -1, -1))], 3)
180 | # structural_fea = torch.max(structural_fea, 3)[0]
181 |
182 | # elif self.aggregation_method == 'Average':
183 | # structural_fea = torch.cat([structural_fea.unsqueeze(3),
184 | # torch.gather(structural_fea.unsqueeze(3).expand(-1, -1, -1, 3), 2,
185 | # neighbor_index.unsqueeze(1).expand(-1, self.structural_in_channel,
186 | # -1, -1))], 3)
187 | # structural_fea = torch.sum(structural_fea, dim=3) / 4
188 |
189 | structural_fea = self.aggregation_mlp(structural_fea)
190 |
191 | return spatial_fea, structural_fea
192 |
193 |
194 | class NormGraphConv(nn.Module):
195 | def __init__(self):
196 | super(NormGraphConv, self).__init__()
197 | self.graph_conv = SAGEConv(3, 64)
198 | self.norm_mlp = nn.Sequential(
199 | nn.Conv1d(64, 64, 1),
200 | nn.BatchNorm1d(64),
201 | nn.ReLU()
202 | )
203 |
204 | def forward(self, x, edge_index, opt):
205 | result = self.graph_conv(x, edge_index)
206 | # x = x.view(opt.batch_size, -1, 3) # n 1024 3
207 | # result = None
208 | # for i in range(opt.batch_size):
209 | # temp = self.graph_conv(x[i], edge_index) # 1024 3 -> 1024 64
210 | # if i == 0:
211 | # result = temp
212 | # else:
213 | # result = torch.cat([result, temp], dim=0) # 1024 * i 64
214 | result = result.view(opt.batch_size, -1,
215 | 64).transpose(1, 2) # 64 64 1024
216 | return self.norm_mlp(result)
217 |
--------------------------------------------------------------------------------
/models/layers/mesh_net.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.nn.parameter import Parameter
4 |
5 |
6 | class FaceRotateConvolution(nn.Module):
7 |
8 | def __init__(self):
9 | super(FaceRotateConvolution, self).__init__()
10 | self.rotate_mlp = nn.Sequential(
11 | nn.Conv1d(6, 32, 1),
12 | nn.BatchNorm1d(32),
13 | nn.ReLU(),
14 | nn.Conv1d(32, 32, 1),
15 | nn.BatchNorm1d(32),
16 | nn.ReLU()
17 | )
18 | self.fusion_mlp = nn.Sequential(
19 | nn.Conv1d(32, 64, 1),
20 | nn.BatchNorm1d(64),
21 | nn.ReLU(),
22 | nn.Conv1d(64, 64, 1),
23 | nn.BatchNorm1d(64),
24 | nn.ReLU()
25 | )
26 |
27 | def forward(self, corners):
28 | fea = (self.rotate_mlp(corners[:, :6]) +
29 | self.rotate_mlp(corners[:, 3:9]) +
30 | self.rotate_mlp(torch.cat([corners[:, 6:], corners[:, :3]], 1))) / 3
31 | return self.fusion_mlp(fea)
32 |
33 |
34 | class SpatialDescriptor(nn.Module):
35 |
36 | def __init__(self):
37 | super(SpatialDescriptor, self).__init__()
38 |
39 | self.spatial_mlp = nn.Sequential(
40 | nn.Conv1d(3, 64, 1),
41 | nn.BatchNorm1d(64),
42 | nn.ReLU(),
43 | nn.Conv1d(64, 64, 1),
44 | nn.BatchNorm1d(64),
45 | nn.ReLU(),
46 | )
47 |
48 | def forward(self, centers):
49 | return self.spatial_mlp(centers)
50 |
51 |
52 | class StructuralDescriptor(nn.Module):
53 |
54 | def __init__(self):
55 | super(StructuralDescriptor, self).__init__()
56 |
57 | self.FRC = FaceRotateConvolution()
58 | self.structural_mlp = nn.Sequential(
59 | nn.Conv1d(64+3+64, 131, 1),
60 | nn.BatchNorm1d(131),
61 | nn.ReLU(),
62 | nn.Conv1d(131, 131, 1),
63 | nn.BatchNorm1d(131),
64 | nn.ReLU(),
65 | )
66 |
67 | def forward(self, corners, normals, extra_norm):
68 | structural_fea1 = self.FRC(corners)
69 |
70 | return self.structural_mlp(torch.cat([structural_fea1, normals, extra_norm], 1))
71 |
72 |
73 | class MeshConvolution(nn.Module):
74 |
75 | def __init__(self, spatial_in_channel, structural_in_channel, spatial_out_channel, structural_out_channel):
76 | super(MeshConvolution, self).__init__()
77 |
78 | self.spatial_in_channel = spatial_in_channel
79 | self.structural_in_channel = structural_in_channel
80 | self.spatial_out_channel = spatial_out_channel
81 | self.structural_out_channel = structural_out_channel
82 |
83 | self.combination_mlp = nn.Sequential(
84 | nn.Conv1d(self.spatial_in_channel +
85 | self.structural_in_channel, self.spatial_out_channel, 1),
86 | nn.GroupNorm(32, self.spatial_out_channel),
87 | nn.ReLU(),
88 | )
89 |
90 | self.aggregation_mlp = nn.Sequential(
91 | nn.Conv1d(self.structural_in_channel,
92 | self.structural_out_channel, 1),
93 | nn.GroupNorm(32, self.structural_out_channel),
94 | nn.ReLU(),
95 | )
96 |
97 | def forward(self, spatial_fea, structural_fea, _):
98 | b, _, n = spatial_fea.size()
99 |
100 | # Combination
101 | spatial_fea = self.combination_mlp(
102 | torch.cat([spatial_fea, structural_fea], 1))
103 |
104 | # Aggregation
105 | structural_fea = self.aggregation_mlp(structural_fea)
106 |
107 | return spatial_fea, structural_fea
108 |
--------------------------------------------------------------------------------
/models/layers/mesh_net_with_out_neigbour.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | from torch.nn.parameter import Parameter
5 |
6 |
7 | class FaceRotateConvolution(nn.Module):
8 |
9 | def __init__(self):
10 | super(FaceRotateConvolution, self).__init__()
11 | self.rotate_mlp = nn.Sequential(
12 | nn.Conv1d(6, 32, 1),
13 | nn.BatchNorm1d(32),
14 | nn.ReLU(),
15 | nn.Conv1d(32, 32, 1),
16 | nn.BatchNorm1d(32),
17 | nn.ReLU()
18 | )
19 | self.fusion_mlp = nn.Sequential(
20 | nn.Conv1d(32, 64, 1),
21 | nn.BatchNorm1d(64),
22 | nn.ReLU(),
23 | nn.Conv1d(64, 64, 1),
24 | nn.BatchNorm1d(64),
25 | nn.ReLU()
26 | )
27 |
28 | def forward(self, corners):
29 | fea = (self.rotate_mlp(corners[:, :6]) +
30 | self.rotate_mlp(corners[:, 3:9]) +
31 | self.rotate_mlp(torch.cat([corners[:, 6:], corners[:, :3]], 1))) / 3
32 | return self.fusion_mlp(fea)
33 |
34 |
35 | class FaceKernelCorrelation(nn.Module):
36 |
37 | def __init__(self, num_kernel=64, sigma=0.2):
38 | super(FaceKernelCorrelation, self).__init__()
39 | self.num_kernel = num_kernel
40 | self.sigma = sigma
41 | self.weight_alpha = Parameter(torch.rand(1, num_kernel, 4) * np.pi)
42 | self.weight_beta = Parameter(torch.rand(1, num_kernel, 4) * 2 * np.pi)
43 | self.bn = nn.BatchNorm1d(num_kernel)
44 | self.relu = nn.ReLU()
45 |
46 | def forward(self, normals, neighbor_index):
47 |
48 | b, _, n = normals.size()
49 |
50 | center = normals.unsqueeze(
51 | 2).expand(-1, -1, self.num_kernel, -1).unsqueeze(4)
52 | neighbor = torch.gather(normals.unsqueeze(3).expand(-1, -1, -1, 3), 2,
53 | neighbor_index.unsqueeze(1).expand(-1, 3, -1, -1))
54 | neighbor = neighbor.unsqueeze(
55 | 2).expand(-1, -1, self.num_kernel, -1, -1)
56 |
57 | fea = torch.cat([center, neighbor], 4)
58 | fea = fea.unsqueeze(5).expand(-1, -1, -1, -1, -1, 4)
59 | weight = torch.cat([torch.sin(self.weight_alpha) * torch.cos(self.weight_beta),
60 | torch.sin(self.weight_alpha) *
61 | torch.sin(self.weight_beta),
62 | torch.cos(self.weight_alpha)], 0)
63 | weight = weight.unsqueeze(0).expand(b, -1, -1, -1)
64 | weight = weight.unsqueeze(3).expand(-1, -1, -1, n, -1)
65 | weight = weight.unsqueeze(4).expand(-1, -1, -1, -1, 4, -1)
66 |
67 | dist = torch.sum((fea - weight)**2, 1)
68 | fea = torch.sum(
69 | torch.sum(np.e**(dist / (-2 * self.sigma**2)), 4), 3) / 16
70 |
71 | return self.relu(self.bn(fea))
72 |
73 |
74 | class SpatialDescriptor(nn.Module):
75 |
76 | def __init__(self):
77 | super(SpatialDescriptor, self).__init__()
78 |
79 | self.spatial_mlp = nn.Sequential(
80 | nn.Conv1d(3, 64, 1),
81 | nn.BatchNorm1d(64),
82 | nn.ReLU(),
83 | nn.Conv1d(64, 64, 1),
84 | nn.BatchNorm1d(64),
85 | nn.ReLU(),
86 | )
87 |
88 | def forward(self, centers):
89 | return self.spatial_mlp(centers)
90 |
91 |
92 | class StructuralDescriptor(nn.Module):
93 |
94 | def __init__(self):
95 | super(StructuralDescriptor, self).__init__()
96 |
97 | self.FRC = FaceRotateConvolution()
98 | self.FKC = FaceKernelCorrelation(64, 0.2)
99 | self.structural_mlp = nn.Sequential(
100 | nn.Conv1d(64 + 3, 131, 1),
101 | nn.BatchNorm1d(131),
102 | nn.ReLU(),
103 | nn.Conv1d(131, 131, 1),
104 | nn.BatchNorm1d(131),
105 | nn.ReLU(),
106 | )
107 |
108 | def forward(self, corners, normals):
109 | structural_fea1 = self.FRC(corners)
110 | # structural_fea2 = self.FKC(normals, neighbor_index)
111 |
112 | return self.structural_mlp(torch.cat([structural_fea1, normals], 1))
113 |
114 |
115 | class MeshConvolution(nn.Module):
116 |
117 | def __init__(self, spatial_in_channel, structural_in_channel, spatial_out_channel, structural_out_channel):
118 | super(MeshConvolution, self).__init__()
119 |
120 | self.spatial_in_channel = spatial_in_channel
121 | self.structural_in_channel = structural_in_channel
122 | self.spatial_out_channel = spatial_out_channel
123 | self.structural_out_channel = structural_out_channel
124 |
125 | self.aggregation_method = 'Concat'
126 |
127 | self.combination_mlp = nn.Sequential(
128 | nn.Conv1d(self.spatial_in_channel +
129 | self.structural_in_channel, self.spatial_out_channel, 1),
130 | nn.BatchNorm1d(self.spatial_out_channel),
131 | nn.ReLU(),
132 | )
133 |
134 | # if self.aggregation_method == 'Concat':
135 | # self.concat_mlp = nn.Sequential(
136 | # nn.Conv2d(self.structural_in_channel * 2,
137 | # self.structural_in_channel, 1),
138 | # nn.BatchNorm2d(self.structural_in_channel),
139 | # nn.ReLU(),
140 | # )
141 |
142 | # self.concat_mlp = nn.Sequential(
143 | # nn.Conv1d(self.structural_in_channel,
144 | # self.structural_out_channel, 1),
145 | # nn.BatchNorm1d(self.spatial_out_channel),
146 | # nn.ReLU(),
147 | # )
148 |
149 | self.aggregation_mlp = nn.Sequential(
150 | nn.Conv1d(self.structural_in_channel + 64,
151 | self.structural_out_channel, 1),
152 | nn.BatchNorm1d(self.structural_out_channel),
153 | nn.ReLU(),
154 | )
155 |
156 | def forward(self, spatial_fea, structural_fea, shape_norm):
157 | b, _, n = spatial_fea.size()
158 |
159 | # Combination
160 | spatial_fea = self.combination_mlp(
161 | torch.cat([spatial_fea, structural_fea], 1))
162 |
163 | # # Aggregation
164 | # if self.aggregation_method == 'Concat':
165 | # structural_fea = torch.cat([structural_fea.unsqueeze(3).expand(-1, -1, -1, 3),
166 | # torch.gather(structural_fea.unsqueeze(3).expand(-1, -1, -1, 3), 2,
167 | # neighbor_index.unsqueeze(1).expand(-1, self.structural_in_channel,
168 | # -1, -1))], 1)
169 | # structural_fea = self.concat_mlp(structural_fea)
170 | # structural_fea = torch.max(structural_fea, 3)[0]
171 |
172 | # elif self.aggregation_method == 'Max':
173 | # structural_fea = torch.cat([structural_fea.unsqueeze(3),
174 | # torch.gather(structural_fea.unsqueeze(3).expand(-1, -1, -1, 3), 2,
175 | # neighbor_index.unsqueeze(1).expand(-1, self.structural_in_channel,
176 | # -1, -1))], 3)
177 | # structural_fea = torch.max(structural_fea, 3)[0]
178 |
179 | # elif self.aggregation_method == 'Average':
180 | # structural_fea = torch.cat([structural_fea.unsqueeze(3),
181 | # torch.gather(structural_fea.unsqueeze(3).expand(-1, -1, -1, 3), 2,
182 | # neighbor_index.unsqueeze(1).expand(-1, self.structural_in_channel,
183 | # -1, -1))], 3)
184 | # structural_fea = torch.sum(structural_fea, dim=3) / 4
185 |
186 | structural_fea = self.aggregation_mlp(
187 | torch.cat([structural_fea, shape_norm], dim=1)) # n 64+3+64
188 |
189 | return spatial_fea, structural_fea
190 |
--------------------------------------------------------------------------------
/models/layers/struct_conv.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | class FaceVectorConv(nn.Module):
6 | def __init__(self, output_channel=64):
7 | super(FaceVectorConv, self).__init__()
8 | self.rotate_mlp = nn.Sequential(
9 | nn.Conv1d(6, 32, 1),
10 | nn.BatchNorm1d(32),
11 | nn.ReLU(),
12 | nn.Conv1d(32, 32, 1),
13 | nn.BatchNorm1d(32),
14 | nn.ReLU()
15 | )
16 | self.fusion_mlp = nn.Sequential(
17 | nn.Conv1d(32, 64, 1),
18 | nn.BatchNorm1d(64),
19 | nn.ReLU(),
20 | nn.Conv1d(64, output_channel, 1),
21 | nn.BatchNorm1d(output_channel),
22 | nn.ReLU()
23 | )
24 |
25 | def forward(self, x, opt):
26 | '''
27 | x : batch_size *15 , 1024
28 | center ox oy oz norm
29 | 3 3 3 3 3
30 | '''
31 | data = x.view(opt.batch_size, -1, 15).transpose(1, 2)
32 | xy = data[:, 3:9]
33 | yz = data[:, 6:12]
34 | xz = torch.cat([data[:, 3:6], data[:, 9:12]], dim=1)
35 | face_line = (
36 | self.rotate_mlp(xy) +
37 | self.rotate_mlp(yz) +
38 | self.rotate_mlp(xz)
39 | ) / 3 # 64 , 64 , 1024
40 | return self.fusion_mlp(face_line)
41 |
42 |
43 | class PointConv(nn.Module):
44 | def __init__(self):
45 | super(PointConv, self).__init__()
46 |
47 | self.spatial_mlp = nn.Sequential(
48 | nn.Conv1d(3, 64, 1),
49 | nn.BatchNorm1d(64),
50 | nn.ReLU(),
51 | nn.Conv1d(64, 64, 1),
52 | nn.BatchNorm1d(64),
53 | nn.ReLU(),
54 | )
55 |
56 | def forward(self, x, opt):
57 | '''
58 | center ox oy oz norm
59 | '''
60 | data = x.view(opt.batch_size, -1, 15).transpose(1, 2)
61 | x = data[:, :3]
62 | return self.spatial_mlp(x)
63 |
64 |
65 | class MeshMlp(nn.Module):
66 | def __init__(self, opt, output_channel=256):
67 | super(MeshMlp, self).__init__()
68 | self.opt = opt
69 | self.pc = PointConv()
70 | self.fvc = FaceVectorConv()
71 | self.mlp = nn.Sequential(
72 | nn.Conv1d(131, output_channel, 1),
73 | nn.BatchNorm1d(output_channel),
74 | nn.ReLU(),
75 | nn.Conv1d(output_channel, output_channel, 1),
76 | nn.BatchNorm1d(output_channel),
77 | nn.ReLU()
78 | )
79 |
80 | def forward(self, x):
81 | # x: batch_size*15,1024
82 | data = x.view(self.opt.batch_size, -1, 15).transpose(1, 2)
83 | point_feature = self.pc(x, self.opt) # n 64 1024
84 | face_feature = self.fvc(x, self.opt) # n 64 1024
85 | norm = data[:, 12:]
86 | fusion_feature = torch.cat(
87 | [norm, point_feature, face_feature] # n 64+64+3 = 131 1024
88 | , dim=1
89 | )
90 | return self.mlp(fusion_feature) # n 256 1024
91 |
92 |
93 | class NormMlp(nn.Module):
94 | def __init__(self, opt):
95 | super(NormMlp, self).__init__()
96 | self.opt = opt
97 | self.extra_mlp = nn.Sequential(
98 | nn.Conv1d(3, 64, 1),
99 | nn.BatchNorm1d(64),
100 | nn.ReLU(),
101 | nn.Conv1d(64, 64, 1),
102 | nn.BatchNorm1d(64),
103 | nn.ReLU()
104 | )
105 |
106 | def forward(self, x):
107 | data = x.view(self.opt.batch_size, -1, 15).transpose(1, 2)
108 | norm = data[:, 12:]
109 | return self.extra_mlp(norm) # n 64 1024
110 |
--------------------------------------------------------------------------------
/models/mesh_graph.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import os
4 | import os.path as osp
5 | import numpy as np
6 | from . import networks
7 | import torch.optim as optim
8 | from models.optimizer import adabound
9 | from torch_geometric.utils import remove_self_loops, contains_self_loops, contains_isolated_nodes
10 | import hiddenlayer as hl
11 | import cv2
12 |
13 | ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
14 | OUTPUT_DIR = os.path.join(ROOT_DIR, "feature_img")
15 | print(OUTPUT_DIR)
16 | class mesh_graph:
17 | def __init__(self, opt):
18 | self.opt = opt
19 | self.cuda = opt.cuda
20 | self.is_train = opt.is_train
21 | self.device = torch.device('cuda:{}'.format(
22 | self.cuda[0]) if self.cuda else 'cpu')
23 | self.save_dir = osp.join(opt.ckpt_root, opt.name)
24 | self.optimizer = None
25 | self.loss = None
26 |
27 | # init mesh data
28 | self.nclasses = opt.nclasses
29 |
30 | # init network
31 | self.net = networks.get_net(opt)
32 | self.net.train(self.is_train)
33 |
34 | # criterion
35 | self.loss = networks.get_loss(self.opt).to(self.device)
36 |
37 | if self.is_train:
38 | # self.optimizer = adabound.AdaBound(
39 | # params=self.net.parameters(), lr=self.opt.lr, final_lr=self.opt.final_lr)
40 | self.optimizer = optim.SGD(self.net.parameters(
41 | ), lr=opt.lr, momentum=opt.momentum, weight_decay=opt.weight_decay)
42 | self.scheduler = networks.get_scheduler(self.optimizer, self.opt)
43 | if not self.is_train or opt.continue_train:
44 | self.load_state(opt.last_epoch)
45 |
46 | # A History object to store metrics
47 | self.history = hl.History()
48 | # A Canvas object to draw the metrics
49 | self.canvas = hl.Canvas()
50 |
51 | def test(self):
52 | """tests model
53 | returns: number correct and total number
54 | """
55 | with torch.no_grad():
56 | out, fea = self.forward()
57 | # compute number of correct
58 | pred_class = torch.max(out, dim=1)[1]
59 | # print(pred_class)
60 | label_class = self.labels
61 | correct = self.get_accuracy(pred_class, label_class)
62 | return correct, len(label_class)
63 |
64 | def get_accuracy(self, pred, labels):
65 | """computes accuracy for classification / segmentation """
66 | if self.opt.task == 'cls':
67 | correct = pred.eq(labels).sum()
68 | return correct
69 |
70 | def load_state(self, last_epoch):
71 | ''' load epoch '''
72 | load_file = '%s_net.pth' % last_epoch
73 | load_path = osp.join(self.save_dir, load_file)
74 | net = self.net
75 | if isinstance(net, torch.nn.DataParallel):
76 | net = net.module
77 | print('loading the model from %s' % load_path)
78 | state_dict = torch.load(load_path, map_location=str(self.device))
79 | if hasattr(state_dict, '_metadata'):
80 | del state_dict._metadata
81 | net.load_state_dict(state_dict)
82 |
83 | def set_input_data(self, data):
84 | '''set input data'''
85 |
86 | gt_label = data.y
87 | edge_index = data.edge_index
88 | if self.opt.batch_size == 1:
89 | nodes_features = data.x.unsqueeze(0)
90 | # neigbour_index = data.pos.unsqueeze(0)
91 | else:
92 | nodes_features = data.x.view(self.opt.batch_size, 1024, -1)
93 | neigbour_index = data.pos.view(self.opt.batch_size, -1, 3)
94 |
95 | self.labels = gt_label.to(self.device).long()
96 | self.edge_index = edge_index.to(self.device).long()
97 |
98 | self.neigbour_index = neigbour_index.to(self.device).long()
99 | self.centers = nodes_features[:, :, :3].transpose(1, 2)
100 | self.corners = nodes_features[:, :, 3:12].transpose(1, 2)
101 | self.normals = nodes_features[:, :, 12:].transpose(1, 2)
102 | self.x = data.x[:, -3:].to(self.device).float()
103 |
104 | def forward(self):
105 | out = self.net(self.x, self.edge_index, self.centers,
106 | self.corners, self.normals, self.neigbour_index)
107 | return out
108 |
109 | def backward(self, out, fea):
110 | self.loss_val = self.loss(out, self.labels)
111 | self.loss_val.backward()
112 |
113 | def optimize(self):
114 | '''
115 | optimize paramater
116 | '''
117 | self.optimizer.zero_grad()
118 | out, fea = self.forward()
119 | # print(self.net.module.concat_mlp[0].weight)
120 | self.backward(out, fea)
121 | nn.utils.clip_grad_norm_(self.net.parameters(), self.opt.grad_clip)
122 | self.optimizer.step()
123 |
124 | def save_network(self, which_epoch):
125 | '''
126 | save network to disk
127 | '''
128 | save_filename = '%s_net.pth' % (which_epoch)
129 | save_path = osp.join(self.save_dir, save_filename)
130 | if len(self.cuda) > 0 and torch.cuda.is_available():
131 | torch.save(self.net.module.cpu().state_dict(), save_path)
132 | self.net.cuda(self.cuda[0])
133 | else:
134 | torch.save(self.net.cpu().state_dict(), save_path)
135 |
136 | def log_history_and_plot(self, writer, epoch, batch):
137 | writer.history_log(epoch, batch, self.net.module.concat_mlp[0].weight)
138 | writer.draw_hist()
139 |
140 | def log_features_and_plot(self, epoch, batch):
141 | out, fea = self.forward()
142 | labels = self.labels
143 | fea_norm_numpy = fea.cpu().detach().numpy()
144 | cv_img = deprocess_image(fea_norm_numpy)
145 | im_color = cv2.applyColorMap(cv_img, cv2.COLORMAP_JET)
146 | cv2.imwrite(os.path.join(OUTPUT_DIR, "%d_epoch1.png" % epoch),im_color)
147 | # add canvas
148 |
149 | # self.history.log(epoch, image=im_color)
150 | # # self.canvas.draw_image(self.history["image"])
151 | # self.canvas.save(os.path.join(OUTPUT_DIR, "%d_epoch.png" % epoch))
152 |
153 | def deprocess_image(img):
154 | """ see https://github.com/jacobgil/keras-grad-cam/blob/master/grad-cam.py#L65 """
155 | img = img - np.mean(img)
156 | img = img / (np.std(img) + 1e-5)
157 | img = img * 0.1
158 | img = img + 0.5
159 | img = np.clip(img, 0, 1)
160 | return np.uint8(img*255)
--------------------------------------------------------------------------------
/models/networks.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.nn import init
4 | from torch.optim import lr_scheduler
5 | import torch.nn.functional as F
6 | from torch_geometric.nn import GraphConv
7 | from models.layers.struct_conv import MeshMlp, NormMlp
8 | from models.layers.mesh_net import SpatialDescriptor, StructuralDescriptor, MeshConvolution
9 | # ,SpatialDescriptor, StructuralDescriptor, MeshConvolution
10 | from models.layers.mesh_graph_conv import NormGraphConv, GINConv
11 |
12 |
13 | def init_weights(net, init_type, init_gain):
14 | def init_func(m):
15 | classname = m.__class__.__name__
16 | if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
17 | if init_type == 'normal':
18 | init.normal_(m.weight.data, 0.0, init_gain)
19 | elif init_type == 'xavier':
20 | init.xavier_normal_(m.weight.data, gain=init_gain)
21 | elif init_type == 'kaiming':
22 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
23 | elif init_type == 'orthogonal':
24 | init.orthogonal_(m.weight.data, gain=init_gain)
25 | else:
26 | raise NotImplementedError(
27 | 'initialization method [%s] is not implemented' % init_type)
28 | elif classname.find('BatchNorm2d') != -1:
29 | init.normal_(m.weight.data, 1.0, init_gain)
30 | init.constant_(m.bias.data, 0.0)
31 | net.apply(init_func)
32 |
33 |
34 | def init_net(net, init_type, init_gain, cuda_ids):
35 | if len(cuda_ids) > 0:
36 | assert(torch.cuda.is_available())
37 | net.cuda(cuda_ids[0])
38 | net = net.cuda()
39 | net = torch.nn.DataParallel(net, cuda_ids)
40 | if init_type != 'none':
41 | init_weights(net, init_type, init_gain)
42 | return net
43 |
44 |
45 | def get_net(opt):
46 | net = None
47 | if opt.arch == 'meshconv':
48 | net = MeshGraph(opt)
49 | else:
50 | raise NotImplementedError(
51 | 'model name [%s] is not implemented' % opt.arch)
52 | return init_net(net, opt.init_type, opt.init_gain, opt.cuda)
53 |
54 |
55 | def get_loss(opt):
56 | if opt.task == 'cls':
57 | loss = nn.CrossEntropyLoss()
58 | elif opt.task == 'seg':
59 | loss = nn.CrossEntropyLoss(ignore_index=-1)
60 | return loss
61 |
62 |
63 | def get_scheduler(optimizer, opt):
64 | if opt.lr_policy == 'step':
65 | scheduler = lr_scheduler.MultiStepLR(
66 | optimizer, gamma=opt.gamma, milestones=opt.milestones)
67 | else:
68 | return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy)
69 | return scheduler
70 |
71 |
72 | # class MeshGraph(nn.Module):
73 | # """Some Information about MeshGraph"""
74 |
75 | # def __init__(self, opt):
76 | # super(MeshGraph, self).__init__()
77 | # self.mesh_mlp_256 = MeshMlp(opt, 256)
78 | # self.gin_conv_256 = GINConv(self.mesh_mlp_256)
79 |
80 | # # self.graph_conv_64 = GraphConv(1024, 256)
81 | # # self.graph_conv_64 = GraphConv(256, 64)
82 |
83 | # self.classifier = nn.Sequential(
84 | # nn.Linear(256, 1024),
85 | # nn.ReLU(),
86 | # nn.Dropout(p=0.5),
87 | # nn.Linear(1024, 256),
88 | # nn.ReLU(),
89 | # nn.Dropout(p=0.5),
90 | # nn.Linear(256, 40),
91 | # )
92 |
93 | # if opt.use_fpm:
94 | # self.mesh_mlp_64 = MeshMlp(64)
95 | # self.mesh_mlp_128 = MeshMlp(128)
96 | # self.gin_conv_64 = GINConv(self.mesh_mlp_64)
97 | # self.gin_conv_128 = GINConv(self.mesh_mlp_128)
98 |
99 | # def forward(self, nodes_features, edge_index):
100 | # x = nodes_features
101 | # edge_index = edge_index
102 | # x1 = self.gin_conv_256(x, edge_index) # 64 256 1024
103 | # # x1 = x1.view(x1.size(0), -1)
104 | # # x1 = torch.max(x1, dim=2)[0]
105 | # return self.classifier(x1)
106 |
107 |
108 | # class MeshGraph(nn.Module):
109 | # """Some Information about MeshGraph"""
110 |
111 | # def __init__(self, opt):
112 | # super(MeshGraph, self).__init__()
113 | # self.spatial_descriptor = SpatialDescriptor()
114 | # self.structural_descriptor = StructuralDescriptor()
115 | # self.mesh_conv1 = MeshConvolution(64, 131, 256, 256)
116 | # self.mesh_conv2 = MeshConvolution(256, 256, 512, 512)
117 | # self.fusion_mlp = nn.Sequential(
118 | # nn.Conv1d(1024, 1024, 1),
119 | # nn.BatchNorm1d(1024),
120 | # nn.ReLU(),
121 | # )
122 | # self.concat_mlp = nn.Sequential(
123 | # nn.Conv1d(1792, 1024, 1),
124 | # nn.BatchNorm1d(1024),
125 | # nn.ReLU(),
126 | # )
127 | # self.classifier = nn.Sequential(
128 | # nn.Linear(1024, 512),
129 | # nn.ReLU(),
130 | # nn.Dropout(p=0.5),
131 | # nn.Linear(512, 256),
132 | # nn.ReLU(),
133 | # nn.Dropout(p=0.5),
134 | # nn.Linear(256, 40)
135 | # )
136 |
137 | # def forward(self, nodes_features, edge_index, centers, corners, normals, neigbour_index):
138 | # spatial_fea0 = self.spatial_descriptor(centers)
139 | # structural_fea0 = self.structural_descriptor(
140 | # corners, normals)
141 |
142 | # spatial_fea1, structural_fea1 = self.mesh_conv1(
143 | # spatial_fea0, structural_fea0)
144 | # spatial_fea2, structural_fea2 = self.mesh_conv2(
145 | # spatial_fea1, structural_fea1)
146 |
147 | # spatial_fea3 = self.fusion_mlp(
148 | # torch.cat([spatial_fea2, structural_fea2], 1))
149 |
150 | # fea = self.concat_mlp(
151 | # torch.cat([spatial_fea1, spatial_fea2, spatial_fea3], 1))
152 |
153 | # fea = torch.max(fea, dim=2)[0]
154 | # fea = fea.reshape(fea.size(0), -1)
155 | # fea = self.classifier[:-1](fea)
156 | # cls = self.classifier[-1:](fea)
157 | # return cls
158 |
159 |
160 | class MeshGraph(nn.Module):
161 | """Some Information about MeshGraph"""
162 |
163 | def __init__(self, opt):
164 | super(MeshGraph, self).__init__()
165 | self.opt = opt
166 | self.spatial_descriptor = SpatialDescriptor()
167 | self.structural_descriptor = StructuralDescriptor()
168 | self.shape_descriptor = NormGraphConv()
169 |
170 | self.mesh_conv1 = MeshConvolution(64, 131, 256, 256)
171 | self.mesh_conv2 = MeshConvolution(256, 256, 512, 512)
172 | self.fusion_mlp = nn.Sequential(
173 | nn.Conv1d(1024, 1024, 1),
174 | nn.BatchNorm1d(1024),
175 | nn.ReLU(),
176 | )
177 | self.concat_mlp = nn.Sequential(
178 | nn.Conv1d(1792, 1024, 1),
179 | nn.BatchNorm1d(1024),
180 | nn.ReLU(),
181 | )
182 | self.classifier = nn.Sequential(
183 | nn.Linear(1024, 512),
184 | nn.ReLU(),
185 | nn.Dropout(p=0.5),
186 | nn.Linear(512, 256),
187 | nn.ReLU(),
188 | nn.Dropout(p=0.5),
189 | nn.Linear(256, 40)
190 | )
191 |
192 | def forward(self, nodes_features, edge_index, centers, corners, normals, neigbour_index):
193 | shape_norm = self.shape_descriptor(
194 | nodes_features, edge_index, self.opt) # n 64 1024
195 |
196 | spatial_fea0 = self.spatial_descriptor(centers)
197 | structural_fea0 = self.structural_descriptor(
198 | corners, normals)
199 |
200 | spatial_fea1, structural_fea1 = self.mesh_conv1(
201 | spatial_fea0, structural_fea0, shape_norm)
202 | spatial_fea2, structural_fea2 = self.mesh_conv2(
203 | spatial_fea1, structural_fea1, shape_norm)
204 |
205 | spatial_fea3 = self.fusion_mlp(
206 | torch.cat([spatial_fea2, structural_fea2], 1))
207 |
208 | fea = self.concat_mlp(
209 | torch.cat([spatial_fea1, spatial_fea2, spatial_fea3], 1))
210 |
211 | fea = torch.max(fea, dim=2)[0]
212 | fea = fea.reshape(fea.size(0), -1)
213 | fea = self.classifier[:-1](fea)
214 | cls = self.classifier[-1:](fea)
215 | return cls, fea
216 |
--------------------------------------------------------------------------------
/models/optimizer/adabound.py:
--------------------------------------------------------------------------------
1 | '''
2 | @inproceedings{Luo2019AdaBound,
3 | author = {Luo, Liangchen and Xiong, Yuanhao and Liu, Yan and Sun, Xu},
4 | title = {Adaptive Gradient Methods with Dynamic Bound of Learning Rate},
5 | booktitle = {Proceedings of the 7th International Conference on Learning Representations},
6 | month = {May},
7 | year = {2019},
8 | address = {New Orleans, Louisiana}
9 | }
10 | '''
11 |
12 | import math
13 | import torch
14 | from torch.optim import Optimizer
15 |
16 |
17 | class AdaBound(Optimizer):
18 | """Implements AdaBound algorithm.
19 | It has been proposed in `Adaptive Gradient Methods with Dynamic Bound of Learning Rate`_.
20 | Arguments:
21 | params (iterable): iterable of parameters to optimize or dicts defining
22 | parameter groups
23 | lr (float, optional): Adam learning rate (default: 1e-3)
24 | betas (Tuple[float, float], optional): coefficients used for computing
25 | running averages of gradient and its square (default: (0.9, 0.999))
26 | final_lr (float, optional): final (SGD) learning rate (default: 0.1)
27 | gamma (float, optional): convergence speed of the bound functions (default: 1e-3)
28 | eps (float, optional): term added to the denominator to improve
29 | numerical stability (default: 1e-8)
30 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
31 | amsbound (boolean, optional): whether to use the AMSBound variant of this algorithm
32 | .. Adaptive Gradient Methods with Dynamic Bound of Learning Rate:
33 | https://openreview.net/forum?id=Bkg3g2R9FX
34 | """
35 |
36 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), final_lr=0.1, gamma=1e-3,
37 | eps=1e-8, weight_decay=0, amsbound=False):
38 | if not 0.0 <= lr:
39 | raise ValueError("Invalid learning rate: {}".format(lr))
40 | if not 0.0 <= eps:
41 | raise ValueError("Invalid epsilon value: {}".format(eps))
42 | if not 0.0 <= betas[0] < 1.0:
43 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
44 | if not 0.0 <= betas[1] < 1.0:
45 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
46 | if not 0.0 <= final_lr:
47 | raise ValueError("Invalid final learning rate: {}".format(final_lr))
48 | if not 0.0 <= gamma < 1.0:
49 | raise ValueError("Invalid gamma parameter: {}".format(gamma))
50 | defaults = dict(lr=lr, betas=betas, final_lr=final_lr, gamma=gamma, eps=eps,
51 | weight_decay=weight_decay, amsbound=amsbound)
52 | super(AdaBound, self).__init__(params, defaults)
53 |
54 | self.base_lrs = list(map(lambda group: group['lr'], self.param_groups))
55 |
56 | def __setstate__(self, state):
57 | super(AdaBound, self).__setstate__(state)
58 | for group in self.param_groups:
59 | group.setdefault('amsbound', False)
60 |
61 | def step(self, closure=None):
62 | """Performs a single optimization step.
63 | Arguments:
64 | closure (callable, optional): A closure that reevaluates the model
65 | and returns the loss.
66 | """
67 | loss = None
68 | if closure is not None:
69 | loss = closure()
70 |
71 | for group, base_lr in zip(self.param_groups, self.base_lrs):
72 | for p in group['params']:
73 | if p.grad is None:
74 | continue
75 | grad = p.grad.data
76 | if grad.is_sparse:
77 | raise RuntimeError(
78 | 'Adam does not support sparse gradients, please consider SparseAdam instead')
79 | amsbound = group['amsbound']
80 |
81 | state = self.state[p]
82 |
83 | # State initialization
84 | if len(state) == 0:
85 | state['step'] = 0
86 | # Exponential moving average of gradient values
87 | state['exp_avg'] = torch.zeros_like(p.data)
88 | # Exponential moving average of squared gradient values
89 | state['exp_avg_sq'] = torch.zeros_like(p.data)
90 | if amsbound:
91 | # Maintains max of all exp. moving avg. of sq. grad. values
92 | state['max_exp_avg_sq'] = torch.zeros_like(p.data)
93 |
94 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
95 | if amsbound:
96 | max_exp_avg_sq = state['max_exp_avg_sq']
97 | beta1, beta2 = group['betas']
98 |
99 | state['step'] += 1
100 |
101 | if group['weight_decay'] != 0:
102 | grad = grad.add(group['weight_decay'], p.data)
103 |
104 | # Decay the first and second moment running average coefficient
105 | exp_avg.mul_(beta1).add_(1 - beta1, grad)
106 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
107 | if amsbound:
108 | # Maintains the maximum of all 2nd moment running avg. till now
109 | torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
110 | # Use the max. for normalizing running avg. of gradient
111 | denom = max_exp_avg_sq.sqrt().add_(group['eps'])
112 | else:
113 | denom = exp_avg_sq.sqrt().add_(group['eps'])
114 |
115 | bias_correction1 = 1 - beta1 ** state['step']
116 | bias_correction2 = 1 - beta2 ** state['step']
117 | step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1
118 |
119 | # Applies bounds on actual learning rate
120 | # lr_scheduler cannot affect final_lr, this is a workaround to apply lr decay
121 | final_lr = group['final_lr'] * group['lr'] / base_lr
122 | lower_bound = final_lr * (1 - 1 / (group['gamma'] * state['step'] + 1))
123 | upper_bound = final_lr * (1 + 1 / (group['gamma'] * state['step']))
124 | step_size = torch.full_like(denom, step_size)
125 | step_size.div_(denom).clamp_(lower_bound, upper_bound).mul_(exp_avg)
126 |
127 | p.data.add_(-step_size)
128 |
129 | return loss
--------------------------------------------------------------------------------
/options/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JsBlueCat/MeshGraph/c7b331f64b70a442d1042351efe4dc4231e8b289/options/__init__.py
--------------------------------------------------------------------------------
/options/base_options.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import torch
4 | from util import util
5 |
6 |
7 | class base_options:
8 | def __init__(self):
9 | self.parser = argparse.ArgumentParser(
10 | formatter_class=argparse.ArgumentDefaultsHelpFormatter)
11 | self.is_init = False
12 |
13 | def initialize(self):
14 | # dataset
15 | self.parser.add_argument(
16 | '--datasets', required=True, help='dataset to be train/test'
17 | )
18 | self.parser.add_argument(
19 | '--task', choices={'cls', 'seg'}, default='cls', help='task for network to loader model'
20 | )
21 | self.parser.add_argument(
22 | '--nclasses', type=int, default=40, help='num classes for classify'
23 | )
24 | # general arg
25 | self.parser.add_argument(
26 | '--cuda', type=str, default='0', help='cuda device number e.g. 0 0,1,2, 0,2. use -1 for CPU'
27 | )
28 | self.parser.add_argument(
29 | '--ckpt_root', type=str, default='./ckpt_root', help='model saved path'
30 | )
31 | self.parser.add_argument(
32 | '--ckpt', type=str, default='./ckpt', help='final model saved path'
33 | )
34 | self.parser.add_argument(
35 | '--name', type=str, default='debug', help='model saved path'
36 | )
37 | # network parameter
38 | self.parser.add_argument(
39 | '--batch_size', type=int, default=64, help='mini batch size'
40 | )
41 | self.parser.add_argument(
42 | '--arch', type=str, default='meshconv', help='model arch'
43 | )
44 | self.parser.add_argument(
45 | '--init_type', type=str, default='xavier', help='network initialization [normal|xavier|kaiming|orthogonal]'
46 | )
47 | self.parser.add_argument(
48 | '--init_gain', type=float, default=0.02, help='scaling factor for normal, xavier and orthogonal.'
49 | )
50 | self.parser.add_argument(
51 | '--use_fpm', action='store_true', help='use fpm model to catch feature'
52 | )
53 | self.parser.add_argument(
54 | '--milestones', default='30,60', help='milestones for MultiStepLR'
55 | )
56 |
57 | def parse(self):
58 | if not self.is_init:
59 | self.initialize()
60 | self.opt, _ = self.parser.parse_known_args()
61 | self.opt.is_train = self.is_train # train or test
62 |
63 | cuda_ids = self.opt.cuda.split(',')
64 | self.opt.cuda = []
65 | self.opt.milestones = [int(x) for x in self.opt.milestones.split(',')]
66 |
67 | for cuda_id in cuda_ids:
68 | id = int(cuda_id)
69 | if id >= 0:
70 | self.opt.cuda.append(id)
71 | # set gpu id
72 | if len(self.opt.cuda) > 0:
73 | torch.cuda.set_device(self.opt.cuda[0])
74 |
75 | args = vars(self.opt)
76 |
77 | if self.is_train:
78 | print('------------ Options -------------')
79 | for k, v in sorted(args.items()):
80 | print('%s: %s' % (str(k), str(v)))
81 | print('-------------- End ----------------')
82 |
83 | # check dir
84 | util.check_dir(os.path.join(self.opt.ckpt_root, self.opt.name))
85 | util.check_dir(os.path.join(self.opt.ckpt, self.opt.name))
86 |
87 | # save train options
88 | expr_dir = os.path.join(self.opt.ckpt_root, self.opt.name)
89 | file_name = os.path.join(expr_dir, 'opt.txt')
90 |
91 | with open(file_name, 'wt') as opt_file:
92 | opt_file.write('------------ Options -------------\n')
93 | for k, v in sorted(args.items()):
94 | opt_file.write('%s: %s\n' % (str(k), str(v)))
95 | opt_file.write('-------------- End ----------------\n')
96 | return self.opt
97 |
--------------------------------------------------------------------------------
/options/test_options.py:
--------------------------------------------------------------------------------
1 | from .base_options import base_options
2 |
3 |
4 | class test_options(base_options):
5 | def initialize(self):
6 | base_options.initialize(self)
7 | self.parser.add_argument(
8 | '--last_epoch', default='latest', help='which epoch to load?'
9 | )
10 | self.is_train = False
11 |
--------------------------------------------------------------------------------
/options/train_options.py:
--------------------------------------------------------------------------------
1 | from .base_options import base_options
2 |
3 |
4 | class train_options(base_options):
5 | def initialize(self):
6 | base_options.initialize(self)
7 | self.parser.add_argument(
8 | '--continue_train', action='store_true', help='continue training: load the latest model'
9 | )
10 | self.parser.add_argument(
11 | '--last_epoch', default='latest', help='which epoch to load?'
12 | )
13 | # optimizer param
14 | self.parser.add_argument(
15 | '--lr', default=0.01, type=float, help='learning rate'
16 | )
17 | self.parser.add_argument(
18 | '--final_lr', default=0.1, type=float, help='final learning rate'
19 | )
20 | self.parser.add_argument(
21 | '--epoch', type=int, default=200, help='training epoch'
22 | )
23 | self.parser.add_argument(
24 | '--frequency', type=int, default=10, help='training epoch'
25 | )
26 | self.parser.add_argument(
27 | '--epoch_frequency', type=int, default=1, help='epoch to print'
28 | )
29 | self.parser.add_argument(
30 | '--loop_frequency', type=int, default=100, help='iters epoch to print'
31 | )
32 | self.parser.add_argument(
33 | '--test_frequency', type=int, default=1, help='test epoch'
34 | )
35 | self.parser.add_argument(
36 | '--lr_policy', default='step', type=str, help='learning rate policy: step|'
37 | )
38 | self.parser.add_argument(
39 | '--gamma', default=0.1, type=float, help='gamma for MultiStepLR'
40 | )
41 | self.parser.add_argument(
42 | '--momentum', default=0.9, type=float
43 | )
44 | self.parser.add_argument(
45 | '--weight_decay', default=0.0005, type=float
46 | )
47 | self.parser.add_argument('--grad_clip', type=float,
48 | default=5, help='gradient clipping')
49 |
50 | # model
51 | self.is_train = True
52 |
--------------------------------------------------------------------------------
/preprocess/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JsBlueCat/MeshGraph/c7b331f64b70a442d1042351efe4dc4231e8b289/preprocess/__init__.py
--------------------------------------------------------------------------------
/preprocess/preprocess.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.utils.cpp_extension import load
3 | from torch_geometric.utils import to_undirected, remove_self_loops
4 | # from models.layers.mesh_cpp_extension import Mesh
5 | from models.layers.mesh import Mesh
6 | import time
7 |
8 |
9 | def maybe_num_nodes(index, num_nodes=None):
10 | return index.max().item() + 1 if num_nodes is None else num_nodes
11 |
12 |
13 | def segregate_self_loops(edge_index, edge_attr=None):
14 | r"""Segregates self-loops from the graph.
15 |
16 | Args:
17 | edge_index (LongTensor): The edge indices.
18 | edge_attr (Tensor, optional): Edge weights or multi-dimensional
19 | edge features. (default: :obj:`None`)
20 |
21 | :rtype: (:class:`LongTensor`, :class:`Tensor`, :class:`LongTensor`,
22 | :class:`Tensor`)
23 | """
24 |
25 | mask = edge_index[0] != edge_index[1]
26 | inv_mask = ~mask
27 |
28 | loop_edge_index = edge_index[:, inv_mask]
29 | loop_edge_attr = None if edge_attr is None else edge_attr[inv_mask]
30 | edge_index = edge_index[:, mask]
31 | edge_attr = None if edge_attr is None else edge_attr[mask]
32 |
33 | return edge_index, edge_attr, loop_edge_index, loop_edge_attr
34 |
35 |
36 | def remove_isolated_nodes(edge_index, edge_attr=None, num_nodes=None):
37 | r"""Removes the isolated nodes from the graph given by :attr:`edge_index`
38 | with optional edge attributes :attr:`edge_attr`.
39 | In addition, returns a mask of shape :obj:`[num_nodes]` to manually filter
40 | out isolated node features later on.
41 | Self-loops are preserved for non-isolated nodes.
42 |
43 | Args:
44 | edge_index (LongTensor): The edge indices.
45 | edge_attr (Tensor, optional): Edge weights or multi-dimensional
46 | edge features. (default: :obj:`None`)
47 | num_nodes (int, optional): The number of nodes, *i.e.*
48 | :obj:`max_val + 1` of :attr:`edge_index`. (default: :obj:`None`)
49 |
50 | :rtype: (LongTensor, Tensor, ByteTensor)
51 | """
52 | num_nodes = maybe_num_nodes(edge_index, num_nodes)
53 |
54 | out = segregate_self_loops(edge_index, edge_attr)
55 | edge_index, edge_attr, loop_edge_index, loop_edge_attr = out
56 |
57 | mask = torch.zeros(num_nodes, dtype=torch.uint8, device=edge_index.device)
58 | mask[edge_index.view(-1)] = 1
59 |
60 | assoc = torch.full((num_nodes, ), -1, dtype=torch.long, device=mask.device)
61 | assoc[mask] = torch.arange(mask.sum(), device=assoc.device)
62 | edge_index = assoc[edge_index]
63 |
64 | loop_mask = torch.zeros_like(mask)
65 | loop_mask[loop_edge_index[0]] = 1
66 | loop_mask = loop_mask & mask
67 | loop_assoc = torch.full_like(assoc, -1)
68 | loop_assoc[loop_edge_index[0]] = torch.arange(
69 | loop_edge_index.size(1), device=loop_assoc.device)
70 | loop_idx = loop_assoc[loop_mask]
71 | loop_edge_index = assoc[loop_edge_index[:, loop_idx]]
72 |
73 | edge_index = torch.cat([edge_index, loop_edge_index], dim=1)
74 |
75 | if edge_attr is not None:
76 | loop_edge_attr = loop_edge_attr[loop_idx]
77 | edge_attr = torch.cat([edge_attr, loop_edge_attr], dim=0)
78 |
79 | return edge_index, edge_attr, mask
80 |
81 | class FaceToGraph(object):
82 | r"""Converts mesh faces :obj:`[3, num_faces]` to graph.
83 |
84 | Args:
85 | remove_faces (bool, optional): If set to :obj:`False`, the face tensor
86 | will not be removed.
87 | """
88 |
89 | def __init__(self, remove_faces=True):
90 | # self.mesh_graph_cpp = load(name='meshgraph_cpp', sources=[
91 | # 'models/layers/meshgraph.cpp'])
92 | self.remove_faces = remove_faces
93 | self.count = 0
94 |
95 | def __call__(self, data):
96 | start_time = time.time()
97 | data.num_nodes = data.x.size(0)
98 | edge_index = to_undirected(data.edge_index, data.num_nodes)
99 | # edge_index, _ = remove_self_loops(edge_index)
100 |
101 | print('%d-th mesh,size: %d' % (self.count, data.x.size(1)))
102 |
103 | data.edge_index = edge_index
104 | end_time = time.time()
105 | print('take {} s time for translate'.format(end_time-start_time))
106 | if self.remove_faces:
107 | data.face = None
108 | self.count += 1
109 | return data
110 |
111 | def __repr__(self):
112 | return '{}()'.format(self.__class__.__name__)
113 |
114 |
115 |
116 | # class FaceToGraph(object):
117 | # r"""Converts mesh faces :obj:`[3, num_faces]` to graph.
118 |
119 | # Args:
120 | # remove_faces (bool, optional): If set to :obj:`False`, the face tensor
121 | # will not be removed.
122 | # """
123 |
124 | # def __init__(self, remove_faces=True):
125 | # # self.mesh_graph_cpp = load(name='meshgraph_cpp', sources=[
126 | # # 'models/layers/meshgraph.cpp'])
127 | # self.remove_faces = remove_faces
128 | # self.count = 0
129 |
130 | # def __call__(self, data):
131 | # start_time = time.time()
132 | # print('start transform')
133 | # mesh_grap = Mesh(data.pos, data.face)
134 | # # set the center ox oy oz unit_norm
135 | # data.x = mesh_grap.nodes
136 | # print(data.x)
137 | # data.num_nodes = data.x.size(0)
138 | # edge_index = to_undirected(mesh_grap.edge_index.t(), data.num_nodes)
139 | # # edge_index, _ = remove_self_loops(edge_index)
140 |
141 | # print('%d-th mesh,size: %d' % (self.count, data.x.size(0)))
142 | # # set edge_index to data
143 | # data.edge_index = edge_index
144 | # end_time = time.time()
145 | # print('take {} s time for translate'.format(end_time-start_time))
146 | # mesh_grap = None
147 | # if self.remove_faces:
148 | # data.face = None
149 | # self.count += 1
150 | # return data
151 |
152 | # def __repr__(self):
153 | # return '{}()'.format(self.__class__.__name__)
154 |
155 |
156 | class FaceToEdge(object):
157 | r"""Converts mesh faces :obj:`[3, num_faces]` to edge indices
158 | :obj:`[2, num_edges]`.
159 |
160 | Args:
161 | remove_faces (bool, optional): If set to :obj:`False`, the face tensor
162 | will not be removed.
163 | """
164 |
165 | def __init__(self, remove_faces=True):
166 | self.remove_faces = remove_faces
167 | self.count = 0
168 |
169 | def __call__(self, data):
170 | print(self.count)
171 | face = data.face
172 |
173 | edge_index = torch.cat([face[:2], face[1:], face[::2]], dim=1)
174 | edge_index = to_undirected(edge_index, num_nodes=data.num_nodes)
175 |
176 | data.edge_index = edge_index
177 | if self.remove_faces:
178 | data.face = None
179 | self.count += 1
180 | return data
181 |
182 | def __repr__(self):
183 | return '{}()'.format(self.__class__.__name__)
184 |
--------------------------------------------------------------------------------
/script/human_seg/get_data.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | DATADIR='datasets' #location where data gets downloaded to
4 |
5 | # get data
6 | echo "downloading the data and putting it in: " $DATADIR
7 | mkdir -p $DATADIR && cd $DATADIR
8 | wget https://www.dropbox.com/s/s3n05sw0zg27fz3/human_seg.tar.gz
9 | tar -xzvf human_seg.tar.gz && rm human_seg.tar.gz
--------------------------------------------------------------------------------
/script/human_seg/get_pretrained.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 |
--------------------------------------------------------------------------------
/script/human_seg/test.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | ## run the test and export collapses
4 |
--------------------------------------------------------------------------------
/script/human_seg/train.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | ## run the training
4 |
--------------------------------------------------------------------------------
/script/human_seg/view.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | python util/mesh_viewer.py \
4 | --files \
5 | ckpt/human_seg/meshes/shrec__14_0.obj
--------------------------------------------------------------------------------
/script/modelnet10/train.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | ## run the training
4 | python train.py \
5 | --datasets datasets/modelnet10 \
6 | --name modelnet10 \
7 | --batch_size 1 \
8 | --nclasses 10
--------------------------------------------------------------------------------
/script/modelnet40_graph/test.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | ## run the test
4 | python test.py \
5 | --datasets datasets/modelnet40_graph \
6 | --name 40_graph \
7 | --batch_size 64 \
8 | --nclasses 40 \
9 | --last_epoch final
--------------------------------------------------------------------------------
/script/modelnet40_graph/train.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | ## run the training
4 | python train.py \
5 | --datasets datasets/modelnet40_graph \
6 | --name 40_graph \
7 | --batch_size 64 \
8 | --nclasses 40 \
9 | --epoch 300 \
10 | --init_type orthogonal
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | from options.test_options import test_options
2 | from data.ModelNet import ModelNet
3 | from preprocess.preprocess import FaceToGraph
4 | from torch_geometric.data import DataLoader
5 | from models import create_model
6 | from util.writer import Writer
7 |
8 |
9 | def run_test(epoch=-1):
10 | print('Running Test')
11 | opt = test_options().parse()
12 | dataset = ModelNet(root=opt.datasets, name='40_graph', train=False,
13 | pre_transform=FaceToGraph(remove_faces=True))
14 | loader = DataLoader(dataset, batch_size=opt.batch_size, shuffle=False)
15 | model = create_model(opt)
16 | writer = Writer(opt)
17 | writer.reset_counter()
18 | for i, data in enumerate(loader):
19 | if data.y.size(0) % 64 != 0:
20 | continue
21 | model.set_input_data(data)
22 | ncorrect, nexamples = model.test()
23 | writer.update_counter(ncorrect, nexamples)
24 | writer.print_acc(epoch, writer.acc)
25 | return writer.acc
26 |
27 |
28 | if __name__ == '__main__':
29 | run_test()
30 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import time
2 | from options.train_options import train_options
3 | from data.ModelNet import ModelNet
4 | from preprocess.preprocess import FaceToGraph, FaceToEdge
5 | from torch_geometric.data import DataLoader
6 | from models import create_model
7 | from util.writer import Writer
8 | from test import run_test
9 |
10 | if __name__ == '__main__':
11 | opt = train_options().parse()
12 |
13 | # load dataset
14 | dataset = ModelNet(root=opt.datasets, name=str(opt.name),
15 | pre_transform=FaceToGraph(remove_faces=True))
16 | print('# training meshes = %d' % len(dataset))
17 | loader = DataLoader(dataset, batch_size=opt.batch_size, shuffle=True)
18 | model = create_model(opt)
19 | writer = Writer(opt)
20 | total_steps = 0
21 |
22 | for epoch in range(1, opt.epoch):
23 | start_time = time.time()
24 | count = 0
25 | running_loss = 0.0
26 | for i, data in enumerate(loader):
27 | # break
28 | if data.y.size(0) % 64 != 0:
29 | continue
30 | total_steps += opt.batch_size
31 | count += opt.batch_size
32 | model.set_input_data(data)
33 | model.optimize()
34 | running_loss += model.loss_val
35 | if total_steps % opt.frequency == 0:
36 | loss_val = running_loss/opt.frequency
37 | writer.print_loss(epoch, count, loss_val)
38 | writer.plot_loss(epoch, count, loss_val, len(dataset))
39 | running_loss = 0
40 |
41 | if i % opt.loop_frequency == 0:
42 | print('saving the latest model (epoch %d, total_steps %d)' %
43 | (epoch, total_steps))
44 | model.save_network('latest')
45 | # break
46 |
47 | if epoch % opt.epoch_frequency == 0:
48 | print('saving the model at the end of epoch %d, iters %d' %
49 | (epoch, total_steps))
50 | if (epoch-1) % 20 == 0:
51 | model.log_history_and_plot(writer, epoch, count)
52 | model.log_features_and_plot(epoch, count)
53 | model.save_network('latest')
54 | model.save_network(epoch)
55 |
56 | if epoch % opt.test_frequency == 0:
57 | acc = run_test(epoch)
58 | writer.plot_acc(acc, epoch)
59 | # break
60 | wait = input("input")
61 | writer.close()
62 |
--------------------------------------------------------------------------------
/util/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JsBlueCat/MeshGraph/c7b331f64b70a442d1042351efe4dc4231e8b289/util/__init__.py
--------------------------------------------------------------------------------
/util/mesh_viewer.py:
--------------------------------------------------------------------------------
1 | import mpl_toolkits.mplot3d as a3
2 | import matplotlib.colors as colors
3 | import pylab as pl
4 | import numpy as np
5 |
6 | V = np.array
7 |
8 |
9 | def r2h(x): return colors.rgb2hex(tuple(map(lambda y: y / 255., x)))
10 |
11 |
12 | surface_color = r2h((255, 230, 205))
13 | edge_color = r2h((90, 90, 90))
14 | edge_colors = (r2h((15, 167, 175)), r2h((230, 81, 81)), r2h((142, 105, 252)), r2h((248, 235, 57)),
15 | r2h((51, 159, 255)), r2h((225, 117, 231)), r2h((97, 243, 185)), r2h((161, 183, 196)))
16 |
17 |
18 | def init_plot():
19 | ax = pl.figure().add_subplot(111, projection='3d')
20 | # hide axis, thank to
21 | # https://stackoverflow.com/questions/29041326/3d-plot-with-matplotlib-hide-axes-but-keep-axis-labels/
22 | ax.w_xaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))
23 | ax.w_yaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))
24 | ax.w_zaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))
25 | # Get rid of the spines
26 | ax.w_xaxis.line.set_color((1.0, 1.0, 1.0, 0.0))
27 | ax.w_yaxis.line.set_color((1.0, 1.0, 1.0, 0.0))
28 | ax.w_zaxis.line.set_color((1.0, 1.0, 1.0, 0.0))
29 | # Get rid of the ticks
30 | ax.set_xticks([])
31 | ax.set_yticks([])
32 | ax.set_zticks([])
33 | return (ax, [np.inf, -np.inf, np.inf, -np.inf, np.inf, -np.inf])
34 |
35 |
36 | def update_lim(mesh, plot):
37 | vs = mesh[0]
38 | for i in range(3):
39 | plot[1][2 * i] = min(plot[1][2 * i], vs[:, i].min())
40 | plot[1][2 * i + 1] = max(plot[1][2 * i], vs[:, i].max())
41 | return plot
42 |
43 |
44 | def update_plot(mesh, plot):
45 | if plot is None:
46 | plot = init_plot()
47 | return update_lim(mesh, plot)
48 |
49 |
50 | def surfaces(mesh, plot):
51 | vs, faces, edges = mesh
52 | vtx = vs[faces]
53 | edgecolor = edge_color if not len(edges) else 'none'
54 | tri = a3.art3d.Poly3DCollection(vtx, facecolors=surface_color + '55', edgecolors=edgecolor,
55 | linewidths=.5, linestyles='dashdot')
56 | plot[0].add_collection3d(tri)
57 | return plot
58 |
59 |
60 | def segments(mesh, plot):
61 | vs, _, edges = mesh
62 | for edge_c, edge_group in enumerate(edges):
63 | for edge_idx in edge_group:
64 | edge = vs[edge_idx]
65 | line = a3.art3d.Line3DCollection(
66 | [edge], linewidths=.5, linestyles='dashdot')
67 | line.set_color(edge_colors[edge_c % len(edge_colors)])
68 | plot[0].add_collection3d(line)
69 | return plot
70 |
71 |
72 | def plot_mesh(mesh, *whats, show=True, plot=None):
73 | for what in [update_plot] + list(whats):
74 | plot = what(mesh, plot)
75 | if show:
76 | li = max(plot[1][1], plot[1][3], plot[1][5])
77 | plot[0].auto_scale_xyz([0, li], [0, li], [0, li])
78 | pl.tight_layout()
79 | pl.show()
80 | return plot
81 |
82 |
83 | def parse_obje(obj_file, scale_by):
84 | vs = []
85 | faces = []
86 | edges = []
87 |
88 | def add_to_edges():
89 | if edge_c >= len(edges):
90 | for _ in range(len(edges), edge_c + 1):
91 | edges.append([])
92 | edges[edge_c].append(edge_v)
93 |
94 | def fix_vertices():
95 | nonlocal vs, scale_by
96 | vs = V(vs)
97 | z = vs[:, 2].copy()
98 | vs[:, 2] = vs[:, 1]
99 | vs[:, 1] = z
100 | max_range = 0
101 | for i in range(3):
102 | min_value = np.min(vs[:, i])
103 | max_value = np.max(vs[:, i])
104 | max_range = max(max_range, max_value - min_value)
105 | vs[:, i] -= min_value
106 | if not scale_by:
107 | scale_by = max_range
108 | vs /= scale_by
109 |
110 | with open(obj_file) as f:
111 | for line in f:
112 | line = line.strip()
113 | splitted_line = line.split()
114 | if not splitted_line:
115 | continue
116 | elif splitted_line[0] == 'v':
117 | vs.append([float(v) for v in splitted_line[1:]])
118 | elif splitted_line[0] == 'f':
119 | faces.append([int(c) - 1 for c in splitted_line[1:]])
120 | elif splitted_line[0] == 'e':
121 | if len(splitted_line) >= 4:
122 | edge_v = [int(c) - 1 for c in splitted_line[1:-1]]
123 | edge_c = int(splitted_line[-1])
124 | add_to_edges()
125 |
126 | vs = V(vs)
127 | fix_vertices()
128 | faces = V(faces, dtype=int)
129 | edges = [V(c, dtype=int) for c in edges]
130 | return (vs, faces, edges), scale_by
131 |
132 |
133 | def view_meshes(*files, offset=.2):
134 | plot = None
135 | max_x = 0
136 | scale = 0
137 | for file in files:
138 | mesh, scale = parse_obje(file, scale)
139 | max_x_current = mesh[0][:, 0].max()
140 | mesh[0][:, 0] += max_x + offset
141 | plot = plot_mesh(mesh, surfaces, segments,
142 | plot=plot, show=file == files[-1])
143 | max_x += max_x_current + offset
144 |
145 |
146 | if __name__ == '__main__':
147 | import argparse
148 | parser = argparse.ArgumentParser("view meshes")
149 | parser.add_argument('--files', nargs='+', default=['/meshgraph/ckpt_root/airplane_0124.obj',
150 | #'/meshgraph/ckpt_root/bed_0010.obj',
151 | #'/meshgraph/ckpt_root/cup_0007.obj',
152 | #'/meshgraph/ckpt_root/person_0004.obj',
153 | #'/meshgraph/ckpt_root/toilet_0006.obj',
154 | #'/meshgraph/ckpt_root/tv_stand_0004.obj',
155 | #'/meshgraph/ckpt_root/xbox_0005.obj'
156 | ], type=str,
157 | help="list of 1 or more .obj files")
158 | args = parser.parse_args()
159 |
160 | # view meshes
161 | view_meshes(*args.files)
162 |
--------------------------------------------------------------------------------
/util/util.py:
--------------------------------------------------------------------------------
1 | import os
2 | import os.path as osp
3 | import numpy as np
4 |
5 | def check_dir(dir):
6 | if not osp.exists(dir):
7 | os.makedirs(dir)
8 |
9 |
10 |
--------------------------------------------------------------------------------
/util/writer.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | import random
4 | import numpy as np
5 | import torch
6 | import torchvision.models
7 | import torch.nn as nn
8 | from torchvision import datasets, transforms
9 | import hiddenlayer as hl
10 |
11 | try:
12 | from tensorboardX import SummaryWriter
13 | except ImportError as error:
14 | raise('tensorboardX is not available, please install it.')
15 | SummaryWriter = None
16 |
17 |
18 | class Writer:
19 | def __init__(self, opt):
20 | self.opt = opt
21 | self.name = opt.name
22 | self.save_path = os.path.join(opt.ckpt_root, opt.name)
23 | self.train_loss = os.path.join(self.save_path, 'train_loss.txt')
24 | self.test_loss = os.path.join(self.save_path, 'test_loss.txt')
25 |
26 | # set display
27 | if opt.is_train and SummaryWriter is not None:
28 | self.display = SummaryWriter() # comment=opt.name
29 | else:
30 | self.display = None
31 |
32 | self.start_logs()
33 | self.nexamples = 0
34 | self.ncorrect = 0
35 |
36 | # A History object to store metrics
37 | self.history = hl.History()
38 |
39 | # A Canvas object to draw the metrics
40 | self.canvas = hl.Canvas()
41 |
42 | def start_logs(self):
43 | ''' create log file'''
44 | if self.opt.is_train:
45 | with open(self.train_loss, 'a') as train_loss:
46 | now = time.strftime('%c')
47 | train_loss.write(
48 | '================ Training Loss (%s) ================\n' % now)
49 | else:
50 | with open(self.test_loss, 'a') as test_loss:
51 | now = time.strftime('%c')
52 | test_loss.write(
53 | '================ Test Loss (%s) ================\n' % now)
54 |
55 | def reset_counter(self):
56 | """
57 | counts # of correct examples
58 | """
59 | self.ncorrect = 0
60 | self.nexamples = 0
61 |
62 | def update_counter(self, ncorrect, nexamples):
63 | self.ncorrect += ncorrect
64 | self.nexamples += nexamples
65 |
66 | def print_loss(self, epoch, iters, loss):
67 | print('epoch : %d, iter : %d , loss : %.3f' %
68 | (epoch, iters, loss.item()))
69 | with open(self.train_loss, 'a') as train_loss:
70 | train_loss.write('epoch : %d, iter : %d , loss : %.3f\n' %
71 | (epoch, iters, loss.item()))
72 |
73 | def plot_loss(self, epoch, i, loss, n):
74 | train_data_iter = i + (epoch-1) * n
75 | if self.display:
76 | self.display.add_scalar(
77 | 'data/train_loss', loss.item(), train_data_iter)
78 |
79 | def plot_acc(self, acc, epoch):
80 | if self.display:
81 | self.display.add_scalar('data/test_acc', acc, epoch-1)
82 |
83 | def print_acc(self, epoch, acc):
84 | """ prints test accuracy to terminal / file """
85 | message = 'epoch: {}, TEST ACC: [{:.5} %]\n' \
86 | .format(epoch, acc * 100)
87 | print(message)
88 | with open(self.test_loss, "a") as log_file:
89 | log_file.write('%s\n' % message)
90 |
91 | def history_log(self, epoch, batch, weight):
92 | self.history.log((epoch, batch), global_feature_wight=weight)
93 |
94 | def draw_hist(self):
95 | self.canvas.draw_hist(self.history["global_feature_wight"])
96 |
97 | @property
98 | def acc(self):
99 | return float(self.ncorrect) / self.nexamples
100 |
101 | def close(self):
102 | if self.display is not None:
103 | self.display.close()
104 |
--------------------------------------------------------------------------------