├── .gitignore ├── LICENSE ├── README.md ├── assets ├── optimal_matching_network.png ├── pix2poly_overall_bg_white.png ├── sfo7.png └── vertex_sequence_detector.png ├── config.py ├── data_preprocess ├── README.md ├── download_mass_roads_dataset.sh ├── inria_to_coco.py ├── mass_roads_clip_shapefile.py ├── mass_roads_clip_tile_vectors.py ├── mass_roads_tiles_to_patches.py ├── mass_roads_world_to_pixel_coords.py ├── spacenet_convert_16bit_to_8bit.py ├── spacenet_to_coco.py ├── spacenet_world_to_pixel_coords.py └── whu_buildings_to_coco.py ├── datasets ├── __init__.py ├── dataset_inria_coco.py ├── dataset_mass_roads.py ├── dataset_spacenet_coco.py └── dataset_whu_buildings_coco.py ├── ddp_utils.py ├── engine.py ├── eval ├── __init__.py ├── hisup_eval_utils │ └── metrics │ │ ├── angle_eval.py │ │ ├── cIoU.py │ │ └── polis.py └── topdig_eval_utils │ └── metrics │ └── topdig_metrics.py ├── evaluate_mass_roads_predictions.py ├── evaluation.py ├── models └── model.py ├── postprocess_coco_parts.py ├── predict_inria_coco_val_set.py ├── predict_mass_roads_test_set.py ├── predict_spacenet_coco_val_set.py ├── predict_whu_buildings_coco_test_set.py ├── pyrightconfig.json ├── tokenizer.py ├── train_ddp.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Directories 2 | __pycache__ 3 | data 4 | runs 5 | .vscode 6 | preds* 7 | crowdai_visualizations 8 | scratch 9 | 10 | # Files 11 | *.jpg 12 | *.png 13 | *.tar.xz 14 | *.gif 15 | *.txt 16 | 17 | # Files ignore 18 | !assets/sfo7.png 19 | !assets/pix2poly_overall_bg_white.png 20 | !assets/vertex_sequence_detector.png 21 | !assets/optimal_matching_network.png 22 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 Yeshwanth Kumar Adimoolam 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 |

Pix2Poly: A Sequence Prediction Method for End-to-end Polygonal Building Footprint Extraction

3 |

WACV 2025

