├── .gitattributes ├── .gitignore ├── CITATION.bib ├── LICENSE ├── README.md ├── checkpoints └── .gitkeep ├── conf ├── config.yaml ├── dataset │ ├── mini.yaml │ ├── munich.yaml │ └── nuremberg.yaml └── hydra │ └── job_logging │ └── custom.yaml ├── data └── .gitkeep ├── dataset.py ├── docs └── architecture.png ├── download.py ├── environment.yml ├── network ├── __init__.py ├── convonet │ ├── ResnetBlockFC.py │ ├── __init__.py │ ├── pointnet.py │ ├── unet.py │ └── unet3d.py ├── decoder.py ├── encoder.py ├── gnn.py ├── loss.py └── polygnn.py ├── reconstruct.py ├── remap.py ├── stats.py ├── test.py ├── train.py └── utils.py /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # celery beat schedule file 95 | celerybeat-schedule 96 | 97 | # SageMath parsed files 98 | *.sage.py 99 | 100 | # Environments 101 | .env 102 | .venv 103 | env/ 104 | venv/ 105 | ENV/ 106 | env.bak/ 107 | venv.bak/ 108 | 109 | # Spyder project settings 110 | .spyderproject 111 | .spyproject 112 | 113 | # Rope project settings 114 | .ropeproject 115 | 116 | # mkdocs documentation 117 | /site 118 | 119 | # mypy 120 | .mypy_cache/ 121 | .dmypy.json 122 | dmypy.json 123 | 124 | # Pyre type checker 125 | .pyre/ 126 | 127 | # Pycharm 128 | .idea 129 | 130 | # macOS 131 | .DS_Store 132 | 133 | # results 134 | outputs/ 135 | 136 | # data 137 | data/ 138 | 139 | # checkpoints 140 | checkpoints/ 141 | 142 | # wandb 143 | wandb/ 144 | 145 | # plots 146 | *.pdf 147 | 148 | -------------------------------------------------------------------------------- /CITATION.bib: -------------------------------------------------------------------------------- 1 | @article{chen2024polygnn, 2 | title = {PolyGNN: Polyhedron-based graph neural network for 3D building reconstruction from point clouds}, 3 | journal = {ISPRS Journal of Photogrammetry and Remote Sensing}, 4 | volume = {218}, 5 | pages = {693-706}, 6 | year = {2024}, 7 | issn = {0924-2716}, 8 | doi = {https://doi.org/10.1016/j.isprsjprs.2024.09.031}, 9 | url = {https://www.sciencedirect.com/science/article/pii/S0924271624003691}, 10 | author = {Zhaiyu Chen and Yilei Shi and Liangliang Nan and Zhitong Xiong and Xiao Xiang Zhu}, 11 | } 12 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Zhaiyu Chen 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PolyGNN 2 | 3 | ----------- 4 | [![Paper: HTML](https://img.shields.io/badge/Paper-HTML-yellow)](https://www.sciencedirect.com/science/article/pii/S0924271624003691) 5 | [![Paper: PDF](https://img.shields.io/badge/Paper-PDF-green)](https://www.sciencedirect.com/science/article/pii/S0924271624003691/pdfft?md5=3d0d8b3b72cdd3f4c809d714b1292137&pid=1-s2.0-S0924271624003691-main.pdf) 6 | [![License: MIT](https://img.shields.io/badge/License-MIT-blue.svg)](https://raw.githubusercontent.com/chenzhaiyu/polygnn/main/LICENSE) 7 | 8 | PolyGNN is an implementation of the paper [*PolyGNN: Polyhedron-based Graph Neural Network for 3D Building Reconstruction from Point Clouds*](https://www.sciencedirect.com/science/article/pii/S0924271624003691). 9 | PolyGNN learns a piecewise planar occupancy function, supported by polyhedral decomposition, for efficient and scalable 3D building reconstruction. 10 | 11 |

12 | 13 |

