├── earthparserdataset ├── utils │ ├── __init__.py │ ├── color.py │ └── labels.py ├── __init__.py ├── transforms │ ├── __init__.py │ ├── grid_sampling.py │ └── base.py ├── shapenetsem.py ├── base.py └── earthparserdataset.py ├── media └── earthparserdataset.png ├── configs └── data │ ├── lidar.yaml │ ├── default.yaml │ ├── urban.yaml │ ├── crop_field.yaml │ ├── forest.yaml │ ├── greenhouse.yaml │ ├── windturbine.yaml │ ├── marina.yaml │ ├── power_plant.yaml │ ├── shapenet_planes.yaml │ └── earthparserdataset.yaml ├── LICENSE ├── README.md ├── .gitignore └── notebooks ├── demo_shapenet.ipynb ├── demo_2D.ipynb └── demo_3D.ipynb /earthparserdataset/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /media/earthparserdataset.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/romainloiseau/EarthParserDataset/HEAD/media/earthparserdataset.png -------------------------------------------------------------------------------- /earthparserdataset/__init__.py: -------------------------------------------------------------------------------- 1 | from .earthparserdataset import LidarHDDataModule 2 | from .shapenetsem import ShapeNetSemDataModule -------------------------------------------------------------------------------- /earthparserdataset/transforms/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import ZCrop, MaxPoints, RandomRotate, RandomFlip, RandomScale, RandomCropSubTileVal, RandomCropSubTileTrain, CenterCrop 2 | from .grid_sampling import GridSampling -------------------------------------------------------------------------------- /configs/data/lidar.yaml: -------------------------------------------------------------------------------- 1 | batch_size: 64 2 | n_features: 8 3 | N_scene: 0 4 | 5 | items_per_epoch: 6 | train: 32768 7 | val: 128 8 | 9 | distance: xyzk 10 | 11 | DEBUG: True 12 | 13 | defaults: 14 | - default -------------------------------------------------------------------------------- /earthparserdataset/utils/color.py: -------------------------------------------------------------------------------- 1 | import kornia.color as col 2 | 3 | def rgb_to_lab(x): 4 | return col.rgb_to_lab(x.T.unsqueeze(-1)).squeeze().T / 127. 5 | 6 | def lab_to_rgb(x): 7 | return col.lab_to_rgb(127. * x.T.unsqueeze(-1)).squeeze().T -------------------------------------------------------------------------------- /configs/data/default.yaml: -------------------------------------------------------------------------------- 1 | data_dir: "." 2 | batch_size: 64 3 | num_workers: 8 4 | 5 | n_features: 3 6 | 7 | min_z: 0 8 | 9 | distance: xyz 10 | 11 | ignore_index_0: False 12 | 13 | modality: 3D # 2D or 3D 14 | 15 | image: 16 | res: 32 17 | n_dim: 12 -------------------------------------------------------------------------------- /configs/data/urban.yaml: -------------------------------------------------------------------------------- 1 | name: "urban" 2 | 3 | class_names: ["Unlabeled", "Ground", "Vegetation", "Building"] 4 | 5 | learning_map: 6 | 1 : 0 7 | 2 : 1 8 | 3 : 2 9 | 4 : 2 10 | 5 : 2 11 | 6 : 3 12 | 9 : 0 13 | 17: 0 14 | 64: 0 15 | 65: 0 16 | 66: 0 17 | 160: 0 18 | 161: 0 19 | 162: 0 20 | learning_map_inv: 21 | 0 : 1 22 | 1 : 2 23 | 2 : 4 24 | 3 : 6 25 | 26 | defaults: 27 | - earthparserdataset -------------------------------------------------------------------------------- /configs/data/crop_field.yaml: -------------------------------------------------------------------------------- 1 | name: "crop_field" 2 | 3 | max_xy: 25.6 4 | 5 | class_names: ["Unlabeled", "Ground", "Vegetation"] 6 | 7 | learning_map: 8 | 1 : 0 9 | 2 : 1 10 | 3 : 2 11 | 4 : 2 12 | 5 : 2 13 | 6 : 0 14 | 9 : 0 15 | 17: 0 16 | 64: 0 17 | 65: 0 18 | 66: 0 19 | 160: 0 20 | 161: 0 21 | 162: 0 22 | learning_map_inv: 23 | 0 : 1 24 | 1 : 2 25 | 2 : 4 26 | 27 | defaults: 28 | - earthparserdataset -------------------------------------------------------------------------------- /configs/data/forest.yaml: -------------------------------------------------------------------------------- 1 | name: "forest" 2 | 3 | max_z: 34.8 4 | max_xy: 25.6 5 | 6 | class_names: ["Unlabeled", "Ground", "Vegetation"] 7 | 8 | learning_map: 9 | 1 : 0 10 | 2 : 1 11 | 3 : 2 12 | 4 : 2 13 | 5 : 2 14 | 6 : 0 15 | 9 : 0 16 | 17: 0 17 | 64: 0 18 | 65: 0 19 | 66: 0 20 | 160: 0 21 | 161: 0 22 | 162: 0 23 | learning_map_inv: 24 | 0 : 1 25 | 1 : 2 26 | 2 : 4 27 | 28 | defaults: 29 | - earthparserdataset -------------------------------------------------------------------------------- /configs/data/greenhouse.yaml: -------------------------------------------------------------------------------- 1 | name: "greenhouse" 2 | 3 | max_xy: 38.4 4 | max_z: 25.6 5 | 6 | subtile_max_xy: -1 7 | 8 | class_names: ["Unlabeled", "Ground", "Vegetation", "Building"] 9 | 10 | learning_map: 11 | 1 : 0 12 | 2 : 1 13 | 3 : 2 14 | 4 : 2 15 | 5 : 2 16 | 6 : 3 17 | 9 : 0 18 | 17: 0 19 | 64: 0 20 | 65: 0 21 | 66: 0 22 | 160: 0 23 | 161: 0 24 | 162: 0 25 | learning_map_inv: 26 | 0 : 1 27 | 1 : 2 28 | 2 : 4 29 | 3 : 6 30 | 31 | defaults: 32 | - earthparserdataset -------------------------------------------------------------------------------- /configs/data/windturbine.yaml: -------------------------------------------------------------------------------- 1 | name: "windturbine" 2 | 3 | pre_transform_grid_sample: 1. 4 | N_max: 10000 5 | 6 | subtile_max_xy: 1000. 7 | 8 | max_xy: 204.8 9 | max_z: 102.4 10 | 11 | class_names: ["Unlabeled"] 12 | 13 | ignore_index_0: False 14 | 15 | learning_map: 16 | 1 : 0 17 | 2 : 0 18 | 3 : 0 19 | 4 : 0 20 | 5 : 0 21 | 6 : 0 22 | 9 : 0 23 | 17: 0 24 | 64: 0 25 | 65: 0 26 | 66: 0 27 | 160: 0 28 | 161: 0 29 | 162: 0 30 | learning_map_inv: 31 | 0 : 1 32 | 33 | defaults: 34 | - earthparserdataset -------------------------------------------------------------------------------- /configs/data/marina.yaml: -------------------------------------------------------------------------------- 1 | name: "marina" 2 | 3 | max_xy: 12.8 4 | pre_transform_grid_sample: 0.1 5 | 6 | N_max: 2000 7 | 8 | subtile_max_xy: -1 9 | 10 | class_names: ["Unlabeled", "Boats", "Bridge"] 11 | 12 | ignore_index_0: True 13 | 14 | learning_map: 15 | 1 : 1 16 | 2 : 0 17 | 3 : 0 18 | 4 : 0 19 | 5 : 0 20 | 6 : 0 21 | 9 : 0 22 | 17: 2 23 | 64: 0 24 | 65: 0 25 | 66: 0 26 | 160: 0 27 | 161: 0 28 | 162: 0 29 | learning_map_inv: 30 | 0 : 1 31 | 1 : 2 32 | 2 : 17 33 | 34 | defaults: 35 | - earthparserdataset -------------------------------------------------------------------------------- /configs/data/power_plant.yaml: -------------------------------------------------------------------------------- 1 | name: "power_plant" 2 | 3 | max_z: 34.8 4 | subtile_max_xy: -1 5 | 6 | class_names: ["Unlabeled", "Ground", "Vegetation", "Building", "Lasting above"] #, "Pylon"] 7 | 8 | learning_map: 9 | 1 : 0 10 | 2 : 1 11 | 3 : 2 12 | 4 : 2 13 | 5 : 2 14 | 6 : 3 15 | 9 : 0 16 | 17: 0 17 | 64: 4 18 | 65: 0 19 | 66: 0 20 | 160: 0 21 | 161: 0 22 | 162: 4 23 | learning_map_inv: 24 | 0 : 1 25 | 1 : 2 26 | 2 : 4 27 | 3 : 6 28 | 4 : 64 29 | #5 : 162 30 | 31 | defaults: 32 | - earthparserdataset -------------------------------------------------------------------------------- /configs/data/shapenet_planes.yaml: -------------------------------------------------------------------------------- 1 | name: snsem 2 | 3 | data_dir: "../../Datasets_PanSeg/ShapeNetSem/shapenetcore_partanno_segmentation_benchmark_v0_normal" 4 | _target_: earthparserdataset.ShapeNetSemDataModule 5 | 6 | input_dim: 3 7 | 8 | max_xy: 3.2 9 | max_z: 3.2 10 | 11 | n_max: 0 12 | 13 | rotate_z: False 14 | 15 | num_workers: 16 16 | 17 | classes: [Airplane] # Airplane, Bag, Cap, Car, Chair, Earphone, Guitar, Knife, Lamp, Laptop, Motorbike, Mug, Pistol, Rocket, Skateboard, Table 18 | 19 | class_names: [a, b, c, d] #["Airplane", "Bag", "Cap", "Car", "Chair", "Earphone", "Guitar", "Knife", "Lamp", "Laptop", "Motorbike", "Mug", "Pistol", "Rocket", "Skateboard", "Table"] 20 | 21 | color_map: # rgb 22 | 0: [33, 158, 188] 23 | 1: [2, 48, 71] 24 | 2: [255, 183, 3] 25 | 3: [251, 133, 0] 26 | 27 | defaults: 28 | - default -------------------------------------------------------------------------------- /earthparserdataset/utils/labels.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | import torch 4 | 5 | import numpy as np 6 | 7 | def apply_learning_map(segmentation, f): 8 | mapped = copy.deepcopy(segmentation) 9 | for k, v in f.items(): 10 | mapped[segmentation == k] = v 11 | return mapped 12 | 13 | def from_sem_to_color(segmentation, f): 14 | color = torch.zeros(list(segmentation.size())+[3]).int() 15 | color[..., 0] = 255 16 | for k, v in f.items(): 17 | color[segmentation == k] = torch.tensor(v).int() 18 | return color 19 | 20 | def from_inst_to_color(instance): 21 | max_inst_id = 100000 22 | inst_color = np.random.uniform(low=0.0, 23 | high=1.0, 24 | size=(max_inst_id, 3)) 25 | inst_color[0] = np.full((3), 0.1) 26 | inst_color = torch.tensor(255*inst_color).int() 27 | 28 | color = torch.zeros(list(instance.size())+[3]).int() 29 | for k in torch.unique(instance.flatten()): 30 | color[instance == k] = inst_color[k] 31 | return color -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Romain Loiseau 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 | # Earth Parser Dataset Toolbox 4 | 5 |
6 | 7 | This repository contains helper scripts to open, visualize, and process point clouds from the Earth Parser dataset. 8 | 9 | ![earth parser dataset](media/earthparserdataset.png) 10 | 11 | ## Usage 12 | 13 | - **Download** 14 | 15 | The dataset can be downloaded from [zenodo](https://zenodo.org/record/7820686) 16 | 17 | - **Data loading** 18 | 19 | This repository contains: 20 | 21 | ```markdown 22 | 📦EarthParserDataset 23 | ┣ 📂configs # hydra config files 24 | ┣ 📂earthparserdataset # PytorchLightning datamodules 25 | ┣ 📂notebooks # some illustrative notebooks 26 | ``` 27 | 28 | - **Visualization and Usage** 29 | 30 | See our notebooks in `/notebooks` for examples of data manipulation and several visualization functions. 31 | 32 | ## Citation 33 | 34 | If you use this dataset and/or this API in your work, please cite our [paper](https://imagine.enpc.fr/~loiseaur/learnable-earth-parser). 35 | 36 | ```markdown 37 | @misc{loiseau2023learnable, 38 | title={Learnable Earth Parser: Discovering 3D Prototypes in Aerial Scans}, 39 | author={Romain Loiseau and Elliot Vincent and Mathieu Aubry and Loic Landrieu}, 40 | year={2023}, 41 | eprint={2304.09704}, 42 | archivePrefix={arXiv}, 43 | primaryClass={cs.CV} 44 | } 45 | ``` -------------------------------------------------------------------------------- /configs/data/earthparserdataset.yaml: -------------------------------------------------------------------------------- 1 | _target_: earthparserdataset.LidarHDDataModule 2 | data_dir: "../../Datasets_PanSeg/earthparserdataset/areas" 3 | 4 | N_scene: 0 5 | 6 | pre_transform_grid_sample: 0.3 7 | N_max: 10000 8 | 9 | max_xy: 25.6 10 | max_z: 25.6 11 | min_z: 0. 12 | 13 | subtile_max_xy: 250. 14 | 15 | random_jitter: .01 16 | random_scales: 17 | - 0.95 18 | - 1.05 19 | 20 | num_workers: 16 21 | 22 | input_dim: 3 23 | n_features: 7 24 | 25 | n_max: 0 26 | 27 | raw_class_names: 28 | 1 : "Unlabeled" 29 | 2 : "Ground" 30 | 3 : "Low vegetation" 31 | 4 : "Medium vegetation" 32 | 5 : "High vegetation" 33 | 6 : "Building" 34 | 9 : "Water" 35 | 17: "Bridge" 36 | 64: "Lasting above" 37 | 65: "Artifacts" 38 | 66: "Virtual Points" 39 | 160: "Aerial" 40 | 161: "Wind turbine" 41 | 162: "Pylon" 42 | 43 | class_names: ["Unlabeled", "Ground", "Low vegetation", "Medium vegetation", "High vegetation", "Building", "Water", "Bridge", "Lasting above", "Artifacts", "Virtual Points"] 44 | 45 | ignore_index_0: True 46 | 47 | color_map: # rgb 48 | 1 : [0, 0, 0] 49 | 2 : [128, 128, 128] 50 | 3 : [0, 255, 0] 51 | 4 : [0, 210, 0] 52 | 5 : [0, 175, 0] 53 | 6 : [0, 200, 255] 54 | 9 : [0, 0, 255] 55 | 17: [90, 58, 34] 56 | 64: [255, 125, 0] 57 | 65: [255, 0, 0] 58 | 66: [255, 0, 0] 59 | 160: [255, 0, 255] 60 | 161: [255, 0, 255] 61 | 162: [255, 0, 255] 62 | learning_map: 63 | 1 : 0 64 | 2 : 1 65 | 3 : 2 66 | 4 : 3 67 | 5 : 4 68 | 6 : 5 69 | 9 : 6 70 | 17: 7 71 | 64: 8 72 | 65: 9 73 | 66: 10 74 | 160: 11 75 | 161: 12 76 | 162: 13 77 | learning_map_inv: 78 | 0 : 1 79 | 1 : 2 80 | 2 : 3 81 | 3 : 4 82 | 4 : 5 83 | 5 : 6 84 | 6 : 9 85 | 7 : 17 86 | 8 : 64 87 | 9 : 65 88 | 10: 66 89 | 11: 160 90 | 12: 161 91 | 13: 162 92 | 93 | defaults: 94 | - lidar -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.html 2 | 3 | # Byte-compiled / optimized / DLL files 4 | __pycache__/ 5 | *.py[cod] 6 | *$py.class 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | pip-wheel-metadata/ 26 | share/python-wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | MANIFEST 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .nox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | *.py,cover 53 | .hypothesis/ 54 | .pytest_cache/ 55 | 56 | # Translations 57 | *.mo 58 | *.pot 59 | 60 | # Django stuff: 61 | *.log 62 | local_settings.py 63 | db.sqlite3 64 | db.sqlite3-journal 65 | 66 | # Flask stuff: 67 | instance/ 68 | .webassets-cache 69 | 70 | # Scrapy stuff: 71 | .scrapy 72 | 73 | # Sphinx documentation 74 | docs/_build/ 75 | 76 | # PyBuilder 77 | target/ 78 | 79 | # Jupyter Notebook 80 | .ipynb_checkpoints 81 | 82 | # IPython 83 | profile_default/ 84 | ipython_config.py 85 | 86 | # pyenv 87 | .python-version 88 | 89 | # pipenv 90 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 91 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 92 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 93 | # install all needed dependencies. 94 | #Pipfile.lock 95 | 96 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 97 | __pypackages__/ 98 | 99 | # Celery stuff 100 | celerybeat-schedule 101 | celerybeat.pid 102 | 103 | # SageMath parsed files 104 | *.sage.py 105 | 106 | # Environments 107 | .env 108 | .venv 109 | env/ 110 | venv/ 111 | ENV/ 112 | env.bak/ 113 | venv.bak/ 114 | 115 | # Spyder project settings 116 | .spyderproject 117 | .spyproject 118 | 119 | # Rope project settings 120 | .ropeproject 121 | 122 | # mkdocs documentation 123 | /site 124 | 125 | # mypy 126 | .mypy_cache/ 127 | .dmypy.json 128 | dmypy.json 129 | 130 | # Pyre type checker 131 | .pyre/ 132 | -------------------------------------------------------------------------------- /notebooks/demo_shapenet.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": { 7 | "scrolled": false 8 | }, 9 | "outputs": [], 10 | "source": [ 11 | "!HYDRA_FULL_ERROR=1\n", 12 | "\n", 13 | "%load_ext autoreload\n", 14 | "%autoreload 2\n", 15 | "import warnings\n", 16 | "warnings.filterwarnings(\"ignore\")\n", 17 | "\n", 18 | "from plotly.offline import init_notebook_mode\n", 19 | "init_notebook_mode(connected = True)\n", 20 | "\n", 21 | "import sys, os\n", 22 | "sys.path.append(\"../\")\n", 23 | "\n", 24 | "import hydra\n", 25 | "import pytorch_lightning as pl\n", 26 | "import logging\n", 27 | "\n", 28 | "pl.utilities.distributed.log.setLevel(logging.ERROR)\n", 29 | "\n", 30 | "hydra.initialize(config_path=\"../configs\")\n", 31 | "cfg = hydra.compose(overrides=[\"+data=shapenet_planes\"])\n", 32 | "cfg.data.data_dir = os.path.join(\"../\", cfg.data.data_dir)\n", 33 | "\n", 34 | "import numpy as np" 35 | ] 36 | }, 37 | { 38 | "cell_type": "code", 39 | "execution_count": null, 40 | "metadata": { 41 | "scrolled": false 42 | }, 43 | "outputs": [], 44 | "source": [ 45 | "datamodule = hydra.utils.instantiate(cfg.data)\n", 46 | "datamodule.setup()\n", 47 | "\n", 48 | "datamodule" 49 | ] 50 | }, 51 | { 52 | "cell_type": "code", 53 | "execution_count": null, 54 | "metadata": { 55 | "scrolled": false 56 | }, 57 | "outputs": [], 58 | "source": [ 59 | "for _ in range(8):\n", 60 | " data = datamodule.train_dataset[np.random.randint(len(datamodule.train_dataset))]\n", 61 | " datamodule.show(data)" 62 | ] 63 | }, 64 | { 65 | "cell_type": "code", 66 | "execution_count": null, 67 | "metadata": {}, 68 | "outputs": [], 69 | "source": [] 70 | } 71 | ], 72 | "metadata": { 73 | "interpreter": { 74 | "hash": "f66f361b60a0f669a83bf49483502faf91b64305c31842951638f8f87d6b5230" 75 | }, 76 | "kernelspec": { 77 | "display_name": "Python 3 (ipykernel)", 78 | "language": "python", 79 | "name": "python3" 80 | }, 81 | "language_info": { 82 | "codemirror_mode": { 83 | "name": "ipython", 84 | "version": 3 85 | }, 86 | "file_extension": ".py", 87 | "mimetype": "text/x-python", 88 | "name": "python", 89 | "nbconvert_exporter": "python", 90 | "pygments_lexer": "ipython3", 91 | "version": "3.9.13" 92 | } 93 | }, 94 | "nbformat": 4, 95 | "nbformat_minor": 2 96 | } 97 | -------------------------------------------------------------------------------- /notebooks/demo_2D.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": { 7 | "scrolled": false 8 | }, 9 | "outputs": [], 10 | "source": [ 11 | "!HYDRA_FULL_ERROR=1\n", 12 | "\n", 13 | "%load_ext autoreload\n", 14 | "%autoreload 2\n", 15 | "import warnings\n", 16 | "warnings.filterwarnings(\"ignore\")\n", 17 | "\n", 18 | "from plotly.offline import init_notebook_mode\n", 19 | "init_notebook_mode(connected = True)\n", 20 | "\n", 21 | "import sys, os\n", 22 | "sys.path.append(\"../\")\n", 23 | "\n", 24 | "import hydra\n", 25 | "import pytorch_lightning as pl\n", 26 | "import logging\n", 27 | "\n", 28 | "pl.utilities.distributed.log.setLevel(logging.ERROR)\n", 29 | "\n", 30 | "hydra.initialize(config_path=\"../configs\")\n", 31 | "cfg = hydra.compose(overrides=[\"+data=power_plant\", \"data.modality=2D\"])\n", 32 | "cfg.data.data_dir = os.path.join(\"../\", cfg.data.data_dir)\n", 33 | "\n", 34 | "import numpy as np" 35 | ] 36 | }, 37 | { 38 | "cell_type": "code", 39 | "execution_count": null, 40 | "metadata": { 41 | "scrolled": false 42 | }, 43 | "outputs": [], 44 | "source": [ 45 | "datamodule = hydra.utils.instantiate(cfg.data)\n", 46 | "datamodule.setup()\n", 47 | "\n", 48 | "datamodule.describe()" 49 | ] 50 | }, 51 | { 52 | "cell_type": "code", 53 | "execution_count": null, 54 | "metadata": { 55 | "scrolled": false 56 | }, 57 | "outputs": [], 58 | "source": [ 59 | "for _ in range(2):\n", 60 | " data = datamodule.train_dataset[np.random.randint(len(datamodule.train_dataset))]\n", 61 | " datamodule.show(data)" 62 | ] 63 | }, 64 | { 65 | "cell_type": "code", 66 | "execution_count": null, 67 | "metadata": {}, 68 | "outputs": [], 69 | "source": [] 70 | } 71 | ], 72 | "metadata": { 73 | "interpreter": { 74 | "hash": "f66f361b60a0f669a83bf49483502faf91b64305c31842951638f8f87d6b5230" 75 | }, 76 | "kernelspec": { 77 | "display_name": "Python 3 (ipykernel)", 78 | "language": "python", 79 | "name": "python3" 80 | }, 81 | "language_info": { 82 | "codemirror_mode": { 83 | "name": "ipython", 84 | "version": 3 85 | }, 86 | "file_extension": ".py", 87 | "mimetype": "text/x-python", 88 | "name": "python", 89 | "nbconvert_exporter": "python", 90 | "pygments_lexer": "ipython3", 91 | "version": "3.9.13" 92 | } 93 | }, 94 | "nbformat": 4, 95 | "nbformat_minor": 2 96 | } 97 | -------------------------------------------------------------------------------- /earthparserdataset/transforms/grid_sampling.py: -------------------------------------------------------------------------------- 1 | """ 2 | Modified from https://pytorch-geometric.readthedocs.io/en/latest/_modules/torch_geometric/transforms/grid_sampling.html 3 | """ 4 | 5 | import re 6 | from typing import List, Optional, Union 7 | 8 | import torch 9 | import torch.nn.functional as F 10 | from torch import Tensor 11 | from torch_scatter import scatter_add, scatter_mean, scatter_max 12 | 13 | import torch_geometric 14 | from torch_geometric.data import Data 15 | from torch_geometric.transforms import BaseTransform 16 | 17 | 18 | class GridSampling(BaseTransform): 19 | r"""Clusters points into voxels with size :attr:`size`. 20 | Each cluster returned is a new point based on the mean of all points 21 | inside the given cluster. 22 | 23 | Args: 24 | size (float or [float] or Tensor): Size of a voxel (in each dimension). 25 | start (float or [float] or Tensor, optional): Start coordinates of the 26 | grid (in each dimension). If set to :obj:`None`, will be set to the 27 | minimum coordinates found in :obj:`data.pos`. 28 | (default: :obj:`None`) 29 | end (float or [float] or Tensor, optional): End coordinates of the grid 30 | (in each dimension). If set to :obj:`None`, will be set to the 31 | maximum coordinates found in :obj:`data.pos`. 32 | (default: :obj:`None`) 33 | """ 34 | def __init__(self, size: Union[float, List[float], Tensor], 35 | start: Optional[Union[float, List[float], Tensor]] = None, 36 | end: Optional[Union[float, List[float], Tensor]] = None): 37 | self.size = size 38 | self.start = start 39 | self.end = end 40 | 41 | def __call__(self, data: Data) -> Data: 42 | num_nodes = data.num_nodes 43 | 44 | batch = data.get('batch', None) 45 | 46 | c = torch_geometric.nn.voxel_grid(data.pos, self.size, batch, 47 | self.start, self.end) 48 | c, perm = torch_geometric.nn.pool.consecutive.consecutive_cluster(c) 49 | 50 | for key, item in data: 51 | if bool(re.search('edge', key)): 52 | raise ValueError( 53 | 'GridSampling does not support coarsening of edges') 54 | 55 | if torch.is_tensor(item) and item.size(0) == num_nodes: 56 | if key in ['y', 'point_y', 'label', 'point_inst']: 57 | item = F.one_hot(item) 58 | item = scatter_add(item, c, dim=0) 59 | data[key] = item.argmax(dim=-1) 60 | elif key == 'batch': 61 | data[key] = item[perm] 62 | elif data[key].dtype == torch.uint8: 63 | data[key] = (255. * scatter_mean(item / 255., c, dim=0)).to(torch.uint8) 64 | else: 65 | data[key] = scatter_mean(item, c, dim=0) 66 | 67 | return data 68 | 69 | def __repr__(self) -> str: 70 | return f'{self.__class__.__name__}(size={self.size})' 71 | -------------------------------------------------------------------------------- /notebooks/demo_3D.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": { 7 | "scrolled": false 8 | }, 9 | "outputs": [], 10 | "source": [ 11 | "!HYDRA_FULL_ERROR=1\n", 12 | "\n", 13 | "%load_ext autoreload\n", 14 | "%autoreload 2\n", 15 | "import warnings\n", 16 | "warnings.filterwarnings(\"ignore\")\n", 17 | "\n", 18 | "from plotly.offline import init_notebook_mode\n", 19 | "init_notebook_mode(connected = True)\n", 20 | "import copy\n", 21 | "import sys, os\n", 22 | "sys.path.append(\"../\")\n", 23 | "\n", 24 | "import hydra\n", 25 | "import pytorch_lightning as pl\n", 26 | "import logging\n", 27 | "\n", 28 | "pl.utilities.distributed.log.setLevel(logging.ERROR)\n", 29 | "\n", 30 | "hydra.initialize(config_path=\"../configs\")\n", 31 | "cfg = hydra.compose(overrides=[\"+data=marina\"])\n", 32 | "cfg.data.data_dir = os.path.join(\"../\", cfg.data.data_dir)\n", 33 | "\n", 34 | "import numpy as np" 35 | ] 36 | }, 37 | { 38 | "cell_type": "code", 39 | "execution_count": null, 40 | "metadata": { 41 | "scrolled": false 42 | }, 43 | "outputs": [], 44 | "source": [ 45 | "datamodule = hydra.utils.instantiate(cfg.data)\n", 46 | "datamodule.setup()\n", 47 | "\n", 48 | "datamodule.describe()" 49 | ] 50 | }, 51 | { 52 | "cell_type": "code", 53 | "execution_count": null, 54 | "metadata": {}, 55 | "outputs": [], 56 | "source": [ 57 | "print(len(datamodule.train_dataset.slices[\"pos\"]) if datamodule.train_dataset.slices is not None else \"NO SLICES\")\n", 58 | "idx = 0\n", 59 | "\n", 60 | "data_to_print = copy.deepcopy(datamodule.train_dataset.data)\n", 61 | " \n", 62 | "if datamodule.train_dataset.slices is not None:\n", 63 | " for k in datamodule.train_dataset.slices.keys():\n", 64 | " data_to_print[k] = data_to_print[k][datamodule.train_dataset.slices[k][idx]:datamodule.train_dataset.slices[k][idx + 1]]\n", 65 | " \n", 66 | "datamodule.show(data_to_print, voxelize=2.*data_to_print.size(0) / 1000000)" 67 | ] 68 | }, 69 | { 70 | "cell_type": "code", 71 | "execution_count": null, 72 | "metadata": { 73 | "scrolled": false 74 | }, 75 | "outputs": [], 76 | "source": [ 77 | "for _ in range(4):\n", 78 | " data = datamodule.train_dataset[np.random.randint(len(datamodule.train_dataset))]\n", 79 | " datamodule.show(data)" 80 | ] 81 | }, 82 | { 83 | "cell_type": "code", 84 | "execution_count": null, 85 | "metadata": {}, 86 | "outputs": [], 87 | "source": [] 88 | } 89 | ], 90 | "metadata": { 91 | "interpreter": { 92 | "hash": "f66f361b60a0f669a83bf49483502faf91b64305c31842951638f8f87d6b5230" 93 | }, 94 | "kernelspec": { 95 | "display_name": "Python 3 (ipykernel)", 96 | "language": "python", 97 | "name": "python3" 98 | }, 99 | "language_info": { 100 | "codemirror_mode": { 101 | "name": "ipython", 102 | "version": 3 103 | }, 104 | "file_extension": ".py", 105 | "mimetype": "text/x-python", 106 | "name": "python", 107 | "nbconvert_exporter": "python", 108 | "pygments_lexer": "ipython3", 109 | "version": "3.9.13" 110 | } 111 | }, 112 | "nbformat": 4, 113 | "nbformat_minor": 2 114 | } 115 | -------------------------------------------------------------------------------- /earthparserdataset/shapenetsem.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import os.path as osp 3 | 4 | from torch_geometric.io import read_txt_array 5 | import torch 6 | import torch.nn.functional as F 7 | import torch_geometric.transforms as T 8 | from torch_geometric.data import Data, InMemoryDataset 9 | from .utils.labels import from_sem_to_color 10 | from .base import BaseDataModule 11 | from tqdm.auto import tqdm 12 | import json 13 | 14 | class MeanRemove(T.BaseTransform): 15 | def __init__(self) -> None: 16 | super().__init__() 17 | 18 | def __call__(self, data): 19 | data.pos -= data.pos.mean(0) 20 | return data 21 | 22 | class ShapeNetSemDataset(InMemoryDataset): 23 | CAT2ID = { 24 | "Airplane" : "02691156", 25 | "Bag" : "02773838", 26 | "Cap" : "02954340", 27 | "Car" : "02958343", 28 | "Chair" : "03001627", 29 | "Earphone" : "03261776", 30 | "Guitar" : "03467517", 31 | "Knife" : "03624134", 32 | "Lamp" : "03636649", 33 | "Laptop" : "03642806", 34 | "Motorbike" : "03790512", 35 | "Mug" : "03797390", 36 | "Pistol" : "03948459", 37 | "Rocket" : "04099429", 38 | "Skateboard" : "04225987", 39 | "Table" : "04379243", 40 | } 41 | 42 | def __init__(self, options, mode): 43 | self.options = copy.deepcopy(options) 44 | self.mode = mode 45 | 46 | transform = [ 47 | MeanRemove() 48 | ] 49 | 50 | if mode == "train": 51 | transform += [ 52 | T.RandomTranslate(0.01), 53 | T.RandomFlip(axis=1), 54 | ] 55 | 56 | if self.options.rotate_z: 57 | transform.append(T.RandomRotate(180, axis=2)) 58 | 59 | super().__init__(options.data_dir, T.Compose(transform), None, None) 60 | self.data, self.slices = torch.load(self.processed_paths[0]) 61 | 62 | self.data.point_y -= self.data.point_y.min() 63 | 64 | self.N_point_max = int((self.slices["pos"][1:] - self.slices["pos"][:-1]).max().item()) 65 | 66 | @property 67 | def raw_file_names(self): 68 | 69 | files = [] 70 | modes = ["train", "val", "test"] if self.mode in ["train", "test"] else ["val"] 71 | for mode in modes: 72 | with open(osp.join(self.root, 'train_test_split', f'shuffled_{mode}_file_list.json'), 'r') as f: 73 | files += json.load(f) 74 | 75 | select_ids = [self.CAT2ID[cat] for cat in self.options.classes] 76 | file_names = [] 77 | for file in files: 78 | if file.split('/')[1] in select_ids: 79 | file_names.append(f"{osp.join(*file.split('/')[1:])}.txt") 80 | 81 | return file_names 82 | 83 | @property 84 | def processed_file_names(self): 85 | return [f"{'_'.join(self.options.classes)}.pt"] 86 | 87 | @property 88 | def processed_dir(self) -> str: 89 | return osp.join(self.root, 'processed', self.mode) 90 | 91 | def process(self): 92 | # Read data into huge `Data` list. 93 | data_list = [] 94 | for file in tqdm(self.raw_file_names): 95 | data = read_txt_array(osp.join(self.root, file)) 96 | 97 | data_list.append(Data(pos=data[:, :3], point_y=data[:, -1])) 98 | 99 | if self.pre_filter is not None: 100 | data_list = [data for data in data_list if self.pre_filter(data)] 101 | 102 | if self.pre_transform is not None: 103 | data_list = [self.pre_transform(data) for data in data_list] 104 | 105 | data, slices = self.collate(data_list) 106 | data.pos = data.pos[:, [0, 2, 1]] 107 | torch.save((data, slices), self.processed_paths[0]) 108 | 109 | def len(self): 110 | return len(self.raw_file_names) 111 | 112 | def __getitem__(self, idx): 113 | data = super().__getitem__(idx) 114 | data.features = data.pos.clone() 115 | 116 | data.pos *= .5 * self.options.max_xy / ((data.pos**2).sum(-1)**.5).max() 117 | data.pos += .5 * self.options.max_xy 118 | 119 | data.pos_lenght = torch.tensor(data.pos.size(0)) 120 | 121 | data.pos_padded = F.pad(data.pos, (0, 0, 0, self.N_point_max - data.pos.shape[0])).unsqueeze(0) 122 | 123 | return data 124 | 125 | class ShapeNetSemDataModule(BaseDataModule): 126 | _DatasetSplit_ = { 127 | "train": ShapeNetSemDataset, 128 | "val": ShapeNetSemDataset, 129 | "test": ShapeNetSemDataset 130 | } 131 | 132 | def get_feature_names(self): 133 | return ["x", "y", "z"] 134 | 135 | def from_labels_to_color(self, labels): 136 | return from_sem_to_color( 137 | labels, 138 | self.myhparams.color_map 139 | ) -------------------------------------------------------------------------------- /earthparserdataset/transforms/base.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numbers 3 | import random 4 | 5 | import numpy as np 6 | import torch 7 | import torch_geometric.transforms as T 8 | from torch_geometric.transforms import BaseTransform 9 | 10 | 11 | class RandomCropSubTileTrain(T.BaseTransform): 12 | def __init__(self, max_xy) -> None: 13 | self.max_xy = max_xy 14 | super().__init__() 15 | 16 | def __call__(self, data): 17 | xy = data.pos[np.random.randint(data.pos.size(0)), :2] 18 | while (xy > data.pos_maxi - self.max_xy / 2.).any() or (xy < data.pos[:, :2].min(0)[0] + self.max_xy / 2.).any(): 19 | xy = data.pos[np.random.randint(data.pos.size(0)), :2] 20 | 21 | keep = np.abs(data.pos[:, :2] - xy).max(-1)[0] <= self.max_xy / 2. 22 | 23 | del data.pos_maxi 24 | 25 | for k in data.keys: 26 | setattr(data, k, getattr(data, k)[keep]) 27 | 28 | data.pos[..., :2] -= xy 29 | return data 30 | 31 | 32 | class CenterCrop(T.BaseTransform): 33 | def __init__(self, max_xy) -> None: 34 | self.max_xy = max_xy 35 | super().__init__() 36 | 37 | def __call__(self, data): 38 | 39 | keep = np.abs(data.pos[:, :2]).max(-1)[0] <= self.max_xy / 2. 40 | 41 | for k in data.keys: 42 | setattr(data, k, getattr(data, k)[keep]) 43 | 44 | data.pos[..., :2] += self.max_xy / 2. 45 | return data 46 | 47 | 48 | class RandomCropSubTileVal(T.BaseTransform): 49 | def __init__(self, max_xy) -> None: 50 | self.max_xy = max_xy 51 | super().__init__() 52 | 53 | def __call__(self, data): 54 | xy = data.pos[np.random.randint(data.pos.size(0)), :2] 55 | while (xy > data.pos_maxi - self.max_xy / 2.).any() or (xy < data.pos[:, :2].min(0)[0] + self.max_xy / 2.).any(): 56 | xy = data.pos[np.random.randint(data.pos.size(0)), :2] 57 | 58 | keep = np.abs(data.pos[:, :2] - xy).max(-1)[0] <= self.max_xy / 2. 59 | 60 | del data.pos_maxi 61 | 62 | for k in data.keys: 63 | setattr(data, k, getattr(data, k)[keep]) 64 | 65 | data.pos[..., :2] -= xy - self.max_xy / 2. 66 | return data 67 | 68 | 69 | class ZCrop(T.BaseTransform): 70 | def __init__(self, max_z) -> None: 71 | self.max_z = max_z 72 | super().__init__() 73 | 74 | def __call__(self, data): 75 | data.pos[:, -1] -= data.pos[:, -1].min() 76 | keep = data.pos[:, -1] < self.max_z 77 | 78 | for k in data.keys: 79 | setattr(data, k, getattr(data, k)[keep]) 80 | 81 | return data 82 | 83 | 84 | class MaxPoints(T.BaseTransform): 85 | def __init__(self, max_points) -> None: 86 | self.max_points = max_points 87 | super().__init__() 88 | 89 | def __call__(self, data): 90 | 91 | if data.pos.shape[0] > self.max_points: 92 | keep = np.random.choice( 93 | data.pos.shape[0], self.max_points, replace=True) 94 | for k in data.keys: 95 | setattr(data, k, getattr(data, k)[keep]) 96 | return data 97 | 98 | 99 | class RandomRotate(T.BaseTransform): 100 | r"""Rotates node positions around a specific axis by a randomly sampled 101 | factor within a given interval. 102 | 103 | Args: 104 | degrees (tuple or float): Rotation interval from which the rotation 105 | angle is sampled. If :obj:`degrees` is a number instead of a 106 | tuple, the interval is given by :math:`[-\mathrm{degrees}, 107 | \mathrm{degrees}]`. 108 | axis (int, optional): The rotation axis. (default: :obj:`0`) 109 | """ 110 | 111 | def __init__(self, degrees, axis=0): 112 | if isinstance(degrees, numbers.Number): 113 | degrees = (-abs(degrees), abs(degrees)) 114 | assert isinstance(degrees, (tuple, list)) and len(degrees) == 2 115 | self.degrees = degrees 116 | self.axis = axis 117 | 118 | def __call__(self, data): 119 | degree = math.pi * random.uniform(*self.degrees) / 180.0 120 | sin, cos = math.sin(degree), math.cos(degree) 121 | 122 | if data.pos.size(-1) == 2: 123 | matrix = [[cos, sin], [-sin, cos]] 124 | else: 125 | if self.axis == 0: 126 | matrix = [[1, 0, 0], [0, cos, sin], [0, -sin, cos]] 127 | elif self.axis == 1: 128 | matrix = [[cos, 0, -sin], [0, 1, 0], [sin, 0, cos]] 129 | else: 130 | matrix = [[cos, sin, 0], [-sin, cos, 0], [0, 0, 1]] 131 | 132 | matrix = torch.tensor(matrix).to(data.pos.device, data.pos.dtype).t() 133 | 134 | data.pos = data.pos @ matrix 135 | 136 | return data 137 | 138 | def __repr__(self) -> str: 139 | return (f'{self.__class__.__name__}({self.degrees}, ' 140 | f'axis={self.axis})') 141 | 142 | 143 | class RandomScale(BaseTransform): 144 | r"""Scales node positions by a randomly sampled factor :math:`s` within a 145 | given interval, *e.g.*, resulting in the transformation matrix 146 | 147 | .. math:: 148 | \begin{bmatrix} 149 | s & 0 & 0 \\ 150 | 0 & s & 0 \\ 151 | 0 & 0 & s \\ 152 | \end{bmatrix} 153 | 154 | for three-dimensional positions. 155 | 156 | Args: 157 | scales (tuple): scaling factor interval, e.g. :obj:`(a, b)`, then scale 158 | is randomly sampled from the range 159 | :math:`a \leq \mathrm{scale} \leq b`. 160 | """ 161 | 162 | def __init__(self, scales): 163 | assert isinstance(scales, (tuple, list)) and len(scales) == 2 164 | self.scales = scales 165 | 166 | def __call__(self, data): 167 | scale = random.uniform(*self.scales) 168 | data.pos = data.pos * scale 169 | data.pos_maxi = data.pos_maxi * scale 170 | 171 | return data 172 | 173 | def __repr__(self) -> str: 174 | return f'{self.__class__.__name__}({self.scales})' 175 | 176 | 177 | class RandomFlip(BaseTransform): 178 | """Flips node positions along a given axis randomly with a given 179 | probability. 180 | 181 | Args: 182 | axis (int): The axis along the position of nodes being flipped. 183 | p (float, optional): Probability that node positions will be flipped. 184 | (default: :obj:`0.5`) 185 | """ 186 | 187 | def __init__(self, axis, max_xy, p=0.5): 188 | self.axis = axis 189 | self.p = p 190 | self.max_xy = max_xy 191 | 192 | def __call__(self, data): 193 | if random.random() < self.p: 194 | data.pos[..., self.axis] = - data.pos[..., self.axis] 195 | 196 | return data 197 | 198 | def __repr__(self) -> str: 199 | return f'{self.__class__.__name__}(axis={self.axis}, p={self.p})' 200 | -------------------------------------------------------------------------------- /earthparserdataset/base.py: -------------------------------------------------------------------------------- 1 | from types import SimpleNamespace 2 | 3 | import copy 4 | import torch 5 | import plotly.graph_objects as go 6 | import numpy as np 7 | import pytorch_lightning as pl 8 | from torch_geometric.loader import DataLoader 9 | from matplotlib import cm 10 | import torch_scatter 11 | import torch.nn.functional as F 12 | import matplotlib.pyplot as plt 13 | 14 | from hydra.utils import to_absolute_path 15 | 16 | 17 | class BaseDataModule(pl.LightningDataModule): 18 | FEATURE2NAME = {"y": "Ground truth", "y_pred": "Prediction", "inst": "Ground truth instance", 19 | "inst_pred": "Predicted instance", "i": "Intensity", "rgb": "RGB", "infrared": "Infrared", "xyz": "Position"} 20 | 21 | def __init__(self, *args, **kwargs): 22 | super().__init__() 23 | 24 | self.myhparams = SimpleNamespace(**kwargs) 25 | self.myhparams.data_dir = to_absolute_path(self.myhparams.data_dir) 26 | 27 | def __repr__(self) -> str: 28 | 29 | out = f"DataModule:\t{self.__class__.__name__}" 30 | 31 | for split in ["train", "val", "test"]: 32 | if hasattr(self, f"{split}_dataset"): 33 | out += f"\n\t{getattr(self, f'{split}_dataset')}" 34 | 35 | return out 36 | 37 | def setup(self, stage: str = None): 38 | if stage in (None, 'fit'): 39 | self.train_dataset = self._DatasetSplit_["train"](self.myhparams, "train") 40 | self.val_dataset = self._DatasetSplit_["val"](self.myhparams, "val") 41 | elif stage in (None, 'validate'): 42 | self.val_dataset = self._DatasetSplit_["val"](self.myhparams, "val") 43 | elif stage in (None, 'train'): 44 | self.train_dataset = self._DatasetSplit_["train"](self.myhparams, "train") 45 | 46 | if stage in (None, 'test'): 47 | self.train_dataset = self._DatasetSplit_["train"](self.myhparams, "train") 48 | self.test_dataset = self._DatasetSplit_["test"](self.myhparams, "test") 49 | 50 | def train_dataloader(self): 51 | return DataLoader( 52 | self.train_dataset, 53 | batch_size=self.myhparams.batch_size, 54 | shuffle=True, 55 | num_workers=self.myhparams.num_workers, 56 | persistent_workers=True 57 | ) 58 | 59 | def val_dataloader(self): 60 | return DataLoader( 61 | self.val_dataset, 62 | batch_size=1, 63 | shuffle=False, 64 | num_workers=self.myhparams.num_workers, 65 | persistent_workers=True 66 | ) 67 | 68 | def test_dataloader(self): 69 | return DataLoader( 70 | self.test_dataset, 71 | batch_size=1, 72 | shuffle=False, 73 | num_workers=self.myhparams.num_workers 74 | ) 75 | 76 | def get_label_from_raw_feature(self, c): 77 | return self.FEATURE2NAME[c] 78 | 79 | def get_color_from_item(self, item, c): 80 | if c == "y" and hasattr(item, "point_y"): 81 | color = self.from_labels_to_color(item.point_y.squeeze()).numpy() 82 | elif c == "y_pred" and hasattr(item, "point_y_pred"): 83 | try: 84 | color = self.from_labels_to_color( 85 | item.point_y_pred.squeeze()).numpy() 86 | except: 87 | color = cm.get_cmap("tab20")(item.point_y_pred.squeeze( 88 | ).cpu().numpy() / item.point_y_pred.max().item())[:, :-1] 89 | color = (255*color).astype(np.uint8) 90 | elif c in ["inst_pred", "inst"] and hasattr(item, f"point_{c}"): 91 | color = cm.get_cmap("tab20")(np.random.permutation(getattr(item, f"point_{c}").max().item( 92 | )+1)[getattr(item, f"point_{c}").squeeze().cpu().numpy()] / getattr(item, f"point_{c}").max().item())[:, :-1] 93 | color = (255*color).astype(np.uint8) 94 | elif c == "i" and hasattr(item, "intensity"): 95 | color = item.intensity.squeeze().numpy() 96 | color = 0.01 + 0.98*(color - color.min()) / \ 97 | (color.max() - color.min()) 98 | color = cm.get_cmap("viridis")(color)[:, :-1] 99 | color = (255*color).astype(np.uint8) 100 | elif c == "rgb" and hasattr(item, "rgb"): 101 | color = 0.01 + 0.98*item.rgb / 255. 102 | else: 103 | color = item.pos.squeeze().numpy() 104 | color = 0.01 + 0.98*(color - color.min()) / \ 105 | (color.max() - color.min()) 106 | color = (255*color).astype(np.uint8) 107 | return color 108 | 109 | def show(self, item, voxelize=0, color="y;xyz;rgb;i", ps=5): 110 | if self.myhparams.modality == "3D": 111 | self.show3D(item, voxelize, color, ps) 112 | elif self.myhparams.modality == "2D": 113 | self.show2D(item) 114 | else: 115 | raise NotImplementedError(f"Modality {self.hparams.modality} not implemented") 116 | 117 | def show2D(self, item): 118 | fig, ax = plt.subplots(1, 4, figsize=(20, 5)) 119 | 120 | ax[0].imshow(item[..., 1] / 255.) 121 | ax[0].set_title("z max projection") 122 | ax[0].set_aspect("equal") 123 | ax[0].axis("off") 124 | 125 | ax[1].imshow(item[..., 3] / 255.) 126 | ax[1].set_title("intensity max projection") 127 | ax[1].set_aspect("equal") 128 | ax[1].axis("off") 129 | 130 | ax[2].imshow(item[..., [4,5,6]] / 255.) 131 | ax[2].set_title("rgb") 132 | ax[2].set_aspect("equal") 133 | ax[2].axis("off") 134 | 135 | ax[3].imshow(self.from_labels_to_color(item[..., -1])) 136 | ax[3].set_title("labels") 137 | ax[3].set_aspect("equal") 138 | ax[3].axis("off") 139 | 140 | plt.show() 141 | 142 | def show3D(self, item, voxelize=0, color="y;xyz;rgb;i", ps=5): 143 | 144 | thisitem = copy.deepcopy(item) 145 | 146 | if voxelize: 147 | choice = torch.unique( 148 | (thisitem.pos / voxelize).int(), return_inverse=True, dim=0)[1] 149 | for attr in ["point_y", "point_y_pred", "point_inst", "point_inst_pred"]: 150 | if hasattr(thisitem, attr): 151 | setattr(thisitem, attr, torch_scatter.scatter_sum(F.one_hot( 152 | getattr(thisitem, attr).squeeze().long()), choice, 0).argmax(-1).unsqueeze(0)) 153 | for attr in ["pos", "features"]: 154 | if hasattr(thisitem, attr): 155 | setattr(thisitem, attr, torch_scatter.scatter_mean( 156 | getattr(thisitem, attr), choice, 0)) 157 | for attr in ["intensity", "rgbi", "rgb"]: 158 | if hasattr(thisitem, attr): 159 | setattr(thisitem, attr, torch_scatter.scatter_max( 160 | getattr(thisitem, attr), choice, 0)[0]) 161 | 162 | datadtype = torch.float16 163 | margin = int(0.02 * 600) 164 | layout = go.Layout( 165 | width=1000, 166 | height=600, 167 | margin=dict(l=margin, r=margin, b=margin, t=4*margin), 168 | uirevision=True, 169 | showlegend=False 170 | ) 171 | fig = go.Figure( 172 | layout=layout, 173 | data=go.Scatter3d( 174 | x=thisitem.pos[:, 0].to(datadtype), y=thisitem.pos[:, 1].to(datadtype), z=thisitem.pos[:, 2].to(datadtype), 175 | mode='markers', 176 | marker=dict(size=ps, color=self.get_color_from_item( 177 | thisitem, color.split(";")[0])), 178 | ) 179 | ) 180 | updatemenus = [ 181 | dict( 182 | buttons=list([ 183 | dict( 184 | args=[{"marker": dict(size=ps, color=self.get_color_from_item(thisitem, c))}, [ 185 | 0]], 186 | label=self.get_label_from_raw_feature(c), 187 | method="restyle" 188 | ) for c in color.split(";") 189 | ]), 190 | direction="down", 191 | pad={"r": 10, "t": 10}, 192 | showactive=True, 193 | x=0.05, 194 | xanchor="left", 195 | y=0.88, 196 | yanchor="top" 197 | ), 198 | ] 199 | 200 | fig.update_layout(updatemenus=updatemenus) 201 | fig.update_layout( 202 | scene_aspectmode='data', 203 | ) 204 | 205 | fig.show() 206 | 207 | del thisitem -------------------------------------------------------------------------------- /earthparserdataset/earthparserdataset.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import os 3 | import os.path as osp 4 | 5 | import matplotlib.pyplot as plt 6 | import numpy as np 7 | import torch 8 | import torch.nn.functional as F 9 | import torch_geometric.transforms as T 10 | import torch_scatter 11 | from numpy.lib.recfunctions import structured_to_unstructured 12 | from torch_geometric.data import Data, InMemoryDataset 13 | from tqdm.auto import tqdm 14 | 15 | from .base import BaseDataModule 16 | from .transforms import (CenterCrop, GridSampling, MaxPoints, 17 | RandomCropSubTileTrain, RandomCropSubTileVal, 18 | RandomFlip, RandomRotate, RandomScale, ZCrop) 19 | from .utils import color as color 20 | from .utils.labels import apply_learning_map, from_sem_to_color 21 | 22 | 23 | class LidarHDSplit(InMemoryDataset): 24 | def __init__(self, options, mode): 25 | self.options = copy.deepcopy(options) 26 | self.options.data_dir = osp.join(self.options.data_dir, self.options.name) 27 | self.mode = mode 28 | 29 | self.feature_normalizer = torch.tensor( 30 | [[self.options.max_xy, self.options.max_xy, self.options.max_z]]) 31 | 32 | super().__init__( 33 | self.options.data_dir, 34 | transform=self.get_transform(), 35 | pre_transform=self.get_pre_transform(), 36 | pre_filter=None 37 | ) 38 | 39 | self.data, self.slices = torch.load(self.processed_paths[0]) 40 | self.data.point_y = apply_learning_map( 41 | self.data.point_y, self.options.learning_map) 42 | 43 | if mode in ["train", "val"]: 44 | self.items_per_epoch = int(self.options.items_per_epoch[mode] / (len(self.slices["pos"]) - 1 if self.slices is not None else 1)) 45 | else: 46 | self.prepare_test_dataset() 47 | 48 | if mode == "val": 49 | self.load_all_items() 50 | 51 | def prepare_test_dataset(self): 52 | self.items_per_epoch = [] 53 | self.tiles_unique_selection = {} 54 | self.tiles_min_z = {} 55 | self.from_idx_to_tile = {} 56 | idx = 0 57 | for i in range(self.__superlen__()): 58 | unique_i, inverse_i = torch.unique((self.__getsuperitem__( 59 | i).pos[:, :2] / self.options.max_xy).int(), dim=0, return_inverse=True) 60 | self.items_per_epoch.append(unique_i.shape[0]) 61 | self.tiles_unique_selection[i] = unique_i 62 | self.tiles_min_z[i] = torch_scatter.scatter_min( 63 | self.__getsuperitem__(i).pos[:, 2], inverse_i, dim=0)[0] 64 | for j in range(unique_i.shape[0]): 65 | self.from_idx_to_tile[idx] = (i, unique_i[j]) 66 | idx += 1 67 | 68 | def load_all_items(self): 69 | self.items = [super(LidarHDSplit, self).__getitem__(int(i / self.items_per_epoch)) for i in range(len(self))] 70 | del self.data 71 | 72 | def get_pre_transform(self): 73 | pre_transform = [] 74 | 75 | if self.mode == "test": 76 | pre_transform.append( 77 | GridSampling( 78 | self.options.pre_transform_grid_sample / 79 | 4. if not "windturbine" in self.options.name else self.options.pre_transform_grid_sample 80 | ) 81 | ) 82 | else: 83 | pre_transform.append(GridSampling( 84 | self.options.pre_transform_grid_sample)) 85 | pre_transform = T.Compose(pre_transform) 86 | return pre_transform 87 | 88 | def get_transform(self): 89 | if self.mode == "train": 90 | transform = [ 91 | RandomScale(tuple(self.options.random_scales)), 92 | RandomCropSubTileTrain(2.**.5 * self.options.max_xy), 93 | RandomRotate(degrees=180., axis=2), 94 | RandomFlip(0, self.options.max_xy), 95 | CenterCrop(self.options.max_xy), 96 | ZCrop(self.options.max_z), 97 | T.RandomTranslate(self.options.random_jitter) 98 | ] 99 | elif self.mode == "val": 100 | transform = [ 101 | RandomCropSubTileVal(self.options.max_xy), 102 | ZCrop(self.options.max_z) 103 | ] 104 | elif self.mode == "test": 105 | transform = [] 106 | else: 107 | raise NotImplementedError( 108 | f"Mode {self.mode} not implemented. Should be in ['train', 'val', 'test']") 109 | 110 | if self.mode in ["train", "val"]: 111 | if self.options.N_scene != 0: 112 | transform.append(T.FixedPoints(self.options.N_scene)) 113 | else: 114 | transform.append(MaxPoints(self.options.N_max)) 115 | 116 | transform = T.Compose(transform) 117 | return transform 118 | 119 | def __repr__(self) -> str: 120 | return f"{self.__class__.__name__}({self.mode})" 121 | 122 | @property 123 | def raw_file_names(self): 124 | laz_dirs = [] 125 | for tile in os.listdir(osp.join(self.options.data_dir, "tiles")): 126 | tile_dir = osp.join(self.options.data_dir, "tiles", tile) 127 | if osp.isdir(tile_dir): 128 | for subtile in os.listdir(tile_dir): 129 | if subtile.split(".")[-1] in ["las", "laz"]: 130 | laz_dirs.append(osp.join(tile_dir, subtile)) 131 | if tile_dir.split(".")[-1] in ["las", "laz"]: 132 | laz_dirs.append(tile_dir) 133 | 134 | return laz_dirs 135 | 136 | @property 137 | def processed_file_names(self): 138 | return [f'grid{1000*self.options.pre_transform_grid_sample:.0f}mm_data.pt'] 139 | 140 | @property 141 | def processed_dir(self) -> str: 142 | return osp.join(self.root, 'processed', self.mode if self.mode != "val" else "train") 143 | 144 | def process(self): 145 | data_list = [] 146 | 147 | import pdal 148 | for laz in tqdm(self.raw_file_names): 149 | list_pipeline = [pdal.Reader(laz)] 150 | pipeline = pdal.Pipeline() 151 | for p in list_pipeline: 152 | pipeline |= p 153 | count = pipeline.execute() 154 | arrays = pipeline.arrays 155 | 156 | label = structured_to_unstructured(arrays[0][["Classification"]]) 157 | for l in np.unique(label): 158 | if l not in self.options.raw_class_names.keys(): 159 | label[label == l] = 1 160 | 161 | xyz = structured_to_unstructured(arrays[0][["X", "Y", "Z"]]) 162 | intensity = structured_to_unstructured(arrays[0][["Intensity"]]) 163 | 164 | if "Red" in arrays[0].dtype.names: 165 | rgb = structured_to_unstructured( 166 | arrays[0][["Red", "Green", "Blue"]]) 167 | 168 | if rgb.max() >= 256: 169 | rgb = (rgb / 2**8).astype(np.uint8) 170 | else: 171 | rgb = rgb.astype(np.uint8) 172 | else: 173 | rgb = np.zeros_like(xyz, dtype=np.uint8) 174 | 175 | intensity = intensity.clip(np.percentile(intensity.flatten(), .1), np.percentile(intensity.flatten(), 99.9)) 176 | intensity = (intensity / intensity.max()).astype(np.float32) 177 | 178 | label = label.astype(np.int64) 179 | 180 | xyz -= xyz.min(0) 181 | xyz = xyz.astype(np.float32) 182 | maxi = xyz.max(0) 183 | 184 | if self.options.subtile_max_xy > 0: 185 | n_split = 1 + \ 186 | (maxi / ((1 + (self.mode == "test")) * self.options.subtile_max_xy)).astype(np.int32) 187 | else: 188 | n_split = [1, 1] 189 | 190 | for i in range(n_split[0]): 191 | for j in range(n_split[1]): 192 | keep = np.logical_and( 193 | np.logical_and( 194 | xyz[:, 0] >= i * maxi[0] / n_split[0], 195 | xyz[:, 0] < (i + 1) * maxi[0] / n_split[0] 196 | ), 197 | np.logical_and( 198 | xyz[:, 1] >= j * maxi[1] / n_split[1], 199 | xyz[:, 1] < (j + 1) * maxi[1] / n_split[1] 200 | ) 201 | ) 202 | 203 | data_list.append(Data( 204 | pos=torch.from_numpy(xyz[keep] - xyz[keep].min(0)), 205 | intensity=torch.from_numpy(intensity[keep]), 206 | rgb=torch.from_numpy(rgb[keep]), 207 | point_y=torch.from_numpy(label[keep]), 208 | )) 209 | 210 | if self.mode in ["train", "val"]: 211 | data_list[-1].pos_maxi = data_list[-1].pos[:, :2].max(0)[0] 212 | 213 | if self.pre_filter is not None: 214 | data_list = [data for data in data_list if self.pre_filter(data)] 215 | 216 | if self.pre_transform is not None: 217 | data_list = [self.pre_transform(data) for data in data_list] 218 | 219 | data, slices = self.collate(data_list) 220 | 221 | torch.save((data, slices), self.processed_paths[0]) 222 | 223 | def __superlen__(self) -> int: 224 | return super().__len__() 225 | 226 | def __len__(self) -> int: 227 | if self.mode != "test": 228 | return self.items_per_epoch * self.__superlen__() 229 | else: 230 | return sum(self.items_per_epoch) 231 | 232 | def __getsuperitem__(self, idx): 233 | return super().__getitem__(idx) 234 | 235 | def __getitem__(self, idx): 236 | if self.mode == "train": 237 | item = self.__getsuperitem__(int(idx / self.items_per_epoch)) 238 | elif self.mode == "val": 239 | item = self.items[idx] 240 | else: 241 | this_idx, tile = self.from_idx_to_tile[idx] 242 | item = self.__getsuperitem__(this_idx) 243 | keep = (item.pos[:, :2] / self.options.max_xy).int() 244 | keep = torch.logical_and( 245 | keep[:, 0] == tile[0], keep[:, 1] == tile[1]) 246 | for k in item.keys: 247 | setattr(item, k, getattr(item, k)[keep]) 248 | 249 | item.pos[:, :2] -= self.options.max_xy * tile 250 | item.pos[:, -1] -= item.pos[:, -1].min() 251 | 252 | keep = item.pos[:, -1] < self.options.max_z 253 | for k in item.keys: 254 | setattr(item, k, getattr(item, k)[keep]) 255 | 256 | item.pos_lenght = torch.tensor(item.pos.size(0)) 257 | 258 | if self.options.modality == "3D": 259 | pad = self.options.N_max - item.pos.size(0) if self.mode != "test" else 0 260 | if self.options.distance == "xyz": 261 | item.pos_padded = F.pad(item.pos, (0, 0, 0, pad), mode="constant", value=0).unsqueeze(0) 262 | elif self.options.distance == "xyzk": 263 | item.pos_padded = F.pad(torch.cat([item.pos, item.intensity], -1), (0, 0, 0, pad), mode="constant", value=0).unsqueeze(0) 264 | else: 265 | raise NotImplementedError( 266 | f"LiDAR-HD can't produce {self.options.distance}") 267 | 268 | item.features = 2 * \ 269 | torch.cat([item.rgb / 255., item.pos / 270 | self.feature_normalizer, item.intensity], -1) - 1 271 | #del item.intensity 272 | 273 | if self.options.modality == "2D": 274 | item = self.from3Dto2Ditem(item) 275 | 276 | return item 277 | 278 | def from3Dto2Ditem(self, item): 279 | del item.pos_lenght 280 | res, n_dim = self.options.image.res, self.options.image.n_dim 281 | intensity = item.intensity.squeeze() 282 | rgb = item.rgb.float() 283 | labels = item.point_y.float().squeeze() 284 | xy, z = item.pos[:, :2], item.pos[:, 2] 285 | xy = torch.clamp(torch.floor( 286 | xy / (self.options.max_xy / (res-0.001))), 0, res - 1) 287 | xy = (xy[:, 0] * res + xy[:, 1]).long() 288 | features = [] 289 | for values, init in zip([z, intensity, rgb, labels], [0, 0, 0, -1]): 290 | for mode in ['min', 'max']: 291 | features.append(gather_values( 292 | xy, z, values, mode=mode, res=res, init=init)) 293 | item = torch.cat(features, dim=-1) 294 | item = item.reshape(res, res, n_dim) 295 | return item 296 | 297 | def gather_values(xy, z, values, mode='max', res=32, init=0): 298 | img = (torch.ones( 299 | res*res, values.shape[1] if 1 < len(values.shape) else 1) * init).squeeze() 300 | z = z.sort(descending=mode == 'max') 301 | xy, values = xy[z[1]], values[z[1]] 302 | unique, inverse = torch.unique(xy, sorted=True, return_inverse=True) 303 | perm = torch.arange(inverse.size( 304 | 0), dtype=inverse.dtype, device=inverse.device) 305 | inverse, perm = inverse.flip([0]), perm.flip([0]) 306 | perm = inverse.new_empty(unique.size(0)).scatter_(0, inverse, perm) 307 | img.index_put_((xy[perm],), values[perm], accumulate=False) 308 | img = img.reshape(res, res, -1) 309 | return img 310 | 311 | 312 | class LidarHDDataModule(BaseDataModule): 313 | _DatasetSplit_ = { 314 | "train": LidarHDSplit, 315 | "val": LidarHDSplit, 316 | "test": LidarHDSplit 317 | } 318 | 319 | def from_labels_to_color(self, labels): 320 | return from_sem_to_color(apply_learning_map(labels, self.myhparams.learning_map_inv), self.myhparams.color_map) 321 | 322 | def get_feature_names(self): 323 | return ["red", "green", "blue", "x", "y", "z", "intensity"] 324 | 325 | def describe(self): 326 | print(self) 327 | 328 | for split in ["train", "val", "test"]: 329 | if hasattr(self, f"{split}_dataset") and hasattr(getattr(self, f"{split}_dataset"), "data"): 330 | print(f"{split} data\t", getattr( 331 | self, f"{split}_dataset").data) 332 | if hasattr(getattr(self, f"{split}_dataset").data, "point_y"): 333 | for c, n in zip(*np.unique(getattr(self, f"{split}_dataset").data.point_y.flatten().numpy(), return_counts=True)): 334 | print( 335 | f"class {self.myhparams.raw_class_names[self.myhparams.learning_map_inv[int(c)]]} ({c}) {(20 - len(self.myhparams.raw_class_names[self.myhparams.learning_map_inv[int(c)]]))*' '} \thas {n} \tpoints") 336 | 337 | if hasattr(self, "val_dataset"): 338 | lens = [item.pos.size(0) for item in self.val_dataset.items] 339 | plt.hist(lens) 340 | plt.title( 341 | f"size of val dataset items between {min(lens)} and {max(lens)}") 342 | plt.show() 343 | --------------------------------------------------------------------------------