4 | Yeshwanth Kumar Adimoolam1, Charalambos Poullis2, Melinos Averkiou1
5 | 1CYENS CoE, Cyprus, 2Concordia University 6 | 7 |
8 | 9 | 10 | [[Project Webpage](https://yeshwanth95.github.io/Pix2Poly)] [[Paper](https://arxiv.org/abs/2412.07899)] [[Video]()] 11 | 12 | ## UPDATES: 13 | 14 | 1. 05.06.2025 - Pretrained checkpoints for Pix2Poly on the various datasets used in the paper are released. See [pretrained checkpoints](#pretrained-checkpoints). 15 | 2. 21.05.2025 - As reported by the authors of the [$P^3$ dataset](https://arxiv.org/abs/2505.15379), Pix2Poly achieves state-of-the-art results for multimodal building vectorization from image and LiDAR data sources. 16 | 17 | ### Abstract: 18 | 19 | Extraction of building footprint polygons from remotely sensed data is essential for several urban understanding tasks such as reconstruction, navigation, and mapping. Despite significant progress in the area, extracting accurate polygonal building footprints remains an open problem. In this paper, we introduce Pix2Poly, an attention-based end-to-end trainable and differentiable deep neural network capable of directly generating explicit high-quality building footprints in a ring graph format. Pix2Poly employs a generative encoder-decoder transformer to produce a sequence of graph vertex tokens whose connectivity information is learned by an optimal matching network. Compared to previous graph learning methods, ours is a truly end-to-end trainable approach that extracts high-quality building footprints and road networks without requiring complicated, computationally intensive raster loss functions and intricate training pipelines. Upon evaluating Pix2Poly on several complex and challenging datasets, we report that Pix2Poly outperforms state-of-the-art methods in several vector shape quality metrics while being an entirely explicit method. 20 | 21 | ### Method 22 | 23 |
24 | 25 | 26 | 27 | 28 | 29 |
30 | 31 | __Overview of the Pix2Poly architecture:__ The Pix2Poly architecture consists of three major components: (i) The Discrete Sequence Tokenizer, (ii) the Vertex Sequence Detector, (iii) and the Optimal Matching Network. The Discrete Sequence Tokenizer is used to convert the continuous building corner coordinates into discrete building corner coordinate tokens which form the ground truth for training Pix2Poly. The Vertex Sequence Detector is an encoder-decoder transformer that predicts a sequence of sequence of discrete building corner coordinate tokens. The Optimal Matching Network takes the predicted corner coordinate tokens and the per-corner features from the vertex sequence detector and predicts a N X N permutation matrix which contains the connectivity information between every possible corner pair. Together, the predicted building corners and permutation matrix are used to recover the final building polygons. 32 | 33 | ## Installation 34 | 35 | Pix2Poly was developed with `python=3.11`, `pytorch=2.1.2`, `pytorch-cuda=11.8`, `timm=0.9.12`, `transformers=4.32.1` 36 | 37 | Create a conda environment with the following specification: 38 | 39 | ``` 40 | Conda requirements: 41 | channels: 42 | - defaults 43 | dependencies: 44 | - torchvision=0.16.2 45 | - pytorch=2.1.2 46 | - pytorch-cuda=11.8 47 | - torchaudio=2.1.2 48 | - timm=0.9.12 49 | - transformers=4.32.1 50 | - pycocotools=2.0.6 51 | - torchmetrics=1.2.1 52 | - tensorboard=2.15.1 53 | - pip: 54 | - albumentations==1.3.1 55 | - imageio==2.33.1 56 | - matplotlib-inline==0.1.6 57 | - opencv-python-headless==4.8.1.78 58 | - scikit-image==0.22.0 59 | - scikit-learn==1.3.2 60 | - scipy==1.11.4 61 | - shapely==2.0.4 62 | ``` 63 | 64 | 65 | ## Datasets preparation 66 | 67 | See [datasets preprocessing](data_preprocess) for instructions on preparing the various datasets for training/inference. 68 | 69 | ## Pretrained Checkpoints 70 | 71 | Pretrained checkpoints for the various datasets used in the paper are available for download at the following links: [Google Drive](https://drive.google.com/file/d/1oEs2n81nMAzdY4G9bdrji13pOKk6MOET/view?usp=sharing) | [MEGA](https://mega.nz/file/ExQEBDxY#faK1yNaQ8KYvPGuxJY1snvFi7TfbF1kOx4mvhmUSb4s) 72 | 73 | Download the zip file, extract and place the individual runs folder in the `runs` directory at the root of the project. 74 | 75 | ## Configurations 76 | 77 | ## Training 78 | 79 | Start training with the following command: 80 | 81 | ``` 82 | torchrun --nproc_per_node= train_ddp.py 83 | ``` 84 | 85 | ## Prediction 86 | 87 | ### (i) INRIA Dataset 88 | To generate predictions for the INRIA dataset, run the following: 89 | ```shell 90 | python predict_inria_coco_val_set.py -d inria_coco_224_negAug \ 91 | -e \ 92 | -c \ 93 | -o 94 | python postprocess_coco_parts.py # change input and output paths in L006 to L010. 95 | ``` 96 | 97 | ### (ii) Spacenet 2 Dataset 98 | To generate predictions for the Spacenet 2 dataset, run the following: 99 | ```shell 100 | python predict_spacenet_coco_val_set.py -d spacenet_coco \ 101 | -e \ 102 | -c \ 103 | -o 104 | python postprocess_coco_parts.py # change input and output paths in L006 to L010. 105 | ``` 106 | 107 | ### (iii) WHU Buildings Dataset 108 | To generate predictions for the WHU Buildings dataset, run the following: 109 | ```shell 110 | python predict_whu_buildings_coco_test_set.py -d whu_buildings_224_coco \ 111 | -e \ 112 | -c \ 113 | -o 114 | python postprocess_coco_parts.py # change input and output paths in L006 to L010. 115 | ``` 116 | ### (iv) Massachusetts Roads Dataset 117 | To generate predictions for the Massachusetts Roads dataset, run the following: 118 | ```shell 119 | python predict_mass_roads_test_set.py -e \ 120 | -c \ 121 | -s \ # 'test' or 'val' 122 | --img_size 224 \ 123 | --input_size 224 \ 124 | --batch_size 24 \ # modify according to resources 125 | ``` 126 | 127 | ## Evaluation (buildings datasets) 128 | 129 | Once predictions are made, metrics can be computed for the predicted files as follows: 130 | 131 | ```bash 132 | python evaluation.py --gt-file path/to/ground/truth/annotation.json --dt-file path/to/prediction.json --eval-type 133 | ``` 134 | 135 | where `metric_type` can be one of the following: `ciou`, `angle`, `polis`, `topdig`. 136 | 137 | ## Evaluation (Massachusetts Roads Dataset) 138 | 139 | Once raster predictions are made for the Massachusetts Roads dataset, metrics can be computed for the predicted files as follows: 140 | 141 | ```bash 142 | python evaluate_mass_roads_predictions.py --gt-dir data/mass_roads_1500/test/map --dt-dir path/to/predicted/raster/masks/folder 143 | ``` 144 | 145 | ## Citation 146 | 147 | If you find our work useful, please consider citing: 148 | ```bibtex 149 | @misc{adimoolam2024pix2poly, 150 | title={Pix2Poly: A Sequence Prediction Method for End-to-end Polygonal Building Footprint Extraction from Remote Sensing Imagery}, 151 | author={Yeshwanth Kumar Adimoolam and Charalambos Poullis and Melinos Averkiou}, 152 | year={2024}, 153 | eprint={2412.07899}, 154 | archivePrefix={arXiv}, 155 | primaryClass={cs.CV}, 156 | url={https://arxiv.org/abs/2412.07899}, 157 | } 158 | ``` 159 | 160 | ## Acknowledgements 161 | 162 | This repository benefits from the following open-source work. We thank the authors for their great work. 163 | 164 | 1. [Pix2Seq - official repo](https://github.com/google-research/pix2seq) 165 | 2. [Pix2Seq - unofficial repo](https://github.com/moein-shariatnia/Pix2Seq) 166 | 3. [Frame Field Learning](https://github.com/Lydorn/Polygonization-by-Frame-Field-Learning) 167 | 4. [PolyWorld](https://github.com/zorzi-s/PolyWorldPretrainedNetwork) 168 | 5. [HiSup](https://github.com/SarahwXU/HiSup) 169 | -------------------------------------------------------------------------------- /assets/optimal_matching_network.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yeshwanth95/Pix2Poly/d2e288b4e267ae910cee6f54bece8eec98f43bea/assets/optimal_matching_network.png -------------------------------------------------------------------------------- /assets/pix2poly_overall_bg_white.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yeshwanth95/Pix2Poly/d2e288b4e267ae910cee6f54bece8eec98f43bea/assets/pix2poly_overall_bg_white.png -------------------------------------------------------------------------------- /assets/sfo7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yeshwanth95/Pix2Poly/d2e288b4e267ae910cee6f54bece8eec98f43bea/assets/sfo7.png -------------------------------------------------------------------------------- /assets/vertex_sequence_detector.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yeshwanth95/Pix2Poly/d2e288b4e267ae910cee6f54bece8eec98f43bea/assets/vertex_sequence_detector.png -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class CFG: 4 | IMG_PATH = '' 5 | DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 6 | 7 | """ 8 | supported datasets are: 9 | - inria_coco_224_negAug 10 | - spacenet_coco 11 | - whu_buildings_224_coco 12 | - mass_roads_224 13 | """ 14 | DATASET = f"inria_coco_224_negAug" 15 | if "coco" in DATASET: 16 | TRAIN_DATASET_DIR = f"./data/{DATASET}/train" 17 | VAL_DATASET_DIR = f"./data/{DATASET}/val" 18 | TEST_IMAGES_DIR = f"./data/{DATASET}/val/images" 19 | elif "mass_roads" in DATASET: 20 | TRAIN_DATASET_DIR = f"./data/{DATASET}/train" 21 | VAL_DATASET_DIR = f"./data/{DATASET}/valid" 22 | TEST_IMAGES_DIR = f"./data/{DATASET}/test/images" 23 | 24 | 25 | TRAIN_DDP = True 26 | NUM_WORKERS = 2 27 | PIN_MEMORY = True 28 | LOAD_MODEL = False 29 | 30 | if "inria" in DATASET: 31 | N_VERTICES = 192 # maximum number of vertices per image in dataset. 32 | elif "spacenet" in DATASET: 33 | N_VERTICES = 192 # maximum number of vertices per image in dataset. 34 | elif "whu_buildings" in DATASET: 35 | N_VERTICES = 144 # maximum number of vertices per image in dataset. 36 | elif "mass_roads" in DATASET: 37 | N_VERTICES = 192 # maximum number of vertices per image in dataset. 38 | 39 | SINKHORN_ITERATIONS = 100 40 | MAX_LEN = (N_VERTICES*2) + 2 41 | if "inria" in DATASET: 42 | IMG_SIZE = 224 43 | elif "spacenet" in DATASET: 44 | IMG_SIZE = 224 45 | elif "whu_buildings" in DATASET: 46 | IMG_SIZE = 224 47 | elif "mass_roads" in DATASET: 48 | IMG_SIZE = 224 49 | INPUT_SIZE = 224 50 | PATCH_SIZE = 8 51 | INPUT_HEIGHT = INPUT_SIZE 52 | INPUT_WIDTH = INPUT_SIZE 53 | NUM_BINS = INPUT_HEIGHT*1 54 | LABEL_SMOOTHING = 0.0 55 | vertex_loss_weight = 1.0 56 | perm_loss_weight = 10.0 57 | SHUFFLE_TOKENS = False # order gt vertex tokens randomly every time 58 | 59 | BATCH_SIZE = 24 # batch size per gpu; effective batch size = BATCH_SIZE * NUM_GPUs 60 | START_EPOCH = 0 61 | NUM_EPOCHS = 500 62 | MILESTONE = 0 63 | SAVE_BEST = True 64 | SAVE_LATEST = True 65 | SAVE_EVERY = 10 66 | VAL_EVERY = 1 67 | 68 | MODEL_NAME = f'vit_small_patch{PATCH_SIZE}_{INPUT_SIZE}_dino' 69 | NUM_PATCHES = int((INPUT_SIZE // PATCH_SIZE) ** 2) 70 | 71 | LR = 4e-4 72 | WEIGHT_DECAY = 1e-4 73 | 74 | generation_steps = (N_VERTICES * 2) + 1 # sequence length during prediction. Should not be more than max_len 75 | run_eval = False 76 | 77 | # EXPERIMENT_NAME = f"debug_run_Pix2Poly224_Bins{NUM_BINS}_fullRotateAugs_permLossWeight{perm_loss_weight}_LR{LR}__{NUM_EPOCHS}epochs" 78 | EXPERIMENT_NAME = f"train_Pix2Poly_{DATASET}_run1_{MODEL_NAME}_AffineRotaugs0.8_LinearWarmupLRS_{vertex_loss_weight}xVertexLoss_{perm_loss_weight}xPermLoss__2xScoreNet_initialLR_{LR}_bs_{BATCH_SIZE}_Nv_{N_VERTICES}_Nbins{NUM_BINS}_{NUM_EPOCHS}epochs" 79 | 80 | if "debug" in EXPERIMENT_NAME: 81 | BATCH_SIZE = 10 82 | NUM_WORKERS = 0 83 | SAVE_BEST = False 84 | SAVE_LATEST = False 85 | SAVE_EVERY = NUM_EPOCHS 86 | VAL_EVERY = 50 87 | 88 | if LOAD_MODEL: 89 | CHECKPOINT_PATH = f"runs/{EXPERIMENT_NAME}/logs/checkpoints/latest.pth" # full path to checkpoint to be loaded if LOAD_MODEL=True 90 | else: 91 | CHECKPOINT_PATH = "" 92 | 93 | -------------------------------------------------------------------------------- /data_preprocess/README.md: -------------------------------------------------------------------------------- 1 | # Datasets preparation 2 | 3 | ## INRIA Dataset 4 | 5 | 1. Download the [INRIA Aerial Image Labeling Dataset](https://project.inria.fr/aerialimagelabeling/). 6 | 2. Extract and place the aerial image tiles in the `data` directory as follows: 7 | ``` 8 | data/inria_raw/ 9 | ├── test/ 10 | │   └── images/ 11 | └── train/ 12 | ├── gt/ 13 | └── images/ 14 | ``` 15 | 3. Set path to raw INRIA train tiles and gts in L255 & L256 in `inria_to_coco.py` 16 | 4. Run the following command to prepare the INRIA dataset's train and validation splits in MS COCO format. The first 5 tiles of each city are kept as validation split as per the official recommendation. 17 | ```shell 18 | # with pix2poly_env 19 | python inria_to_coco.py 20 | ``` 21 | --- 22 | 23 | 24 | ## SpaceNet 2 Building Detection v2 Dataset (Vegas Subset) 25 | 26 | NOTE: We only use the Vegas subset for all our experiments in the paper. 27 | 28 | 1. Download the [Spacenet 2 Building Detection v2 Dataset](https://spacenet.ai/spacenet-buildings-dataset-v2/). 29 | 2. Extract and place the satellite image tiles for the Vegas subset in the `data` folder in the following directory structure: 30 | ``` 31 | data/AOI_2_Vegas_Train/ 32 | └── geojson/ 33 | └── buildings/ 34 | └── RGB-PanSharpen/ 35 | ├── gt/ 36 | └── images/ 37 | ``` 38 | 3. Convert the pansharpened RGB image tiles from 16-bit to 8-bit using the following command: 39 | ```shell 40 | # with gdal_env 41 | python spacenet_convert_16bit_to_8bit.py 42 | ``` 43 | 4. Convert geojson annotations from world space coordinates to pixel space coordinates using the following command: 44 | ```shell 45 | # with gdal_env 46 | python spacenet_world_to_pixel_coords.py 47 | ``` 48 | 5. Set path to raw SpaceNet dataset's tiles and gts in L202 & L203 in `spacenet_to_coco.py` 49 | 6. Run the following command to prepare the SpaceNet dataset's train and validation splits in MS COCO format. The first 15% tiles kept as validation split. 50 | ```shell 51 | # with pix2poly_env 52 | python spacenet_to_coco.py 53 | ``` 54 | --- 55 | 56 | 57 | ## WHU Buildings Dataset 58 | 59 | 60 | 1. Download the 0.2 meter split of the [WHU Buildings Aerial Imagery Dataset](http://gpcv.whu.edu.cn/data/building_dataset.html). 61 | 2. Extract and place the aerial image tiles (512x512) in the `data` folder in the following directory structure: 62 | ``` 63 | data/WHU_aerial_0.2/ 64 | ├── test/ 65 | │   ├── image/ 66 | │   └── label/ 67 | ├── train/ 68 | │   ├── image/ 69 | │   └── label/ 70 | └── val/ 71 | ├── image/ 72 | └── label/ 73 | ``` 74 | 3. Set path to raw WHU Buildings tiles and gts (512x512) in L263 & L264 in `whu_buildings_to_coco.py` 75 | 4. Run the following command to prepare the WHU Buildings dataset's train, validation and test splits in MS COCO format. 76 | ```shell 77 | # with pix2poly_env 78 | python whu_buildings_to_coco.py 79 | ``` 80 | --- 81 | 82 | 83 | 84 | ## Massachusetts Roads Dataset 85 | 86 | 1. Download the [Massachusetts Roads Dataset](https://www.cs.toronto.edu/~vmnih/data/) using the following command: 87 | ```shell 88 | ./download_mass_roads_dataset.sh 89 | ``` 90 | 2. Download and extract the roads vector shapefile for the dataset [here](https://www.cs.toronto.edu/~vmnih/data/mass_roads/massachusetts_roads_shape.zip). Use QGIS or any preferred tool to convert the SHP file to a geojson. 91 | 3. From this vector roads geojson, generate vector annotations for each tile in the dataset by clipping the geojson to the corresponding raster extents: 92 | ```shell 93 | # using gdal_env 94 | python mass_roads_clip_shapefile.py 95 | ``` 96 | 4. This results in the following directory structure containing the 1500x1500 tiles of the Massachusetts Roads Dataset: 97 | ``` 98 | data/mass_roads_1500/ 99 | ├── test/ 100 | │   ├── map/ 101 | │   ├── sat/ 102 | │   └── shape/ 103 | ├── train/ 104 | │   ├── map/ 105 | │   ├── sat/ 106 | │   └── shape/ 107 | └── valid/ 108 | ├── map/ 109 | ├── sat/ 110 | └── shape/ 111 | ``` 112 | 5. Split the 1500x1500 tiles into 224x224 overlapping patches with the following command: 113 | ```shell 114 | # using gdal_env 115 | python mass_roads_tiles_to_patches.py 116 | ``` 117 | 6. Generate vector annotation files for the patches as follows: 118 | ```shell 119 | # using gdal_env 120 | python mass_roads_clip_tile_vectors.py 121 | python mass_roads_world_to_pixel_coords.py 122 | ``` 123 | 7. This results in the processed 224x224 patches of the Massachusetts Roads Dataset to be used for training Pix2Poly in the following directory structure: 124 | ``` 125 | data/mass_roads_224/ 126 | ├── test/ 127 | │   ├── map/ 128 | │   ├── pixel_annotations/ 129 | │   ├── sat/ 130 | │   └── shape/ 131 | ├── train/ 132 | │   ├── map/ 133 | │   ├── pixel_annotations/ 134 | │   ├── sat/ 135 | │   └── shape/ 136 | └── valid/ 137 | ├── map/ 138 | ├── pixel_annotations/ 139 | ├── sat/ 140 | └── shape/ 141 | ``` 142 | --- 143 | -------------------------------------------------------------------------------- /data_preprocess/download_mass_roads_dataset.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | mkdir -p ../data 4 | cd ../data 5 | 6 | # Download train split 7 | wget -r --no-parent https://www.cs.toronto.edu/~vmnih/data/mass_roads/train/sat/ 8 | wget -r --no-parent https://www.cs.toronto.edu/~vmnih/data/mass_roads/train/map/ 9 | 10 | # Download valid split 11 | wget -r --no-parent https://www.cs.toronto.edu/~vmnih/data/mass_roads/valid/sat/ 12 | wget -r --no-parent https://www.cs.toronto.edu/~vmnih/data/mass_roads/valid/map/ 13 | 14 | # Download test split 15 | wget -r --no-parent https://www.cs.toronto.edu/~vmnih/data/mass_roads/test/sat/ 16 | wget -r --no-parent https://www.cs.toronto.edu/~vmnih/data/mass_roads/test/map/ 17 | 18 | mv www.cs.toronto.edu/\~vmnih/data/mass_roads . 19 | mv mass_roads mass_roads_1500 20 | rm -r www.cs.toronto.edu 21 | -------------------------------------------------------------------------------- /data_preprocess/mass_roads_clip_shapefile.py: -------------------------------------------------------------------------------- 1 | from osgeo import gdal 2 | import os 3 | import subprocess 4 | from tqdm import tqdm 5 | from multiprocessing import Pool 6 | 7 | 8 | def clip_shapefile(paths_dict): 9 | # get the extent of the raster 10 | raster_path = paths_dict['raster_path'] 11 | save_path = paths_dict['save_path'] 12 | vector_path = paths_dict['vector_path'] 13 | src = gdal.Open(raster_path) 14 | ulx, xres, xskew, uly, yskew, yres = src.GetGeoTransform() 15 | sizeX = src.RasterXSize * xres 16 | sizeY = src.RasterYSize * yres 17 | lrx = ulx + sizeX 18 | lry = uly + sizeY 19 | src = None 20 | 21 | # format the extent coords 22 | extent = f"{ulx} {lry} {lrx} {uly}" 23 | # print(extent) 24 | 25 | # make clip command with ogr2ogr 26 | cmd = f"ogr2ogr {save_path} {vector_path} -clipsrc {extent}" 27 | 28 | # call the command 29 | subprocess.call(cmd, shell=True) 30 | return 0 31 | 32 | 33 | def main(): 34 | split = "train" # "train" or "valid" or "test" 35 | data_root = "../data/mass_roads_1500" 36 | vector_path = os.path.join(data_root, "massachusetts_roads_shape.geojson") 37 | save_dir = os.path.join(data_root, split, "shape") 38 | os.makedirs(save_dir, exist_ok=True) 39 | 40 | rasters_dir = os.path.join(data_root, split, "sat") 41 | rasters = os.listdir(rasters_dir) 42 | 43 | param_inputs = [] 44 | for raster in rasters: 45 | raster_path = os.path.join(rasters_dir, raster) 46 | save_path = os.path.join(save_dir, raster.split('.')[0]+".geojson") 47 | param_inputs.append( 48 | { 49 | 'raster_path': raster_path, 50 | 'vector_path': vector_path, 51 | 'save_path': save_path, 52 | } 53 | ) 54 | # clip_shapefile(raster_path, vector_path, save_path) 55 | 56 | with Pool() as p: 57 | _ = list(tqdm(p.imap(clip_shapefile, param_inputs), total=len(param_inputs))) 58 | 59 | 60 | if __name__ == "__main__": 61 | main() 62 | 63 | -------------------------------------------------------------------------------- /data_preprocess/mass_roads_clip_tile_vectors.py: -------------------------------------------------------------------------------- 1 | from osgeo import gdal 2 | import os 3 | import subprocess 4 | from tqdm import tqdm 5 | from multiprocessing import Pool 6 | 7 | 8 | def clip_shapefile(paths_dict): 9 | # get the extent of the raster 10 | raster_path = paths_dict['raster_path'] 11 | save_path = paths_dict['save_path'] 12 | vector_path = paths_dict['vector_path'] 13 | src = gdal.Open(raster_path) 14 | ulx, xres, xskew, uly, yskew, yres = src.GetGeoTransform() 15 | sizeX = src.RasterXSize * xres 16 | sizeY = src.RasterYSize * yres 17 | lrx = ulx + sizeX 18 | lry = uly + sizeY 19 | src = None 20 | 21 | # format the extent coords 22 | extent = f"{ulx} {lry} {lrx} {uly}" 23 | # print(extent) 24 | 25 | # make clip command with ogr2ogr 26 | cmd = f"ogr2ogr {save_path} {vector_path} -clipsrc {extent}" 27 | 28 | # call the command 29 | subprocess.call(cmd, shell=True) 30 | return 0 31 | 32 | 33 | def main(): 34 | split = "train" # "train" or "valid" or "test" 35 | data_root = f"../data/mass_roads_1500" 36 | vector_dir = os.path.join(data_root, split, "shape") 37 | save_dir = f'../data/mass_roads_224/{split}/shape' 38 | os.makedirs(save_dir, exist_ok=True) 39 | 40 | rasters_dir = f'../data/mass_roads_224/{split}/images/' 41 | rasters = os.listdir(rasters_dir) 42 | 43 | param_inputs = [] 44 | for raster in rasters: 45 | raster_path = os.path.join(rasters_dir, raster) 46 | save_path = os.path.join(save_dir, raster.split('.')[0]+".geojson") 47 | raster_info = raster.split('_') 48 | vec_desc = f"{raster_info[0]}_{raster_info[1]}.geojson" 49 | vector_path = os.path.join(vector_dir, vec_desc) 50 | param_inputs.append( 51 | { 52 | 'raster_path': raster_path, 53 | 'vector_path':vector_path, 54 | 'save_path': save_path 55 | } 56 | ) 57 | 58 | 59 | # clip_shapefile(raster_path, vector_path, save_path) 60 | with Pool() as p: 61 | _ = list(tqdm(p.imap(clip_shapefile, param_inputs), total=len(param_inputs))) 62 | 63 | if __name__ == "__main__": 64 | main() 65 | 66 | -------------------------------------------------------------------------------- /data_preprocess/mass_roads_tiles_to_patches.py: -------------------------------------------------------------------------------- 1 | # gdal_retile.py -targetDir gdal_retile/valid/images/ -ps 224 224 -overlap 34 valid/sat/*.tiff 2 | 3 | import os 4 | import subprocess 5 | from tqdm import tqdm 6 | 7 | 8 | def main(): 9 | split = 'train' 10 | data_root = f"../data/mass_roads_1500" 11 | images_dir = os.path.join(data_root, split, "sat") 12 | masks_dir = os.path.join(data_root, split, "map") 13 | 14 | out_imgs_dir = f"../data/mass_roads_224/{split}/images" 15 | out_mask_dir = f"../data/mass_roads_224/{split}/mask" 16 | os.makedirs(out_imgs_dir, exist_ok=True) 17 | os.makedirs(out_mask_dir, exist_ok=True) 18 | 19 | patch_size = 224 20 | ph = pw = patch_size 21 | overlap = int(round(0.15*patch_size)) 22 | 23 | images = os.listdir(images_dir) 24 | for img in tqdm(images): 25 | in_path = os.path.join(images_dir, img) 26 | cmd = f"gdal_retile.py -targetDir {out_imgs_dir} -ps {ph} {pw} -overlap {overlap} {in_path}" 27 | subprocess.call(cmd, shell=True) 28 | 29 | masks = os.listdir(masks_dir) 30 | for mask in tqdm(masks): 31 | in_path = os.path.join(masks_dir, mask) 32 | cmd = f"gdal_retile.py -targetDir {out_mask_dir} -ps {ph} {pw} -overlap {overlap} {in_path}" 33 | subprocess.call(cmd, shell=True) 34 | 35 | 36 | if __name__ == "__main__": 37 | main() 38 | -------------------------------------------------------------------------------- /data_preprocess/mass_roads_world_to_pixel_coords.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | from tqdm import tqdm 4 | from osgeo import gdal # use gdal env for this script. 5 | 6 | 7 | def main(): 8 | split = "train" # "train" or "valid" or "test" 9 | # data_root = f"../data/mass_roads_1500" 10 | data_root = f"../data/mass_roads_224" 11 | # geoimages_dir = os.path.join(data_root, split, "sat") 12 | geoimages_dir = os.path.join(data_root, split, "images") 13 | # shapefiles_dir = os.path.join(data_root, split, "shape") 14 | shapefiles_dir = os.path.join(data_root, split, "shape") 15 | 16 | save_shapefiles_dir = os.path.join(data_root, split, "pixel_annotations") 17 | os.makedirs(save_shapefiles_dir, exist_ok=True) 18 | 19 | geoimages = os.listdir(geoimages_dir) 20 | shapefiles = os.listdir(shapefiles_dir) 21 | 22 | for i in tqdm(range(len(geoimages))): 23 | geo_im = geoimages[i] 24 | # im_desc = geo_im.split('_')[-1].split('.')[0] 25 | im_desc = geo_im.split('.')[0] 26 | # shp = [sh for sh in shapefiles if f"{im_desc}.geojson" in sh] 27 | # assert len(shp) == 1 28 | # shp = shp[0] 29 | shp = f"{im_desc}.geojson" 30 | # shp = shapefiles[0] 31 | 32 | driver = gdal.GetDriverByName('GTiff') 33 | dataset = gdal.Open(os.path.join(geoimages_dir, geo_im)) 34 | band = dataset.GetRasterBand(1) 35 | cols = dataset.RasterXSize 36 | rows = dataset.RasterYSize 37 | transform = dataset.GetGeoTransform() 38 | xOrigin = transform[0] 39 | yOrigin = transform[3] 40 | pixelWidth = transform[1] 41 | pixelHeight = -transform[5] 42 | data = band.ReadAsArray(0, 0, cols, rows) 43 | 44 | with open(os.path.join(shapefiles_dir, shp), 'r') as f: 45 | geo_shp = json.load(f) 46 | 47 | pixel_shp = {} 48 | pixel_shp['type'] = geo_shp['type'] 49 | pixel_shp['features'] = [] 50 | 51 | for feature in geo_shp['features']: 52 | out_feat = { 53 | 'type': 'Feature', 54 | 'geometry': { 55 | 'type': 'MultiLineString', 56 | 'coordinates': [] 57 | } 58 | } 59 | if feature['geometry']['type'] == "MultiLineString": 60 | feature_coords = feature['geometry']['coordinates'] 61 | for geo_coords in feature_coords: 62 | points_list = [(gc[0], gc[1]) for gc in geo_coords] 63 | coords_list = [] 64 | for point in points_list: 65 | col = (point[0] - xOrigin) / pixelWidth 66 | col = col if col < cols else cols 67 | row = (yOrigin - point[1]) / pixelHeight 68 | row = row if row < rows else rows 69 | # 'row' has negative sign to be compatible with qgis visualization. Must not be used for compatibility in image space. 70 | # coords_list.append([col, -row]) 71 | out_feat['geometry']['coordinates'].append(coords_list) 72 | pixel_shp['features'].append(out_feat) 73 | else: 74 | geo_coords = feature['geometry']['coordinates'] 75 | points_list = [(gc[0], gc[1]) for gc in geo_coords] 76 | coords_list = [] 77 | for point in points_list: 78 | col = (point[0] - xOrigin) / pixelWidth 79 | col = col if col < cols else cols 80 | row = (yOrigin - point[1]) / pixelHeight 81 | row = row if row < rows else rows 82 | # 'row' has negative sign to be compatible with qgis visualization. Must not be used for compatibility in image space. 83 | # coords_list.append([col, -row]) 84 | coords_list.append([col, row]) 85 | out_feat['geometry']['coordinates'].append(coords_list) 86 | pixel_shp['features'].append(out_feat) 87 | 88 | with open(os.path.join(save_shapefiles_dir, shp), 'w') as o: 89 | json.dump(pixel_shp, o) 90 | 91 | 92 | if __name__ == "__main__": 93 | main() 94 | 95 | -------------------------------------------------------------------------------- /data_preprocess/spacenet_convert_16bit_to_8bit.py: -------------------------------------------------------------------------------- 1 | import os 2 | from osgeo import gdal 3 | from tqdm import tqdm 4 | 5 | 6 | def convert_16bit_to_8bit(in_path, out_path, out_format='GTiff'): 7 | translate_options = gdal.TranslateOptions(format=out_format, 8 | outputType=gdal.GDT_Byte, 9 | scaleParams=[''], 10 | # scaleParams=[min_val, max_val], 11 | ) 12 | gdal.Translate(destName=out_path, srcDS=in_path, options=translate_options) 13 | 14 | 15 | def main(): 16 | src_dir = f"../data/AOI_2_Vegas_Train/RGB-PanSharpen/" 17 | dest_dir = f"../data/AOI_2_Vegas_Train/RGB_8bit/train/images" 18 | os.makedirs(dest_dir, exist_ok=True) 19 | 20 | src_ims = os.listdir(src_dir) 21 | for im in tqdm(src_ims): 22 | src_im = os.path.join(src_dir, im) 23 | dest_im = os.path.join(dest_dir, im) 24 | convert_16bit_to_8bit(src_im, dest_im) 25 | 26 | 27 | if __name__ == "__main__": 28 | main() 29 | 30 | -------------------------------------------------------------------------------- /data_preprocess/spacenet_to_coco.py: -------------------------------------------------------------------------------- 1 | # Modified from https://github.com/SarahwXU/HiSup/blob/main/tools/inria_to_coco.py 2 | # Transform Spacenet 2 dataset (image and geojson pairs) to COCO format 3 | # 4 | # The first 15% images are kept as validation set 5 | 6 | import os 7 | import numpy as np 8 | from skimage import io 9 | import json 10 | from tqdm import tqdm 11 | from shapely.geometry import Polygon, mapping 12 | from shapely.ops import transform as poly_transform 13 | from shapely.ops import unary_union 14 | from shapely.geometry import box 15 | from skimage.measure import label as ski_label 16 | from skimage.measure import regionprops 17 | import cv2 18 | import math 19 | import shapely 20 | 21 | 22 | def clip_by_bound(poly, im_h, im_w): 23 | """ 24 | Bound poly coordinates by image shape 25 | """ 26 | p_x = poly[:, 0] 27 | p_y = poly[:, 1] 28 | p_x = np.clip(p_x, 0.0, im_w-1) 29 | p_y = np.clip(p_y, 0.0, im_h-1) 30 | return np.concatenate((p_x[:, np.newaxis], p_y[:, np.newaxis]), axis=1) 31 | 32 | 33 | def crop2patch(im_p, p_h, p_w, p_overlap): 34 | """ 35 | Get coordinates of upper-left point for image patch 36 | return: patch_list [X_upper-left, Y_upper-left, patch_width, patch_height] 37 | """ 38 | im_h, im_w, _ = im_p 39 | x = np.arange(0, im_w-p_w, p_w-p_overlap) 40 | x = np.append(x, im_w-p_w) 41 | y = np.arange(0, im_h-p_h, p_h-p_overlap) 42 | y = np.append(y, im_h-p_h) 43 | X, Y = np.meshgrid(x, y) 44 | patch_list = [[i, j, p_w, p_h] for i, j in zip(X.flatten(), Y.flatten())] 45 | return patch_list 46 | 47 | 48 | def rotate_image(image, angle): 49 | """ 50 | Rotates an OpenCV 2 / NumPy image about it's centre by the given angle 51 | (in degrees). The returned image will be large enough to hold the entire 52 | new image, with a black background 53 | """ 54 | 55 | # Get the image size 56 | # No that's not an error - NumPy stores image matricies backwards 57 | image_size = (image.shape[1], image.shape[0]) 58 | image_center = tuple(np.array(image_size) / 2) 59 | 60 | # Convert the OpenCV 3x2 rotation matrix to 3x3 61 | rot_mat = np.vstack( 62 | [cv2.getRotationMatrix2D(image_center, angle, 1.0), [0, 0, 1]] 63 | ) 64 | 65 | rot_mat_notranslate = np.matrix(rot_mat[0:2, 0:2]) 66 | 67 | # Shorthand for below calcs 68 | image_w2 = image_size[0] * 0.5 69 | image_h2 = image_size[1] * 0.5 70 | 71 | # Obtain the rotated coordinates of the image corners 72 | rotated_coords = [ 73 | (np.array([-image_w2, image_h2]) * rot_mat_notranslate).A[0], 74 | (np.array([ image_w2, image_h2]) * rot_mat_notranslate).A[0], 75 | (np.array([-image_w2, -image_h2]) * rot_mat_notranslate).A[0], 76 | (np.array([ image_w2, -image_h2]) * rot_mat_notranslate).A[0] 77 | ] 78 | 79 | # Find the size of the new image 80 | x_coords = [pt[0] for pt in rotated_coords] 81 | x_pos = [x for x in x_coords if x > 0] 82 | x_neg = [x for x in x_coords if x < 0] 83 | 84 | y_coords = [pt[1] for pt in rotated_coords] 85 | y_pos = [y for y in y_coords if y > 0] 86 | y_neg = [y for y in y_coords if y < 0] 87 | 88 | right_bound = max(x_pos) 89 | left_bound = min(x_neg) 90 | top_bound = max(y_pos) 91 | bot_bound = min(y_neg) 92 | 93 | new_w = int(abs(right_bound - left_bound)) 94 | new_h = int(abs(top_bound - bot_bound)) 95 | 96 | # We require a translation matrix to keep the image centred 97 | trans_mat = np.matrix([ 98 | [1, 0, int(new_w * 0.5 - image_w2)], 99 | [0, 1, int(new_h * 0.5 - image_h2)], 100 | [0, 0, 1] 101 | ]) 102 | 103 | # Compute the tranform for the combined rotation and translation 104 | affine_mat = (np.matrix(trans_mat) * np.matrix(rot_mat))[0:2, :] 105 | 106 | # Apply the transform 107 | result = cv2.warpAffine( 108 | image, 109 | affine_mat, 110 | (new_w, new_h), 111 | flags=cv2.INTER_LINEAR 112 | ) 113 | 114 | return result 115 | 116 | 117 | def largest_rotated_rect(w, h, angle): 118 | """ 119 | Given a rectangle of size wxh that has been rotated by 'angle' (in 120 | radians), computes the width and height of the largest possible 121 | axis-aligned rectangle within the rotated rectangle. 122 | 123 | Original JS code by 'Andri' and Magnus Hoff from Stack Overflow 124 | 125 | Converted to Python by Aaron Snoswell 126 | """ 127 | 128 | quadrant = int(math.floor(angle / (math.pi / 2))) & 3 129 | sign_alpha = angle if ((quadrant & 1) == 0) else math.pi - angle 130 | alpha = (sign_alpha % math.pi + math.pi) % math.pi 131 | 132 | bb_w = w * math.cos(alpha) + h * math.sin(alpha) 133 | bb_h = w * math.sin(alpha) + h * math.cos(alpha) 134 | 135 | gamma = math.atan2(bb_w, bb_w) if (w < h) else math.atan2(bb_w, bb_w) 136 | 137 | delta = math.pi - alpha - gamma 138 | 139 | length = h if (w < h) else w 140 | 141 | d = length * math.cos(alpha) 142 | a = d * math.sin(alpha) / math.sin(delta) 143 | 144 | y = a * math.cos(gamma) 145 | x = y * math.tan(gamma) 146 | 147 | return ( 148 | bb_w - 2 * x, 149 | bb_h - 2 * y 150 | ) 151 | 152 | 153 | def crop_around_center(image, width, height): 154 | """ 155 | Given a NumPy / OpenCV 2 image, crops it to the given width and height, 156 | around it's centre point 157 | """ 158 | 159 | image_size = (image.shape[1], image.shape[0]) 160 | image_center = (int(image_size[0] * 0.5), int(image_size[1] * 0.5)) 161 | 162 | if(width > image_size[0]): 163 | width = image_size[0] 164 | 165 | if(height > image_size[1]): 166 | height = image_size[1] 167 | 168 | x1 = int(image_center[0] - width * 0.5) 169 | x2 = int(image_center[0] + width * 0.5) 170 | y1 = int(image_center[1] - height * 0.5) 171 | y2 = int(image_center[1] + height * 0.5) 172 | 173 | return image[y1:y2, x1:x2] 174 | 175 | 176 | def rotate_crop(im, gt, crop_size, angle): 177 | h, w = im.shape[0:2] 178 | im_rotated = rotate_image(im, angle) 179 | gt_rotated = rotate_image(gt, angle) 180 | if largest_rotated_rect(w, h, math.radians(angle))[0] >= crop_size: 181 | im_cropped = crop_around_center(im_rotated, crop_size, crop_size) 182 | gt_cropped = crop_around_center(gt_rotated, crop_size, crop_size) 183 | else: 184 | # print('error') 185 | im_cropped = crop_around_center(im, crop_size, crop_size) 186 | gt_cropped = crop_around_center(gt, crop_size, crop_size) 187 | return im_cropped, gt_cropped 188 | 189 | 190 | def lt_crop(im, gt, crop_size): 191 | im_cropped = im[0:crop_size, 0:crop_size, :] 192 | gt_cropped = gt[0:crop_size, 0:crop_size] 193 | return im_cropped, gt_cropped 194 | 195 | 196 | # for polygon vflip 197 | def reflection(): 198 | return lambda x, y: (x, -y) 199 | 200 | 201 | if __name__ == '__main__': 202 | input_image_path = f"../data/AOI_2_Vegas_Train/RGB_8bit/train/images" 203 | input_annos_path = f"../data/AOI_2_Vegas_Train/pixel_geojson" 204 | 205 | save_path = f"../data/spacenet_coco/" 206 | 207 | all_images = os.listdir(input_image_path) 208 | val_count = int(0.15 * len(all_images)) 209 | print(f"No. of val images: {val_count}") 210 | val_images = all_images[0:val_count] 211 | train_images = all_images[val_count:] 212 | 213 | train_set = set(train_images) 214 | val_set = set(val_images) 215 | if len(train_set.intersection(val_set)) > 0 or len(val_set.intersection(train_set)): 216 | raise RuntimeError() 217 | 218 | split = 'train' 219 | 220 | if split == 'train': 221 | query_images = train_images 222 | elif split == 'val': 223 | query_images = val_images 224 | else: 225 | raise Exception(f'"{split}" is an incorrect split choice. Split choice must be either "train" or "val".') 226 | 227 | output_im_train = os.path.join(save_path, split, 'images') 228 | if not os.path.exists(output_im_train): 229 | os.makedirs(output_im_train) 230 | 231 | # patch_width = 725 232 | # patch_height = 725 233 | # patch_overlap = 300 234 | # patch_size = 512 235 | # rotation_list = [22.5, 45, 67.5] 236 | 237 | patch_width = 224 238 | patch_height = 224 239 | patch_overlap = 34 # ~15% of patch size 240 | patch_size = 224 241 | rotation_list = [] 242 | 243 | # main dict for annotation file 244 | output_data_train = { 245 | 'info': {'district': 'SpaceNetv2', 'description': 'building footprints', 'contributor': 'cyens'}, 246 | 'categories': [{'id': 100, 'name': 'building'}], 247 | 'images': [], 248 | 'annotations': [], 249 | } 250 | 251 | train_ob_id = 0 252 | train_im_id = 0 253 | # read in data with npy format 254 | input_label = os.listdir(input_annos_path) 255 | for g_id, label in enumerate(tqdm(input_label)): 256 | # read data 257 | # label_info = [''.join(list(g)) for k, g in groupby(label, key=lambda x: x.isdigit())] 258 | label_info = label.split('_') 259 | 260 | label_name = label_info[-1].split('.')[0] 261 | im_name = [im for im in all_images if label_name+".tif" in im] 262 | assert len(im_name) == 1 263 | im_name = im_name[0] 264 | image_data = io.imread(os.path.join(input_image_path, im_name)) 265 | with open(os.path.join(input_annos_path, label), 'r') as f: 266 | anno_data = json.load(f) 267 | im_h, im_w, _ = image_data.shape 268 | 269 | tile_polygons = [] 270 | for poly in anno_data['features']: 271 | poly = poly['geometry']['coordinates'] 272 | assert len(poly) == 1 273 | poly = np.array(poly[0]) 274 | poly = Polygon(poly) 275 | poly = poly_transform(reflection(), poly) 276 | tile_polygons.append(poly) 277 | tile_polygons = shapely.geometry.MultiPolygon(tile_polygons) 278 | tile_polygons = unary_union(tile_polygons) 279 | # tile_polygons = poly_transform(reflection(), tile_polygons) 280 | 281 | if im_name in query_images: 282 | # for training/val set, split image to 224x224 283 | patch_list = crop2patch(image_data.shape, patch_width, patch_height, patch_overlap) 284 | for pid, pa in enumerate(patch_list): 285 | x_ul, y_ul, pw, ph = pa 286 | # bbox_s = box(y_ul, y_ul+patch_height, x_ul, x_ul+patch_width) 287 | bbox_s = box(x_ul, y_ul, x_ul+patch_width, y_ul+patch_height) 288 | 289 | p_gt = tile_polygons.intersection(bbox_s) 290 | # print(type(p_gt)) 291 | if isinstance(p_gt, Polygon): 292 | p_gt = shapely.geometry.MultiPolygon([p_gt]) 293 | else: 294 | p_gt = shapely.geometry.MultiPolygon(p_gt) 295 | p_im = image_data[y_ul:y_ul+patch_height, x_ul:x_ul+patch_width, :] 296 | p_gts = [] 297 | p_ims = [] 298 | p_im_rd, _ = lt_crop(p_im, p_im[0], patch_size) 299 | p_gts.append(p_gt) 300 | p_ims.append(p_im_rd) 301 | # for angle in rotation_list: 302 | # rot_im, _ = rotate_crop(p_im, p_im, patch_size, angle) 303 | # # p_gts.append(rot_gt) 304 | # p_ims.append(rot_im) 305 | for p_im, p_gt in zip(p_ims, p_gts): 306 | if len(p_gt.geoms) > 0: 307 | p_polygons = p_gt.geoms 308 | for poly in p_polygons: 309 | # poly = poly['geometry']['coordinates'] 310 | # assert len(poly) == 1 311 | poly = np.asarray(poly.exterior.coords) 312 | poly -= np.array([x_ul, y_ul]) 313 | poly = Polygon(poly) 314 | # poly = poly_transform(reflection(), poly) 315 | p_area = round(poly.area, 2) 316 | if p_area > 0: 317 | p_bbox = [poly.bounds[0], poly.bounds[1], poly.bounds[2]-poly.bounds[0], poly.bounds[3]-poly.bounds[1]] 318 | if p_bbox[2] > 0 and p_bbox[3] > 0: 319 | p_seg = [] 320 | coor_list = mapping(poly)['coordinates'] 321 | assert len(coor_list) == 1 322 | # import code; code.interact(local=locals()) 323 | for part_poly in coor_list: 324 | p_seg.append(np.asarray(part_poly).ravel().tolist()) 325 | anno_info = { 326 | 'id': train_ob_id, 327 | 'image_id': train_im_id, 328 | 'segmentation': p_seg, 329 | 'area': p_area, 330 | 'bbox': p_bbox, 331 | 'category_id': 100, 332 | 'iscrowd': 0 333 | } 334 | output_data_train['annotations'].append(anno_info) 335 | train_ob_id += 1 336 | else: # for including negative samples. 337 | anno_info = { 338 | 'id': train_ob_id, 339 | 'image_id': train_im_id, 340 | 'segmentation': [], 341 | 'area': 0., 342 | 'bbox': [], 343 | 'category_id': 100, 344 | 'iscrowd': 1 345 | } 346 | output_data_train['annotations'].append(anno_info) 347 | train_ob_id += 1 348 | # get patch info 349 | p_name = label_name + '-' + str(train_im_id) + '.tif' 350 | patch_info = {'id': train_im_id, 'file_name': p_name, 'width': patch_size, 'height': patch_size} 351 | output_data_train['images'].append(patch_info) 352 | # save patch image 353 | io.imsave(os.path.join(output_im_train, p_name), p_im) 354 | train_im_id += 1 355 | 356 | if not os.path.exists(os.path.join(save_path, split)): 357 | os.makedirs(save_path) 358 | with open(os.path.join(save_path, split, 'annotation.json'), 'w') as f_json: 359 | json.dump(output_data_train, f_json) 360 | -------------------------------------------------------------------------------- /data_preprocess/spacenet_world_to_pixel_coords.py: -------------------------------------------------------------------------------- 1 | import code 2 | import os 3 | import json 4 | from tqdm import tqdm 5 | from osgeo import gdal # use gdal env for this script. 6 | 7 | 8 | def main(): 9 | ## VEGAS SUBSET 10 | # NOTE: Set path to spacenet dataset images and geojson annotations here. 11 | spacenet_dataset_root = f"../data/AOI_2_Vegas_Train/" 12 | geoimages_dir = os.path.join(spacenet_dataset_root, 'RGB_8bit', 'train', 'images') 13 | shapefiles_dir = os.path.join(spacenet_dataset_root, 'geojson', 'buildings') 14 | 15 | save_shapefiles_dir = os.path.join(spacenet_dataset_root, 'pixel_geojson') 16 | os.makedirs(save_shapefiles_dir, exist_ok=True) 17 | 18 | geoimages = os.listdir(geoimages_dir) 19 | shapefiles = os.listdir(shapefiles_dir) 20 | 21 | for i in tqdm(range(len(geoimages))): 22 | geo_im = geoimages[i] 23 | im_desc = geo_im.split('_')[-1].split('.')[0] 24 | shp = [sh for sh in shapefiles if f"{im_desc}.geojson" in sh] 25 | assert len(shp) == 1 26 | shp = shp[0] 27 | 28 | driver = gdal.GetDriverByName('GTiff') 29 | dataset = gdal.Open(os.path.join(geoimages_dir, geo_im)) 30 | band = dataset.GetRasterBand(1) 31 | cols = dataset.RasterXSize 32 | rows = dataset.RasterYSize 33 | transform = dataset.GetGeoTransform() 34 | xOrigin = transform[0] 35 | yOrigin = transform[3] 36 | pixelWidth = transform[1] 37 | pixelHeight = -transform[5] 38 | data = band.ReadAsArray(0, 0, cols, rows) 39 | 40 | with open(os.path.join(shapefiles_dir, shp), 'r') as f: 41 | geo_shp = json.load(f) 42 | 43 | pixel_shp = {} 44 | pixel_shp['type'] = geo_shp['type'] 45 | pixel_shp['features'] = [] 46 | 47 | for feature in geo_shp['features']: 48 | out_feat = { 49 | 'type': 'Feature', 50 | 'geometry': { 51 | 'type': 'Polygon', 52 | 'coordinates': [] 53 | } 54 | } 55 | if feature['geometry']['type'] == "Polygon": 56 | geo_coords = feature['geometry']['coordinates'][0] 57 | points_list = [(gc[0], gc[1]) for gc in geo_coords] 58 | coords_list = [] 59 | for point in points_list: 60 | col = (point[0] - xOrigin) / pixelWidth 61 | col = col if col < cols else cols 62 | row = (yOrigin - point[1]) / pixelHeight 63 | row = row if row < rows else rows 64 | # 'row' has negative sign to be compatible with qgis visualization. Must be removed for compatibility in image space. 65 | coords_list.append([col, -row]) 66 | out_feat['geometry']['coordinates'].append(coords_list) 67 | pixel_shp['features'].append(out_feat) 68 | 69 | with open(os.path.join(save_shapefiles_dir, shp), 'w') as o: 70 | json.dump(pixel_shp, o) 71 | 72 | 73 | if __name__ == "__main__": 74 | main() 75 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yeshwanth95/Pix2Poly/d2e288b4e267ae910cee6f54bece8eec98f43bea/datasets/__init__.py -------------------------------------------------------------------------------- /datasets/dataset_inria_coco.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import numpy as np 3 | import os 4 | from os import path as osp 5 | from pycocotools.coco import COCO 6 | from config import CFG 7 | 8 | import torch 9 | from torch.utils.data import Dataset 10 | from torch.nn.utils.rnn import pad_sequence 11 | 12 | 13 | class InriaCocoDataset(Dataset): 14 | def __init__(self, dataset_dir, transform=None, tokenizer=None, shuffle_tokens=False): 15 | image_dir = osp.join(dataset_dir, "images") 16 | self.image_dir = image_dir 17 | self.annotations_path = osp.join(dataset_dir, "annotation.json") 18 | self.transform = transform 19 | self.tokenizer = tokenizer 20 | self.shuffle_tokens = shuffle_tokens 21 | # self.images = os.listdir(self.image_dir) 22 | self.coco = COCO(self.annotations_path) 23 | # self.image_ids = self.coco.getImgIds(catIds=self.coco.getCatIds()) 24 | self.images = [file for file in os.listdir(self.image_dir) if osp.isfile(osp.join(self.image_dir, file))] 25 | self.image_ids = [int(im.split('-')[-1].split('.')[0]) for im in self.images if im.split('-')[0] not in ['kitsap4', 'kitsap5']] 26 | 27 | def __len__(self): 28 | return len(self.image_ids) 29 | 30 | def annToMask(self): 31 | return 32 | 33 | def shuffle_perm_matrix_by_indices(self, old_perm: torch.Tensor, shuffle_idxs: np.ndarray): 34 | Nv = old_perm.shape[0] 35 | padd_idxs = np.arange(len(shuffle_idxs), Nv) 36 | shuffle_idxs = np.concatenate([shuffle_idxs, padd_idxs], axis=0) 37 | 38 | transform_arr = torch.zeros_like(old_perm) 39 | for i in range(len(shuffle_idxs)): 40 | transform_arr[i, shuffle_idxs[i]] = 1. 41 | 42 | # https://math.stackexchange.com/questions/2481213/adjacency-matrix-and-changing-order-of-vertices 43 | new_perm = torch.mm(torch.mm(transform_arr, old_perm), transform_arr.T) 44 | 45 | return new_perm 46 | 47 | def __getitem__(self, index): 48 | n_vertices = CFG.N_VERTICES 49 | img_id = self.image_ids[index] 50 | img = self.coco.loadImgs(img_id)[0] 51 | img_path = osp.join(self.image_dir, img["file_name"]) 52 | ann_ids = self.coco.getAnnIds(imgIds=img['id']) 53 | annotations = self.coco.loadAnns(ann_ids) # annotations of all instances in an image. 54 | 55 | image = np.array(Image.open(img_path).convert("RGB")) 56 | 57 | mask = np.zeros((img['width'], img['height'])) 58 | corner_coords = [] 59 | corner_mask = np.zeros((img['width'], img['height']), dtype=np.float32) 60 | perm_matrix = np.zeros((n_vertices, n_vertices), dtype=np.float32) 61 | for ins in annotations: 62 | segmentations = ins['segmentation'] 63 | for i, segm in enumerate(segmentations): 64 | segm = np.array(segm).reshape(-1, 2) 65 | segm[:, 0] = np.clip(segm[:, 0], 0, img['width'] - 1) 66 | segm[:, 1] = np.clip(segm[:, 1], 0, img['height'] - 1) 67 | points = segm[:-1] 68 | corner_coords.extend(points.tolist()) 69 | mask += self.coco.annToMask(ins) 70 | mask = mask / 255. if mask.max() == 255 else mask 71 | mask = np.clip(mask, 0, 1) 72 | 73 | corner_coords = np.flip(np.round(corner_coords, 0), axis=-1).astype(np.int32) 74 | 75 | if len(corner_coords) > 0.: 76 | corner_mask[corner_coords[:, 0], corner_coords[:, 1]] = 1. 77 | 78 | ############# START: Generate gt permutation matrix. ############# 79 | v_count = 0 80 | for ins in annotations: 81 | segmentations = ins['segmentation'] 82 | for idx, segm in enumerate(segmentations): 83 | segm = np.array(segm).reshape(-1, 2) 84 | points = segm[:-1] 85 | for i in range(len(points)): 86 | j = (i + 1) % len(points) 87 | if v_count+i > n_vertices - 1 or v_count+j > n_vertices-1: 88 | break 89 | perm_matrix[v_count+i, v_count+j] = 1. 90 | v_count += len(points) 91 | 92 | for i in range(v_count, n_vertices): 93 | perm_matrix[i, i] = 1. 94 | 95 | # Workaround for open contours: 96 | for i in range(n_vertices): 97 | row = perm_matrix[i, :] 98 | col = perm_matrix[:, i] 99 | if np.sum(row) == 0 or np.sum(col) == 0: 100 | perm_matrix[i, i] = 1. 101 | perm_matrix = torch.from_numpy(perm_matrix) 102 | ############# END: Generate gt permutation matrix. ############# 103 | 104 | masks = [mask, corner_mask] 105 | 106 | if len(corner_coords) > CFG.N_VERTICES: 107 | corner_coords = corner_coords[:CFG.N_VERTICES] 108 | 109 | if self.transform is not None: 110 | augmentations = self.transform(image=image, masks=masks, keypoints=corner_coords.tolist()) 111 | image = augmentations['image'] 112 | mask = augmentations['masks'][0] 113 | corner_mask = augmentations['masks'][1] 114 | corner_coords = np.array(augmentations['keypoints']) 115 | 116 | if self.tokenizer is not None: 117 | coords_seqs, rand_idxs = self.tokenizer(corner_coords, shuffle=self.shuffle_tokens) 118 | coords_seqs = torch.LongTensor(coords_seqs) 119 | if self.shuffle_tokens: 120 | perm_matrix = self.shuffle_perm_matrix_by_indices(perm_matrix, rand_idxs) 121 | else: 122 | coords_seqs = corner_coords 123 | 124 | return image, mask[None, ...], corner_mask[None, ...], coords_seqs, perm_matrix 125 | 126 | 127 | class InriaCocoDataset_val(Dataset): 128 | def __init__(self, cfg, dataset_dir, transform=None, tokenizer=None, shuffle_tokens=False): 129 | self.CFG = cfg 130 | image_dir = osp.join(dataset_dir, "images") 131 | self.image_dir = image_dir 132 | self.annotations_path = osp.join(dataset_dir, "annotation.json") 133 | self.transform = transform 134 | self.tokenizer = tokenizer 135 | self.shuffle_tokens = shuffle_tokens 136 | # self.images = os.listdir(self.image_dir) 137 | self.coco = COCO(self.annotations_path) 138 | # self.image_ids = self.coco.getImgIds(catIds=self.coco.getCatIds()) 139 | self.images = [file for file in os.listdir(self.image_dir) if osp.isfile(osp.join(self.image_dir, file))] 140 | self.image_ids = [int(im.split('-')[-1].split('.')[0]) for im in self.images] 141 | 142 | def __len__(self): 143 | return len(self.image_ids) 144 | 145 | def annToMask(self): 146 | return 147 | 148 | def shuffle_perm_matrix_by_indices(self, old_perm: torch.Tensor, shuffle_idxs: np.ndarray): 149 | Nv = old_perm.shape[0] 150 | padd_idxs = np.arange(len(shuffle_idxs), Nv) 151 | shuffle_idxs = np.concatenate([shuffle_idxs, padd_idxs], axis=0) 152 | 153 | transform_arr = torch.zeros_like(old_perm) 154 | for i in range(len(shuffle_idxs)): 155 | transform_arr[i, shuffle_idxs[i]] = 1. 156 | 157 | # https://math.stackexchange.com/questions/2481213/adjacency-matrix-and-changing-order-of-vertices 158 | new_perm = torch.mm(torch.mm(transform_arr, old_perm), transform_arr.T) 159 | 160 | return new_perm 161 | 162 | def __getitem__(self, index): 163 | n_vertices = self.CFG.N_VERTICES 164 | img_id = self.image_ids[index] 165 | img = self.coco.loadImgs(img_id)[0] 166 | img_path = osp.join(self.image_dir, img["file_name"]) 167 | ann_ids = self.coco.getAnnIds(imgIds=img['id']) 168 | annotations = self.coco.loadAnns(ann_ids) # annotations of all instances in an image. 169 | 170 | image = np.array(Image.open(img_path).convert("RGB")) 171 | 172 | mask = np.zeros((img['width'], img['height'])) 173 | corner_coords = [] 174 | corner_mask = np.zeros((img['width'], img['height']), dtype=np.float32) 175 | perm_matrix = np.zeros((n_vertices, n_vertices), dtype=np.float32) 176 | for ins in annotations: 177 | segmentations = ins['segmentation'] 178 | for i, segm in enumerate(segmentations): 179 | segm = np.array(segm).reshape(-1, 2) 180 | segm[:, 0] = np.clip(segm[:, 0], 0, img['width'] - 1) 181 | segm[:, 1] = np.clip(segm[:, 1], 0, img['height'] - 1) 182 | points = segm[:-1] 183 | corner_coords.extend(points.tolist()) 184 | mask += self.coco.annToMask(ins) 185 | mask = mask / 255. if mask.max() == 255 else mask 186 | mask = np.clip(mask, 0, 1) 187 | 188 | corner_coords = np.flip(np.round(corner_coords, 0), axis=-1).astype(np.int32) 189 | 190 | if len(corner_coords) > 0.: 191 | corner_mask[corner_coords[:, 0], corner_coords[:, 1]] = 1. 192 | 193 | ############# START: Generate gt permutation matrix. ############# 194 | v_count = 0 195 | for ins in annotations: 196 | segmentations = ins['segmentation'] 197 | for idx, segm in enumerate(segmentations): 198 | segm = np.array(segm).reshape(-1, 2) 199 | points = segm[:-1] 200 | for i in range(len(points)): 201 | j = (i + 1) % len(points) 202 | if v_count+i > n_vertices - 1 or v_count+j > n_vertices-1: 203 | break 204 | perm_matrix[v_count+i, v_count+j] = 1. 205 | v_count += len(points) 206 | 207 | for i in range(v_count, n_vertices): 208 | perm_matrix[i, i] = 1. 209 | 210 | # Workaround for open contours: 211 | for i in range(n_vertices): 212 | row = perm_matrix[i, :] 213 | col = perm_matrix[:, i] 214 | if np.sum(row) == 0 or np.sum(col) == 0: 215 | perm_matrix[i, i] = 1. 216 | perm_matrix = torch.from_numpy(perm_matrix) 217 | ############# END: Generate gt permutation matrix. ############# 218 | 219 | masks = [mask, corner_mask] 220 | 221 | if len(corner_coords) > self.CFG.N_VERTICES: 222 | corner_coords = corner_coords[:self.CFG.N_VERTICES] 223 | 224 | if self.transform is not None: 225 | augmentations = self.transform(image=image, masks=masks, keypoints=corner_coords.tolist()) 226 | image = augmentations['image'] 227 | mask = augmentations['masks'][0] 228 | corner_mask = augmentations['masks'][1] 229 | corner_coords = np.array(augmentations['keypoints']) 230 | 231 | if self.tokenizer is not None: 232 | coords_seqs, rand_idxs = self.tokenizer(corner_coords, shuffle=self.shuffle_tokens) 233 | coords_seqs = torch.LongTensor(coords_seqs) 234 | if self.shuffle_tokens: 235 | perm_matrix = self.shuffle_perm_matrix_by_indices(perm_matrix, rand_idxs) 236 | else: 237 | coords_seqs = corner_coords 238 | 239 | return image, mask[None, ...], corner_mask[None, ...], coords_seqs, perm_matrix, torch.tensor([img['id']]) 240 | 241 | 242 | def collate_fn(batch, max_len, pad_idx): 243 | """ 244 | if max_len: 245 | the sequences will all be padded to that length. 246 | """ 247 | 248 | image_batch, mask_batch, coords_mask_batch, coords_seq_batch, perm_matrix_batch = [], [], [], [], [] 249 | for image, mask, c_mask, seq, perm_mat in batch: 250 | image_batch.append(image) 251 | mask_batch.append(mask) 252 | coords_mask_batch.append(c_mask) 253 | coords_seq_batch.append(seq) 254 | perm_matrix_batch.append(perm_mat) 255 | 256 | coords_seq_batch = pad_sequence( 257 | coords_seq_batch, 258 | padding_value=pad_idx, 259 | batch_first=True 260 | ) 261 | 262 | if max_len: 263 | pad = torch.ones(coords_seq_batch.size(0), max_len - coords_seq_batch.size(1)).fill_(pad_idx).long() 264 | coords_seq_batch = torch.cat([coords_seq_batch, pad], dim=1) 265 | 266 | image_batch = torch.stack(image_batch) 267 | mask_batch = torch.stack(mask_batch) 268 | coords_mask_batch = torch.stack(coords_mask_batch) 269 | perm_matrix_batch = torch.stack(perm_matrix_batch) 270 | return image_batch, mask_batch, coords_mask_batch, coords_seq_batch, perm_matrix_batch 271 | 272 | 273 | class InriaCocoDatasetTest(Dataset): 274 | def __init__(self, image_dir, transform=None): 275 | self.image_dir = image_dir 276 | self.transform = transform 277 | self.images = [file for file in os.listdir(self.image_dir) if osp.isfile(osp.join(self.image_dir, file))] 278 | 279 | def __getitem__(self, index): 280 | img_path = osp.join(self.image_dir, self.images[index]) 281 | image = np.array(Image.open(img_path).convert("RGB")) 282 | 283 | if self.transform is not None: 284 | image = self.transform(image=image)['image'] 285 | 286 | image = torch.FloatTensor(image) 287 | return image 288 | 289 | def __len__(self): 290 | return len(self.images) 291 | -------------------------------------------------------------------------------- /datasets/dataset_spacenet_coco.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import numpy as np 3 | import os 4 | from os import path as osp 5 | from pycocotools.coco import COCO 6 | from config import CFG 7 | 8 | import torch 9 | from torch.utils.data import Dataset 10 | 11 | 12 | class SpacenetCocoDataset(Dataset): 13 | def __init__(self, dataset_dir, transform=None, tokenizer=None, shuffle_tokens=False): 14 | image_dir = osp.join(dataset_dir, "images") 15 | self.image_dir = image_dir 16 | self.annotations_path = osp.join(dataset_dir, "annotation.json") 17 | self.transform = transform 18 | self.tokenizer = tokenizer 19 | self.shuffle_tokens = shuffle_tokens 20 | # self.images = os.listdir(self.image_dir) 21 | self.coco = COCO(self.annotations_path) 22 | # self.image_ids = self.coco.getImgIds(catIds=self.coco.getCatIds()) 23 | self.images = [file for file in os.listdir(self.image_dir) if osp.isfile(osp.join(self.image_dir, file))] 24 | self.image_ids = [int(im.split('-')[-1].split('.')[0]) for im in self.images] 25 | 26 | def __len__(self): 27 | return len(self.image_ids) 28 | 29 | def annToMask(self): 30 | return 31 | 32 | def shuffle_perm_matrix_by_indices(self, old_perm: torch.Tensor, shuffle_idxs: np.ndarray): 33 | Nv = old_perm.shape[0] 34 | padd_idxs = np.arange(len(shuffle_idxs), Nv) 35 | shuffle_idxs = np.concatenate([shuffle_idxs, padd_idxs], axis=0) 36 | 37 | transform_arr = torch.zeros_like(old_perm) 38 | for i in range(len(shuffle_idxs)): 39 | transform_arr[i, shuffle_idxs[i]] = 1. 40 | 41 | # https://math.stackexchange.com/questions/2481213/adjacency-matrix-and-changing-order-of-vertices 42 | new_perm = torch.mm(torch.mm(transform_arr, old_perm), transform_arr.T) 43 | 44 | return new_perm 45 | 46 | def __getitem__(self, index): 47 | n_vertices = CFG.N_VERTICES 48 | img_id = self.image_ids[index] 49 | img = self.coco.loadImgs(img_id)[0] 50 | img_path = osp.join(self.image_dir, img["file_name"]) 51 | ann_ids = self.coco.getAnnIds(imgIds=img['id']) 52 | annotations = self.coco.loadAnns(ann_ids) # annotations of all instances in an image. 53 | 54 | image = np.array(Image.open(img_path).convert("RGB")) 55 | 56 | mask = np.zeros((img['width'], img['height'])) 57 | corner_coords = [] 58 | corner_mask = np.zeros((img['width'], img['height']), dtype=np.float32) 59 | perm_matrix = np.zeros((n_vertices, n_vertices), dtype=np.float32) 60 | for ins in annotations: 61 | segmentations = ins['segmentation'] 62 | for i, segm in enumerate(segmentations): 63 | segm = np.array(segm).reshape(-1, 2) 64 | segm[:, 0] = np.clip(segm[:, 0], 0, img['width'] - 1) 65 | segm[:, 1] = np.clip(segm[:, 1], 0, img['height'] - 1) 66 | points = segm[:-1] 67 | corner_coords.extend(points.tolist()) 68 | mask += self.coco.annToMask(ins) 69 | mask = mask / 255. if mask.max() == 255 else mask 70 | mask = np.clip(mask, 0, 1) 71 | 72 | corner_coords = np.flip(np.round(corner_coords, 0), axis=-1).astype(np.int32) 73 | 74 | if len(corner_coords) > 0.: 75 | corner_mask[corner_coords[:, 0], corner_coords[:, 1]] = 1. 76 | 77 | ############# START: Generate gt permutation matrix. ############# 78 | v_count = 0 79 | for ins in annotations: 80 | segmentations = ins['segmentation'] 81 | for idx, segm in enumerate(segmentations): 82 | segm = np.array(segm).reshape(-1, 2) 83 | points = segm[:-1] 84 | for i in range(len(points)): 85 | j = (i + 1) % len(points) 86 | if v_count+i > n_vertices - 1 or v_count+j > n_vertices-1: 87 | break 88 | perm_matrix[v_count+i, v_count+j] = 1. 89 | v_count += len(points) 90 | 91 | for i in range(v_count, n_vertices): 92 | perm_matrix[i, i] = 1. 93 | 94 | # Workaround for open contours: 95 | for i in range(n_vertices): 96 | row = perm_matrix[i, :] 97 | col = perm_matrix[:, i] 98 | if np.sum(row) == 0 or np.sum(col) == 0: 99 | perm_matrix[i, i] = 1. 100 | perm_matrix = torch.from_numpy(perm_matrix) 101 | ############# END: Generate gt permutation matrix. ############# 102 | 103 | masks = [mask, corner_mask] 104 | 105 | if len(corner_coords) > CFG.N_VERTICES: 106 | corner_coords = corner_coords[:CFG.N_VERTICES] 107 | 108 | if self.transform is not None: 109 | augmentations = self.transform(image=image, masks=masks, keypoints=corner_coords.tolist()) 110 | image = augmentations['image'] 111 | mask = augmentations['masks'][0] 112 | corner_mask = augmentations['masks'][1] 113 | corner_coords = np.array(augmentations['keypoints']) 114 | 115 | if self.tokenizer is not None: 116 | coords_seqs, rand_idxs = self.tokenizer(corner_coords, shuffle=self.shuffle_tokens) 117 | coords_seqs = torch.LongTensor(coords_seqs) 118 | if self.shuffle_tokens: 119 | perm_matrix = self.shuffle_perm_matrix_by_indices(perm_matrix, rand_idxs) 120 | else: 121 | coords_seqs = corner_coords 122 | 123 | return image, mask[None, ...], corner_mask[None, ...], coords_seqs, perm_matrix 124 | 125 | 126 | class SpacenetCocoDataset_val(Dataset): 127 | def __init__(self, cfg, dataset_dir, transform=None, tokenizer=None, shuffle_tokens=False): 128 | self.CFG = cfg 129 | image_dir = osp.join(dataset_dir, "images") 130 | self.image_dir = image_dir 131 | self.annotations_path = osp.join(dataset_dir, "annotation.json") 132 | self.transform = transform 133 | self.tokenizer = tokenizer 134 | self.shuffle_tokens = shuffle_tokens 135 | # self.images = os.listdir(self.image_dir) 136 | self.coco = COCO(self.annotations_path) 137 | # self.image_ids = self.coco.getImgIds(catIds=self.coco.getCatIds()) 138 | self.images = [file for file in os.listdir(self.image_dir) if osp.isfile(osp.join(self.image_dir, file))] 139 | self.image_ids = [int(im.split('-')[-1].split('.')[0]) for im in self.images] 140 | 141 | def __len__(self): 142 | return len(self.image_ids) 143 | 144 | def annToMask(self): 145 | return 146 | 147 | def __getitem__(self, index): 148 | n_vertices = self.CFG.N_VERTICES 149 | img_id = self.image_ids[index] 150 | img = self.coco.loadImgs(img_id)[0] 151 | img_path = osp.join(self.image_dir, img["file_name"]) 152 | ann_ids = self.coco.getAnnIds(imgIds=img['id']) 153 | annotations = self.coco.loadAnns(ann_ids) # annotations of all instances in an image. 154 | 155 | image = np.array(Image.open(img_path).convert("RGB")) 156 | 157 | mask = np.zeros((img['width'], img['height'])) 158 | corner_coords = [] 159 | corner_mask = np.zeros((img['width'], img['height']), dtype=np.float32) 160 | perm_matrix = np.zeros((n_vertices, n_vertices), dtype=np.float32) 161 | for ins in annotations: 162 | segmentations = ins['segmentation'] 163 | for i, segm in enumerate(segmentations): 164 | segm = np.array(segm).reshape(-1, 2) 165 | segm[:, 0] = np.clip(segm[:, 0], 0, img['width'] - 1) 166 | segm[:, 1] = np.clip(segm[:, 1], 0, img['height'] - 1) 167 | points = segm[:-1] 168 | corner_coords.extend(points.tolist()) 169 | mask += self.coco.annToMask(ins) 170 | mask = mask / 255. if mask.max() == 255 else mask 171 | mask = np.clip(mask, 0, 1) 172 | 173 | corner_coords = np.flip(np.round(corner_coords, 0), axis=-1).astype(np.int32) 174 | 175 | if len(corner_coords) > 0.: 176 | corner_mask[corner_coords[:, 0], corner_coords[:, 1]] = 1. 177 | 178 | ############# START: Generate gt permutation matrix. ############# 179 | v_count = 0 180 | for ins in annotations: 181 | segmentations = ins['segmentation'] 182 | for idx, segm in enumerate(segmentations): 183 | segm = np.array(segm).reshape(-1, 2) 184 | points = segm[:-1] 185 | for i in range(len(points)): 186 | j = (i + 1) % len(points) 187 | if v_count+i > n_vertices - 1 or v_count+j > n_vertices-1: 188 | break 189 | perm_matrix[v_count+i, v_count+j] = 1. 190 | v_count += len(points) 191 | 192 | for i in range(v_count, n_vertices): 193 | perm_matrix[i, i] = 1. 194 | 195 | # Workaround for open contours: 196 | for i in range(n_vertices): 197 | row = perm_matrix[i, :] 198 | col = perm_matrix[:, i] 199 | if np.sum(row) == 0 or np.sum(col) == 0: 200 | perm_matrix[i, i] = 1. 201 | perm_matrix = torch.from_numpy(perm_matrix) 202 | ############# END: Generate gt permutation matrix. ############# 203 | 204 | masks = [mask, corner_mask] 205 | 206 | if len(corner_coords) > self.CFG.N_VERTICES: 207 | corner_coords = corner_coords[:self.CFG.N_VERTICES] 208 | 209 | if self.transform is not None: 210 | augmentations = self.transform(image=image, masks=masks, keypoints=corner_coords.tolist()) 211 | image = augmentations['image'] 212 | mask = augmentations['masks'][0] 213 | corner_mask = augmentations['masks'][1] 214 | corner_coords = np.array(augmentations['keypoints']) 215 | 216 | if self.tokenizer is not None: 217 | coords_seqs, rand_idxs = self.tokenizer(corner_coords, shuffle=self.shuffle_tokens) 218 | coords_seqs = torch.LongTensor(coords_seqs) 219 | perm_matrix = torch.cat((perm_matrix[rand_idxs], perm_matrix[len(rand_idxs):])) 220 | else: 221 | coords_seqs = corner_coords 222 | 223 | return image, mask[None, ...], corner_mask[None, ...], coords_seqs, perm_matrix, torch.tensor([img['id']]) 224 | 225 | 226 | class SpacenetCocoDatasetTest(Dataset): 227 | def __init__(self, image_dir, transform=None): 228 | self.image_dir = image_dir 229 | self.transform = transform 230 | self.images = [file for file in os.listdir(self.image_dir) if osp.isfile(osp.join(self.image_dir, file))] 231 | 232 | def __getitem__(self, index): 233 | img_path = osp.join(self.image_dir, self.images[index]) 234 | image = np.array(Image.open(img_path).convert("RGB")) 235 | 236 | if self.transform is not None: 237 | image = self.transform(image=image)['image'] 238 | 239 | image = torch.FloatTensor(image) 240 | return image 241 | 242 | def __len__(self): 243 | return len(self.images) 244 | -------------------------------------------------------------------------------- /datasets/dataset_whu_buildings_coco.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import numpy as np 3 | from functools import partial 4 | import os 5 | from os import path as osp 6 | from pycocotools.coco import COCO 7 | from pycocotools import mask as cocomask 8 | from config import CFG 9 | 10 | import torch 11 | from torch.utils.data import Dataset 12 | import albumentations as A 13 | from torch.nn.utils.rnn import pad_sequence 14 | 15 | 16 | class WHUBuildingsCocoDataset(Dataset): 17 | def __init__(self, dataset_dir, transform=None, tokenizer=None, shuffle_tokens=False): 18 | image_dir = osp.join(dataset_dir, "images") 19 | self.image_dir = image_dir 20 | self.annotations_path = osp.join(dataset_dir, "annotation.json") 21 | self.transform = transform 22 | self.tokenizer = tokenizer 23 | self.shuffle_tokens = shuffle_tokens 24 | # self.images = os.listdir(self.image_dir) 25 | self.coco = COCO(self.annotations_path) 26 | # self.image_ids = self.coco.getImgIds(catIds=self.coco.getCatIds()) 27 | self.images = [file for file in os.listdir(self.image_dir) if osp.isfile(osp.join(self.image_dir, file))] 28 | # image_ids = [int(im.split('-')[-1].split('.')[0]) for im in self.images] 29 | image_ids = self.coco.getImgIds(catIds=self.coco.getCatIds()) 30 | if "train" in dataset_dir: 31 | # remove images with more than 144 vertices from training. 32 | self.image_ids = [im for im in image_ids if int(im) not in [16608, 36020, 36021]] 33 | else: 34 | self.image_ids = image_ids 35 | 36 | def __len__(self): 37 | return len(self.image_ids) 38 | 39 | def annToMask(self): 40 | return 41 | 42 | def shuffle_perm_matrix_by_indices(self, old_perm: torch.Tensor, shuffle_idxs: np.ndarray): 43 | Nv = old_perm.shape[0] 44 | padd_idxs = np.arange(len(shuffle_idxs), Nv) 45 | shuffle_idxs = np.concatenate([shuffle_idxs, padd_idxs], axis=0) 46 | 47 | transform_arr = torch.zeros_like(old_perm) 48 | for i in range(len(shuffle_idxs)): 49 | transform_arr[i, shuffle_idxs[i]] = 1. 50 | 51 | # https://math.stackexchange.com/questions/2481213/adjacency-matrix-and-changing-order-of-vertices 52 | new_perm = torch.mm(torch.mm(transform_arr, old_perm), transform_arr.T) 53 | # new_perm = torch.zeros_like(old_perm) 54 | 55 | # # generate new perm matrix based on shuffling indices. 56 | # for i in range(len(shuffle_idxs)): 57 | # new_i = shuffle_idxs[i] 58 | # new_j = shuffle_idxs[old_perm[i].nonzero().item()] 59 | # new_perm[new_i, new_j] = 1. 60 | 61 | # # Add self connections of unconnected vertices. 62 | # for i in range(Nv): 63 | # row = new_perm[i, :] 64 | # col = new_perm[:, i] 65 | # if torch.sum(row) == 0 or torch.sum(col) == 0: 66 | # new_perm[i, i] = 1. 67 | 68 | return new_perm 69 | 70 | def __getitem__(self, index): 71 | n_vertices = CFG.N_VERTICES 72 | img_id = self.image_ids[index] 73 | img = self.coco.loadImgs(img_id)[0] 74 | img_path = osp.join(self.image_dir, img["file_name"]) 75 | ann_ids = self.coco.getAnnIds(imgIds=img['id']) 76 | annotations = self.coco.loadAnns(ann_ids) # annotations of all instances in an image. 77 | 78 | image = np.array(Image.open(img_path).convert("RGB")) 79 | 80 | mask = np.zeros((img['width'], img['height'])) 81 | corner_coords = [] 82 | corner_mask = np.zeros((img['width'], img['height']), dtype=np.float32) 83 | perm_matrix = np.zeros((n_vertices, n_vertices), dtype=np.float32) 84 | for ins in annotations: 85 | segmentations = ins['segmentation'] 86 | for i, segm in enumerate(segmentations): 87 | segm = np.array(segm).reshape(-1, 2) 88 | segm[:, 0] = np.clip(segm[:, 0], 0, img['width'] - 1) 89 | segm[:, 1] = np.clip(segm[:, 1], 0, img['height'] - 1) 90 | points = segm[:-1] 91 | corner_coords.extend(points.tolist()) 92 | mask += self.coco.annToMask(ins) 93 | mask = mask / 255. if mask.max() == 255 else mask 94 | mask = np.clip(mask, 0, 1) 95 | 96 | # corner_coords = np.clip(np.array(corner_coords), 0, 299) 97 | corner_coords = np.flip(np.round(corner_coords, 0), axis=-1).astype(np.int32) 98 | 99 | if len(corner_coords) > 0.: 100 | corner_mask[corner_coords[:, 0], corner_coords[:, 1]] = 1. 101 | # corner_coords = (corner_coords / img['width']) * CFG.INPUT_WIDTH 102 | 103 | ############# START: Generate gt permutation matrix. ############# 104 | v_count = 0 105 | for ins in annotations: 106 | segmentations = ins['segmentation'] 107 | for idx, segm in enumerate(segmentations): 108 | segm = np.array(segm).reshape(-1, 2) 109 | points = segm[:-1] 110 | for i in range(len(points)): 111 | j = (i + 1) % len(points) 112 | if v_count+i > n_vertices - 1 or v_count+j > n_vertices-1: 113 | break 114 | perm_matrix[v_count+i, v_count+j] = 1. 115 | v_count += len(points) 116 | 117 | for i in range(v_count, n_vertices): 118 | perm_matrix[i, i] = 1. 119 | 120 | # Workaround for open contours: 121 | for i in range(n_vertices): 122 | row = perm_matrix[i, :] 123 | col = perm_matrix[:, i] 124 | if np.sum(row) == 0 or np.sum(col) == 0: 125 | perm_matrix[i, i] = 1. 126 | perm_matrix = torch.from_numpy(perm_matrix) 127 | ############# END: Generate gt permutation matrix. ############# 128 | 129 | masks = [mask, corner_mask] 130 | 131 | if len(corner_coords) > CFG.N_VERTICES: 132 | corner_coords = corner_coords[:CFG.N_VERTICES] 133 | 134 | if self.transform is not None: 135 | augmentations = self.transform(image=image, masks=masks, keypoints=corner_coords.tolist()) 136 | image = augmentations['image'] 137 | mask = augmentations['masks'][0] 138 | corner_mask = augmentations['masks'][1] 139 | corner_coords = np.array(augmentations['keypoints']) 140 | 141 | if self.tokenizer is not None: 142 | coords_seqs, rand_idxs = self.tokenizer(corner_coords, shuffle=self.shuffle_tokens) 143 | coords_seqs = torch.LongTensor(coords_seqs) 144 | # perm_matrix = torch.cat((perm_matrix[rand_idxs], perm_matrix[len(rand_idxs):])) 145 | if self.shuffle_tokens: 146 | perm_matrix = self.shuffle_perm_matrix_by_indices(perm_matrix, rand_idxs) 147 | else: 148 | coords_seqs = corner_coords 149 | 150 | return image, mask[None, ...], corner_mask[None, ...], coords_seqs, perm_matrix 151 | 152 | 153 | class WHUBuildingsCocoDataset_val(Dataset): 154 | def __init__(self, cfg, dataset_dir, transform=None, tokenizer=None, shuffle_tokens=False): 155 | self.CFG = cfg 156 | image_dir = osp.join(dataset_dir, "images") 157 | self.image_dir = image_dir 158 | self.annotations_path = osp.join(dataset_dir, "annotation.json") 159 | self.transform = transform 160 | self.tokenizer = tokenizer 161 | self.shuffle_tokens = shuffle_tokens 162 | # self.images = os.listdir(self.image_dir) 163 | self.coco = COCO(self.annotations_path) 164 | # self.image_ids = self.coco.getImgIds(catIds=self.coco.getCatIds()) 165 | self.images = [file for file in os.listdir(self.image_dir) if osp.isfile(osp.join(self.image_dir, file))] 166 | # self.image_ids = [int(im.split('-')[-1].split('.')[0]) for im in self.images] 167 | image_ids = self.coco.getImgIds(catIds=self.coco.getCatIds()) 168 | if "train" in dataset_dir: 169 | # remove images with more than 144 vertices from training. 170 | self.image_ids = [im for im in image_ids if int(im) not in [16608, 36020, 36021]] 171 | else: 172 | self.image_ids = image_ids 173 | 174 | def __len__(self): 175 | return len(self.image_ids) 176 | 177 | def annToMask(self): 178 | return 179 | 180 | def shuffle_perm_matrix_by_indices(self, old_perm: torch.Tensor, shuffle_idxs: np.ndarray): 181 | Nv = old_perm.shape[0] 182 | padd_idxs = np.arange(len(shuffle_idxs), Nv) 183 | shuffle_idxs = np.concatenate([shuffle_idxs, padd_idxs], axis=0) 184 | 185 | transform_arr = torch.zeros_like(old_perm) 186 | for i in range(len(shuffle_idxs)): 187 | transform_arr[i, shuffle_idxs[i]] = 1. 188 | 189 | # https://math.stackexchange.com/questions/2481213/adjacency-matrix-and-changing-order-of-vertices 190 | new_perm = torch.mm(torch.mm(transform_arr, old_perm), transform_arr.T) 191 | 192 | return new_perm 193 | 194 | def __getitem__(self, index): 195 | n_vertices = self.CFG.N_VERTICES 196 | img_id = self.image_ids[index] 197 | img = self.coco.loadImgs(img_id)[0] 198 | img_path = osp.join(self.image_dir, img["file_name"]) 199 | ann_ids = self.coco.getAnnIds(imgIds=img['id']) 200 | annotations = self.coco.loadAnns(ann_ids) # annotations of all instances in an image. 201 | 202 | image = np.array(Image.open(img_path).convert("RGB")) 203 | 204 | mask = np.zeros((img['width'], img['height'])) 205 | corner_coords = [] 206 | corner_mask = np.zeros((img['width'], img['height']), dtype=np.float32) 207 | perm_matrix = np.zeros((n_vertices, n_vertices), dtype=np.float32) 208 | for ins in annotations: 209 | segmentations = ins['segmentation'] 210 | for i, segm in enumerate(segmentations): 211 | segm = np.array(segm).reshape(-1, 2) 212 | segm[:, 0] = np.clip(segm[:, 0], 0, img['width'] - 1) 213 | segm[:, 1] = np.clip(segm[:, 1], 0, img['height'] - 1) 214 | points = segm[:-1] 215 | corner_coords.extend(points.tolist()) 216 | mask += self.coco.annToMask(ins) 217 | mask = mask / 255. if mask.max() == 255 else mask 218 | mask = np.clip(mask, 0, 1) 219 | 220 | # corner_coords = np.clip(np.array(corner_coords), 0, 299) 221 | corner_coords = np.flip(np.round(corner_coords, 0), axis=-1).astype(np.int32) 222 | 223 | if len(corner_coords) > 0.: 224 | corner_mask[corner_coords[:, 0], corner_coords[:, 1]] = 1. 225 | # corner_coords = (corner_coords / img['width']) * CFG.INPUT_WIDTH 226 | 227 | ############# START: Generate gt permutation matrix. ############# 228 | v_count = 0 229 | for ins in annotations: 230 | segmentations = ins['segmentation'] 231 | for idx, segm in enumerate(segmentations): 232 | segm = np.array(segm).reshape(-1, 2) 233 | points = segm[:-1] 234 | for i in range(len(points)): 235 | j = (i + 1) % len(points) 236 | if v_count+i > n_vertices - 1 or v_count+j > n_vertices-1: 237 | break 238 | perm_matrix[v_count+i, v_count+j] = 1. 239 | v_count += len(points) 240 | 241 | for i in range(v_count, n_vertices): 242 | perm_matrix[i, i] = 1. 243 | 244 | # Workaround for open contours: 245 | for i in range(n_vertices): 246 | row = perm_matrix[i, :] 247 | col = perm_matrix[:, i] 248 | if np.sum(row) == 0 or np.sum(col) == 0: 249 | perm_matrix[i, i] = 1. 250 | perm_matrix = torch.from_numpy(perm_matrix) 251 | ############# END: Generate gt permutation matrix. ############# 252 | 253 | masks = [mask, corner_mask] 254 | 255 | if len(corner_coords) > self.CFG.N_VERTICES: 256 | corner_coords = corner_coords[:self.CFG.N_VERTICES] 257 | 258 | if self.transform is not None: 259 | augmentations = self.transform(image=image, masks=masks, keypoints=corner_coords.tolist()) 260 | image = augmentations['image'] 261 | mask = augmentations['masks'][0] 262 | corner_mask = augmentations['masks'][1] 263 | corner_coords = np.array(augmentations['keypoints']) 264 | 265 | if self.tokenizer is not None: 266 | coords_seqs, rand_idxs = self.tokenizer(corner_coords, shuffle=self.shuffle_tokens) 267 | coords_seqs = torch.LongTensor(coords_seqs) 268 | # perm_matrix = torch.cat((perm_matrix[rand_idxs], perm_matrix[len(rand_idxs):])) 269 | if self.shuffle_tokens: 270 | perm_matrix = self.shuffle_perm_matrix_by_indices(perm_matrix, rand_idxs) 271 | else: 272 | coords_seqs = corner_coords 273 | 274 | return image, mask[None, ...], corner_mask[None, ...], coords_seqs, perm_matrix, torch.tensor([img['id']]) 275 | 276 | 277 | def whu_buildings_collate_fn(batch, max_len, pad_idx): 278 | """ 279 | if max_len: 280 | the sequences will all be padded to that length. 281 | """ 282 | 283 | image_batch, mask_batch, coords_mask_batch, coords_seq_batch, perm_matrix_batch = [], [], [], [], [] 284 | for image, mask, c_mask, seq, perm_mat in batch: 285 | image_batch.append(image) 286 | mask_batch.append(mask) 287 | coords_mask_batch.append(c_mask) 288 | coords_seq_batch.append(seq) 289 | perm_matrix_batch.append(perm_mat) 290 | 291 | coords_seq_batch = pad_sequence( 292 | coords_seq_batch, 293 | padding_value=pad_idx, 294 | batch_first=True 295 | ) 296 | 297 | if max_len: 298 | pad = torch.ones(coords_seq_batch.size(0), max_len - coords_seq_batch.size(1)).fill_(pad_idx).long() 299 | coords_seq_batch = torch.cat([coords_seq_batch, pad], dim=1) 300 | 301 | image_batch = torch.stack(image_batch) 302 | mask_batch = torch.stack(mask_batch) 303 | coords_mask_batch = torch.stack(coords_mask_batch) 304 | perm_matrix_batch = torch.stack(perm_matrix_batch) 305 | return image_batch, mask_batch, coords_mask_batch, coords_seq_batch, perm_matrix_batch 306 | 307 | 308 | class WHUBuildingsCocoDatasetTest(Dataset): 309 | def __init__(self, image_dir, transform=None): 310 | self.image_dir = image_dir 311 | self.transform = transform 312 | self.images = [file for file in os.listdir(self.image_dir) if osp.isfile(osp.join(self.image_dir, file))] 313 | 314 | def __getitem__(self, index): 315 | img_path = osp.join(self.image_dir, self.images[index]) 316 | image = np.array(Image.open(img_path).convert("RGB")) 317 | 318 | if self.transform is not None: 319 | image = self.transform(image=image)['image'] 320 | 321 | image = torch.FloatTensor(image) 322 | return image 323 | 324 | def __len__(self): 325 | return len(self.images) 326 | -------------------------------------------------------------------------------- /ddp_utils.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | from torch import distributed as dist 3 | from torch.utils.data import DataLoader 4 | from torch.utils.data.distributed import DistributedSampler 5 | from torch.nn.utils.rnn import pad_sequence 6 | import torch 7 | 8 | from datasets.dataset_inria_coco import InriaCocoDataset, InriaCocoDatasetTest 9 | from datasets.dataset_spacenet_coco import SpacenetCocoDataset, SpacenetCocoDatasetTest 10 | from datasets.dataset_whu_buildings_coco import WHUBuildingsCocoDataset, WHUBuildingsCocoDatasetTest 11 | from datasets.dataset_mass_roads import MassRoadsDataset, MassRoadsDatasetTest 12 | 13 | 14 | def is_dist_avail_and_initialized(): 15 | if not dist.is_available(): 16 | return False 17 | if not dist.is_initialized(): 18 | return False 19 | return True 20 | 21 | 22 | def get_rank(): 23 | if not is_dist_avail_and_initialized(): 24 | return 0 25 | return dist.get_rank() 26 | 27 | 28 | def is_main_process(): 29 | return get_rank() == 0 30 | 31 | 32 | def collate_fn(batch, max_len, pad_idx): 33 | """ 34 | if max_len: 35 | the sequences will all be padded to that length. 36 | """ 37 | 38 | image_batch, mask_batch, coords_mask_batch, coords_seq_batch, perm_matrix_batch = [], [], [], [], [] 39 | for image, mask, c_mask, seq, perm_mat in batch: 40 | image_batch.append(image) 41 | mask_batch.append(mask) 42 | coords_mask_batch.append(c_mask) 43 | coords_seq_batch.append(seq) 44 | perm_matrix_batch.append(perm_mat) 45 | 46 | coords_seq_batch = pad_sequence( 47 | coords_seq_batch, 48 | padding_value=pad_idx, 49 | batch_first=True 50 | ) 51 | 52 | if max_len: 53 | pad = torch.ones(coords_seq_batch.size(0), max_len - coords_seq_batch.size(1)).fill_(pad_idx).long() 54 | coords_seq_batch = torch.cat([coords_seq_batch, pad], dim=1) 55 | 56 | image_batch = torch.stack(image_batch) 57 | mask_batch = torch.stack(mask_batch) 58 | coords_mask_batch = torch.stack(coords_mask_batch) 59 | perm_matrix_batch = torch.stack(perm_matrix_batch) 60 | return image_batch, mask_batch, coords_mask_batch, coords_seq_batch, perm_matrix_batch 61 | 62 | 63 | def get_inria_loaders( 64 | train_dataset_dir, 65 | val_dataset_dir, 66 | test_images_dir, 67 | tokenizer, 68 | max_len, 69 | pad_idx, 70 | shuffle_tokens, 71 | batch_size, 72 | train_transform, 73 | val_transform, 74 | num_workers=2, 75 | pin_memory=True 76 | ): 77 | 78 | train_ds = InriaCocoDataset( 79 | dataset_dir=train_dataset_dir, 80 | transform=train_transform, 81 | tokenizer=tokenizer, 82 | shuffle_tokens=shuffle_tokens 83 | ) 84 | 85 | train_sampler = DistributedSampler(dataset=train_ds, shuffle=True) 86 | 87 | train_loader = DataLoader( 88 | train_ds, 89 | batch_size=batch_size, 90 | collate_fn=partial(collate_fn, max_len=max_len, pad_idx=pad_idx), 91 | sampler=train_sampler, 92 | num_workers=num_workers, 93 | pin_memory=pin_memory, 94 | drop_last=True 95 | ) 96 | 97 | valid_ds = InriaCocoDataset( 98 | dataset_dir=val_dataset_dir, 99 | transform=val_transform, 100 | tokenizer=tokenizer, 101 | shuffle_tokens=shuffle_tokens 102 | ) 103 | 104 | valid_sampler = DistributedSampler(dataset=valid_ds, shuffle=False) 105 | 106 | valid_loader = DataLoader( 107 | valid_ds, 108 | batch_size=batch_size, 109 | collate_fn=partial(collate_fn, max_len=max_len, pad_idx=pad_idx), 110 | sampler=valid_sampler, 111 | num_workers=0, 112 | pin_memory=True, 113 | ) 114 | 115 | test_ds = InriaCocoDatasetTest( 116 | image_dir=test_images_dir, 117 | transform=val_transform 118 | ) 119 | 120 | test_sampler = DistributedSampler(dataset=test_ds, shuffle=False) 121 | 122 | test_loader = DataLoader( 123 | test_ds, 124 | batch_size=batch_size, 125 | sampler=test_sampler, 126 | num_workers=num_workers, 127 | pin_memory=pin_memory, 128 | ) 129 | 130 | return train_loader, valid_loader, test_loader 131 | 132 | 133 | def get_spacenet_loaders( 134 | train_dataset_dir, 135 | val_dataset_dir, 136 | test_images_dir, 137 | tokenizer, 138 | max_len, 139 | pad_idx, 140 | shuffle_tokens, 141 | batch_size, 142 | train_transform, 143 | val_transform, 144 | num_workers=2, 145 | pin_memory=True 146 | ): 147 | 148 | train_ds = SpacenetCocoDataset( 149 | dataset_dir=train_dataset_dir, 150 | transform=train_transform, 151 | tokenizer=tokenizer, 152 | shuffle_tokens=shuffle_tokens 153 | ) 154 | 155 | train_sampler = DistributedSampler(dataset=train_ds, shuffle=True) 156 | 157 | train_loader = DataLoader( 158 | train_ds, 159 | batch_size=batch_size, 160 | collate_fn=partial(collate_fn, max_len=max_len, pad_idx=pad_idx), 161 | sampler=train_sampler, 162 | num_workers=num_workers, 163 | pin_memory=pin_memory, 164 | drop_last=True 165 | ) 166 | 167 | valid_ds = SpacenetCocoDataset( 168 | dataset_dir=val_dataset_dir, 169 | transform=val_transform, 170 | tokenizer=tokenizer, 171 | shuffle_tokens=shuffle_tokens 172 | ) 173 | 174 | valid_sampler = DistributedSampler(dataset=valid_ds, shuffle=False) 175 | 176 | valid_loader = DataLoader( 177 | valid_ds, 178 | batch_size=batch_size, 179 | collate_fn=partial(collate_fn, max_len=max_len, pad_idx=pad_idx), 180 | sampler=valid_sampler, 181 | num_workers=0, 182 | pin_memory=True, 183 | ) 184 | 185 | test_ds = SpacenetCocoDatasetTest( 186 | image_dir=test_images_dir, 187 | transform=val_transform 188 | ) 189 | 190 | test_sampler = DistributedSampler(dataset=test_ds, shuffle=False) 191 | 192 | test_loader = DataLoader( 193 | test_ds, 194 | batch_size=batch_size, 195 | sampler=test_sampler, 196 | num_workers=num_workers, 197 | pin_memory=pin_memory, 198 | ) 199 | 200 | return train_loader, valid_loader, test_loader 201 | 202 | 203 | def get_whu_buildings_loaders( 204 | train_dataset_dir, 205 | val_dataset_dir, 206 | test_images_dir, 207 | tokenizer, 208 | max_len, 209 | pad_idx, 210 | shuffle_tokens, 211 | batch_size, 212 | train_transform, 213 | val_transform, 214 | num_workers=2, 215 | pin_memory=True 216 | ): 217 | 218 | train_ds = WHUBuildingsCocoDataset( 219 | dataset_dir=train_dataset_dir, 220 | transform=train_transform, 221 | tokenizer=tokenizer, 222 | shuffle_tokens=shuffle_tokens 223 | ) 224 | 225 | train_sampler = DistributedSampler(dataset=train_ds, shuffle=True) 226 | 227 | train_loader = DataLoader( 228 | train_ds, 229 | batch_size=batch_size, 230 | collate_fn=partial(collate_fn, max_len=max_len, pad_idx=pad_idx), 231 | sampler=train_sampler, 232 | num_workers=num_workers, 233 | pin_memory=pin_memory, 234 | drop_last=True 235 | ) 236 | 237 | valid_ds = WHUBuildingsCocoDataset( 238 | dataset_dir=val_dataset_dir, 239 | transform=val_transform, 240 | tokenizer=tokenizer, 241 | shuffle_tokens=shuffle_tokens 242 | ) 243 | 244 | valid_sampler = DistributedSampler(dataset=valid_ds, shuffle=False) 245 | 246 | valid_loader = DataLoader( 247 | valid_ds, 248 | batch_size=batch_size, 249 | collate_fn=partial(collate_fn, max_len=max_len, pad_idx=pad_idx), 250 | sampler=valid_sampler, 251 | num_workers=0, 252 | pin_memory=True, 253 | ) 254 | 255 | test_ds = WHUBuildingsCocoDatasetTest( 256 | image_dir=test_images_dir, 257 | transform=val_transform 258 | ) 259 | 260 | test_sampler = DistributedSampler(dataset=test_ds, shuffle=False) 261 | 262 | test_loader = DataLoader( 263 | test_ds, 264 | batch_size=batch_size, 265 | sampler=test_sampler, 266 | num_workers=num_workers, 267 | pin_memory=pin_memory, 268 | ) 269 | 270 | return train_loader, valid_loader, test_loader 271 | 272 | 273 | def get_mass_roads_loaders( 274 | train_dataset_dir, 275 | val_dataset_dir, 276 | test_images_dir, 277 | tokenizer, 278 | max_len, 279 | pad_idx, 280 | shuffle_tokens, 281 | batch_size, 282 | train_transform, 283 | val_transform, 284 | num_workers=2, 285 | pin_memory=True 286 | ): 287 | 288 | train_ds = MassRoadsDataset( 289 | dataset_dir=train_dataset_dir, 290 | transform=train_transform, 291 | tokenizer=tokenizer, 292 | shuffle_tokens=shuffle_tokens 293 | ) 294 | 295 | train_sampler = DistributedSampler(dataset=train_ds, shuffle=True) 296 | 297 | train_loader = DataLoader( 298 | train_ds, 299 | batch_size=batch_size, 300 | collate_fn=partial(collate_fn, max_len=max_len, pad_idx=pad_idx), 301 | sampler=train_sampler, 302 | num_workers=num_workers, 303 | pin_memory=pin_memory, 304 | drop_last=True 305 | ) 306 | 307 | valid_ds = MassRoadsDataset( 308 | dataset_dir=val_dataset_dir, 309 | transform=val_transform, 310 | tokenizer=tokenizer, 311 | shuffle_tokens=shuffle_tokens 312 | ) 313 | 314 | valid_sampler = DistributedSampler(dataset=valid_ds, shuffle=False) 315 | 316 | valid_loader = DataLoader( 317 | valid_ds, 318 | batch_size=batch_size, 319 | collate_fn=partial(collate_fn, max_len=max_len, pad_idx=pad_idx), 320 | sampler=valid_sampler, 321 | num_workers=0, 322 | pin_memory=True, 323 | ) 324 | 325 | test_ds = MassRoadsDatasetTest( 326 | image_dir=test_images_dir, 327 | transform=val_transform 328 | ) 329 | 330 | test_sampler = DistributedSampler(dataset=test_ds, shuffle=False) 331 | 332 | test_loader = DataLoader( 333 | test_ds, 334 | batch_size=batch_size, 335 | sampler=test_sampler, 336 | num_workers=num_workers, 337 | pin_memory=pin_memory, 338 | ) 339 | 340 | return train_loader, valid_loader, test_loader 341 | 342 | -------------------------------------------------------------------------------- /engine.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | import torch 3 | 4 | from utils import ( 5 | AverageMeter, 6 | get_lr, 7 | save_checkpoint, 8 | save_single_predictions_as_images 9 | ) 10 | from config import CFG 11 | 12 | from ddp_utils import is_main_process 13 | 14 | 15 | def train_one_epoch(epoch, iter_idx, model, train_loader, optimizer, lr_scheduler, vertex_loss_fn, perm_loss_fn, writer): 16 | model.train() 17 | vertex_loss_fn.train() 18 | perm_loss_fn.train() 19 | 20 | loss_meter = AverageMeter() 21 | vertex_loss_meter = AverageMeter() 22 | perm_loss_meter = AverageMeter() 23 | 24 | loader = train_loader 25 | if is_main_process(): 26 | loader = tqdm(train_loader, total=len(train_loader)) 27 | 28 | # prof = torch.profiler.profile( 29 | # schedule=torch.profiler.schedule(wait=1, warmup=1, active=3, repeat=2), 30 | # on_trace_ready=torch.profiler.tensorboard_trace_handler(f"runs/{CFG.EXPERIMENT_NAME}/logs/profiler"), 31 | # record_shapes=True, 32 | # with_stack=True 33 | # ) 34 | # prof.start() 35 | for x, y_mask, y_corner_mask, y, y_perm in loader: 36 | x = x.to(CFG.DEVICE, non_blocking=True) 37 | y = y.to(CFG.DEVICE, non_blocking=True) 38 | y_perm = y_perm.to(CFG.DEVICE, non_blocking=True) 39 | 40 | y_input = y[:, :-1] 41 | y_expected = y[:, 1:] 42 | 43 | preds, perm_mat = model(x, y_input) 44 | 45 | if epoch < CFG.MILESTONE: 46 | vertex_loss_weight = CFG.vertex_loss_weight 47 | perm_loss_weight = 0.0 48 | else: 49 | vertex_loss_weight = CFG.vertex_loss_weight 50 | perm_loss_weight = CFG.perm_loss_weight 51 | 52 | vertex_loss = vertex_loss_weight*vertex_loss_fn(preds.reshape(-1, preds.shape[-1]), y_expected.reshape(-1)) 53 | perm_loss = perm_loss_weight*perm_loss_fn(perm_mat, y_perm) 54 | 55 | loss = vertex_loss + perm_loss 56 | 57 | optimizer.zero_grad(set_to_none=True) 58 | loss.backward() 59 | # nn.utils.clip_grad_norm_(model.module.parameters(), max_norm=0.1) 60 | optimizer.step() 61 | 62 | if lr_scheduler is not None: 63 | lr_scheduler.step() 64 | 65 | loss_meter.update(loss.item(), x.size(0)) 66 | vertex_loss_meter.update(vertex_loss.item(), x.size(0)) 67 | perm_loss_meter.update(perm_loss.item(), x.size(0)) 68 | 69 | lr = get_lr(optimizer) 70 | if is_main_process(): 71 | loader.set_postfix(train_loss=loss_meter.avg, lr=f"{lr:.5f}") 72 | writer.add_scalar('Running_logs/Train_Loss', loss_meter.avg, iter_idx) 73 | writer.add_scalar('Running_logs/LR', lr, iter_idx) 74 | # writer.add_image(f"Running_logs/input_images", torchvision.utils.make_grid(x), iter_idx) 75 | # writer.add_graph(model, input_to_model=(x, y_input)) 76 | 77 | iter_idx += 1 78 | # prof.step() 79 | # prof.stop() 80 | print(f"Total train loss: {loss_meter.avg}\n\n") 81 | loss_dict = { 82 | 'total_loss': loss_meter.avg, 83 | 'vertex_loss': vertex_loss_meter.avg, 84 | 'perm_loss': perm_loss_meter.avg, 85 | } 86 | 87 | return loss_dict, iter_idx 88 | 89 | 90 | def valid_one_epoch(epoch, model, valid_loader, vertex_loss_fn, perm_loss_fn): 91 | print(f"\nValidating...") 92 | model.eval() 93 | vertex_loss_fn.eval() 94 | perm_loss_fn.eval() 95 | 96 | loss_meter = AverageMeter() 97 | vertex_loss_meter = AverageMeter() 98 | perm_loss_meter = AverageMeter() 99 | 100 | loader = valid_loader 101 | if is_main_process(): 102 | loader = tqdm(valid_loader, total=len(valid_loader)) 103 | 104 | with torch.no_grad(): 105 | for x, y_mask, y_corner_mask, y, y_perm in loader: 106 | x = x.to(CFG.DEVICE, non_blocking=True) 107 | y = y.to(CFG.DEVICE, non_blocking=True) 108 | y_perm = y_perm.to(CFG.DEVICE, non_blocking=True) 109 | 110 | y_input = y[:, :-1] 111 | y_expected = y[:, 1:] 112 | 113 | preds, perm_mat = model(x, y_input) 114 | 115 | if epoch < CFG.MILESTONE: 116 | vertex_loss_weight = CFG.vertex_loss_weight 117 | perm_loss_weight = 0.0 118 | else: 119 | vertex_loss_weight = CFG.vertex_loss_weight 120 | perm_loss_weight = CFG.perm_loss_weight 121 | vertex_loss = vertex_loss_weight*vertex_loss_fn(preds.reshape(-1, preds.shape[-1]), y_expected.reshape(-1)) 122 | perm_loss = perm_loss_weight*perm_loss_fn(perm_mat, y_perm) 123 | 124 | loss = vertex_loss + perm_loss 125 | 126 | loss_meter.update(loss.item(), x.size(0)) 127 | vertex_loss_meter.update(vertex_loss.item(), x.size(0)) 128 | perm_loss_meter.update(perm_loss.item(), x.size(0)) 129 | 130 | loss_dict = { 131 | 'total_loss': loss_meter.avg, 132 | 'vertex_loss': vertex_loss_meter.avg, 133 | 'perm_loss': perm_loss_meter.avg, 134 | } 135 | 136 | return loss_dict 137 | 138 | 139 | def train_eval( 140 | model, 141 | train_loader, 142 | valid_loader, 143 | test_loader, 144 | tokenizer, 145 | vertex_loss_fn, 146 | perm_loss_fn, 147 | optimizer, 148 | lr_scheduler, 149 | step, 150 | writer 151 | ): 152 | best_loss = float('inf') 153 | best_metric = float('-inf') 154 | 155 | iter_idx=CFG.START_EPOCH * len(train_loader) 156 | epoch_iterator = range(CFG.START_EPOCH, CFG.NUM_EPOCHS) 157 | if is_main_process(): 158 | epoch_iterator = tqdm(epoch_iterator) 159 | for epoch in epoch_iterator: 160 | if is_main_process(): 161 | print(f"\n\nEPOCH: {epoch + 1}\n\n") 162 | 163 | if CFG.TRAIN_DDP: 164 | train_loader.sampler.set_epoch(epoch) 165 | valid_loader.sampler.set_epoch(epoch) 166 | test_loader.sampler.set_epoch(epoch) 167 | 168 | train_loss_dict, iter_idx = train_one_epoch( 169 | epoch, 170 | iter_idx, 171 | model, 172 | train_loader, 173 | optimizer, 174 | lr_scheduler if step=='batch' else None, 175 | vertex_loss_fn, 176 | perm_loss_fn, 177 | writer 178 | ) 179 | if is_main_process(): 180 | writer.add_scalar('Train_Losses/Total_Loss', train_loss_dict['total_loss'], epoch) 181 | writer.add_scalar('Train_Losses/Vertex_Loss', train_loss_dict['vertex_loss'], epoch) 182 | writer.add_scalar('Train_Losses/Perm_Loss', train_loss_dict['perm_loss'], epoch) 183 | 184 | valid_loss_dict = valid_one_epoch( 185 | epoch, 186 | model, 187 | valid_loader, 188 | vertex_loss_fn, 189 | perm_loss_fn, 190 | ) # TODO: add eval metrics to validation function? 191 | if is_main_process(): 192 | print(f"Valid loss: {valid_loss_dict['total_loss']:.3f}\n\n") 193 | 194 | if step=='epoch': 195 | pass 196 | 197 | # Save best validation loss epoch. 198 | if valid_loss_dict['total_loss'] < best_loss and CFG.SAVE_BEST and is_main_process(): 199 | best_loss = valid_loss_dict['total_loss'] 200 | checkpoint = { 201 | "state_dict": model.module.state_dict(), 202 | "optimizer": optimizer.state_dict(), 203 | "scheduler": lr_scheduler.state_dict(), 204 | "epochs_run": epoch, 205 | "loss": train_loss_dict["total_loss"] 206 | } 207 | save_checkpoint( 208 | checkpoint, 209 | folder=f"runs/{CFG.EXPERIMENT_NAME}/logs/checkpoints/", 210 | filename="best_valid_loss.pth" 211 | ) 212 | # torch.save(model.state_dict(), 'best_valid_loss.pth') 213 | print(f"Saved best val loss model.") 214 | 215 | # Save latest checkpoint every epoch. 216 | if CFG.SAVE_LATEST and is_main_process(): 217 | checkpoint = { 218 | "state_dict": model.module.state_dict(), 219 | "optimizer": optimizer.state_dict(), 220 | "scheduler": lr_scheduler.state_dict(), 221 | "epochs_run": epoch, 222 | "loss": train_loss_dict["total_loss"] 223 | } 224 | save_checkpoint( 225 | checkpoint, 226 | folder=f"runs/{CFG.EXPERIMENT_NAME}/logs/checkpoints/", 227 | filename="latest.pth" 228 | ) 229 | 230 | if (epoch + 1) % CFG.SAVE_EVERY == 0 and is_main_process(): 231 | checkpoint = { 232 | "state_dict": model.module.state_dict(), 233 | "optimizer": optimizer.state_dict(), 234 | "scheduler": lr_scheduler.state_dict(), 235 | "epochs_run": epoch, 236 | "loss": train_loss_dict["total_loss"] 237 | } 238 | save_checkpoint( 239 | checkpoint, 240 | folder=f"runs/{CFG.EXPERIMENT_NAME}/logs/checkpoints/", 241 | filename=f"epoch_{epoch}.pth" 242 | ) 243 | 244 | if is_main_process(): 245 | writer.add_scalar('Val_Losses/Total_Loss', valid_loss_dict['total_loss'], epoch) 246 | writer.add_scalar('Val_Losses/Vertex_Loss', valid_loss_dict['vertex_loss'], epoch) 247 | writer.add_scalar('Val_Losses/Perm_Loss', valid_loss_dict['perm_loss'], epoch) 248 | 249 | # output examples to a folder 250 | if (epoch + 1) % CFG.VAL_EVERY == 0 and is_main_process(): 251 | val_metrics_dict = save_single_predictions_as_images( 252 | test_loader, 253 | model, 254 | tokenizer, 255 | epoch, 256 | writer, 257 | folder=f"runs/{CFG.EXPERIMENT_NAME}/runtime_outputs/", 258 | device=CFG.DEVICE 259 | ) 260 | for metric, value in zip(val_metrics_dict.keys(), val_metrics_dict.values()): 261 | print(f"{metric}: {value}") 262 | 263 | # Save best single batch validation metric epoch. 264 | if val_metrics_dict["miou"] > best_metric and CFG.SAVE_BEST and is_main_process(): 265 | best_metric = val_metrics_dict["miou"] 266 | checkpoint = { 267 | "state_dict": model.module.state_dict(), 268 | "optimizer": optimizer.state_dict(), 269 | "scheduler": lr_scheduler.state_dict(), 270 | "epochs_run": epoch, 271 | "loss": train_loss_dict["total_loss"] 272 | } 273 | save_checkpoint( 274 | checkpoint, 275 | folder=f"runs/{CFG.EXPERIMENT_NAME}/logs/checkpoints/", 276 | filename="best_valid_metric.pth" 277 | ) 278 | print(f"Saved best val metric model.") 279 | 280 | -------------------------------------------------------------------------------- /eval/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yeshwanth95/Pix2Poly/d2e288b4e267ae910cee6f54bece8eec98f43bea/eval/__init__.py -------------------------------------------------------------------------------- /eval/hisup_eval_utils/metrics/cIoU.py: -------------------------------------------------------------------------------- 1 | """ 2 | This is the code from https://github.com/zorzi-s/PolyWorldPretrainedNetwork. 3 | @article{zorzi2021polyworld, 4 | title={PolyWorld: Polygonal Building Extraction with Graph Neural Networks in Satellite Images}, 5 | author={Zorzi, Stefano and Bazrafkan, Shabab and Habenschuss, Stefan and Fraundorfer, Friedrich}, 6 | journal={arXiv preprint arXiv:2111.15491}, 7 | year={2021} 8 | } 9 | """ 10 | 11 | from pycocotools.coco import COCO 12 | from pycocotools import mask as cocomask 13 | import numpy as np 14 | import json 15 | import argparse 16 | from tqdm import tqdm 17 | 18 | def calc_IoU(a, b): 19 | i = np.logical_and(a, b) 20 | u = np.logical_or(a, b) 21 | I = np.sum(i) 22 | U = np.sum(u) 23 | 24 | iou = I/(U + 1e-9) 25 | 26 | is_void = U == 0 27 | if is_void: 28 | return 1.0 29 | else: 30 | return iou 31 | 32 | def compute_IoU_cIoU(input_json, gti_annotations): 33 | # Ground truth annotations 34 | coco_gti = COCO(gti_annotations) 35 | 36 | # Predictions annotations 37 | submission_file = json.loads(open(input_json).read()) 38 | coco = COCO(gti_annotations) 39 | coco = coco.loadRes(submission_file) 40 | 41 | 42 | image_ids = coco.getImgIds(catIds=coco.getCatIds()) 43 | bar = tqdm(image_ids) 44 | 45 | list_iou = [] 46 | list_ciou = [] 47 | pss = [] 48 | rel_difs = [] 49 | n_ratios = [] 50 | for image_id in bar: 51 | 52 | img = coco.loadImgs(image_id)[0] 53 | 54 | annotation_ids = coco.getAnnIds(imgIds=img['id']) 55 | annotations = coco.loadAnns(annotation_ids) 56 | N = 0 57 | for _idx, annotation in enumerate(annotations): 58 | try: 59 | rle = cocomask.frPyObjects(annotation['segmentation'], img['height'], img['width']) 60 | except Exception: 61 | import ipdb; ipdb.set_trace() 62 | m = cocomask.decode(rle) 63 | if _idx == 0: 64 | mask = m.reshape((img['height'], img['width'])) 65 | N = len(annotation['segmentation'][0]) // 2 66 | else: 67 | mask = mask + m.reshape((img['height'], img['width'])) 68 | N = N + len(annotation['segmentation'][0]) // 2 69 | 70 | mask = mask != 0 71 | 72 | 73 | annotation_ids = coco_gti.getAnnIds(imgIds=img['id']) 74 | annotations = coco_gti.loadAnns(annotation_ids) 75 | N_GT = 0 76 | for _idx, annotation in enumerate(annotations): 77 | if any(annotation['segmentation']): 78 | rle = cocomask.frPyObjects(annotation['segmentation'], img['height'], img['width']) 79 | m = cocomask.decode(rle) 80 | else: 81 | annotation['segmentation'] = [[]] 82 | m = np.zeros((img['height'], img['width'])) 83 | if m.ndim > 2: 84 | m = np.clip(0, 1, m.sum(axis=-1)) 85 | if _idx == 0: 86 | mask_gti = m.reshape((img['height'], img['width'])) 87 | N_GT = len(annotation['segmentation'][0]) // 2 88 | else: 89 | mask_gti = mask_gti + m.reshape((img['height'], img['width'])) 90 | N_GT = N_GT + len(annotation['segmentation'][0]) // 2 91 | 92 | mask_gti = mask_gti != 0 93 | 94 | ps = 1 - np.abs(N - N_GT) / (N + N_GT + 1e-9) 95 | rel_dif = np.abs(N - N_GT) / (N + N_GT + 1e-9) 96 | iou = calc_IoU(mask, mask_gti) 97 | list_iou.append(iou) 98 | list_ciou.append(iou * ps) 99 | pss.append(ps) 100 | rel_difs.append(rel_dif) 101 | if N_GT > 0: 102 | nr = N / N_GT 103 | n_ratios.append(nr) 104 | 105 | bar.set_description("iou: %2.4f, c-iou: %2.4f, ps:%2.4f, rd:%2.4f" % (np.mean(list_iou), np.mean(list_ciou), np.mean(pss), np.mean(rel_difs))) 106 | bar.refresh() 107 | 108 | print("Done!") 109 | print("Mean IoU: ", np.mean(list_iou)) 110 | print("Mean C-IoU: ", np.mean(list_ciou)) 111 | print("Mean N-Relative Difference: ", np.mean(rel_difs)) 112 | print("Mean N-Ratio: ", np.mean(n_ratios)) 113 | 114 | 115 | 116 | if __name__ == "__main__": 117 | parser = argparse.ArgumentParser() 118 | parser.add_argument("--gt-file", default="") 119 | parser.add_argument("--dt-file", default="") 120 | args = parser.parse_args() 121 | 122 | gt_file = args.gt_file 123 | dt_file = args.dt_file 124 | compute_IoU_cIoU(input_json=dt_file, 125 | gti_annotations=gt_file) 126 | -------------------------------------------------------------------------------- /eval/hisup_eval_utils/metrics/polis.py: -------------------------------------------------------------------------------- 1 | """ 2 | The code is adopted from https://github.com/spgriffin/polis 3 | """ 4 | 5 | import numpy as np 6 | 7 | from tqdm import tqdm 8 | from collections import defaultdict 9 | from pycocotools import mask as maskUtils 10 | from shapely import geometry 11 | from shapely.geometry import Polygon 12 | 13 | 14 | def bounding_box(points): 15 | """returns a list containing the bottom left and the top right 16 | points in the sequence 17 | Here, we traverse the collection of points only once, 18 | to find the min and max for x and y 19 | """ 20 | bot_left_x, bot_left_y = float('inf'), float('inf') 21 | top_right_x, top_right_y = float('-inf'), float('-inf') 22 | for x, y in points: 23 | bot_left_x = min(bot_left_x, x) 24 | bot_left_y = min(bot_left_y, y) 25 | top_right_x = max(top_right_x, x) 26 | top_right_y = max(top_right_y, y) 27 | 28 | return [bot_left_x, bot_left_y, top_right_x - bot_left_x, top_right_y - bot_left_y] 29 | 30 | def compare_polys(poly_a, poly_b): 31 | """Compares two polygons via the "polis" distance metric. 32 | See "A Metric for Polygon Comparison and Building Extraction 33 | Evaluation" by J. Avbelj, et al. 34 | Input: 35 | poly_a: A Shapely polygon. 36 | poly_b: Another Shapely polygon. 37 | Returns: 38 | The "polis" distance between these two polygons. 39 | """ 40 | bndry_a, bndry_b = poly_a.exterior, poly_b.exterior 41 | dist = polis(bndry_a.coords, bndry_b) 42 | dist += polis(bndry_b.coords, bndry_a) 43 | return dist 44 | 45 | 46 | def polis(coords, bndry): 47 | """Computes one side of the "polis" metric. 48 | Input: 49 | coords: A Shapley coordinate sequence (presumably the vertices 50 | of a polygon). 51 | bndry: A Shapely linestring (presumably the boundary of 52 | another polygon). 53 | 54 | Returns: 55 | The "polis" metric for this pair. You usually compute this in 56 | both directions to preserve symmetry. 57 | """ 58 | sum = 0.0 59 | for pt in (geometry.Point(c) for c in coords[:-1]): # Skip the last point (same as first) 60 | sum += bndry.distance(pt) 61 | return sum/float(2*len(coords)) 62 | 63 | 64 | class PolisEval(): 65 | 66 | def __init__(self, cocoGt=None, cocoDt=None): 67 | self.cocoGt = cocoGt 68 | self.cocoDt = cocoDt 69 | self.evalImgs = defaultdict(list) 70 | self.eval = {} 71 | self._gts = defaultdict(list) 72 | self._dts = defaultdict(list) 73 | self.stats = [] 74 | self.imgIds = list(sorted(self.cocoGt.imgs.keys())) 75 | 76 | def _prepare(self): 77 | gts = self.cocoGt.loadAnns(self.cocoGt.getAnnIds(imgIds=self.imgIds)) 78 | dts = self.cocoDt.loadAnns(self.cocoDt.getAnnIds(imgIds=self.imgIds)) 79 | self._gts = defaultdict(list) # gt for evaluation 80 | self._dts = defaultdict(list) # dt for evaluation 81 | for gt in gts: 82 | self._gts[gt['image_id']].append(gt) 83 | for dt in dts: 84 | self._dts[dt['image_id']].append(dt) 85 | self.evalImgs = defaultdict(list) # per-image per-category evaluation results 86 | self.eval = {} # accumulated evaluation results 87 | 88 | def evaluateImg(self, imgId): 89 | gts = self._gts[imgId] 90 | dts = self._dts[imgId] 91 | 92 | if len(gts) == 0 or len(dts) == 0: 93 | return 0 94 | 95 | for gt in gts: 96 | if not any(gt['segmentation']): 97 | gt['segmentation'] = [[]] 98 | gt_bboxs = [bounding_box(np.array(gt['segmentation'][0]).reshape(-1,2)) for gt in gts] 99 | dt_bboxs = [bounding_box(np.array(dt['segmentation'][0]).reshape(-1,2)) for dt in dts] 100 | gt_polygons = [np.array(gt['segmentation'][0]).reshape(-1,2) for gt in gts] 101 | dt_polygons = [np.array(dt['segmentation'][0]).reshape(-1,2) for dt in dts] 102 | 103 | # IoU match 104 | iscrowd = [0] * len(gt_bboxs) 105 | # ious = maskUtils.iou(gt_bboxs, dt_bboxs, iscrowd) 106 | ious = maskUtils.iou(dt_bboxs, gt_bboxs, iscrowd) 107 | 108 | # compute polis 109 | img_polis_avg = 0 110 | num_sample = 0 111 | for i, gt_poly in enumerate(gt_polygons): 112 | matched_idx = np.argmax(ious[:,i]) 113 | iou = ious[matched_idx, i] 114 | if iou > 0.5: # iouThres: 115 | polis = compare_polys(Polygon(gt_poly), Polygon(dt_polygons[matched_idx])) 116 | img_polis_avg += polis 117 | num_sample += 1 118 | 119 | if num_sample == 0: 120 | return 0 121 | else: 122 | return img_polis_avg / num_sample 123 | 124 | 125 | def evaluate(self): 126 | self._prepare() 127 | polis_tot = 0 128 | 129 | num_valid_imgs = 0 130 | for imgId in tqdm(self.imgIds): 131 | img_polis_avg = self.evaluateImg(imgId) 132 | 133 | if img_polis_avg == 0: 134 | continue 135 | else: 136 | polis_tot += img_polis_avg 137 | num_valid_imgs += 1 138 | 139 | polis_avg = polis_tot / num_valid_imgs 140 | 141 | print('average polis: %f' % (polis_avg)) 142 | 143 | return polis_avg 144 | 145 | -------------------------------------------------------------------------------- /eval/topdig_eval_utils/metrics/topdig_metrics.py: -------------------------------------------------------------------------------- 1 | from pycocotools.coco import COCO 2 | from pycocotools import mask as cocomask 3 | import numpy as np 4 | import json 5 | import argparse 6 | from tqdm import tqdm 7 | 8 | import torch 9 | import cv2 10 | 11 | from eval.hisup_eval_utils.metrics.cIoU import calc_IoU 12 | from torchmetrics.functional.classification import binary_accuracy, binary_f1_score 13 | 14 | 15 | def calc_f1score(mask: np.ndarray, mask_gti: np.ndarray): 16 | union = np.logical_or(mask, mask_gti) 17 | U = np.sum(union) 18 | is_void = U == 0 19 | mask = torch.from_numpy(mask) 20 | mask_gti = torch.from_numpy(mask_gti) 21 | 22 | if is_void: 23 | return 1.0 24 | else: 25 | return binary_f1_score(preds=mask, target=mask_gti) 26 | 27 | 28 | def calc_acc(mask: np.ndarray, mask_gti: np.ndarray): 29 | union = np.logical_or(mask, mask_gti) 30 | U = np.sum(union) 31 | is_void = U == 0 32 | mask = torch.from_numpy(mask) 33 | mask_gti = torch.from_numpy(mask_gti) 34 | 35 | if is_void: 36 | return 1.0 37 | else: 38 | return binary_accuracy(preds=mask, target=mask_gti) 39 | 40 | 41 | def compute_mask_metrics(input_json, gti_annotations): 42 | # Ground truth annotations 43 | coco_gti = COCO(gti_annotations) 44 | 45 | # Predictions annotations 46 | submission_file = json.loads(open(input_json).read()) 47 | coco = COCO(gti_annotations) 48 | coco = coco.loadRes(submission_file) 49 | 50 | 51 | image_ids = coco.getImgIds(catIds=coco.getCatIds()) 52 | bar = tqdm(image_ids) 53 | 54 | buffer_thickness = 5 # dilation factor same as that used in TopDiG. 55 | 56 | list_acc = [] 57 | list_f1 = [] 58 | list_iou = [] 59 | 60 | list_acc_topo = [] 61 | list_f1_topo = [] 62 | list_iou_topo = [] 63 | 64 | for image_id in bar: 65 | 66 | img = coco.loadImgs(image_id)[0] 67 | 68 | # Predictions 69 | annotation_ids = coco.getAnnIds(imgIds=img['id']) 70 | annotations = coco.loadAnns(annotation_ids) 71 | topo_mask = np.zeros((img['height'], img['width'])) 72 | poly_lines = [] 73 | for _idx, annotation in enumerate(annotations): 74 | try: 75 | rle = cocomask.frPyObjects(annotation['segmentation'], img['height'], img['width']) 76 | except Exception: 77 | import ipdb; ipdb.set_trace() 78 | m = cocomask.decode(rle) 79 | if _idx == 0: 80 | mask = m.reshape((img['height'], img['width'])) 81 | else: 82 | mask = mask + m.reshape((img['height'], img['width'])) 83 | for ann in annotation['segmentation']: 84 | ann = np.array(ann).reshape(-1, 2) 85 | ann = np.round(ann).astype(np.int32) 86 | poly_lines.append(ann) 87 | cv2.polylines(topo_mask, poly_lines, isClosed=True, color=1., thickness=buffer_thickness) 88 | 89 | mask = mask != 0 90 | topo_mask = (topo_mask != 0).astype(np.float32) 91 | 92 | 93 | # Ground truth 94 | annotation_ids = coco_gti.getAnnIds(imgIds=img['id']) 95 | annotations = coco_gti.loadAnns(annotation_ids) 96 | topo_mask_gt = np.zeros((img['height'], img['width'])) 97 | poly_lines_gt = [] 98 | for _idx, annotation in enumerate(annotations): 99 | if any(annotation['segmentation']): 100 | rle = cocomask.frPyObjects(annotation['segmentation'], img['height'], img['width']) 101 | m = cocomask.decode(rle) 102 | else: 103 | annotation['segmentation'] = [[]] 104 | m = np.zeros((img['height'], img['width'])) 105 | if m.ndim > 2: 106 | m = np.clip(0, 1, m.sum(axis=-1)) 107 | if _idx == 0: 108 | mask_gti = m.reshape((img['height'], img['width'])) 109 | else: 110 | mask_gti = mask_gti + m.reshape((img['height'], img['width'])) 111 | for ann in annotation['segmentation']: 112 | ann = np.array(ann).reshape(-1, 2) 113 | ann = np.round(ann).astype(np.int32) 114 | poly_lines_gt.append(ann) 115 | cv2.polylines(topo_mask_gt, poly_lines_gt, isClosed=True, color=1., thickness=buffer_thickness) 116 | 117 | mask_gti = mask_gti != 0 118 | topo_mask_gt = (topo_mask_gt != 0).astype(np.float32) 119 | 120 | 121 | pacc = calc_acc(mask, mask_gti) 122 | list_acc.append(pacc) 123 | f1score = calc_f1score(mask, mask_gti) 124 | list_f1.append(f1score) 125 | iou = calc_IoU(mask, mask_gti) 126 | list_iou.append(iou) 127 | 128 | pacc_topo = calc_acc(topo_mask, topo_mask_gt) 129 | list_acc_topo.append(pacc_topo) 130 | f1score_topo = calc_f1score(topo_mask, topo_mask_gt) 131 | list_f1_topo.append(f1score_topo) 132 | iou_topo = calc_IoU(topo_mask, topo_mask_gt) 133 | list_iou_topo.append(iou_topo) 134 | 135 | bar.set_description("iou: %2.4f, p-acc: %2.4f, f1:%2.4f, iou-topo: %2.4f, p-acc-topo: %2.4f, f1-topo:%2.4f " % (np.mean(list_iou), np.mean(list_acc), np.mean(list_f1), np.mean(list_iou_topo), np.mean(list_acc_topo), np.mean(list_f1_topo))) 136 | bar.refresh() 137 | 138 | print("Done!") 139 | print("Mean IoU: ", np.mean(list_iou)) 140 | print("Mean P-Acc: ", np.mean(list_acc)) 141 | print("Mean F1-Score: ", np.mean(list_f1)) 142 | print("Mean IoU-Topo: ", np.mean(list_iou_topo)) 143 | print("Mean P-Acc-Topo: ", np.mean(list_acc_topo)) 144 | print("Mean F1-Score-Topo: ", np.mean(list_f1_topo)) 145 | 146 | 147 | if __name__ == "__main__": 148 | parser = argparse.ArgumentParser() 149 | parser.add_argument("--gt-file", default="") 150 | parser.add_argument("--dt-file", default="") 151 | args = parser.parse_args() 152 | 153 | gt_file = args.gt_file 154 | dt_file = args.dt_file 155 | compute_mask_metrics(input_json=dt_file, 156 | gti_annotations=gt_file) 157 | -------------------------------------------------------------------------------- /evaluate_mass_roads_predictions.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import argparse 4 | from tqdm import tqdm 5 | 6 | import torch 7 | import cv2 8 | 9 | from eval.hisup_eval_utils.metrics.cIoU import calc_IoU 10 | from torchmetrics.functional.classification import binary_accuracy, binary_f1_score 11 | 12 | def calc_f1score(mask: np.ndarray, mask_gti: np.ndarray): 13 | mask = torch.from_numpy(mask) 14 | mask_gti = torch.from_numpy(mask_gti) 15 | return binary_f1_score(preds=mask, target=mask_gti) 16 | 17 | 18 | def calc_acc(mask: np.ndarray, mask_gti: np.ndarray): 19 | mask = torch.from_numpy(mask) 20 | mask_gti = torch.from_numpy(mask_gti) 21 | return binary_accuracy(preds=mask, target=mask_gti) 22 | 23 | 24 | def compute_mask_metrics(predictions_dir, gt_dir): 25 | # Ground truth annotations 26 | # gt_masks = os.listdir(gt_dir) 27 | 28 | # Predictions annotations 29 | pred_masks = os.listdir(predictions_dir) 30 | 31 | 32 | images = pred_masks 33 | bar = tqdm(images) 34 | 35 | list_acc_topo = [] 36 | list_f1_topo = [] 37 | list_iou_topo = [] 38 | 39 | for image_id in bar: 40 | 41 | # img = cv2.imread(os.path.join(predictions_dir, image_id)) 42 | 43 | # Predictions 44 | topo_mask = cv2.imread(os.path.join(predictions_dir, image_id)) 45 | topo_mask = (topo_mask != 0).astype(np.float32) 46 | 47 | # Ground truth 48 | topo_mask_gt = cv2.imread(os.path.join(gt_dir, f"{image_id.split('.')[0]}.tif")) 49 | topo_mask_gt = (topo_mask_gt != 0).astype(np.float32) 50 | 51 | # Standard Torchmetrics Implementation 52 | pacc_orig = calc_acc(topo_mask, topo_mask_gt) 53 | list_acc_topo.append(pacc_orig) 54 | iou_orig = calc_IoU(topo_mask, topo_mask_gt) 55 | list_iou_topo.append(iou_orig) 56 | f1score_orig = calc_f1score(topo_mask, topo_mask_gt) 57 | list_f1_topo.append(f1score_orig) 58 | 59 | bar.set_description("iou-topo: %2.4f, p-acc-topo: %2.4f, f1-topo:%2.4f " % (np.mean(list_iou_topo), np.mean(list_acc_topo), np.mean(list_f1_topo))) 60 | bar.refresh() 61 | 62 | print("Done!") 63 | print("############## TOPO METRICS ############") 64 | print("Mean IoU-Topo: ", np.mean(list_iou_topo)) 65 | print("Mean P-Acc-Topo: ", np.mean(list_acc_topo)) 66 | print("Mean F1-Score-Topo: ", np.mean(list_f1_topo)) 67 | 68 | 69 | if __name__ == "__main__": 70 | parser = argparse.ArgumentParser() 71 | parser.add_argument("--gt-dir", default="") 72 | parser.add_argument("--dt-dir", default="") 73 | args = parser.parse_args() 74 | 75 | gt_dir = args.gt_dir 76 | dt_dir = args.dt_dir 77 | compute_mask_metrics(predictions_dir=dt_dir, 78 | gt_dir=gt_dir) 79 | -------------------------------------------------------------------------------- /evaluation.py: -------------------------------------------------------------------------------- 1 | # Borrowed from https://github.com/SarahwXU/HiSup/blob/main/tools/evaluation.py 2 | 3 | import argparse 4 | 5 | from multiprocess import Pool 6 | from pycocotools.coco import COCO 7 | from pycocotools.cocoeval import COCOeval 8 | from eval.hisup_eval_utils.metrics.polis import PolisEval 9 | from eval.hisup_eval_utils.metrics.angle_eval import ContourEval 10 | from eval.hisup_eval_utils.metrics.cIoU import compute_IoU_cIoU 11 | from eval.topdig_eval_utils.metrics.topdig_metrics import compute_mask_metrics 12 | 13 | 14 | def coco_eval(annFile, resFile): 15 | type=1 16 | annType = ['bbox', 'segm'] 17 | print('Running demo for *%s* results.' % (annType[type])) 18 | 19 | cocoGt = COCO(annFile) 20 | cocoDt = cocoGt.loadRes(resFile) 21 | 22 | imgIds = cocoGt.getImgIds() 23 | imgIds = imgIds[:] 24 | 25 | cocoEval = COCOeval(cocoGt, cocoDt, annType[type]) 26 | cocoEval.params.imgIds = imgIds 27 | cocoEval.params.catIds = [100] 28 | cocoEval.evaluate() 29 | cocoEval.accumulate() 30 | cocoEval.summarize() 31 | return cocoEval.stats 32 | 33 | def polis_eval(annFile, resFile): 34 | gt_coco = COCO(annFile) 35 | dt_coco = gt_coco.loadRes(resFile) 36 | polisEval = PolisEval(gt_coco, dt_coco) 37 | polisEval.evaluate() 38 | 39 | def max_angle_error_eval(annFile, resFile): 40 | gt_coco = COCO(annFile) 41 | dt_coco = gt_coco.loadRes(resFile) 42 | contour_eval = ContourEval(gt_coco, dt_coco) 43 | pool = Pool(processes=20) 44 | max_angle_diffs = contour_eval.evaluate(pool=pool) 45 | print('Mean max tangent angle error(MTA): ', max_angle_diffs.mean()) 46 | 47 | if __name__ == "__main__": 48 | parser = argparse.ArgumentParser() 49 | parser.add_argument("--gt-file", default="") 50 | parser.add_argument("--dt-file", default="") 51 | parser.add_argument("--eval-type", default="coco_iou", choices=["coco_iou", "polis", "angle", "ciou", "topdig"]) 52 | args = parser.parse_args() 53 | 54 | eval_type = args.eval_type 55 | gt_file = args.gt_file 56 | dt_file = args.dt_file 57 | if eval_type == 'coco_iou': 58 | coco_eval(gt_file, dt_file) 59 | elif eval_type == 'polis': 60 | polis_eval(gt_file, dt_file) 61 | elif eval_type == 'angle': 62 | max_angle_error_eval(gt_file, dt_file) 63 | elif eval_type == 'ciou': 64 | compute_IoU_cIoU(dt_file, gt_file) 65 | elif eval_type == 'topdig': 66 | compute_mask_metrics(dt_file, gt_file) 67 | else: 68 | raise RuntimeError('please choose a correct type from \ 69 | ["coco_iou", "polis", "angle", "ciou", "topdig"]') 70 | -------------------------------------------------------------------------------- /models/model.py: -------------------------------------------------------------------------------- 1 | import timm 2 | import torch 3 | from torch import nn 4 | from torch.nn import functional as F 5 | from timm.models.layers import trunc_normal_ 6 | 7 | import os 8 | import sys 9 | sys.path.insert(1, os.getcwd()) 10 | 11 | from config import CFG 12 | from utils import ( 13 | create_mask, 14 | ) 15 | 16 | 17 | # Borrowed from https://github.com/magicleap/SuperGluePretrainedNetwork/blob/ddcf11f42e7e0732a0c4607648f9448ea8d73590/models/superglue.py#L143 18 | def log_sinkhorn_iterations(Z: torch.Tensor, log_mu: torch.Tensor, log_nu: torch.Tensor, iters: int) -> torch.Tensor: 19 | """ Perform Sinkhorn Normalization in Log-space for stability""" 20 | u, v = torch.zeros_like(log_mu), torch.zeros_like(log_nu) 21 | for _ in range(iters): 22 | u = log_mu - torch.logsumexp(Z + v.unsqueeze(1), dim=2) 23 | v = log_nu - torch.logsumexp(Z + u.unsqueeze(2), dim=1) 24 | return Z + u.unsqueeze(2) + v.unsqueeze(1) 25 | 26 | # Borrowed from https://github.com/magicleap/SuperGluePretrainedNetwork/blob/ddcf11f42e7e0732a0c4607648f9448ea8d73590/models/superglue.py#L152 27 | def log_optimal_transport(scores: torch.Tensor, alpha: torch.Tensor, iters: int) -> torch.Tensor: 28 | """ Perform Differentiable Optimal Transport in Log-space for stability""" 29 | b, m, n = scores.shape 30 | one = scores.new_tensor(1) 31 | ms, ns = (m*one).to(scores), (n*one).to(scores) 32 | 33 | bins0 = alpha.expand(b, m, 1) 34 | bins1 = alpha.expand(b, 1, n) 35 | alpha = alpha.expand(b, 1, 1) 36 | 37 | couplings = torch.cat([torch.cat([scores, bins0], -1), 38 | torch.cat([bins1, alpha], -1)], 1) 39 | 40 | norm = - (ms + ns).log() 41 | log_mu = torch.cat([norm.expand(m), ns.log()[None] + norm]) 42 | log_nu = torch.cat([norm.expand(n), ms.log()[None] + norm]) 43 | log_mu, log_nu = log_mu[None].expand(b, -1), log_nu[None].expand(b, -1) 44 | 45 | Z = log_sinkhorn_iterations(couplings, log_mu, log_nu, iters) 46 | Z = Z - norm # multiply probabilities by M+N 47 | return Z 48 | 49 | 50 | class ScoreNet(nn.Module): 51 | def __init__(self, n_vertices, in_channels=512): 52 | super().__init__() 53 | self.n_vertices = n_vertices 54 | self.in_channels = in_channels 55 | self.relu = nn.ReLU(inplace=True) 56 | self.conv1 = nn.Conv2d(in_channels, 256, kernel_size=1, stride=1, padding=0, bias=True) 57 | self.bn1 = nn.BatchNorm2d(256) 58 | self.conv2 = nn.Conv2d(256, 128, kernel_size=1, stride=1, padding=0, bias=True) 59 | self.bn2 = nn.BatchNorm2d(128) 60 | self.conv3 = nn.Conv2d(128, 64, kernel_size=1, stride=1, padding=0, bias=True) 61 | self.bn3 = nn.BatchNorm2d(64) 62 | self.conv4 = nn.Conv2d(64, 1, kernel_size=1, stride=1, padding=0, bias=True) 63 | 64 | def forward(self, feats): 65 | feats = feats[:, 1:] 66 | feats = feats.unsqueeze(2) 67 | feats = feats.view(feats.size(0), feats.size(1)//2, 2, feats.size(3)) 68 | feats = torch.mean(feats, dim=2) 69 | 70 | x = torch.transpose(feats, 1, 2) 71 | x = x.unsqueeze(-1) 72 | x = x.repeat(1, 1, 1, self.n_vertices) 73 | t = torch.transpose(x, 2, 3) 74 | x = torch.cat((x, t), dim=1) 75 | 76 | x = self.conv1(x) 77 | x = self.bn1(x) 78 | x = self.relu(x) 79 | 80 | x = self.conv2(x) 81 | x = self.bn2(x) 82 | x = self.relu(x) 83 | 84 | x = self.conv3(x) 85 | x = self.bn3(x) 86 | x = self.relu(x) 87 | 88 | x = self.conv4(x) 89 | 90 | return x[:, 0] 91 | 92 | 93 | class Encoder(nn.Module): 94 | def __init__(self, model_name='deit3_small_patch16_384_in21ft1k', pretrained=False, out_dim=256) -> None: 95 | super().__init__() 96 | self.model = timm.create_model( 97 | model_name=model_name, 98 | num_classes=0, 99 | global_pool='', 100 | pretrained=pretrained 101 | ) 102 | self.bottleneck = nn.AdaptiveAvgPool1d(out_dim) 103 | 104 | def forward(self, x): 105 | features = self.model(x) 106 | return self.bottleneck(features[:, 1:]) 107 | 108 | 109 | class Decoder(nn.Module): 110 | def __init__(self, cfg, vocab_size, encoder_len, dim, num_heads, num_layers): 111 | super().__init__() 112 | self.cfg = cfg 113 | self.dim = dim 114 | 115 | self.embedding = nn.Embedding(vocab_size, dim) 116 | self.decoder_pos_embed = nn.Parameter(torch.randn(1, self.cfg.MAX_LEN-1, dim) * .02) 117 | self.decoder_pos_drop = nn.Dropout(p=0.05) 118 | 119 | decoder_layer = nn.TransformerDecoderLayer(d_model=dim, nhead=num_heads) 120 | self.decoder = nn.TransformerDecoder(decoder_layer=decoder_layer, num_layers=num_layers) 121 | self.output = nn.Linear(dim, vocab_size) 122 | 123 | self.encoder_pos_embed = nn.Parameter(torch.randn(1, encoder_len, dim) * .02) 124 | self.encoder_pos_drop = nn.Dropout(p=0.05) 125 | 126 | self.init_weights() 127 | 128 | def init_weights(self): 129 | for name, p in self.named_parameters(): 130 | if 'encoder_pos_embed' in name or 'decoder_pos_embed' in name: 131 | print(f"Skipping initialization of pos embed layers...") 132 | continue 133 | if p.dim() > 1: 134 | nn.init.xavier_uniform_(p) 135 | 136 | trunc_normal_(self.encoder_pos_embed, std=.02) 137 | trunc_normal_(self.decoder_pos_embed, std=.02) 138 | 139 | def forward(self, encoder_out, tgt): 140 | """ 141 | encoder_out shape: (N, L, D) 142 | tgt shape: (N, L) 143 | """ 144 | 145 | tgt_mask, tgt_padding_mask = create_mask(tgt, self.cfg.PAD_IDX) 146 | tgt_embedding = self.embedding(tgt) 147 | tgt_embedding = self.decoder_pos_drop( 148 | tgt_embedding + self.decoder_pos_embed 149 | ) 150 | 151 | encoder_out = self.encoder_pos_drop( 152 | encoder_out + self.encoder_pos_embed 153 | ) 154 | 155 | encoder_out = encoder_out.transpose(0, 1) 156 | tgt_embedding = tgt_embedding.transpose(0, 1) 157 | 158 | preds = self.decoder( 159 | memory=encoder_out, 160 | tgt=tgt_embedding, 161 | tgt_mask=tgt_mask, 162 | tgt_key_padding_mask=tgt_padding_mask 163 | ) 164 | 165 | preds = preds.transpose(0, 1) 166 | return self.output(preds), preds 167 | 168 | def predict(self, encoder_out, tgt): 169 | length = tgt.size(1) 170 | padding = torch.ones((tgt.size(0), self.cfg.MAX_LEN-length-1), device=tgt.device).fill_(self.cfg.PAD_IDX).long() 171 | tgt = torch.cat([tgt, padding], dim=1) 172 | tgt_mask, tgt_padding_mask = create_mask(tgt, self.cfg.PAD_IDX) 173 | tgt_embedding = self.embedding(tgt) 174 | tgt_embedding = self.decoder_pos_drop( 175 | tgt_embedding + self.decoder_pos_embed 176 | ) 177 | 178 | encoder_out = self.encoder_pos_drop( 179 | encoder_out + self.encoder_pos_embed 180 | ) 181 | 182 | encoder_out = encoder_out.transpose(0, 1) 183 | tgt_embedding = tgt_embedding.transpose(0, 1) 184 | 185 | preds = self.decoder( 186 | memory=encoder_out, 187 | tgt=tgt_embedding, 188 | tgt_mask=tgt_mask, 189 | tgt_key_padding_mask = tgt_padding_mask 190 | ) 191 | 192 | preds = preds.transpose(0, 1) 193 | return self.output(preds)[:, length-1, :], preds 194 | 195 | 196 | class EncoderDecoder(nn.Module): 197 | def __init__(self, cfg, encoder, decoder): 198 | super().__init__() 199 | self.cfg = cfg 200 | self.encoder = encoder 201 | self.decoder = decoder 202 | self.scorenet1 = ScoreNet(self.cfg.N_VERTICES) 203 | self.scorenet2 = ScoreNet(self.cfg.N_VERTICES) 204 | bin_score = torch.nn.Parameter(torch.tensor(1.)) 205 | self.register_parameter('bin_score', bin_score) 206 | 207 | def forward(self, image, tgt): 208 | encoder_out = self.encoder(image) 209 | preds, feats = self.decoder(encoder_out, tgt) 210 | perm_mat1 = self.scorenet1(feats) 211 | perm_mat2 = self.scorenet2(feats) 212 | perm_mat = perm_mat1 + torch.transpose(perm_mat2, 1, 2) 213 | 214 | perm_mat = log_optimal_transport(perm_mat, self.bin_score, self.cfg.SINKHORN_ITERATIONS)[:, :perm_mat.shape[1], :perm_mat.shape[2]] 215 | perm_mat = F.softmax(perm_mat, dim=-1) # NOTE: perhaps try gumbel softmax here? 216 | # perm_mat = F.gumbel_softmax(perm_mat, tau=1.0, hard=False) 217 | 218 | return preds, perm_mat 219 | 220 | def predict(self, image, tgt): 221 | encoder_out = self.encoder(image) 222 | preds, feats = self.decoder.predict(encoder_out, tgt) 223 | return preds, feats 224 | 225 | 226 | if __name__ == "__main__": 227 | # run this script as main for debugging. 228 | from tokenizer import Tokenizer 229 | from torch.nn.utils.rnn import pad_sequence 230 | import numpy as np 231 | import torch 232 | from torch import nn 233 | 234 | image = torch.randn(1, 3, CFG.INPUT_HEIGHT, CFG.INPUT_WIDTH).to('cuda') 235 | 236 | n_vertices = 192 237 | gt_coords = np.random.randint(size=(n_vertices, 2), low=0, high=CFG.IMG_SIZE).astype(np.float32) 238 | # in dataset 239 | tokenizer = Tokenizer(num_classes=1, num_bins=CFG.NUM_BINS, width=CFG.IMG_SIZE, height=CFG.IMG_SIZE, max_len=CFG.MAX_LEN) 240 | gt_seqs, rand_idxs = tokenizer(gt_coords) 241 | # in dataloader collate 242 | gt_seqs = [torch.LongTensor(gt_seqs)] 243 | gt_seqs = pad_sequence(gt_seqs, batch_first=True, padding_value=tokenizer.PAD_code) 244 | pad = torch.ones(gt_seqs.size(0), CFG.MAX_LEN - gt_seqs.size(1)).fill_(tokenizer.PAD_code).long() 245 | gt_seqs = torch.cat([gt_seqs, pad], dim=1).to('cuda') 246 | # in train fn 247 | gt_seqs_input = gt_seqs[:, :-1] 248 | gt_seqs_expected = gt_seqs[:, 1:] 249 | CFG.PAD_IDX = tokenizer.PAD_code 250 | 251 | encoder = Encoder(model_name=CFG.MODEL_NAME, pretrained=False, out_dim=256) 252 | decoder = Decoder(vocab_size=tokenizer.vocab_size, encoder_len=CFG.NUM_PATCHES, dim=256, num_heads=8, num_layers=6) 253 | model = EncoderDecoder(encoder, decoder).to('cuda') 254 | vertex_loss_fn = nn.CrossEntropyLoss(ignore_index=CFG.PAD_IDX) 255 | 256 | # Forward pass during training. 257 | preds_f, perm_mat, batch_polygons = model(image, gt_seqs_input) 258 | loss = vertex_loss_fn(preds_f.reshape(-1, preds_f.shape[-1]), gt_seqs_expected.reshape(-1)) 259 | 260 | # Sequence generation during prediction. 261 | batch_preds = torch.ones(image.size(0), 1).fill_(tokenizer.BOS_code).long().to(CFG.DEVICE) 262 | batch_feats = torch.ones(image.size(0), 1).fill_(tokenizer.BOS_code).long().to(CFG.DEVICE) 263 | sample = lambda preds: torch.softmax(preds, dim=-1).argmax(dim=-1).view(-1, 1) 264 | 265 | out_coords = [] 266 | out_confs = [] 267 | 268 | confs = [] 269 | with torch.no_grad(): 270 | for i in range(1 + n_vertices*2): 271 | try: 272 | print(i) 273 | preds_p, feats_p = model.predict(image, batch_preds) 274 | # print(preds_p.shape, feats_p.shape) 275 | if i % 2 == 0: 276 | confs_ = torch.softmax(preds_p, dim=-1).sort(axis=-1, descending=True)[0][:, 0].cpu() 277 | confs.append(confs_) 278 | preds_p = sample(preds_p) 279 | batch_preds = torch.cat([batch_preds, preds_p], dim=1) 280 | except: 281 | print(f"Error at iteration: {i}") 282 | perm_pred = model.scorenet(feats_p) 283 | 284 | # Postprocessing. 285 | EOS_idxs = (batch_preds == tokenizer.EOS_code).float().argmax(dim=-1) 286 | invalid_idxs = ((EOS_idxs - 1) % 2 != 0).nonzero().view(-1) # sanity check 287 | EOS_idxs[invalid_idxs] = 0 288 | 289 | all_coords = [] 290 | all_confs = [] 291 | for i, EOS_idx in enumerate(EOS_idxs.tolist()): 292 | if EOS_idx == 0: 293 | all_coords.append(None) 294 | all_confs.append(None) 295 | continue 296 | coords = tokenizer.decode(batch_preds[i, :EOS_idx+1]) 297 | confs = [round(confs[j][i].item(), 3) for j in range(len(coords))] 298 | 299 | all_coords.append(coords) 300 | all_confs.append(confs) 301 | 302 | out_coords.extend(all_coords) 303 | out_confs.extend(out_confs) 304 | 305 | print(f"preds_f shape: {preds_f.shape}") 306 | print(f"preds_f grad: {preds_f.requires_grad}") 307 | print(f"preds_f min: {preds_f.min()}, max: {preds_f.max()}") 308 | 309 | print(f"perm_mat shape: {perm_mat.shape}") 310 | print(f"perm_mat grad: {perm_mat.requires_grad}") 311 | print(f"perm_mat min: {perm_mat.min()}, max: {preds_f.max()}") 312 | 313 | print(f"batch_preds shape: {batch_preds.shape}") 314 | print(f"batch_preds grad: {batch_preds.requires_grad}") 315 | print(f"batch_preds min: {batch_preds.min()}, max: {batch_preds.max()}") 316 | 317 | print(f"loss : {loss}") 318 | print(f"loss grad: {loss.requires_grad}") 319 | 320 | -------------------------------------------------------------------------------- /postprocess_coco_parts.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import json 3 | from tqdm import tqdm 4 | import shapely 5 | 6 | parts = [ 7 | "runs/CYENS_CLUSTER_train_Pix2Poly_AUGSRUNS_inria_coco_224_negAug_run1_vit_small_patch8_224_dino_AffineRotaugs0.8_LinearWarmupLRS_1.0xVertexLoss_10.0xPermLoss__2xScoreNet_initialLR_0.0004_bs_24_Nv_192_Nbins224_500epochs/predictions_inria_coco_224_negAug_val_images_epoch_499.json" 8 | ] 9 | out_path = f"runs/CYENS_CLUSTER_train_Pix2Poly_AUGSRUNS_inria_coco_224_negAug_run1_vit_small_patch8_224_dino_AffineRotaugs0.8_LinearWarmupLRS_1.0xVertexLoss_10.0xPermLoss__2xScoreNet_initialLR_0.0004_bs_24_Nv_192_Nbins224_500epochs/combined_predictions_inria_coco_224_negAug_val_images_epoch_499.json" 10 | 11 | 12 | ################################################################ 13 | combined = [] 14 | part_lengths = [] 15 | for i, part in enumerate(parts): 16 | print(f"PART {i}") 17 | with open(part) as f: 18 | pred_part = json.loads(f.read()) 19 | part_lengths.append(len(pred_part)) 20 | for ins in tqdm(pred_part): 21 | assert len(ins['segmentation']) == 1 22 | segm = ins['segmentation'][0] 23 | segm = np.array(segm).reshape(-1, 2) 24 | # segm = np.flip(segm, axis=0) # invert order of vertices cw -> ccw or ccw -> cw 25 | if segm.shape[0] > 2 and shapely.Polygon(segm).area > 50.: 26 | segm = np.concatenate((segm, segm[0, None]), axis=0) 27 | segm = segm.reshape(-1).tolist() 28 | ins['segmentation'] = [segm] 29 | combined.append(ins) 30 | 31 | with open(out_path, "w") as fp: 32 | fp.write(json.dumps(combined)) 33 | ################################################################ 34 | 35 | -------------------------------------------------------------------------------- /predict_inria_coco_val_set.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import json 4 | from tqdm import tqdm 5 | import numpy as np 6 | import cv2 7 | import matplotlib.pyplot as plt 8 | import argparse 9 | 10 | from functools import partial 11 | import torch 12 | from torch import nn 13 | from torch.utils.data.distributed import DistributedSampler 14 | from torchvision.utils import make_grid 15 | import albumentations as A 16 | from albumentations.pytorch import ToTensorV2 17 | 18 | from test_config import CFG 19 | from tokenizer import Tokenizer 20 | from utils import ( 21 | seed_everything, 22 | load_checkpoint, 23 | test_generate, 24 | postprocess, 25 | permutations_to_polygons, 26 | ) 27 | from models.model import ( 28 | Encoder, 29 | Decoder, 30 | EncoderDecoder 31 | ) 32 | 33 | from torch.utils.data import DataLoader 34 | from datasets.dataset_inria_coco import InriaCocoDataset_val 35 | from torch.nn.utils.rnn import pad_sequence 36 | from torchmetrics.classification import BinaryJaccardIndex, BinaryAccuracy 37 | import torch.multiprocessing 38 | torch.multiprocessing.set_sharing_strategy("file_system") 39 | 40 | 41 | parser = argparse.ArgumentParser() 42 | parser.add_argument("-d", "--dataset", help="Dataset to use for evaluation.") 43 | parser.add_argument("-e", "--experiment_path", help="path to experiment folder to evaluate.") 44 | parser.add_argument("-c", "--checkpoint_name", help="Choice of checkpoint to evaluate in experiment.") 45 | parser.add_argument("-o", "--output_dir", help="Name of output subdirectory to store part predictions.") 46 | args = parser.parse_args() 47 | 48 | 49 | torch.backends.cuda.matmul.allow_tf32 = True 50 | torch.backends.cudnn.allow_tf32 = True 51 | 52 | DATASET = f"{args.dataset}" 53 | VAL_DATASET_DIR = f"./data/{DATASET}/val" 54 | # PART_DESC = "val_images" 55 | PART_DESC = f"{args.output_dir}" 56 | 57 | EXPERIMENT_NAME = os.path.basename(os.path.realpath(args.experiment_path)) 58 | CHECKPOINT_PATH = f"runs/{EXPERIMENT_NAME}/logs/checkpoints/{args.checkpoint_name}.pth" 59 | BATCH_SIZE = 24 60 | 61 | 62 | def bounding_box_from_points(points): 63 | points = np.array(points).flatten() 64 | even_locations = np.arange(points.shape[0]/2) * 2 65 | odd_locations = even_locations + 1 66 | X = np.take(points, even_locations.tolist()) 67 | Y = np.take(points, odd_locations.tolist()) 68 | bbox = [X.min(), Y.min(), X.max()-X.min(), Y.max()-Y.min()] 69 | bbox = [int(b) for b in bbox] 70 | return bbox 71 | 72 | 73 | def single_annotation(image_id, poly): 74 | _result = {} 75 | _result["image_id"] = int(image_id) 76 | _result["category_id"] = 100 77 | _result["score"] = 1 78 | _result["segmentation"] = poly 79 | _result["bbox"] = bounding_box_from_points(_result["segmentation"]) 80 | return _result 81 | 82 | 83 | def collate_fn(batch, max_len, pad_idx): 84 | """ 85 | if max_len: 86 | the sequences will all be padded to that length. 87 | """ 88 | 89 | image_batch, mask_batch, coords_mask_batch, coords_seq_batch, perm_matrix_batch, idx_batch = [], [], [], [], [], [] 90 | for image, mask, c_mask, seq, perm_mat, idx in batch: 91 | image_batch.append(image) 92 | mask_batch.append(mask) 93 | coords_mask_batch.append(c_mask) 94 | coords_seq_batch.append(seq) 95 | perm_matrix_batch.append(perm_mat) 96 | idx_batch.append(idx) 97 | 98 | coords_seq_batch = pad_sequence( 99 | coords_seq_batch, 100 | padding_value=pad_idx, 101 | batch_first=True 102 | ) 103 | 104 | if max_len: 105 | pad = torch.ones(coords_seq_batch.size(0), max_len - coords_seq_batch.size(1)).fill_(pad_idx).long() 106 | coords_seq_batch = torch.cat([coords_seq_batch, pad], dim=1) 107 | 108 | image_batch = torch.stack(image_batch) 109 | mask_batch = torch.stack(mask_batch) 110 | coords_mask_batch = torch.stack(coords_mask_batch) 111 | perm_matrix_batch = torch.stack(perm_matrix_batch) 112 | idx_batch = torch.stack(idx_batch) 113 | return image_batch, mask_batch, coords_mask_batch, coords_seq_batch, perm_matrix_batch, idx_batch 114 | 115 | 116 | def main(): 117 | seed_everything(42) 118 | 119 | valid_transforms = A.Compose( 120 | [ 121 | A.Resize(height=CFG.INPUT_HEIGHT, width=CFG.INPUT_WIDTH), 122 | A.Normalize( 123 | mean=[0.0, 0.0, 0.0], 124 | std=[1.0, 1.0, 1.0], 125 | max_pixel_value=255.0 126 | ), 127 | ToTensorV2(), 128 | ], 129 | keypoint_params=A.KeypointParams(format='yx', remove_invisible=False) 130 | ) 131 | 132 | tokenizer = Tokenizer( 133 | num_classes=1, 134 | num_bins=CFG.NUM_BINS, 135 | width=CFG.INPUT_WIDTH, 136 | height=CFG.INPUT_HEIGHT, 137 | max_len=CFG.MAX_LEN 138 | ) 139 | CFG.PAD_IDX = tokenizer.PAD_code 140 | 141 | val_ds = InriaCocoDataset_val( 142 | cfg=CFG, 143 | dataset_dir=VAL_DATASET_DIR, 144 | transform=valid_transforms, 145 | tokenizer=tokenizer, 146 | shuffle_tokens=CFG.SHUFFLE_TOKENS 147 | ) 148 | val_loader = DataLoader( 149 | val_ds, 150 | batch_size=BATCH_SIZE, 151 | collate_fn=partial(collate_fn, max_len=CFG.MAX_LEN, pad_idx=CFG.PAD_IDX), 152 | num_workers=2 153 | ) 154 | 155 | encoder = Encoder(model_name=CFG.MODEL_NAME, pretrained=True, out_dim=256) 156 | decoder = Decoder( 157 | cfg=CFG, 158 | vocab_size=tokenizer.vocab_size, 159 | encoder_len=CFG.NUM_PATCHES, 160 | dim=256, 161 | num_heads=8, 162 | num_layers=6 163 | ) 164 | model = EncoderDecoder(cfg=CFG, encoder=encoder, decoder=decoder) 165 | model.to(CFG.DEVICE) 166 | model.eval() 167 | 168 | checkpoint = torch.load(CHECKPOINT_PATH) 169 | model.load_state_dict(checkpoint['state_dict']) 170 | epoch = checkpoint['epochs_run'] 171 | 172 | print(f"Model loaded from epoch: {epoch}") 173 | ckpt_desc = f"epoch_{epoch}" 174 | if "best_valid_loss" in os.path.basename(CHECKPOINT_PATH): 175 | ckpt_desc = f"epoch_{epoch}_bestValLoss" 176 | elif "best_valid_metric" in os.path.basename(CHECKPOINT_PATH): 177 | ckpt_desc = f"epoch_{epoch}_bestValMetric" 178 | else: 179 | pass 180 | 181 | mean_iou_metric = BinaryJaccardIndex() 182 | mean_acc_metric = BinaryAccuracy() 183 | 184 | 185 | with torch.no_grad(): 186 | cumulative_miou = [] 187 | cumulative_macc = [] 188 | speed = [] 189 | predictions = [] 190 | for i_batch, (x, y_mask, y_corner_mask, y, y_perm, idx) in enumerate(tqdm(val_loader)): 191 | all_coords = [] 192 | all_confs = [] 193 | t0 = time.time() 194 | batch_preds, batch_confs, perm_preds = test_generate(model, x, tokenizer, max_len=CFG.generation_steps, top_k=0, top_p=1) 195 | speed.append(time.time() - t0) 196 | vertex_coords, confs = postprocess(batch_preds, batch_confs, tokenizer) 197 | 198 | all_coords.extend(vertex_coords) 199 | all_confs.extend(confs) 200 | 201 | coords = [] 202 | for i in range(len(all_coords)): 203 | if all_coords[i] is not None: 204 | coord = torch.from_numpy(all_coords[i]) 205 | else: 206 | coord = torch.tensor([]) 207 | 208 | padd = torch.ones((CFG.N_VERTICES - len(coord), 2)).fill_(CFG.PAD_IDX) 209 | coord = torch.cat([coord, padd], dim=0) 210 | coords.append(coord) 211 | batch_polygons = permutations_to_polygons(perm_preds, coords, out='torch') # [0, 224] 212 | # pred_polygons = permutations_to_polygons(perm_preds, coords, out='coco') # [0, 224] 213 | 214 | for ip, pp in enumerate(batch_polygons): 215 | for p in pp: 216 | p = torch.fliplr(p) 217 | p = p[p[:, 0] != CFG.PAD_IDX] 218 | p = p * (CFG.IMG_SIZE / CFG.INPUT_WIDTH) 219 | p = p.view(-1).tolist() 220 | if len(p) > 0: 221 | predictions.append(single_annotation(idx[ip], [p])) 222 | 223 | B, C, H, W = x.shape 224 | 225 | polygons_mask = np.zeros((B, 1, H, W)) 226 | for b in range(len(batch_polygons)): 227 | for c in range(len(batch_polygons[b])): 228 | poly = batch_polygons[b][c] 229 | poly = poly[poly[:, 0] != CFG.PAD_IDX] 230 | cnt = np.flip(np.int32(poly.cpu()), 1) 231 | if len(cnt) > 0: 232 | cv2.fillPoly(polygons_mask[b, 0], pts=[cnt], color=1.) 233 | polygons_mask = torch.from_numpy(polygons_mask) 234 | 235 | batch_miou = mean_iou_metric(polygons_mask, y_mask) 236 | batch_macc = mean_acc_metric(polygons_mask, y_mask) 237 | 238 | cumulative_miou.append(batch_miou) 239 | cumulative_macc.append(batch_macc) 240 | 241 | pred_grid = make_grid(polygons_mask).permute(1, 2, 0) 242 | gt_grid = make_grid(y_mask).permute(1, 2, 0) 243 | plt.subplot(211), plt.imshow(pred_grid) ,plt.title("Predicted Polygons") ,plt.axis('off') 244 | plt.subplot(212), plt.imshow(gt_grid) ,plt.title("Ground Truth") ,plt.axis('off') 245 | 246 | if not os.path.exists(os.path.join(f"runs/{EXPERIMENT_NAME}", 'val_preds', DATASET, PART_DESC, ckpt_desc)): 247 | os.makedirs(os.path.join(f"runs/{EXPERIMENT_NAME}", 'val_preds', DATASET, PART_DESC, ckpt_desc)) 248 | plt.savefig(f"runs/{EXPERIMENT_NAME}/val_preds/{DATASET}/{PART_DESC}/{ckpt_desc}/batch_{i_batch}.png") 249 | plt.close() 250 | 251 | print("Average model speed: ", np.mean(speed) / BATCH_SIZE, " [s / image]") 252 | 253 | print(f"Average Mean IOU: {torch.tensor(cumulative_miou).nanmean()}") 254 | print(f"Average Mean Acc: {torch.tensor(cumulative_macc).nanmean()}") 255 | 256 | with open(f"runs/{EXPERIMENT_NAME}/predictions_{DATASET}_{PART_DESC}_{ckpt_desc}.json", "w") as fp: 257 | fp.write(json.dumps(predictions)) 258 | 259 | with open(f"runs/{EXPERIMENT_NAME}/val_metrics_{DATASET}_{PART_DESC}_{ckpt_desc}.txt", 'w') as ff: 260 | print(f"Average Mean IOU: {torch.tensor(cumulative_miou).nanmean()}", file=ff) 261 | print(f"Average Mean Acc: {torch.tensor(cumulative_macc).nanmean()}", file=ff) 262 | 263 | 264 | if __name__ == "__main__": 265 | main() 266 | 267 | -------------------------------------------------------------------------------- /predict_mass_roads_test_set.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from PIL import Image 4 | import cv2 5 | import json 6 | import matplotlib.pyplot as plt 7 | from matplotlib import patches 8 | # import gif 9 | from tqdm import tqdm 10 | 11 | import torch 12 | import albumentations as A 13 | from albumentations.pytorch import ToTensorV2 14 | from tokenizer import Tokenizer 15 | from test_config import CFG 16 | from models.model import ( 17 | Encoder, 18 | Decoder, 19 | EncoderDecoder, 20 | ) 21 | from utils import ( 22 | seed_everything, 23 | test_generate, 24 | postprocess, 25 | permutations_to_polygons, 26 | ) 27 | import time 28 | 29 | 30 | # adapted from https://github.com/obss/sahi/blob/e798c80d6e09079ae07a672c89732dd602fe9001/sahi/slicing.py#L30, MIT License 31 | def calculate_slice_bboxes( 32 | image_height: int, 33 | image_width: int, 34 | slice_height: int = 512, 35 | slice_width: int = 512, 36 | overlap_height_ratio: float = 0.2, 37 | overlap_width_ratio: float = 0.2, 38 | ) -> list[list[int]]: 39 | """ 40 | Given the height and width of an image, calculates how to divide the image into 41 | overlapping slices according to the height and width provided. These slices are returned 42 | as bounding boxes in xyxy format. 43 | :param image_height: Height of the original image. 44 | :param image_width: Width of the original image. 45 | :param slice_height: Height of each slice 46 | :param slice_width: Width of each slice 47 | :param overlap_height_ratio: Fractional overlap in height of each slice (e.g. an overlap of 0.2 for a slice of size 100 yields an overlap of 20 pixels) 48 | :param overlap_width_ratio: Fractional overlap in width of each slice (e.g. an overlap of 0.2 for a slice of size 100 yields an overlap of 20 pixels) 49 | :return: a list of bounding boxes in xyxy format 50 | """ 51 | 52 | slice_bboxes = [] 53 | y_max = y_min = 0 54 | y_overlap = int(overlap_height_ratio * slice_height) 55 | x_overlap = int(overlap_width_ratio * slice_width) 56 | while y_max < image_height: 57 | x_min = x_max = 0 58 | y_max = y_min + slice_height 59 | while x_max < image_width: 60 | x_max = x_min + slice_width 61 | if y_max > image_height or x_max > image_width: 62 | xmax = min(image_width, x_max) 63 | ymax = min(image_height, y_max) 64 | xmin = max(0, xmax - slice_width) 65 | ymin = max(0, ymax - slice_height) 66 | slice_bboxes.append([xmin, ymin, xmax, ymax]) 67 | else: 68 | slice_bboxes.append([x_min, y_min, x_max, y_max]) 69 | x_min = x_max - x_overlap 70 | y_min = y_max - y_overlap 71 | 72 | return slice_bboxes 73 | 74 | 75 | def get_rectangle_params_from_pascal_bbox(bbox): 76 | xmin_top_left, ymin_top_left, xmax_bottom_right, ymax_bottom_right = bbox 77 | 78 | bottom_left = (xmin_top_left, ymax_bottom_right) 79 | width = xmax_bottom_right - xmin_top_left 80 | height = ymin_top_left - ymax_bottom_right 81 | 82 | return bottom_left, width, height 83 | 84 | 85 | def draw_bboxes( 86 | plot_ax, 87 | bboxes, 88 | class_labels, 89 | get_rectangle_corners_fn=get_rectangle_params_from_pascal_bbox, 90 | ): 91 | for bbox, label in zip(bboxes, class_labels): 92 | bottom_left, width, height = get_rectangle_corners_fn(bbox) 93 | 94 | rect_1 = patches.Rectangle( 95 | bottom_left, width, height, linewidth=4, edgecolor="black", fill=False, 96 | ) 97 | rect_2 = patches.Rectangle( 98 | bottom_left, width, height, linewidth=2, edgecolor="white", fill=False, 99 | ) 100 | rx, ry = rect_1.get_xy() 101 | 102 | # Add the patch to the Axes 103 | plot_ax.add_patch(rect_1) 104 | plot_ax.add_patch(rect_2) 105 | plot_ax.annotate(label, (rx+width, ry+height), color='white', fontsize=20) 106 | 107 | # @gif.frame 108 | def show_image(image, bboxes=None, class_labels=None, draw_bboxes_fn=draw_bboxes): 109 | fig, ax = plt.subplots(1, figsize=(10, 10)) 110 | ax.imshow(image) 111 | 112 | if bboxes: 113 | draw_bboxes_fn(ax, bboxes, class_labels) 114 | 115 | # plt.show() 116 | 117 | 118 | def bounding_box_from_points(points): 119 | points = np.array(points).flatten() 120 | even_locations = np.arange(points.shape[0]/2) * 2 121 | odd_locations = even_locations + 1 122 | X = np.take(points, even_locations.tolist()) 123 | Y = np.take(points, odd_locations.tolist()) 124 | bbox = [X.min(), Y.min(), X.max()-X.min(), Y.max()-Y.min()] 125 | bbox = [int(b) for b in bbox] 126 | return bbox 127 | 128 | 129 | def single_annotation(image_id, poly): 130 | _result = {} 131 | _result["image_id"] = int(image_id) 132 | _result["category_id"] = 100 133 | _result["score"] = 1 134 | _result["segmentation"] = poly 135 | _result["bbox"] = bounding_box_from_points(_result["segmentation"]) 136 | return _result 137 | 138 | 139 | def main(args): 140 | BATCH_SIZE = int(args.batch_size) # 24 141 | PATCH_SIZE = int(args.img_size) # 224 142 | INPUT_HEIGHT = int(args.input_size) # 224 143 | INPUT_WIDTH = int(args.input_size) # 224 144 | 145 | # EXPERIMENT_NAME = f"CYENS_CLUSTER_train_Pix2PolyFullDataExps_inria_coco_224_negAug_run1_deit3_small_patch16_384_in21ft1k_Rotaugs_LinearWarmupLRS_NoShuffle_1.0xVertexLoss_10.0xPermLoss_0.0xVertexRegLoss__2xScoreNet_initialLR_0.0004_bs_16_Nv_192_Nbins384_LbSm_0.0_500epochs" 146 | EXPERIMENT_PATH = args.experiment_path 147 | EXPERIMENT_NAME = os.path.basename(os.path.abspath(EXPERIMENT_PATH)) 148 | CHECKPOINT_NAME = args.checkpoint_name 149 | CHECKPOINT_PATH = f"runs/{EXPERIMENT_NAME}/logs/checkpoints/{CHECKPOINT_NAME}.pth" 150 | 151 | SPLIT = args.split # 'val' or 'test' 152 | 153 | test_image_dir = f"data/mass_roads_1500/test/sat" 154 | val_image_dir = f"data/mass_roads_1500/val/sat" 155 | 156 | if SPLIT == "test": 157 | test_images = [] 158 | for im in os.listdir(test_image_dir): 159 | test_images.append(im) 160 | image_dir = test_image_dir 161 | images = test_images 162 | elif SPLIT == "val": 163 | val_images = [] 164 | for im in os.listdir(val_image_dir): 165 | val_images.append(im) 166 | image_dir = val_image_dir 167 | images = val_images 168 | else: 169 | raise ValueError("Specify either test or val split for prediction.") 170 | 171 | test_transforms = A.Compose( 172 | [ 173 | A.Resize(height=INPUT_HEIGHT, width=INPUT_WIDTH), 174 | A.Normalize( 175 | mean=[0.0, 0.0, 0.0], 176 | std=[1.0, 1.0, 1.0], 177 | max_pixel_value=255.0 178 | ), 179 | ToTensorV2(), 180 | ], 181 | ) 182 | 183 | tokenizer = Tokenizer( 184 | num_classes=1, 185 | num_bins=CFG.NUM_BINS, 186 | width=INPUT_WIDTH, 187 | height=INPUT_HEIGHT, 188 | max_len=CFG.MAX_LEN 189 | ) 190 | CFG.PAD_IDX = tokenizer.PAD_code 191 | 192 | encoder = Encoder(model_name=CFG.MODEL_NAME, pretrained=True, out_dim=256) 193 | decoder = Decoder( 194 | cfg=CFG, 195 | vocab_size=tokenizer.vocab_size, 196 | encoder_len=CFG.NUM_PATCHES, 197 | dim=256, 198 | num_heads=8, 199 | num_layers=6 200 | ) 201 | model = EncoderDecoder(cfg=CFG, encoder=encoder, decoder=decoder) 202 | model.to(CFG.DEVICE) 203 | model.eval() 204 | 205 | checkpoint = torch.load(CHECKPOINT_PATH) 206 | model.load_state_dict(checkpoint['state_dict']) 207 | epoch = checkpoint['epochs_run'] 208 | 209 | print(f"Model loaded from epoch: {epoch}") 210 | ckpt_desc = f"epoch_{epoch}" 211 | if "best_valid_loss" in os.path.basename(CHECKPOINT_PATH): 212 | ckpt_desc = f"epoch_{epoch}_bestValLoss" 213 | elif "best_valid_metric" in os.path.basename(CHECKPOINT_PATH): 214 | ckpt_desc = f"epoch_{epoch}_bestValMetric" 215 | else: 216 | pass 217 | 218 | 219 | results_dir = os.path.join(f"runs/{EXPERIMENT_NAME}", f"{SPLIT}_predictions", ckpt_desc) 220 | os.makedirs(results_dir, exist_ok=True) 221 | os.makedirs(os.path.join(results_dir, "raster_preds"), exist_ok=True) 222 | os.makedirs(os.path.join(results_dir, "polygon_preds"), exist_ok=True) 223 | 224 | 225 | with torch.no_grad(): 226 | for idx, image in enumerate(tqdm(images)): 227 | print(f"<---------Processing {idx+1}/{len(images)}: {image}----------->") 228 | img_name = image 229 | if os.path.exists(os.path.join(results_dir, 'raster_preds', img_name)): 230 | continue 231 | img = Image.open(os.path.join(image_dir, img_name)) 232 | img = np.array(img) 233 | 234 | slice_bboxes = calculate_slice_bboxes( 235 | image_height=img.shape[1], 236 | image_width=img.shape[0], 237 | slice_height=PATCH_SIZE, 238 | slice_width=PATCH_SIZE, 239 | overlap_height_ratio=0.2, 240 | overlap_width_ratio=0.2 241 | ) 242 | 243 | speed = [] 244 | predictions = [] 245 | for bi, box in enumerate(tqdm(slice_bboxes)): 246 | xmin_top_left, ymin_top_left, xmax_bottom_right, ymax_bottom_right = box 247 | patch = img[ymin_top_left:ymax_bottom_right, xmin_top_left:xmax_bottom_right] 248 | patch = test_transforms(image=patch.astype(np.float32))['image'][None] 249 | 250 | all_coords = [] 251 | all_confs = [] 252 | t0 = time.time() 253 | batch_preds, batch_confs, perm_preds = test_generate(model, patch, tokenizer, max_len=CFG.generation_steps, top_k=0, top_p=1) 254 | speed.append(time.time() - t0) 255 | vertex_coords, confs = postprocess(batch_preds, batch_confs, tokenizer) 256 | 257 | all_coords.extend(vertex_coords) 258 | all_confs.extend(confs) 259 | 260 | coords = [] 261 | for i in range(len(all_coords)): 262 | if all_coords[i] is not None: 263 | coord = torch.from_numpy(all_coords[i]) 264 | else: 265 | coord = torch.tensor([]) 266 | 267 | padd = torch.ones((CFG.N_VERTICES - len(coord), 2)).fill_(CFG.PAD_IDX) 268 | coord = torch.cat([coord, padd], dim=0) 269 | coords.append(coord) 270 | batch_polygons = permutations_to_polygons(perm_preds, coords, out='torch') # [0, 224] 271 | 272 | for ip, pp in enumerate(batch_polygons): 273 | if pp is not None: 274 | for p in pp: 275 | if p is not None: 276 | p = torch.fliplr(p) 277 | p = p[p[:, 0] != CFG.PAD_IDX] 278 | p = p * (PATCH_SIZE / INPUT_WIDTH) 279 | p[:, 0] = p[:, 0] + xmin_top_left 280 | p[:, 1] = p[:, 1] + ymin_top_left 281 | if len(p) > 0: 282 | if (p[0] == p[-1]).all(): 283 | p = p [:-1] 284 | p = p.view(-1).tolist() 285 | if len(p) > 0: 286 | predictions.append(single_annotation(idx, [p])) 287 | # For debugging 288 | # if bi >= 10: 289 | # break 290 | 291 | H, W = img.shape[0], img.shape[1] 292 | 293 | polygons_mask = np.zeros((H, W)) 294 | for pred in predictions: 295 | poly = np.array(pred['segmentation']) 296 | poly = poly.reshape((poly.shape[-1]//2, 2)) 297 | cv2.polylines(polygons_mask, [np.int32(poly)], isClosed=False, color=1., thickness=5) 298 | polygons_mask = (polygons_mask*255).astype(np.uint8) 299 | 300 | cv2.imwrite(os.path.join(results_dir, 'raster_preds', img_name), polygons_mask) 301 | print("Average model speed: ", np.mean(speed), " [s / patch]") 302 | print("Time for a single tile: ", np.sum(speed), " [s / tile]") 303 | 304 | with open(f"{results_dir}/polygon_preds/{img_name.split('.')[0]}.json", "w") as fp: 305 | fp.write(json.dumps(predictions)) 306 | 307 | 308 | ############# Visualizations #################: 309 | # frames = [] 310 | # for slice in tqdm(slice_bboxes): 311 | # frames.append(show_image(img, [slice], [''])) 312 | # # if sid > 40: 313 | # # break 314 | 315 | # gif.save(frames, "overlapping_patches.gif", 316 | # duration=15) 317 | 318 | 319 | if __name__ == "__main__": 320 | import argparse 321 | parser = argparse.ArgumentParser() 322 | 323 | parser.add_argument("-e", "--experiment_path", help="path to experiment folder to evaluate.") 324 | parser.add_argument("-c", "--checkpoint_name", help="Choice of checkpoint to evaluate in experiment.") 325 | parser.add_argument("-s", "--split", help="Dataset split to use for prediction ('test' or 'val').") 326 | parser.add_argument("--img_size", help="Original image size.") 327 | parser.add_argument("--input_size", help="Image size of input to network.") 328 | parser.add_argument("--batch_size", help="Batch size to network.") 329 | args = parser.parse_args() 330 | 331 | main(args) 332 | 333 | -------------------------------------------------------------------------------- /predict_spacenet_coco_val_set.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import json 4 | from tqdm import tqdm 5 | import numpy as np 6 | import cv2 7 | import matplotlib.pyplot as plt 8 | import argparse 9 | 10 | from functools import partial 11 | import torch 12 | from torch.utils.data.distributed import DistributedSampler 13 | from torchvision.utils import make_grid 14 | import albumentations as A 15 | from albumentations.pytorch import ToTensorV2 16 | 17 | from test_config import CFG 18 | from tokenizer import Tokenizer 19 | from utils import ( 20 | seed_everything, 21 | load_checkpoint, 22 | test_generate, 23 | postprocess, 24 | permutations_to_polygons, 25 | ) 26 | from models.model import ( 27 | Encoder, 28 | Decoder, 29 | EncoderDecoder 30 | ) 31 | 32 | from torch.utils.data import DataLoader 33 | from datasets.dataset_spacenet_coco import SpacenetCocoDataset_val 34 | from torch.nn.utils.rnn import pad_sequence 35 | from torchmetrics.classification import BinaryJaccardIndex, BinaryAccuracy 36 | from torch import distributed as dist 37 | import torch.multiprocessing 38 | torch.multiprocessing.set_sharing_strategy("file_system") 39 | 40 | 41 | parser = argparse.ArgumentParser() 42 | parser.add_argument("-d", "--dataset", help="Dataset to use for evaluation.") 43 | parser.add_argument("-e", "--experiment_path", help="path to experiment folder to evaluate.") 44 | parser.add_argument("-c", "--checkpoint_name", help="Choice of checkpoint to evaluate in experiment.") 45 | parser.add_argument("-o", "--output_dir", help="Name of output subdirectory to store part predictions.") 46 | args = parser.parse_args() 47 | 48 | 49 | torch.backends.cuda.matmul.allow_tf32 = True 50 | torch.backends.cudnn.allow_tf32 = True 51 | 52 | DATASET = f"{args.dataset}" 53 | VAL_DATASET_DIR = f"./data/{DATASET}/val" 54 | # PART_DESC = "val_images" 55 | PART_DESC = f"{args.output_dir}" 56 | 57 | EXPERIMENT_NAME = os.path.basename(os.path.realpath(args.experiment_path)) 58 | CHECKPOINT_PATH = f"runs/{EXPERIMENT_NAME}/logs/checkpoints/{args.checkpoint_name}.pth" 59 | BATCH_SIZE = 24 60 | 61 | 62 | def bounding_box_from_points(points): 63 | points = np.array(points).flatten() 64 | even_locations = np.arange(points.shape[0]/2) * 2 65 | odd_locations = even_locations + 1 66 | X = np.take(points, even_locations.tolist()) 67 | Y = np.take(points, odd_locations.tolist()) 68 | bbox = [X.min(), Y.min(), X.max()-X.min(), Y.max()-Y.min()] 69 | bbox = [int(b) for b in bbox] 70 | return bbox 71 | 72 | 73 | def single_annotation(image_id, poly): 74 | _result = {} 75 | _result["image_id"] = int(image_id) 76 | _result["category_id"] = 100 77 | _result["score"] = 1 78 | _result["segmentation"] = poly 79 | _result["bbox"] = bounding_box_from_points(_result["segmentation"]) 80 | return _result 81 | 82 | 83 | def collate_fn(batch, max_len, pad_idx): 84 | """ 85 | if max_len: 86 | the sequences will all be padded to that length. 87 | """ 88 | 89 | image_batch, mask_batch, coords_mask_batch, coords_seq_batch, perm_matrix_batch, idx_batch = [], [], [], [], [], [] 90 | for image, mask, c_mask, seq, perm_mat, idx in batch: 91 | image_batch.append(image) 92 | mask_batch.append(mask) 93 | coords_mask_batch.append(c_mask) 94 | coords_seq_batch.append(seq) 95 | perm_matrix_batch.append(perm_mat) 96 | idx_batch.append(idx) 97 | 98 | coords_seq_batch = pad_sequence( 99 | coords_seq_batch, 100 | padding_value=pad_idx, 101 | batch_first=True 102 | ) 103 | 104 | if max_len: 105 | pad = torch.ones(coords_seq_batch.size(0), max_len - coords_seq_batch.size(1)).fill_(pad_idx).long() 106 | coords_seq_batch = torch.cat([coords_seq_batch, pad], dim=1) 107 | 108 | image_batch = torch.stack(image_batch) 109 | mask_batch = torch.stack(mask_batch) 110 | coords_mask_batch = torch.stack(coords_mask_batch) 111 | perm_matrix_batch = torch.stack(perm_matrix_batch) 112 | idx_batch = torch.stack(idx_batch) 113 | return image_batch, mask_batch, coords_mask_batch, coords_seq_batch, perm_matrix_batch, idx_batch 114 | 115 | 116 | def main(): 117 | seed_everything(42) 118 | 119 | valid_transforms = A.Compose( 120 | [ 121 | A.Resize(height=CFG.INPUT_HEIGHT, width=CFG.INPUT_WIDTH), 122 | A.Normalize( 123 | mean=[0.0, 0.0, 0.0], 124 | std=[1.0, 1.0, 1.0], 125 | max_pixel_value=255.0 126 | ), 127 | ToTensorV2(), 128 | ], 129 | keypoint_params=A.KeypointParams(format='yx', remove_invisible=False) 130 | ) 131 | 132 | tokenizer = Tokenizer( 133 | num_classes=1, 134 | num_bins=CFG.NUM_BINS, 135 | width=CFG.INPUT_WIDTH, 136 | height=CFG.INPUT_HEIGHT, 137 | max_len=CFG.MAX_LEN 138 | ) 139 | CFG.PAD_IDX = tokenizer.PAD_code 140 | 141 | val_ds = SpacenetCocoDataset_val( 142 | cfg=CFG, 143 | dataset_dir=VAL_DATASET_DIR, 144 | transform=valid_transforms, 145 | tokenizer=tokenizer, 146 | shuffle_tokens=CFG.SHUFFLE_TOKENS 147 | ) 148 | val_loader = DataLoader( 149 | val_ds, 150 | batch_size=BATCH_SIZE, 151 | collate_fn=partial(collate_fn, max_len=CFG.MAX_LEN, pad_idx=CFG.PAD_IDX), 152 | num_workers=2 153 | ) 154 | 155 | encoder = Encoder(model_name=CFG.MODEL_NAME, pretrained=True, out_dim=256) 156 | decoder = Decoder( 157 | cfg=CFG, 158 | vocab_size=tokenizer.vocab_size, 159 | encoder_len=CFG.NUM_PATCHES, 160 | dim=256, 161 | num_heads=8, 162 | num_layers=6 163 | ) 164 | model = EncoderDecoder(cfg=CFG, encoder=encoder, decoder=decoder) 165 | model.to(CFG.DEVICE) 166 | model.eval() 167 | 168 | checkpoint = torch.load(CHECKPOINT_PATH) 169 | model.load_state_dict(checkpoint['state_dict']) 170 | epoch = checkpoint['epochs_run'] 171 | 172 | print(f"Model loaded from epoch: {epoch}") 173 | ckpt_desc = f"epoch_{epoch}" 174 | if "best_valid_loss" in os.path.basename(CHECKPOINT_PATH): 175 | ckpt_desc = f"epoch_{epoch}_bestValLoss" 176 | elif "best_valid_metric" in os.path.basename(CHECKPOINT_PATH): 177 | ckpt_desc = f"epoch_{epoch}_bestValMetric" 178 | else: 179 | pass 180 | 181 | mean_iou_metric = BinaryJaccardIndex() 182 | mean_acc_metric = BinaryAccuracy() 183 | 184 | 185 | with torch.no_grad(): 186 | cumulative_miou = [] 187 | cumulative_macc = [] 188 | speed = [] 189 | predictions = [] 190 | for i_batch, (x, y_mask, y_corner_mask, y, y_perm, idx) in enumerate(tqdm(val_loader)): 191 | all_coords = [] 192 | all_confs = [] 193 | t0 = time.time() 194 | batch_preds, batch_confs, perm_preds = test_generate(model, x, tokenizer, max_len=CFG.generation_steps, top_k=0, top_p=1) 195 | speed.append(time.time() - t0) 196 | vertex_coords, confs = postprocess(batch_preds, batch_confs, tokenizer) 197 | 198 | all_coords.extend(vertex_coords) 199 | all_confs.extend(confs) 200 | 201 | coords = [] 202 | for i in range(len(all_coords)): 203 | if all_coords[i] is not None: 204 | coord = torch.from_numpy(all_coords[i]) 205 | else: 206 | coord = torch.tensor([]) 207 | 208 | padd = torch.ones((CFG.N_VERTICES - len(coord), 2)).fill_(CFG.PAD_IDX) 209 | coord = torch.cat([coord, padd], dim=0) 210 | coords.append(coord) 211 | batch_polygons = permutations_to_polygons(perm_preds, coords, out='torch') # [0, 224] 212 | # pred_polygons = permutations_to_polygons(perm_preds, coords, out='coco') # [0, 224] 213 | 214 | for ip, pp in enumerate(batch_polygons): 215 | for p in pp: 216 | p = torch.fliplr(p) 217 | p = p[p[:, 0] != CFG.PAD_IDX] 218 | p = p * (CFG.IMG_SIZE / CFG.INPUT_WIDTH) 219 | p = p.view(-1).tolist() 220 | if len(p) > 0: 221 | predictions.append(single_annotation(idx[ip], [p])) 222 | 223 | B, C, H, W = x.shape 224 | 225 | polygons_mask = np.zeros((B, 1, H, W)) 226 | for b in range(len(batch_polygons)): 227 | for c in range(len(batch_polygons[b])): 228 | poly = batch_polygons[b][c] 229 | poly = poly[poly[:, 0] != CFG.PAD_IDX] 230 | cnt = np.flip(np.int32(poly.cpu()), 1) 231 | if len(cnt) > 0: 232 | cv2.fillPoly(polygons_mask[b, 0], pts=[cnt], color=1.) 233 | polygons_mask = torch.from_numpy(polygons_mask) 234 | 235 | batch_miou = mean_iou_metric(polygons_mask, y_mask) 236 | batch_macc = mean_acc_metric(polygons_mask, y_mask) 237 | 238 | cumulative_miou.append(batch_miou) 239 | cumulative_macc.append(batch_macc) 240 | 241 | pred_grid = make_grid(polygons_mask).permute(1, 2, 0) 242 | gt_grid = make_grid(y_mask).permute(1, 2, 0) 243 | plt.subplot(211), plt.imshow(pred_grid) ,plt.title("Predicted Polygons") ,plt.axis('off') 244 | plt.subplot(212), plt.imshow(gt_grid) ,plt.title("Ground Truth") ,plt.axis('off') 245 | 246 | if not os.path.exists(os.path.join(f"runs/{EXPERIMENT_NAME}", 'val_preds', DATASET, PART_DESC, ckpt_desc)): 247 | os.makedirs(os.path.join(f"runs/{EXPERIMENT_NAME}", 'val_preds', DATASET, PART_DESC, ckpt_desc)) 248 | plt.savefig(f"runs/{EXPERIMENT_NAME}/val_preds/{DATASET}/{PART_DESC}/{ckpt_desc}/batch_{i_batch}.png") 249 | plt.close() 250 | 251 | print("Average model speed: ", np.mean(speed) / BATCH_SIZE, " [s / image]") 252 | 253 | print(f"Average Mean IOU: {torch.tensor(cumulative_miou).nanmean()}") 254 | print(f"Average Mean Acc: {torch.tensor(cumulative_macc).nanmean()}") 255 | 256 | with open(f"runs/{EXPERIMENT_NAME}/predictions_{DATASET}_{PART_DESC}_{ckpt_desc}.json", "w") as fp: 257 | fp.write(json.dumps(predictions)) 258 | 259 | with open(f"runs/{EXPERIMENT_NAME}/val_metrics_{DATASET}_{PART_DESC}_{ckpt_desc}.txt", 'w') as ff: 260 | print(f"Average Mean IOU: {torch.tensor(cumulative_miou).nanmean()}", file=ff) 261 | print(f"Average Mean Acc: {torch.tensor(cumulative_macc).nanmean()}", file=ff) 262 | 263 | 264 | if __name__ == "__main__": 265 | main() 266 | 267 | -------------------------------------------------------------------------------- /predict_whu_buildings_coco_test_set.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import json 4 | from tqdm import tqdm 5 | import numpy as np 6 | import cv2 7 | import matplotlib.pyplot as plt 8 | import argparse 9 | 10 | from functools import partial 11 | import torch 12 | from torchvision.utils import make_grid 13 | import albumentations as A 14 | from albumentations.pytorch import ToTensorV2 15 | 16 | from test_config import CFG 17 | from tokenizer import Tokenizer 18 | from utils import ( 19 | seed_everything, 20 | load_checkpoint, 21 | test_generate, 22 | postprocess, 23 | permutations_to_polygons, 24 | ) 25 | from models.model import ( 26 | Encoder, 27 | Decoder, 28 | EncoderDecoder 29 | ) 30 | 31 | from torch.utils.data import DataLoader 32 | from datasets.dataset_whu_buildings_coco import WHUBuildingsCocoDataset_val 33 | from torch.nn.utils.rnn import pad_sequence 34 | from torchmetrics.classification import BinaryJaccardIndex, BinaryAccuracy 35 | import torch.multiprocessing 36 | torch.multiprocessing.set_sharing_strategy("file_system") 37 | 38 | 39 | parser = argparse.ArgumentParser() 40 | parser.add_argument("-d", "--dataset", help="Dataset to use for evaluation.") 41 | parser.add_argument("-e", "--experiment_path", help="path to experiment folder to evaluate.") 42 | parser.add_argument("-c", "--checkpoint_name", help="Choice of checkpoint to evaluate in experiment.") 43 | parser.add_argument("-o", "--output_dir", help="Name of output subdirectory to store part predictions.") 44 | args = parser.parse_args() 45 | 46 | 47 | torch.backends.cuda.matmul.allow_tf32 = True 48 | torch.backends.cudnn.allow_tf32 = True 49 | 50 | DATASET = f"{args.dataset}" 51 | VAL_DATASET_DIR = f"./data/{DATASET}/test" 52 | PART_DESC = f"{args.output_dir}" 53 | 54 | EXPERIMENT_NAME = os.path.basename(os.path.realpath(args.experiment_path)) 55 | CHECKPOINT_PATH = f"runs/{EXPERIMENT_NAME}/logs/checkpoints/{args.checkpoint_name}.pth" 56 | BATCH_SIZE = 24 57 | 58 | 59 | def bounding_box_from_points(points): 60 | points = np.array(points).flatten() 61 | even_locations = np.arange(points.shape[0]/2) * 2 62 | odd_locations = even_locations + 1 63 | X = np.take(points, even_locations.tolist()) 64 | Y = np.take(points, odd_locations.tolist()) 65 | bbox = [X.min(), Y.min(), X.max()-X.min(), Y.max()-Y.min()] 66 | bbox = [int(b) for b in bbox] 67 | return bbox 68 | 69 | 70 | def single_annotation(image_id, poly): 71 | _result = {} 72 | _result["image_id"] = int(image_id) 73 | _result["category_id"] = 100 74 | _result["score"] = 1 75 | _result["segmentation"] = poly 76 | _result["bbox"] = bounding_box_from_points(_result["segmentation"]) 77 | return _result 78 | 79 | 80 | def collate_fn(batch, max_len, pad_idx): 81 | """ 82 | if max_len: 83 | the sequences will all be padded to that length. 84 | """ 85 | 86 | image_batch, mask_batch, coords_mask_batch, coords_seq_batch, perm_matrix_batch, idx_batch = [], [], [], [], [], [] 87 | for image, mask, c_mask, seq, perm_mat, idx in batch: 88 | image_batch.append(image) 89 | mask_batch.append(mask) 90 | coords_mask_batch.append(c_mask) 91 | coords_seq_batch.append(seq) 92 | perm_matrix_batch.append(perm_mat) 93 | idx_batch.append(idx) 94 | 95 | coords_seq_batch = pad_sequence( 96 | coords_seq_batch, 97 | padding_value=pad_idx, 98 | batch_first=True 99 | ) 100 | 101 | if max_len: 102 | pad = torch.ones(coords_seq_batch.size(0), max_len - coords_seq_batch.size(1)).fill_(pad_idx).long() 103 | coords_seq_batch = torch.cat([coords_seq_batch, pad], dim=1) 104 | 105 | image_batch = torch.stack(image_batch) 106 | mask_batch = torch.stack(mask_batch) 107 | coords_mask_batch = torch.stack(coords_mask_batch) 108 | perm_matrix_batch = torch.stack(perm_matrix_batch) 109 | idx_batch = torch.stack(idx_batch) 110 | return image_batch, mask_batch, coords_mask_batch, coords_seq_batch, perm_matrix_batch, idx_batch 111 | 112 | 113 | def main(): 114 | seed_everything(42) 115 | 116 | valid_transforms = A.Compose( 117 | [ 118 | A.Resize(height=CFG.INPUT_HEIGHT, width=CFG.INPUT_WIDTH), 119 | A.Normalize( 120 | mean=[0.0, 0.0, 0.0], 121 | std=[1.0, 1.0, 1.0], 122 | max_pixel_value=255.0 123 | ), 124 | ToTensorV2(), 125 | ], 126 | keypoint_params=A.KeypointParams(format='yx', remove_invisible=False) 127 | ) 128 | 129 | tokenizer = Tokenizer( 130 | num_classes=1, 131 | num_bins=CFG.NUM_BINS, 132 | width=CFG.INPUT_WIDTH, 133 | height=CFG.INPUT_HEIGHT, 134 | max_len=CFG.MAX_LEN 135 | ) 136 | CFG.PAD_IDX = tokenizer.PAD_code 137 | 138 | val_ds = WHUBuildingsCocoDataset_val( 139 | cfg=CFG, 140 | dataset_dir=VAL_DATASET_DIR, 141 | transform=valid_transforms, 142 | tokenizer=tokenizer, 143 | shuffle_tokens=CFG.SHUFFLE_TOKENS 144 | ) 145 | val_loader = DataLoader( 146 | val_ds, 147 | batch_size=BATCH_SIZE, 148 | collate_fn=partial(collate_fn, max_len=CFG.MAX_LEN, pad_idx=CFG.PAD_IDX), 149 | num_workers=2 150 | ) 151 | 152 | encoder = Encoder(model_name=CFG.MODEL_NAME, pretrained=True, out_dim=256) 153 | decoder = Decoder( 154 | cfg=CFG, 155 | vocab_size=tokenizer.vocab_size, 156 | encoder_len=CFG.NUM_PATCHES, 157 | dim=256, 158 | num_heads=8, 159 | num_layers=6 160 | ) 161 | model = EncoderDecoder(cfg=CFG, encoder=encoder, decoder=decoder) 162 | model.to(CFG.DEVICE) 163 | model.eval() 164 | 165 | checkpoint = torch.load(CHECKPOINT_PATH) 166 | model.load_state_dict(checkpoint['state_dict']) 167 | epoch = checkpoint['epochs_run'] 168 | 169 | print(f"Model loaded from epoch: {epoch}") 170 | ckpt_desc = f"epoch_{epoch}" 171 | if "best_valid_loss" in os.path.basename(CHECKPOINT_PATH): 172 | ckpt_desc = f"epoch_{epoch}_bestValLoss" 173 | elif "best_valid_metric" in os.path.basename(CHECKPOINT_PATH): 174 | ckpt_desc = f"epoch_{epoch}_bestValMetric" 175 | else: 176 | pass 177 | 178 | mean_iou_metric = BinaryJaccardIndex() 179 | mean_acc_metric = BinaryAccuracy() 180 | 181 | 182 | with torch.no_grad(): 183 | cumulative_miou = [] 184 | cumulative_macc = [] 185 | speed = [] 186 | predictions = [] 187 | for i_batch, (x, y_mask, y_corner_mask, y, y_perm, idx) in enumerate(tqdm(val_loader)): 188 | all_coords = [] 189 | all_confs = [] 190 | t0 = time.time() 191 | batch_preds, batch_confs, perm_preds = test_generate(model, x, tokenizer, max_len=CFG.generation_steps, top_k=0, top_p=1) 192 | speed.append(time.time() - t0) 193 | vertex_coords, confs = postprocess(batch_preds, batch_confs, tokenizer) 194 | 195 | all_coords.extend(vertex_coords) 196 | all_confs.extend(confs) 197 | 198 | coords = [] 199 | for i in range(len(all_coords)): 200 | if all_coords[i] is not None: 201 | coord = torch.from_numpy(all_coords[i]) 202 | else: 203 | coord = torch.tensor([]) 204 | 205 | padd = torch.ones((CFG.N_VERTICES - len(coord), 2)).fill_(CFG.PAD_IDX) 206 | coord = torch.cat([coord, padd], dim=0) 207 | coords.append(coord) 208 | batch_polygons = permutations_to_polygons(perm_preds, coords, out='torch') # [0, 224] 209 | # pred_polygons = permutations_to_polygons(perm_preds, coords, out='coco') # [0, 224] 210 | 211 | for ip, pp in enumerate(batch_polygons): 212 | for p in pp: 213 | p = torch.fliplr(p) 214 | p = p[p[:, 0] != CFG.PAD_IDX] 215 | p = p * (CFG.IMG_SIZE / CFG.INPUT_WIDTH) 216 | p = p.view(-1).tolist() 217 | if len(p) > 0: 218 | predictions.append(single_annotation(idx[ip], [p])) 219 | 220 | B, C, H, W = x.shape 221 | 222 | polygons_mask = np.zeros((B, 1, H, W)) 223 | for b in range(len(batch_polygons)): 224 | for c in range(len(batch_polygons[b])): 225 | poly = batch_polygons[b][c] 226 | poly = poly[poly[:, 0] != CFG.PAD_IDX] 227 | cnt = np.flip(np.int32(poly.cpu()), 1) 228 | if len(cnt) > 0: 229 | cv2.fillPoly(polygons_mask[b, 0], pts=[cnt], color=1.) 230 | polygons_mask = torch.from_numpy(polygons_mask) 231 | 232 | batch_miou = mean_iou_metric(polygons_mask, y_mask) 233 | batch_macc = mean_acc_metric(polygons_mask, y_mask) 234 | 235 | cumulative_miou.append(batch_miou) 236 | cumulative_macc.append(batch_macc) 237 | 238 | pred_grid = make_grid(polygons_mask).permute(1, 2, 0) 239 | gt_grid = make_grid(y_mask).permute(1, 2, 0) 240 | plt.subplot(211), plt.imshow(pred_grid) ,plt.title("Predicted Polygons") ,plt.axis('off') 241 | plt.subplot(212), plt.imshow(gt_grid) ,plt.title("Ground Truth") ,plt.axis('off') 242 | 243 | if not os.path.exists(os.path.join(f"runs/{EXPERIMENT_NAME}", 'test_preds', DATASET, PART_DESC, ckpt_desc)): 244 | os.makedirs(os.path.join(f"runs/{EXPERIMENT_NAME}", 'test_preds', DATASET, PART_DESC, ckpt_desc)) 245 | plt.savefig(f"runs/{EXPERIMENT_NAME}/test_preds/{DATASET}/{PART_DESC}/{ckpt_desc}/batch_{i_batch}.png") 246 | plt.close() 247 | 248 | print("Average model speed: ", np.mean(speed) / BATCH_SIZE, " [s / image]") 249 | 250 | print(f"Average Mean IOU: {torch.tensor(cumulative_miou).nanmean()}") 251 | print(f"Average Mean Acc: {torch.tensor(cumulative_macc).nanmean()}") 252 | 253 | with open(f"runs/{EXPERIMENT_NAME}/predictions_{DATASET}_{PART_DESC}_{ckpt_desc}.json", "w") as fp: 254 | fp.write(json.dumps(predictions)) 255 | 256 | with open(f"runs/{EXPERIMENT_NAME}/test_metrics_{DATASET}_{PART_DESC}_{ckpt_desc}.txt", 'w') as ff: 257 | print(f"Average Mean IOU: {torch.tensor(cumulative_miou).nanmean()}", file=ff) 258 | print(f"Average Mean Acc: {torch.tensor(cumulative_macc).nanmean()}", file=ff) 259 | 260 | 261 | if __name__ == "__main__": 262 | main() 263 | 264 | -------------------------------------------------------------------------------- /pyrightconfig.json: -------------------------------------------------------------------------------- 1 | { 2 | "typeCheckingMode": "off" 3 | } 4 | -------------------------------------------------------------------------------- /tokenizer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from config import CFG 4 | 5 | 6 | class Tokenizer: 7 | def __init__(self, num_classes: int, num_bins: int, width: int, height: int, max_len=256): 8 | self.num_classes = num_classes 9 | self.num_bins = num_bins 10 | self.width = width 11 | self.height = height 12 | self.max_len = max_len 13 | 14 | self.BOS_code = num_bins 15 | self.EOS_code = self.BOS_code + 1 16 | self.PAD_code = self.EOS_code + 1 17 | 18 | self.vocab_size = num_bins + 3 #+ num_classes 19 | 20 | def quantize(self, x: np.array): 21 | """ 22 | x is a real number in [0, 1] 23 | """ 24 | 25 | return (x * (self.num_bins - 1)).round(0).astype('int') 26 | 27 | def dequantize(self, x: np.array): 28 | """ 29 | x is an integer between [0, num_bins-1] 30 | """ 31 | 32 | return x.astype('float32') / (self.num_bins - 1) 33 | 34 | def __call__(self, coords: np.array, shuffle=True): 35 | 36 | if len(coords) > 0: 37 | coords[:, 0] = coords[:, 0] / self.width 38 | coords[:, 1] = coords[:, 1] / self.height 39 | 40 | coords = self.quantize(coords)[:self.max_len] 41 | 42 | if shuffle: 43 | rand_idxs = np.arange(0, len(coords)) 44 | if 'debug' in CFG.EXPERIMENT_NAME: 45 | rand_idxs = rand_idxs[::-1] 46 | else: 47 | np.random.shuffle(rand_idxs) 48 | coords = coords[rand_idxs] 49 | else: 50 | rand_idxs = np.arange(0, len(coords)) 51 | 52 | tokenized = [self.BOS_code] 53 | for coord in coords: 54 | tokens = list(coord) 55 | 56 | tokenized.extend(list(map(int, tokens))) 57 | tokenized.append(self.EOS_code) 58 | 59 | return tokenized, rand_idxs 60 | 61 | def decode(self, tokens: torch.Tensor): 62 | """ 63 | tokens: torch.LongTensor with shape [L] 64 | """ 65 | 66 | mask = tokens != self.PAD_code 67 | tokens = tokens[mask] 68 | tokens = tokens[1:-1] 69 | assert len(tokens) % 2 == 0, "Invalid tokens!" 70 | 71 | coords = [] 72 | for i in range(2, len(tokens)+1, 2): 73 | coord = tokens[i-2: i] 74 | coords.append([int(item) for item in coord]) 75 | coords = np.array(coords) 76 | coords = self.dequantize(coords) 77 | 78 | if len(coords) > 0: 79 | coords[:, 0] = coords[:, 0] * self.width 80 | coords[:, 1] = coords[:, 1] * self.height 81 | 82 | return coords 83 | 84 | -------------------------------------------------------------------------------- /train_ddp.py: -------------------------------------------------------------------------------- 1 | import os 2 | from os import path as osp 3 | import torch 4 | from torch import nn 5 | from torch import optim 6 | import albumentations as A 7 | from albumentations.pytorch import ToTensorV2 8 | from transformers import ( 9 | get_linear_schedule_with_warmup, 10 | ) 11 | from torch.utils.tensorboard import SummaryWriter 12 | 13 | from config import CFG 14 | from tokenizer import Tokenizer 15 | from utils import ( 16 | seed_everything, 17 | load_checkpoint, 18 | ) 19 | from ddp_utils import ( 20 | get_inria_loaders, 21 | get_spacenet_loaders, 22 | get_whu_buildings_loaders, 23 | get_mass_roads_loaders, 24 | ) 25 | 26 | from models.model import ( 27 | Encoder, 28 | Decoder, 29 | EncoderDecoder 30 | ) 31 | 32 | from engine import train_eval 33 | 34 | from torch import distributed as dist 35 | import torch.multiprocessing 36 | torch.multiprocessing.set_sharing_strategy("file_system") 37 | 38 | 39 | def init_distributed(): 40 | 41 | # Initializes the distributed backend which will take care of synchronizing nodes/GPUs. 42 | dist_url = "env://" # default 43 | 44 | # only works with torch.distributed.launch or torch.run. 45 | rank = int(os.environ["RANK"]) 46 | world_size = int(os.environ["WORLD_SIZE"]) 47 | local_rank = int(os.environ["LOCAL_RANK"]) 48 | dist.init_process_group( 49 | backend="nccl", 50 | init_method=dist_url, 51 | world_size=world_size, 52 | rank=rank 53 | ) 54 | 55 | # this will make all .cuda() calls work properly. 56 | torch.cuda.set_device(local_rank) 57 | 58 | # synchronizes all threads to reach this point before moving on. 59 | dist.barrier() 60 | 61 | 62 | def main(): 63 | # setup the process groups 64 | init_distributed() 65 | seed_everything(42) 66 | 67 | # Define tensorboard for logging. 68 | writer = SummaryWriter(f"runs/{CFG.EXPERIMENT_NAME}/logs/tensorboard") 69 | attrs = vars(CFG) 70 | with open(f"runs/{CFG.EXPERIMENT_NAME}/config.txt", "w") as f: 71 | print("\n".join("%s: %s" % item for item in attrs.items()), file=f) 72 | 73 | train_transforms = A.Compose( 74 | [ 75 | A.Affine(rotate=[-360, 360], fit_output=True, p=0.8), # scaled rotations are performed before resizing to ensure rotated and scaled images are correctly resized. 76 | A.Resize(height=CFG.INPUT_HEIGHT, width=CFG.INPUT_WIDTH), 77 | A.RandomRotate90(p=1.), 78 | A.RandomBrightnessContrast(p=0.5), 79 | A.ColorJitter(), 80 | A.ToGray(p=0.4), 81 | A.GaussNoise(), 82 | # ToTensorV2 of albumentations doesn't divide by 255 like in PyTorch, 83 | # it is done inside Normalize function. 84 | A.Normalize( 85 | mean=[0.0, 0.0, 0.0], 86 | std=[1.0, 1.0, 1.0], 87 | max_pixel_value=255.0 88 | ), 89 | ToTensorV2(), 90 | ], 91 | keypoint_params=A.KeypointParams(format='yx', remove_invisible=False) 92 | ) 93 | 94 | valid_transforms = A.Compose( 95 | [ 96 | A.Resize(height=CFG.INPUT_HEIGHT, width=CFG.INPUT_WIDTH), 97 | A.Normalize( 98 | mean=[0.0, 0.0, 0.0], 99 | std=[1.0, 1.0, 1.0], 100 | max_pixel_value=255.0 101 | ), 102 | ToTensorV2(), 103 | ], 104 | keypoint_params=A.KeypointParams(format='yx', remove_invisible=False) 105 | ) 106 | 107 | if "debug" in CFG.EXPERIMENT_NAME: 108 | train_transforms = valid_transforms 109 | 110 | tokenizer = Tokenizer( 111 | num_classes=1, 112 | num_bins=CFG.NUM_BINS, 113 | width=CFG.INPUT_WIDTH, 114 | height=CFG.INPUT_HEIGHT, 115 | max_len=CFG.MAX_LEN 116 | ) 117 | CFG.PAD_IDX = tokenizer.PAD_code 118 | 119 | if "inria" in CFG.DATASET: 120 | train_loader, val_loader, _ = get_inria_loaders( 121 | CFG.TRAIN_DATASET_DIR, 122 | CFG.VAL_DATASET_DIR, 123 | CFG.TEST_IMAGES_DIR, 124 | tokenizer, 125 | CFG.MAX_LEN, 126 | tokenizer.PAD_code, 127 | CFG.SHUFFLE_TOKENS, 128 | CFG.BATCH_SIZE, 129 | train_transforms, 130 | valid_transforms, 131 | CFG.NUM_WORKERS, 132 | CFG.PIN_MEMORY, 133 | ) 134 | elif "spacenet" in CFG.DATASET: 135 | train_loader, val_loader, _ = get_spacenet_loaders( 136 | CFG.TRAIN_DATASET_DIR, 137 | CFG.VAL_DATASET_DIR, 138 | CFG.TEST_IMAGES_DIR, 139 | tokenizer, 140 | CFG.MAX_LEN, 141 | tokenizer.PAD_code, 142 | CFG.SHUFFLE_TOKENS, 143 | CFG.BATCH_SIZE, 144 | train_transforms, 145 | valid_transforms, 146 | CFG.NUM_WORKERS, 147 | CFG.PIN_MEMORY, 148 | ) 149 | elif "whu_buildings" in CFG.DATASET: 150 | train_loader, val_loader, _ = get_whu_buildings_loaders( 151 | CFG.TRAIN_DATASET_DIR, 152 | CFG.VAL_DATASET_DIR, 153 | CFG.TEST_IMAGES_DIR, 154 | tokenizer, 155 | CFG.MAX_LEN, 156 | tokenizer.PAD_code, 157 | CFG.SHUFFLE_TOKENS, 158 | CFG.BATCH_SIZE, 159 | train_transforms, 160 | valid_transforms, 161 | CFG.NUM_WORKERS, 162 | CFG.PIN_MEMORY, 163 | ) 164 | elif "mass_roads" in CFG.DATASET: 165 | train_loader, val_loader, test_loader = get_mass_roads_loaders( 166 | CFG.TRAIN_DATASET_DIR, 167 | CFG.VAL_DATASET_DIR, 168 | CFG.TEST_IMAGES_DIR, 169 | tokenizer, 170 | CFG.MAX_LEN, 171 | tokenizer.PAD_code, 172 | CFG.SHUFFLE_TOKENS, 173 | CFG.BATCH_SIZE, 174 | train_transforms, 175 | valid_transforms, 176 | CFG.NUM_WORKERS, 177 | CFG.PIN_MEMORY, 178 | ) 179 | else: 180 | pass 181 | 182 | encoder = Encoder(model_name=CFG.MODEL_NAME, pretrained=True, out_dim=256) 183 | decoder = Decoder( 184 | cfg=CFG, 185 | vocab_size=tokenizer.vocab_size, 186 | encoder_len=CFG.NUM_PATCHES, 187 | dim=256, 188 | num_heads=8, 189 | num_layers=6 190 | ) 191 | model = EncoderDecoder(cfg=CFG, encoder=encoder, decoder=decoder) 192 | model.to(CFG.DEVICE) 193 | 194 | weight = torch.ones(CFG.PAD_IDX + 1, device=CFG.DEVICE) 195 | weight[tokenizer.num_bins:tokenizer.BOS_code] = 0.0 196 | vertex_loss_fn = nn.CrossEntropyLoss(ignore_index=CFG.PAD_IDX, label_smoothing=CFG.LABEL_SMOOTHING, weight=weight) 197 | perm_loss_fn = nn.BCELoss() 198 | 199 | optimizer = optim.AdamW(model.parameters(), lr=CFG.LR, weight_decay=CFG.WEIGHT_DECAY, betas=(0.9, 0.95)) 200 | 201 | num_training_steps = CFG.NUM_EPOCHS * (len(train_loader.dataset) // CFG.BATCH_SIZE // torch.cuda.device_count()) 202 | num_warmup_steps = int(0.05 * num_training_steps) 203 | lr_scheduler = get_linear_schedule_with_warmup( 204 | optimizer, 205 | num_training_steps=num_training_steps, 206 | num_warmup_steps=num_warmup_steps 207 | ) 208 | 209 | local_rank = int(os.environ["LOCAL_RANK"]) 210 | CFG.START_EPOCH = 0 211 | if CFG.LOAD_MODEL: 212 | checkpoint_name = osp.basename(osp.realpath(CFG.CHECKPOINT_PATH)) 213 | map_location = {'cuda:%d' % 0: 'cuda:%d' % local_rank} 214 | start_epoch = load_checkpoint( 215 | torch.load(f"runs/{CFG.EXPERIMENT_NAME}/logs/checkpoints/{checkpoint_name}", map_location=map_location), 216 | model, 217 | optimizer, 218 | lr_scheduler 219 | ) 220 | CFG.START_EPOCH = start_epoch + 1 221 | dist.barrier() 222 | 223 | # Convert BatchNorm in model to SyncBatchNorm. 224 | model = nn.SyncBatchNorm.convert_sync_batchnorm(model) 225 | # Wrap model with distributed data parallel. 226 | model = nn.parallel.DistributedDataParallel(model, device_ids=[local_rank]) 227 | 228 | train_eval( 229 | model, 230 | train_loader, 231 | val_loader, 232 | val_loader, 233 | tokenizer, 234 | vertex_loss_fn, 235 | perm_loss_fn, 236 | optimizer, 237 | lr_scheduler=lr_scheduler, 238 | step='batch', 239 | writer=writer 240 | ) 241 | 242 | 243 | if __name__ == "__main__": 244 | main() 245 | 246 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import random 4 | import numpy as np 5 | from scipy.optimize import linear_sum_assignment 6 | 7 | import torch 8 | import torchvision 9 | from transformers import top_k_top_p_filtering 10 | from torchmetrics.functional.classification import binary_jaccard_index, binary_accuracy 11 | from config import CFG 12 | 13 | 14 | def seed_everything(seed=1234): 15 | random.seed(seed) 16 | os.environ['PYTHONHASHSEED'] = str(seed) 17 | np.random.seed(seed) 18 | torch.manual_seed(seed) 19 | torch.cuda.manual_seed(seed) 20 | torch.backends.cudnn.deterministic = True 21 | 22 | 23 | def save_checkpoint(state, folder="logs/checkpoint/run1", filename="my_checkpoint.pth.tar"): 24 | print("=> Saving checkpoint") 25 | if not os.path.exists(folder): 26 | os.makedirs(folder) 27 | torch.save(state, os.path.join(folder, filename)) 28 | 29 | 30 | def load_checkpoint(checkpoint, model, optimizer, scheduler): 31 | print("=> Loading checkpoint") 32 | model.load_state_dict(checkpoint["state_dict"]) 33 | optimizer.load_state_dict(checkpoint["optimizer"]) 34 | scheduler.load_state_dict(checkpoint["scheduler"]) 35 | 36 | return checkpoint["epochs_run"] 37 | 38 | 39 | def generate_square_subsequent_mask(sz): 40 | mask = ( 41 | torch.triu(torch.ones((sz, sz), device=CFG.DEVICE)) == 1 42 | ).transpose(0, 1) 43 | 44 | mask = mask.float().masked_fill(mask==0, float('-inf')).masked_fill(mask==1, float(0.0)) 45 | 46 | return mask 47 | 48 | 49 | def create_mask(tgt, pad_idx): 50 | """ 51 | tgt shape: (N, L) 52 | """ 53 | 54 | tgt_seq_len = tgt.size(1) 55 | tgt_mask = generate_square_subsequent_mask(tgt_seq_len) 56 | tgt_padding_mask = (tgt == pad_idx) 57 | 58 | return tgt_mask, tgt_padding_mask 59 | 60 | 61 | class AverageMeter: 62 | def __init__(self, name="Metric"): 63 | self.name = name 64 | self.reset() 65 | 66 | def reset(self): 67 | self.avg, self.sum, self.count = [0]*3 68 | 69 | def update(self, val, count=1): 70 | self.count += count 71 | self.sum += val * count 72 | self.avg = self.sum / self.count 73 | 74 | def __repr__(self) -> str: 75 | text = f"{self.name}: {self.avg:.4f}" 76 | return text 77 | 78 | 79 | def get_lr(optimizer): 80 | for param_group in optimizer.param_groups: 81 | return param_group['lr'] 82 | 83 | 84 | def scores_to_permutations(scores): 85 | """ 86 | Input a batched array of scores and returns the hungarian optimized 87 | permutation matrices. 88 | """ 89 | B, N, N = scores.shape 90 | 91 | scores = scores.detach().cpu().numpy() 92 | perm = np.zeros_like(scores) 93 | for b in range(B): 94 | r, c = linear_sum_assignment(-scores[b]) 95 | perm[b,r,c] = 1 96 | return torch.tensor(perm) 97 | 98 | 99 | # TODO: add permalink to polyworld repo 100 | def permutations_to_polygons(perm, graph, out='torch'): 101 | B, N, N = perm.shape 102 | device = perm.device 103 | 104 | def bubble_merge(poly): 105 | s = 0 106 | P = len(poly) 107 | while s < P: 108 | head = poly[s][-1] 109 | 110 | t = s+1 111 | while t < P: 112 | tail = poly[t][0] 113 | if head == tail: 114 | poly[s] = poly[s] + poly[t][1:] 115 | del poly[t] 116 | poly = bubble_merge(poly) 117 | P = len(poly) 118 | t += 1 119 | s += 1 120 | return poly 121 | 122 | diag = torch.logical_not(perm[:,range(N),range(N)]) 123 | batch = [] 124 | for b in range(B): 125 | b_perm = perm[b] 126 | b_graph = graph[b] 127 | b_diag = diag[b] 128 | 129 | idx = torch.arange(N, device=perm.device)[b_diag] 130 | 131 | if idx.shape[0] > 0: 132 | # If there are vertices in the batch 133 | 134 | b_perm = b_perm[idx,:] 135 | b_graph = b_graph[idx,:] 136 | b_perm = b_perm[:,idx] 137 | 138 | first = torch.arange(idx.shape[0]).unsqueeze(1).to(device=device) 139 | second = torch.argmax(b_perm, dim=1).unsqueeze(1) 140 | 141 | polygons_idx = torch.cat((first, second), dim=1).tolist() 142 | polygons_idx = bubble_merge(polygons_idx) 143 | 144 | batch_poly = [] 145 | for p_idx in polygons_idx: 146 | if out == 'torch': 147 | batch_poly.append(b_graph[p_idx,:]) 148 | elif out == 'numpy': 149 | batch_poly.append(b_graph[p_idx,:].cpu().numpy()) 150 | elif out == 'list': 151 | g = b_graph[p_idx,:] * 300 / 320 152 | g[:,0] = -g[:,0] 153 | g = torch.fliplr(g) 154 | batch_poly.append(g.tolist()) 155 | elif out == 'coco': 156 | g = b_graph[p_idx,:]# * CFG.IMG_SIZE / CFG.INPUT_WIDTH 157 | g = torch.fliplr(g) 158 | batch_poly.append(g.view(-1).tolist()) 159 | elif out == 'inria-torch': 160 | batch_poly.append(b_graph[p_idx,:]) 161 | else: 162 | print("Indicate a valid output polygon format") 163 | exit() 164 | 165 | batch.append(batch_poly) 166 | 167 | else: 168 | # If the batch has no vertices 169 | batch.append([]) 170 | 171 | return batch 172 | 173 | 174 | def test_generate(model, x, tokenizer, max_len=50, top_k=0, top_p=1): 175 | x = x.to(CFG.DEVICE) 176 | batch_preds = torch.ones((x.size(0), 1), device=CFG.DEVICE).fill_(tokenizer.BOS_code).long() 177 | confs = [] 178 | 179 | if top_k != 0 or top_p != 1: 180 | sample = lambda preds: torch.softmax(preds, dim=-1).multinomial(num_samples=1).view(-1, 1) 181 | else: 182 | sample = lambda preds: torch.softmax(preds, dim=-1).argmax(dim=-1).view(-1, 1) 183 | 184 | with torch.no_grad(): 185 | for i in range(max_len): 186 | if isinstance(model, torch.nn.parallel.DistributedDataParallel): 187 | preds, feats = model.module.predict(x, batch_preds) 188 | else: 189 | preds, feats = model.predict(x, batch_preds) 190 | preds = top_k_top_p_filtering(preds, top_k=top_k, top_p=top_p) # if top_k and top_p are set to default, this line does nothing. 191 | if i % 2 == 0: 192 | confs_ = torch.softmax(preds, dim=-1).sort(axis=-1, descending=True)[0][:, 0].cpu() 193 | confs.append(confs_) 194 | preds = sample(preds) 195 | batch_preds = torch.cat([batch_preds, preds], dim=1) 196 | 197 | if isinstance(model, torch.nn.parallel.DistributedDataParallel): 198 | perm_preds = model.module.scorenet1(feats) + torch.transpose(model.module.scorenet2(feats), 1, 2) 199 | else: 200 | perm_preds = model.scorenet1(feats) + torch.transpose(model.scorenet2(feats), 1, 2) 201 | 202 | perm_preds = scores_to_permutations(perm_preds) 203 | 204 | return batch_preds.cpu(), confs, perm_preds 205 | 206 | 207 | def postprocess(batch_preds, batch_confs, tokenizer): 208 | EOS_idxs = (batch_preds == tokenizer.EOS_code).float().argmax(dim=-1) 209 | ## sanity check 210 | invalid_idxs = ((EOS_idxs - 1) % 2 != 0).nonzero().view(-1) 211 | EOS_idxs[invalid_idxs] = 0 212 | 213 | all_coords = [] 214 | all_confs = [] 215 | for i, EOS_idx in enumerate(EOS_idxs.tolist()): 216 | if EOS_idx == 0: 217 | all_coords.append(None) 218 | all_confs.append(None) 219 | continue 220 | coords = tokenizer.decode(batch_preds[i, :EOS_idx+1]) 221 | confs = [round(batch_confs[j][i].item(), 3) for j in range(len(coords))] 222 | 223 | all_coords.append(coords) 224 | all_confs.append(confs) 225 | 226 | return all_coords, all_confs 227 | 228 | 229 | def save_single_predictions_as_images( 230 | loader, model, tokenizer, epoch, writer, folder="saved_outputs/", device="cuda" 231 | ): 232 | print(f"=> Saving val predictions...") 233 | if not os.path.exists(folder): 234 | print(f"==> Creating output subdirectory...") 235 | os.makedirs(folder) 236 | 237 | model.eval() 238 | 239 | all_coords = [] 240 | all_confs = [] 241 | 242 | with torch.no_grad(): 243 | loader_iterator = iter(loader) 244 | idx, (x, y_mask, y_corner_mask, y, y_perm) = 0, next(loader_iterator) 245 | batch_preds, batch_confs, perm_preds = test_generate(model, x, tokenizer, max_len=CFG.generation_steps, top_k=0, top_p=1) 246 | vertex_coords, confs = postprocess(batch_preds, batch_confs, tokenizer) 247 | 248 | all_coords.extend(vertex_coords) 249 | all_confs.extend(confs) 250 | 251 | coords = [] 252 | for i in range(len(all_coords)): 253 | if all_coords[i] is not None: 254 | coord = torch.from_numpy(all_coords[i]) 255 | else: 256 | coord = torch.tensor([]) 257 | padd = torch.ones((CFG.N_VERTICES - len(coord), 2)).fill_(tokenizer.PAD_code) 258 | coord = torch.cat((coord, padd), dim=0) 259 | coords.append(coord) 260 | batch_polygons = permutations_to_polygons(perm_preds, coords, out='torch') # list of polygon coordinate tensors 261 | 262 | B, C, H, W = x.shape 263 | # Write predicted vertices as mask to disk. 264 | vertex_mask = np.zeros((B, 1, H, W)) 265 | for b in range(len(all_coords)): 266 | if all_coords[b] is not None: 267 | print(f"Vertices found!") 268 | for i in range(len(all_coords[b])): 269 | coord = all_coords[b][i] 270 | cx, cy = coord 271 | cv2.circle(vertex_mask[b, 0], (int(cy), int(cx)), 0, 255, -1) 272 | vertex_mask = torch.from_numpy(vertex_mask) 273 | if not os.path.exists(os.path.join(folder, 'corners_mask')): 274 | os.makedirs(os.path.join(folder, 'corners_mask')) 275 | vertex_pred_vis = torch.zeros_like(x) 276 | for b in range(B): 277 | vertex_pred_vis[b] = torchvision.utils.draw_segmentation_masks( 278 | (x[b]*255).to(dtype=torch.uint8), 279 | torch.zeros_like(x[b, 0]).bool() 280 | ) 281 | vertex_pred_vis = vertex_pred_vis.cpu().numpy().astype(np.uint8) 282 | for b in range(len(all_coords)): 283 | if all_coords[b] is not None: 284 | for i in range(len(all_coords[b])): 285 | coord = all_coords[b][i] 286 | cx, cy = coord 287 | cv2.circle(vertex_pred_vis[b, 0], (int(cy), int(cx)), 3, 255, -1) 288 | vertex_pred_vis = torch.from_numpy(vertex_pred_vis) 289 | torchvision.utils.save_image( 290 | vertex_pred_vis.float()/255, f"{folder}/corners_mask/corners_mask_{b}_{epoch}.png" 291 | ) 292 | 293 | # Write predicted polygons as mask to disk. 294 | polygons = np.zeros((B, 1, H, W)) 295 | for b in range(B): 296 | for c in range(len(batch_polygons[b])): 297 | poly = batch_polygons[b][c] 298 | poly = poly[poly[:, 0] != tokenizer.PAD_code] 299 | cnt = np.flip(np.int32(poly.cpu()), 1) 300 | if len(cnt) > 0: 301 | cv2.fillPoly(polygons[b, 0], pts=[cnt], color=1.) 302 | polygons = torch.from_numpy(polygons) 303 | if not os.path.exists(os.path.join(folder, 'pred_polygons')): 304 | os.makedirs(os.path.join(folder, 'pred_polygons')) 305 | poly_out = torch.zeros_like(x) 306 | for b in range(B): 307 | poly_out[b] = torchvision.utils.draw_segmentation_masks( 308 | (x[b]*255).to(dtype=torch.uint8), 309 | polygons[b, 0].bool() 310 | ) 311 | poly_out = poly_out.cpu().numpy().astype(np.uint8) 312 | for b in range(len(all_coords)): 313 | if all_coords[b] is not None: 314 | for i in range(len(all_coords[b])): 315 | coord = all_coords[b][i] 316 | cx, cy = coord 317 | cv2.circle(poly_out[b, 0], (int(cy), int(cx)), 2, 255, -1) 318 | poly_out = torch.from_numpy(poly_out) 319 | torchvision.utils.save_image( 320 | poly_out.float()/255, f"{folder}/pred_polygons/polygons_{idx}_{epoch}.png" 321 | ) 322 | 323 | batch_miou = binary_jaccard_index(polygons, y_mask) 324 | batch_biou = binary_jaccard_index(polygons, y_mask, ignore_index=0) 325 | batch_macc = binary_accuracy(polygons, y_mask) 326 | batch_bacc = binary_accuracy(polygons, y_mask, ignore_index=0) 327 | 328 | writer.add_scalar('Val_Metrics/Mean_IoU', batch_miou, epoch) 329 | writer.add_scalar('Val_Metrics/Building_IoU', batch_biou, epoch) 330 | writer.add_scalar('Val_Metrics/Mean_Accuracy', batch_macc, epoch) 331 | writer.add_scalar('Val_Metrics/Building_Accuracy', batch_bacc, epoch) 332 | 333 | metrics_dict = { 334 | "miou": batch_miou, 335 | "biou": batch_biou, 336 | "macc": batch_macc, 337 | "bacc": batch_bacc 338 | } 339 | 340 | torchvision.utils.save_image(x, f"{folder}/image_{idx}.png") 341 | ymask_out = torch.zeros_like(x) 342 | for b in range(B): 343 | ymask_out[b] = torchvision.utils.draw_segmentation_masks( 344 | (x[b]*255).to(dtype=torch.uint8), 345 | y_mask[b, 0].bool() 346 | ) 347 | ymask_out = ymask_out.cpu().numpy().astype(np.uint8) 348 | gt_corner_coords, _ = postprocess(y, batch_confs, tokenizer) 349 | for b in range(B): 350 | for corner in gt_corner_coords[b]: 351 | cx, cy = corner 352 | cv2.circle(ymask_out[b, 0], (int(cy), int(cx)), 3, 255, -1) 353 | ymask_out = torch.from_numpy(ymask_out) 354 | torchvision.utils.save_image(ymask_out/255., f"{folder}/gt_mask_{idx}.png") 355 | torchvision.utils.save_image(y_corner_mask*255, f"{folder}/gt_corners_{idx}.png") 356 | torchvision.utils.save_image(y_perm[:, None, :, :]*255, f"{folder}/gt_perm_matrix_{idx}.png") 357 | 358 | return metrics_dict 359 | --------------------------------------------------------------------------------