14 | 15 | ## 🛠️ Setup 16 | 17 | ### Repository 18 | 19 | Clone the repository: 20 | 21 | ```bash 22 | git clone https://github.com/chenzhaiyu/polygnn && cd polygnn 23 | ``` 24 | 25 | ### All-in-one installation 26 | 27 | Create a conda environment with all dependencies: 28 | 29 | ```bash 30 | conda env create -f environment.yml && conda activate polygnn 31 | ``` 32 | 33 | ### Manual installation 34 | 35 | Still easy! Create a conda environment and install [mamba](https://github.com/mamba-org/mamba) for faster parsing: 36 | ```bash 37 | conda create --name polygnn python=3.10 && conda activate polygnn 38 | conda install mamba -c conda-forge 39 | ``` 40 | 41 | Install the required dependencies: 42 | ``` 43 | mamba install pytorch torchvision sage=10.0 pytorch-cuda=11.7 pyg=2.3 pytorch-scatter pytorch-sparse pytorch-cluster torchmetrics rtree -c pyg -c pytorch -c nvidia -c conda-forge 44 | pip install abspy==0.2.6 hydra-core hydra-colorlog omegaconf trimesh tqdm wandb plyfile 45 | ``` 46 | 47 | ## 🚀 Usage 48 | 49 | ### Quick start 50 | 51 | Download the mini dataset and pretrained weights: 52 | 53 | ```python 54 | python download.py dataset=mini 55 | ``` 56 | In case you encounter issues (e.g., Google Drive limits), manually download the data and weights [here](https://drive.google.com/drive/folders/1fAwvhGtOgS8f4IldE1J4v5s0438WM24b?usp=sharing), then extract them into `./checkpoints/mini` and `./data/mini`, respectively. 57 | The mini dataset contains 200 random instances (~0.07% of the full dataset). 58 | 59 | Train PolyGNN on the mini dataset (provided for your reference and is not intended for full-scale training): 60 | ```python 61 | python train.py dataset=mini 62 | ``` 63 | The data will be automatically preprocessed the first time you initiate training. 64 | 65 | Evaluate PolyGNN with option to save predictions: 66 | ```python 67 | python test.py dataset=mini evaluate.save=true 68 | ``` 69 | 70 | Generate meshes from predictions: 71 | ```python 72 | python reconstruct.py dataset=mini reconstruct.type=mesh 73 | ``` 74 | 75 | Remap meshes to their original CRS: 76 | ```python 77 | python remap.py dataset=mini 78 | ``` 79 | 80 | Generate reconstruction statistics: 81 | ```python 82 | python stats.py dataset=mini 83 | ``` 84 | 85 | ### Available configurations 86 | 87 | ```python 88 | # check available configurations for training 89 | python train.py --cfg job 90 | 91 | # check available configurations for evaluation 92 | python test.py --cfg job 93 | ``` 94 | Alternatively, review the configuration file: `conf/config.yaml`. 95 | 96 | ### Full dataset 97 | 98 | The Munich dataset is available for download on [Zenodo](https://zenodo.org/records/14254264). Note that it requires 332 GB of storage when decompressed. Meshes for CRS remapping can be downloaded [here](https://drive.google.com/file/d/1hn11XMqyoPUnq-9WGfAwQq47uuUvcbi7/view?usp=drive_link). 99 | 100 | ### Custom data 101 | 102 | PolyGNN requires polyhedron-based graphs as input. To prepare this from your own point clouds: 103 | 1. Extract planar primitives using tools such as [Easy3D](https://github.com/LiangliangNan/Easy3D) or [GoCoPP](https://github.com/Ylannl/GoCoPP), preferably in [VertexGroup](https://abspy.readthedocs.io/en/latest/vertexgroup.html) format. 104 | 2. Build [CellComplex](https://abspy.readthedocs.io/en/latest/api.html#abspy.CellComplex) from the primitives using [abspy](https://github.com/chenzhaiyu/abspy). Example code: 105 | ```python 106 | from abspy import VertexGroup, CellComplex 107 | vertex_group = VertexGroup(vertex_group_path, quiet=True) 108 | cell_complex = CellComplex(vertex_group.planes, vertex_group.aabbs, 109 | vertex_group.points_grouped, build_graph=True, quiet=True) 110 | cell_complex.prioritise_planes(prioritise_verticals=True) 111 | cell_complex.construct() 112 | cell_complex.save(complex_path) 113 | ``` 114 | Alternatively, you can modify [`CityDataset`](https://github.com/chenzhaiyu/polygnn/blob/67addd77a6be1d100448e3bd7523babfa063d0dd/dataset.py#L157) or [`TestOnlyDataset`](https://github.com/chenzhaiyu/polygnn/blob/67addd77a6be1d100448e3bd7523babfa063d0dd/dataset.py#L276) to accept inputs directly from [VertexGroup](https://abspy.readthedocs.io/en/latest/vertexgroup.html) or [VertexGroupReference](https://abspy.readthedocs.io/en/latest/api.html#abspy.VertexGroupReference). 115 | 3. Structure your dataset similarly to the provided mini dataset: 116 | ```bash 117 | YOUR_DATASET_NAME 118 | └── raw 119 | ├── 03_meshes 120 | │ ├── DEBY_LOD2_104572462.obj 121 | │ ├── DEBY_LOD2_104575306.obj 122 | │ └── DEBY_LOD2_104575493.obj 123 | ├── 04_pts 124 | │ ├── DEBY_LOD2_104572462.npy 125 | │ ├── DEBY_LOD2_104575306.npy 126 | │ └── DEBY_LOD2_104575493.npy 127 | ├── 05_complexes 128 | │ ├── DEBY_LOD2_104572462.cc 129 | │ ├── DEBY_LOD2_104575306.cc 130 | │ └── DEBY_LOD2_104575493.cc 131 | ├── testset.txt 132 | └── trainset.txt 133 | ``` 134 | 4. To train or evaluate PolyGNN using your dataset, run the following commands: 135 | ```python 136 | # start training 137 | python train.py dataset=YOUR_DATASET_NAME 138 | 139 | # start evaluation 140 | python test.py dataset=YOUR_DATASET_NAME 141 | ``` 142 | For evaluation only, you can instantiate your dataset as a [`TestOnlyDataset`](https://github.com/chenzhaiyu/polygnn/blob/67addd77a6be1d100448e3bd7523babfa063d0dd/dataset.py#L276), as in [this line](https://github.com/chenzhaiyu/polygnn/blob/94ffc9e45f0721653038bd91f33f1d4eafeab7cb/test.py#L178). 143 | 144 | ## 👷 TODOs 145 | 146 | - [x] Demo with mini data and pretrained weights 147 | - [x] Short tutorial for getting started 148 | - [x] Host the full dataset 149 | 150 | ## 🎓 Citation 151 | 152 | If you use PolyGNN in a scientific work, please consider citing the paper: 153 | 154 | ```bibtex 155 | @article{chen2024polygnn, 156 | title = {PolyGNN: Polyhedron-based graph neural network for 3D building reconstruction from point clouds}, 157 | journal = {ISPRS Journal of Photogrammetry and Remote Sensing}, 158 | volume = {218}, 159 | pages = {693-706}, 160 | year = {2024}, 161 | issn = {0924-2716}, 162 | doi = {https://doi.org/10.1016/j.isprsjprs.2024.09.031}, 163 | url = {https://www.sciencedirect.com/science/article/pii/S0924271624003691}, 164 | author = {Zhaiyu Chen and Yilei Shi and Liangliang Nan and Zhitong Xiong and Xiao Xiang Zhu}, 165 | } 166 | ``` 167 | 168 | The synthetic point clouds are simulated with [pyhelios](https://github.com/chenzhaiyu/pyhelios). 169 | You might also want to check out [abspy](https://github.com/chenzhaiyu/abspy) for 3D adaptive binary space partitioning and [Points2Poly](https://github.com/chenzhaiyu/points2poly) for reconstruction with deep implicit fields. 170 | -------------------------------------------------------------------------------- /checkpoints/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenzhaiyu/polygnn/6acf63fce040368424cf2a01c978dd2fcad1d831/checkpoints/.gitkeep -------------------------------------------------------------------------------- /conf/config.yaml: -------------------------------------------------------------------------------- 1 | # Default configuration for the Project. 2 | # Values shall be overriden by respective dataset configurations. 3 | 4 | # default settings 5 | defaults: 6 | - _self_ 7 | - dataset: munich 8 | - override hydra/job_logging: custom 9 | - override hydra/hydra_logging: colorlog 10 | 11 | # general path settings 12 | run_suffix: '' 13 | data_dir: './data/${dataset}${dataset_suffix}' 14 | complex_dir: '${data_dir}/raw/05_complexes' 15 | reference_dir: '${data_dir}/raw/03_meshes_global_test' 16 | output_dir: './outputs/${dataset}${run_suffix}' 17 | remap_dir: '${output_dir}/global' 18 | checkpoint_dir: './checkpoints/${dataset}${run_suffix}' 19 | checkpoint_path: '${checkpoint_dir}/model_best.pth' 20 | csv_path: '${output_dir}/evaluation.csv' 21 | 22 | # network settings 23 | encoder: ConvONet # {PointNet, PointNet2, PointCNN, DGCNN, PointTransformerConv, RandLANet, ConvONet} 24 | decoder: ConvONet # {MLP, ConvONet} 25 | gnn: TAGCN # {null, GCN, TransformerGCN, TAGCN} 26 | latent_dim_light: 256 # latent dimension (256) for plain encoder-decoder 27 | latent_dim_conv: 4096 # latent dimension (4096) for convolutional encoder-decoder 28 | use_spatial_transformer: false 29 | convonet_kwargs: 30 | unet: True 31 | unet_kwargs: 32 | depth: 4 33 | merge_mode: 'concat' 34 | start_filts: 32 35 | plane_resolution: 128 36 | plane_type: ['xz', 'xy', 'yz'] 37 | 38 | # training settings 39 | warm: false # warm start from checkpoint 40 | warm_optimizer: true # load optimizer from checkpoint if available 41 | warm_scheduler: true # load scheduler from checkpoint if available 42 | freeze_stages: [] # [encoder, decoder, gnn] 43 | gpu_ids: [5, 6, 7, 8] 44 | gpu_freeze: false 45 | weight_decay: 1e-6 46 | num_epochs: 50 47 | save_interval: 1 48 | dropout: false 49 | validate: true 50 | seed: 1117 51 | batch_size: 64 52 | num_workers: 32 53 | loss: bce # {bce, focal} 54 | lr: 1e-3 55 | scheduler: 56 | base_lr: 1e-4 57 | max_lr: ${lr} 58 | step_size_up: 4400 59 | mode: triangular2 60 | 61 | # ddp settings 62 | master_addr: localhost 63 | master_port: 12345 64 | 65 | # dataset settings (shall be overwritten) 66 | shuffle: true 67 | class_weights: [1., 1.] 68 | sample: 69 | strategy: random # {null, fps, random, grid} 70 | transform: true # on-the-fly sampling (may introduce randomness) 71 | pre_transform: false # sampling as pre-transform 72 | duplicate: true # effective only for random sampling 73 | length: 4096 # effective only for random sampling 74 | resolutions: [0.05, 0.01, 0.005] # effective only for grid sampling and progressive training 75 | resolution: 0.01 # one of the previous 76 | ratio: 0.3 # effective only for fps sampling 77 | 78 | # evaluation settings 79 | evaluate: 80 | score: true # compute accuracy score in evaluation 81 | save: false # save prediction as numpy file 82 | seal: true # seal non-watertight model with bounding volume 83 | num_samples: 10000 84 | 85 | # reconstruction settings 86 | reconstruct: 87 | type: mesh # {cell, mesh} 88 | scale: true 89 | seal: false # seal with bounding volume 90 | translate: true 91 | offset: [0, 0, 0] # [-653200, -5478800, 0] 92 | 93 | # hydra settings 94 | hydra: 95 | run: 96 | dir: ./outputs 97 | verbose: false 98 | 99 | # wandb settings 100 | wandb: true 101 | wandb_dir: './outputs' 102 | -------------------------------------------------------------------------------- /conf/dataset/mini.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # dataset settings 4 | dataset: 'mini' 5 | num_queries: 16 6 | dataset_suffix: '' 7 | url_dataset: {'default': 'https://drive.google.com/uc?export=download&id=1Yzlw4QGnbhybPbhVrMZ27Fr8qhA51mWZ'} 8 | url_checkpoint: {'default': 'https://drive.usercontent.google.com/download?id=1IivVlSGmTUBa8TDnmDKV24VaoWZBGT1M&export=download&authuser=0&confirm=t'} -------------------------------------------------------------------------------- /conf/dataset/munich.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # dataset settings 4 | dataset: 'munich' 5 | num_queries: 16 6 | dataset_suffix: '' 7 | url_dataset: {'default': null} 8 | url_checkpoint: {'default': 'https://drive.usercontent.google.com/download?id=1IivVlSGmTUBa8TDnmDKV24VaoWZBGT1M&export=download&authuser=0&confirm=t'} -------------------------------------------------------------------------------- /conf/dataset/nuremberg.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # dataset settings 4 | dataset: 'nuremberg' 5 | num_queries: 16 6 | dataset_suffix: '' 7 | reference_dir: '${data_dir}/raw/03_meshes_global' 8 | reconstruct: 9 | offset: [-653200, -5478800, 0] -------------------------------------------------------------------------------- /conf/hydra/job_logging/custom.yaml: -------------------------------------------------------------------------------- 1 | # @package hydra.job_logging 2 | # python logging configuration for tasks 3 | version: 1 4 | formatters: 5 | simple: 6 | format: '[%(asctime)s][%(name)s][%(levelname)s] - %(message)s' 7 | colorlog: 8 | '()': 'colorlog.ColoredFormatter' 9 | format: '[%(cyan)s%(asctime)s%(reset)s][%(blue)s%(name)s%(reset)s][%(log_color)s%(levelname)s%(reset)s] - %(message)s' 10 | log_colors: 11 | DEBUG: purple 12 | INFO: green 13 | WARNING: yellow 14 | ERROR: red 15 | CRITICAL: red 16 | handlers: 17 | console: 18 | class: logging.StreamHandler 19 | formatter: colorlog 20 | stream: ext://sys.stdout 21 | file: 22 | class: logging.FileHandler 23 | formatter: simple 24 | # relative to the job log directory 25 | filename: ${hydra.run.dir}/${hydra.job.name}.log 26 | root: 27 | level: INFO 28 | handlers: [console, file] 29 | 30 | disable_existing_loggers: false -------------------------------------------------------------------------------- /data/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenzhaiyu/polygnn/6acf63fce040368424cf2a01c978dd2fcad1d831/data/.gitkeep -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | Definitions of polyhedral graph data structure and datasets. 3 | """ 4 | 5 | import os 6 | import logging 7 | import glob 8 | import collections 9 | import multiprocessing 10 | from pathlib import Path 11 | 12 | from tqdm import tqdm 13 | import numpy as np 14 | import torch 15 | from torch_geometric.data import Data, Dataset 16 | from abspy import VertexGroup, VertexGroupReference, CellComplex 17 | 18 | from utils import edge_index_from_dict, index_to_mask 19 | 20 | logger = logging.getLogger('dataset') 21 | 22 | 23 | class PolyGraph: 24 | """ 25 | Cell-based graph data structure. 26 | """ 27 | 28 | def __init__(self, use_reference=False, num_queries=None): 29 | self.vertex_group = None 30 | self.cell_complex = None 31 | self.vertex_group_reference = None 32 | self.use_reference = use_reference 33 | self.num_queries = num_queries 34 | 35 | def cell_adjacency(self): 36 | """ 37 | Create adjacency among cells. 38 | """ 39 | # mapping gaped adjacency indices to contiguous ones 40 | adj = self.cell_complex.graph.adj 41 | uid = list(self.cell_complex.graph.nodes) 42 | mapping = {c: i for i, c in enumerate(uid)} 43 | adj_ = collections.defaultdict(set) 44 | for key in adj: 45 | adj_[mapping[key]] = {mapping[value] for value in adj[key]} 46 | 47 | # graph edge index in COO format 48 | return edge_index_from_dict(adj_) 49 | 50 | def cell_labels(self, mesh_path): 51 | """ 52 | Labels of cells, one-hot encoding. 53 | """ 54 | labels = np.zeros(self.cell_complex.num_cells).astype(np.int64) 55 | 56 | # cells inside reference mesh 57 | cells_in_mesh = self.cell_complex.cells_in_mesh(mesh_path, engine='distance') 58 | 59 | for cell in cells_in_mesh: 60 | labels[cell] = 1 61 | 62 | return torch.tensor(labels) 63 | 64 | def data_loader(self, cloud_path, mesh_path=None, complex_path=None, vertex_group_path=None): 65 | """ 66 | Load bvg file and obj file in network readable format. 67 | """ 68 | if complex_path is not None and os.path.exists(complex_path): 69 | # load existing cell complex 70 | import pickle 71 | with open(complex_path, 'rb') as handle: 72 | self.cell_complex = pickle.load(handle) 73 | else: 74 | # construct cell complex 75 | if not self.use_reference: 76 | # load point cloud as vertex group 77 | if vertex_group_path: 78 | self.vertex_group = VertexGroup(vertex_group_path, quiet=True) 79 | # initialise cell complex from planar primitives 80 | self.cell_complex = CellComplex(self.vertex_group.planes, self.vertex_group.aabbs, 81 | self.vertex_group.points_grouped, build_graph=True, quiet=True) 82 | else: 83 | # cannot process vertex group from points alone 84 | raise NotImplementedError 85 | else: 86 | # load mesh as vertex group reference 87 | self.vertex_group_reference = VertexGroupReference(mesh_path, quiet=True) 88 | # initialise cell complex from planar primitives 89 | self.cell_complex = CellComplex(np.array(self.vertex_group_reference.planes), 90 | np.array(self.vertex_group_reference.aabbs), 91 | build_graph=True, quiet=True) 92 | 93 | # prioritise certain planes (e.g., vertical ones) 94 | self.cell_complex.prioritise_planes(prioritise_verticals=True) 95 | 96 | try: 97 | # construct cell complex 98 | self.cell_complex.construct() 99 | except (AssertionError, IndexError) as e: 100 | logger.error(f'Error [{e}] occurred with {cloud_path}.') 101 | return 102 | 103 | # save cell complex to CC files 104 | if complex_path is not None: 105 | Path(complex_path).parent.mkdir(exist_ok=True) 106 | self.cell_complex.save(complex_path) 107 | 108 | # points 109 | if cloud_path is not None: 110 | # npy and vg may contain different point sets 111 | points = np.load(cloud_path) 112 | else: 113 | points = self.vertex_group.points 114 | 115 | # queries 116 | queries = np.array(self.cell_complex.cell_representatives(location='skeleton', num=self.num_queries)) 117 | 118 | # cell adjacency 119 | adjacency = self.cell_adjacency() 120 | 121 | # cell ground truth labels 122 | if mesh_path: 123 | labels = self.cell_labels(mesh_path) 124 | else: 125 | labels = None 126 | 127 | # construct data for pytorch geometric 128 | data = Data(x=None, edge_index=adjacency, y=labels) 129 | 130 | # store sizes 131 | len_cells = queries.shape[0] 132 | len_points = len(points) 133 | data.num_nodes = len_cells 134 | data.num_points = len_points 135 | 136 | # store points and queries 137 | data.points = torch.as_tensor(points, dtype=torch.float) 138 | data.queries = torch.as_tensor(queries, dtype=torch.float) 139 | 140 | # batch indices of points 141 | data.batch_points = torch.zeros(len_points, dtype=torch.long) 142 | 143 | # specify masks 144 | data.train_mask = index_to_mask(range(len_cells), size=len_cells) 145 | data.val_mask = index_to_mask(range(len_cells), size=len_cells) 146 | data.test_mask = index_to_mask(range(len_cells), size=len_cells) 147 | 148 | # name for reference 149 | data.name = Path(cloud_path).stem 150 | 151 | # validate data 152 | data.validate(raise_on_error=True) 153 | 154 | return data 155 | 156 | 157 | class CityDataset(Dataset): 158 | """ 159 | Base building dataset. Applies to Munich and Nuremberg. 160 | """ 161 | 162 | def __init__(self, root, name=None, split=None, num_workers=1, num_queries=16, **kwargs): 163 | self.name = name 164 | self.split = split 165 | self.num_workers = num_workers 166 | self.cloud_suffix = '.npy' 167 | self.mesh_suffix = '.obj' 168 | self.complex_suffix = '.cc' 169 | self.num_queries = num_queries 170 | super().__init__(root, **kwargs) # this line calls download() and process() 171 | 172 | @property 173 | def raw_dir(self) -> str: 174 | return os.path.join(self.root, 'raw') 175 | 176 | @property 177 | def raw_file_names(self): 178 | with open(os.path.join(self.raw_dir, f'{self.split}set.txt'), 'r') as f: 179 | return f.read().splitlines() 180 | 181 | @property 182 | def processed_file_names(self): 183 | return [f'data_{self.split}_{i}.pt' for i in range(len(self.raw_file_names))] 184 | 185 | def download(self): 186 | pass 187 | 188 | def thread(self, kwargs): 189 | """ 190 | Process one file. 191 | """ 192 | path_save = os.path.join(self.processed_dir, f'data_{kwargs["split"]}_{kwargs["index"]}.pt') 193 | if os.path.exists(path_save): 194 | return 195 | logger.debug(f'processing {Path(kwargs["cloud"]).stem}') 196 | try: 197 | data = PolyGraph(use_reference=True, num_queries=self.num_queries).data_loader(kwargs['cloud'], 198 | kwargs['mesh'], 199 | kwargs['complex']) 200 | except (ValueError, IndexError, EOFError) as e: 201 | logger.error(f'error with file {kwargs["mesh"]}: {e}') 202 | return 203 | if self.pre_transform is not None: 204 | data = self.pre_transform(data) 205 | if data is not None: 206 | torch.save(data, path_save) 207 | 208 | def process(self): 209 | """ 210 | Start multiprocessing. 211 | """ 212 | with open(os.path.join(self.raw_dir, 'trainset.txt'), 'r') as f_train: 213 | filenames_train = f_train.read().splitlines() 214 | with open(os.path.join(self.raw_dir, 'testset.txt'), 'r') as f_test: 215 | filenames_test = f_test.read().splitlines() 216 | 217 | args = [] 218 | for i, filename_train in enumerate(filenames_train): 219 | cloud_train = os.path.join(self.raw_dir, '04_pts', filename_train + self.cloud_suffix) 220 | mesh_train = os.path.join(self.raw_dir, '03_meshes', filename_train + self.mesh_suffix) 221 | complex_train = os.path.join(self.raw_dir, '05_complexes', filename_train + self.complex_suffix) 222 | args.append( 223 | {'index': i, 'split': 'train', 'cloud': cloud_train, 'mesh': mesh_train, 'complex': complex_train}) 224 | 225 | for j, filename_test in enumerate(filenames_test): 226 | cloud_test = os.path.join(self.raw_dir, '04_pts', filename_test + self.cloud_suffix) 227 | mesh_test = os.path.join(self.raw_dir, '03_meshes', filename_test + self.mesh_suffix) 228 | complex_test = os.path.join(self.raw_dir, '05_complexes', filename_test + self.complex_suffix) 229 | args.append( 230 | {'index': j, 'split': 'test', 'cloud': cloud_test, 'mesh': mesh_test, 'complex': complex_test}) 231 | 232 | with multiprocessing.Pool( 233 | processes=self.num_workers if self.num_workers else multiprocessing.cpu_count()) as pool: 234 | # call with multiprocessing 235 | for _ in tqdm(pool.imap_unordered(self.thread, args), desc='Preparing dataset', total=len(args)): 236 | pass 237 | 238 | def len(self): 239 | return len(self.processed_file_names) 240 | 241 | def get(self, idx): 242 | data = torch.load(os.path.join(self.processed_dir, f'data_{self.split}_{idx}.pt')) 243 | # to disable UserWarning: Unable to accurately infer 'num_nodes' from the attribute set 244 | # '{'points', 'train_mask', 'val_mask', 'queries', 'test_mask', 'edge_index', 'y'}' 245 | data.num_nodes = len(data.y) 246 | return data 247 | 248 | 249 | class HelsinkiDataset(CityDataset): 250 | """ 251 | Helsinki dataset. 252 | """ 253 | 254 | def __init__(self, **kwargs): 255 | super().__init__(**kwargs) 256 | 257 | @property 258 | def processed_file_names(self): 259 | """ 260 | Modified processed filenames due to discontinuity. 261 | """ 262 | return [os.path.basename(filename) for filename in 263 | glob.glob(os.path.join(self.processed_dir, f'data_{self.split}_*.pt'))] 264 | 265 | def get(self, idx): 266 | """ 267 | Modified data retrieval due to discontinuity. 268 | """ 269 | data = torch.load(self.processed_paths[idx]) 270 | # to disable UserWarning: Unable to accurately infer 'num_nodes' from the attribute set 271 | # '{'points', 'train_mask', 'val_mask', 'queries', 'test_mask', 'edge_index', 'y'}' 272 | data.num_nodes = len(data.y) 273 | return data 274 | 275 | 276 | class TestOnlyDataset(CityDataset): 277 | """ 278 | Test-only dataset. 279 | """ 280 | 281 | def __init__(self, **kwargs): 282 | super().__init__(**kwargs) 283 | 284 | def process(self): 285 | """ 286 | Start multiprocessing. 287 | """ 288 | with open(os.path.join(self.raw_dir, 'testset.txt'), 'r') as f_test: 289 | filenames_test = f_test.read().splitlines() 290 | 291 | args = [] 292 | 293 | for j, filename_test in enumerate(filenames_test): 294 | cloud_test = os.path.join(self.raw_dir, '04_pts', filename_test + self.cloud_suffix) 295 | mesh_test = os.path.join(self.raw_dir, '03_meshes', filename_test + self.mesh_suffix) 296 | complex_test = os.path.join(self.raw_dir, '05_complexes', filename_test + self.complex_suffix) 297 | args.append( 298 | {'index': j, 'split': 'test', 'cloud': cloud_test, 'mesh': mesh_test, 'complex': complex_test}) 299 | 300 | with multiprocessing.Pool( 301 | processes=self.num_workers if self.num_workers else multiprocessing.cpu_count()) as pool: 302 | # call with multiprocessing 303 | for _ in tqdm(pool.imap(self.thread, args), desc='Preparing dataset', total=len(args)): 304 | pass 305 | -------------------------------------------------------------------------------- /docs/architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenzhaiyu/polygnn/6acf63fce040368424cf2a01c978dd2fcad1d831/docs/architecture.png -------------------------------------------------------------------------------- /download.py: -------------------------------------------------------------------------------- 1 | """ 2 | Download datasets and/or models from public urls. 3 | """ 4 | 5 | import os 6 | import tarfile 7 | import urllib.request 8 | 9 | from tqdm import tqdm 10 | import hydra 11 | from omegaconf import DictConfig 12 | 13 | 14 | def my_hook(t): 15 | """ 16 | Wraps tqdm instance. 17 | Don't forget to close() or __exit__() 18 | the tqdm instance once you're done with it (easiest using `with` syntax). 19 | https://github.com/tqdm/tqdm/blob/master/examples/tqdm_wget.py 20 | 21 | Example 22 | ------- 23 | with tqdm(...) as t: 24 | ... reporthook = my_hook(t) 25 | ... urllib.urlretrieve(..., reporthook=reporthook) 26 | """ 27 | last_b = [0] 28 | 29 | def update_to(b=1, bsize=1, tsize=None): 30 | """ 31 | b : int, optional 32 | Number of blocks transferred so far [default: 1]. 33 | bsize : int, optional 34 | Size of each block (in tqdm units) [default: 1]. 35 | tsize : int, optional 36 | Total size (in tqdm units). If [default: None] remains unchanged. 37 | """ 38 | if tsize is not None: 39 | t.total = tsize 40 | t.update((b - last_b[0]) * bsize) 41 | last_b[0] = b 42 | 43 | return update_to 44 | 45 | 46 | @hydra.main(config_path='./conf', config_name='config', version_base='1.2') 47 | def download(cfg: DictConfig): 48 | """ 49 | Download datasets and/or models from public urls. 50 | 51 | Parameters 52 | ---------- 53 | cfg: DictConfig 54 | Hydra configuration 55 | """ 56 | data_url, checkpoint_url = cfg.url_dataset['default'], cfg.url_checkpoint['default'] 57 | data_dir, checkpoint_dir = cfg.data_dir, cfg.checkpoint_dir 58 | data_file, checkpoint_file = (os.path.join(data_dir, f'{cfg.dataset}.tar.gz'), f'{cfg.checkpoint_path}') 59 | 60 | os.makedirs(data_dir, exist_ok=True) 61 | os.makedirs(checkpoint_dir, exist_ok=True) 62 | 63 | if data_url is not None: 64 | with tqdm(unit='B', unit_scale=True, unit_divisor=1024, miniters=1, desc=data_file) as t: 65 | urllib.request.urlretrieve(data_url, filename=data_file, reporthook=my_hook(t), data=None) 66 | with tarfile.open(data_file, 'r:gz') as tar: 67 | tar.extractall(data_dir) 68 | os.remove(data_file) 69 | 70 | if checkpoint_url is not None: 71 | with tqdm(unit='B', unit_scale=True, unit_divisor=1024, miniters=1, desc=checkpoint_file) as t: 72 | urllib.request.urlretrieve(checkpoint_url, filename=checkpoint_file, reporthook=my_hook(t), data=None) 73 | 74 | 75 | if __name__ == '__main__': 76 | download() 77 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: polygnn 2 | channels: 3 | - pyg 4 | - pytorch 5 | - nvidia 6 | - conda-forge 7 | - defaults 8 | dependencies: 9 | - defaults::python=3.10 10 | - pytorch 11 | - torchvision 12 | - sage=10.0 13 | - pytorch-cuda=11.7 14 | - pyg=2.3 15 | - pytorch-scatter 16 | - pytorch-sparse 17 | - pytorch-cluster 18 | - torchmetrics 19 | - rtree 20 | - pip 21 | - pip: 22 | - abspy==0.2.6 23 | - hydra-core 24 | - hydra-colorlog 25 | - omegaconf 26 | - trimesh 27 | - tqdm 28 | - wandb 29 | - plyfile 30 | -------------------------------------------------------------------------------- /network/__init__.py: -------------------------------------------------------------------------------- 1 | from .encoder import Encoder 2 | from .decoder import Decoder 3 | from .gnn import GNN 4 | from .polygnn import PolyGNN 5 | from .loss import bce_loss, focal_loss 6 | -------------------------------------------------------------------------------- /network/convonet/ResnetBlockFC.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | # Resnet Blocks 5 | class ResnetBlockFC(nn.Module): 6 | ''' Fully connected ResNet Block class. 7 | 8 | Args: 9 | size_in (int): input dimension 10 | size_out (int): output dimension 11 | size_h (int): hidden dimension 12 | ''' 13 | 14 | def __init__(self, size_in, size_out=None, size_h=None): 15 | super().__init__() 16 | # Attributes 17 | if size_out is None: 18 | size_out = size_in 19 | 20 | if size_h is None: 21 | size_h = min(size_in, size_out) 22 | 23 | self.size_in = size_in 24 | self.size_h = size_h 25 | self.size_out = size_out 26 | # Submodules 27 | self.fc_0 = nn.Linear(size_in, size_h) 28 | self.fc_1 = nn.Linear(size_h, size_out) 29 | self.actvn = nn.ReLU() 30 | 31 | if size_in == size_out: 32 | self.shortcut = None 33 | else: 34 | self.shortcut = nn.Linear(size_in, size_out, bias=False) 35 | # Initialization 36 | nn.init.zeros_(self.fc_1.weight) 37 | 38 | def forward(self, x): 39 | net = self.fc_0(self.actvn(x)) 40 | dx = self.fc_1(self.actvn(net)) 41 | 42 | if self.shortcut is not None: 43 | x_s = self.shortcut(x) 44 | else: 45 | x_s = x 46 | 47 | return x_s + dx 48 | -------------------------------------------------------------------------------- /network/convonet/__init__.py: -------------------------------------------------------------------------------- 1 | from network.convonet import pointnet 2 | 3 | 4 | encoder_dict = { 5 | 'pointnet_local_pool': pointnet.LocalPoolPointnet, 6 | } 7 | -------------------------------------------------------------------------------- /network/convonet/pointnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch_scatter import scatter_mean, scatter_max 4 | 5 | from utils import coordinate2index, normalize_coordinate, normalize_3d_coordinate, map2local 6 | from network.convonet.ResnetBlockFC import ResnetBlockFC 7 | from network.convonet.unet import UNet 8 | from network.convonet.unet3d import UNet3D 9 | 10 | 11 | class LocalPoolPointnet(nn.Module): 12 | """ PointNet-based encoder network with ResNet blocks for each point. 13 | Number of input points are fixed. 14 | 15 | Args: 16 | c_dim (int): dimension of latent code c 17 | dim (int): input points dimension 18 | hidden_dim (int): hidden dimension of the network 19 | scatter_type (str): feature aggregation when doing local pooling 20 | unet (bool): weather to use U-Net 21 | unet_kwargs (dict): U-Net parameters 22 | unet3d (bool): weather to use 3D U-Net 23 | unet3d_kwargs (str): 3D U-Net parameters 24 | plane_resolution (int): defined resolution for plane feature 25 | grid_resolution (int): defined resolution for grid feature 26 | plane_type (str or list): feature type, 'xz' - 1-plane, ['xz', 'xy', 'yz'] - 3-plane, ['grid'] - 3D grid volume 27 | padding (float): conventional padding paramter of ONet for unit cube, so [-0.5, 0.5] -> [-0.55, 0.55] 28 | n_blocks (int): number of blocks ResNetBlockFC layers 29 | """ 30 | 31 | def __init__(self, c_dim=128, dim=3, hidden_dim=128, scatter_type='max', 32 | unet=False, unet_kwargs=None, unet3d=False, unet3d_kwargs=None, 33 | plane_resolution=None, grid_resolution=None, plane_type='xz', padding=0.1, n_blocks=5): 34 | super().__init__() 35 | self.c_dim = c_dim 36 | 37 | self.fc_pos = nn.Linear(dim, 2 * hidden_dim) 38 | self.blocks = nn.ModuleList([ 39 | ResnetBlockFC(2 * hidden_dim, hidden_dim) for i in range(n_blocks) 40 | ]) 41 | self.fc_c = nn.Linear(hidden_dim, c_dim) 42 | 43 | self.actvn = nn.ReLU() 44 | self.hidden_dim = hidden_dim 45 | 46 | if unet: 47 | self.unet = UNet(c_dim, in_channels=c_dim, **unet_kwargs) 48 | else: 49 | self.unet = None 50 | 51 | if unet3d: 52 | self.unet3d = UNet3D(**unet3d_kwargs) 53 | else: 54 | self.unet3d = None 55 | 56 | self.reso_plane = plane_resolution 57 | self.reso_grid = grid_resolution 58 | self.plane_type = plane_type 59 | self.padding = padding 60 | 61 | if scatter_type == 'max': 62 | self.scatter = scatter_max 63 | elif scatter_type == 'mean': 64 | self.scatter = scatter_mean 65 | else: 66 | raise ValueError('incorrect scatter type') 67 | 68 | def generate_plane_features(self, p, c, plane='xz'): 69 | # acquire indices of features in plane 70 | xy = normalize_coordinate(p.clone(), plane=plane, padding=self.padding) # normalize to the range of (0, 1) 71 | index = coordinate2index(xy, self.reso_plane) 72 | 73 | # scatter plane features from points 74 | fea_plane = c.new_zeros(p.size(0), self.c_dim, self.reso_plane ** 2) 75 | c = c.permute(0, 2, 1) # B x 512 x T 76 | fea_plane = scatter_mean(c, index, out=fea_plane) # B x 512 x reso^2 77 | fea_plane = fea_plane.reshape(p.size(0), self.c_dim, self.reso_plane, 78 | self.reso_plane) # sparce matrix (B x 512 x reso x reso) 79 | 80 | # process the plane features with UNet 81 | if self.unet is not None: 82 | fea_plane = self.unet(fea_plane) 83 | 84 | return fea_plane 85 | 86 | def generate_grid_features(self, p, c): 87 | p_nor = normalize_3d_coordinate(p.clone(), padding=self.padding) 88 | index = coordinate2index(p_nor, self.reso_grid, coord_type='3d') 89 | # scatter grid features from points 90 | fea_grid = c.new_zeros(p.size(0), self.c_dim, self.reso_grid ** 3) 91 | c = c.permute(0, 2, 1) 92 | fea_grid = scatter_mean(c, index, out=fea_grid) # B x C x reso^3 93 | fea_grid = fea_grid.reshape(p.size(0), self.c_dim, self.reso_grid, self.reso_grid, 94 | self.reso_grid) # sparce matrix (B x 512 x reso x reso) 95 | 96 | if self.unet3d is not None: 97 | fea_grid = self.unet3d(fea_grid) 98 | 99 | return fea_grid 100 | 101 | def pool_local(self, xy, index, c): 102 | bs, fea_dim = c.size(0), c.size(2) 103 | keys = xy.keys() 104 | 105 | c_out = 0 106 | for key in keys: 107 | # scatter plane features from points 108 | if key == 'grid': 109 | fea = self.scatter(c.permute(0, 2, 1), index[key], dim_size=self.reso_grid ** 3) 110 | else: 111 | fea = self.scatter(c.permute(0, 2, 1), index[key], dim_size=self.reso_plane ** 2) 112 | if self.scatter == scatter_max: 113 | fea = fea[0] 114 | # gather feature back to points 115 | fea = fea.gather(dim=2, index=index[key].expand(-1, fea_dim, -1)) 116 | c_out += fea 117 | return c_out.permute(0, 2, 1) 118 | 119 | def forward(self, p): 120 | batch_size, T, D = p.size() 121 | 122 | # acquire the index for each point 123 | coord = {} 124 | index = {} 125 | if 'xz' in self.plane_type: 126 | coord['xz'] = normalize_coordinate(p.clone(), plane='xz', padding=self.padding) 127 | index['xz'] = coordinate2index(coord['xz'], self.reso_plane) 128 | if 'xy' in self.plane_type: 129 | coord['xy'] = normalize_coordinate(p.clone(), plane='xy', padding=self.padding) 130 | index['xy'] = coordinate2index(coord['xy'], self.reso_plane) 131 | if 'yz' in self.plane_type: 132 | coord['yz'] = normalize_coordinate(p.clone(), plane='yz', padding=self.padding) 133 | index['yz'] = coordinate2index(coord['yz'], self.reso_plane) 134 | if 'grid' in self.plane_type: 135 | coord['grid'] = normalize_3d_coordinate(p.clone(), padding=self.padding) 136 | index['grid'] = coordinate2index(coord['grid'], self.reso_grid, coord_type='3d') 137 | 138 | net = self.fc_pos(p) 139 | 140 | net = self.blocks[0](net) 141 | for block in self.blocks[1:]: 142 | pooled = self.pool_local(coord, index, net) 143 | net = torch.cat([net, pooled], dim=2) 144 | net = block(net) 145 | 146 | c = self.fc_c(net) 147 | 148 | fea = {} 149 | if 'grid' in self.plane_type: 150 | fea['grid'] = self.generate_grid_features(p, c) 151 | if 'xz' in self.plane_type: 152 | fea['xz'] = self.generate_plane_features(p, c, plane='xz') 153 | if 'xy' in self.plane_type: 154 | fea['xy'] = self.generate_plane_features(p, c, plane='xy') 155 | if 'yz' in self.plane_type: 156 | fea['yz'] = self.generate_plane_features(p, c, plane='yz') 157 | 158 | return fea 159 | 160 | 161 | class PatchLocalPoolPointnet(nn.Module): 162 | """ 163 | PointNet-based encoder network with ResNet blocks. 164 | First transform input points to local system based on the given voxel size. 165 | Support non-fixed number of point cloud, but need to precompute the index. 166 | 167 | Args: 168 | c_dim (int): dimension of latent code c 169 | dim (int): input points dimension 170 | hidden_dim (int): hidden dimension of the network 171 | scatter_type (str): feature aggregation when doing local pooling 172 | unet (bool): weather to use U-Net 173 | unet_kwargs (str): U-Net parameters 174 | unet3d (bool): weather to use 3D U-Net 175 | unet3d_kwargs (str): 3D U-Net parameters 176 | plane_resolution (int): defined resolution for plane feature 177 | grid_resolution (int): defined resolution for grid feature 178 | plane_type (str): feature type, 'xz' - 1-plane, ['xz', 'xy', 'yz'] - 3-plane, ['grid'] - 3D grid volume 179 | padding (float): conventional padding paramter of ONet for unit cube, so [-0.5, 0.5] -> [-0.55, 0.55] 180 | n_blocks (int): number of blocks ResNetBlockFC layers 181 | local_coord (bool): whether to use local coordinate 182 | pos_encoding (str): method for the positional encoding, linear|sin_cos 183 | unit_size (float): defined voxel unit size for local system 184 | """ 185 | 186 | def __init__(self, c_dim=128, dim=3, hidden_dim=128, scatter_type='max', 187 | unet=False, unet_kwargs=None, unet3d=False, unet3d_kwargs=None, 188 | plane_resolution=None, grid_resolution=None, plane_type='xz', padding=0.1, n_blocks=5, 189 | local_coord=False, pos_encoding='linear', unit_size=0.1): 190 | super().__init__() 191 | self.c_dim = c_dim 192 | 193 | self.blocks = nn.ModuleList([ 194 | ResnetBlockFC(2 * hidden_dim, hidden_dim) for i in range(n_blocks) 195 | ]) 196 | self.fc_c = nn.Linear(hidden_dim, c_dim) 197 | 198 | self.actvn = nn.ReLU() 199 | self.hidden_dim = hidden_dim 200 | self.reso_plane = plane_resolution 201 | self.reso_grid = grid_resolution 202 | self.plane_type = plane_type 203 | self.padding = padding 204 | 205 | if unet: 206 | self.unet = UNet(c_dim, in_channels=c_dim, **unet_kwargs) 207 | else: 208 | self.unet = None 209 | 210 | if unet3d: 211 | self.unet3d = UNet3D(**unet3d_kwargs) 212 | else: 213 | self.unet3d = None 214 | 215 | if scatter_type == 'max': 216 | self.scatter = scatter_max 217 | elif scatter_type == 'mean': 218 | self.scatter = scatter_mean 219 | else: 220 | raise ValueError('incorrect scatter type') 221 | 222 | if local_coord: 223 | self.map2local = map2local(unit_size, pos_encoding=pos_encoding) 224 | else: 225 | self.map2local = None 226 | 227 | if pos_encoding == 'sin_cos': 228 | self.fc_pos = nn.Linear(60, 2 * hidden_dim) 229 | else: 230 | self.fc_pos = nn.Linear(dim, 2 * hidden_dim) 231 | 232 | def generate_plane_features(self, index, c): 233 | c = c.permute(0, 2, 1) 234 | # scatter plane features from points 235 | if index.max() < self.reso_plane ** 2: 236 | fea_plane = c.new_zeros(c.size(0), self.c_dim, self.reso_plane ** 2) 237 | fea_plane = scatter_mean(c, index, out=fea_plane) # B x c_dim x reso^2 238 | else: 239 | fea_plane = scatter_mean(c, index) # B x c_dim x reso^2 240 | if fea_plane.shape[-1] > self.reso_plane ** 2: # deal with outliers 241 | fea_plane = fea_plane[:, :, :-1] 242 | 243 | fea_plane = fea_plane.reshape(c.size(0), self.c_dim, self.reso_plane, self.reso_plane) 244 | 245 | # process the plane features with UNet 246 | if self.unet is not None: 247 | fea_plane = self.unet(fea_plane) 248 | 249 | return fea_plane 250 | 251 | def generate_grid_features(self, index, c): 252 | # scatter grid features from points 253 | c = c.permute(0, 2, 1) 254 | if index.max() < self.reso_grid ** 3: 255 | fea_grid = c.new_zeros(c.size(0), self.c_dim, self.reso_grid ** 3) 256 | fea_grid = scatter_mean(c, index, out=fea_grid) # B x c_dim x reso^3 257 | else: 258 | fea_grid = scatter_mean(c, index) # B x c_dim x reso^3 259 | if fea_grid.shape[-1] > self.reso_grid ** 3: # deal with outliers 260 | fea_grid = fea_grid[:, :, :-1] 261 | fea_grid = fea_grid.reshape(c.size(0), self.c_dim, self.reso_grid, self.reso_grid, self.reso_grid) 262 | 263 | if self.unet3d is not None: 264 | fea_grid = self.unet3d(fea_grid) 265 | 266 | return fea_grid 267 | 268 | def pool_local(self, index, c): 269 | bs, fea_dim = c.size(0), c.size(2) 270 | keys = index.keys() 271 | 272 | c_out = 0 273 | for key in keys: 274 | # scatter plane features from points 275 | if key == 'grid': 276 | fea = self.scatter(c.permute(0, 2, 1), index[key]) 277 | else: 278 | fea = self.scatter(c.permute(0, 2, 1), index[key]) 279 | if self.scatter == scatter_max: 280 | fea = fea[0] 281 | 282 | # gather feature back to points 283 | fea = fea.gather(dim=2, index=index[key].expand(-1, fea_dim, -1)) 284 | c_out += fea 285 | return c_out.permute(0, 2, 1) 286 | 287 | def forward(self, inputs): 288 | p = inputs['points'] 289 | index = inputs['index'] 290 | 291 | if self.map2local: 292 | pp = self.map2local(p) 293 | net = self.fc_pos(pp) 294 | else: 295 | net = self.fc_pos(p) 296 | 297 | net = self.blocks[0](net) 298 | for block in self.blocks[1:]: 299 | pooled = self.pool_local(index, net) 300 | net = torch.cat([net, pooled], dim=2) 301 | net = block(net) 302 | 303 | c = self.fc_c(net) 304 | 305 | fea = {} 306 | if 'grid' in self.plane_type: 307 | fea['grid'] = self.generate_grid_features(index['grid'], c) 308 | if 'xz' in self.plane_type: 309 | fea['xz'] = self.generate_plane_features(index['xz'], c) 310 | if 'xy' in self.plane_type: 311 | fea['xy'] = self.generate_plane_features(index['xy'], c) 312 | if 'yz' in self.plane_type: 313 | fea['yz'] = self.generate_plane_features(index['yz'], c) 314 | 315 | return fea 316 | -------------------------------------------------------------------------------- /network/convonet/unet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from torch.nn import init 6 | 7 | 8 | def conv3x3(in_channels, out_channels, stride=1, 9 | padding=1, bias=True, groups=1): 10 | return nn.Conv2d( 11 | in_channels, 12 | out_channels, 13 | kernel_size=3, 14 | stride=stride, 15 | padding=padding, 16 | bias=bias, 17 | groups=groups) 18 | 19 | 20 | def upconv2x2(in_channels, out_channels, mode='transpose'): 21 | if mode == 'transpose': 22 | return nn.ConvTranspose2d( 23 | in_channels, 24 | out_channels, 25 | kernel_size=2, 26 | stride=2) 27 | else: 28 | # out_channels is always going to be the same 29 | # as in_channels 30 | return nn.Sequential( 31 | nn.Upsample(mode='bilinear', scale_factor=2), 32 | conv1x1(in_channels, out_channels)) 33 | 34 | 35 | def conv1x1(in_channels, out_channels, groups=1): 36 | return nn.Conv2d( 37 | in_channels, 38 | out_channels, 39 | kernel_size=1, 40 | groups=groups, 41 | stride=1) 42 | 43 | 44 | class DownConv(nn.Module): 45 | """ 46 | A helper Module that performs 2 convolutions and 1 MaxPool. 47 | A ReLU activation follows each convolution. 48 | """ 49 | def __init__(self, in_channels, out_channels, pooling=True): 50 | super(DownConv, self).__init__() 51 | 52 | self.in_channels = in_channels 53 | self.out_channels = out_channels 54 | self.pooling = pooling 55 | 56 | self.conv1 = conv3x3(self.in_channels, self.out_channels) 57 | self.conv2 = conv3x3(self.out_channels, self.out_channels) 58 | 59 | if self.pooling: 60 | self.pool = nn.MaxPool2d(kernel_size=2, stride=2) 61 | 62 | def forward(self, x): 63 | x = F.relu(self.conv1(x)) 64 | x = F.relu(self.conv2(x)) 65 | before_pool = x 66 | if self.pooling: 67 | x = self.pool(x) 68 | return x, before_pool 69 | 70 | 71 | class UpConv(nn.Module): 72 | """ 73 | A helper Module that performs 2 convolutions and 1 UpConvolution. 74 | A ReLU activation follows each convolution. 75 | """ 76 | def __init__(self, in_channels, out_channels, 77 | merge_mode='concat', up_mode='transpose'): 78 | super(UpConv, self).__init__() 79 | 80 | self.in_channels = in_channels 81 | self.out_channels = out_channels 82 | self.merge_mode = merge_mode 83 | self.up_mode = up_mode 84 | 85 | self.upconv = upconv2x2(self.in_channels, self.out_channels, 86 | mode=self.up_mode) 87 | 88 | if self.merge_mode == 'concat': 89 | self.conv1 = conv3x3( 90 | 2*self.out_channels, self.out_channels) 91 | else: 92 | # num of input channels to conv2 is same 93 | self.conv1 = conv3x3(self.out_channels, self.out_channels) 94 | self.conv2 = conv3x3(self.out_channels, self.out_channels) 95 | 96 | 97 | def forward(self, from_down, from_up): 98 | """ Forward pass 99 | Arguments: 100 | from_down: tensor from the encoder pathway 101 | from_up: upconv'd tensor from the decoder pathway 102 | """ 103 | from_up = self.upconv(from_up) 104 | if self.merge_mode == 'concat': 105 | x = torch.cat((from_up, from_down), 1) 106 | else: 107 | x = from_up + from_down 108 | x = F.relu(self.conv1(x)) 109 | x = F.relu(self.conv2(x)) 110 | return x 111 | 112 | 113 | class UNet(nn.Module): 114 | """ `UNet` class is based on https://arxiv.org/abs/1505.04597 115 | 116 | The U-Net is a convolutional encoder-decoder neural network. 117 | Contextual spatial information (from the decoding, 118 | expansive pathway) about an input tensor is merged with 119 | information representing the localization of details 120 | (from the encoding, compressive pathway). 121 | 122 | Modifications to the original paper: 123 | (1) padding is used in 3x3 convolutions to prevent loss 124 | of border pixels 125 | (2) merging outputs does not require cropping due to (1) 126 | (3) residual connections can be used by specifying 127 | UNet(merge_mode='add') 128 | (4) if non-parametric upsampling is used in the decoder 129 | pathway (specified by upmode='upsample'), then an 130 | additional 1x1 2d convolution occurs after upsampling 131 | to reduce channel dimensionality by a factor of 2. 132 | This channel halving happens with the convolution in 133 | the tranpose convolution (specified by upmode='transpose') 134 | """ 135 | 136 | def __init__(self, num_classes, in_channels=3, depth=5, 137 | start_filts=64, up_mode='transpose', 138 | merge_mode='concat', **kwargs): 139 | """ 140 | Arguments: 141 | in_channels: int, number of channels in the input tensor. 142 | Default is 3 for RGB images. 143 | depth: int, number of MaxPools in the U-Net. 144 | start_filts: int, number of convolutional filters for the 145 | first conv. 146 | up_mode: string, type of upconvolution. Choices: 'transpose' 147 | for transpose convolution or 'upsample' for nearest neighbour 148 | upsampling. 149 | """ 150 | super(UNet, self).__init__() 151 | 152 | if up_mode in ('transpose', 'upsample'): 153 | self.up_mode = up_mode 154 | else: 155 | raise ValueError("\"{}\" is not a valid mode for " 156 | "upsampling. Only \"transpose\" and " 157 | "\"upsample\" are allowed.".format(up_mode)) 158 | 159 | if merge_mode in ('concat', 'add'): 160 | self.merge_mode = merge_mode 161 | else: 162 | raise ValueError("\"{}\" is not a valid mode for" 163 | "merging up and down paths. " 164 | "Only \"concat\" and " 165 | "\"add\" are allowed.".format(up_mode)) 166 | 167 | # NOTE: up_mode 'upsample' is incompatible with merge_mode 'add' 168 | if self.up_mode == 'upsample' and self.merge_mode == 'add': 169 | raise ValueError("up_mode \"upsample\" is incompatible " 170 | "with merge_mode \"add\" at the moment " 171 | "because it doesn't make sense to use " 172 | "nearest neighbour to reduce " 173 | "depth channels (by half).") 174 | 175 | self.num_classes = num_classes 176 | self.in_channels = in_channels 177 | self.start_filts = start_filts 178 | self.depth = depth 179 | 180 | self.down_convs = [] 181 | self.up_convs = [] 182 | 183 | # create the encoder pathway and add to a list 184 | for i in range(depth): 185 | ins = self.in_channels if i == 0 else outs 186 | outs = self.start_filts*(2**i) 187 | pooling = True if i < depth-1 else False 188 | 189 | down_conv = DownConv(ins, outs, pooling=pooling) 190 | self.down_convs.append(down_conv) 191 | 192 | # create the decoder pathway and add to a list 193 | # - careful! decoding only requires depth-1 block 194 | for i in range(depth-1): 195 | ins = outs 196 | outs = ins // 2 197 | up_conv = UpConv(ins, outs, up_mode=up_mode, 198 | merge_mode=merge_mode) 199 | self.up_convs.append(up_conv) 200 | 201 | # add the list of modules to current module 202 | self.down_convs = nn.ModuleList(self.down_convs) 203 | self.up_convs = nn.ModuleList(self.up_convs) 204 | self.conv_final = conv1x1(outs, self.num_classes) 205 | self.reset_params() 206 | 207 | @staticmethod 208 | def weight_init(m): 209 | if isinstance(m, nn.Conv2d): 210 | init.xavier_normal_(m.weight) 211 | init.constant_(m.bias, 0) 212 | 213 | def reset_params(self): 214 | for i, m in enumerate(self.modules()): 215 | self.weight_init(m) 216 | 217 | def forward(self, x): 218 | encoder_outs = [] 219 | # encoder pathway, save outputs for merging 220 | for i, module in enumerate(self.down_convs): 221 | x, before_pool = module(x) 222 | encoder_outs.append(before_pool) 223 | for i, module in enumerate(self.up_convs): 224 | before_pool = encoder_outs[-(i+2)] 225 | x = module(before_pool, x) 226 | 227 | # No softmax is used. This means you need to use 228 | # nn.CrossEntropyLoss is your training script, 229 | # as this module includes a softmax already. 230 | x = self.conv_final(x) 231 | return x 232 | -------------------------------------------------------------------------------- /network/convonet/unet3d.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import torch.nn as nn 3 | from torch.nn import functional as F 4 | from functools import partial 5 | 6 | 7 | def number_of_features_per_level(init_channel_number, num_levels): 8 | return [init_channel_number * 2 ** k for k in range(num_levels)] 9 | 10 | 11 | def conv3d(in_channels, out_channels, kernel_size, bias, padding=1): 12 | return nn.Conv3d(in_channels, out_channels, kernel_size, padding=padding, bias=bias) 13 | 14 | 15 | def create_conv(in_channels, out_channels, kernel_size, order, num_groups, padding=1): 16 | """ 17 | Create a list of modules with together constitute a single conv layer with non-linearity 18 | and optional batchnorm/groupnorm. 19 | 20 | Args: 21 | in_channels (int): number of input channels 22 | out_channels (int): number of output channels 23 | order (string): order of things, e.g. 24 | 'cr' -> conv + ReLU 25 | 'gcr' -> groupnorm + conv + ReLU 26 | 'cl' -> conv + LeakyReLU 27 | 'ce' -> conv + ELU 28 | 'bcr' -> batchnorm + conv + ReLU 29 | num_groups (int): number of groups for the GroupNorm 30 | padding (int): add zero-padding to the input 31 | 32 | Return: 33 | list of tuple (name, module) 34 | """ 35 | assert 'c' in order, "Conv layer MUST be present" 36 | assert order[0] not in 'rle', 'Non-linearity cannot be the first operation in the layer' 37 | 38 | modules = [] 39 | for i, char in enumerate(order): 40 | if char == 'r': 41 | modules.append(('ReLU', nn.ReLU(inplace=True))) 42 | elif char == 'l': 43 | modules.append(('LeakyReLU', nn.LeakyReLU(negative_slope=0.1, inplace=True))) 44 | elif char == 'e': 45 | modules.append(('ELU', nn.ELU(inplace=True))) 46 | elif char == 'c': 47 | # add learnable bias only in the absence of batchnorm/groupnorm 48 | bias = not ('g' in order or 'b' in order) 49 | modules.append(('conv', conv3d(in_channels, out_channels, kernel_size, bias, padding=padding))) 50 | elif char == 'g': 51 | is_before_conv = i < order.index('c') 52 | if is_before_conv: 53 | num_channels = in_channels 54 | else: 55 | num_channels = out_channels 56 | 57 | # use only one group if the given number of groups is greater than the number of channels 58 | if num_channels < num_groups: 59 | num_groups = 1 60 | 61 | assert num_channels % num_groups == 0, f'Expected number of channels in input to be divisible by num_groups. num_channels={num_channels}, num_groups={num_groups}' 62 | modules.append(('groupnorm', nn.GroupNorm(num_groups=num_groups, num_channels=num_channels))) 63 | elif char == 'b': 64 | is_before_conv = i < order.index('c') 65 | if is_before_conv: 66 | modules.append(('batchnorm', nn.BatchNorm3d(in_channels))) 67 | else: 68 | modules.append(('batchnorm', nn.BatchNorm3d(out_channels))) 69 | else: 70 | raise ValueError(f"Unsupported layer type '{char}'. MUST be one of ['b', 'g', 'r', 'l', 'e', 'c']") 71 | 72 | return modules 73 | 74 | 75 | class SingleConv(nn.Sequential): 76 | """ 77 | Basic convolutional module consisting of a Conv3d, non-linearity and optional batchnorm/groupnorm. The order 78 | of operations can be specified via the `order` parameter 79 | 80 | Args: 81 | in_channels (int): number of input channels 82 | out_channels (int): number of output channels 83 | kernel_size (int): size of the convolving kernel 84 | order (string): determines the order of layers, e.g. 85 | 'cr' -> conv + ReLU 86 | 'crg' -> conv + ReLU + groupnorm 87 | 'cl' -> conv + LeakyReLU 88 | 'ce' -> conv + ELU 89 | num_groups (int): number of groups for the GroupNorm 90 | """ 91 | 92 | def __init__(self, in_channels, out_channels, kernel_size=3, order='crg', num_groups=8, padding=1): 93 | super(SingleConv, self).__init__() 94 | 95 | for name, module in create_conv(in_channels, out_channels, kernel_size, order, num_groups, padding=padding): 96 | self.add_module(name, module) 97 | 98 | 99 | class DoubleConv(nn.Sequential): 100 | """ 101 | A module consisting of two consecutive convolution layers (e.g. BatchNorm3d+ReLU+Conv3d). 102 | We use (Conv3d+ReLU+GroupNorm3d) by default. 103 | This can be changed however by providing the 'order' argument, e.g. in order 104 | to change to Conv3d+BatchNorm3d+ELU use order='cbe'. 105 | Use padded convolutions to make sure that the output (H_out, W_out) is the same 106 | as (H_in, W_in), so that you don't have to crop in the decoder path. 107 | 108 | Args: 109 | in_channels (int): number of input channels 110 | out_channels (int): number of output channels 111 | encoder (bool): if True we're in the encoder path, otherwise we're in the decoder 112 | kernel_size (int): size of the convolving kernel 113 | order (string): determines the order of layers, e.g. 114 | 'cr' -> conv + ReLU 115 | 'crg' -> conv + ReLU + groupnorm 116 | 'cl' -> conv + LeakyReLU 117 | 'ce' -> conv + ELU 118 | num_groups (int): number of groups for the GroupNorm 119 | """ 120 | 121 | def __init__(self, in_channels, out_channels, encoder, kernel_size=3, order='crg', num_groups=8): 122 | super(DoubleConv, self).__init__() 123 | if encoder: 124 | # we're in the encoder path 125 | conv1_in_channels = in_channels 126 | conv1_out_channels = out_channels // 2 127 | if conv1_out_channels < in_channels: 128 | conv1_out_channels = in_channels 129 | conv2_in_channels, conv2_out_channels = conv1_out_channels, out_channels 130 | else: 131 | # we're in the decoder path, decrease the number of channels in the 1st convolution 132 | conv1_in_channels, conv1_out_channels = in_channels, out_channels 133 | conv2_in_channels, conv2_out_channels = out_channels, out_channels 134 | 135 | # conv1 136 | self.add_module('SingleConv1', 137 | SingleConv(conv1_in_channels, conv1_out_channels, kernel_size, order, num_groups)) 138 | # conv2 139 | self.add_module('SingleConv2', 140 | SingleConv(conv2_in_channels, conv2_out_channels, kernel_size, order, num_groups)) 141 | 142 | 143 | class ExtResNetBlock(nn.Module): 144 | """ 145 | Basic UNet block consisting of a SingleConv followed by the residual block. 146 | The SingleConv takes care of increasing/decreasing the number of channels and also ensures that the number 147 | of output channels is compatible with the residual block that follows. 148 | This block can be used instead of standard DoubleConv in the Encoder module. 149 | Motivated by: https://arxiv.org/pdf/1706.00120.pdf 150 | 151 | Notice we use ELU instead of ReLU (order='cge') and put non-linearity after the groupnorm. 152 | """ 153 | 154 | def __init__(self, in_channels, out_channels, kernel_size=3, order='cge', num_groups=8, **kwargs): 155 | super(ExtResNetBlock, self).__init__() 156 | 157 | # first convolution 158 | self.conv1 = SingleConv(in_channels, out_channels, kernel_size=kernel_size, order=order, num_groups=num_groups) 159 | # residual block 160 | self.conv2 = SingleConv(out_channels, out_channels, kernel_size=kernel_size, order=order, num_groups=num_groups) 161 | # remove non-linearity from the 3rd convolution since it's going to be applied after adding the residual 162 | n_order = order 163 | for c in 'rel': 164 | n_order = n_order.replace(c, '') 165 | self.conv3 = SingleConv(out_channels, out_channels, kernel_size=kernel_size, order=n_order, 166 | num_groups=num_groups) 167 | 168 | # create non-linearity separately 169 | if 'l' in order: 170 | self.non_linearity = nn.LeakyReLU(negative_slope=0.1, inplace=True) 171 | elif 'e' in order: 172 | self.non_linearity = nn.ELU(inplace=True) 173 | else: 174 | self.non_linearity = nn.ReLU(inplace=True) 175 | 176 | def forward(self, x): 177 | # apply first convolution and save the output as a residual 178 | out = self.conv1(x) 179 | residual = out 180 | 181 | # residual block 182 | out = self.conv2(out) 183 | out = self.conv3(out) 184 | 185 | out += residual 186 | out = self.non_linearity(out) 187 | 188 | return out 189 | 190 | 191 | class Encoder(nn.Module): 192 | """ 193 | A single module from the encoder path consisting of the optional max 194 | pooling layer (one may specify the MaxPool kernel_size to be different 195 | than the standard (2,2,2), e.g. if the volumetric data is anisotropic 196 | (make sure to use complementary scale_factor in the decoder path) followed by 197 | a DoubleConv module. 198 | Args: 199 | in_channels (int): number of input channels 200 | out_channels (int): number of output channels 201 | conv_kernel_size (int): size of the convolving kernel 202 | apply_pooling (bool): if True use MaxPool3d before DoubleConv 203 | pool_kernel_size (tuple): the size of the window to take a max over 204 | pool_type (str): pooling layer: 'max' or 'avg' 205 | basic_module(nn.Module): either ResNetBlock or DoubleConv 206 | conv_layer_order (string): determines the order of layers 207 | in `DoubleConv` module. See `DoubleConv` for more info. 208 | num_groups (int): number of groups for the GroupNorm 209 | """ 210 | 211 | def __init__(self, in_channels, out_channels, conv_kernel_size=3, apply_pooling=True, 212 | pool_kernel_size=(2, 2, 2), pool_type='max', basic_module=DoubleConv, conv_layer_order='crg', 213 | num_groups=8): 214 | super(Encoder, self).__init__() 215 | assert pool_type in ['max', 'avg'] 216 | if apply_pooling: 217 | if pool_type == 'max': 218 | self.pooling = nn.MaxPool3d(kernel_size=pool_kernel_size) 219 | else: 220 | self.pooling = nn.AvgPool3d(kernel_size=pool_kernel_size) 221 | else: 222 | self.pooling = None 223 | 224 | self.basic_module = basic_module(in_channels, out_channels, 225 | encoder=True, 226 | kernel_size=conv_kernel_size, 227 | order=conv_layer_order, 228 | num_groups=num_groups) 229 | 230 | def forward(self, x): 231 | if self.pooling is not None: 232 | x = self.pooling(x) 233 | x = self.basic_module(x) 234 | return x 235 | 236 | 237 | class Decoder(nn.Module): 238 | """ 239 | A single module for decoder path consisting of the upsampling layer 240 | (either learned ConvTranspose3d or nearest neighbor interpolation) followed by a basic module (DoubleConv or ExtResNetBlock). 241 | Args: 242 | in_channels (int): number of input channels 243 | out_channels (int): number of output channels 244 | kernel_size (int): size of the convolving kernel 245 | scale_factor (tuple): used as the multiplier for the image H/W/D in 246 | case of nn.Upsample or as stride in case of ConvTranspose3d, must reverse the MaxPool3d operation 247 | from the corresponding encoder 248 | basic_module(nn.Module): either ResNetBlock or DoubleConv 249 | conv_layer_order (string): determines the order of layers 250 | in `DoubleConv` module. See `DoubleConv` for more info. 251 | num_groups (int): number of groups for the GroupNorm 252 | """ 253 | 254 | def __init__(self, in_channels, out_channels, kernel_size=3, scale_factor=(2, 2, 2), basic_module=DoubleConv, 255 | conv_layer_order='crg', num_groups=8, mode='nearest'): 256 | super(Decoder, self).__init__() 257 | if basic_module == DoubleConv: 258 | # if DoubleConv is the basic_module use interpolation for upsampling and concatenation joining 259 | self.upsampling = Upsampling(transposed_conv=False, in_channels=in_channels, out_channels=out_channels, 260 | kernel_size=kernel_size, scale_factor=scale_factor, mode=mode) 261 | # concat joining 262 | self.joining = partial(self._joining, concat=True) 263 | else: 264 | # if basic_module=ExtResNetBlock use transposed convolution upsampling and summation joining 265 | self.upsampling = Upsampling(transposed_conv=True, in_channels=in_channels, out_channels=out_channels, 266 | kernel_size=kernel_size, scale_factor=scale_factor, mode=mode) 267 | # sum joining 268 | self.joining = partial(self._joining, concat=False) 269 | # adapt the number of in_channels for the ExtResNetBlock 270 | in_channels = out_channels 271 | 272 | self.basic_module = basic_module(in_channels, out_channels, 273 | encoder=False, 274 | kernel_size=kernel_size, 275 | order=conv_layer_order, 276 | num_groups=num_groups) 277 | 278 | def forward(self, encoder_features, x): 279 | x = self.upsampling(encoder_features=encoder_features, x=x) 280 | x = self.joining(encoder_features, x) 281 | x = self.basic_module(x) 282 | return x 283 | 284 | @staticmethod 285 | def _joining(encoder_features, x, concat): 286 | if concat: 287 | return torch.cat((encoder_features, x), dim=1) 288 | else: 289 | return encoder_features + x 290 | 291 | 292 | class Upsampling(nn.Module): 293 | """ 294 | Upsamples a given multi-channel 3D data using either interpolation or learned transposed convolution. 295 | 296 | Args: 297 | transposed_conv (bool): if True uses ConvTranspose3d for upsampling, otherwise uses interpolation 298 | concat_joining (bool): if True uses concatenation joining between encoder and decoder features, otherwise 299 | uses summation joining (see Residual U-Net) 300 | in_channels (int): number of input channels for transposed conv 301 | out_channels (int): number of output channels for transpose conv 302 | kernel_size (int or tuple): size of the convolving kernel 303 | scale_factor (int or tuple): stride of the convolution 304 | mode (str): algorithm used for upsampling: 305 | 'nearest' | 'linear' | 'bilinear' | 'trilinear' | 'area'. Default: 'nearest' 306 | """ 307 | 308 | def __init__(self, transposed_conv, in_channels=None, out_channels=None, kernel_size=3, 309 | scale_factor=(2, 2, 2), mode='nearest'): 310 | super(Upsampling, self).__init__() 311 | 312 | if transposed_conv: 313 | # make sure that the output size reverses the MaxPool3d from the corresponding encoder 314 | # (D_out = (D_in − 1) ×  stride[0] − 2 ×  padding[0] +  kernel_size[0] +  output_padding[0]) 315 | self.upsample = nn.ConvTranspose3d(in_channels, out_channels, kernel_size=kernel_size, stride=scale_factor, 316 | padding=1) 317 | else: 318 | self.upsample = partial(self._interpolate, mode=mode) 319 | 320 | def forward(self, encoder_features, x): 321 | output_size = encoder_features.size()[2:] 322 | return self.upsample(x, output_size) 323 | 324 | @staticmethod 325 | def _interpolate(x, size, mode): 326 | return F.interpolate(x, size=size, mode=mode) 327 | 328 | 329 | class FinalConv(nn.Sequential): 330 | """ 331 | A module consisting of a convolution layer (e.g. Conv3d+ReLU+GroupNorm3d) and the final 1x1 convolution 332 | which reduces the number of channels to 'out_channels'. 333 | with the number of output channels 'out_channels // 2' and 'out_channels' respectively. 334 | We use (Conv3d+ReLU+GroupNorm3d) by default. 335 | This can be change however by providing the 'order' argument, e.g. in order 336 | to change to Conv3d+BatchNorm3d+ReLU use order='cbr'. 337 | Args: 338 | in_channels (int): number of input channels 339 | out_channels (int): number of output channels 340 | kernel_size (int): size of the convolving kernel 341 | order (string): determines the order of layers, e.g. 342 | 'cr' -> conv + ReLU 343 | 'crg' -> conv + ReLU + groupnorm 344 | num_groups (int): number of groups for the GroupNorm 345 | """ 346 | 347 | def __init__(self, in_channels, out_channels, kernel_size=3, order='crg', num_groups=8): 348 | super(FinalConv, self).__init__() 349 | 350 | # conv1 351 | self.add_module('SingleConv', SingleConv(in_channels, in_channels, kernel_size, order, num_groups)) 352 | 353 | # in the last layer a 1×1 convolution reduces the number of output channels to out_channels 354 | final_conv = nn.Conv3d(in_channels, out_channels, 1) 355 | self.add_module('final_conv', final_conv) 356 | 357 | class Abstract3DUNet(nn.Module): 358 | """ 359 | Base class for standard and residual UNet. 360 | 361 | Args: 362 | in_channels (int): number of input channels 363 | out_channels (int): number of output segmentation masks; 364 | Note that that the of out_channels might correspond to either 365 | different semantic classes or to different binary segmentation mask. 366 | It's up to the user of the class to interpret the out_channels and 367 | use the proper loss criterion during training (i.e. CrossEntropyLoss (multi-class) 368 | or BCEWithLogitsLoss (two-class) respectively) 369 | f_maps (int, tuple): number of feature maps at each level of the encoder; if it's an integer the number 370 | of feature maps is given by the geometric progression: f_maps ^ k, k=1,2,3,4 371 | final_sigmoid (bool): if True apply element-wise nn.Sigmoid after the 372 | final 1x1 convolution, otherwise apply nn.Softmax. MUST be True if nn.BCELoss (two-class) is used 373 | to train the model. MUST be False if nn.CrossEntropyLoss (multi-class) is used to train the model. 374 | basic_module: basic model for the encoder/decoder (DoubleConv, ExtResNetBlock, ....) 375 | layer_order (string): determines the order of layers 376 | in `SingleConv` module. e.g. 'crg' stands for Conv3d+ReLU+GroupNorm3d. 377 | See `SingleConv` for more info 378 | f_maps (int, tuple): if int: number of feature maps in the first conv layer of the encoder (default: 64); 379 | if tuple: number of feature maps at each level 380 | num_groups (int): number of groups for the GroupNorm 381 | num_levels (int): number of levels in the encoder/decoder path (applied only if f_maps is an int) 382 | is_segmentation (bool): if True (semantic segmentation problem) Sigmoid/Softmax normalization is applied 383 | after the final convolution; if False (regression problem) the normalization layer is skipped at the end 384 | testing (bool): if True (testing mode) the `final_activation` (if present, i.e. `is_segmentation=true`) 385 | will be applied as the last operation during the forward pass; if False the model is in training mode 386 | and the `final_activation` (even if present) won't be applied; default: False 387 | """ 388 | 389 | def __init__(self, in_channels, out_channels, final_sigmoid, basic_module, f_maps=64, layer_order='gcr', 390 | num_groups=8, num_levels=4, is_segmentation=False, testing=False, **kwargs): 391 | super(Abstract3DUNet, self).__init__() 392 | 393 | self.testing = testing 394 | 395 | if isinstance(f_maps, int): 396 | f_maps = number_of_features_per_level(f_maps, num_levels=num_levels) 397 | 398 | # create encoder path consisting of Encoder modules. Depth of the encoder is equal to `len(f_maps)` 399 | encoders = [] 400 | for i, out_feature_num in enumerate(f_maps): 401 | if i == 0: 402 | encoder = Encoder(in_channels, out_feature_num, apply_pooling=False, basic_module=basic_module, 403 | conv_layer_order=layer_order, num_groups=num_groups) 404 | else: 405 | # TODO: adapt for anisotropy in the data, i.e. use proper pooling kernel to make the data isotropic after 1-2 pooling operations 406 | # currently pools with a constant kernel: (2, 2, 2) 407 | encoder = Encoder(f_maps[i - 1], out_feature_num, basic_module=basic_module, 408 | conv_layer_order=layer_order, num_groups=num_groups) 409 | encoders.append(encoder) 410 | 411 | self.encoders = nn.ModuleList(encoders) 412 | 413 | # create decoder path consisting of the Decoder modules. The length of the decoder is equal to `len(f_maps) - 1` 414 | decoders = [] 415 | reversed_f_maps = list(reversed(f_maps)) 416 | for i in range(len(reversed_f_maps) - 1): 417 | if basic_module == DoubleConv: 418 | in_feature_num = reversed_f_maps[i] + reversed_f_maps[i + 1] 419 | else: 420 | in_feature_num = reversed_f_maps[i] 421 | 422 | out_feature_num = reversed_f_maps[i + 1] 423 | # TODO: if non-standard pooling was used, make sure to use correct striding for transpose conv 424 | # currently strides with a constant stride: (2, 2, 2) 425 | decoder = Decoder(in_feature_num, out_feature_num, basic_module=basic_module, 426 | conv_layer_order=layer_order, num_groups=num_groups) 427 | decoders.append(decoder) 428 | 429 | self.decoders = nn.ModuleList(decoders) 430 | 431 | # in the last layer a 1×1 convolution reduces the number of output 432 | # channels to the number of labels 433 | self.final_conv = nn.Conv3d(f_maps[0], out_channels, 1) 434 | 435 | if is_segmentation: 436 | # semantic segmentation problem 437 | if final_sigmoid: 438 | self.final_activation = nn.Sigmoid() 439 | else: 440 | self.final_activation = nn.Softmax(dim=1) 441 | else: 442 | # regression problem 443 | self.final_activation = None 444 | 445 | def forward(self, x): 446 | # encoder part 447 | encoders_features = [] 448 | for encoder in self.encoders: 449 | x = encoder(x) 450 | # reverse the encoder outputs to be aligned with the decoder 451 | encoders_features.insert(0, x) 452 | 453 | # remove the last encoder's output from the list 454 | # !!remember: it's the 1st in the list 455 | encoders_features = encoders_features[1:] 456 | 457 | # decoder part 458 | for decoder, encoder_features in zip(self.decoders, encoders_features): 459 | # pass the output from the corresponding encoder and the output 460 | # of the previous decoder 461 | x = decoder(encoder_features, x) 462 | 463 | x = self.final_conv(x) 464 | 465 | # apply final_activation (i.e. Sigmoid or Softmax) only during prediction. During training the network outputs 466 | # logits and it's up to the user to normalize it before visualising with tensorboard or computing validation metric 467 | if self.testing and self.final_activation is not None: 468 | x = self.final_activation(x) 469 | 470 | return x 471 | 472 | 473 | class UNet3D(Abstract3DUNet): 474 | """ 475 | 3DUnet model from 476 | `"3D U-Net: Learning Dense Volumetric Segmentation from Sparse Annotation" 477 | `. 478 | 479 | Uses `DoubleConv` as a basic_module and nearest neighbor upsampling in the decoder 480 | """ 481 | 482 | def __init__(self, in_channels, out_channels, final_sigmoid=True, f_maps=64, layer_order='gcr', 483 | num_groups=8, num_levels=4, is_segmentation=True, **kwargs): 484 | super(UNet3D, self).__init__(in_channels=in_channels, out_channels=out_channels, final_sigmoid=final_sigmoid, 485 | basic_module=DoubleConv, f_maps=f_maps, layer_order=layer_order, 486 | num_groups=num_groups, num_levels=num_levels, is_segmentation=is_segmentation, 487 | **kwargs) 488 | 489 | 490 | class ResidualUNet3D(Abstract3DUNet): 491 | """ 492 | Residual 3DUnet model implementation based on https://arxiv.org/pdf/1706.00120.pdf. 493 | Uses ExtResNetBlock as a basic building block, summation joining instead 494 | of concatenation joining and transposed convolutions for upsampling (watch out for block artifacts). 495 | Since the model effectively becomes a residual net, in theory it allows for deeper UNet. 496 | """ 497 | 498 | def __init__(self, in_channels, out_channels, final_sigmoid=True, f_maps=64, layer_order='gcr', 499 | num_groups=8, num_levels=5, is_segmentation=True, **kwargs): 500 | super(ResidualUNet3D, self).__init__(in_channels=in_channels, out_channels=out_channels, 501 | final_sigmoid=final_sigmoid, 502 | basic_module=ExtResNetBlock, f_maps=f_maps, layer_order=layer_order, 503 | num_groups=num_groups, num_levels=num_levels, 504 | is_segmentation=is_segmentation, 505 | **kwargs) 506 | 507 | 508 | def get_model(config): 509 | def _model_class(class_name): 510 | m = importlib.import_module('pytorch3dunet.unet3d.model') 511 | clazz = getattr(m, class_name) 512 | return clazz 513 | 514 | assert 'model' in config, 'Could not find model configuration' 515 | model_config = config['model'] 516 | model_class = _model_class(model_config['name']) 517 | return model_class(**model_config) 518 | -------------------------------------------------------------------------------- /network/decoder.py: -------------------------------------------------------------------------------- 1 | """ 2 | Polyhedral feature decoders. 3 | """ 4 | 5 | import torch 6 | import torch.nn.functional as F 7 | from torch_geometric.nn import MLP 8 | 9 | from network.convonet.ResnetBlockFC import ResnetBlockFC 10 | from utils import normalize_coordinate, normalize_3d_coordinate 11 | 12 | 13 | class MLPDecoder(torch.nn.Module): 14 | """ 15 | Global decoder to fuse queries and latent. Adapted from IM-Net. 16 | """ 17 | 18 | def __init__(self, latent_dim, num_queries): 19 | super().__init__() 20 | self.num_queries = num_queries 21 | self.mlp = MLP([latent_dim + self.num_queries * 3, latent_dim * 4, latent_dim * 4, latent_dim * 2, latent_dim], plain_last=False) 22 | 23 | def forward(self, latent, queries, batch): 24 | # concat queries and latent 25 | latent = latent[batch] # (num_cells_in_batch, 256) 26 | queries = queries.view(queries.shape[0], -1) # (num_cells_in_batch, num_queries_per_cell * 3) 27 | pointz = torch.cat([queries, latent], 1) # (num_cells_in_batch, num_queries_per_cell * 3 + 256) 28 | return self.mlp(pointz) # (num_cells_in_batch, num_queries_per_cell, 1) 29 | 30 | 31 | class ConvONetDecoder(torch.nn.Module): 32 | """ 33 | Convolutional occupancy decoder. Instead of conditioning on global features, on plane/volume local features. 34 | Args: 35 | dim (int): input dimension 36 | c_dim (int): dimension of latent conditioned code c 37 | hidden_size (int): hidden size of Decoder network 38 | n_blocks (int): number of blocks ResNetBlockFC layers 39 | leaky (bool): whether to use leaky ReLUs 40 | sample_mode (str): sampling feature strategy, bi-linear | nearest 41 | padding (float): conventional padding parameter of ONet for unit cube, so [-0.5, 0.5] -> [-0.55, 0.55] 42 | """ 43 | 44 | def __init__(self, dim=3, c_dim=128, hidden_size=256, latent_dim=4096, num_queries=16 ,n_blocks=5, 45 | leaky=False, sample_mode='bilinear', padding=0.1): 46 | super().__init__() 47 | self.c_dim = c_dim 48 | self.num_queries = num_queries 49 | self.n_blocks = n_blocks 50 | 51 | if c_dim != 0: 52 | self.fc_c = torch.nn.ModuleList([ 53 | torch.nn.Linear(c_dim, hidden_size) for _ in range(n_blocks) 54 | ]) 55 | 56 | self.fc_p = torch.nn.Linear(dim, hidden_size) 57 | 58 | self.blocks = torch.nn.ModuleList([ 59 | ResnetBlockFC(hidden_size) for _ in range(n_blocks) 60 | ]) 61 | if latent_dim != hidden_size * num_queries: 62 | # not recommended as breaks explicit per-query feature 63 | self.fc_out = torch.nn.Linear(hidden_size * num_queries, latent_dim) 64 | else: 65 | self.fc_out = None 66 | 67 | if not leaky: 68 | self.actvn = F.relu 69 | else: 70 | self.actvn = lambda x: F.leaky_relu(x, 0.2) 71 | 72 | self.sample_mode = sample_mode 73 | self.padding = padding 74 | 75 | def sample_plane_feature(self, p, c, plane='xz'): 76 | xy = normalize_coordinate(p.clone(), plane=plane, padding=self.padding) # normalize to the range of (0, 1) 77 | xy = xy[:, :, None].float() 78 | vgrid = 2.0 * xy - 1.0 # normalize to (-1, 1) 79 | c = F.grid_sample(c, vgrid, padding_mode='border', align_corners=True, mode=self.sample_mode).squeeze(-1) 80 | return c 81 | 82 | def sample_grid_feature(self, p, c): 83 | p_nor = normalize_3d_coordinate(p.clone(), padding=self.padding) # normalize to the range of (0, 1) 84 | p_nor = p_nor[:, :, None, None].float() 85 | vgrid = 2.0 * p_nor - 1.0 # normalize to (-1, 1) 86 | # bilinear interpolation 87 | c = F.grid_sample(c, vgrid, padding_mode='border', align_corners=True, mode=self.sample_mode).squeeze( 88 | -1).squeeze(-1) 89 | return c 90 | 91 | def forward(self, p, c_plane, **kwargs): 92 | if self.c_dim != 0: 93 | plane_type = list(c_plane.keys()) 94 | c = 0 95 | if 'grid' in plane_type: 96 | c += self.sample_grid_feature(p, c_plane['grid']) 97 | if 'xz' in plane_type: 98 | c += self.sample_plane_feature(p, c_plane['xz'], plane='xz') 99 | if 'xy' in plane_type: 100 | c += self.sample_plane_feature(p, c_plane['xy'], plane='xy') 101 | if 'yz' in plane_type: 102 | c += self.sample_plane_feature(p, c_plane['yz'], plane='yz') 103 | c = c.transpose(1, 2) 104 | 105 | p = p.float() 106 | net = self.fc_p(p) 107 | 108 | for i in range(self.n_blocks): 109 | if self.c_dim != 0: 110 | net = net + self.fc_c[i](c) 111 | 112 | net = self.blocks[i](net) 113 | 114 | # aggregate to cell-wise features 115 | out = net.squeeze() 116 | out = out.view(out.shape[0] // self.num_queries , -1) # (num_cells, latent_b_dim * num_queries == 4096) 117 | if self.fc_out is not None: 118 | out = self.fc_out(out) 119 | 120 | return out 121 | 122 | 123 | class Decoder(torch.nn.Module): 124 | """ 125 | Global decoder to fuse queries and latent. Adapted from IM-Net. 126 | """ 127 | 128 | def __init__(self, backbone, latent_dim, num_queries): 129 | super().__init__() 130 | self.backbone = backbone 131 | self.latent_dim = latent_dim 132 | if backbone.casefold() == 'MLP'.casefold(): 133 | self.network = MLPDecoder(latent_dim=latent_dim, num_queries=num_queries) 134 | elif backbone.casefold() == 'ConvONet'.casefold(): 135 | self.network = ConvONetDecoder(latent_dim=latent_dim, num_queries=num_queries) 136 | else: 137 | raise ValueError(f'Unexpected backbone: {backbone}') 138 | 139 | def forward(self, latent, queries, batch): 140 | # latent: (num_graphs_in_batch, latent_a_dim) 141 | # queries: (num_cells_in_batch, num_queries_per_cell, 3) 142 | # batch: (num_cells_in_batch, 1) 143 | 144 | if self.backbone.casefold() == 'MLP'.casefold(): 145 | outs = self.network(latent, queries, batch) 146 | 147 | else: 148 | # occupancy network decoder 149 | outs = torch.zeros([len(batch), self.latent_dim]).to(batch.device) 150 | for i in range(batch[-1] + 1): 151 | latent_i = {'xz': latent['xz'][i].unsqueeze(0), 'xy': latent['xy'][i].unsqueeze(0), 152 | 'yz': latent['yz'][i].unsqueeze(0)} 153 | queries_i = queries[batch == i].view(-1, 3).unsqueeze(0) 154 | outs[batch == i] = self.network(queries_i, latent_i) 155 | 156 | # (num_cells_in_batch, num_features_per_cell == latent_b_dim) 157 | return outs 158 | -------------------------------------------------------------------------------- /network/encoder.py: -------------------------------------------------------------------------------- 1 | """ 2 | Point cloud encoders. 3 | """ 4 | 5 | import torch 6 | from torch import Tensor 7 | from torch.nn import Sequential, Linear, ReLU, LeakyReLU 8 | from torch_geometric.nn import PointNetConv, XConv, DynamicEdgeConv 9 | from torch_geometric.nn import fps, global_mean_pool, global_max_pool, knn_graph 10 | from torch_geometric.nn import MLP, PointTransformerConv, knn, radius 11 | from torch_geometric.utils import scatter 12 | from torch_geometric.nn.aggr import MaxAggregation 13 | from torch_geometric.nn.conv import MessagePassing 14 | from torch_geometric.nn.pool.decimation import decimation_indices 15 | from torch_geometric.utils import softmax 16 | 17 | from network.convonet.pointnet import LocalPoolPointnet 18 | 19 | 20 | class PointNet(torch.nn.Module): 21 | """ 22 | PointNet with PointNetConv. 23 | """ 24 | 25 | def __init__(self, latent_dim): 26 | super().__init__() 27 | # conv layers 28 | self.conv1 = PointNetConv(MLP([3 + 3, 64, 64])) 29 | self.conv2 = PointNetConv(MLP([64 + 3, 64, 128, 1024])) 30 | self.lin = MLP([1024, latent_dim], plain_last=False) 31 | 32 | def forward(self, pos, batch): 33 | x, pos, batch = None, pos, batch 34 | edge_index = knn_graph(pos, k=16, batch=batch, loop=True) 35 | 36 | # point-wise features 37 | x = self.conv1(x, pos, edge_index) 38 | x = self.conv2(x, pos, edge_index) 39 | 40 | # instance-wise features 41 | x = global_max_pool(x, batch) 42 | x = self.lin(x) 43 | return x 44 | 45 | 46 | class SAModule(torch.nn.Module): 47 | """ 48 | SA Module for PointNet++. 49 | """ 50 | 51 | def __init__(self, ratio, r, nn): 52 | super().__init__() 53 | self.ratio = ratio 54 | self.r = r 55 | self.conv = PointNetConv(nn, add_self_loops=False) 56 | 57 | def forward(self, x, pos, batch): 58 | idx = fps(pos, batch, ratio=self.ratio) 59 | row, col = radius(pos, pos[idx], self.r, batch, batch[idx], max_num_neighbors=64) 60 | edge_index = torch.stack([col, row], dim=0) 61 | x_dst = None if x is None else x[idx] 62 | x = self.conv((x, x_dst), (pos, pos[idx]), edge_index) 63 | pos, batch = pos[idx], batch[idx] 64 | return x, pos, batch 65 | 66 | 67 | class GlobalSAModule(torch.nn.Module): 68 | """ 69 | Global SA Module for PointNet++. 70 | """ 71 | 72 | def __init__(self, nn): 73 | super().__init__() 74 | self.nn = nn 75 | 76 | def forward(self, x, pos, batch): 77 | x = self.nn(torch.cat([x, pos], dim=1)) 78 | x = global_max_pool(x, batch) 79 | pos = pos.new_zeros((x.size(0), 3)) 80 | batch = torch.arange(x.size(0), device=batch.device) 81 | return x, pos, batch 82 | 83 | 84 | class PointNet2(torch.nn.Module): 85 | """ 86 | PointNet++ with SAModule and GlobalSAModule. 87 | """ 88 | 89 | def __init__(self, latent_dim): 90 | super().__init__() 91 | 92 | # input channels account for both `pos` and node features. 93 | self.sa1_module = SAModule(0.5, 0.2, MLP([3, 64, 64, 128])) 94 | self.sa2_module = SAModule(0.25, 0.4, MLP([128 + 3, 128, 128, 256])) 95 | self.sa3_module = GlobalSAModule(MLP([256 + 3, 256, 512, 1024])) 96 | self.lin = MLP([1024, latent_dim], plain_last=False) 97 | 98 | def forward(self, pos, batch): 99 | # point-wise features 100 | sa0_out = (pos, pos, batch) 101 | sa1_out = self.sa1_module(*sa0_out) 102 | sa2_out = self.sa2_module(*sa1_out) 103 | sa3_out = self.sa3_module(*sa2_out) 104 | x, _, _ = sa3_out 105 | 106 | # instance-wise features 107 | x = global_max_pool(x, batch) 108 | x = self.lin(x) 109 | return x 110 | 111 | 112 | class PointCNN(torch.nn.Module): 113 | """ 114 | PointCNN with XConv. 115 | """ 116 | 117 | def __init__(self, latent_dim): 118 | super().__init__() 119 | 120 | self.conv1 = XConv(3, 32, dim=3, kernel_size=8, dilation=2) 121 | self.conv2 = XConv(32, 128, dim=3, kernel_size=12, dilation=2) 122 | self.conv3 = XConv(128, 256, dim=3, kernel_size=16, dilation=1) 123 | self.conv4 = XConv(256, 256, dim=3, kernel_size=16, dilation=2) 124 | self.conv5 = XConv(256, 256, dim=3, kernel_size=16, dilation=3) 125 | self.conv6 = XConv(256, 690, dim=3, kernel_size=16, dilation=4) 126 | self.lin = MLP([690, latent_dim], plain_last=False) 127 | 128 | def forward(self, pos, batch): 129 | # point-wise features 130 | x, pos, batch = None, pos, batch 131 | x = self.conv1(x, pos, batch) 132 | idx = fps(pos, batch, ratio=0.375) 133 | x, pos, batch = x[idx], pos[idx], batch[idx] 134 | 135 | x = self.conv2(x, pos, batch) 136 | idx = fps(pos, batch, ratio=0.325) 137 | x, pos, batch = x[idx], pos[idx], batch[idx] 138 | 139 | x = self.conv3(x, pos, batch) 140 | x = self.conv4(x, pos, batch) 141 | x = self.conv5(x, pos, batch) 142 | x = self.conv6(x, pos, batch) 143 | 144 | # instance-wise features 145 | x = global_max_pool(x, batch) 146 | x = self.lin(x) 147 | return x 148 | 149 | 150 | class DGCNN(torch.nn.Module): 151 | """ 152 | DGCNN with DynamicEdgeConv. 153 | """ 154 | 155 | def __init__(self, latent_dim): 156 | super().__init__() 157 | 158 | self.conv1 = DynamicEdgeConv(Sequential(Linear(3 * 2, 64, bias=True), LeakyReLU(negative_slope=0.02)), k=20) 159 | self.conv2 = DynamicEdgeConv(Sequential(Linear(64 * 2, 64, bias=True), LeakyReLU(negative_slope=0.02)), k=20) 160 | self.conv3 = DynamicEdgeConv(Sequential(Linear(64 * 2, 128, bias=True), LeakyReLU(negative_slope=0.02)), k=20) 161 | self.conv4 = DynamicEdgeConv(Sequential(Linear(128 * 2, 256, bias=True), LeakyReLU(negative_slope=0.02)), k=20) 162 | self.lin = MLP([512, latent_dim], plain_last=False) 163 | 164 | def forward(self, pos, batch): 165 | x, batch = pos, batch 166 | 167 | # point-wise features 168 | x1 = self.conv1(x, batch) 169 | x2 = self.conv2(x1, batch) 170 | x3 = self.conv3(x2, batch) 171 | x4 = self.conv4(x3, batch) 172 | x = torch.concat([x1, x2, x3, x4], dim=-1) 173 | 174 | # instance-wise features 175 | x = global_max_pool(x, batch) 176 | x = self.lin(x) 177 | return x 178 | 179 | 180 | class SpatialTransformer(torch.nn.Module): 181 | """ 182 | Optional spatial transform block in DGCNN to align a point set to a canonical space. 183 | """ 184 | 185 | def __init__(self, k=16): 186 | super().__init__() 187 | # to estimate the 3 x 3 matrix, a tensor concatenating the coordinates of each point 188 | # and the coordinate differences between its k neighboring points is used. 189 | self.k = k 190 | self.mlp = Sequential(Linear(k * 6, 1024), ReLU(), Linear(1024, 256), ReLU(), Linear(256, 9)) 191 | 192 | def forward(self, pos, batch): 193 | # pos: (num_batch_points, 3) 194 | edge_index = knn_graph(pos, k=self.k, batch=batch, loop=True).to(batch.device) 195 | neighbours = pos[edge_index[0].reshape([-1, 16])] # (num_batch_points, k, 3) 196 | pos = torch.unsqueeze(pos, dim=1) # (num_batch_points, k, 6) 197 | # concatenating the coordinates of each point and the coordinate differences between its k neighboring points 198 | x = torch.concat([pos.repeat(1, self.k, 1), pos - neighbours], dim=2) # (num_batch_points, k, 6) 199 | x = x.reshape([x.shape[0], -1]) # (num_batch_points, k * 6) 200 | x = self.mlp(x) # (num_batch_points, 9) 201 | x = global_mean_pool(x, batch) # (batch_size, 9) 202 | x = x[batch].reshape([-1, 3, 3]) # (num_batch_points, 3, 3) 203 | x = torch.squeeze(torch.bmm(pos, x)) # (num_batch_points, 3) 204 | return x 205 | 206 | 207 | class TransformerBlock(torch.nn.Module): 208 | """ 209 | Transformer block for PointTransformer. 210 | """ 211 | 212 | def __init__(self, in_channels, out_channels): 213 | super().__init__() 214 | self.lin_in = Linear(in_channels, in_channels) 215 | self.lin_out = Linear(out_channels, out_channels) 216 | 217 | self.pos_nn = MLP([3, 64, out_channels], norm=None, plain_last=False) 218 | self.attn_nn = MLP([out_channels, 64, out_channels], norm=None, plain_last=False) 219 | self.transformer = PointTransformerConv(in_channels, out_channels, pos_nn=self.pos_nn, attn_nn=self.attn_nn) 220 | 221 | def forward(self, x, pos, edge_index): 222 | x = self.lin_in(x).relu() 223 | x = self.transformer(x, pos, edge_index) 224 | x = self.lin_out(x).relu() 225 | return x 226 | 227 | 228 | class TransitionDown(torch.nn.Module): 229 | """ 230 | TransitionDown for PointTransformer. 231 | Samples the input point cloud by a ratio percentage to reduce 232 | cardinality and uses an MLP to augment features dimensionality. 233 | """ 234 | 235 | def __init__(self, in_channels, out_channels, ratio=0.25, k=16): 236 | super().__init__() 237 | self.k = k 238 | self.ratio = ratio 239 | self.mlp = MLP([in_channels, out_channels], plain_last=False) 240 | 241 | def forward(self, x, pos, batch): 242 | # FPS sampling 243 | id_clusters = fps(pos, ratio=self.ratio, batch=batch) 244 | 245 | # compute k-nearest points for each cluster 246 | sub_batch = batch[id_clusters] if batch is not None else None 247 | 248 | # beware of self loop 249 | id_k_neighbor = knn(pos, pos[id_clusters], k=self.k, batch_x=batch, batch_y=sub_batch) 250 | 251 | # transformation of features through a simple MLP 252 | x = self.mlp(x) 253 | 254 | # Max pool onto each cluster the features from knn in points 255 | x_out = scatter(x[id_k_neighbor[1]], id_k_neighbor[0], dim=0, dim_size=id_clusters.size(0), reduce='max') 256 | 257 | # keep only the clusters and their max-pooled features 258 | sub_pos, out = pos[id_clusters], x_out 259 | return out, sub_pos, sub_batch 260 | 261 | 262 | class PointTransformer(torch.nn.Module): 263 | """ 264 | PointTransformer. 265 | """ 266 | 267 | def __init__(self, latent_dim, k=16): 268 | super().__init__() 269 | self.k = k 270 | 271 | # dummy feature is created if there is none given 272 | in_channels = 1 273 | 274 | # hidden channels 275 | dim_model = [32, 64, 128, 256, 512] 276 | 277 | # first block 278 | self.mlp_input = MLP([in_channels, dim_model[0]], plain_last=False) 279 | self.transformer_input = TransformerBlock(in_channels=dim_model[0], out_channels=dim_model[0]) 280 | 281 | # backbone layers 282 | self.transformers_down = torch.nn.ModuleList() 283 | self.transition_down = torch.nn.ModuleList() 284 | 285 | for i in range(len(dim_model) - 1): 286 | # Add Transition Down block followed by a Transformer block 287 | self.transition_down.append( 288 | TransitionDown(in_channels=dim_model[i], out_channels=dim_model[i + 1], k=self.k)) 289 | self.transformers_down.append( 290 | TransformerBlock(in_channels=dim_model[i + 1], out_channels=dim_model[i + 1])) 291 | self.lin = MLP([dim_model[-1], latent_dim], plain_last=False) 292 | 293 | def forward(self, pos, batch=None, x=None): 294 | # add dummy features in case there is none 295 | if x is None: 296 | x = torch.ones((pos.shape[0], 1), device=pos.get_device()) 297 | 298 | # first block 299 | x = self.mlp_input(x) 300 | edge_index = knn_graph(pos, k=self.k, batch=batch) 301 | x = self.transformer_input(x, pos, edge_index) 302 | 303 | # backbone 304 | for i in range(len(self.transformers_down)): 305 | x, pos, batch = self.transition_down[i](x, pos, batch=batch) 306 | 307 | edge_index = knn_graph(pos, k=self.k, batch=batch) 308 | x = self.transformers_down[i](x, pos, edge_index) 309 | 310 | # GlobalAveragePooling 311 | x = global_mean_pool(x, batch) 312 | 313 | # MLP blocks 314 | out = self.lin(x) 315 | return out 316 | 317 | 318 | class SharedMLP(MLP): 319 | """ 320 | 321 | """ 322 | 323 | def __init__(self, *args, **kwargs): 324 | # BN + Act always active even at last layer. 325 | kwargs['plain_last'] = False 326 | # LeakyRelu with 0.2 slope by default. 327 | kwargs['act'] = kwargs.get('act', 'LeakyReLU') 328 | kwargs['act_kwargs'] = kwargs.get('act_kwargs', {'negative_slope': 0.2}) 329 | # BatchNorm with 1 - 0.99 = 0.01 momentum 330 | # and 1e-6 eps by defaut (tensorflow momentum != pytorch momentum) 331 | kwargs['norm_kwargs'] = kwargs.get('norm_kwargs', {'momentum': 0.01, 'eps': 1e-6}) 332 | super().__init__(*args, **kwargs) 333 | 334 | 335 | class LocalFeatureAggregation(MessagePassing): 336 | """Positional encoding of points in a neighborhood.""" 337 | 338 | def __init__(self, channels): 339 | super().__init__(aggr='add') 340 | self.mlp_encoder = SharedMLP([10, channels // 2]) 341 | self.mlp_attention = SharedMLP([channels, channels], bias=False, 342 | act=None, norm=None) 343 | self.mlp_post_attention = SharedMLP([channels, channels]) 344 | 345 | def forward(self, edge_index, x, pos): 346 | out = self.propagate(edge_index, x=x, pos=pos) # N, d_out 347 | out = self.mlp_post_attention(out) # N, d_out 348 | return out 349 | 350 | def message(self, x_j, pos_i, pos_j, index): 351 | """Local Spatial Encoding (locSE) and attentive pooling of features. 352 | Args: 353 | x_j (Tensor): neighbors features (K,d) 354 | pos_i (Tensor): centroid position (repeated) (K,3) 355 | pos_j (Tensor): neighbors positions (K,3) 356 | index (Tensor): index of centroid positions 357 | (e.g. [0,...,0,1,...,1,...,N,...,N]) 358 | returns: 359 | (Tensor): locSE weighted by feature attention scores. 360 | """ 361 | # Encode local neighborhood structural information 362 | pos_diff = pos_j - pos_i 363 | distance = torch.sqrt((pos_diff * pos_diff).sum(1, keepdim=True)) 364 | relative_infos = torch.cat([pos_i, pos_j, pos_diff, distance], dim=1) # N * K, d 365 | local_spatial_encoding = self.mlp_encoder(relative_infos) # N * K, d 366 | local_features = torch.cat([x_j, local_spatial_encoding], dim=1) # N * K, 2d 367 | 368 | # Attention will weight the different features of x 369 | # along the neighborhood dimension. 370 | att_features = self.mlp_attention(local_features) # N * K, d_out 371 | att_scores = softmax(att_features, index=index) # N * K, d_out 372 | 373 | return att_scores * local_features # N * K, d_out 374 | 375 | 376 | class DilatedResidualBlock(torch.nn.Module): 377 | """ 378 | Dilated residual block for RandLANet. 379 | """ 380 | 381 | def __init__(self, num_neighbors, d_in: int, d_out: int): 382 | super().__init__() 383 | self.num_neighbors = num_neighbors 384 | self.d_in = d_in 385 | self.d_out = d_out 386 | 387 | # MLP on input 388 | self.mlp1 = SharedMLP([d_in, d_out // 8]) 389 | # MLP on input, and the result is summed with the output of mlp2 390 | self.shortcut = SharedMLP([d_in, d_out], act=None) 391 | # MLP on output 392 | self.mlp2 = SharedMLP([d_out // 2, d_out], act=None) 393 | 394 | self.lfa1 = LocalFeatureAggregation(d_out // 4) 395 | self.lfa2 = LocalFeatureAggregation(d_out // 2) 396 | self.lrelu = torch.nn.LeakyReLU(**{'negative_slope': 0.2}) 397 | 398 | def forward(self, x, pos, batch): 399 | edge_index = knn_graph(pos, self.num_neighbors, batch=batch, loop=True) 400 | 401 | shortcut_of_x = self.shortcut(x) # N, d_out 402 | x = self.mlp1(x) # N, d_out//8 403 | x = self.lfa1(edge_index, x, pos) # N, d_out//2 404 | x = self.lfa2(edge_index, x, pos) # N, d_out//2 405 | x = self.mlp2(x) # N, d_out 406 | x = self.lrelu(x + shortcut_of_x) # N, d_out 407 | 408 | return x, pos, batch 409 | 410 | 411 | def decimate(tensors, ptr: Tensor, decimation_factor: int): 412 | """ 413 | Decimates each element of the given tuple of tensors for RandLANet. 414 | """ 415 | idx_decim, ptr_decim = decimation_indices(ptr, decimation_factor) 416 | tensors_decim = tuple(tensor[idx_decim] for tensor in tensors) 417 | return tensors_decim, ptr_decim 418 | 419 | 420 | class RandLANet(torch.nn.Module): 421 | """ 422 | An adaptation of RandLA-Net for point cloud encoding, which was not addressed in the paper: 423 | RandLA-Net: Efficient Semantic Segmentation of Large-Scale Point Clouds. 424 | """ 425 | 426 | def __init__(self, latent_dim, decimation: int = 4, num_neighbors: int = 16): 427 | super().__init__() 428 | self.decimation = decimation 429 | self.fc0 = Linear(in_features=3, out_features=8) 430 | # 2 DilatedResidualBlock converges better than 4 on ModelNet 431 | self.block1 = DilatedResidualBlock(num_neighbors, 8, 32) 432 | self.block2 = DilatedResidualBlock(num_neighbors, 32, 128) 433 | self.block3 = DilatedResidualBlock(num_neighbors, 128, 512) 434 | self.mlp1 = SharedMLP([512, 512]) 435 | self.max_agg = MaxAggregation() 436 | self.mlp2 = Linear(512, latent_dim) 437 | 438 | def forward(self, pos, batch): 439 | x = pos 440 | ptr = torch.where(batch[1:] - batch[:-1])[0] + 1 # use ptr elsewhere 441 | ptr = torch.cat([torch.tensor([0]).cuda(), ptr]) 442 | ptr = torch.cat([ptr, torch.tensor([len(batch)]).cuda()]) 443 | b1 = self.block1(self.fc0(x), pos, batch) 444 | b1_decimated, ptr1 = decimate(b1, ptr, self.decimation) 445 | 446 | b2 = self.block2(*b1_decimated) 447 | b2_decimated, ptr2 = decimate(b2, ptr1, self.decimation) 448 | 449 | b3 = self.block3(*b2_decimated) 450 | b3_decimated, _ = decimate(b3, ptr2, self.decimation) 451 | 452 | x = self.mlp1(b3_decimated[0]) 453 | x = self.max_agg(x, b3_decimated[2]) 454 | x = self.mlp2(x) 455 | 456 | return x 457 | 458 | 459 | class Encoder(torch.nn.Module): 460 | """ 461 | Point cloud encoder. 462 | """ 463 | def __init__(self, backbone, latent_dim, use_spatial_transformer=False, convonet_kwargs=None): 464 | super().__init__() 465 | self.backbone = backbone 466 | 467 | # spatial transformer placeholder 468 | self.use_spatial_transformer = False 469 | if use_spatial_transformer: 470 | self.use_spatial_transformer = True 471 | self.spatial_transformer = SpatialTransformer() 472 | 473 | backbone_mapping = { 474 | 'pointnet': PointNet, 475 | 'pointnet2': PointNet2, 476 | 'pointcnn': PointCNN, 477 | 'dgcnn': DGCNN, 478 | 'randlanet': RandLANet, 479 | 'pointtransformer': PointTransformer, 480 | 'convonet': lambda d: LocalPoolPointnet(**convonet_kwargs)} 481 | self.backbone_key = backbone.casefold() 482 | if self.backbone_key in backbone_mapping: 483 | self.network = backbone_mapping[self.backbone_key](latent_dim) 484 | else: 485 | raise ValueError(f'Unexpected backbone: {self.backbone_key}') 486 | 487 | def forward(self, points, batch_points): 488 | # points: (total_num_points, 3) 489 | # batch_points: (total_num_points) 490 | 491 | # spatial transformation 492 | if self.use_spatial_transformer: 493 | points = self.spatial_transformer(points, batch_points) 494 | 495 | if self.backbone_key == 'convonet': 496 | # reshape to split it into a single tensor 497 | points_split = points.view(batch_points[-1] + 1, -1, 3) 498 | return self.network(points_split) 499 | else: 500 | return self.network(points, batch_points) 501 | -------------------------------------------------------------------------------- /network/gnn.py: -------------------------------------------------------------------------------- 1 | """ 2 | Graph neural networks. 3 | """ 4 | 5 | import torch 6 | import torch.nn.functional as F 7 | from torch_geometric.nn import GCNConv, TransformerConv, TAGConv, MLP 8 | 9 | 10 | class GCN(torch.nn.Module): 11 | """ 12 | Graph convolution network from the "Semi-supervised 13 | Classification with Graph Convolutional Networks" paper. 14 | """ 15 | 16 | def __init__(self, num_features, num_classes, dropout=False): 17 | super().__init__() 18 | self.dropout = dropout 19 | self.conv1 = GCNConv(num_features, 16) 20 | self.conv2 = GCNConv(16, 16) 21 | self.conv3 = GCNConv(16, 16) 22 | self.mlp = MLP([16, 64, 16, num_classes], act='leaky_relu', act_kwargs={'negative_slope': 0.02}) 23 | 24 | def forward(self, x, edge_index): 25 | # GCN layers 26 | x = self.conv1(x, edge_index) 27 | x = F.relu(x) 28 | x = F.dropout(x, training=self.training) if self.dropout else x 29 | x = self.conv2(x, edge_index) 30 | x = F.relu(x) 31 | x = F.dropout(x, training=self.training) if self.dropout else x 32 | x = self.conv3(x, edge_index) 33 | 34 | # MLP layers 35 | x = self.mlp(x) # (num_batch_cells, 2) 36 | 37 | # return the log softmax value 38 | return F.log_softmax(x, dim=1) 39 | 40 | 41 | class TransformerGCN(torch.nn.Module): 42 | """ 43 | Graph transformer operator from the "Masked Label Prediction: 44 | Unified Message Passing Model for Semi-Supervised Classification" paper. 45 | """ 46 | 47 | def __init__(self, num_features, num_classes, dropout=False): 48 | super().__init__() 49 | self.dropout = dropout 50 | self.conv1 = TransformerConv(num_features, 16) 51 | self.conv2 = TransformerConv(16, 16) 52 | self.conv3 = TransformerConv(16, 16) 53 | self.mlp = MLP([16, 64, 16, num_classes], act='leaky_relu', act_kwargs={'negative_slope': 0.02}) 54 | 55 | def forward(self, x, edge_index): 56 | # Graph transformer layers 57 | x = self.conv1(x, edge_index) 58 | x = F.relu(x) 59 | x = F.dropout(x, training=self.training) if self.dropout else x 60 | x = self.conv2(x, edge_index) 61 | x = F.relu(x) 62 | x = F.dropout(x, training=self.training) if self.dropout else x 63 | x = self.conv3(x, edge_index) 64 | 65 | # MLP layers 66 | x = self.mlp(x) # (num_batch_cells, 2) 67 | 68 | # return the log softmax value 69 | return F.log_softmax(x, dim=1) 70 | 71 | 72 | class TAGCN(torch.nn.Module): 73 | """ 74 | Topology adaptive graph convolutional network operator from 75 | the "Topology Adaptive Graph Convolutional Networks" paper. 76 | """ 77 | 78 | def __init__(self, num_features, num_classes, dropout=False): 79 | super().__init__() 80 | self.dropout = dropout 81 | self.conv1 = TAGConv(num_features, 16) 82 | self.conv2 = TAGConv(16, 16) 83 | self.conv3 = TAGConv(16, 16) 84 | self.mlp = MLP([16, 64, 16, num_classes], act='leaky_relu', act_kwargs={'negative_slope': 0.02}) 85 | 86 | def forward(self, x, edge_index): 87 | # Topology Adaptive Graph Convolutional layers 88 | x = self.conv1(x, edge_index) 89 | x = F.leaky_relu(x, negative_slope=0.02) 90 | x = F.dropout(x, training=self.training) if self.dropout else x 91 | x = self.conv2(x, edge_index) 92 | x = F.leaky_relu(x, negative_slope=0.02) 93 | x = F.dropout(x, training=self.training) if self.dropout else x 94 | x = self.conv3(x, edge_index) 95 | 96 | # MLP layers 97 | x = self.mlp(x) # (num_batch_cells, 2) 98 | 99 | # return the log softmax value 100 | return F.log_softmax(x, dim=1) 101 | 102 | 103 | class GNN(torch.nn.Module): 104 | def __init__(self, backend, num_features, num_classes, **kwargs): 105 | super().__init__() 106 | self.backend = backend 107 | 108 | if backend.casefold() == 'GCN'.casefold(): 109 | self.network = GCN(num_features, num_classes, *kwargs) 110 | elif backend.casefold() == 'TransformerGCN'.casefold(): 111 | self.network = TransformerGCN(num_features, num_classes, *kwargs) 112 | elif backend.casefold() == 'TAGCN'.casefold(): 113 | self.network = TAGCN(num_features, num_classes, *kwargs) 114 | else: 115 | raise ValueError(f'Expected backend GCN, GraphTransformer or TAGCN, got {backend} instead') 116 | 117 | def forward(self, x, data): 118 | x = torch.squeeze(x) # (num_cells_in_batch, num_queries_per_cell) 119 | return self.network(x, data) 120 | -------------------------------------------------------------------------------- /network/loss.py: -------------------------------------------------------------------------------- 1 | """ 2 | Loss functions. 3 | """ 4 | 5 | import torch 6 | import torch.nn.functional as F 7 | 8 | 9 | def bce_loss(inputs, targets, weight=(1., 1.)): 10 | """ 11 | Binary cross entropy loss for PolyGNN prediction. 12 | """ 13 | pred = inputs.argmax(dim=1) 14 | total = len(pred) 15 | correct = torch.sum((pred.clone().detach() == targets).long()) 16 | accuracy = correct / total 17 | loss = F.nll_loss(inputs, targets, weight=torch.tensor(weight, device=inputs.device)) 18 | ratio = torch.sum(pred) / total 19 | return loss, accuracy, ratio, total, correct 20 | 21 | 22 | def focal_loss( 23 | inputs: torch.Tensor, 24 | targets: torch.Tensor, 25 | alpha: float = 0.25, 26 | gamma: float = 2, 27 | reduction: str = "mean", 28 | ) -> (torch.Tensor, torch.Tensor, torch.Tensor): 29 | """ 30 | Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002. 31 | Source from https://pytorch.org/vision/main/_modules/torchvision/ops/focal_loss.html. 32 | 33 | Args: 34 | inputs (Tensor): A float tensor of arbitrary shape. 35 | The predictions for each example. 36 | targets (Tensor): A float tensor with the same shape as inputs. Stores the binary 37 | classification label for each element in inputs 38 | (0 for the negative class and 1 for the positive class). 39 | alpha (float): Weighting factor in range (0,1) to balance 40 | positive vs negative examples or -1 for ignore. Default: ``0.25``. 41 | gamma (float): Exponent of the modulating factor (1 - p_t) to 42 | balance easy vs hard examples. Default: ``2``. 43 | reduction (string): ``'none'`` | ``'mean'`` | ``'sum'`` 44 | ``'none'``: No reduction will be applied to the output. 45 | ``'mean'``: The output will be averaged. 46 | ``'sum'``: The output will be summed. Default: ``'none'``. 47 | Returns: 48 | Loss tensors loss, accuracy and ratio. 49 | """ 50 | ce_loss = F.nll_loss(inputs, targets, reduction="none") 51 | p_t = torch.exp(-ce_loss) 52 | p_t = torch.clamp(p_t, max=1.0) # clip to avoid NaN 53 | loss = ce_loss * ((1 - p_t) ** gamma) 54 | 55 | if alpha >= 0: 56 | alpha_t = alpha * targets + (1 - alpha) * (1 - targets) 57 | loss = alpha_t * loss 58 | 59 | # Check reduction option and return loss accordingly 60 | if reduction == "none": 61 | pass 62 | elif reduction == "mean": 63 | loss = loss.mean() 64 | elif reduction == "sum": 65 | loss = loss.sum() 66 | else: 67 | raise ValueError( 68 | f"Invalid Value for arg 'reduction': '{reduction} \n Supported reduction modes: 'none', 'mean', 'sum'" 69 | ) 70 | 71 | pred = inputs.argmax(dim=1) 72 | total = len(pred) 73 | correct = torch.sum((pred.clone().detach() == targets).long()) 74 | accuracy = correct / total 75 | ratio = torch.sum(pred) / total 76 | return loss, accuracy, ratio, total, correct 77 | -------------------------------------------------------------------------------- /network/polygnn.py: -------------------------------------------------------------------------------- 1 | """ 2 | PolyGNN architecture. 3 | """ 4 | 5 | import torch 6 | import torch.nn.functional as F 7 | 8 | from network import Encoder, Decoder, GNN 9 | 10 | 11 | class PolyGNN(torch.nn.Module): 12 | """ 13 | PolyGNN. 14 | """ 15 | def __init__(self, cfg): 16 | super().__init__() 17 | self.gnn = cfg.gnn 18 | self.dataset_suffix = cfg.dataset_suffix 19 | if cfg.sample.strategy == 'grid': 20 | self.points_suffix = f'_{cfg.sample.resolution}' 21 | elif cfg.sample.strategy == 'random': 22 | self.points_suffix = f'_{cfg.sample.length}' 23 | else: 24 | self.points_suffix = '' 25 | 26 | # encoder 27 | self.encoder = Encoder(backbone=cfg.encoder, latent_dim=cfg.latent_dim_light, 28 | use_spatial_transformer=cfg.use_spatial_transformer, convonet_kwargs=cfg.convonet_kwargs) 29 | 30 | # decoder 31 | latent_dim = cfg.latent_dim_light if cfg.decoder.casefold() == 'MLP'.casefold() else cfg.latent_dim_conv 32 | self.decoder = Decoder(backbone=cfg.decoder, latent_dim=latent_dim, num_queries=cfg.num_queries) 33 | 34 | # GNN 35 | if cfg.gnn is not None: 36 | self.gnn = GNN(backend=cfg.gnn, num_features=latent_dim, num_classes=2, dropout=cfg.dropout) 37 | else: 38 | self.gnn = None 39 | 40 | def forward(self, data): 41 | x = self.encoder(data[f'points{self.points_suffix}'], data[f'batch_points{self.points_suffix}']) 42 | # x: {(num_graphs_in_batch, 128, 128, 128) * 3} from ConvOEncoder 43 | # x: (num_graphs_in_batch, 256) from other encoders 44 | 45 | x = self.decoder(x, data.queries, data.batch) 46 | # x: (num_cells_in_batch, latent_dim) 47 | 48 | if self.gnn: 49 | x = self.gnn(x, data.edge_index) 50 | # x: (num_cells_in_batch, 2) 51 | else: 52 | # cell-wise voting from query points 53 | x = torch.mean(x, dim=1) 54 | # x: (num_cells_in_batch, 1) 55 | x = self.lin(x) 56 | # x: (num_cells_in_batch, 2) 57 | x = F.log_softmax(x, dim=1) 58 | # x: (num_cells_in_batch, 2) 59 | return x 60 | -------------------------------------------------------------------------------- /reconstruct.py: -------------------------------------------------------------------------------- 1 | """ 2 | Conversion from polyhedral labels to tangible mesh file. 3 | """ 4 | 5 | import pickle 6 | import glob 7 | import multiprocessing 8 | from pathlib import Path 9 | 10 | from tqdm import tqdm 11 | import numpy as np 12 | import hydra 13 | from omegaconf import DictConfig 14 | 15 | from abspy import AdjacencyGraph 16 | from utils import attach_to_log 17 | 18 | 19 | def reconstruct_from_numpy(args): 20 | """ 21 | Reconstruct from numpy and cell complex. 22 | args[0]: numpy_filepath 23 | args[1]: complex_filepath 24 | args[2]: mesh_filepath 25 | args[3]: reconstruction type ('cell' or 'mesh') 26 | args[4]: seal by excluding boundary cells 27 | """ 28 | with open(args[1], 'rb') as handle: 29 | cell_complex = pickle.load(handle) 30 | 31 | pred = np.load(args[0]) 32 | 33 | # exclude boundary cells 34 | if args[4]: 35 | cells_boundary = cell_complex.cells_boundary() 36 | pred[cells_boundary] = 0 37 | 38 | indices_cells = np.where(pred)[0] 39 | 40 | if args[3] == 'mesh': 41 | adjacency_graph = AdjacencyGraph(cell_complex.graph, quiet=True) 42 | adjacency_graph.reachable = adjacency_graph.to_uids(indices_cells) 43 | adjacency_graph.non_reachable = np.setdiff1d(adjacency_graph.uid, adjacency_graph.reachable).tolist() 44 | adjacency_graph.save_surface_obj(args[2], cells=cell_complex.cells) 45 | 46 | elif args[3] == 'cell': 47 | if len(indices_cells) > 0: 48 | cell_complex.save_obj(args[2], indices_cells=indices_cells, use_mtl=True) 49 | else: 50 | raise ValueError(f'Unexpected reconstruction type: {args[3]}') 51 | 52 | 53 | @hydra.main(config_path='./conf', config_name='config', version_base='1.2') 54 | def multi_reconstruct_from_numpy(cfg: DictConfig): 55 | """ 56 | Reconstruct from numpy and cell complex with multiprocessing. 57 | """ 58 | # initialize logging 59 | logger = attach_to_log() 60 | 61 | # numpy filenames 62 | filenames_numpy = glob.glob(f'{cfg.output_dir}' + '/*.npy') 63 | args = [] 64 | for filename_numpy in filenames_numpy: 65 | stem = Path(filename_numpy).stem 66 | filename_complex = Path(cfg.complex_dir) / (stem + '.cc') 67 | filename_output = Path(filename_numpy).with_suffix('.obj') 68 | if not filename_output.exists(): 69 | args.append((filename_numpy, filename_complex, filename_output, cfg.reconstruct.type, cfg.reconstruct.seal)) 70 | 71 | logger.info('Start reconstruction from numpy and cell complex') 72 | with multiprocessing.Pool(processes=cfg.num_workers if cfg.num_workers else multiprocessing.cpu_count()) as pool: 73 | # call with multiprocessing 74 | for _ in tqdm(pool.imap_unordered(reconstruct_from_numpy, args), desc='reconstruction', total=len(args)): 75 | pass 76 | 77 | # exit with a message 78 | logger.info('Done reconstruction from numpy and cell complex') 79 | 80 | 81 | if __name__ == '__main__': 82 | multi_reconstruct_from_numpy() 83 | -------------------------------------------------------------------------------- /remap.py: -------------------------------------------------------------------------------- 1 | """ 2 | Remap normalized instances to global CRS. 3 | """ 4 | 5 | import glob 6 | from pathlib import Path 7 | import multiprocessing 8 | import logging 9 | 10 | import hydra 11 | from omegaconf import DictConfig 12 | from tqdm import tqdm 13 | 14 | from utils import reverse_normalise_mesh, reverse_normalise_cloud, normalise_mesh 15 | 16 | 17 | logger = logging.getLogger("trimesh") 18 | logger.setLevel(logging.WARNING) 19 | 20 | 21 | @hydra.main(config_path='./conf', config_name='config', version_base='1.2') 22 | def normalize_meshes(cfg: DictConfig): 23 | """ 24 | Normalize meshes. 25 | 26 | cfg: DictConfig 27 | Hydra configuration 28 | """ 29 | args = [] 30 | input_filenames = glob.glob(f'{cfg.output_dir}/*.obj') 31 | output_dir = Path(cfg.output_dir) / 'normalized' 32 | output_dir.mkdir(exist_ok=True) 33 | for input_filename in input_filenames: 34 | base_filename = Path(input_filename).name 35 | reference_filename = (Path(cfg.reference_dir) / base_filename).with_suffix('.obj') 36 | output_filename = output_dir / base_filename 37 | args.append((input_filename, reference_filename, output_filename, 'scene', cfg.reconstruct.offset, True, False)) 38 | print('start processing') 39 | with multiprocessing.Pool(processes=cfg.num_workers if cfg.num_workers else multiprocessing.cpu_count()) as pool: 40 | # call with multiprocessing 41 | for _ in tqdm(pool.imap_unordered(normalise_mesh, args), desc='Normalizing meshes', total=len(args)): 42 | pass 43 | 44 | 45 | # normalize clouds as meshes 46 | normalize_clouds = normalize_meshes 47 | 48 | 49 | @hydra.main(config_path='./conf', config_name='config', version_base='1.2') 50 | def remap_meshes(cfg: DictConfig): 51 | """ 52 | Remap normalized buildings to global CRS. 53 | 54 | Parameters 55 | ---------- 56 | cfg: DictConfig 57 | Hydra configuration 58 | """ 59 | args = [] 60 | input_filenames = glob.glob(cfg.output_dir + '/*.obj') 61 | output_dir = Path(cfg.output_dir) / 'global' 62 | output_dir.mkdir(exist_ok=True) 63 | for input_filename in input_filenames: 64 | base_filename = Path(input_filename).name 65 | reference_filename = Path(cfg.reference_dir) / base_filename 66 | output_filename = output_dir / base_filename 67 | args.append((input_filename, reference_filename, output_filename, 'scene', cfg.reconstruct.offset, cfg.reconstruct.scale, cfg.reconstruct.translate)) 68 | with multiprocessing.Pool(processes=cfg.num_workers if cfg.num_workers else multiprocessing.cpu_count()) as pool: 69 | # call with multiprocessing 70 | for _ in tqdm(pool.imap_unordered(reverse_normalise_mesh, args), desc='Remapping meshes', total=len(args)): 71 | pass 72 | 73 | 74 | @hydra.main(config_path='./conf', config_name='config', version_base='1.2') 75 | def remap_clouds(cfg: DictConfig): 76 | """ 77 | Remap normalized point clouds to global CRS. 78 | 79 | Parameters 80 | ---------- 81 | cfg: DictConfig 82 | Hydra configuration 83 | """ 84 | args = [] 85 | input_filenames = glob.glob(f'{cfg.data_dir}/raw/test_cloud_normalised_ply/*.ply') 86 | output_dir = Path(cfg.output_dir) / 'global_clouds' 87 | output_dir.mkdir(exist_ok=True) 88 | for input_filename in input_filenames: 89 | base_filename = Path(input_filename).name 90 | reference_filename = (Path(cfg.reference_dir) / base_filename).with_suffix('.obj') 91 | output_filename = output_dir / base_filename 92 | args.append((input_filename, reference_filename, output_filename, 'scene', cfg.reconstruct.offset)) 93 | print('start processing') 94 | with multiprocessing.Pool(processes=cfg.num_workers if cfg.num_workers else multiprocessing.cpu_count()) as pool: 95 | # call with multiprocessing 96 | for _ in tqdm(pool.imap_unordered(reverse_normalise_cloud, args), desc='Remapping clouds', total=len(args)): 97 | pass 98 | 99 | 100 | if __name__ == '__main__': 101 | remap_meshes() 102 | -------------------------------------------------------------------------------- /stats.py: -------------------------------------------------------------------------------- 1 | """ 2 | Generate statistics for mesh reconstruction results. 3 | Copied from https://github.com/ErlerPhilipp/points2surf/blob/master/source/base/evaluation.py. 4 | """ 5 | 6 | import os 7 | import subprocess 8 | import multiprocessing 9 | 10 | import numpy as np 11 | import hydra 12 | from omegaconf import DictConfig 13 | 14 | 15 | def make_dir_for_file(file): 16 | """ 17 | Make dir for file. 18 | """ 19 | file_dir = os.path.dirname(file) 20 | if file_dir != '': 21 | if not os.path.exists(file_dir): 22 | try: 23 | os.makedirs(os.path.dirname(file)) 24 | except OSError as exc: # Guard against race condition 25 | raise 26 | 27 | 28 | def mp_worker(call): 29 | """ 30 | Small function that starts a new thread with a system call. Used for thread pooling. 31 | """ 32 | call = call.split(' ') 33 | verbose = call[-1] == '--verbose' 34 | if verbose: 35 | call = call[:-1] 36 | subprocess.run(call) 37 | else: 38 | # subprocess.run(call, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) # suppress outputs 39 | subprocess.run(call, stdout=subprocess.DEVNULL) 40 | 41 | 42 | def start_process_pool(worker_function, parameters, num_processes, timeout=None): 43 | 44 | if len(parameters) > 0: 45 | if num_processes <= 1: 46 | print('Running loop for {} with {} calls on {} workers'.format( 47 | str(worker_function), len(parameters), num_processes)) 48 | results = [] 49 | for c in parameters: 50 | results.append(worker_function(*c)) 51 | return results 52 | print('Running loop for {} with {} calls on {} subprocess workers'.format( 53 | str(worker_function), len(parameters), num_processes)) 54 | with multiprocessing.Pool(processes=num_processes, maxtasksperchild=1) as pool: 55 | results = pool.starmap(worker_function, parameters) 56 | return results 57 | else: 58 | return None 59 | 60 | 61 | def _chamfer_distance_single_file(file_in, file_ref, samples_per_model, num_processes=1): 62 | # http://graphics.stanford.edu/courses/cs468-17-spring/LectureSlides/L14%20-%203d%20deep%20learning%20on%20point%20cloud%20representation%20(analysis).pdf 63 | 64 | import trimesh 65 | import trimesh.sample 66 | import sys 67 | import scipy.spatial as spatial 68 | 69 | def sample_mesh(mesh_file, num_samples): 70 | try: 71 | mesh = trimesh.load(mesh_file) 72 | except: 73 | return np.zeros((0, 3)) 74 | samples, face_indices = trimesh.sample.sample_surface_even(mesh, num_samples) 75 | return samples 76 | 77 | try: 78 | new_mesh_samples = sample_mesh(file_in, samples_per_model) 79 | ref_mesh_samples = sample_mesh(file_ref, samples_per_model) 80 | except AttributeError: 81 | # unable to sample 82 | return file_in, file_ref, -1.0 83 | 84 | if new_mesh_samples.shape[0] == 0 or ref_mesh_samples.shape[0] == 0: 85 | return file_in, file_ref, -1.0 86 | 87 | leaf_size = 100 88 | sys.setrecursionlimit(int(max(1000, round(new_mesh_samples.shape[0] / leaf_size)))) 89 | kdtree_new_mesh_samples = spatial.cKDTree(new_mesh_samples, leaf_size) 90 | kdtree_ref_mesh_samples = spatial.cKDTree(ref_mesh_samples, leaf_size) 91 | 92 | ref_new_dist, corr_new_ids = kdtree_new_mesh_samples.query(ref_mesh_samples, 1, workers=num_processes) 93 | new_ref_dist, corr_ref_ids = kdtree_ref_mesh_samples.query(new_mesh_samples, 1, workers=num_processes) 94 | 95 | ref_new_dist_sum = np.sum(ref_new_dist) 96 | new_ref_dist_sum = np.sum(new_ref_dist) 97 | chamfer_dist = ref_new_dist_sum + new_ref_dist_sum 98 | 99 | return file_in, file_ref, chamfer_dist 100 | 101 | 102 | def _hausdorff_distance_directed_single_file(file_in, file_ref, samples_per_model): 103 | import scipy.spatial as spatial 104 | import trimesh 105 | import trimesh.sample 106 | 107 | def sample_mesh(mesh_file, num_samples): 108 | try: 109 | mesh = trimesh.load(mesh_file) 110 | except: 111 | return np.zeros((0, 3)) 112 | samples, face_indices = trimesh.sample.sample_surface_even(mesh, num_samples) 113 | return samples 114 | 115 | try: 116 | new_mesh_samples = sample_mesh(file_in, samples_per_model) 117 | ref_mesh_samples = sample_mesh(file_ref, samples_per_model) 118 | except AttributeError: 119 | # unable to sample 120 | return file_in, file_ref, -1.0 121 | 122 | if new_mesh_samples.shape[0] == 0 or ref_mesh_samples.shape[0] == 0: 123 | return file_in, file_ref, -1.0 124 | 125 | dist, _, _ = spatial.distance.directed_hausdorff(new_mesh_samples, ref_mesh_samples) 126 | return file_in, file_ref, dist 127 | 128 | 129 | def _hausdorff_distance_single_file(file_in, file_ref, samples_per_model): 130 | import scipy.spatial as spatial 131 | import trimesh 132 | import trimesh.sample 133 | 134 | def sample_mesh(mesh_file, num_samples): 135 | try: 136 | mesh = trimesh.load(mesh_file) 137 | except: 138 | return np.zeros((0, 3)) 139 | samples, face_indices = trimesh.sample.sample_surface_even(mesh, num_samples) 140 | return samples 141 | 142 | try: 143 | new_mesh_samples = sample_mesh(file_in, samples_per_model) 144 | ref_mesh_samples = sample_mesh(file_ref, samples_per_model) 145 | except AttributeError: 146 | # unable to sample 147 | return file_in, file_ref, -1.0, -1.0, -1.0 148 | 149 | if new_mesh_samples.shape[0] == 0 or ref_mesh_samples.shape[0] == 0: 150 | return file_in, file_ref, -1.0, -1.0, -1.0 151 | 152 | dist_new_ref, _, _ = spatial.distance.directed_hausdorff(new_mesh_samples, ref_mesh_samples) 153 | dist_ref_new, _, _ = spatial.distance.directed_hausdorff(ref_mesh_samples, new_mesh_samples) 154 | dist = max(dist_new_ref, dist_ref_new) 155 | return file_in, file_ref, dist_new_ref, dist_ref_new, dist 156 | 157 | 158 | def _scale_single_file(file_ref): 159 | import trimesh 160 | if not file_ref.endswith('.obj'): 161 | file_ref = file_ref + '.obj' 162 | mesh = trimesh.load(file_ref) 163 | extents = mesh.extents 164 | scale = extents.max() 165 | return scale 166 | 167 | 168 | def mesh_comparison(new_meshes_dir_abs, ref_meshes_dir_abs, 169 | num_processes, report_name, samples_per_model=10000, dataset_file_abs=None): 170 | if not os.path.isdir(new_meshes_dir_abs): 171 | print('Warning: dir to check doesn\'t exist'.format(new_meshes_dir_abs)) 172 | return 173 | 174 | new_mesh_files = [f for f in os.listdir(new_meshes_dir_abs) 175 | if os.path.isfile(os.path.join(new_meshes_dir_abs, f))] 176 | ref_mesh_files = [f for f in os.listdir(ref_meshes_dir_abs) 177 | if os.path.isfile(os.path.join(ref_meshes_dir_abs, f))] 178 | 179 | if dataset_file_abs is None: 180 | mesh_files_to_compare_set = set(ref_mesh_files) # set for efficient search 181 | else: 182 | if not os.path.isfile(dataset_file_abs): 183 | raise ValueError('File does not exist: {}'.format(dataset_file_abs)) 184 | with open(dataset_file_abs) as f: 185 | mesh_files_to_compare_set = f.readlines() 186 | mesh_files_to_compare_set = [f.replace('\n', '') + '.ply' for f in mesh_files_to_compare_set] 187 | mesh_files_to_compare_set = [f.split('.')[0] for f in mesh_files_to_compare_set] 188 | mesh_files_to_compare_set = set(mesh_files_to_compare_set) 189 | 190 | # # skip if everything is unchanged 191 | # new_mesh_files_abs = [os.path.join(new_meshes_dir_abs, f) for f in new_mesh_files] 192 | # ref_mesh_files_abs = [os.path.join(ref_meshes_dir_abs, f) for f in ref_mesh_files] 193 | # if not utils_files.call_necessary(new_mesh_files_abs + ref_mesh_files_abs, report_name): 194 | # return 195 | 196 | def ref_mesh_for_new_mesh(new_mesh_file: str, all_ref_meshes: list) -> list: 197 | stem_new_mesh_file = new_mesh_file.split('.')[0] 198 | ref_files = list(set([f for f in all_ref_meshes if f.split('.')[0] == stem_new_mesh_file])) 199 | return ref_files 200 | 201 | call_params = [] 202 | for fi, new_mesh_file in enumerate(new_mesh_files): 203 | if new_mesh_file.split('.')[0] in mesh_files_to_compare_set: 204 | new_mesh_file_abs = os.path.join(new_meshes_dir_abs, new_mesh_file) 205 | ref_mesh_files_matching = ref_mesh_for_new_mesh(new_mesh_file, ref_mesh_files) 206 | if len(ref_mesh_files_matching) > 0: 207 | ref_mesh_file_abs = os.path.join(ref_meshes_dir_abs, ref_mesh_files_matching[0]) 208 | call_params.append((new_mesh_file_abs, ref_mesh_file_abs, samples_per_model)) 209 | if len(call_params) == 0: 210 | raise ValueError('Results are empty!') 211 | results_hausdorff = start_process_pool(_hausdorff_distance_single_file, call_params, num_processes) 212 | results = [(r[0], r[1], str(r[2]), str(r[3]), str(r[4])) for r in results_hausdorff] 213 | 214 | call_params = [] 215 | for fi, new_mesh_file in enumerate(new_mesh_files): 216 | if new_mesh_file.split('.')[0] in mesh_files_to_compare_set: 217 | new_mesh_file_abs = os.path.join(new_meshes_dir_abs, new_mesh_file) 218 | ref_mesh_files_matching = ref_mesh_for_new_mesh(new_mesh_file, ref_mesh_files) 219 | if len(ref_mesh_files_matching) > 0: 220 | ref_mesh_file_abs = os.path.join(ref_meshes_dir_abs, ref_mesh_files_matching[0]) 221 | call_params.append((new_mesh_file_abs, ref_mesh_file_abs, samples_per_model)) 222 | results_chamfer = start_process_pool(_chamfer_distance_single_file, call_params, num_processes) 223 | results = [r + (str(results_chamfer[ri][2]),) for ri, r in enumerate(results)] 224 | 225 | # no reference but reconstruction 226 | for fi, new_mesh_file in enumerate(new_mesh_files): 227 | if new_mesh_file.split('.')[0] not in mesh_files_to_compare_set: 228 | if dataset_file_abs is None: 229 | new_mesh_file_abs = os.path.join(new_meshes_dir_abs, new_mesh_file) 230 | ref_mesh_files_matching = ref_mesh_for_new_mesh(new_mesh_file, ref_mesh_files) 231 | if len(ref_mesh_files_matching) > 0: 232 | reference_mesh_file_abs = os.path.join(ref_meshes_dir_abs, ref_mesh_files_matching[0]) 233 | results.append((new_mesh_file_abs, reference_mesh_file_abs, str(-2), str(-2), str(-2), str(-2))) 234 | else: 235 | mesh_files_to_compare_set.remove(new_mesh_file.split('.')[0]) 236 | 237 | # no reconstruction but reference 238 | for ref_without_new_mesh in mesh_files_to_compare_set: 239 | new_mesh_file_abs = os.path.join(new_meshes_dir_abs, ref_without_new_mesh) 240 | reference_mesh_file_abs = os.path.join(ref_meshes_dir_abs, ref_without_new_mesh) 241 | results.append((new_mesh_file_abs, reference_mesh_file_abs, str(-1), str(-1), str(-1), str(-1))) 242 | 243 | # append scale to each row 244 | call_params = [] 245 | for fi, row in enumerate(results): 246 | ref_file_abs = row[1] 247 | call_params.append([ref_file_abs]) 248 | results_scale = start_process_pool(_scale_single_file, call_params, num_processes) 249 | results = [r + (str(results_scale[ri]),) for ri, r in enumerate(results)] 250 | 251 | # sort by file name 252 | results = sorted(results, key=lambda x: x[0]) 253 | 254 | make_dir_for_file(report_name) 255 | csv_lines = ['in mesh,ref mesh,Hausdorff dist new-ref,Hausdorff dist ref-new,Hausdorff dist,' 256 | 'Chamfer dist(-1: no input; -2: no reference),Scale'] 257 | csv_lines += [','.join(item) for item in results] 258 | # csv_lines += ['=AVERAGE(E2:E41)'] 259 | csv_lines_str = '\n'.join(csv_lines) 260 | with open(report_name, "w") as text_file: 261 | text_file.write(csv_lines_str) 262 | 263 | 264 | @hydra.main(config_path='./conf', config_name='config', version_base='1.2') 265 | def generate_stats(cfg: DictConfig): 266 | """ 267 | Evaluate Hausdorff distance between reconstructed and GT models. 268 | 269 | Parameters 270 | ---------- 271 | cfg: DictConfig 272 | Hydra configuration 273 | """ 274 | 275 | csv_file = os.path.join(cfg.csv_path) 276 | mesh_comparison( 277 | new_meshes_dir_abs=cfg.remap_dir, 278 | ref_meshes_dir_abs=cfg.reference_dir, 279 | num_processes=cfg.num_workers, 280 | report_name=csv_file, 281 | samples_per_model=cfg.evaluate.num_samples, 282 | dataset_file_abs=os.path.join(cfg.data_dir, 'raw/testset.txt')) 283 | 284 | 285 | if __name__ == '__main__': 286 | generate_stats() 287 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | """ 2 | Evaluation of PolyGNN. 3 | """ 4 | 5 | import pickle 6 | from pathlib import Path 7 | import multiprocessing 8 | 9 | import numpy as np 10 | import torch 11 | import hydra 12 | from omegaconf import DictConfig 13 | from tqdm import tqdm 14 | from torch.nn.parallel import DistributedDataParallel 15 | from torch_geometric.loader import DataLoader 16 | import torch.multiprocessing as mp 17 | from torchmetrics.classification import BinaryAccuracy 18 | import torch.distributed as dist 19 | from torch_geometric import compile 20 | 21 | from network import PolyGNN 22 | from dataset import CityDataset, TestOnlyDataset 23 | from utils import init_device, Sampler, set_seed, attach_to_log, setup_runner 24 | 25 | 26 | class PredictionSaver: 27 | """ 28 | Asynchronous prediction saving. 29 | """ 30 | def __init__(self, processes): 31 | self.pool = multiprocessing.Pool(processes=processes) 32 | 33 | @staticmethod 34 | def save(args): 35 | pred, name, cfg = args 36 | indices_cells = np.where(pred.cpu().numpy())[0] 37 | if len(indices_cells) > 0: 38 | complex_path = f'{cfg.complex_dir}/{name}.cc' 39 | with open(complex_path, 'rb') as handle: 40 | cell_complex = pickle.load(handle) 41 | output_path = f'{cfg.output_dir}/{name}.npy' 42 | output = np.zeros([cell_complex.num_cells], dtype=int) 43 | output[indices_cells] = 1 44 | 45 | if cfg.evaluate.seal: 46 | cells_boundary = cell_complex.cells_boundary() 47 | output[cells_boundary] = 0 48 | np.save(output_path, output) 49 | 50 | 51 | def run_eval(rank, world_size, dataset_test, cfg): 52 | """ 53 | Runner function for distributed inference of PolyGNN. 54 | """ 55 | # set up runner 56 | setup_runner(rank, world_size, cfg.master_addr, cfg.master_port) 57 | 58 | # limit number of threads 59 | torch.set_num_threads(cfg.num_workers // world_size) 60 | 61 | # initialize logging 62 | logger = attach_to_log(filepath='./outputs/test.log') 63 | 64 | # indicate device 65 | logger.debug(f"Device activated: " + f"CUDA: {cfg.gpu_ids[rank]}") 66 | 67 | # initialize metric 68 | metric = BinaryAccuracy() 69 | 70 | # split test indices into `world_size` many chunks 71 | eval_indices = torch.arange(len(dataset_test)) 72 | eval_indices = eval_indices.split(len(eval_indices) // world_size)[rank] 73 | dataloader_test = DataLoader(dataset_test[eval_indices], batch_size=cfg.batch_size // world_size, 74 | shuffle=cfg.shuffle, num_workers=cfg.num_workers // world_size, 75 | pin_memory=True, prefetch_factor=8) 76 | 77 | # initialize model 78 | model = PolyGNN(cfg) 79 | model.metric = metric 80 | model = model.to(rank) 81 | 82 | # distributed parallelization 83 | model = DistributedDataParallel(model, device_ids=[rank]) 84 | 85 | # compile model for better performance 86 | compile(model, dynamic=True, fullgraph=True) 87 | 88 | # load from checkpoint 89 | map_location = f'cuda:{rank}' 90 | if rank == 0: 91 | logger.info(f'Resuming from {cfg.checkpoint_path}') 92 | state = torch.load(cfg.checkpoint_path, map_location=map_location) 93 | state_dict = state['state_dict'] 94 | model.load_state_dict(state_dict, strict=False) 95 | 96 | # specify data attributes 97 | if cfg.sample.strategy == 'grid': 98 | points_suffix = f'_{cfg.sample.resolution}' 99 | elif cfg.sample.strategy == 'random': 100 | points_suffix = f'_{cfg.sample.length}' 101 | else: 102 | points_suffix = '' 103 | 104 | # start inference 105 | model.eval() 106 | pbar = tqdm(dataloader_test, desc=f'eval', disable=rank != 0) 107 | 108 | # initialize PredictionSaver instance 109 | prediction_saver = PredictionSaver(processes=cfg.num_workers // world_size) 110 | 111 | with torch.no_grad(): 112 | for batch in pbar: 113 | batch = batch.to(rank, f'points{points_suffix}', f'batch_points{points_suffix}', 'queries', 114 | 'edge_index', 'batch', 'y') 115 | outs = model(batch) 116 | outs = outs.argmax(dim=1) 117 | targets = batch.y 118 | 119 | # metric on current batch 120 | _accuracy = metric(outs, targets) 121 | pbar.set_postfix_str('acc={:.2f}'.format(_accuracy)) 122 | 123 | # save prediction as numpy file 124 | if cfg.evaluate.save: 125 | Path(cfg.output_dir).mkdir(exist_ok=True) 126 | _, boundary_indices = torch.unique(batch.batch, return_counts=True) 127 | 128 | preds = torch.split(outs, split_size_or_sections=boundary_indices.tolist(), dim=0) 129 | names = batch.name 130 | 131 | # asynchronous file saving 132 | prediction_saver.pool.map(prediction_saver.save, zip(preds, names, [cfg] * len(preds))) 133 | 134 | # metric on all batches and all accelerators using custom accumulation 135 | accuracy = metric.compute() 136 | 137 | if rank == 0: 138 | logger.info(f"Evaluation accuracy: {accuracy}") 139 | 140 | # reset internal state such that metric ready for new data 141 | metric.reset() 142 | 143 | dist.barrier() 144 | 145 | dist.destroy_process_group() 146 | 147 | 148 | @hydra.main(config_path='./conf', config_name='config', version_base='1.2') 149 | def test(cfg: DictConfig): 150 | """ 151 | Test PolyGNN for reconstruction. 152 | 153 | Parameters 154 | ---------- 155 | cfg: DictConfig 156 | Hydra configuration 157 | """ 158 | 159 | # initialize logger 160 | logger = attach_to_log(filepath='./outputs/test.log') 161 | 162 | # initialize device 163 | init_device(cfg.gpu_ids, register_freeze=cfg.gpu_freeze) 164 | logger.info(f"Device initialized: " + f"CUDA: {cfg.gpu_ids}") 165 | 166 | # fix randomness 167 | set_seed(cfg.seed) 168 | logger.info(f"Random seed set to {cfg.seed}") 169 | 170 | # initialize data sampler 171 | sampler = Sampler(strategy=cfg.sample.strategy, length=cfg.sample.length, ratio=cfg.sample.ratio, 172 | resolutions=cfg.sample.resolutions, duplicate=cfg.sample.duplicate, seed=cfg.seed) 173 | transform = sampler.sample if cfg.sample.transform else None 174 | pre_transform = sampler.sample if cfg.sample.pre_transform else None 175 | 176 | # initialize dataset 177 | if cfg.dataset in {'munich', 'munich_perturb', 'munich_subsample', 'munich_truncate', 'munich_haswall', 'campus_ldbv'}: 178 | dataset = TestOnlyDataset(pre_transform=pre_transform, transform=transform, root=cfg.data_dir, 179 | split='test', num_workers=cfg.num_workers) 180 | else: 181 | dataset = CityDataset(pre_transform=pre_transform, transform=transform, root=cfg.data_dir, 182 | split='test', num_workers=cfg.num_workers) 183 | 184 | world_size = len(cfg.gpu_ids) 185 | mp.spawn(run_eval, args=(world_size, dataset, cfg), nprocs=world_size, join=True) 186 | 187 | 188 | if __name__ == '__main__': 189 | test() 190 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | """ 2 | Supervised training of PolyGNN. 3 | """ 4 | 5 | import os 6 | from pathlib import Path 7 | 8 | import wandb 9 | import hydra 10 | from omegaconf import DictConfig 11 | from tqdm import tqdm 12 | import torch 13 | import torch.multiprocessing as mp 14 | import torch.distributed as dist 15 | from torch.nn.parallel import DistributedDataParallel 16 | from torch_geometric.loader import DataLoader 17 | from torch_geometric import compile 18 | from torchmetrics.classification import BinaryAccuracy 19 | 20 | from network import PolyGNN, focal_loss, bce_loss 21 | from dataset import CityDataset 22 | from utils import init_device, Sampler, set_seed, attach_to_log, setup_runner 23 | 24 | 25 | def run_train(rank, world_size, dataset_train, dataset_test, cfg): 26 | """ 27 | Runner function for distributed training of PolyGNN. 28 | """ 29 | # set up runner 30 | setup_runner(rank, world_size, cfg.master_addr, cfg.master_port) 31 | 32 | # limit number of threads 33 | torch.set_num_threads(cfg.num_workers // world_size) 34 | 35 | # initialize logging 36 | logger = attach_to_log(filepath='./outputs/train.log') 37 | if rank == 0: 38 | logger.info(f'Training PolyGNN on {cfg.dataset}') 39 | wandb_mode = 'online' if cfg.wandb else 'disabled' 40 | wandb.init(mode=wandb_mode, project='polygnn', entity='zhaiyu', 41 | name=cfg.dataset+cfg.run_suffix, dir=cfg.wandb_dir) 42 | wandb.save('./outputs/.hydra/*') 43 | 44 | # indicate device 45 | logger.debug(f"Device activated: " + f"CUDA: {cfg.gpu_ids[rank]}") 46 | 47 | # split training indices into `world_size` many chunks 48 | train_indices = torch.arange(len(dataset_train)) 49 | train_indices = train_indices.split(len(train_indices) // world_size)[rank] 50 | eval_indices = torch.arange(len(dataset_test)) 51 | eval_indices = eval_indices.split(len(eval_indices) // world_size)[rank] 52 | 53 | # setup dataloaders 54 | dataloader_train = DataLoader(dataset_train[train_indices], batch_size=cfg.batch_size // world_size, 55 | shuffle=cfg.shuffle, num_workers=cfg.num_workers // world_size, 56 | pin_memory=True, prefetch_factor=8) 57 | dataloader_test = DataLoader(dataset_test[eval_indices], batch_size=cfg.batch_size // world_size, 58 | shuffle=cfg.shuffle, num_workers=cfg.num_workers // world_size, 59 | pin_memory=True, prefetch_factor=8) 60 | 61 | # initialize model 62 | model = PolyGNN(cfg).to(rank) 63 | 64 | # distributed parallelization 65 | model = DistributedDataParallel(model, device_ids=[rank]) 66 | 67 | # compile model for better performance 68 | compile(model, dynamic=True, fullgraph=True) 69 | 70 | # freeze certain layers for fine-tuning 71 | if cfg.warm: 72 | for stage in cfg.freeze_stages: 73 | logger.info(f'Freezing stage: {stage}') 74 | for parameter in getattr(model, stage).parameters(): 75 | parameter.requires_grad = False 76 | 77 | # initialize optimizer and scheduler 78 | optimizer = torch.optim.Adam(model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay) 79 | scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer, base_lr=cfg.scheduler.base_lr, max_lr=cfg.scheduler.max_lr, 80 | step_size_up=cfg.scheduler.step_size_up, mode=cfg.scheduler.mode, 81 | cycle_momentum=False) 82 | # initialize metrics 83 | metric = BinaryAccuracy().to(rank) 84 | 85 | # warm start from checkpoint if available 86 | if cfg.warm: 87 | if rank == 0: 88 | logger.info(f'Resuming from {cfg.checkpoint_path}') 89 | map_location = f'cuda:{rank}' 90 | state = torch.load(cfg.checkpoint_path, map_location=map_location) 91 | state_dict = state['state_dict'] 92 | model.load_state_dict(state_dict, strict=False) 93 | if cfg.warm_optimizer: 94 | try: 95 | optimizer.load_state_dict(state['optimizer']) 96 | if rank == 0: 97 | logger.info(f'Optimizer loaded from checkpoint') 98 | except (KeyError, ValueError) as error: 99 | if rank == 0: 100 | logger.warning(f'Optimizer not loaded from checkpoint: {error}') 101 | 102 | if cfg.warm_scheduler: 103 | try: 104 | scheduler.load_state_dict(state['scheduler']) 105 | if rank == 0: 106 | logger.info(f'Scheduler loaded from checkpoint') 107 | except (KeyError, ValueError) as error: 108 | if rank == 0: 109 | logger.warning(f'Scheduler not loaded from checkpoint: {error}') 110 | 111 | best_accuracy = state['accuracy'] 112 | if state['epoch'] > cfg.num_epochs: 113 | if rank == 0: 114 | logger.info(f'Expected epoch reached from checkpoint') 115 | return 116 | epoch_generator = range(state['epoch'] + 1, cfg.num_epochs) 117 | else: 118 | best_accuracy = 0 119 | epoch_generator = range(cfg.num_epochs) 120 | 121 | # initialize loss function 122 | if cfg.loss == 'focal': 123 | loss_func = focal_loss 124 | elif cfg.loss == 'bce': 125 | loss_func = bce_loss 126 | else: 127 | raise ValueError(f'Unexpected loss function: {cfg.loss}') 128 | 129 | # specify data attributes 130 | if cfg.sample.strategy == 'grid': 131 | points_suffix = f'_{cfg.sample.resolution}' 132 | elif cfg.sample.strategy == 'random': 133 | points_suffix = f'_{cfg.sample.length}' 134 | else: 135 | points_suffix = '' 136 | 137 | # start training 138 | for i in epoch_generator: 139 | model.train() 140 | pbar = tqdm(dataloader_train, desc=f'epoch {i}', disable=rank != 0) 141 | 142 | if rank == 0: 143 | wandb.log({"epoch": i}) 144 | 145 | for batch in pbar: 146 | optimizer.zero_grad() 147 | batch = batch.to(rank, f'points{points_suffix}', f'batch_points{points_suffix}', 'queries', 'edge_index', 148 | 'batch', 'y') 149 | outs = model(batch) 150 | targets = batch.y 151 | loss, accuracy, ratio, _, _ = loss_func(outs, targets) 152 | 153 | if rank == 0: 154 | wandb.log({"loss": loss}) 155 | wandb.log({"train_accuracy": accuracy}) 156 | wandb.log({"ratio:": ratio}) 157 | wandb.log({"learning_rate": optimizer.param_groups[0]['lr']}) 158 | pbar.set_postfix_str('loss={:.2f}, acc={:.2f}, ratio={:.2f}'.format(loss, accuracy, ratio)) 159 | 160 | loss.backward() 161 | optimizer.step() 162 | scheduler.step() 163 | 164 | dist.barrier() 165 | 166 | # validate and save checkpoint with DDP 167 | if cfg.validate and i % cfg.save_interval == 0: 168 | model.metric = metric 169 | model = model.to(rank) 170 | model.eval() 171 | 172 | pbar = tqdm(dataloader_test, desc=f'eval', disable=rank != 0) 173 | with torch.no_grad(): 174 | for batch in pbar: 175 | batch = batch.to(rank, f'points{points_suffix}', f'batch_points{points_suffix}', 'queries', 176 | 'edge_index', 'batch', 'y') 177 | outs = model(batch) 178 | outs = outs.argmax(dim=1) 179 | targets = batch.y 180 | 181 | # metric on current batch 182 | accuracy = metric(outs, targets) 183 | if rank == 0: 184 | pbar.set_postfix_str('acc={:.2f}'.format(accuracy)) 185 | 186 | # metrics on all batches and all accelerators using custom accumulation 187 | accuracy = metric.compute() 188 | 189 | dist.barrier() 190 | 191 | if rank == 0: 192 | logger.info(f'Evaluation accuracy: {accuracy:.4f}') 193 | wandb.log({"eval_accuracy": accuracy}) 194 | checkpoint_path = os.path.join(cfg.checkpoint_dir, f'model_epoch{i}.pth') 195 | logger.info(f'Saving checkpoint to {checkpoint_path}.') 196 | Path(cfg.checkpoint_dir).mkdir(parents=True, exist_ok=True) 197 | state = { 198 | 'epoch': i, 199 | 'state_dict': model.state_dict(), 200 | 'optimizer': optimizer.state_dict(), 201 | 'scheduler': scheduler.state_dict(), 202 | 'accuracy': accuracy, 203 | } 204 | # Cannot pickle 'WeakMethod' object when saving state_dict for CyclicLr 205 | # https://github.com/pytorch/pytorch/pull/91400 206 | torch.save(state, checkpoint_path) 207 | if accuracy > best_accuracy: 208 | logger.info(f'Saving checkpoint to {cfg.checkpoint_path}.') 209 | torch.save(state, cfg.checkpoint_path) 210 | wandb.save(cfg.checkpoint_path) 211 | best_accuracy = accuracy 212 | 213 | # reset internal state such that metric ready for new data 214 | metric.reset() 215 | 216 | dist.barrier() 217 | 218 | dist.destroy_process_group() 219 | 220 | 221 | @hydra.main(config_path='./conf', config_name='config', version_base='1.2') 222 | def train(cfg: DictConfig): 223 | """ 224 | Train PolyGNN. 225 | 226 | Parameters 227 | ---------- 228 | cfg: DictConfig 229 | Hydra configuration 230 | """ 231 | logger = attach_to_log(filepath='./outputs/train.log') 232 | 233 | # initialize device 234 | init_device(cfg.gpu_ids, register_freeze=cfg.gpu_freeze) 235 | logger.info(f"Device initialized: " + f"CUDA: {cfg.gpu_ids}") 236 | 237 | # fix randomness 238 | set_seed(cfg.seed) 239 | logger.info(f"Random seed set to {cfg.seed}") 240 | 241 | # initialize data sampler 242 | sampler = Sampler(strategy=cfg.sample.strategy, length=cfg.sample.length, ratio=cfg.sample.ratio, 243 | resolutions=cfg.sample.resolutions, duplicate=cfg.sample.duplicate, seed=cfg.seed) 244 | transform = sampler.sample if cfg.sample.transform else None 245 | pre_transform = sampler.sample if cfg.sample.pre_transform else None 246 | 247 | # initialize dataset 248 | dataset_train = CityDataset(pre_transform=pre_transform, transform=transform, root=cfg.data_dir, 249 | split='train', num_workers=cfg.num_workers, num_queries=cfg.num_queries) 250 | dataset_test = CityDataset(pre_transform=pre_transform, transform=transform, root=cfg.data_dir, 251 | split='test', num_workers=cfg.num_workers, num_queries=cfg.num_queries) 252 | 253 | world_size = len(cfg.gpu_ids) 254 | mp.spawn(run_train, args=(world_size, dataset_train, dataset_test, cfg), nprocs=world_size, join=True) 255 | 256 | 257 | if __name__ == '__main__': 258 | train() 259 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utility functions. 3 | """ 4 | 5 | import os 6 | import time 7 | import glob 8 | from pathlib import Path 9 | import atexit 10 | from itertools import repeat 11 | import math 12 | import logging 13 | import random 14 | import pickle 15 | import csv 16 | import multiprocessing 17 | 18 | from tqdm import tqdm 19 | import numpy as np 20 | import torch 21 | import torch.distributed as dist 22 | import hydra 23 | from omegaconf import DictConfig 24 | import torch_geometric as pyg 25 | import trimesh 26 | from plyfile import PlyData 27 | 28 | from abspy import VertexGroup, CellComplex 29 | 30 | 31 | def setup_runner(rank, world_size, master_addr, master_port): 32 | """ 33 | Set up runner for distributed parallelization. 34 | """ 35 | # initialize torch.distributed 36 | os.environ['MASTER_ADDR'] = str(master_addr) 37 | os.environ['MASTER_PORT'] = str(master_port) 38 | dist.init_process_group('nccl', rank=rank, world_size=world_size) 39 | 40 | 41 | def attach_to_log(level=logging.INFO, 42 | filepath=None, 43 | colors=True, 44 | capture_warnings=True): 45 | """ 46 | Attach a stream handler to all loggers. 47 | 48 | Parameters 49 | ------------ 50 | level : enum (int) 51 | Logging level, like logging.INFO 52 | colors : bool 53 | If True try to use colorlog formatter 54 | capture_warnings: bool 55 | If True capture warnings 56 | filepath: None or str 57 | path to save the logfile 58 | 59 | Returns 60 | ------- 61 | logger: Logger object 62 | Logger attached with a stream handler 63 | """ 64 | 65 | # make sure we log warnings from the warnings module 66 | logging.captureWarnings(capture_warnings) 67 | 68 | # create a basic formatter 69 | formatter_file = logging.Formatter( 70 | "[%(asctime)s] %(levelname)-7s (%(filename)s:%(lineno)3s) %(message)s", 71 | "%Y-%m-%d %H:%M:%S") 72 | if colors: 73 | try: 74 | from colorlog import ColoredFormatter 75 | formatter_stream = ColoredFormatter( 76 | ("%(log_color)s%(levelname)-8s%(reset)s " + 77 | "%(filename)17s:%(lineno)-4s %(blue)4s%(message)s"), 78 | datefmt=None, 79 | reset=True, 80 | log_colors={'DEBUG': 'cyan', 81 | 'INFO': 'green', 82 | 'WARNING': 'yellow', 83 | 'ERROR': 'red', 84 | 'CRITICAL': 'red'}) 85 | except ImportError: 86 | formatter_stream = formatter_file 87 | else: 88 | formatter_stream = formatter_file 89 | 90 | # if no handler was passed use a StreamHandler 91 | logger = logging.getLogger() 92 | logger.setLevel(level) 93 | 94 | if not any([isinstance(handler, logging.StreamHandler) for handler in logger.handlers]): 95 | stream_handler = logging.StreamHandler() 96 | stream_handler.setFormatter(formatter_stream) 97 | logger.addHandler(stream_handler) 98 | 99 | if filepath and not any([isinstance(handler, logging.FileHandler) for handler in logger.handlers]): 100 | file_handler = logging.FileHandler(filepath) 101 | file_handler.setFormatter(formatter_file) 102 | logger.addHandler(file_handler) 103 | 104 | # set nicer numpy print options 105 | np.set_printoptions(precision=5, suppress=True) 106 | 107 | return logger 108 | 109 | 110 | def edge_index_from_dict(graph_dict): 111 | """ 112 | Convert adjacency dict to edge index. 113 | 114 | Parameters 115 | ---------- 116 | graph_dict: dict 117 | Adjacency dict 118 | 119 | Returns 120 | ------- 121 | as_tensor: torch.Tensor 122 | Edge index 123 | """ 124 | row, col = [], [] 125 | for key, value in graph_dict.items(): 126 | row += repeat(key, len(value)) 127 | col += value 128 | edge_index = torch.tensor([row, col], dtype=torch.long) 129 | return edge_index 130 | 131 | 132 | def index_to_mask(index, size): 133 | """ 134 | Convert index to binary mask. 135 | 136 | Parameters 137 | ---------- 138 | index: range Object 139 | Index of 1s 140 | size: int 141 | Size of mask 142 | 143 | Returns 144 | ------- 145 | as_tensor: torch.Tensor 146 | Binary mask 147 | """ 148 | mask = torch.zeros((size,), dtype=torch.bool) 149 | mask[index] = 1 150 | return mask 151 | 152 | 153 | def freeze_vram(cuda_devices, timeout=500): 154 | """ 155 | Freeze VRAM for a short time at program exit. For debugging. 156 | 157 | Parameters 158 | ---------- 159 | cuda_devices: list of int 160 | Indices of CUDA devices 161 | timeout: int 162 | Timeout seconds 163 | """ 164 | torch.cuda.empty_cache() 165 | devices_info = os.popen( 166 | '"/usr/bin/nvidia-smi" --query-gpu=memory.total,memory.used --format=csv,nounits,noheader') \ 167 | .read().strip().split("\n") 168 | for i, device in enumerate(cuda_devices): 169 | total, used = devices_info[int(device)].split(',') 170 | total = int(total) 171 | used = int(used) 172 | max_mem = int(total * 0.90) 173 | block_mem = max_mem - used 174 | if block_mem > 0: 175 | x = torch.FloatTensor(256, 1024, block_mem).to(torch.device(f'cuda:{i}')) 176 | del x 177 | for _ in tqdm(range(timeout), desc='VRAM freezing'): 178 | time.sleep(1) 179 | 180 | 181 | def init_device(gpu_ids, register_freeze=False): 182 | """ 183 | Init devices. 184 | 185 | Parameters 186 | ---------- 187 | gpu_ids: list of int 188 | GPU indices to use 189 | register_freeze: bool 190 | Register GPU memory freeze if set True 191 | """ 192 | # set multiprocessing sharing strategy 193 | torch.multiprocessing.set_sharing_strategy('file_system') 194 | 195 | # does not work for DP after import torch with PyTorch 2.0, but works for DDP nevertheless 196 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 197 | os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_ids)[1:-1] 198 | 199 | # raise soft limit from 1024 to 4096 for open files to address RuntimeError: received 0 items of ancdata 200 | import resource 201 | rlimit = resource.getrlimit(resource.RLIMIT_NOFILE) 202 | resource.setrlimit(resource.RLIMIT_NOFILE, (4096, rlimit[1])) 203 | 204 | if register_freeze: 205 | atexit.register(freeze_vram, gpu_ids) 206 | 207 | 208 | class Sampler: 209 | """ 210 | Sampler to sample points from a point cloud. 211 | """ 212 | 213 | def __init__(self, strategy, length, ratio, resolutions, duplicate, seed=None): 214 | self.strategy = strategy 215 | self.length = length 216 | self.ratio = ratio 217 | self.resolutions = resolutions 218 | self.duplicate = duplicate 219 | 220 | # seed once in initialization 221 | self.seed = seed 222 | 223 | def sample(self, data): 224 | with torch.no_grad(): 225 | if self.seed is not None: 226 | set_seed(self.seed) 227 | if self.strategy is None: 228 | return data 229 | if self.strategy == 'fps': 230 | return self.farthest_sample(data) 231 | elif self.strategy == 'random': 232 | return self.random_sample(data) 233 | elif self.strategy == 'grid': 234 | return self.grid_sample(data) 235 | else: 236 | raise ValueError(f'Unexpected sampling strategy={self.strategy}.') 237 | 238 | def random_sample(self, data): 239 | """ 240 | Random uniform sampling. 241 | """ 242 | # https://pytorch-geometric.readthedocs.io/en/latest/_modules/torch_geometric/transforms/fixed_points.html#FixedPoints 243 | if not self.duplicate: 244 | choice = torch.randperm(data.num_points)[:self.length] 245 | else: 246 | choice = torch.cat( 247 | [torch.randperm(data.num_points) for _ in range(math.ceil(self.length / data.num_points))], 248 | dim=0)[:self.length] 249 | data[f'batch_points_{self.length}'] = data.batch_points[choice] 250 | data[f'points_{self.length}'] = data.points[choice] 251 | return data 252 | 253 | def grid_sample(self, data): 254 | """ 255 | Sampling points into fixed-sized voxels. 256 | Each cluster returned is the cluster barycenter. 257 | """ 258 | # https://pytorch-geometric.readthedocs.io/en/latest/_modules/torch_geometric/transforms/grid_sampling.html#GridSampling 259 | for size in self.resolutions: 260 | c = pyg.nn.voxel_grid(data.points, size, data.batch_points, None, None) 261 | _, perm = pyg.nn.pool.consecutive.consecutive_cluster(c) 262 | data[f'batch_points_{size}'] = data.batch_points[perm] 263 | data[f'points_{size}'] = data.points[perm] 264 | return data 265 | 266 | def farthest_sample(self, data): 267 | """ 268 | Farthest sampling which iteratively samples the most distant point with regard to the rest points. Inplace. 269 | """ 270 | # https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#torch_geometric.nn.pool.fps 271 | perm = pyg.nn.pool.fps(data.points, data.batch_points, ratio=self.ratio) 272 | data[f'batch_points_fps'] = data.batch_points[perm] 273 | data[f'points_fps'] = data.points[perm] 274 | return data 275 | 276 | 277 | def reverse_translation_and_scale(mesh): 278 | """ 279 | Translation and scale for reverse normalisation of mesh. 280 | """ 281 | bounds = mesh.extents 282 | if bounds.min() == 0.0: 283 | return 284 | 285 | # translate to origin 286 | translation = - (mesh.bounds[0] + mesh.bounds[1]) * 0.5 287 | translation = trimesh.transformations.translation_matrix(direction=-translation) 288 | 289 | # scale to unit cube 290 | scale = bounds.max() 291 | scale_trafo = trimesh.transformations.scale_matrix(factor=scale) 292 | 293 | return translation, scale_trafo 294 | 295 | 296 | def normalise_mesh(args): 297 | """ 298 | Normalize mesh (or point cloud). 299 | First translation, then scaling, if any. 300 | 301 | Parameters 302 | ---------- 303 | args[0]: input_path: str or Path 304 | Path to input mesh (results from reconstruction, normalised) 305 | args[1]: reference_path: str or Path 306 | Path to reference mesh (used to determine the transformation) 307 | args[2]: output_path: str or Path 308 | Path to output mesh (final reconstruction, with reversed normalisation) 309 | args[3]: force: str 310 | Force loading type ('mesh' or 'scene') 311 | args[4]: offset: list 312 | Coordinate offset (x, y, z) 313 | args[5]: scaling: bool 314 | Switch on scaling if set True 315 | args[6]: translation: bool 316 | Switch on translation if set True 317 | """ 318 | input_path, reference_path, output_path, force, offset, is_scaling, is_translation = args 319 | 320 | reference_mesh = trimesh.load(reference_path) 321 | translation, scale_trafo = reverse_translation_and_scale(reference_mesh) 322 | if offset is not None: 323 | translation[0][-1] = translation[0][-1] + offset[0] 324 | translation[1][-1] = translation[1][-1] + offset[1] 325 | translation[2][-1] = translation[2][-1] + offset[2] 326 | 327 | # trimesh built-in transform would result in an issue of missing triangles 328 | with open(input_path, 'r') as fin: 329 | lines = fin.readlines() 330 | lines_ = [] 331 | for line in lines: 332 | if line.startswith('v'): 333 | vertices = np.array(line.split()[1:], dtype=float) 334 | if is_translation is True: 335 | vertices = vertices - translation[:3, 3] 336 | if is_scaling is True: 337 | vertices = vertices / scale_trafo[0][0] 338 | line_ = f'v {vertices[0]} {vertices[1]} {vertices[2]}\n' 339 | else: 340 | line_ = line 341 | lines_.append(line_) 342 | with open(output_path, 'w') as fout: 343 | fout.writelines(lines_) 344 | 345 | 346 | def reverse_normalise_mesh(args): 347 | """ 348 | Reverse normalisation for reconstructed mesh. 349 | First scaling, then translation, if any. 350 | 351 | Parameters 352 | ---------- 353 | args[0]: input_path: str or Path 354 | Path to input mesh (results from reconstruction, normalised) 355 | args[1]: reference_path: str or Path 356 | Path to reference mesh (used to determine the transformation) 357 | args[2]: output_path: str or Path 358 | Path to output mesh (final reconstruction, with reversed normalisation) 359 | args[3]: force: str 360 | Force loading type ('mesh' or 'scene') 361 | args[4]: offset: list 362 | Coordinate offset (x, y, z) 363 | args[5]: scaling: bool 364 | Switch on scaling if set True 365 | args[6]: translation: bool 366 | Switch on translation if set True 367 | """ 368 | input_path, reference_path, output_path, force, offset, is_scaling, is_translation = args 369 | 370 | reference_mesh = trimesh.load(reference_path) 371 | translation, scale_trafo = reverse_translation_and_scale(reference_mesh) 372 | if offset is not None: 373 | translation[0][-1] = translation[0][-1] + offset[0] 374 | translation[1][-1] = translation[1][-1] + offset[1] 375 | translation[2][-1] = translation[2][-1] + offset[2] 376 | 377 | # trimesh built-in transform would result in an issue of missing triangles 378 | with open(input_path, 'r') as fin: 379 | lines = fin.readlines() 380 | lines_ = [] 381 | for line in lines: 382 | if line.startswith('v'): 383 | vertices = np.array(line.split()[1:], dtype=float) 384 | if is_scaling is True: 385 | vertices = vertices * scale_trafo[0][0] 386 | if is_translation is True: 387 | vertices = vertices + translation[:3, 3] 388 | line_ = f'v {vertices[0]} {vertices[1]} {vertices[2]}\n' 389 | else: 390 | line_ = line 391 | lines_.append(line_) 392 | with open(output_path, 'w') as fout: 393 | fout.writelines(lines_) 394 | 395 | 396 | def reverse_normalise_cloud(args): 397 | """ 398 | Reverse normalisation for normalised point cloud. 399 | 400 | Parameters 401 | ---------- 402 | args[0]: input_path: str or Path 403 | Path to input point cloud (normalised) 404 | args[1]: reference_path: str or Path 405 | Path to reference mesh (used to determine the transformation) 406 | args[2]: output_path: str or Path 407 | Path to output point cloud (with reversed normalisation) 408 | args[3]: force: str 409 | Force loading type ('mesh' or 'scene') 410 | args[4]: offset: list 411 | Coordinate offset (x, y, z) 412 | """ 413 | input_path, reference_path, output_path, force, offset = args 414 | plydata = PlyData.read(input_path)['vertex'] 415 | points = np.array([plydata['x'], plydata['y'], plydata['z']]).T 416 | cloud = trimesh.PointCloud(points) 417 | reference_mesh = trimesh.load(reference_path) 418 | translation, scale_trafo = reverse_translation_and_scale(reference_mesh) 419 | if offset is not None: 420 | translation[0][-1] = translation[0][-1] + offset[0] 421 | translation[1][-1] = translation[1][-1] + offset[1] 422 | translation[2][-1] = translation[2][-1] + offset[2] 423 | 424 | cloud.apply_transform(scale_trafo) 425 | cloud.apply_transform(translation) 426 | cloud.export(str(output_path)) 427 | 428 | 429 | def coerce(data): 430 | """ 431 | Coercion for legacy data. 432 | """ 433 | data.points = torch.as_tensor(data.points, dtype=torch.float) 434 | data.queries = torch.as_tensor(data.queries, dtype=torch.float) 435 | 436 | if not hasattr(data, 'num_points'): 437 | data.num_points = len(data.points) 438 | if not hasattr(data, 'batch_points'): 439 | data.batch_points = torch.zeros(len(data.points), dtype=torch.long) 440 | return data 441 | 442 | 443 | def set_seed(seed: int) -> None: 444 | """ 445 | Set singular seed to fix randomness. 446 | May need to be repeatedly invoked (at least for np.random). 447 | """ 448 | np.random.seed(seed) 449 | random.seed(seed) 450 | torch.manual_seed(seed) 451 | torch.cuda.manual_seed(seed) 452 | # When running on the CuDNN backend, two further options must be set 453 | torch.backends.cudnn.deterministic = True 454 | torch.backends.cudnn.benchmark = False 455 | # Set a fixed value for the hash seed 456 | os.environ["PYTHONHASHSEED"] = str(seed) 457 | 458 | 459 | def append_labels(args): 460 | """ 461 | Append occupancy labels to existing cell complex file. 462 | 463 | Parameters 464 | ---------- 465 | args[0]: input_cc: str or Path 466 | Path to input CC file 467 | args[1]: input_manifold: str or Path 468 | Path to input manifold mesh file 469 | args[2]: output_cc: str or Path 470 | Path to output CC file 471 | """ 472 | with open(args[0], 'rb') as handle: 473 | cell_complex = pickle.load(handle) 474 | cells_in_mesh = cell_complex.cells_in_mesh(args[1]) 475 | # one-hot encoding 476 | labels = [0] * cell_complex.num_cells 477 | for i in cells_in_mesh: 478 | labels[i] = 1 479 | cell_complex.labels = labels 480 | cell_complex.save(args[2]) 481 | 482 | 483 | @hydra.main(config_path='./conf', config_name='config', version_base='1.2') 484 | def multi_append_labels(cfg: DictConfig): 485 | # initialize logging 486 | logger = logging.getLogger('Labels') 487 | 488 | filenames = glob.glob('data/munich_16star/raw/05_complexes/*.cc') 489 | args = [] 490 | for filename_input in filenames: 491 | stem = Path(filename_input).stem 492 | filename_output = Path(filename_input).with_suffix('.cc.new') 493 | filename_manifold = Path(filename_input).parent.parent / '03_meshes_manifold' / (stem + '.obj') 494 | if not filename_output.exists(): 495 | args.append((filename_input, filename_manifold, filename_output)) 496 | 497 | logger.info('Start complex labeling') 498 | with multiprocessing.Pool(processes=cfg.num_workers if cfg.num_workers else multiprocessing.cpu_count()) as pool: 499 | # call with multiprocessing 500 | for _ in tqdm(pool.imap_unordered(append_labels, args), desc='Appending labels', total=len(args)): 501 | pass 502 | 503 | # exit with a message 504 | logger.info('Done complex labeling') 505 | 506 | 507 | def append_samples(args): 508 | """ 509 | Append multi-grid-size samples to existing data file. 510 | 511 | Parameters 512 | ---------- 513 | args[0]: input_path: str or Path 514 | Path to input torch file 515 | args[1]: output_path: str or Path 516 | Path to output torch file 517 | args[2]: sample_func: func 518 | Function to sample 519 | """ 520 | data = torch.load(args[0]) 521 | data = coerce(data) 522 | data = args[2](data) 523 | torch.save(data, args[1]) 524 | 525 | 526 | def append_queries(args): 527 | """ 528 | Append queries to existing data files. 529 | 530 | Parameters 531 | ---------- 532 | args[0]: input_path: str or Path 533 | Path to input torch file 534 | args[1]: complex_dir: str or Path 535 | Dir to complexes 536 | args[2]: output_path: str or Path 537 | Path to output torch file 538 | """ 539 | data = torch.load(args[0]) 540 | with open(os.path.join(args[1], data.name + '.cc'), 'rb') as handle: 541 | cell_complex = pickle.load(handle) 542 | 543 | queries_random = np.array(cell_complex.cell_representatives(location='random_t', num=16)) 544 | queries_boundary = np.array(cell_complex.cell_representatives(location='boundary', num=16)) 545 | queries_skeleton = np.array(cell_complex.cell_representatives(location='skeleton', num=16)) 546 | data.queries_random = torch.as_tensor(queries_random, dtype=torch.float) 547 | data.queries_boundary = torch.as_tensor(queries_boundary, dtype=torch.float) 548 | data.queries_skeleton = torch.as_tensor(queries_skeleton, dtype=torch.float) 549 | 550 | torch.save(data, args[2]) 551 | 552 | 553 | @hydra.main(config_path='./conf', config_name='config', version_base='1.2') 554 | def multi_append_queries(cfg: DictConfig): 555 | # initialize logging 556 | logger = logging.getLogger('Querying') 557 | 558 | filenames = glob.glob(f'{cfg.data_dir}/processed/*[0-9].pt') 559 | args = [] 560 | for filename_input in filenames: 561 | filename_output = Path(filename_input).with_suffix('.pt.queries_appended') 562 | if not filename_output.exists(): 563 | args.append((filename_input, cfg.complex_dir, filename_output)) 564 | 565 | logger.info('Start polyhedra sampling') 566 | with multiprocessing.Pool(processes=cfg.num_workers if cfg.num_workers else multiprocessing.cpu_count()) as pool: 567 | # call with multiprocessing 568 | for _ in tqdm(pool.imap_unordered(append_queries, args), desc='Appending queries', total=len(args)): 569 | pass 570 | 571 | # exit with a message 572 | logger.info('Done polyhedra sampling') 573 | 574 | 575 | @hydra.main(config_path='./conf', config_name='config', version_base='1.2') 576 | def multi_append_samples(cfg: DictConfig): 577 | # initialize logging 578 | logger = logging.getLogger('Sampling') 579 | 580 | filenames = glob.glob(f'{cfg.data_dir}/processed/*[0-9].pt') 581 | sampler = Sampler(strategy=cfg.sample.strategy, length=cfg.sample.length, ratio=cfg.sample.ratio, 582 | resolutions=cfg.sample.resolutions, duplicate=cfg.sample.duplicate, seed=cfg.seed) 583 | args = [] 584 | for filename_input in filenames: 585 | filename_output = Path(filename_input).with_suffix('.pt.samples_appended') 586 | if not filename_output.exists(): 587 | args.append((filename_input, filename_output, sampler.sample)) 588 | 589 | logger.info('Start point cloud sampling') 590 | with multiprocessing.Pool(processes=cfg.num_workers if cfg.num_workers else multiprocessing.cpu_count()) as pool: 591 | # call with multiprocessing 592 | for _ in tqdm(pool.imap_unordered(append_samples, args), desc='Appending samples', total=len(args)): 593 | pass 594 | 595 | # exit with a message 596 | logger.info('Done point cloud sampling') 597 | 598 | 599 | def count_facets(mesh_path): 600 | """ 601 | Count the number of facets given a mesh. 602 | """ 603 | mesh = trimesh.load(mesh_path) 604 | faces_extracted = np.concatenate(mesh.facets) 605 | faces_left = np.setdiff1d(np.arange(len(mesh.faces)), faces_extracted) 606 | num_facets = len(mesh.facets) + len(faces_left) 607 | return num_facets 608 | 609 | 610 | def dict_count_facets(mesh_dir): 611 | """ 612 | Count the number of facets given a directory of meshes. 613 | """ 614 | filenames = glob.glob(f'{mesh_dir}/*.obj') 615 | facet_dict = {} 616 | for filename_input in filenames: 617 | stem = Path(filename_input).stem 618 | num_facets = count_facets(filename_input) 619 | facet_dict[stem] = num_facets 620 | 621 | # sorted by facet number 622 | print({k: v for k, v in sorted(facet_dict.items(), key=lambda item: item[1])}) 623 | 624 | 625 | def append_scale_to_csv(input_csv, output_csv): 626 | """ 627 | Append scale into an existing csv file. 628 | Note that scale has been implemented in stats.py. 629 | """ 630 | rows = [] 631 | with open(input_csv, 'r', newline='') as input_csvfile: 632 | reader = csv.reader(input_csvfile) 633 | next(reader) # skip header 634 | for r in reader: 635 | filename_input = r[1] 636 | row = r 637 | if not filename_input.endswith('.obj'): 638 | filename_input = filename_input + '.obj' 639 | mesh = trimesh.load(filename_input) 640 | extents = mesh.extents 641 | scale = extents.max() 642 | row.append(scale) 643 | rows.append(row) 644 | 645 | with open(output_csv, 'w') as output_csvfile: 646 | writer = csv.writer(output_csvfile, lineterminator='\n') 647 | writer.writerows(rows) 648 | 649 | 650 | @hydra.main(config_path='./conf', config_name='config', version_base='1.2') 651 | def calculate(cfg): 652 | """ 653 | Calculate number of parameters. 654 | """ 655 | from network import PolyGNN 656 | 657 | # initialize model 658 | model = PolyGNN(cfg) 659 | 660 | # calculate params 661 | total_params = sum( 662 | param.numel() for param in model.parameters() 663 | ) 664 | 665 | trainable_params = sum( 666 | p.numel() for p in model.parameters() if p.requires_grad 667 | ) 668 | 669 | print(f'total_params: {total_params}') 670 | print(f'trainable_params: {trainable_params}') 671 | 672 | 673 | def merge_vg2cc(args): 674 | """ 675 | Merge vg from RANSAC and vg from City3D, then generate complexes. 676 | """ 677 | vg_ransac_path, vg_city3d_path, cc_output_path = args 678 | epsilon = 0.0001 679 | vertex_group_ransac = VertexGroup(filepath=vg_ransac_path, refit=True, global_group=False, quiet=True) 680 | vertex_group_city3d = VertexGroup(filepath=vg_city3d_path, refit=False, global_group=True, quiet=True) 681 | additional_planes = [p for p in vertex_group_city3d.planes if -epsilon < p[2] < epsilon or 682 | (-epsilon < p[0] < epsilon and -epsilon < p[1] < epsilon and 1 - epsilon < p[2] < 1 + epsilon)] 683 | 684 | if len(vertex_group_ransac.planes) == 0: 685 | return 686 | cell_complex = CellComplex(vertex_group_ransac.planes, vertex_group_ransac.bounds, 687 | vertex_group_ransac.points_grouped, 688 | build_graph=True, additional_planes=additional_planes, 689 | initial_bound=[[-0.5, -0.5, -0.5], [0.5, 0.5, 0.5]], quiet=True) 690 | cell_complex.refine_planes() # additional planes are not refined 691 | cell_complex.prioritise_planes() 692 | cell_complex.construct() 693 | 694 | cell_complex.save_obj(filepath=Path(cc_output_path).with_suffix('.obj')) 695 | cell_complex.save(filepath=cc_output_path) 696 | 697 | 698 | def multi_merge_vg2cc(vg_ransac_dir, vg_city3d_dir, cc_output_dir): 699 | """ 700 | Merge vertex groups from RANSAC and from City3D, then generate complexes, with multiprocessing. 701 | """ 702 | args = [] 703 | num_workers = 42 704 | vg_filenames_ransac = glob.glob(vg_ransac_dir + '/*.vg') 705 | 706 | for vg_filename_ransac in vg_filenames_ransac: 707 | stem = Path(vg_filename_ransac).stem 708 | vg_filenames_city3d = (Path(vg_city3d_dir) / stem).with_suffix('.vg') 709 | if not vg_filenames_city3d.exists(): 710 | continue 711 | cc_filenames_output = (Path(cc_output_dir) / stem).with_suffix('.cc') 712 | if cc_filenames_output.exists(): 713 | continue 714 | args.append([vg_filename_ransac, vg_filenames_city3d, cc_filenames_output]) 715 | 716 | with multiprocessing.Pool(processes=num_workers) as pool: 717 | # call with multiprocessing 718 | for _ in tqdm(pool.imap_unordered(merge_vg2cc, args), desc='Creating complexes from vertex groups', 719 | total=len(args)): 720 | pass 721 | 722 | 723 | def vg2cc(args): 724 | """ 725 | Create cell complex from vertex group. 726 | """ 727 | vg_path, cc_path = args 728 | # print(vg_path) 729 | 730 | vertex_group = VertexGroup(filepath=vg_path, refit=False, global_group=True) 731 | 732 | cell_complex = CellComplex(vertex_group.planes, vertex_group.bounds, vertex_group.points_grouped, 733 | build_graph=True, additional_planes=None, 734 | initial_bound=[[-0.5, -0.5, -0.5], [0.5, 0.5, 0.5]]) 735 | cell_complex.refine_planes(theta=5 * 3.1416 / 180, epsilon=0.002) 736 | cell_complex.construct() 737 | 738 | cell_complex.save_obj(filepath=Path(cc_path).with_suffix('.obj')) 739 | cell_complex.save(filepath=cc_path) 740 | 741 | 742 | def multi_vg2cc(vg_dir, cc_dir): 743 | """ 744 | Create cell complexes from vertex groups with multiprocessing. 745 | """ 746 | args = [] 747 | num_workers = 38 748 | vg_filenames = glob.glob(vg_dir + '/*.vg') 749 | 750 | for vg_filename in vg_filenames: 751 | stem = Path(vg_filename).stem 752 | cc_filename = (Path(cc_dir) / stem).with_suffix('.cc') 753 | args.append([vg_filename, cc_filename]) 754 | 755 | with multiprocessing.Pool(processes=num_workers) as pool: 756 | # call with multiprocessing 757 | for _ in tqdm(pool.imap_unordered(vg2cc, args), desc='Creating complexes from vertex groups', total=len(args)): 758 | pass 759 | 760 | 761 | def coordinate2index(x, reso, coord_type='2d'): 762 | """ Generate grid index of points 763 | 764 | Args: 765 | x (tensor): points (normalized to [0, 1]) 766 | reso (int): defined resolution 767 | coord_type (str): coordinate type 768 | """ 769 | x = (x * reso).long() 770 | if coord_type == '2d': # plane 771 | index = x[:, :, 0] + reso * x[:, :, 1] # [B, N, 1] 772 | index = index[:, None, :] # [B, 1, N] 773 | return index 774 | 775 | 776 | def normalize_coordinate(p, padding=0, plane='xz', scale=1.0): 777 | """ Normalize coordinate to [0, 1] for unit cube experiments 778 | 779 | Args: 780 | p (tensor): point 781 | padding (float): conventional padding parameter of ONet for unit cube, so [-0.5, 0.5] -> [-0.55, 0.55] 782 | plane (str): plane feature type, ['xz', 'xy', 'yz'] 783 | scale: normalize scale 784 | """ 785 | if plane == 'xz': 786 | xy = p[:, :, [0, 2]] 787 | elif plane == 'xy': 788 | xy = p[:, :, [0, 1]] 789 | else: 790 | xy = p[:, :, [1, 2]] 791 | 792 | xy_new = xy / (1 + padding + 10e-6) # (-0.5, 0.5) 793 | xy_new = xy_new + 0.5 # range (0, 1) 794 | 795 | # f there are outliers out of the range 796 | if xy_new.max() >= 1: 797 | xy_new[xy_new >= 1] = 1 - 10e-6 798 | if xy_new.min() < 0: 799 | xy_new[xy_new < 0] = 0.0 800 | return xy_new 801 | 802 | 803 | def normalize_3d_coordinate(p, padding=0): 804 | """ Normalize coordinate to [0, 1] for unit cube experiments. 805 | Corresponds to our 3D model 806 | 807 | Args: 808 | p (tensor): point 809 | padding (float): conventional padding paramter of ONet for unit cube, so [-0.5, 0.5] -> [-0.55, 0.55] 810 | """ 811 | p_nor = p / (1 + padding + 10e-4) # (-0.5, 0.5) 812 | p_nor = p_nor + 0.5 # range (0, 1) 813 | # f there are outliers out of the range 814 | if p_nor.max() >= 1: 815 | p_nor[p_nor >= 1] = 1 - 10e-4 816 | if p_nor.min() < 0: 817 | p_nor[p_nor < 0] = 0.0 818 | return p_nor 819 | 820 | 821 | class map2local(object): 822 | """ 823 | Add new keys to the given input 824 | 825 | Args: 826 | s (float): the defined voxel size 827 | pos_encoding (str): method for the positional encoding, linear|sin_cos 828 | """ 829 | 830 | def __init__(self, s, pos_encoding='linear'): 831 | super().__init__() 832 | self.s = s 833 | self.pe = positional_encoding(basis_function=pos_encoding) 834 | 835 | def __call__(self, p): 836 | p = torch.remainder(p, self.s) / self.s # always positive 837 | # p = torch.fmod(p, self.s) / self.s # same sign as input p! 838 | p = self.pe(p) 839 | return p 840 | 841 | 842 | class positional_encoding(object): 843 | """ Positional Encoding (presented in NeRF) 844 | 845 | Args: 846 | basis_function (str): basis function 847 | """ 848 | 849 | def __init__(self, basis_function='sin_cos'): 850 | super().__init__() 851 | self.func = basis_function 852 | 853 | L = 10 854 | freq_bands = 2. ** (np.linspace(0, L - 1, L)) 855 | self.freq_bands = freq_bands * math.pi 856 | 857 | def __call__(self, p): 858 | if self.func == 'sin_cos': 859 | out = [] 860 | p = 2.0 * p - 1.0 # chagne to the range [-1, 1] 861 | for freq in self.freq_bands: 862 | out.append(torch.sin(freq * p)) 863 | out.append(torch.cos(freq * p)) 864 | p = torch.cat(out, dim=2) 865 | return p 866 | 867 | 868 | def make_3d_grid(bb_min, bb_max, shape): 869 | """ 870 | Makes a 3D grid. 871 | 872 | Args: 873 | bb_min (tuple): bounding box minimum 874 | bb_max (tuple): bounding box maximum 875 | shape (tuple): output shape 876 | """ 877 | size = shape[0] * shape[1] * shape[2] 878 | 879 | pxs = torch.linspace(bb_min[0], bb_max[0], int(shape[0])) 880 | pys = torch.linspace(bb_min[1], bb_max[1], int(shape[1])) 881 | pzs = torch.linspace(bb_min[2], bb_max[2], int(shape[2])) 882 | 883 | pxs = pxs.view(-1, 1, 1).expand(*shape).contiguous().view(size) 884 | pys = pys.view(1, -1, 1).expand(*shape).contiguous().view(size) 885 | pzs = pzs.view(1, 1, -1).expand(*shape).contiguous().view(size) 886 | p = torch.stack([pxs, pys, pzs], dim=1) 887 | 888 | return p 889 | 890 | 891 | if __name__ == '__main__': 892 | pass 893 | --------------------------------------------------------------------------------