├── LICENSE ├── README.md ├── configs ├── augmentations │ ├── augmentations.json │ ├── detection_augmentations.json │ └── point_augmentations.json ├── configs.json ├── datasets │ ├── bigearthnet.json │ ├── fomo_pretraining_datasets.json │ ├── mixed_detection.json │ ├── neontree_detection.json │ ├── reforestree.json │ ├── supervised_foundation_cls.json │ └── tallos.json ├── download │ └── download.json ├── method │ ├── classification │ │ ├── convnext.json │ │ ├── resnet50.json │ │ └── vit.json │ ├── detection │ │ ├── fasterrcnn.json │ │ ├── retinanet.json │ │ └── yolos.json │ ├── mae │ │ ├── mae_multi.json │ │ └── mae_single.json │ └── segmentation │ │ ├── deeplab.json │ │ ├── unet50.json │ │ └── upernet.json ├── stats │ └── stats.json └── training │ ├── detection │ ├── detection_finetuning_fasterrcnn_mixte_neontree.json │ ├── detection_finetuning_fasterrcnn_mixte_neontree.json~ │ ├── detection_finetuning_fasterrcnn_mixte_reforestree.json │ ├── detection_finetuning_fasterrcnn_mixte_reforestree.json~ │ └── detection_training_fasterrcnn_mixed.json │ ├── mae │ └── mae_training.json │ └── training_example.json ├── datasets ├── BigEarthNetDataset.py ├── FLAIR2Dataset.py ├── FLAIRDataset.py ├── FORinstanceDataset.py ├── FiveBillionPixelsDataset.py ├── ForestNetDataset.py ├── MixedDetectionDataset.py ├── MixedPCSegDataset.py ├── NeonTreeDataset.py ├── RapidAI4EODataset.py ├── ReforesTreeDataset.py ├── SSL4EOLDataset.py ├── Sen12MSDataset.py ├── TreeSatAIDataset.py ├── WaitituDataset.py ├── WoodyDataset.py └── __init__.py ├── downloader.py ├── downloading_scripts ├── bigearthnet.sh ├── flair.sh ├── flair_2.sh ├── forestnet.sh ├── neontree.sh ├── rapidai4eo.sh ├── sen12ms.sh ├── spekboom.sh ├── ssl4eos1s2.sh ├── treesat.sh ├── waititu.sh ├── wildforest.sh └── woody.sh ├── examples ├── pretrained_fomo_example.py └── pretrained_fomobench_example.py ├── main.py ├── model_zoo ├── faster_rcnn │ ├── __pycache__ │ │ ├── faster_rcnn.cpython-310.pyc │ │ └── generalized_rcnn.cpython-310.pyc │ ├── faster_rcnn.py │ ├── faster_rcnn.py~ │ ├── generalized_rcnn.py │ └── generalized_rcnn.py~ ├── multimodal_mae.py ├── point_transformer.py ├── pointnet.py ├── pointnet2.py ├── segformer.py ├── upernet.py └── yolov5 │ ├── __init__.py │ ├── __init__.py~ │ ├── __pycache__ │ ├── __init__.cpython-310.pyc │ ├── general.cpython-310.pyc │ ├── loss.cpython-310.pyc │ ├── metrics.cpython-310.pyc │ └── torch_utils.cpython-310.pyc │ ├── general.py │ ├── general.py~ │ ├── hyp.scratch-low.yaml │ ├── loss.py │ ├── loss.py~ │ ├── metrics.py │ ├── metrics.py~ │ ├── torch_utils.py │ └── torch_utils.py~ ├── requirements.txt ├── training ├── classification.py ├── detection.py ├── mae_training.py ├── point_segmentation.py └── segmentation.py └── utilities ├── __init__.py ├── __pycache__ ├── __init__.cpython-310.pyc ├── augmentations.cpython-310.pyc ├── distributed_utils.cpython-310.pyc ├── model_utilities.cpython-310.pyc ├── utils.cpython-310.pyc └── webdataset_writer.cpython-310.pyc ├── augmentations.py ├── detection_datasets ├── tilerize_neontree.py └── tilerize_reforestree.py ├── distributed_utils.py ├── model_utilities.py ├── pointcloud_datasets ├── tilerize_forinstance.py └── tilerize_neontree.py ├── tilerizer.py ├── utils.py └── webdataset_writer.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 RolnickLab 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /configs/augmentations/augmentations.json: -------------------------------------------------------------------------------- 1 | { 2 | "augmentations": { 3 | "Resize":{ 4 | "value":224, 5 | "p":0.0 6 | }, 7 | "RandomResizedCrop": { 8 | "value": 128, 9 | "scale":[0.2, 1.0], 10 | "interpolation":3, 11 | "p": 1.0 12 | }, 13 | "GaussianBlur": { 14 | "value": [ 15 | 0.1, 16 | 2.0 17 | ], 18 | "p": 0.0 19 | }, 20 | "HorizontalFlip": { 21 | "value": "", 22 | "p": 0.5 23 | }, 24 | "VerticalFlip": { 25 | "value": "", 26 | "p": 0.0 27 | }, 28 | "RandomRotation":{ 29 | "p":0.0 30 | }, 31 | "ElasticTransform": { 32 | "value": "", 33 | "p": 0.0 34 | }, 35 | "Cutout": { 36 | "p": 0.0 37 | }, 38 | "MultNoise": { 39 | "p": 0.0 40 | }, 41 | "GaussianNoise": { 42 | "p": 0.0 43 | } 44 | } 45 | } -------------------------------------------------------------------------------- /configs/augmentations/detection_augmentations.json: -------------------------------------------------------------------------------- 1 | { 2 | "augmentations": { 3 | "Resize":{ 4 | "value":224, 5 | // "value":64, 6 | "p":1.0 7 | }, 8 | "RandomResizedCrop": { 9 | "value": 224, 10 | "scale":[0.2, 1.0], 11 | "interpolation":3, 12 | "p": 0.0 13 | }, 14 | "GaussianBlur": { 15 | "value": [ 16 | 0.1, 17 | 2.0 18 | ], 19 | "p": 0.0 20 | }, 21 | "HorizontalFlip": { 22 | "value": "", 23 | "p": 0.5 24 | }, 25 | "VerticalFlip": { 26 | "value": "", 27 | "p": 0.5 28 | }, 29 | "RandomRotation":{ 30 | "p": 0.0 31 | }, 32 | "ElasticTransform": { 33 | "value": "", 34 | "p": 0.0 35 | }, 36 | "Cutout": { 37 | "p": 0.0 38 | }, 39 | "MultNoise": { 40 | "p": 0.0 41 | }, 42 | "GaussianNoise": { 43 | "p": 0.2 44 | } 45 | } 46 | } 47 | -------------------------------------------------------------------------------- /configs/augmentations/point_augmentations.json: -------------------------------------------------------------------------------- 1 | { 2 | "augmentations": { 3 | //"SamplePoints":{ 4 | // "num":1024, 5 | // "remove_faces":true, 6 | // "include_normals":false 7 | //}, 8 | "RandomJitter": { 9 | "translate": 0.01 10 | }, 11 | "RandomRotate_x": { 12 | "degrees": 15 13 | }, 14 | "RandomRotate_y": { 15 | "degrees": 15 16 | }, 17 | "RandomRotate_z": { 18 | "degrees": 15 19 | } 20 | } 21 | } 22 | -------------------------------------------------------------------------------- /configs/configs.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset":"bigearthnet", // Options: All available datasets 3 | "wandb":false, // Use wandb for logging 4 | "wandb_project":"YourWandbProject", // Wandb project name 5 | "wandb_entity":"YourWandbEntity", // Wandb entity name 6 | "wandb_id_resume":null, // provide the run id to resume wandb logging on the same run. Run id can be found in the checkpoint path as id.json 7 | "phase":"train", // Options: train, test 8 | "eval_checkpoint":null, // Checkpoint to evaluate on when phase=test. If null the last checkpoint will be used. If not null, the evaluation will be based on the provided checkpoint path. 9 | "device":"cuda", 10 | "mixed_precision":true, // Mixed precision training 11 | "seed":999, // Seed for reproducibility 12 | "webdataset":true, // Use webdataset for data loading. If webdataset format is not available at webdataset_root_path, the webdataset will be created. 13 | "webdataset_parallel":true, // Use parallel processes for webdataset creation 14 | "webdataset_write_processes":32, // Number of processes for writing webdataset 15 | "webdataset_shuffle_size": 500,// 1000, 16 | "webdataset_initial_buffer":500,// 1000, 17 | "max_samples_per_shard": 256, //set upper limit 256 samples per shard 18 | "max_sample_resolution":null, // Store samples to shards with fixed resolution e.g Height x Width : 64x64. ONLY FOR PRETRAINING! The labels are deliberately not adapted and the resulting webdataset can only be used when the path is provided in the next option. For finetuning the resolution can be adapted at the augmentations.json 19 | "webdataset_root_path":"your_webdataset_root_path", // Where the webdataset will be stored. If null the webdataset will be stored at the project's root directory. 20 | "checkpoint_root_path":null, //root directory for the checkpoint paths. If null checkpoints will be stored at the project's root directory. 21 | } -------------------------------------------------------------------------------- /configs/datasets/bigearthnet.json: -------------------------------------------------------------------------------- 1 | { 2 | "root_path":"your_root", 3 | "task":"classification", // Possible Tasks: Depending on the dataset, 4 | "metrics": ["accuracy","fscore","map"], //Desired metrics to log 5 | "num_classes":19, 6 | "in_channels":14, 7 | "multilabel":true, 8 | "meta_info":"Data collected in Europe." 9 | } -------------------------------------------------------------------------------- /configs/datasets/fomo_pretraining_datasets.json: -------------------------------------------------------------------------------- 1 | { 2 | "train_datasets":["ssl4eol","rapidai4eo","tallos","flair","uav","fivebillionpixels"], 3 | "dataset":"all", 4 | "dataset_probabilities":[0.2,0.2,0.2,0.1,0.1,0.2], //list of probablities to sample the dataloader. If none all datasets have the same probability 5 | "task":"mae", 6 | "modality_channels": { 7 | "0": "planet-r", 8 | "1": "planet-g", 9 | "2": "planet-b", 10 | "3": "planet-nir", 11 | "4": "sentinel-1-vv", 12 | "5":"sentinel-1-vh", 13 | "6":"sentinel-2-b1", 14 | "7":"sentinel-2-b2", 15 | "8":"sentinel-2-b3", 16 | "9":"sentinel-2-b4", 17 | "10":"sentinel-2-b5", 18 | "11":"sentinel-2-b6", 19 | "12":"sentinel-2-b7", 20 | "13":"sentinel-2-b8", 21 | "14":"sentinel-2-b8a", 22 | "15":"sentinel-2-b9", 23 | "16":"sentinel-2-b10", 24 | "17":"sentinel-2-b11", 25 | "18":"sentinel-2-b12", 26 | "19":"landsat-r", 27 | "20":"landsat-g", 28 | "21":"landsat-b", 29 | "22":"landsat-nir", 30 | "23":"landsat-swir-1", 31 | "24":"landsat-swir-2", 32 | "25":"landsat-panchromatic", 33 | "26":"landsat-aerosol", 34 | "27":"landsat-cirrus", 35 | "28":"aerial-r", 36 | "29":"aerial-g", 37 | "30":"aerial-b", 38 | "31":"aerial-nir", 39 | "32":"dem", 40 | "33":"gaofen2-r", 41 | "34":"gaofen2-g", 42 | "35":"gaofen2-b", 43 | "36":"gaofen2-nir" 44 | }, 45 | "dataset_modality_index":{ 46 | "rapidai4eo":{ 47 | "planet-r":0, 48 | "planet-g":1, 49 | "planet-b":2, 50 | "planet-nir":3, 51 | "sentinel-2-b2":4, 52 | "sentinel-2-b3":5, 53 | "sentinel-2-b4":6, 54 | "sentinel-2-b8":7, 55 | "sentinel-2-b5":8, 56 | "sentinel-2-b6":9, 57 | "sentinel-2-b7":10, 58 | "sentinel-2-b8a":11, 59 | "sentinel-2-b11":12, 60 | "sentinel-2-b1":13, 61 | "sentinel-2-b9":14 62 | }, 63 | "ssl4eos1s2":{ 64 | "sentinel-1-vv":0, 65 | "sentinel-1-vh":1, 66 | "sentinel-2-b1":2, 67 | "sentinel-2-b2":3, 68 | "sentinel-2-b3":4, 69 | "sentinel-2-b4":5, 70 | "sentinel-2-b5":6, 71 | "sentinel-2-b6":7, 72 | "sentinel-2-b7":8, 73 | "sentinel-2-b8":9, 74 | "sentinel-2-b8a":10, 75 | "sentinel-2-b9":11, 76 | "sentinel-2-b11":12, 77 | "sentinel-2-b12":13, 78 | 79 | }, 80 | "ssl4eol":{ 81 | "landsat-aerosol":0, 82 | "landsat-r":3, 83 | "landsat-g":2, 84 | "landsat-b":1, 85 | "landsat-nir":4, 86 | "landsat-swir-1":5, 87 | "landsat-swir-2":6, 88 | "landsat-panchromatic":7, 89 | "landsat-cirrus":8 90 | }, 91 | "spekboom":{ 92 | "aerial-r":0, 93 | "aerial-g":1, 94 | "aerial-b":2, 95 | }, 96 | "waititu":{ 97 | "aerial-r":0, 98 | "aerial-g":1, 99 | "aerial-b":2 100 | }, 101 | "flair":{ 102 | "aerial-r":0, 103 | "aerial-g":1, 104 | "aerial-b":2, 105 | "aerial-nir":3, 106 | }, 107 | "tallos":{ 108 | "sentinel-2-b2":0, 109 | "sentinel-2-b3":1, 110 | "sentinel-2-b4":2, 111 | "sentinel-2-b5":3, 112 | "sentinel-2-b6":4, 113 | "sentinel-2-b7":5, 114 | "sentinel-2-b8":6, 115 | "sentinel-2-b8a":7, 116 | "sentinel-2-b11":8, 117 | "sentinel-2-b12":9, 118 | "sentinel-1-vv":10, 119 | "sentinel-1-vh":11, 120 | "dem":12 121 | }, 122 | "uav":{ 123 | "aerial-r":0, 124 | "aerial-g":1, 125 | "aerial-b":2 126 | }, 127 | "fivebillionpixels":{ 128 | "gaofen2-r":0, 129 | "gaofen2-g":1, 130 | "gaofen2-b":2, 131 | "gaofen2-nir":3 132 | }, 133 | "bigearthnet":{ 134 | "sentinel-1-vv":0, 135 | "sentinel-1-vh":1, 136 | "sentinel-2-b2":2, 137 | "sentinel-2-b3":3, 138 | "sentinel-2-b4":4, 139 | "sentinel-2-b5":5, 140 | "sentinel-2-b6":6, 141 | "sentinel-2-b7":7, 142 | "sentinel-2-b8":8, 143 | "sentinel-2-b8a":9, 144 | "sentinel-2-b9":10, 145 | "sentinel-2-b11":11, 146 | "sentinel-2-b12":12, 147 | } 148 | }, 149 | "split_workers":false, 150 | } -------------------------------------------------------------------------------- /configs/datasets/mixed_detection.json: -------------------------------------------------------------------------------- 1 | {"dataset_names":"neontree_detection,reforestree", 2 | "root_path_neontree":"/network/projects/fomobench/Datasets/NeonTree", 3 | "root_path_reforestree":"/network/projects/fomobench/Datasets/ReforesTree", 4 | "task":"detection", // Possible Tasks: Depending on the dataset, 5 | "segmentation_task":null, // Set to false if segmentation is not required 6 | "download":false, // Option to download the dataset 7 | "checksum": false, // check the MD5 of the downloaded files 8 | "num_classes":7, 9 | "in_channels":3, 10 | "modality":"rgb", 11 | "metrics":["iou","ciou","map"], 12 | // "metrics":["accuracy","fscore"], 13 | "meta_info":"" 14 | } 15 | -------------------------------------------------------------------------------- /configs/datasets/neontree_detection.json: -------------------------------------------------------------------------------- 1 | { 2 | "root_path":"/network/scratch/a/arthur.ouaknine/data/NeonTree", 3 | "task":"detection", // Possible Tasks: Depending on the dataset, 4 | "segmentation_task":null, // Set to false if segmentation is not required 5 | "num_classes":1, 6 | "in_channels":3, 7 | "modality":"rgb", 8 | "metrics":["iou","ciou","map"], 9 | // "metrics":["accuracy","fscore"], 10 | "meta_info":"" 11 | } 12 | -------------------------------------------------------------------------------- /configs/datasets/reforestree.json: -------------------------------------------------------------------------------- 1 | { 2 | "root_path":"/home/mila/a/arthur.ouaknine/scratch/data/fomobench/ReforesTree", 3 | "download":false, // Option to download the dataset 4 | "checksum": false, // check the MD5 of the downloaded files 5 | "task":"detection", 6 | "segmentation_task":null, // segmentation not required here 7 | "num_classes":6, 8 | "in_channels":3, 9 | // "metrics": ["accuracy","fscore"], //Desired metrics to log 10 | "metrics":["iou","ciou","map"], 11 | "meta_info":"" 12 | } 13 | -------------------------------------------------------------------------------- /configs/datasets/supervised_foundation_cls.json: -------------------------------------------------------------------------------- 1 | { 2 | "train_datasets":["rapidai4eo","tallos","bigearthnet","sen12ms"], 3 | "task":"supervised_foundation", 4 | "num_classes":1000, 5 | "in_channels": 9, 6 | "total_num_classes":{ 7 | "rapidai4eo":26, 8 | "tallos":1364, 9 | "bigearthnet":19, 10 | "sen12ms":17, 11 | "treesatai":15 12 | }, 13 | "modality_channels": { 14 | "0":"sentinel-2-b2", 15 | "1":"sentinel-2-b3", 16 | "2":"sentinel-2-b4", 17 | "3":"sentinel-2-b5", 18 | "4":"sentinel-2-b6", 19 | "5":"sentinel-2-b7", 20 | "6":"sentinel-2-b8", 21 | "7":"sentinel-2-b8a", 22 | "8":"sentinel-2-b11", 23 | }, 24 | "dataset_modality_index":{ 25 | "tallos":{ 26 | "sentinel-2-b2":0, 27 | "sentinel-2-b3":1, 28 | "sentinel-2-b4":2, 29 | "sentinel-2-b5":3, 30 | "sentinel-2-b6":4, 31 | "sentinel-2-b7":5, 32 | "sentinel-2-b8":6, 33 | "sentinel-2-b8a":7, 34 | "sentinel-2-b11":8, 35 | }, 36 | "bigearthnet":{ 37 | "sentinel-2-b2":2, 38 | "sentinel-2-b3":3, 39 | "sentinel-2-b4":4, 40 | "sentinel-2-b5":5, 41 | "sentinel-2-b6":6, 42 | "sentinel-2-b7":7, 43 | "sentinel-2-b8":8, 44 | "sentinel-2-b8a":9, 45 | "sentinel-2-b11":12, 46 | }, 47 | "sen12ms":{ 48 | "sentinel-2-b2":3, 49 | "sentinel-2-b3":4, 50 | "sentinel-2-b4":5, 51 | "sentinel-2-b5":6, 52 | "sentinel-2-b6":7, 53 | "sentinel-2-b7":8, 54 | "sentinel-2-b8":9, 55 | "sentinel-2-b8a":10, 56 | "sentinel-2-b11":13, 57 | }, 58 | "rapidai4eo":{ 59 | "sentinel-2-b2":4, 60 | "sentinel-2-b3":5, 61 | "sentinel-2-b4":6, 62 | "sentinel-2-b8":7, 63 | "sentinel-2-b5":8, 64 | "sentinel-2-b6":9, 65 | "sentinel-2-b7":10, 66 | "sentinel-2-b8a":11, 67 | "sentinel-2-b11":12, 68 | }, 69 | "treesatai":{ 70 | "sentinel-2-b2":0, 71 | "sentinel-2-b3":1, 72 | "sentinel-2-b4":2, 73 | "sentinel-2-b8":3, 74 | "sentinel-2-b5":4, 75 | "sentinel-2-b6":5, 76 | "sentinel-2-b7":6, 77 | "sentinel-2-b8a":7, 78 | "sentinel-2-b11":8, 79 | } 80 | }, 81 | "split_workers":false, 82 | "max_number_iterations":1280000 83 | } -------------------------------------------------------------------------------- /configs/datasets/tallos.json: -------------------------------------------------------------------------------- 1 | { 2 | "root_path":"YourRoot", 3 | "task":"classification", // Possible Tasks: Depending on the dataset, 4 | "metrics": ["accuracy","fscore"], //Desired metrics to log 5 | "num_classes":1364, 6 | "input_mode":"single_image", 7 | "split_path":"configs/datasets/tallos_split.json", 8 | "class_level":"genus", 9 | "multilabel":true, 10 | "in_channels":13, // Could include options to select season or ROI. 11 | "meta_info":"Global coverage." 12 | } 13 | -------------------------------------------------------------------------------- /configs/download/download.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset":"rapidai4eo", // Dataset name 3 | "root_path":"", // Where to save the data. 4 | } 5 | -------------------------------------------------------------------------------- /configs/method/classification/convnext.json: -------------------------------------------------------------------------------- 1 | { 2 | "backbone":"convnext_base.fb_in1k", 3 | "pretrained":true 4 | } -------------------------------------------------------------------------------- /configs/method/classification/resnet50.json: -------------------------------------------------------------------------------- 1 | { 2 | "backbone":"resnet50",//"convnext_base.fb_in1k", 3 | "pretrained":true 4 | } -------------------------------------------------------------------------------- /configs/method/classification/vit.json: -------------------------------------------------------------------------------- 1 | { 2 | "backbone":"vit_base_patch16_224.augreg_in21k", 3 | "pretrained":true 4 | } 5 | -------------------------------------------------------------------------------- /configs/method/detection/fasterrcnn.json: -------------------------------------------------------------------------------- 1 | { 2 | "architecture":"fasterrcnn", 3 | "backbone":"resnet50", 4 | "pretrained":true 5 | } 6 | -------------------------------------------------------------------------------- /configs/method/detection/retinanet.json: -------------------------------------------------------------------------------- 1 | { 2 | "architecture":"retinanet", 3 | "backbone":"resnet50", 4 | "pretrained":true 5 | } 6 | -------------------------------------------------------------------------------- /configs/method/detection/yolos.json: -------------------------------------------------------------------------------- 1 | { 2 | "architecture":"yolos", 3 | "backbone":"", 4 | "pretrained":true 5 | } 6 | -------------------------------------------------------------------------------- /configs/method/mae/mae_multi.json: -------------------------------------------------------------------------------- 1 | { 2 | "spectral_mae":true, 3 | "image_size":128, 4 | "patch_size":16, 5 | "single_embedding_layer":false, 6 | "num_classes":1000, 7 | "dim":1024, 8 | "depth":12, 9 | "heads":16, 10 | "mlp_dim":2048, 11 | "masked_ratio":0.75, 12 | "decoder_dim":512, // decoder dim. Can be smaller than encoder dim. 13 | "decoder_depth":8, 14 | "decoder_heads":16, 15 | "num_samples_per_epoch":1280000, // Number of iterations for each epoch. Final number of steps = num_steps_per_epoch * epochs 16 | "log_reconstruction_every":100, // How often to log reconstruction for RGB data 17 | "accumulate_gradients":4, // If not null, gradients will be accumulated among N (value of "accumulate_gradients") batches, effectively optimizing towards a common direction for multiple modalities 18 | } -------------------------------------------------------------------------------- /configs/method/mae/mae_single.json: -------------------------------------------------------------------------------- 1 | { 2 | "spectral_mae":true, 3 | "image_size":128, 4 | "patch_size":16, 5 | "single_embedding_layer":true, 6 | "num_classes":1000, 7 | "dim":1024, 8 | "depth":12, 9 | "heads":16, 10 | "mlp_dim":2048, 11 | "masked_ratio":0.75, 12 | "decoder_dim":512, // decoder dim. Can be smaller than encoder dim. 13 | "decoder_depth":8, 14 | "decoder_heads":16, 15 | "num_samples_per_epoch":1280000, // Number of iterations for each epoch. Final number of steps = num_steps_per_epoch * epochs 16 | "log_reconstruction_every":100, // How often to log reconstruction for RGB data 17 | "accumulate_gradients":4, // If not null, gradients will be accumulated among N (value of "accumulate_gradients") batches, effectively optimizing towards a common direction for multiple modalities 18 | } -------------------------------------------------------------------------------- /configs/method/segmentation/deeplab.json: -------------------------------------------------------------------------------- 1 | { 2 | "architecture":"deeplab", 3 | "backbone":"resnet50", 4 | "pretrained":true, 5 | "output_size":64, //only used for finetuning SSL segmentation models 6 | } -------------------------------------------------------------------------------- /configs/method/segmentation/unet50.json: -------------------------------------------------------------------------------- 1 | { 2 | "architecture":"unet", 3 | "backbone":"resnet50", 4 | "pretrained":true 5 | } -------------------------------------------------------------------------------- /configs/method/segmentation/upernet.json: -------------------------------------------------------------------------------- 1 | { 2 | "architecture":"upernet", 3 | "backbone":"swin_base", 4 | "pretrained":true, 5 | "output_size":64, //only used for finetuning SSL segmentation models 6 | } -------------------------------------------------------------------------------- /configs/stats/stats.json: -------------------------------------------------------------------------------- 1 | { 2 | "bigearthnet":{ 3 | "mean":[-12.619993741972035, -19.29044597721542, 340.76769064, 429.9430203, 614.21682446, 590.23569706, 950.68368468, 1792.46290469, 2075.46795189, 4 | 2218.94553375, 2266.46036911, 2246.0605464, 1594.42694882, 1009.32729131], 5 | "std":[5.115911777546365, 5.464428464912864, 554.81258967, 572.41639287, 582.87945694, 675.88746967, 729.89827633, 1096.01480586, 1273.45393088, 6 | 1365.45589904, 1356.13789355, 1302.3292881, 1079.19066363, 818.86747235] 7 | }, 8 | "cactus":{ 9 | "mean":[119.3761, 115.2592, 128.4035], 10 | "std":[37.4228, 33.8390, 36.8710] 11 | }, 12 | "flair":{ 13 | "mean":[113.7753, 118.0812, 109.2739, 102.3642, 16.6973], 14 | "std":[52.4193, 46.0284, 45.2602, 39.4872, 29.5557] 15 | }, 16 | "flair2":{ 17 | "mean":[ 114.7360, 118.9935, 110.1762, 102.9594, 16.7306, 1553.1101, 18 | 1706.3533, 1707.6664, 2072.3354, 2694.5828, 2910.1399, 3030.8735, 19 | 3046.7170, 2398.3044, 1809.3027], 20 | "std": [52.7019, 46.2726, 45.6095, 39.6502, 29.7648, 2267.0027, 21 | 2104.7576, 2078.3738, 2071.4827, 1900.3833, 1864.7145, 1924.3833, 22 | 1820.6776, 1362.4669, 1244.8425] 23 | }, 24 | "sen12ms":{ 25 | "mean":[ -11.7630, -18.2950, 1463.9355, 1230.3521, 1141.8270, 1144.4269, 26 | 1356.1825, 1940.9421, 2220.6450, 2163.8037, 2418.8733, 791.9455, 27 | 23.9868, 2005.3367, 1358.3806], 28 | "std":[ 4.5312, 4.3663, 721.8138, 746.7266, 745.9549, 965.9202, 29 | 951.8728, 988.6033, 1085.2832, 1060.7733, 1138.8784, 551.5662, 30 | 34.1710, 1138.1683, 996.9311] 31 | }, 32 | "woody":{ 33 | "mean":[81.3055, 83.2069, 58.9164],//, 175.1896], Remove alpha channel stats 34 | "std":[66.9682, 67.1797, 48.1175],//, 118.2445] 35 | }, 36 | "spekboom":{ 37 | "mean":[93.7223, 90.5660, 77.5175, 164.8042], 38 | "std":[81.9170, 76.7438, 68.8432, 121.9120] 39 | }, 40 | "waititu":{ 41 | "mean":[55.6236, 54.1820, 35.0456],// 42.4444], 42 | "std":[63.1487, 60.9398, 41.7646],// 45.6104] 43 | }, 44 | "tallos":{ 45 | "mean":[1283.97119, 1438.82031, 1378.67847, 1819.70667, 2835.08765, 3202.86938, 3276.03394, 3422.51392, 2305.23926, 1555.97595, 108.98360, 213.39993, 661.94061], 46 | "std":[2028.46753, 1908.08655, 1912.57837, 1931.35449, 1730.66138, 1714.06738, 1679.35315, 1671.14331, 1351.27197, 1257.49890, 52.72569, 115.37871, 792.45972], 47 | }, 48 | "forestnet":{ 49 | "mean":[15.0060, 29.8892, 21.9384], 50 | "std":[11.2357, 12.2747, 14.3598] 51 | }, 52 | "ssl4eol":{ 53 | "mean":[ 98.0528, 89.4049, 88.3023, 98.6711, 156.3476, 148.6828, 117.2159, 54 | 91.0294, 1.3732, 187.0349, 183.8071], 55 | "std":[29.0827, 33.7028, 46.0983, 66.7981, 74.7895, 82.9275, 81.2110, 53.1308, 56 | 4.3644, 46.4231, 44.9355] 57 | }, 58 | "reforestree":{ 59 | "mean":[191.8039, 199.4721, 171.0712], 60 | "std":[58.9717, 54.2983, 72.2315] 61 | }, 62 | "rapidai4eo":{ 63 | "mean":[-599.8201, -428.7419, -353.7449, 1370.6443, 563.6370, 759.2718, 64 | 790.4224, 2340.5337, 1163.8649, 1954.0909, 2226.9436, 2425.1335, 65 | 1818.0320, 1213.2483, 455.0158, 2433.6865], 66 | "std":[5890.3970, 5925.5083, 5950.9106, 6311.4429, 955.2887, 938.5280, 67 | 1012.0714, 1369.0884, 1030.0022, 1203.9899, 1325.6887, 1377.0955, 68 | 1131.5851, 948.6902, 920.5815, 1329.4309] 69 | }, 70 | "ssl4eos1s2":{ 71 | "mean":[-12.54847273, -20.19237134,752.40087073, 884.29673756, 1144.16202635, 1297.47289228, 1624.90992062, 2194.6423161, 2422.21248945, 2517.76053101, 2581.64687018, 2645.51888987, 2368.51236873, 1805.06846033], 72 | "std":[5.25697717, 5.91150917,1108.02887453, 1155.15170768, 1183.6292542, 1368.11351514, 1370.265037, 1355.55390699, 1416.51487101, 1474.78900051, 1439.3086061, 1582.28010962, 1455.52084939, 1343.48379601] 73 | }, 74 | "uav":{ 75 | "mean":[112.7954, 112.4210, 94.9576], 76 | "std":[87.6614, 86.0243, 80.2605] 77 | }, 78 | "fivebillionpixels":{ 79 | "mean":[417.4763, 321.7472, 247.5841, 295.8605], 80 | "std":[107.1275, 90.3392, 96.7788, 144.8598] 81 | }, 82 | "neontree":{ 83 | "mean":[152.454, 123.325, 144.287], 84 | "std":[44.944, 30.832, 38.879] 85 | }, 86 | "treesatai":{ 87 | "mean":[ 237.94443, 384.66238, 253.62634, 2762.06128, 625.80261, 2074.89673, 2651.58423, 2923.11133, 1307.51001, 600.52789, 253.16537, 2926.36133],//[160.28419, 91.28082, 84.00340, 75.07034], 88 | "std": [137.14885, 160.49547, 178.52460, 798.59344, 233.07007, 542.35138, 721.96539, 790.62885, 462.26263, 300.07938, 127.59349, 717.23358],//[54.13877, 32.30051, 27.13653, 31.13494] 89 | }, 90 | "marida":{ 91 | "mean":[0.05197577, 0.04783991, 0.04056812, 0.03163572, 0.02972606, 0.03457443, 92 | 0.03875053, 0.03436435, 0.0392113, 0.02358126, 0.01588816], 93 | "std":[0.04725893, 0.04743808, 0.04699043, 0.04967381, 0.04946782, 0.06458357, 94 | 0.07594915, 0.07120246, 0.08251058, 0.05111466, 0.03524419], 95 | }, 96 | "treesat_aerial":{ 97 | "mean":[162.3750, 97.8018, 89.1425, 84.2689], 98 | "std":[53.3237, 32.6658, 29.0721, 30.4791] 99 | }, 100 | "treesat_sen200m":{ 101 | "mean":[ 237.9059, 384.5531, 253.3916, 2761.8970, 625.5364, 2074.8232, 102 | 2651.5908, 2923.0964, 1306.9153, 600.0505, 253.4281, 2926.2104], 103 | "std":[137.0299, 160.0558, 177.5254, 799.4852, 232.1957, 542.7112, 722.8084, 104 | 791.5959, 461.6962, 298.6560, 127.5402, 718.0933] 105 | }, 106 | "mlrsnet":{ 107 | "mean":[0.4006, 0.4121, 0.3878], 108 | "std":[0.2108, 0.1918, 0.1942] 109 | }, 110 | "all":{ 111 | //Placeholder. Stats will be replaced by dataset-specific statistics on runtime. 112 | "mean":[], 113 | "std":[], 114 | }, 115 | "supervised_foundation_cls":{ 116 | //Placeholder. Stats will be replaced by dataset-specific statistics on runtime. 117 | "mean":[], 118 | "std":[], 119 | }, 120 | "super_fomo":{ 121 | //Placeholder. Stats will be replaced by dataset-specific statistics on runtime. 122 | "mean":[], 123 | "std":[], 124 | } 125 | } 126 | -------------------------------------------------------------------------------- /configs/training/detection/detection_finetuning_fasterrcnn_mixte_neontree.json: -------------------------------------------------------------------------------- 1 | { 2 | "epochs":200, 3 | "batch_size":64, 4 | "num_workers":2, 5 | "lr":1e-4, 6 | "weight_decay":5e-4, 7 | "momentum":0.9, 8 | "loss":"cross_entropy", 9 | "optimizer":"sgd", 10 | "lr_step":8000, 11 | "normalization":"other", // Options: standard (0 mean, 1 std), min-max (min 0, max 1), none 12 | "augment":true, 13 | "metric_aggregation_strategy":[null], // Options, micro, macro, none for per-class calculation 14 | "distributed":false, 15 | "pin_memory":true, 16 | "prefetch_factor":2, 17 | "linear_evaluation":false, // set to true to evaluate a pretrained model 18 | "finetune_backbone":false, 19 | "pretrained_model_path":"/network/projects/fomobench/checkpoints/detection/mixed_detection/fasterrcnn/resnet50/comic-yogurt-14/best_model.pt", // resnet_star 20 | // "pretrained_model_path":"/network/scratch/a/arthur.ouaknine/FoMo-Bench/checkpoints/detection/reforestree/fasterrcnn/resnet50/dancing-dumpling-10/best_model.pt", // pretrained on Reforestree ONLY 21 | // "checkpoint_path":"/network/projects/fomobench/checkpoints/detection/neontree_detection/fasterrcnn/resnet50_star/likely-night-7", // For testing (evaluation) ONLY 22 | "change_finetuning_resolution":null, 23 | "has_collate_fn":true, // indicates if dataloader has a collate_fn method 24 | "det_format":"pascal_voc", // define format for bounding boxes: pascal_voc, coco or yolo 25 | "resume_checkpoint_path":null, // checkpoint to resume training 26 | "enforce_resize":null, // if not none, data will be resized in both train/evaluation to the provided value (only for webdataset setting), 27 | "change_finetuning_resolution":null, // If not none, the finetuning vit will have its positional embedding changed to support the new res 28 | "finetuning_patch_size":null, // In accordance to finetuning resolution. Ideally keep it the same as in training 29 | "iou_thresh": 0.1, 30 | "conf_thresh": 0., 31 | "eval_metric":"map", 32 | "modality_channels": { 33 | "28": 0, 34 | "29": 1, 35 | "30": 2, 36 | }, 37 | "dataset_modality_index":{ 38 | "reforestree":{ 39 | "28":0, 40 | "29":1, 41 | "30":2, 42 | }, 43 | }, 44 | "output_size":224, //Generate embedding from FoMo 45 | "out_channels":256, //Generate embedding from FoMo 46 | "change_finetuning_resolution":224, 47 | "finetuning_patch_size":16 48 | } 49 | -------------------------------------------------------------------------------- /configs/training/detection/detection_finetuning_fasterrcnn_mixte_neontree.json~: -------------------------------------------------------------------------------- 1 | { 2 | "epochs":200, 3 | "batch_size":64, 4 | "num_workers":2, 5 | "lr":1e-4, 6 | "weight_decay":5e-4, 7 | "momentum":0.9, 8 | "loss":"cross_entropy", 9 | "optimizer":"sgd", 10 | "lr_step":8000, 11 | "normalization":"other", // Options: standard (0 mean, 1 std), min-max (min 0, max 1), none 12 | "augment":true, 13 | "metric_aggregation_strategy":[null], // Options, micro, macro, none for per-class calculation 14 | "distributed":false, 15 | "pin_memory":true, 16 | "prefetch_factor":2, 17 | "linear_evaluation":true, // set to true to evaluate a pretrained model 18 | "finetune_backbone":false, 19 | "pretrained_model_path":"/network/projects/fomobench/checkpoints/detection/mixed_detection/fasterrcnn/resnet50/comic-yogurt-14/best_model.pt", // resnet_star 20 | // "pretrained_model_path":"/network/scratch/a/arthur.ouaknine/FoMo-Bench/checkpoints/detection/reforestree/fasterrcnn/resnet50/dancing-dumpling-10/best_model.pt", // pretrained on Reforestree ONLY 21 | // "checkpoint_path":"/network/projects/fomobench/checkpoints/detection/neontree_detection/fasterrcnn/resnet50_star/likely-night-7", // For testing (evaluation) ONLY 22 | "change_finetuning_resolution":null, 23 | "has_collate_fn":true, // indicates if dataloader has a collate_fn method 24 | "det_format":"pascal_voc", // define format for bounding boxes: pascal_voc, coco or yolo 25 | "resume_checkpoint_path":null, // checkpoint to resume training 26 | "enforce_resize":null, // if not none, data will be resized in both train/evaluation to the provided value (only for webdataset setting), 27 | "change_finetuning_resolution":null, // If not none, the finetuning vit will have its positional embedding changed to support the new res 28 | "finetuning_patch_size":null, // In accordance to finetuning resolution. Ideally keep it the same as in training 29 | "iou_thresh": 0.1, 30 | "conf_thresh": 0., 31 | "eval_metric":"map", 32 | "modality_channels": { 33 | "28": 0, 34 | "29": 1, 35 | "30": 2, 36 | }, 37 | "dataset_modality_index":{ 38 | "reforestree":{ 39 | "28":0, 40 | "29":1, 41 | "30":2, 42 | }, 43 | }, 44 | "output_size":224, //Generate embedding from FoMo 45 | "out_channels":256, //Generate embedding from FoMo 46 | "change_finetuning_resolution":224, 47 | "finetuning_patch_size":16 48 | } 49 | -------------------------------------------------------------------------------- /configs/training/detection/detection_finetuning_fasterrcnn_mixte_reforestree.json: -------------------------------------------------------------------------------- 1 | { 2 | "epochs":200, 3 | "batch_size":64, 4 | "num_workers":2, 5 | "lr":1e-4, 6 | "weight_decay":5e-4, 7 | "momentum":0.9, 8 | "loss":"cross_entropy", 9 | "optimizer":"sgd", 10 | "lr_step":8000, 11 | "normalization":"other", // Options: standard (0 mean, 1 std), min-max (min 0, max 1), none 12 | "augment":true, 13 | "metric_aggregation_strategy":[null], // Options, micro, macro, none for per-class calculation 14 | "distributed":false, 15 | "pin_memory":true, 16 | "prefetch_factor":2, 17 | "linear_evaluation":false, // set to true to evaluate a pretrained model 18 | "finetune_backbone":false, 19 | "pretrained_model_path":"/network/projects/fomobench/checkpoints/detection/mixed_detection/fasterrcnn/resnet50/comic-yogurt-14/best_model.pt", // resnet_star 20 | // "pretrained_model_path":"/home/mila/a/arthur.ouaknine/scratch/FoMo-Bench/checkpoints/detection/neontree_detection/fasterrcnn/resnet50/unique-field-140/best_model.pt", // pretrained only on NeonTree 21 | // "checkpoint_path":"/home/mila/a/arthur.ouaknine/scratch/FoMo-Bench/checkpoints/detection/reforestree/fasterrcnn/resnet50_star/desert-cloud-16", // Evaluation only 22 | "record_pred":true, // Evaluation only 23 | "visu_pred":true, // Evaluation only 24 | "change_finetuning_resolution":null, 25 | "has_collate_fn":true, // indicates if dataloader has a collate_fn method 26 | "det_format":"pascal_voc", // define format for bounding boxes: pascal_voc, coco or yolo 27 | "resume_checkpoint_path":null, // checkpoint to resume training 28 | "enforce_resize":null, // if not none, data will be resized in both train/evaluation to the provided value (only for webdataset setting), 29 | "change_finetuning_resolution":null, // If not none, the finetuning vit will have its positional embedding changed to support the new res 30 | "finetuning_patch_size":null, // In accordance to finetuning resolution. Ideally keep it the same as in training 31 | "iou_thresh": 0.1, 32 | "conf_thresh": 0., 33 | "eval_metric":"map", 34 | "modality_channels": { 35 | "28": 0, 36 | "29": 1, 37 | "30": 2, 38 | }, 39 | "dataset_modality_index":{ 40 | "reforestree":{ 41 | "28":0, 42 | "29":1, 43 | "30":2, 44 | }, 45 | }, 46 | "output_size":224, //Generate embedding from FoMo 47 | "out_channels":256, //Generate embedding from FoMo 48 | "change_finetuning_resolution":224, 49 | "finetuning_patch_size":16 50 | } 51 | -------------------------------------------------------------------------------- /configs/training/detection/detection_finetuning_fasterrcnn_mixte_reforestree.json~: -------------------------------------------------------------------------------- 1 | { 2 | "epochs":200, 3 | "batch_size":64, 4 | "num_workers":2, 5 | "lr":1e-4, 6 | "weight_decay":5e-4, 7 | "momentum":0.9, 8 | "loss":"cross_entropy", 9 | "optimizer":"sgd", 10 | "lr_step":8000, 11 | "normalization":"other", // Options: standard (0 mean, 1 std), min-max (min 0, max 1), none 12 | "augment":true, 13 | "metric_aggregation_strategy":[null], // Options, micro, macro, none for per-class calculation 14 | "distributed":false, 15 | "pin_memory":true, 16 | "prefetch_factor":2, 17 | "linear_evaluation":true, // set to true to evaluate a pretrained model 18 | "finetune_backbone":false, 19 | "pretrained_model_path":"/network/projects/fomobench/checkpoints/detection/mixed_detection/fasterrcnn/resnet50/comic-yogurt-14/best_model.pt", // resnet_star 20 | // "pretrained_model_path":"/home/mila/a/arthur.ouaknine/scratch/FoMo-Bench/checkpoints/detection/neontree_detection/fasterrcnn/resnet50/unique-field-140/best_model.pt", // pretrained only on NeonTree 21 | // "checkpoint_path":"/home/mila/a/arthur.ouaknine/scratch/FoMo-Bench/checkpoints/detection/reforestree/fasterrcnn/resnet50_star/desert-cloud-16", // Evaluation only 22 | "record_pred":true, // Evaluation only 23 | "visu_pred":true, // Evaluation only 24 | "change_finetuning_resolution":null, 25 | "has_collate_fn":true, // indicates if dataloader has a collate_fn method 26 | "det_format":"pascal_voc", // define format for bounding boxes: pascal_voc, coco or yolo 27 | "resume_checkpoint_path":null, // checkpoint to resume training 28 | "enforce_resize":null, // if not none, data will be resized in both train/evaluation to the provided value (only for webdataset setting), 29 | "change_finetuning_resolution":null, // If not none, the finetuning vit will have its positional embedding changed to support the new res 30 | "finetuning_patch_size":null, // In accordance to finetuning resolution. Ideally keep it the same as in training 31 | "iou_thresh": 0.1, 32 | "conf_thresh": 0., 33 | "eval_metric":"map", 34 | "modality_channels": { 35 | "28": 0, 36 | "29": 1, 37 | "30": 2, 38 | }, 39 | "dataset_modality_index":{ 40 | "reforestree":{ 41 | "28":0, 42 | "29":1, 43 | "30":2, 44 | }, 45 | }, 46 | "output_size":224, //Generate embedding from FoMo 47 | "out_channels":256, //Generate embedding from FoMo 48 | "change_finetuning_resolution":224, 49 | "finetuning_patch_size":16 50 | } 51 | -------------------------------------------------------------------------------- /configs/training/detection/detection_training_fasterrcnn_mixed.json: -------------------------------------------------------------------------------- 1 | { 2 | "epochs": 200, 3 | "batch_size":32, 4 | "num_workers":2, 5 | "lr":1e-2, 6 | "weight_decay":5e-4, 7 | "momentum":0.9, 8 | "loss":"cross_entropy", 9 | "optimizer":"sgd", 10 | "lr_step":500, 11 | "normalization":"other", // Options: standard (0 mean, 1 std), min-max (min 0, max 1), none 12 | "augment":true, 13 | "metric_aggregation_strategy":[null], // Options, micro, macro, none for per-class calculation 14 | "distributed":false, 15 | "pin_memory":true, 16 | "prefetch_factor":2, 17 | "linear_evaluation":false, // set to true to evaluate a pretrained model 18 | "pretrained_model_path":null, //path for the checkpoint to be evaluated under the linear eval protocol 19 | "has_collate_fn":true, // indicates if dataloader has a collate_fn method 20 | "det_format":"pascal_voc", // Format to match for all annotations? 21 | "det_format_neontree_detection":"pascal_voc", // define format for bounding boxes: pascal_voc, coco or yolo 22 | "det_format_reforestree":"pascal_voc", // define format for bounding boxes: pascal_voc, coco or yolo 23 | "resume_checkpoint_path":null, // checkpoint to resume training 24 | "enforce_resize":null, // if not none, data will be resized in both train/evaluation to the provided value (only for webdataset setting), 25 | "change_finetuning_resolution":null, // If not none, the finetuning vit will have its positional embedding changed to support the new res 26 | "finetuning_patch_size":null, // In accordance to finetuning resolution. Ideally keep it the same as in training 27 | "iou_thresh": 0.3, 28 | "conf_thresh": 0.3, 29 | "eval_metric":"iou", 30 | } 31 | -------------------------------------------------------------------------------- /configs/training/mae/mae_training.json: -------------------------------------------------------------------------------- 1 | { 2 | "epochs":100, 3 | "batch_size":28, 4 | "num_workers":4, 5 | "lr":0.00001, 6 | "weight_decay":0.05, 7 | "schedule":"cos",//options: none, cosine, 8 | "warmup_epochs":20, //epochs to warmup 9 | "min_lr":0, // minimum learning rate bound for cyclic schedulers that hit 0 10 | "loss":"cross_entropy", 11 | "optimizer":"adam", 12 | "normalization":"standard", // Options: standard (0 mean, 1 std), min-max (min 0, max 1), none 13 | "augment":true, 14 | "metric_aggregation_strategy":["micro","macro","weighted"], // Options, micro, macro, none for per-class calculation 15 | "distributed":true, 16 | "pin_memory":true, 17 | "prefetch_factor":2, 18 | "persistent_workers":false, 19 | "linear_evaluation":false, // set to true to evaluate a pretrained model 20 | "pretrained_model_path":null, //path for the checkpoint to be evaluated under the linear eval protocol 21 | "resume_checkpoint_path":null, // checkpoint to resume training 22 | "enforce_resize":null, // if not none, data will be resized in both train/evaluation to the provided value (only for webdataset setting), 23 | "change_finetuning_resolution":null, // If not none, the finetuning vit will have its positional embedding changed to support the new res 24 | "finetuning_patch_size":null, // In accordance to finetuning resolution. Ideally keep it the same as in training 25 | "start_epoch":0, // epoch to continue training 26 | } 27 | -------------------------------------------------------------------------------- /configs/training/training_example.json: -------------------------------------------------------------------------------- 1 | { 2 | "epochs":20, 3 | "batch_size":128, 4 | "num_workers":4, 5 | "lr":1e-3, 6 | "weight_decay":1e-4, 7 | "schedule":"cos", // cosine with warmup 8 | "warmup_epochs":0, //epochs to warmup 9 | "min_lr":0, // minimum learning rate bound for cyclic schedulers that hit 0 10 | "loss":"cross_entropy", 11 | "optimizer":"adam", 12 | "normalization":"standard", // Options: standard (0 mean, 1 std), min-max (min 0, max 1), none 13 | "augment":false, 14 | "metric_aggregation_strategy":["micro","macro","weighted"], // Options, micro, macro, none for per-class calculation 15 | "distributed":false, 16 | "pin_memory":true, 17 | "prefetch_factor":2, 18 | "persistent_workers":false, 19 | "linear_evaluation":false, // set to true to evaluate a pretrained model 20 | "pretrained_model_path":null, //path for the checkpoint to be evaluated under the linear eval protocol 21 | "resume_checkpoint_path":null, // checkpoint to resume training 22 | "enforce_resize":null, // if not none, data will be resized in both train/evaluation to the provided value (only for webdataset setting), 23 | "change_finetuning_resolution":120, // If not none, the finetuning vit will have its positional embedding changed to support the new res 24 | "finetuning_patch_size":16, // In accordance to finetuning resolution. Ideally keep it the same as in training 25 | "start_epoch":0, // epoch to continue training 26 | "log_images":false, // log segmentation masks and images to wandb 27 | } -------------------------------------------------------------------------------- /datasets/BigEarthNetDataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import albumentations as A 4 | import cv2 as cv 5 | import einops 6 | import matplotlib.pyplot as plt 7 | import numpy as np 8 | import pandas as pd 9 | import pyjson5 as json 10 | import rasterio 11 | import torch 12 | from torchvision import transforms 13 | 14 | import utilities 15 | 16 | 17 | class BigEarthNetDataset(torch.utils.data.Dataset): 18 | def __init__(self, configs, mode="train"): 19 | print("=" * 40) 20 | print("Initializing BigEarthNet-MM mode - ", mode) 21 | print("=" * 40) 22 | 23 | self.configs = configs 24 | self.root_path = os.path.join(configs["root_path"], "BigEarthNet") 25 | self.mode = mode 26 | self.s1_path = os.path.join(self.root_path, "BigEarthNet-S1-v1.0") 27 | self.s2_path = os.path.join(self.root_path, "BigEarthNet-v1.0") 28 | mappings_path = os.path.join(self.root_path, "s1s2_mapping.csv") 29 | self.label_indices_path = os.path.join(self.root_path, "label_indices.json") 30 | self.mappings = pd.read_csv(mappings_path, header=None) 31 | print("Shuffling mappings") 32 | self.mappings = self.mappings.sample(frac=1).reset_index(drop=True) 33 | if self.configs["augment"]: 34 | self.augmentations = utilities.augmentations.get_augmentations(configs) 35 | else: 36 | self.augmentations = None 37 | if self.configs["normalization"] == "standard": 38 | self.normalization = transforms.Normalize(mean=self.configs["mean"], std=self.configs["std"]) 39 | 40 | # radar and spectral band names to read related GeoTIFF files 41 | self.band_names_s1 = ["VV", "VH"] 42 | self.band_names_s2 = ["B01", "B02", "B03", "B04", "B05", "B06", "B07", "B08", "B8A", "B09", "B11", "B12"] 43 | 44 | self.splits_path = os.path.join(self.root_path, mode + ".csv") 45 | self.valid_ids = pd.read_csv(self.splits_path, header=None) 46 | self.mappings = pd.merge(self.mappings, self.valid_ids, how="inner") 47 | print(mode, " samples: ", len(self.mappings)) 48 | 49 | self.total_label = json.load(open(self.label_indices_path, "r")) 50 | self.label_indices = self.total_label["original_labels"] 51 | self.conversion = self.total_label["label_conversion"] 52 | 53 | self.num_examples = len(self.mappings) 54 | 55 | def __len__(self): 56 | return self.num_examples 57 | 58 | def prepare_array(self, patch_dict, resize=120): 59 | patch = None 60 | for key in patch_dict.keys(): 61 | if patch is None: 62 | patch = patch_dict[key] 63 | patch = cv.resize(patch, (resize, resize)) 64 | patch = einops.rearrange(patch, "h w -> 1 h w") 65 | else: 66 | channel = cv.resize(patch_dict[key], (resize, resize)) 67 | channel = einops.rearrange(channel, "h w -> 1 h w") 68 | patch = np.vstack((patch, channel)) 69 | return patch 70 | 71 | def read_all_bands(self, path, patch, bands): 72 | all_bands = {} 73 | for band in bands: 74 | band_path = os.path.join(path, patch + "_" + band + ".tif") 75 | with rasterio.open(band_path) as src: 76 | band_array = src.read().squeeze() 77 | 78 | all_bands[band] = band_array 79 | return all_bands 80 | 81 | def prepare_pairs(self, row): 82 | s2 = self.mappings.iloc[row][0] 83 | s1 = self.mappings.iloc[row][1] 84 | 85 | s2_patch_dict = self.read_all_bands(os.path.join(self.s2_path, s2), s2, self.band_names_s2) 86 | s2_patch = self.prepare_array(s2_patch_dict) 87 | 88 | s1_patch_dict = self.read_all_bands(os.path.join(self.s1_path, s1), s1, self.band_names_s1) 89 | s1_patch = self.prepare_array(s1_patch_dict) 90 | 91 | label_path = os.path.join(self.s2_path, s2, s2 + "_labels_metadata.json") 92 | file = open(label_path, "r") 93 | label = json.load(file)["labels"] 94 | 95 | one_hot = np.zeros((43,)) 96 | for i in label: 97 | one_hot[self.label_indices[i]] = 1 98 | 99 | return (s2_patch, s1_patch), one_hot 100 | 101 | def plot(self, index=0): 102 | (s2_patch, s1_patch), one_hot = self.prepare_pairs(index) 103 | inverted_labels = {v: k for k, v in self.total_label["BigEarthNet-19_labels"].items()} 104 | labels = [] 105 | for idx, elem in enumerate(self.total_label["label_conversion"]): 106 | for i in elem: 107 | if one_hot[i] == 1: 108 | labels.append(inverted_labels[idx]) 109 | labels = list(np.unique(labels)) 110 | labels = ", ".join(labels) 111 | _, ax = plt.subplots(nrows=1, ncols=3, figsize=((12, 4))) 112 | ax[0].imshow(s1_patch[0]) 113 | ax[0].set_title("VV") 114 | ax[1].imshow(s1_patch[1]) 115 | ax[1].set_title("VH") 116 | s2_patch = einops.rearrange(s2_patch[:3, :, :], "c h w -> h w c") 117 | s2_patch = cv.cvtColor(s2_patch, cv.COLOR_BGR2RGB) 118 | ax[2].imshow(s2_patch / s2_patch.max()) 119 | ax[2].set_title("RGB") 120 | text = " " * 70 + "Labels:\n" + labels 121 | plt.text(0.25, 0.9, text, fontsize=8, transform=plt.gcf().transFigure) 122 | plt.savefig("BigEarthNet_sample" + str(index) + ".png") 123 | plt.show() 124 | 125 | def __getitem__(self, index): 126 | (s2_patch, s1_patch), one_hot = self.prepare_pairs(index) 127 | 128 | s2_patch = torch.from_numpy(s2_patch.astype("float")) 129 | s1_patch = torch.from_numpy(s1_patch.astype("float")) 130 | 131 | image = torch.cat((s1_patch, s2_patch), dim=0).float().numpy() 132 | 133 | labels_19 = np.zeros((19,)) 134 | for idx, elem in enumerate(self.total_label["label_conversion"]): 135 | for i in elem: 136 | if one_hot[i] == 1: 137 | labels_19[idx] = 1 138 | 139 | if not self.configs["webdataset"]: 140 | if self.configs["augment"] and self.mode == "train": 141 | image = einops.rearrange(image, "c h w -> h w c") 142 | transform = self.augmentations(image=image) 143 | image = transform["image"] 144 | image = einops.rearrange(image, "h w c -> c h w") 145 | if self.configs["normalization"] == "minmax": 146 | image /= image.max() 147 | elif self.configs["normalization"] == "standard": 148 | image = torch.from_numpy(image).float() 149 | image = self.normalization(image) 150 | else: 151 | image = torch.from_numpy(image).float() 152 | else: 153 | image = torch.from_numpy(image).float() 154 | 155 | return image, torch.from_numpy(labels_19) 156 | -------------------------------------------------------------------------------- /datasets/FLAIR2Dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pprint 3 | import random 4 | 5 | import cv2 as cv 6 | import einops 7 | import kornia 8 | import matplotlib.pyplot as plt 9 | import numpy as np 10 | import pyjson5 as json 11 | import rasterio 12 | import torch 13 | from tqdm import tqdm 14 | from torchvision import transforms 15 | import pyjson5 as json 16 | import glob 17 | import utilities 18 | import albumentations as A 19 | 20 | """ 21 | Data loading for the FLAIR Dataset published in: 22 | Garioud, Anatol, et al. "Challenge FLAIR #2: textural and temporal information for semantic segmentation 23 | from multi-source optical imagery." 24 | 25 | Test data are not released yet. 26 | 27 | """ 28 | 29 | 30 | class FLAIR2Dataset(torch.utils.data.Dataset): 31 | def __init__(self, configs, mode="train"): 32 | self.configs = configs 33 | self.root_path = os.path.join(configs["root_path"], "FLAIR_2") 34 | self.mode = mode 35 | self.labels_path = os.path.join(self.root_path, "flair_labels_train") 36 | self.root_aerial_path = os.path.join(self.root_path, "flair_aerial_train") 37 | self.root_sentinel_path = os.path.join(self.root_path, "flair_sen_train") 38 | self.centroids_to_patch_path = os.path.join(self.root_path, "flair-2_centroids_sp_to_patch.json") 39 | self.centroids_to_patch = json.load(open(self.centroids_to_patch_path, "r")) 40 | self.sat_patch_size = configs["sentinel_size"] 41 | if self.configs["augment"]: 42 | self.augmentations = utilities.augmentations.get_augmentations(configs) 43 | else: 44 | self.augmentations = None 45 | 46 | if self.configs["normalization"] == "standard": 47 | self.normalization = transforms.Normalize(mean=self.configs["mean"], std=self.configs["std"]) 48 | areas = os.listdir(self.root_aerial_path) 49 | 50 | if self.mode == "train" or self.mode == "val": 51 | areas = areas[: int(0.9 * len(areas))] 52 | elif self.mode == "test": 53 | areas = areas[int(0.9 * len(areas)) :] 54 | else: 55 | print(mode, "is not a valid mode! Exiting!") 56 | exit(2) 57 | 58 | print("=" * 40) 59 | print("Initializing FLAIR-2 dataset - mode: ", mode) 60 | print("=" * 40) 61 | self.samples = [] 62 | for area in tqdm(areas): 63 | area_aerial_path = os.path.join(self.root_aerial_path, area) 64 | sub_areas = os.listdir(area_aerial_path) 65 | for sub_area in sub_areas: 66 | sub_area_aerial_path = os.path.join(area_aerial_path, sub_area) 67 | images_list = os.listdir(os.path.join(sub_area_aerial_path, "img")) 68 | for image in images_list: 69 | if image.endswith("tif"): 70 | sample = {} 71 | sample["aerial"] = os.path.join(sub_area_aerial_path, "img", image) 72 | sample["sentinel2"] = os.path.join(self.root_sentinel_path, area, sub_area, "sen") 73 | mask_file = sample["aerial"].replace(self.root_aerial_path, self.labels_path) 74 | mask_file = mask_file.replace("img", "msk").replace("IMG", "MSK") 75 | sample["label"] = mask_file 76 | sample["area"] = area 77 | sample["sub_area"] = sub_area 78 | self.samples.append(sample) 79 | 80 | self.num_examples = len(self.samples) 81 | print("Mode: ", mode, " - Number of examples :", self.num_examples) 82 | 83 | def read_img(self, raster_file: str) -> np.ndarray: 84 | with rasterio.open(raster_file) as src_img: 85 | array = src_img.read() 86 | return torch.from_numpy(array.astype(float)) 87 | 88 | def read_superarea_and_crop(self, numpy_file: str, idx_centroid: list) -> np.ndarray: 89 | data = np.load(numpy_file, mmap_mode="r") 90 | subset_sp = data[ 91 | :, 92 | :, 93 | idx_centroid[0] - int(self.sat_patch_size / 2) : idx_centroid[0] + int(self.sat_patch_size / 2), 94 | idx_centroid[1] - int(self.sat_patch_size / 2) : idx_centroid[1] + int(self.sat_patch_size / 2), 95 | ] 96 | return torch.from_numpy(subset_sp.astype(float)) 97 | 98 | def read_labels(self, raster_file: str) -> np.ndarray: 99 | with rasterio.open(raster_file) as src_label: 100 | labels = src_label.read()[0] 101 | labels[labels > self.configs["num_classes"]] = self.configs["num_classes"] 102 | labels = labels - 1 103 | return labels 104 | 105 | def __len__(self): 106 | return self.num_examples 107 | 108 | def __getitem__(self, index): 109 | sample = self.samples[index] 110 | image = self.read_img(sample["aerial"]) 111 | mask_path = sample["label"] 112 | mask = self.read_labels(mask_path) 113 | aerial_id = sample["aerial"].split("/")[-1] 114 | data_path = glob.glob(os.path.join(sample["sentinel2"], "*data.npy"))[0] 115 | sentinel = self.read_superarea_and_crop(data_path, self.centroids_to_patch[aerial_id]) 116 | 117 | if len(self.configs["data_source"]) == 1: 118 | if self.configs["data_source"] == "sentinel2": 119 | image = sentinel 120 | if self.configs["timeseries"]: 121 | image = einops.rearrange(image, "t c h w -> (t*c) h w") 122 | else: 123 | # Randomly pick a sentinel image to use 124 | choice = random.randint(0, sentinel.shape[0] - 1) 125 | image = image[choice] 126 | elif len(self.configs["data_source"]) == 2: 127 | if self.configs["timeseries"]: 128 | # Randomly pick a sentinel image to use 129 | timeseries_sequence = list(range(sentinel.shape[0])) 130 | timeseries_subset = sorted(random.sample(timeseries_sequence, self.configs["length_of_sequence"])) 131 | sentinel = sentinel[timeseries_subset, :, :, :] 132 | sentinel = einops.rearrange(sentinel, "t c h w -> (t*c) h w") 133 | else: 134 | # Randomly pick a sentinel image to use 135 | choice = random.randint(0, sentinel.shape[0] - 1) 136 | sentinel = sentinel[choice] 137 | sentinel = einops.rearrange(sentinel, "c h w -> h w c") 138 | resize = A.Compose([A.augmentations.Resize(height=image.shape[1], width=image.shape[2], p=1.0)]) 139 | transform = resize(image=sentinel.numpy()) 140 | sentinel = torch.from_numpy(transform["image"]) 141 | sentinel = einops.rearrange(sentinel, "h w c->c h w") 142 | 143 | image = torch.cat((image, sentinel), dim=0) 144 | else: 145 | print("FLAIR2 supports only 2 data sources! Validate config file!") 146 | exit(3) 147 | 148 | image = image.numpy() 149 | if not self.configs["webdataset"]: 150 | if self.configs["augment"] and self.mode == "train": 151 | image = einops.rearrange(image, "c h w -> h w c") 152 | transform = self.augmentations(image=image, mask=mask) 153 | image = transform["image"] 154 | mask = transform["mask"] 155 | image = einops.rearrange(image, "h w c -> c h w") 156 | if self.configs["normalization"] == "minmax": 157 | image /= image.max() 158 | elif self.configs["normalization"] == "standard": 159 | image = torch.from_numpy(image).float() 160 | image = self.normalization(image) 161 | else: 162 | image = torch.from_numpy(image).float() 163 | 164 | # Bring labels in range [0,..] 165 | mask = torch.from_numpy(mask).long() 166 | 167 | return image, mask 168 | -------------------------------------------------------------------------------- /datasets/FLAIRDataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pprint 3 | import random 4 | 5 | import cv2 as cv 6 | import einops 7 | import kornia 8 | import matplotlib.pyplot as plt 9 | import numpy as np 10 | import pyjson5 as json 11 | import rasterio 12 | import torch 13 | from tqdm import tqdm 14 | from torchvision import transforms 15 | 16 | import utilities 17 | 18 | """ 19 | Data loading for the FLAIR Dataset published in: 20 | Garioud, Anatol, et al. "FLAIR: French Land cover from Aerospace ImageRy." 21 | """ 22 | 23 | 24 | class FLAIRDataset(torch.utils.data.Dataset): 25 | def __init__(self, configs, mode="train"): 26 | self.configs = configs 27 | self.root_path = os.path.join(configs["root_path"], "FLAIR") 28 | self.mode = mode 29 | if self.mode == "train" or self.mode == "val": 30 | self.labels_path = os.path.join(self.root_path, "flair_labels_train") 31 | self.root_path = os.path.join(self.root_path, "flair_aerial_train") 32 | elif self.mode == "test": 33 | self.labels_path = os.path.join(self.root_path, "flair_1_labels_test") 34 | self.root_path = os.path.join(self.root_path, "flair_1_aerial_test") 35 | 36 | if self.configs["augment"]: 37 | self.augmentations = utilities.augmentations.get_augmentations(configs) 38 | else: 39 | self.augmentations = None 40 | 41 | if self.configs["normalization"] == "standard": 42 | self.normalization = transforms.Normalize(mean=self.configs["mean"], std=self.configs["std"]) 43 | areas = os.listdir(self.root_path) 44 | print("=" * 40) 45 | print("Initializing FLAIR dataset - mode: ", mode) 46 | print("=" * 40) 47 | self.samples = [] 48 | for area in tqdm(areas): 49 | area_path = os.path.join(self.root_path, area) 50 | sub_areas = os.listdir(area_path) 51 | for sub_area in sub_areas: 52 | sub_area_path = os.path.join(area_path, sub_area) 53 | images_list = os.listdir(os.path.join(sub_area_path, "img")) 54 | for image in images_list: 55 | if image.endswith("tif"): 56 | sample = {} 57 | sample["path"] = os.path.join(sub_area_path, "img", image) 58 | mask_file = image.replace("IMG", "MSK") 59 | sample["label"] = os.path.join(self.labels_path, area, sub_area, "msk", mask_file) 60 | sample["area"] = area 61 | sample["sub_area"] = sub_area 62 | self.samples.append(sample) 63 | if mode == "train": 64 | random.Random(999).shuffle(self.samples) 65 | self.samples = self.samples[: int(0.9 * len(self.samples))] 66 | elif mode == "val": 67 | random.Random(999).shuffle(self.samples) 68 | self.samples = self.samples[int(0.9 * len(self.samples)) :] 69 | self.num_examples = len(self.samples) 70 | 71 | def __len__(self): 72 | return self.num_examples 73 | 74 | def plot(self, index=0): 75 | # TOY preprocessing for dev purposes 76 | sample = self.samples[index] 77 | path = sample["path"] 78 | mask_path = sample["label"] 79 | with rasterio.open(path) as src: 80 | image = src.read() 81 | 82 | with rasterio.open(mask_path) as src2: 83 | mask = src2.read() 84 | 85 | image = image[:3] 86 | image = einops.rearrange(image, "c h w -> h w c") 87 | _, ax = plt.subplots(nrows=1, ncols=2) 88 | 89 | ax[0].imshow(mask) 90 | ax[0].set_title("Mask") 91 | ax[1].imshow(image / image.max()) 92 | ax[1].set_title("Satellite image") 93 | plt.savefig("sample_" + str(index) + ".png") 94 | plt.show() 95 | 96 | def __getitem__(self, index): 97 | # TOY preprocessing for dev purposes 98 | sample = self.samples[index] 99 | path = sample["path"] 100 | mask_path = sample["label"] 101 | with rasterio.open(path) as src: 102 | image = src.read() 103 | 104 | with rasterio.open(mask_path) as src2: 105 | mask = src2.read() 106 | 107 | if not self.configs["webdataset"]: 108 | if self.configs["augment"] and self.mode == "train": 109 | transform = self.augmentations(image=image, mask=mask) 110 | image = transform["image"] 111 | mask = transform["mask"] 112 | image = torch.from_numpy(image).float() 113 | if self.configs["normalization"] == "minmax": 114 | image /= image.max() 115 | elif self.configs["normalization"] == "standard": 116 | image = self.normalization(image) 117 | else: 118 | image = torch.from_numpy(image).float() 119 | 120 | # Bring labels in range [0,..] 121 | mask = torch.from_numpy(mask).long().squeeze() - 1 122 | return image, mask 123 | -------------------------------------------------------------------------------- /datasets/FORinstanceDataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import pprint 4 | import random 5 | import warnings 6 | 7 | from pathlib import Path 8 | import laspy 9 | import matplotlib.pyplot as plt 10 | import numpy as np 11 | import pandas as pd 12 | import pyjson5 as json 13 | import rasterio 14 | import torch 15 | import xmltodict 16 | from tqdm import tqdm 17 | 18 | from torch_geometric.data import Data 19 | import torch_geometric.transforms as T 20 | 21 | import utilities.augmentations 22 | 23 | """ 24 | Data loading for the FORinstance Dataset published in: 25 | https://zenodo.org/record/8287792 26 | """ 27 | 28 | 29 | class FORinstanceDataset(torch.utils.data.Dataset): 30 | def __init__(self, configs, mode="train"): 31 | self.mode = mode 32 | self.configs = configs 33 | self.root_path = Path(configs["root_path"]) 34 | self.seg_task = self.configs["segmentation_task"] 35 | self.samples = self._load_dataset() 36 | 37 | if "nb_points" in self.configs.keys(): 38 | self.nb_points = self.configs["nb_points"] 39 | if self.configs["augment"]: 40 | self.augmentations = utilities.augmentations.get_augmentations(configs) 41 | else: 42 | self.augmentations = None 43 | self.normalization = T.NormalizeScale() 44 | 45 | # debug 46 | # self.samples = self.samples[:100] 47 | 48 | if self.mode == "train": 49 | random.Random(999).shuffle(self.samples) 50 | self.samples = self.samples[: int(0.9 * len(self.samples))] 51 | 52 | elif self.mode == "val": 53 | random.Random(999).shuffle(self.samples) 54 | self.samples = self.samples[int(0.9 * len(self.samples)) :] 55 | 56 | self.num_examples = len(self.samples) 57 | print("Number of samples in split {} = {}".format(self.mode, self.num_examples)) 58 | 59 | def _load_dataset(self): 60 | dataset = pd.read_csv(self.root_path / "data_split_metadata.csv") 61 | if self.mode in ("train", "val"): 62 | dataset = dataset[dataset["split"] == "dev"] 63 | else: 64 | dataset = dataset[dataset["split"] == "test"] 65 | if self.seg_task == "semantic_segmentation": 66 | dataset = dataset[dataset["folder"] != "RMIT"] 67 | dataset = dataset[dataset["folder"] != "TUWIEN"] 68 | path_files = [ 69 | self.root_path / sample_path for sample_path in dataset["path"] if (self.root_path / sample_path).exists() 70 | ] 71 | path_files = [list(path_file.parent.glob(path_file.stem + "*.pkl")) for path_file in path_files] 72 | path_files = [item for sublist in path_files for item in sublist] 73 | samples = [] 74 | 75 | for path_file in path_files: 76 | with open(path_file, "rb") as f: 77 | samples.append(pickle.load(f)) 78 | samples = [item["sub_pc"] for sublist in samples for item in sublist] 79 | return samples 80 | 81 | def __len__(self): 82 | return self.num_examples 83 | 84 | def __getitem__(self, index): 85 | sample = self.samples[index] 86 | with laspy.open(sample) as lidar_file: 87 | pc = lidar_file.read() 88 | if self.seg_task == "semantic_segmentation": 89 | # labels = np.array(pc.treeSP, dtype=np.int64) 90 | labels = np.array(pc.classification, dtype=np.int64) 91 | elif self.seg_task == "instance_segmentation": 92 | labels = np.array(pc.treeID, dtype=np.int64) 93 | else: 94 | raise Exception("Task {} is not supported yet.".format(self.seg_task)) 95 | point_cloud = np.stack([np.array(pc.x), np.array(pc.y), np.array(pc.z)]).T 96 | # get random sub sample 97 | if self.mode == "train": 98 | if self.nb_points > point_cloud.shape[0]: 99 | replace = True 100 | else: 101 | replace = False 102 | rand_idx = np.random.choice(list(range(point_cloud.shape[0])), size=self.nb_points, replace=replace) 103 | point_cloud = point_cloud[rand_idx, :] 104 | labels = labels[rand_idx] 105 | point_cloud = torch.tensor(point_cloud) 106 | labels = torch.tensor(labels) 107 | # No feature is considered for baselines 108 | x = torch.ones((point_cloud.shape[0], 3), dtype=torch.float) 109 | data = Data(pos=point_cloud, x=x, y=labels) 110 | data = self.normalization(data) 111 | if self.configs["augment"] and self.mode == "train": 112 | data = self.augmentations(data) 113 | return data, None 114 | 115 | def collate_fn(self, batch): 116 | return tuple(zip(*batch)) 117 | 118 | 119 | if __name__ == "__main__": 120 | dataset = FORinstanceDataset(path_augment_dict=True, mode="train") 121 | for i, data in enumerate(dataset): 122 | import ipdb 123 | 124 | ipdb.set_trace() 125 | # dataset.plot(0) 126 | -------------------------------------------------------------------------------- /datasets/FiveBillionPixelsDataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import albumentations as A 4 | import cv2 as cv 5 | import einops 6 | import matplotlib.pyplot as plt 7 | import numpy as np 8 | import pandas as pd 9 | import pyjson5 as json 10 | import rasterio 11 | import torch 12 | from torchvision import transforms 13 | import rioxarray as rio 14 | import tqdm 15 | import utilities 16 | from pathlib import Path 17 | import random 18 | 19 | 20 | class FiveBillionPixelsDataset(torch.utils.data.Dataset): 21 | def __init__(self, configs, mode="train"): 22 | print("=" * 40) 23 | print("Initializing FiveBillionPixels Dataset mode - ", mode) 24 | print("=" * 40) 25 | self.mode = mode 26 | self.configs = configs 27 | self.root_path = os.path.join(configs["root_path"], "FiveBillionPixels") 28 | self.image_path = os.path.join(self.root_path, "Image_16bit_BGRNir") 29 | self.label_path = os.path.join(self.root_path, "Annotation__index") 30 | self.tile_path = os.path.join(self.root_path, "tiles") 31 | images = os.listdir(self.image_path) 32 | self.samples = [] 33 | self.tile_path = os.path.join(self.root_path, "tiles", mode) 34 | self.tile_label_path = os.path.join(self.root_path, "tile_labels", mode) 35 | self.val_scenes = [ 36 | "GF2_PMS1__L1A0001094941-MSS1.tiff", 37 | "GF2_PMS1__L1A0001680853-MSS1.tiff", 38 | "GF2_PMS2__L1A0000958144-MSS2.tiff", 39 | "GF2_PMS2__L1A0001757317-MSS2.tiff", 40 | "GF2_PMS1__L1A0001491417-MSS1.tiff", 41 | "GF2_PMS2__L1A0000564691-MSS2.tiff", 42 | "GF2_PMS2__L1A0001206072-MSS2.tiff", 43 | "GF2_PMS2__L1A0001573999-MSS2.tiff", 44 | "GF2_PMS2__L1A0001886305-MSS2.tiff", 45 | ] 46 | self.test_scenes = [ 47 | "GF2_PMS1__L1A0001064454-MSS1.tiff", 48 | "GF2_PMS1__L1A0001118839-MSS1.tiff", 49 | "GF2_PMS1__L1A0001344822-MSS1.tiff", 50 | "GF2_PMS1__L1A0001348919-MSS1.tiff", 51 | "GF2_PMS1__L1A0001366278-MSS1.tiff", 52 | "GF2_PMS1__L1A0001366284-MSS1.tiff", 53 | "GF2_PMS1__L1A0001395956-MSS1.tiff", 54 | "GF2_PMS1__L1A0001432972-MSS1.tiff", 55 | "GF2_PMS1__L1A0001670888-MSS1.tiff", 56 | "GF2_PMS1__L1A0001680857-MSS1.tiff", 57 | "GF2_PMS1__L1A0001680858-MSS1.tiff", 58 | "GF2_PMS1__L1A0001757429-MSS1.tiff", 59 | "GF2_PMS1__L1A0001765574-MSS1.tiff", 60 | "GF2_PMS2__L1A0000607677-MSS2.tiff", 61 | "GF2_PMS2__L1A0000607681-MSS2.tiff", 62 | "GF2_PMS2__L1A0000718813-MSS2.tiff", 63 | "GF2_PMS2__L1A0001038935-MSS2.tiff", 64 | "GF2_PMS2__L1A0001038936-MSS2.tiff", 65 | "GF2_PMS2__L1A0001119060-MSS2.tiff", 66 | "GF2_PMS2__L1A0001367840-MSS2.tiff", 67 | "GF2_PMS2__L1A0001378491-MSS2.tiff", 68 | "GF2_PMS2__L1A0001378501-MSS2.tiff", 69 | "GF2_PMS2__L1A0001396036-MSS2.tiff", 70 | "GF2_PMS2__L1A0001396037-MSS2.tiff", 71 | "GF2_PMS2__L1A0001416129-MSS2.tiff", 72 | "GF2_PMS2__L1A0001471436-MSS2.tiff", 73 | "GF2_PMS2__L1A0001517494-MSS2.tiff", 74 | "GF2_PMS2__L1A0001591676-MSS2.tiff", 75 | "GF2_PMS2__L1A0001787564-MSS2.tiff", 76 | ] 77 | 78 | if self.configs["augment"]: 79 | self.augmentations = utilities.augmentations.get_augmentations(configs) 80 | else: 81 | self.augmentations = None 82 | if self.configs["normalization"] == "standard": 83 | self.normalization = transforms.Normalize(mean=self.configs["mean"], std=self.configs["std"]) 84 | 85 | if not self.configs["tilerize"] or not os.path.isdir(self.tile_path): 86 | for index, image in tqdm.tqdm(enumerate(images)): 87 | if ".DS_Store" in image: 88 | continue 89 | image_path = os.path.join(self.image_path, image) 90 | image_name = image.split(".")[0] 91 | 92 | if image in self.val_scenes and mode != "val": 93 | continue 94 | if image in self.test_scenes and mode != "test": 95 | continue 96 | if image not in self.val_scenes and mode == "val": 97 | continue 98 | if image not in self.test_scenes and mode == "test": 99 | continue 100 | label_path = os.path.join(self.label_path, image_name + "_24label.png") 101 | if not self.configs["tilerize"]: 102 | self.samples.append({"image": image_path, "label": label_path}) 103 | else: 104 | if not os.path.isdir(self.tile_path): 105 | path = Path(self.tile_path) 106 | path.mkdir(parents=True, exist_ok=True) 107 | if not os.path.isdir(self.tile_label_path): 108 | path = Path(self.tile_label_path) 109 | path.mkdir(parents=True, exist_ok=True) 110 | self.samples.extend(self.tilerize(image_path, label_path, image_name)) 111 | else: 112 | images = os.listdir(self.tile_path) 113 | self.samples = [] 114 | for index, image in tqdm.tqdm(enumerate(images)): 115 | image_path = os.path.join(self.tile_path, image) 116 | image_name = image.split(".")[0] 117 | label_path = os.path.join(self.tile_label_path, image_name + ".png") 118 | self.samples.append({"image": image_path, "label": label_path}) 119 | 120 | random.Random(999).shuffle(self.samples) 121 | print(self.samples[0]) 122 | print("Samples for mode: ", mode, " = ", len(self.samples)) 123 | self.num_examples = len(self.samples) 124 | 125 | def tilerize(self, image_path, label_path, image_name): 126 | tif = rio.open_rasterio(image_path, engine="rasterio").sel(band=[3, 2, 1, 4]) 127 | 128 | image = tif.to_numpy() 129 | label = cv.imread(label_path, 0) 130 | image = einops.rearrange(image, "c h w -> h w c") 131 | label = label[: image.shape[0], : image.shape[1]] 132 | 133 | tiles = [] 134 | for i in tqdm.tqdm(range(0, image.shape[0], self.configs["tile_size"])): 135 | for j in range(0, image.shape[1], self.configs["tile_size"]): 136 | pad_x = False 137 | pad_y = False 138 | if image.shape[1] <= j + self.configs["tile_size"]: 139 | xmax_step = image.shape[1] - j 140 | pad_x = True 141 | else: 142 | xmax_step = self.configs["tile_size"] 143 | 144 | if image.shape[0] <= i + self.configs["tile_size"]: 145 | ymax_step = image.shape[0] - i 146 | pad_y = True 147 | else: 148 | ymax_step = self.configs["tile_size"] 149 | 150 | transform = A.augmentations.Crop(p=1.0, x_min=j, y_min=i, x_max=j + xmax_step, y_max=i + ymax_step) 151 | tile_transform = transform(image=image, mask=label) 152 | tile_image = tile_transform["image"] 153 | tile_mask = tile_transform["mask"] 154 | 155 | if pad_x or pad_y: 156 | tmp_tile = np.zeros((120, 120, 4)) 157 | tmp_label = np.zeros((120, 120)) 158 | 159 | tmp_tile[0 : tile_image.shape[0], 0 : tile_image.shape[1], :] = tile_image 160 | tmp_label[0 : tile_image.shape[0], 0 : tile_image.shape[1]] = tile_mask 161 | 162 | tile_image = tmp_tile 163 | tile_mask = tmp_label 164 | 165 | tile_id = os.path.join(self.tile_path, image_name + "_" + str(i) + "_" + str(j) + ".npy") 166 | label_id = os.path.join(self.tile_label_path, image_name + "_" + str(i) + "_" + str(j) + ".png") 167 | tile_image = tile_image.astype(np.float32) 168 | with open(tile_id, "wb") as f: 169 | np.save(f, tile_image) 170 | cv.imwrite(label_id, tile_mask) 171 | record = {"image": tile_id, "label": label_id} 172 | tiles.append(record) 173 | 174 | return tiles 175 | 176 | def __len__(self): 177 | return self.num_examples 178 | 179 | def __getitem__(self, index): 180 | sample = self.samples[index] 181 | image_path = sample["image"] 182 | label = cv.imread(sample["label"], 0) 183 | 184 | if not self.configs["tilerize"]: 185 | tif = rio.open_rasterio(image_path, engine="rasterio").sel(band=[3, 2, 1, 4]) 186 | image = tif.to_numpy() 187 | image = image[:, :6907, :7300] 188 | label = label[:6907, :7300] 189 | image = einops.rearrange(image, "c h w -> h w c") 190 | else: 191 | with open(image_path, "rb") as f: 192 | image = np.load(f) 193 | image = image.astype(np.float32) 194 | 195 | image = einops.rearrange(image, "h w c->c h w") 196 | if not self.configs["webdataset"]: 197 | if self.configs["augment"] and self.mode == "train": 198 | image = einops.rearrange(image, "c h w -> h w c") 199 | transform = self.augmentations(image=image) 200 | image = transform["image"] 201 | image = einops.rearrange(image, "h w c -> c h w") 202 | if self.configs["normalization"] == "minmax": 203 | image /= image.max() 204 | elif self.configs["normalization"] == "standard": 205 | image = torch.from_numpy(image).float() 206 | image = self.normalization(image) 207 | else: 208 | image = torch.from_numpy(image).float() 209 | else: 210 | image = torch.from_numpy(image).float() 211 | return image, torch.from_numpy(label) 212 | -------------------------------------------------------------------------------- /datasets/ForestNetDataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import pprint 4 | import random 5 | import warnings 6 | 7 | import cv2 as cv 8 | import einops 9 | import matplotlib.pyplot as plt 10 | import numpy as np 11 | import pandas as pd 12 | import pyjson5 as json 13 | import rasterio 14 | import torch 15 | from tqdm import tqdm 16 | 17 | import utilities.augmentations 18 | import albumentations as A 19 | 20 | warnings.simplefilter("ignore") 21 | """ 22 | Data loading for the ForestNet Dataset published in: 23 | Irvin, Jeremy, et al. "Forestnet: Classifying drivers of deforestation in indonesia using deep learning on satellite imagery." 24 | """ 25 | 26 | 27 | class ForestNetDataset(torch.utils.data.Dataset): 28 | def __init__(self, configs, mode="train"): 29 | self.configs = configs 30 | self.root_path = os.path.join(configs["root_path"], "ForestNet") 31 | self.mode = mode 32 | self.label_category = {"Plantation": 1, "Smallholder agriculture": 2, "Grassland shrubland": 3, "Other": 4} 33 | if self.configs["augment"]: 34 | self.augmentations = utilities.augmentations.get_augmentations(configs) 35 | else: 36 | self.augmentations = None 37 | record_path = os.path.join(self.root_path, self.mode + ".csv") 38 | 39 | self.metadata = pd.read_csv(record_path) 40 | self.samples = [] 41 | for index, row in self.metadata.iterrows(): 42 | sample = {} 43 | 44 | sample["label"] = self.label_category[row["merged_label"]] 45 | sample["year"] = row["year"] 46 | sample["path"] = os.path.join(self.root_path, row["example_path"]) 47 | sample["forest_loss_path"] = os.path.join(sample["path"], "forest_loss_region.pkl") 48 | sample["auxiliary_path"] = os.path.join(sample["path"], "auxiliary") 49 | sample["images_path"] = os.path.join(sample["path"], "images") 50 | self.samples.append(sample) 51 | 52 | self.num_examples = len(self.samples) 53 | 54 | def __len__(self): 55 | return self.num_examples 56 | 57 | def plot(self, index=0): 58 | sample = self.samples[index] 59 | forest_loss = pickle.load(open(sample["forest_loss_path"], "rb")) 60 | 61 | image = cv.imread(os.path.join(sample["images_path"], "visible", "composite.png")) 62 | mask = rasterio.features.rasterize([forest_loss], fill=0, out_shape=image.shape[:2]) 63 | mask[mask > 0] = sample["label"] 64 | infrared = np.load(os.path.join(sample["images_path"], "infrared", "composite.npy")) 65 | _, ax = plt.subplots(nrows=1, ncols=3, figsize=(12, 4)) 66 | 67 | ax[0].imshow(mask) 68 | ax[0].set_title("Forest loss mask") 69 | ax[1].imshow(image) 70 | ax[1].set_title("Visible satellite image") 71 | ax[2].imshow(infrared) 72 | ax[2].set_title("Infrared") 73 | plt.savefig("ForestNetSample.png") 74 | plt.show() 75 | 76 | def __getitem__(self, index): 77 | sample = self.samples[index] 78 | forest_loss = pickle.load(open(sample["forest_loss_path"], "rb")) 79 | image = cv.imread(os.path.join(sample["images_path"], "visible", "composite.png")).astype(float) 80 | mask = rasterio.features.rasterize([forest_loss], fill=0, out_shape=image.shape[:2]) 81 | 82 | # Make it a multiclass problem 83 | mask[mask > 0] = sample["label"] 84 | 85 | # Composite image. According to the paper: A composite image is constructed by taking a per-pixel median over these cloud-filtered scenes, using the five least 86 | # cloudy scenes when less than five such scenes were available. 87 | # We don't use it 88 | composite = np.load(os.path.join(sample["images_path"], "infrared", "composite.npy")) 89 | 90 | # Auxiliary data 91 | # aux = np.load(os.path.join(sample['images_path'],'auxiliary',..)) 92 | 93 | if self.configs["resize"] is not None: 94 | resize = A.Compose( 95 | [A.augmentations.Resize(height=self.configs["resize"], width=self.configs["resize"], p=1.0)] 96 | ) 97 | transform = resize(image=image, mask=mask) 98 | image = transform["image"] 99 | mask = transform["mask"] 100 | 101 | label = sample["label"] 102 | if not self.configs["webdataset"]: 103 | if self.configs["augment"] and self.mode == "train": 104 | transform = self.augmentations(image=image, mask=mask) 105 | image = transform["image"] 106 | mask = transform["mask"] 107 | if self.configs["normalization"] == "minmax": 108 | image /= image.max() 109 | elif self.configs["normalization"] == "standard": 110 | image = torch.from_numpy(image).float() 111 | image = self.normalization(image) 112 | else: 113 | image = torch.from_numpy(image).float() 114 | 115 | image = einops.rearrange(image, "h w c -> c h w") 116 | image = torch.from_numpy(image).float() 117 | 118 | if self.configs["task"] == "segmentation": 119 | return image, torch.from_numpy(mask).long() 120 | elif self.configs["task"] == "classification": 121 | return image, torch.tensor(label).long() 122 | -------------------------------------------------------------------------------- /datasets/MixedPCSegDataset.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import random 3 | 4 | from pathlib import Path 5 | import einops 6 | import laspy 7 | import matplotlib.pyplot as plt 8 | import numpy as np 9 | import rasterio 10 | import torch 11 | from torchvision import transforms 12 | 13 | import albumentations as A 14 | from albumentations.augmentations import Resize 15 | from torch_geometric.data import Data 16 | import torch_geometric.transforms as T 17 | 18 | import utilities.augmentations 19 | from utilities.utils import format_bboxes_voc_to_yolo 20 | from datasets.NeonTreeDataset import NeonTreeDataset 21 | from datasets.FORinstanceDataset import FORinstanceDataset 22 | 23 | 24 | class MixedPCSegDataset(torch.utils.data.Dataset): 25 | def __init__(self, configs, mode="train", test_on="forinstance"): 26 | self.configs = configs 27 | self.mode = mode 28 | self.test_on = test_on 29 | self.dataset_names = self.configs["dataset_names"].split(",") 30 | self.nb_datasets = len(self.dataset_names) 31 | self.seg_task = self.configs["segmentation_task"] 32 | self.datasets = self._load_datasets() 33 | self.samples = self._group_samples() 34 | # This class support only pc data for the moment 35 | self.modality = "lidar" 36 | 37 | if "nb_points" in self.configs.keys(): 38 | self.nb_points = self.configs["nb_points"] 39 | if self.configs["augment"]: 40 | self.augmentations = utilities.augmentations.get_augmentations(configs) 41 | else: 42 | self.augmentations = None 43 | self.normalization = T.NormalizeScale() 44 | 45 | # Full debugging 46 | # if self.mode in ('train', 'val'): 47 | # Debugging train/val 48 | # self.samples = self.samples[:100] 49 | # else: 50 | # Debugging test 51 | # self.samples['neontree_point_cloud'] = self.samples['neontree_point_cloud'][:100] 52 | # self.samples['forinstance'] = self.samples['forinstance'][:100] 53 | 54 | if self.mode == "test": 55 | self.num_examples = len(self.samples[self.test_on]) 56 | # sum([len(self.samples[dataset]) for dataset in self.samples.keys()]) 57 | else: 58 | self.num_examples = len(self.samples) 59 | print( 60 | "Total number of mixed samples in split " 61 | + "{} with modality: {} = {}".format(self.mode, self.modality, self.num_examples) 62 | ) 63 | 64 | def __len__(self): 65 | return self.num_examples 66 | 67 | def _load_datasets(self): 68 | datasets = dict() 69 | for dataset_name in self.dataset_names: 70 | if dataset_name.lower() == "neontree_point_cloud": 71 | self.configs["dataset_name"] = dataset_name.lower() 72 | self.configs["root_path"] = self.configs["root_path_neontree"] 73 | datasets[dataset_name.lower()] = NeonTreeDataset(self.configs, self.mode) 74 | if dataset_name.lower() == "forinstance": 75 | self.configs["dataset_name"] = dataset_name.lower() 76 | self.configs["root_path"] = self.configs["root_path_forinstance"] 77 | datasets[dataset_name.lower()] = FORinstanceDataset(self.configs, self.mode) 78 | if len(datasets) == 0: 79 | raise Exception("Dataset {} not supported".format(dataset_name)) 80 | return datasets 81 | 82 | def _group_samples(self): 83 | if self.mode in ("train", "val"): 84 | samples = [] 85 | for dataset_name in self.datasets.keys(): 86 | sub_samples = self.datasets[dataset_name].samples 87 | # Need to add dataset name to each sample 88 | if dataset_name == "neontree_point_cloud": 89 | sub_samples = [{"sub_pc": sample["sub_pc"], "dataset": dataset_name} for sample in sub_samples] 90 | else: 91 | sub_samples = [{"sub_pc": sample, "dataset": dataset_name} for sample in sub_samples] 92 | samples += sub_samples 93 | random.Random(999).shuffle(samples) 94 | else: # Supposed to be test 95 | samples = dict() 96 | for dataset_name in self.datasets.keys(): 97 | sub_samples = self.datasets[dataset_name].samples 98 | # Need to add dataset name to each sample 99 | if dataset_name == "neontree_point_cloud": 100 | sub_samples = [sample.update({"dataset": dataset_name}) for sample in sub_samples] 101 | else: 102 | sub_samples = [{"sub_pc": sample, "dataset": dataset_name} for sample in sub_samples] 103 | samples[dataset_name] = sub_samples 104 | return samples 105 | 106 | def __getitem__(self, index): 107 | if self.mode == "test": 108 | sample = self.samples[self.test_on][index] 109 | else: 110 | sample = self.samples[index] 111 | 112 | with laspy.open(sample["sub_pc"]) as lidar_file: 113 | pc = lidar_file.read() 114 | 115 | if sample["dataset"] == "neontree_point_cloud": 116 | # Only point location are considered for the moment 117 | point_cloud = np.vstack([pc.x, pc.y, pc.z]).T 118 | elif sample["dataset"] == "forinstance": 119 | point_cloud = np.stack([np.array(pc.x), np.array(pc.y), np.array(pc.z)]).T 120 | else: 121 | raise Exception("Dataset {} not supported yet") 122 | 123 | if self.seg_task == "semantic_segmentation": 124 | if sample["dataset"] == "neontree_point_cloud": 125 | labels = pc.instance_id.copy() 126 | labels[labels != 0] = 7 # 0-6 are FORinstance 127 | labels = np.array(labels, dtype=np.int64) 128 | else: 129 | labels = np.array(pc.classification, dtype=np.int64) 130 | else: 131 | raise Exception("Task {} is not supported yet.".format(self.seg_task)) 132 | 133 | # get random sub sample 134 | if self.mode == "train": 135 | if self.nb_points > point_cloud.shape[0]: 136 | replace = True 137 | else: 138 | replace = False 139 | rand_idx = np.random.choice(list(range(point_cloud.shape[0])), size=self.nb_points, replace=replace) 140 | point_cloud = point_cloud[rand_idx, :] 141 | labels = labels[rand_idx] 142 | 143 | point_cloud = torch.tensor(point_cloud) 144 | labels = torch.tensor(labels) 145 | # No feature is considered for baselines 146 | x = torch.ones((point_cloud.shape[0], 3), dtype=torch.float) 147 | data = Data(pos=point_cloud, x=x, y=labels) 148 | # normalize coords 149 | data = self.normalization(data) 150 | if self.configs["augment"] and self.mode == "train": 151 | # Augment PC if required 152 | data = self.augmentations(data) 153 | 154 | return data, None # forced by collate 155 | 156 | def collate_fn(self, batch): 157 | return tuple(zip(*batch)) 158 | -------------------------------------------------------------------------------- /datasets/Sen12MSDataset.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | import pickle 4 | import random 5 | 6 | import albumentations as A 7 | import cv2 as cv 8 | import einops 9 | import matplotlib.pyplot as plt 10 | import numpy as np 11 | import pandas as pd 12 | import pyjson5 as json 13 | import rasterio 14 | import torch 15 | import tqdm 16 | from torchvision import transforms 17 | 18 | 19 | class Sen12MSDataset(torch.utils.data.Dataset): 20 | def __init__(self, configs, mode="train"): 21 | print("=" * 40) 22 | print("Initializing SEN12MS mode - ", mode) 23 | print("=" * 40) 24 | 25 | self.configs = configs 26 | self.root_path = os.path.join(configs["root_path"], "Sen12MS") 27 | self.mode = mode 28 | if self.mode == "train" or self.mode == "val": 29 | sample_file = "train_list.txt" 30 | elif self.mode == "test": 31 | sample_file = "test_list.txt" 32 | else: 33 | print("Uknown phase") 34 | exit(2) 35 | self.valid_samples = open(os.path.join(self.root_path, sample_file), "r").readlines() 36 | self.labels = pickle.load(open(os.path.join(self.root_path, "IGBP_probability_labels.pkl"), "rb")) 37 | 38 | if self.configs["normalization"] == "standard": 39 | self.normalization = transforms.Normalize(mean=self.configs["mean"], std=self.configs["std"]) 40 | 41 | rois_path = os.path.join(self.root_path, "ROIs") 42 | self.rois = os.listdir() 43 | self.samples = [] 44 | for file in self.valid_samples: 45 | file = file.strip() 46 | roi = "_".join(file.split("_")[:2]) 47 | 48 | s2_folder = "_".join(file.split("_")[2:4]) 49 | s1_folder = s2_folder.replace("s2", "s1") 50 | lc_folder = s2_folder.replace("s2", "lc") 51 | 52 | s1_path = os.path.join(rois_path, roi, s1_folder, file.replace("_s2_", "_s1_")) 53 | s2_path = os.path.join(rois_path, roi, s2_folder, file) 54 | lc_path = os.path.join(rois_path, roi, lc_folder, file.replace("_s2_", "_lc_")) 55 | 56 | label = self.labels[file] 57 | sample = {"lc_path": lc_path, "s1_path": s1_path, "s2_path": s2_path, "labels": label} 58 | self.samples.append(sample) 59 | 60 | if mode == "train": 61 | random.Random(999).shuffle(self.samples) 62 | self.samples = self.samples[: int(0.9 * len(self.samples))] 63 | elif mode == "val": 64 | random.Random(999).shuffle(self.samples) 65 | self.samples = self.samples[int(0.9 * len(self.samples)) :] 66 | 67 | self.num_examples = len(self.samples) 68 | 69 | def __len__(self): 70 | return self.num_examples 71 | 72 | def plot(self, index=0): 73 | sample = self.samples[index] 74 | s1_path = sample["s1_path"] 75 | s2_path = sample["s2_path"] 76 | with rasterio.open(s1_path) as srcs1: 77 | s1_patch = srcs1.read() 78 | with rasterio.open(s2_path) as srcs2: 79 | s2_patch = srcs2.read() 80 | labels = sample["labels"] 81 | labels[labels >= 0.5] = 1 82 | labels[labels < 0.5] = 0 83 | _, ax = plt.subplots(nrows=1, ncols=3, figsize=((12, 4))) 84 | ax[0].imshow(s1_patch[0]) 85 | ax[0].set_title("VV") 86 | ax[1].imshow(s1_patch[1]) 87 | ax[1].set_title("VH") 88 | s2_patch = einops.rearrange(s2_patch[1:4, :, :], "c h w -> h w c") 89 | s2_patch = cv.cvtColor(s2_patch, cv.COLOR_BGR2RGB) 90 | ax[2].imshow(s2_patch / s2_patch.max()) 91 | ax[2].set_title("RGB") 92 | 93 | plt.savefig("Sen12MS_sample_" + str(index) + ".png") 94 | plt.show() 95 | 96 | def __getitem__(self, index): 97 | sample = self.samples[index] 98 | s1_path = sample["s1_path"] 99 | s2_path = sample["s2_path"] 100 | with rasterio.open(s1_path) as srcs1: 101 | s1_patch = srcs1.read() 102 | with rasterio.open(s2_path) as srcs2: 103 | s2_patch = srcs2.read() 104 | labels = sample["labels"] 105 | labels[labels >= 0.5] = 1 106 | labels[labels < 0.5] = 0 107 | s1_patch = torch.from_numpy(s1_patch.astype("float")).float() 108 | s2_patch = torch.from_numpy(s2_patch.astype("float")).float() 109 | 110 | s2_patch = torch.clamp(s2_patch, min=0, max=10000) 111 | s1_patch = torch.clamp(s1_patch, min=-25, max=0) 112 | 113 | # Stack data 114 | image = torch.cat((s1_patch, s2_patch), dim=0).numpy() 115 | 116 | if not self.configs["webdataset"]: 117 | if self.configs["augment"] and self.mode == "train": 118 | image = einops.rearrange(image, "c h w -> h w c") 119 | transform = self.augmentations(image=image) 120 | image = transform["image"] 121 | image = einops.rearrange(image, "h w c -> c h w") 122 | if self.configs["normalization"] == "minmax": 123 | s1 = torch.from_numpy(image[:2, :, :]) / 25 + 1 124 | s2 = torch.from_numpy(image[2:, :, :]) / 10000 125 | image = torch.cat((s1, s2), dim=0) 126 | elif self.configs["normalization"] == "standard": 127 | image = torch.from_numpy(image).float() 128 | image = self.normalization(image) 129 | else: 130 | image = torch.from_numpy(image).float() 131 | else: 132 | image = torch.from_numpy(image).float() 133 | 134 | return image, torch.from_numpy(labels) 135 | -------------------------------------------------------------------------------- /datasets/TreeSatAIDataset.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | import random 4 | 5 | import cv2 as cv 6 | import einops 7 | import matplotlib.pyplot as plt 8 | import numpy as np 9 | import pyjson5 as json 10 | import rasterio 11 | import torch 12 | from torchvision import transforms 13 | from tqdm import tqdm 14 | import albumentations as A 15 | 16 | """ 17 | Data loading for TreeSat published in : 18 | S. Ahlswede, C. Schulz, C. Gava, P. Helber, B. Bischke, M. Frster, F. Arias, J. Hees, B. Demir, 19 | and B. Kleinschmit. TreeSatAI Benchmark Archive: A multi-sensor, multi-label dataset for tree 20 | species classification in remote sensing. ESSD, 2022. 21 | """ 22 | 23 | 24 | class TreeSatAIDataset(torch.utils.data.Dataset): 25 | def __init__(self, configs, mode="train"): 26 | print("=" * 20) 27 | print("Initializing TreeSatAI dataset - mode: ", mode) 28 | print("=" * 20) 29 | self.configs = configs 30 | self.root_path = os.path.expandvars(os.path.join(configs["root_path"], "TreeSat")) 31 | self.mode = mode 32 | if self.configs["normalization"] == "standard": 33 | self.normalization = transforms.Normalize(mean=self.configs["mean"], std=self.configs["std"]) 34 | self.data_sources = [] 35 | 36 | for source in self.configs["data_source"]: 37 | if source == "aerial_60m": 38 | self.data_sources.append("aerial_60m") 39 | elif source == "s1_60m": 40 | self.data_sources.append("s1/60m") 41 | elif source == "s2_60m": 42 | self.data_sources.append("s2/60m") 43 | elif source == "s1_200m": 44 | self.data_sources.append("s1/200m") 45 | elif source == "s2_200m": 46 | self.data_sources.append("s2/200m") 47 | else: 48 | print("Data source: ", source, " not supported.") 49 | exit(2) 50 | self.species = os.listdir(os.path.join(self.root_path, "aerial_60m")) 51 | labels_file = open(os.path.join(self.root_path, "labels", "TreeSatBA_v9_60m_multi_labels.json"), "r") 52 | self.labels = json.load(labels_file) 53 | all_species = [] 54 | for v in self.labels.values(): 55 | for item in v: 56 | sp, v = item 57 | all_species.append(sp) 58 | all_species = np.unique(all_species) 59 | print(all_species) 60 | print("Num species: ", len(all_species)) 61 | self.species_dict = {} 62 | for idx, sp in enumerate(all_species): 63 | self.species_dict[sp] = idx 64 | self.valid_samples = None 65 | if self.mode == "train" or self.mode == "val": 66 | self.valid_samples = open(os.path.join(self.root_path, "train_filenames.lst"), "r").readlines() 67 | elif self.mode == "test": 68 | self.valid_samples = open(os.path.join(self.root_path, "test_filenames.lst"), "r").readlines() 69 | else: 70 | print("Mode not supported") 71 | exit(2) 72 | 73 | self.valid_samples = [v.strip() for v in self.valid_samples] 74 | 75 | self.samples = [] 76 | 77 | for data_source in self.data_sources: 78 | if data_source == "aerial_60m": 79 | for plant in tqdm(self.species): 80 | sample = {} 81 | sample["data"] = {} 82 | files = os.listdir(os.path.join(self.root_path, data_source, plant)) 83 | aerial_file = None 84 | for file in files: 85 | if file in self.valid_samples: 86 | aerial_file = file 87 | sample["data"][data_source] = os.path.join(self.root_path, data_source, plant, aerial_file) 88 | sample["labels"] = self.labels[aerial_file] 89 | self.samples.append(sample) 90 | elif "s2/60m" in data_source: 91 | files = os.listdir(os.path.join(self.root_path, data_source)) 92 | for file in files: 93 | if file in self.valid_samples: 94 | sample = {} 95 | sample["data"] = {} 96 | sample["data"][data_source] = os.path.join(self.root_path, data_source, file) 97 | if file in self.labels: 98 | sample["labels"] = self.labels[file] 99 | else: 100 | continue 101 | self.samples.append(sample) 102 | elif "s2/200m" in data_source: 103 | files = os.listdir(os.path.join(self.root_path, data_source)) 104 | for file in files: 105 | if file in self.valid_samples: 106 | sample = {} 107 | sample["data"] = {} 108 | sample["data"][data_source] = os.path.join(self.root_path, data_source, file) 109 | if file in self.labels: 110 | sample["labels"] = self.labels[file] 111 | else: 112 | continue 113 | self.samples.append(sample) 114 | 115 | if self.mode == "train" or self.mode == "val": 116 | random.Random(999).shuffle(self.samples) 117 | if self.mode == "train": 118 | self.samples = self.samples[: int(0.8 * len(self.samples))] 119 | else: 120 | self.samples = self.samples[int(0.8 * len(self.samples)) :] 121 | 122 | self.num_examples = len(self.samples) 123 | print("Number of samples: ", self.num_examples) 124 | 125 | def __len__(self): 126 | return self.num_examples 127 | 128 | def plot(self, index=0): 129 | sample = self.samples[index] 130 | 131 | data = sample["data"] 132 | label = sample["labels"] 133 | num_plots = len(self.data_sources) 134 | if "s1/60m" in self.data_sources: 135 | num_plots += 1 136 | _, ax = plt.subplots(nrows=1, ncols=num_plots, figsize=(12, 4)) 137 | 138 | for idx, source in enumerate(self.data_sources): 139 | with rasterio.open(data[source]) as src: 140 | if "s1/60m" in self.data_sources[:idx]: 141 | idx += 1 142 | img = src.read() 143 | img = einops.rearrange(img[:3, :, :], "c h w -> h w c") 144 | if "s2" in source: 145 | img = cv.cvtColor(img, cv.COLOR_BGR2RGB) 146 | img = cv.resize(img, (304, 304)) 147 | elif "s1" in source: 148 | img = cv.resize(img, (304, 304)) 149 | ax[idx].imshow(img[:, :, 0] / img[:, :, 0].max()) 150 | ax[idx].set_title("VV") 151 | ax[idx + 1].imshow(img[:, :, 1] / img[:, :, 1].max()) 152 | ax[idx + 1].set_title("VH") 153 | continue 154 | ax[idx].imshow(img / img.max()) 155 | ax[idx].set_title(source) 156 | 157 | labels = list(np.unique(label)) 158 | labels = ", ".join(labels) 159 | text = " " * 40 + "Labels:\n" + labels 160 | plt.text(0.4, 0.9, text, fontsize=8, transform=plt.gcf().transFigure) 161 | plt.savefig("TreeSat_sample_" + str(index) + ".png") 162 | 163 | def __getitem__(self, index): 164 | sample = self.samples[index] 165 | 166 | data = sample["data"] 167 | label = sample["labels"] 168 | image = None 169 | for source in self.data_sources: 170 | if len(data.keys()) == 0: 171 | print("Empty!") 172 | print(data) 173 | pass 174 | with rasterio.open(data[source]) as src: 175 | img = src.read() 176 | 177 | if source == "s2/200m" and (img.shape[1] != 20 or img.shape[2] != 20): 178 | img = einops.rearrange(img, "c h w -> h w c") 179 | resize = A.Compose([A.augmentations.Resize(height=20, width=20, p=1.0)]) 180 | transform = resize(image=img) 181 | img = transform["image"] 182 | img = einops.rearrange(img, "h w c->c h w") 183 | if image is None: 184 | img = img.astype(np.float32) 185 | image = torch.from_numpy(img) 186 | else: 187 | image = torch.cat((image, img), dim=0) 188 | image = image.numpy() 189 | if self.configs["enforce_resize"] is not None: 190 | image = einops.rearrange(image, "c h w -> h w c") 191 | resize = A.Compose( 192 | [ 193 | A.augmentations.Resize( 194 | height=self.configs["enforce_resize"], width=self.configs["enforce_resize"], p=1.0 195 | ) 196 | ] 197 | ) 198 | transform = resize(image=image) 199 | image = transform["image"] 200 | image = einops.rearrange(image, "h w c -> c h w") 201 | if not self.configs["webdataset"]: 202 | if self.configs["augment"] and self.mode == "train": 203 | image = einops.rearrange(image, "c h w -> h w c") 204 | transform = self.augmentations(image=image) 205 | image = transform["image"] 206 | image = einops.rearrange(image, "h w c -> c h w") 207 | if self.configs["normalization"] == "minmax": 208 | image /= image.max() + 1e-5 209 | elif self.configs["normalization"] == "standard": 210 | image = torch.from_numpy(image).float() 211 | image = self.normalization(image) 212 | else: 213 | image = torch.from_numpy(image).float() 214 | else: 215 | image = torch.from_numpy(image).float() 216 | 217 | final_label = torch.zeros(len(self.species_dict)) 218 | for l in label: 219 | if self.configs["filter_threshold"] is None: 220 | final_label[self.species_dict[l[0]]] = 1 221 | else: 222 | if l[1] > self.configs["filter_threshold"]: 223 | final_label[self.species_dict[l[0]]] = 1 224 | label = final_label.float() 225 | 226 | return image, label 227 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RolnickLab/FoMo-Bench/93f7218a4bf928b50e0eaf827f74d1d8c79f27e5/datasets/__init__.py -------------------------------------------------------------------------------- /downloader.py: -------------------------------------------------------------------------------- 1 | import pyjson5 as json 2 | import subprocess 3 | import os 4 | from pathlib import Path 5 | import datasets.ReforesTreeDataset as ReforesTreeDataset 6 | import datasets.SSL4EOLDataset as SSL4EOL 7 | import datasets.SSL4EOS1S2Dataset as SSL4EOS1S2 8 | 9 | 10 | if __name__=='__main__': 11 | configs = json.load(open('configs/download/download.json','r')) 12 | dataset = configs['dataset'].lower() 13 | root_path = configs['root_path'] 14 | 15 | if dataset=='reforestree': 16 | root_path = os.path.join(root_path,'Reforestree') 17 | reforestree_configs = {"root_path":root_path,"download":True,"checksum":True,"augment":False,"normalization":"none"} 18 | ReforesTreeDataset.ReforesTreeDataset(reforestree_configs) 19 | exit(0) 20 | elif dataset=='ssl4eol': 21 | ssl4eol_configs = {"root_path":root_path,"download":True,"checksum":True,"augment":False,"normalization":"none","split":"oli_tirs_toa","seasons":4} 22 | SSL4EOL.SSL4EOL(ssl4eol_configs) 23 | exit(0) 24 | 25 | download_script = Path('downloading_scripts/' + dataset + '.sh') 26 | 27 | if not download_script.is_file(): 28 | print('Dataset is not supported for downloading!') 29 | exit(2) 30 | 31 | process = subprocess.Popen([str(download_script) + " " + root_path], shell=True, stdout=subprocess.PIPE) 32 | while True: 33 | output = process.stdout.readline().decode() 34 | if output == "" and process.poll() is not None: 35 | break 36 | print(output.strip()) 37 | 38 | process.wait() 39 | 40 | # Check if the download was successful 41 | if process.returncode == 0: 42 | print("Process finished successfully") 43 | else: 44 | print("Process failed!") 45 | 46 | 47 | -------------------------------------------------------------------------------- /downloading_scripts/bigearthnet.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | if [ -z "$1" ]; then 3 | echo "No root path. Exiting!" 4 | exit 1 5 | fi 6 | 7 | root_folder_path=$1 8 | 9 | full_path="$root_folder_path/BigEarthNet" 10 | 11 | mkdir $full_path 12 | 13 | sentinel2_url=https://bigearth.net/downloads/BigEarthNet-S2-v1.0.tar.gz 14 | sentinel1_url=https://bigearth.net/downloads/BigEarthNet-S1-v1.0.tar.gz 15 | label_url=https://git.tu-berlin.de/rsim/BigEarthNet-S2_19-classes_models/-/raw/master/label_indices.json?inline=false 16 | 17 | train_url=https://git.tu-berlin.de/rsim/BigEarthNet-S2_19-classes_models/-/raw/master/splits/train.csv?inline=false 18 | val_url=https://git.tu-berlin.de/rsim/BigEarthNet-S2_19-classes_models/-/raw/master/splits/val.csv?inline=false 19 | test_url=https://git.tu-berlin.de/rsim/BigEarthNet-S2_19-classes_models/-/raw/master/splits/test.csv?inline=false 20 | 21 | echo "Downloading Sentinel-2 data" 22 | 23 | wget -O "$full_path/BigEarthNet-S2-v1.0.tar.gz" $sentinel2_url 24 | 25 | 26 | # Check if the download was successful 27 | if [ $? -eq 0 ]; then 28 | echo "Downloading finished normally." 29 | else 30 | echo "Downloading failed." 31 | fi 32 | 33 | echo "Downloading Sentinel-1 data" 34 | 35 | wget -O "$full_path/BigEarthNet-S1-v1.0.tar.gz" $sentinel1_url 36 | 37 | # Check if the download was successful 38 | if [ $? -eq 0 ]; then 39 | echo "Downloading finished normally." 40 | else 41 | echo "Downloading failed." 42 | fi 43 | 44 | echo "Downloading metadata" 45 | 46 | wget -O "$full_path/label_indices.json" $label_url 47 | 48 | # Check if the download was successful 49 | if [ $? -eq 0 ]; then 50 | echo "Downloading finished normally." 51 | else 52 | echo "Downloading failed." 53 | fi 54 | 55 | wget -O "$full_path/train.csv" $train_url 56 | # Check if the download was successful 57 | if [ $? -eq 0 ]; then 58 | echo "Downloading finished normally." 59 | else 60 | echo "Downloading failed." 61 | fi 62 | 63 | wget -O "$full_path/val.csv" $val_url 64 | # Check if the download was successful 65 | if [ $? -eq 0 ]; then 66 | echo "Downloading finished normally." 67 | else 68 | echo "Downloading failed." 69 | fi 70 | 71 | wget -O "$full_path/test.csv" $test_url 72 | # Check if the download was successful 73 | if [ $? -eq 0 ]; then 74 | echo "Downloading finished normally." 75 | else 76 | echo "Downloading failed." 77 | fi 78 | 79 | wget -O "$full_path/s1s2_mapping.csv" https://git.tu-berlin.de/rsim/BigEarthNet-MM_tools/-/raw/master/files/s1s2_mapping.csv?inline=false 80 | # Check if the download was successful 81 | if [ $? -eq 0 ]; then 82 | echo "Downloading finished normally." 83 | else 84 | echo "Downloading failed." 85 | fi 86 | 87 | echo "Untaring data" 88 | 89 | tar -xvf "$full_path/BigEarthNet-S2-v1.0.tar.gz" --directory $full_path 90 | 91 | # Check if the download was successful 92 | if [ $? -eq 0 ]; then 93 | echo "Extracted Sentinel-2 data succesfully." 94 | else 95 | echo "Sentinel-2 extraction failed." 96 | fi 97 | 98 | 99 | tar -xvf "$full_path/BigEarthNet-S1-v1.0.tar.gz" --directory $full_path 100 | 101 | # Check if the download was successful 102 | if [ $? -eq 0 ]; then 103 | echo "Extracted Sentinel-1 data succesfully." 104 | else 105 | echo "Sentinel-1 extraction failed." 106 | fi 107 | 108 | 109 | echo "Removing tar files" 110 | 111 | rm "$full_path/BigEarthNet-S2-v1.0.tar.gz" 112 | rm "$full_path/BigEarthNet-S1-v1.0.tar.gz" 113 | 114 | echo "Download finished." -------------------------------------------------------------------------------- /downloading_scripts/flair.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | if [ -z "$1" ]; then 3 | echo "No root path. Exiting!" 4 | exit 1 5 | fi 6 | root_folder_path=$1 7 | 8 | train_url=https://storage.gra.cloud.ovh.net/v1/AUTH_366279ce616242ebb14161b7991a8461/defi-ia/flair_data_1/flair_aerial_train.zip 9 | test_url=https://storage.gra.cloud.ovh.net/v1/AUTH_366279ce616242ebb14161b7991a8461/defi-ia/flair_data_1/flair_1_aerial_test.zip 10 | labels_train=https://storage.gra.cloud.ovh.net/v1/AUTH_366279ce616242ebb14161b7991a8461/defi-ia/flair_data_1/flair_labels_train.zip 11 | labels_test=https://storage.gra.cloud.ovh.net/v1/AUTH_366279ce616242ebb14161b7991a8461/defi-ia/flair_data_1/flair_1_labels_test.zip 12 | 13 | dataset_folder="FLAIR" 14 | train_filename="train.zip" 15 | test_filename="test.zip" 16 | full_path="$root_folder_path/$dataset_folder" 17 | 18 | mkdir $full_path 19 | 20 | echo "Downloading training set." 21 | wget -O "$full_path/$train_filename" $train_url 22 | 23 | # Check if the download was successful 24 | if [ $? -eq 0 ]; then 25 | echo "Flair training set download finished normally." 26 | else 27 | echo "Downloading failed." 28 | fi 29 | 30 | 31 | echo "Downloading labels for the training set." 32 | wget -O "$full_path/train_labels.zip" $labels_train 33 | 34 | # Check if the download was successful 35 | if [ $? -eq 0 ]; then 36 | echo "Flair training labels download finished normally." 37 | else 38 | echo "Downloading failed." 39 | fi 40 | 41 | 42 | wget -O "$full_path/$test_filename" $test_url 43 | 44 | # Check if the download was successful 45 | if [ $? -eq 0 ]; then 46 | echo "Flair training set download finished normally." 47 | else 48 | echo "Downloading failed." 49 | fi 50 | 51 | echo "Downloading labels for the test set." 52 | wget -O "$full_path/test_labels.zip" $labels_test 53 | 54 | # Check if the download was successful 55 | if [ $? -eq 0 ]; then 56 | echo "Flair test labels download finished normally." 57 | else 58 | echo "Downloading failed." 59 | fi 60 | 61 | 62 | echo "Extracting dataset" 63 | 64 | unzip "$full_path/$train_filename" -d "$full_path" 65 | unzip "$full_path/$test_filename" -d "$full_path" 66 | 67 | unzip "$full_path/train_labels.zip" -d "$full_path" 68 | unzip "$full_path/test_labels.zip" -d "$full_path" 69 | 70 | rm "$full_path/$train_filename" 71 | rm "$full_path/$test_filename" 72 | rm "$full_path/train_labels.zip" 73 | rm "$full_path/test_labels.zip" 74 | 75 | echo "FLAIR extracting finished." -------------------------------------------------------------------------------- /downloading_scripts/flair_2.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | if [ -z "$1" ]; then 3 | echo "No root path. Exiting!" 4 | exit 1 5 | fi 6 | root_folder_path=$1 7 | 8 | dataset_folder="FLAIR_2" 9 | 10 | train_aerial_url=https://storage.gra.cloud.ovh.net/v1/AUTH_366279ce616242ebb14161b7991a8461/defi-ia/flair_data_2/flair_aerial_train.zip 11 | train_sentinel_url=https://storage.gra.cloud.ovh.net/v1/AUTH_366279ce616242ebb14161b7991a8461/defi-ia/flair_data_2/flair_sen_train.zip 12 | train_labels_url=https://storage.gra.cloud.ovh.net/v1/AUTH_366279ce616242ebb14161b7991a8461/defi-ia/flair_data_2/flair_labels_train.zip 13 | 14 | test_aerial_url=https://storage.gra.cloud.ovh.net/v1/AUTH_366279ce616242ebb14161b7991a8461/defi-ia/flair_data_2/flair_2_aerial_test.zip 15 | test_sentinel_url=https://storage.gra.cloud.ovh.net/v1/AUTH_366279ce616242ebb14161b7991a8461/defi-ia/flair_data_2/flair_2_sen_test.zip 16 | 17 | aerial_meta_url=https://storage.gra.cloud.ovh.net/v1/AUTH_366279ce616242ebb14161b7991a8461/defi-ia/flair_data_2/flair_2_aerial_metadata.zip 18 | aerial_sentinel_mapping=https://storage.gra.cloud.ovh.net/v1/AUTH_366279ce616242ebb14161b7991a8461/defi-ia/flair_data_2/flair_2_centroids_sp_to_patch.zip 19 | 20 | urls=($train_aerial_url $train_sentinel_url $train_labels_url $test_aerial_url $test_sentinel_url $aerial_meta_url $aerial_sentinel_mapping) 21 | 22 | full_path="$root_folder_path/$dataset_folder" 23 | 24 | mkdir $full_path 25 | for url in "${urls[@]}" 26 | do 27 | clear_name=$(basename "$url") 28 | wget -O "$full_path/$clear_name" $url 29 | if [ $? -eq 0 ]; then 30 | echo "$clear_name Download complete. Unzipping!" 31 | unzip "$full_path/$clear_name" -d "$full_path" 32 | echo "$clear_name extracted. Removing zip file" 33 | rm $full_path/$clear_name 34 | else 35 | echo "$clear_name Download failed." 36 | fi 37 | done 38 | -------------------------------------------------------------------------------- /downloading_scripts/forestnet.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | if [ -z "$1" ]; then 3 | echo "No root path. Exiting!" 4 | exit 1 5 | fi 6 | root_folder_path=$1 7 | 8 | download_link=http://download.cs.stanford.edu/deep/ForestNetDataset.zip 9 | 10 | dataset_folder="ForestNet" 11 | filename="ForestNet.zip" 12 | full_path="$root_folder_path/$dataset_folder" 13 | 14 | mkdir $full_path 15 | 16 | wget -O "$full_path/$filename" $download_link 17 | 18 | # Check if the download was successful 19 | if [ $? -eq 0 ]; then 20 | echo "ForestNet download finished normally." 21 | else 22 | echo "Downloading failed." 23 | fi 24 | 25 | echo "Extracting dataset" 26 | 27 | unzip "$full_path/$filename" -d "$full_path" 28 | 29 | rm "$full_path/$filename" 30 | 31 | sub_dir="/deep/downloads/ForestNetDataset" 32 | mv "$full_path/$sub_dir"/* $full_path 33 | rm -rf "$full_path/deep" 34 | echo "ForestNet has been extracted." -------------------------------------------------------------------------------- /downloading_scripts/neontree.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | if [ -z "$1" ]; then 3 | echo "No root path. Exiting!" 4 | exit 1 5 | fi 6 | root_folder_path=$1 7 | 8 | dataset_folder="NeonTree" 9 | full_path="$root_folder_path/$dataset_folder" 10 | 11 | mkdir $full_path 12 | annotations_url=https://zenodo.org/record/5914554/files/annotations.zip?download=1 13 | training_url=https://zenodo.org/record/5914554/files/training.zip?download=1 14 | evaluation_url=https://zenodo.org/record/5914554/files/evaluation.zip?download=1 15 | 16 | echo "Downloading files" 17 | 18 | wget -O "$full_path/annotations.zip" $annotations_url 19 | 20 | # Check if the download was successful 21 | if [ $? -eq 0 ]; then 22 | echo "NeonTree download finished normally." 23 | else 24 | echo "Downloading failed." 25 | fi 26 | 27 | wget -O "$full_path/training.zip" $training_url 28 | 29 | # Check if the download was successful 30 | if [ $? -eq 0 ]; then 31 | echo "NeonTree download finished normally." 32 | else 33 | echo "Downloading failed." 34 | fi 35 | 36 | wget -O "$full_path/evaluation.zip" $evaluation_url 37 | 38 | # Check if the download was successful 39 | if [ $? -eq 0 ]; then 40 | echo "NeonTree download finished normally." 41 | else 42 | echo "Downloading failed." 43 | fi 44 | echo "Extracting files" 45 | 46 | unzip "$full_path/annotations.zip" -d "$full_path" 47 | mkdir "$full_path/training" 48 | unzip "$full_path/training.zip" -d "$full_path/training/" 49 | unzip "$full_path/evaluation.zip" -d "$full_path" 50 | 51 | echo "Removing zip files" 52 | 53 | rm "$full_path/annotations.zip" 54 | rm "$full_path/training.zip" 55 | rm "$full_path/evaluation.zip" 56 | 57 | rm -rf "$full_path/__MACOSX" -------------------------------------------------------------------------------- /downloading_scripts/rapidai4eo.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | if [ -z "$1" ]; then 3 | echo "No root path. Exiting!" 4 | exit 1 5 | fi 6 | root_folder_path=$1 7 | 8 | full_path="$root_folder_path" 9 | 10 | mkdir $full_path 11 | 12 | azcopy copy --recursive https://radiantearth.blob.core.windows.net/mlhub/rapidai4eo/ "$full_path" -------------------------------------------------------------------------------- /downloading_scripts/sen12ms.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | if [ -z "$1" ]; then 3 | echo "No root path. Exiting!" 4 | exit 1 5 | fi 6 | root_folder_path=$1 7 | full_path="$root_folder_path/Sen12MS" 8 | mkdir $full_path 9 | download_link=https://dataserv.ub.tum.de/s/m1474000/download 10 | 11 | wget --no-check-certificate -O "$full_path/sen12ms.zip" $download_link 12 | 13 | if [ $? -eq 0 ]; then 14 | echo "Downloading finished normally." 15 | else 16 | echo "Downloading failed." 17 | fi 18 | 19 | echo "Extracting data" 20 | 21 | unzip "$full_path/sen12ms.zip" -d $full_path 22 | 23 | if [ $? -eq 0 ]; then 24 | echo "Removing zip files" 25 | rm "$full_path/sen12ms.zip" 26 | else 27 | echo "Extraction failed." 28 | fi 29 | 30 | mv "$full_path/m1474000"/* "$full_path/" 31 | rm -rf "$full_path/m1474000/" 32 | mkdir "$full_path/ROIs" 33 | 34 | for file in "$full_path"/*.tar.gz; 35 | do 36 | tar -xvf "$file" --directory "$full_path/ROIs"; 37 | rm "$file" 38 | done 39 | 40 | echo "Downloading train/test splits and labels" 41 | 42 | wget -O "$full_path/test_list.txt" https://raw.githubusercontent.com/schmitt-muc/SEN12MS/master/splits/test_list.txt 43 | 44 | if [ $? -eq 0 ]; then 45 | echo "" 46 | else 47 | echo "Downloading failed." 48 | fi 49 | 50 | wget -O "$full_path/train_list.txt" https://raw.githubusercontent.com/schmitt-muc/SEN12MS/master/splits/train_list.txt 51 | 52 | if [ $? -eq 0 ]; then 53 | echo "" 54 | else 55 | echo "Downloading failed." 56 | fi 57 | 58 | wget -O "$full_path/IGBP_probability_labels.pkl" https://raw.githubusercontent.com/schmitt-muc/SEN12MS/master/labels/IGBP_probability_labels.pkl 59 | 60 | if [ $? -eq 0 ]; then 61 | echo "Downloading finished" 62 | else 63 | echo "Downloading failed." 64 | fi 65 | -------------------------------------------------------------------------------- /downloading_scripts/spekboom.sh: -------------------------------------------------------------------------------- 1 | if [ -z "$1" ]; then 2 | echo "No root path. Exiting!" 3 | exit 1 4 | fi 5 | root_folder_path=$1 6 | 7 | dataset_folder="Spekboom" 8 | full_path="$root_folder_path/$dataset_folder" 9 | 10 | mkdir $full_path 11 | 12 | download_url=https://zenodo.org/record/7564954/files/data_spekboom.zip?download=1#!/bin/bash 13 | 14 | 15 | echo "Downloading files" 16 | 17 | wget -O "$full_path/spekboom.zip" $download_url 18 | 19 | # Check if the download was successful 20 | if [ $? -eq 0 ]; then 21 | echo "Spekboom download finished normally." 22 | else 23 | echo "Downloading failed." 24 | fi 25 | 26 | echo "Extracting data" 27 | 28 | unzip "$full_path/spekboom.zip" -d "$full_path" 29 | 30 | echo "Removing zip files" 31 | 32 | rm "$full_path/spekboom.zip" 33 | 34 | -------------------------------------------------------------------------------- /downloading_scripts/ssl4eos1s2.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | if [ -z "$1" ]; then 3 | echo "No root path. Exiting!" 4 | exit 1 5 | fi 6 | root_folder_path=$1 7 | 8 | full_path="$root_folder_path/SSL4EOS1S2" 9 | 10 | mkdir $full_path 11 | 12 | s2_url="https://dataserv.ub.tum.de/s/m1660427.001/download?path=%2Fssl4eo-s12&files=s2_l2a.tar.gz" 13 | 14 | s1_url="https://dataserv.ub.tum.de/s/m1660427.001/download?path=%2Fssl4eo-s12&files=s1.tar.gz" 15 | 16 | mkdir $full_path/"S2" 17 | mkdir $full_path/"S1" 18 | 19 | echo "Downloading Sentinel-1 from $s1_url" 20 | 21 | wget --no-check-certificate -O "$full_path/s1.tar.gz" $s1_url 22 | 23 | if [ $? -eq 0 ]; then 24 | echo "Sentinel-1 downloading finished!" 25 | else 26 | echo "Downloading failed." 27 | fi 28 | 29 | echo "Extracting Sentinel-1" 30 | 31 | tar -xvf "$full_path/s1.tar.gz" --directory "$full_path/S1"; 32 | 33 | 34 | if [ $? -eq 0 ]; then 35 | echo "Sentinel-1 extraction finished!" 36 | else 37 | echo "Extraction failed." 38 | fi 39 | 40 | 41 | 42 | echo "Downloading Sentinel-2 from $s2_url" 43 | 44 | wget --no-check-certificate -O "$full_path/s2.tar.gz" $s2_url 45 | 46 | if [ $? -eq 0 ]; then 47 | echo "Sentinel-2 downloading finished!" 48 | else 49 | echo "Downloading failed." 50 | fi 51 | 52 | 53 | echo "Extracting Sentinel-2" 54 | 55 | tar -xvf "$full_path/s2.tar.gz" --directory "$full_path/S2"; 56 | 57 | 58 | if [ $? -eq 0 ]; then 59 | echo "Sentinel-1 extraction finished!" 60 | else 61 | echo "Extraction failed." 62 | fi 63 | -------------------------------------------------------------------------------- /downloading_scripts/treesat.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | if [ -z "$1" ]; then 3 | echo "No root path. Exiting!" 4 | exit 1 5 | fi 6 | root_folder_path=$1 7 | 8 | aerial_albies_alba=https://zenodo.org/record/6778154/files/aerial_60m_abies_alba.zip?download=1 9 | aerial_acer_pseudoplatanus=https://zenodo.org/record/6778154/files/aerial_60m_acer_pseudoplatanus.zip?download=1 10 | aerial_alnus_spec=https://zenodo.org/record/6778154/files/aerial_60m_alnus_spec.zip?download=1 11 | aerial_betula_spec=https://zenodo.org/record/6778154/files/aerial_60m_betula_spec.zip?download=1 12 | aerial_cleared=https://zenodo.org/record/6778154/files/aerial_60m_cleared.zip?download=1 13 | aerial_fagus_sylvatica=https://zenodo.org/record/6778154/files/aerial_60m_fagus_sylvatica.zip?download=1 14 | aerial_fraxinus_excelsior=https://zenodo.org/record/6778154/files/aerial_60m_fraxinus_excelsior.zip?download=1 15 | aerial_larix_decidua=https://zenodo.org/record/6778154/files/aerial_60m_larix_decidua.zip?download=1 16 | aerial_larix_kaempferi=https://zenodo.org/record/6778154/files/aerial_60m_larix_kaempferi.zip?download=1 17 | aerial_picea_albies=https://zenodo.org/record/6778154/files/aerial_60m_picea_abies.zip?download=1 18 | aerial_pinus_nigra=https://zenodo.org/record/6778154/files/aerial_60m_pinus_nigra.zip?download=1 19 | aerial_pinus_strobus=https://zenodo.org/record/6778154/files/aerial_60m_pinus_strobus.zip?download=1 20 | aerial_pinus_sylvestris=https://zenodo.org/record/6778154/files/aerial_60m_pinus_sylvestris.zip?download=1 21 | aerial_populus_spec=https://zenodo.org/record/6778154/files/aerial_60m_populus_spec.zip?download=1 22 | aerial_prunus_spec=https://zenodo.org/record/6778154/files/aerial_60m_prunus_spec.zip?download=1 23 | aerial_pseudotsuga=https://zenodo.org/record/6778154/files/aerial_60m_pseudotsuga_menziesii.zip?download=1 24 | aerial_quercus_petraea=https://zenodo.org/record/6778154/files/aerial_60m_quercus_petraea.zip?download=1 25 | aerial_quercus_robur=https://zenodo.org/record/6778154/files/aerial_60m_quercus_robur.zip?download=1 26 | aerial_quercus_rubra=https://zenodo.org/record/6778154/files/aerial_60m_quercus_rubra.zip?download=1 27 | aerial_tilia_spec=https://zenodo.org/record/6778154/files/aerial_60m_tilia_spec.zip?download=1 28 | 29 | geo_json_url=https://zenodo.org/record/6778154/files/geojson.zip?download=1 30 | labels_url=https://zenodo.org/record/6778154/files/labels.zip?download=1 31 | 32 | s1_url=https://zenodo.org/record/6778154/files/s1.zip?download=1 33 | s2_url=https://zenodo.org/record/6778154/files/s2.zip?download=1 34 | 35 | test_filenames_url=https://zenodo.org/record/6778154/files/test_filenames.lst?download=1 36 | train_filenames_url=https://zenodo.org/record/6778154/files/train_filenames.lst?download=1 37 | 38 | 39 | 40 | 41 | aerial_urls=($aerial_albies_alba $aerial_acer_pseudoplatanus $aerial_alnus_spec $aerial_betula_spec $aerial_cleared $aerial_fagus_sylvatica $aerial_fraxinus_excelsior $aerial_larix_decidua $aerial_larix_kaempferi $aerial_picea_albies $aerial_pinus_nigra $aerial_pinus_strobus $aerial_pinus_sylvestris $aerial_populus_spec $aerial_prunus_spec $aerial_pseudotsuga $aerial_quercus_petraea $aerial_quercus_robur $aerial_quercus_rubra $aerial_tilia_spec) 42 | #Filenames for each url 43 | aerial_filenames=("albies_alba.zip" "acer_pseudoplatanus.zip" "alnus_spec.zip" "betula_spec.zip" "cleared.zip" "fagus_sylvatica.zip" "fraxinus_excelsior.zip" "larix_decidua.zip" "larix_kaempferi.zip" "picea_albies.zip" "pinus_nigra.zip" "pinus_strobus.zip" "pinus_sylvestris.zip" "populus_spec.zip" "prunus_spec.zip" "pseudotsuga_menziesii.zip" "quercus_petraea.zip" "quercus_robur" "quercus_rubra.zip" "tilia_spec.zip") 44 | 45 | #Create folder if it doesn't exist 46 | dataset_folder="TreeSat" 47 | mkdir $root_folder_path/$dataset_folder 48 | 49 | mkdir $root_folder_path/$dataset_folder/"aerial_60m" $root_folder_path/$dataset_folder/"geojson" $root_folder_path/$dataset_folder/"labels" 50 | echo "Downloading aerial data." 51 | #Download each url 52 | full_path="$root_folder_path/$dataset_folder" 53 | 54 | for ((index=0;index<${#aerial_filenames[@]}; index++)) 55 | do 56 | file="aerial_60m/${aerial_filenames[index]}" 57 | echo "Saving to $full_path/$file" 58 | # Use the curl command to download the file 59 | wget -O "$full_path"/"$file" "${aerial_urls[index]}" 60 | clear_name=${file%.zip} 61 | mkdir "$full_path"/"$clear_name" 62 | unzip "$full_path"/"$file" -d "$full_path"/"$clear_name" 63 | rm "$full_path"/"$file" 64 | echo "File "${aerial_filenames[index]}" downloaded and extracted successfully!" 65 | done 66 | 67 | echo "Downloading Sentinel data" 68 | 69 | sentinel_urls=($s1_url $s2_url) 70 | sentinel_filenames=("s1.zip" "s2.zip") 71 | 72 | for ((index=0;index<${#sentinel_urls[@]}; index++)) 73 | do 74 | current_name=${sentinel_filenames[index]} 75 | clear_name=${current_name%.zip} 76 | file="${sentinel_filenames[index]}" 77 | echo "Saving to $full_path/$file" 78 | # Use the wget command to download the file 79 | wget -O "$full_path/$file" "${sentinel_urls[index]}" 80 | unzip "$full_path/$file" -d "$full_path" 81 | rm "$full_path/$file" 82 | echo "File "${sentinel_filenames[index]}" downloaded and extracted successfully!" 83 | done 84 | 85 | 86 | 87 | echo "Downloading labels" 88 | 89 | wget -O "$full_path/labels.zip" $labels_url 90 | unzip "$full_path/labels.zip" -d "$full_path" 91 | rm "$full_path/labels.zip" 92 | 93 | echo "Downloading geojson metadata" 94 | wget -O "$full_path/geojson.zip" $geo_json_url 95 | unzip "$full_path/geojson.zip" -d "$full_path" 96 | rm "$full_path/geojson.zip" 97 | 98 | 99 | echo "Downloading train/test split filenames" 100 | 101 | wget -O "$full_path/train_filenames.lst" $train_filenames_url 102 | wget -O "$full_path/test_filenames.lst" $test_filenames_url 103 | 104 | # Check if the download was successful 105 | if [ $? -eq 0 ]; then 106 | echo "Downloading finished normally." 107 | else 108 | echo "Downloading failed." 109 | fi 110 | 111 | echo "Process finished. Exiting!" -------------------------------------------------------------------------------- /downloading_scripts/waititu.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | if [ -z "$1" ]; then 3 | echo "No root path. Exiting!" 4 | exit 1 5 | fi 6 | root_folder_path=$1 7 | 8 | download_url=https://zenodo.org/record/7648984/files/data_treespecies_waitutu_nz.zip?download=1 9 | 10 | full_path="$root_folder_path/Waititu" 11 | 12 | mkdir $full_path 13 | 14 | echo "Downloading Data" 15 | 16 | wget -O "$full_path/waititu.zip" $download_url 17 | 18 | # Check if the download was successful 19 | if [ $? -eq 0 ]; then 20 | echo "Downloading finished normally." 21 | else 22 | echo "Downloading failed." 23 | fi 24 | 25 | unzip "$full_path/waititu.zip" -d $full_path 26 | 27 | echo "Removing zip files" 28 | 29 | rm "$full_path/waititu.zip" -------------------------------------------------------------------------------- /downloading_scripts/wildforest.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | if [ -z "$1" ]; then 3 | echo "No root path. Exiting!" 4 | exit 1 5 | fi 6 | root_folder_path=$1 7 | 8 | wildforest_url=https://raw.githubusercontent.com/ekalinicheva/multi_layer_vegetation/main/DATASET/WildForest3D.zip 9 | abbreviations_url=https://raw.githubusercontent.com/ekalinicheva/multi_layer_vegetation/main/DATASET/Abbreviations_species.xlsx 10 | 11 | #Filenames for each url 12 | filenames=("wildforest.zip" "abbreviations.zip") 13 | urls=($wildforest_url $abbreviations_url) 14 | 15 | #Create folder if it doesn't exist 16 | 17 | #Download each url 18 | full_path=$root_folder_path/"WildForest" 19 | mkdir $full_path 20 | 21 | echo "Downloading WildForest3D!" 22 | wget -O "$full_path/wildforest.zip" $wildforest_url 23 | unzip "$full_path/wildforest.zip" -d $full_path 24 | rm "$full_path/wildforest.zip" 25 | 26 | echo "WildForest3D downloaded and extracted successfully!" 27 | 28 | echo "Downloading abbreviations!" 29 | wget -O "$full_path/Abbreviations_species.xlsx" $abbreviations_url 30 | 31 | # Check if the download was successful 32 | if [ $? -eq 0 ]; then 33 | echo "Downloading finished normally." 34 | else 35 | echo "Downloading failed." 36 | fi 37 | 38 | echo "Process finished. Exiting!" -------------------------------------------------------------------------------- /downloading_scripts/woody.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | if [ -z "$1" ]; then 3 | echo "No root path. Exiting!" 4 | exit 1 5 | fi 6 | root_folder_path=$1 7 | 8 | #Urls to download 9 | pinus_url="https://zenodo.org/record/7565500/files/data_pinus_radiata.zip?download=1" 10 | europaeus_url="https://zenodo.org/record/7565490/files/data_ulex_europaeus.zip?download=1" 11 | acacia_url="https://zenodo.org/record/7565546/files/data_acacia_dealbata.zip?download=1" 12 | 13 | #Filenames for each url 14 | filenames=("data_pinus_radiata.zip" "data_ulex_europaeus.zip" "data_acacia_dealbata.zip") 15 | urls=($pinus_url $europaeus_url $acacia_url) 16 | 17 | #Create folder if it doesn't exist 18 | dataset_folder="Woody" 19 | mkdir $root_folder_path/$dataset_folder 20 | 21 | #Download each url 22 | full_path="$root_folder_path/$dataset_folder" 23 | for ((index=0;index<${#filenames[@]}; index++)) 24 | do 25 | echo "Saving to $full_path/"${filenames[index]}"" 26 | clean_name=${filenames[index]%.zip} 27 | # Use the curl command to download the file 28 | wget -O "$full_path"/"${filenames[index]}" "${urls[index]}" 29 | unzip "$full_path"/"${filenames[index]}" -d "$full_path" 30 | rm "$full_path"/"${filenames[index]}" 31 | echo "File "${filenames[index]}" downloaded successfully!" 32 | done 33 | # Check if the download was successful 34 | if [ $? -eq 0 ]; then 35 | echo "Download complete." 36 | else 37 | echo "Download failed." 38 | fi 39 | -------------------------------------------------------------------------------- /examples/pretrained_fomo_example.py: -------------------------------------------------------------------------------- 1 | from model_zoo import multimodal_mae 2 | import torch 3 | import torch.nn as nn 4 | import pyjson5 as json 5 | import argparse 6 | 7 | 8 | def construct_fomo_configs(args): 9 | ''' 10 | Construct configurations for FoMo_1 model 11 | ''' 12 | 13 | configs = { 14 | "image_size":args.image_size, 15 | "patch_size":args.patch_size, 16 | "dim":args.dim, 17 | "depth":args.depth, 18 | "heads":args.heads, 19 | "mlp_dim":args.mlp_dim, 20 | "num_classes":args.num_classes, 21 | "single_embedding_layer":True, 22 | } 23 | 24 | #Update configs with modality specific configurations as defined during pretraining 25 | 26 | modality_configs = json.load(open(args.modality_configs,'r')) 27 | configs.update(modality_configs) 28 | 29 | return configs 30 | 31 | if __name__ == "__main__": 32 | 33 | parser = argparse.ArgumentParser() 34 | parser.add_argument("--config", default=None) 35 | parser.add_argument("--checkpoint_path",default="fomo_single_embedding_layer_weights.pt") 36 | parser.add_argument("--modality_configs", default="configs/datasets/fomo_pretraining_datasets.json") 37 | parser.add_argument("--image_size", default=64, type=int) 38 | parser.add_argument("--patch_size", default=16, type=int) 39 | parser.add_argument("--dim", default=768, type=int) 40 | parser.add_argument("--depth", default=12, type=int) 41 | parser.add_argument("--heads", default=12, type=int) 42 | parser.add_argument("--mlp_dim", default=2048, type=int) 43 | parser.add_argument("--num_classes", default=1000, type=int) 44 | parser.add_argument("--single_embedding_layer", default=True, type=bool) 45 | 46 | args = parser.parse_args() 47 | 48 | configs = construct_fomo_configs(args) 49 | 50 | #Initialize FoMo model 51 | v = multimodal_mae.MultiSpectralViT( 52 | image_size=configs["image_size"], 53 | patch_size=configs["patch_size"], 54 | channels=1, 55 | num_classes=configs["num_classes"], 56 | dim=configs["dim"], 57 | depth=configs["depth"], 58 | heads=configs["heads"], 59 | mlp_dim=configs["mlp_dim"], 60 | configs=configs, 61 | ) 62 | v.load_state_dict(torch.load( args.checkpoint_path,map_location='cpu')) 63 | -------------------------------------------------------------------------------- /examples/pretrained_fomobench_example.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import timm 4 | import argparse 5 | import pyjson5 as json 6 | 7 | 8 | if __name__ == "__main__": 9 | 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument("--dataset_config", default="configs/datasets/tallos.json") 12 | parser.add_argument("--model_config", default="configs/method/classification/convnext.json") #Backbone configuration file. See configs/method directory 13 | parser.add_argument("--checkpoint_path",default="YOUR_CHECKPOINT_PATH.pt") 14 | 15 | 16 | args = parser.parse_args() 17 | 18 | checkpoint_path = args.checkpoint_path 19 | 20 | #load task specific configs 21 | with open(args.dataset_config, "r") as f: 22 | configs = json.load(f) 23 | 24 | #load model specific configs 25 | with open(args.model_config, "r") as f: 26 | model_configs = json.load(f) 27 | 28 | configs.update(model_configs) 29 | 30 | model = timm.create_model( 31 | configs["backbone"].lower(), 32 | pretrained=False, 33 | in_chans=configs["in_channels"], 34 | num_classes=configs["num_classes"], 35 | ) 36 | 37 | #Load pretrained model 38 | 39 | model.load_state_dict(torch.load(checkpoint_path, map_location="cpu")) 40 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import pprint 4 | import random 5 | 6 | import matplotlib.pyplot as plt 7 | import numpy as np 8 | import torch 9 | import wandb 10 | import utilities 11 | import utilities.utils as utils 12 | import argparse 13 | 14 | import torch.multiprocessing 15 | torch.multiprocessing.set_sharing_strategy('file_system') 16 | 17 | 18 | def seed_everything(seed): 19 | print("Setting seed to {}".format(seed)) 20 | np.random.seed(seed) 21 | torch.manual_seed(seed) 22 | random.seed(seed) 23 | 24 | 25 | def store_experiment_id(configs): 26 | json.dump( 27 | {"run_id": configs["wandb_run_id"]}, 28 | open(os.path.join(configs["checkpoint_path"], "id.json"), "w"), 29 | ) 30 | 31 | 32 | def init_wandb(configs): 33 | if configs["wandb_id_resume"] is None: 34 | id = wandb.util.generate_id() 35 | else: 36 | id = configs["wandb_id_resume"] 37 | wandb.init( 38 | project=configs["wandb_project"], 39 | entity=configs["wandb_entity"], 40 | config=configs, 41 | id=id, 42 | resume="allow", 43 | ) 44 | run = wandb.run 45 | name = run.name 46 | configs["wandb_run_name"] = name 47 | configs["wandb_run_id"] = id 48 | if "checkpoint_path" not in configs.keys(): 49 | checkpoint_path = utils.create_checkpoint_path(configs) 50 | configs["checkpoint_path"] = checkpoint_path 51 | store_experiment_id(configs) 52 | return configs 53 | 54 | 55 | def init_offline_experiment(configs): 56 | configs["wandb_run_name"] = "offline" 57 | configs["wandb_run_id"] = "None" 58 | if "checkpoint_path" not in configs.keys(): 59 | checkpoint_path = utils.create_checkpoint_path(configs) 60 | configs["checkpoint_path"] = checkpoint_path 61 | return configs 62 | 63 | 64 | if __name__ == "__main__": 65 | parser = argparse.ArgumentParser() 66 | parser.add_argument("--config", default=None) 67 | parser.add_argument("--training_config", default=None) 68 | parser.add_argument("--dataset_config", default=None) 69 | parser.add_argument("--method_config", default=None) 70 | parser.add_argument("--augmentation_config", default=None) 71 | parser.add_argument("--wandb_id_resume", default=None) 72 | parser.add_argument("--seed", default=0, type=int) 73 | 74 | args = parser.parse_args() 75 | 76 | # Setup configurations 77 | configs = utils.load_configs(args) 78 | pprint.pprint(configs) 79 | if args.seed is not None: 80 | configs["seed"] = args.seed 81 | seed_everything(configs["seed"]) 82 | 83 | # Setup wandb 84 | if configs["wandb"] and not configs["distributed"]: 85 | configs = init_wandb(configs) 86 | else: 87 | configs = init_offline_experiment(configs) 88 | 89 | trainer, tester = utils.create_procedures(configs) 90 | 91 | if configs["phase"] == "train": 92 | trainer(configs) 93 | if tester is not None: 94 | _, _, loader = utils.create_dataloaders(configs) 95 | if isinstance(loader, list): 96 | dataset_names = configs['dataset_names'].split(',') 97 | for i, sub_loader in enumerate(loader): 98 | tester(configs, loader=sub_loader, phase="test", dataset_name=dataset_names[i]) 99 | else: 100 | tester(configs, loader=loader, phase="test") 101 | -------------------------------------------------------------------------------- /model_zoo/faster_rcnn/__pycache__/faster_rcnn.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RolnickLab/FoMo-Bench/93f7218a4bf928b50e0eaf827f74d1d8c79f27e5/model_zoo/faster_rcnn/__pycache__/faster_rcnn.cpython-310.pyc -------------------------------------------------------------------------------- /model_zoo/faster_rcnn/__pycache__/generalized_rcnn.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RolnickLab/FoMo-Bench/93f7218a4bf928b50e0eaf827f74d1d8c79f27e5/model_zoo/faster_rcnn/__pycache__/generalized_rcnn.cpython-310.pyc -------------------------------------------------------------------------------- /model_zoo/faster_rcnn/generalized_rcnn.py: -------------------------------------------------------------------------------- 1 | """ 2 | Implements the Generalized R-CNN framework 3 | """ 4 | 5 | import warnings 6 | from collections import OrderedDict 7 | from typing import Dict, List, Optional, Tuple, Union 8 | 9 | import torch 10 | from torch import nn, Tensor 11 | import transformers # used for DinoV2 12 | from torchvision.models.detection.transform import ImageList 13 | 14 | from torchvision.utils import _log_api_usage_once 15 | 16 | 17 | class GeneralizedRCNN(nn.Module): 18 | """ 19 | Main class for Generalized R-CNN. 20 | 21 | Args: 22 | backbone (nn.Module): 23 | rpn (nn.Module): 24 | roi_heads (nn.Module): takes the features + the proposals from the RPN and computes 25 | detections / masks from it. 26 | transform (nn.Module): performs the data transformation from the inputs to feed into 27 | the model 28 | """ 29 | 30 | def __init__( 31 | self, backbone: nn.Module, rpn: nn.Module, roi_heads: nn.Module, transform: nn.Module, embedding_shapes: dict 32 | ) -> None: 33 | super().__init__() 34 | _log_api_usage_once(self) 35 | self.transform = transform 36 | self.backbone = backbone 37 | self.rpn = rpn 38 | self.roi_heads = roi_heads 39 | self.embedding_shapes = embedding_shapes 40 | # used only on torchscript mode 41 | self._has_warned = False 42 | 43 | @torch.jit.unused 44 | def eager_outputs(self, losses, detections): 45 | # type: (Dict[str, Tensor], List[Dict[str, Tensor]]) -> Union[Dict[str, Tensor], List[Dict[str, Tensor]]] 46 | if self.training: 47 | return losses 48 | 49 | return detections 50 | 51 | def forward(self, images, spectral_keys=None, targets=None): 52 | # type: (List[Tensor], Optional[List[Dict[str, Tensor]]]) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]] 53 | """ 54 | Args: 55 | images (list[Tensor]): images to be processed 56 | targets (list[Dict[str, Tensor]]): ground-truth boxes present in the image (optional) 57 | spectral_keys (list[int]): special keys to fine tune FoMo Bench foundation model 58 | 59 | Returns: 60 | result (list[BoxList] or dict[Tensor]): the output from the model. 61 | During training, it returns a dict[Tensor] which contains the losses. 62 | During testing, it returns list[BoxList] contains additional fields 63 | like `scores`, `labels` and `mask` (for Mask R-CNN models). 64 | 65 | """ 66 | if self.training: 67 | if targets is None: 68 | torch._assert(False, "targets should not be none when in training mode") 69 | else: 70 | for target in targets: 71 | boxes = target["boxes"] 72 | if isinstance(boxes, torch.Tensor): 73 | torch._assert( 74 | len(boxes.shape) == 2 and boxes.shape[-1] == 4, 75 | f"Expected target boxes to be a tensor of shape [N, 4], got {boxes.shape}.", 76 | ) 77 | else: 78 | torch._assert(False, f"Expected target boxes to be of type Tensor, got {type(boxes)}.") 79 | 80 | original_image_sizes: List[Tuple[int, int]] = [] 81 | for img in images: 82 | val = img.shape[-2:] 83 | torch._assert( 84 | len(val) == 2, 85 | f"expecting the last two dimensions of the Tensor to be H and W instead got {img.shape[-2:]}", 86 | ) 87 | original_image_sizes.append((val[0], val[1])) 88 | 89 | if not isinstance(self.backbone, transformers.models.dinov2.modeling_dinov2.Dinov2Model): 90 | images, targets = self.transform(images, targets) 91 | 92 | # Check for degenerate boxes 93 | # TODO: Move this to a function 94 | if targets is not None: 95 | for target_idx, target in enumerate(targets): 96 | boxes = target["boxes"] 97 | degenerate_boxes = boxes[:, 2:] <= boxes[:, :2] 98 | if degenerate_boxes.any(): 99 | # print the first degenerate box 100 | bb_idx = torch.where(degenerate_boxes.any(dim=1))[0][0] 101 | degen_bb: List[float] = boxes[bb_idx].tolist() 102 | torch._assert( 103 | False, 104 | "All bounding boxes should have positive height and width." 105 | f" Found invalid box {degen_bb} for target at index {target_idx}.", 106 | ) 107 | 108 | if spectral_keys: 109 | # Custom backbone: FoMo Net foundation model 110 | # Requires 'spectral_keys' 111 | features = self.backbone((images.tensors, spectral_keys)) 112 | 113 | elif isinstance(self.backbone, transformers.models.dinov2.modeling_dinov2.Dinov2Model): 114 | # This is specific to DinoV2 115 | # Preprocessing is made outside the class 116 | images = torch.stack(images) 117 | img_sizes = list(images.size()) 118 | images = ImageList(images, [img_sizes[2:]] * img_sizes[0]) 119 | features = self.backbone(images.tensors, output_hidden_states=True) 120 | 121 | # Method to create squared patches and then resample to iamge size 122 | # features = features['last_hidden_state'][:, :-1, :] # from 257 to 256 123 | features = features["last_hidden_state"] 124 | # Create patches 125 | features = features.view(self.embedding_shapes["batch_size"], 16, 16, -1).permute(0, 3, 1, 2) 126 | features = nn.functional.interpolate( 127 | features, (self.embedding_shapes["output_size"], self.embedding_shapes["output_size"]) 128 | ) 129 | 130 | # Potential extension => re-order the squared features to map the orginal image 131 | # Ie try to create the inverse transformation from patchinzation of ViT 132 | 133 | """ 134 | # Method to interpolate features to square => too much interpolation 135 | features = torch.stack(features['hidden_states'], dim=1) 136 | # for squared features: 137 | features = nn.functional.interpolate(features, (features.shape[3], features.shape[3])) 138 | # resize to input image sizes 139 | features = nn.functional.interpolate(features, (self.embedding_shapes['output_size'], 140 | self.embedding_shapes['output_size'])) 141 | """ 142 | else: 143 | features = self.backbone(images.tensors) 144 | 145 | if isinstance(features, torch.Tensor): 146 | features = OrderedDict([("0", features)]) 147 | 148 | proposals, proposal_losses = self.rpn(images, features, targets) 149 | detections, detector_losses = self.roi_heads(features, proposals, images.image_sizes, targets) 150 | detections = self.transform.postprocess(detections, images.image_sizes, original_image_sizes) # type: ignore[operator] 151 | 152 | losses = {} 153 | losses.update(detector_losses) 154 | losses.update(proposal_losses) 155 | 156 | if torch.jit.is_scripting(): 157 | if not self._has_warned: 158 | warnings.warn("RCNN always returns a (Losses, Detections) tuple in scripting") 159 | self._has_warned = True 160 | return losses, detections 161 | else: 162 | return self.eager_outputs(losses, detections) 163 | -------------------------------------------------------------------------------- /model_zoo/faster_rcnn/generalized_rcnn.py~: -------------------------------------------------------------------------------- 1 | """ 2 | Implements the Generalized R-CNN framework 3 | """ 4 | 5 | import warnings 6 | from collections import OrderedDict 7 | from typing import Dict, List, Optional, Tuple, Union 8 | 9 | import torch 10 | from torch import nn, Tensor 11 | import transformers # used for DinoV2 12 | from torchvision.models.detection.transform import ImageList 13 | 14 | from torchvision.utils import _log_api_usage_once 15 | 16 | 17 | class GeneralizedRCNN(nn.Module): 18 | """ 19 | Main class for Generalized R-CNN. 20 | 21 | Args: 22 | backbone (nn.Module): 23 | rpn (nn.Module): 24 | roi_heads (nn.Module): takes the features + the proposals from the RPN and computes 25 | detections / masks from it. 26 | transform (nn.Module): performs the data transformation from the inputs to feed into 27 | the model 28 | """ 29 | 30 | def __init__( 31 | self, backbone: nn.Module, rpn: nn.Module, roi_heads: nn.Module, transform: nn.Module, embedding_shapes: dict 32 | ) -> None: 33 | super().__init__() 34 | _log_api_usage_once(self) 35 | self.transform = transform 36 | self.backbone = backbone 37 | self.rpn = rpn 38 | self.roi_heads = roi_heads 39 | self.embedding_shapes = embedding_shapes 40 | # used only on torchscript mode 41 | self._has_warned = False 42 | 43 | @torch.jit.unused 44 | def eager_outputs(self, losses, detections): 45 | # type: (Dict[str, Tensor], List[Dict[str, Tensor]]) -> Union[Dict[str, Tensor], List[Dict[str, Tensor]]] 46 | if self.training: 47 | return losses 48 | 49 | return detections 50 | 51 | def forward(self, images, spectral_keys=None, targets=None): 52 | # type: (List[Tensor], Optional[List[Dict[str, Tensor]]]) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]] 53 | """ 54 | Args: 55 | images (list[Tensor]): images to be processed 56 | targets (list[Dict[str, Tensor]]): ground-truth boxes present in the image (optional) 57 | spectral_keys (list[int]): special keys to fine tune FoMo Bench foundation model 58 | 59 | Returns: 60 | result (list[BoxList] or dict[Tensor]): the output from the model. 61 | During training, it returns a dict[Tensor] which contains the losses. 62 | During testing, it returns list[BoxList] contains additional fields 63 | like `scores`, `labels` and `mask` (for Mask R-CNN models). 64 | 65 | """ 66 | if self.training: 67 | if targets is None: 68 | torch._assert(False, "targets should not be none when in training mode") 69 | else: 70 | for target in targets: 71 | boxes = target["boxes"] 72 | if isinstance(boxes, torch.Tensor): 73 | torch._assert( 74 | len(boxes.shape) == 2 and boxes.shape[-1] == 4, 75 | f"Expected target boxes to be a tensor of shape [N, 4], got {boxes.shape}.", 76 | ) 77 | else: 78 | torch._assert(False, f"Expected target boxes to be of type Tensor, got {type(boxes)}.") 79 | 80 | original_image_sizes: List[Tuple[int, int]] = [] 81 | for img in images: 82 | val = img.shape[-2:] 83 | torch._assert( 84 | len(val) == 2, 85 | f"expecting the last two dimensions of the Tensor to be H and W instead got {img.shape[-2:]}", 86 | ) 87 | original_image_sizes.append((val[0], val[1])) 88 | 89 | if not isinstance(self.backbone, transformers.models.dinov2.modeling_dinov2.Dinov2Model): 90 | images, targets = self.transform(images, targets) 91 | 92 | # Check for degenerate boxes 93 | # TODO: Move this to a function 94 | if targets is not None: 95 | for target_idx, target in enumerate(targets): 96 | boxes = target["boxes"] 97 | degenerate_boxes = boxes[:, 2:] <= boxes[:, :2] 98 | if degenerate_boxes.any(): 99 | # print the first degenerate box 100 | bb_idx = torch.where(degenerate_boxes.any(dim=1))[0][0] 101 | degen_bb: List[float] = boxes[bb_idx].tolist() 102 | torch._assert( 103 | False, 104 | "All bounding boxes should have positive height and width." 105 | f" Found invalid box {degen_bb} for target at index {target_idx}.", 106 | ) 107 | 108 | if spectral_keys: 109 | # Custom backbone: FoMo Net foundation model 110 | # Requires 'spectral_keys' 111 | features = self.backbone((images.tensors, spectral_keys)) 112 | 113 | # Previous method // now included in model_utilities 114 | # features = torch.reshape(features, (self.embedding_shapes['batch_size'], 115 | # self.embedding_shapes['out_channels'], 116 | # self.embedding_shapes['output_size'], 117 | # self.embedding_shapes['output_size'])) 118 | 119 | elif isinstance(self.backbone, transformers.models.dinov2.modeling_dinov2.Dinov2Model): 120 | # with torch.no_grad(): # to freeze the backbone?? 121 | # This is specific to DinoV2 122 | # Preprocessing is made outside the class 123 | images = torch.stack(images) 124 | img_sizes = list(images.size()) 125 | images = ImageList(images, [img_sizes[2:]] * img_sizes[0]) 126 | features = self.backbone(images.tensors, output_hidden_states=True) 127 | 128 | # Method to create squared patches and then resample to iamge size 129 | # features = features['last_hidden_state'][:, :-1, :] # from 257 to 256 130 | features = features["last_hidden_state"] 131 | # Create patches 132 | features = features.view(self.embedding_shapes["batch_size"], 16, 16, -1).permute(0, 3, 1, 2) 133 | features = nn.functional.interpolate( 134 | features, (self.embedding_shapes["output_size"], self.embedding_shapes["output_size"]) 135 | ) 136 | 137 | # Potential extension => re-order the squared features to map the orginal image 138 | # Ie try to create the inverse transformation from patchinzation of ViT 139 | 140 | """ 141 | # Method to interpolate features to square => too much interpolation 142 | features = torch.stack(features['hidden_states'], dim=1) 143 | # for squared features: 144 | features = nn.functional.interpolate(features, (features.shape[3], features.shape[3])) 145 | # resize to input image sizes 146 | features = nn.functional.interpolate(features, (self.embedding_shapes['output_size'], 147 | self.embedding_shapes['output_size'])) 148 | """ 149 | else: 150 | features = self.backbone(images.tensors) 151 | 152 | if isinstance(features, torch.Tensor): 153 | features = OrderedDict([("0", features)]) 154 | 155 | proposals, proposal_losses = self.rpn(images, features, targets) 156 | detections, detector_losses = self.roi_heads(features, proposals, images.image_sizes, targets) 157 | detections = self.transform.postprocess(detections, images.image_sizes, original_image_sizes) # type: ignore[operator] 158 | 159 | losses = {} 160 | losses.update(detector_losses) 161 | losses.update(proposal_losses) 162 | 163 | if torch.jit.is_scripting(): 164 | if not self._has_warned: 165 | warnings.warn("RCNN always returns a (Losses, Detections) tuple in scripting") 166 | self._has_warned = True 167 | return losses, detections 168 | else: 169 | return self.eager_outputs(losses, detections) 170 | -------------------------------------------------------------------------------- /model_zoo/point_transformer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Credits to pytorch_geometric. 3 | Repo: git@github.com:pyg-team/pytorch_geometric.git 4 | """ 5 | import os.path as osp 6 | 7 | import torch 8 | import torch.nn.functional as F 9 | from torch.nn import Linear as Lin 10 | from torch_geometric.nn import ( 11 | MLP, 12 | PointTransformerConv, 13 | fps, 14 | global_mean_pool, 15 | knn, 16 | knn_graph, 17 | knn_interpolate, 18 | ) 19 | from torch_geometric.utils import scatter 20 | from torch_geometric.typing import WITH_TORCH_CLUSTER 21 | 22 | if not WITH_TORCH_CLUSTER: 23 | quit("This example requires 'torch-cluster'") 24 | 25 | 26 | class TransformerBlock(torch.nn.Module): 27 | def __init__(self, in_channels, out_channels): 28 | super().__init__() 29 | self.lin_in = Lin(in_channels, in_channels) 30 | self.lin_out = Lin(out_channels, out_channels) 31 | 32 | self.pos_nn = MLP([3, 64, out_channels], norm=None, plain_last=False) 33 | 34 | self.attn_nn = MLP([out_channels, 64, out_channels], norm=None, 35 | plain_last=False) 36 | 37 | self.transformer = PointTransformerConv(in_channels, out_channels, 38 | pos_nn=self.pos_nn, 39 | attn_nn=self.attn_nn) 40 | 41 | def forward(self, x, pos, edge_index): 42 | x = self.lin_in(x).relu() 43 | x = self.transformer(x, pos, edge_index) 44 | x = self.lin_out(x).relu() 45 | return x 46 | 47 | 48 | class TransitionDown(torch.nn.Module): 49 | ''' 50 | Samples the input point cloud by a ratio percentage to reduce 51 | cardinality and uses an mlp to augment features dimensionnality 52 | ''' 53 | def __init__(self, in_channels, out_channels, ratio=0.25, k=16): 54 | super().__init__() 55 | self.k = k 56 | self.ratio = ratio 57 | self.mlp = MLP([in_channels, out_channels], plain_last=False) 58 | 59 | def forward(self, x, pos, batch): 60 | # FPS sampling 61 | id_clusters = fps(pos, ratio=self.ratio, batch=batch) 62 | 63 | # compute for each cluster the k nearest points 64 | sub_batch = batch[id_clusters] if batch is not None else None 65 | 66 | # beware of self loop 67 | id_k_neighbor = knn(pos, pos[id_clusters], k=self.k, batch_x=batch, 68 | batch_y=sub_batch) 69 | 70 | # transformation of features through a simple MLP 71 | x = self.mlp(x) 72 | 73 | # Max pool onto each cluster the features from knn in points 74 | x_out = scatter(x[id_k_neighbor[1]], id_k_neighbor[0], dim=0, 75 | dim_size=id_clusters.size(0), reduce='max') 76 | 77 | # keep only the clusters and their max-pooled features 78 | sub_pos, out = pos[id_clusters], x_out 79 | return out, sub_pos, sub_batch 80 | 81 | 82 | class TransitionUp(torch.nn.Module): 83 | ''' 84 | Reduce features dimensionnality and interpolate back to higher 85 | resolution and cardinality 86 | ''' 87 | def __init__(self, in_channels, out_channels): 88 | super().__init__() 89 | self.mlp_sub = MLP([in_channels, out_channels], plain_last=False) 90 | self.mlp = MLP([out_channels, out_channels], plain_last=False) 91 | 92 | def forward(self, x, x_sub, pos, pos_sub, batch=None, batch_sub=None): 93 | # transform low-res features and reduce the number of features 94 | x_sub = self.mlp_sub(x_sub) 95 | 96 | # interpolate low-res feats to high-res points 97 | x_interpolated = knn_interpolate(x_sub, pos_sub, pos, k=3, 98 | batch_x=batch_sub, batch_y=batch) 99 | 100 | x = self.mlp(x) + x_interpolated 101 | 102 | return x 103 | 104 | 105 | class PointTransformer(torch.nn.Module): 106 | def __init__(self, in_channels, out_channels, dim_model, k=16): 107 | super().__init__() 108 | self.k = k 109 | 110 | # dummy feature is created if there is none given 111 | in_channels = max(in_channels, 1) 112 | 113 | # first block 114 | self.mlp_input = MLP([in_channels, dim_model[0]], plain_last=False) 115 | 116 | self.transformer_input = TransformerBlock( 117 | in_channels=dim_model[0], 118 | out_channels=dim_model[0], 119 | ) 120 | 121 | # backbone layers 122 | self.transformers_up = torch.nn.ModuleList() 123 | self.transformers_down = torch.nn.ModuleList() 124 | self.transition_up = torch.nn.ModuleList() 125 | self.transition_down = torch.nn.ModuleList() 126 | 127 | for i in range(0, len(dim_model) - 1): 128 | 129 | # Add Transition Down block followed by a Point Transformer block 130 | self.transition_down.append( 131 | TransitionDown(in_channels=dim_model[i], 132 | out_channels=dim_model[i + 1], k=self.k)) 133 | 134 | self.transformers_down.append( 135 | TransformerBlock(in_channels=dim_model[i + 1], 136 | out_channels=dim_model[i + 1])) 137 | 138 | # Add Transition Up block followed by Point Transformer block 139 | self.transition_up.append( 140 | TransitionUp(in_channels=dim_model[i + 1], 141 | out_channels=dim_model[i])) 142 | 143 | self.transformers_up.append( 144 | TransformerBlock(in_channels=dim_model[i], 145 | out_channels=dim_model[i])) 146 | 147 | # summit layers 148 | self.mlp_summit = MLP([dim_model[-1], dim_model[-1]], norm=None, 149 | plain_last=False) 150 | 151 | self.transformer_summit = TransformerBlock( 152 | in_channels=dim_model[-1], 153 | out_channels=dim_model[-1], 154 | ) 155 | 156 | # class score computation 157 | self.mlp_output = MLP([dim_model[0], 64, out_channels], norm=None) 158 | 159 | def forward(self, x, pos, batch=None): 160 | 161 | # add dummy features in case there is none 162 | if x is None: 163 | x = torch.ones((pos.shape[0], 1)).to(pos.get_device()) 164 | 165 | out_x = [] 166 | out_pos = [] 167 | out_batch = [] 168 | 169 | # first block 170 | x = self.mlp_input(x) 171 | edge_index = knn_graph(pos, k=self.k, batch=batch) 172 | x = self.transformer_input(x, pos, edge_index) 173 | 174 | # save outputs for skipping connections 175 | out_x.append(x) 176 | out_pos.append(pos) 177 | out_batch.append(batch) 178 | 179 | # backbone down : #reduce cardinality and augment dimensionnality 180 | for i in range(len(self.transformers_down)): 181 | x, pos, batch = self.transition_down[i](x, pos, batch=batch) 182 | edge_index = knn_graph(pos, k=self.k, batch=batch) 183 | x = self.transformers_down[i](x, pos, edge_index) 184 | 185 | out_x.append(x) 186 | out_pos.append(pos) 187 | out_batch.append(batch) 188 | 189 | # summit 190 | x = self.mlp_summit(x) 191 | edge_index = knn_graph(pos, k=self.k, batch=batch) 192 | x = self.transformer_summit(x, pos, edge_index) 193 | 194 | # backbone up : augment cardinality and reduce dimensionnality 195 | n = len(self.transformers_down) 196 | for i in range(n): 197 | x = self.transition_up[-i - 1](x=out_x[-i - 2], x_sub=x, 198 | pos=out_pos[-i - 2], 199 | pos_sub=out_pos[-i - 1], 200 | batch_sub=out_batch[-i - 1], 201 | batch=out_batch[-i - 2]) 202 | 203 | edge_index = knn_graph(out_pos[-i - 2], k=self.k, 204 | batch=out_batch[-i - 2]) 205 | x = self.transformers_up[-i - 1](x, out_pos[-i - 2], edge_index) 206 | 207 | # Class score 208 | out = self.mlp_output(x) 209 | 210 | return F.log_softmax(out, dim=-1) 211 | 212 | -------------------------------------------------------------------------------- /model_zoo/pointnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch.nn import Sequential, Linear, ReLU 4 | from torch_geometric.nn import MessagePassing, global_max_pool 5 | from torch_cluster import knn_graph 6 | from torch_geometric.nn import PPFConv 7 | from torch_cluster import fps 8 | 9 | 10 | class PointNetLayer(MessagePassing): 11 | def __init__(self, in_dim, in_channels, out_channels): 12 | # Message passing with "max" aggregation. 13 | super().__init__(aggr="max") 14 | 15 | # Initialization of the MLP: 16 | # Here, the number of input features correspond to the hidden node 17 | # dimensionality plus point dimensionality (=3). 18 | self.mlp = Sequential(Linear(in_channels + in_dim, out_channels), ReLU(), Linear(out_channels, out_channels)) 19 | 20 | def forward(self, h, pos, edge_index): 21 | # Start propagating messages. 22 | return self.propagate(edge_index, h=h, pos=pos) 23 | 24 | def message(self, h_j, pos_j, pos_i): 25 | # h_j defines the features of neighboring nodes as shape [num_edges, in_channels] 26 | # pos_j defines the position of neighboring nodes as shape [num_edges, 3] 27 | # pos_i defines the position of central nodes as shape [num_edges, 3] 28 | 29 | input = pos_j - pos_i # Compute spatial relation. 30 | 31 | if h_j is not None: 32 | # In the first layer, we may not have any hidden node features, 33 | # so we only combine them in case they are present. 34 | input = torch.cat([h_j, input], dim=-1) 35 | 36 | return self.mlp(input) # Apply our final MLP. 37 | 38 | 39 | class PointNet(torch.nn.Module): 40 | def __init__(self, nb_classes, in_channels=3, in_dim=3): 41 | super().__init__() 42 | 43 | torch.manual_seed(12345) 44 | self.conv1 = PointNetLayer(in_dim, in_channels, 32) 45 | self.conv2 = PointNetLayer(in_dim, 32, 32) 46 | self.classifier = Linear(32, nb_classes) 47 | 48 | def forward(self, pos, batch=None): 49 | # Compute the kNN graph: 50 | # Here, we need to pass the batch vector to the function call in order 51 | # to prevent creating edges between points of different examples. 52 | # We also add `loop=True` which will add self-loops to the graph in 53 | # order to preserve central point information. 54 | edge_index = knn_graph(pos, k=16, batch=batch, loop=True) 55 | 56 | # 3. Start bipartite message passing. 57 | h = self.conv1(h=pos, pos=pos, edge_index=edge_index) 58 | h = h.relu() 59 | h = self.conv2(h=h, pos=pos, edge_index=edge_index) 60 | h = h.relu() 61 | 62 | # 4. Global Pooling. 63 | # h = global_max_pool(h, batch) # [num_examples, hidden_channels] 64 | 65 | # 5. Classifier. 66 | return self.classifier(h) 67 | 68 | 69 | class PPFNet(torch.nn.Module): 70 | def __init__(self, nb_classes, in_channels=3, n_dims=3): 71 | super().__init__() 72 | 73 | # torch.manual_seed(12345) 74 | mlp1 = Sequential(Linear(in_channels + n_dims, 32), ReLU(), Linear(32, 32)) 75 | self.conv1 = PPFConv(mlp1) # TODO 76 | mlp2 = Sequential(Linear(32 + n_dims, 32), ReLU(), Linear(32, 32)) 77 | self.conv2 = PPFConv(mlp2) 78 | self.classifier = Linear(32, nb_classes) 79 | 80 | def forward(self, pos, normal, batch=None): 81 | edge_index = knn_graph(pos, k=16, batch=batch, loop=False) 82 | 83 | x = self.conv1(x=None, pos=pos, normal=normal, edge_index=edge_index) 84 | x = x.relu() 85 | x = self.conv2(x=x, pos=pos, normal=normal, edge_index=edge_index) 86 | x = x.relu() 87 | 88 | x = global_max_pool(x, batch) # [num_examples, hidden_channels] 89 | return self.classifier(x) 90 | -------------------------------------------------------------------------------- /model_zoo/pointnet2.py: -------------------------------------------------------------------------------- 1 | """ 2 | Credits to pytorch_geometric. 3 | Repo: git@github.com:pyg-team/pytorch_geometric.git 4 | """ 5 | import torch 6 | 7 | # import torch_geometric.transforms as T 8 | from torch_geometric.nn import MLP, knn_interpolate, PointNetConv, fps, global_max_pool, radius 9 | from torch_geometric.typing import WITH_TORCH_CLUSTER 10 | 11 | 12 | if not WITH_TORCH_CLUSTER: 13 | quit("This PointNet++ implementation requires 'torch-cluster'") 14 | 15 | 16 | class SAModule(torch.nn.Module): 17 | def __init__(self, ratio, r, nn): 18 | super().__init__() 19 | self.ratio = ratio 20 | self.r = r 21 | self.conv = PointNetConv(nn, add_self_loops=False) 22 | 23 | def forward(self, x, pos, batch=None): 24 | idx = fps(pos, batch, ratio=self.ratio) 25 | row, col = radius(pos, pos[idx], self.r, batch, batch[idx], max_num_neighbors=64) 26 | edge_index = torch.stack([col, row], dim=0) 27 | x_dst = None if x is None else x[idx] 28 | x = self.conv((x, x_dst), (pos, pos[idx]), edge_index) 29 | pos, batch = pos[idx], batch[idx] 30 | return x, pos, batch 31 | 32 | 33 | class GlobalSAModule(torch.nn.Module): 34 | def __init__(self, nn): 35 | super().__init__() 36 | self.nn = nn 37 | 38 | def forward(self, x, pos, batch): 39 | x = self.nn(torch.cat([x, pos], dim=1)) 40 | x = global_max_pool(x, batch) 41 | pos = pos.new_zeros((x.size(0), 3)) 42 | batch = torch.arange(x.size(0), device=batch.device) 43 | return x, pos, batch 44 | 45 | 46 | class FPModule(torch.nn.Module): 47 | def __init__(self, k, nn): 48 | super().__init__() 49 | self.k = k 50 | self.nn = nn 51 | 52 | def forward(self, x, pos, batch, x_skip, pos_skip, batch_skip): 53 | x = knn_interpolate(x, pos, pos_skip, batch, batch_skip, k=self.k) 54 | if x_skip is not None: 55 | x = torch.cat([x, x_skip], dim=1) 56 | x = self.nn(x) 57 | return x, pos_skip, batch_skip 58 | 59 | 60 | class PointNet2(torch.nn.Module): 61 | def __init__(self, nb_classes, in_channels=3, in_dim=3): 62 | super().__init__() 63 | 64 | # Input channels account for both `pos` and node features. 65 | self.sa1_module = SAModule(0.2, 0.2, MLP([in_channels + in_dim, 64, 64, 128])) 66 | self.sa2_module = SAModule(0.25, 0.4, MLP([128 + in_channels, 128, 128, 256])) 67 | self.sa3_module = GlobalSAModule(MLP([256 + in_channels, 256, 512, 1024])) 68 | 69 | self.fp3_module = FPModule(1, MLP([1024 + 256, 256, 256])) 70 | self.fp2_module = FPModule(3, MLP([256 + 128, 256, 128])) 71 | self.fp1_module = FPModule(3, MLP([128 + in_channels, 128, 128, 128])) 72 | 73 | self.mlp = MLP([128, 128, 128, nb_classes], dropout=0.5, norm=None) 74 | 75 | self.lin1 = torch.nn.Linear(128, 128) 76 | self.lin2 = torch.nn.Linear(128, 128) 77 | self.lin3 = torch.nn.Linear(128, nb_classes) 78 | 79 | def forward(self, x, pos, batch): 80 | # import ipdb; ipdb.set_trace() 81 | # sa0_out = (data.x, data.pos, data.batch) 82 | sa0_out = (x, pos, batch) 83 | sa1_out = self.sa1_module(*sa0_out) 84 | sa2_out = self.sa2_module(*sa1_out) 85 | sa3_out = self.sa3_module(*sa2_out) 86 | 87 | fp3_out = self.fp3_module(*sa3_out, *sa2_out) 88 | fp2_out = self.fp2_module(*fp3_out, *sa1_out) 89 | x, _, _ = self.fp1_module(*fp2_out, *sa0_out) 90 | 91 | return self.mlp(x).log_softmax(dim=-1) 92 | -------------------------------------------------------------------------------- /model_zoo/segformer.py: -------------------------------------------------------------------------------- 1 | from transformers import SegformerForSemanticSegmentation 2 | import torch 3 | import torch.nn as nn 4 | 5 | backbones = {"mit-b0-ade":"nvidia/mit-b0","mit-b1-ade":"nvidia/segformer-b1-finetuned-ade-512-512","mit-b2-ade":"nvidia/segformer-b2-finetuned-ade-512-512","mit-b3-ade":"nvidia/segformer-b3-finetuned-ade-512-512","mit-b4-ade":"nvidia/segformer-b4-finetuned-ade-512-512","mit-b5-ade":"nvidia/segformer-b5-finetuned-ade-640-640"} 6 | 7 | class Segformer(nn.Module): 8 | def __init__(self,config) -> None: 9 | super().__init__() 10 | if config['backbone'] not in backbones: 11 | print('Backbone: ',config['backbone'],' Not supported!') 12 | exit(2) 13 | self.model = SegformerForSemanticSegmentation.from_pretrained(backbones[config['backbone']],num_labels=config['num_classes'],ignore_mismatched_sizes=True) 14 | 15 | if config['in_channels']!=3: 16 | self.model.segformer.encoder.patch_embeddings[0].proj = nn.Conv2d(config['in_channels'],32,kernel_size=(7,7),stride=(4,4),padding=(3,3)) 17 | 18 | def forward(self,x): 19 | logits = self.model(x)['logits'] 20 | upsampled_logits = nn.functional.interpolate( 21 | logits, 22 | size=(x.shape[2],x.shape[3]), # (height, width) 23 | mode='bilinear', 24 | align_corners=False) 25 | 26 | return upsampled_logits 27 | 28 | ''' 29 | #Example usage: 30 | #================ 31 | config= {'in_channels':5,'num_classes':2,'backbone':'mit-b0-ade'} 32 | k = torch.randn((4,config['in_channels'],224,224)) 33 | model = Segformer(config) 34 | print(model) 35 | logits = model(k) 36 | print(logits.shape)''' -------------------------------------------------------------------------------- /model_zoo/upernet.py: -------------------------------------------------------------------------------- 1 | from transformers import ConvNextConfig, UperNetConfig, UperNetForSemanticSegmentation, AutoConfig 2 | import torch 3 | import torch.nn as nn 4 | 5 | backbones = { 6 | "swin_base": "openmmlab/upernet-swin-base", 7 | "swin_tiny": "openmmlab/upernet-swin-tiny", 8 | "swin_small": "openmmlab/upernet-swin-small", 9 | "convnext_tiny": "openmmlab/upernet-convnext-tiny", 10 | "convnext_small": "openmmlab/upernet-convnext-small", 11 | "convnext_base": "openmmlab/upernet-convnext-base", 12 | } 13 | 14 | 15 | class UperNet(nn.Module): 16 | def __init__(self, config) -> None: 17 | super().__init__() 18 | if config["backbone"] not in backbones: 19 | print("Backbone: ", config["backbone"], " Not supported!") 20 | exit(2) 21 | self.model = UperNetForSemanticSegmentation.from_pretrained( 22 | backbones[config["backbone"]], num_labels=config["num_classes"], ignore_mismatched_sizes=True 23 | ) 24 | if config["in_channels"] != 3: 25 | if "convnext" in config["backbone"]: 26 | out_channels = self.model.backbone.embeddings.patch_embeddings.out_channels 27 | kernel_size = self.model.backbone.embeddings.patch_embeddings.kernel_size 28 | stride = self.model.backbone.embeddings.patch_embeddings.stride 29 | self.model.backbone.embeddings.num_channels = config["in_channels"] 30 | self.model.config.backbone_config.num_channels = config["in_channels"] 31 | self.model.backbone.embeddings.patch_embeddings = nn.Conv2d( 32 | config["in_channels"], out_channels, kernel_size=kernel_size, stride=stride 33 | ) 34 | elif "swin" in config["backbone"]: 35 | out_channels = self.model.backbone.embeddings.patch_embeddings.projection.out_channels 36 | kernel_size = self.model.backbone.embeddings.patch_embeddings.projection.kernel_size 37 | stride = self.model.backbone.embeddings.patch_embeddings.projection.stride 38 | self.model.backbone.embeddings.patch_embeddings.num_channels = config["in_channels"] 39 | self.model.backbone.embeddings.patch_embeddings.projection = nn.Conv2d( 40 | config["in_channels"], out_channels, kernel_size=kernel_size, stride=stride 41 | ) 42 | 43 | def forward(self, x): 44 | x = self.model(x, return_dict=True) 45 | return x["logits"] 46 | 47 | 48 | """ 49 | #Example usage: 50 | #================ 51 | config= {'in_channels':5,'num_classes':2,'backbone':'swin_tiny'} 52 | k = torch.randn((4,config['in_channels'],120,120)) 53 | model = UperNet(config) 54 | logits = model(k) 55 | print(logits.shape) 56 | """ 57 | -------------------------------------------------------------------------------- /model_zoo/yolov5/__init__.py: -------------------------------------------------------------------------------- 1 | # YOLOv5 🚀 by Ultralytics, AGPL-3.0 license 2 | """ 3 | utils/initialization 4 | """ 5 | 6 | import contextlib 7 | import platform 8 | import threading 9 | 10 | 11 | def emojis(str=''): 12 | # Return platform-dependent emoji-safe version of string 13 | return str.encode().decode('ascii', 'ignore') if platform.system() == 'Windows' else str 14 | 15 | 16 | class TryExcept(contextlib.ContextDecorator): 17 | # YOLOv5 TryExcept class. Usage: @TryExcept() decorator or 'with TryExcept():' context manager 18 | def __init__(self, msg=''): 19 | self.msg = msg 20 | 21 | def __enter__(self): 22 | pass 23 | 24 | def __exit__(self, exc_type, value, traceback): 25 | if value: 26 | print(emojis(f"{self.msg}{': ' if self.msg else ''}{value}")) 27 | return True 28 | 29 | 30 | def threaded(func): 31 | # Multi-threads a target function and returns thread. Usage: @threaded decorator 32 | def wrapper(*args, **kwargs): 33 | thread = threading.Thread(target=func, args=args, kwargs=kwargs, daemon=True) 34 | thread.start() 35 | return thread 36 | 37 | return wrapper 38 | 39 | 40 | def join_threads(verbose=False): 41 | # Join all daemon threads, i.e. atexit.register(lambda: join_threads()) 42 | main_thread = threading.current_thread() 43 | for t in threading.enumerate(): 44 | if t is not main_thread: 45 | if verbose: 46 | print(f'Joining thread {t.name}') 47 | t.join() 48 | 49 | 50 | def notebook_init(verbose=True): 51 | # Check system software and hardware 52 | print('Checking setup...') 53 | 54 | import os 55 | import shutil 56 | 57 | from ultralytics.utils.checks import check_requirements 58 | 59 | from utils.general import check_font, is_colab 60 | from utils.torch_utils import select_device # imports 61 | 62 | check_font() 63 | 64 | import psutil 65 | 66 | if check_requirements('wandb', install=False): 67 | os.system('pip uninstall -y wandb') # eliminate unexpected account creation prompt with infinite hang 68 | if is_colab(): 69 | shutil.rmtree('/content/sample_data', ignore_errors=True) # remove colab /sample_data directory 70 | 71 | # System info 72 | display = None 73 | if verbose: 74 | gb = 1 << 30 # bytes to GiB (1024 ** 3) 75 | ram = psutil.virtual_memory().total 76 | total, used, free = shutil.disk_usage('/') 77 | with contextlib.suppress(Exception): # clear display if ipython is installed 78 | from IPython import display 79 | display.clear_output() 80 | s = f'({os.cpu_count()} CPUs, {ram / gb:.1f} GB RAM, {(total - free) / gb:.1f}/{total / gb:.1f} GB disk)' 81 | else: 82 | s = '' 83 | 84 | select_device(newline=False) 85 | print(emojis(f'Setup complete ✅ {s}')) 86 | return display 87 | -------------------------------------------------------------------------------- /model_zoo/yolov5/__init__.py~: -------------------------------------------------------------------------------- 1 | # YOLOv5 🚀 by Ultralytics, AGPL-3.0 license 2 | """ 3 | utils/initialization 4 | """ 5 | 6 | import contextlib 7 | import platform 8 | import threading 9 | 10 | 11 | def emojis(str=''): 12 | # Return platform-dependent emoji-safe version of string 13 | return str.encode().decode('ascii', 'ignore') if platform.system() == 'Windows' else str 14 | 15 | 16 | class TryExcept(contextlib.ContextDecorator): 17 | # YOLOv5 TryExcept class. Usage: @TryExcept() decorator or 'with TryExcept():' context manager 18 | def __init__(self, msg=''): 19 | self.msg = msg 20 | 21 | def __enter__(self): 22 | pass 23 | 24 | def __exit__(self, exc_type, value, traceback): 25 | if value: 26 | print(emojis(f"{self.msg}{': ' if self.msg else ''}{value}")) 27 | return True 28 | 29 | 30 | def threaded(func): 31 | # Multi-threads a target function and returns thread. Usage: @threaded decorator 32 | def wrapper(*args, **kwargs): 33 | thread = threading.Thread(target=func, args=args, kwargs=kwargs, daemon=True) 34 | thread.start() 35 | return thread 36 | 37 | return wrapper 38 | 39 | 40 | def join_threads(verbose=False): 41 | # Join all daemon threads, i.e. atexit.register(lambda: join_threads()) 42 | main_thread = threading.current_thread() 43 | for t in threading.enumerate(): 44 | if t is not main_thread: 45 | if verbose: 46 | print(f'Joining thread {t.name}') 47 | t.join() 48 | 49 | 50 | def notebook_init(verbose=True): 51 | # Check system software and hardware 52 | print('Checking setup...') 53 | 54 | import os 55 | import shutil 56 | 57 | from ultralytics.utils.checks import check_requirements 58 | 59 | from utils.general import check_font, is_colab 60 | from utils.torch_utils import select_device # imports 61 | 62 | check_font() 63 | 64 | import psutil 65 | 66 | if check_requirements('wandb', install=False): 67 | os.system('pip uninstall -y wandb') # eliminate unexpected account creation prompt with infinite hang 68 | if is_colab(): 69 | shutil.rmtree('/content/sample_data', ignore_errors=True) # remove colab /sample_data directory 70 | 71 | # System info 72 | display = None 73 | if verbose: 74 | gb = 1 << 30 # bytes to GiB (1024 ** 3) 75 | ram = psutil.virtual_memory().total 76 | total, used, free = shutil.disk_usage('/') 77 | with contextlib.suppress(Exception): # clear display if ipython is installed 78 | from IPython import display 79 | display.clear_output() 80 | s = f'({os.cpu_count()} CPUs, {ram / gb:.1f} GB RAM, {(total - free) / gb:.1f}/{total / gb:.1f} GB disk)' 81 | else: 82 | s = '' 83 | 84 | select_device(newline=False) 85 | print(emojis(f'Setup complete ✅ {s}')) 86 | return display 87 | -------------------------------------------------------------------------------- /model_zoo/yolov5/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RolnickLab/FoMo-Bench/93f7218a4bf928b50e0eaf827f74d1d8c79f27e5/model_zoo/yolov5/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /model_zoo/yolov5/__pycache__/general.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RolnickLab/FoMo-Bench/93f7218a4bf928b50e0eaf827f74d1d8c79f27e5/model_zoo/yolov5/__pycache__/general.cpython-310.pyc -------------------------------------------------------------------------------- /model_zoo/yolov5/__pycache__/loss.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RolnickLab/FoMo-Bench/93f7218a4bf928b50e0eaf827f74d1d8c79f27e5/model_zoo/yolov5/__pycache__/loss.cpython-310.pyc -------------------------------------------------------------------------------- /model_zoo/yolov5/__pycache__/metrics.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RolnickLab/FoMo-Bench/93f7218a4bf928b50e0eaf827f74d1d8c79f27e5/model_zoo/yolov5/__pycache__/metrics.cpython-310.pyc -------------------------------------------------------------------------------- /model_zoo/yolov5/__pycache__/torch_utils.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RolnickLab/FoMo-Bench/93f7218a4bf928b50e0eaf827f74d1d8c79f27e5/model_zoo/yolov5/__pycache__/torch_utils.cpython-310.pyc -------------------------------------------------------------------------------- /model_zoo/yolov5/hyp.scratch-low.yaml: -------------------------------------------------------------------------------- 1 | # YOLOv5 🚀 by Ultralytics, AGPL-3.0 license 2 | # Hyperparameters for low-augmentation COCO training from scratch 3 | # python train.py --batch 64 --cfg yolov5n6.yaml --weights '' --data coco.yaml --img 640 --epochs 300 --linear 4 | # See tutorials for hyperparameter evolution https://github.com/ultralytics/yolov5#tutorials 5 | 6 | lr0: 0.01 # initial learning rate (SGD=1E-2, Adam=1E-3) 7 | lrf: 0.01 # final OneCycleLR learning rate (lr0 * lrf) 8 | momentum: 0.937 # SGD momentum/Adam beta1 9 | weight_decay: 0.0005 # optimizer weight decay 5e-4 10 | warmup_epochs: 3.0 # warmup epochs (fractions ok) 11 | warmup_momentum: 0.8 # warmup initial momentum 12 | warmup_bias_lr: 0.1 # warmup initial bias lr 13 | box: 0.05 # box loss gain 14 | cls: 0.5 # cls loss gain 15 | cls_pw: 1.0 # cls BCELoss positive_weight 16 | obj: 1.0 # obj loss gain (scale with pixels) 17 | obj_pw: 1.0 # obj BCELoss positive_weight 18 | iou_t: 0.20 # IoU training threshold 19 | anchor_t: 4.0 # anchor-multiple threshold 20 | # anchors: 3 # anchors per output layer (0 to ignore) 21 | fl_gamma: 0.0 # focal loss gamma (efficientDet default gamma=1.5) 22 | hsv_h: 0.015 # image HSV-Hue augmentation (fraction) 23 | hsv_s: 0.7 # image HSV-Saturation augmentation (fraction) 24 | hsv_v: 0.4 # image HSV-Value augmentation (fraction) 25 | degrees: 0.0 # image rotation (+/- deg) 26 | translate: 0.1 # image translation (+/- fraction) 27 | scale: 0.5 # image scale (+/- gain) 28 | shear: 0.0 # image shear (+/- deg) 29 | perspective: 0.0 # image perspective (+/- fraction), range 0-0.001 30 | flipud: 0.0 # image flip up-down (probability) 31 | fliplr: 0.5 # image flip left-right (probability) 32 | mosaic: 1.0 # image mosaic (probability) 33 | mixup: 0.0 # image mixup (probability) 34 | copy_paste: 0.0 # segment copy-paste (probability) 35 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==1.26.2 2 | torch==2.1.2 3 | torchvision==0.16.2 4 | webdataset==0.2.86 5 | torchmetrics==1.2.1 6 | black==23.12.1 7 | pre_commit==3.6.0 8 | matplotlib==3.8.2 9 | einops==0.7.0 10 | pyjson5==1.6.5 11 | albumentations==1.3.1 12 | tqdm==4.66.1 13 | segmentation-models-pytorch==0.3.3 14 | torchsummary==1.5.1 15 | transformers==4.36.2 16 | torch_geometric==2.4.0 17 | kornia==0.7.0 18 | pandas==2.1.4 19 | ray==2.9.0 20 | wandb==0.16.1 21 | rasterio==1.3.9 22 | laspy==2.5.3 23 | xmltodict==0.13.0 24 | geopandas==0.14.3 25 | torchgeo==0.5.1 -------------------------------------------------------------------------------- /training/point_segmentation.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pprint 3 | 4 | import kornia 5 | import numpy as np 6 | import pyjson5 as json 7 | import torch 8 | import tqdm 9 | 10 | import utilities.model_utilities as model_utils 11 | import utilities.utils as utils 12 | import wandb 13 | from torch.optim.lr_scheduler import ExponentialLR 14 | 15 | 16 | def train_epoch(train_loader, model, optimizer, criterion, epoch, configs, scaler, iteration, scheduler=None): 17 | if not configs["linear_evaluation"]: 18 | model.train() 19 | 20 | for idx, batch in enumerate(tqdm.tqdm(train_loader)): 21 | iteration += 1 22 | optimizer.zero_grad() 23 | with torch.cuda.amp.autocast(enabled=configs["mixed_precision"]): 24 | batch_data = batch[0] 25 | point_cloud = [data.pos for data in batch_data] 26 | point_cloud = torch.vstack(point_cloud) 27 | # Empty features for benchmarking 28 | x = [data.x for data in batch_data] 29 | x = torch.vstack(x) 30 | label = [data.y for data in batch_data] 31 | label = torch.cat(label).to(configs["device"]) 32 | point_cloud = point_cloud.to(configs["device"]).float() 33 | x = x.to(configs["device"]).float() 34 | label = label.to(configs["device"]) 35 | if configs["architecture"] in ("pointnet2", "point_transformer"): 36 | batch = x.new_zeros(point_cloud.shape[0], dtype=torch.int64).to(configs["device"]) 37 | out = model(x, point_cloud, batch) 38 | else: 39 | out = model(point_cloud) 40 | 41 | loss = criterion(out, label) 42 | 43 | if iteration % 10 == 0: 44 | log_dict = { 45 | "Epoch": epoch, 46 | "Iteration": iteration, 47 | "train loss": loss.item(), 48 | "lr": scheduler.get_last_lr()[0], 49 | } 50 | if configs["wandb"]: 51 | wandb.log(log_dict) 52 | else: 53 | print(log_dict) 54 | if configs["mixed_precision"]: 55 | scaler.scale(loss).backward() 56 | scaler.step(optimizer) 57 | scaler.update() 58 | if iteration % configs["lr_step"] == 0: 59 | scheduler.step() 60 | else: 61 | loss.backward() 62 | optimizer.step() 63 | if iteration % configs["lr_step"] == 0: 64 | scheduler.step() 65 | return iteration 66 | 67 | 68 | def train(configs): 69 | print("=" * 20) 70 | print("Initializing segmentation trainer") 71 | print("=" * 20) 72 | metrics = utils.initialize_metrics(configs) 73 | criterion = utils.create_loss(configs) 74 | model = model_utils.create_model(configs) 75 | optimizer = utils.create_optimizer(configs)(model.parameters(), lr=configs["lr"], weight_decay=configs["weight_decay"]) 76 | print("Number of trainable parameters: {}".format(utils.count_params(model))) 77 | print("=" * 20) 78 | scheduler = ExponentialLR(optimizer, gamma=0.9) 79 | # compile model (torch 2.0) 80 | # model = torch.compile(model) 81 | 82 | if configs["mixed_precision"]: 83 | # Creates a GradScaler once at the beginning of training. 84 | scaler = torch.cuda.amp.GradScaler() 85 | else: 86 | scaler = None 87 | 88 | train_loader, val_loader, _ = utils.create_dataloaders(configs) 89 | model.to(configs["device"]) 90 | best_loss = 10000.0 91 | iteration = 0 92 | for epoch in range(configs["epochs"]): 93 | iteration = train_epoch(train_loader, model, optimizer, criterion, epoch, configs, scaler, iteration, scheduler) 94 | if (epoch + 1) % configs["val_step"] == 0: 95 | val_loss = test(configs, phase="val", model=model, criterion=criterion, loader=val_loader, epoch=epoch) 96 | if val_loss < best_loss: 97 | best_loss = val_loss 98 | print("New best validation loss: ", best_loss) 99 | print("Saving checkpoint") 100 | # Store checkpoint 101 | torch.save(model, os.path.join(configs["checkpoint_path"], "best_model.pt")) 102 | 103 | 104 | def test(configs, phase, model=None, loader=None, criterion=None, epoch="Test"): 105 | if phase == "test": 106 | print("=" * 20) 107 | print("Begin Testing") 108 | print("=" * 20) 109 | _, _, loader = utils.create_dataloaders(configs) 110 | criterion = utils.create_loss(configs) 111 | 112 | # Load model from checkpoint 113 | model = torch.load(os.path.join(configs["checkpoint_path"], "best_model.pt"), map_location=configs["device"]) 114 | print("Number of trainable parameters: {}".format(utils.count_params(model))) 115 | print("=" * 20) 116 | 117 | # compile model (torch 2.0) 118 | # model = torch.compile(model) 119 | elif phase == "val": 120 | print("=" * 20) 121 | print("Begin Evaluation") 122 | print("=" * 20) 123 | else: 124 | print("Uknown phase!") 125 | exit(3) 126 | 127 | metrics = utils.initialize_metrics(configs) 128 | for metric in metrics: 129 | metric = metric.to("cpu") 130 | model.to(configs["device"]) 131 | model.eval() 132 | total_loss = 0.0 133 | 134 | # Images to log to wandb 135 | first_image = None 136 | first_prediction = None 137 | first_mask = None 138 | num_samples = 0 139 | random_batch_idx = np.random.randint(len(loader)) 140 | for idx, batch in enumerate(tqdm.tqdm(loader)): 141 | with torch.no_grad(): 142 | with torch.cuda.amp.autocast(enabled=configs["mixed_precision"]): 143 | batch_data = batch[0] 144 | point_cloud = [data.pos for data in batch_data] 145 | if idx == random_batch_idx: 146 | # for visualization purposes 147 | first_pc_len = point_cloud[0].shape[0] 148 | point_cloud = torch.vstack(point_cloud) 149 | # Empty features for benchmarking 150 | x = [data.x for data in batch_data] 151 | x = torch.vstack(x) 152 | label = [data.y for data in batch_data] 153 | label = torch.cat(label) 154 | point_cloud = point_cloud.to(configs["device"]).float() 155 | x = x.to(configs["device"]).float() 156 | label = label.to(configs["device"]) 157 | if configs["architecture"] in ("pointnet2", "point_transformer"): 158 | batch = x.new_zeros(point_cloud.shape[0], dtype=torch.int64).to(configs["device"]) 159 | out = model(x, point_cloud, batch) 160 | else: 161 | out = model(point_cloud) 162 | 163 | loss = criterion(out, label) 164 | total_loss += loss.item() 165 | if idx == random_batch_idx: 166 | # for visualization purposes 167 | predictions = out.argmax(1) 168 | first_pc = point_cloud.detach().cpu()[:first_pc_len].numpy() 169 | first_pc_prediction = predictions.detach().cpu()[:first_pc_len].numpy() 170 | first_pc_label = label.detach().cpu()[:first_pc_len].numpy() 171 | for metric in metrics: 172 | metric(out.detach().cpu(), label.detach().cpu()) 173 | num_samples += 1 # the entire PC is considered as a single sample for loss computation 174 | 175 | total_loss = total_loss / num_samples 176 | log_dict = {"Epoch": epoch, phase + " loss": total_loss} 177 | for idx, metric in enumerate(metrics): 178 | if metric.average != "none": 179 | log_dict[phase + " " + metric.average + " " + metric.__class__.__name__] = metric.compute() 180 | else: 181 | if phase != "val": 182 | scores = metric.compute() 183 | for idx in range(scores.shape[0]): 184 | log_dict[phase + " " + metric.__class__.__name__ + " Class: " + str(idx)] = scores[idx] 185 | if configs["wandb"]: 186 | class_labels = {} 187 | for i in list(range(configs["num_classes"])): 188 | class_labels[i] = str(i) 189 | 190 | first_pc_prediction += 1 191 | first_pc_label += 1 192 | pc_vis = { 193 | "predictions": wandb.Object3D(np.hstack((first_pc, first_pc_prediction.reshape(-1, 1)))), 194 | "labels": wandb.Object3D(np.hstack((first_pc, first_pc_label.reshape(-1, 1)))), 195 | } 196 | 197 | log_dict[phase + " sample"] = pc_vis 198 | wandb.log(log_dict) 199 | else: 200 | print(log_dict) 201 | return total_loss 202 | -------------------------------------------------------------------------------- /training/segmentation.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pprint 3 | 4 | import kornia 5 | import numpy as np 6 | import pyjson5 as json 7 | import torch 8 | import tqdm 9 | 10 | import utilities.model_utilities as model_utils 11 | import utilities.utils as utils 12 | import wandb 13 | from utilities.model_utilities import adjust_learning_rate, get_current_learning_rate 14 | 15 | 16 | def early_stopping(val_loss, best_loss, counter, patience=5): 17 | if val_loss < best_loss: 18 | counter = 0 19 | else: 20 | counter += 1 21 | 22 | if counter >= patience: 23 | print("Early stopping") 24 | return True, counter 25 | else: 26 | return False, counter 27 | 28 | 29 | def train_epoch(train_loader, model, optimizer, criterion, epoch, configs, scaler): 30 | if not configs["linear_evaluation"] and not configs["fully_finetune"]: 31 | model.train() 32 | else: 33 | model.train() 34 | modality_dictionary = {v: k for k, v in configs["modality_channels"].items()} 35 | available_modalities = configs["dataset_modality_index"][configs["dataset"]] 36 | spectral_keys = [] 37 | desired_indices = [] 38 | for modality in available_modalities: 39 | spectral_keys.append(int(modality_dictionary[modality])) 40 | desired_indices.append(available_modalities[modality]) 41 | 42 | for idx, batch in enumerate(tqdm.tqdm(train_loader)): 43 | if "samples_per_epoch" in configs: 44 | if (idx + 1) * configs["batch_size"] > configs["samples_per_epoch"]: 45 | break 46 | optimizer.zero_grad() 47 | with torch.cuda.amp.autocast(enabled=configs["mixed_precision"]): 48 | image, label = batch 49 | 50 | image = image.to(configs["device"]) 51 | label = label.to(configs["device"]) 52 | 53 | if not configs["linear_evaluation"] and not configs["fully_finetune"]: 54 | out = model(image) 55 | else: 56 | image = image[:, desired_indices, :, :] 57 | 58 | out = model((image, spectral_keys)) 59 | 60 | label = label.long() 61 | loss = criterion(out, label) 62 | if idx % 100 == 0: 63 | log_dict = {"Epoch": epoch, "Iteration": idx, "train loss": loss.item()} 64 | if configs["wandb"]: 65 | wandb.log(log_dict) 66 | else: 67 | print(log_dict) 68 | if configs["mixed_precision"]: 69 | scaler.scale(loss).backward() 70 | scaler.step(optimizer) 71 | scaler.update() 72 | else: 73 | loss.backward() 74 | optimizer.step() 75 | 76 | if configs["schedule"] == "cos": 77 | if "samples_per_epoch" in configs: 78 | num_steps = configs["samples_per_epoch"] // configs["batch_size"] 79 | else: 80 | num_steps = 10000 81 | adjust_learning_rate(optimizer, idx / num_steps + epoch, configs) 82 | 83 | 84 | def train(configs): 85 | print("=" * 20) 86 | print("Initializing segmentation trainer") 87 | print("=" * 20) 88 | metrics = utils.initialize_metrics(configs) 89 | criterion = utils.create_loss(configs) 90 | base_model = model_utils.create_model(configs) 91 | optimizer = utils.create_optimizer(configs)( 92 | base_model.parameters(), lr=configs["lr"], weight_decay=configs["weight_decay"] 93 | ) 94 | 95 | # compile model (torch 2.0) 96 | model = base_model # torch.compile(base_model) 97 | 98 | if configs["mixed_precision"]: 99 | # Creates a GradScaler once at the beginning of training. 100 | scaler = torch.cuda.amp.GradScaler() 101 | else: 102 | scaler = None 103 | 104 | train_loader, val_loader, _ = utils.create_dataloaders(configs) 105 | model.to(configs["device"]) 106 | best_loss = 10000.0 107 | early_stop_counter = 0 108 | for epoch in range(configs["epochs"]): 109 | train_epoch(train_loader, model, optimizer, criterion, epoch, configs, scaler) 110 | 111 | val_loss = test(configs, phase="val", model=model, criterion=criterion, loader=val_loader, epoch=epoch) 112 | if "early_stopping" in configs and configs["early_stopping"] > 0: 113 | early_stop, early_stop_counter = early_stopping( 114 | val_loss, best_loss, early_stop_counter, patience=configs["early_stopping"] 115 | ) 116 | else: 117 | early_stop = False 118 | 119 | if val_loss < best_loss: 120 | best_loss = val_loss 121 | print("New best validation loss: ", best_loss) 122 | print("Saving checkpoint") 123 | # Store checkpoint 124 | torch.save(base_model, os.path.join(configs["checkpoint_path"], "best_model.pt")) 125 | 126 | torch.save( 127 | base_model.state_dict(), 128 | os.path.join(configs["checkpoint_path"], "best_model_state_dict_" + str(epoch) + ".pt"), 129 | ) 130 | if early_stop: 131 | print("Early stopping at epoch: ", epoch) 132 | break 133 | 134 | 135 | def test(configs, phase, model=None, loader=None, criterion=None, epoch="Test"): 136 | if phase == "test": 137 | print("=" * 20) 138 | print("Begin Testing") 139 | print("=" * 20) 140 | _, _, loader = utils.create_dataloaders(configs) 141 | criterion = utils.create_loss(configs) 142 | 143 | if configs["eval_checkpoint"] is not None: 144 | print("=" * 20) 145 | print("Evaluating segmentation for phase: ", phase, " with checkpoint: ") 146 | print(configs["eval_checkpoint"]) 147 | print("=" * 20) 148 | # Load model from checkpoint 149 | model = torch.load(os.path.join(configs["eval_checkpoint"]), map_location=configs["device"]) 150 | else: 151 | # infer checkpoint path from configs 152 | # Load model from checkpoint 153 | model = torch.load(os.path.join(configs["checkpoint_path"], "best_model.pt"), map_location=configs["device"]) 154 | 155 | # compile model (torch 2.0) 156 | # model = torch.compile(model) 157 | elif phase == "val": 158 | print("=" * 20) 159 | print("Begin Evaluation") 160 | print("=" * 20) 161 | else: 162 | print("Uknown phase!") 163 | exit(3) 164 | 165 | if configs["linear_evaluation"] or configs["fully_finetune"]: 166 | modality_dictionary = {v: k for k, v in configs["modality_channels"].items()} 167 | available_modalities = configs["dataset_modality_index"][configs["dataset"]] 168 | spectral_keys = [] 169 | desired_indices = [] 170 | for modality in available_modalities: 171 | spectral_keys.append(int(modality_dictionary[modality])) 172 | desired_indices.append(available_modalities[modality]) 173 | 174 | metrics = utils.initialize_metrics(configs) 175 | model.to(configs["device"]) 176 | model.eval() 177 | total_loss = 0.0 178 | 179 | # Images to log to wandb 180 | first_image = None 181 | first_prediction = None 182 | first_mask = None 183 | num_samples = 0 184 | for idx, batch in enumerate(tqdm.tqdm(loader)): 185 | with torch.no_grad(): 186 | with torch.cuda.amp.autocast(enabled=configs["mixed_precision"]): 187 | image, label = batch 188 | image = image.to(configs["device"]) 189 | label = label.to(configs["device"]) 190 | 191 | if not configs["linear_evaluation"] and not configs["fully_finetune"]: 192 | out = model(image) 193 | else: 194 | image = image[:, desired_indices, :, :] 195 | out = model((image, spectral_keys)) 196 | if idx == 0: 197 | predictions = out.argmax(1) 198 | first_image = image.detach().cpu()[0] 199 | first_prediction = predictions.detach().cpu()[0] 200 | first_mask = label.detach().cpu()[0] 201 | label = label.long() 202 | loss = criterion(out, label) 203 | total_loss += loss.item() 204 | for metric in metrics: 205 | metric(out, label) 206 | num_samples += image.shape[0] 207 | 208 | total_loss = total_loss / num_samples 209 | log_dict = {"Epoch": epoch, phase + " loss": total_loss} 210 | for idx, metric in enumerate(metrics): 211 | if metric.average != "none": 212 | log_dict[phase + " " + metric.average + " " + metric.__class__.__name__] = metric.compute() 213 | else: 214 | if phase != "val": 215 | scores = metric.compute() 216 | for idx in range(scores.shape[0]): 217 | log_dict[phase + " " + metric.__class__.__name__ + " Class: " + str(idx)] = scores[idx] 218 | if configs["wandb"]: 219 | class_labels = {} 220 | for i in list(range(configs["num_classes"])): 221 | class_labels[i] = str(i) 222 | 223 | if configs["log_images"]: 224 | if configs["normalization"] == "standard": 225 | first_image = first_image.unsqueeze(0) 226 | first_image = kornia.enhance.Denormalize(torch.tensor(configs["mean"]), torch.tensor(configs["std"]))( 227 | first_image 228 | ).squeeze() 229 | first_image = first_image[:3, :, :].permute(1, 2, 0) / first_image.max() 230 | first_image *= 255 231 | else: 232 | first_image = first_image[:3, :, :].permute(1, 2, 0) * 255 233 | mask_img = wandb.Image( 234 | (first_image).int().cpu().detach().numpy(), 235 | masks={ 236 | "predictions": {"mask_data": first_prediction.float().numpy(), "class_labels": class_labels}, 237 | "ground_truth": {"mask_data": first_mask.float().numpy(), "class_labels": class_labels}, 238 | }, 239 | ) 240 | log_dict[phase + " sample"] = mask_img 241 | wandb.log(log_dict) 242 | else: 243 | print(log_dict) 244 | return total_loss 245 | -------------------------------------------------------------------------------- /utilities/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RolnickLab/FoMo-Bench/93f7218a4bf928b50e0eaf827f74d1d8c79f27e5/utilities/__init__.py -------------------------------------------------------------------------------- /utilities/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RolnickLab/FoMo-Bench/93f7218a4bf928b50e0eaf827f74d1d8c79f27e5/utilities/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /utilities/__pycache__/augmentations.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RolnickLab/FoMo-Bench/93f7218a4bf928b50e0eaf827f74d1d8c79f27e5/utilities/__pycache__/augmentations.cpython-310.pyc -------------------------------------------------------------------------------- /utilities/__pycache__/distributed_utils.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RolnickLab/FoMo-Bench/93f7218a4bf928b50e0eaf827f74d1d8c79f27e5/utilities/__pycache__/distributed_utils.cpython-310.pyc -------------------------------------------------------------------------------- /utilities/__pycache__/model_utilities.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RolnickLab/FoMo-Bench/93f7218a4bf928b50e0eaf827f74d1d8c79f27e5/utilities/__pycache__/model_utilities.cpython-310.pyc -------------------------------------------------------------------------------- /utilities/__pycache__/utils.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RolnickLab/FoMo-Bench/93f7218a4bf928b50e0eaf827f74d1d8c79f27e5/utilities/__pycache__/utils.cpython-310.pyc -------------------------------------------------------------------------------- /utilities/__pycache__/webdataset_writer.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RolnickLab/FoMo-Bench/93f7218a4bf928b50e0eaf827f74d1d8c79f27e5/utilities/__pycache__/webdataset_writer.cpython-310.pyc -------------------------------------------------------------------------------- /utilities/augmentations.py: -------------------------------------------------------------------------------- 1 | import albumentations as A 2 | import torch_geometric.transforms as T 3 | 4 | 5 | def get_augmentations(config): 6 | augmentations = config["augmentations"] 7 | task = config["task"] 8 | independend_aug = [] 9 | for k, v in augmentations.items(): 10 | if k == "RandomResizedCrop": 11 | aug = A.augmentations.RandomResizedCrop( 12 | height=v["value"], width=v["value"], p=v["p"], scale=tuple(v["scale"]), interpolation=v["interpolation"] 13 | ) 14 | elif k == "Resize": 15 | aug = A.augmentations.Resize(height=v["value"], width=v["value"], p=v["p"]) 16 | elif k == "ColorJitter": 17 | aug = A.augmentations.ColorJitter( 18 | brightness=v["value"][0], 19 | contrast=v["value"][1], 20 | saturation=v["value"][2], 21 | hue=v["value"][3], 22 | p=v["p"], 23 | ) 24 | elif k == "HorizontalFlip": 25 | aug = A.augmentations.HorizontalFlip(p=v["p"]) 26 | elif k == "VerticalFlip": 27 | aug = A.augmentations.VerticalFlip(p=v["p"]) 28 | elif k == "RandomRotation": 29 | aug = A.augmentations.Rotate(p=v["p"]) 30 | elif k == "GaussianBlur": 31 | aug = A.augmentations.GaussianBlur(sigma_limit=v["value"], p=v["p"]) 32 | elif k == "ElasticTransform": 33 | aug = A.augmentations.ElasticTransform(p=v["p"]) 34 | elif k == "Cutout": 35 | aug = A.augmentations.CoarseDropout(p=v["p"]) 36 | elif k == "GaussianNoise": 37 | aug = A.augmentations.GaussNoise(p=v["p"]) 38 | elif k == "MultNoise": 39 | aug = A.augmentations.MultiplicativeNoise(p=v["p"]) 40 | elif k == "SamplePoints": 41 | aug = T.SamplePoints(num=v["num"], remove_faces=v["remove_faces"], include_normals=v["include_normals"]) 42 | elif k == "RandomJitter": 43 | aug = T.RandomJitter(translate=v["translate"]) 44 | elif k == "RandomRotate_x": 45 | aug = T.RandomRotate(degrees=v["degrees"], axis=0) 46 | elif k == "RandomRotate_y": 47 | aug = T.RandomRotate(degrees=v["degrees"], axis=1) 48 | elif k == "RandomRotate_z": 49 | aug = T.RandomRotate(degrees=v["degrees"], axis=2) 50 | else: 51 | print("Augmentation: ", k, " not supported!") 52 | exit(2) 53 | independend_aug.append(aug) 54 | if task == "detection": 55 | return A.Compose( 56 | independend_aug, 57 | bbox_params=A.BboxParams(format=config["det_format"], min_visibility=0.01, label_fields=["class_labels"]), 58 | ) 59 | elif task == "point_segmentation": 60 | return T.Compose(independend_aug) 61 | return A.Compose(independend_aug) 62 | -------------------------------------------------------------------------------- /utilities/detection_datasets/tilerize_neontree.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from pathlib import Path 4 | 5 | # Allow loading tilerizer from the above hierarchy 6 | higher_level_folder = os.path.abspath(os.path.join(os.path.dirname(__file__), '..')) 7 | sys.path.append(higher_level_folder) 8 | 9 | from tilerizer import Tilerizer 10 | 11 | 12 | def tile_training(tile_path, annot_path): 13 | tif_file_paths = tile_path.glob("*.tif") 14 | for tif_file_path in tif_file_paths: 15 | file_name = tif_file_path.stem 16 | annot_file_name = file_name + ".xml" 17 | annot_file_path = annot_path / annot_file_name 18 | print("Starting tiling TIF: {}".format(file_name)) 19 | tilerizer = Tilerizer(tif_file_path, annot_file_path) 20 | tilerizer.create_tiles() 21 | 22 | 23 | def tile_evaluation(tile_path, annot_path): 24 | tif_file_paths = tile_path.glob("*.tif") 25 | for tif_file_path in tif_file_paths: 26 | file_name = tif_file_path.stem 27 | annot_file_name = file_name + ".xml" 28 | annot_file_path = annot_path / annot_file_name 29 | if annot_file_path.exists(): 30 | print("Starting tiling TIF: {}".format(file_name)) 31 | tilerizer = Tilerizer(tif_file_path, annot_file_path) 32 | tilerizer.create_tiles() 33 | else: 34 | print("Following TIF file has no annotation: {}".format(file_name)) 35 | 36 | 37 | def main(): 38 | # Please modify the folder_path to the NeonTree official dataset 39 | # folder_path = "path/to/dataset" 40 | folder_path = "/network/scratch/a/arthur.ouaknine/data/NeonTree" 41 | neontree_path = Path(folder_path) 42 | annot_path = neontree_path / "annotations" 43 | evaluation_rgb_path = neontree_path / "evaluation" / "RGB" 44 | training_rgb_path = neontree_path / "training" / "RGB" 45 | tile_training(training_rgb_path, annot_path) 46 | tile_evaluation(evaluation_rgb_path, annot_path) 47 | 48 | 49 | if __name__ == "__main__": 50 | main() 51 | -------------------------------------------------------------------------------- /utilities/detection_datasets/tilerize_reforestree.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from pathlib import Path 4 | 5 | # Allow loading tilerizer from the above hierarchy 6 | higher_level_folder = os.path.abspath(os.path.join(os.path.dirname(__file__), '..')) 7 | sys.path.append(higher_level_folder) 8 | 9 | from tilerizer import Tilerizer 10 | 11 | 12 | def tile_dataset(tile_paths, annot_path): 13 | for rgb_file_path in tile_paths: 14 | file_name = rgb_file_path.stem 15 | print("Starting tiling RGB file: {}".format(file_name)) 16 | tilerizer = Tilerizer(rgb_file_path, annot_path) 17 | tilerizer.create_tiles(tile_size=1000) 18 | 19 | 20 | def main(): 21 | # Please modify the folder_path to the NeonTree official dataset 22 | folder_path = "path/to/dataset" 23 | reforestree_path = Path(folder_path) 24 | annot_path = reforestree_path / "mapping" / "final_dataset.csv" 25 | rgb_data_paths = list((reforestree_path / "tiles").glob("*/*.png")) 26 | tile_dataset(rgb_data_paths, annot_path) 27 | 28 | 29 | if __name__ == "__main__": 30 | main() 31 | -------------------------------------------------------------------------------- /utilities/distributed_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import warnings 4 | 5 | import torch 6 | import torch.backends.cudnn as cudnn 7 | import torch.distributed as dist 8 | import torch.nn as nn 9 | import torch.utils.data 10 | import torch.utils.data.distributed 11 | 12 | 13 | def is_distributed(): 14 | if "WORLD_SIZE" in os.environ: 15 | return int(os.environ["WORLD_SIZE"]) > 1 16 | if "SLURM_NTASKS" in os.environ: 17 | return int(os.environ["SLURM_NTASKS"]) > 1 18 | return False 19 | 20 | 21 | def world_info_from_env(): 22 | local_rank = 0 23 | for v in ("LOCAL_RANK", "SLURM_LOCALID"): 24 | if v in os.environ: 25 | local_rank = int(os.environ[v]) 26 | break 27 | global_rank = 0 28 | for v in ("RANK", "SLURM_PROCID"): 29 | if v in os.environ: 30 | global_rank = int(os.environ[v]) 31 | break 32 | world_size = 1 33 | for v in ("WORLD_SIZE", "SLURM_NTASKS"): 34 | if v in os.environ: 35 | world_size = int(os.environ[v]) 36 | break 37 | return local_rank, global_rank, world_size 38 | 39 | 40 | def is_global_master(configs): 41 | return configs["rank"] == 0 42 | 43 | 44 | def seed(configs): 45 | random.seed(configs["seed"]) 46 | torch.manual_seed(configs["seed"]) 47 | cudnn.deterministic = True 48 | warnings.warn( 49 | "You have chosen to seed training. " 50 | "This will turn on the CUDNN deterministic setting, " 51 | "which can slow down your training considerably! " 52 | "You may see unexpected behavior when restarting " 53 | "from checkpoints." 54 | ) 55 | 56 | 57 | def init_distributed(configs): 58 | if "SLURM_PROCID" in os.environ: 59 | configs["local_rank"], configs["rank"], configs["world_size"] = world_info_from_env() 60 | configs["num_workers"] = int(os.environ["SLURM_CPUS_PER_TASK"]) 61 | os.environ["LOCAL_RANK"] = str(configs["local_rank"]) 62 | os.environ["RANK"] = str(configs["rank"]) 63 | os.environ["WORLD_SIZE"] = str(configs["world_size"]) 64 | dist.init_process_group( 65 | backend="nccl", 66 | init_method="env://", 67 | world_size=configs["world_size"], 68 | rank=configs["rank"], 69 | ) 70 | else: 71 | configs["local_rank"], _, _ = world_info_from_env() 72 | dist.init_process_group(backend="nccl") 73 | configs["world_size"] = dist.get_world_size() 74 | configs["rank"] = dist.get_rank() 75 | 76 | return configs 77 | -------------------------------------------------------------------------------- /utilities/pointcloud_datasets/tilerize_forinstance.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | from pathlib import Path 3 | from utilities.tilerizer import Tilerizer 4 | 5 | 6 | def main(): 7 | # Please modify the folder_path to the NeonTree official dataset 8 | folder_path = "path/to/dataset" 9 | forinstance_path = Path(folder_path) 10 | paths = pd.read_csv(forinstance_path / "data_split_metadata.csv")["path"] 11 | paths = [forinstance_path / path for path in paths] 12 | for i, pc_path in enumerate(paths): 13 | print("Starting sub point cloud {}:".format(pc_path)) 14 | tilerizer = Tilerizer(pc_path, task="segmentation", modality="point_cloud") 15 | samples = tilerizer.create_subpointcloud(nb_max_points=100000) 16 | 17 | 18 | if __name__ == "__main__": 19 | main() 20 | -------------------------------------------------------------------------------- /utilities/pointcloud_datasets/tilerize_neontree.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | from pathlib import Path 3 | from utilities.tilerizer import Tilerizer 4 | 5 | 6 | def main(): 7 | # Please modify the folder_path to the NeonTree official dataset 8 | folder_path = "path/to/dataset" 9 | neontree_path = Path(folder_path) 10 | with open(neontree_path / "lidar_annots_paths.yml", "r") as fp: 11 | paths = yaml.safe_load(fp) 12 | for split in ("training", "evaluation"): 13 | for i, pc_path in enumerate(paths[split]["annot_laz"]): 14 | pc_path = Path(pc_path) 15 | print("Starting sub point cloud {}:".format(pc_path)) 16 | tilerizer = Tilerizer(pc_path, task="segmentation", modality="point_cloud") 17 | samples = tilerizer.create_subpointcloud(nb_max_points=100000) 18 | 19 | 20 | if __name__ == "__main__": 21 | main() 22 | -------------------------------------------------------------------------------- /utilities/webdataset_writer.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | 4 | import albumentations as A 5 | import einops 6 | import torch 7 | import tqdm 8 | import webdataset as wds 9 | 10 | import utilities.utils as utils 11 | import ray 12 | 13 | 14 | def wds_write(configs): 15 | for mode in ["train", "val", "test"]: 16 | dataset = utils.load_dataset(configs, mode=mode) 17 | print("=" * 40) 18 | print("Creating shards for dataset: ", configs["dataset"]) 19 | print("Mode: ", mode, " Size: ", len(dataset)) 20 | print("=" * 40) 21 | 22 | if configs["max_sample_resolution"] is None: 23 | shard_path = Path(os.path.join(configs["root_path"], "webdataset", configs["dataset"], mode)) 24 | shard_path.mkdir(parents=True, exist_ok=True) 25 | 26 | pattern = os.path.join(configs["root_path"], "webdataset", configs["dataset"], mode, f"sample-{mode}-%06d.tar") 27 | else: 28 | shard_path = Path( 29 | os.path.join( 30 | configs["root_path"], 31 | "webdataset" + "_" + str(configs["max_sample_resolution"]), 32 | "webdataset", 33 | configs["dataset"], 34 | mode, 35 | ) 36 | ) 37 | shard_path.mkdir(parents=True, exist_ok=True) 38 | pattern = os.path.join( 39 | configs["root_path"], 40 | "webdataset" + "_" + str(configs["max_sample_resolution"]), 41 | "webdataset", 42 | configs["dataset"], 43 | mode, 44 | f"sample-{mode}-%06d.tar", 45 | ) 46 | with wds.ShardWriter(pattern, maxcount=configs["max_samples_per_shard"]) as sink: 47 | for index, batch in enumerate(tqdm.tqdm(dataset)): 48 | if isinstance(batch, dict): 49 | image = batch["image"] 50 | else: 51 | (image, labels) = batch 52 | if configs["max_sample_resolution"] is not None: 53 | image = image.permute(1, 2, 0).numpy() 54 | resize = A.Compose( 55 | [ 56 | A.augmentations.Resize( 57 | height=configs["max_sample_resolution"], width=configs["max_sample_resolution"], p=1.0 58 | ) 59 | ] 60 | ) 61 | transform = resize(image=image) 62 | image = transform["image"] 63 | image = torch.from_numpy(einops.rearrange(image, "h w c -> c h w")) 64 | 65 | if isinstance(batch, dict): 66 | labels_dict = {} 67 | for key in batch: 68 | if key != "image": 69 | labels_dict[key] = batch[key] 70 | sink.write({"__key__": "sample%06d" % index, "image.pth": image, "labels.pth": labels_dict}) 71 | else: 72 | sink.write({"__key__": "sample%06d" % index, "image.pth": image, "labels.pth": labels}) 73 | 74 | 75 | @ray.remote 76 | def wds_write_ith_shard(configs, dataset, mode, i, n): 77 | if configs["max_sample_resolution"] is None: 78 | shard_path = Path(os.path.join(configs["root_path"], "webdataset", configs["dataset"], mode)) 79 | shard_path.mkdir(parents=True, exist_ok=True) 80 | 81 | pattern = os.path.join(configs["root_path"], "webdataset", configs["dataset"], mode, f"sample-{mode}-{i}-%06d.tar") 82 | else: 83 | shard_path = Path( 84 | os.path.join( 85 | configs["root_path"], 86 | "webdataset" + "_" + str(configs["max_sample_resolution"]), 87 | "webdataset", 88 | configs["dataset"], 89 | mode, 90 | ) 91 | ) 92 | shard_path.mkdir(parents=True, exist_ok=True) 93 | pattern = os.path.join( 94 | configs["root_path"], 95 | "webdataset" + "_" + str(configs["max_sample_resolution"]), 96 | "webdataset", 97 | configs["dataset"], 98 | mode, 99 | f"sample-{mode}-{i}-%06d.tar", 100 | ) 101 | 102 | with wds.ShardWriter(pattern, maxcount=configs["max_samples_per_shard"]) as sink: 103 | for index in tqdm.tqdm(range(i, len(dataset), n)): 104 | batch = dataset[index] 105 | if isinstance(batch, dict): 106 | image = batch["image"] 107 | else: 108 | (image, labels) = batch 109 | if configs["max_sample_resolution"] is not None: 110 | image = image.permute(1, 2, 0).numpy() 111 | resize = A.Compose( 112 | [ 113 | A.augmentations.Resize( 114 | height=configs["max_sample_resolution"], width=configs["max_sample_resolution"], p=1.0 115 | ) 116 | ] 117 | ) 118 | transform = resize(image=image) 119 | image = transform["image"] 120 | image = torch.from_numpy(einops.rearrange(image, "h w c -> c h w")) 121 | 122 | if isinstance(batch, dict): 123 | labels_dict = {} 124 | for key in batch: 125 | if key != "image": 126 | labels_dict[key] = batch[key] 127 | sink.write({"__key__": "sample%06d" % index, "image.pth": image, "labels.pth": labels_dict}) 128 | else: 129 | sink.write({"__key__": "sample%06d" % index, "image.pth": image, "labels.pth": labels}) 130 | 131 | 132 | def wds_write_parallel(configs): 133 | ray.init() 134 | n = configs["webdataset_write_processes"] 135 | for mode in ["train", "val", "test"]: 136 | dataset = utils.load_dataset(configs, mode=mode) 137 | print("=" * 40) 138 | print("Creating shards for dataset: ", configs["dataset"]) 139 | print("Mode: ", mode, " Size: ", len(dataset)) 140 | print("=" * 40) 141 | 142 | ray.get([wds_write_ith_shard.remote(configs, dataset, mode, i, n) for i in range(n)]) 143 | --------------------------------------------------------------------------------