├── .gitattributes
├── .gitignore
├── CITATION.bib
├── LICENSE
├── README.md
├── checkpoints
└── .gitkeep
├── conf
├── config.yaml
├── dataset
│ ├── mini.yaml
│ ├── munich.yaml
│ └── nuremberg.yaml
└── hydra
│ └── job_logging
│ └── custom.yaml
├── data
└── .gitkeep
├── dataset.py
├── docs
└── architecture.png
├── download.py
├── environment.yml
├── network
├── __init__.py
├── convonet
│ ├── ResnetBlockFC.py
│ ├── __init__.py
│ ├── pointnet.py
│ ├── unet.py
│ └── unet3d.py
├── decoder.py
├── encoder.py
├── gnn.py
├── loss.py
└── polygnn.py
├── reconstruct.py
├── remap.py
├── stats.py
├── test.py
├── train.py
└── utils.py
/.gitattributes:
--------------------------------------------------------------------------------
1 | # Auto detect text files and perform LF normalization
2 | * text=auto
3 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # celery beat schedule file
95 | celerybeat-schedule
96 |
97 | # SageMath parsed files
98 | *.sage.py
99 |
100 | # Environments
101 | .env
102 | .venv
103 | env/
104 | venv/
105 | ENV/
106 | env.bak/
107 | venv.bak/
108 |
109 | # Spyder project settings
110 | .spyderproject
111 | .spyproject
112 |
113 | # Rope project settings
114 | .ropeproject
115 |
116 | # mkdocs documentation
117 | /site
118 |
119 | # mypy
120 | .mypy_cache/
121 | .dmypy.json
122 | dmypy.json
123 |
124 | # Pyre type checker
125 | .pyre/
126 |
127 | # Pycharm
128 | .idea
129 |
130 | # macOS
131 | .DS_Store
132 |
133 | # results
134 | outputs/
135 |
136 | # data
137 | data/
138 |
139 | # checkpoints
140 | checkpoints/
141 |
142 | # wandb
143 | wandb/
144 |
145 | # plots
146 | *.pdf
147 |
148 |
--------------------------------------------------------------------------------
/CITATION.bib:
--------------------------------------------------------------------------------
1 | @article{chen2024polygnn,
2 | title = {PolyGNN: Polyhedron-based graph neural network for 3D building reconstruction from point clouds},
3 | journal = {ISPRS Journal of Photogrammetry and Remote Sensing},
4 | volume = {218},
5 | pages = {693-706},
6 | year = {2024},
7 | issn = {0924-2716},
8 | doi = {https://doi.org/10.1016/j.isprsjprs.2024.09.031},
9 | url = {https://www.sciencedirect.com/science/article/pii/S0924271624003691},
10 | author = {Zhaiyu Chen and Yilei Shi and Liangliang Nan and Zhitong Xiong and Xiao Xiang Zhu},
11 | }
12 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 Zhaiyu Chen
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # PolyGNN
2 |
3 | -----------
4 | [](https://www.sciencedirect.com/science/article/pii/S0924271624003691)
5 | [](https://www.sciencedirect.com/science/article/pii/S0924271624003691/pdfft?md5=3d0d8b3b72cdd3f4c809d714b1292137&pid=1-s2.0-S0924271624003691-main.pdf)
6 | [](https://raw.githubusercontent.com/chenzhaiyu/polygnn/main/LICENSE)
7 |
8 | PolyGNN is an implementation of the paper [*PolyGNN: Polyhedron-based Graph Neural Network for 3D Building Reconstruction from Point Clouds*](https://www.sciencedirect.com/science/article/pii/S0924271624003691).
9 | PolyGNN learns a piecewise planar occupancy function, supported by polyhedral decomposition, for efficient and scalable 3D building reconstruction.
10 |
11 |
12 |
13 |
14 |
15 | ## 🛠️ Setup
16 |
17 | ### Repository
18 |
19 | Clone the repository:
20 |
21 | ```bash
22 | git clone https://github.com/chenzhaiyu/polygnn && cd polygnn
23 | ```
24 |
25 | ### All-in-one installation
26 |
27 | Create a conda environment with all dependencies:
28 |
29 | ```bash
30 | conda env create -f environment.yml && conda activate polygnn
31 | ```
32 |
33 | ### Manual installation
34 |
35 | Still easy! Create a conda environment and install [mamba](https://github.com/mamba-org/mamba) for faster parsing:
36 | ```bash
37 | conda create --name polygnn python=3.10 && conda activate polygnn
38 | conda install mamba -c conda-forge
39 | ```
40 |
41 | Install the required dependencies:
42 | ```
43 | mamba install pytorch torchvision sage=10.0 pytorch-cuda=11.7 pyg=2.3 pytorch-scatter pytorch-sparse pytorch-cluster torchmetrics rtree -c pyg -c pytorch -c nvidia -c conda-forge
44 | pip install abspy==0.2.6 hydra-core hydra-colorlog omegaconf trimesh tqdm wandb plyfile
45 | ```
46 |
47 | ## 🚀 Usage
48 |
49 | ### Quick start
50 |
51 | Download the mini dataset and pretrained weights:
52 |
53 | ```python
54 | python download.py dataset=mini
55 | ```
56 | In case you encounter issues (e.g., Google Drive limits), manually download the data and weights [here](https://drive.google.com/drive/folders/1fAwvhGtOgS8f4IldE1J4v5s0438WM24b?usp=sharing), then extract them into `./checkpoints/mini` and `./data/mini`, respectively.
57 | The mini dataset contains 200 random instances (~0.07% of the full dataset).
58 |
59 | Train PolyGNN on the mini dataset (provided for your reference and is not intended for full-scale training):
60 | ```python
61 | python train.py dataset=mini
62 | ```
63 | The data will be automatically preprocessed the first time you initiate training.
64 |
65 | Evaluate PolyGNN with option to save predictions:
66 | ```python
67 | python test.py dataset=mini evaluate.save=true
68 | ```
69 |
70 | Generate meshes from predictions:
71 | ```python
72 | python reconstruct.py dataset=mini reconstruct.type=mesh
73 | ```
74 |
75 | Remap meshes to their original CRS:
76 | ```python
77 | python remap.py dataset=mini
78 | ```
79 |
80 | Generate reconstruction statistics:
81 | ```python
82 | python stats.py dataset=mini
83 | ```
84 |
85 | ### Available configurations
86 |
87 | ```python
88 | # check available configurations for training
89 | python train.py --cfg job
90 |
91 | # check available configurations for evaluation
92 | python test.py --cfg job
93 | ```
94 | Alternatively, review the configuration file: `conf/config.yaml`.
95 |
96 | ### Full dataset
97 |
98 | The Munich dataset is available for download on [Zenodo](https://zenodo.org/records/14254264). Note that it requires 332 GB of storage when decompressed. Meshes for CRS remapping can be downloaded [here](https://drive.google.com/file/d/1hn11XMqyoPUnq-9WGfAwQq47uuUvcbi7/view?usp=drive_link).
99 |
100 | ### Custom data
101 |
102 | PolyGNN requires polyhedron-based graphs as input. To prepare this from your own point clouds:
103 | 1. Extract planar primitives using tools such as [Easy3D](https://github.com/LiangliangNan/Easy3D) or [GoCoPP](https://github.com/Ylannl/GoCoPP), preferably in [VertexGroup](https://abspy.readthedocs.io/en/latest/vertexgroup.html) format.
104 | 2. Build [CellComplex](https://abspy.readthedocs.io/en/latest/api.html#abspy.CellComplex) from the primitives using [abspy](https://github.com/chenzhaiyu/abspy). Example code:
105 | ```python
106 | from abspy import VertexGroup, CellComplex
107 | vertex_group = VertexGroup(vertex_group_path, quiet=True)
108 | cell_complex = CellComplex(vertex_group.planes, vertex_group.aabbs,
109 | vertex_group.points_grouped, build_graph=True, quiet=True)
110 | cell_complex.prioritise_planes(prioritise_verticals=True)
111 | cell_complex.construct()
112 | cell_complex.save(complex_path)
113 | ```
114 | Alternatively, you can modify [`CityDataset`](https://github.com/chenzhaiyu/polygnn/blob/67addd77a6be1d100448e3bd7523babfa063d0dd/dataset.py#L157) or [`TestOnlyDataset`](https://github.com/chenzhaiyu/polygnn/blob/67addd77a6be1d100448e3bd7523babfa063d0dd/dataset.py#L276) to accept inputs directly from [VertexGroup](https://abspy.readthedocs.io/en/latest/vertexgroup.html) or [VertexGroupReference](https://abspy.readthedocs.io/en/latest/api.html#abspy.VertexGroupReference).
115 | 3. Structure your dataset similarly to the provided mini dataset:
116 | ```bash
117 | YOUR_DATASET_NAME
118 | └── raw
119 | ├── 03_meshes
120 | │ ├── DEBY_LOD2_104572462.obj
121 | │ ├── DEBY_LOD2_104575306.obj
122 | │ └── DEBY_LOD2_104575493.obj
123 | ├── 04_pts
124 | │ ├── DEBY_LOD2_104572462.npy
125 | │ ├── DEBY_LOD2_104575306.npy
126 | │ └── DEBY_LOD2_104575493.npy
127 | ├── 05_complexes
128 | │ ├── DEBY_LOD2_104572462.cc
129 | │ ├── DEBY_LOD2_104575306.cc
130 | │ └── DEBY_LOD2_104575493.cc
131 | ├── testset.txt
132 | └── trainset.txt
133 | ```
134 | 4. To train or evaluate PolyGNN using your dataset, run the following commands:
135 | ```python
136 | # start training
137 | python train.py dataset=YOUR_DATASET_NAME
138 |
139 | # start evaluation
140 | python test.py dataset=YOUR_DATASET_NAME
141 | ```
142 | For evaluation only, you can instantiate your dataset as a [`TestOnlyDataset`](https://github.com/chenzhaiyu/polygnn/blob/67addd77a6be1d100448e3bd7523babfa063d0dd/dataset.py#L276), as in [this line](https://github.com/chenzhaiyu/polygnn/blob/94ffc9e45f0721653038bd91f33f1d4eafeab7cb/test.py#L178).
143 |
144 | ## 👷 TODOs
145 |
146 | - [x] Demo with mini data and pretrained weights
147 | - [x] Short tutorial for getting started
148 | - [x] Host the full dataset
149 |
150 | ## 🎓 Citation
151 |
152 | If you use PolyGNN in a scientific work, please consider citing the paper:
153 |
154 | ```bibtex
155 | @article{chen2024polygnn,
156 | title = {PolyGNN: Polyhedron-based graph neural network for 3D building reconstruction from point clouds},
157 | journal = {ISPRS Journal of Photogrammetry and Remote Sensing},
158 | volume = {218},
159 | pages = {693-706},
160 | year = {2024},
161 | issn = {0924-2716},
162 | doi = {https://doi.org/10.1016/j.isprsjprs.2024.09.031},
163 | url = {https://www.sciencedirect.com/science/article/pii/S0924271624003691},
164 | author = {Zhaiyu Chen and Yilei Shi and Liangliang Nan and Zhitong Xiong and Xiao Xiang Zhu},
165 | }
166 | ```
167 |
168 | The synthetic point clouds are simulated with [pyhelios](https://github.com/chenzhaiyu/pyhelios).
169 | You might also want to check out [abspy](https://github.com/chenzhaiyu/abspy) for 3D adaptive binary space partitioning and [Points2Poly](https://github.com/chenzhaiyu/points2poly) for reconstruction with deep implicit fields.
170 |
--------------------------------------------------------------------------------
/checkpoints/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenzhaiyu/polygnn/6acf63fce040368424cf2a01c978dd2fcad1d831/checkpoints/.gitkeep
--------------------------------------------------------------------------------
/conf/config.yaml:
--------------------------------------------------------------------------------
1 | # Default configuration for the Project.
2 | # Values shall be overriden by respective dataset configurations.
3 |
4 | # default settings
5 | defaults:
6 | - _self_
7 | - dataset: munich
8 | - override hydra/job_logging: custom
9 | - override hydra/hydra_logging: colorlog
10 |
11 | # general path settings
12 | run_suffix: ''
13 | data_dir: './data/${dataset}${dataset_suffix}'
14 | complex_dir: '${data_dir}/raw/05_complexes'
15 | reference_dir: '${data_dir}/raw/03_meshes_global_test'
16 | output_dir: './outputs/${dataset}${run_suffix}'
17 | remap_dir: '${output_dir}/global'
18 | checkpoint_dir: './checkpoints/${dataset}${run_suffix}'
19 | checkpoint_path: '${checkpoint_dir}/model_best.pth'
20 | csv_path: '${output_dir}/evaluation.csv'
21 |
22 | # network settings
23 | encoder: ConvONet # {PointNet, PointNet2, PointCNN, DGCNN, PointTransformerConv, RandLANet, ConvONet}
24 | decoder: ConvONet # {MLP, ConvONet}
25 | gnn: TAGCN # {null, GCN, TransformerGCN, TAGCN}
26 | latent_dim_light: 256 # latent dimension (256) for plain encoder-decoder
27 | latent_dim_conv: 4096 # latent dimension (4096) for convolutional encoder-decoder
28 | use_spatial_transformer: false
29 | convonet_kwargs:
30 | unet: True
31 | unet_kwargs:
32 | depth: 4
33 | merge_mode: 'concat'
34 | start_filts: 32
35 | plane_resolution: 128
36 | plane_type: ['xz', 'xy', 'yz']
37 |
38 | # training settings
39 | warm: false # warm start from checkpoint
40 | warm_optimizer: true # load optimizer from checkpoint if available
41 | warm_scheduler: true # load scheduler from checkpoint if available
42 | freeze_stages: [] # [encoder, decoder, gnn]
43 | gpu_ids: [5, 6, 7, 8]
44 | gpu_freeze: false
45 | weight_decay: 1e-6
46 | num_epochs: 50
47 | save_interval: 1
48 | dropout: false
49 | validate: true
50 | seed: 1117
51 | batch_size: 64
52 | num_workers: 32
53 | loss: bce # {bce, focal}
54 | lr: 1e-3
55 | scheduler:
56 | base_lr: 1e-4
57 | max_lr: ${lr}
58 | step_size_up: 4400
59 | mode: triangular2
60 |
61 | # ddp settings
62 | master_addr: localhost
63 | master_port: 12345
64 |
65 | # dataset settings (shall be overwritten)
66 | shuffle: true
67 | class_weights: [1., 1.]
68 | sample:
69 | strategy: random # {null, fps, random, grid}
70 | transform: true # on-the-fly sampling (may introduce randomness)
71 | pre_transform: false # sampling as pre-transform
72 | duplicate: true # effective only for random sampling
73 | length: 4096 # effective only for random sampling
74 | resolutions: [0.05, 0.01, 0.005] # effective only for grid sampling and progressive training
75 | resolution: 0.01 # one of the previous
76 | ratio: 0.3 # effective only for fps sampling
77 |
78 | # evaluation settings
79 | evaluate:
80 | score: true # compute accuracy score in evaluation
81 | save: false # save prediction as numpy file
82 | seal: true # seal non-watertight model with bounding volume
83 | num_samples: 10000
84 |
85 | # reconstruction settings
86 | reconstruct:
87 | type: mesh # {cell, mesh}
88 | scale: true
89 | seal: false # seal with bounding volume
90 | translate: true
91 | offset: [0, 0, 0] # [-653200, -5478800, 0]
92 |
93 | # hydra settings
94 | hydra:
95 | run:
96 | dir: ./outputs
97 | verbose: false
98 |
99 | # wandb settings
100 | wandb: true
101 | wandb_dir: './outputs'
102 |
--------------------------------------------------------------------------------
/conf/dataset/mini.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | # dataset settings
4 | dataset: 'mini'
5 | num_queries: 16
6 | dataset_suffix: ''
7 | url_dataset: {'default': 'https://drive.google.com/uc?export=download&id=1Yzlw4QGnbhybPbhVrMZ27Fr8qhA51mWZ'}
8 | url_checkpoint: {'default': 'https://drive.usercontent.google.com/download?id=1IivVlSGmTUBa8TDnmDKV24VaoWZBGT1M&export=download&authuser=0&confirm=t'}
--------------------------------------------------------------------------------
/conf/dataset/munich.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | # dataset settings
4 | dataset: 'munich'
5 | num_queries: 16
6 | dataset_suffix: ''
7 | url_dataset: {'default': null}
8 | url_checkpoint: {'default': 'https://drive.usercontent.google.com/download?id=1IivVlSGmTUBa8TDnmDKV24VaoWZBGT1M&export=download&authuser=0&confirm=t'}
--------------------------------------------------------------------------------
/conf/dataset/nuremberg.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | # dataset settings
4 | dataset: 'nuremberg'
5 | num_queries: 16
6 | dataset_suffix: ''
7 | reference_dir: '${data_dir}/raw/03_meshes_global'
8 | reconstruct:
9 | offset: [-653200, -5478800, 0]
--------------------------------------------------------------------------------
/conf/hydra/job_logging/custom.yaml:
--------------------------------------------------------------------------------
1 | # @package hydra.job_logging
2 | # python logging configuration for tasks
3 | version: 1
4 | formatters:
5 | simple:
6 | format: '[%(asctime)s][%(name)s][%(levelname)s] - %(message)s'
7 | colorlog:
8 | '()': 'colorlog.ColoredFormatter'
9 | format: '[%(cyan)s%(asctime)s%(reset)s][%(blue)s%(name)s%(reset)s][%(log_color)s%(levelname)s%(reset)s] - %(message)s'
10 | log_colors:
11 | DEBUG: purple
12 | INFO: green
13 | WARNING: yellow
14 | ERROR: red
15 | CRITICAL: red
16 | handlers:
17 | console:
18 | class: logging.StreamHandler
19 | formatter: colorlog
20 | stream: ext://sys.stdout
21 | file:
22 | class: logging.FileHandler
23 | formatter: simple
24 | # relative to the job log directory
25 | filename: ${hydra.run.dir}/${hydra.job.name}.log
26 | root:
27 | level: INFO
28 | handlers: [console, file]
29 |
30 | disable_existing_loggers: false
--------------------------------------------------------------------------------
/data/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenzhaiyu/polygnn/6acf63fce040368424cf2a01c978dd2fcad1d831/data/.gitkeep
--------------------------------------------------------------------------------
/dataset.py:
--------------------------------------------------------------------------------
1 | """
2 | Definitions of polyhedral graph data structure and datasets.
3 | """
4 |
5 | import os
6 | import logging
7 | import glob
8 | import collections
9 | import multiprocessing
10 | from pathlib import Path
11 |
12 | from tqdm import tqdm
13 | import numpy as np
14 | import torch
15 | from torch_geometric.data import Data, Dataset
16 | from abspy import VertexGroup, VertexGroupReference, CellComplex
17 |
18 | from utils import edge_index_from_dict, index_to_mask
19 |
20 | logger = logging.getLogger('dataset')
21 |
22 |
23 | class PolyGraph:
24 | """
25 | Cell-based graph data structure.
26 | """
27 |
28 | def __init__(self, use_reference=False, num_queries=None):
29 | self.vertex_group = None
30 | self.cell_complex = None
31 | self.vertex_group_reference = None
32 | self.use_reference = use_reference
33 | self.num_queries = num_queries
34 |
35 | def cell_adjacency(self):
36 | """
37 | Create adjacency among cells.
38 | """
39 | # mapping gaped adjacency indices to contiguous ones
40 | adj = self.cell_complex.graph.adj
41 | uid = list(self.cell_complex.graph.nodes)
42 | mapping = {c: i for i, c in enumerate(uid)}
43 | adj_ = collections.defaultdict(set)
44 | for key in adj:
45 | adj_[mapping[key]] = {mapping[value] for value in adj[key]}
46 |
47 | # graph edge index in COO format
48 | return edge_index_from_dict(adj_)
49 |
50 | def cell_labels(self, mesh_path):
51 | """
52 | Labels of cells, one-hot encoding.
53 | """
54 | labels = np.zeros(self.cell_complex.num_cells).astype(np.int64)
55 |
56 | # cells inside reference mesh
57 | cells_in_mesh = self.cell_complex.cells_in_mesh(mesh_path, engine='distance')
58 |
59 | for cell in cells_in_mesh:
60 | labels[cell] = 1
61 |
62 | return torch.tensor(labels)
63 |
64 | def data_loader(self, cloud_path, mesh_path=None, complex_path=None, vertex_group_path=None):
65 | """
66 | Load bvg file and obj file in network readable format.
67 | """
68 | if complex_path is not None and os.path.exists(complex_path):
69 | # load existing cell complex
70 | import pickle
71 | with open(complex_path, 'rb') as handle:
72 | self.cell_complex = pickle.load(handle)
73 | else:
74 | # construct cell complex
75 | if not self.use_reference:
76 | # load point cloud as vertex group
77 | if vertex_group_path:
78 | self.vertex_group = VertexGroup(vertex_group_path, quiet=True)
79 | # initialise cell complex from planar primitives
80 | self.cell_complex = CellComplex(self.vertex_group.planes, self.vertex_group.aabbs,
81 | self.vertex_group.points_grouped, build_graph=True, quiet=True)
82 | else:
83 | # cannot process vertex group from points alone
84 | raise NotImplementedError
85 | else:
86 | # load mesh as vertex group reference
87 | self.vertex_group_reference = VertexGroupReference(mesh_path, quiet=True)
88 | # initialise cell complex from planar primitives
89 | self.cell_complex = CellComplex(np.array(self.vertex_group_reference.planes),
90 | np.array(self.vertex_group_reference.aabbs),
91 | build_graph=True, quiet=True)
92 |
93 | # prioritise certain planes (e.g., vertical ones)
94 | self.cell_complex.prioritise_planes(prioritise_verticals=True)
95 |
96 | try:
97 | # construct cell complex
98 | self.cell_complex.construct()
99 | except (AssertionError, IndexError) as e:
100 | logger.error(f'Error [{e}] occurred with {cloud_path}.')
101 | return
102 |
103 | # save cell complex to CC files
104 | if complex_path is not None:
105 | Path(complex_path).parent.mkdir(exist_ok=True)
106 | self.cell_complex.save(complex_path)
107 |
108 | # points
109 | if cloud_path is not None:
110 | # npy and vg may contain different point sets
111 | points = np.load(cloud_path)
112 | else:
113 | points = self.vertex_group.points
114 |
115 | # queries
116 | queries = np.array(self.cell_complex.cell_representatives(location='skeleton', num=self.num_queries))
117 |
118 | # cell adjacency
119 | adjacency = self.cell_adjacency()
120 |
121 | # cell ground truth labels
122 | if mesh_path:
123 | labels = self.cell_labels(mesh_path)
124 | else:
125 | labels = None
126 |
127 | # construct data for pytorch geometric
128 | data = Data(x=None, edge_index=adjacency, y=labels)
129 |
130 | # store sizes
131 | len_cells = queries.shape[0]
132 | len_points = len(points)
133 | data.num_nodes = len_cells
134 | data.num_points = len_points
135 |
136 | # store points and queries
137 | data.points = torch.as_tensor(points, dtype=torch.float)
138 | data.queries = torch.as_tensor(queries, dtype=torch.float)
139 |
140 | # batch indices of points
141 | data.batch_points = torch.zeros(len_points, dtype=torch.long)
142 |
143 | # specify masks
144 | data.train_mask = index_to_mask(range(len_cells), size=len_cells)
145 | data.val_mask = index_to_mask(range(len_cells), size=len_cells)
146 | data.test_mask = index_to_mask(range(len_cells), size=len_cells)
147 |
148 | # name for reference
149 | data.name = Path(cloud_path).stem
150 |
151 | # validate data
152 | data.validate(raise_on_error=True)
153 |
154 | return data
155 |
156 |
157 | class CityDataset(Dataset):
158 | """
159 | Base building dataset. Applies to Munich and Nuremberg.
160 | """
161 |
162 | def __init__(self, root, name=None, split=None, num_workers=1, num_queries=16, **kwargs):
163 | self.name = name
164 | self.split = split
165 | self.num_workers = num_workers
166 | self.cloud_suffix = '.npy'
167 | self.mesh_suffix = '.obj'
168 | self.complex_suffix = '.cc'
169 | self.num_queries = num_queries
170 | super().__init__(root, **kwargs) # this line calls download() and process()
171 |
172 | @property
173 | def raw_dir(self) -> str:
174 | return os.path.join(self.root, 'raw')
175 |
176 | @property
177 | def raw_file_names(self):
178 | with open(os.path.join(self.raw_dir, f'{self.split}set.txt'), 'r') as f:
179 | return f.read().splitlines()
180 |
181 | @property
182 | def processed_file_names(self):
183 | return [f'data_{self.split}_{i}.pt' for i in range(len(self.raw_file_names))]
184 |
185 | def download(self):
186 | pass
187 |
188 | def thread(self, kwargs):
189 | """
190 | Process one file.
191 | """
192 | path_save = os.path.join(self.processed_dir, f'data_{kwargs["split"]}_{kwargs["index"]}.pt')
193 | if os.path.exists(path_save):
194 | return
195 | logger.debug(f'processing {Path(kwargs["cloud"]).stem}')
196 | try:
197 | data = PolyGraph(use_reference=True, num_queries=self.num_queries).data_loader(kwargs['cloud'],
198 | kwargs['mesh'],
199 | kwargs['complex'])
200 | except (ValueError, IndexError, EOFError) as e:
201 | logger.error(f'error with file {kwargs["mesh"]}: {e}')
202 | return
203 | if self.pre_transform is not None:
204 | data = self.pre_transform(data)
205 | if data is not None:
206 | torch.save(data, path_save)
207 |
208 | def process(self):
209 | """
210 | Start multiprocessing.
211 | """
212 | with open(os.path.join(self.raw_dir, 'trainset.txt'), 'r') as f_train:
213 | filenames_train = f_train.read().splitlines()
214 | with open(os.path.join(self.raw_dir, 'testset.txt'), 'r') as f_test:
215 | filenames_test = f_test.read().splitlines()
216 |
217 | args = []
218 | for i, filename_train in enumerate(filenames_train):
219 | cloud_train = os.path.join(self.raw_dir, '04_pts', filename_train + self.cloud_suffix)
220 | mesh_train = os.path.join(self.raw_dir, '03_meshes', filename_train + self.mesh_suffix)
221 | complex_train = os.path.join(self.raw_dir, '05_complexes', filename_train + self.complex_suffix)
222 | args.append(
223 | {'index': i, 'split': 'train', 'cloud': cloud_train, 'mesh': mesh_train, 'complex': complex_train})
224 |
225 | for j, filename_test in enumerate(filenames_test):
226 | cloud_test = os.path.join(self.raw_dir, '04_pts', filename_test + self.cloud_suffix)
227 | mesh_test = os.path.join(self.raw_dir, '03_meshes', filename_test + self.mesh_suffix)
228 | complex_test = os.path.join(self.raw_dir, '05_complexes', filename_test + self.complex_suffix)
229 | args.append(
230 | {'index': j, 'split': 'test', 'cloud': cloud_test, 'mesh': mesh_test, 'complex': complex_test})
231 |
232 | with multiprocessing.Pool(
233 | processes=self.num_workers if self.num_workers else multiprocessing.cpu_count()) as pool:
234 | # call with multiprocessing
235 | for _ in tqdm(pool.imap_unordered(self.thread, args), desc='Preparing dataset', total=len(args)):
236 | pass
237 |
238 | def len(self):
239 | return len(self.processed_file_names)
240 |
241 | def get(self, idx):
242 | data = torch.load(os.path.join(self.processed_dir, f'data_{self.split}_{idx}.pt'))
243 | # to disable UserWarning: Unable to accurately infer 'num_nodes' from the attribute set
244 | # '{'points', 'train_mask', 'val_mask', 'queries', 'test_mask', 'edge_index', 'y'}'
245 | data.num_nodes = len(data.y)
246 | return data
247 |
248 |
249 | class HelsinkiDataset(CityDataset):
250 | """
251 | Helsinki dataset.
252 | """
253 |
254 | def __init__(self, **kwargs):
255 | super().__init__(**kwargs)
256 |
257 | @property
258 | def processed_file_names(self):
259 | """
260 | Modified processed filenames due to discontinuity.
261 | """
262 | return [os.path.basename(filename) for filename in
263 | glob.glob(os.path.join(self.processed_dir, f'data_{self.split}_*.pt'))]
264 |
265 | def get(self, idx):
266 | """
267 | Modified data retrieval due to discontinuity.
268 | """
269 | data = torch.load(self.processed_paths[idx])
270 | # to disable UserWarning: Unable to accurately infer 'num_nodes' from the attribute set
271 | # '{'points', 'train_mask', 'val_mask', 'queries', 'test_mask', 'edge_index', 'y'}'
272 | data.num_nodes = len(data.y)
273 | return data
274 |
275 |
276 | class TestOnlyDataset(CityDataset):
277 | """
278 | Test-only dataset.
279 | """
280 |
281 | def __init__(self, **kwargs):
282 | super().__init__(**kwargs)
283 |
284 | def process(self):
285 | """
286 | Start multiprocessing.
287 | """
288 | with open(os.path.join(self.raw_dir, 'testset.txt'), 'r') as f_test:
289 | filenames_test = f_test.read().splitlines()
290 |
291 | args = []
292 |
293 | for j, filename_test in enumerate(filenames_test):
294 | cloud_test = os.path.join(self.raw_dir, '04_pts', filename_test + self.cloud_suffix)
295 | mesh_test = os.path.join(self.raw_dir, '03_meshes', filename_test + self.mesh_suffix)
296 | complex_test = os.path.join(self.raw_dir, '05_complexes', filename_test + self.complex_suffix)
297 | args.append(
298 | {'index': j, 'split': 'test', 'cloud': cloud_test, 'mesh': mesh_test, 'complex': complex_test})
299 |
300 | with multiprocessing.Pool(
301 | processes=self.num_workers if self.num_workers else multiprocessing.cpu_count()) as pool:
302 | # call with multiprocessing
303 | for _ in tqdm(pool.imap(self.thread, args), desc='Preparing dataset', total=len(args)):
304 | pass
305 |
--------------------------------------------------------------------------------
/docs/architecture.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenzhaiyu/polygnn/6acf63fce040368424cf2a01c978dd2fcad1d831/docs/architecture.png
--------------------------------------------------------------------------------
/download.py:
--------------------------------------------------------------------------------
1 | """
2 | Download datasets and/or models from public urls.
3 | """
4 |
5 | import os
6 | import tarfile
7 | import urllib.request
8 |
9 | from tqdm import tqdm
10 | import hydra
11 | from omegaconf import DictConfig
12 |
13 |
14 | def my_hook(t):
15 | """
16 | Wraps tqdm instance.
17 | Don't forget to close() or __exit__()
18 | the tqdm instance once you're done with it (easiest using `with` syntax).
19 | https://github.com/tqdm/tqdm/blob/master/examples/tqdm_wget.py
20 |
21 | Example
22 | -------
23 | with tqdm(...) as t:
24 | ... reporthook = my_hook(t)
25 | ... urllib.urlretrieve(..., reporthook=reporthook)
26 | """
27 | last_b = [0]
28 |
29 | def update_to(b=1, bsize=1, tsize=None):
30 | """
31 | b : int, optional
32 | Number of blocks transferred so far [default: 1].
33 | bsize : int, optional
34 | Size of each block (in tqdm units) [default: 1].
35 | tsize : int, optional
36 | Total size (in tqdm units). If [default: None] remains unchanged.
37 | """
38 | if tsize is not None:
39 | t.total = tsize
40 | t.update((b - last_b[0]) * bsize)
41 | last_b[0] = b
42 |
43 | return update_to
44 |
45 |
46 | @hydra.main(config_path='./conf', config_name='config', version_base='1.2')
47 | def download(cfg: DictConfig):
48 | """
49 | Download datasets and/or models from public urls.
50 |
51 | Parameters
52 | ----------
53 | cfg: DictConfig
54 | Hydra configuration
55 | """
56 | data_url, checkpoint_url = cfg.url_dataset['default'], cfg.url_checkpoint['default']
57 | data_dir, checkpoint_dir = cfg.data_dir, cfg.checkpoint_dir
58 | data_file, checkpoint_file = (os.path.join(data_dir, f'{cfg.dataset}.tar.gz'), f'{cfg.checkpoint_path}')
59 |
60 | os.makedirs(data_dir, exist_ok=True)
61 | os.makedirs(checkpoint_dir, exist_ok=True)
62 |
63 | if data_url is not None:
64 | with tqdm(unit='B', unit_scale=True, unit_divisor=1024, miniters=1, desc=data_file) as t:
65 | urllib.request.urlretrieve(data_url, filename=data_file, reporthook=my_hook(t), data=None)
66 | with tarfile.open(data_file, 'r:gz') as tar:
67 | tar.extractall(data_dir)
68 | os.remove(data_file)
69 |
70 | if checkpoint_url is not None:
71 | with tqdm(unit='B', unit_scale=True, unit_divisor=1024, miniters=1, desc=checkpoint_file) as t:
72 | urllib.request.urlretrieve(checkpoint_url, filename=checkpoint_file, reporthook=my_hook(t), data=None)
73 |
74 |
75 | if __name__ == '__main__':
76 | download()
77 |
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: polygnn
2 | channels:
3 | - pyg
4 | - pytorch
5 | - nvidia
6 | - conda-forge
7 | - defaults
8 | dependencies:
9 | - defaults::python=3.10
10 | - pytorch
11 | - torchvision
12 | - sage=10.0
13 | - pytorch-cuda=11.7
14 | - pyg=2.3
15 | - pytorch-scatter
16 | - pytorch-sparse
17 | - pytorch-cluster
18 | - torchmetrics
19 | - rtree
20 | - pip
21 | - pip:
22 | - abspy==0.2.6
23 | - hydra-core
24 | - hydra-colorlog
25 | - omegaconf
26 | - trimesh
27 | - tqdm
28 | - wandb
29 | - plyfile
30 |
--------------------------------------------------------------------------------
/network/__init__.py:
--------------------------------------------------------------------------------
1 | from .encoder import Encoder
2 | from .decoder import Decoder
3 | from .gnn import GNN
4 | from .polygnn import PolyGNN
5 | from .loss import bce_loss, focal_loss
6 |
--------------------------------------------------------------------------------
/network/convonet/ResnetBlockFC.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 |
4 | # Resnet Blocks
5 | class ResnetBlockFC(nn.Module):
6 | ''' Fully connected ResNet Block class.
7 |
8 | Args:
9 | size_in (int): input dimension
10 | size_out (int): output dimension
11 | size_h (int): hidden dimension
12 | '''
13 |
14 | def __init__(self, size_in, size_out=None, size_h=None):
15 | super().__init__()
16 | # Attributes
17 | if size_out is None:
18 | size_out = size_in
19 |
20 | if size_h is None:
21 | size_h = min(size_in, size_out)
22 |
23 | self.size_in = size_in
24 | self.size_h = size_h
25 | self.size_out = size_out
26 | # Submodules
27 | self.fc_0 = nn.Linear(size_in, size_h)
28 | self.fc_1 = nn.Linear(size_h, size_out)
29 | self.actvn = nn.ReLU()
30 |
31 | if size_in == size_out:
32 | self.shortcut = None
33 | else:
34 | self.shortcut = nn.Linear(size_in, size_out, bias=False)
35 | # Initialization
36 | nn.init.zeros_(self.fc_1.weight)
37 |
38 | def forward(self, x):
39 | net = self.fc_0(self.actvn(x))
40 | dx = self.fc_1(self.actvn(net))
41 |
42 | if self.shortcut is not None:
43 | x_s = self.shortcut(x)
44 | else:
45 | x_s = x
46 |
47 | return x_s + dx
48 |
--------------------------------------------------------------------------------
/network/convonet/__init__.py:
--------------------------------------------------------------------------------
1 | from network.convonet import pointnet
2 |
3 |
4 | encoder_dict = {
5 | 'pointnet_local_pool': pointnet.LocalPoolPointnet,
6 | }
7 |
--------------------------------------------------------------------------------
/network/convonet/pointnet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch_scatter import scatter_mean, scatter_max
4 |
5 | from utils import coordinate2index, normalize_coordinate, normalize_3d_coordinate, map2local
6 | from network.convonet.ResnetBlockFC import ResnetBlockFC
7 | from network.convonet.unet import UNet
8 | from network.convonet.unet3d import UNet3D
9 |
10 |
11 | class LocalPoolPointnet(nn.Module):
12 | """ PointNet-based encoder network with ResNet blocks for each point.
13 | Number of input points are fixed.
14 |
15 | Args:
16 | c_dim (int): dimension of latent code c
17 | dim (int): input points dimension
18 | hidden_dim (int): hidden dimension of the network
19 | scatter_type (str): feature aggregation when doing local pooling
20 | unet (bool): weather to use U-Net
21 | unet_kwargs (dict): U-Net parameters
22 | unet3d (bool): weather to use 3D U-Net
23 | unet3d_kwargs (str): 3D U-Net parameters
24 | plane_resolution (int): defined resolution for plane feature
25 | grid_resolution (int): defined resolution for grid feature
26 | plane_type (str or list): feature type, 'xz' - 1-plane, ['xz', 'xy', 'yz'] - 3-plane, ['grid'] - 3D grid volume
27 | padding (float): conventional padding paramter of ONet for unit cube, so [-0.5, 0.5] -> [-0.55, 0.55]
28 | n_blocks (int): number of blocks ResNetBlockFC layers
29 | """
30 |
31 | def __init__(self, c_dim=128, dim=3, hidden_dim=128, scatter_type='max',
32 | unet=False, unet_kwargs=None, unet3d=False, unet3d_kwargs=None,
33 | plane_resolution=None, grid_resolution=None, plane_type='xz', padding=0.1, n_blocks=5):
34 | super().__init__()
35 | self.c_dim = c_dim
36 |
37 | self.fc_pos = nn.Linear(dim, 2 * hidden_dim)
38 | self.blocks = nn.ModuleList([
39 | ResnetBlockFC(2 * hidden_dim, hidden_dim) for i in range(n_blocks)
40 | ])
41 | self.fc_c = nn.Linear(hidden_dim, c_dim)
42 |
43 | self.actvn = nn.ReLU()
44 | self.hidden_dim = hidden_dim
45 |
46 | if unet:
47 | self.unet = UNet(c_dim, in_channels=c_dim, **unet_kwargs)
48 | else:
49 | self.unet = None
50 |
51 | if unet3d:
52 | self.unet3d = UNet3D(**unet3d_kwargs)
53 | else:
54 | self.unet3d = None
55 |
56 | self.reso_plane = plane_resolution
57 | self.reso_grid = grid_resolution
58 | self.plane_type = plane_type
59 | self.padding = padding
60 |
61 | if scatter_type == 'max':
62 | self.scatter = scatter_max
63 | elif scatter_type == 'mean':
64 | self.scatter = scatter_mean
65 | else:
66 | raise ValueError('incorrect scatter type')
67 |
68 | def generate_plane_features(self, p, c, plane='xz'):
69 | # acquire indices of features in plane
70 | xy = normalize_coordinate(p.clone(), plane=plane, padding=self.padding) # normalize to the range of (0, 1)
71 | index = coordinate2index(xy, self.reso_plane)
72 |
73 | # scatter plane features from points
74 | fea_plane = c.new_zeros(p.size(0), self.c_dim, self.reso_plane ** 2)
75 | c = c.permute(0, 2, 1) # B x 512 x T
76 | fea_plane = scatter_mean(c, index, out=fea_plane) # B x 512 x reso^2
77 | fea_plane = fea_plane.reshape(p.size(0), self.c_dim, self.reso_plane,
78 | self.reso_plane) # sparce matrix (B x 512 x reso x reso)
79 |
80 | # process the plane features with UNet
81 | if self.unet is not None:
82 | fea_plane = self.unet(fea_plane)
83 |
84 | return fea_plane
85 |
86 | def generate_grid_features(self, p, c):
87 | p_nor = normalize_3d_coordinate(p.clone(), padding=self.padding)
88 | index = coordinate2index(p_nor, self.reso_grid, coord_type='3d')
89 | # scatter grid features from points
90 | fea_grid = c.new_zeros(p.size(0), self.c_dim, self.reso_grid ** 3)
91 | c = c.permute(0, 2, 1)
92 | fea_grid = scatter_mean(c, index, out=fea_grid) # B x C x reso^3
93 | fea_grid = fea_grid.reshape(p.size(0), self.c_dim, self.reso_grid, self.reso_grid,
94 | self.reso_grid) # sparce matrix (B x 512 x reso x reso)
95 |
96 | if self.unet3d is not None:
97 | fea_grid = self.unet3d(fea_grid)
98 |
99 | return fea_grid
100 |
101 | def pool_local(self, xy, index, c):
102 | bs, fea_dim = c.size(0), c.size(2)
103 | keys = xy.keys()
104 |
105 | c_out = 0
106 | for key in keys:
107 | # scatter plane features from points
108 | if key == 'grid':
109 | fea = self.scatter(c.permute(0, 2, 1), index[key], dim_size=self.reso_grid ** 3)
110 | else:
111 | fea = self.scatter(c.permute(0, 2, 1), index[key], dim_size=self.reso_plane ** 2)
112 | if self.scatter == scatter_max:
113 | fea = fea[0]
114 | # gather feature back to points
115 | fea = fea.gather(dim=2, index=index[key].expand(-1, fea_dim, -1))
116 | c_out += fea
117 | return c_out.permute(0, 2, 1)
118 |
119 | def forward(self, p):
120 | batch_size, T, D = p.size()
121 |
122 | # acquire the index for each point
123 | coord = {}
124 | index = {}
125 | if 'xz' in self.plane_type:
126 | coord['xz'] = normalize_coordinate(p.clone(), plane='xz', padding=self.padding)
127 | index['xz'] = coordinate2index(coord['xz'], self.reso_plane)
128 | if 'xy' in self.plane_type:
129 | coord['xy'] = normalize_coordinate(p.clone(), plane='xy', padding=self.padding)
130 | index['xy'] = coordinate2index(coord['xy'], self.reso_plane)
131 | if 'yz' in self.plane_type:
132 | coord['yz'] = normalize_coordinate(p.clone(), plane='yz', padding=self.padding)
133 | index['yz'] = coordinate2index(coord['yz'], self.reso_plane)
134 | if 'grid' in self.plane_type:
135 | coord['grid'] = normalize_3d_coordinate(p.clone(), padding=self.padding)
136 | index['grid'] = coordinate2index(coord['grid'], self.reso_grid, coord_type='3d')
137 |
138 | net = self.fc_pos(p)
139 |
140 | net = self.blocks[0](net)
141 | for block in self.blocks[1:]:
142 | pooled = self.pool_local(coord, index, net)
143 | net = torch.cat([net, pooled], dim=2)
144 | net = block(net)
145 |
146 | c = self.fc_c(net)
147 |
148 | fea = {}
149 | if 'grid' in self.plane_type:
150 | fea['grid'] = self.generate_grid_features(p, c)
151 | if 'xz' in self.plane_type:
152 | fea['xz'] = self.generate_plane_features(p, c, plane='xz')
153 | if 'xy' in self.plane_type:
154 | fea['xy'] = self.generate_plane_features(p, c, plane='xy')
155 | if 'yz' in self.plane_type:
156 | fea['yz'] = self.generate_plane_features(p, c, plane='yz')
157 |
158 | return fea
159 |
160 |
161 | class PatchLocalPoolPointnet(nn.Module):
162 | """
163 | PointNet-based encoder network with ResNet blocks.
164 | First transform input points to local system based on the given voxel size.
165 | Support non-fixed number of point cloud, but need to precompute the index.
166 |
167 | Args:
168 | c_dim (int): dimension of latent code c
169 | dim (int): input points dimension
170 | hidden_dim (int): hidden dimension of the network
171 | scatter_type (str): feature aggregation when doing local pooling
172 | unet (bool): weather to use U-Net
173 | unet_kwargs (str): U-Net parameters
174 | unet3d (bool): weather to use 3D U-Net
175 | unet3d_kwargs (str): 3D U-Net parameters
176 | plane_resolution (int): defined resolution for plane feature
177 | grid_resolution (int): defined resolution for grid feature
178 | plane_type (str): feature type, 'xz' - 1-plane, ['xz', 'xy', 'yz'] - 3-plane, ['grid'] - 3D grid volume
179 | padding (float): conventional padding paramter of ONet for unit cube, so [-0.5, 0.5] -> [-0.55, 0.55]
180 | n_blocks (int): number of blocks ResNetBlockFC layers
181 | local_coord (bool): whether to use local coordinate
182 | pos_encoding (str): method for the positional encoding, linear|sin_cos
183 | unit_size (float): defined voxel unit size for local system
184 | """
185 |
186 | def __init__(self, c_dim=128, dim=3, hidden_dim=128, scatter_type='max',
187 | unet=False, unet_kwargs=None, unet3d=False, unet3d_kwargs=None,
188 | plane_resolution=None, grid_resolution=None, plane_type='xz', padding=0.1, n_blocks=5,
189 | local_coord=False, pos_encoding='linear', unit_size=0.1):
190 | super().__init__()
191 | self.c_dim = c_dim
192 |
193 | self.blocks = nn.ModuleList([
194 | ResnetBlockFC(2 * hidden_dim, hidden_dim) for i in range(n_blocks)
195 | ])
196 | self.fc_c = nn.Linear(hidden_dim, c_dim)
197 |
198 | self.actvn = nn.ReLU()
199 | self.hidden_dim = hidden_dim
200 | self.reso_plane = plane_resolution
201 | self.reso_grid = grid_resolution
202 | self.plane_type = plane_type
203 | self.padding = padding
204 |
205 | if unet:
206 | self.unet = UNet(c_dim, in_channels=c_dim, **unet_kwargs)
207 | else:
208 | self.unet = None
209 |
210 | if unet3d:
211 | self.unet3d = UNet3D(**unet3d_kwargs)
212 | else:
213 | self.unet3d = None
214 |
215 | if scatter_type == 'max':
216 | self.scatter = scatter_max
217 | elif scatter_type == 'mean':
218 | self.scatter = scatter_mean
219 | else:
220 | raise ValueError('incorrect scatter type')
221 |
222 | if local_coord:
223 | self.map2local = map2local(unit_size, pos_encoding=pos_encoding)
224 | else:
225 | self.map2local = None
226 |
227 | if pos_encoding == 'sin_cos':
228 | self.fc_pos = nn.Linear(60, 2 * hidden_dim)
229 | else:
230 | self.fc_pos = nn.Linear(dim, 2 * hidden_dim)
231 |
232 | def generate_plane_features(self, index, c):
233 | c = c.permute(0, 2, 1)
234 | # scatter plane features from points
235 | if index.max() < self.reso_plane ** 2:
236 | fea_plane = c.new_zeros(c.size(0), self.c_dim, self.reso_plane ** 2)
237 | fea_plane = scatter_mean(c, index, out=fea_plane) # B x c_dim x reso^2
238 | else:
239 | fea_plane = scatter_mean(c, index) # B x c_dim x reso^2
240 | if fea_plane.shape[-1] > self.reso_plane ** 2: # deal with outliers
241 | fea_plane = fea_plane[:, :, :-1]
242 |
243 | fea_plane = fea_plane.reshape(c.size(0), self.c_dim, self.reso_plane, self.reso_plane)
244 |
245 | # process the plane features with UNet
246 | if self.unet is not None:
247 | fea_plane = self.unet(fea_plane)
248 |
249 | return fea_plane
250 |
251 | def generate_grid_features(self, index, c):
252 | # scatter grid features from points
253 | c = c.permute(0, 2, 1)
254 | if index.max() < self.reso_grid ** 3:
255 | fea_grid = c.new_zeros(c.size(0), self.c_dim, self.reso_grid ** 3)
256 | fea_grid = scatter_mean(c, index, out=fea_grid) # B x c_dim x reso^3
257 | else:
258 | fea_grid = scatter_mean(c, index) # B x c_dim x reso^3
259 | if fea_grid.shape[-1] > self.reso_grid ** 3: # deal with outliers
260 | fea_grid = fea_grid[:, :, :-1]
261 | fea_grid = fea_grid.reshape(c.size(0), self.c_dim, self.reso_grid, self.reso_grid, self.reso_grid)
262 |
263 | if self.unet3d is not None:
264 | fea_grid = self.unet3d(fea_grid)
265 |
266 | return fea_grid
267 |
268 | def pool_local(self, index, c):
269 | bs, fea_dim = c.size(0), c.size(2)
270 | keys = index.keys()
271 |
272 | c_out = 0
273 | for key in keys:
274 | # scatter plane features from points
275 | if key == 'grid':
276 | fea = self.scatter(c.permute(0, 2, 1), index[key])
277 | else:
278 | fea = self.scatter(c.permute(0, 2, 1), index[key])
279 | if self.scatter == scatter_max:
280 | fea = fea[0]
281 |
282 | # gather feature back to points
283 | fea = fea.gather(dim=2, index=index[key].expand(-1, fea_dim, -1))
284 | c_out += fea
285 | return c_out.permute(0, 2, 1)
286 |
287 | def forward(self, inputs):
288 | p = inputs['points']
289 | index = inputs['index']
290 |
291 | if self.map2local:
292 | pp = self.map2local(p)
293 | net = self.fc_pos(pp)
294 | else:
295 | net = self.fc_pos(p)
296 |
297 | net = self.blocks[0](net)
298 | for block in self.blocks[1:]:
299 | pooled = self.pool_local(index, net)
300 | net = torch.cat([net, pooled], dim=2)
301 | net = block(net)
302 |
303 | c = self.fc_c(net)
304 |
305 | fea = {}
306 | if 'grid' in self.plane_type:
307 | fea['grid'] = self.generate_grid_features(index['grid'], c)
308 | if 'xz' in self.plane_type:
309 | fea['xz'] = self.generate_plane_features(index['xz'], c)
310 | if 'xy' in self.plane_type:
311 | fea['xy'] = self.generate_plane_features(index['xy'], c)
312 | if 'yz' in self.plane_type:
313 | fea['yz'] = self.generate_plane_features(index['yz'], c)
314 |
315 | return fea
316 |
--------------------------------------------------------------------------------
/network/convonet/unet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | from torch.nn import init
6 |
7 |
8 | def conv3x3(in_channels, out_channels, stride=1,
9 | padding=1, bias=True, groups=1):
10 | return nn.Conv2d(
11 | in_channels,
12 | out_channels,
13 | kernel_size=3,
14 | stride=stride,
15 | padding=padding,
16 | bias=bias,
17 | groups=groups)
18 |
19 |
20 | def upconv2x2(in_channels, out_channels, mode='transpose'):
21 | if mode == 'transpose':
22 | return nn.ConvTranspose2d(
23 | in_channels,
24 | out_channels,
25 | kernel_size=2,
26 | stride=2)
27 | else:
28 | # out_channels is always going to be the same
29 | # as in_channels
30 | return nn.Sequential(
31 | nn.Upsample(mode='bilinear', scale_factor=2),
32 | conv1x1(in_channels, out_channels))
33 |
34 |
35 | def conv1x1(in_channels, out_channels, groups=1):
36 | return nn.Conv2d(
37 | in_channels,
38 | out_channels,
39 | kernel_size=1,
40 | groups=groups,
41 | stride=1)
42 |
43 |
44 | class DownConv(nn.Module):
45 | """
46 | A helper Module that performs 2 convolutions and 1 MaxPool.
47 | A ReLU activation follows each convolution.
48 | """
49 | def __init__(self, in_channels, out_channels, pooling=True):
50 | super(DownConv, self).__init__()
51 |
52 | self.in_channels = in_channels
53 | self.out_channels = out_channels
54 | self.pooling = pooling
55 |
56 | self.conv1 = conv3x3(self.in_channels, self.out_channels)
57 | self.conv2 = conv3x3(self.out_channels, self.out_channels)
58 |
59 | if self.pooling:
60 | self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
61 |
62 | def forward(self, x):
63 | x = F.relu(self.conv1(x))
64 | x = F.relu(self.conv2(x))
65 | before_pool = x
66 | if self.pooling:
67 | x = self.pool(x)
68 | return x, before_pool
69 |
70 |
71 | class UpConv(nn.Module):
72 | """
73 | A helper Module that performs 2 convolutions and 1 UpConvolution.
74 | A ReLU activation follows each convolution.
75 | """
76 | def __init__(self, in_channels, out_channels,
77 | merge_mode='concat', up_mode='transpose'):
78 | super(UpConv, self).__init__()
79 |
80 | self.in_channels = in_channels
81 | self.out_channels = out_channels
82 | self.merge_mode = merge_mode
83 | self.up_mode = up_mode
84 |
85 | self.upconv = upconv2x2(self.in_channels, self.out_channels,
86 | mode=self.up_mode)
87 |
88 | if self.merge_mode == 'concat':
89 | self.conv1 = conv3x3(
90 | 2*self.out_channels, self.out_channels)
91 | else:
92 | # num of input channels to conv2 is same
93 | self.conv1 = conv3x3(self.out_channels, self.out_channels)
94 | self.conv2 = conv3x3(self.out_channels, self.out_channels)
95 |
96 |
97 | def forward(self, from_down, from_up):
98 | """ Forward pass
99 | Arguments:
100 | from_down: tensor from the encoder pathway
101 | from_up: upconv'd tensor from the decoder pathway
102 | """
103 | from_up = self.upconv(from_up)
104 | if self.merge_mode == 'concat':
105 | x = torch.cat((from_up, from_down), 1)
106 | else:
107 | x = from_up + from_down
108 | x = F.relu(self.conv1(x))
109 | x = F.relu(self.conv2(x))
110 | return x
111 |
112 |
113 | class UNet(nn.Module):
114 | """ `UNet` class is based on https://arxiv.org/abs/1505.04597
115 |
116 | The U-Net is a convolutional encoder-decoder neural network.
117 | Contextual spatial information (from the decoding,
118 | expansive pathway) about an input tensor is merged with
119 | information representing the localization of details
120 | (from the encoding, compressive pathway).
121 |
122 | Modifications to the original paper:
123 | (1) padding is used in 3x3 convolutions to prevent loss
124 | of border pixels
125 | (2) merging outputs does not require cropping due to (1)
126 | (3) residual connections can be used by specifying
127 | UNet(merge_mode='add')
128 | (4) if non-parametric upsampling is used in the decoder
129 | pathway (specified by upmode='upsample'), then an
130 | additional 1x1 2d convolution occurs after upsampling
131 | to reduce channel dimensionality by a factor of 2.
132 | This channel halving happens with the convolution in
133 | the tranpose convolution (specified by upmode='transpose')
134 | """
135 |
136 | def __init__(self, num_classes, in_channels=3, depth=5,
137 | start_filts=64, up_mode='transpose',
138 | merge_mode='concat', **kwargs):
139 | """
140 | Arguments:
141 | in_channels: int, number of channels in the input tensor.
142 | Default is 3 for RGB images.
143 | depth: int, number of MaxPools in the U-Net.
144 | start_filts: int, number of convolutional filters for the
145 | first conv.
146 | up_mode: string, type of upconvolution. Choices: 'transpose'
147 | for transpose convolution or 'upsample' for nearest neighbour
148 | upsampling.
149 | """
150 | super(UNet, self).__init__()
151 |
152 | if up_mode in ('transpose', 'upsample'):
153 | self.up_mode = up_mode
154 | else:
155 | raise ValueError("\"{}\" is not a valid mode for "
156 | "upsampling. Only \"transpose\" and "
157 | "\"upsample\" are allowed.".format(up_mode))
158 |
159 | if merge_mode in ('concat', 'add'):
160 | self.merge_mode = merge_mode
161 | else:
162 | raise ValueError("\"{}\" is not a valid mode for"
163 | "merging up and down paths. "
164 | "Only \"concat\" and "
165 | "\"add\" are allowed.".format(up_mode))
166 |
167 | # NOTE: up_mode 'upsample' is incompatible with merge_mode 'add'
168 | if self.up_mode == 'upsample' and self.merge_mode == 'add':
169 | raise ValueError("up_mode \"upsample\" is incompatible "
170 | "with merge_mode \"add\" at the moment "
171 | "because it doesn't make sense to use "
172 | "nearest neighbour to reduce "
173 | "depth channels (by half).")
174 |
175 | self.num_classes = num_classes
176 | self.in_channels = in_channels
177 | self.start_filts = start_filts
178 | self.depth = depth
179 |
180 | self.down_convs = []
181 | self.up_convs = []
182 |
183 | # create the encoder pathway and add to a list
184 | for i in range(depth):
185 | ins = self.in_channels if i == 0 else outs
186 | outs = self.start_filts*(2**i)
187 | pooling = True if i < depth-1 else False
188 |
189 | down_conv = DownConv(ins, outs, pooling=pooling)
190 | self.down_convs.append(down_conv)
191 |
192 | # create the decoder pathway and add to a list
193 | # - careful! decoding only requires depth-1 block
194 | for i in range(depth-1):
195 | ins = outs
196 | outs = ins // 2
197 | up_conv = UpConv(ins, outs, up_mode=up_mode,
198 | merge_mode=merge_mode)
199 | self.up_convs.append(up_conv)
200 |
201 | # add the list of modules to current module
202 | self.down_convs = nn.ModuleList(self.down_convs)
203 | self.up_convs = nn.ModuleList(self.up_convs)
204 | self.conv_final = conv1x1(outs, self.num_classes)
205 | self.reset_params()
206 |
207 | @staticmethod
208 | def weight_init(m):
209 | if isinstance(m, nn.Conv2d):
210 | init.xavier_normal_(m.weight)
211 | init.constant_(m.bias, 0)
212 |
213 | def reset_params(self):
214 | for i, m in enumerate(self.modules()):
215 | self.weight_init(m)
216 |
217 | def forward(self, x):
218 | encoder_outs = []
219 | # encoder pathway, save outputs for merging
220 | for i, module in enumerate(self.down_convs):
221 | x, before_pool = module(x)
222 | encoder_outs.append(before_pool)
223 | for i, module in enumerate(self.up_convs):
224 | before_pool = encoder_outs[-(i+2)]
225 | x = module(before_pool, x)
226 |
227 | # No softmax is used. This means you need to use
228 | # nn.CrossEntropyLoss is your training script,
229 | # as this module includes a softmax already.
230 | x = self.conv_final(x)
231 | return x
232 |
--------------------------------------------------------------------------------
/network/convonet/unet3d.py:
--------------------------------------------------------------------------------
1 | import importlib
2 | import torch.nn as nn
3 | from torch.nn import functional as F
4 | from functools import partial
5 |
6 |
7 | def number_of_features_per_level(init_channel_number, num_levels):
8 | return [init_channel_number * 2 ** k for k in range(num_levels)]
9 |
10 |
11 | def conv3d(in_channels, out_channels, kernel_size, bias, padding=1):
12 | return nn.Conv3d(in_channels, out_channels, kernel_size, padding=padding, bias=bias)
13 |
14 |
15 | def create_conv(in_channels, out_channels, kernel_size, order, num_groups, padding=1):
16 | """
17 | Create a list of modules with together constitute a single conv layer with non-linearity
18 | and optional batchnorm/groupnorm.
19 |
20 | Args:
21 | in_channels (int): number of input channels
22 | out_channels (int): number of output channels
23 | order (string): order of things, e.g.
24 | 'cr' -> conv + ReLU
25 | 'gcr' -> groupnorm + conv + ReLU
26 | 'cl' -> conv + LeakyReLU
27 | 'ce' -> conv + ELU
28 | 'bcr' -> batchnorm + conv + ReLU
29 | num_groups (int): number of groups for the GroupNorm
30 | padding (int): add zero-padding to the input
31 |
32 | Return:
33 | list of tuple (name, module)
34 | """
35 | assert 'c' in order, "Conv layer MUST be present"
36 | assert order[0] not in 'rle', 'Non-linearity cannot be the first operation in the layer'
37 |
38 | modules = []
39 | for i, char in enumerate(order):
40 | if char == 'r':
41 | modules.append(('ReLU', nn.ReLU(inplace=True)))
42 | elif char == 'l':
43 | modules.append(('LeakyReLU', nn.LeakyReLU(negative_slope=0.1, inplace=True)))
44 | elif char == 'e':
45 | modules.append(('ELU', nn.ELU(inplace=True)))
46 | elif char == 'c':
47 | # add learnable bias only in the absence of batchnorm/groupnorm
48 | bias = not ('g' in order or 'b' in order)
49 | modules.append(('conv', conv3d(in_channels, out_channels, kernel_size, bias, padding=padding)))
50 | elif char == 'g':
51 | is_before_conv = i < order.index('c')
52 | if is_before_conv:
53 | num_channels = in_channels
54 | else:
55 | num_channels = out_channels
56 |
57 | # use only one group if the given number of groups is greater than the number of channels
58 | if num_channels < num_groups:
59 | num_groups = 1
60 |
61 | assert num_channels % num_groups == 0, f'Expected number of channels in input to be divisible by num_groups. num_channels={num_channels}, num_groups={num_groups}'
62 | modules.append(('groupnorm', nn.GroupNorm(num_groups=num_groups, num_channels=num_channels)))
63 | elif char == 'b':
64 | is_before_conv = i < order.index('c')
65 | if is_before_conv:
66 | modules.append(('batchnorm', nn.BatchNorm3d(in_channels)))
67 | else:
68 | modules.append(('batchnorm', nn.BatchNorm3d(out_channels)))
69 | else:
70 | raise ValueError(f"Unsupported layer type '{char}'. MUST be one of ['b', 'g', 'r', 'l', 'e', 'c']")
71 |
72 | return modules
73 |
74 |
75 | class SingleConv(nn.Sequential):
76 | """
77 | Basic convolutional module consisting of a Conv3d, non-linearity and optional batchnorm/groupnorm. The order
78 | of operations can be specified via the `order` parameter
79 |
80 | Args:
81 | in_channels (int): number of input channels
82 | out_channels (int): number of output channels
83 | kernel_size (int): size of the convolving kernel
84 | order (string): determines the order of layers, e.g.
85 | 'cr' -> conv + ReLU
86 | 'crg' -> conv + ReLU + groupnorm
87 | 'cl' -> conv + LeakyReLU
88 | 'ce' -> conv + ELU
89 | num_groups (int): number of groups for the GroupNorm
90 | """
91 |
92 | def __init__(self, in_channels, out_channels, kernel_size=3, order='crg', num_groups=8, padding=1):
93 | super(SingleConv, self).__init__()
94 |
95 | for name, module in create_conv(in_channels, out_channels, kernel_size, order, num_groups, padding=padding):
96 | self.add_module(name, module)
97 |
98 |
99 | class DoubleConv(nn.Sequential):
100 | """
101 | A module consisting of two consecutive convolution layers (e.g. BatchNorm3d+ReLU+Conv3d).
102 | We use (Conv3d+ReLU+GroupNorm3d) by default.
103 | This can be changed however by providing the 'order' argument, e.g. in order
104 | to change to Conv3d+BatchNorm3d+ELU use order='cbe'.
105 | Use padded convolutions to make sure that the output (H_out, W_out) is the same
106 | as (H_in, W_in), so that you don't have to crop in the decoder path.
107 |
108 | Args:
109 | in_channels (int): number of input channels
110 | out_channels (int): number of output channels
111 | encoder (bool): if True we're in the encoder path, otherwise we're in the decoder
112 | kernel_size (int): size of the convolving kernel
113 | order (string): determines the order of layers, e.g.
114 | 'cr' -> conv + ReLU
115 | 'crg' -> conv + ReLU + groupnorm
116 | 'cl' -> conv + LeakyReLU
117 | 'ce' -> conv + ELU
118 | num_groups (int): number of groups for the GroupNorm
119 | """
120 |
121 | def __init__(self, in_channels, out_channels, encoder, kernel_size=3, order='crg', num_groups=8):
122 | super(DoubleConv, self).__init__()
123 | if encoder:
124 | # we're in the encoder path
125 | conv1_in_channels = in_channels
126 | conv1_out_channels = out_channels // 2
127 | if conv1_out_channels < in_channels:
128 | conv1_out_channels = in_channels
129 | conv2_in_channels, conv2_out_channels = conv1_out_channels, out_channels
130 | else:
131 | # we're in the decoder path, decrease the number of channels in the 1st convolution
132 | conv1_in_channels, conv1_out_channels = in_channels, out_channels
133 | conv2_in_channels, conv2_out_channels = out_channels, out_channels
134 |
135 | # conv1
136 | self.add_module('SingleConv1',
137 | SingleConv(conv1_in_channels, conv1_out_channels, kernel_size, order, num_groups))
138 | # conv2
139 | self.add_module('SingleConv2',
140 | SingleConv(conv2_in_channels, conv2_out_channels, kernel_size, order, num_groups))
141 |
142 |
143 | class ExtResNetBlock(nn.Module):
144 | """
145 | Basic UNet block consisting of a SingleConv followed by the residual block.
146 | The SingleConv takes care of increasing/decreasing the number of channels and also ensures that the number
147 | of output channels is compatible with the residual block that follows.
148 | This block can be used instead of standard DoubleConv in the Encoder module.
149 | Motivated by: https://arxiv.org/pdf/1706.00120.pdf
150 |
151 | Notice we use ELU instead of ReLU (order='cge') and put non-linearity after the groupnorm.
152 | """
153 |
154 | def __init__(self, in_channels, out_channels, kernel_size=3, order='cge', num_groups=8, **kwargs):
155 | super(ExtResNetBlock, self).__init__()
156 |
157 | # first convolution
158 | self.conv1 = SingleConv(in_channels, out_channels, kernel_size=kernel_size, order=order, num_groups=num_groups)
159 | # residual block
160 | self.conv2 = SingleConv(out_channels, out_channels, kernel_size=kernel_size, order=order, num_groups=num_groups)
161 | # remove non-linearity from the 3rd convolution since it's going to be applied after adding the residual
162 | n_order = order
163 | for c in 'rel':
164 | n_order = n_order.replace(c, '')
165 | self.conv3 = SingleConv(out_channels, out_channels, kernel_size=kernel_size, order=n_order,
166 | num_groups=num_groups)
167 |
168 | # create non-linearity separately
169 | if 'l' in order:
170 | self.non_linearity = nn.LeakyReLU(negative_slope=0.1, inplace=True)
171 | elif 'e' in order:
172 | self.non_linearity = nn.ELU(inplace=True)
173 | else:
174 | self.non_linearity = nn.ReLU(inplace=True)
175 |
176 | def forward(self, x):
177 | # apply first convolution and save the output as a residual
178 | out = self.conv1(x)
179 | residual = out
180 |
181 | # residual block
182 | out = self.conv2(out)
183 | out = self.conv3(out)
184 |
185 | out += residual
186 | out = self.non_linearity(out)
187 |
188 | return out
189 |
190 |
191 | class Encoder(nn.Module):
192 | """
193 | A single module from the encoder path consisting of the optional max
194 | pooling layer (one may specify the MaxPool kernel_size to be different
195 | than the standard (2,2,2), e.g. if the volumetric data is anisotropic
196 | (make sure to use complementary scale_factor in the decoder path) followed by
197 | a DoubleConv module.
198 | Args:
199 | in_channels (int): number of input channels
200 | out_channels (int): number of output channels
201 | conv_kernel_size (int): size of the convolving kernel
202 | apply_pooling (bool): if True use MaxPool3d before DoubleConv
203 | pool_kernel_size (tuple): the size of the window to take a max over
204 | pool_type (str): pooling layer: 'max' or 'avg'
205 | basic_module(nn.Module): either ResNetBlock or DoubleConv
206 | conv_layer_order (string): determines the order of layers
207 | in `DoubleConv` module. See `DoubleConv` for more info.
208 | num_groups (int): number of groups for the GroupNorm
209 | """
210 |
211 | def __init__(self, in_channels, out_channels, conv_kernel_size=3, apply_pooling=True,
212 | pool_kernel_size=(2, 2, 2), pool_type='max', basic_module=DoubleConv, conv_layer_order='crg',
213 | num_groups=8):
214 | super(Encoder, self).__init__()
215 | assert pool_type in ['max', 'avg']
216 | if apply_pooling:
217 | if pool_type == 'max':
218 | self.pooling = nn.MaxPool3d(kernel_size=pool_kernel_size)
219 | else:
220 | self.pooling = nn.AvgPool3d(kernel_size=pool_kernel_size)
221 | else:
222 | self.pooling = None
223 |
224 | self.basic_module = basic_module(in_channels, out_channels,
225 | encoder=True,
226 | kernel_size=conv_kernel_size,
227 | order=conv_layer_order,
228 | num_groups=num_groups)
229 |
230 | def forward(self, x):
231 | if self.pooling is not None:
232 | x = self.pooling(x)
233 | x = self.basic_module(x)
234 | return x
235 |
236 |
237 | class Decoder(nn.Module):
238 | """
239 | A single module for decoder path consisting of the upsampling layer
240 | (either learned ConvTranspose3d or nearest neighbor interpolation) followed by a basic module (DoubleConv or ExtResNetBlock).
241 | Args:
242 | in_channels (int): number of input channels
243 | out_channels (int): number of output channels
244 | kernel_size (int): size of the convolving kernel
245 | scale_factor (tuple): used as the multiplier for the image H/W/D in
246 | case of nn.Upsample or as stride in case of ConvTranspose3d, must reverse the MaxPool3d operation
247 | from the corresponding encoder
248 | basic_module(nn.Module): either ResNetBlock or DoubleConv
249 | conv_layer_order (string): determines the order of layers
250 | in `DoubleConv` module. See `DoubleConv` for more info.
251 | num_groups (int): number of groups for the GroupNorm
252 | """
253 |
254 | def __init__(self, in_channels, out_channels, kernel_size=3, scale_factor=(2, 2, 2), basic_module=DoubleConv,
255 | conv_layer_order='crg', num_groups=8, mode='nearest'):
256 | super(Decoder, self).__init__()
257 | if basic_module == DoubleConv:
258 | # if DoubleConv is the basic_module use interpolation for upsampling and concatenation joining
259 | self.upsampling = Upsampling(transposed_conv=False, in_channels=in_channels, out_channels=out_channels,
260 | kernel_size=kernel_size, scale_factor=scale_factor, mode=mode)
261 | # concat joining
262 | self.joining = partial(self._joining, concat=True)
263 | else:
264 | # if basic_module=ExtResNetBlock use transposed convolution upsampling and summation joining
265 | self.upsampling = Upsampling(transposed_conv=True, in_channels=in_channels, out_channels=out_channels,
266 | kernel_size=kernel_size, scale_factor=scale_factor, mode=mode)
267 | # sum joining
268 | self.joining = partial(self._joining, concat=False)
269 | # adapt the number of in_channels for the ExtResNetBlock
270 | in_channels = out_channels
271 |
272 | self.basic_module = basic_module(in_channels, out_channels,
273 | encoder=False,
274 | kernel_size=kernel_size,
275 | order=conv_layer_order,
276 | num_groups=num_groups)
277 |
278 | def forward(self, encoder_features, x):
279 | x = self.upsampling(encoder_features=encoder_features, x=x)
280 | x = self.joining(encoder_features, x)
281 | x = self.basic_module(x)
282 | return x
283 |
284 | @staticmethod
285 | def _joining(encoder_features, x, concat):
286 | if concat:
287 | return torch.cat((encoder_features, x), dim=1)
288 | else:
289 | return encoder_features + x
290 |
291 |
292 | class Upsampling(nn.Module):
293 | """
294 | Upsamples a given multi-channel 3D data using either interpolation or learned transposed convolution.
295 |
296 | Args:
297 | transposed_conv (bool): if True uses ConvTranspose3d for upsampling, otherwise uses interpolation
298 | concat_joining (bool): if True uses concatenation joining between encoder and decoder features, otherwise
299 | uses summation joining (see Residual U-Net)
300 | in_channels (int): number of input channels for transposed conv
301 | out_channels (int): number of output channels for transpose conv
302 | kernel_size (int or tuple): size of the convolving kernel
303 | scale_factor (int or tuple): stride of the convolution
304 | mode (str): algorithm used for upsampling:
305 | 'nearest' | 'linear' | 'bilinear' | 'trilinear' | 'area'. Default: 'nearest'
306 | """
307 |
308 | def __init__(self, transposed_conv, in_channels=None, out_channels=None, kernel_size=3,
309 | scale_factor=(2, 2, 2), mode='nearest'):
310 | super(Upsampling, self).__init__()
311 |
312 | if transposed_conv:
313 | # make sure that the output size reverses the MaxPool3d from the corresponding encoder
314 | # (D_out = (D_in − 1) × stride[0] − 2 × padding[0] + kernel_size[0] + output_padding[0])
315 | self.upsample = nn.ConvTranspose3d(in_channels, out_channels, kernel_size=kernel_size, stride=scale_factor,
316 | padding=1)
317 | else:
318 | self.upsample = partial(self._interpolate, mode=mode)
319 |
320 | def forward(self, encoder_features, x):
321 | output_size = encoder_features.size()[2:]
322 | return self.upsample(x, output_size)
323 |
324 | @staticmethod
325 | def _interpolate(x, size, mode):
326 | return F.interpolate(x, size=size, mode=mode)
327 |
328 |
329 | class FinalConv(nn.Sequential):
330 | """
331 | A module consisting of a convolution layer (e.g. Conv3d+ReLU+GroupNorm3d) and the final 1x1 convolution
332 | which reduces the number of channels to 'out_channels'.
333 | with the number of output channels 'out_channels // 2' and 'out_channels' respectively.
334 | We use (Conv3d+ReLU+GroupNorm3d) by default.
335 | This can be change however by providing the 'order' argument, e.g. in order
336 | to change to Conv3d+BatchNorm3d+ReLU use order='cbr'.
337 | Args:
338 | in_channels (int): number of input channels
339 | out_channels (int): number of output channels
340 | kernel_size (int): size of the convolving kernel
341 | order (string): determines the order of layers, e.g.
342 | 'cr' -> conv + ReLU
343 | 'crg' -> conv + ReLU + groupnorm
344 | num_groups (int): number of groups for the GroupNorm
345 | """
346 |
347 | def __init__(self, in_channels, out_channels, kernel_size=3, order='crg', num_groups=8):
348 | super(FinalConv, self).__init__()
349 |
350 | # conv1
351 | self.add_module('SingleConv', SingleConv(in_channels, in_channels, kernel_size, order, num_groups))
352 |
353 | # in the last layer a 1×1 convolution reduces the number of output channels to out_channels
354 | final_conv = nn.Conv3d(in_channels, out_channels, 1)
355 | self.add_module('final_conv', final_conv)
356 |
357 | class Abstract3DUNet(nn.Module):
358 | """
359 | Base class for standard and residual UNet.
360 |
361 | Args:
362 | in_channels (int): number of input channels
363 | out_channels (int): number of output segmentation masks;
364 | Note that that the of out_channels might correspond to either
365 | different semantic classes or to different binary segmentation mask.
366 | It's up to the user of the class to interpret the out_channels and
367 | use the proper loss criterion during training (i.e. CrossEntropyLoss (multi-class)
368 | or BCEWithLogitsLoss (two-class) respectively)
369 | f_maps (int, tuple): number of feature maps at each level of the encoder; if it's an integer the number
370 | of feature maps is given by the geometric progression: f_maps ^ k, k=1,2,3,4
371 | final_sigmoid (bool): if True apply element-wise nn.Sigmoid after the
372 | final 1x1 convolution, otherwise apply nn.Softmax. MUST be True if nn.BCELoss (two-class) is used
373 | to train the model. MUST be False if nn.CrossEntropyLoss (multi-class) is used to train the model.
374 | basic_module: basic model for the encoder/decoder (DoubleConv, ExtResNetBlock, ....)
375 | layer_order (string): determines the order of layers
376 | in `SingleConv` module. e.g. 'crg' stands for Conv3d+ReLU+GroupNorm3d.
377 | See `SingleConv` for more info
378 | f_maps (int, tuple): if int: number of feature maps in the first conv layer of the encoder (default: 64);
379 | if tuple: number of feature maps at each level
380 | num_groups (int): number of groups for the GroupNorm
381 | num_levels (int): number of levels in the encoder/decoder path (applied only if f_maps is an int)
382 | is_segmentation (bool): if True (semantic segmentation problem) Sigmoid/Softmax normalization is applied
383 | after the final convolution; if False (regression problem) the normalization layer is skipped at the end
384 | testing (bool): if True (testing mode) the `final_activation` (if present, i.e. `is_segmentation=true`)
385 | will be applied as the last operation during the forward pass; if False the model is in training mode
386 | and the `final_activation` (even if present) won't be applied; default: False
387 | """
388 |
389 | def __init__(self, in_channels, out_channels, final_sigmoid, basic_module, f_maps=64, layer_order='gcr',
390 | num_groups=8, num_levels=4, is_segmentation=False, testing=False, **kwargs):
391 | super(Abstract3DUNet, self).__init__()
392 |
393 | self.testing = testing
394 |
395 | if isinstance(f_maps, int):
396 | f_maps = number_of_features_per_level(f_maps, num_levels=num_levels)
397 |
398 | # create encoder path consisting of Encoder modules. Depth of the encoder is equal to `len(f_maps)`
399 | encoders = []
400 | for i, out_feature_num in enumerate(f_maps):
401 | if i == 0:
402 | encoder = Encoder(in_channels, out_feature_num, apply_pooling=False, basic_module=basic_module,
403 | conv_layer_order=layer_order, num_groups=num_groups)
404 | else:
405 | # TODO: adapt for anisotropy in the data, i.e. use proper pooling kernel to make the data isotropic after 1-2 pooling operations
406 | # currently pools with a constant kernel: (2, 2, 2)
407 | encoder = Encoder(f_maps[i - 1], out_feature_num, basic_module=basic_module,
408 | conv_layer_order=layer_order, num_groups=num_groups)
409 | encoders.append(encoder)
410 |
411 | self.encoders = nn.ModuleList(encoders)
412 |
413 | # create decoder path consisting of the Decoder modules. The length of the decoder is equal to `len(f_maps) - 1`
414 | decoders = []
415 | reversed_f_maps = list(reversed(f_maps))
416 | for i in range(len(reversed_f_maps) - 1):
417 | if basic_module == DoubleConv:
418 | in_feature_num = reversed_f_maps[i] + reversed_f_maps[i + 1]
419 | else:
420 | in_feature_num = reversed_f_maps[i]
421 |
422 | out_feature_num = reversed_f_maps[i + 1]
423 | # TODO: if non-standard pooling was used, make sure to use correct striding for transpose conv
424 | # currently strides with a constant stride: (2, 2, 2)
425 | decoder = Decoder(in_feature_num, out_feature_num, basic_module=basic_module,
426 | conv_layer_order=layer_order, num_groups=num_groups)
427 | decoders.append(decoder)
428 |
429 | self.decoders = nn.ModuleList(decoders)
430 |
431 | # in the last layer a 1×1 convolution reduces the number of output
432 | # channels to the number of labels
433 | self.final_conv = nn.Conv3d(f_maps[0], out_channels, 1)
434 |
435 | if is_segmentation:
436 | # semantic segmentation problem
437 | if final_sigmoid:
438 | self.final_activation = nn.Sigmoid()
439 | else:
440 | self.final_activation = nn.Softmax(dim=1)
441 | else:
442 | # regression problem
443 | self.final_activation = None
444 |
445 | def forward(self, x):
446 | # encoder part
447 | encoders_features = []
448 | for encoder in self.encoders:
449 | x = encoder(x)
450 | # reverse the encoder outputs to be aligned with the decoder
451 | encoders_features.insert(0, x)
452 |
453 | # remove the last encoder's output from the list
454 | # !!remember: it's the 1st in the list
455 | encoders_features = encoders_features[1:]
456 |
457 | # decoder part
458 | for decoder, encoder_features in zip(self.decoders, encoders_features):
459 | # pass the output from the corresponding encoder and the output
460 | # of the previous decoder
461 | x = decoder(encoder_features, x)
462 |
463 | x = self.final_conv(x)
464 |
465 | # apply final_activation (i.e. Sigmoid or Softmax) only during prediction. During training the network outputs
466 | # logits and it's up to the user to normalize it before visualising with tensorboard or computing validation metric
467 | if self.testing and self.final_activation is not None:
468 | x = self.final_activation(x)
469 |
470 | return x
471 |
472 |
473 | class UNet3D(Abstract3DUNet):
474 | """
475 | 3DUnet model from
476 | `"3D U-Net: Learning Dense Volumetric Segmentation from Sparse Annotation"
477 | `.
478 |
479 | Uses `DoubleConv` as a basic_module and nearest neighbor upsampling in the decoder
480 | """
481 |
482 | def __init__(self, in_channels, out_channels, final_sigmoid=True, f_maps=64, layer_order='gcr',
483 | num_groups=8, num_levels=4, is_segmentation=True, **kwargs):
484 | super(UNet3D, self).__init__(in_channels=in_channels, out_channels=out_channels, final_sigmoid=final_sigmoid,
485 | basic_module=DoubleConv, f_maps=f_maps, layer_order=layer_order,
486 | num_groups=num_groups, num_levels=num_levels, is_segmentation=is_segmentation,
487 | **kwargs)
488 |
489 |
490 | class ResidualUNet3D(Abstract3DUNet):
491 | """
492 | Residual 3DUnet model implementation based on https://arxiv.org/pdf/1706.00120.pdf.
493 | Uses ExtResNetBlock as a basic building block, summation joining instead
494 | of concatenation joining and transposed convolutions for upsampling (watch out for block artifacts).
495 | Since the model effectively becomes a residual net, in theory it allows for deeper UNet.
496 | """
497 |
498 | def __init__(self, in_channels, out_channels, final_sigmoid=True, f_maps=64, layer_order='gcr',
499 | num_groups=8, num_levels=5, is_segmentation=True, **kwargs):
500 | super(ResidualUNet3D, self).__init__(in_channels=in_channels, out_channels=out_channels,
501 | final_sigmoid=final_sigmoid,
502 | basic_module=ExtResNetBlock, f_maps=f_maps, layer_order=layer_order,
503 | num_groups=num_groups, num_levels=num_levels,
504 | is_segmentation=is_segmentation,
505 | **kwargs)
506 |
507 |
508 | def get_model(config):
509 | def _model_class(class_name):
510 | m = importlib.import_module('pytorch3dunet.unet3d.model')
511 | clazz = getattr(m, class_name)
512 | return clazz
513 |
514 | assert 'model' in config, 'Could not find model configuration'
515 | model_config = config['model']
516 | model_class = _model_class(model_config['name'])
517 | return model_class(**model_config)
518 |
--------------------------------------------------------------------------------
/network/decoder.py:
--------------------------------------------------------------------------------
1 | """
2 | Polyhedral feature decoders.
3 | """
4 |
5 | import torch
6 | import torch.nn.functional as F
7 | from torch_geometric.nn import MLP
8 |
9 | from network.convonet.ResnetBlockFC import ResnetBlockFC
10 | from utils import normalize_coordinate, normalize_3d_coordinate
11 |
12 |
13 | class MLPDecoder(torch.nn.Module):
14 | """
15 | Global decoder to fuse queries and latent. Adapted from IM-Net.
16 | """
17 |
18 | def __init__(self, latent_dim, num_queries):
19 | super().__init__()
20 | self.num_queries = num_queries
21 | self.mlp = MLP([latent_dim + self.num_queries * 3, latent_dim * 4, latent_dim * 4, latent_dim * 2, latent_dim], plain_last=False)
22 |
23 | def forward(self, latent, queries, batch):
24 | # concat queries and latent
25 | latent = latent[batch] # (num_cells_in_batch, 256)
26 | queries = queries.view(queries.shape[0], -1) # (num_cells_in_batch, num_queries_per_cell * 3)
27 | pointz = torch.cat([queries, latent], 1) # (num_cells_in_batch, num_queries_per_cell * 3 + 256)
28 | return self.mlp(pointz) # (num_cells_in_batch, num_queries_per_cell, 1)
29 |
30 |
31 | class ConvONetDecoder(torch.nn.Module):
32 | """
33 | Convolutional occupancy decoder. Instead of conditioning on global features, on plane/volume local features.
34 | Args:
35 | dim (int): input dimension
36 | c_dim (int): dimension of latent conditioned code c
37 | hidden_size (int): hidden size of Decoder network
38 | n_blocks (int): number of blocks ResNetBlockFC layers
39 | leaky (bool): whether to use leaky ReLUs
40 | sample_mode (str): sampling feature strategy, bi-linear | nearest
41 | padding (float): conventional padding parameter of ONet for unit cube, so [-0.5, 0.5] -> [-0.55, 0.55]
42 | """
43 |
44 | def __init__(self, dim=3, c_dim=128, hidden_size=256, latent_dim=4096, num_queries=16 ,n_blocks=5,
45 | leaky=False, sample_mode='bilinear', padding=0.1):
46 | super().__init__()
47 | self.c_dim = c_dim
48 | self.num_queries = num_queries
49 | self.n_blocks = n_blocks
50 |
51 | if c_dim != 0:
52 | self.fc_c = torch.nn.ModuleList([
53 | torch.nn.Linear(c_dim, hidden_size) for _ in range(n_blocks)
54 | ])
55 |
56 | self.fc_p = torch.nn.Linear(dim, hidden_size)
57 |
58 | self.blocks = torch.nn.ModuleList([
59 | ResnetBlockFC(hidden_size) for _ in range(n_blocks)
60 | ])
61 | if latent_dim != hidden_size * num_queries:
62 | # not recommended as breaks explicit per-query feature
63 | self.fc_out = torch.nn.Linear(hidden_size * num_queries, latent_dim)
64 | else:
65 | self.fc_out = None
66 |
67 | if not leaky:
68 | self.actvn = F.relu
69 | else:
70 | self.actvn = lambda x: F.leaky_relu(x, 0.2)
71 |
72 | self.sample_mode = sample_mode
73 | self.padding = padding
74 |
75 | def sample_plane_feature(self, p, c, plane='xz'):
76 | xy = normalize_coordinate(p.clone(), plane=plane, padding=self.padding) # normalize to the range of (0, 1)
77 | xy = xy[:, :, None].float()
78 | vgrid = 2.0 * xy - 1.0 # normalize to (-1, 1)
79 | c = F.grid_sample(c, vgrid, padding_mode='border', align_corners=True, mode=self.sample_mode).squeeze(-1)
80 | return c
81 |
82 | def sample_grid_feature(self, p, c):
83 | p_nor = normalize_3d_coordinate(p.clone(), padding=self.padding) # normalize to the range of (0, 1)
84 | p_nor = p_nor[:, :, None, None].float()
85 | vgrid = 2.0 * p_nor - 1.0 # normalize to (-1, 1)
86 | # bilinear interpolation
87 | c = F.grid_sample(c, vgrid, padding_mode='border', align_corners=True, mode=self.sample_mode).squeeze(
88 | -1).squeeze(-1)
89 | return c
90 |
91 | def forward(self, p, c_plane, **kwargs):
92 | if self.c_dim != 0:
93 | plane_type = list(c_plane.keys())
94 | c = 0
95 | if 'grid' in plane_type:
96 | c += self.sample_grid_feature(p, c_plane['grid'])
97 | if 'xz' in plane_type:
98 | c += self.sample_plane_feature(p, c_plane['xz'], plane='xz')
99 | if 'xy' in plane_type:
100 | c += self.sample_plane_feature(p, c_plane['xy'], plane='xy')
101 | if 'yz' in plane_type:
102 | c += self.sample_plane_feature(p, c_plane['yz'], plane='yz')
103 | c = c.transpose(1, 2)
104 |
105 | p = p.float()
106 | net = self.fc_p(p)
107 |
108 | for i in range(self.n_blocks):
109 | if self.c_dim != 0:
110 | net = net + self.fc_c[i](c)
111 |
112 | net = self.blocks[i](net)
113 |
114 | # aggregate to cell-wise features
115 | out = net.squeeze()
116 | out = out.view(out.shape[0] // self.num_queries , -1) # (num_cells, latent_b_dim * num_queries == 4096)
117 | if self.fc_out is not None:
118 | out = self.fc_out(out)
119 |
120 | return out
121 |
122 |
123 | class Decoder(torch.nn.Module):
124 | """
125 | Global decoder to fuse queries and latent. Adapted from IM-Net.
126 | """
127 |
128 | def __init__(self, backbone, latent_dim, num_queries):
129 | super().__init__()
130 | self.backbone = backbone
131 | self.latent_dim = latent_dim
132 | if backbone.casefold() == 'MLP'.casefold():
133 | self.network = MLPDecoder(latent_dim=latent_dim, num_queries=num_queries)
134 | elif backbone.casefold() == 'ConvONet'.casefold():
135 | self.network = ConvONetDecoder(latent_dim=latent_dim, num_queries=num_queries)
136 | else:
137 | raise ValueError(f'Unexpected backbone: {backbone}')
138 |
139 | def forward(self, latent, queries, batch):
140 | # latent: (num_graphs_in_batch, latent_a_dim)
141 | # queries: (num_cells_in_batch, num_queries_per_cell, 3)
142 | # batch: (num_cells_in_batch, 1)
143 |
144 | if self.backbone.casefold() == 'MLP'.casefold():
145 | outs = self.network(latent, queries, batch)
146 |
147 | else:
148 | # occupancy network decoder
149 | outs = torch.zeros([len(batch), self.latent_dim]).to(batch.device)
150 | for i in range(batch[-1] + 1):
151 | latent_i = {'xz': latent['xz'][i].unsqueeze(0), 'xy': latent['xy'][i].unsqueeze(0),
152 | 'yz': latent['yz'][i].unsqueeze(0)}
153 | queries_i = queries[batch == i].view(-1, 3).unsqueeze(0)
154 | outs[batch == i] = self.network(queries_i, latent_i)
155 |
156 | # (num_cells_in_batch, num_features_per_cell == latent_b_dim)
157 | return outs
158 |
--------------------------------------------------------------------------------
/network/encoder.py:
--------------------------------------------------------------------------------
1 | """
2 | Point cloud encoders.
3 | """
4 |
5 | import torch
6 | from torch import Tensor
7 | from torch.nn import Sequential, Linear, ReLU, LeakyReLU
8 | from torch_geometric.nn import PointNetConv, XConv, DynamicEdgeConv
9 | from torch_geometric.nn import fps, global_mean_pool, global_max_pool, knn_graph
10 | from torch_geometric.nn import MLP, PointTransformerConv, knn, radius
11 | from torch_geometric.utils import scatter
12 | from torch_geometric.nn.aggr import MaxAggregation
13 | from torch_geometric.nn.conv import MessagePassing
14 | from torch_geometric.nn.pool.decimation import decimation_indices
15 | from torch_geometric.utils import softmax
16 |
17 | from network.convonet.pointnet import LocalPoolPointnet
18 |
19 |
20 | class PointNet(torch.nn.Module):
21 | """
22 | PointNet with PointNetConv.
23 | """
24 |
25 | def __init__(self, latent_dim):
26 | super().__init__()
27 | # conv layers
28 | self.conv1 = PointNetConv(MLP([3 + 3, 64, 64]))
29 | self.conv2 = PointNetConv(MLP([64 + 3, 64, 128, 1024]))
30 | self.lin = MLP([1024, latent_dim], plain_last=False)
31 |
32 | def forward(self, pos, batch):
33 | x, pos, batch = None, pos, batch
34 | edge_index = knn_graph(pos, k=16, batch=batch, loop=True)
35 |
36 | # point-wise features
37 | x = self.conv1(x, pos, edge_index)
38 | x = self.conv2(x, pos, edge_index)
39 |
40 | # instance-wise features
41 | x = global_max_pool(x, batch)
42 | x = self.lin(x)
43 | return x
44 |
45 |
46 | class SAModule(torch.nn.Module):
47 | """
48 | SA Module for PointNet++.
49 | """
50 |
51 | def __init__(self, ratio, r, nn):
52 | super().__init__()
53 | self.ratio = ratio
54 | self.r = r
55 | self.conv = PointNetConv(nn, add_self_loops=False)
56 |
57 | def forward(self, x, pos, batch):
58 | idx = fps(pos, batch, ratio=self.ratio)
59 | row, col = radius(pos, pos[idx], self.r, batch, batch[idx], max_num_neighbors=64)
60 | edge_index = torch.stack([col, row], dim=0)
61 | x_dst = None if x is None else x[idx]
62 | x = self.conv((x, x_dst), (pos, pos[idx]), edge_index)
63 | pos, batch = pos[idx], batch[idx]
64 | return x, pos, batch
65 |
66 |
67 | class GlobalSAModule(torch.nn.Module):
68 | """
69 | Global SA Module for PointNet++.
70 | """
71 |
72 | def __init__(self, nn):
73 | super().__init__()
74 | self.nn = nn
75 |
76 | def forward(self, x, pos, batch):
77 | x = self.nn(torch.cat([x, pos], dim=1))
78 | x = global_max_pool(x, batch)
79 | pos = pos.new_zeros((x.size(0), 3))
80 | batch = torch.arange(x.size(0), device=batch.device)
81 | return x, pos, batch
82 |
83 |
84 | class PointNet2(torch.nn.Module):
85 | """
86 | PointNet++ with SAModule and GlobalSAModule.
87 | """
88 |
89 | def __init__(self, latent_dim):
90 | super().__init__()
91 |
92 | # input channels account for both `pos` and node features.
93 | self.sa1_module = SAModule(0.5, 0.2, MLP([3, 64, 64, 128]))
94 | self.sa2_module = SAModule(0.25, 0.4, MLP([128 + 3, 128, 128, 256]))
95 | self.sa3_module = GlobalSAModule(MLP([256 + 3, 256, 512, 1024]))
96 | self.lin = MLP([1024, latent_dim], plain_last=False)
97 |
98 | def forward(self, pos, batch):
99 | # point-wise features
100 | sa0_out = (pos, pos, batch)
101 | sa1_out = self.sa1_module(*sa0_out)
102 | sa2_out = self.sa2_module(*sa1_out)
103 | sa3_out = self.sa3_module(*sa2_out)
104 | x, _, _ = sa3_out
105 |
106 | # instance-wise features
107 | x = global_max_pool(x, batch)
108 | x = self.lin(x)
109 | return x
110 |
111 |
112 | class PointCNN(torch.nn.Module):
113 | """
114 | PointCNN with XConv.
115 | """
116 |
117 | def __init__(self, latent_dim):
118 | super().__init__()
119 |
120 | self.conv1 = XConv(3, 32, dim=3, kernel_size=8, dilation=2)
121 | self.conv2 = XConv(32, 128, dim=3, kernel_size=12, dilation=2)
122 | self.conv3 = XConv(128, 256, dim=3, kernel_size=16, dilation=1)
123 | self.conv4 = XConv(256, 256, dim=3, kernel_size=16, dilation=2)
124 | self.conv5 = XConv(256, 256, dim=3, kernel_size=16, dilation=3)
125 | self.conv6 = XConv(256, 690, dim=3, kernel_size=16, dilation=4)
126 | self.lin = MLP([690, latent_dim], plain_last=False)
127 |
128 | def forward(self, pos, batch):
129 | # point-wise features
130 | x, pos, batch = None, pos, batch
131 | x = self.conv1(x, pos, batch)
132 | idx = fps(pos, batch, ratio=0.375)
133 | x, pos, batch = x[idx], pos[idx], batch[idx]
134 |
135 | x = self.conv2(x, pos, batch)
136 | idx = fps(pos, batch, ratio=0.325)
137 | x, pos, batch = x[idx], pos[idx], batch[idx]
138 |
139 | x = self.conv3(x, pos, batch)
140 | x = self.conv4(x, pos, batch)
141 | x = self.conv5(x, pos, batch)
142 | x = self.conv6(x, pos, batch)
143 |
144 | # instance-wise features
145 | x = global_max_pool(x, batch)
146 | x = self.lin(x)
147 | return x
148 |
149 |
150 | class DGCNN(torch.nn.Module):
151 | """
152 | DGCNN with DynamicEdgeConv.
153 | """
154 |
155 | def __init__(self, latent_dim):
156 | super().__init__()
157 |
158 | self.conv1 = DynamicEdgeConv(Sequential(Linear(3 * 2, 64, bias=True), LeakyReLU(negative_slope=0.02)), k=20)
159 | self.conv2 = DynamicEdgeConv(Sequential(Linear(64 * 2, 64, bias=True), LeakyReLU(negative_slope=0.02)), k=20)
160 | self.conv3 = DynamicEdgeConv(Sequential(Linear(64 * 2, 128, bias=True), LeakyReLU(negative_slope=0.02)), k=20)
161 | self.conv4 = DynamicEdgeConv(Sequential(Linear(128 * 2, 256, bias=True), LeakyReLU(negative_slope=0.02)), k=20)
162 | self.lin = MLP([512, latent_dim], plain_last=False)
163 |
164 | def forward(self, pos, batch):
165 | x, batch = pos, batch
166 |
167 | # point-wise features
168 | x1 = self.conv1(x, batch)
169 | x2 = self.conv2(x1, batch)
170 | x3 = self.conv3(x2, batch)
171 | x4 = self.conv4(x3, batch)
172 | x = torch.concat([x1, x2, x3, x4], dim=-1)
173 |
174 | # instance-wise features
175 | x = global_max_pool(x, batch)
176 | x = self.lin(x)
177 | return x
178 |
179 |
180 | class SpatialTransformer(torch.nn.Module):
181 | """
182 | Optional spatial transform block in DGCNN to align a point set to a canonical space.
183 | """
184 |
185 | def __init__(self, k=16):
186 | super().__init__()
187 | # to estimate the 3 x 3 matrix, a tensor concatenating the coordinates of each point
188 | # and the coordinate differences between its k neighboring points is used.
189 | self.k = k
190 | self.mlp = Sequential(Linear(k * 6, 1024), ReLU(), Linear(1024, 256), ReLU(), Linear(256, 9))
191 |
192 | def forward(self, pos, batch):
193 | # pos: (num_batch_points, 3)
194 | edge_index = knn_graph(pos, k=self.k, batch=batch, loop=True).to(batch.device)
195 | neighbours = pos[edge_index[0].reshape([-1, 16])] # (num_batch_points, k, 3)
196 | pos = torch.unsqueeze(pos, dim=1) # (num_batch_points, k, 6)
197 | # concatenating the coordinates of each point and the coordinate differences between its k neighboring points
198 | x = torch.concat([pos.repeat(1, self.k, 1), pos - neighbours], dim=2) # (num_batch_points, k, 6)
199 | x = x.reshape([x.shape[0], -1]) # (num_batch_points, k * 6)
200 | x = self.mlp(x) # (num_batch_points, 9)
201 | x = global_mean_pool(x, batch) # (batch_size, 9)
202 | x = x[batch].reshape([-1, 3, 3]) # (num_batch_points, 3, 3)
203 | x = torch.squeeze(torch.bmm(pos, x)) # (num_batch_points, 3)
204 | return x
205 |
206 |
207 | class TransformerBlock(torch.nn.Module):
208 | """
209 | Transformer block for PointTransformer.
210 | """
211 |
212 | def __init__(self, in_channels, out_channels):
213 | super().__init__()
214 | self.lin_in = Linear(in_channels, in_channels)
215 | self.lin_out = Linear(out_channels, out_channels)
216 |
217 | self.pos_nn = MLP([3, 64, out_channels], norm=None, plain_last=False)
218 | self.attn_nn = MLP([out_channels, 64, out_channels], norm=None, plain_last=False)
219 | self.transformer = PointTransformerConv(in_channels, out_channels, pos_nn=self.pos_nn, attn_nn=self.attn_nn)
220 |
221 | def forward(self, x, pos, edge_index):
222 | x = self.lin_in(x).relu()
223 | x = self.transformer(x, pos, edge_index)
224 | x = self.lin_out(x).relu()
225 | return x
226 |
227 |
228 | class TransitionDown(torch.nn.Module):
229 | """
230 | TransitionDown for PointTransformer.
231 | Samples the input point cloud by a ratio percentage to reduce
232 | cardinality and uses an MLP to augment features dimensionality.
233 | """
234 |
235 | def __init__(self, in_channels, out_channels, ratio=0.25, k=16):
236 | super().__init__()
237 | self.k = k
238 | self.ratio = ratio
239 | self.mlp = MLP([in_channels, out_channels], plain_last=False)
240 |
241 | def forward(self, x, pos, batch):
242 | # FPS sampling
243 | id_clusters = fps(pos, ratio=self.ratio, batch=batch)
244 |
245 | # compute k-nearest points for each cluster
246 | sub_batch = batch[id_clusters] if batch is not None else None
247 |
248 | # beware of self loop
249 | id_k_neighbor = knn(pos, pos[id_clusters], k=self.k, batch_x=batch, batch_y=sub_batch)
250 |
251 | # transformation of features through a simple MLP
252 | x = self.mlp(x)
253 |
254 | # Max pool onto each cluster the features from knn in points
255 | x_out = scatter(x[id_k_neighbor[1]], id_k_neighbor[0], dim=0, dim_size=id_clusters.size(0), reduce='max')
256 |
257 | # keep only the clusters and their max-pooled features
258 | sub_pos, out = pos[id_clusters], x_out
259 | return out, sub_pos, sub_batch
260 |
261 |
262 | class PointTransformer(torch.nn.Module):
263 | """
264 | PointTransformer.
265 | """
266 |
267 | def __init__(self, latent_dim, k=16):
268 | super().__init__()
269 | self.k = k
270 |
271 | # dummy feature is created if there is none given
272 | in_channels = 1
273 |
274 | # hidden channels
275 | dim_model = [32, 64, 128, 256, 512]
276 |
277 | # first block
278 | self.mlp_input = MLP([in_channels, dim_model[0]], plain_last=False)
279 | self.transformer_input = TransformerBlock(in_channels=dim_model[0], out_channels=dim_model[0])
280 |
281 | # backbone layers
282 | self.transformers_down = torch.nn.ModuleList()
283 | self.transition_down = torch.nn.ModuleList()
284 |
285 | for i in range(len(dim_model) - 1):
286 | # Add Transition Down block followed by a Transformer block
287 | self.transition_down.append(
288 | TransitionDown(in_channels=dim_model[i], out_channels=dim_model[i + 1], k=self.k))
289 | self.transformers_down.append(
290 | TransformerBlock(in_channels=dim_model[i + 1], out_channels=dim_model[i + 1]))
291 | self.lin = MLP([dim_model[-1], latent_dim], plain_last=False)
292 |
293 | def forward(self, pos, batch=None, x=None):
294 | # add dummy features in case there is none
295 | if x is None:
296 | x = torch.ones((pos.shape[0], 1), device=pos.get_device())
297 |
298 | # first block
299 | x = self.mlp_input(x)
300 | edge_index = knn_graph(pos, k=self.k, batch=batch)
301 | x = self.transformer_input(x, pos, edge_index)
302 |
303 | # backbone
304 | for i in range(len(self.transformers_down)):
305 | x, pos, batch = self.transition_down[i](x, pos, batch=batch)
306 |
307 | edge_index = knn_graph(pos, k=self.k, batch=batch)
308 | x = self.transformers_down[i](x, pos, edge_index)
309 |
310 | # GlobalAveragePooling
311 | x = global_mean_pool(x, batch)
312 |
313 | # MLP blocks
314 | out = self.lin(x)
315 | return out
316 |
317 |
318 | class SharedMLP(MLP):
319 | """
320 |
321 | """
322 |
323 | def __init__(self, *args, **kwargs):
324 | # BN + Act always active even at last layer.
325 | kwargs['plain_last'] = False
326 | # LeakyRelu with 0.2 slope by default.
327 | kwargs['act'] = kwargs.get('act', 'LeakyReLU')
328 | kwargs['act_kwargs'] = kwargs.get('act_kwargs', {'negative_slope': 0.2})
329 | # BatchNorm with 1 - 0.99 = 0.01 momentum
330 | # and 1e-6 eps by defaut (tensorflow momentum != pytorch momentum)
331 | kwargs['norm_kwargs'] = kwargs.get('norm_kwargs', {'momentum': 0.01, 'eps': 1e-6})
332 | super().__init__(*args, **kwargs)
333 |
334 |
335 | class LocalFeatureAggregation(MessagePassing):
336 | """Positional encoding of points in a neighborhood."""
337 |
338 | def __init__(self, channels):
339 | super().__init__(aggr='add')
340 | self.mlp_encoder = SharedMLP([10, channels // 2])
341 | self.mlp_attention = SharedMLP([channels, channels], bias=False,
342 | act=None, norm=None)
343 | self.mlp_post_attention = SharedMLP([channels, channels])
344 |
345 | def forward(self, edge_index, x, pos):
346 | out = self.propagate(edge_index, x=x, pos=pos) # N, d_out
347 | out = self.mlp_post_attention(out) # N, d_out
348 | return out
349 |
350 | def message(self, x_j, pos_i, pos_j, index):
351 | """Local Spatial Encoding (locSE) and attentive pooling of features.
352 | Args:
353 | x_j (Tensor): neighbors features (K,d)
354 | pos_i (Tensor): centroid position (repeated) (K,3)
355 | pos_j (Tensor): neighbors positions (K,3)
356 | index (Tensor): index of centroid positions
357 | (e.g. [0,...,0,1,...,1,...,N,...,N])
358 | returns:
359 | (Tensor): locSE weighted by feature attention scores.
360 | """
361 | # Encode local neighborhood structural information
362 | pos_diff = pos_j - pos_i
363 | distance = torch.sqrt((pos_diff * pos_diff).sum(1, keepdim=True))
364 | relative_infos = torch.cat([pos_i, pos_j, pos_diff, distance], dim=1) # N * K, d
365 | local_spatial_encoding = self.mlp_encoder(relative_infos) # N * K, d
366 | local_features = torch.cat([x_j, local_spatial_encoding], dim=1) # N * K, 2d
367 |
368 | # Attention will weight the different features of x
369 | # along the neighborhood dimension.
370 | att_features = self.mlp_attention(local_features) # N * K, d_out
371 | att_scores = softmax(att_features, index=index) # N * K, d_out
372 |
373 | return att_scores * local_features # N * K, d_out
374 |
375 |
376 | class DilatedResidualBlock(torch.nn.Module):
377 | """
378 | Dilated residual block for RandLANet.
379 | """
380 |
381 | def __init__(self, num_neighbors, d_in: int, d_out: int):
382 | super().__init__()
383 | self.num_neighbors = num_neighbors
384 | self.d_in = d_in
385 | self.d_out = d_out
386 |
387 | # MLP on input
388 | self.mlp1 = SharedMLP([d_in, d_out // 8])
389 | # MLP on input, and the result is summed with the output of mlp2
390 | self.shortcut = SharedMLP([d_in, d_out], act=None)
391 | # MLP on output
392 | self.mlp2 = SharedMLP([d_out // 2, d_out], act=None)
393 |
394 | self.lfa1 = LocalFeatureAggregation(d_out // 4)
395 | self.lfa2 = LocalFeatureAggregation(d_out // 2)
396 | self.lrelu = torch.nn.LeakyReLU(**{'negative_slope': 0.2})
397 |
398 | def forward(self, x, pos, batch):
399 | edge_index = knn_graph(pos, self.num_neighbors, batch=batch, loop=True)
400 |
401 | shortcut_of_x = self.shortcut(x) # N, d_out
402 | x = self.mlp1(x) # N, d_out//8
403 | x = self.lfa1(edge_index, x, pos) # N, d_out//2
404 | x = self.lfa2(edge_index, x, pos) # N, d_out//2
405 | x = self.mlp2(x) # N, d_out
406 | x = self.lrelu(x + shortcut_of_x) # N, d_out
407 |
408 | return x, pos, batch
409 |
410 |
411 | def decimate(tensors, ptr: Tensor, decimation_factor: int):
412 | """
413 | Decimates each element of the given tuple of tensors for RandLANet.
414 | """
415 | idx_decim, ptr_decim = decimation_indices(ptr, decimation_factor)
416 | tensors_decim = tuple(tensor[idx_decim] for tensor in tensors)
417 | return tensors_decim, ptr_decim
418 |
419 |
420 | class RandLANet(torch.nn.Module):
421 | """
422 | An adaptation of RandLA-Net for point cloud encoding, which was not addressed in the paper:
423 | RandLA-Net: Efficient Semantic Segmentation of Large-Scale Point Clouds.
424 | """
425 |
426 | def __init__(self, latent_dim, decimation: int = 4, num_neighbors: int = 16):
427 | super().__init__()
428 | self.decimation = decimation
429 | self.fc0 = Linear(in_features=3, out_features=8)
430 | # 2 DilatedResidualBlock converges better than 4 on ModelNet
431 | self.block1 = DilatedResidualBlock(num_neighbors, 8, 32)
432 | self.block2 = DilatedResidualBlock(num_neighbors, 32, 128)
433 | self.block3 = DilatedResidualBlock(num_neighbors, 128, 512)
434 | self.mlp1 = SharedMLP([512, 512])
435 | self.max_agg = MaxAggregation()
436 | self.mlp2 = Linear(512, latent_dim)
437 |
438 | def forward(self, pos, batch):
439 | x = pos
440 | ptr = torch.where(batch[1:] - batch[:-1])[0] + 1 # use ptr elsewhere
441 | ptr = torch.cat([torch.tensor([0]).cuda(), ptr])
442 | ptr = torch.cat([ptr, torch.tensor([len(batch)]).cuda()])
443 | b1 = self.block1(self.fc0(x), pos, batch)
444 | b1_decimated, ptr1 = decimate(b1, ptr, self.decimation)
445 |
446 | b2 = self.block2(*b1_decimated)
447 | b2_decimated, ptr2 = decimate(b2, ptr1, self.decimation)
448 |
449 | b3 = self.block3(*b2_decimated)
450 | b3_decimated, _ = decimate(b3, ptr2, self.decimation)
451 |
452 | x = self.mlp1(b3_decimated[0])
453 | x = self.max_agg(x, b3_decimated[2])
454 | x = self.mlp2(x)
455 |
456 | return x
457 |
458 |
459 | class Encoder(torch.nn.Module):
460 | """
461 | Point cloud encoder.
462 | """
463 | def __init__(self, backbone, latent_dim, use_spatial_transformer=False, convonet_kwargs=None):
464 | super().__init__()
465 | self.backbone = backbone
466 |
467 | # spatial transformer placeholder
468 | self.use_spatial_transformer = False
469 | if use_spatial_transformer:
470 | self.use_spatial_transformer = True
471 | self.spatial_transformer = SpatialTransformer()
472 |
473 | backbone_mapping = {
474 | 'pointnet': PointNet,
475 | 'pointnet2': PointNet2,
476 | 'pointcnn': PointCNN,
477 | 'dgcnn': DGCNN,
478 | 'randlanet': RandLANet,
479 | 'pointtransformer': PointTransformer,
480 | 'convonet': lambda d: LocalPoolPointnet(**convonet_kwargs)}
481 | self.backbone_key = backbone.casefold()
482 | if self.backbone_key in backbone_mapping:
483 | self.network = backbone_mapping[self.backbone_key](latent_dim)
484 | else:
485 | raise ValueError(f'Unexpected backbone: {self.backbone_key}')
486 |
487 | def forward(self, points, batch_points):
488 | # points: (total_num_points, 3)
489 | # batch_points: (total_num_points)
490 |
491 | # spatial transformation
492 | if self.use_spatial_transformer:
493 | points = self.spatial_transformer(points, batch_points)
494 |
495 | if self.backbone_key == 'convonet':
496 | # reshape to split it into a single tensor
497 | points_split = points.view(batch_points[-1] + 1, -1, 3)
498 | return self.network(points_split)
499 | else:
500 | return self.network(points, batch_points)
501 |
--------------------------------------------------------------------------------
/network/gnn.py:
--------------------------------------------------------------------------------
1 | """
2 | Graph neural networks.
3 | """
4 |
5 | import torch
6 | import torch.nn.functional as F
7 | from torch_geometric.nn import GCNConv, TransformerConv, TAGConv, MLP
8 |
9 |
10 | class GCN(torch.nn.Module):
11 | """
12 | Graph convolution network from the "Semi-supervised
13 | Classification with Graph Convolutional Networks" paper.
14 | """
15 |
16 | def __init__(self, num_features, num_classes, dropout=False):
17 | super().__init__()
18 | self.dropout = dropout
19 | self.conv1 = GCNConv(num_features, 16)
20 | self.conv2 = GCNConv(16, 16)
21 | self.conv3 = GCNConv(16, 16)
22 | self.mlp = MLP([16, 64, 16, num_classes], act='leaky_relu', act_kwargs={'negative_slope': 0.02})
23 |
24 | def forward(self, x, edge_index):
25 | # GCN layers
26 | x = self.conv1(x, edge_index)
27 | x = F.relu(x)
28 | x = F.dropout(x, training=self.training) if self.dropout else x
29 | x = self.conv2(x, edge_index)
30 | x = F.relu(x)
31 | x = F.dropout(x, training=self.training) if self.dropout else x
32 | x = self.conv3(x, edge_index)
33 |
34 | # MLP layers
35 | x = self.mlp(x) # (num_batch_cells, 2)
36 |
37 | # return the log softmax value
38 | return F.log_softmax(x, dim=1)
39 |
40 |
41 | class TransformerGCN(torch.nn.Module):
42 | """
43 | Graph transformer operator from the "Masked Label Prediction:
44 | Unified Message Passing Model for Semi-Supervised Classification" paper.
45 | """
46 |
47 | def __init__(self, num_features, num_classes, dropout=False):
48 | super().__init__()
49 | self.dropout = dropout
50 | self.conv1 = TransformerConv(num_features, 16)
51 | self.conv2 = TransformerConv(16, 16)
52 | self.conv3 = TransformerConv(16, 16)
53 | self.mlp = MLP([16, 64, 16, num_classes], act='leaky_relu', act_kwargs={'negative_slope': 0.02})
54 |
55 | def forward(self, x, edge_index):
56 | # Graph transformer layers
57 | x = self.conv1(x, edge_index)
58 | x = F.relu(x)
59 | x = F.dropout(x, training=self.training) if self.dropout else x
60 | x = self.conv2(x, edge_index)
61 | x = F.relu(x)
62 | x = F.dropout(x, training=self.training) if self.dropout else x
63 | x = self.conv3(x, edge_index)
64 |
65 | # MLP layers
66 | x = self.mlp(x) # (num_batch_cells, 2)
67 |
68 | # return the log softmax value
69 | return F.log_softmax(x, dim=1)
70 |
71 |
72 | class TAGCN(torch.nn.Module):
73 | """
74 | Topology adaptive graph convolutional network operator from
75 | the "Topology Adaptive Graph Convolutional Networks" paper.
76 | """
77 |
78 | def __init__(self, num_features, num_classes, dropout=False):
79 | super().__init__()
80 | self.dropout = dropout
81 | self.conv1 = TAGConv(num_features, 16)
82 | self.conv2 = TAGConv(16, 16)
83 | self.conv3 = TAGConv(16, 16)
84 | self.mlp = MLP([16, 64, 16, num_classes], act='leaky_relu', act_kwargs={'negative_slope': 0.02})
85 |
86 | def forward(self, x, edge_index):
87 | # Topology Adaptive Graph Convolutional layers
88 | x = self.conv1(x, edge_index)
89 | x = F.leaky_relu(x, negative_slope=0.02)
90 | x = F.dropout(x, training=self.training) if self.dropout else x
91 | x = self.conv2(x, edge_index)
92 | x = F.leaky_relu(x, negative_slope=0.02)
93 | x = F.dropout(x, training=self.training) if self.dropout else x
94 | x = self.conv3(x, edge_index)
95 |
96 | # MLP layers
97 | x = self.mlp(x) # (num_batch_cells, 2)
98 |
99 | # return the log softmax value
100 | return F.log_softmax(x, dim=1)
101 |
102 |
103 | class GNN(torch.nn.Module):
104 | def __init__(self, backend, num_features, num_classes, **kwargs):
105 | super().__init__()
106 | self.backend = backend
107 |
108 | if backend.casefold() == 'GCN'.casefold():
109 | self.network = GCN(num_features, num_classes, *kwargs)
110 | elif backend.casefold() == 'TransformerGCN'.casefold():
111 | self.network = TransformerGCN(num_features, num_classes, *kwargs)
112 | elif backend.casefold() == 'TAGCN'.casefold():
113 | self.network = TAGCN(num_features, num_classes, *kwargs)
114 | else:
115 | raise ValueError(f'Expected backend GCN, GraphTransformer or TAGCN, got {backend} instead')
116 |
117 | def forward(self, x, data):
118 | x = torch.squeeze(x) # (num_cells_in_batch, num_queries_per_cell)
119 | return self.network(x, data)
120 |
--------------------------------------------------------------------------------
/network/loss.py:
--------------------------------------------------------------------------------
1 | """
2 | Loss functions.
3 | """
4 |
5 | import torch
6 | import torch.nn.functional as F
7 |
8 |
9 | def bce_loss(inputs, targets, weight=(1., 1.)):
10 | """
11 | Binary cross entropy loss for PolyGNN prediction.
12 | """
13 | pred = inputs.argmax(dim=1)
14 | total = len(pred)
15 | correct = torch.sum((pred.clone().detach() == targets).long())
16 | accuracy = correct / total
17 | loss = F.nll_loss(inputs, targets, weight=torch.tensor(weight, device=inputs.device))
18 | ratio = torch.sum(pred) / total
19 | return loss, accuracy, ratio, total, correct
20 |
21 |
22 | def focal_loss(
23 | inputs: torch.Tensor,
24 | targets: torch.Tensor,
25 | alpha: float = 0.25,
26 | gamma: float = 2,
27 | reduction: str = "mean",
28 | ) -> (torch.Tensor, torch.Tensor, torch.Tensor):
29 | """
30 | Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.
31 | Source from https://pytorch.org/vision/main/_modules/torchvision/ops/focal_loss.html.
32 |
33 | Args:
34 | inputs (Tensor): A float tensor of arbitrary shape.
35 | The predictions for each example.
36 | targets (Tensor): A float tensor with the same shape as inputs. Stores the binary
37 | classification label for each element in inputs
38 | (0 for the negative class and 1 for the positive class).
39 | alpha (float): Weighting factor in range (0,1) to balance
40 | positive vs negative examples or -1 for ignore. Default: ``0.25``.
41 | gamma (float): Exponent of the modulating factor (1 - p_t) to
42 | balance easy vs hard examples. Default: ``2``.
43 | reduction (string): ``'none'`` | ``'mean'`` | ``'sum'``
44 | ``'none'``: No reduction will be applied to the output.
45 | ``'mean'``: The output will be averaged.
46 | ``'sum'``: The output will be summed. Default: ``'none'``.
47 | Returns:
48 | Loss tensors loss, accuracy and ratio.
49 | """
50 | ce_loss = F.nll_loss(inputs, targets, reduction="none")
51 | p_t = torch.exp(-ce_loss)
52 | p_t = torch.clamp(p_t, max=1.0) # clip to avoid NaN
53 | loss = ce_loss * ((1 - p_t) ** gamma)
54 |
55 | if alpha >= 0:
56 | alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
57 | loss = alpha_t * loss
58 |
59 | # Check reduction option and return loss accordingly
60 | if reduction == "none":
61 | pass
62 | elif reduction == "mean":
63 | loss = loss.mean()
64 | elif reduction == "sum":
65 | loss = loss.sum()
66 | else:
67 | raise ValueError(
68 | f"Invalid Value for arg 'reduction': '{reduction} \n Supported reduction modes: 'none', 'mean', 'sum'"
69 | )
70 |
71 | pred = inputs.argmax(dim=1)
72 | total = len(pred)
73 | correct = torch.sum((pred.clone().detach() == targets).long())
74 | accuracy = correct / total
75 | ratio = torch.sum(pred) / total
76 | return loss, accuracy, ratio, total, correct
77 |
--------------------------------------------------------------------------------
/network/polygnn.py:
--------------------------------------------------------------------------------
1 | """
2 | PolyGNN architecture.
3 | """
4 |
5 | import torch
6 | import torch.nn.functional as F
7 |
8 | from network import Encoder, Decoder, GNN
9 |
10 |
11 | class PolyGNN(torch.nn.Module):
12 | """
13 | PolyGNN.
14 | """
15 | def __init__(self, cfg):
16 | super().__init__()
17 | self.gnn = cfg.gnn
18 | self.dataset_suffix = cfg.dataset_suffix
19 | if cfg.sample.strategy == 'grid':
20 | self.points_suffix = f'_{cfg.sample.resolution}'
21 | elif cfg.sample.strategy == 'random':
22 | self.points_suffix = f'_{cfg.sample.length}'
23 | else:
24 | self.points_suffix = ''
25 |
26 | # encoder
27 | self.encoder = Encoder(backbone=cfg.encoder, latent_dim=cfg.latent_dim_light,
28 | use_spatial_transformer=cfg.use_spatial_transformer, convonet_kwargs=cfg.convonet_kwargs)
29 |
30 | # decoder
31 | latent_dim = cfg.latent_dim_light if cfg.decoder.casefold() == 'MLP'.casefold() else cfg.latent_dim_conv
32 | self.decoder = Decoder(backbone=cfg.decoder, latent_dim=latent_dim, num_queries=cfg.num_queries)
33 |
34 | # GNN
35 | if cfg.gnn is not None:
36 | self.gnn = GNN(backend=cfg.gnn, num_features=latent_dim, num_classes=2, dropout=cfg.dropout)
37 | else:
38 | self.gnn = None
39 |
40 | def forward(self, data):
41 | x = self.encoder(data[f'points{self.points_suffix}'], data[f'batch_points{self.points_suffix}'])
42 | # x: {(num_graphs_in_batch, 128, 128, 128) * 3} from ConvOEncoder
43 | # x: (num_graphs_in_batch, 256) from other encoders
44 |
45 | x = self.decoder(x, data.queries, data.batch)
46 | # x: (num_cells_in_batch, latent_dim)
47 |
48 | if self.gnn:
49 | x = self.gnn(x, data.edge_index)
50 | # x: (num_cells_in_batch, 2)
51 | else:
52 | # cell-wise voting from query points
53 | x = torch.mean(x, dim=1)
54 | # x: (num_cells_in_batch, 1)
55 | x = self.lin(x)
56 | # x: (num_cells_in_batch, 2)
57 | x = F.log_softmax(x, dim=1)
58 | # x: (num_cells_in_batch, 2)
59 | return x
60 |
--------------------------------------------------------------------------------
/reconstruct.py:
--------------------------------------------------------------------------------
1 | """
2 | Conversion from polyhedral labels to tangible mesh file.
3 | """
4 |
5 | import pickle
6 | import glob
7 | import multiprocessing
8 | from pathlib import Path
9 |
10 | from tqdm import tqdm
11 | import numpy as np
12 | import hydra
13 | from omegaconf import DictConfig
14 |
15 | from abspy import AdjacencyGraph
16 | from utils import attach_to_log
17 |
18 |
19 | def reconstruct_from_numpy(args):
20 | """
21 | Reconstruct from numpy and cell complex.
22 | args[0]: numpy_filepath
23 | args[1]: complex_filepath
24 | args[2]: mesh_filepath
25 | args[3]: reconstruction type ('cell' or 'mesh')
26 | args[4]: seal by excluding boundary cells
27 | """
28 | with open(args[1], 'rb') as handle:
29 | cell_complex = pickle.load(handle)
30 |
31 | pred = np.load(args[0])
32 |
33 | # exclude boundary cells
34 | if args[4]:
35 | cells_boundary = cell_complex.cells_boundary()
36 | pred[cells_boundary] = 0
37 |
38 | indices_cells = np.where(pred)[0]
39 |
40 | if args[3] == 'mesh':
41 | adjacency_graph = AdjacencyGraph(cell_complex.graph, quiet=True)
42 | adjacency_graph.reachable = adjacency_graph.to_uids(indices_cells)
43 | adjacency_graph.non_reachable = np.setdiff1d(adjacency_graph.uid, adjacency_graph.reachable).tolist()
44 | adjacency_graph.save_surface_obj(args[2], cells=cell_complex.cells)
45 |
46 | elif args[3] == 'cell':
47 | if len(indices_cells) > 0:
48 | cell_complex.save_obj(args[2], indices_cells=indices_cells, use_mtl=True)
49 | else:
50 | raise ValueError(f'Unexpected reconstruction type: {args[3]}')
51 |
52 |
53 | @hydra.main(config_path='./conf', config_name='config', version_base='1.2')
54 | def multi_reconstruct_from_numpy(cfg: DictConfig):
55 | """
56 | Reconstruct from numpy and cell complex with multiprocessing.
57 | """
58 | # initialize logging
59 | logger = attach_to_log()
60 |
61 | # numpy filenames
62 | filenames_numpy = glob.glob(f'{cfg.output_dir}' + '/*.npy')
63 | args = []
64 | for filename_numpy in filenames_numpy:
65 | stem = Path(filename_numpy).stem
66 | filename_complex = Path(cfg.complex_dir) / (stem + '.cc')
67 | filename_output = Path(filename_numpy).with_suffix('.obj')
68 | if not filename_output.exists():
69 | args.append((filename_numpy, filename_complex, filename_output, cfg.reconstruct.type, cfg.reconstruct.seal))
70 |
71 | logger.info('Start reconstruction from numpy and cell complex')
72 | with multiprocessing.Pool(processes=cfg.num_workers if cfg.num_workers else multiprocessing.cpu_count()) as pool:
73 | # call with multiprocessing
74 | for _ in tqdm(pool.imap_unordered(reconstruct_from_numpy, args), desc='reconstruction', total=len(args)):
75 | pass
76 |
77 | # exit with a message
78 | logger.info('Done reconstruction from numpy and cell complex')
79 |
80 |
81 | if __name__ == '__main__':
82 | multi_reconstruct_from_numpy()
83 |
--------------------------------------------------------------------------------
/remap.py:
--------------------------------------------------------------------------------
1 | """
2 | Remap normalized instances to global CRS.
3 | """
4 |
5 | import glob
6 | from pathlib import Path
7 | import multiprocessing
8 | import logging
9 |
10 | import hydra
11 | from omegaconf import DictConfig
12 | from tqdm import tqdm
13 |
14 | from utils import reverse_normalise_mesh, reverse_normalise_cloud, normalise_mesh
15 |
16 |
17 | logger = logging.getLogger("trimesh")
18 | logger.setLevel(logging.WARNING)
19 |
20 |
21 | @hydra.main(config_path='./conf', config_name='config', version_base='1.2')
22 | def normalize_meshes(cfg: DictConfig):
23 | """
24 | Normalize meshes.
25 |
26 | cfg: DictConfig
27 | Hydra configuration
28 | """
29 | args = []
30 | input_filenames = glob.glob(f'{cfg.output_dir}/*.obj')
31 | output_dir = Path(cfg.output_dir) / 'normalized'
32 | output_dir.mkdir(exist_ok=True)
33 | for input_filename in input_filenames:
34 | base_filename = Path(input_filename).name
35 | reference_filename = (Path(cfg.reference_dir) / base_filename).with_suffix('.obj')
36 | output_filename = output_dir / base_filename
37 | args.append((input_filename, reference_filename, output_filename, 'scene', cfg.reconstruct.offset, True, False))
38 | print('start processing')
39 | with multiprocessing.Pool(processes=cfg.num_workers if cfg.num_workers else multiprocessing.cpu_count()) as pool:
40 | # call with multiprocessing
41 | for _ in tqdm(pool.imap_unordered(normalise_mesh, args), desc='Normalizing meshes', total=len(args)):
42 | pass
43 |
44 |
45 | # normalize clouds as meshes
46 | normalize_clouds = normalize_meshes
47 |
48 |
49 | @hydra.main(config_path='./conf', config_name='config', version_base='1.2')
50 | def remap_meshes(cfg: DictConfig):
51 | """
52 | Remap normalized buildings to global CRS.
53 |
54 | Parameters
55 | ----------
56 | cfg: DictConfig
57 | Hydra configuration
58 | """
59 | args = []
60 | input_filenames = glob.glob(cfg.output_dir + '/*.obj')
61 | output_dir = Path(cfg.output_dir) / 'global'
62 | output_dir.mkdir(exist_ok=True)
63 | for input_filename in input_filenames:
64 | base_filename = Path(input_filename).name
65 | reference_filename = Path(cfg.reference_dir) / base_filename
66 | output_filename = output_dir / base_filename
67 | args.append((input_filename, reference_filename, output_filename, 'scene', cfg.reconstruct.offset, cfg.reconstruct.scale, cfg.reconstruct.translate))
68 | with multiprocessing.Pool(processes=cfg.num_workers if cfg.num_workers else multiprocessing.cpu_count()) as pool:
69 | # call with multiprocessing
70 | for _ in tqdm(pool.imap_unordered(reverse_normalise_mesh, args), desc='Remapping meshes', total=len(args)):
71 | pass
72 |
73 |
74 | @hydra.main(config_path='./conf', config_name='config', version_base='1.2')
75 | def remap_clouds(cfg: DictConfig):
76 | """
77 | Remap normalized point clouds to global CRS.
78 |
79 | Parameters
80 | ----------
81 | cfg: DictConfig
82 | Hydra configuration
83 | """
84 | args = []
85 | input_filenames = glob.glob(f'{cfg.data_dir}/raw/test_cloud_normalised_ply/*.ply')
86 | output_dir = Path(cfg.output_dir) / 'global_clouds'
87 | output_dir.mkdir(exist_ok=True)
88 | for input_filename in input_filenames:
89 | base_filename = Path(input_filename).name
90 | reference_filename = (Path(cfg.reference_dir) / base_filename).with_suffix('.obj')
91 | output_filename = output_dir / base_filename
92 | args.append((input_filename, reference_filename, output_filename, 'scene', cfg.reconstruct.offset))
93 | print('start processing')
94 | with multiprocessing.Pool(processes=cfg.num_workers if cfg.num_workers else multiprocessing.cpu_count()) as pool:
95 | # call with multiprocessing
96 | for _ in tqdm(pool.imap_unordered(reverse_normalise_cloud, args), desc='Remapping clouds', total=len(args)):
97 | pass
98 |
99 |
100 | if __name__ == '__main__':
101 | remap_meshes()
102 |
--------------------------------------------------------------------------------
/stats.py:
--------------------------------------------------------------------------------
1 | """
2 | Generate statistics for mesh reconstruction results.
3 | Copied from https://github.com/ErlerPhilipp/points2surf/blob/master/source/base/evaluation.py.
4 | """
5 |
6 | import os
7 | import subprocess
8 | import multiprocessing
9 |
10 | import numpy as np
11 | import hydra
12 | from omegaconf import DictConfig
13 |
14 |
15 | def make_dir_for_file(file):
16 | """
17 | Make dir for file.
18 | """
19 | file_dir = os.path.dirname(file)
20 | if file_dir != '':
21 | if not os.path.exists(file_dir):
22 | try:
23 | os.makedirs(os.path.dirname(file))
24 | except OSError as exc: # Guard against race condition
25 | raise
26 |
27 |
28 | def mp_worker(call):
29 | """
30 | Small function that starts a new thread with a system call. Used for thread pooling.
31 | """
32 | call = call.split(' ')
33 | verbose = call[-1] == '--verbose'
34 | if verbose:
35 | call = call[:-1]
36 | subprocess.run(call)
37 | else:
38 | # subprocess.run(call, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) # suppress outputs
39 | subprocess.run(call, stdout=subprocess.DEVNULL)
40 |
41 |
42 | def start_process_pool(worker_function, parameters, num_processes, timeout=None):
43 |
44 | if len(parameters) > 0:
45 | if num_processes <= 1:
46 | print('Running loop for {} with {} calls on {} workers'.format(
47 | str(worker_function), len(parameters), num_processes))
48 | results = []
49 | for c in parameters:
50 | results.append(worker_function(*c))
51 | return results
52 | print('Running loop for {} with {} calls on {} subprocess workers'.format(
53 | str(worker_function), len(parameters), num_processes))
54 | with multiprocessing.Pool(processes=num_processes, maxtasksperchild=1) as pool:
55 | results = pool.starmap(worker_function, parameters)
56 | return results
57 | else:
58 | return None
59 |
60 |
61 | def _chamfer_distance_single_file(file_in, file_ref, samples_per_model, num_processes=1):
62 | # http://graphics.stanford.edu/courses/cs468-17-spring/LectureSlides/L14%20-%203d%20deep%20learning%20on%20point%20cloud%20representation%20(analysis).pdf
63 |
64 | import trimesh
65 | import trimesh.sample
66 | import sys
67 | import scipy.spatial as spatial
68 |
69 | def sample_mesh(mesh_file, num_samples):
70 | try:
71 | mesh = trimesh.load(mesh_file)
72 | except:
73 | return np.zeros((0, 3))
74 | samples, face_indices = trimesh.sample.sample_surface_even(mesh, num_samples)
75 | return samples
76 |
77 | try:
78 | new_mesh_samples = sample_mesh(file_in, samples_per_model)
79 | ref_mesh_samples = sample_mesh(file_ref, samples_per_model)
80 | except AttributeError:
81 | # unable to sample
82 | return file_in, file_ref, -1.0
83 |
84 | if new_mesh_samples.shape[0] == 0 or ref_mesh_samples.shape[0] == 0:
85 | return file_in, file_ref, -1.0
86 |
87 | leaf_size = 100
88 | sys.setrecursionlimit(int(max(1000, round(new_mesh_samples.shape[0] / leaf_size))))
89 | kdtree_new_mesh_samples = spatial.cKDTree(new_mesh_samples, leaf_size)
90 | kdtree_ref_mesh_samples = spatial.cKDTree(ref_mesh_samples, leaf_size)
91 |
92 | ref_new_dist, corr_new_ids = kdtree_new_mesh_samples.query(ref_mesh_samples, 1, workers=num_processes)
93 | new_ref_dist, corr_ref_ids = kdtree_ref_mesh_samples.query(new_mesh_samples, 1, workers=num_processes)
94 |
95 | ref_new_dist_sum = np.sum(ref_new_dist)
96 | new_ref_dist_sum = np.sum(new_ref_dist)
97 | chamfer_dist = ref_new_dist_sum + new_ref_dist_sum
98 |
99 | return file_in, file_ref, chamfer_dist
100 |
101 |
102 | def _hausdorff_distance_directed_single_file(file_in, file_ref, samples_per_model):
103 | import scipy.spatial as spatial
104 | import trimesh
105 | import trimesh.sample
106 |
107 | def sample_mesh(mesh_file, num_samples):
108 | try:
109 | mesh = trimesh.load(mesh_file)
110 | except:
111 | return np.zeros((0, 3))
112 | samples, face_indices = trimesh.sample.sample_surface_even(mesh, num_samples)
113 | return samples
114 |
115 | try:
116 | new_mesh_samples = sample_mesh(file_in, samples_per_model)
117 | ref_mesh_samples = sample_mesh(file_ref, samples_per_model)
118 | except AttributeError:
119 | # unable to sample
120 | return file_in, file_ref, -1.0
121 |
122 | if new_mesh_samples.shape[0] == 0 or ref_mesh_samples.shape[0] == 0:
123 | return file_in, file_ref, -1.0
124 |
125 | dist, _, _ = spatial.distance.directed_hausdorff(new_mesh_samples, ref_mesh_samples)
126 | return file_in, file_ref, dist
127 |
128 |
129 | def _hausdorff_distance_single_file(file_in, file_ref, samples_per_model):
130 | import scipy.spatial as spatial
131 | import trimesh
132 | import trimesh.sample
133 |
134 | def sample_mesh(mesh_file, num_samples):
135 | try:
136 | mesh = trimesh.load(mesh_file)
137 | except:
138 | return np.zeros((0, 3))
139 | samples, face_indices = trimesh.sample.sample_surface_even(mesh, num_samples)
140 | return samples
141 |
142 | try:
143 | new_mesh_samples = sample_mesh(file_in, samples_per_model)
144 | ref_mesh_samples = sample_mesh(file_ref, samples_per_model)
145 | except AttributeError:
146 | # unable to sample
147 | return file_in, file_ref, -1.0, -1.0, -1.0
148 |
149 | if new_mesh_samples.shape[0] == 0 or ref_mesh_samples.shape[0] == 0:
150 | return file_in, file_ref, -1.0, -1.0, -1.0
151 |
152 | dist_new_ref, _, _ = spatial.distance.directed_hausdorff(new_mesh_samples, ref_mesh_samples)
153 | dist_ref_new, _, _ = spatial.distance.directed_hausdorff(ref_mesh_samples, new_mesh_samples)
154 | dist = max(dist_new_ref, dist_ref_new)
155 | return file_in, file_ref, dist_new_ref, dist_ref_new, dist
156 |
157 |
158 | def _scale_single_file(file_ref):
159 | import trimesh
160 | if not file_ref.endswith('.obj'):
161 | file_ref = file_ref + '.obj'
162 | mesh = trimesh.load(file_ref)
163 | extents = mesh.extents
164 | scale = extents.max()
165 | return scale
166 |
167 |
168 | def mesh_comparison(new_meshes_dir_abs, ref_meshes_dir_abs,
169 | num_processes, report_name, samples_per_model=10000, dataset_file_abs=None):
170 | if not os.path.isdir(new_meshes_dir_abs):
171 | print('Warning: dir to check doesn\'t exist'.format(new_meshes_dir_abs))
172 | return
173 |
174 | new_mesh_files = [f for f in os.listdir(new_meshes_dir_abs)
175 | if os.path.isfile(os.path.join(new_meshes_dir_abs, f))]
176 | ref_mesh_files = [f for f in os.listdir(ref_meshes_dir_abs)
177 | if os.path.isfile(os.path.join(ref_meshes_dir_abs, f))]
178 |
179 | if dataset_file_abs is None:
180 | mesh_files_to_compare_set = set(ref_mesh_files) # set for efficient search
181 | else:
182 | if not os.path.isfile(dataset_file_abs):
183 | raise ValueError('File does not exist: {}'.format(dataset_file_abs))
184 | with open(dataset_file_abs) as f:
185 | mesh_files_to_compare_set = f.readlines()
186 | mesh_files_to_compare_set = [f.replace('\n', '') + '.ply' for f in mesh_files_to_compare_set]
187 | mesh_files_to_compare_set = [f.split('.')[0] for f in mesh_files_to_compare_set]
188 | mesh_files_to_compare_set = set(mesh_files_to_compare_set)
189 |
190 | # # skip if everything is unchanged
191 | # new_mesh_files_abs = [os.path.join(new_meshes_dir_abs, f) for f in new_mesh_files]
192 | # ref_mesh_files_abs = [os.path.join(ref_meshes_dir_abs, f) for f in ref_mesh_files]
193 | # if not utils_files.call_necessary(new_mesh_files_abs + ref_mesh_files_abs, report_name):
194 | # return
195 |
196 | def ref_mesh_for_new_mesh(new_mesh_file: str, all_ref_meshes: list) -> list:
197 | stem_new_mesh_file = new_mesh_file.split('.')[0]
198 | ref_files = list(set([f for f in all_ref_meshes if f.split('.')[0] == stem_new_mesh_file]))
199 | return ref_files
200 |
201 | call_params = []
202 | for fi, new_mesh_file in enumerate(new_mesh_files):
203 | if new_mesh_file.split('.')[0] in mesh_files_to_compare_set:
204 | new_mesh_file_abs = os.path.join(new_meshes_dir_abs, new_mesh_file)
205 | ref_mesh_files_matching = ref_mesh_for_new_mesh(new_mesh_file, ref_mesh_files)
206 | if len(ref_mesh_files_matching) > 0:
207 | ref_mesh_file_abs = os.path.join(ref_meshes_dir_abs, ref_mesh_files_matching[0])
208 | call_params.append((new_mesh_file_abs, ref_mesh_file_abs, samples_per_model))
209 | if len(call_params) == 0:
210 | raise ValueError('Results are empty!')
211 | results_hausdorff = start_process_pool(_hausdorff_distance_single_file, call_params, num_processes)
212 | results = [(r[0], r[1], str(r[2]), str(r[3]), str(r[4])) for r in results_hausdorff]
213 |
214 | call_params = []
215 | for fi, new_mesh_file in enumerate(new_mesh_files):
216 | if new_mesh_file.split('.')[0] in mesh_files_to_compare_set:
217 | new_mesh_file_abs = os.path.join(new_meshes_dir_abs, new_mesh_file)
218 | ref_mesh_files_matching = ref_mesh_for_new_mesh(new_mesh_file, ref_mesh_files)
219 | if len(ref_mesh_files_matching) > 0:
220 | ref_mesh_file_abs = os.path.join(ref_meshes_dir_abs, ref_mesh_files_matching[0])
221 | call_params.append((new_mesh_file_abs, ref_mesh_file_abs, samples_per_model))
222 | results_chamfer = start_process_pool(_chamfer_distance_single_file, call_params, num_processes)
223 | results = [r + (str(results_chamfer[ri][2]),) for ri, r in enumerate(results)]
224 |
225 | # no reference but reconstruction
226 | for fi, new_mesh_file in enumerate(new_mesh_files):
227 | if new_mesh_file.split('.')[0] not in mesh_files_to_compare_set:
228 | if dataset_file_abs is None:
229 | new_mesh_file_abs = os.path.join(new_meshes_dir_abs, new_mesh_file)
230 | ref_mesh_files_matching = ref_mesh_for_new_mesh(new_mesh_file, ref_mesh_files)
231 | if len(ref_mesh_files_matching) > 0:
232 | reference_mesh_file_abs = os.path.join(ref_meshes_dir_abs, ref_mesh_files_matching[0])
233 | results.append((new_mesh_file_abs, reference_mesh_file_abs, str(-2), str(-2), str(-2), str(-2)))
234 | else:
235 | mesh_files_to_compare_set.remove(new_mesh_file.split('.')[0])
236 |
237 | # no reconstruction but reference
238 | for ref_without_new_mesh in mesh_files_to_compare_set:
239 | new_mesh_file_abs = os.path.join(new_meshes_dir_abs, ref_without_new_mesh)
240 | reference_mesh_file_abs = os.path.join(ref_meshes_dir_abs, ref_without_new_mesh)
241 | results.append((new_mesh_file_abs, reference_mesh_file_abs, str(-1), str(-1), str(-1), str(-1)))
242 |
243 | # append scale to each row
244 | call_params = []
245 | for fi, row in enumerate(results):
246 | ref_file_abs = row[1]
247 | call_params.append([ref_file_abs])
248 | results_scale = start_process_pool(_scale_single_file, call_params, num_processes)
249 | results = [r + (str(results_scale[ri]),) for ri, r in enumerate(results)]
250 |
251 | # sort by file name
252 | results = sorted(results, key=lambda x: x[0])
253 |
254 | make_dir_for_file(report_name)
255 | csv_lines = ['in mesh,ref mesh,Hausdorff dist new-ref,Hausdorff dist ref-new,Hausdorff dist,'
256 | 'Chamfer dist(-1: no input; -2: no reference),Scale']
257 | csv_lines += [','.join(item) for item in results]
258 | # csv_lines += ['=AVERAGE(E2:E41)']
259 | csv_lines_str = '\n'.join(csv_lines)
260 | with open(report_name, "w") as text_file:
261 | text_file.write(csv_lines_str)
262 |
263 |
264 | @hydra.main(config_path='./conf', config_name='config', version_base='1.2')
265 | def generate_stats(cfg: DictConfig):
266 | """
267 | Evaluate Hausdorff distance between reconstructed and GT models.
268 |
269 | Parameters
270 | ----------
271 | cfg: DictConfig
272 | Hydra configuration
273 | """
274 |
275 | csv_file = os.path.join(cfg.csv_path)
276 | mesh_comparison(
277 | new_meshes_dir_abs=cfg.remap_dir,
278 | ref_meshes_dir_abs=cfg.reference_dir,
279 | num_processes=cfg.num_workers,
280 | report_name=csv_file,
281 | samples_per_model=cfg.evaluate.num_samples,
282 | dataset_file_abs=os.path.join(cfg.data_dir, 'raw/testset.txt'))
283 |
284 |
285 | if __name__ == '__main__':
286 | generate_stats()
287 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | """
2 | Evaluation of PolyGNN.
3 | """
4 |
5 | import pickle
6 | from pathlib import Path
7 | import multiprocessing
8 |
9 | import numpy as np
10 | import torch
11 | import hydra
12 | from omegaconf import DictConfig
13 | from tqdm import tqdm
14 | from torch.nn.parallel import DistributedDataParallel
15 | from torch_geometric.loader import DataLoader
16 | import torch.multiprocessing as mp
17 | from torchmetrics.classification import BinaryAccuracy
18 | import torch.distributed as dist
19 | from torch_geometric import compile
20 |
21 | from network import PolyGNN
22 | from dataset import CityDataset, TestOnlyDataset
23 | from utils import init_device, Sampler, set_seed, attach_to_log, setup_runner
24 |
25 |
26 | class PredictionSaver:
27 | """
28 | Asynchronous prediction saving.
29 | """
30 | def __init__(self, processes):
31 | self.pool = multiprocessing.Pool(processes=processes)
32 |
33 | @staticmethod
34 | def save(args):
35 | pred, name, cfg = args
36 | indices_cells = np.where(pred.cpu().numpy())[0]
37 | if len(indices_cells) > 0:
38 | complex_path = f'{cfg.complex_dir}/{name}.cc'
39 | with open(complex_path, 'rb') as handle:
40 | cell_complex = pickle.load(handle)
41 | output_path = f'{cfg.output_dir}/{name}.npy'
42 | output = np.zeros([cell_complex.num_cells], dtype=int)
43 | output[indices_cells] = 1
44 |
45 | if cfg.evaluate.seal:
46 | cells_boundary = cell_complex.cells_boundary()
47 | output[cells_boundary] = 0
48 | np.save(output_path, output)
49 |
50 |
51 | def run_eval(rank, world_size, dataset_test, cfg):
52 | """
53 | Runner function for distributed inference of PolyGNN.
54 | """
55 | # set up runner
56 | setup_runner(rank, world_size, cfg.master_addr, cfg.master_port)
57 |
58 | # limit number of threads
59 | torch.set_num_threads(cfg.num_workers // world_size)
60 |
61 | # initialize logging
62 | logger = attach_to_log(filepath='./outputs/test.log')
63 |
64 | # indicate device
65 | logger.debug(f"Device activated: " + f"CUDA: {cfg.gpu_ids[rank]}")
66 |
67 | # initialize metric
68 | metric = BinaryAccuracy()
69 |
70 | # split test indices into `world_size` many chunks
71 | eval_indices = torch.arange(len(dataset_test))
72 | eval_indices = eval_indices.split(len(eval_indices) // world_size)[rank]
73 | dataloader_test = DataLoader(dataset_test[eval_indices], batch_size=cfg.batch_size // world_size,
74 | shuffle=cfg.shuffle, num_workers=cfg.num_workers // world_size,
75 | pin_memory=True, prefetch_factor=8)
76 |
77 | # initialize model
78 | model = PolyGNN(cfg)
79 | model.metric = metric
80 | model = model.to(rank)
81 |
82 | # distributed parallelization
83 | model = DistributedDataParallel(model, device_ids=[rank])
84 |
85 | # compile model for better performance
86 | compile(model, dynamic=True, fullgraph=True)
87 |
88 | # load from checkpoint
89 | map_location = f'cuda:{rank}'
90 | if rank == 0:
91 | logger.info(f'Resuming from {cfg.checkpoint_path}')
92 | state = torch.load(cfg.checkpoint_path, map_location=map_location)
93 | state_dict = state['state_dict']
94 | model.load_state_dict(state_dict, strict=False)
95 |
96 | # specify data attributes
97 | if cfg.sample.strategy == 'grid':
98 | points_suffix = f'_{cfg.sample.resolution}'
99 | elif cfg.sample.strategy == 'random':
100 | points_suffix = f'_{cfg.sample.length}'
101 | else:
102 | points_suffix = ''
103 |
104 | # start inference
105 | model.eval()
106 | pbar = tqdm(dataloader_test, desc=f'eval', disable=rank != 0)
107 |
108 | # initialize PredictionSaver instance
109 | prediction_saver = PredictionSaver(processes=cfg.num_workers // world_size)
110 |
111 | with torch.no_grad():
112 | for batch in pbar:
113 | batch = batch.to(rank, f'points{points_suffix}', f'batch_points{points_suffix}', 'queries',
114 | 'edge_index', 'batch', 'y')
115 | outs = model(batch)
116 | outs = outs.argmax(dim=1)
117 | targets = batch.y
118 |
119 | # metric on current batch
120 | _accuracy = metric(outs, targets)
121 | pbar.set_postfix_str('acc={:.2f}'.format(_accuracy))
122 |
123 | # save prediction as numpy file
124 | if cfg.evaluate.save:
125 | Path(cfg.output_dir).mkdir(exist_ok=True)
126 | _, boundary_indices = torch.unique(batch.batch, return_counts=True)
127 |
128 | preds = torch.split(outs, split_size_or_sections=boundary_indices.tolist(), dim=0)
129 | names = batch.name
130 |
131 | # asynchronous file saving
132 | prediction_saver.pool.map(prediction_saver.save, zip(preds, names, [cfg] * len(preds)))
133 |
134 | # metric on all batches and all accelerators using custom accumulation
135 | accuracy = metric.compute()
136 |
137 | if rank == 0:
138 | logger.info(f"Evaluation accuracy: {accuracy}")
139 |
140 | # reset internal state such that metric ready for new data
141 | metric.reset()
142 |
143 | dist.barrier()
144 |
145 | dist.destroy_process_group()
146 |
147 |
148 | @hydra.main(config_path='./conf', config_name='config', version_base='1.2')
149 | def test(cfg: DictConfig):
150 | """
151 | Test PolyGNN for reconstruction.
152 |
153 | Parameters
154 | ----------
155 | cfg: DictConfig
156 | Hydra configuration
157 | """
158 |
159 | # initialize logger
160 | logger = attach_to_log(filepath='./outputs/test.log')
161 |
162 | # initialize device
163 | init_device(cfg.gpu_ids, register_freeze=cfg.gpu_freeze)
164 | logger.info(f"Device initialized: " + f"CUDA: {cfg.gpu_ids}")
165 |
166 | # fix randomness
167 | set_seed(cfg.seed)
168 | logger.info(f"Random seed set to {cfg.seed}")
169 |
170 | # initialize data sampler
171 | sampler = Sampler(strategy=cfg.sample.strategy, length=cfg.sample.length, ratio=cfg.sample.ratio,
172 | resolutions=cfg.sample.resolutions, duplicate=cfg.sample.duplicate, seed=cfg.seed)
173 | transform = sampler.sample if cfg.sample.transform else None
174 | pre_transform = sampler.sample if cfg.sample.pre_transform else None
175 |
176 | # initialize dataset
177 | if cfg.dataset in {'munich', 'munich_perturb', 'munich_subsample', 'munich_truncate', 'munich_haswall', 'campus_ldbv'}:
178 | dataset = TestOnlyDataset(pre_transform=pre_transform, transform=transform, root=cfg.data_dir,
179 | split='test', num_workers=cfg.num_workers)
180 | else:
181 | dataset = CityDataset(pre_transform=pre_transform, transform=transform, root=cfg.data_dir,
182 | split='test', num_workers=cfg.num_workers)
183 |
184 | world_size = len(cfg.gpu_ids)
185 | mp.spawn(run_eval, args=(world_size, dataset, cfg), nprocs=world_size, join=True)
186 |
187 |
188 | if __name__ == '__main__':
189 | test()
190 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | """
2 | Supervised training of PolyGNN.
3 | """
4 |
5 | import os
6 | from pathlib import Path
7 |
8 | import wandb
9 | import hydra
10 | from omegaconf import DictConfig
11 | from tqdm import tqdm
12 | import torch
13 | import torch.multiprocessing as mp
14 | import torch.distributed as dist
15 | from torch.nn.parallel import DistributedDataParallel
16 | from torch_geometric.loader import DataLoader
17 | from torch_geometric import compile
18 | from torchmetrics.classification import BinaryAccuracy
19 |
20 | from network import PolyGNN, focal_loss, bce_loss
21 | from dataset import CityDataset
22 | from utils import init_device, Sampler, set_seed, attach_to_log, setup_runner
23 |
24 |
25 | def run_train(rank, world_size, dataset_train, dataset_test, cfg):
26 | """
27 | Runner function for distributed training of PolyGNN.
28 | """
29 | # set up runner
30 | setup_runner(rank, world_size, cfg.master_addr, cfg.master_port)
31 |
32 | # limit number of threads
33 | torch.set_num_threads(cfg.num_workers // world_size)
34 |
35 | # initialize logging
36 | logger = attach_to_log(filepath='./outputs/train.log')
37 | if rank == 0:
38 | logger.info(f'Training PolyGNN on {cfg.dataset}')
39 | wandb_mode = 'online' if cfg.wandb else 'disabled'
40 | wandb.init(mode=wandb_mode, project='polygnn', entity='zhaiyu',
41 | name=cfg.dataset+cfg.run_suffix, dir=cfg.wandb_dir)
42 | wandb.save('./outputs/.hydra/*')
43 |
44 | # indicate device
45 | logger.debug(f"Device activated: " + f"CUDA: {cfg.gpu_ids[rank]}")
46 |
47 | # split training indices into `world_size` many chunks
48 | train_indices = torch.arange(len(dataset_train))
49 | train_indices = train_indices.split(len(train_indices) // world_size)[rank]
50 | eval_indices = torch.arange(len(dataset_test))
51 | eval_indices = eval_indices.split(len(eval_indices) // world_size)[rank]
52 |
53 | # setup dataloaders
54 | dataloader_train = DataLoader(dataset_train[train_indices], batch_size=cfg.batch_size // world_size,
55 | shuffle=cfg.shuffle, num_workers=cfg.num_workers // world_size,
56 | pin_memory=True, prefetch_factor=8)
57 | dataloader_test = DataLoader(dataset_test[eval_indices], batch_size=cfg.batch_size // world_size,
58 | shuffle=cfg.shuffle, num_workers=cfg.num_workers // world_size,
59 | pin_memory=True, prefetch_factor=8)
60 |
61 | # initialize model
62 | model = PolyGNN(cfg).to(rank)
63 |
64 | # distributed parallelization
65 | model = DistributedDataParallel(model, device_ids=[rank])
66 |
67 | # compile model for better performance
68 | compile(model, dynamic=True, fullgraph=True)
69 |
70 | # freeze certain layers for fine-tuning
71 | if cfg.warm:
72 | for stage in cfg.freeze_stages:
73 | logger.info(f'Freezing stage: {stage}')
74 | for parameter in getattr(model, stage).parameters():
75 | parameter.requires_grad = False
76 |
77 | # initialize optimizer and scheduler
78 | optimizer = torch.optim.Adam(model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay)
79 | scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer, base_lr=cfg.scheduler.base_lr, max_lr=cfg.scheduler.max_lr,
80 | step_size_up=cfg.scheduler.step_size_up, mode=cfg.scheduler.mode,
81 | cycle_momentum=False)
82 | # initialize metrics
83 | metric = BinaryAccuracy().to(rank)
84 |
85 | # warm start from checkpoint if available
86 | if cfg.warm:
87 | if rank == 0:
88 | logger.info(f'Resuming from {cfg.checkpoint_path}')
89 | map_location = f'cuda:{rank}'
90 | state = torch.load(cfg.checkpoint_path, map_location=map_location)
91 | state_dict = state['state_dict']
92 | model.load_state_dict(state_dict, strict=False)
93 | if cfg.warm_optimizer:
94 | try:
95 | optimizer.load_state_dict(state['optimizer'])
96 | if rank == 0:
97 | logger.info(f'Optimizer loaded from checkpoint')
98 | except (KeyError, ValueError) as error:
99 | if rank == 0:
100 | logger.warning(f'Optimizer not loaded from checkpoint: {error}')
101 |
102 | if cfg.warm_scheduler:
103 | try:
104 | scheduler.load_state_dict(state['scheduler'])
105 | if rank == 0:
106 | logger.info(f'Scheduler loaded from checkpoint')
107 | except (KeyError, ValueError) as error:
108 | if rank == 0:
109 | logger.warning(f'Scheduler not loaded from checkpoint: {error}')
110 |
111 | best_accuracy = state['accuracy']
112 | if state['epoch'] > cfg.num_epochs:
113 | if rank == 0:
114 | logger.info(f'Expected epoch reached from checkpoint')
115 | return
116 | epoch_generator = range(state['epoch'] + 1, cfg.num_epochs)
117 | else:
118 | best_accuracy = 0
119 | epoch_generator = range(cfg.num_epochs)
120 |
121 | # initialize loss function
122 | if cfg.loss == 'focal':
123 | loss_func = focal_loss
124 | elif cfg.loss == 'bce':
125 | loss_func = bce_loss
126 | else:
127 | raise ValueError(f'Unexpected loss function: {cfg.loss}')
128 |
129 | # specify data attributes
130 | if cfg.sample.strategy == 'grid':
131 | points_suffix = f'_{cfg.sample.resolution}'
132 | elif cfg.sample.strategy == 'random':
133 | points_suffix = f'_{cfg.sample.length}'
134 | else:
135 | points_suffix = ''
136 |
137 | # start training
138 | for i in epoch_generator:
139 | model.train()
140 | pbar = tqdm(dataloader_train, desc=f'epoch {i}', disable=rank != 0)
141 |
142 | if rank == 0:
143 | wandb.log({"epoch": i})
144 |
145 | for batch in pbar:
146 | optimizer.zero_grad()
147 | batch = batch.to(rank, f'points{points_suffix}', f'batch_points{points_suffix}', 'queries', 'edge_index',
148 | 'batch', 'y')
149 | outs = model(batch)
150 | targets = batch.y
151 | loss, accuracy, ratio, _, _ = loss_func(outs, targets)
152 |
153 | if rank == 0:
154 | wandb.log({"loss": loss})
155 | wandb.log({"train_accuracy": accuracy})
156 | wandb.log({"ratio:": ratio})
157 | wandb.log({"learning_rate": optimizer.param_groups[0]['lr']})
158 | pbar.set_postfix_str('loss={:.2f}, acc={:.2f}, ratio={:.2f}'.format(loss, accuracy, ratio))
159 |
160 | loss.backward()
161 | optimizer.step()
162 | scheduler.step()
163 |
164 | dist.barrier()
165 |
166 | # validate and save checkpoint with DDP
167 | if cfg.validate and i % cfg.save_interval == 0:
168 | model.metric = metric
169 | model = model.to(rank)
170 | model.eval()
171 |
172 | pbar = tqdm(dataloader_test, desc=f'eval', disable=rank != 0)
173 | with torch.no_grad():
174 | for batch in pbar:
175 | batch = batch.to(rank, f'points{points_suffix}', f'batch_points{points_suffix}', 'queries',
176 | 'edge_index', 'batch', 'y')
177 | outs = model(batch)
178 | outs = outs.argmax(dim=1)
179 | targets = batch.y
180 |
181 | # metric on current batch
182 | accuracy = metric(outs, targets)
183 | if rank == 0:
184 | pbar.set_postfix_str('acc={:.2f}'.format(accuracy))
185 |
186 | # metrics on all batches and all accelerators using custom accumulation
187 | accuracy = metric.compute()
188 |
189 | dist.barrier()
190 |
191 | if rank == 0:
192 | logger.info(f'Evaluation accuracy: {accuracy:.4f}')
193 | wandb.log({"eval_accuracy": accuracy})
194 | checkpoint_path = os.path.join(cfg.checkpoint_dir, f'model_epoch{i}.pth')
195 | logger.info(f'Saving checkpoint to {checkpoint_path}.')
196 | Path(cfg.checkpoint_dir).mkdir(parents=True, exist_ok=True)
197 | state = {
198 | 'epoch': i,
199 | 'state_dict': model.state_dict(),
200 | 'optimizer': optimizer.state_dict(),
201 | 'scheduler': scheduler.state_dict(),
202 | 'accuracy': accuracy,
203 | }
204 | # Cannot pickle 'WeakMethod' object when saving state_dict for CyclicLr
205 | # https://github.com/pytorch/pytorch/pull/91400
206 | torch.save(state, checkpoint_path)
207 | if accuracy > best_accuracy:
208 | logger.info(f'Saving checkpoint to {cfg.checkpoint_path}.')
209 | torch.save(state, cfg.checkpoint_path)
210 | wandb.save(cfg.checkpoint_path)
211 | best_accuracy = accuracy
212 |
213 | # reset internal state such that metric ready for new data
214 | metric.reset()
215 |
216 | dist.barrier()
217 |
218 | dist.destroy_process_group()
219 |
220 |
221 | @hydra.main(config_path='./conf', config_name='config', version_base='1.2')
222 | def train(cfg: DictConfig):
223 | """
224 | Train PolyGNN.
225 |
226 | Parameters
227 | ----------
228 | cfg: DictConfig
229 | Hydra configuration
230 | """
231 | logger = attach_to_log(filepath='./outputs/train.log')
232 |
233 | # initialize device
234 | init_device(cfg.gpu_ids, register_freeze=cfg.gpu_freeze)
235 | logger.info(f"Device initialized: " + f"CUDA: {cfg.gpu_ids}")
236 |
237 | # fix randomness
238 | set_seed(cfg.seed)
239 | logger.info(f"Random seed set to {cfg.seed}")
240 |
241 | # initialize data sampler
242 | sampler = Sampler(strategy=cfg.sample.strategy, length=cfg.sample.length, ratio=cfg.sample.ratio,
243 | resolutions=cfg.sample.resolutions, duplicate=cfg.sample.duplicate, seed=cfg.seed)
244 | transform = sampler.sample if cfg.sample.transform else None
245 | pre_transform = sampler.sample if cfg.sample.pre_transform else None
246 |
247 | # initialize dataset
248 | dataset_train = CityDataset(pre_transform=pre_transform, transform=transform, root=cfg.data_dir,
249 | split='train', num_workers=cfg.num_workers, num_queries=cfg.num_queries)
250 | dataset_test = CityDataset(pre_transform=pre_transform, transform=transform, root=cfg.data_dir,
251 | split='test', num_workers=cfg.num_workers, num_queries=cfg.num_queries)
252 |
253 | world_size = len(cfg.gpu_ids)
254 | mp.spawn(run_train, args=(world_size, dataset_train, dataset_test, cfg), nprocs=world_size, join=True)
255 |
256 |
257 | if __name__ == '__main__':
258 | train()
259 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | """
2 | Utility functions.
3 | """
4 |
5 | import os
6 | import time
7 | import glob
8 | from pathlib import Path
9 | import atexit
10 | from itertools import repeat
11 | import math
12 | import logging
13 | import random
14 | import pickle
15 | import csv
16 | import multiprocessing
17 |
18 | from tqdm import tqdm
19 | import numpy as np
20 | import torch
21 | import torch.distributed as dist
22 | import hydra
23 | from omegaconf import DictConfig
24 | import torch_geometric as pyg
25 | import trimesh
26 | from plyfile import PlyData
27 |
28 | from abspy import VertexGroup, CellComplex
29 |
30 |
31 | def setup_runner(rank, world_size, master_addr, master_port):
32 | """
33 | Set up runner for distributed parallelization.
34 | """
35 | # initialize torch.distributed
36 | os.environ['MASTER_ADDR'] = str(master_addr)
37 | os.environ['MASTER_PORT'] = str(master_port)
38 | dist.init_process_group('nccl', rank=rank, world_size=world_size)
39 |
40 |
41 | def attach_to_log(level=logging.INFO,
42 | filepath=None,
43 | colors=True,
44 | capture_warnings=True):
45 | """
46 | Attach a stream handler to all loggers.
47 |
48 | Parameters
49 | ------------
50 | level : enum (int)
51 | Logging level, like logging.INFO
52 | colors : bool
53 | If True try to use colorlog formatter
54 | capture_warnings: bool
55 | If True capture warnings
56 | filepath: None or str
57 | path to save the logfile
58 |
59 | Returns
60 | -------
61 | logger: Logger object
62 | Logger attached with a stream handler
63 | """
64 |
65 | # make sure we log warnings from the warnings module
66 | logging.captureWarnings(capture_warnings)
67 |
68 | # create a basic formatter
69 | formatter_file = logging.Formatter(
70 | "[%(asctime)s] %(levelname)-7s (%(filename)s:%(lineno)3s) %(message)s",
71 | "%Y-%m-%d %H:%M:%S")
72 | if colors:
73 | try:
74 | from colorlog import ColoredFormatter
75 | formatter_stream = ColoredFormatter(
76 | ("%(log_color)s%(levelname)-8s%(reset)s " +
77 | "%(filename)17s:%(lineno)-4s %(blue)4s%(message)s"),
78 | datefmt=None,
79 | reset=True,
80 | log_colors={'DEBUG': 'cyan',
81 | 'INFO': 'green',
82 | 'WARNING': 'yellow',
83 | 'ERROR': 'red',
84 | 'CRITICAL': 'red'})
85 | except ImportError:
86 | formatter_stream = formatter_file
87 | else:
88 | formatter_stream = formatter_file
89 |
90 | # if no handler was passed use a StreamHandler
91 | logger = logging.getLogger()
92 | logger.setLevel(level)
93 |
94 | if not any([isinstance(handler, logging.StreamHandler) for handler in logger.handlers]):
95 | stream_handler = logging.StreamHandler()
96 | stream_handler.setFormatter(formatter_stream)
97 | logger.addHandler(stream_handler)
98 |
99 | if filepath and not any([isinstance(handler, logging.FileHandler) for handler in logger.handlers]):
100 | file_handler = logging.FileHandler(filepath)
101 | file_handler.setFormatter(formatter_file)
102 | logger.addHandler(file_handler)
103 |
104 | # set nicer numpy print options
105 | np.set_printoptions(precision=5, suppress=True)
106 |
107 | return logger
108 |
109 |
110 | def edge_index_from_dict(graph_dict):
111 | """
112 | Convert adjacency dict to edge index.
113 |
114 | Parameters
115 | ----------
116 | graph_dict: dict
117 | Adjacency dict
118 |
119 | Returns
120 | -------
121 | as_tensor: torch.Tensor
122 | Edge index
123 | """
124 | row, col = [], []
125 | for key, value in graph_dict.items():
126 | row += repeat(key, len(value))
127 | col += value
128 | edge_index = torch.tensor([row, col], dtype=torch.long)
129 | return edge_index
130 |
131 |
132 | def index_to_mask(index, size):
133 | """
134 | Convert index to binary mask.
135 |
136 | Parameters
137 | ----------
138 | index: range Object
139 | Index of 1s
140 | size: int
141 | Size of mask
142 |
143 | Returns
144 | -------
145 | as_tensor: torch.Tensor
146 | Binary mask
147 | """
148 | mask = torch.zeros((size,), dtype=torch.bool)
149 | mask[index] = 1
150 | return mask
151 |
152 |
153 | def freeze_vram(cuda_devices, timeout=500):
154 | """
155 | Freeze VRAM for a short time at program exit. For debugging.
156 |
157 | Parameters
158 | ----------
159 | cuda_devices: list of int
160 | Indices of CUDA devices
161 | timeout: int
162 | Timeout seconds
163 | """
164 | torch.cuda.empty_cache()
165 | devices_info = os.popen(
166 | '"/usr/bin/nvidia-smi" --query-gpu=memory.total,memory.used --format=csv,nounits,noheader') \
167 | .read().strip().split("\n")
168 | for i, device in enumerate(cuda_devices):
169 | total, used = devices_info[int(device)].split(',')
170 | total = int(total)
171 | used = int(used)
172 | max_mem = int(total * 0.90)
173 | block_mem = max_mem - used
174 | if block_mem > 0:
175 | x = torch.FloatTensor(256, 1024, block_mem).to(torch.device(f'cuda:{i}'))
176 | del x
177 | for _ in tqdm(range(timeout), desc='VRAM freezing'):
178 | time.sleep(1)
179 |
180 |
181 | def init_device(gpu_ids, register_freeze=False):
182 | """
183 | Init devices.
184 |
185 | Parameters
186 | ----------
187 | gpu_ids: list of int
188 | GPU indices to use
189 | register_freeze: bool
190 | Register GPU memory freeze if set True
191 | """
192 | # set multiprocessing sharing strategy
193 | torch.multiprocessing.set_sharing_strategy('file_system')
194 |
195 | # does not work for DP after import torch with PyTorch 2.0, but works for DDP nevertheless
196 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
197 | os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_ids)[1:-1]
198 |
199 | # raise soft limit from 1024 to 4096 for open files to address RuntimeError: received 0 items of ancdata
200 | import resource
201 | rlimit = resource.getrlimit(resource.RLIMIT_NOFILE)
202 | resource.setrlimit(resource.RLIMIT_NOFILE, (4096, rlimit[1]))
203 |
204 | if register_freeze:
205 | atexit.register(freeze_vram, gpu_ids)
206 |
207 |
208 | class Sampler:
209 | """
210 | Sampler to sample points from a point cloud.
211 | """
212 |
213 | def __init__(self, strategy, length, ratio, resolutions, duplicate, seed=None):
214 | self.strategy = strategy
215 | self.length = length
216 | self.ratio = ratio
217 | self.resolutions = resolutions
218 | self.duplicate = duplicate
219 |
220 | # seed once in initialization
221 | self.seed = seed
222 |
223 | def sample(self, data):
224 | with torch.no_grad():
225 | if self.seed is not None:
226 | set_seed(self.seed)
227 | if self.strategy is None:
228 | return data
229 | if self.strategy == 'fps':
230 | return self.farthest_sample(data)
231 | elif self.strategy == 'random':
232 | return self.random_sample(data)
233 | elif self.strategy == 'grid':
234 | return self.grid_sample(data)
235 | else:
236 | raise ValueError(f'Unexpected sampling strategy={self.strategy}.')
237 |
238 | def random_sample(self, data):
239 | """
240 | Random uniform sampling.
241 | """
242 | # https://pytorch-geometric.readthedocs.io/en/latest/_modules/torch_geometric/transforms/fixed_points.html#FixedPoints
243 | if not self.duplicate:
244 | choice = torch.randperm(data.num_points)[:self.length]
245 | else:
246 | choice = torch.cat(
247 | [torch.randperm(data.num_points) for _ in range(math.ceil(self.length / data.num_points))],
248 | dim=0)[:self.length]
249 | data[f'batch_points_{self.length}'] = data.batch_points[choice]
250 | data[f'points_{self.length}'] = data.points[choice]
251 | return data
252 |
253 | def grid_sample(self, data):
254 | """
255 | Sampling points into fixed-sized voxels.
256 | Each cluster returned is the cluster barycenter.
257 | """
258 | # https://pytorch-geometric.readthedocs.io/en/latest/_modules/torch_geometric/transforms/grid_sampling.html#GridSampling
259 | for size in self.resolutions:
260 | c = pyg.nn.voxel_grid(data.points, size, data.batch_points, None, None)
261 | _, perm = pyg.nn.pool.consecutive.consecutive_cluster(c)
262 | data[f'batch_points_{size}'] = data.batch_points[perm]
263 | data[f'points_{size}'] = data.points[perm]
264 | return data
265 |
266 | def farthest_sample(self, data):
267 | """
268 | Farthest sampling which iteratively samples the most distant point with regard to the rest points. Inplace.
269 | """
270 | # https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#torch_geometric.nn.pool.fps
271 | perm = pyg.nn.pool.fps(data.points, data.batch_points, ratio=self.ratio)
272 | data[f'batch_points_fps'] = data.batch_points[perm]
273 | data[f'points_fps'] = data.points[perm]
274 | return data
275 |
276 |
277 | def reverse_translation_and_scale(mesh):
278 | """
279 | Translation and scale for reverse normalisation of mesh.
280 | """
281 | bounds = mesh.extents
282 | if bounds.min() == 0.0:
283 | return
284 |
285 | # translate to origin
286 | translation = - (mesh.bounds[0] + mesh.bounds[1]) * 0.5
287 | translation = trimesh.transformations.translation_matrix(direction=-translation)
288 |
289 | # scale to unit cube
290 | scale = bounds.max()
291 | scale_trafo = trimesh.transformations.scale_matrix(factor=scale)
292 |
293 | return translation, scale_trafo
294 |
295 |
296 | def normalise_mesh(args):
297 | """
298 | Normalize mesh (or point cloud).
299 | First translation, then scaling, if any.
300 |
301 | Parameters
302 | ----------
303 | args[0]: input_path: str or Path
304 | Path to input mesh (results from reconstruction, normalised)
305 | args[1]: reference_path: str or Path
306 | Path to reference mesh (used to determine the transformation)
307 | args[2]: output_path: str or Path
308 | Path to output mesh (final reconstruction, with reversed normalisation)
309 | args[3]: force: str
310 | Force loading type ('mesh' or 'scene')
311 | args[4]: offset: list
312 | Coordinate offset (x, y, z)
313 | args[5]: scaling: bool
314 | Switch on scaling if set True
315 | args[6]: translation: bool
316 | Switch on translation if set True
317 | """
318 | input_path, reference_path, output_path, force, offset, is_scaling, is_translation = args
319 |
320 | reference_mesh = trimesh.load(reference_path)
321 | translation, scale_trafo = reverse_translation_and_scale(reference_mesh)
322 | if offset is not None:
323 | translation[0][-1] = translation[0][-1] + offset[0]
324 | translation[1][-1] = translation[1][-1] + offset[1]
325 | translation[2][-1] = translation[2][-1] + offset[2]
326 |
327 | # trimesh built-in transform would result in an issue of missing triangles
328 | with open(input_path, 'r') as fin:
329 | lines = fin.readlines()
330 | lines_ = []
331 | for line in lines:
332 | if line.startswith('v'):
333 | vertices = np.array(line.split()[1:], dtype=float)
334 | if is_translation is True:
335 | vertices = vertices - translation[:3, 3]
336 | if is_scaling is True:
337 | vertices = vertices / scale_trafo[0][0]
338 | line_ = f'v {vertices[0]} {vertices[1]} {vertices[2]}\n'
339 | else:
340 | line_ = line
341 | lines_.append(line_)
342 | with open(output_path, 'w') as fout:
343 | fout.writelines(lines_)
344 |
345 |
346 | def reverse_normalise_mesh(args):
347 | """
348 | Reverse normalisation for reconstructed mesh.
349 | First scaling, then translation, if any.
350 |
351 | Parameters
352 | ----------
353 | args[0]: input_path: str or Path
354 | Path to input mesh (results from reconstruction, normalised)
355 | args[1]: reference_path: str or Path
356 | Path to reference mesh (used to determine the transformation)
357 | args[2]: output_path: str or Path
358 | Path to output mesh (final reconstruction, with reversed normalisation)
359 | args[3]: force: str
360 | Force loading type ('mesh' or 'scene')
361 | args[4]: offset: list
362 | Coordinate offset (x, y, z)
363 | args[5]: scaling: bool
364 | Switch on scaling if set True
365 | args[6]: translation: bool
366 | Switch on translation if set True
367 | """
368 | input_path, reference_path, output_path, force, offset, is_scaling, is_translation = args
369 |
370 | reference_mesh = trimesh.load(reference_path)
371 | translation, scale_trafo = reverse_translation_and_scale(reference_mesh)
372 | if offset is not None:
373 | translation[0][-1] = translation[0][-1] + offset[0]
374 | translation[1][-1] = translation[1][-1] + offset[1]
375 | translation[2][-1] = translation[2][-1] + offset[2]
376 |
377 | # trimesh built-in transform would result in an issue of missing triangles
378 | with open(input_path, 'r') as fin:
379 | lines = fin.readlines()
380 | lines_ = []
381 | for line in lines:
382 | if line.startswith('v'):
383 | vertices = np.array(line.split()[1:], dtype=float)
384 | if is_scaling is True:
385 | vertices = vertices * scale_trafo[0][0]
386 | if is_translation is True:
387 | vertices = vertices + translation[:3, 3]
388 | line_ = f'v {vertices[0]} {vertices[1]} {vertices[2]}\n'
389 | else:
390 | line_ = line
391 | lines_.append(line_)
392 | with open(output_path, 'w') as fout:
393 | fout.writelines(lines_)
394 |
395 |
396 | def reverse_normalise_cloud(args):
397 | """
398 | Reverse normalisation for normalised point cloud.
399 |
400 | Parameters
401 | ----------
402 | args[0]: input_path: str or Path
403 | Path to input point cloud (normalised)
404 | args[1]: reference_path: str or Path
405 | Path to reference mesh (used to determine the transformation)
406 | args[2]: output_path: str or Path
407 | Path to output point cloud (with reversed normalisation)
408 | args[3]: force: str
409 | Force loading type ('mesh' or 'scene')
410 | args[4]: offset: list
411 | Coordinate offset (x, y, z)
412 | """
413 | input_path, reference_path, output_path, force, offset = args
414 | plydata = PlyData.read(input_path)['vertex']
415 | points = np.array([plydata['x'], plydata['y'], plydata['z']]).T
416 | cloud = trimesh.PointCloud(points)
417 | reference_mesh = trimesh.load(reference_path)
418 | translation, scale_trafo = reverse_translation_and_scale(reference_mesh)
419 | if offset is not None:
420 | translation[0][-1] = translation[0][-1] + offset[0]
421 | translation[1][-1] = translation[1][-1] + offset[1]
422 | translation[2][-1] = translation[2][-1] + offset[2]
423 |
424 | cloud.apply_transform(scale_trafo)
425 | cloud.apply_transform(translation)
426 | cloud.export(str(output_path))
427 |
428 |
429 | def coerce(data):
430 | """
431 | Coercion for legacy data.
432 | """
433 | data.points = torch.as_tensor(data.points, dtype=torch.float)
434 | data.queries = torch.as_tensor(data.queries, dtype=torch.float)
435 |
436 | if not hasattr(data, 'num_points'):
437 | data.num_points = len(data.points)
438 | if not hasattr(data, 'batch_points'):
439 | data.batch_points = torch.zeros(len(data.points), dtype=torch.long)
440 | return data
441 |
442 |
443 | def set_seed(seed: int) -> None:
444 | """
445 | Set singular seed to fix randomness.
446 | May need to be repeatedly invoked (at least for np.random).
447 | """
448 | np.random.seed(seed)
449 | random.seed(seed)
450 | torch.manual_seed(seed)
451 | torch.cuda.manual_seed(seed)
452 | # When running on the CuDNN backend, two further options must be set
453 | torch.backends.cudnn.deterministic = True
454 | torch.backends.cudnn.benchmark = False
455 | # Set a fixed value for the hash seed
456 | os.environ["PYTHONHASHSEED"] = str(seed)
457 |
458 |
459 | def append_labels(args):
460 | """
461 | Append occupancy labels to existing cell complex file.
462 |
463 | Parameters
464 | ----------
465 | args[0]: input_cc: str or Path
466 | Path to input CC file
467 | args[1]: input_manifold: str or Path
468 | Path to input manifold mesh file
469 | args[2]: output_cc: str or Path
470 | Path to output CC file
471 | """
472 | with open(args[0], 'rb') as handle:
473 | cell_complex = pickle.load(handle)
474 | cells_in_mesh = cell_complex.cells_in_mesh(args[1])
475 | # one-hot encoding
476 | labels = [0] * cell_complex.num_cells
477 | for i in cells_in_mesh:
478 | labels[i] = 1
479 | cell_complex.labels = labels
480 | cell_complex.save(args[2])
481 |
482 |
483 | @hydra.main(config_path='./conf', config_name='config', version_base='1.2')
484 | def multi_append_labels(cfg: DictConfig):
485 | # initialize logging
486 | logger = logging.getLogger('Labels')
487 |
488 | filenames = glob.glob('data/munich_16star/raw/05_complexes/*.cc')
489 | args = []
490 | for filename_input in filenames:
491 | stem = Path(filename_input).stem
492 | filename_output = Path(filename_input).with_suffix('.cc.new')
493 | filename_manifold = Path(filename_input).parent.parent / '03_meshes_manifold' / (stem + '.obj')
494 | if not filename_output.exists():
495 | args.append((filename_input, filename_manifold, filename_output))
496 |
497 | logger.info('Start complex labeling')
498 | with multiprocessing.Pool(processes=cfg.num_workers if cfg.num_workers else multiprocessing.cpu_count()) as pool:
499 | # call with multiprocessing
500 | for _ in tqdm(pool.imap_unordered(append_labels, args), desc='Appending labels', total=len(args)):
501 | pass
502 |
503 | # exit with a message
504 | logger.info('Done complex labeling')
505 |
506 |
507 | def append_samples(args):
508 | """
509 | Append multi-grid-size samples to existing data file.
510 |
511 | Parameters
512 | ----------
513 | args[0]: input_path: str or Path
514 | Path to input torch file
515 | args[1]: output_path: str or Path
516 | Path to output torch file
517 | args[2]: sample_func: func
518 | Function to sample
519 | """
520 | data = torch.load(args[0])
521 | data = coerce(data)
522 | data = args[2](data)
523 | torch.save(data, args[1])
524 |
525 |
526 | def append_queries(args):
527 | """
528 | Append queries to existing data files.
529 |
530 | Parameters
531 | ----------
532 | args[0]: input_path: str or Path
533 | Path to input torch file
534 | args[1]: complex_dir: str or Path
535 | Dir to complexes
536 | args[2]: output_path: str or Path
537 | Path to output torch file
538 | """
539 | data = torch.load(args[0])
540 | with open(os.path.join(args[1], data.name + '.cc'), 'rb') as handle:
541 | cell_complex = pickle.load(handle)
542 |
543 | queries_random = np.array(cell_complex.cell_representatives(location='random_t', num=16))
544 | queries_boundary = np.array(cell_complex.cell_representatives(location='boundary', num=16))
545 | queries_skeleton = np.array(cell_complex.cell_representatives(location='skeleton', num=16))
546 | data.queries_random = torch.as_tensor(queries_random, dtype=torch.float)
547 | data.queries_boundary = torch.as_tensor(queries_boundary, dtype=torch.float)
548 | data.queries_skeleton = torch.as_tensor(queries_skeleton, dtype=torch.float)
549 |
550 | torch.save(data, args[2])
551 |
552 |
553 | @hydra.main(config_path='./conf', config_name='config', version_base='1.2')
554 | def multi_append_queries(cfg: DictConfig):
555 | # initialize logging
556 | logger = logging.getLogger('Querying')
557 |
558 | filenames = glob.glob(f'{cfg.data_dir}/processed/*[0-9].pt')
559 | args = []
560 | for filename_input in filenames:
561 | filename_output = Path(filename_input).with_suffix('.pt.queries_appended')
562 | if not filename_output.exists():
563 | args.append((filename_input, cfg.complex_dir, filename_output))
564 |
565 | logger.info('Start polyhedra sampling')
566 | with multiprocessing.Pool(processes=cfg.num_workers if cfg.num_workers else multiprocessing.cpu_count()) as pool:
567 | # call with multiprocessing
568 | for _ in tqdm(pool.imap_unordered(append_queries, args), desc='Appending queries', total=len(args)):
569 | pass
570 |
571 | # exit with a message
572 | logger.info('Done polyhedra sampling')
573 |
574 |
575 | @hydra.main(config_path='./conf', config_name='config', version_base='1.2')
576 | def multi_append_samples(cfg: DictConfig):
577 | # initialize logging
578 | logger = logging.getLogger('Sampling')
579 |
580 | filenames = glob.glob(f'{cfg.data_dir}/processed/*[0-9].pt')
581 | sampler = Sampler(strategy=cfg.sample.strategy, length=cfg.sample.length, ratio=cfg.sample.ratio,
582 | resolutions=cfg.sample.resolutions, duplicate=cfg.sample.duplicate, seed=cfg.seed)
583 | args = []
584 | for filename_input in filenames:
585 | filename_output = Path(filename_input).with_suffix('.pt.samples_appended')
586 | if not filename_output.exists():
587 | args.append((filename_input, filename_output, sampler.sample))
588 |
589 | logger.info('Start point cloud sampling')
590 | with multiprocessing.Pool(processes=cfg.num_workers if cfg.num_workers else multiprocessing.cpu_count()) as pool:
591 | # call with multiprocessing
592 | for _ in tqdm(pool.imap_unordered(append_samples, args), desc='Appending samples', total=len(args)):
593 | pass
594 |
595 | # exit with a message
596 | logger.info('Done point cloud sampling')
597 |
598 |
599 | def count_facets(mesh_path):
600 | """
601 | Count the number of facets given a mesh.
602 | """
603 | mesh = trimesh.load(mesh_path)
604 | faces_extracted = np.concatenate(mesh.facets)
605 | faces_left = np.setdiff1d(np.arange(len(mesh.faces)), faces_extracted)
606 | num_facets = len(mesh.facets) + len(faces_left)
607 | return num_facets
608 |
609 |
610 | def dict_count_facets(mesh_dir):
611 | """
612 | Count the number of facets given a directory of meshes.
613 | """
614 | filenames = glob.glob(f'{mesh_dir}/*.obj')
615 | facet_dict = {}
616 | for filename_input in filenames:
617 | stem = Path(filename_input).stem
618 | num_facets = count_facets(filename_input)
619 | facet_dict[stem] = num_facets
620 |
621 | # sorted by facet number
622 | print({k: v for k, v in sorted(facet_dict.items(), key=lambda item: item[1])})
623 |
624 |
625 | def append_scale_to_csv(input_csv, output_csv):
626 | """
627 | Append scale into an existing csv file.
628 | Note that scale has been implemented in stats.py.
629 | """
630 | rows = []
631 | with open(input_csv, 'r', newline='') as input_csvfile:
632 | reader = csv.reader(input_csvfile)
633 | next(reader) # skip header
634 | for r in reader:
635 | filename_input = r[1]
636 | row = r
637 | if not filename_input.endswith('.obj'):
638 | filename_input = filename_input + '.obj'
639 | mesh = trimesh.load(filename_input)
640 | extents = mesh.extents
641 | scale = extents.max()
642 | row.append(scale)
643 | rows.append(row)
644 |
645 | with open(output_csv, 'w') as output_csvfile:
646 | writer = csv.writer(output_csvfile, lineterminator='\n')
647 | writer.writerows(rows)
648 |
649 |
650 | @hydra.main(config_path='./conf', config_name='config', version_base='1.2')
651 | def calculate(cfg):
652 | """
653 | Calculate number of parameters.
654 | """
655 | from network import PolyGNN
656 |
657 | # initialize model
658 | model = PolyGNN(cfg)
659 |
660 | # calculate params
661 | total_params = sum(
662 | param.numel() for param in model.parameters()
663 | )
664 |
665 | trainable_params = sum(
666 | p.numel() for p in model.parameters() if p.requires_grad
667 | )
668 |
669 | print(f'total_params: {total_params}')
670 | print(f'trainable_params: {trainable_params}')
671 |
672 |
673 | def merge_vg2cc(args):
674 | """
675 | Merge vg from RANSAC and vg from City3D, then generate complexes.
676 | """
677 | vg_ransac_path, vg_city3d_path, cc_output_path = args
678 | epsilon = 0.0001
679 | vertex_group_ransac = VertexGroup(filepath=vg_ransac_path, refit=True, global_group=False, quiet=True)
680 | vertex_group_city3d = VertexGroup(filepath=vg_city3d_path, refit=False, global_group=True, quiet=True)
681 | additional_planes = [p for p in vertex_group_city3d.planes if -epsilon < p[2] < epsilon or
682 | (-epsilon < p[0] < epsilon and -epsilon < p[1] < epsilon and 1 - epsilon < p[2] < 1 + epsilon)]
683 |
684 | if len(vertex_group_ransac.planes) == 0:
685 | return
686 | cell_complex = CellComplex(vertex_group_ransac.planes, vertex_group_ransac.bounds,
687 | vertex_group_ransac.points_grouped,
688 | build_graph=True, additional_planes=additional_planes,
689 | initial_bound=[[-0.5, -0.5, -0.5], [0.5, 0.5, 0.5]], quiet=True)
690 | cell_complex.refine_planes() # additional planes are not refined
691 | cell_complex.prioritise_planes()
692 | cell_complex.construct()
693 |
694 | cell_complex.save_obj(filepath=Path(cc_output_path).with_suffix('.obj'))
695 | cell_complex.save(filepath=cc_output_path)
696 |
697 |
698 | def multi_merge_vg2cc(vg_ransac_dir, vg_city3d_dir, cc_output_dir):
699 | """
700 | Merge vertex groups from RANSAC and from City3D, then generate complexes, with multiprocessing.
701 | """
702 | args = []
703 | num_workers = 42
704 | vg_filenames_ransac = glob.glob(vg_ransac_dir + '/*.vg')
705 |
706 | for vg_filename_ransac in vg_filenames_ransac:
707 | stem = Path(vg_filename_ransac).stem
708 | vg_filenames_city3d = (Path(vg_city3d_dir) / stem).with_suffix('.vg')
709 | if not vg_filenames_city3d.exists():
710 | continue
711 | cc_filenames_output = (Path(cc_output_dir) / stem).with_suffix('.cc')
712 | if cc_filenames_output.exists():
713 | continue
714 | args.append([vg_filename_ransac, vg_filenames_city3d, cc_filenames_output])
715 |
716 | with multiprocessing.Pool(processes=num_workers) as pool:
717 | # call with multiprocessing
718 | for _ in tqdm(pool.imap_unordered(merge_vg2cc, args), desc='Creating complexes from vertex groups',
719 | total=len(args)):
720 | pass
721 |
722 |
723 | def vg2cc(args):
724 | """
725 | Create cell complex from vertex group.
726 | """
727 | vg_path, cc_path = args
728 | # print(vg_path)
729 |
730 | vertex_group = VertexGroup(filepath=vg_path, refit=False, global_group=True)
731 |
732 | cell_complex = CellComplex(vertex_group.planes, vertex_group.bounds, vertex_group.points_grouped,
733 | build_graph=True, additional_planes=None,
734 | initial_bound=[[-0.5, -0.5, -0.5], [0.5, 0.5, 0.5]])
735 | cell_complex.refine_planes(theta=5 * 3.1416 / 180, epsilon=0.002)
736 | cell_complex.construct()
737 |
738 | cell_complex.save_obj(filepath=Path(cc_path).with_suffix('.obj'))
739 | cell_complex.save(filepath=cc_path)
740 |
741 |
742 | def multi_vg2cc(vg_dir, cc_dir):
743 | """
744 | Create cell complexes from vertex groups with multiprocessing.
745 | """
746 | args = []
747 | num_workers = 38
748 | vg_filenames = glob.glob(vg_dir + '/*.vg')
749 |
750 | for vg_filename in vg_filenames:
751 | stem = Path(vg_filename).stem
752 | cc_filename = (Path(cc_dir) / stem).with_suffix('.cc')
753 | args.append([vg_filename, cc_filename])
754 |
755 | with multiprocessing.Pool(processes=num_workers) as pool:
756 | # call with multiprocessing
757 | for _ in tqdm(pool.imap_unordered(vg2cc, args), desc='Creating complexes from vertex groups', total=len(args)):
758 | pass
759 |
760 |
761 | def coordinate2index(x, reso, coord_type='2d'):
762 | """ Generate grid index of points
763 |
764 | Args:
765 | x (tensor): points (normalized to [0, 1])
766 | reso (int): defined resolution
767 | coord_type (str): coordinate type
768 | """
769 | x = (x * reso).long()
770 | if coord_type == '2d': # plane
771 | index = x[:, :, 0] + reso * x[:, :, 1] # [B, N, 1]
772 | index = index[:, None, :] # [B, 1, N]
773 | return index
774 |
775 |
776 | def normalize_coordinate(p, padding=0, plane='xz', scale=1.0):
777 | """ Normalize coordinate to [0, 1] for unit cube experiments
778 |
779 | Args:
780 | p (tensor): point
781 | padding (float): conventional padding parameter of ONet for unit cube, so [-0.5, 0.5] -> [-0.55, 0.55]
782 | plane (str): plane feature type, ['xz', 'xy', 'yz']
783 | scale: normalize scale
784 | """
785 | if plane == 'xz':
786 | xy = p[:, :, [0, 2]]
787 | elif plane == 'xy':
788 | xy = p[:, :, [0, 1]]
789 | else:
790 | xy = p[:, :, [1, 2]]
791 |
792 | xy_new = xy / (1 + padding + 10e-6) # (-0.5, 0.5)
793 | xy_new = xy_new + 0.5 # range (0, 1)
794 |
795 | # f there are outliers out of the range
796 | if xy_new.max() >= 1:
797 | xy_new[xy_new >= 1] = 1 - 10e-6
798 | if xy_new.min() < 0:
799 | xy_new[xy_new < 0] = 0.0
800 | return xy_new
801 |
802 |
803 | def normalize_3d_coordinate(p, padding=0):
804 | """ Normalize coordinate to [0, 1] for unit cube experiments.
805 | Corresponds to our 3D model
806 |
807 | Args:
808 | p (tensor): point
809 | padding (float): conventional padding paramter of ONet for unit cube, so [-0.5, 0.5] -> [-0.55, 0.55]
810 | """
811 | p_nor = p / (1 + padding + 10e-4) # (-0.5, 0.5)
812 | p_nor = p_nor + 0.5 # range (0, 1)
813 | # f there are outliers out of the range
814 | if p_nor.max() >= 1:
815 | p_nor[p_nor >= 1] = 1 - 10e-4
816 | if p_nor.min() < 0:
817 | p_nor[p_nor < 0] = 0.0
818 | return p_nor
819 |
820 |
821 | class map2local(object):
822 | """
823 | Add new keys to the given input
824 |
825 | Args:
826 | s (float): the defined voxel size
827 | pos_encoding (str): method for the positional encoding, linear|sin_cos
828 | """
829 |
830 | def __init__(self, s, pos_encoding='linear'):
831 | super().__init__()
832 | self.s = s
833 | self.pe = positional_encoding(basis_function=pos_encoding)
834 |
835 | def __call__(self, p):
836 | p = torch.remainder(p, self.s) / self.s # always positive
837 | # p = torch.fmod(p, self.s) / self.s # same sign as input p!
838 | p = self.pe(p)
839 | return p
840 |
841 |
842 | class positional_encoding(object):
843 | """ Positional Encoding (presented in NeRF)
844 |
845 | Args:
846 | basis_function (str): basis function
847 | """
848 |
849 | def __init__(self, basis_function='sin_cos'):
850 | super().__init__()
851 | self.func = basis_function
852 |
853 | L = 10
854 | freq_bands = 2. ** (np.linspace(0, L - 1, L))
855 | self.freq_bands = freq_bands * math.pi
856 |
857 | def __call__(self, p):
858 | if self.func == 'sin_cos':
859 | out = []
860 | p = 2.0 * p - 1.0 # chagne to the range [-1, 1]
861 | for freq in self.freq_bands:
862 | out.append(torch.sin(freq * p))
863 | out.append(torch.cos(freq * p))
864 | p = torch.cat(out, dim=2)
865 | return p
866 |
867 |
868 | def make_3d_grid(bb_min, bb_max, shape):
869 | """
870 | Makes a 3D grid.
871 |
872 | Args:
873 | bb_min (tuple): bounding box minimum
874 | bb_max (tuple): bounding box maximum
875 | shape (tuple): output shape
876 | """
877 | size = shape[0] * shape[1] * shape[2]
878 |
879 | pxs = torch.linspace(bb_min[0], bb_max[0], int(shape[0]))
880 | pys = torch.linspace(bb_min[1], bb_max[1], int(shape[1]))
881 | pzs = torch.linspace(bb_min[2], bb_max[2], int(shape[2]))
882 |
883 | pxs = pxs.view(-1, 1, 1).expand(*shape).contiguous().view(size)
884 | pys = pys.view(1, -1, 1).expand(*shape).contiguous().view(size)
885 | pzs = pzs.view(1, 1, -1).expand(*shape).contiguous().view(size)
886 | p = torch.stack([pxs, pys, pzs], dim=1)
887 |
888 | return p
889 |
890 |
891 | if __name__ == '__main__':
892 | pass
893 |
--------------------------------------------------------------------------------