├── .gitignore
├── LICENSE
├── README.md
├── assets
├── optimal_matching_network.png
├── pix2poly_overall_bg_white.png
├── sfo7.png
└── vertex_sequence_detector.png
├── config.py
├── data_preprocess
├── README.md
├── download_mass_roads_dataset.sh
├── inria_to_coco.py
├── mass_roads_clip_shapefile.py
├── mass_roads_clip_tile_vectors.py
├── mass_roads_tiles_to_patches.py
├── mass_roads_world_to_pixel_coords.py
├── spacenet_convert_16bit_to_8bit.py
├── spacenet_to_coco.py
├── spacenet_world_to_pixel_coords.py
└── whu_buildings_to_coco.py
├── datasets
├── __init__.py
├── dataset_inria_coco.py
├── dataset_mass_roads.py
├── dataset_spacenet_coco.py
└── dataset_whu_buildings_coco.py
├── ddp_utils.py
├── engine.py
├── eval
├── __init__.py
├── hisup_eval_utils
│ └── metrics
│ │ ├── angle_eval.py
│ │ ├── cIoU.py
│ │ └── polis.py
└── topdig_eval_utils
│ └── metrics
│ └── topdig_metrics.py
├── evaluate_mass_roads_predictions.py
├── evaluation.py
├── models
└── model.py
├── postprocess_coco_parts.py
├── predict_inria_coco_val_set.py
├── predict_mass_roads_test_set.py
├── predict_spacenet_coco_val_set.py
├── predict_whu_buildings_coco_test_set.py
├── pyrightconfig.json
├── tokenizer.py
├── train_ddp.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Directories
2 | __pycache__
3 | data
4 | runs
5 | .vscode
6 | preds*
7 | crowdai_visualizations
8 | scratch
9 |
10 | # Files
11 | *.jpg
12 | *.png
13 | *.tar.xz
14 | *.gif
15 | *.txt
16 |
17 | # Files ignore
18 | !assets/sfo7.png
19 | !assets/pix2poly_overall_bg_white.png
20 | !assets/vertex_sequence_detector.png
21 | !assets/optimal_matching_network.png
22 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2025 Yeshwanth Kumar Adimoolam
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
8 |
9 |
10 | [[Project Webpage](https://yeshwanth95.github.io/Pix2Poly)] [[Paper](https://arxiv.org/abs/2412.07899)] [[Video]()]
11 |
12 | ## UPDATES:
13 |
14 | 1. 05.06.2025 - Pretrained checkpoints for Pix2Poly on the various datasets used in the paper are released. See [pretrained checkpoints](#pretrained-checkpoints).
15 | 2. 21.05.2025 - As reported by the authors of the [$P^3$ dataset](https://arxiv.org/abs/2505.15379), Pix2Poly achieves state-of-the-art results for multimodal building vectorization from image and LiDAR data sources.
16 |
17 | ### Abstract:
18 |
19 | Extraction of building footprint polygons from remotely sensed data is essential for several urban understanding tasks such as reconstruction, navigation, and mapping. Despite significant progress in the area, extracting accurate polygonal building footprints remains an open problem. In this paper, we introduce Pix2Poly, an attention-based end-to-end trainable and differentiable deep neural network capable of directly generating explicit high-quality building footprints in a ring graph format. Pix2Poly employs a generative encoder-decoder transformer to produce a sequence of graph vertex tokens whose connectivity information is learned by an optimal matching network. Compared to previous graph learning methods, ours is a truly end-to-end trainable approach that extracts high-quality building footprints and road networks without requiring complicated, computationally intensive raster loss functions and intricate training pipelines. Upon evaluating Pix2Poly on several complex and challenging datasets, we report that Pix2Poly outperforms state-of-the-art methods in several vector shape quality metrics while being an entirely explicit method.
20 |
21 | ### Method
22 |
23 |
24 |

25 |
26 |
27 |
28 |
29 |
30 |
31 | __Overview of the Pix2Poly architecture:__ The Pix2Poly architecture consists of three major components: (i) The Discrete Sequence Tokenizer, (ii) the Vertex Sequence Detector, (iii) and the Optimal Matching Network. The Discrete Sequence Tokenizer is used to convert the continuous building corner coordinates into discrete building corner coordinate tokens which form the ground truth for training Pix2Poly. The Vertex Sequence Detector is an encoder-decoder transformer that predicts a sequence of sequence of discrete building corner coordinate tokens. The Optimal Matching Network takes the predicted corner coordinate tokens and the per-corner features from the vertex sequence detector and predicts a N X N permutation matrix which contains the connectivity information between every possible corner pair. Together, the predicted building corners and permutation matrix are used to recover the final building polygons.
32 |
33 | ## Installation
34 |
35 | Pix2Poly was developed with `python=3.11`, `pytorch=2.1.2`, `pytorch-cuda=11.8`, `timm=0.9.12`, `transformers=4.32.1`
36 |
37 | Create a conda environment with the following specification:
38 |
39 | ```
40 | Conda requirements:
41 | channels:
42 | - defaults
43 | dependencies:
44 | - torchvision=0.16.2
45 | - pytorch=2.1.2
46 | - pytorch-cuda=11.8
47 | - torchaudio=2.1.2
48 | - timm=0.9.12
49 | - transformers=4.32.1
50 | - pycocotools=2.0.6
51 | - torchmetrics=1.2.1
52 | - tensorboard=2.15.1
53 | - pip:
54 | - albumentations==1.3.1
55 | - imageio==2.33.1
56 | - matplotlib-inline==0.1.6
57 | - opencv-python-headless==4.8.1.78
58 | - scikit-image==0.22.0
59 | - scikit-learn==1.3.2
60 | - scipy==1.11.4
61 | - shapely==2.0.4
62 | ```
63 |
64 |
65 | ## Datasets preparation
66 |
67 | See [datasets preprocessing](data_preprocess) for instructions on preparing the various datasets for training/inference.
68 |
69 | ## Pretrained Checkpoints
70 |
71 | Pretrained checkpoints for the various datasets used in the paper are available for download at the following links: [Google Drive](https://drive.google.com/file/d/1oEs2n81nMAzdY4G9bdrji13pOKk6MOET/view?usp=sharing) | [MEGA](https://mega.nz/file/ExQEBDxY#faK1yNaQ8KYvPGuxJY1snvFi7TfbF1kOx4mvhmUSb4s)
72 |
73 | Download the zip file, extract and place the individual runs folder in the `runs` directory at the root of the project.
74 |
75 | ## Configurations
76 |
77 | ## Training
78 |
79 | Start training with the following command:
80 |
81 | ```
82 | torchrun --nproc_per_node= train_ddp.py
83 | ```
84 |
85 | ## Prediction
86 |
87 | ### (i) INRIA Dataset
88 | To generate predictions for the INRIA dataset, run the following:
89 | ```shell
90 | python predict_inria_coco_val_set.py -d inria_coco_224_negAug \
91 | -e \
92 | -c \
93 | -o
94 | python postprocess_coco_parts.py # change input and output paths in L006 to L010.
95 | ```
96 |
97 | ### (ii) Spacenet 2 Dataset
98 | To generate predictions for the Spacenet 2 dataset, run the following:
99 | ```shell
100 | python predict_spacenet_coco_val_set.py -d spacenet_coco \
101 | -e \
102 | -c \
103 | -o
104 | python postprocess_coco_parts.py # change input and output paths in L006 to L010.
105 | ```
106 |
107 | ### (iii) WHU Buildings Dataset
108 | To generate predictions for the WHU Buildings dataset, run the following:
109 | ```shell
110 | python predict_whu_buildings_coco_test_set.py -d whu_buildings_224_coco \
111 | -e \
112 | -c \
113 | -o
114 | python postprocess_coco_parts.py # change input and output paths in L006 to L010.
115 | ```
116 | ### (iv) Massachusetts Roads Dataset
117 | To generate predictions for the Massachusetts Roads dataset, run the following:
118 | ```shell
119 | python predict_mass_roads_test_set.py -e \
120 | -c \
121 | -s \ # 'test' or 'val'
122 | --img_size 224 \
123 | --input_size 224 \
124 | --batch_size 24 \ # modify according to resources
125 | ```
126 |
127 | ## Evaluation (buildings datasets)
128 |
129 | Once predictions are made, metrics can be computed for the predicted files as follows:
130 |
131 | ```bash
132 | python evaluation.py --gt-file path/to/ground/truth/annotation.json --dt-file path/to/prediction.json --eval-type
133 | ```
134 |
135 | where `metric_type` can be one of the following: `ciou`, `angle`, `polis`, `topdig`.
136 |
137 | ## Evaluation (Massachusetts Roads Dataset)
138 |
139 | Once raster predictions are made for the Massachusetts Roads dataset, metrics can be computed for the predicted files as follows:
140 |
141 | ```bash
142 | python evaluate_mass_roads_predictions.py --gt-dir data/mass_roads_1500/test/map --dt-dir path/to/predicted/raster/masks/folder
143 | ```
144 |
145 | ## Citation
146 |
147 | If you find our work useful, please consider citing:
148 | ```bibtex
149 | @misc{adimoolam2024pix2poly,
150 | title={Pix2Poly: A Sequence Prediction Method for End-to-end Polygonal Building Footprint Extraction from Remote Sensing Imagery},
151 | author={Yeshwanth Kumar Adimoolam and Charalambos Poullis and Melinos Averkiou},
152 | year={2024},
153 | eprint={2412.07899},
154 | archivePrefix={arXiv},
155 | primaryClass={cs.CV},
156 | url={https://arxiv.org/abs/2412.07899},
157 | }
158 | ```
159 |
160 | ## Acknowledgements
161 |
162 | This repository benefits from the following open-source work. We thank the authors for their great work.
163 |
164 | 1. [Pix2Seq - official repo](https://github.com/google-research/pix2seq)
165 | 2. [Pix2Seq - unofficial repo](https://github.com/moein-shariatnia/Pix2Seq)
166 | 3. [Frame Field Learning](https://github.com/Lydorn/Polygonization-by-Frame-Field-Learning)
167 | 4. [PolyWorld](https://github.com/zorzi-s/PolyWorldPretrainedNetwork)
168 | 5. [HiSup](https://github.com/SarahwXU/HiSup)
169 |
--------------------------------------------------------------------------------
/assets/optimal_matching_network.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yeshwanth95/Pix2Poly/d2e288b4e267ae910cee6f54bece8eec98f43bea/assets/optimal_matching_network.png
--------------------------------------------------------------------------------
/assets/pix2poly_overall_bg_white.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yeshwanth95/Pix2Poly/d2e288b4e267ae910cee6f54bece8eec98f43bea/assets/pix2poly_overall_bg_white.png
--------------------------------------------------------------------------------
/assets/sfo7.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yeshwanth95/Pix2Poly/d2e288b4e267ae910cee6f54bece8eec98f43bea/assets/sfo7.png
--------------------------------------------------------------------------------
/assets/vertex_sequence_detector.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yeshwanth95/Pix2Poly/d2e288b4e267ae910cee6f54bece8eec98f43bea/assets/vertex_sequence_detector.png
--------------------------------------------------------------------------------
/config.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | class CFG:
4 | IMG_PATH = ''
5 | DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
6 |
7 | """
8 | supported datasets are:
9 | - inria_coco_224_negAug
10 | - spacenet_coco
11 | - whu_buildings_224_coco
12 | - mass_roads_224
13 | """
14 | DATASET = f"inria_coco_224_negAug"
15 | if "coco" in DATASET:
16 | TRAIN_DATASET_DIR = f"./data/{DATASET}/train"
17 | VAL_DATASET_DIR = f"./data/{DATASET}/val"
18 | TEST_IMAGES_DIR = f"./data/{DATASET}/val/images"
19 | elif "mass_roads" in DATASET:
20 | TRAIN_DATASET_DIR = f"./data/{DATASET}/train"
21 | VAL_DATASET_DIR = f"./data/{DATASET}/valid"
22 | TEST_IMAGES_DIR = f"./data/{DATASET}/test/images"
23 |
24 |
25 | TRAIN_DDP = True
26 | NUM_WORKERS = 2
27 | PIN_MEMORY = True
28 | LOAD_MODEL = False
29 |
30 | if "inria" in DATASET:
31 | N_VERTICES = 192 # maximum number of vertices per image in dataset.
32 | elif "spacenet" in DATASET:
33 | N_VERTICES = 192 # maximum number of vertices per image in dataset.
34 | elif "whu_buildings" in DATASET:
35 | N_VERTICES = 144 # maximum number of vertices per image in dataset.
36 | elif "mass_roads" in DATASET:
37 | N_VERTICES = 192 # maximum number of vertices per image in dataset.
38 |
39 | SINKHORN_ITERATIONS = 100
40 | MAX_LEN = (N_VERTICES*2) + 2
41 | if "inria" in DATASET:
42 | IMG_SIZE = 224
43 | elif "spacenet" in DATASET:
44 | IMG_SIZE = 224
45 | elif "whu_buildings" in DATASET:
46 | IMG_SIZE = 224
47 | elif "mass_roads" in DATASET:
48 | IMG_SIZE = 224
49 | INPUT_SIZE = 224
50 | PATCH_SIZE = 8
51 | INPUT_HEIGHT = INPUT_SIZE
52 | INPUT_WIDTH = INPUT_SIZE
53 | NUM_BINS = INPUT_HEIGHT*1
54 | LABEL_SMOOTHING = 0.0
55 | vertex_loss_weight = 1.0
56 | perm_loss_weight = 10.0
57 | SHUFFLE_TOKENS = False # order gt vertex tokens randomly every time
58 |
59 | BATCH_SIZE = 24 # batch size per gpu; effective batch size = BATCH_SIZE * NUM_GPUs
60 | START_EPOCH = 0
61 | NUM_EPOCHS = 500
62 | MILESTONE = 0
63 | SAVE_BEST = True
64 | SAVE_LATEST = True
65 | SAVE_EVERY = 10
66 | VAL_EVERY = 1
67 |
68 | MODEL_NAME = f'vit_small_patch{PATCH_SIZE}_{INPUT_SIZE}_dino'
69 | NUM_PATCHES = int((INPUT_SIZE // PATCH_SIZE) ** 2)
70 |
71 | LR = 4e-4
72 | WEIGHT_DECAY = 1e-4
73 |
74 | generation_steps = (N_VERTICES * 2) + 1 # sequence length during prediction. Should not be more than max_len
75 | run_eval = False
76 |
77 | # EXPERIMENT_NAME = f"debug_run_Pix2Poly224_Bins{NUM_BINS}_fullRotateAugs_permLossWeight{perm_loss_weight}_LR{LR}__{NUM_EPOCHS}epochs"
78 | EXPERIMENT_NAME = f"train_Pix2Poly_{DATASET}_run1_{MODEL_NAME}_AffineRotaugs0.8_LinearWarmupLRS_{vertex_loss_weight}xVertexLoss_{perm_loss_weight}xPermLoss__2xScoreNet_initialLR_{LR}_bs_{BATCH_SIZE}_Nv_{N_VERTICES}_Nbins{NUM_BINS}_{NUM_EPOCHS}epochs"
79 |
80 | if "debug" in EXPERIMENT_NAME:
81 | BATCH_SIZE = 10
82 | NUM_WORKERS = 0
83 | SAVE_BEST = False
84 | SAVE_LATEST = False
85 | SAVE_EVERY = NUM_EPOCHS
86 | VAL_EVERY = 50
87 |
88 | if LOAD_MODEL:
89 | CHECKPOINT_PATH = f"runs/{EXPERIMENT_NAME}/logs/checkpoints/latest.pth" # full path to checkpoint to be loaded if LOAD_MODEL=True
90 | else:
91 | CHECKPOINT_PATH = ""
92 |
93 |
--------------------------------------------------------------------------------
/data_preprocess/README.md:
--------------------------------------------------------------------------------
1 | # Datasets preparation
2 |
3 | ## INRIA Dataset
4 |
5 | 1. Download the [INRIA Aerial Image Labeling Dataset](https://project.inria.fr/aerialimagelabeling/).
6 | 2. Extract and place the aerial image tiles in the `data` directory as follows:
7 | ```
8 | data/inria_raw/
9 | ├── test/
10 | │ └── images/
11 | └── train/
12 | ├── gt/
13 | └── images/
14 | ```
15 | 3. Set path to raw INRIA train tiles and gts in L255 & L256 in `inria_to_coco.py`
16 | 4. Run the following command to prepare the INRIA dataset's train and validation splits in MS COCO format. The first 5 tiles of each city are kept as validation split as per the official recommendation.
17 | ```shell
18 | # with pix2poly_env
19 | python inria_to_coco.py
20 | ```
21 | ---
22 |
23 |
24 | ## SpaceNet 2 Building Detection v2 Dataset (Vegas Subset)
25 |
26 | NOTE: We only use the Vegas subset for all our experiments in the paper.
27 |
28 | 1. Download the [Spacenet 2 Building Detection v2 Dataset](https://spacenet.ai/spacenet-buildings-dataset-v2/).
29 | 2. Extract and place the satellite image tiles for the Vegas subset in the `data` folder in the following directory structure:
30 | ```
31 | data/AOI_2_Vegas_Train/
32 | └── geojson/
33 | └── buildings/
34 | └── RGB-PanSharpen/
35 | ├── gt/
36 | └── images/
37 | ```
38 | 3. Convert the pansharpened RGB image tiles from 16-bit to 8-bit using the following command:
39 | ```shell
40 | # with gdal_env
41 | python spacenet_convert_16bit_to_8bit.py
42 | ```
43 | 4. Convert geojson annotations from world space coordinates to pixel space coordinates using the following command:
44 | ```shell
45 | # with gdal_env
46 | python spacenet_world_to_pixel_coords.py
47 | ```
48 | 5. Set path to raw SpaceNet dataset's tiles and gts in L202 & L203 in `spacenet_to_coco.py`
49 | 6. Run the following command to prepare the SpaceNet dataset's train and validation splits in MS COCO format. The first 15% tiles kept as validation split.
50 | ```shell
51 | # with pix2poly_env
52 | python spacenet_to_coco.py
53 | ```
54 | ---
55 |
56 |
57 | ## WHU Buildings Dataset
58 |
59 |
60 | 1. Download the 0.2 meter split of the [WHU Buildings Aerial Imagery Dataset](http://gpcv.whu.edu.cn/data/building_dataset.html).
61 | 2. Extract and place the aerial image tiles (512x512) in the `data` folder in the following directory structure:
62 | ```
63 | data/WHU_aerial_0.2/
64 | ├── test/
65 | │ ├── image/
66 | │ └── label/
67 | ├── train/
68 | │ ├── image/
69 | │ └── label/
70 | └── val/
71 | ├── image/
72 | └── label/
73 | ```
74 | 3. Set path to raw WHU Buildings tiles and gts (512x512) in L263 & L264 in `whu_buildings_to_coco.py`
75 | 4. Run the following command to prepare the WHU Buildings dataset's train, validation and test splits in MS COCO format.
76 | ```shell
77 | # with pix2poly_env
78 | python whu_buildings_to_coco.py
79 | ```
80 | ---
81 |
82 |
83 |
84 | ## Massachusetts Roads Dataset
85 |
86 | 1. Download the [Massachusetts Roads Dataset](https://www.cs.toronto.edu/~vmnih/data/) using the following command:
87 | ```shell
88 | ./download_mass_roads_dataset.sh
89 | ```
90 | 2. Download and extract the roads vector shapefile for the dataset [here](https://www.cs.toronto.edu/~vmnih/data/mass_roads/massachusetts_roads_shape.zip). Use QGIS or any preferred tool to convert the SHP file to a geojson.
91 | 3. From this vector roads geojson, generate vector annotations for each tile in the dataset by clipping the geojson to the corresponding raster extents:
92 | ```shell
93 | # using gdal_env
94 | python mass_roads_clip_shapefile.py
95 | ```
96 | 4. This results in the following directory structure containing the 1500x1500 tiles of the Massachusetts Roads Dataset:
97 | ```
98 | data/mass_roads_1500/
99 | ├── test/
100 | │ ├── map/
101 | │ ├── sat/
102 | │ └── shape/
103 | ├── train/
104 | │ ├── map/
105 | │ ├── sat/
106 | │ └── shape/
107 | └── valid/
108 | ├── map/
109 | ├── sat/
110 | └── shape/
111 | ```
112 | 5. Split the 1500x1500 tiles into 224x224 overlapping patches with the following command:
113 | ```shell
114 | # using gdal_env
115 | python mass_roads_tiles_to_patches.py
116 | ```
117 | 6. Generate vector annotation files for the patches as follows:
118 | ```shell
119 | # using gdal_env
120 | python mass_roads_clip_tile_vectors.py
121 | python mass_roads_world_to_pixel_coords.py
122 | ```
123 | 7. This results in the processed 224x224 patches of the Massachusetts Roads Dataset to be used for training Pix2Poly in the following directory structure:
124 | ```
125 | data/mass_roads_224/
126 | ├── test/
127 | │ ├── map/
128 | │ ├── pixel_annotations/
129 | │ ├── sat/
130 | │ └── shape/
131 | ├── train/
132 | │ ├── map/
133 | │ ├── pixel_annotations/
134 | │ ├── sat/
135 | │ └── shape/
136 | └── valid/
137 | ├── map/
138 | ├── pixel_annotations/
139 | ├── sat/
140 | └── shape/
141 | ```
142 | ---
143 |
--------------------------------------------------------------------------------
/data_preprocess/download_mass_roads_dataset.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | mkdir -p ../data
4 | cd ../data
5 |
6 | # Download train split
7 | wget -r --no-parent https://www.cs.toronto.edu/~vmnih/data/mass_roads/train/sat/
8 | wget -r --no-parent https://www.cs.toronto.edu/~vmnih/data/mass_roads/train/map/
9 |
10 | # Download valid split
11 | wget -r --no-parent https://www.cs.toronto.edu/~vmnih/data/mass_roads/valid/sat/
12 | wget -r --no-parent https://www.cs.toronto.edu/~vmnih/data/mass_roads/valid/map/
13 |
14 | # Download test split
15 | wget -r --no-parent https://www.cs.toronto.edu/~vmnih/data/mass_roads/test/sat/
16 | wget -r --no-parent https://www.cs.toronto.edu/~vmnih/data/mass_roads/test/map/
17 |
18 | mv www.cs.toronto.edu/\~vmnih/data/mass_roads .
19 | mv mass_roads mass_roads_1500
20 | rm -r www.cs.toronto.edu
21 |
--------------------------------------------------------------------------------
/data_preprocess/mass_roads_clip_shapefile.py:
--------------------------------------------------------------------------------
1 | from osgeo import gdal
2 | import os
3 | import subprocess
4 | from tqdm import tqdm
5 | from multiprocessing import Pool
6 |
7 |
8 | def clip_shapefile(paths_dict):
9 | # get the extent of the raster
10 | raster_path = paths_dict['raster_path']
11 | save_path = paths_dict['save_path']
12 | vector_path = paths_dict['vector_path']
13 | src = gdal.Open(raster_path)
14 | ulx, xres, xskew, uly, yskew, yres = src.GetGeoTransform()
15 | sizeX = src.RasterXSize * xres
16 | sizeY = src.RasterYSize * yres
17 | lrx = ulx + sizeX
18 | lry = uly + sizeY
19 | src = None
20 |
21 | # format the extent coords
22 | extent = f"{ulx} {lry} {lrx} {uly}"
23 | # print(extent)
24 |
25 | # make clip command with ogr2ogr
26 | cmd = f"ogr2ogr {save_path} {vector_path} -clipsrc {extent}"
27 |
28 | # call the command
29 | subprocess.call(cmd, shell=True)
30 | return 0
31 |
32 |
33 | def main():
34 | split = "train" # "train" or "valid" or "test"
35 | data_root = "../data/mass_roads_1500"
36 | vector_path = os.path.join(data_root, "massachusetts_roads_shape.geojson")
37 | save_dir = os.path.join(data_root, split, "shape")
38 | os.makedirs(save_dir, exist_ok=True)
39 |
40 | rasters_dir = os.path.join(data_root, split, "sat")
41 | rasters = os.listdir(rasters_dir)
42 |
43 | param_inputs = []
44 | for raster in rasters:
45 | raster_path = os.path.join(rasters_dir, raster)
46 | save_path = os.path.join(save_dir, raster.split('.')[0]+".geojson")
47 | param_inputs.append(
48 | {
49 | 'raster_path': raster_path,
50 | 'vector_path': vector_path,
51 | 'save_path': save_path,
52 | }
53 | )
54 | # clip_shapefile(raster_path, vector_path, save_path)
55 |
56 | with Pool() as p:
57 | _ = list(tqdm(p.imap(clip_shapefile, param_inputs), total=len(param_inputs)))
58 |
59 |
60 | if __name__ == "__main__":
61 | main()
62 |
63 |
--------------------------------------------------------------------------------
/data_preprocess/mass_roads_clip_tile_vectors.py:
--------------------------------------------------------------------------------
1 | from osgeo import gdal
2 | import os
3 | import subprocess
4 | from tqdm import tqdm
5 | from multiprocessing import Pool
6 |
7 |
8 | def clip_shapefile(paths_dict):
9 | # get the extent of the raster
10 | raster_path = paths_dict['raster_path']
11 | save_path = paths_dict['save_path']
12 | vector_path = paths_dict['vector_path']
13 | src = gdal.Open(raster_path)
14 | ulx, xres, xskew, uly, yskew, yres = src.GetGeoTransform()
15 | sizeX = src.RasterXSize * xres
16 | sizeY = src.RasterYSize * yres
17 | lrx = ulx + sizeX
18 | lry = uly + sizeY
19 | src = None
20 |
21 | # format the extent coords
22 | extent = f"{ulx} {lry} {lrx} {uly}"
23 | # print(extent)
24 |
25 | # make clip command with ogr2ogr
26 | cmd = f"ogr2ogr {save_path} {vector_path} -clipsrc {extent}"
27 |
28 | # call the command
29 | subprocess.call(cmd, shell=True)
30 | return 0
31 |
32 |
33 | def main():
34 | split = "train" # "train" or "valid" or "test"
35 | data_root = f"../data/mass_roads_1500"
36 | vector_dir = os.path.join(data_root, split, "shape")
37 | save_dir = f'../data/mass_roads_224/{split}/shape'
38 | os.makedirs(save_dir, exist_ok=True)
39 |
40 | rasters_dir = f'../data/mass_roads_224/{split}/images/'
41 | rasters = os.listdir(rasters_dir)
42 |
43 | param_inputs = []
44 | for raster in rasters:
45 | raster_path = os.path.join(rasters_dir, raster)
46 | save_path = os.path.join(save_dir, raster.split('.')[0]+".geojson")
47 | raster_info = raster.split('_')
48 | vec_desc = f"{raster_info[0]}_{raster_info[1]}.geojson"
49 | vector_path = os.path.join(vector_dir, vec_desc)
50 | param_inputs.append(
51 | {
52 | 'raster_path': raster_path,
53 | 'vector_path':vector_path,
54 | 'save_path': save_path
55 | }
56 | )
57 |
58 |
59 | # clip_shapefile(raster_path, vector_path, save_path)
60 | with Pool() as p:
61 | _ = list(tqdm(p.imap(clip_shapefile, param_inputs), total=len(param_inputs)))
62 |
63 | if __name__ == "__main__":
64 | main()
65 |
66 |
--------------------------------------------------------------------------------
/data_preprocess/mass_roads_tiles_to_patches.py:
--------------------------------------------------------------------------------
1 | # gdal_retile.py -targetDir gdal_retile/valid/images/ -ps 224 224 -overlap 34 valid/sat/*.tiff
2 |
3 | import os
4 | import subprocess
5 | from tqdm import tqdm
6 |
7 |
8 | def main():
9 | split = 'train'
10 | data_root = f"../data/mass_roads_1500"
11 | images_dir = os.path.join(data_root, split, "sat")
12 | masks_dir = os.path.join(data_root, split, "map")
13 |
14 | out_imgs_dir = f"../data/mass_roads_224/{split}/images"
15 | out_mask_dir = f"../data/mass_roads_224/{split}/mask"
16 | os.makedirs(out_imgs_dir, exist_ok=True)
17 | os.makedirs(out_mask_dir, exist_ok=True)
18 |
19 | patch_size = 224
20 | ph = pw = patch_size
21 | overlap = int(round(0.15*patch_size))
22 |
23 | images = os.listdir(images_dir)
24 | for img in tqdm(images):
25 | in_path = os.path.join(images_dir, img)
26 | cmd = f"gdal_retile.py -targetDir {out_imgs_dir} -ps {ph} {pw} -overlap {overlap} {in_path}"
27 | subprocess.call(cmd, shell=True)
28 |
29 | masks = os.listdir(masks_dir)
30 | for mask in tqdm(masks):
31 | in_path = os.path.join(masks_dir, mask)
32 | cmd = f"gdal_retile.py -targetDir {out_mask_dir} -ps {ph} {pw} -overlap {overlap} {in_path}"
33 | subprocess.call(cmd, shell=True)
34 |
35 |
36 | if __name__ == "__main__":
37 | main()
38 |
--------------------------------------------------------------------------------
/data_preprocess/mass_roads_world_to_pixel_coords.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | from tqdm import tqdm
4 | from osgeo import gdal # use gdal env for this script.
5 |
6 |
7 | def main():
8 | split = "train" # "train" or "valid" or "test"
9 | # data_root = f"../data/mass_roads_1500"
10 | data_root = f"../data/mass_roads_224"
11 | # geoimages_dir = os.path.join(data_root, split, "sat")
12 | geoimages_dir = os.path.join(data_root, split, "images")
13 | # shapefiles_dir = os.path.join(data_root, split, "shape")
14 | shapefiles_dir = os.path.join(data_root, split, "shape")
15 |
16 | save_shapefiles_dir = os.path.join(data_root, split, "pixel_annotations")
17 | os.makedirs(save_shapefiles_dir, exist_ok=True)
18 |
19 | geoimages = os.listdir(geoimages_dir)
20 | shapefiles = os.listdir(shapefiles_dir)
21 |
22 | for i in tqdm(range(len(geoimages))):
23 | geo_im = geoimages[i]
24 | # im_desc = geo_im.split('_')[-1].split('.')[0]
25 | im_desc = geo_im.split('.')[0]
26 | # shp = [sh for sh in shapefiles if f"{im_desc}.geojson" in sh]
27 | # assert len(shp) == 1
28 | # shp = shp[0]
29 | shp = f"{im_desc}.geojson"
30 | # shp = shapefiles[0]
31 |
32 | driver = gdal.GetDriverByName('GTiff')
33 | dataset = gdal.Open(os.path.join(geoimages_dir, geo_im))
34 | band = dataset.GetRasterBand(1)
35 | cols = dataset.RasterXSize
36 | rows = dataset.RasterYSize
37 | transform = dataset.GetGeoTransform()
38 | xOrigin = transform[0]
39 | yOrigin = transform[3]
40 | pixelWidth = transform[1]
41 | pixelHeight = -transform[5]
42 | data = band.ReadAsArray(0, 0, cols, rows)
43 |
44 | with open(os.path.join(shapefiles_dir, shp), 'r') as f:
45 | geo_shp = json.load(f)
46 |
47 | pixel_shp = {}
48 | pixel_shp['type'] = geo_shp['type']
49 | pixel_shp['features'] = []
50 |
51 | for feature in geo_shp['features']:
52 | out_feat = {
53 | 'type': 'Feature',
54 | 'geometry': {
55 | 'type': 'MultiLineString',
56 | 'coordinates': []
57 | }
58 | }
59 | if feature['geometry']['type'] == "MultiLineString":
60 | feature_coords = feature['geometry']['coordinates']
61 | for geo_coords in feature_coords:
62 | points_list = [(gc[0], gc[1]) for gc in geo_coords]
63 | coords_list = []
64 | for point in points_list:
65 | col = (point[0] - xOrigin) / pixelWidth
66 | col = col if col < cols else cols
67 | row = (yOrigin - point[1]) / pixelHeight
68 | row = row if row < rows else rows
69 | # 'row' has negative sign to be compatible with qgis visualization. Must not be used for compatibility in image space.
70 | # coords_list.append([col, -row])
71 | out_feat['geometry']['coordinates'].append(coords_list)
72 | pixel_shp['features'].append(out_feat)
73 | else:
74 | geo_coords = feature['geometry']['coordinates']
75 | points_list = [(gc[0], gc[1]) for gc in geo_coords]
76 | coords_list = []
77 | for point in points_list:
78 | col = (point[0] - xOrigin) / pixelWidth
79 | col = col if col < cols else cols
80 | row = (yOrigin - point[1]) / pixelHeight
81 | row = row if row < rows else rows
82 | # 'row' has negative sign to be compatible with qgis visualization. Must not be used for compatibility in image space.
83 | # coords_list.append([col, -row])
84 | coords_list.append([col, row])
85 | out_feat['geometry']['coordinates'].append(coords_list)
86 | pixel_shp['features'].append(out_feat)
87 |
88 | with open(os.path.join(save_shapefiles_dir, shp), 'w') as o:
89 | json.dump(pixel_shp, o)
90 |
91 |
92 | if __name__ == "__main__":
93 | main()
94 |
95 |
--------------------------------------------------------------------------------
/data_preprocess/spacenet_convert_16bit_to_8bit.py:
--------------------------------------------------------------------------------
1 | import os
2 | from osgeo import gdal
3 | from tqdm import tqdm
4 |
5 |
6 | def convert_16bit_to_8bit(in_path, out_path, out_format='GTiff'):
7 | translate_options = gdal.TranslateOptions(format=out_format,
8 | outputType=gdal.GDT_Byte,
9 | scaleParams=[''],
10 | # scaleParams=[min_val, max_val],
11 | )
12 | gdal.Translate(destName=out_path, srcDS=in_path, options=translate_options)
13 |
14 |
15 | def main():
16 | src_dir = f"../data/AOI_2_Vegas_Train/RGB-PanSharpen/"
17 | dest_dir = f"../data/AOI_2_Vegas_Train/RGB_8bit/train/images"
18 | os.makedirs(dest_dir, exist_ok=True)
19 |
20 | src_ims = os.listdir(src_dir)
21 | for im in tqdm(src_ims):
22 | src_im = os.path.join(src_dir, im)
23 | dest_im = os.path.join(dest_dir, im)
24 | convert_16bit_to_8bit(src_im, dest_im)
25 |
26 |
27 | if __name__ == "__main__":
28 | main()
29 |
30 |
--------------------------------------------------------------------------------
/data_preprocess/spacenet_to_coco.py:
--------------------------------------------------------------------------------
1 | # Modified from https://github.com/SarahwXU/HiSup/blob/main/tools/inria_to_coco.py
2 | # Transform Spacenet 2 dataset (image and geojson pairs) to COCO format
3 | #
4 | # The first 15% images are kept as validation set
5 |
6 | import os
7 | import numpy as np
8 | from skimage import io
9 | import json
10 | from tqdm import tqdm
11 | from shapely.geometry import Polygon, mapping
12 | from shapely.ops import transform as poly_transform
13 | from shapely.ops import unary_union
14 | from shapely.geometry import box
15 | from skimage.measure import label as ski_label
16 | from skimage.measure import regionprops
17 | import cv2
18 | import math
19 | import shapely
20 |
21 |
22 | def clip_by_bound(poly, im_h, im_w):
23 | """
24 | Bound poly coordinates by image shape
25 | """
26 | p_x = poly[:, 0]
27 | p_y = poly[:, 1]
28 | p_x = np.clip(p_x, 0.0, im_w-1)
29 | p_y = np.clip(p_y, 0.0, im_h-1)
30 | return np.concatenate((p_x[:, np.newaxis], p_y[:, np.newaxis]), axis=1)
31 |
32 |
33 | def crop2patch(im_p, p_h, p_w, p_overlap):
34 | """
35 | Get coordinates of upper-left point for image patch
36 | return: patch_list [X_upper-left, Y_upper-left, patch_width, patch_height]
37 | """
38 | im_h, im_w, _ = im_p
39 | x = np.arange(0, im_w-p_w, p_w-p_overlap)
40 | x = np.append(x, im_w-p_w)
41 | y = np.arange(0, im_h-p_h, p_h-p_overlap)
42 | y = np.append(y, im_h-p_h)
43 | X, Y = np.meshgrid(x, y)
44 | patch_list = [[i, j, p_w, p_h] for i, j in zip(X.flatten(), Y.flatten())]
45 | return patch_list
46 |
47 |
48 | def rotate_image(image, angle):
49 | """
50 | Rotates an OpenCV 2 / NumPy image about it's centre by the given angle
51 | (in degrees). The returned image will be large enough to hold the entire
52 | new image, with a black background
53 | """
54 |
55 | # Get the image size
56 | # No that's not an error - NumPy stores image matricies backwards
57 | image_size = (image.shape[1], image.shape[0])
58 | image_center = tuple(np.array(image_size) / 2)
59 |
60 | # Convert the OpenCV 3x2 rotation matrix to 3x3
61 | rot_mat = np.vstack(
62 | [cv2.getRotationMatrix2D(image_center, angle, 1.0), [0, 0, 1]]
63 | )
64 |
65 | rot_mat_notranslate = np.matrix(rot_mat[0:2, 0:2])
66 |
67 | # Shorthand for below calcs
68 | image_w2 = image_size[0] * 0.5
69 | image_h2 = image_size[1] * 0.5
70 |
71 | # Obtain the rotated coordinates of the image corners
72 | rotated_coords = [
73 | (np.array([-image_w2, image_h2]) * rot_mat_notranslate).A[0],
74 | (np.array([ image_w2, image_h2]) * rot_mat_notranslate).A[0],
75 | (np.array([-image_w2, -image_h2]) * rot_mat_notranslate).A[0],
76 | (np.array([ image_w2, -image_h2]) * rot_mat_notranslate).A[0]
77 | ]
78 |
79 | # Find the size of the new image
80 | x_coords = [pt[0] for pt in rotated_coords]
81 | x_pos = [x for x in x_coords if x > 0]
82 | x_neg = [x for x in x_coords if x < 0]
83 |
84 | y_coords = [pt[1] for pt in rotated_coords]
85 | y_pos = [y for y in y_coords if y > 0]
86 | y_neg = [y for y in y_coords if y < 0]
87 |
88 | right_bound = max(x_pos)
89 | left_bound = min(x_neg)
90 | top_bound = max(y_pos)
91 | bot_bound = min(y_neg)
92 |
93 | new_w = int(abs(right_bound - left_bound))
94 | new_h = int(abs(top_bound - bot_bound))
95 |
96 | # We require a translation matrix to keep the image centred
97 | trans_mat = np.matrix([
98 | [1, 0, int(new_w * 0.5 - image_w2)],
99 | [0, 1, int(new_h * 0.5 - image_h2)],
100 | [0, 0, 1]
101 | ])
102 |
103 | # Compute the tranform for the combined rotation and translation
104 | affine_mat = (np.matrix(trans_mat) * np.matrix(rot_mat))[0:2, :]
105 |
106 | # Apply the transform
107 | result = cv2.warpAffine(
108 | image,
109 | affine_mat,
110 | (new_w, new_h),
111 | flags=cv2.INTER_LINEAR
112 | )
113 |
114 | return result
115 |
116 |
117 | def largest_rotated_rect(w, h, angle):
118 | """
119 | Given a rectangle of size wxh that has been rotated by 'angle' (in
120 | radians), computes the width and height of the largest possible
121 | axis-aligned rectangle within the rotated rectangle.
122 |
123 | Original JS code by 'Andri' and Magnus Hoff from Stack Overflow
124 |
125 | Converted to Python by Aaron Snoswell
126 | """
127 |
128 | quadrant = int(math.floor(angle / (math.pi / 2))) & 3
129 | sign_alpha = angle if ((quadrant & 1) == 0) else math.pi - angle
130 | alpha = (sign_alpha % math.pi + math.pi) % math.pi
131 |
132 | bb_w = w * math.cos(alpha) + h * math.sin(alpha)
133 | bb_h = w * math.sin(alpha) + h * math.cos(alpha)
134 |
135 | gamma = math.atan2(bb_w, bb_w) if (w < h) else math.atan2(bb_w, bb_w)
136 |
137 | delta = math.pi - alpha - gamma
138 |
139 | length = h if (w < h) else w
140 |
141 | d = length * math.cos(alpha)
142 | a = d * math.sin(alpha) / math.sin(delta)
143 |
144 | y = a * math.cos(gamma)
145 | x = y * math.tan(gamma)
146 |
147 | return (
148 | bb_w - 2 * x,
149 | bb_h - 2 * y
150 | )
151 |
152 |
153 | def crop_around_center(image, width, height):
154 | """
155 | Given a NumPy / OpenCV 2 image, crops it to the given width and height,
156 | around it's centre point
157 | """
158 |
159 | image_size = (image.shape[1], image.shape[0])
160 | image_center = (int(image_size[0] * 0.5), int(image_size[1] * 0.5))
161 |
162 | if(width > image_size[0]):
163 | width = image_size[0]
164 |
165 | if(height > image_size[1]):
166 | height = image_size[1]
167 |
168 | x1 = int(image_center[0] - width * 0.5)
169 | x2 = int(image_center[0] + width * 0.5)
170 | y1 = int(image_center[1] - height * 0.5)
171 | y2 = int(image_center[1] + height * 0.5)
172 |
173 | return image[y1:y2, x1:x2]
174 |
175 |
176 | def rotate_crop(im, gt, crop_size, angle):
177 | h, w = im.shape[0:2]
178 | im_rotated = rotate_image(im, angle)
179 | gt_rotated = rotate_image(gt, angle)
180 | if largest_rotated_rect(w, h, math.radians(angle))[0] >= crop_size:
181 | im_cropped = crop_around_center(im_rotated, crop_size, crop_size)
182 | gt_cropped = crop_around_center(gt_rotated, crop_size, crop_size)
183 | else:
184 | # print('error')
185 | im_cropped = crop_around_center(im, crop_size, crop_size)
186 | gt_cropped = crop_around_center(gt, crop_size, crop_size)
187 | return im_cropped, gt_cropped
188 |
189 |
190 | def lt_crop(im, gt, crop_size):
191 | im_cropped = im[0:crop_size, 0:crop_size, :]
192 | gt_cropped = gt[0:crop_size, 0:crop_size]
193 | return im_cropped, gt_cropped
194 |
195 |
196 | # for polygon vflip
197 | def reflection():
198 | return lambda x, y: (x, -y)
199 |
200 |
201 | if __name__ == '__main__':
202 | input_image_path = f"../data/AOI_2_Vegas_Train/RGB_8bit/train/images"
203 | input_annos_path = f"../data/AOI_2_Vegas_Train/pixel_geojson"
204 |
205 | save_path = f"../data/spacenet_coco/"
206 |
207 | all_images = os.listdir(input_image_path)
208 | val_count = int(0.15 * len(all_images))
209 | print(f"No. of val images: {val_count}")
210 | val_images = all_images[0:val_count]
211 | train_images = all_images[val_count:]
212 |
213 | train_set = set(train_images)
214 | val_set = set(val_images)
215 | if len(train_set.intersection(val_set)) > 0 or len(val_set.intersection(train_set)):
216 | raise RuntimeError()
217 |
218 | split = 'train'
219 |
220 | if split == 'train':
221 | query_images = train_images
222 | elif split == 'val':
223 | query_images = val_images
224 | else:
225 | raise Exception(f'"{split}" is an incorrect split choice. Split choice must be either "train" or "val".')
226 |
227 | output_im_train = os.path.join(save_path, split, 'images')
228 | if not os.path.exists(output_im_train):
229 | os.makedirs(output_im_train)
230 |
231 | # patch_width = 725
232 | # patch_height = 725
233 | # patch_overlap = 300
234 | # patch_size = 512
235 | # rotation_list = [22.5, 45, 67.5]
236 |
237 | patch_width = 224
238 | patch_height = 224
239 | patch_overlap = 34 # ~15% of patch size
240 | patch_size = 224
241 | rotation_list = []
242 |
243 | # main dict for annotation file
244 | output_data_train = {
245 | 'info': {'district': 'SpaceNetv2', 'description': 'building footprints', 'contributor': 'cyens'},
246 | 'categories': [{'id': 100, 'name': 'building'}],
247 | 'images': [],
248 | 'annotations': [],
249 | }
250 |
251 | train_ob_id = 0
252 | train_im_id = 0
253 | # read in data with npy format
254 | input_label = os.listdir(input_annos_path)
255 | for g_id, label in enumerate(tqdm(input_label)):
256 | # read data
257 | # label_info = [''.join(list(g)) for k, g in groupby(label, key=lambda x: x.isdigit())]
258 | label_info = label.split('_')
259 |
260 | label_name = label_info[-1].split('.')[0]
261 | im_name = [im for im in all_images if label_name+".tif" in im]
262 | assert len(im_name) == 1
263 | im_name = im_name[0]
264 | image_data = io.imread(os.path.join(input_image_path, im_name))
265 | with open(os.path.join(input_annos_path, label), 'r') as f:
266 | anno_data = json.load(f)
267 | im_h, im_w, _ = image_data.shape
268 |
269 | tile_polygons = []
270 | for poly in anno_data['features']:
271 | poly = poly['geometry']['coordinates']
272 | assert len(poly) == 1
273 | poly = np.array(poly[0])
274 | poly = Polygon(poly)
275 | poly = poly_transform(reflection(), poly)
276 | tile_polygons.append(poly)
277 | tile_polygons = shapely.geometry.MultiPolygon(tile_polygons)
278 | tile_polygons = unary_union(tile_polygons)
279 | # tile_polygons = poly_transform(reflection(), tile_polygons)
280 |
281 | if im_name in query_images:
282 | # for training/val set, split image to 224x224
283 | patch_list = crop2patch(image_data.shape, patch_width, patch_height, patch_overlap)
284 | for pid, pa in enumerate(patch_list):
285 | x_ul, y_ul, pw, ph = pa
286 | # bbox_s = box(y_ul, y_ul+patch_height, x_ul, x_ul+patch_width)
287 | bbox_s = box(x_ul, y_ul, x_ul+patch_width, y_ul+patch_height)
288 |
289 | p_gt = tile_polygons.intersection(bbox_s)
290 | # print(type(p_gt))
291 | if isinstance(p_gt, Polygon):
292 | p_gt = shapely.geometry.MultiPolygon([p_gt])
293 | else:
294 | p_gt = shapely.geometry.MultiPolygon(p_gt)
295 | p_im = image_data[y_ul:y_ul+patch_height, x_ul:x_ul+patch_width, :]
296 | p_gts = []
297 | p_ims = []
298 | p_im_rd, _ = lt_crop(p_im, p_im[0], patch_size)
299 | p_gts.append(p_gt)
300 | p_ims.append(p_im_rd)
301 | # for angle in rotation_list:
302 | # rot_im, _ = rotate_crop(p_im, p_im, patch_size, angle)
303 | # # p_gts.append(rot_gt)
304 | # p_ims.append(rot_im)
305 | for p_im, p_gt in zip(p_ims, p_gts):
306 | if len(p_gt.geoms) > 0:
307 | p_polygons = p_gt.geoms
308 | for poly in p_polygons:
309 | # poly = poly['geometry']['coordinates']
310 | # assert len(poly) == 1
311 | poly = np.asarray(poly.exterior.coords)
312 | poly -= np.array([x_ul, y_ul])
313 | poly = Polygon(poly)
314 | # poly = poly_transform(reflection(), poly)
315 | p_area = round(poly.area, 2)
316 | if p_area > 0:
317 | p_bbox = [poly.bounds[0], poly.bounds[1], poly.bounds[2]-poly.bounds[0], poly.bounds[3]-poly.bounds[1]]
318 | if p_bbox[2] > 0 and p_bbox[3] > 0:
319 | p_seg = []
320 | coor_list = mapping(poly)['coordinates']
321 | assert len(coor_list) == 1
322 | # import code; code.interact(local=locals())
323 | for part_poly in coor_list:
324 | p_seg.append(np.asarray(part_poly).ravel().tolist())
325 | anno_info = {
326 | 'id': train_ob_id,
327 | 'image_id': train_im_id,
328 | 'segmentation': p_seg,
329 | 'area': p_area,
330 | 'bbox': p_bbox,
331 | 'category_id': 100,
332 | 'iscrowd': 0
333 | }
334 | output_data_train['annotations'].append(anno_info)
335 | train_ob_id += 1
336 | else: # for including negative samples.
337 | anno_info = {
338 | 'id': train_ob_id,
339 | 'image_id': train_im_id,
340 | 'segmentation': [],
341 | 'area': 0.,
342 | 'bbox': [],
343 | 'category_id': 100,
344 | 'iscrowd': 1
345 | }
346 | output_data_train['annotations'].append(anno_info)
347 | train_ob_id += 1
348 | # get patch info
349 | p_name = label_name + '-' + str(train_im_id) + '.tif'
350 | patch_info = {'id': train_im_id, 'file_name': p_name, 'width': patch_size, 'height': patch_size}
351 | output_data_train['images'].append(patch_info)
352 | # save patch image
353 | io.imsave(os.path.join(output_im_train, p_name), p_im)
354 | train_im_id += 1
355 |
356 | if not os.path.exists(os.path.join(save_path, split)):
357 | os.makedirs(save_path)
358 | with open(os.path.join(save_path, split, 'annotation.json'), 'w') as f_json:
359 | json.dump(output_data_train, f_json)
360 |
--------------------------------------------------------------------------------
/data_preprocess/spacenet_world_to_pixel_coords.py:
--------------------------------------------------------------------------------
1 | import code
2 | import os
3 | import json
4 | from tqdm import tqdm
5 | from osgeo import gdal # use gdal env for this script.
6 |
7 |
8 | def main():
9 | ## VEGAS SUBSET
10 | # NOTE: Set path to spacenet dataset images and geojson annotations here.
11 | spacenet_dataset_root = f"../data/AOI_2_Vegas_Train/"
12 | geoimages_dir = os.path.join(spacenet_dataset_root, 'RGB_8bit', 'train', 'images')
13 | shapefiles_dir = os.path.join(spacenet_dataset_root, 'geojson', 'buildings')
14 |
15 | save_shapefiles_dir = os.path.join(spacenet_dataset_root, 'pixel_geojson')
16 | os.makedirs(save_shapefiles_dir, exist_ok=True)
17 |
18 | geoimages = os.listdir(geoimages_dir)
19 | shapefiles = os.listdir(shapefiles_dir)
20 |
21 | for i in tqdm(range(len(geoimages))):
22 | geo_im = geoimages[i]
23 | im_desc = geo_im.split('_')[-1].split('.')[0]
24 | shp = [sh for sh in shapefiles if f"{im_desc}.geojson" in sh]
25 | assert len(shp) == 1
26 | shp = shp[0]
27 |
28 | driver = gdal.GetDriverByName('GTiff')
29 | dataset = gdal.Open(os.path.join(geoimages_dir, geo_im))
30 | band = dataset.GetRasterBand(1)
31 | cols = dataset.RasterXSize
32 | rows = dataset.RasterYSize
33 | transform = dataset.GetGeoTransform()
34 | xOrigin = transform[0]
35 | yOrigin = transform[3]
36 | pixelWidth = transform[1]
37 | pixelHeight = -transform[5]
38 | data = band.ReadAsArray(0, 0, cols, rows)
39 |
40 | with open(os.path.join(shapefiles_dir, shp), 'r') as f:
41 | geo_shp = json.load(f)
42 |
43 | pixel_shp = {}
44 | pixel_shp['type'] = geo_shp['type']
45 | pixel_shp['features'] = []
46 |
47 | for feature in geo_shp['features']:
48 | out_feat = {
49 | 'type': 'Feature',
50 | 'geometry': {
51 | 'type': 'Polygon',
52 | 'coordinates': []
53 | }
54 | }
55 | if feature['geometry']['type'] == "Polygon":
56 | geo_coords = feature['geometry']['coordinates'][0]
57 | points_list = [(gc[0], gc[1]) for gc in geo_coords]
58 | coords_list = []
59 | for point in points_list:
60 | col = (point[0] - xOrigin) / pixelWidth
61 | col = col if col < cols else cols
62 | row = (yOrigin - point[1]) / pixelHeight
63 | row = row if row < rows else rows
64 | # 'row' has negative sign to be compatible with qgis visualization. Must be removed for compatibility in image space.
65 | coords_list.append([col, -row])
66 | out_feat['geometry']['coordinates'].append(coords_list)
67 | pixel_shp['features'].append(out_feat)
68 |
69 | with open(os.path.join(save_shapefiles_dir, shp), 'w') as o:
70 | json.dump(pixel_shp, o)
71 |
72 |
73 | if __name__ == "__main__":
74 | main()
75 |
--------------------------------------------------------------------------------
/datasets/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yeshwanth95/Pix2Poly/d2e288b4e267ae910cee6f54bece8eec98f43bea/datasets/__init__.py
--------------------------------------------------------------------------------
/datasets/dataset_inria_coco.py:
--------------------------------------------------------------------------------
1 | from PIL import Image
2 | import numpy as np
3 | import os
4 | from os import path as osp
5 | from pycocotools.coco import COCO
6 | from config import CFG
7 |
8 | import torch
9 | from torch.utils.data import Dataset
10 | from torch.nn.utils.rnn import pad_sequence
11 |
12 |
13 | class InriaCocoDataset(Dataset):
14 | def __init__(self, dataset_dir, transform=None, tokenizer=None, shuffle_tokens=False):
15 | image_dir = osp.join(dataset_dir, "images")
16 | self.image_dir = image_dir
17 | self.annotations_path = osp.join(dataset_dir, "annotation.json")
18 | self.transform = transform
19 | self.tokenizer = tokenizer
20 | self.shuffle_tokens = shuffle_tokens
21 | # self.images = os.listdir(self.image_dir)
22 | self.coco = COCO(self.annotations_path)
23 | # self.image_ids = self.coco.getImgIds(catIds=self.coco.getCatIds())
24 | self.images = [file for file in os.listdir(self.image_dir) if osp.isfile(osp.join(self.image_dir, file))]
25 | self.image_ids = [int(im.split('-')[-1].split('.')[0]) for im in self.images if im.split('-')[0] not in ['kitsap4', 'kitsap5']]
26 |
27 | def __len__(self):
28 | return len(self.image_ids)
29 |
30 | def annToMask(self):
31 | return
32 |
33 | def shuffle_perm_matrix_by_indices(self, old_perm: torch.Tensor, shuffle_idxs: np.ndarray):
34 | Nv = old_perm.shape[0]
35 | padd_idxs = np.arange(len(shuffle_idxs), Nv)
36 | shuffle_idxs = np.concatenate([shuffle_idxs, padd_idxs], axis=0)
37 |
38 | transform_arr = torch.zeros_like(old_perm)
39 | for i in range(len(shuffle_idxs)):
40 | transform_arr[i, shuffle_idxs[i]] = 1.
41 |
42 | # https://math.stackexchange.com/questions/2481213/adjacency-matrix-and-changing-order-of-vertices
43 | new_perm = torch.mm(torch.mm(transform_arr, old_perm), transform_arr.T)
44 |
45 | return new_perm
46 |
47 | def __getitem__(self, index):
48 | n_vertices = CFG.N_VERTICES
49 | img_id = self.image_ids[index]
50 | img = self.coco.loadImgs(img_id)[0]
51 | img_path = osp.join(self.image_dir, img["file_name"])
52 | ann_ids = self.coco.getAnnIds(imgIds=img['id'])
53 | annotations = self.coco.loadAnns(ann_ids) # annotations of all instances in an image.
54 |
55 | image = np.array(Image.open(img_path).convert("RGB"))
56 |
57 | mask = np.zeros((img['width'], img['height']))
58 | corner_coords = []
59 | corner_mask = np.zeros((img['width'], img['height']), dtype=np.float32)
60 | perm_matrix = np.zeros((n_vertices, n_vertices), dtype=np.float32)
61 | for ins in annotations:
62 | segmentations = ins['segmentation']
63 | for i, segm in enumerate(segmentations):
64 | segm = np.array(segm).reshape(-1, 2)
65 | segm[:, 0] = np.clip(segm[:, 0], 0, img['width'] - 1)
66 | segm[:, 1] = np.clip(segm[:, 1], 0, img['height'] - 1)
67 | points = segm[:-1]
68 | corner_coords.extend(points.tolist())
69 | mask += self.coco.annToMask(ins)
70 | mask = mask / 255. if mask.max() == 255 else mask
71 | mask = np.clip(mask, 0, 1)
72 |
73 | corner_coords = np.flip(np.round(corner_coords, 0), axis=-1).astype(np.int32)
74 |
75 | if len(corner_coords) > 0.:
76 | corner_mask[corner_coords[:, 0], corner_coords[:, 1]] = 1.
77 |
78 | ############# START: Generate gt permutation matrix. #############
79 | v_count = 0
80 | for ins in annotations:
81 | segmentations = ins['segmentation']
82 | for idx, segm in enumerate(segmentations):
83 | segm = np.array(segm).reshape(-1, 2)
84 | points = segm[:-1]
85 | for i in range(len(points)):
86 | j = (i + 1) % len(points)
87 | if v_count+i > n_vertices - 1 or v_count+j > n_vertices-1:
88 | break
89 | perm_matrix[v_count+i, v_count+j] = 1.
90 | v_count += len(points)
91 |
92 | for i in range(v_count, n_vertices):
93 | perm_matrix[i, i] = 1.
94 |
95 | # Workaround for open contours:
96 | for i in range(n_vertices):
97 | row = perm_matrix[i, :]
98 | col = perm_matrix[:, i]
99 | if np.sum(row) == 0 or np.sum(col) == 0:
100 | perm_matrix[i, i] = 1.
101 | perm_matrix = torch.from_numpy(perm_matrix)
102 | ############# END: Generate gt permutation matrix. #############
103 |
104 | masks = [mask, corner_mask]
105 |
106 | if len(corner_coords) > CFG.N_VERTICES:
107 | corner_coords = corner_coords[:CFG.N_VERTICES]
108 |
109 | if self.transform is not None:
110 | augmentations = self.transform(image=image, masks=masks, keypoints=corner_coords.tolist())
111 | image = augmentations['image']
112 | mask = augmentations['masks'][0]
113 | corner_mask = augmentations['masks'][1]
114 | corner_coords = np.array(augmentations['keypoints'])
115 |
116 | if self.tokenizer is not None:
117 | coords_seqs, rand_idxs = self.tokenizer(corner_coords, shuffle=self.shuffle_tokens)
118 | coords_seqs = torch.LongTensor(coords_seqs)
119 | if self.shuffle_tokens:
120 | perm_matrix = self.shuffle_perm_matrix_by_indices(perm_matrix, rand_idxs)
121 | else:
122 | coords_seqs = corner_coords
123 |
124 | return image, mask[None, ...], corner_mask[None, ...], coords_seqs, perm_matrix
125 |
126 |
127 | class InriaCocoDataset_val(Dataset):
128 | def __init__(self, cfg, dataset_dir, transform=None, tokenizer=None, shuffle_tokens=False):
129 | self.CFG = cfg
130 | image_dir = osp.join(dataset_dir, "images")
131 | self.image_dir = image_dir
132 | self.annotations_path = osp.join(dataset_dir, "annotation.json")
133 | self.transform = transform
134 | self.tokenizer = tokenizer
135 | self.shuffle_tokens = shuffle_tokens
136 | # self.images = os.listdir(self.image_dir)
137 | self.coco = COCO(self.annotations_path)
138 | # self.image_ids = self.coco.getImgIds(catIds=self.coco.getCatIds())
139 | self.images = [file for file in os.listdir(self.image_dir) if osp.isfile(osp.join(self.image_dir, file))]
140 | self.image_ids = [int(im.split('-')[-1].split('.')[0]) for im in self.images]
141 |
142 | def __len__(self):
143 | return len(self.image_ids)
144 |
145 | def annToMask(self):
146 | return
147 |
148 | def shuffle_perm_matrix_by_indices(self, old_perm: torch.Tensor, shuffle_idxs: np.ndarray):
149 | Nv = old_perm.shape[0]
150 | padd_idxs = np.arange(len(shuffle_idxs), Nv)
151 | shuffle_idxs = np.concatenate([shuffle_idxs, padd_idxs], axis=0)
152 |
153 | transform_arr = torch.zeros_like(old_perm)
154 | for i in range(len(shuffle_idxs)):
155 | transform_arr[i, shuffle_idxs[i]] = 1.
156 |
157 | # https://math.stackexchange.com/questions/2481213/adjacency-matrix-and-changing-order-of-vertices
158 | new_perm = torch.mm(torch.mm(transform_arr, old_perm), transform_arr.T)
159 |
160 | return new_perm
161 |
162 | def __getitem__(self, index):
163 | n_vertices = self.CFG.N_VERTICES
164 | img_id = self.image_ids[index]
165 | img = self.coco.loadImgs(img_id)[0]
166 | img_path = osp.join(self.image_dir, img["file_name"])
167 | ann_ids = self.coco.getAnnIds(imgIds=img['id'])
168 | annotations = self.coco.loadAnns(ann_ids) # annotations of all instances in an image.
169 |
170 | image = np.array(Image.open(img_path).convert("RGB"))
171 |
172 | mask = np.zeros((img['width'], img['height']))
173 | corner_coords = []
174 | corner_mask = np.zeros((img['width'], img['height']), dtype=np.float32)
175 | perm_matrix = np.zeros((n_vertices, n_vertices), dtype=np.float32)
176 | for ins in annotations:
177 | segmentations = ins['segmentation']
178 | for i, segm in enumerate(segmentations):
179 | segm = np.array(segm).reshape(-1, 2)
180 | segm[:, 0] = np.clip(segm[:, 0], 0, img['width'] - 1)
181 | segm[:, 1] = np.clip(segm[:, 1], 0, img['height'] - 1)
182 | points = segm[:-1]
183 | corner_coords.extend(points.tolist())
184 | mask += self.coco.annToMask(ins)
185 | mask = mask / 255. if mask.max() == 255 else mask
186 | mask = np.clip(mask, 0, 1)
187 |
188 | corner_coords = np.flip(np.round(corner_coords, 0), axis=-1).astype(np.int32)
189 |
190 | if len(corner_coords) > 0.:
191 | corner_mask[corner_coords[:, 0], corner_coords[:, 1]] = 1.
192 |
193 | ############# START: Generate gt permutation matrix. #############
194 | v_count = 0
195 | for ins in annotations:
196 | segmentations = ins['segmentation']
197 | for idx, segm in enumerate(segmentations):
198 | segm = np.array(segm).reshape(-1, 2)
199 | points = segm[:-1]
200 | for i in range(len(points)):
201 | j = (i + 1) % len(points)
202 | if v_count+i > n_vertices - 1 or v_count+j > n_vertices-1:
203 | break
204 | perm_matrix[v_count+i, v_count+j] = 1.
205 | v_count += len(points)
206 |
207 | for i in range(v_count, n_vertices):
208 | perm_matrix[i, i] = 1.
209 |
210 | # Workaround for open contours:
211 | for i in range(n_vertices):
212 | row = perm_matrix[i, :]
213 | col = perm_matrix[:, i]
214 | if np.sum(row) == 0 or np.sum(col) == 0:
215 | perm_matrix[i, i] = 1.
216 | perm_matrix = torch.from_numpy(perm_matrix)
217 | ############# END: Generate gt permutation matrix. #############
218 |
219 | masks = [mask, corner_mask]
220 |
221 | if len(corner_coords) > self.CFG.N_VERTICES:
222 | corner_coords = corner_coords[:self.CFG.N_VERTICES]
223 |
224 | if self.transform is not None:
225 | augmentations = self.transform(image=image, masks=masks, keypoints=corner_coords.tolist())
226 | image = augmentations['image']
227 | mask = augmentations['masks'][0]
228 | corner_mask = augmentations['masks'][1]
229 | corner_coords = np.array(augmentations['keypoints'])
230 |
231 | if self.tokenizer is not None:
232 | coords_seqs, rand_idxs = self.tokenizer(corner_coords, shuffle=self.shuffle_tokens)
233 | coords_seqs = torch.LongTensor(coords_seqs)
234 | if self.shuffle_tokens:
235 | perm_matrix = self.shuffle_perm_matrix_by_indices(perm_matrix, rand_idxs)
236 | else:
237 | coords_seqs = corner_coords
238 |
239 | return image, mask[None, ...], corner_mask[None, ...], coords_seqs, perm_matrix, torch.tensor([img['id']])
240 |
241 |
242 | def collate_fn(batch, max_len, pad_idx):
243 | """
244 | if max_len:
245 | the sequences will all be padded to that length.
246 | """
247 |
248 | image_batch, mask_batch, coords_mask_batch, coords_seq_batch, perm_matrix_batch = [], [], [], [], []
249 | for image, mask, c_mask, seq, perm_mat in batch:
250 | image_batch.append(image)
251 | mask_batch.append(mask)
252 | coords_mask_batch.append(c_mask)
253 | coords_seq_batch.append(seq)
254 | perm_matrix_batch.append(perm_mat)
255 |
256 | coords_seq_batch = pad_sequence(
257 | coords_seq_batch,
258 | padding_value=pad_idx,
259 | batch_first=True
260 | )
261 |
262 | if max_len:
263 | pad = torch.ones(coords_seq_batch.size(0), max_len - coords_seq_batch.size(1)).fill_(pad_idx).long()
264 | coords_seq_batch = torch.cat([coords_seq_batch, pad], dim=1)
265 |
266 | image_batch = torch.stack(image_batch)
267 | mask_batch = torch.stack(mask_batch)
268 | coords_mask_batch = torch.stack(coords_mask_batch)
269 | perm_matrix_batch = torch.stack(perm_matrix_batch)
270 | return image_batch, mask_batch, coords_mask_batch, coords_seq_batch, perm_matrix_batch
271 |
272 |
273 | class InriaCocoDatasetTest(Dataset):
274 | def __init__(self, image_dir, transform=None):
275 | self.image_dir = image_dir
276 | self.transform = transform
277 | self.images = [file for file in os.listdir(self.image_dir) if osp.isfile(osp.join(self.image_dir, file))]
278 |
279 | def __getitem__(self, index):
280 | img_path = osp.join(self.image_dir, self.images[index])
281 | image = np.array(Image.open(img_path).convert("RGB"))
282 |
283 | if self.transform is not None:
284 | image = self.transform(image=image)['image']
285 |
286 | image = torch.FloatTensor(image)
287 | return image
288 |
289 | def __len__(self):
290 | return len(self.images)
291 |
--------------------------------------------------------------------------------
/datasets/dataset_spacenet_coco.py:
--------------------------------------------------------------------------------
1 | from PIL import Image
2 | import numpy as np
3 | import os
4 | from os import path as osp
5 | from pycocotools.coco import COCO
6 | from config import CFG
7 |
8 | import torch
9 | from torch.utils.data import Dataset
10 |
11 |
12 | class SpacenetCocoDataset(Dataset):
13 | def __init__(self, dataset_dir, transform=None, tokenizer=None, shuffle_tokens=False):
14 | image_dir = osp.join(dataset_dir, "images")
15 | self.image_dir = image_dir
16 | self.annotations_path = osp.join(dataset_dir, "annotation.json")
17 | self.transform = transform
18 | self.tokenizer = tokenizer
19 | self.shuffle_tokens = shuffle_tokens
20 | # self.images = os.listdir(self.image_dir)
21 | self.coco = COCO(self.annotations_path)
22 | # self.image_ids = self.coco.getImgIds(catIds=self.coco.getCatIds())
23 | self.images = [file for file in os.listdir(self.image_dir) if osp.isfile(osp.join(self.image_dir, file))]
24 | self.image_ids = [int(im.split('-')[-1].split('.')[0]) for im in self.images]
25 |
26 | def __len__(self):
27 | return len(self.image_ids)
28 |
29 | def annToMask(self):
30 | return
31 |
32 | def shuffle_perm_matrix_by_indices(self, old_perm: torch.Tensor, shuffle_idxs: np.ndarray):
33 | Nv = old_perm.shape[0]
34 | padd_idxs = np.arange(len(shuffle_idxs), Nv)
35 | shuffle_idxs = np.concatenate([shuffle_idxs, padd_idxs], axis=0)
36 |
37 | transform_arr = torch.zeros_like(old_perm)
38 | for i in range(len(shuffle_idxs)):
39 | transform_arr[i, shuffle_idxs[i]] = 1.
40 |
41 | # https://math.stackexchange.com/questions/2481213/adjacency-matrix-and-changing-order-of-vertices
42 | new_perm = torch.mm(torch.mm(transform_arr, old_perm), transform_arr.T)
43 |
44 | return new_perm
45 |
46 | def __getitem__(self, index):
47 | n_vertices = CFG.N_VERTICES
48 | img_id = self.image_ids[index]
49 | img = self.coco.loadImgs(img_id)[0]
50 | img_path = osp.join(self.image_dir, img["file_name"])
51 | ann_ids = self.coco.getAnnIds(imgIds=img['id'])
52 | annotations = self.coco.loadAnns(ann_ids) # annotations of all instances in an image.
53 |
54 | image = np.array(Image.open(img_path).convert("RGB"))
55 |
56 | mask = np.zeros((img['width'], img['height']))
57 | corner_coords = []
58 | corner_mask = np.zeros((img['width'], img['height']), dtype=np.float32)
59 | perm_matrix = np.zeros((n_vertices, n_vertices), dtype=np.float32)
60 | for ins in annotations:
61 | segmentations = ins['segmentation']
62 | for i, segm in enumerate(segmentations):
63 | segm = np.array(segm).reshape(-1, 2)
64 | segm[:, 0] = np.clip(segm[:, 0], 0, img['width'] - 1)
65 | segm[:, 1] = np.clip(segm[:, 1], 0, img['height'] - 1)
66 | points = segm[:-1]
67 | corner_coords.extend(points.tolist())
68 | mask += self.coco.annToMask(ins)
69 | mask = mask / 255. if mask.max() == 255 else mask
70 | mask = np.clip(mask, 0, 1)
71 |
72 | corner_coords = np.flip(np.round(corner_coords, 0), axis=-1).astype(np.int32)
73 |
74 | if len(corner_coords) > 0.:
75 | corner_mask[corner_coords[:, 0], corner_coords[:, 1]] = 1.
76 |
77 | ############# START: Generate gt permutation matrix. #############
78 | v_count = 0
79 | for ins in annotations:
80 | segmentations = ins['segmentation']
81 | for idx, segm in enumerate(segmentations):
82 | segm = np.array(segm).reshape(-1, 2)
83 | points = segm[:-1]
84 | for i in range(len(points)):
85 | j = (i + 1) % len(points)
86 | if v_count+i > n_vertices - 1 or v_count+j > n_vertices-1:
87 | break
88 | perm_matrix[v_count+i, v_count+j] = 1.
89 | v_count += len(points)
90 |
91 | for i in range(v_count, n_vertices):
92 | perm_matrix[i, i] = 1.
93 |
94 | # Workaround for open contours:
95 | for i in range(n_vertices):
96 | row = perm_matrix[i, :]
97 | col = perm_matrix[:, i]
98 | if np.sum(row) == 0 or np.sum(col) == 0:
99 | perm_matrix[i, i] = 1.
100 | perm_matrix = torch.from_numpy(perm_matrix)
101 | ############# END: Generate gt permutation matrix. #############
102 |
103 | masks = [mask, corner_mask]
104 |
105 | if len(corner_coords) > CFG.N_VERTICES:
106 | corner_coords = corner_coords[:CFG.N_VERTICES]
107 |
108 | if self.transform is not None:
109 | augmentations = self.transform(image=image, masks=masks, keypoints=corner_coords.tolist())
110 | image = augmentations['image']
111 | mask = augmentations['masks'][0]
112 | corner_mask = augmentations['masks'][1]
113 | corner_coords = np.array(augmentations['keypoints'])
114 |
115 | if self.tokenizer is not None:
116 | coords_seqs, rand_idxs = self.tokenizer(corner_coords, shuffle=self.shuffle_tokens)
117 | coords_seqs = torch.LongTensor(coords_seqs)
118 | if self.shuffle_tokens:
119 | perm_matrix = self.shuffle_perm_matrix_by_indices(perm_matrix, rand_idxs)
120 | else:
121 | coords_seqs = corner_coords
122 |
123 | return image, mask[None, ...], corner_mask[None, ...], coords_seqs, perm_matrix
124 |
125 |
126 | class SpacenetCocoDataset_val(Dataset):
127 | def __init__(self, cfg, dataset_dir, transform=None, tokenizer=None, shuffle_tokens=False):
128 | self.CFG = cfg
129 | image_dir = osp.join(dataset_dir, "images")
130 | self.image_dir = image_dir
131 | self.annotations_path = osp.join(dataset_dir, "annotation.json")
132 | self.transform = transform
133 | self.tokenizer = tokenizer
134 | self.shuffle_tokens = shuffle_tokens
135 | # self.images = os.listdir(self.image_dir)
136 | self.coco = COCO(self.annotations_path)
137 | # self.image_ids = self.coco.getImgIds(catIds=self.coco.getCatIds())
138 | self.images = [file for file in os.listdir(self.image_dir) if osp.isfile(osp.join(self.image_dir, file))]
139 | self.image_ids = [int(im.split('-')[-1].split('.')[0]) for im in self.images]
140 |
141 | def __len__(self):
142 | return len(self.image_ids)
143 |
144 | def annToMask(self):
145 | return
146 |
147 | def __getitem__(self, index):
148 | n_vertices = self.CFG.N_VERTICES
149 | img_id = self.image_ids[index]
150 | img = self.coco.loadImgs(img_id)[0]
151 | img_path = osp.join(self.image_dir, img["file_name"])
152 | ann_ids = self.coco.getAnnIds(imgIds=img['id'])
153 | annotations = self.coco.loadAnns(ann_ids) # annotations of all instances in an image.
154 |
155 | image = np.array(Image.open(img_path).convert("RGB"))
156 |
157 | mask = np.zeros((img['width'], img['height']))
158 | corner_coords = []
159 | corner_mask = np.zeros((img['width'], img['height']), dtype=np.float32)
160 | perm_matrix = np.zeros((n_vertices, n_vertices), dtype=np.float32)
161 | for ins in annotations:
162 | segmentations = ins['segmentation']
163 | for i, segm in enumerate(segmentations):
164 | segm = np.array(segm).reshape(-1, 2)
165 | segm[:, 0] = np.clip(segm[:, 0], 0, img['width'] - 1)
166 | segm[:, 1] = np.clip(segm[:, 1], 0, img['height'] - 1)
167 | points = segm[:-1]
168 | corner_coords.extend(points.tolist())
169 | mask += self.coco.annToMask(ins)
170 | mask = mask / 255. if mask.max() == 255 else mask
171 | mask = np.clip(mask, 0, 1)
172 |
173 | corner_coords = np.flip(np.round(corner_coords, 0), axis=-1).astype(np.int32)
174 |
175 | if len(corner_coords) > 0.:
176 | corner_mask[corner_coords[:, 0], corner_coords[:, 1]] = 1.
177 |
178 | ############# START: Generate gt permutation matrix. #############
179 | v_count = 0
180 | for ins in annotations:
181 | segmentations = ins['segmentation']
182 | for idx, segm in enumerate(segmentations):
183 | segm = np.array(segm).reshape(-1, 2)
184 | points = segm[:-1]
185 | for i in range(len(points)):
186 | j = (i + 1) % len(points)
187 | if v_count+i > n_vertices - 1 or v_count+j > n_vertices-1:
188 | break
189 | perm_matrix[v_count+i, v_count+j] = 1.
190 | v_count += len(points)
191 |
192 | for i in range(v_count, n_vertices):
193 | perm_matrix[i, i] = 1.
194 |
195 | # Workaround for open contours:
196 | for i in range(n_vertices):
197 | row = perm_matrix[i, :]
198 | col = perm_matrix[:, i]
199 | if np.sum(row) == 0 or np.sum(col) == 0:
200 | perm_matrix[i, i] = 1.
201 | perm_matrix = torch.from_numpy(perm_matrix)
202 | ############# END: Generate gt permutation matrix. #############
203 |
204 | masks = [mask, corner_mask]
205 |
206 | if len(corner_coords) > self.CFG.N_VERTICES:
207 | corner_coords = corner_coords[:self.CFG.N_VERTICES]
208 |
209 | if self.transform is not None:
210 | augmentations = self.transform(image=image, masks=masks, keypoints=corner_coords.tolist())
211 | image = augmentations['image']
212 | mask = augmentations['masks'][0]
213 | corner_mask = augmentations['masks'][1]
214 | corner_coords = np.array(augmentations['keypoints'])
215 |
216 | if self.tokenizer is not None:
217 | coords_seqs, rand_idxs = self.tokenizer(corner_coords, shuffle=self.shuffle_tokens)
218 | coords_seqs = torch.LongTensor(coords_seqs)
219 | perm_matrix = torch.cat((perm_matrix[rand_idxs], perm_matrix[len(rand_idxs):]))
220 | else:
221 | coords_seqs = corner_coords
222 |
223 | return image, mask[None, ...], corner_mask[None, ...], coords_seqs, perm_matrix, torch.tensor([img['id']])
224 |
225 |
226 | class SpacenetCocoDatasetTest(Dataset):
227 | def __init__(self, image_dir, transform=None):
228 | self.image_dir = image_dir
229 | self.transform = transform
230 | self.images = [file for file in os.listdir(self.image_dir) if osp.isfile(osp.join(self.image_dir, file))]
231 |
232 | def __getitem__(self, index):
233 | img_path = osp.join(self.image_dir, self.images[index])
234 | image = np.array(Image.open(img_path).convert("RGB"))
235 |
236 | if self.transform is not None:
237 | image = self.transform(image=image)['image']
238 |
239 | image = torch.FloatTensor(image)
240 | return image
241 |
242 | def __len__(self):
243 | return len(self.images)
244 |
--------------------------------------------------------------------------------
/datasets/dataset_whu_buildings_coco.py:
--------------------------------------------------------------------------------
1 | from PIL import Image
2 | import numpy as np
3 | from functools import partial
4 | import os
5 | from os import path as osp
6 | from pycocotools.coco import COCO
7 | from pycocotools import mask as cocomask
8 | from config import CFG
9 |
10 | import torch
11 | from torch.utils.data import Dataset
12 | import albumentations as A
13 | from torch.nn.utils.rnn import pad_sequence
14 |
15 |
16 | class WHUBuildingsCocoDataset(Dataset):
17 | def __init__(self, dataset_dir, transform=None, tokenizer=None, shuffle_tokens=False):
18 | image_dir = osp.join(dataset_dir, "images")
19 | self.image_dir = image_dir
20 | self.annotations_path = osp.join(dataset_dir, "annotation.json")
21 | self.transform = transform
22 | self.tokenizer = tokenizer
23 | self.shuffle_tokens = shuffle_tokens
24 | # self.images = os.listdir(self.image_dir)
25 | self.coco = COCO(self.annotations_path)
26 | # self.image_ids = self.coco.getImgIds(catIds=self.coco.getCatIds())
27 | self.images = [file for file in os.listdir(self.image_dir) if osp.isfile(osp.join(self.image_dir, file))]
28 | # image_ids = [int(im.split('-')[-1].split('.')[0]) for im in self.images]
29 | image_ids = self.coco.getImgIds(catIds=self.coco.getCatIds())
30 | if "train" in dataset_dir:
31 | # remove images with more than 144 vertices from training.
32 | self.image_ids = [im for im in image_ids if int(im) not in [16608, 36020, 36021]]
33 | else:
34 | self.image_ids = image_ids
35 |
36 | def __len__(self):
37 | return len(self.image_ids)
38 |
39 | def annToMask(self):
40 | return
41 |
42 | def shuffle_perm_matrix_by_indices(self, old_perm: torch.Tensor, shuffle_idxs: np.ndarray):
43 | Nv = old_perm.shape[0]
44 | padd_idxs = np.arange(len(shuffle_idxs), Nv)
45 | shuffle_idxs = np.concatenate([shuffle_idxs, padd_idxs], axis=0)
46 |
47 | transform_arr = torch.zeros_like(old_perm)
48 | for i in range(len(shuffle_idxs)):
49 | transform_arr[i, shuffle_idxs[i]] = 1.
50 |
51 | # https://math.stackexchange.com/questions/2481213/adjacency-matrix-and-changing-order-of-vertices
52 | new_perm = torch.mm(torch.mm(transform_arr, old_perm), transform_arr.T)
53 | # new_perm = torch.zeros_like(old_perm)
54 |
55 | # # generate new perm matrix based on shuffling indices.
56 | # for i in range(len(shuffle_idxs)):
57 | # new_i = shuffle_idxs[i]
58 | # new_j = shuffle_idxs[old_perm[i].nonzero().item()]
59 | # new_perm[new_i, new_j] = 1.
60 |
61 | # # Add self connections of unconnected vertices.
62 | # for i in range(Nv):
63 | # row = new_perm[i, :]
64 | # col = new_perm[:, i]
65 | # if torch.sum(row) == 0 or torch.sum(col) == 0:
66 | # new_perm[i, i] = 1.
67 |
68 | return new_perm
69 |
70 | def __getitem__(self, index):
71 | n_vertices = CFG.N_VERTICES
72 | img_id = self.image_ids[index]
73 | img = self.coco.loadImgs(img_id)[0]
74 | img_path = osp.join(self.image_dir, img["file_name"])
75 | ann_ids = self.coco.getAnnIds(imgIds=img['id'])
76 | annotations = self.coco.loadAnns(ann_ids) # annotations of all instances in an image.
77 |
78 | image = np.array(Image.open(img_path).convert("RGB"))
79 |
80 | mask = np.zeros((img['width'], img['height']))
81 | corner_coords = []
82 | corner_mask = np.zeros((img['width'], img['height']), dtype=np.float32)
83 | perm_matrix = np.zeros((n_vertices, n_vertices), dtype=np.float32)
84 | for ins in annotations:
85 | segmentations = ins['segmentation']
86 | for i, segm in enumerate(segmentations):
87 | segm = np.array(segm).reshape(-1, 2)
88 | segm[:, 0] = np.clip(segm[:, 0], 0, img['width'] - 1)
89 | segm[:, 1] = np.clip(segm[:, 1], 0, img['height'] - 1)
90 | points = segm[:-1]
91 | corner_coords.extend(points.tolist())
92 | mask += self.coco.annToMask(ins)
93 | mask = mask / 255. if mask.max() == 255 else mask
94 | mask = np.clip(mask, 0, 1)
95 |
96 | # corner_coords = np.clip(np.array(corner_coords), 0, 299)
97 | corner_coords = np.flip(np.round(corner_coords, 0), axis=-1).astype(np.int32)
98 |
99 | if len(corner_coords) > 0.:
100 | corner_mask[corner_coords[:, 0], corner_coords[:, 1]] = 1.
101 | # corner_coords = (corner_coords / img['width']) * CFG.INPUT_WIDTH
102 |
103 | ############# START: Generate gt permutation matrix. #############
104 | v_count = 0
105 | for ins in annotations:
106 | segmentations = ins['segmentation']
107 | for idx, segm in enumerate(segmentations):
108 | segm = np.array(segm).reshape(-1, 2)
109 | points = segm[:-1]
110 | for i in range(len(points)):
111 | j = (i + 1) % len(points)
112 | if v_count+i > n_vertices - 1 or v_count+j > n_vertices-1:
113 | break
114 | perm_matrix[v_count+i, v_count+j] = 1.
115 | v_count += len(points)
116 |
117 | for i in range(v_count, n_vertices):
118 | perm_matrix[i, i] = 1.
119 |
120 | # Workaround for open contours:
121 | for i in range(n_vertices):
122 | row = perm_matrix[i, :]
123 | col = perm_matrix[:, i]
124 | if np.sum(row) == 0 or np.sum(col) == 0:
125 | perm_matrix[i, i] = 1.
126 | perm_matrix = torch.from_numpy(perm_matrix)
127 | ############# END: Generate gt permutation matrix. #############
128 |
129 | masks = [mask, corner_mask]
130 |
131 | if len(corner_coords) > CFG.N_VERTICES:
132 | corner_coords = corner_coords[:CFG.N_VERTICES]
133 |
134 | if self.transform is not None:
135 | augmentations = self.transform(image=image, masks=masks, keypoints=corner_coords.tolist())
136 | image = augmentations['image']
137 | mask = augmentations['masks'][0]
138 | corner_mask = augmentations['masks'][1]
139 | corner_coords = np.array(augmentations['keypoints'])
140 |
141 | if self.tokenizer is not None:
142 | coords_seqs, rand_idxs = self.tokenizer(corner_coords, shuffle=self.shuffle_tokens)
143 | coords_seqs = torch.LongTensor(coords_seqs)
144 | # perm_matrix = torch.cat((perm_matrix[rand_idxs], perm_matrix[len(rand_idxs):]))
145 | if self.shuffle_tokens:
146 | perm_matrix = self.shuffle_perm_matrix_by_indices(perm_matrix, rand_idxs)
147 | else:
148 | coords_seqs = corner_coords
149 |
150 | return image, mask[None, ...], corner_mask[None, ...], coords_seqs, perm_matrix
151 |
152 |
153 | class WHUBuildingsCocoDataset_val(Dataset):
154 | def __init__(self, cfg, dataset_dir, transform=None, tokenizer=None, shuffle_tokens=False):
155 | self.CFG = cfg
156 | image_dir = osp.join(dataset_dir, "images")
157 | self.image_dir = image_dir
158 | self.annotations_path = osp.join(dataset_dir, "annotation.json")
159 | self.transform = transform
160 | self.tokenizer = tokenizer
161 | self.shuffle_tokens = shuffle_tokens
162 | # self.images = os.listdir(self.image_dir)
163 | self.coco = COCO(self.annotations_path)
164 | # self.image_ids = self.coco.getImgIds(catIds=self.coco.getCatIds())
165 | self.images = [file for file in os.listdir(self.image_dir) if osp.isfile(osp.join(self.image_dir, file))]
166 | # self.image_ids = [int(im.split('-')[-1].split('.')[0]) for im in self.images]
167 | image_ids = self.coco.getImgIds(catIds=self.coco.getCatIds())
168 | if "train" in dataset_dir:
169 | # remove images with more than 144 vertices from training.
170 | self.image_ids = [im for im in image_ids if int(im) not in [16608, 36020, 36021]]
171 | else:
172 | self.image_ids = image_ids
173 |
174 | def __len__(self):
175 | return len(self.image_ids)
176 |
177 | def annToMask(self):
178 | return
179 |
180 | def shuffle_perm_matrix_by_indices(self, old_perm: torch.Tensor, shuffle_idxs: np.ndarray):
181 | Nv = old_perm.shape[0]
182 | padd_idxs = np.arange(len(shuffle_idxs), Nv)
183 | shuffle_idxs = np.concatenate([shuffle_idxs, padd_idxs], axis=0)
184 |
185 | transform_arr = torch.zeros_like(old_perm)
186 | for i in range(len(shuffle_idxs)):
187 | transform_arr[i, shuffle_idxs[i]] = 1.
188 |
189 | # https://math.stackexchange.com/questions/2481213/adjacency-matrix-and-changing-order-of-vertices
190 | new_perm = torch.mm(torch.mm(transform_arr, old_perm), transform_arr.T)
191 |
192 | return new_perm
193 |
194 | def __getitem__(self, index):
195 | n_vertices = self.CFG.N_VERTICES
196 | img_id = self.image_ids[index]
197 | img = self.coco.loadImgs(img_id)[0]
198 | img_path = osp.join(self.image_dir, img["file_name"])
199 | ann_ids = self.coco.getAnnIds(imgIds=img['id'])
200 | annotations = self.coco.loadAnns(ann_ids) # annotations of all instances in an image.
201 |
202 | image = np.array(Image.open(img_path).convert("RGB"))
203 |
204 | mask = np.zeros((img['width'], img['height']))
205 | corner_coords = []
206 | corner_mask = np.zeros((img['width'], img['height']), dtype=np.float32)
207 | perm_matrix = np.zeros((n_vertices, n_vertices), dtype=np.float32)
208 | for ins in annotations:
209 | segmentations = ins['segmentation']
210 | for i, segm in enumerate(segmentations):
211 | segm = np.array(segm).reshape(-1, 2)
212 | segm[:, 0] = np.clip(segm[:, 0], 0, img['width'] - 1)
213 | segm[:, 1] = np.clip(segm[:, 1], 0, img['height'] - 1)
214 | points = segm[:-1]
215 | corner_coords.extend(points.tolist())
216 | mask += self.coco.annToMask(ins)
217 | mask = mask / 255. if mask.max() == 255 else mask
218 | mask = np.clip(mask, 0, 1)
219 |
220 | # corner_coords = np.clip(np.array(corner_coords), 0, 299)
221 | corner_coords = np.flip(np.round(corner_coords, 0), axis=-1).astype(np.int32)
222 |
223 | if len(corner_coords) > 0.:
224 | corner_mask[corner_coords[:, 0], corner_coords[:, 1]] = 1.
225 | # corner_coords = (corner_coords / img['width']) * CFG.INPUT_WIDTH
226 |
227 | ############# START: Generate gt permutation matrix. #############
228 | v_count = 0
229 | for ins in annotations:
230 | segmentations = ins['segmentation']
231 | for idx, segm in enumerate(segmentations):
232 | segm = np.array(segm).reshape(-1, 2)
233 | points = segm[:-1]
234 | for i in range(len(points)):
235 | j = (i + 1) % len(points)
236 | if v_count+i > n_vertices - 1 or v_count+j > n_vertices-1:
237 | break
238 | perm_matrix[v_count+i, v_count+j] = 1.
239 | v_count += len(points)
240 |
241 | for i in range(v_count, n_vertices):
242 | perm_matrix[i, i] = 1.
243 |
244 | # Workaround for open contours:
245 | for i in range(n_vertices):
246 | row = perm_matrix[i, :]
247 | col = perm_matrix[:, i]
248 | if np.sum(row) == 0 or np.sum(col) == 0:
249 | perm_matrix[i, i] = 1.
250 | perm_matrix = torch.from_numpy(perm_matrix)
251 | ############# END: Generate gt permutation matrix. #############
252 |
253 | masks = [mask, corner_mask]
254 |
255 | if len(corner_coords) > self.CFG.N_VERTICES:
256 | corner_coords = corner_coords[:self.CFG.N_VERTICES]
257 |
258 | if self.transform is not None:
259 | augmentations = self.transform(image=image, masks=masks, keypoints=corner_coords.tolist())
260 | image = augmentations['image']
261 | mask = augmentations['masks'][0]
262 | corner_mask = augmentations['masks'][1]
263 | corner_coords = np.array(augmentations['keypoints'])
264 |
265 | if self.tokenizer is not None:
266 | coords_seqs, rand_idxs = self.tokenizer(corner_coords, shuffle=self.shuffle_tokens)
267 | coords_seqs = torch.LongTensor(coords_seqs)
268 | # perm_matrix = torch.cat((perm_matrix[rand_idxs], perm_matrix[len(rand_idxs):]))
269 | if self.shuffle_tokens:
270 | perm_matrix = self.shuffle_perm_matrix_by_indices(perm_matrix, rand_idxs)
271 | else:
272 | coords_seqs = corner_coords
273 |
274 | return image, mask[None, ...], corner_mask[None, ...], coords_seqs, perm_matrix, torch.tensor([img['id']])
275 |
276 |
277 | def whu_buildings_collate_fn(batch, max_len, pad_idx):
278 | """
279 | if max_len:
280 | the sequences will all be padded to that length.
281 | """
282 |
283 | image_batch, mask_batch, coords_mask_batch, coords_seq_batch, perm_matrix_batch = [], [], [], [], []
284 | for image, mask, c_mask, seq, perm_mat in batch:
285 | image_batch.append(image)
286 | mask_batch.append(mask)
287 | coords_mask_batch.append(c_mask)
288 | coords_seq_batch.append(seq)
289 | perm_matrix_batch.append(perm_mat)
290 |
291 | coords_seq_batch = pad_sequence(
292 | coords_seq_batch,
293 | padding_value=pad_idx,
294 | batch_first=True
295 | )
296 |
297 | if max_len:
298 | pad = torch.ones(coords_seq_batch.size(0), max_len - coords_seq_batch.size(1)).fill_(pad_idx).long()
299 | coords_seq_batch = torch.cat([coords_seq_batch, pad], dim=1)
300 |
301 | image_batch = torch.stack(image_batch)
302 | mask_batch = torch.stack(mask_batch)
303 | coords_mask_batch = torch.stack(coords_mask_batch)
304 | perm_matrix_batch = torch.stack(perm_matrix_batch)
305 | return image_batch, mask_batch, coords_mask_batch, coords_seq_batch, perm_matrix_batch
306 |
307 |
308 | class WHUBuildingsCocoDatasetTest(Dataset):
309 | def __init__(self, image_dir, transform=None):
310 | self.image_dir = image_dir
311 | self.transform = transform
312 | self.images = [file for file in os.listdir(self.image_dir) if osp.isfile(osp.join(self.image_dir, file))]
313 |
314 | def __getitem__(self, index):
315 | img_path = osp.join(self.image_dir, self.images[index])
316 | image = np.array(Image.open(img_path).convert("RGB"))
317 |
318 | if self.transform is not None:
319 | image = self.transform(image=image)['image']
320 |
321 | image = torch.FloatTensor(image)
322 | return image
323 |
324 | def __len__(self):
325 | return len(self.images)
326 |
--------------------------------------------------------------------------------
/ddp_utils.py:
--------------------------------------------------------------------------------
1 | from functools import partial
2 | from torch import distributed as dist
3 | from torch.utils.data import DataLoader
4 | from torch.utils.data.distributed import DistributedSampler
5 | from torch.nn.utils.rnn import pad_sequence
6 | import torch
7 |
8 | from datasets.dataset_inria_coco import InriaCocoDataset, InriaCocoDatasetTest
9 | from datasets.dataset_spacenet_coco import SpacenetCocoDataset, SpacenetCocoDatasetTest
10 | from datasets.dataset_whu_buildings_coco import WHUBuildingsCocoDataset, WHUBuildingsCocoDatasetTest
11 | from datasets.dataset_mass_roads import MassRoadsDataset, MassRoadsDatasetTest
12 |
13 |
14 | def is_dist_avail_and_initialized():
15 | if not dist.is_available():
16 | return False
17 | if not dist.is_initialized():
18 | return False
19 | return True
20 |
21 |
22 | def get_rank():
23 | if not is_dist_avail_and_initialized():
24 | return 0
25 | return dist.get_rank()
26 |
27 |
28 | def is_main_process():
29 | return get_rank() == 0
30 |
31 |
32 | def collate_fn(batch, max_len, pad_idx):
33 | """
34 | if max_len:
35 | the sequences will all be padded to that length.
36 | """
37 |
38 | image_batch, mask_batch, coords_mask_batch, coords_seq_batch, perm_matrix_batch = [], [], [], [], []
39 | for image, mask, c_mask, seq, perm_mat in batch:
40 | image_batch.append(image)
41 | mask_batch.append(mask)
42 | coords_mask_batch.append(c_mask)
43 | coords_seq_batch.append(seq)
44 | perm_matrix_batch.append(perm_mat)
45 |
46 | coords_seq_batch = pad_sequence(
47 | coords_seq_batch,
48 | padding_value=pad_idx,
49 | batch_first=True
50 | )
51 |
52 | if max_len:
53 | pad = torch.ones(coords_seq_batch.size(0), max_len - coords_seq_batch.size(1)).fill_(pad_idx).long()
54 | coords_seq_batch = torch.cat([coords_seq_batch, pad], dim=1)
55 |
56 | image_batch = torch.stack(image_batch)
57 | mask_batch = torch.stack(mask_batch)
58 | coords_mask_batch = torch.stack(coords_mask_batch)
59 | perm_matrix_batch = torch.stack(perm_matrix_batch)
60 | return image_batch, mask_batch, coords_mask_batch, coords_seq_batch, perm_matrix_batch
61 |
62 |
63 | def get_inria_loaders(
64 | train_dataset_dir,
65 | val_dataset_dir,
66 | test_images_dir,
67 | tokenizer,
68 | max_len,
69 | pad_idx,
70 | shuffle_tokens,
71 | batch_size,
72 | train_transform,
73 | val_transform,
74 | num_workers=2,
75 | pin_memory=True
76 | ):
77 |
78 | train_ds = InriaCocoDataset(
79 | dataset_dir=train_dataset_dir,
80 | transform=train_transform,
81 | tokenizer=tokenizer,
82 | shuffle_tokens=shuffle_tokens
83 | )
84 |
85 | train_sampler = DistributedSampler(dataset=train_ds, shuffle=True)
86 |
87 | train_loader = DataLoader(
88 | train_ds,
89 | batch_size=batch_size,
90 | collate_fn=partial(collate_fn, max_len=max_len, pad_idx=pad_idx),
91 | sampler=train_sampler,
92 | num_workers=num_workers,
93 | pin_memory=pin_memory,
94 | drop_last=True
95 | )
96 |
97 | valid_ds = InriaCocoDataset(
98 | dataset_dir=val_dataset_dir,
99 | transform=val_transform,
100 | tokenizer=tokenizer,
101 | shuffle_tokens=shuffle_tokens
102 | )
103 |
104 | valid_sampler = DistributedSampler(dataset=valid_ds, shuffle=False)
105 |
106 | valid_loader = DataLoader(
107 | valid_ds,
108 | batch_size=batch_size,
109 | collate_fn=partial(collate_fn, max_len=max_len, pad_idx=pad_idx),
110 | sampler=valid_sampler,
111 | num_workers=0,
112 | pin_memory=True,
113 | )
114 |
115 | test_ds = InriaCocoDatasetTest(
116 | image_dir=test_images_dir,
117 | transform=val_transform
118 | )
119 |
120 | test_sampler = DistributedSampler(dataset=test_ds, shuffle=False)
121 |
122 | test_loader = DataLoader(
123 | test_ds,
124 | batch_size=batch_size,
125 | sampler=test_sampler,
126 | num_workers=num_workers,
127 | pin_memory=pin_memory,
128 | )
129 |
130 | return train_loader, valid_loader, test_loader
131 |
132 |
133 | def get_spacenet_loaders(
134 | train_dataset_dir,
135 | val_dataset_dir,
136 | test_images_dir,
137 | tokenizer,
138 | max_len,
139 | pad_idx,
140 | shuffle_tokens,
141 | batch_size,
142 | train_transform,
143 | val_transform,
144 | num_workers=2,
145 | pin_memory=True
146 | ):
147 |
148 | train_ds = SpacenetCocoDataset(
149 | dataset_dir=train_dataset_dir,
150 | transform=train_transform,
151 | tokenizer=tokenizer,
152 | shuffle_tokens=shuffle_tokens
153 | )
154 |
155 | train_sampler = DistributedSampler(dataset=train_ds, shuffle=True)
156 |
157 | train_loader = DataLoader(
158 | train_ds,
159 | batch_size=batch_size,
160 | collate_fn=partial(collate_fn, max_len=max_len, pad_idx=pad_idx),
161 | sampler=train_sampler,
162 | num_workers=num_workers,
163 | pin_memory=pin_memory,
164 | drop_last=True
165 | )
166 |
167 | valid_ds = SpacenetCocoDataset(
168 | dataset_dir=val_dataset_dir,
169 | transform=val_transform,
170 | tokenizer=tokenizer,
171 | shuffle_tokens=shuffle_tokens
172 | )
173 |
174 | valid_sampler = DistributedSampler(dataset=valid_ds, shuffle=False)
175 |
176 | valid_loader = DataLoader(
177 | valid_ds,
178 | batch_size=batch_size,
179 | collate_fn=partial(collate_fn, max_len=max_len, pad_idx=pad_idx),
180 | sampler=valid_sampler,
181 | num_workers=0,
182 | pin_memory=True,
183 | )
184 |
185 | test_ds = SpacenetCocoDatasetTest(
186 | image_dir=test_images_dir,
187 | transform=val_transform
188 | )
189 |
190 | test_sampler = DistributedSampler(dataset=test_ds, shuffle=False)
191 |
192 | test_loader = DataLoader(
193 | test_ds,
194 | batch_size=batch_size,
195 | sampler=test_sampler,
196 | num_workers=num_workers,
197 | pin_memory=pin_memory,
198 | )
199 |
200 | return train_loader, valid_loader, test_loader
201 |
202 |
203 | def get_whu_buildings_loaders(
204 | train_dataset_dir,
205 | val_dataset_dir,
206 | test_images_dir,
207 | tokenizer,
208 | max_len,
209 | pad_idx,
210 | shuffle_tokens,
211 | batch_size,
212 | train_transform,
213 | val_transform,
214 | num_workers=2,
215 | pin_memory=True
216 | ):
217 |
218 | train_ds = WHUBuildingsCocoDataset(
219 | dataset_dir=train_dataset_dir,
220 | transform=train_transform,
221 | tokenizer=tokenizer,
222 | shuffle_tokens=shuffle_tokens
223 | )
224 |
225 | train_sampler = DistributedSampler(dataset=train_ds, shuffle=True)
226 |
227 | train_loader = DataLoader(
228 | train_ds,
229 | batch_size=batch_size,
230 | collate_fn=partial(collate_fn, max_len=max_len, pad_idx=pad_idx),
231 | sampler=train_sampler,
232 | num_workers=num_workers,
233 | pin_memory=pin_memory,
234 | drop_last=True
235 | )
236 |
237 | valid_ds = WHUBuildingsCocoDataset(
238 | dataset_dir=val_dataset_dir,
239 | transform=val_transform,
240 | tokenizer=tokenizer,
241 | shuffle_tokens=shuffle_tokens
242 | )
243 |
244 | valid_sampler = DistributedSampler(dataset=valid_ds, shuffle=False)
245 |
246 | valid_loader = DataLoader(
247 | valid_ds,
248 | batch_size=batch_size,
249 | collate_fn=partial(collate_fn, max_len=max_len, pad_idx=pad_idx),
250 | sampler=valid_sampler,
251 | num_workers=0,
252 | pin_memory=True,
253 | )
254 |
255 | test_ds = WHUBuildingsCocoDatasetTest(
256 | image_dir=test_images_dir,
257 | transform=val_transform
258 | )
259 |
260 | test_sampler = DistributedSampler(dataset=test_ds, shuffle=False)
261 |
262 | test_loader = DataLoader(
263 | test_ds,
264 | batch_size=batch_size,
265 | sampler=test_sampler,
266 | num_workers=num_workers,
267 | pin_memory=pin_memory,
268 | )
269 |
270 | return train_loader, valid_loader, test_loader
271 |
272 |
273 | def get_mass_roads_loaders(
274 | train_dataset_dir,
275 | val_dataset_dir,
276 | test_images_dir,
277 | tokenizer,
278 | max_len,
279 | pad_idx,
280 | shuffle_tokens,
281 | batch_size,
282 | train_transform,
283 | val_transform,
284 | num_workers=2,
285 | pin_memory=True
286 | ):
287 |
288 | train_ds = MassRoadsDataset(
289 | dataset_dir=train_dataset_dir,
290 | transform=train_transform,
291 | tokenizer=tokenizer,
292 | shuffle_tokens=shuffle_tokens
293 | )
294 |
295 | train_sampler = DistributedSampler(dataset=train_ds, shuffle=True)
296 |
297 | train_loader = DataLoader(
298 | train_ds,
299 | batch_size=batch_size,
300 | collate_fn=partial(collate_fn, max_len=max_len, pad_idx=pad_idx),
301 | sampler=train_sampler,
302 | num_workers=num_workers,
303 | pin_memory=pin_memory,
304 | drop_last=True
305 | )
306 |
307 | valid_ds = MassRoadsDataset(
308 | dataset_dir=val_dataset_dir,
309 | transform=val_transform,
310 | tokenizer=tokenizer,
311 | shuffle_tokens=shuffle_tokens
312 | )
313 |
314 | valid_sampler = DistributedSampler(dataset=valid_ds, shuffle=False)
315 |
316 | valid_loader = DataLoader(
317 | valid_ds,
318 | batch_size=batch_size,
319 | collate_fn=partial(collate_fn, max_len=max_len, pad_idx=pad_idx),
320 | sampler=valid_sampler,
321 | num_workers=0,
322 | pin_memory=True,
323 | )
324 |
325 | test_ds = MassRoadsDatasetTest(
326 | image_dir=test_images_dir,
327 | transform=val_transform
328 | )
329 |
330 | test_sampler = DistributedSampler(dataset=test_ds, shuffle=False)
331 |
332 | test_loader = DataLoader(
333 | test_ds,
334 | batch_size=batch_size,
335 | sampler=test_sampler,
336 | num_workers=num_workers,
337 | pin_memory=pin_memory,
338 | )
339 |
340 | return train_loader, valid_loader, test_loader
341 |
342 |
--------------------------------------------------------------------------------
/engine.py:
--------------------------------------------------------------------------------
1 | from tqdm import tqdm
2 | import torch
3 |
4 | from utils import (
5 | AverageMeter,
6 | get_lr,
7 | save_checkpoint,
8 | save_single_predictions_as_images
9 | )
10 | from config import CFG
11 |
12 | from ddp_utils import is_main_process
13 |
14 |
15 | def train_one_epoch(epoch, iter_idx, model, train_loader, optimizer, lr_scheduler, vertex_loss_fn, perm_loss_fn, writer):
16 | model.train()
17 | vertex_loss_fn.train()
18 | perm_loss_fn.train()
19 |
20 | loss_meter = AverageMeter()
21 | vertex_loss_meter = AverageMeter()
22 | perm_loss_meter = AverageMeter()
23 |
24 | loader = train_loader
25 | if is_main_process():
26 | loader = tqdm(train_loader, total=len(train_loader))
27 |
28 | # prof = torch.profiler.profile(
29 | # schedule=torch.profiler.schedule(wait=1, warmup=1, active=3, repeat=2),
30 | # on_trace_ready=torch.profiler.tensorboard_trace_handler(f"runs/{CFG.EXPERIMENT_NAME}/logs/profiler"),
31 | # record_shapes=True,
32 | # with_stack=True
33 | # )
34 | # prof.start()
35 | for x, y_mask, y_corner_mask, y, y_perm in loader:
36 | x = x.to(CFG.DEVICE, non_blocking=True)
37 | y = y.to(CFG.DEVICE, non_blocking=True)
38 | y_perm = y_perm.to(CFG.DEVICE, non_blocking=True)
39 |
40 | y_input = y[:, :-1]
41 | y_expected = y[:, 1:]
42 |
43 | preds, perm_mat = model(x, y_input)
44 |
45 | if epoch < CFG.MILESTONE:
46 | vertex_loss_weight = CFG.vertex_loss_weight
47 | perm_loss_weight = 0.0
48 | else:
49 | vertex_loss_weight = CFG.vertex_loss_weight
50 | perm_loss_weight = CFG.perm_loss_weight
51 |
52 | vertex_loss = vertex_loss_weight*vertex_loss_fn(preds.reshape(-1, preds.shape[-1]), y_expected.reshape(-1))
53 | perm_loss = perm_loss_weight*perm_loss_fn(perm_mat, y_perm)
54 |
55 | loss = vertex_loss + perm_loss
56 |
57 | optimizer.zero_grad(set_to_none=True)
58 | loss.backward()
59 | # nn.utils.clip_grad_norm_(model.module.parameters(), max_norm=0.1)
60 | optimizer.step()
61 |
62 | if lr_scheduler is not None:
63 | lr_scheduler.step()
64 |
65 | loss_meter.update(loss.item(), x.size(0))
66 | vertex_loss_meter.update(vertex_loss.item(), x.size(0))
67 | perm_loss_meter.update(perm_loss.item(), x.size(0))
68 |
69 | lr = get_lr(optimizer)
70 | if is_main_process():
71 | loader.set_postfix(train_loss=loss_meter.avg, lr=f"{lr:.5f}")
72 | writer.add_scalar('Running_logs/Train_Loss', loss_meter.avg, iter_idx)
73 | writer.add_scalar('Running_logs/LR', lr, iter_idx)
74 | # writer.add_image(f"Running_logs/input_images", torchvision.utils.make_grid(x), iter_idx)
75 | # writer.add_graph(model, input_to_model=(x, y_input))
76 |
77 | iter_idx += 1
78 | # prof.step()
79 | # prof.stop()
80 | print(f"Total train loss: {loss_meter.avg}\n\n")
81 | loss_dict = {
82 | 'total_loss': loss_meter.avg,
83 | 'vertex_loss': vertex_loss_meter.avg,
84 | 'perm_loss': perm_loss_meter.avg,
85 | }
86 |
87 | return loss_dict, iter_idx
88 |
89 |
90 | def valid_one_epoch(epoch, model, valid_loader, vertex_loss_fn, perm_loss_fn):
91 | print(f"\nValidating...")
92 | model.eval()
93 | vertex_loss_fn.eval()
94 | perm_loss_fn.eval()
95 |
96 | loss_meter = AverageMeter()
97 | vertex_loss_meter = AverageMeter()
98 | perm_loss_meter = AverageMeter()
99 |
100 | loader = valid_loader
101 | if is_main_process():
102 | loader = tqdm(valid_loader, total=len(valid_loader))
103 |
104 | with torch.no_grad():
105 | for x, y_mask, y_corner_mask, y, y_perm in loader:
106 | x = x.to(CFG.DEVICE, non_blocking=True)
107 | y = y.to(CFG.DEVICE, non_blocking=True)
108 | y_perm = y_perm.to(CFG.DEVICE, non_blocking=True)
109 |
110 | y_input = y[:, :-1]
111 | y_expected = y[:, 1:]
112 |
113 | preds, perm_mat = model(x, y_input)
114 |
115 | if epoch < CFG.MILESTONE:
116 | vertex_loss_weight = CFG.vertex_loss_weight
117 | perm_loss_weight = 0.0
118 | else:
119 | vertex_loss_weight = CFG.vertex_loss_weight
120 | perm_loss_weight = CFG.perm_loss_weight
121 | vertex_loss = vertex_loss_weight*vertex_loss_fn(preds.reshape(-1, preds.shape[-1]), y_expected.reshape(-1))
122 | perm_loss = perm_loss_weight*perm_loss_fn(perm_mat, y_perm)
123 |
124 | loss = vertex_loss + perm_loss
125 |
126 | loss_meter.update(loss.item(), x.size(0))
127 | vertex_loss_meter.update(vertex_loss.item(), x.size(0))
128 | perm_loss_meter.update(perm_loss.item(), x.size(0))
129 |
130 | loss_dict = {
131 | 'total_loss': loss_meter.avg,
132 | 'vertex_loss': vertex_loss_meter.avg,
133 | 'perm_loss': perm_loss_meter.avg,
134 | }
135 |
136 | return loss_dict
137 |
138 |
139 | def train_eval(
140 | model,
141 | train_loader,
142 | valid_loader,
143 | test_loader,
144 | tokenizer,
145 | vertex_loss_fn,
146 | perm_loss_fn,
147 | optimizer,
148 | lr_scheduler,
149 | step,
150 | writer
151 | ):
152 | best_loss = float('inf')
153 | best_metric = float('-inf')
154 |
155 | iter_idx=CFG.START_EPOCH * len(train_loader)
156 | epoch_iterator = range(CFG.START_EPOCH, CFG.NUM_EPOCHS)
157 | if is_main_process():
158 | epoch_iterator = tqdm(epoch_iterator)
159 | for epoch in epoch_iterator:
160 | if is_main_process():
161 | print(f"\n\nEPOCH: {epoch + 1}\n\n")
162 |
163 | if CFG.TRAIN_DDP:
164 | train_loader.sampler.set_epoch(epoch)
165 | valid_loader.sampler.set_epoch(epoch)
166 | test_loader.sampler.set_epoch(epoch)
167 |
168 | train_loss_dict, iter_idx = train_one_epoch(
169 | epoch,
170 | iter_idx,
171 | model,
172 | train_loader,
173 | optimizer,
174 | lr_scheduler if step=='batch' else None,
175 | vertex_loss_fn,
176 | perm_loss_fn,
177 | writer
178 | )
179 | if is_main_process():
180 | writer.add_scalar('Train_Losses/Total_Loss', train_loss_dict['total_loss'], epoch)
181 | writer.add_scalar('Train_Losses/Vertex_Loss', train_loss_dict['vertex_loss'], epoch)
182 | writer.add_scalar('Train_Losses/Perm_Loss', train_loss_dict['perm_loss'], epoch)
183 |
184 | valid_loss_dict = valid_one_epoch(
185 | epoch,
186 | model,
187 | valid_loader,
188 | vertex_loss_fn,
189 | perm_loss_fn,
190 | ) # TODO: add eval metrics to validation function?
191 | if is_main_process():
192 | print(f"Valid loss: {valid_loss_dict['total_loss']:.3f}\n\n")
193 |
194 | if step=='epoch':
195 | pass
196 |
197 | # Save best validation loss epoch.
198 | if valid_loss_dict['total_loss'] < best_loss and CFG.SAVE_BEST and is_main_process():
199 | best_loss = valid_loss_dict['total_loss']
200 | checkpoint = {
201 | "state_dict": model.module.state_dict(),
202 | "optimizer": optimizer.state_dict(),
203 | "scheduler": lr_scheduler.state_dict(),
204 | "epochs_run": epoch,
205 | "loss": train_loss_dict["total_loss"]
206 | }
207 | save_checkpoint(
208 | checkpoint,
209 | folder=f"runs/{CFG.EXPERIMENT_NAME}/logs/checkpoints/",
210 | filename="best_valid_loss.pth"
211 | )
212 | # torch.save(model.state_dict(), 'best_valid_loss.pth')
213 | print(f"Saved best val loss model.")
214 |
215 | # Save latest checkpoint every epoch.
216 | if CFG.SAVE_LATEST and is_main_process():
217 | checkpoint = {
218 | "state_dict": model.module.state_dict(),
219 | "optimizer": optimizer.state_dict(),
220 | "scheduler": lr_scheduler.state_dict(),
221 | "epochs_run": epoch,
222 | "loss": train_loss_dict["total_loss"]
223 | }
224 | save_checkpoint(
225 | checkpoint,
226 | folder=f"runs/{CFG.EXPERIMENT_NAME}/logs/checkpoints/",
227 | filename="latest.pth"
228 | )
229 |
230 | if (epoch + 1) % CFG.SAVE_EVERY == 0 and is_main_process():
231 | checkpoint = {
232 | "state_dict": model.module.state_dict(),
233 | "optimizer": optimizer.state_dict(),
234 | "scheduler": lr_scheduler.state_dict(),
235 | "epochs_run": epoch,
236 | "loss": train_loss_dict["total_loss"]
237 | }
238 | save_checkpoint(
239 | checkpoint,
240 | folder=f"runs/{CFG.EXPERIMENT_NAME}/logs/checkpoints/",
241 | filename=f"epoch_{epoch}.pth"
242 | )
243 |
244 | if is_main_process():
245 | writer.add_scalar('Val_Losses/Total_Loss', valid_loss_dict['total_loss'], epoch)
246 | writer.add_scalar('Val_Losses/Vertex_Loss', valid_loss_dict['vertex_loss'], epoch)
247 | writer.add_scalar('Val_Losses/Perm_Loss', valid_loss_dict['perm_loss'], epoch)
248 |
249 | # output examples to a folder
250 | if (epoch + 1) % CFG.VAL_EVERY == 0 and is_main_process():
251 | val_metrics_dict = save_single_predictions_as_images(
252 | test_loader,
253 | model,
254 | tokenizer,
255 | epoch,
256 | writer,
257 | folder=f"runs/{CFG.EXPERIMENT_NAME}/runtime_outputs/",
258 | device=CFG.DEVICE
259 | )
260 | for metric, value in zip(val_metrics_dict.keys(), val_metrics_dict.values()):
261 | print(f"{metric}: {value}")
262 |
263 | # Save best single batch validation metric epoch.
264 | if val_metrics_dict["miou"] > best_metric and CFG.SAVE_BEST and is_main_process():
265 | best_metric = val_metrics_dict["miou"]
266 | checkpoint = {
267 | "state_dict": model.module.state_dict(),
268 | "optimizer": optimizer.state_dict(),
269 | "scheduler": lr_scheduler.state_dict(),
270 | "epochs_run": epoch,
271 | "loss": train_loss_dict["total_loss"]
272 | }
273 | save_checkpoint(
274 | checkpoint,
275 | folder=f"runs/{CFG.EXPERIMENT_NAME}/logs/checkpoints/",
276 | filename="best_valid_metric.pth"
277 | )
278 | print(f"Saved best val metric model.")
279 |
280 |
--------------------------------------------------------------------------------
/eval/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yeshwanth95/Pix2Poly/d2e288b4e267ae910cee6f54bece8eec98f43bea/eval/__init__.py
--------------------------------------------------------------------------------
/eval/hisup_eval_utils/metrics/cIoU.py:
--------------------------------------------------------------------------------
1 | """
2 | This is the code from https://github.com/zorzi-s/PolyWorldPretrainedNetwork.
3 | @article{zorzi2021polyworld,
4 | title={PolyWorld: Polygonal Building Extraction with Graph Neural Networks in Satellite Images},
5 | author={Zorzi, Stefano and Bazrafkan, Shabab and Habenschuss, Stefan and Fraundorfer, Friedrich},
6 | journal={arXiv preprint arXiv:2111.15491},
7 | year={2021}
8 | }
9 | """
10 |
11 | from pycocotools.coco import COCO
12 | from pycocotools import mask as cocomask
13 | import numpy as np
14 | import json
15 | import argparse
16 | from tqdm import tqdm
17 |
18 | def calc_IoU(a, b):
19 | i = np.logical_and(a, b)
20 | u = np.logical_or(a, b)
21 | I = np.sum(i)
22 | U = np.sum(u)
23 |
24 | iou = I/(U + 1e-9)
25 |
26 | is_void = U == 0
27 | if is_void:
28 | return 1.0
29 | else:
30 | return iou
31 |
32 | def compute_IoU_cIoU(input_json, gti_annotations):
33 | # Ground truth annotations
34 | coco_gti = COCO(gti_annotations)
35 |
36 | # Predictions annotations
37 | submission_file = json.loads(open(input_json).read())
38 | coco = COCO(gti_annotations)
39 | coco = coco.loadRes(submission_file)
40 |
41 |
42 | image_ids = coco.getImgIds(catIds=coco.getCatIds())
43 | bar = tqdm(image_ids)
44 |
45 | list_iou = []
46 | list_ciou = []
47 | pss = []
48 | rel_difs = []
49 | n_ratios = []
50 | for image_id in bar:
51 |
52 | img = coco.loadImgs(image_id)[0]
53 |
54 | annotation_ids = coco.getAnnIds(imgIds=img['id'])
55 | annotations = coco.loadAnns(annotation_ids)
56 | N = 0
57 | for _idx, annotation in enumerate(annotations):
58 | try:
59 | rle = cocomask.frPyObjects(annotation['segmentation'], img['height'], img['width'])
60 | except Exception:
61 | import ipdb; ipdb.set_trace()
62 | m = cocomask.decode(rle)
63 | if _idx == 0:
64 | mask = m.reshape((img['height'], img['width']))
65 | N = len(annotation['segmentation'][0]) // 2
66 | else:
67 | mask = mask + m.reshape((img['height'], img['width']))
68 | N = N + len(annotation['segmentation'][0]) // 2
69 |
70 | mask = mask != 0
71 |
72 |
73 | annotation_ids = coco_gti.getAnnIds(imgIds=img['id'])
74 | annotations = coco_gti.loadAnns(annotation_ids)
75 | N_GT = 0
76 | for _idx, annotation in enumerate(annotations):
77 | if any(annotation['segmentation']):
78 | rle = cocomask.frPyObjects(annotation['segmentation'], img['height'], img['width'])
79 | m = cocomask.decode(rle)
80 | else:
81 | annotation['segmentation'] = [[]]
82 | m = np.zeros((img['height'], img['width']))
83 | if m.ndim > 2:
84 | m = np.clip(0, 1, m.sum(axis=-1))
85 | if _idx == 0:
86 | mask_gti = m.reshape((img['height'], img['width']))
87 | N_GT = len(annotation['segmentation'][0]) // 2
88 | else:
89 | mask_gti = mask_gti + m.reshape((img['height'], img['width']))
90 | N_GT = N_GT + len(annotation['segmentation'][0]) // 2
91 |
92 | mask_gti = mask_gti != 0
93 |
94 | ps = 1 - np.abs(N - N_GT) / (N + N_GT + 1e-9)
95 | rel_dif = np.abs(N - N_GT) / (N + N_GT + 1e-9)
96 | iou = calc_IoU(mask, mask_gti)
97 | list_iou.append(iou)
98 | list_ciou.append(iou * ps)
99 | pss.append(ps)
100 | rel_difs.append(rel_dif)
101 | if N_GT > 0:
102 | nr = N / N_GT
103 | n_ratios.append(nr)
104 |
105 | bar.set_description("iou: %2.4f, c-iou: %2.4f, ps:%2.4f, rd:%2.4f" % (np.mean(list_iou), np.mean(list_ciou), np.mean(pss), np.mean(rel_difs)))
106 | bar.refresh()
107 |
108 | print("Done!")
109 | print("Mean IoU: ", np.mean(list_iou))
110 | print("Mean C-IoU: ", np.mean(list_ciou))
111 | print("Mean N-Relative Difference: ", np.mean(rel_difs))
112 | print("Mean N-Ratio: ", np.mean(n_ratios))
113 |
114 |
115 |
116 | if __name__ == "__main__":
117 | parser = argparse.ArgumentParser()
118 | parser.add_argument("--gt-file", default="")
119 | parser.add_argument("--dt-file", default="")
120 | args = parser.parse_args()
121 |
122 | gt_file = args.gt_file
123 | dt_file = args.dt_file
124 | compute_IoU_cIoU(input_json=dt_file,
125 | gti_annotations=gt_file)
126 |
--------------------------------------------------------------------------------
/eval/hisup_eval_utils/metrics/polis.py:
--------------------------------------------------------------------------------
1 | """
2 | The code is adopted from https://github.com/spgriffin/polis
3 | """
4 |
5 | import numpy as np
6 |
7 | from tqdm import tqdm
8 | from collections import defaultdict
9 | from pycocotools import mask as maskUtils
10 | from shapely import geometry
11 | from shapely.geometry import Polygon
12 |
13 |
14 | def bounding_box(points):
15 | """returns a list containing the bottom left and the top right
16 | points in the sequence
17 | Here, we traverse the collection of points only once,
18 | to find the min and max for x and y
19 | """
20 | bot_left_x, bot_left_y = float('inf'), float('inf')
21 | top_right_x, top_right_y = float('-inf'), float('-inf')
22 | for x, y in points:
23 | bot_left_x = min(bot_left_x, x)
24 | bot_left_y = min(bot_left_y, y)
25 | top_right_x = max(top_right_x, x)
26 | top_right_y = max(top_right_y, y)
27 |
28 | return [bot_left_x, bot_left_y, top_right_x - bot_left_x, top_right_y - bot_left_y]
29 |
30 | def compare_polys(poly_a, poly_b):
31 | """Compares two polygons via the "polis" distance metric.
32 | See "A Metric for Polygon Comparison and Building Extraction
33 | Evaluation" by J. Avbelj, et al.
34 | Input:
35 | poly_a: A Shapely polygon.
36 | poly_b: Another Shapely polygon.
37 | Returns:
38 | The "polis" distance between these two polygons.
39 | """
40 | bndry_a, bndry_b = poly_a.exterior, poly_b.exterior
41 | dist = polis(bndry_a.coords, bndry_b)
42 | dist += polis(bndry_b.coords, bndry_a)
43 | return dist
44 |
45 |
46 | def polis(coords, bndry):
47 | """Computes one side of the "polis" metric.
48 | Input:
49 | coords: A Shapley coordinate sequence (presumably the vertices
50 | of a polygon).
51 | bndry: A Shapely linestring (presumably the boundary of
52 | another polygon).
53 |
54 | Returns:
55 | The "polis" metric for this pair. You usually compute this in
56 | both directions to preserve symmetry.
57 | """
58 | sum = 0.0
59 | for pt in (geometry.Point(c) for c in coords[:-1]): # Skip the last point (same as first)
60 | sum += bndry.distance(pt)
61 | return sum/float(2*len(coords))
62 |
63 |
64 | class PolisEval():
65 |
66 | def __init__(self, cocoGt=None, cocoDt=None):
67 | self.cocoGt = cocoGt
68 | self.cocoDt = cocoDt
69 | self.evalImgs = defaultdict(list)
70 | self.eval = {}
71 | self._gts = defaultdict(list)
72 | self._dts = defaultdict(list)
73 | self.stats = []
74 | self.imgIds = list(sorted(self.cocoGt.imgs.keys()))
75 |
76 | def _prepare(self):
77 | gts = self.cocoGt.loadAnns(self.cocoGt.getAnnIds(imgIds=self.imgIds))
78 | dts = self.cocoDt.loadAnns(self.cocoDt.getAnnIds(imgIds=self.imgIds))
79 | self._gts = defaultdict(list) # gt for evaluation
80 | self._dts = defaultdict(list) # dt for evaluation
81 | for gt in gts:
82 | self._gts[gt['image_id']].append(gt)
83 | for dt in dts:
84 | self._dts[dt['image_id']].append(dt)
85 | self.evalImgs = defaultdict(list) # per-image per-category evaluation results
86 | self.eval = {} # accumulated evaluation results
87 |
88 | def evaluateImg(self, imgId):
89 | gts = self._gts[imgId]
90 | dts = self._dts[imgId]
91 |
92 | if len(gts) == 0 or len(dts) == 0:
93 | return 0
94 |
95 | for gt in gts:
96 | if not any(gt['segmentation']):
97 | gt['segmentation'] = [[]]
98 | gt_bboxs = [bounding_box(np.array(gt['segmentation'][0]).reshape(-1,2)) for gt in gts]
99 | dt_bboxs = [bounding_box(np.array(dt['segmentation'][0]).reshape(-1,2)) for dt in dts]
100 | gt_polygons = [np.array(gt['segmentation'][0]).reshape(-1,2) for gt in gts]
101 | dt_polygons = [np.array(dt['segmentation'][0]).reshape(-1,2) for dt in dts]
102 |
103 | # IoU match
104 | iscrowd = [0] * len(gt_bboxs)
105 | # ious = maskUtils.iou(gt_bboxs, dt_bboxs, iscrowd)
106 | ious = maskUtils.iou(dt_bboxs, gt_bboxs, iscrowd)
107 |
108 | # compute polis
109 | img_polis_avg = 0
110 | num_sample = 0
111 | for i, gt_poly in enumerate(gt_polygons):
112 | matched_idx = np.argmax(ious[:,i])
113 | iou = ious[matched_idx, i]
114 | if iou > 0.5: # iouThres:
115 | polis = compare_polys(Polygon(gt_poly), Polygon(dt_polygons[matched_idx]))
116 | img_polis_avg += polis
117 | num_sample += 1
118 |
119 | if num_sample == 0:
120 | return 0
121 | else:
122 | return img_polis_avg / num_sample
123 |
124 |
125 | def evaluate(self):
126 | self._prepare()
127 | polis_tot = 0
128 |
129 | num_valid_imgs = 0
130 | for imgId in tqdm(self.imgIds):
131 | img_polis_avg = self.evaluateImg(imgId)
132 |
133 | if img_polis_avg == 0:
134 | continue
135 | else:
136 | polis_tot += img_polis_avg
137 | num_valid_imgs += 1
138 |
139 | polis_avg = polis_tot / num_valid_imgs
140 |
141 | print('average polis: %f' % (polis_avg))
142 |
143 | return polis_avg
144 |
145 |
--------------------------------------------------------------------------------
/eval/topdig_eval_utils/metrics/topdig_metrics.py:
--------------------------------------------------------------------------------
1 | from pycocotools.coco import COCO
2 | from pycocotools import mask as cocomask
3 | import numpy as np
4 | import json
5 | import argparse
6 | from tqdm import tqdm
7 |
8 | import torch
9 | import cv2
10 |
11 | from eval.hisup_eval_utils.metrics.cIoU import calc_IoU
12 | from torchmetrics.functional.classification import binary_accuracy, binary_f1_score
13 |
14 |
15 | def calc_f1score(mask: np.ndarray, mask_gti: np.ndarray):
16 | union = np.logical_or(mask, mask_gti)
17 | U = np.sum(union)
18 | is_void = U == 0
19 | mask = torch.from_numpy(mask)
20 | mask_gti = torch.from_numpy(mask_gti)
21 |
22 | if is_void:
23 | return 1.0
24 | else:
25 | return binary_f1_score(preds=mask, target=mask_gti)
26 |
27 |
28 | def calc_acc(mask: np.ndarray, mask_gti: np.ndarray):
29 | union = np.logical_or(mask, mask_gti)
30 | U = np.sum(union)
31 | is_void = U == 0
32 | mask = torch.from_numpy(mask)
33 | mask_gti = torch.from_numpy(mask_gti)
34 |
35 | if is_void:
36 | return 1.0
37 | else:
38 | return binary_accuracy(preds=mask, target=mask_gti)
39 |
40 |
41 | def compute_mask_metrics(input_json, gti_annotations):
42 | # Ground truth annotations
43 | coco_gti = COCO(gti_annotations)
44 |
45 | # Predictions annotations
46 | submission_file = json.loads(open(input_json).read())
47 | coco = COCO(gti_annotations)
48 | coco = coco.loadRes(submission_file)
49 |
50 |
51 | image_ids = coco.getImgIds(catIds=coco.getCatIds())
52 | bar = tqdm(image_ids)
53 |
54 | buffer_thickness = 5 # dilation factor same as that used in TopDiG.
55 |
56 | list_acc = []
57 | list_f1 = []
58 | list_iou = []
59 |
60 | list_acc_topo = []
61 | list_f1_topo = []
62 | list_iou_topo = []
63 |
64 | for image_id in bar:
65 |
66 | img = coco.loadImgs(image_id)[0]
67 |
68 | # Predictions
69 | annotation_ids = coco.getAnnIds(imgIds=img['id'])
70 | annotations = coco.loadAnns(annotation_ids)
71 | topo_mask = np.zeros((img['height'], img['width']))
72 | poly_lines = []
73 | for _idx, annotation in enumerate(annotations):
74 | try:
75 | rle = cocomask.frPyObjects(annotation['segmentation'], img['height'], img['width'])
76 | except Exception:
77 | import ipdb; ipdb.set_trace()
78 | m = cocomask.decode(rle)
79 | if _idx == 0:
80 | mask = m.reshape((img['height'], img['width']))
81 | else:
82 | mask = mask + m.reshape((img['height'], img['width']))
83 | for ann in annotation['segmentation']:
84 | ann = np.array(ann).reshape(-1, 2)
85 | ann = np.round(ann).astype(np.int32)
86 | poly_lines.append(ann)
87 | cv2.polylines(topo_mask, poly_lines, isClosed=True, color=1., thickness=buffer_thickness)
88 |
89 | mask = mask != 0
90 | topo_mask = (topo_mask != 0).astype(np.float32)
91 |
92 |
93 | # Ground truth
94 | annotation_ids = coco_gti.getAnnIds(imgIds=img['id'])
95 | annotations = coco_gti.loadAnns(annotation_ids)
96 | topo_mask_gt = np.zeros((img['height'], img['width']))
97 | poly_lines_gt = []
98 | for _idx, annotation in enumerate(annotations):
99 | if any(annotation['segmentation']):
100 | rle = cocomask.frPyObjects(annotation['segmentation'], img['height'], img['width'])
101 | m = cocomask.decode(rle)
102 | else:
103 | annotation['segmentation'] = [[]]
104 | m = np.zeros((img['height'], img['width']))
105 | if m.ndim > 2:
106 | m = np.clip(0, 1, m.sum(axis=-1))
107 | if _idx == 0:
108 | mask_gti = m.reshape((img['height'], img['width']))
109 | else:
110 | mask_gti = mask_gti + m.reshape((img['height'], img['width']))
111 | for ann in annotation['segmentation']:
112 | ann = np.array(ann).reshape(-1, 2)
113 | ann = np.round(ann).astype(np.int32)
114 | poly_lines_gt.append(ann)
115 | cv2.polylines(topo_mask_gt, poly_lines_gt, isClosed=True, color=1., thickness=buffer_thickness)
116 |
117 | mask_gti = mask_gti != 0
118 | topo_mask_gt = (topo_mask_gt != 0).astype(np.float32)
119 |
120 |
121 | pacc = calc_acc(mask, mask_gti)
122 | list_acc.append(pacc)
123 | f1score = calc_f1score(mask, mask_gti)
124 | list_f1.append(f1score)
125 | iou = calc_IoU(mask, mask_gti)
126 | list_iou.append(iou)
127 |
128 | pacc_topo = calc_acc(topo_mask, topo_mask_gt)
129 | list_acc_topo.append(pacc_topo)
130 | f1score_topo = calc_f1score(topo_mask, topo_mask_gt)
131 | list_f1_topo.append(f1score_topo)
132 | iou_topo = calc_IoU(topo_mask, topo_mask_gt)
133 | list_iou_topo.append(iou_topo)
134 |
135 | bar.set_description("iou: %2.4f, p-acc: %2.4f, f1:%2.4f, iou-topo: %2.4f, p-acc-topo: %2.4f, f1-topo:%2.4f " % (np.mean(list_iou), np.mean(list_acc), np.mean(list_f1), np.mean(list_iou_topo), np.mean(list_acc_topo), np.mean(list_f1_topo)))
136 | bar.refresh()
137 |
138 | print("Done!")
139 | print("Mean IoU: ", np.mean(list_iou))
140 | print("Mean P-Acc: ", np.mean(list_acc))
141 | print("Mean F1-Score: ", np.mean(list_f1))
142 | print("Mean IoU-Topo: ", np.mean(list_iou_topo))
143 | print("Mean P-Acc-Topo: ", np.mean(list_acc_topo))
144 | print("Mean F1-Score-Topo: ", np.mean(list_f1_topo))
145 |
146 |
147 | if __name__ == "__main__":
148 | parser = argparse.ArgumentParser()
149 | parser.add_argument("--gt-file", default="")
150 | parser.add_argument("--dt-file", default="")
151 | args = parser.parse_args()
152 |
153 | gt_file = args.gt_file
154 | dt_file = args.dt_file
155 | compute_mask_metrics(input_json=dt_file,
156 | gti_annotations=gt_file)
157 |
--------------------------------------------------------------------------------
/evaluate_mass_roads_predictions.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import argparse
4 | from tqdm import tqdm
5 |
6 | import torch
7 | import cv2
8 |
9 | from eval.hisup_eval_utils.metrics.cIoU import calc_IoU
10 | from torchmetrics.functional.classification import binary_accuracy, binary_f1_score
11 |
12 | def calc_f1score(mask: np.ndarray, mask_gti: np.ndarray):
13 | mask = torch.from_numpy(mask)
14 | mask_gti = torch.from_numpy(mask_gti)
15 | return binary_f1_score(preds=mask, target=mask_gti)
16 |
17 |
18 | def calc_acc(mask: np.ndarray, mask_gti: np.ndarray):
19 | mask = torch.from_numpy(mask)
20 | mask_gti = torch.from_numpy(mask_gti)
21 | return binary_accuracy(preds=mask, target=mask_gti)
22 |
23 |
24 | def compute_mask_metrics(predictions_dir, gt_dir):
25 | # Ground truth annotations
26 | # gt_masks = os.listdir(gt_dir)
27 |
28 | # Predictions annotations
29 | pred_masks = os.listdir(predictions_dir)
30 |
31 |
32 | images = pred_masks
33 | bar = tqdm(images)
34 |
35 | list_acc_topo = []
36 | list_f1_topo = []
37 | list_iou_topo = []
38 |
39 | for image_id in bar:
40 |
41 | # img = cv2.imread(os.path.join(predictions_dir, image_id))
42 |
43 | # Predictions
44 | topo_mask = cv2.imread(os.path.join(predictions_dir, image_id))
45 | topo_mask = (topo_mask != 0).astype(np.float32)
46 |
47 | # Ground truth
48 | topo_mask_gt = cv2.imread(os.path.join(gt_dir, f"{image_id.split('.')[0]}.tif"))
49 | topo_mask_gt = (topo_mask_gt != 0).astype(np.float32)
50 |
51 | # Standard Torchmetrics Implementation
52 | pacc_orig = calc_acc(topo_mask, topo_mask_gt)
53 | list_acc_topo.append(pacc_orig)
54 | iou_orig = calc_IoU(topo_mask, topo_mask_gt)
55 | list_iou_topo.append(iou_orig)
56 | f1score_orig = calc_f1score(topo_mask, topo_mask_gt)
57 | list_f1_topo.append(f1score_orig)
58 |
59 | bar.set_description("iou-topo: %2.4f, p-acc-topo: %2.4f, f1-topo:%2.4f " % (np.mean(list_iou_topo), np.mean(list_acc_topo), np.mean(list_f1_topo)))
60 | bar.refresh()
61 |
62 | print("Done!")
63 | print("############## TOPO METRICS ############")
64 | print("Mean IoU-Topo: ", np.mean(list_iou_topo))
65 | print("Mean P-Acc-Topo: ", np.mean(list_acc_topo))
66 | print("Mean F1-Score-Topo: ", np.mean(list_f1_topo))
67 |
68 |
69 | if __name__ == "__main__":
70 | parser = argparse.ArgumentParser()
71 | parser.add_argument("--gt-dir", default="")
72 | parser.add_argument("--dt-dir", default="")
73 | args = parser.parse_args()
74 |
75 | gt_dir = args.gt_dir
76 | dt_dir = args.dt_dir
77 | compute_mask_metrics(predictions_dir=dt_dir,
78 | gt_dir=gt_dir)
79 |
--------------------------------------------------------------------------------
/evaluation.py:
--------------------------------------------------------------------------------
1 | # Borrowed from https://github.com/SarahwXU/HiSup/blob/main/tools/evaluation.py
2 |
3 | import argparse
4 |
5 | from multiprocess import Pool
6 | from pycocotools.coco import COCO
7 | from pycocotools.cocoeval import COCOeval
8 | from eval.hisup_eval_utils.metrics.polis import PolisEval
9 | from eval.hisup_eval_utils.metrics.angle_eval import ContourEval
10 | from eval.hisup_eval_utils.metrics.cIoU import compute_IoU_cIoU
11 | from eval.topdig_eval_utils.metrics.topdig_metrics import compute_mask_metrics
12 |
13 |
14 | def coco_eval(annFile, resFile):
15 | type=1
16 | annType = ['bbox', 'segm']
17 | print('Running demo for *%s* results.' % (annType[type]))
18 |
19 | cocoGt = COCO(annFile)
20 | cocoDt = cocoGt.loadRes(resFile)
21 |
22 | imgIds = cocoGt.getImgIds()
23 | imgIds = imgIds[:]
24 |
25 | cocoEval = COCOeval(cocoGt, cocoDt, annType[type])
26 | cocoEval.params.imgIds = imgIds
27 | cocoEval.params.catIds = [100]
28 | cocoEval.evaluate()
29 | cocoEval.accumulate()
30 | cocoEval.summarize()
31 | return cocoEval.stats
32 |
33 | def polis_eval(annFile, resFile):
34 | gt_coco = COCO(annFile)
35 | dt_coco = gt_coco.loadRes(resFile)
36 | polisEval = PolisEval(gt_coco, dt_coco)
37 | polisEval.evaluate()
38 |
39 | def max_angle_error_eval(annFile, resFile):
40 | gt_coco = COCO(annFile)
41 | dt_coco = gt_coco.loadRes(resFile)
42 | contour_eval = ContourEval(gt_coco, dt_coco)
43 | pool = Pool(processes=20)
44 | max_angle_diffs = contour_eval.evaluate(pool=pool)
45 | print('Mean max tangent angle error(MTA): ', max_angle_diffs.mean())
46 |
47 | if __name__ == "__main__":
48 | parser = argparse.ArgumentParser()
49 | parser.add_argument("--gt-file", default="")
50 | parser.add_argument("--dt-file", default="")
51 | parser.add_argument("--eval-type", default="coco_iou", choices=["coco_iou", "polis", "angle", "ciou", "topdig"])
52 | args = parser.parse_args()
53 |
54 | eval_type = args.eval_type
55 | gt_file = args.gt_file
56 | dt_file = args.dt_file
57 | if eval_type == 'coco_iou':
58 | coco_eval(gt_file, dt_file)
59 | elif eval_type == 'polis':
60 | polis_eval(gt_file, dt_file)
61 | elif eval_type == 'angle':
62 | max_angle_error_eval(gt_file, dt_file)
63 | elif eval_type == 'ciou':
64 | compute_IoU_cIoU(dt_file, gt_file)
65 | elif eval_type == 'topdig':
66 | compute_mask_metrics(dt_file, gt_file)
67 | else:
68 | raise RuntimeError('please choose a correct type from \
69 | ["coco_iou", "polis", "angle", "ciou", "topdig"]')
70 |
--------------------------------------------------------------------------------
/models/model.py:
--------------------------------------------------------------------------------
1 | import timm
2 | import torch
3 | from torch import nn
4 | from torch.nn import functional as F
5 | from timm.models.layers import trunc_normal_
6 |
7 | import os
8 | import sys
9 | sys.path.insert(1, os.getcwd())
10 |
11 | from config import CFG
12 | from utils import (
13 | create_mask,
14 | )
15 |
16 |
17 | # Borrowed from https://github.com/magicleap/SuperGluePretrainedNetwork/blob/ddcf11f42e7e0732a0c4607648f9448ea8d73590/models/superglue.py#L143
18 | def log_sinkhorn_iterations(Z: torch.Tensor, log_mu: torch.Tensor, log_nu: torch.Tensor, iters: int) -> torch.Tensor:
19 | """ Perform Sinkhorn Normalization in Log-space for stability"""
20 | u, v = torch.zeros_like(log_mu), torch.zeros_like(log_nu)
21 | for _ in range(iters):
22 | u = log_mu - torch.logsumexp(Z + v.unsqueeze(1), dim=2)
23 | v = log_nu - torch.logsumexp(Z + u.unsqueeze(2), dim=1)
24 | return Z + u.unsqueeze(2) + v.unsqueeze(1)
25 |
26 | # Borrowed from https://github.com/magicleap/SuperGluePretrainedNetwork/blob/ddcf11f42e7e0732a0c4607648f9448ea8d73590/models/superglue.py#L152
27 | def log_optimal_transport(scores: torch.Tensor, alpha: torch.Tensor, iters: int) -> torch.Tensor:
28 | """ Perform Differentiable Optimal Transport in Log-space for stability"""
29 | b, m, n = scores.shape
30 | one = scores.new_tensor(1)
31 | ms, ns = (m*one).to(scores), (n*one).to(scores)
32 |
33 | bins0 = alpha.expand(b, m, 1)
34 | bins1 = alpha.expand(b, 1, n)
35 | alpha = alpha.expand(b, 1, 1)
36 |
37 | couplings = torch.cat([torch.cat([scores, bins0], -1),
38 | torch.cat([bins1, alpha], -1)], 1)
39 |
40 | norm = - (ms + ns).log()
41 | log_mu = torch.cat([norm.expand(m), ns.log()[None] + norm])
42 | log_nu = torch.cat([norm.expand(n), ms.log()[None] + norm])
43 | log_mu, log_nu = log_mu[None].expand(b, -1), log_nu[None].expand(b, -1)
44 |
45 | Z = log_sinkhorn_iterations(couplings, log_mu, log_nu, iters)
46 | Z = Z - norm # multiply probabilities by M+N
47 | return Z
48 |
49 |
50 | class ScoreNet(nn.Module):
51 | def __init__(self, n_vertices, in_channels=512):
52 | super().__init__()
53 | self.n_vertices = n_vertices
54 | self.in_channels = in_channels
55 | self.relu = nn.ReLU(inplace=True)
56 | self.conv1 = nn.Conv2d(in_channels, 256, kernel_size=1, stride=1, padding=0, bias=True)
57 | self.bn1 = nn.BatchNorm2d(256)
58 | self.conv2 = nn.Conv2d(256, 128, kernel_size=1, stride=1, padding=0, bias=True)
59 | self.bn2 = nn.BatchNorm2d(128)
60 | self.conv3 = nn.Conv2d(128, 64, kernel_size=1, stride=1, padding=0, bias=True)
61 | self.bn3 = nn.BatchNorm2d(64)
62 | self.conv4 = nn.Conv2d(64, 1, kernel_size=1, stride=1, padding=0, bias=True)
63 |
64 | def forward(self, feats):
65 | feats = feats[:, 1:]
66 | feats = feats.unsqueeze(2)
67 | feats = feats.view(feats.size(0), feats.size(1)//2, 2, feats.size(3))
68 | feats = torch.mean(feats, dim=2)
69 |
70 | x = torch.transpose(feats, 1, 2)
71 | x = x.unsqueeze(-1)
72 | x = x.repeat(1, 1, 1, self.n_vertices)
73 | t = torch.transpose(x, 2, 3)
74 | x = torch.cat((x, t), dim=1)
75 |
76 | x = self.conv1(x)
77 | x = self.bn1(x)
78 | x = self.relu(x)
79 |
80 | x = self.conv2(x)
81 | x = self.bn2(x)
82 | x = self.relu(x)
83 |
84 | x = self.conv3(x)
85 | x = self.bn3(x)
86 | x = self.relu(x)
87 |
88 | x = self.conv4(x)
89 |
90 | return x[:, 0]
91 |
92 |
93 | class Encoder(nn.Module):
94 | def __init__(self, model_name='deit3_small_patch16_384_in21ft1k', pretrained=False, out_dim=256) -> None:
95 | super().__init__()
96 | self.model = timm.create_model(
97 | model_name=model_name,
98 | num_classes=0,
99 | global_pool='',
100 | pretrained=pretrained
101 | )
102 | self.bottleneck = nn.AdaptiveAvgPool1d(out_dim)
103 |
104 | def forward(self, x):
105 | features = self.model(x)
106 | return self.bottleneck(features[:, 1:])
107 |
108 |
109 | class Decoder(nn.Module):
110 | def __init__(self, cfg, vocab_size, encoder_len, dim, num_heads, num_layers):
111 | super().__init__()
112 | self.cfg = cfg
113 | self.dim = dim
114 |
115 | self.embedding = nn.Embedding(vocab_size, dim)
116 | self.decoder_pos_embed = nn.Parameter(torch.randn(1, self.cfg.MAX_LEN-1, dim) * .02)
117 | self.decoder_pos_drop = nn.Dropout(p=0.05)
118 |
119 | decoder_layer = nn.TransformerDecoderLayer(d_model=dim, nhead=num_heads)
120 | self.decoder = nn.TransformerDecoder(decoder_layer=decoder_layer, num_layers=num_layers)
121 | self.output = nn.Linear(dim, vocab_size)
122 |
123 | self.encoder_pos_embed = nn.Parameter(torch.randn(1, encoder_len, dim) * .02)
124 | self.encoder_pos_drop = nn.Dropout(p=0.05)
125 |
126 | self.init_weights()
127 |
128 | def init_weights(self):
129 | for name, p in self.named_parameters():
130 | if 'encoder_pos_embed' in name or 'decoder_pos_embed' in name:
131 | print(f"Skipping initialization of pos embed layers...")
132 | continue
133 | if p.dim() > 1:
134 | nn.init.xavier_uniform_(p)
135 |
136 | trunc_normal_(self.encoder_pos_embed, std=.02)
137 | trunc_normal_(self.decoder_pos_embed, std=.02)
138 |
139 | def forward(self, encoder_out, tgt):
140 | """
141 | encoder_out shape: (N, L, D)
142 | tgt shape: (N, L)
143 | """
144 |
145 | tgt_mask, tgt_padding_mask = create_mask(tgt, self.cfg.PAD_IDX)
146 | tgt_embedding = self.embedding(tgt)
147 | tgt_embedding = self.decoder_pos_drop(
148 | tgt_embedding + self.decoder_pos_embed
149 | )
150 |
151 | encoder_out = self.encoder_pos_drop(
152 | encoder_out + self.encoder_pos_embed
153 | )
154 |
155 | encoder_out = encoder_out.transpose(0, 1)
156 | tgt_embedding = tgt_embedding.transpose(0, 1)
157 |
158 | preds = self.decoder(
159 | memory=encoder_out,
160 | tgt=tgt_embedding,
161 | tgt_mask=tgt_mask,
162 | tgt_key_padding_mask=tgt_padding_mask
163 | )
164 |
165 | preds = preds.transpose(0, 1)
166 | return self.output(preds), preds
167 |
168 | def predict(self, encoder_out, tgt):
169 | length = tgt.size(1)
170 | padding = torch.ones((tgt.size(0), self.cfg.MAX_LEN-length-1), device=tgt.device).fill_(self.cfg.PAD_IDX).long()
171 | tgt = torch.cat([tgt, padding], dim=1)
172 | tgt_mask, tgt_padding_mask = create_mask(tgt, self.cfg.PAD_IDX)
173 | tgt_embedding = self.embedding(tgt)
174 | tgt_embedding = self.decoder_pos_drop(
175 | tgt_embedding + self.decoder_pos_embed
176 | )
177 |
178 | encoder_out = self.encoder_pos_drop(
179 | encoder_out + self.encoder_pos_embed
180 | )
181 |
182 | encoder_out = encoder_out.transpose(0, 1)
183 | tgt_embedding = tgt_embedding.transpose(0, 1)
184 |
185 | preds = self.decoder(
186 | memory=encoder_out,
187 | tgt=tgt_embedding,
188 | tgt_mask=tgt_mask,
189 | tgt_key_padding_mask = tgt_padding_mask
190 | )
191 |
192 | preds = preds.transpose(0, 1)
193 | return self.output(preds)[:, length-1, :], preds
194 |
195 |
196 | class EncoderDecoder(nn.Module):
197 | def __init__(self, cfg, encoder, decoder):
198 | super().__init__()
199 | self.cfg = cfg
200 | self.encoder = encoder
201 | self.decoder = decoder
202 | self.scorenet1 = ScoreNet(self.cfg.N_VERTICES)
203 | self.scorenet2 = ScoreNet(self.cfg.N_VERTICES)
204 | bin_score = torch.nn.Parameter(torch.tensor(1.))
205 | self.register_parameter('bin_score', bin_score)
206 |
207 | def forward(self, image, tgt):
208 | encoder_out = self.encoder(image)
209 | preds, feats = self.decoder(encoder_out, tgt)
210 | perm_mat1 = self.scorenet1(feats)
211 | perm_mat2 = self.scorenet2(feats)
212 | perm_mat = perm_mat1 + torch.transpose(perm_mat2, 1, 2)
213 |
214 | perm_mat = log_optimal_transport(perm_mat, self.bin_score, self.cfg.SINKHORN_ITERATIONS)[:, :perm_mat.shape[1], :perm_mat.shape[2]]
215 | perm_mat = F.softmax(perm_mat, dim=-1) # NOTE: perhaps try gumbel softmax here?
216 | # perm_mat = F.gumbel_softmax(perm_mat, tau=1.0, hard=False)
217 |
218 | return preds, perm_mat
219 |
220 | def predict(self, image, tgt):
221 | encoder_out = self.encoder(image)
222 | preds, feats = self.decoder.predict(encoder_out, tgt)
223 | return preds, feats
224 |
225 |
226 | if __name__ == "__main__":
227 | # run this script as main for debugging.
228 | from tokenizer import Tokenizer
229 | from torch.nn.utils.rnn import pad_sequence
230 | import numpy as np
231 | import torch
232 | from torch import nn
233 |
234 | image = torch.randn(1, 3, CFG.INPUT_HEIGHT, CFG.INPUT_WIDTH).to('cuda')
235 |
236 | n_vertices = 192
237 | gt_coords = np.random.randint(size=(n_vertices, 2), low=0, high=CFG.IMG_SIZE).astype(np.float32)
238 | # in dataset
239 | tokenizer = Tokenizer(num_classes=1, num_bins=CFG.NUM_BINS, width=CFG.IMG_SIZE, height=CFG.IMG_SIZE, max_len=CFG.MAX_LEN)
240 | gt_seqs, rand_idxs = tokenizer(gt_coords)
241 | # in dataloader collate
242 | gt_seqs = [torch.LongTensor(gt_seqs)]
243 | gt_seqs = pad_sequence(gt_seqs, batch_first=True, padding_value=tokenizer.PAD_code)
244 | pad = torch.ones(gt_seqs.size(0), CFG.MAX_LEN - gt_seqs.size(1)).fill_(tokenizer.PAD_code).long()
245 | gt_seqs = torch.cat([gt_seqs, pad], dim=1).to('cuda')
246 | # in train fn
247 | gt_seqs_input = gt_seqs[:, :-1]
248 | gt_seqs_expected = gt_seqs[:, 1:]
249 | CFG.PAD_IDX = tokenizer.PAD_code
250 |
251 | encoder = Encoder(model_name=CFG.MODEL_NAME, pretrained=False, out_dim=256)
252 | decoder = Decoder(vocab_size=tokenizer.vocab_size, encoder_len=CFG.NUM_PATCHES, dim=256, num_heads=8, num_layers=6)
253 | model = EncoderDecoder(encoder, decoder).to('cuda')
254 | vertex_loss_fn = nn.CrossEntropyLoss(ignore_index=CFG.PAD_IDX)
255 |
256 | # Forward pass during training.
257 | preds_f, perm_mat, batch_polygons = model(image, gt_seqs_input)
258 | loss = vertex_loss_fn(preds_f.reshape(-1, preds_f.shape[-1]), gt_seqs_expected.reshape(-1))
259 |
260 | # Sequence generation during prediction.
261 | batch_preds = torch.ones(image.size(0), 1).fill_(tokenizer.BOS_code).long().to(CFG.DEVICE)
262 | batch_feats = torch.ones(image.size(0), 1).fill_(tokenizer.BOS_code).long().to(CFG.DEVICE)
263 | sample = lambda preds: torch.softmax(preds, dim=-1).argmax(dim=-1).view(-1, 1)
264 |
265 | out_coords = []
266 | out_confs = []
267 |
268 | confs = []
269 | with torch.no_grad():
270 | for i in range(1 + n_vertices*2):
271 | try:
272 | print(i)
273 | preds_p, feats_p = model.predict(image, batch_preds)
274 | # print(preds_p.shape, feats_p.shape)
275 | if i % 2 == 0:
276 | confs_ = torch.softmax(preds_p, dim=-1).sort(axis=-1, descending=True)[0][:, 0].cpu()
277 | confs.append(confs_)
278 | preds_p = sample(preds_p)
279 | batch_preds = torch.cat([batch_preds, preds_p], dim=1)
280 | except:
281 | print(f"Error at iteration: {i}")
282 | perm_pred = model.scorenet(feats_p)
283 |
284 | # Postprocessing.
285 | EOS_idxs = (batch_preds == tokenizer.EOS_code).float().argmax(dim=-1)
286 | invalid_idxs = ((EOS_idxs - 1) % 2 != 0).nonzero().view(-1) # sanity check
287 | EOS_idxs[invalid_idxs] = 0
288 |
289 | all_coords = []
290 | all_confs = []
291 | for i, EOS_idx in enumerate(EOS_idxs.tolist()):
292 | if EOS_idx == 0:
293 | all_coords.append(None)
294 | all_confs.append(None)
295 | continue
296 | coords = tokenizer.decode(batch_preds[i, :EOS_idx+1])
297 | confs = [round(confs[j][i].item(), 3) for j in range(len(coords))]
298 |
299 | all_coords.append(coords)
300 | all_confs.append(confs)
301 |
302 | out_coords.extend(all_coords)
303 | out_confs.extend(out_confs)
304 |
305 | print(f"preds_f shape: {preds_f.shape}")
306 | print(f"preds_f grad: {preds_f.requires_grad}")
307 | print(f"preds_f min: {preds_f.min()}, max: {preds_f.max()}")
308 |
309 | print(f"perm_mat shape: {perm_mat.shape}")
310 | print(f"perm_mat grad: {perm_mat.requires_grad}")
311 | print(f"perm_mat min: {perm_mat.min()}, max: {preds_f.max()}")
312 |
313 | print(f"batch_preds shape: {batch_preds.shape}")
314 | print(f"batch_preds grad: {batch_preds.requires_grad}")
315 | print(f"batch_preds min: {batch_preds.min()}, max: {batch_preds.max()}")
316 |
317 | print(f"loss : {loss}")
318 | print(f"loss grad: {loss.requires_grad}")
319 |
320 |
--------------------------------------------------------------------------------
/postprocess_coco_parts.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import json
3 | from tqdm import tqdm
4 | import shapely
5 |
6 | parts = [
7 | "runs/CYENS_CLUSTER_train_Pix2Poly_AUGSRUNS_inria_coco_224_negAug_run1_vit_small_patch8_224_dino_AffineRotaugs0.8_LinearWarmupLRS_1.0xVertexLoss_10.0xPermLoss__2xScoreNet_initialLR_0.0004_bs_24_Nv_192_Nbins224_500epochs/predictions_inria_coco_224_negAug_val_images_epoch_499.json"
8 | ]
9 | out_path = f"runs/CYENS_CLUSTER_train_Pix2Poly_AUGSRUNS_inria_coco_224_negAug_run1_vit_small_patch8_224_dino_AffineRotaugs0.8_LinearWarmupLRS_1.0xVertexLoss_10.0xPermLoss__2xScoreNet_initialLR_0.0004_bs_24_Nv_192_Nbins224_500epochs/combined_predictions_inria_coco_224_negAug_val_images_epoch_499.json"
10 |
11 |
12 | ################################################################
13 | combined = []
14 | part_lengths = []
15 | for i, part in enumerate(parts):
16 | print(f"PART {i}")
17 | with open(part) as f:
18 | pred_part = json.loads(f.read())
19 | part_lengths.append(len(pred_part))
20 | for ins in tqdm(pred_part):
21 | assert len(ins['segmentation']) == 1
22 | segm = ins['segmentation'][0]
23 | segm = np.array(segm).reshape(-1, 2)
24 | # segm = np.flip(segm, axis=0) # invert order of vertices cw -> ccw or ccw -> cw
25 | if segm.shape[0] > 2 and shapely.Polygon(segm).area > 50.:
26 | segm = np.concatenate((segm, segm[0, None]), axis=0)
27 | segm = segm.reshape(-1).tolist()
28 | ins['segmentation'] = [segm]
29 | combined.append(ins)
30 |
31 | with open(out_path, "w") as fp:
32 | fp.write(json.dumps(combined))
33 | ################################################################
34 |
35 |
--------------------------------------------------------------------------------
/predict_inria_coco_val_set.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | import json
4 | from tqdm import tqdm
5 | import numpy as np
6 | import cv2
7 | import matplotlib.pyplot as plt
8 | import argparse
9 |
10 | from functools import partial
11 | import torch
12 | from torch import nn
13 | from torch.utils.data.distributed import DistributedSampler
14 | from torchvision.utils import make_grid
15 | import albumentations as A
16 | from albumentations.pytorch import ToTensorV2
17 |
18 | from test_config import CFG
19 | from tokenizer import Tokenizer
20 | from utils import (
21 | seed_everything,
22 | load_checkpoint,
23 | test_generate,
24 | postprocess,
25 | permutations_to_polygons,
26 | )
27 | from models.model import (
28 | Encoder,
29 | Decoder,
30 | EncoderDecoder
31 | )
32 |
33 | from torch.utils.data import DataLoader
34 | from datasets.dataset_inria_coco import InriaCocoDataset_val
35 | from torch.nn.utils.rnn import pad_sequence
36 | from torchmetrics.classification import BinaryJaccardIndex, BinaryAccuracy
37 | import torch.multiprocessing
38 | torch.multiprocessing.set_sharing_strategy("file_system")
39 |
40 |
41 | parser = argparse.ArgumentParser()
42 | parser.add_argument("-d", "--dataset", help="Dataset to use for evaluation.")
43 | parser.add_argument("-e", "--experiment_path", help="path to experiment folder to evaluate.")
44 | parser.add_argument("-c", "--checkpoint_name", help="Choice of checkpoint to evaluate in experiment.")
45 | parser.add_argument("-o", "--output_dir", help="Name of output subdirectory to store part predictions.")
46 | args = parser.parse_args()
47 |
48 |
49 | torch.backends.cuda.matmul.allow_tf32 = True
50 | torch.backends.cudnn.allow_tf32 = True
51 |
52 | DATASET = f"{args.dataset}"
53 | VAL_DATASET_DIR = f"./data/{DATASET}/val"
54 | # PART_DESC = "val_images"
55 | PART_DESC = f"{args.output_dir}"
56 |
57 | EXPERIMENT_NAME = os.path.basename(os.path.realpath(args.experiment_path))
58 | CHECKPOINT_PATH = f"runs/{EXPERIMENT_NAME}/logs/checkpoints/{args.checkpoint_name}.pth"
59 | BATCH_SIZE = 24
60 |
61 |
62 | def bounding_box_from_points(points):
63 | points = np.array(points).flatten()
64 | even_locations = np.arange(points.shape[0]/2) * 2
65 | odd_locations = even_locations + 1
66 | X = np.take(points, even_locations.tolist())
67 | Y = np.take(points, odd_locations.tolist())
68 | bbox = [X.min(), Y.min(), X.max()-X.min(), Y.max()-Y.min()]
69 | bbox = [int(b) for b in bbox]
70 | return bbox
71 |
72 |
73 | def single_annotation(image_id, poly):
74 | _result = {}
75 | _result["image_id"] = int(image_id)
76 | _result["category_id"] = 100
77 | _result["score"] = 1
78 | _result["segmentation"] = poly
79 | _result["bbox"] = bounding_box_from_points(_result["segmentation"])
80 | return _result
81 |
82 |
83 | def collate_fn(batch, max_len, pad_idx):
84 | """
85 | if max_len:
86 | the sequences will all be padded to that length.
87 | """
88 |
89 | image_batch, mask_batch, coords_mask_batch, coords_seq_batch, perm_matrix_batch, idx_batch = [], [], [], [], [], []
90 | for image, mask, c_mask, seq, perm_mat, idx in batch:
91 | image_batch.append(image)
92 | mask_batch.append(mask)
93 | coords_mask_batch.append(c_mask)
94 | coords_seq_batch.append(seq)
95 | perm_matrix_batch.append(perm_mat)
96 | idx_batch.append(idx)
97 |
98 | coords_seq_batch = pad_sequence(
99 | coords_seq_batch,
100 | padding_value=pad_idx,
101 | batch_first=True
102 | )
103 |
104 | if max_len:
105 | pad = torch.ones(coords_seq_batch.size(0), max_len - coords_seq_batch.size(1)).fill_(pad_idx).long()
106 | coords_seq_batch = torch.cat([coords_seq_batch, pad], dim=1)
107 |
108 | image_batch = torch.stack(image_batch)
109 | mask_batch = torch.stack(mask_batch)
110 | coords_mask_batch = torch.stack(coords_mask_batch)
111 | perm_matrix_batch = torch.stack(perm_matrix_batch)
112 | idx_batch = torch.stack(idx_batch)
113 | return image_batch, mask_batch, coords_mask_batch, coords_seq_batch, perm_matrix_batch, idx_batch
114 |
115 |
116 | def main():
117 | seed_everything(42)
118 |
119 | valid_transforms = A.Compose(
120 | [
121 | A.Resize(height=CFG.INPUT_HEIGHT, width=CFG.INPUT_WIDTH),
122 | A.Normalize(
123 | mean=[0.0, 0.0, 0.0],
124 | std=[1.0, 1.0, 1.0],
125 | max_pixel_value=255.0
126 | ),
127 | ToTensorV2(),
128 | ],
129 | keypoint_params=A.KeypointParams(format='yx', remove_invisible=False)
130 | )
131 |
132 | tokenizer = Tokenizer(
133 | num_classes=1,
134 | num_bins=CFG.NUM_BINS,
135 | width=CFG.INPUT_WIDTH,
136 | height=CFG.INPUT_HEIGHT,
137 | max_len=CFG.MAX_LEN
138 | )
139 | CFG.PAD_IDX = tokenizer.PAD_code
140 |
141 | val_ds = InriaCocoDataset_val(
142 | cfg=CFG,
143 | dataset_dir=VAL_DATASET_DIR,
144 | transform=valid_transforms,
145 | tokenizer=tokenizer,
146 | shuffle_tokens=CFG.SHUFFLE_TOKENS
147 | )
148 | val_loader = DataLoader(
149 | val_ds,
150 | batch_size=BATCH_SIZE,
151 | collate_fn=partial(collate_fn, max_len=CFG.MAX_LEN, pad_idx=CFG.PAD_IDX),
152 | num_workers=2
153 | )
154 |
155 | encoder = Encoder(model_name=CFG.MODEL_NAME, pretrained=True, out_dim=256)
156 | decoder = Decoder(
157 | cfg=CFG,
158 | vocab_size=tokenizer.vocab_size,
159 | encoder_len=CFG.NUM_PATCHES,
160 | dim=256,
161 | num_heads=8,
162 | num_layers=6
163 | )
164 | model = EncoderDecoder(cfg=CFG, encoder=encoder, decoder=decoder)
165 | model.to(CFG.DEVICE)
166 | model.eval()
167 |
168 | checkpoint = torch.load(CHECKPOINT_PATH)
169 | model.load_state_dict(checkpoint['state_dict'])
170 | epoch = checkpoint['epochs_run']
171 |
172 | print(f"Model loaded from epoch: {epoch}")
173 | ckpt_desc = f"epoch_{epoch}"
174 | if "best_valid_loss" in os.path.basename(CHECKPOINT_PATH):
175 | ckpt_desc = f"epoch_{epoch}_bestValLoss"
176 | elif "best_valid_metric" in os.path.basename(CHECKPOINT_PATH):
177 | ckpt_desc = f"epoch_{epoch}_bestValMetric"
178 | else:
179 | pass
180 |
181 | mean_iou_metric = BinaryJaccardIndex()
182 | mean_acc_metric = BinaryAccuracy()
183 |
184 |
185 | with torch.no_grad():
186 | cumulative_miou = []
187 | cumulative_macc = []
188 | speed = []
189 | predictions = []
190 | for i_batch, (x, y_mask, y_corner_mask, y, y_perm, idx) in enumerate(tqdm(val_loader)):
191 | all_coords = []
192 | all_confs = []
193 | t0 = time.time()
194 | batch_preds, batch_confs, perm_preds = test_generate(model, x, tokenizer, max_len=CFG.generation_steps, top_k=0, top_p=1)
195 | speed.append(time.time() - t0)
196 | vertex_coords, confs = postprocess(batch_preds, batch_confs, tokenizer)
197 |
198 | all_coords.extend(vertex_coords)
199 | all_confs.extend(confs)
200 |
201 | coords = []
202 | for i in range(len(all_coords)):
203 | if all_coords[i] is not None:
204 | coord = torch.from_numpy(all_coords[i])
205 | else:
206 | coord = torch.tensor([])
207 |
208 | padd = torch.ones((CFG.N_VERTICES - len(coord), 2)).fill_(CFG.PAD_IDX)
209 | coord = torch.cat([coord, padd], dim=0)
210 | coords.append(coord)
211 | batch_polygons = permutations_to_polygons(perm_preds, coords, out='torch') # [0, 224]
212 | # pred_polygons = permutations_to_polygons(perm_preds, coords, out='coco') # [0, 224]
213 |
214 | for ip, pp in enumerate(batch_polygons):
215 | for p in pp:
216 | p = torch.fliplr(p)
217 | p = p[p[:, 0] != CFG.PAD_IDX]
218 | p = p * (CFG.IMG_SIZE / CFG.INPUT_WIDTH)
219 | p = p.view(-1).tolist()
220 | if len(p) > 0:
221 | predictions.append(single_annotation(idx[ip], [p]))
222 |
223 | B, C, H, W = x.shape
224 |
225 | polygons_mask = np.zeros((B, 1, H, W))
226 | for b in range(len(batch_polygons)):
227 | for c in range(len(batch_polygons[b])):
228 | poly = batch_polygons[b][c]
229 | poly = poly[poly[:, 0] != CFG.PAD_IDX]
230 | cnt = np.flip(np.int32(poly.cpu()), 1)
231 | if len(cnt) > 0:
232 | cv2.fillPoly(polygons_mask[b, 0], pts=[cnt], color=1.)
233 | polygons_mask = torch.from_numpy(polygons_mask)
234 |
235 | batch_miou = mean_iou_metric(polygons_mask, y_mask)
236 | batch_macc = mean_acc_metric(polygons_mask, y_mask)
237 |
238 | cumulative_miou.append(batch_miou)
239 | cumulative_macc.append(batch_macc)
240 |
241 | pred_grid = make_grid(polygons_mask).permute(1, 2, 0)
242 | gt_grid = make_grid(y_mask).permute(1, 2, 0)
243 | plt.subplot(211), plt.imshow(pred_grid) ,plt.title("Predicted Polygons") ,plt.axis('off')
244 | plt.subplot(212), plt.imshow(gt_grid) ,plt.title("Ground Truth") ,plt.axis('off')
245 |
246 | if not os.path.exists(os.path.join(f"runs/{EXPERIMENT_NAME}", 'val_preds', DATASET, PART_DESC, ckpt_desc)):
247 | os.makedirs(os.path.join(f"runs/{EXPERIMENT_NAME}", 'val_preds', DATASET, PART_DESC, ckpt_desc))
248 | plt.savefig(f"runs/{EXPERIMENT_NAME}/val_preds/{DATASET}/{PART_DESC}/{ckpt_desc}/batch_{i_batch}.png")
249 | plt.close()
250 |
251 | print("Average model speed: ", np.mean(speed) / BATCH_SIZE, " [s / image]")
252 |
253 | print(f"Average Mean IOU: {torch.tensor(cumulative_miou).nanmean()}")
254 | print(f"Average Mean Acc: {torch.tensor(cumulative_macc).nanmean()}")
255 |
256 | with open(f"runs/{EXPERIMENT_NAME}/predictions_{DATASET}_{PART_DESC}_{ckpt_desc}.json", "w") as fp:
257 | fp.write(json.dumps(predictions))
258 |
259 | with open(f"runs/{EXPERIMENT_NAME}/val_metrics_{DATASET}_{PART_DESC}_{ckpt_desc}.txt", 'w') as ff:
260 | print(f"Average Mean IOU: {torch.tensor(cumulative_miou).nanmean()}", file=ff)
261 | print(f"Average Mean Acc: {torch.tensor(cumulative_macc).nanmean()}", file=ff)
262 |
263 |
264 | if __name__ == "__main__":
265 | main()
266 |
267 |
--------------------------------------------------------------------------------
/predict_mass_roads_test_set.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | from PIL import Image
4 | import cv2
5 | import json
6 | import matplotlib.pyplot as plt
7 | from matplotlib import patches
8 | # import gif
9 | from tqdm import tqdm
10 |
11 | import torch
12 | import albumentations as A
13 | from albumentations.pytorch import ToTensorV2
14 | from tokenizer import Tokenizer
15 | from test_config import CFG
16 | from models.model import (
17 | Encoder,
18 | Decoder,
19 | EncoderDecoder,
20 | )
21 | from utils import (
22 | seed_everything,
23 | test_generate,
24 | postprocess,
25 | permutations_to_polygons,
26 | )
27 | import time
28 |
29 |
30 | # adapted from https://github.com/obss/sahi/blob/e798c80d6e09079ae07a672c89732dd602fe9001/sahi/slicing.py#L30, MIT License
31 | def calculate_slice_bboxes(
32 | image_height: int,
33 | image_width: int,
34 | slice_height: int = 512,
35 | slice_width: int = 512,
36 | overlap_height_ratio: float = 0.2,
37 | overlap_width_ratio: float = 0.2,
38 | ) -> list[list[int]]:
39 | """
40 | Given the height and width of an image, calculates how to divide the image into
41 | overlapping slices according to the height and width provided. These slices are returned
42 | as bounding boxes in xyxy format.
43 | :param image_height: Height of the original image.
44 | :param image_width: Width of the original image.
45 | :param slice_height: Height of each slice
46 | :param slice_width: Width of each slice
47 | :param overlap_height_ratio: Fractional overlap in height of each slice (e.g. an overlap of 0.2 for a slice of size 100 yields an overlap of 20 pixels)
48 | :param overlap_width_ratio: Fractional overlap in width of each slice (e.g. an overlap of 0.2 for a slice of size 100 yields an overlap of 20 pixels)
49 | :return: a list of bounding boxes in xyxy format
50 | """
51 |
52 | slice_bboxes = []
53 | y_max = y_min = 0
54 | y_overlap = int(overlap_height_ratio * slice_height)
55 | x_overlap = int(overlap_width_ratio * slice_width)
56 | while y_max < image_height:
57 | x_min = x_max = 0
58 | y_max = y_min + slice_height
59 | while x_max < image_width:
60 | x_max = x_min + slice_width
61 | if y_max > image_height or x_max > image_width:
62 | xmax = min(image_width, x_max)
63 | ymax = min(image_height, y_max)
64 | xmin = max(0, xmax - slice_width)
65 | ymin = max(0, ymax - slice_height)
66 | slice_bboxes.append([xmin, ymin, xmax, ymax])
67 | else:
68 | slice_bboxes.append([x_min, y_min, x_max, y_max])
69 | x_min = x_max - x_overlap
70 | y_min = y_max - y_overlap
71 |
72 | return slice_bboxes
73 |
74 |
75 | def get_rectangle_params_from_pascal_bbox(bbox):
76 | xmin_top_left, ymin_top_left, xmax_bottom_right, ymax_bottom_right = bbox
77 |
78 | bottom_left = (xmin_top_left, ymax_bottom_right)
79 | width = xmax_bottom_right - xmin_top_left
80 | height = ymin_top_left - ymax_bottom_right
81 |
82 | return bottom_left, width, height
83 |
84 |
85 | def draw_bboxes(
86 | plot_ax,
87 | bboxes,
88 | class_labels,
89 | get_rectangle_corners_fn=get_rectangle_params_from_pascal_bbox,
90 | ):
91 | for bbox, label in zip(bboxes, class_labels):
92 | bottom_left, width, height = get_rectangle_corners_fn(bbox)
93 |
94 | rect_1 = patches.Rectangle(
95 | bottom_left, width, height, linewidth=4, edgecolor="black", fill=False,
96 | )
97 | rect_2 = patches.Rectangle(
98 | bottom_left, width, height, linewidth=2, edgecolor="white", fill=False,
99 | )
100 | rx, ry = rect_1.get_xy()
101 |
102 | # Add the patch to the Axes
103 | plot_ax.add_patch(rect_1)
104 | plot_ax.add_patch(rect_2)
105 | plot_ax.annotate(label, (rx+width, ry+height), color='white', fontsize=20)
106 |
107 | # @gif.frame
108 | def show_image(image, bboxes=None, class_labels=None, draw_bboxes_fn=draw_bboxes):
109 | fig, ax = plt.subplots(1, figsize=(10, 10))
110 | ax.imshow(image)
111 |
112 | if bboxes:
113 | draw_bboxes_fn(ax, bboxes, class_labels)
114 |
115 | # plt.show()
116 |
117 |
118 | def bounding_box_from_points(points):
119 | points = np.array(points).flatten()
120 | even_locations = np.arange(points.shape[0]/2) * 2
121 | odd_locations = even_locations + 1
122 | X = np.take(points, even_locations.tolist())
123 | Y = np.take(points, odd_locations.tolist())
124 | bbox = [X.min(), Y.min(), X.max()-X.min(), Y.max()-Y.min()]
125 | bbox = [int(b) for b in bbox]
126 | return bbox
127 |
128 |
129 | def single_annotation(image_id, poly):
130 | _result = {}
131 | _result["image_id"] = int(image_id)
132 | _result["category_id"] = 100
133 | _result["score"] = 1
134 | _result["segmentation"] = poly
135 | _result["bbox"] = bounding_box_from_points(_result["segmentation"])
136 | return _result
137 |
138 |
139 | def main(args):
140 | BATCH_SIZE = int(args.batch_size) # 24
141 | PATCH_SIZE = int(args.img_size) # 224
142 | INPUT_HEIGHT = int(args.input_size) # 224
143 | INPUT_WIDTH = int(args.input_size) # 224
144 |
145 | # EXPERIMENT_NAME = f"CYENS_CLUSTER_train_Pix2PolyFullDataExps_inria_coco_224_negAug_run1_deit3_small_patch16_384_in21ft1k_Rotaugs_LinearWarmupLRS_NoShuffle_1.0xVertexLoss_10.0xPermLoss_0.0xVertexRegLoss__2xScoreNet_initialLR_0.0004_bs_16_Nv_192_Nbins384_LbSm_0.0_500epochs"
146 | EXPERIMENT_PATH = args.experiment_path
147 | EXPERIMENT_NAME = os.path.basename(os.path.abspath(EXPERIMENT_PATH))
148 | CHECKPOINT_NAME = args.checkpoint_name
149 | CHECKPOINT_PATH = f"runs/{EXPERIMENT_NAME}/logs/checkpoints/{CHECKPOINT_NAME}.pth"
150 |
151 | SPLIT = args.split # 'val' or 'test'
152 |
153 | test_image_dir = f"data/mass_roads_1500/test/sat"
154 | val_image_dir = f"data/mass_roads_1500/val/sat"
155 |
156 | if SPLIT == "test":
157 | test_images = []
158 | for im in os.listdir(test_image_dir):
159 | test_images.append(im)
160 | image_dir = test_image_dir
161 | images = test_images
162 | elif SPLIT == "val":
163 | val_images = []
164 | for im in os.listdir(val_image_dir):
165 | val_images.append(im)
166 | image_dir = val_image_dir
167 | images = val_images
168 | else:
169 | raise ValueError("Specify either test or val split for prediction.")
170 |
171 | test_transforms = A.Compose(
172 | [
173 | A.Resize(height=INPUT_HEIGHT, width=INPUT_WIDTH),
174 | A.Normalize(
175 | mean=[0.0, 0.0, 0.0],
176 | std=[1.0, 1.0, 1.0],
177 | max_pixel_value=255.0
178 | ),
179 | ToTensorV2(),
180 | ],
181 | )
182 |
183 | tokenizer = Tokenizer(
184 | num_classes=1,
185 | num_bins=CFG.NUM_BINS,
186 | width=INPUT_WIDTH,
187 | height=INPUT_HEIGHT,
188 | max_len=CFG.MAX_LEN
189 | )
190 | CFG.PAD_IDX = tokenizer.PAD_code
191 |
192 | encoder = Encoder(model_name=CFG.MODEL_NAME, pretrained=True, out_dim=256)
193 | decoder = Decoder(
194 | cfg=CFG,
195 | vocab_size=tokenizer.vocab_size,
196 | encoder_len=CFG.NUM_PATCHES,
197 | dim=256,
198 | num_heads=8,
199 | num_layers=6
200 | )
201 | model = EncoderDecoder(cfg=CFG, encoder=encoder, decoder=decoder)
202 | model.to(CFG.DEVICE)
203 | model.eval()
204 |
205 | checkpoint = torch.load(CHECKPOINT_PATH)
206 | model.load_state_dict(checkpoint['state_dict'])
207 | epoch = checkpoint['epochs_run']
208 |
209 | print(f"Model loaded from epoch: {epoch}")
210 | ckpt_desc = f"epoch_{epoch}"
211 | if "best_valid_loss" in os.path.basename(CHECKPOINT_PATH):
212 | ckpt_desc = f"epoch_{epoch}_bestValLoss"
213 | elif "best_valid_metric" in os.path.basename(CHECKPOINT_PATH):
214 | ckpt_desc = f"epoch_{epoch}_bestValMetric"
215 | else:
216 | pass
217 |
218 |
219 | results_dir = os.path.join(f"runs/{EXPERIMENT_NAME}", f"{SPLIT}_predictions", ckpt_desc)
220 | os.makedirs(results_dir, exist_ok=True)
221 | os.makedirs(os.path.join(results_dir, "raster_preds"), exist_ok=True)
222 | os.makedirs(os.path.join(results_dir, "polygon_preds"), exist_ok=True)
223 |
224 |
225 | with torch.no_grad():
226 | for idx, image in enumerate(tqdm(images)):
227 | print(f"<---------Processing {idx+1}/{len(images)}: {image}----------->")
228 | img_name = image
229 | if os.path.exists(os.path.join(results_dir, 'raster_preds', img_name)):
230 | continue
231 | img = Image.open(os.path.join(image_dir, img_name))
232 | img = np.array(img)
233 |
234 | slice_bboxes = calculate_slice_bboxes(
235 | image_height=img.shape[1],
236 | image_width=img.shape[0],
237 | slice_height=PATCH_SIZE,
238 | slice_width=PATCH_SIZE,
239 | overlap_height_ratio=0.2,
240 | overlap_width_ratio=0.2
241 | )
242 |
243 | speed = []
244 | predictions = []
245 | for bi, box in enumerate(tqdm(slice_bboxes)):
246 | xmin_top_left, ymin_top_left, xmax_bottom_right, ymax_bottom_right = box
247 | patch = img[ymin_top_left:ymax_bottom_right, xmin_top_left:xmax_bottom_right]
248 | patch = test_transforms(image=patch.astype(np.float32))['image'][None]
249 |
250 | all_coords = []
251 | all_confs = []
252 | t0 = time.time()
253 | batch_preds, batch_confs, perm_preds = test_generate(model, patch, tokenizer, max_len=CFG.generation_steps, top_k=0, top_p=1)
254 | speed.append(time.time() - t0)
255 | vertex_coords, confs = postprocess(batch_preds, batch_confs, tokenizer)
256 |
257 | all_coords.extend(vertex_coords)
258 | all_confs.extend(confs)
259 |
260 | coords = []
261 | for i in range(len(all_coords)):
262 | if all_coords[i] is not None:
263 | coord = torch.from_numpy(all_coords[i])
264 | else:
265 | coord = torch.tensor([])
266 |
267 | padd = torch.ones((CFG.N_VERTICES - len(coord), 2)).fill_(CFG.PAD_IDX)
268 | coord = torch.cat([coord, padd], dim=0)
269 | coords.append(coord)
270 | batch_polygons = permutations_to_polygons(perm_preds, coords, out='torch') # [0, 224]
271 |
272 | for ip, pp in enumerate(batch_polygons):
273 | if pp is not None:
274 | for p in pp:
275 | if p is not None:
276 | p = torch.fliplr(p)
277 | p = p[p[:, 0] != CFG.PAD_IDX]
278 | p = p * (PATCH_SIZE / INPUT_WIDTH)
279 | p[:, 0] = p[:, 0] + xmin_top_left
280 | p[:, 1] = p[:, 1] + ymin_top_left
281 | if len(p) > 0:
282 | if (p[0] == p[-1]).all():
283 | p = p [:-1]
284 | p = p.view(-1).tolist()
285 | if len(p) > 0:
286 | predictions.append(single_annotation(idx, [p]))
287 | # For debugging
288 | # if bi >= 10:
289 | # break
290 |
291 | H, W = img.shape[0], img.shape[1]
292 |
293 | polygons_mask = np.zeros((H, W))
294 | for pred in predictions:
295 | poly = np.array(pred['segmentation'])
296 | poly = poly.reshape((poly.shape[-1]//2, 2))
297 | cv2.polylines(polygons_mask, [np.int32(poly)], isClosed=False, color=1., thickness=5)
298 | polygons_mask = (polygons_mask*255).astype(np.uint8)
299 |
300 | cv2.imwrite(os.path.join(results_dir, 'raster_preds', img_name), polygons_mask)
301 | print("Average model speed: ", np.mean(speed), " [s / patch]")
302 | print("Time for a single tile: ", np.sum(speed), " [s / tile]")
303 |
304 | with open(f"{results_dir}/polygon_preds/{img_name.split('.')[0]}.json", "w") as fp:
305 | fp.write(json.dumps(predictions))
306 |
307 |
308 | ############# Visualizations #################:
309 | # frames = []
310 | # for slice in tqdm(slice_bboxes):
311 | # frames.append(show_image(img, [slice], ['']))
312 | # # if sid > 40:
313 | # # break
314 |
315 | # gif.save(frames, "overlapping_patches.gif",
316 | # duration=15)
317 |
318 |
319 | if __name__ == "__main__":
320 | import argparse
321 | parser = argparse.ArgumentParser()
322 |
323 | parser.add_argument("-e", "--experiment_path", help="path to experiment folder to evaluate.")
324 | parser.add_argument("-c", "--checkpoint_name", help="Choice of checkpoint to evaluate in experiment.")
325 | parser.add_argument("-s", "--split", help="Dataset split to use for prediction ('test' or 'val').")
326 | parser.add_argument("--img_size", help="Original image size.")
327 | parser.add_argument("--input_size", help="Image size of input to network.")
328 | parser.add_argument("--batch_size", help="Batch size to network.")
329 | args = parser.parse_args()
330 |
331 | main(args)
332 |
333 |
--------------------------------------------------------------------------------
/predict_spacenet_coco_val_set.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | import json
4 | from tqdm import tqdm
5 | import numpy as np
6 | import cv2
7 | import matplotlib.pyplot as plt
8 | import argparse
9 |
10 | from functools import partial
11 | import torch
12 | from torch.utils.data.distributed import DistributedSampler
13 | from torchvision.utils import make_grid
14 | import albumentations as A
15 | from albumentations.pytorch import ToTensorV2
16 |
17 | from test_config import CFG
18 | from tokenizer import Tokenizer
19 | from utils import (
20 | seed_everything,
21 | load_checkpoint,
22 | test_generate,
23 | postprocess,
24 | permutations_to_polygons,
25 | )
26 | from models.model import (
27 | Encoder,
28 | Decoder,
29 | EncoderDecoder
30 | )
31 |
32 | from torch.utils.data import DataLoader
33 | from datasets.dataset_spacenet_coco import SpacenetCocoDataset_val
34 | from torch.nn.utils.rnn import pad_sequence
35 | from torchmetrics.classification import BinaryJaccardIndex, BinaryAccuracy
36 | from torch import distributed as dist
37 | import torch.multiprocessing
38 | torch.multiprocessing.set_sharing_strategy("file_system")
39 |
40 |
41 | parser = argparse.ArgumentParser()
42 | parser.add_argument("-d", "--dataset", help="Dataset to use for evaluation.")
43 | parser.add_argument("-e", "--experiment_path", help="path to experiment folder to evaluate.")
44 | parser.add_argument("-c", "--checkpoint_name", help="Choice of checkpoint to evaluate in experiment.")
45 | parser.add_argument("-o", "--output_dir", help="Name of output subdirectory to store part predictions.")
46 | args = parser.parse_args()
47 |
48 |
49 | torch.backends.cuda.matmul.allow_tf32 = True
50 | torch.backends.cudnn.allow_tf32 = True
51 |
52 | DATASET = f"{args.dataset}"
53 | VAL_DATASET_DIR = f"./data/{DATASET}/val"
54 | # PART_DESC = "val_images"
55 | PART_DESC = f"{args.output_dir}"
56 |
57 | EXPERIMENT_NAME = os.path.basename(os.path.realpath(args.experiment_path))
58 | CHECKPOINT_PATH = f"runs/{EXPERIMENT_NAME}/logs/checkpoints/{args.checkpoint_name}.pth"
59 | BATCH_SIZE = 24
60 |
61 |
62 | def bounding_box_from_points(points):
63 | points = np.array(points).flatten()
64 | even_locations = np.arange(points.shape[0]/2) * 2
65 | odd_locations = even_locations + 1
66 | X = np.take(points, even_locations.tolist())
67 | Y = np.take(points, odd_locations.tolist())
68 | bbox = [X.min(), Y.min(), X.max()-X.min(), Y.max()-Y.min()]
69 | bbox = [int(b) for b in bbox]
70 | return bbox
71 |
72 |
73 | def single_annotation(image_id, poly):
74 | _result = {}
75 | _result["image_id"] = int(image_id)
76 | _result["category_id"] = 100
77 | _result["score"] = 1
78 | _result["segmentation"] = poly
79 | _result["bbox"] = bounding_box_from_points(_result["segmentation"])
80 | return _result
81 |
82 |
83 | def collate_fn(batch, max_len, pad_idx):
84 | """
85 | if max_len:
86 | the sequences will all be padded to that length.
87 | """
88 |
89 | image_batch, mask_batch, coords_mask_batch, coords_seq_batch, perm_matrix_batch, idx_batch = [], [], [], [], [], []
90 | for image, mask, c_mask, seq, perm_mat, idx in batch:
91 | image_batch.append(image)
92 | mask_batch.append(mask)
93 | coords_mask_batch.append(c_mask)
94 | coords_seq_batch.append(seq)
95 | perm_matrix_batch.append(perm_mat)
96 | idx_batch.append(idx)
97 |
98 | coords_seq_batch = pad_sequence(
99 | coords_seq_batch,
100 | padding_value=pad_idx,
101 | batch_first=True
102 | )
103 |
104 | if max_len:
105 | pad = torch.ones(coords_seq_batch.size(0), max_len - coords_seq_batch.size(1)).fill_(pad_idx).long()
106 | coords_seq_batch = torch.cat([coords_seq_batch, pad], dim=1)
107 |
108 | image_batch = torch.stack(image_batch)
109 | mask_batch = torch.stack(mask_batch)
110 | coords_mask_batch = torch.stack(coords_mask_batch)
111 | perm_matrix_batch = torch.stack(perm_matrix_batch)
112 | idx_batch = torch.stack(idx_batch)
113 | return image_batch, mask_batch, coords_mask_batch, coords_seq_batch, perm_matrix_batch, idx_batch
114 |
115 |
116 | def main():
117 | seed_everything(42)
118 |
119 | valid_transforms = A.Compose(
120 | [
121 | A.Resize(height=CFG.INPUT_HEIGHT, width=CFG.INPUT_WIDTH),
122 | A.Normalize(
123 | mean=[0.0, 0.0, 0.0],
124 | std=[1.0, 1.0, 1.0],
125 | max_pixel_value=255.0
126 | ),
127 | ToTensorV2(),
128 | ],
129 | keypoint_params=A.KeypointParams(format='yx', remove_invisible=False)
130 | )
131 |
132 | tokenizer = Tokenizer(
133 | num_classes=1,
134 | num_bins=CFG.NUM_BINS,
135 | width=CFG.INPUT_WIDTH,
136 | height=CFG.INPUT_HEIGHT,
137 | max_len=CFG.MAX_LEN
138 | )
139 | CFG.PAD_IDX = tokenizer.PAD_code
140 |
141 | val_ds = SpacenetCocoDataset_val(
142 | cfg=CFG,
143 | dataset_dir=VAL_DATASET_DIR,
144 | transform=valid_transforms,
145 | tokenizer=tokenizer,
146 | shuffle_tokens=CFG.SHUFFLE_TOKENS
147 | )
148 | val_loader = DataLoader(
149 | val_ds,
150 | batch_size=BATCH_SIZE,
151 | collate_fn=partial(collate_fn, max_len=CFG.MAX_LEN, pad_idx=CFG.PAD_IDX),
152 | num_workers=2
153 | )
154 |
155 | encoder = Encoder(model_name=CFG.MODEL_NAME, pretrained=True, out_dim=256)
156 | decoder = Decoder(
157 | cfg=CFG,
158 | vocab_size=tokenizer.vocab_size,
159 | encoder_len=CFG.NUM_PATCHES,
160 | dim=256,
161 | num_heads=8,
162 | num_layers=6
163 | )
164 | model = EncoderDecoder(cfg=CFG, encoder=encoder, decoder=decoder)
165 | model.to(CFG.DEVICE)
166 | model.eval()
167 |
168 | checkpoint = torch.load(CHECKPOINT_PATH)
169 | model.load_state_dict(checkpoint['state_dict'])
170 | epoch = checkpoint['epochs_run']
171 |
172 | print(f"Model loaded from epoch: {epoch}")
173 | ckpt_desc = f"epoch_{epoch}"
174 | if "best_valid_loss" in os.path.basename(CHECKPOINT_PATH):
175 | ckpt_desc = f"epoch_{epoch}_bestValLoss"
176 | elif "best_valid_metric" in os.path.basename(CHECKPOINT_PATH):
177 | ckpt_desc = f"epoch_{epoch}_bestValMetric"
178 | else:
179 | pass
180 |
181 | mean_iou_metric = BinaryJaccardIndex()
182 | mean_acc_metric = BinaryAccuracy()
183 |
184 |
185 | with torch.no_grad():
186 | cumulative_miou = []
187 | cumulative_macc = []
188 | speed = []
189 | predictions = []
190 | for i_batch, (x, y_mask, y_corner_mask, y, y_perm, idx) in enumerate(tqdm(val_loader)):
191 | all_coords = []
192 | all_confs = []
193 | t0 = time.time()
194 | batch_preds, batch_confs, perm_preds = test_generate(model, x, tokenizer, max_len=CFG.generation_steps, top_k=0, top_p=1)
195 | speed.append(time.time() - t0)
196 | vertex_coords, confs = postprocess(batch_preds, batch_confs, tokenizer)
197 |
198 | all_coords.extend(vertex_coords)
199 | all_confs.extend(confs)
200 |
201 | coords = []
202 | for i in range(len(all_coords)):
203 | if all_coords[i] is not None:
204 | coord = torch.from_numpy(all_coords[i])
205 | else:
206 | coord = torch.tensor([])
207 |
208 | padd = torch.ones((CFG.N_VERTICES - len(coord), 2)).fill_(CFG.PAD_IDX)
209 | coord = torch.cat([coord, padd], dim=0)
210 | coords.append(coord)
211 | batch_polygons = permutations_to_polygons(perm_preds, coords, out='torch') # [0, 224]
212 | # pred_polygons = permutations_to_polygons(perm_preds, coords, out='coco') # [0, 224]
213 |
214 | for ip, pp in enumerate(batch_polygons):
215 | for p in pp:
216 | p = torch.fliplr(p)
217 | p = p[p[:, 0] != CFG.PAD_IDX]
218 | p = p * (CFG.IMG_SIZE / CFG.INPUT_WIDTH)
219 | p = p.view(-1).tolist()
220 | if len(p) > 0:
221 | predictions.append(single_annotation(idx[ip], [p]))
222 |
223 | B, C, H, W = x.shape
224 |
225 | polygons_mask = np.zeros((B, 1, H, W))
226 | for b in range(len(batch_polygons)):
227 | for c in range(len(batch_polygons[b])):
228 | poly = batch_polygons[b][c]
229 | poly = poly[poly[:, 0] != CFG.PAD_IDX]
230 | cnt = np.flip(np.int32(poly.cpu()), 1)
231 | if len(cnt) > 0:
232 | cv2.fillPoly(polygons_mask[b, 0], pts=[cnt], color=1.)
233 | polygons_mask = torch.from_numpy(polygons_mask)
234 |
235 | batch_miou = mean_iou_metric(polygons_mask, y_mask)
236 | batch_macc = mean_acc_metric(polygons_mask, y_mask)
237 |
238 | cumulative_miou.append(batch_miou)
239 | cumulative_macc.append(batch_macc)
240 |
241 | pred_grid = make_grid(polygons_mask).permute(1, 2, 0)
242 | gt_grid = make_grid(y_mask).permute(1, 2, 0)
243 | plt.subplot(211), plt.imshow(pred_grid) ,plt.title("Predicted Polygons") ,plt.axis('off')
244 | plt.subplot(212), plt.imshow(gt_grid) ,plt.title("Ground Truth") ,plt.axis('off')
245 |
246 | if not os.path.exists(os.path.join(f"runs/{EXPERIMENT_NAME}", 'val_preds', DATASET, PART_DESC, ckpt_desc)):
247 | os.makedirs(os.path.join(f"runs/{EXPERIMENT_NAME}", 'val_preds', DATASET, PART_DESC, ckpt_desc))
248 | plt.savefig(f"runs/{EXPERIMENT_NAME}/val_preds/{DATASET}/{PART_DESC}/{ckpt_desc}/batch_{i_batch}.png")
249 | plt.close()
250 |
251 | print("Average model speed: ", np.mean(speed) / BATCH_SIZE, " [s / image]")
252 |
253 | print(f"Average Mean IOU: {torch.tensor(cumulative_miou).nanmean()}")
254 | print(f"Average Mean Acc: {torch.tensor(cumulative_macc).nanmean()}")
255 |
256 | with open(f"runs/{EXPERIMENT_NAME}/predictions_{DATASET}_{PART_DESC}_{ckpt_desc}.json", "w") as fp:
257 | fp.write(json.dumps(predictions))
258 |
259 | with open(f"runs/{EXPERIMENT_NAME}/val_metrics_{DATASET}_{PART_DESC}_{ckpt_desc}.txt", 'w') as ff:
260 | print(f"Average Mean IOU: {torch.tensor(cumulative_miou).nanmean()}", file=ff)
261 | print(f"Average Mean Acc: {torch.tensor(cumulative_macc).nanmean()}", file=ff)
262 |
263 |
264 | if __name__ == "__main__":
265 | main()
266 |
267 |
--------------------------------------------------------------------------------
/predict_whu_buildings_coco_test_set.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | import json
4 | from tqdm import tqdm
5 | import numpy as np
6 | import cv2
7 | import matplotlib.pyplot as plt
8 | import argparse
9 |
10 | from functools import partial
11 | import torch
12 | from torchvision.utils import make_grid
13 | import albumentations as A
14 | from albumentations.pytorch import ToTensorV2
15 |
16 | from test_config import CFG
17 | from tokenizer import Tokenizer
18 | from utils import (
19 | seed_everything,
20 | load_checkpoint,
21 | test_generate,
22 | postprocess,
23 | permutations_to_polygons,
24 | )
25 | from models.model import (
26 | Encoder,
27 | Decoder,
28 | EncoderDecoder
29 | )
30 |
31 | from torch.utils.data import DataLoader
32 | from datasets.dataset_whu_buildings_coco import WHUBuildingsCocoDataset_val
33 | from torch.nn.utils.rnn import pad_sequence
34 | from torchmetrics.classification import BinaryJaccardIndex, BinaryAccuracy
35 | import torch.multiprocessing
36 | torch.multiprocessing.set_sharing_strategy("file_system")
37 |
38 |
39 | parser = argparse.ArgumentParser()
40 | parser.add_argument("-d", "--dataset", help="Dataset to use for evaluation.")
41 | parser.add_argument("-e", "--experiment_path", help="path to experiment folder to evaluate.")
42 | parser.add_argument("-c", "--checkpoint_name", help="Choice of checkpoint to evaluate in experiment.")
43 | parser.add_argument("-o", "--output_dir", help="Name of output subdirectory to store part predictions.")
44 | args = parser.parse_args()
45 |
46 |
47 | torch.backends.cuda.matmul.allow_tf32 = True
48 | torch.backends.cudnn.allow_tf32 = True
49 |
50 | DATASET = f"{args.dataset}"
51 | VAL_DATASET_DIR = f"./data/{DATASET}/test"
52 | PART_DESC = f"{args.output_dir}"
53 |
54 | EXPERIMENT_NAME = os.path.basename(os.path.realpath(args.experiment_path))
55 | CHECKPOINT_PATH = f"runs/{EXPERIMENT_NAME}/logs/checkpoints/{args.checkpoint_name}.pth"
56 | BATCH_SIZE = 24
57 |
58 |
59 | def bounding_box_from_points(points):
60 | points = np.array(points).flatten()
61 | even_locations = np.arange(points.shape[0]/2) * 2
62 | odd_locations = even_locations + 1
63 | X = np.take(points, even_locations.tolist())
64 | Y = np.take(points, odd_locations.tolist())
65 | bbox = [X.min(), Y.min(), X.max()-X.min(), Y.max()-Y.min()]
66 | bbox = [int(b) for b in bbox]
67 | return bbox
68 |
69 |
70 | def single_annotation(image_id, poly):
71 | _result = {}
72 | _result["image_id"] = int(image_id)
73 | _result["category_id"] = 100
74 | _result["score"] = 1
75 | _result["segmentation"] = poly
76 | _result["bbox"] = bounding_box_from_points(_result["segmentation"])
77 | return _result
78 |
79 |
80 | def collate_fn(batch, max_len, pad_idx):
81 | """
82 | if max_len:
83 | the sequences will all be padded to that length.
84 | """
85 |
86 | image_batch, mask_batch, coords_mask_batch, coords_seq_batch, perm_matrix_batch, idx_batch = [], [], [], [], [], []
87 | for image, mask, c_mask, seq, perm_mat, idx in batch:
88 | image_batch.append(image)
89 | mask_batch.append(mask)
90 | coords_mask_batch.append(c_mask)
91 | coords_seq_batch.append(seq)
92 | perm_matrix_batch.append(perm_mat)
93 | idx_batch.append(idx)
94 |
95 | coords_seq_batch = pad_sequence(
96 | coords_seq_batch,
97 | padding_value=pad_idx,
98 | batch_first=True
99 | )
100 |
101 | if max_len:
102 | pad = torch.ones(coords_seq_batch.size(0), max_len - coords_seq_batch.size(1)).fill_(pad_idx).long()
103 | coords_seq_batch = torch.cat([coords_seq_batch, pad], dim=1)
104 |
105 | image_batch = torch.stack(image_batch)
106 | mask_batch = torch.stack(mask_batch)
107 | coords_mask_batch = torch.stack(coords_mask_batch)
108 | perm_matrix_batch = torch.stack(perm_matrix_batch)
109 | idx_batch = torch.stack(idx_batch)
110 | return image_batch, mask_batch, coords_mask_batch, coords_seq_batch, perm_matrix_batch, idx_batch
111 |
112 |
113 | def main():
114 | seed_everything(42)
115 |
116 | valid_transforms = A.Compose(
117 | [
118 | A.Resize(height=CFG.INPUT_HEIGHT, width=CFG.INPUT_WIDTH),
119 | A.Normalize(
120 | mean=[0.0, 0.0, 0.0],
121 | std=[1.0, 1.0, 1.0],
122 | max_pixel_value=255.0
123 | ),
124 | ToTensorV2(),
125 | ],
126 | keypoint_params=A.KeypointParams(format='yx', remove_invisible=False)
127 | )
128 |
129 | tokenizer = Tokenizer(
130 | num_classes=1,
131 | num_bins=CFG.NUM_BINS,
132 | width=CFG.INPUT_WIDTH,
133 | height=CFG.INPUT_HEIGHT,
134 | max_len=CFG.MAX_LEN
135 | )
136 | CFG.PAD_IDX = tokenizer.PAD_code
137 |
138 | val_ds = WHUBuildingsCocoDataset_val(
139 | cfg=CFG,
140 | dataset_dir=VAL_DATASET_DIR,
141 | transform=valid_transforms,
142 | tokenizer=tokenizer,
143 | shuffle_tokens=CFG.SHUFFLE_TOKENS
144 | )
145 | val_loader = DataLoader(
146 | val_ds,
147 | batch_size=BATCH_SIZE,
148 | collate_fn=partial(collate_fn, max_len=CFG.MAX_LEN, pad_idx=CFG.PAD_IDX),
149 | num_workers=2
150 | )
151 |
152 | encoder = Encoder(model_name=CFG.MODEL_NAME, pretrained=True, out_dim=256)
153 | decoder = Decoder(
154 | cfg=CFG,
155 | vocab_size=tokenizer.vocab_size,
156 | encoder_len=CFG.NUM_PATCHES,
157 | dim=256,
158 | num_heads=8,
159 | num_layers=6
160 | )
161 | model = EncoderDecoder(cfg=CFG, encoder=encoder, decoder=decoder)
162 | model.to(CFG.DEVICE)
163 | model.eval()
164 |
165 | checkpoint = torch.load(CHECKPOINT_PATH)
166 | model.load_state_dict(checkpoint['state_dict'])
167 | epoch = checkpoint['epochs_run']
168 |
169 | print(f"Model loaded from epoch: {epoch}")
170 | ckpt_desc = f"epoch_{epoch}"
171 | if "best_valid_loss" in os.path.basename(CHECKPOINT_PATH):
172 | ckpt_desc = f"epoch_{epoch}_bestValLoss"
173 | elif "best_valid_metric" in os.path.basename(CHECKPOINT_PATH):
174 | ckpt_desc = f"epoch_{epoch}_bestValMetric"
175 | else:
176 | pass
177 |
178 | mean_iou_metric = BinaryJaccardIndex()
179 | mean_acc_metric = BinaryAccuracy()
180 |
181 |
182 | with torch.no_grad():
183 | cumulative_miou = []
184 | cumulative_macc = []
185 | speed = []
186 | predictions = []
187 | for i_batch, (x, y_mask, y_corner_mask, y, y_perm, idx) in enumerate(tqdm(val_loader)):
188 | all_coords = []
189 | all_confs = []
190 | t0 = time.time()
191 | batch_preds, batch_confs, perm_preds = test_generate(model, x, tokenizer, max_len=CFG.generation_steps, top_k=0, top_p=1)
192 | speed.append(time.time() - t0)
193 | vertex_coords, confs = postprocess(batch_preds, batch_confs, tokenizer)
194 |
195 | all_coords.extend(vertex_coords)
196 | all_confs.extend(confs)
197 |
198 | coords = []
199 | for i in range(len(all_coords)):
200 | if all_coords[i] is not None:
201 | coord = torch.from_numpy(all_coords[i])
202 | else:
203 | coord = torch.tensor([])
204 |
205 | padd = torch.ones((CFG.N_VERTICES - len(coord), 2)).fill_(CFG.PAD_IDX)
206 | coord = torch.cat([coord, padd], dim=0)
207 | coords.append(coord)
208 | batch_polygons = permutations_to_polygons(perm_preds, coords, out='torch') # [0, 224]
209 | # pred_polygons = permutations_to_polygons(perm_preds, coords, out='coco') # [0, 224]
210 |
211 | for ip, pp in enumerate(batch_polygons):
212 | for p in pp:
213 | p = torch.fliplr(p)
214 | p = p[p[:, 0] != CFG.PAD_IDX]
215 | p = p * (CFG.IMG_SIZE / CFG.INPUT_WIDTH)
216 | p = p.view(-1).tolist()
217 | if len(p) > 0:
218 | predictions.append(single_annotation(idx[ip], [p]))
219 |
220 | B, C, H, W = x.shape
221 |
222 | polygons_mask = np.zeros((B, 1, H, W))
223 | for b in range(len(batch_polygons)):
224 | for c in range(len(batch_polygons[b])):
225 | poly = batch_polygons[b][c]
226 | poly = poly[poly[:, 0] != CFG.PAD_IDX]
227 | cnt = np.flip(np.int32(poly.cpu()), 1)
228 | if len(cnt) > 0:
229 | cv2.fillPoly(polygons_mask[b, 0], pts=[cnt], color=1.)
230 | polygons_mask = torch.from_numpy(polygons_mask)
231 |
232 | batch_miou = mean_iou_metric(polygons_mask, y_mask)
233 | batch_macc = mean_acc_metric(polygons_mask, y_mask)
234 |
235 | cumulative_miou.append(batch_miou)
236 | cumulative_macc.append(batch_macc)
237 |
238 | pred_grid = make_grid(polygons_mask).permute(1, 2, 0)
239 | gt_grid = make_grid(y_mask).permute(1, 2, 0)
240 | plt.subplot(211), plt.imshow(pred_grid) ,plt.title("Predicted Polygons") ,plt.axis('off')
241 | plt.subplot(212), plt.imshow(gt_grid) ,plt.title("Ground Truth") ,plt.axis('off')
242 |
243 | if not os.path.exists(os.path.join(f"runs/{EXPERIMENT_NAME}", 'test_preds', DATASET, PART_DESC, ckpt_desc)):
244 | os.makedirs(os.path.join(f"runs/{EXPERIMENT_NAME}", 'test_preds', DATASET, PART_DESC, ckpt_desc))
245 | plt.savefig(f"runs/{EXPERIMENT_NAME}/test_preds/{DATASET}/{PART_DESC}/{ckpt_desc}/batch_{i_batch}.png")
246 | plt.close()
247 |
248 | print("Average model speed: ", np.mean(speed) / BATCH_SIZE, " [s / image]")
249 |
250 | print(f"Average Mean IOU: {torch.tensor(cumulative_miou).nanmean()}")
251 | print(f"Average Mean Acc: {torch.tensor(cumulative_macc).nanmean()}")
252 |
253 | with open(f"runs/{EXPERIMENT_NAME}/predictions_{DATASET}_{PART_DESC}_{ckpt_desc}.json", "w") as fp:
254 | fp.write(json.dumps(predictions))
255 |
256 | with open(f"runs/{EXPERIMENT_NAME}/test_metrics_{DATASET}_{PART_DESC}_{ckpt_desc}.txt", 'w') as ff:
257 | print(f"Average Mean IOU: {torch.tensor(cumulative_miou).nanmean()}", file=ff)
258 | print(f"Average Mean Acc: {torch.tensor(cumulative_macc).nanmean()}", file=ff)
259 |
260 |
261 | if __name__ == "__main__":
262 | main()
263 |
264 |
--------------------------------------------------------------------------------
/pyrightconfig.json:
--------------------------------------------------------------------------------
1 | {
2 | "typeCheckingMode": "off"
3 | }
4 |
--------------------------------------------------------------------------------
/tokenizer.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from config import CFG
4 |
5 |
6 | class Tokenizer:
7 | def __init__(self, num_classes: int, num_bins: int, width: int, height: int, max_len=256):
8 | self.num_classes = num_classes
9 | self.num_bins = num_bins
10 | self.width = width
11 | self.height = height
12 | self.max_len = max_len
13 |
14 | self.BOS_code = num_bins
15 | self.EOS_code = self.BOS_code + 1
16 | self.PAD_code = self.EOS_code + 1
17 |
18 | self.vocab_size = num_bins + 3 #+ num_classes
19 |
20 | def quantize(self, x: np.array):
21 | """
22 | x is a real number in [0, 1]
23 | """
24 |
25 | return (x * (self.num_bins - 1)).round(0).astype('int')
26 |
27 | def dequantize(self, x: np.array):
28 | """
29 | x is an integer between [0, num_bins-1]
30 | """
31 |
32 | return x.astype('float32') / (self.num_bins - 1)
33 |
34 | def __call__(self, coords: np.array, shuffle=True):
35 |
36 | if len(coords) > 0:
37 | coords[:, 0] = coords[:, 0] / self.width
38 | coords[:, 1] = coords[:, 1] / self.height
39 |
40 | coords = self.quantize(coords)[:self.max_len]
41 |
42 | if shuffle:
43 | rand_idxs = np.arange(0, len(coords))
44 | if 'debug' in CFG.EXPERIMENT_NAME:
45 | rand_idxs = rand_idxs[::-1]
46 | else:
47 | np.random.shuffle(rand_idxs)
48 | coords = coords[rand_idxs]
49 | else:
50 | rand_idxs = np.arange(0, len(coords))
51 |
52 | tokenized = [self.BOS_code]
53 | for coord in coords:
54 | tokens = list(coord)
55 |
56 | tokenized.extend(list(map(int, tokens)))
57 | tokenized.append(self.EOS_code)
58 |
59 | return tokenized, rand_idxs
60 |
61 | def decode(self, tokens: torch.Tensor):
62 | """
63 | tokens: torch.LongTensor with shape [L]
64 | """
65 |
66 | mask = tokens != self.PAD_code
67 | tokens = tokens[mask]
68 | tokens = tokens[1:-1]
69 | assert len(tokens) % 2 == 0, "Invalid tokens!"
70 |
71 | coords = []
72 | for i in range(2, len(tokens)+1, 2):
73 | coord = tokens[i-2: i]
74 | coords.append([int(item) for item in coord])
75 | coords = np.array(coords)
76 | coords = self.dequantize(coords)
77 |
78 | if len(coords) > 0:
79 | coords[:, 0] = coords[:, 0] * self.width
80 | coords[:, 1] = coords[:, 1] * self.height
81 |
82 | return coords
83 |
84 |
--------------------------------------------------------------------------------
/train_ddp.py:
--------------------------------------------------------------------------------
1 | import os
2 | from os import path as osp
3 | import torch
4 | from torch import nn
5 | from torch import optim
6 | import albumentations as A
7 | from albumentations.pytorch import ToTensorV2
8 | from transformers import (
9 | get_linear_schedule_with_warmup,
10 | )
11 | from torch.utils.tensorboard import SummaryWriter
12 |
13 | from config import CFG
14 | from tokenizer import Tokenizer
15 | from utils import (
16 | seed_everything,
17 | load_checkpoint,
18 | )
19 | from ddp_utils import (
20 | get_inria_loaders,
21 | get_spacenet_loaders,
22 | get_whu_buildings_loaders,
23 | get_mass_roads_loaders,
24 | )
25 |
26 | from models.model import (
27 | Encoder,
28 | Decoder,
29 | EncoderDecoder
30 | )
31 |
32 | from engine import train_eval
33 |
34 | from torch import distributed as dist
35 | import torch.multiprocessing
36 | torch.multiprocessing.set_sharing_strategy("file_system")
37 |
38 |
39 | def init_distributed():
40 |
41 | # Initializes the distributed backend which will take care of synchronizing nodes/GPUs.
42 | dist_url = "env://" # default
43 |
44 | # only works with torch.distributed.launch or torch.run.
45 | rank = int(os.environ["RANK"])
46 | world_size = int(os.environ["WORLD_SIZE"])
47 | local_rank = int(os.environ["LOCAL_RANK"])
48 | dist.init_process_group(
49 | backend="nccl",
50 | init_method=dist_url,
51 | world_size=world_size,
52 | rank=rank
53 | )
54 |
55 | # this will make all .cuda() calls work properly.
56 | torch.cuda.set_device(local_rank)
57 |
58 | # synchronizes all threads to reach this point before moving on.
59 | dist.barrier()
60 |
61 |
62 | def main():
63 | # setup the process groups
64 | init_distributed()
65 | seed_everything(42)
66 |
67 | # Define tensorboard for logging.
68 | writer = SummaryWriter(f"runs/{CFG.EXPERIMENT_NAME}/logs/tensorboard")
69 | attrs = vars(CFG)
70 | with open(f"runs/{CFG.EXPERIMENT_NAME}/config.txt", "w") as f:
71 | print("\n".join("%s: %s" % item for item in attrs.items()), file=f)
72 |
73 | train_transforms = A.Compose(
74 | [
75 | A.Affine(rotate=[-360, 360], fit_output=True, p=0.8), # scaled rotations are performed before resizing to ensure rotated and scaled images are correctly resized.
76 | A.Resize(height=CFG.INPUT_HEIGHT, width=CFG.INPUT_WIDTH),
77 | A.RandomRotate90(p=1.),
78 | A.RandomBrightnessContrast(p=0.5),
79 | A.ColorJitter(),
80 | A.ToGray(p=0.4),
81 | A.GaussNoise(),
82 | # ToTensorV2 of albumentations doesn't divide by 255 like in PyTorch,
83 | # it is done inside Normalize function.
84 | A.Normalize(
85 | mean=[0.0, 0.0, 0.0],
86 | std=[1.0, 1.0, 1.0],
87 | max_pixel_value=255.0
88 | ),
89 | ToTensorV2(),
90 | ],
91 | keypoint_params=A.KeypointParams(format='yx', remove_invisible=False)
92 | )
93 |
94 | valid_transforms = A.Compose(
95 | [
96 | A.Resize(height=CFG.INPUT_HEIGHT, width=CFG.INPUT_WIDTH),
97 | A.Normalize(
98 | mean=[0.0, 0.0, 0.0],
99 | std=[1.0, 1.0, 1.0],
100 | max_pixel_value=255.0
101 | ),
102 | ToTensorV2(),
103 | ],
104 | keypoint_params=A.KeypointParams(format='yx', remove_invisible=False)
105 | )
106 |
107 | if "debug" in CFG.EXPERIMENT_NAME:
108 | train_transforms = valid_transforms
109 |
110 | tokenizer = Tokenizer(
111 | num_classes=1,
112 | num_bins=CFG.NUM_BINS,
113 | width=CFG.INPUT_WIDTH,
114 | height=CFG.INPUT_HEIGHT,
115 | max_len=CFG.MAX_LEN
116 | )
117 | CFG.PAD_IDX = tokenizer.PAD_code
118 |
119 | if "inria" in CFG.DATASET:
120 | train_loader, val_loader, _ = get_inria_loaders(
121 | CFG.TRAIN_DATASET_DIR,
122 | CFG.VAL_DATASET_DIR,
123 | CFG.TEST_IMAGES_DIR,
124 | tokenizer,
125 | CFG.MAX_LEN,
126 | tokenizer.PAD_code,
127 | CFG.SHUFFLE_TOKENS,
128 | CFG.BATCH_SIZE,
129 | train_transforms,
130 | valid_transforms,
131 | CFG.NUM_WORKERS,
132 | CFG.PIN_MEMORY,
133 | )
134 | elif "spacenet" in CFG.DATASET:
135 | train_loader, val_loader, _ = get_spacenet_loaders(
136 | CFG.TRAIN_DATASET_DIR,
137 | CFG.VAL_DATASET_DIR,
138 | CFG.TEST_IMAGES_DIR,
139 | tokenizer,
140 | CFG.MAX_LEN,
141 | tokenizer.PAD_code,
142 | CFG.SHUFFLE_TOKENS,
143 | CFG.BATCH_SIZE,
144 | train_transforms,
145 | valid_transforms,
146 | CFG.NUM_WORKERS,
147 | CFG.PIN_MEMORY,
148 | )
149 | elif "whu_buildings" in CFG.DATASET:
150 | train_loader, val_loader, _ = get_whu_buildings_loaders(
151 | CFG.TRAIN_DATASET_DIR,
152 | CFG.VAL_DATASET_DIR,
153 | CFG.TEST_IMAGES_DIR,
154 | tokenizer,
155 | CFG.MAX_LEN,
156 | tokenizer.PAD_code,
157 | CFG.SHUFFLE_TOKENS,
158 | CFG.BATCH_SIZE,
159 | train_transforms,
160 | valid_transforms,
161 | CFG.NUM_WORKERS,
162 | CFG.PIN_MEMORY,
163 | )
164 | elif "mass_roads" in CFG.DATASET:
165 | train_loader, val_loader, test_loader = get_mass_roads_loaders(
166 | CFG.TRAIN_DATASET_DIR,
167 | CFG.VAL_DATASET_DIR,
168 | CFG.TEST_IMAGES_DIR,
169 | tokenizer,
170 | CFG.MAX_LEN,
171 | tokenizer.PAD_code,
172 | CFG.SHUFFLE_TOKENS,
173 | CFG.BATCH_SIZE,
174 | train_transforms,
175 | valid_transforms,
176 | CFG.NUM_WORKERS,
177 | CFG.PIN_MEMORY,
178 | )
179 | else:
180 | pass
181 |
182 | encoder = Encoder(model_name=CFG.MODEL_NAME, pretrained=True, out_dim=256)
183 | decoder = Decoder(
184 | cfg=CFG,
185 | vocab_size=tokenizer.vocab_size,
186 | encoder_len=CFG.NUM_PATCHES,
187 | dim=256,
188 | num_heads=8,
189 | num_layers=6
190 | )
191 | model = EncoderDecoder(cfg=CFG, encoder=encoder, decoder=decoder)
192 | model.to(CFG.DEVICE)
193 |
194 | weight = torch.ones(CFG.PAD_IDX + 1, device=CFG.DEVICE)
195 | weight[tokenizer.num_bins:tokenizer.BOS_code] = 0.0
196 | vertex_loss_fn = nn.CrossEntropyLoss(ignore_index=CFG.PAD_IDX, label_smoothing=CFG.LABEL_SMOOTHING, weight=weight)
197 | perm_loss_fn = nn.BCELoss()
198 |
199 | optimizer = optim.AdamW(model.parameters(), lr=CFG.LR, weight_decay=CFG.WEIGHT_DECAY, betas=(0.9, 0.95))
200 |
201 | num_training_steps = CFG.NUM_EPOCHS * (len(train_loader.dataset) // CFG.BATCH_SIZE // torch.cuda.device_count())
202 | num_warmup_steps = int(0.05 * num_training_steps)
203 | lr_scheduler = get_linear_schedule_with_warmup(
204 | optimizer,
205 | num_training_steps=num_training_steps,
206 | num_warmup_steps=num_warmup_steps
207 | )
208 |
209 | local_rank = int(os.environ["LOCAL_RANK"])
210 | CFG.START_EPOCH = 0
211 | if CFG.LOAD_MODEL:
212 | checkpoint_name = osp.basename(osp.realpath(CFG.CHECKPOINT_PATH))
213 | map_location = {'cuda:%d' % 0: 'cuda:%d' % local_rank}
214 | start_epoch = load_checkpoint(
215 | torch.load(f"runs/{CFG.EXPERIMENT_NAME}/logs/checkpoints/{checkpoint_name}", map_location=map_location),
216 | model,
217 | optimizer,
218 | lr_scheduler
219 | )
220 | CFG.START_EPOCH = start_epoch + 1
221 | dist.barrier()
222 |
223 | # Convert BatchNorm in model to SyncBatchNorm.
224 | model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
225 | # Wrap model with distributed data parallel.
226 | model = nn.parallel.DistributedDataParallel(model, device_ids=[local_rank])
227 |
228 | train_eval(
229 | model,
230 | train_loader,
231 | val_loader,
232 | val_loader,
233 | tokenizer,
234 | vertex_loss_fn,
235 | perm_loss_fn,
236 | optimizer,
237 | lr_scheduler=lr_scheduler,
238 | step='batch',
239 | writer=writer
240 | )
241 |
242 |
243 | if __name__ == "__main__":
244 | main()
245 |
246 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 | import random
4 | import numpy as np
5 | from scipy.optimize import linear_sum_assignment
6 |
7 | import torch
8 | import torchvision
9 | from transformers import top_k_top_p_filtering
10 | from torchmetrics.functional.classification import binary_jaccard_index, binary_accuracy
11 | from config import CFG
12 |
13 |
14 | def seed_everything(seed=1234):
15 | random.seed(seed)
16 | os.environ['PYTHONHASHSEED'] = str(seed)
17 | np.random.seed(seed)
18 | torch.manual_seed(seed)
19 | torch.cuda.manual_seed(seed)
20 | torch.backends.cudnn.deterministic = True
21 |
22 |
23 | def save_checkpoint(state, folder="logs/checkpoint/run1", filename="my_checkpoint.pth.tar"):
24 | print("=> Saving checkpoint")
25 | if not os.path.exists(folder):
26 | os.makedirs(folder)
27 | torch.save(state, os.path.join(folder, filename))
28 |
29 |
30 | def load_checkpoint(checkpoint, model, optimizer, scheduler):
31 | print("=> Loading checkpoint")
32 | model.load_state_dict(checkpoint["state_dict"])
33 | optimizer.load_state_dict(checkpoint["optimizer"])
34 | scheduler.load_state_dict(checkpoint["scheduler"])
35 |
36 | return checkpoint["epochs_run"]
37 |
38 |
39 | def generate_square_subsequent_mask(sz):
40 | mask = (
41 | torch.triu(torch.ones((sz, sz), device=CFG.DEVICE)) == 1
42 | ).transpose(0, 1)
43 |
44 | mask = mask.float().masked_fill(mask==0, float('-inf')).masked_fill(mask==1, float(0.0))
45 |
46 | return mask
47 |
48 |
49 | def create_mask(tgt, pad_idx):
50 | """
51 | tgt shape: (N, L)
52 | """
53 |
54 | tgt_seq_len = tgt.size(1)
55 | tgt_mask = generate_square_subsequent_mask(tgt_seq_len)
56 | tgt_padding_mask = (tgt == pad_idx)
57 |
58 | return tgt_mask, tgt_padding_mask
59 |
60 |
61 | class AverageMeter:
62 | def __init__(self, name="Metric"):
63 | self.name = name
64 | self.reset()
65 |
66 | def reset(self):
67 | self.avg, self.sum, self.count = [0]*3
68 |
69 | def update(self, val, count=1):
70 | self.count += count
71 | self.sum += val * count
72 | self.avg = self.sum / self.count
73 |
74 | def __repr__(self) -> str:
75 | text = f"{self.name}: {self.avg:.4f}"
76 | return text
77 |
78 |
79 | def get_lr(optimizer):
80 | for param_group in optimizer.param_groups:
81 | return param_group['lr']
82 |
83 |
84 | def scores_to_permutations(scores):
85 | """
86 | Input a batched array of scores and returns the hungarian optimized
87 | permutation matrices.
88 | """
89 | B, N, N = scores.shape
90 |
91 | scores = scores.detach().cpu().numpy()
92 | perm = np.zeros_like(scores)
93 | for b in range(B):
94 | r, c = linear_sum_assignment(-scores[b])
95 | perm[b,r,c] = 1
96 | return torch.tensor(perm)
97 |
98 |
99 | # TODO: add permalink to polyworld repo
100 | def permutations_to_polygons(perm, graph, out='torch'):
101 | B, N, N = perm.shape
102 | device = perm.device
103 |
104 | def bubble_merge(poly):
105 | s = 0
106 | P = len(poly)
107 | while s < P:
108 | head = poly[s][-1]
109 |
110 | t = s+1
111 | while t < P:
112 | tail = poly[t][0]
113 | if head == tail:
114 | poly[s] = poly[s] + poly[t][1:]
115 | del poly[t]
116 | poly = bubble_merge(poly)
117 | P = len(poly)
118 | t += 1
119 | s += 1
120 | return poly
121 |
122 | diag = torch.logical_not(perm[:,range(N),range(N)])
123 | batch = []
124 | for b in range(B):
125 | b_perm = perm[b]
126 | b_graph = graph[b]
127 | b_diag = diag[b]
128 |
129 | idx = torch.arange(N, device=perm.device)[b_diag]
130 |
131 | if idx.shape[0] > 0:
132 | # If there are vertices in the batch
133 |
134 | b_perm = b_perm[idx,:]
135 | b_graph = b_graph[idx,:]
136 | b_perm = b_perm[:,idx]
137 |
138 | first = torch.arange(idx.shape[0]).unsqueeze(1).to(device=device)
139 | second = torch.argmax(b_perm, dim=1).unsqueeze(1)
140 |
141 | polygons_idx = torch.cat((first, second), dim=1).tolist()
142 | polygons_idx = bubble_merge(polygons_idx)
143 |
144 | batch_poly = []
145 | for p_idx in polygons_idx:
146 | if out == 'torch':
147 | batch_poly.append(b_graph[p_idx,:])
148 | elif out == 'numpy':
149 | batch_poly.append(b_graph[p_idx,:].cpu().numpy())
150 | elif out == 'list':
151 | g = b_graph[p_idx,:] * 300 / 320
152 | g[:,0] = -g[:,0]
153 | g = torch.fliplr(g)
154 | batch_poly.append(g.tolist())
155 | elif out == 'coco':
156 | g = b_graph[p_idx,:]# * CFG.IMG_SIZE / CFG.INPUT_WIDTH
157 | g = torch.fliplr(g)
158 | batch_poly.append(g.view(-1).tolist())
159 | elif out == 'inria-torch':
160 | batch_poly.append(b_graph[p_idx,:])
161 | else:
162 | print("Indicate a valid output polygon format")
163 | exit()
164 |
165 | batch.append(batch_poly)
166 |
167 | else:
168 | # If the batch has no vertices
169 | batch.append([])
170 |
171 | return batch
172 |
173 |
174 | def test_generate(model, x, tokenizer, max_len=50, top_k=0, top_p=1):
175 | x = x.to(CFG.DEVICE)
176 | batch_preds = torch.ones((x.size(0), 1), device=CFG.DEVICE).fill_(tokenizer.BOS_code).long()
177 | confs = []
178 |
179 | if top_k != 0 or top_p != 1:
180 | sample = lambda preds: torch.softmax(preds, dim=-1).multinomial(num_samples=1).view(-1, 1)
181 | else:
182 | sample = lambda preds: torch.softmax(preds, dim=-1).argmax(dim=-1).view(-1, 1)
183 |
184 | with torch.no_grad():
185 | for i in range(max_len):
186 | if isinstance(model, torch.nn.parallel.DistributedDataParallel):
187 | preds, feats = model.module.predict(x, batch_preds)
188 | else:
189 | preds, feats = model.predict(x, batch_preds)
190 | preds = top_k_top_p_filtering(preds, top_k=top_k, top_p=top_p) # if top_k and top_p are set to default, this line does nothing.
191 | if i % 2 == 0:
192 | confs_ = torch.softmax(preds, dim=-1).sort(axis=-1, descending=True)[0][:, 0].cpu()
193 | confs.append(confs_)
194 | preds = sample(preds)
195 | batch_preds = torch.cat([batch_preds, preds], dim=1)
196 |
197 | if isinstance(model, torch.nn.parallel.DistributedDataParallel):
198 | perm_preds = model.module.scorenet1(feats) + torch.transpose(model.module.scorenet2(feats), 1, 2)
199 | else:
200 | perm_preds = model.scorenet1(feats) + torch.transpose(model.scorenet2(feats), 1, 2)
201 |
202 | perm_preds = scores_to_permutations(perm_preds)
203 |
204 | return batch_preds.cpu(), confs, perm_preds
205 |
206 |
207 | def postprocess(batch_preds, batch_confs, tokenizer):
208 | EOS_idxs = (batch_preds == tokenizer.EOS_code).float().argmax(dim=-1)
209 | ## sanity check
210 | invalid_idxs = ((EOS_idxs - 1) % 2 != 0).nonzero().view(-1)
211 | EOS_idxs[invalid_idxs] = 0
212 |
213 | all_coords = []
214 | all_confs = []
215 | for i, EOS_idx in enumerate(EOS_idxs.tolist()):
216 | if EOS_idx == 0:
217 | all_coords.append(None)
218 | all_confs.append(None)
219 | continue
220 | coords = tokenizer.decode(batch_preds[i, :EOS_idx+1])
221 | confs = [round(batch_confs[j][i].item(), 3) for j in range(len(coords))]
222 |
223 | all_coords.append(coords)
224 | all_confs.append(confs)
225 |
226 | return all_coords, all_confs
227 |
228 |
229 | def save_single_predictions_as_images(
230 | loader, model, tokenizer, epoch, writer, folder="saved_outputs/", device="cuda"
231 | ):
232 | print(f"=> Saving val predictions...")
233 | if not os.path.exists(folder):
234 | print(f"==> Creating output subdirectory...")
235 | os.makedirs(folder)
236 |
237 | model.eval()
238 |
239 | all_coords = []
240 | all_confs = []
241 |
242 | with torch.no_grad():
243 | loader_iterator = iter(loader)
244 | idx, (x, y_mask, y_corner_mask, y, y_perm) = 0, next(loader_iterator)
245 | batch_preds, batch_confs, perm_preds = test_generate(model, x, tokenizer, max_len=CFG.generation_steps, top_k=0, top_p=1)
246 | vertex_coords, confs = postprocess(batch_preds, batch_confs, tokenizer)
247 |
248 | all_coords.extend(vertex_coords)
249 | all_confs.extend(confs)
250 |
251 | coords = []
252 | for i in range(len(all_coords)):
253 | if all_coords[i] is not None:
254 | coord = torch.from_numpy(all_coords[i])
255 | else:
256 | coord = torch.tensor([])
257 | padd = torch.ones((CFG.N_VERTICES - len(coord), 2)).fill_(tokenizer.PAD_code)
258 | coord = torch.cat((coord, padd), dim=0)
259 | coords.append(coord)
260 | batch_polygons = permutations_to_polygons(perm_preds, coords, out='torch') # list of polygon coordinate tensors
261 |
262 | B, C, H, W = x.shape
263 | # Write predicted vertices as mask to disk.
264 | vertex_mask = np.zeros((B, 1, H, W))
265 | for b in range(len(all_coords)):
266 | if all_coords[b] is not None:
267 | print(f"Vertices found!")
268 | for i in range(len(all_coords[b])):
269 | coord = all_coords[b][i]
270 | cx, cy = coord
271 | cv2.circle(vertex_mask[b, 0], (int(cy), int(cx)), 0, 255, -1)
272 | vertex_mask = torch.from_numpy(vertex_mask)
273 | if not os.path.exists(os.path.join(folder, 'corners_mask')):
274 | os.makedirs(os.path.join(folder, 'corners_mask'))
275 | vertex_pred_vis = torch.zeros_like(x)
276 | for b in range(B):
277 | vertex_pred_vis[b] = torchvision.utils.draw_segmentation_masks(
278 | (x[b]*255).to(dtype=torch.uint8),
279 | torch.zeros_like(x[b, 0]).bool()
280 | )
281 | vertex_pred_vis = vertex_pred_vis.cpu().numpy().astype(np.uint8)
282 | for b in range(len(all_coords)):
283 | if all_coords[b] is not None:
284 | for i in range(len(all_coords[b])):
285 | coord = all_coords[b][i]
286 | cx, cy = coord
287 | cv2.circle(vertex_pred_vis[b, 0], (int(cy), int(cx)), 3, 255, -1)
288 | vertex_pred_vis = torch.from_numpy(vertex_pred_vis)
289 | torchvision.utils.save_image(
290 | vertex_pred_vis.float()/255, f"{folder}/corners_mask/corners_mask_{b}_{epoch}.png"
291 | )
292 |
293 | # Write predicted polygons as mask to disk.
294 | polygons = np.zeros((B, 1, H, W))
295 | for b in range(B):
296 | for c in range(len(batch_polygons[b])):
297 | poly = batch_polygons[b][c]
298 | poly = poly[poly[:, 0] != tokenizer.PAD_code]
299 | cnt = np.flip(np.int32(poly.cpu()), 1)
300 | if len(cnt) > 0:
301 | cv2.fillPoly(polygons[b, 0], pts=[cnt], color=1.)
302 | polygons = torch.from_numpy(polygons)
303 | if not os.path.exists(os.path.join(folder, 'pred_polygons')):
304 | os.makedirs(os.path.join(folder, 'pred_polygons'))
305 | poly_out = torch.zeros_like(x)
306 | for b in range(B):
307 | poly_out[b] = torchvision.utils.draw_segmentation_masks(
308 | (x[b]*255).to(dtype=torch.uint8),
309 | polygons[b, 0].bool()
310 | )
311 | poly_out = poly_out.cpu().numpy().astype(np.uint8)
312 | for b in range(len(all_coords)):
313 | if all_coords[b] is not None:
314 | for i in range(len(all_coords[b])):
315 | coord = all_coords[b][i]
316 | cx, cy = coord
317 | cv2.circle(poly_out[b, 0], (int(cy), int(cx)), 2, 255, -1)
318 | poly_out = torch.from_numpy(poly_out)
319 | torchvision.utils.save_image(
320 | poly_out.float()/255, f"{folder}/pred_polygons/polygons_{idx}_{epoch}.png"
321 | )
322 |
323 | batch_miou = binary_jaccard_index(polygons, y_mask)
324 | batch_biou = binary_jaccard_index(polygons, y_mask, ignore_index=0)
325 | batch_macc = binary_accuracy(polygons, y_mask)
326 | batch_bacc = binary_accuracy(polygons, y_mask, ignore_index=0)
327 |
328 | writer.add_scalar('Val_Metrics/Mean_IoU', batch_miou, epoch)
329 | writer.add_scalar('Val_Metrics/Building_IoU', batch_biou, epoch)
330 | writer.add_scalar('Val_Metrics/Mean_Accuracy', batch_macc, epoch)
331 | writer.add_scalar('Val_Metrics/Building_Accuracy', batch_bacc, epoch)
332 |
333 | metrics_dict = {
334 | "miou": batch_miou,
335 | "biou": batch_biou,
336 | "macc": batch_macc,
337 | "bacc": batch_bacc
338 | }
339 |
340 | torchvision.utils.save_image(x, f"{folder}/image_{idx}.png")
341 | ymask_out = torch.zeros_like(x)
342 | for b in range(B):
343 | ymask_out[b] = torchvision.utils.draw_segmentation_masks(
344 | (x[b]*255).to(dtype=torch.uint8),
345 | y_mask[b, 0].bool()
346 | )
347 | ymask_out = ymask_out.cpu().numpy().astype(np.uint8)
348 | gt_corner_coords, _ = postprocess(y, batch_confs, tokenizer)
349 | for b in range(B):
350 | for corner in gt_corner_coords[b]:
351 | cx, cy = corner
352 | cv2.circle(ymask_out[b, 0], (int(cy), int(cx)), 3, 255, -1)
353 | ymask_out = torch.from_numpy(ymask_out)
354 | torchvision.utils.save_image(ymask_out/255., f"{folder}/gt_mask_{idx}.png")
355 | torchvision.utils.save_image(y_corner_mask*255, f"{folder}/gt_corners_{idx}.png")
356 | torchvision.utils.save_image(y_perm[:, None, :, :]*255, f"{folder}/gt_perm_matrix_{idx}.png")
357 |
358 | return metrics_dict
359 |
--------------------------------------------------------------------------------