├── earthparserdataset
├── utils
│ ├── __init__.py
│ ├── color.py
│ └── labels.py
├── __init__.py
├── transforms
│ ├── __init__.py
│ ├── grid_sampling.py
│ └── base.py
├── shapenetsem.py
├── base.py
└── earthparserdataset.py
├── media
└── earthparserdataset.png
├── configs
└── data
│ ├── lidar.yaml
│ ├── default.yaml
│ ├── urban.yaml
│ ├── crop_field.yaml
│ ├── forest.yaml
│ ├── greenhouse.yaml
│ ├── windturbine.yaml
│ ├── marina.yaml
│ ├── power_plant.yaml
│ ├── shapenet_planes.yaml
│ └── earthparserdataset.yaml
├── LICENSE
├── README.md
├── .gitignore
└── notebooks
├── demo_shapenet.ipynb
├── demo_2D.ipynb
└── demo_3D.ipynb
/earthparserdataset/utils/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/media/earthparserdataset.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/romainloiseau/EarthParserDataset/HEAD/media/earthparserdataset.png
--------------------------------------------------------------------------------
/earthparserdataset/__init__.py:
--------------------------------------------------------------------------------
1 | from .earthparserdataset import LidarHDDataModule
2 | from .shapenetsem import ShapeNetSemDataModule
--------------------------------------------------------------------------------
/earthparserdataset/transforms/__init__.py:
--------------------------------------------------------------------------------
1 | from .base import ZCrop, MaxPoints, RandomRotate, RandomFlip, RandomScale, RandomCropSubTileVal, RandomCropSubTileTrain, CenterCrop
2 | from .grid_sampling import GridSampling
--------------------------------------------------------------------------------
/configs/data/lidar.yaml:
--------------------------------------------------------------------------------
1 | batch_size: 64
2 | n_features: 8
3 | N_scene: 0
4 |
5 | items_per_epoch:
6 | train: 32768
7 | val: 128
8 |
9 | distance: xyzk
10 |
11 | DEBUG: True
12 |
13 | defaults:
14 | - default
--------------------------------------------------------------------------------
/earthparserdataset/utils/color.py:
--------------------------------------------------------------------------------
1 | import kornia.color as col
2 |
3 | def rgb_to_lab(x):
4 | return col.rgb_to_lab(x.T.unsqueeze(-1)).squeeze().T / 127.
5 |
6 | def lab_to_rgb(x):
7 | return col.lab_to_rgb(127. * x.T.unsqueeze(-1)).squeeze().T
--------------------------------------------------------------------------------
/configs/data/default.yaml:
--------------------------------------------------------------------------------
1 | data_dir: "."
2 | batch_size: 64
3 | num_workers: 8
4 |
5 | n_features: 3
6 |
7 | min_z: 0
8 |
9 | distance: xyz
10 |
11 | ignore_index_0: False
12 |
13 | modality: 3D # 2D or 3D
14 |
15 | image:
16 | res: 32
17 | n_dim: 12
--------------------------------------------------------------------------------
/configs/data/urban.yaml:
--------------------------------------------------------------------------------
1 | name: "urban"
2 |
3 | class_names: ["Unlabeled", "Ground", "Vegetation", "Building"]
4 |
5 | learning_map:
6 | 1 : 0
7 | 2 : 1
8 | 3 : 2
9 | 4 : 2
10 | 5 : 2
11 | 6 : 3
12 | 9 : 0
13 | 17: 0
14 | 64: 0
15 | 65: 0
16 | 66: 0
17 | 160: 0
18 | 161: 0
19 | 162: 0
20 | learning_map_inv:
21 | 0 : 1
22 | 1 : 2
23 | 2 : 4
24 | 3 : 6
25 |
26 | defaults:
27 | - earthparserdataset
--------------------------------------------------------------------------------
/configs/data/crop_field.yaml:
--------------------------------------------------------------------------------
1 | name: "crop_field"
2 |
3 | max_xy: 25.6
4 |
5 | class_names: ["Unlabeled", "Ground", "Vegetation"]
6 |
7 | learning_map:
8 | 1 : 0
9 | 2 : 1
10 | 3 : 2
11 | 4 : 2
12 | 5 : 2
13 | 6 : 0
14 | 9 : 0
15 | 17: 0
16 | 64: 0
17 | 65: 0
18 | 66: 0
19 | 160: 0
20 | 161: 0
21 | 162: 0
22 | learning_map_inv:
23 | 0 : 1
24 | 1 : 2
25 | 2 : 4
26 |
27 | defaults:
28 | - earthparserdataset
--------------------------------------------------------------------------------
/configs/data/forest.yaml:
--------------------------------------------------------------------------------
1 | name: "forest"
2 |
3 | max_z: 34.8
4 | max_xy: 25.6
5 |
6 | class_names: ["Unlabeled", "Ground", "Vegetation"]
7 |
8 | learning_map:
9 | 1 : 0
10 | 2 : 1
11 | 3 : 2
12 | 4 : 2
13 | 5 : 2
14 | 6 : 0
15 | 9 : 0
16 | 17: 0
17 | 64: 0
18 | 65: 0
19 | 66: 0
20 | 160: 0
21 | 161: 0
22 | 162: 0
23 | learning_map_inv:
24 | 0 : 1
25 | 1 : 2
26 | 2 : 4
27 |
28 | defaults:
29 | - earthparserdataset
--------------------------------------------------------------------------------
/configs/data/greenhouse.yaml:
--------------------------------------------------------------------------------
1 | name: "greenhouse"
2 |
3 | max_xy: 38.4
4 | max_z: 25.6
5 |
6 | subtile_max_xy: -1
7 |
8 | class_names: ["Unlabeled", "Ground", "Vegetation", "Building"]
9 |
10 | learning_map:
11 | 1 : 0
12 | 2 : 1
13 | 3 : 2
14 | 4 : 2
15 | 5 : 2
16 | 6 : 3
17 | 9 : 0
18 | 17: 0
19 | 64: 0
20 | 65: 0
21 | 66: 0
22 | 160: 0
23 | 161: 0
24 | 162: 0
25 | learning_map_inv:
26 | 0 : 1
27 | 1 : 2
28 | 2 : 4
29 | 3 : 6
30 |
31 | defaults:
32 | - earthparserdataset
--------------------------------------------------------------------------------
/configs/data/windturbine.yaml:
--------------------------------------------------------------------------------
1 | name: "windturbine"
2 |
3 | pre_transform_grid_sample: 1.
4 | N_max: 10000
5 |
6 | subtile_max_xy: 1000.
7 |
8 | max_xy: 204.8
9 | max_z: 102.4
10 |
11 | class_names: ["Unlabeled"]
12 |
13 | ignore_index_0: False
14 |
15 | learning_map:
16 | 1 : 0
17 | 2 : 0
18 | 3 : 0
19 | 4 : 0
20 | 5 : 0
21 | 6 : 0
22 | 9 : 0
23 | 17: 0
24 | 64: 0
25 | 65: 0
26 | 66: 0
27 | 160: 0
28 | 161: 0
29 | 162: 0
30 | learning_map_inv:
31 | 0 : 1
32 |
33 | defaults:
34 | - earthparserdataset
--------------------------------------------------------------------------------
/configs/data/marina.yaml:
--------------------------------------------------------------------------------
1 | name: "marina"
2 |
3 | max_xy: 12.8
4 | pre_transform_grid_sample: 0.1
5 |
6 | N_max: 2000
7 |
8 | subtile_max_xy: -1
9 |
10 | class_names: ["Unlabeled", "Boats", "Bridge"]
11 |
12 | ignore_index_0: True
13 |
14 | learning_map:
15 | 1 : 1
16 | 2 : 0
17 | 3 : 0
18 | 4 : 0
19 | 5 : 0
20 | 6 : 0
21 | 9 : 0
22 | 17: 2
23 | 64: 0
24 | 65: 0
25 | 66: 0
26 | 160: 0
27 | 161: 0
28 | 162: 0
29 | learning_map_inv:
30 | 0 : 1
31 | 1 : 2
32 | 2 : 17
33 |
34 | defaults:
35 | - earthparserdataset
--------------------------------------------------------------------------------
/configs/data/power_plant.yaml:
--------------------------------------------------------------------------------
1 | name: "power_plant"
2 |
3 | max_z: 34.8
4 | subtile_max_xy: -1
5 |
6 | class_names: ["Unlabeled", "Ground", "Vegetation", "Building", "Lasting above"] #, "Pylon"]
7 |
8 | learning_map:
9 | 1 : 0
10 | 2 : 1
11 | 3 : 2
12 | 4 : 2
13 | 5 : 2
14 | 6 : 3
15 | 9 : 0
16 | 17: 0
17 | 64: 4
18 | 65: 0
19 | 66: 0
20 | 160: 0
21 | 161: 0
22 | 162: 4
23 | learning_map_inv:
24 | 0 : 1
25 | 1 : 2
26 | 2 : 4
27 | 3 : 6
28 | 4 : 64
29 | #5 : 162
30 |
31 | defaults:
32 | - earthparserdataset
--------------------------------------------------------------------------------
/configs/data/shapenet_planes.yaml:
--------------------------------------------------------------------------------
1 | name: snsem
2 |
3 | data_dir: "../../Datasets_PanSeg/ShapeNetSem/shapenetcore_partanno_segmentation_benchmark_v0_normal"
4 | _target_: earthparserdataset.ShapeNetSemDataModule
5 |
6 | input_dim: 3
7 |
8 | max_xy: 3.2
9 | max_z: 3.2
10 |
11 | n_max: 0
12 |
13 | rotate_z: False
14 |
15 | num_workers: 16
16 |
17 | classes: [Airplane] # Airplane, Bag, Cap, Car, Chair, Earphone, Guitar, Knife, Lamp, Laptop, Motorbike, Mug, Pistol, Rocket, Skateboard, Table
18 |
19 | class_names: [a, b, c, d] #["Airplane", "Bag", "Cap", "Car", "Chair", "Earphone", "Guitar", "Knife", "Lamp", "Laptop", "Motorbike", "Mug", "Pistol", "Rocket", "Skateboard", "Table"]
20 |
21 | color_map: # rgb
22 | 0: [33, 158, 188]
23 | 1: [2, 48, 71]
24 | 2: [255, 183, 3]
25 | 3: [251, 133, 0]
26 |
27 | defaults:
28 | - default
--------------------------------------------------------------------------------
/earthparserdataset/utils/labels.py:
--------------------------------------------------------------------------------
1 | import copy
2 |
3 | import torch
4 |
5 | import numpy as np
6 |
7 | def apply_learning_map(segmentation, f):
8 | mapped = copy.deepcopy(segmentation)
9 | for k, v in f.items():
10 | mapped[segmentation == k] = v
11 | return mapped
12 |
13 | def from_sem_to_color(segmentation, f):
14 | color = torch.zeros(list(segmentation.size())+[3]).int()
15 | color[..., 0] = 255
16 | for k, v in f.items():
17 | color[segmentation == k] = torch.tensor(v).int()
18 | return color
19 |
20 | def from_inst_to_color(instance):
21 | max_inst_id = 100000
22 | inst_color = np.random.uniform(low=0.0,
23 | high=1.0,
24 | size=(max_inst_id, 3))
25 | inst_color[0] = np.full((3), 0.1)
26 | inst_color = torch.tensor(255*inst_color).int()
27 |
28 | color = torch.zeros(list(instance.size())+[3]).int()
29 | for k in torch.unique(instance.flatten()):
30 | color[instance == k] = inst_color[k]
31 | return color
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Romain Loiseau
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | # Earth Parser Dataset Toolbox
4 |
5 |
6 |
7 | This repository contains helper scripts to open, visualize, and process point clouds from the Earth Parser dataset.
8 |
9 | 
10 |
11 | ## Usage
12 |
13 | - **Download**
14 |
15 | The dataset can be downloaded from [zenodo](https://zenodo.org/record/7820686)
16 |
17 | - **Data loading**
18 |
19 | This repository contains:
20 |
21 | ```markdown
22 | 📦EarthParserDataset
23 | ┣ 📂configs # hydra config files
24 | ┣ 📂earthparserdataset # PytorchLightning datamodules
25 | ┣ 📂notebooks # some illustrative notebooks
26 | ```
27 |
28 | - **Visualization and Usage**
29 |
30 | See our notebooks in `/notebooks` for examples of data manipulation and several visualization functions.
31 |
32 | ## Citation
33 |
34 | If you use this dataset and/or this API in your work, please cite our [paper](https://imagine.enpc.fr/~loiseaur/learnable-earth-parser).
35 |
36 | ```markdown
37 | @misc{loiseau2023learnable,
38 | title={Learnable Earth Parser: Discovering 3D Prototypes in Aerial Scans},
39 | author={Romain Loiseau and Elliot Vincent and Mathieu Aubry and Loic Landrieu},
40 | year={2023},
41 | eprint={2304.09704},
42 | archivePrefix={arXiv},
43 | primaryClass={cs.CV}
44 | }
45 | ```
--------------------------------------------------------------------------------
/configs/data/earthparserdataset.yaml:
--------------------------------------------------------------------------------
1 | _target_: earthparserdataset.LidarHDDataModule
2 | data_dir: "../../Datasets_PanSeg/earthparserdataset/areas"
3 |
4 | N_scene: 0
5 |
6 | pre_transform_grid_sample: 0.3
7 | N_max: 10000
8 |
9 | max_xy: 25.6
10 | max_z: 25.6
11 | min_z: 0.
12 |
13 | subtile_max_xy: 250.
14 |
15 | random_jitter: .01
16 | random_scales:
17 | - 0.95
18 | - 1.05
19 |
20 | num_workers: 16
21 |
22 | input_dim: 3
23 | n_features: 7
24 |
25 | n_max: 0
26 |
27 | raw_class_names:
28 | 1 : "Unlabeled"
29 | 2 : "Ground"
30 | 3 : "Low vegetation"
31 | 4 : "Medium vegetation"
32 | 5 : "High vegetation"
33 | 6 : "Building"
34 | 9 : "Water"
35 | 17: "Bridge"
36 | 64: "Lasting above"
37 | 65: "Artifacts"
38 | 66: "Virtual Points"
39 | 160: "Aerial"
40 | 161: "Wind turbine"
41 | 162: "Pylon"
42 |
43 | class_names: ["Unlabeled", "Ground", "Low vegetation", "Medium vegetation", "High vegetation", "Building", "Water", "Bridge", "Lasting above", "Artifacts", "Virtual Points"]
44 |
45 | ignore_index_0: True
46 |
47 | color_map: # rgb
48 | 1 : [0, 0, 0]
49 | 2 : [128, 128, 128]
50 | 3 : [0, 255, 0]
51 | 4 : [0, 210, 0]
52 | 5 : [0, 175, 0]
53 | 6 : [0, 200, 255]
54 | 9 : [0, 0, 255]
55 | 17: [90, 58, 34]
56 | 64: [255, 125, 0]
57 | 65: [255, 0, 0]
58 | 66: [255, 0, 0]
59 | 160: [255, 0, 255]
60 | 161: [255, 0, 255]
61 | 162: [255, 0, 255]
62 | learning_map:
63 | 1 : 0
64 | 2 : 1
65 | 3 : 2
66 | 4 : 3
67 | 5 : 4
68 | 6 : 5
69 | 9 : 6
70 | 17: 7
71 | 64: 8
72 | 65: 9
73 | 66: 10
74 | 160: 11
75 | 161: 12
76 | 162: 13
77 | learning_map_inv:
78 | 0 : 1
79 | 1 : 2
80 | 2 : 3
81 | 3 : 4
82 | 4 : 5
83 | 5 : 6
84 | 6 : 9
85 | 7 : 17
86 | 8 : 64
87 | 9 : 65
88 | 10: 66
89 | 11: 160
90 | 12: 161
91 | 13: 162
92 |
93 | defaults:
94 | - lidar
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | *.html
2 |
3 | # Byte-compiled / optimized / DLL files
4 | __pycache__/
5 | *.py[cod]
6 | *$py.class
7 |
8 | # C extensions
9 | *.so
10 |
11 | # Distribution / packaging
12 | .Python
13 | build/
14 | develop-eggs/
15 | dist/
16 | downloads/
17 | eggs/
18 | .eggs/
19 | lib/
20 | lib64/
21 | parts/
22 | sdist/
23 | var/
24 | wheels/
25 | pip-wheel-metadata/
26 | share/python-wheels/
27 | *.egg-info/
28 | .installed.cfg
29 | *.egg
30 | MANIFEST
31 |
32 | # PyInstaller
33 | # Usually these files are written by a python script from a template
34 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
35 | *.manifest
36 | *.spec
37 |
38 | # Installer logs
39 | pip-log.txt
40 | pip-delete-this-directory.txt
41 |
42 | # Unit test / coverage reports
43 | htmlcov/
44 | .tox/
45 | .nox/
46 | .coverage
47 | .coverage.*
48 | .cache
49 | nosetests.xml
50 | coverage.xml
51 | *.cover
52 | *.py,cover
53 | .hypothesis/
54 | .pytest_cache/
55 |
56 | # Translations
57 | *.mo
58 | *.pot
59 |
60 | # Django stuff:
61 | *.log
62 | local_settings.py
63 | db.sqlite3
64 | db.sqlite3-journal
65 |
66 | # Flask stuff:
67 | instance/
68 | .webassets-cache
69 |
70 | # Scrapy stuff:
71 | .scrapy
72 |
73 | # Sphinx documentation
74 | docs/_build/
75 |
76 | # PyBuilder
77 | target/
78 |
79 | # Jupyter Notebook
80 | .ipynb_checkpoints
81 |
82 | # IPython
83 | profile_default/
84 | ipython_config.py
85 |
86 | # pyenv
87 | .python-version
88 |
89 | # pipenv
90 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
91 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
92 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
93 | # install all needed dependencies.
94 | #Pipfile.lock
95 |
96 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
97 | __pypackages__/
98 |
99 | # Celery stuff
100 | celerybeat-schedule
101 | celerybeat.pid
102 |
103 | # SageMath parsed files
104 | *.sage.py
105 |
106 | # Environments
107 | .env
108 | .venv
109 | env/
110 | venv/
111 | ENV/
112 | env.bak/
113 | venv.bak/
114 |
115 | # Spyder project settings
116 | .spyderproject
117 | .spyproject
118 |
119 | # Rope project settings
120 | .ropeproject
121 |
122 | # mkdocs documentation
123 | /site
124 |
125 | # mypy
126 | .mypy_cache/
127 | .dmypy.json
128 | dmypy.json
129 |
130 | # Pyre type checker
131 | .pyre/
132 |
--------------------------------------------------------------------------------
/notebooks/demo_shapenet.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {
7 | "scrolled": false
8 | },
9 | "outputs": [],
10 | "source": [
11 | "!HYDRA_FULL_ERROR=1\n",
12 | "\n",
13 | "%load_ext autoreload\n",
14 | "%autoreload 2\n",
15 | "import warnings\n",
16 | "warnings.filterwarnings(\"ignore\")\n",
17 | "\n",
18 | "from plotly.offline import init_notebook_mode\n",
19 | "init_notebook_mode(connected = True)\n",
20 | "\n",
21 | "import sys, os\n",
22 | "sys.path.append(\"../\")\n",
23 | "\n",
24 | "import hydra\n",
25 | "import pytorch_lightning as pl\n",
26 | "import logging\n",
27 | "\n",
28 | "pl.utilities.distributed.log.setLevel(logging.ERROR)\n",
29 | "\n",
30 | "hydra.initialize(config_path=\"../configs\")\n",
31 | "cfg = hydra.compose(overrides=[\"+data=shapenet_planes\"])\n",
32 | "cfg.data.data_dir = os.path.join(\"../\", cfg.data.data_dir)\n",
33 | "\n",
34 | "import numpy as np"
35 | ]
36 | },
37 | {
38 | "cell_type": "code",
39 | "execution_count": null,
40 | "metadata": {
41 | "scrolled": false
42 | },
43 | "outputs": [],
44 | "source": [
45 | "datamodule = hydra.utils.instantiate(cfg.data)\n",
46 | "datamodule.setup()\n",
47 | "\n",
48 | "datamodule"
49 | ]
50 | },
51 | {
52 | "cell_type": "code",
53 | "execution_count": null,
54 | "metadata": {
55 | "scrolled": false
56 | },
57 | "outputs": [],
58 | "source": [
59 | "for _ in range(8):\n",
60 | " data = datamodule.train_dataset[np.random.randint(len(datamodule.train_dataset))]\n",
61 | " datamodule.show(data)"
62 | ]
63 | },
64 | {
65 | "cell_type": "code",
66 | "execution_count": null,
67 | "metadata": {},
68 | "outputs": [],
69 | "source": []
70 | }
71 | ],
72 | "metadata": {
73 | "interpreter": {
74 | "hash": "f66f361b60a0f669a83bf49483502faf91b64305c31842951638f8f87d6b5230"
75 | },
76 | "kernelspec": {
77 | "display_name": "Python 3 (ipykernel)",
78 | "language": "python",
79 | "name": "python3"
80 | },
81 | "language_info": {
82 | "codemirror_mode": {
83 | "name": "ipython",
84 | "version": 3
85 | },
86 | "file_extension": ".py",
87 | "mimetype": "text/x-python",
88 | "name": "python",
89 | "nbconvert_exporter": "python",
90 | "pygments_lexer": "ipython3",
91 | "version": "3.9.13"
92 | }
93 | },
94 | "nbformat": 4,
95 | "nbformat_minor": 2
96 | }
97 |
--------------------------------------------------------------------------------
/notebooks/demo_2D.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {
7 | "scrolled": false
8 | },
9 | "outputs": [],
10 | "source": [
11 | "!HYDRA_FULL_ERROR=1\n",
12 | "\n",
13 | "%load_ext autoreload\n",
14 | "%autoreload 2\n",
15 | "import warnings\n",
16 | "warnings.filterwarnings(\"ignore\")\n",
17 | "\n",
18 | "from plotly.offline import init_notebook_mode\n",
19 | "init_notebook_mode(connected = True)\n",
20 | "\n",
21 | "import sys, os\n",
22 | "sys.path.append(\"../\")\n",
23 | "\n",
24 | "import hydra\n",
25 | "import pytorch_lightning as pl\n",
26 | "import logging\n",
27 | "\n",
28 | "pl.utilities.distributed.log.setLevel(logging.ERROR)\n",
29 | "\n",
30 | "hydra.initialize(config_path=\"../configs\")\n",
31 | "cfg = hydra.compose(overrides=[\"+data=power_plant\", \"data.modality=2D\"])\n",
32 | "cfg.data.data_dir = os.path.join(\"../\", cfg.data.data_dir)\n",
33 | "\n",
34 | "import numpy as np"
35 | ]
36 | },
37 | {
38 | "cell_type": "code",
39 | "execution_count": null,
40 | "metadata": {
41 | "scrolled": false
42 | },
43 | "outputs": [],
44 | "source": [
45 | "datamodule = hydra.utils.instantiate(cfg.data)\n",
46 | "datamodule.setup()\n",
47 | "\n",
48 | "datamodule.describe()"
49 | ]
50 | },
51 | {
52 | "cell_type": "code",
53 | "execution_count": null,
54 | "metadata": {
55 | "scrolled": false
56 | },
57 | "outputs": [],
58 | "source": [
59 | "for _ in range(2):\n",
60 | " data = datamodule.train_dataset[np.random.randint(len(datamodule.train_dataset))]\n",
61 | " datamodule.show(data)"
62 | ]
63 | },
64 | {
65 | "cell_type": "code",
66 | "execution_count": null,
67 | "metadata": {},
68 | "outputs": [],
69 | "source": []
70 | }
71 | ],
72 | "metadata": {
73 | "interpreter": {
74 | "hash": "f66f361b60a0f669a83bf49483502faf91b64305c31842951638f8f87d6b5230"
75 | },
76 | "kernelspec": {
77 | "display_name": "Python 3 (ipykernel)",
78 | "language": "python",
79 | "name": "python3"
80 | },
81 | "language_info": {
82 | "codemirror_mode": {
83 | "name": "ipython",
84 | "version": 3
85 | },
86 | "file_extension": ".py",
87 | "mimetype": "text/x-python",
88 | "name": "python",
89 | "nbconvert_exporter": "python",
90 | "pygments_lexer": "ipython3",
91 | "version": "3.9.13"
92 | }
93 | },
94 | "nbformat": 4,
95 | "nbformat_minor": 2
96 | }
97 |
--------------------------------------------------------------------------------
/earthparserdataset/transforms/grid_sampling.py:
--------------------------------------------------------------------------------
1 | """
2 | Modified from https://pytorch-geometric.readthedocs.io/en/latest/_modules/torch_geometric/transforms/grid_sampling.html
3 | """
4 |
5 | import re
6 | from typing import List, Optional, Union
7 |
8 | import torch
9 | import torch.nn.functional as F
10 | from torch import Tensor
11 | from torch_scatter import scatter_add, scatter_mean, scatter_max
12 |
13 | import torch_geometric
14 | from torch_geometric.data import Data
15 | from torch_geometric.transforms import BaseTransform
16 |
17 |
18 | class GridSampling(BaseTransform):
19 | r"""Clusters points into voxels with size :attr:`size`.
20 | Each cluster returned is a new point based on the mean of all points
21 | inside the given cluster.
22 |
23 | Args:
24 | size (float or [float] or Tensor): Size of a voxel (in each dimension).
25 | start (float or [float] or Tensor, optional): Start coordinates of the
26 | grid (in each dimension). If set to :obj:`None`, will be set to the
27 | minimum coordinates found in :obj:`data.pos`.
28 | (default: :obj:`None`)
29 | end (float or [float] or Tensor, optional): End coordinates of the grid
30 | (in each dimension). If set to :obj:`None`, will be set to the
31 | maximum coordinates found in :obj:`data.pos`.
32 | (default: :obj:`None`)
33 | """
34 | def __init__(self, size: Union[float, List[float], Tensor],
35 | start: Optional[Union[float, List[float], Tensor]] = None,
36 | end: Optional[Union[float, List[float], Tensor]] = None):
37 | self.size = size
38 | self.start = start
39 | self.end = end
40 |
41 | def __call__(self, data: Data) -> Data:
42 | num_nodes = data.num_nodes
43 |
44 | batch = data.get('batch', None)
45 |
46 | c = torch_geometric.nn.voxel_grid(data.pos, self.size, batch,
47 | self.start, self.end)
48 | c, perm = torch_geometric.nn.pool.consecutive.consecutive_cluster(c)
49 |
50 | for key, item in data:
51 | if bool(re.search('edge', key)):
52 | raise ValueError(
53 | 'GridSampling does not support coarsening of edges')
54 |
55 | if torch.is_tensor(item) and item.size(0) == num_nodes:
56 | if key in ['y', 'point_y', 'label', 'point_inst']:
57 | item = F.one_hot(item)
58 | item = scatter_add(item, c, dim=0)
59 | data[key] = item.argmax(dim=-1)
60 | elif key == 'batch':
61 | data[key] = item[perm]
62 | elif data[key].dtype == torch.uint8:
63 | data[key] = (255. * scatter_mean(item / 255., c, dim=0)).to(torch.uint8)
64 | else:
65 | data[key] = scatter_mean(item, c, dim=0)
66 |
67 | return data
68 |
69 | def __repr__(self) -> str:
70 | return f'{self.__class__.__name__}(size={self.size})'
71 |
--------------------------------------------------------------------------------
/notebooks/demo_3D.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {
7 | "scrolled": false
8 | },
9 | "outputs": [],
10 | "source": [
11 | "!HYDRA_FULL_ERROR=1\n",
12 | "\n",
13 | "%load_ext autoreload\n",
14 | "%autoreload 2\n",
15 | "import warnings\n",
16 | "warnings.filterwarnings(\"ignore\")\n",
17 | "\n",
18 | "from plotly.offline import init_notebook_mode\n",
19 | "init_notebook_mode(connected = True)\n",
20 | "import copy\n",
21 | "import sys, os\n",
22 | "sys.path.append(\"../\")\n",
23 | "\n",
24 | "import hydra\n",
25 | "import pytorch_lightning as pl\n",
26 | "import logging\n",
27 | "\n",
28 | "pl.utilities.distributed.log.setLevel(logging.ERROR)\n",
29 | "\n",
30 | "hydra.initialize(config_path=\"../configs\")\n",
31 | "cfg = hydra.compose(overrides=[\"+data=marina\"])\n",
32 | "cfg.data.data_dir = os.path.join(\"../\", cfg.data.data_dir)\n",
33 | "\n",
34 | "import numpy as np"
35 | ]
36 | },
37 | {
38 | "cell_type": "code",
39 | "execution_count": null,
40 | "metadata": {
41 | "scrolled": false
42 | },
43 | "outputs": [],
44 | "source": [
45 | "datamodule = hydra.utils.instantiate(cfg.data)\n",
46 | "datamodule.setup()\n",
47 | "\n",
48 | "datamodule.describe()"
49 | ]
50 | },
51 | {
52 | "cell_type": "code",
53 | "execution_count": null,
54 | "metadata": {},
55 | "outputs": [],
56 | "source": [
57 | "print(len(datamodule.train_dataset.slices[\"pos\"]) if datamodule.train_dataset.slices is not None else \"NO SLICES\")\n",
58 | "idx = 0\n",
59 | "\n",
60 | "data_to_print = copy.deepcopy(datamodule.train_dataset.data)\n",
61 | " \n",
62 | "if datamodule.train_dataset.slices is not None:\n",
63 | " for k in datamodule.train_dataset.slices.keys():\n",
64 | " data_to_print[k] = data_to_print[k][datamodule.train_dataset.slices[k][idx]:datamodule.train_dataset.slices[k][idx + 1]]\n",
65 | " \n",
66 | "datamodule.show(data_to_print, voxelize=2.*data_to_print.size(0) / 1000000)"
67 | ]
68 | },
69 | {
70 | "cell_type": "code",
71 | "execution_count": null,
72 | "metadata": {
73 | "scrolled": false
74 | },
75 | "outputs": [],
76 | "source": [
77 | "for _ in range(4):\n",
78 | " data = datamodule.train_dataset[np.random.randint(len(datamodule.train_dataset))]\n",
79 | " datamodule.show(data)"
80 | ]
81 | },
82 | {
83 | "cell_type": "code",
84 | "execution_count": null,
85 | "metadata": {},
86 | "outputs": [],
87 | "source": []
88 | }
89 | ],
90 | "metadata": {
91 | "interpreter": {
92 | "hash": "f66f361b60a0f669a83bf49483502faf91b64305c31842951638f8f87d6b5230"
93 | },
94 | "kernelspec": {
95 | "display_name": "Python 3 (ipykernel)",
96 | "language": "python",
97 | "name": "python3"
98 | },
99 | "language_info": {
100 | "codemirror_mode": {
101 | "name": "ipython",
102 | "version": 3
103 | },
104 | "file_extension": ".py",
105 | "mimetype": "text/x-python",
106 | "name": "python",
107 | "nbconvert_exporter": "python",
108 | "pygments_lexer": "ipython3",
109 | "version": "3.9.13"
110 | }
111 | },
112 | "nbformat": 4,
113 | "nbformat_minor": 2
114 | }
115 |
--------------------------------------------------------------------------------
/earthparserdataset/shapenetsem.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import os.path as osp
3 |
4 | from torch_geometric.io import read_txt_array
5 | import torch
6 | import torch.nn.functional as F
7 | import torch_geometric.transforms as T
8 | from torch_geometric.data import Data, InMemoryDataset
9 | from .utils.labels import from_sem_to_color
10 | from .base import BaseDataModule
11 | from tqdm.auto import tqdm
12 | import json
13 |
14 | class MeanRemove(T.BaseTransform):
15 | def __init__(self) -> None:
16 | super().__init__()
17 |
18 | def __call__(self, data):
19 | data.pos -= data.pos.mean(0)
20 | return data
21 |
22 | class ShapeNetSemDataset(InMemoryDataset):
23 | CAT2ID = {
24 | "Airplane" : "02691156",
25 | "Bag" : "02773838",
26 | "Cap" : "02954340",
27 | "Car" : "02958343",
28 | "Chair" : "03001627",
29 | "Earphone" : "03261776",
30 | "Guitar" : "03467517",
31 | "Knife" : "03624134",
32 | "Lamp" : "03636649",
33 | "Laptop" : "03642806",
34 | "Motorbike" : "03790512",
35 | "Mug" : "03797390",
36 | "Pistol" : "03948459",
37 | "Rocket" : "04099429",
38 | "Skateboard" : "04225987",
39 | "Table" : "04379243",
40 | }
41 |
42 | def __init__(self, options, mode):
43 | self.options = copy.deepcopy(options)
44 | self.mode = mode
45 |
46 | transform = [
47 | MeanRemove()
48 | ]
49 |
50 | if mode == "train":
51 | transform += [
52 | T.RandomTranslate(0.01),
53 | T.RandomFlip(axis=1),
54 | ]
55 |
56 | if self.options.rotate_z:
57 | transform.append(T.RandomRotate(180, axis=2))
58 |
59 | super().__init__(options.data_dir, T.Compose(transform), None, None)
60 | self.data, self.slices = torch.load(self.processed_paths[0])
61 |
62 | self.data.point_y -= self.data.point_y.min()
63 |
64 | self.N_point_max = int((self.slices["pos"][1:] - self.slices["pos"][:-1]).max().item())
65 |
66 | @property
67 | def raw_file_names(self):
68 |
69 | files = []
70 | modes = ["train", "val", "test"] if self.mode in ["train", "test"] else ["val"]
71 | for mode in modes:
72 | with open(osp.join(self.root, 'train_test_split', f'shuffled_{mode}_file_list.json'), 'r') as f:
73 | files += json.load(f)
74 |
75 | select_ids = [self.CAT2ID[cat] for cat in self.options.classes]
76 | file_names = []
77 | for file in files:
78 | if file.split('/')[1] in select_ids:
79 | file_names.append(f"{osp.join(*file.split('/')[1:])}.txt")
80 |
81 | return file_names
82 |
83 | @property
84 | def processed_file_names(self):
85 | return [f"{'_'.join(self.options.classes)}.pt"]
86 |
87 | @property
88 | def processed_dir(self) -> str:
89 | return osp.join(self.root, 'processed', self.mode)
90 |
91 | def process(self):
92 | # Read data into huge `Data` list.
93 | data_list = []
94 | for file in tqdm(self.raw_file_names):
95 | data = read_txt_array(osp.join(self.root, file))
96 |
97 | data_list.append(Data(pos=data[:, :3], point_y=data[:, -1]))
98 |
99 | if self.pre_filter is not None:
100 | data_list = [data for data in data_list if self.pre_filter(data)]
101 |
102 | if self.pre_transform is not None:
103 | data_list = [self.pre_transform(data) for data in data_list]
104 |
105 | data, slices = self.collate(data_list)
106 | data.pos = data.pos[:, [0, 2, 1]]
107 | torch.save((data, slices), self.processed_paths[0])
108 |
109 | def len(self):
110 | return len(self.raw_file_names)
111 |
112 | def __getitem__(self, idx):
113 | data = super().__getitem__(idx)
114 | data.features = data.pos.clone()
115 |
116 | data.pos *= .5 * self.options.max_xy / ((data.pos**2).sum(-1)**.5).max()
117 | data.pos += .5 * self.options.max_xy
118 |
119 | data.pos_lenght = torch.tensor(data.pos.size(0))
120 |
121 | data.pos_padded = F.pad(data.pos, (0, 0, 0, self.N_point_max - data.pos.shape[0])).unsqueeze(0)
122 |
123 | return data
124 |
125 | class ShapeNetSemDataModule(BaseDataModule):
126 | _DatasetSplit_ = {
127 | "train": ShapeNetSemDataset,
128 | "val": ShapeNetSemDataset,
129 | "test": ShapeNetSemDataset
130 | }
131 |
132 | def get_feature_names(self):
133 | return ["x", "y", "z"]
134 |
135 | def from_labels_to_color(self, labels):
136 | return from_sem_to_color(
137 | labels,
138 | self.myhparams.color_map
139 | )
--------------------------------------------------------------------------------
/earthparserdataset/transforms/base.py:
--------------------------------------------------------------------------------
1 | import math
2 | import numbers
3 | import random
4 |
5 | import numpy as np
6 | import torch
7 | import torch_geometric.transforms as T
8 | from torch_geometric.transforms import BaseTransform
9 |
10 |
11 | class RandomCropSubTileTrain(T.BaseTransform):
12 | def __init__(self, max_xy) -> None:
13 | self.max_xy = max_xy
14 | super().__init__()
15 |
16 | def __call__(self, data):
17 | xy = data.pos[np.random.randint(data.pos.size(0)), :2]
18 | while (xy > data.pos_maxi - self.max_xy / 2.).any() or (xy < data.pos[:, :2].min(0)[0] + self.max_xy / 2.).any():
19 | xy = data.pos[np.random.randint(data.pos.size(0)), :2]
20 |
21 | keep = np.abs(data.pos[:, :2] - xy).max(-1)[0] <= self.max_xy / 2.
22 |
23 | del data.pos_maxi
24 |
25 | for k in data.keys:
26 | setattr(data, k, getattr(data, k)[keep])
27 |
28 | data.pos[..., :2] -= xy
29 | return data
30 |
31 |
32 | class CenterCrop(T.BaseTransform):
33 | def __init__(self, max_xy) -> None:
34 | self.max_xy = max_xy
35 | super().__init__()
36 |
37 | def __call__(self, data):
38 |
39 | keep = np.abs(data.pos[:, :2]).max(-1)[0] <= self.max_xy / 2.
40 |
41 | for k in data.keys:
42 | setattr(data, k, getattr(data, k)[keep])
43 |
44 | data.pos[..., :2] += self.max_xy / 2.
45 | return data
46 |
47 |
48 | class RandomCropSubTileVal(T.BaseTransform):
49 | def __init__(self, max_xy) -> None:
50 | self.max_xy = max_xy
51 | super().__init__()
52 |
53 | def __call__(self, data):
54 | xy = data.pos[np.random.randint(data.pos.size(0)), :2]
55 | while (xy > data.pos_maxi - self.max_xy / 2.).any() or (xy < data.pos[:, :2].min(0)[0] + self.max_xy / 2.).any():
56 | xy = data.pos[np.random.randint(data.pos.size(0)), :2]
57 |
58 | keep = np.abs(data.pos[:, :2] - xy).max(-1)[0] <= self.max_xy / 2.
59 |
60 | del data.pos_maxi
61 |
62 | for k in data.keys:
63 | setattr(data, k, getattr(data, k)[keep])
64 |
65 | data.pos[..., :2] -= xy - self.max_xy / 2.
66 | return data
67 |
68 |
69 | class ZCrop(T.BaseTransform):
70 | def __init__(self, max_z) -> None:
71 | self.max_z = max_z
72 | super().__init__()
73 |
74 | def __call__(self, data):
75 | data.pos[:, -1] -= data.pos[:, -1].min()
76 | keep = data.pos[:, -1] < self.max_z
77 |
78 | for k in data.keys:
79 | setattr(data, k, getattr(data, k)[keep])
80 |
81 | return data
82 |
83 |
84 | class MaxPoints(T.BaseTransform):
85 | def __init__(self, max_points) -> None:
86 | self.max_points = max_points
87 | super().__init__()
88 |
89 | def __call__(self, data):
90 |
91 | if data.pos.shape[0] > self.max_points:
92 | keep = np.random.choice(
93 | data.pos.shape[0], self.max_points, replace=True)
94 | for k in data.keys:
95 | setattr(data, k, getattr(data, k)[keep])
96 | return data
97 |
98 |
99 | class RandomRotate(T.BaseTransform):
100 | r"""Rotates node positions around a specific axis by a randomly sampled
101 | factor within a given interval.
102 |
103 | Args:
104 | degrees (tuple or float): Rotation interval from which the rotation
105 | angle is sampled. If :obj:`degrees` is a number instead of a
106 | tuple, the interval is given by :math:`[-\mathrm{degrees},
107 | \mathrm{degrees}]`.
108 | axis (int, optional): The rotation axis. (default: :obj:`0`)
109 | """
110 |
111 | def __init__(self, degrees, axis=0):
112 | if isinstance(degrees, numbers.Number):
113 | degrees = (-abs(degrees), abs(degrees))
114 | assert isinstance(degrees, (tuple, list)) and len(degrees) == 2
115 | self.degrees = degrees
116 | self.axis = axis
117 |
118 | def __call__(self, data):
119 | degree = math.pi * random.uniform(*self.degrees) / 180.0
120 | sin, cos = math.sin(degree), math.cos(degree)
121 |
122 | if data.pos.size(-1) == 2:
123 | matrix = [[cos, sin], [-sin, cos]]
124 | else:
125 | if self.axis == 0:
126 | matrix = [[1, 0, 0], [0, cos, sin], [0, -sin, cos]]
127 | elif self.axis == 1:
128 | matrix = [[cos, 0, -sin], [0, 1, 0], [sin, 0, cos]]
129 | else:
130 | matrix = [[cos, sin, 0], [-sin, cos, 0], [0, 0, 1]]
131 |
132 | matrix = torch.tensor(matrix).to(data.pos.device, data.pos.dtype).t()
133 |
134 | data.pos = data.pos @ matrix
135 |
136 | return data
137 |
138 | def __repr__(self) -> str:
139 | return (f'{self.__class__.__name__}({self.degrees}, '
140 | f'axis={self.axis})')
141 |
142 |
143 | class RandomScale(BaseTransform):
144 | r"""Scales node positions by a randomly sampled factor :math:`s` within a
145 | given interval, *e.g.*, resulting in the transformation matrix
146 |
147 | .. math::
148 | \begin{bmatrix}
149 | s & 0 & 0 \\
150 | 0 & s & 0 \\
151 | 0 & 0 & s \\
152 | \end{bmatrix}
153 |
154 | for three-dimensional positions.
155 |
156 | Args:
157 | scales (tuple): scaling factor interval, e.g. :obj:`(a, b)`, then scale
158 | is randomly sampled from the range
159 | :math:`a \leq \mathrm{scale} \leq b`.
160 | """
161 |
162 | def __init__(self, scales):
163 | assert isinstance(scales, (tuple, list)) and len(scales) == 2
164 | self.scales = scales
165 |
166 | def __call__(self, data):
167 | scale = random.uniform(*self.scales)
168 | data.pos = data.pos * scale
169 | data.pos_maxi = data.pos_maxi * scale
170 |
171 | return data
172 |
173 | def __repr__(self) -> str:
174 | return f'{self.__class__.__name__}({self.scales})'
175 |
176 |
177 | class RandomFlip(BaseTransform):
178 | """Flips node positions along a given axis randomly with a given
179 | probability.
180 |
181 | Args:
182 | axis (int): The axis along the position of nodes being flipped.
183 | p (float, optional): Probability that node positions will be flipped.
184 | (default: :obj:`0.5`)
185 | """
186 |
187 | def __init__(self, axis, max_xy, p=0.5):
188 | self.axis = axis
189 | self.p = p
190 | self.max_xy = max_xy
191 |
192 | def __call__(self, data):
193 | if random.random() < self.p:
194 | data.pos[..., self.axis] = - data.pos[..., self.axis]
195 |
196 | return data
197 |
198 | def __repr__(self) -> str:
199 | return f'{self.__class__.__name__}(axis={self.axis}, p={self.p})'
200 |
--------------------------------------------------------------------------------
/earthparserdataset/base.py:
--------------------------------------------------------------------------------
1 | from types import SimpleNamespace
2 |
3 | import copy
4 | import torch
5 | import plotly.graph_objects as go
6 | import numpy as np
7 | import pytorch_lightning as pl
8 | from torch_geometric.loader import DataLoader
9 | from matplotlib import cm
10 | import torch_scatter
11 | import torch.nn.functional as F
12 | import matplotlib.pyplot as plt
13 |
14 | from hydra.utils import to_absolute_path
15 |
16 |
17 | class BaseDataModule(pl.LightningDataModule):
18 | FEATURE2NAME = {"y": "Ground truth", "y_pred": "Prediction", "inst": "Ground truth instance",
19 | "inst_pred": "Predicted instance", "i": "Intensity", "rgb": "RGB", "infrared": "Infrared", "xyz": "Position"}
20 |
21 | def __init__(self, *args, **kwargs):
22 | super().__init__()
23 |
24 | self.myhparams = SimpleNamespace(**kwargs)
25 | self.myhparams.data_dir = to_absolute_path(self.myhparams.data_dir)
26 |
27 | def __repr__(self) -> str:
28 |
29 | out = f"DataModule:\t{self.__class__.__name__}"
30 |
31 | for split in ["train", "val", "test"]:
32 | if hasattr(self, f"{split}_dataset"):
33 | out += f"\n\t{getattr(self, f'{split}_dataset')}"
34 |
35 | return out
36 |
37 | def setup(self, stage: str = None):
38 | if stage in (None, 'fit'):
39 | self.train_dataset = self._DatasetSplit_["train"](self.myhparams, "train")
40 | self.val_dataset = self._DatasetSplit_["val"](self.myhparams, "val")
41 | elif stage in (None, 'validate'):
42 | self.val_dataset = self._DatasetSplit_["val"](self.myhparams, "val")
43 | elif stage in (None, 'train'):
44 | self.train_dataset = self._DatasetSplit_["train"](self.myhparams, "train")
45 |
46 | if stage in (None, 'test'):
47 | self.train_dataset = self._DatasetSplit_["train"](self.myhparams, "train")
48 | self.test_dataset = self._DatasetSplit_["test"](self.myhparams, "test")
49 |
50 | def train_dataloader(self):
51 | return DataLoader(
52 | self.train_dataset,
53 | batch_size=self.myhparams.batch_size,
54 | shuffle=True,
55 | num_workers=self.myhparams.num_workers,
56 | persistent_workers=True
57 | )
58 |
59 | def val_dataloader(self):
60 | return DataLoader(
61 | self.val_dataset,
62 | batch_size=1,
63 | shuffle=False,
64 | num_workers=self.myhparams.num_workers,
65 | persistent_workers=True
66 | )
67 |
68 | def test_dataloader(self):
69 | return DataLoader(
70 | self.test_dataset,
71 | batch_size=1,
72 | shuffle=False,
73 | num_workers=self.myhparams.num_workers
74 | )
75 |
76 | def get_label_from_raw_feature(self, c):
77 | return self.FEATURE2NAME[c]
78 |
79 | def get_color_from_item(self, item, c):
80 | if c == "y" and hasattr(item, "point_y"):
81 | color = self.from_labels_to_color(item.point_y.squeeze()).numpy()
82 | elif c == "y_pred" and hasattr(item, "point_y_pred"):
83 | try:
84 | color = self.from_labels_to_color(
85 | item.point_y_pred.squeeze()).numpy()
86 | except:
87 | color = cm.get_cmap("tab20")(item.point_y_pred.squeeze(
88 | ).cpu().numpy() / item.point_y_pred.max().item())[:, :-1]
89 | color = (255*color).astype(np.uint8)
90 | elif c in ["inst_pred", "inst"] and hasattr(item, f"point_{c}"):
91 | color = cm.get_cmap("tab20")(np.random.permutation(getattr(item, f"point_{c}").max().item(
92 | )+1)[getattr(item, f"point_{c}").squeeze().cpu().numpy()] / getattr(item, f"point_{c}").max().item())[:, :-1]
93 | color = (255*color).astype(np.uint8)
94 | elif c == "i" and hasattr(item, "intensity"):
95 | color = item.intensity.squeeze().numpy()
96 | color = 0.01 + 0.98*(color - color.min()) / \
97 | (color.max() - color.min())
98 | color = cm.get_cmap("viridis")(color)[:, :-1]
99 | color = (255*color).astype(np.uint8)
100 | elif c == "rgb" and hasattr(item, "rgb"):
101 | color = 0.01 + 0.98*item.rgb / 255.
102 | else:
103 | color = item.pos.squeeze().numpy()
104 | color = 0.01 + 0.98*(color - color.min()) / \
105 | (color.max() - color.min())
106 | color = (255*color).astype(np.uint8)
107 | return color
108 |
109 | def show(self, item, voxelize=0, color="y;xyz;rgb;i", ps=5):
110 | if self.myhparams.modality == "3D":
111 | self.show3D(item, voxelize, color, ps)
112 | elif self.myhparams.modality == "2D":
113 | self.show2D(item)
114 | else:
115 | raise NotImplementedError(f"Modality {self.hparams.modality} not implemented")
116 |
117 | def show2D(self, item):
118 | fig, ax = plt.subplots(1, 4, figsize=(20, 5))
119 |
120 | ax[0].imshow(item[..., 1] / 255.)
121 | ax[0].set_title("z max projection")
122 | ax[0].set_aspect("equal")
123 | ax[0].axis("off")
124 |
125 | ax[1].imshow(item[..., 3] / 255.)
126 | ax[1].set_title("intensity max projection")
127 | ax[1].set_aspect("equal")
128 | ax[1].axis("off")
129 |
130 | ax[2].imshow(item[..., [4,5,6]] / 255.)
131 | ax[2].set_title("rgb")
132 | ax[2].set_aspect("equal")
133 | ax[2].axis("off")
134 |
135 | ax[3].imshow(self.from_labels_to_color(item[..., -1]))
136 | ax[3].set_title("labels")
137 | ax[3].set_aspect("equal")
138 | ax[3].axis("off")
139 |
140 | plt.show()
141 |
142 | def show3D(self, item, voxelize=0, color="y;xyz;rgb;i", ps=5):
143 |
144 | thisitem = copy.deepcopy(item)
145 |
146 | if voxelize:
147 | choice = torch.unique(
148 | (thisitem.pos / voxelize).int(), return_inverse=True, dim=0)[1]
149 | for attr in ["point_y", "point_y_pred", "point_inst", "point_inst_pred"]:
150 | if hasattr(thisitem, attr):
151 | setattr(thisitem, attr, torch_scatter.scatter_sum(F.one_hot(
152 | getattr(thisitem, attr).squeeze().long()), choice, 0).argmax(-1).unsqueeze(0))
153 | for attr in ["pos", "features"]:
154 | if hasattr(thisitem, attr):
155 | setattr(thisitem, attr, torch_scatter.scatter_mean(
156 | getattr(thisitem, attr), choice, 0))
157 | for attr in ["intensity", "rgbi", "rgb"]:
158 | if hasattr(thisitem, attr):
159 | setattr(thisitem, attr, torch_scatter.scatter_max(
160 | getattr(thisitem, attr), choice, 0)[0])
161 |
162 | datadtype = torch.float16
163 | margin = int(0.02 * 600)
164 | layout = go.Layout(
165 | width=1000,
166 | height=600,
167 | margin=dict(l=margin, r=margin, b=margin, t=4*margin),
168 | uirevision=True,
169 | showlegend=False
170 | )
171 | fig = go.Figure(
172 | layout=layout,
173 | data=go.Scatter3d(
174 | x=thisitem.pos[:, 0].to(datadtype), y=thisitem.pos[:, 1].to(datadtype), z=thisitem.pos[:, 2].to(datadtype),
175 | mode='markers',
176 | marker=dict(size=ps, color=self.get_color_from_item(
177 | thisitem, color.split(";")[0])),
178 | )
179 | )
180 | updatemenus = [
181 | dict(
182 | buttons=list([
183 | dict(
184 | args=[{"marker": dict(size=ps, color=self.get_color_from_item(thisitem, c))}, [
185 | 0]],
186 | label=self.get_label_from_raw_feature(c),
187 | method="restyle"
188 | ) for c in color.split(";")
189 | ]),
190 | direction="down",
191 | pad={"r": 10, "t": 10},
192 | showactive=True,
193 | x=0.05,
194 | xanchor="left",
195 | y=0.88,
196 | yanchor="top"
197 | ),
198 | ]
199 |
200 | fig.update_layout(updatemenus=updatemenus)
201 | fig.update_layout(
202 | scene_aspectmode='data',
203 | )
204 |
205 | fig.show()
206 |
207 | del thisitem
--------------------------------------------------------------------------------
/earthparserdataset/earthparserdataset.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import os
3 | import os.path as osp
4 |
5 | import matplotlib.pyplot as plt
6 | import numpy as np
7 | import torch
8 | import torch.nn.functional as F
9 | import torch_geometric.transforms as T
10 | import torch_scatter
11 | from numpy.lib.recfunctions import structured_to_unstructured
12 | from torch_geometric.data import Data, InMemoryDataset
13 | from tqdm.auto import tqdm
14 |
15 | from .base import BaseDataModule
16 | from .transforms import (CenterCrop, GridSampling, MaxPoints,
17 | RandomCropSubTileTrain, RandomCropSubTileVal,
18 | RandomFlip, RandomRotate, RandomScale, ZCrop)
19 | from .utils import color as color
20 | from .utils.labels import apply_learning_map, from_sem_to_color
21 |
22 |
23 | class LidarHDSplit(InMemoryDataset):
24 | def __init__(self, options, mode):
25 | self.options = copy.deepcopy(options)
26 | self.options.data_dir = osp.join(self.options.data_dir, self.options.name)
27 | self.mode = mode
28 |
29 | self.feature_normalizer = torch.tensor(
30 | [[self.options.max_xy, self.options.max_xy, self.options.max_z]])
31 |
32 | super().__init__(
33 | self.options.data_dir,
34 | transform=self.get_transform(),
35 | pre_transform=self.get_pre_transform(),
36 | pre_filter=None
37 | )
38 |
39 | self.data, self.slices = torch.load(self.processed_paths[0])
40 | self.data.point_y = apply_learning_map(
41 | self.data.point_y, self.options.learning_map)
42 |
43 | if mode in ["train", "val"]:
44 | self.items_per_epoch = int(self.options.items_per_epoch[mode] / (len(self.slices["pos"]) - 1 if self.slices is not None else 1))
45 | else:
46 | self.prepare_test_dataset()
47 |
48 | if mode == "val":
49 | self.load_all_items()
50 |
51 | def prepare_test_dataset(self):
52 | self.items_per_epoch = []
53 | self.tiles_unique_selection = {}
54 | self.tiles_min_z = {}
55 | self.from_idx_to_tile = {}
56 | idx = 0
57 | for i in range(self.__superlen__()):
58 | unique_i, inverse_i = torch.unique((self.__getsuperitem__(
59 | i).pos[:, :2] / self.options.max_xy).int(), dim=0, return_inverse=True)
60 | self.items_per_epoch.append(unique_i.shape[0])
61 | self.tiles_unique_selection[i] = unique_i
62 | self.tiles_min_z[i] = torch_scatter.scatter_min(
63 | self.__getsuperitem__(i).pos[:, 2], inverse_i, dim=0)[0]
64 | for j in range(unique_i.shape[0]):
65 | self.from_idx_to_tile[idx] = (i, unique_i[j])
66 | idx += 1
67 |
68 | def load_all_items(self):
69 | self.items = [super(LidarHDSplit, self).__getitem__(int(i / self.items_per_epoch)) for i in range(len(self))]
70 | del self.data
71 |
72 | def get_pre_transform(self):
73 | pre_transform = []
74 |
75 | if self.mode == "test":
76 | pre_transform.append(
77 | GridSampling(
78 | self.options.pre_transform_grid_sample /
79 | 4. if not "windturbine" in self.options.name else self.options.pre_transform_grid_sample
80 | )
81 | )
82 | else:
83 | pre_transform.append(GridSampling(
84 | self.options.pre_transform_grid_sample))
85 | pre_transform = T.Compose(pre_transform)
86 | return pre_transform
87 |
88 | def get_transform(self):
89 | if self.mode == "train":
90 | transform = [
91 | RandomScale(tuple(self.options.random_scales)),
92 | RandomCropSubTileTrain(2.**.5 * self.options.max_xy),
93 | RandomRotate(degrees=180., axis=2),
94 | RandomFlip(0, self.options.max_xy),
95 | CenterCrop(self.options.max_xy),
96 | ZCrop(self.options.max_z),
97 | T.RandomTranslate(self.options.random_jitter)
98 | ]
99 | elif self.mode == "val":
100 | transform = [
101 | RandomCropSubTileVal(self.options.max_xy),
102 | ZCrop(self.options.max_z)
103 | ]
104 | elif self.mode == "test":
105 | transform = []
106 | else:
107 | raise NotImplementedError(
108 | f"Mode {self.mode} not implemented. Should be in ['train', 'val', 'test']")
109 |
110 | if self.mode in ["train", "val"]:
111 | if self.options.N_scene != 0:
112 | transform.append(T.FixedPoints(self.options.N_scene))
113 | else:
114 | transform.append(MaxPoints(self.options.N_max))
115 |
116 | transform = T.Compose(transform)
117 | return transform
118 |
119 | def __repr__(self) -> str:
120 | return f"{self.__class__.__name__}({self.mode})"
121 |
122 | @property
123 | def raw_file_names(self):
124 | laz_dirs = []
125 | for tile in os.listdir(osp.join(self.options.data_dir, "tiles")):
126 | tile_dir = osp.join(self.options.data_dir, "tiles", tile)
127 | if osp.isdir(tile_dir):
128 | for subtile in os.listdir(tile_dir):
129 | if subtile.split(".")[-1] in ["las", "laz"]:
130 | laz_dirs.append(osp.join(tile_dir, subtile))
131 | if tile_dir.split(".")[-1] in ["las", "laz"]:
132 | laz_dirs.append(tile_dir)
133 |
134 | return laz_dirs
135 |
136 | @property
137 | def processed_file_names(self):
138 | return [f'grid{1000*self.options.pre_transform_grid_sample:.0f}mm_data.pt']
139 |
140 | @property
141 | def processed_dir(self) -> str:
142 | return osp.join(self.root, 'processed', self.mode if self.mode != "val" else "train")
143 |
144 | def process(self):
145 | data_list = []
146 |
147 | import pdal
148 | for laz in tqdm(self.raw_file_names):
149 | list_pipeline = [pdal.Reader(laz)]
150 | pipeline = pdal.Pipeline()
151 | for p in list_pipeline:
152 | pipeline |= p
153 | count = pipeline.execute()
154 | arrays = pipeline.arrays
155 |
156 | label = structured_to_unstructured(arrays[0][["Classification"]])
157 | for l in np.unique(label):
158 | if l not in self.options.raw_class_names.keys():
159 | label[label == l] = 1
160 |
161 | xyz = structured_to_unstructured(arrays[0][["X", "Y", "Z"]])
162 | intensity = structured_to_unstructured(arrays[0][["Intensity"]])
163 |
164 | if "Red" in arrays[0].dtype.names:
165 | rgb = structured_to_unstructured(
166 | arrays[0][["Red", "Green", "Blue"]])
167 |
168 | if rgb.max() >= 256:
169 | rgb = (rgb / 2**8).astype(np.uint8)
170 | else:
171 | rgb = rgb.astype(np.uint8)
172 | else:
173 | rgb = np.zeros_like(xyz, dtype=np.uint8)
174 |
175 | intensity = intensity.clip(np.percentile(intensity.flatten(), .1), np.percentile(intensity.flatten(), 99.9))
176 | intensity = (intensity / intensity.max()).astype(np.float32)
177 |
178 | label = label.astype(np.int64)
179 |
180 | xyz -= xyz.min(0)
181 | xyz = xyz.astype(np.float32)
182 | maxi = xyz.max(0)
183 |
184 | if self.options.subtile_max_xy > 0:
185 | n_split = 1 + \
186 | (maxi / ((1 + (self.mode == "test")) * self.options.subtile_max_xy)).astype(np.int32)
187 | else:
188 | n_split = [1, 1]
189 |
190 | for i in range(n_split[0]):
191 | for j in range(n_split[1]):
192 | keep = np.logical_and(
193 | np.logical_and(
194 | xyz[:, 0] >= i * maxi[0] / n_split[0],
195 | xyz[:, 0] < (i + 1) * maxi[0] / n_split[0]
196 | ),
197 | np.logical_and(
198 | xyz[:, 1] >= j * maxi[1] / n_split[1],
199 | xyz[:, 1] < (j + 1) * maxi[1] / n_split[1]
200 | )
201 | )
202 |
203 | data_list.append(Data(
204 | pos=torch.from_numpy(xyz[keep] - xyz[keep].min(0)),
205 | intensity=torch.from_numpy(intensity[keep]),
206 | rgb=torch.from_numpy(rgb[keep]),
207 | point_y=torch.from_numpy(label[keep]),
208 | ))
209 |
210 | if self.mode in ["train", "val"]:
211 | data_list[-1].pos_maxi = data_list[-1].pos[:, :2].max(0)[0]
212 |
213 | if self.pre_filter is not None:
214 | data_list = [data for data in data_list if self.pre_filter(data)]
215 |
216 | if self.pre_transform is not None:
217 | data_list = [self.pre_transform(data) for data in data_list]
218 |
219 | data, slices = self.collate(data_list)
220 |
221 | torch.save((data, slices), self.processed_paths[0])
222 |
223 | def __superlen__(self) -> int:
224 | return super().__len__()
225 |
226 | def __len__(self) -> int:
227 | if self.mode != "test":
228 | return self.items_per_epoch * self.__superlen__()
229 | else:
230 | return sum(self.items_per_epoch)
231 |
232 | def __getsuperitem__(self, idx):
233 | return super().__getitem__(idx)
234 |
235 | def __getitem__(self, idx):
236 | if self.mode == "train":
237 | item = self.__getsuperitem__(int(idx / self.items_per_epoch))
238 | elif self.mode == "val":
239 | item = self.items[idx]
240 | else:
241 | this_idx, tile = self.from_idx_to_tile[idx]
242 | item = self.__getsuperitem__(this_idx)
243 | keep = (item.pos[:, :2] / self.options.max_xy).int()
244 | keep = torch.logical_and(
245 | keep[:, 0] == tile[0], keep[:, 1] == tile[1])
246 | for k in item.keys:
247 | setattr(item, k, getattr(item, k)[keep])
248 |
249 | item.pos[:, :2] -= self.options.max_xy * tile
250 | item.pos[:, -1] -= item.pos[:, -1].min()
251 |
252 | keep = item.pos[:, -1] < self.options.max_z
253 | for k in item.keys:
254 | setattr(item, k, getattr(item, k)[keep])
255 |
256 | item.pos_lenght = torch.tensor(item.pos.size(0))
257 |
258 | if self.options.modality == "3D":
259 | pad = self.options.N_max - item.pos.size(0) if self.mode != "test" else 0
260 | if self.options.distance == "xyz":
261 | item.pos_padded = F.pad(item.pos, (0, 0, 0, pad), mode="constant", value=0).unsqueeze(0)
262 | elif self.options.distance == "xyzk":
263 | item.pos_padded = F.pad(torch.cat([item.pos, item.intensity], -1), (0, 0, 0, pad), mode="constant", value=0).unsqueeze(0)
264 | else:
265 | raise NotImplementedError(
266 | f"LiDAR-HD can't produce {self.options.distance}")
267 |
268 | item.features = 2 * \
269 | torch.cat([item.rgb / 255., item.pos /
270 | self.feature_normalizer, item.intensity], -1) - 1
271 | #del item.intensity
272 |
273 | if self.options.modality == "2D":
274 | item = self.from3Dto2Ditem(item)
275 |
276 | return item
277 |
278 | def from3Dto2Ditem(self, item):
279 | del item.pos_lenght
280 | res, n_dim = self.options.image.res, self.options.image.n_dim
281 | intensity = item.intensity.squeeze()
282 | rgb = item.rgb.float()
283 | labels = item.point_y.float().squeeze()
284 | xy, z = item.pos[:, :2], item.pos[:, 2]
285 | xy = torch.clamp(torch.floor(
286 | xy / (self.options.max_xy / (res-0.001))), 0, res - 1)
287 | xy = (xy[:, 0] * res + xy[:, 1]).long()
288 | features = []
289 | for values, init in zip([z, intensity, rgb, labels], [0, 0, 0, -1]):
290 | for mode in ['min', 'max']:
291 | features.append(gather_values(
292 | xy, z, values, mode=mode, res=res, init=init))
293 | item = torch.cat(features, dim=-1)
294 | item = item.reshape(res, res, n_dim)
295 | return item
296 |
297 | def gather_values(xy, z, values, mode='max', res=32, init=0):
298 | img = (torch.ones(
299 | res*res, values.shape[1] if 1 < len(values.shape) else 1) * init).squeeze()
300 | z = z.sort(descending=mode == 'max')
301 | xy, values = xy[z[1]], values[z[1]]
302 | unique, inverse = torch.unique(xy, sorted=True, return_inverse=True)
303 | perm = torch.arange(inverse.size(
304 | 0), dtype=inverse.dtype, device=inverse.device)
305 | inverse, perm = inverse.flip([0]), perm.flip([0])
306 | perm = inverse.new_empty(unique.size(0)).scatter_(0, inverse, perm)
307 | img.index_put_((xy[perm],), values[perm], accumulate=False)
308 | img = img.reshape(res, res, -1)
309 | return img
310 |
311 |
312 | class LidarHDDataModule(BaseDataModule):
313 | _DatasetSplit_ = {
314 | "train": LidarHDSplit,
315 | "val": LidarHDSplit,
316 | "test": LidarHDSplit
317 | }
318 |
319 | def from_labels_to_color(self, labels):
320 | return from_sem_to_color(apply_learning_map(labels, self.myhparams.learning_map_inv), self.myhparams.color_map)
321 |
322 | def get_feature_names(self):
323 | return ["red", "green", "blue", "x", "y", "z", "intensity"]
324 |
325 | def describe(self):
326 | print(self)
327 |
328 | for split in ["train", "val", "test"]:
329 | if hasattr(self, f"{split}_dataset") and hasattr(getattr(self, f"{split}_dataset"), "data"):
330 | print(f"{split} data\t", getattr(
331 | self, f"{split}_dataset").data)
332 | if hasattr(getattr(self, f"{split}_dataset").data, "point_y"):
333 | for c, n in zip(*np.unique(getattr(self, f"{split}_dataset").data.point_y.flatten().numpy(), return_counts=True)):
334 | print(
335 | f"class {self.myhparams.raw_class_names[self.myhparams.learning_map_inv[int(c)]]} ({c}) {(20 - len(self.myhparams.raw_class_names[self.myhparams.learning_map_inv[int(c)]]))*' '} \thas {n} \tpoints")
336 |
337 | if hasattr(self, "val_dataset"):
338 | lens = [item.pos.size(0) for item in self.val_dataset.items]
339 | plt.hist(lens)
340 | plt.title(
341 | f"size of val dataset items between {min(lens)} and {max(lens)}")
342 | plt.show()
343 |
--------------------------------------------------------------------------------