├── .gitignore ├── CITATION.cff ├── README.md ├── configs ├── callbacks │ ├── default.yaml │ ├── segmentation_callback.yaml │ └── siamese_callback.yaml ├── datamodule │ └── siamese_datamodule.yaml ├── debug │ ├── default.yaml │ ├── limit_batches.yaml │ ├── overfit.yaml │ ├── profiler.yaml │ ├── step.yaml │ └── test_only.yaml ├── experiment │ ├── Siamese_Type1_hokkaido.yaml │ ├── Siamese_Type1_kaikoura.yaml │ ├── seg_hokk_hokk_pretrain_cnn.yaml │ └── seg_hokk_kaik_pretrain_cnn.yaml ├── hparams_search │ └── mnist_optuna.yaml ├── local │ └── .gitkeep ├── log_dir │ ├── debug.yaml │ ├── default.yaml │ └── evaluation.yaml ├── logger │ ├── comet.yaml │ ├── csv.yaml │ ├── many_loggers.yaml │ ├── mlflow.yaml │ ├── neptune.yaml │ ├── tensorboard.yaml │ └── wandb.yaml ├── model │ ├── augmented_segmentation_model.yaml │ └── siamese_model.yaml ├── test.yaml ├── train.yaml └── trainer │ ├── ddp.yaml │ └── default.yaml ├── data ├── stats_dict_hokkaido.pkl └── stats_dict_kaikoura.pkl ├── notebooks ├── .gitkeep ├── Analyze_results.ipynb └── paper_plots.ipynb ├── requirements.txt ├── scripts ├── run_pretext_tasks.sh └── run_segmentation_experiments.sh ├── src ├── __init__.py ├── callbacks │ ├── __init__.py │ ├── probe_callback.py │ └── wandb_callbacks.py ├── datamodules │ ├── __init__.py │ ├── components │ │ ├── __init__.py │ │ ├── chips.py │ │ ├── transforms.py │ │ └── utils.py │ └── siamese_datamodule.py ├── models │ ├── __init__.py │ ├── siamese_downstream_module.py │ └── siamese_module.py ├── testing_pipeline.py └── training_pipeline.py ├── test.py ├── tests ├── __init__.py ├── helpers │ ├── __init__.py │ ├── module_available.py │ ├── run_command.py │ └── runif.py ├── shell │ ├── __init__.py │ ├── test_basic_commands.py │ ├── test_debug_configs.py │ └── test_sweeps.py └── unit │ ├── __init__.py │ └── test_mnist_datamodule.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /CITATION.cff: -------------------------------------------------------------------------------- 1 | # This CITATION.cff file was generated with cffinit. 2 | # Visit https://bit.ly/cffinit to generate yours today! 3 | 4 | cff-version: 1.2.0 5 | title: >- 6 | SAR-based landslide classification pretraining leads to 7 | better segmentation 8 | message: >- 9 | If you use this software, please cite it using the 10 | metadata from this file. 11 | type: software 12 | repository-code: "https://github.com/VMBoehm/SAR-landslide-detection-pretraining/" 13 | authors: 14 | - given-names: Vanessa 15 | family-names: Boehm 16 | affiliation: >- 17 | University of California Berkeley, United 18 | States 19 | orcid: 'https://orcid.org/0000-0003-3801-1912' 20 | - given-names: Wei Ji 21 | family-names: Leong 22 | affiliation: 'The Ohio State University, United States' 23 | orcid: 'https://orcid.org/0000-0003-2354-1988' 24 | - given-names: Ragini Bal 25 | family-names: Mahesh 26 | affiliation: 'German Aerospace Center DLR, Germany' 27 | orcid: 'https://orcid.org/0000-0002-2747-9811' 28 | - given-names: Ioannis 29 | family-names: Prapas 30 | affiliation: 'University of Valencia, Spain' 31 | orcid: 'https://orcid.org/0000-0002-9111-4112' 32 | - given-names: Edoardo 33 | family-names: Nemni 34 | affiliation: 'United Nations Satellite Centre, Switzerland' 35 | orcid: 'https://orcid.org/0000-0002-0166-4613' 36 | - family-names: Kalaitzis 37 | given-names: Freddie 38 | affiliation: 'University of Oxford, United Kingdom' 39 | orcid: 'https://orcid.org/0000-0002-1471-646X' 40 | - given-names: Siddha 41 | family-names: Ganju 42 | affiliation: 'NVIDIA, United States' 43 | orcid: 'https://orcid.org/0000-0002-9462-4898' 44 | - given-names: Raul 45 | family-names: Ramos-Pollan 46 | affiliation: 'Universidad de Antioquia, Colombia' 47 | orcid: 'https://orcid.org/0000-0001-6195-3612' 48 | preferred-citation: 49 | type: conference-paper 50 | authors: 51 | - given-names: Vanessa 52 | family-names: Boehm 53 | affiliation: >- 54 | University of California Berkeley, United 55 | States 56 | orcid: 'https://orcid.org/0000-0003-3801-1912' 57 | - given-names: Wei Ji 58 | family-names: Leong 59 | affiliation: 'The Ohio State University, United States' 60 | orcid: 'https://orcid.org/0000-0003-2354-1988' 61 | - given-names: Ragini Bal 62 | family-names: Mahesh 63 | affiliation: 'German Aerospace Center DLR, Germany' 64 | orcid: 'https://orcid.org/0000-0002-2747-9811' 65 | - given-names: Ioannis 66 | family-names: Prapas 67 | affiliation: 'University of Valencia, Spain' 68 | orcid: 'https://orcid.org/0000-0002-9111-4112' 69 | - given-names: Edoardo 70 | family-names: Nemni 71 | affiliation: 'United Nations Satellite Centre, Switzerland' 72 | orcid: 'https://orcid.org/0000-0002-0166-4613' 73 | - family-names: Kalaitzis 74 | given-names: Freddie 75 | affiliation: 'University of Oxford, United Kingdom' 76 | orcid: 'https://orcid.org/0000-0002-1471-646X' 77 | - given-names: Siddha 78 | family-names: Ganju 79 | affiliation: 'NVIDIA, United States' 80 | orcid: 'https://orcid.org/0000-0002-9462-4898' 81 | - given-names: Raul 82 | family-names: Ramos-Pollan 83 | affiliation: 'Universidad de Antioquia, Colombia' 84 | orcid: 'https://orcid.org/0000-0001-6195-3612' 85 | doi: "" 86 | conference: 87 | name: "NeurIPS 2022 workshop on Artificial Intelligence for Humanitarian Assistance and Disaster Response Workshop" 88 | date-end: "2022-12-03" 89 | title: "Artificial Intelligence for Humanitarian Assistance and Disaster Response Workshop" 90 | year: 2022 91 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SAR-landslide-detection-pretraining 2 | Repository for the paper [SAR-based landslide classification pretraining leads to better segmentation](https://arxiv.org/abs/2211.09927) (accepted at [AI+HADR](https://www.hadr.ai/home) @NeurIPS 2022) 3 | 4 | ## Installing the requirements 5 | To run the experiments presented in the paper make sure to install the requirements. 6 | 7 | `pip install -r requirements.txt` 8 | 9 | ## Downloading the data 10 | 11 | Download the data from [Zenodo](https://doi.org/10.5281/zenodo.7248056). You will need the [hokkaido](https://zenodo.org/record/7248056/files/hokkaido_japan.zip) and the [kaikoura](https://zenodo.org/record/7248056/files/kaikoura_newzealand.zip) datacubes. 12 | 13 | ## Running the experiments 14 | 15 | Follow these steps to reproduce the experiments from the paper: 16 | 17 | 1) Train models on the pretext tasks 18 | 19 | `bash ./scripts/run_pretext_tasks.sh` 20 | 21 | 2) Train the downstream tasks 22 | 23 | `bash ./scripts/run_segmentation_experiments.sh` 24 | 25 | 3) Analyze results and create figures by running the notebooks in the [notebook folder](https://github.com/VMBoehm/SAR-landslide-detection-pretraining/tree/main/notebooks). 26 | 27 | **IMPORTANT:** Before running the experiments, you will need to adapt the filepaths in the configurations files located in [configs/experiment/](https://github.com/VMBoehm/SAR-landslide-detection-pretraining/tree/main/configs/experiment). 28 | 29 | ## Notes 30 | 31 | The original experiments were run on an NVIDIA V100 GPU in Google Cloud. 32 | 33 | ## Citation 34 | 35 | If you use this code for your research, please cite our [paper](https://arxiv.org/abs/2211.09927): 36 | 37 | ``` 38 | @misc{https://doi.org/10.48550/arxiv.2211.09927, 39 | doi = {10.48550/ARXIV.2211.09927}, 40 | 41 | url = {https://arxiv.org/abs/2211.09927}, 42 | 43 | author = {Böhm, Vanessa and Leong, Wei Ji and Mahesh, Ragini Bal and Prapas, Ioannis and Nemni, Edoardo and Kalaitzis, Freddie and Ganju, Siddha and Ramos-Pollan, Raul}, 44 | 45 | keywords = {Computer Vision and Pattern Recognition (cs.CV), Image and Video Processing (eess.IV), Signal Processing (eess.SP), FOS: Computer and information sciences, FOS: Computer and information sciences, FOS: Electrical engineering, electronic engineering, information engineering, FOS: Electrical engineering, electronic engineering, information engineering}, 46 | 47 | title = {SAR-based landslide classification pretraining leads to better segmentation}, 48 | 49 | publisher = {arXiv}, 50 | 51 | year = {2022}, 52 | 53 | copyright = {arXiv.org perpetual, non-exclusive license} 54 | } 55 | 56 | ``` 57 | 58 | 59 | ## Acknowledgements 60 | 61 | This work has been enabled by the Frontier Development Lab Program (FDL). FDL is a collaboration between SETI Institute and Trillium Technologies Inc., in partnership with the Department of Energy (DOE), National Aeronautics and Space Administration (NASA), the U.S. Geological Survey (USGS), Google Cloud and NVIDIA. The material is based upon work supported by NASA under award No(s) NNX14AT27A. 62 | -------------------------------------------------------------------------------- /configs/callbacks/default.yaml: -------------------------------------------------------------------------------- 1 | model_checkpoint: 2 | _target_: pytorch_lightning.callbacks.ModelCheckpoint 3 | monitor: "val/loss" # name of the logged metric which determines when model is improving 4 | mode: "max" # "max" means higher metric value is better, can be also "min" 5 | save_top_k: 1 # save k best models (determined by above metric) 6 | save_last: True # additionaly always save model from last epoch 7 | verbose: True 8 | dirpath: "/home/jupyter/deepslide/trained_models" 9 | filename: "epoch_{epoch:03d}" 10 | auto_insert_metric_name: False 11 | 12 | early_stopping: 13 | _target_: pytorch_lightning.callbacks.EarlyStopping 14 | monitor: "val/loss" # name of the logged metric which determines when model is improving 15 | mode: "min" # "max" means higher metric value is better, can be also "min" 16 | patience: 100 # how many validation epochs of not improving until training stops 17 | min_delta: 0 # minimum change in the monitored metric needed to qualify as an improvement 18 | 19 | model_summary: 20 | _target_: pytorch_lightning.callbacks.RichModelSummary 21 | max_depth: -1 22 | 23 | rich_progress_bar: 24 | _target_: pytorch_lightning.callbacks.RichProgressBar 25 | 26 | log_val_predictions: 27 | _target_: src.callbacks.wandb_callbacks.LogValPredictions 28 | num_samples: 8 29 | 30 | log_train_predictions: 31 | _target_: src.callbacks.wandb_callbacks.LogTrainPredictions 32 | num_samples: 4 33 | -------------------------------------------------------------------------------- /configs/callbacks/segmentation_callback.yaml: -------------------------------------------------------------------------------- 1 | model_checkpoint_AP: 2 | _target_: pytorch_lightning.callbacks.ModelCheckpoint 3 | monitor: "valid/AP" # name of the logged metric which determines when model is improving 4 | mode: "max" # "max" means higher metric value is better, can be also "min" 5 | save_top_k: 5 # save k best models (determined by above metric) 6 | save_last: True # additionaly always save model from last epoch 7 | verbose: False 8 | dirpath: "checkpoints/" 9 | filename: "epoch_ap_{epoch:03d}" 10 | auto_insert_metric_name: False 11 | 12 | model_checkpoint_f1: 13 | _target_: pytorch_lightning.callbacks.ModelCheckpoint 14 | monitor: "valid/f1" # name of the logged metric which determines when model is improving 15 | mode: "max" # "max" means higher metric value is better, can be also "min" 16 | save_top_k: 5 # save k best models (determined by above metric) 17 | save_last: True # additionaly always save model from last epoch 18 | verbose: False 19 | dirpath: "checkpoints/" 20 | filename: "epoch_f1_{epoch:03d}" 21 | auto_insert_metric_name: False 22 | 23 | early_stopping: 24 | _target_: pytorch_lightning.callbacks.EarlyStopping 25 | monitor: "valid/loss" # name of the logged metric which determines when model is improving 26 | mode: "min" # "max" means higher metric value is better, can be also "min" 27 | patience: 300 # how many validation epochs of not improving until training stops 28 | min_delta: 0 # minimum change in the monitored metric needed to qualify as an improvement 29 | 30 | model_summary: 31 | _target_: pytorch_lightning.callbacks.RichModelSummary 32 | max_depth: -1 33 | 34 | rich_progress_bar: 35 | _target_: pytorch_lightning.callbacks.RichProgressBar 36 | 37 | # log_val_predictions: 38 | # _target_: src.callbacks.wandb_callbacks.LogValPredictions_MAE_Downstream 39 | # num_samples: 4 40 | # counts: False 41 | 42 | # log_train_predictions: 43 | # _target_: src.callbacks.wandb_callbacks.LogTrainPredictions_MAE_Downstream 44 | # num_samples: 4 45 | # counts: False 46 | 47 | learning_rate: 48 | _target_: pytorch_lightning.callbacks.LearningRateMonitor 49 | 50 | # linear_probe: 51 | # _target_: src.callbacks.probe_callback.SSLOnlineEvaluator 52 | 53 | # bottleneck_probe: 54 | # _target_: src.callbacks.probe_callback.SSLOnlineEvaluator_bottleneck -------------------------------------------------------------------------------- /configs/callbacks/siamese_callback.yaml: -------------------------------------------------------------------------------- 1 | model_checkpoint: 2 | _target_: pytorch_lightning.callbacks.ModelCheckpoint 3 | monitor: "valid/acc" # name of the logged metric which determines when model is improving 4 | mode: "max" # "max" means higher metric value is better, can be also "min" 5 | save_top_k: 5 # save k best models (determined by above metric) 6 | save_last: True # additionaly always save model from last epoch 7 | verbose: False 8 | dirpath: "checkpoints/" 9 | filename: "epoch_{epoch:03d}" 10 | auto_insert_metric_name: False 11 | 12 | early_stopping: 13 | _target_: pytorch_lightning.callbacks.EarlyStopping 14 | monitor: "valid/loss" # name of the logged metric which determines when model is improving 15 | mode: "min" # "max" means higher metric value is better, can be also "min" 16 | patience: 400 # how many validation epochs of not improving until training stops 17 | min_delta: 0 # minimum change in the monitored metric needed to qualify as an improvement 18 | 19 | model_summary: 20 | _target_: pytorch_lightning.callbacks.RichModelSummary 21 | max_depth: -1 22 | 23 | rich_progress_bar: 24 | _target_: pytorch_lightning.callbacks.RichProgressBar 25 | 26 | # log_val_predictions: 27 | # _target_: src.callbacks.wandb_callbacks.LogValPredictions_MAE 28 | # num_samples: 2 29 | # counts: False 30 | 31 | # log_train_predictions: 32 | # _target_: src.callbacks.wandb_callbacks.LogTrainPredictions_MAE 33 | # num_samples: 2 34 | # counts: False 35 | 36 | learning_rate: 37 | _target_: pytorch_lightning.callbacks.LearningRateMonitor 38 | 39 | # linear_probe: 40 | # _target_: src.callbacks.probe_callback.SSLOnlineEvaluator 41 | 42 | # bottleneck_probe: 43 | # _target_: src.callbacks.probe_callback.SSLOnlineEvaluator_bottleneck -------------------------------------------------------------------------------- /configs/datamodule/siamese_datamodule.yaml: -------------------------------------------------------------------------------- 1 | _target_: src.datamodules.siamese_datamodule.Siamese_Landslide_Datamodule 2 | 3 | data_dir: ${data_dir} # data_dir is specified in config.yaml 4 | dict_dir: ${stats_dir} 5 | batch_size: 32 6 | num_workers: 8 7 | pin_memory: False 8 | input_channels: ['vh', 'vv'] #, 'los.rdr_0', 'los.rdr_1', 'topophase.cor_1', 'topophase.flat_imag', 'topophase.flat_real', 'dem'] 9 | input_transforms: ['Log_transform','Standardize'] 10 | num_time_steps: 1 11 | setting: 'pretraining' -------------------------------------------------------------------------------- /configs/debug/default.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # default debugging setup, runs 1 full epoch 4 | # other debugging configs can inherit from this one 5 | 6 | defaults: 7 | - override /log_dir: debug.yaml 8 | 9 | trainer: 10 | max_epochs: 1 11 | gpus: 0 # debuggers don't like gpus 12 | detect_anomaly: true # raise exception if NaN or +/-inf is detected in any tensor 13 | track_grad_norm: 2 # track gradient norm with loggers 14 | 15 | datamodule: 16 | num_workers: 0 # debuggers don't like multiprocessing 17 | pin_memory: False # disable gpu memory pin 18 | 19 | # sets level of all command line loggers to 'DEBUG' 20 | # https://hydra.cc/docs/tutorials/basic/running_your_app/logging/ 21 | hydra: 22 | verbose: True 23 | 24 | # use this to set level of only chosen command line loggers to 'DEBUG': 25 | # verbose: [src.train, src.utils] 26 | 27 | # config is already printed by hydra when `hydra/verbose: True` 28 | print_config: False 29 | -------------------------------------------------------------------------------- /configs/debug/limit_batches.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # uses only 1% of the training data and 5% of validation/test data 4 | 5 | defaults: 6 | - default.yaml 7 | 8 | trainer: 9 | max_epochs: 3 10 | limit_train_batches: 0.01 11 | limit_val_batches: 0.05 12 | limit_test_batches: 0.05 13 | -------------------------------------------------------------------------------- /configs/debug/overfit.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # overfits to 3 batches 4 | 5 | defaults: 6 | - default.yaml 7 | 8 | trainer: 9 | max_epochs: 20 10 | overfit_batches: 3 11 | -------------------------------------------------------------------------------- /configs/debug/profiler.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # runs with execution time profiling 4 | 5 | defaults: 6 | - default.yaml 7 | 8 | trainer: 9 | max_epochs: 1 10 | profiler: "simple" 11 | # profiler: "advanced" 12 | # profiler: "pytorch" 13 | -------------------------------------------------------------------------------- /configs/debug/step.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # runs 1 train, 1 validation and 1 test step 4 | 5 | defaults: 6 | - default.yaml 7 | 8 | trainer: 9 | fast_dev_run: true 10 | -------------------------------------------------------------------------------- /configs/debug/test_only.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # runs only test epoch 4 | 5 | defaults: 6 | - default.yaml 7 | 8 | train: False 9 | test: True 10 | -------------------------------------------------------------------------------- /configs/experiment/Siamese_Type1_hokkaido.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment:example 5 | 6 | defaults: 7 | - override /datamodule: siamese_datamodule.yaml 8 | - override /model: siamese_model.yaml 9 | - override /callbacks: siamese_callback.yaml 10 | - override /logger: wandb.yaml 11 | - override /trainer: default.yaml 12 | 13 | # all parameters below will be merged with parameters from default configurations set above 14 | # this allows you to overwrite only specified parameters 15 | 16 | # name of the run determines folder name in logs 17 | name: "Siamese_Type1_hokkaido" 18 | 19 | seed: 12345 20 | 21 | trainer: 22 | min_epochs: 1 23 | max_epochs: 1000 24 | gradient_clip_val: 0.5 25 | log_every_n_steps: 500 26 | gpus: 2 27 | strategy: 'ddp' 28 | #resume_from_checkpoint: '/home/jupyter/deepslide/logs/experiments/runs/MAE_Downstream/2022-09-21_06-04-52/checkpoints/last.ckpt' 29 | 30 | model: 31 | input_size: [2,128,128] 32 | embedding_size: 32 33 | decoder_depth: 1 34 | encoder_depth: 1 35 | base_lr: 0.001 36 | unet: True 37 | cnn: True 38 | decoder_channels: [32] 39 | 40 | datamodule: 41 | data_dir: #add path to data here 42 | dict_dir: # dictionaries are located in the /data folder 43 | batch_size: 32 44 | num_workers: 8 45 | pin_memory: False 46 | input_channels: ['vh', 'vv'] #, 'los.rdr_0', 'los.rdr_1', 'topophase.cor_1', 'topophase.flat_imag', 'topophase.flat_real', 'dem'] 47 | input_transforms: ['Log_transform','Standardize'] 48 | num_time_steps: 1 49 | setting: 'pretraining' 50 | datasets: ['hokkaido'] 51 | 52 | logger: 53 | wandb: 54 | tags: ["${name}"] 55 | project: 'siamese_pretraining' 56 | -------------------------------------------------------------------------------- /configs/experiment/Siamese_Type1_kaikoura.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment:example 5 | 6 | defaults: 7 | - override /datamodule: siamese_datamodule.yaml 8 | - override /model: siamese_model.yaml 9 | - override /callbacks: siamese_callback.yaml 10 | - override /logger: wandb.yaml 11 | - override /trainer: default.yaml 12 | 13 | # all parameters below will be merged with parameters from default configurations set above 14 | # this allows you to overwrite only specified parameters 15 | 16 | # name of the run determines folder name in logs 17 | name: "Siamese_Type1_kaikora" 18 | 19 | seed: 12345 20 | 21 | trainer: 22 | min_epochs: 1 23 | max_epochs: 1000 24 | gradient_clip_val: 0.5 25 | log_every_n_steps: 500 26 | gpus: 2 27 | strategy: 'ddp' 28 | #resume_from_checkpoint: '/home/jupyter/deepslide/logs/experiments/runs/MAE_Downstream/2022-09-21_06-04-52/checkpoints/last.ckpt' 29 | 30 | model: 31 | input_size: [2,128,128] 32 | embedding_size: 32 33 | decoder_depth: 1 34 | encoder_depth: 1 35 | base_lr: 0.001 36 | unet: True 37 | cnn: True 38 | decoder_channels: [32] 39 | 40 | datamodule: 41 | data_dir: #add path to data here 42 | dict_dir: # dictionaries are located in the /data folder 43 | batch_size: 32 44 | num_workers: 8 45 | pin_memory: False 46 | input_channels: ['vh', 'vv'] #, 'los.rdr_0', 'los.rdr_1', 'topophase.cor_1', 'topophase.flat_imag', 'topophase.flat_real', 'dem'] 47 | input_transforms: ['Log_transform','Standardize'] 48 | num_time_steps: 1 49 | setting: 'pretraining' 50 | datasets: ['kaikoura'] 51 | 52 | logger: 53 | wandb: 54 | tags: ["${name}"] 55 | project: 'siamese_pretraining' 56 | -------------------------------------------------------------------------------- /configs/experiment/seg_hokk_hokk_pretrain_cnn.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment:example 5 | 6 | defaults: 7 | - override /datamodule: siamese_datamodule.yaml 8 | - override /model: augmented_segmentation_model.yaml 9 | - override /callbacks: segmentation_callback.yaml 10 | - override /logger: wandb.yaml 11 | - override /trainer: default.yaml 12 | 13 | # all parameters below will be merged with parameters from default configurations set above 14 | # this allows you to overwrite only specified parameters 15 | 16 | # name of the run determines folder name in logs 17 | name: "segment_hokk_pretrain_hokk_cnn" 18 | 19 | seed: 12345 20 | 21 | trainer: 22 | min_epochs: 1 23 | max_epochs: 1000 24 | gradient_clip_val: 0.5 25 | log_every_n_steps: 500 26 | gpus: 2 27 | strategy: 'ddp' 28 | #resume_from_checkpoint: '/home/jupyter/deepslide/logs/experiments/runs/MAE_Downstream/2022-09-21_06-04-52/checkpoints/last.ckpt' 29 | 30 | model: 31 | input_size: [4,128,128] 32 | embedding_size: 64 33 | pre_train_augmented: True 34 | pretrain_path: #add location of pretrained model 35 | unet: False 36 | base_lr: 0.001 37 | pretrain_params: {'input_size':[2,128,128],'embedding_size':32,'unet':True,'decoder_depth':1, 'encoder_depth':1, 'cnn':True,'base_lr':0.001,'decoder_channels':[32]} 38 | encoder_depth: 1 39 | decoder_channels: [32] 40 | loss: 'dice' 41 | 42 | datamodule: 43 | data_dir: #add path to data here 44 | dict_dir: # dictionaries are located in the /data folder 45 | batch_size: 16 46 | num_workers: 8 47 | pin_memory: False 48 | input_channels: ['vh', 'vv'] #, 'los.rdr_0', 'los.rdr_1', 'topophase.cor_1', 'topophase.flat_imag', 'topophase.flat_real', 'dem'] 49 | input_transforms: ['Log_transform','Standardize'] 50 | num_time_steps: 1 51 | trainsize: -1 52 | setting: 'downstream' 53 | datasets: ['hokkaido'] 54 | 55 | logger: 56 | wandb: 57 | tags: ["${name}"] 58 | project: 'segmentation_task' 59 | -------------------------------------------------------------------------------- /configs/experiment/seg_hokk_kaik_pretrain_cnn.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment:example 5 | 6 | defaults: 7 | - override /datamodule: siamese_datamodule.yaml 8 | - override /model: augmented_segmentation_model.yaml 9 | - override /callbacks: segmentation_callback.yaml 10 | - override /logger: wandb.yaml 11 | - override /trainer: default.yaml 12 | 13 | # all parameters below will be merged with parameters from default configurations set above 14 | # this allows you to overwrite only specified parameters 15 | 16 | # name of the run determines folder name in logs 17 | name: "segment_hokk_pretrain_kaik_cnn" 18 | 19 | seed: 12345 20 | 21 | trainer: 22 | min_epochs: 1 23 | max_epochs: 1000 24 | gradient_clip_val: 0.5 25 | log_every_n_steps: 500 26 | gpus: 2 27 | strategy: 'ddp' 28 | #resume_from_checkpoint: '/home/jupyter/deepslide/logs/experiments/runs/MAE_Downstream/2022-09-21_06-04-52/checkpoints/last.ckpt' 29 | 30 | model: 31 | input_size: [4,128,128] 32 | embedding_size: 64 33 | pre_train_augmented: True 34 | pretrain_path: #add location of pretrained model 35 | unet: False 36 | base_lr: 0.001 37 | pretrain_params: {'input_size':[2,128,128],'embedding_size':32,'unet':True,'decoder_depth':1, 'encoder_depth':1, 'cnn':True,'base_lr':0.001,'decoder_channels':[32]} 38 | encoder_depth: 1 39 | decoder_channels: [32] 40 | loss: 'dice' 41 | 42 | datamodule: 43 | data_dir: #add path to data here 44 | dict_dir: # dictionaries are located in the /data folder 45 | batch_size: 16 46 | num_workers: 8 47 | pin_memory: False 48 | input_channels: ['vh', 'vv'] #, 'los.rdr_0', 'los.rdr_1', 'topophase.cor_1', 'topophase.flat_imag', 'topophase.flat_real', 'dem'] 49 | input_transforms: ['Log_transform','Standardize'] 50 | num_time_steps: 1 51 | trainsize: -1 52 | setting: 'downstream' 53 | datasets: ['hokkaido'] 54 | 55 | logger: 56 | wandb: 57 | tags: ["${name}"] 58 | project: 'segmentation_task' 59 | -------------------------------------------------------------------------------- /configs/hparams_search/mnist_optuna.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # example hyperparameter optimization of some experiment with Optuna: 4 | # python train.py -m hparams_search=mnist_optuna experiment=example 5 | 6 | defaults: 7 | - override /hydra/sweeper: optuna 8 | 9 | # choose metric which will be optimized by Optuna 10 | # make sure this is the correct name of some metric logged in lightning module! 11 | optimized_metric: "val/acc_best" 12 | 13 | # here we define Optuna hyperparameter search 14 | # it optimizes for value returned from function with @hydra.main decorator 15 | # docs: https://hydra.cc/docs/next/plugins/optuna_sweeper 16 | hydra: 17 | sweeper: 18 | _target_: hydra_plugins.hydra_optuna_sweeper.optuna_sweeper.OptunaSweeper 19 | 20 | # storage URL to persist optimization results 21 | # for example, you can use SQLite if you set 'sqlite:///example.db' 22 | storage: null 23 | 24 | # name of the study to persist optimization results 25 | study_name: null 26 | 27 | # number of parallel workers 28 | n_jobs: 1 29 | 30 | # 'minimize' or 'maximize' the objective 31 | direction: maximize 32 | 33 | # total number of runs that will be executed 34 | n_trials: 25 35 | 36 | # choose Optuna hyperparameter sampler 37 | # docs: https://optuna.readthedocs.io/en/stable/reference/samplers.html 38 | sampler: 39 | _target_: optuna.samplers.TPESampler 40 | seed: 12345 41 | n_startup_trials: 10 # number of random sampling runs before optimization starts 42 | 43 | # define range of hyperparameters 44 | search_space: 45 | datamodule.batch_size: 46 | type: categorical 47 | choices: [32, 64, 128] 48 | model.lr: 49 | type: float 50 | low: 0.0001 51 | high: 0.2 52 | model.net.lin1_size: 53 | type: categorical 54 | choices: [32, 64, 128, 256, 512] 55 | model.net.lin2_size: 56 | type: categorical 57 | choices: [32, 64, 128, 256, 512] 58 | model.net.lin3_size: 59 | type: categorical 60 | choices: [32, 64, 128, 256, 512] 61 | -------------------------------------------------------------------------------- /configs/local/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VMBoehm/SAR-landslide-detection-pretraining/177799a5f22a1fab38cf0da76b98b3d32843a765/configs/local/.gitkeep -------------------------------------------------------------------------------- /configs/log_dir/debug.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | hydra: 4 | run: 5 | dir: logs/debugs/runs/${name}/${now:%Y-%m-%d}_${now:%H-%M-%S} 6 | sweep: 7 | dir: logs/debugs/multiruns/${name}/${now:%Y-%m-%d}_${now:%H-%M-%S} 8 | subdir: ${hydra.job.num} 9 | -------------------------------------------------------------------------------- /configs/log_dir/default.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | hydra: 4 | run: 5 | dir: logs/experiments/runs/${name}/${now:%Y-%m-%d}_${now:%H-%M-%S} 6 | sweep: 7 | dir: logs/experiments/multiruns/${name}/${now:%Y-%m-%d}_${now:%H-%M-%S} 8 | subdir: ${hydra.job.num} 9 | -------------------------------------------------------------------------------- /configs/log_dir/evaluation.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | hydra: 4 | run: 5 | dir: logs/evaluations/runs/${name}/${now:%Y-%m-%d}_${now:%H-%M-%S} 6 | sweep: 7 | dir: logs/evaluations/multiruns/${name}/${now:%Y-%m-%d}_${now:%H-%M-%S} 8 | subdir: ${hydra.job.num} 9 | -------------------------------------------------------------------------------- /configs/logger/comet.yaml: -------------------------------------------------------------------------------- 1 | # https://www.comet.ml 2 | 3 | comet: 4 | _target_: pytorch_lightning.loggers.comet.CometLogger 5 | api_key: ${oc.env:COMET_API_TOKEN} # api key is loaded from environment variable 6 | project_name: "template-tests" 7 | experiment_name: ${name} 8 | -------------------------------------------------------------------------------- /configs/logger/csv.yaml: -------------------------------------------------------------------------------- 1 | # csv logger built in lightning 2 | 3 | csv: 4 | _target_: pytorch_lightning.loggers.csv_logs.CSVLogger 5 | save_dir: "." 6 | name: "csv/" 7 | prefix: "" 8 | -------------------------------------------------------------------------------- /configs/logger/many_loggers.yaml: -------------------------------------------------------------------------------- 1 | # train with many loggers at once 2 | 3 | defaults: 4 | # - comet.yaml 5 | - csv.yaml 6 | # - mlflow.yaml 7 | # - neptune.yaml 8 | - tensorboard.yaml 9 | - wandb.yaml 10 | -------------------------------------------------------------------------------- /configs/logger/mlflow.yaml: -------------------------------------------------------------------------------- 1 | # https://mlflow.org 2 | 3 | mlflow: 4 | _target_: pytorch_lightning.loggers.mlflow.MLFlowLogger 5 | experiment_name: ${name} 6 | tracking_uri: ${original_work_dir}/logs/mlflow/mlruns # run `mlflow ui` command inside the `logs/mlflow/` dir to open the UI 7 | tags: null 8 | prefix: "" 9 | artifact_location: null 10 | -------------------------------------------------------------------------------- /configs/logger/neptune.yaml: -------------------------------------------------------------------------------- 1 | # https://neptune.ai 2 | 3 | neptune: 4 | _target_: pytorch_lightning.loggers.neptune.NeptuneLogger 5 | api_key: ${oc.env:NEPTUNE_API_TOKEN} # api key is loaded from environment variable 6 | project_name: your_name/template-tests 7 | close_after_fit: True 8 | offline_mode: False 9 | experiment_name: ${name} 10 | experiment_id: null 11 | prefix: "" 12 | -------------------------------------------------------------------------------- /configs/logger/tensorboard.yaml: -------------------------------------------------------------------------------- 1 | # https://www.tensorflow.org/tensorboard/ 2 | 3 | tensorboard: 4 | _target_: pytorch_lightning.loggers.tensorboard.TensorBoardLogger 5 | save_dir: "tensorboard/" 6 | name: null 7 | version: ${name} 8 | log_graph: False 9 | default_hp_metric: True 10 | prefix: "" 11 | -------------------------------------------------------------------------------- /configs/logger/wandb.yaml: -------------------------------------------------------------------------------- 1 | # https://wandb.ai 2 | 3 | wandb: 4 | _target_: pytorch_lightning.loggers.wandb.WandbLogger 5 | project: ${oc.env:WANDB_PROJECT} 6 | name: ${oc.env:WANDB_NAME_PREFIX}_${now:%Y%m%d_%H%M} 7 | save_dir: "." 8 | offline: False # set True to store all logs only locally 9 | id: null # pass correct id to resume experiment! 10 | entity: ${oc.env:WANDB_ENTITY} 11 | log_model: False 12 | prefix: "" 13 | job_type: "train" 14 | group: "" 15 | tags: [] 16 | -------------------------------------------------------------------------------- /configs/model/augmented_segmentation_model.yaml: -------------------------------------------------------------------------------- 1 | _target_: src.models.siamese_downstream_module.Segmentation_Model 2 | input_size: [4,128,128] 3 | embedding_size: 64 4 | pre_train_augmented: False 5 | pretrain_path: "" 6 | unet: False 7 | base_lr: 0.001 -------------------------------------------------------------------------------- /configs/model/siamese_model.yaml: -------------------------------------------------------------------------------- 1 | _target_: src.models.siamese_module.Siamese_Type_1 2 | input_size: [2,128,128] 3 | embedding_size: 32 4 | base_lr: 0.001 -------------------------------------------------------------------------------- /configs/test.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # specify here default evaluation configuration 4 | defaults: 5 | - _self_ 6 | - datamodule: mnist.yaml # choose the datamodule for evaluation 7 | - model: mnist.yaml 8 | - callbacks: null 9 | - logger: null 10 | - trainer: default.yaml 11 | - log_dir: evaluation.yaml 12 | 13 | - experiment: null 14 | 15 | # enable color logging 16 | - override hydra/hydra_logging: colorlog 17 | - override hydra/job_logging: colorlog 18 | 19 | original_work_dir: ${hydra:runtime.cwd} 20 | 21 | data_dir: ${original_work_dir}/data/ 22 | 23 | print_config: True 24 | 25 | ignore_warnings: True 26 | 27 | seed: null 28 | 29 | name: "default" 30 | 31 | # passing checkpoint path is necessary 32 | ckpt_path: ??? 33 | -------------------------------------------------------------------------------- /configs/train.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # specify here default training configuration 4 | defaults: 5 | - _self_ 6 | - datamodule: pr_datamodule.yaml 7 | - model: supervised_model.yaml 8 | - callbacks: default.yaml 9 | - logger: null # set logger here or use command line (e.g. `python train.py logger=tensorboard`) 10 | - trainer: default.yaml 11 | - log_dir: default.yaml 12 | 13 | # experiment configs allow for version control of specific configurations 14 | # e.g. best hyperparameters for each combination of model and datamodule 15 | - experiment: null 16 | 17 | # debugging config (enable through command line, e.g. `python train.py debug=default) 18 | - debug: null 19 | 20 | # config for hyperparameter optimization 21 | - hparams_search: null 22 | 23 | # optional local config for machine/user specific settings 24 | # it's optional since it doesn't need to exist and is excluded from version control 25 | - optional local: default.yaml 26 | 27 | # enable color logging 28 | - override hydra/hydra_logging: colorlog 29 | - override hydra/job_logging: colorlog 30 | 31 | # path to original working directory 32 | # hydra hijacks working directory by changing it to the new log directory 33 | # https://hydra.cc/docs/next/tutorials/basic/running_your_app/working_directory 34 | original_work_dir: ${hydra:runtime.cwd} 35 | 36 | # path to folder with data 37 | data_dir: /home/jupyter/chips_128x128/ 38 | stats_dir: /home/jupyter/deepslide/data/landslides-puertorico/ 39 | 40 | # pretty print config at the start of the run using Rich library 41 | print_config: True 42 | 43 | # disable python warnings if they annoy you 44 | ignore_warnings: True 45 | 46 | # set False to skip model training 47 | train: True 48 | 49 | # evaluate on test set, using best model weights achieved during training 50 | # lightning chooses best weights based on the metric specified in checkpoint callback 51 | test: True 52 | 53 | # seed for random number generators in pytorch, numpy and python.random 54 | seed: null 55 | 56 | # default name for the experiment, determines logging folder path 57 | # (you can overwrite this name in experiment configs) 58 | name: "default" 59 | 60 | #ckpt_path: '/home/jupyter/deepslide/logs/experiments/runs/CNN_MAE_Test/2022-08-04_22-31-57/checkpoints/last.ckpt' -------------------------------------------------------------------------------- /configs/trainer/ddp.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default.yaml 3 | 4 | gpus: 4 5 | strategy: ddp 6 | sync_batchnorm: True 7 | -------------------------------------------------------------------------------- /configs/trainer/default.yaml: -------------------------------------------------------------------------------- 1 | _target_: pytorch_lightning.Trainer 2 | 3 | gpus: 1 4 | 5 | min_epochs: 1 6 | max_epochs: 150 7 | 8 | # number of validation steps to execute at the beginning of the training 9 | # num_sanity_val_steps: 0 10 | 11 | # ckpt path 12 | resume_from_checkpoint: null 13 | #'/home/jupyter/deepslide/logs/experiments/runs/First_MAE_Test/2022-08-02_19-28-40/checkpoints/last.ckpt' 14 | 15 | log_every_n_steps: 80 16 | -------------------------------------------------------------------------------- /data/stats_dict_hokkaido.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VMBoehm/SAR-landslide-detection-pretraining/177799a5f22a1fab38cf0da76b98b3d32843a765/data/stats_dict_hokkaido.pkl -------------------------------------------------------------------------------- /data/stats_dict_kaikoura.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VMBoehm/SAR-landslide-detection-pretraining/177799a5f22a1fab38cf0da76b98b3d32843a765/data/stats_dict_kaikoura.pkl -------------------------------------------------------------------------------- /notebooks/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VMBoehm/SAR-landslide-detection-pretraining/177799a5f22a1fab38cf0da76b98b3d32843a765/notebooks/.gitkeep -------------------------------------------------------------------------------- /notebooks/Analyze_results.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "c93a017f-b8a4-4234-a3fe-5bcf412a483f", 6 | "metadata": {}, 7 | "source": [ 8 | "# Notebook for evaluating Trained Models" 9 | ] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": null, 14 | "id": "f46f64fe-631a-4361-abc8-aaa82f6730a6", 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "%load_ext autoreload\n", 19 | "%autoreload 2 " 20 | ] 21 | }, 22 | { 23 | "cell_type": "code", 24 | "execution_count": null, 25 | "id": "ff6445cd-64bc-4f2a-89d7-48aa9a1330f9", 26 | "metadata": {}, 27 | "outputs": [], 28 | "source": [ 29 | "import sys\n", 30 | "from pathlib import Path\n", 31 | "from torchmetrics.functional import precision_recall\n", 32 | "from torchmetrics import AveragePrecision,Accuracy\n", 33 | "import torch\n", 34 | "### adding model to path\n", 35 | "sys.path.append('/home/jupyter/deepslide')\n", 36 | "from src.datamodules.siamese_datamodule import Siamese_Landslide_Datamodule\n", 37 | "from src.models.siamese_downstream_module import Segmentation_Model\n", 38 | "from src.models.siamese_module import Siamese_Type_1\n", 39 | "import numpy as np\n", 40 | "import matplotlib.pyplot as plt\n", 41 | "import pickle" 42 | ] 43 | }, 44 | { 45 | "cell_type": "code", 46 | "execution_count": null, 47 | "id": "d3049679-87c0-43f0-93e7-bf30435943a9", 48 | "metadata": {}, 49 | "outputs": [], 50 | "source": [ 51 | "import scipy" 52 | ] 53 | }, 54 | { 55 | "cell_type": "code", 56 | "execution_count": null, 57 | "id": "5ba01490-8a35-4ede-a372-2c1fee8345a7", 58 | "metadata": {}, 59 | "outputs": [], 60 | "source": [ 61 | "scipy.special.expit(0.)" 62 | ] 63 | }, 64 | { 65 | "cell_type": "markdown", 66 | "id": "4053f3ca-3cbd-4e85-896f-6ffb4777d777", 67 | "metadata": {}, 68 | "source": [ 69 | "### choose which experiment to evaluate" 70 | ] 71 | }, 72 | { 73 | "cell_type": "code", 74 | "execution_count": null, 75 | "id": "010b67ed-cd4f-49a7-8b5b-c6c2a7d60054", 76 | "metadata": {}, 77 | "outputs": [], 78 | "source": [ 79 | "# settings\n", 80 | "dataset = 'Hokkaido'\n", 81 | "pretraining = 'Hokkaido'\n", 82 | "# same as dataset but in small caption\n", 83 | "dataset2 = 'hokkaido' \n", 84 | "loss = 'dice'\n", 85 | "experiment_name = 'segment_hokk_pretrain_hokk_cnn'\n", 86 | "trainsize = '5'\n", 87 | "unet = False" 88 | ] 89 | }, 90 | { 91 | "cell_type": "markdown", 92 | "id": "38e312b2-97b7-4663-a489-4704c1c23e44", 93 | "metadata": {}, 94 | "source": [ 95 | "### you will need to adapt the model paths here" 96 | ] 97 | }, 98 | { 99 | "cell_type": "code", 100 | "execution_count": null, 101 | "id": "e7c8107b-af01-41c7-9590-b32b80685907", 102 | "metadata": {}, 103 | "outputs": [], 104 | "source": [ 105 | "# pretrained models\n", 106 | "net = Siamese_Type_1(**{'input_size':[2,128,128],'embedding_size':32,'unet':True,'decoder_depth':1, 'encoder_depth':1, 'cnn':True,'base_lr':0.001,'decoder_channels':[32]})\n", 107 | "if pretraining=='Hokkaido':\n", 108 | " pretrain_path = '/home/jupyter/deepslide/logs/experiments/runs/Siamese_Type1_hokkaido/2022-09-27_12-14-28/checkpoints/epoch_269.ckpt'\n", 109 | "elif pretraining==\"Kaikoura\":\n", 110 | " pretrain_path = '/home/jupyter/deepslide/logs/experiments/runs/Siamese_Type1_kaikora/2022-09-27_11-38-01/checkpoints/last.ckpt'\n", 111 | " \n", 112 | "path = Path('/home/jupyter/deepslide/logs/experiments/runs/{}/'.format(experiment_name))" 113 | ] 114 | }, 115 | { 116 | "cell_type": "code", 117 | "execution_count": null, 118 | "id": "ca7f6d87-f4fc-40ea-9221-36ae8fa71726", 119 | "metadata": {}, 120 | "outputs": [], 121 | "source": [ 122 | "def compute_metrics(preds, targets, threshold=0.5):\n", 123 | " #print(preds)\n", 124 | " preds =preds.view((preds.size()[0],-1))\n", 125 | " targets =targets.view((targets.size()[0],-1))\n", 126 | "\n", 127 | " prec, rec = precision_recall(preds, targets, threshold=threshold)\n", 128 | " f1_score = 2*(prec*rec)/(prec+rec)\n", 129 | "\n", 130 | " # pr_curve = PrecisionRecallCurve(num_classes=5)\n", 131 | " # precision, recall, thresholds = pr_curve(preds, targets)\n", 132 | "\n", 133 | " average_precision = AveragePrecision()\n", 134 | " AP_score = average_precision(preds, targets)\n", 135 | " accuracy = Accuracy(threshold=threshold).cuda()\n", 136 | " acc = accuracy(preds, targets)\n", 137 | " return f1_score.item(), AP_score.item(), prec.item(), rec.item(), acc.detach().item()" 138 | ] 139 | }, 140 | { 141 | "cell_type": "markdown", 142 | "id": "74855a47-7656-4949-9383-93a2bc151b28", 143 | "metadata": {}, 144 | "source": [ 145 | "### finds the right model files, evluates them, computes scores and dumps results" 146 | ] 147 | }, 148 | { 149 | "cell_type": "code", 150 | "execution_count": null, 151 | "id": "9ca3a94c-fa0a-42ee-ae78-346defb18977", 152 | "metadata": {}, 153 | "outputs": [], 154 | "source": [ 155 | "import json\n", 156 | "\n", 157 | "\n", 158 | "filepaths1=[]\n", 159 | "filepaths2=[]\n", 160 | "for filepath in path.rglob(\"*config_tree.log\"):\n", 161 | " file = open(filepath, 'r')\n", 162 | " lines = file.read().splitlines()\n", 163 | " for line in lines:\n", 164 | " if 'loss:' in line:\n", 165 | " if \" \"+loss in str(line):\n", 166 | " filepaths1.append(filepath.parent)\n", 167 | "\n", 168 | "print(len(filepaths1))\n", 169 | "\n", 170 | "for pp in filepaths1:\n", 171 | " for filepath in pp.glob(\"config_tree.log\"):\n", 172 | " file = open(filepath, 'r')\n", 173 | " lines = file.read().splitlines()\n", 174 | " for line in lines:\n", 175 | " if 'pretrain_path:' in line:\n", 176 | " if pretrain_path in line:\n", 177 | " filepaths2.append(filepath.parent)\n", 178 | "\n", 179 | "for trainsize in ['2','5','10','20','-1']:\n", 180 | " model_dirs_augmented = []\n", 181 | " model_dirs = []\n", 182 | " timestamps_augmented =[]\n", 183 | " timestamps = []\n", 184 | "\n", 185 | "\n", 186 | " filepaths3=[]\n", 187 | "\n", 188 | " print(len(filepaths2)) \n", 189 | " for pp in filepaths2:\n", 190 | " for filepath in pp.rglob(\"*config_tree.log\"):\n", 191 | " file = open(filepath, 'r')\n", 192 | " lines = file.read().splitlines()\n", 193 | " for line in lines:\n", 194 | " if 'trainsize' in line:\n", 195 | " if trainsize in line:\n", 196 | " filepaths3.append(filepath.parent)\n", 197 | "\n", 198 | " print(len(filepaths3))\n", 199 | "\n", 200 | " count1, count2= 0,0\n", 201 | " gss=[]\n", 202 | "\n", 203 | " for pp in filepaths3:\n", 204 | " count1+=1\n", 205 | " for filepath in pp.glob(\"*config_tree.log\"):\n", 206 | " file = open(filepath, 'r')\n", 207 | " lines = file.read().splitlines()\n", 208 | " for line in lines:\n", 209 | " if '- hokkaido' in line:\n", 210 | " count2+=1\n", 211 | " pp = filepath.parent/'checkpoints/'\n", 212 | " if 'pre_train_augmented' in line:\n", 213 | " if 'true' in line:\n", 214 | " if len(list(pp.glob(\"*.ckpt\")))>3:\n", 215 | " model_dirs_augmented.append([file for file in pp.glob(\"*.ckpt\")])\n", 216 | " with open(filepath.parent/'wandb/latest-run/files/wandb-summary.json') as user_file:\n", 217 | " file_contents = user_file.read()\n", 218 | " timestamps_augmented.append(dict(json.loads(file_contents))['_timestamp'])\n", 219 | " else:\n", 220 | " with open(filepath.parent/'wandb/latest-run/files/wandb-summary.json') as user_file:\n", 221 | " file_contents = user_file.read()\n", 222 | " timestamps.append(dict(json.loads(file_contents))['_timestamp'])\n", 223 | " if len(list(pp.glob(\"*.ckpt\")))>3:\n", 224 | " model_dirs.append([file for file in pp.glob(\"*.ckpt\")])\n", 225 | " print(len(model_dirs_augmented), len(model_dirs)) \n", 226 | " if trainsize=='2':\n", 227 | " model_dirs_augmented = np.asarray(model_dirs_augmented)[np.argsort(np.asarray(timestamps_augmented))[-5::]].flatten()\n", 228 | " model_dirs = np.asarray(model_dirs)[np.argsort(np.asarray(timestamps))[-5::]].flatten()\n", 229 | " else:\n", 230 | " model_dirs_augmented = np.asarray(model_dirs_augmented)[np.argsort(np.asarray(timestamps_augmented))[-3::]].flatten()\n", 231 | " model_dirs = np.asarray(model_dirs)[np.argsort(np.asarray(timestamps))[-3::]].flatten()\n", 232 | " print(len(model_dirs_augmented), len(model_dirs)) \n", 233 | "\n", 234 | " model_dirs\n", 235 | "\n", 236 | " datadict = {'data_dir': '/home/jupyter/deepslide/data/',\n", 237 | " 'dict_dir': '/home/jupyter/deepslide/data/',\n", 238 | " 'batch_size': 32,\n", 239 | " 'num_workers': 8,\n", 240 | " 'pin_memory': False,\n", 241 | " 'input_channels': ['vh', 'vv'], #, 'los.rdr_0', 'los.rdr_1', 'topophase.cor_1', 'topophase.flat_imag', 'topophase.flat_real', 'dem']\n", 242 | " 'input_transforms': ['Log_transform','Standardize'],\n", 243 | " 'num_time_steps': 1,\n", 244 | " 'setting': 'downstream',\n", 245 | " 'datasets': [dataset2]}\n", 246 | "\n", 247 | " dloader = Siamese_Landslide_Datamodule(**datadict)\n", 248 | " test_loader = dloader.test_dataloader()\n", 249 | "\n", 250 | "\n", 251 | " # from collections import OrderedDict\n", 252 | " # new_dict=OrderedDict()\n", 253 | " # for key in pretrain_dict.keys():\n", 254 | " # new_dict['pretrained.'+key] = pretrain_dict[key]\n", 255 | "\n", 256 | " mean_predictions={}\n", 257 | " targets={}\n", 258 | " metrics ={} \n", 259 | " for name, pre_train_augmented, paths in zip(['pretrain','no_pretrain'],[True, False],[model_dirs_augmented, model_dirs]):\n", 260 | "\n", 261 | " metrics[name]={}\n", 262 | " mean_predictions[name]={}\n", 263 | " mean_predictions[name]['all']=np.zeros((56,128,128))\n", 264 | " mean_predictions[name]['nl']=np.zeros((24,128,128))\n", 265 | " mean_predictions[name]['l']=np.zeros((32,128,128))\n", 266 | " metrics[name]['APRC'] = []\n", 267 | "\n", 268 | "\n", 269 | " downstream_model_dict={ \n", 270 | " 'input_size': [4,128,128],\n", 271 | " 'embedding_size': 64,\n", 272 | " 'pre_train_augmented': pre_train_augmented,\n", 273 | " 'pretrain_path': pretrain_path,\n", 274 | " 'unet': unet,\n", 275 | " 'base_lr': 0.001,\n", 276 | " 'pretrain_params': {'input_size':[2,128,128],'embedding_size':32,'decoder_depth':1,'encoder_depth':1,'unet':True,'cnn':True,\n", 277 | " 'base_lr':0.001,'decoder_channels':[32]},\n", 278 | " 'encoder_depth': 1,\n", 279 | " 'decoder_channels': [32],\n", 280 | " 'loss': loss}\n", 281 | "\n", 282 | " count=0\n", 283 | " model = Segmentation_Model(**downstream_model_dict)\n", 284 | " #print(model)\n", 285 | " for path in paths:\n", 286 | " if not path.name=='last.ckpt':\n", 287 | " if 'ap' in path.name:\n", 288 | " checkpoint = torch.load(path)['state_dict']\n", 289 | " # for key in checkpoint.keys():\n", 290 | " # print(key, checkpoint[key].shape)\n", 291 | " # pretrain_dict = torch.load(downstream_model_dict['pretrain_path'])['state_dict']\n", 292 | " # checkpoint.update(new_dict)\n", 293 | " model.load_state_dict(checkpoint)\n", 294 | " model.eval()\n", 295 | " model.cuda()\n", 296 | " for pre, post, label, names, weight in test_loader:\n", 297 | " with torch.no_grad():\n", 298 | " count+=1\n", 299 | " preds = model.forward(pre.cuda(),post.cuda())\n", 300 | " label = label.view(preds.size()).cuda()\n", 301 | " mean_predictions[name]['all']+=np.squeeze(preds.detach().cpu().numpy())\n", 302 | " targets['all'] = label.detach().cpu().numpy()\n", 303 | " landslide=torch.sum(label, axis=(1,2,3))>0\n", 304 | "\n", 305 | " preds_nl = preds[~landslide]\n", 306 | " targets['nl'] = label[~landslide].detach().cpu().numpy()\n", 307 | " mean_predictions[name]['nl']+=np.squeeze(preds_nl.detach().cpu().numpy())\n", 308 | " \n", 309 | " preds_l = preds[landslide]\n", 310 | " targets['l'] = label[landslide].detach().cpu().numpy()\n", 311 | " mean_predictions[name]['l']+=np.squeeze(preds_l.detach().cpu().numpy())\n", 312 | " \n", 313 | " preds =torch.sigmoid(preds)\n", 314 | "\n", 315 | " res = compute_metrics(preds, label.cuda(), threshold=0.5)\n", 316 | " \n", 317 | "\n", 318 | " metrics[name]['APRC'].append(res[1])\n", 319 | " if name=='pretrain':\n", 320 | " save_model=model\n", 321 | "\n", 322 | " \n", 323 | " results={}\n", 324 | " results[dataset] = {}\n", 325 | " \n", 326 | " for name in ['pretrain','no_pretrain']:\n", 327 | " results[dataset][name] = {}\n", 328 | " for subfix in ['all','nl','l']:\n", 329 | " results[dataset][name]['APRC_'+subfix]=[]\n", 330 | " results[dataset][name]['f1_'+subfix]=[]\n", 331 | " results[dataset][name]['iou_'+subfix]=[]\n", 332 | " for ii in range(len(mean_predictions[name][subfix])):\n", 333 | " preds = scipy.special.expit(mean_predictions[name][subfix]/count)[ii:ii+1]\n", 334 | " res = compute_metrics(torch.tensor(preds).cuda(), torch.tensor(targets[subfix]).cuda()[ii], threshold=0.5)\n", 335 | " results[dataset][name]['APRC_'+subfix].append(res[1])\n", 336 | " results[dataset][name]['f1_'+subfix].append(res[0])\n", 337 | " intersection = np.logical_and(preds>0.5,targets[subfix][ii])\n", 338 | " union = np.logical_or(preds>0.5,targets[subfix][ii])\n", 339 | " iou_score = np.sum(intersection) / np.sum(union)\n", 340 | " results[dataset][name]['iou_'+subfix].append(iou_score)\n", 341 | " \n", 342 | " results[dataset][name]['APRC_'+subfix]=np.asarray(results[dataset][name]['APRC_'+subfix])\n", 343 | " results[dataset][name]['f1_'+subfix]=np.asarray(results[dataset][name]['f1_'+subfix])\n", 344 | " results[dataset][name]['iou_'+subfix]=np.asarray(results[dataset][name]['iou_'+subfix])\n", 345 | " \n", 346 | "\n", 347 | " results[dataset][name]['APRC_'+subfix+'_mean']=np.mean(metrics[name]['APRC'])\n", 348 | " results[dataset][name]['f1_'+subfix+'_mean']=np.mean(results[dataset][name]['f1_'+subfix])\n", 349 | " results[dataset][name]['iou_'+subfix+'_mean']=np.mean(results[dataset][name]['iou_'+subfix])\n", 350 | "\n", 351 | " results[dataset][name]['APRC_'+subfix+'_median']=np.median(metrics[name]['APRC'])\n", 352 | " results[dataset][name]['f1_'+subfix+'_median']=np.median(results[dataset][name]['f1_'+subfix])\n", 353 | " results[dataset][name]['iou_'+subfix+'_median']=np.median(results[dataset][name]['iou_'+subfix])\n", 354 | "\n", 355 | " results[dataset][name]['APRC_'+subfix+'_std']=np.std(metrics[name]['APRC'])/len(metrics[name]['APRC'])\n", 356 | " results[dataset][name]['f1_'+subfix+'_std']=np.std(results[dataset][name]['f1_'+subfix])/len(results[dataset][name]['f1_'+subfix])\n", 357 | " results[dataset][name]['iou_'+subfix+'_std']=np.std(results[dataset][name]['iou_'+subfix])/len(results[dataset][name]['iou_'+subfix])\n", 358 | "\n", 359 | "\n", 360 | " mean_predictions[name][subfix]=np.squeeze(scipy.special.expit(mean_predictions[name][subfix]/count))\n", 361 | " results[dataset][name]['l1_'+subfix+'_mean'] = np.mean(np.abs(np.sum(mean_predictions[name][subfix], axis=(1,2))-np.sum(targets[subfix], axis=(1,2,3))))\n", 362 | " results[dataset][name]['sl1_'+subfix+'_mean']= np.mean(np.sum(mean_predictions[name][subfix], axis=(1,2))-np.sum(targets[subfix], axis=(1,2,3)))\n", 363 | "\n", 364 | " results[dataset][name]['l1_'+subfix+'_median'] = np.median(np.abs(np.sum(mean_predictions[name][subfix], axis=(1,2))-np.sum(targets[subfix], axis=(1,2,3))))\n", 365 | " results[dataset][name]['sl1_'+subfix+'_median']= np.median(np.sum(mean_predictions[name][subfix], axis=(1,2))-np.sum(targets[subfix], axis=(1,2,3)))\n", 366 | "\n", 367 | " results[dataset][name]['l1_'+subfix+'_std'] = np.std(np.abs(np.sum(mean_predictions[name][subfix], axis=(1,2))-np.sum(targets[subfix], axis=(1,2,3))))/len(targets[subfix])\n", 368 | " results[dataset][name]['sl1_'+subfix+'_std'] = np.std(np.sum(mean_predictions[name][subfix], axis=(1,2))-np.sum(targets[subfix], axis=(1,2,3)))/len(targets[subfix])\n", 369 | "\n", 370 | " pickle.dump(results, open('scores_{}_{}.pkl'.format(experiment_name, trainsize), 'wb'))\n", 371 | "\n", 372 | " pickle.dump(mean_predictions,open('label_predictions_{}_{}.pkl'.format(experiment_name, trainsize), 'wb'))" 373 | ] 374 | } 375 | ], 376 | "metadata": { 377 | "environment": { 378 | "kernel": "deepslide", 379 | "name": "pytorch-gpu.1-11.m94", 380 | "type": "gcloud", 381 | "uri": "gcr.io/deeplearning-platform-release/pytorch-gpu.1-11:m94" 382 | }, 383 | "kernelspec": { 384 | "display_name": "deepslide", 385 | "language": "python", 386 | "name": "deepslide" 387 | }, 388 | "language_info": { 389 | "codemirror_mode": { 390 | "name": "ipython", 391 | "version": 3 392 | }, 393 | "file_extension": ".py", 394 | "mimetype": "text/x-python", 395 | "name": "python", 396 | "nbconvert_exporter": "python", 397 | "pygments_lexer": "ipython3", 398 | "version": "3.10.5" 399 | } 400 | }, 401 | "nbformat": 4, 402 | "nbformat_minor": 5 403 | } 404 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | hydra-core==1.2.0 2 | matplotlib==3.6.2 3 | numpy==1.23.4 4 | omegaconf==2.2.3 5 | osgeo==0.0.1 6 | packaging==21.3 7 | Pillow==9.3.0 8 | pytest==7.2.0 9 | python-dotenv==0.21.0 10 | pytorch_lightning==1.8.1 11 | rlxutils==0.1.10 12 | scikit_learn==1.1.3 13 | scipy==1.9.3 14 | seaborn==0.12.1 15 | segmentation_models_pytorch==0.3.0 16 | setuptools==52.0.0.post20210125 17 | sh==1.14.3 18 | skimage==0.0 19 | torch==1.13.0 20 | torchmetrics==0.10.3 21 | torchvision==0.14.0 22 | tqdm==4.61.2 23 | wandb==0.13.5 24 | -------------------------------------------------------------------------------- /scripts/run_pretext_tasks.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # baseline 3 | python train.py logger=wandb experiment=Siamese_Type1_hokkaido.yaml model.unet=True model.cnn=True datamodule.balanced=True datamodule.balance_weights=False seed=12345 4 | python train.py logger=wandb experiment=Siamese_Type1_kaikoura.yaml model.unet=True model.cnn=True datamodule.balanced=True datamodule.balance_weights=False seed=12345 5 | 6 | -------------------------------------------------------------------------------- /scripts/run_segmentation_experiments.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | for ii in 20 2 5 10 4 | do 5 | for seed in 12345 55555 54321 6 | do 7 | python train.py logger=wandb experiment=seg_hokk_hokk_pretrain_cnn.yaml model.pre_train_augmented=False datamodule.trainsize=$ii seed=$seed 8 | python train.py logger=wandb experiment=seg_hokk_hokk_pretrain_cnn.yaml model.pre_train_augmented=True datamodule.trainsize=$ii seed=$seed 9 | done 10 | done 11 | 12 | ii=2 13 | for seed in 54321 56762 9867 14 | do 15 | python train.py logger=wandb experiment=seg_hokk_hokk_pretrain_cnn.yaml model.pre_train_augmented=False datamodule.trainsize=$ii seed=$seed 16 | python train.py logger=wandb experiment=seg_hokk_hokk_pretrain_cnn.yaml model.pre_train_augmented=True datamodule.trainsize=$ii seed=$seed 17 | done 18 | 19 | 20 | for ii in 20 2 5 10 21 | do 22 | for seed in 12345 55555 54321 23 | do 24 | python train.py logger=wandb experiment=seg_hokk_kaik_pretrain_cnn.yaml model.pre_train_augmented=False datamodule.trainsize=$ii seed=$seed 25 | python train.py logger=wandb experiment=seg_hokk_kaik_pretrain_cnn.yaml model.pre_train_augmented=True datamodule.trainsize=$ii seed=$seed 26 | done 27 | done 28 | 29 | ii=2 30 | for seed in 54321 56762 9867 31 | do 32 | python train.py logger=wandb experiment=seg_hokk_kaik_pretrain_cnn.yaml model.pre_train_augmented=False datamodule.trainsize=$ii seed=$seed 33 | python train.py logger=wandb experiment=seg_hokk_kaik_pretrain_cnn.yaml model.pre_train_augmented=True datamodule.trainsize=$ii seed=$seed 34 | done 35 | 36 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VMBoehm/SAR-landslide-detection-pretraining/177799a5f22a1fab38cf0da76b98b3d32843a765/src/__init__.py -------------------------------------------------------------------------------- /src/callbacks/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VMBoehm/SAR-landslide-detection-pretraining/177799a5f22a1fab38cf0da76b98b3d32843a765/src/callbacks/__init__.py -------------------------------------------------------------------------------- /src/callbacks/probe_callback.py: -------------------------------------------------------------------------------- 1 | from contextlib import contextmanager 2 | from typing import Any, Dict, Optional, Sequence, Tuple, Union 3 | 4 | import torch 5 | from pytorch_lightning import Callback, LightningModule, Trainer 6 | from pytorch_lightning.utilities import rank_zero_warn 7 | from torch import Tensor, nn 8 | from torch.nn import functional as F 9 | from torch.optim import Optimizer 10 | from torchmetrics.functional import precision_recall 11 | 12 | from src.models.components import networks 13 | import segmentation_models_pytorch as smp 14 | 15 | 16 | 17 | class SSLOnlineEvaluator(Callback): # pragma: no cover 18 | """Attaches a MLP for fine-tuning using the standard self-supervised protocol. 19 | Example:: 20 | # your datamodule must have 2 attributes 21 | dm = DataModule() 22 | dm.num_classes = ... # the num of classes in the datamodule 23 | dm.name = ... # name of the datamodule (e.g. ImageNet, STL10, CIFAR10) 24 | # your model must have 1 attribute 25 | model = Model() 26 | model.z_dim = ... # the representation dim 27 | online_eval = SSLOnlineEvaluator( 28 | z_dim=model.z_dim 29 | ) 30 | """ 31 | 32 | def __init__( 33 | self, 34 | in_channels: int = 4, 35 | out_channels: int = 1, 36 | optimizer = None, 37 | dataset = None, 38 | activation = 'Sigmoid' 39 | ): 40 | """ 41 | Args: 42 | z_dim: Representation dimension 43 | drop_p: Dropout probability 44 | hidden_dim: Hidden dimension for the fine-tune MLP 45 | """ 46 | super().__init__() 47 | 48 | self.optimizer = optimizer 49 | 50 | self.dataset = dataset 51 | 52 | self.in_channels = in_channels 53 | self.out_channels = out_channels 54 | 55 | self.criterion = smp.losses.DiceLoss(mode="binary",from_logits=False) 56 | self.activation = activation 57 | self._recovered_callback_state: Optional[Dict[str, Any]] = None 58 | 59 | def setup(self, trainer: Trainer, pl_module: LightningModule, stage: Optional[str] = None) -> None: 60 | if self.dataset is None: 61 | self.dataset = trainer.datamodule.name 62 | 63 | def on_pretrain_routine_start(self, trainer: Trainer, pl_module: LightningModule) -> None: 64 | # must move to device after setup, as during setup, pl_module is still on cpu 65 | self.online_evaluator = networks.DecoderHead(self.in_channels,self.out_channels,final_activation=self.activation).to(pl_module.device) 66 | 67 | # switch fo PL compatibility reasons 68 | accel = ( 69 | trainer.accelerator_connector 70 | if hasattr(trainer, "accelerator_connector") 71 | else trainer._accelerator_connector 72 | ) 73 | if accel.is_distributed: 74 | if accel._strategy_flag=='ddp': 75 | from torch.nn.parallel import DistributedDataParallel as DDP 76 | 77 | self.online_evaluator = DDP(self.online_evaluator, device_ids=[pl_module.device]) 78 | elif accel._strategy_flag=='dp': 79 | from torch.nn.parallel import DataParallel as DP 80 | 81 | self.online_evaluator = DP(self.online_evaluator, device_ids=[pl_module.device]) 82 | # elif accel._strategy_flag=='ddp_spawn': 83 | # from torch.nn.parallel import DDPSpawnStrategy as DPS 84 | # self.online_evaluator = DPS(self.online_evaluator, device_ids=[pl_module.device]) 85 | else: 86 | rank_zero_warn( 87 | "Does not support this type of distributed accelerator. The online evaluator will not sync." 88 | ) 89 | 90 | self.optimizer = torch.optim.Adam(self.online_evaluator.parameters(), lr=1e-4) 91 | 92 | if self._recovered_callback_state is not None: 93 | self.online_evaluator.load_state_dict(self._recovered_callback_state["state_dict"]) 94 | self.optimizer.load_state_dict(self._recovered_callback_state["optimizer_state"]) 95 | 96 | def to_device(self, batch: Sequence, device: Union[str, torch.device]) -> Tuple[Tensor, Tensor]: 97 | # adapted 98 | inputs, _, _, y = batch 99 | 100 | # last input is for online eval 101 | x = inputs#[-1] 102 | x = x.to(device) 103 | y = y.to(device) 104 | 105 | return x, y 106 | 107 | # adapted 108 | def shared_step( 109 | self, 110 | pl_module: LightningModule, 111 | batch: Sequence, 112 | ): 113 | with torch.no_grad(): 114 | with set_training(pl_module, False): 115 | x, y = self.to_device(batch, pl_module.device) 116 | representations = pl_module(x) 117 | 118 | # forward pass 119 | preds = self.online_evaluator(representations) # type: ignore[operator] 120 | 121 | loss = self.criterion(preds, y) 122 | 123 | prec, rec = precision_recall(preds.view(-1), y.view(-1)) 124 | 125 | return loss, prec, rec 126 | 127 | def on_train_batch_end( 128 | self, 129 | trainer: Trainer, 130 | pl_module: LightningModule, 131 | outputs: Sequence, 132 | batch: Sequence, 133 | batch_idx: int, 134 | dataloader_idx: int, 135 | ) -> None: 136 | loss, prec, rec = self.shared_step(pl_module, batch) 137 | 138 | # update finetune weights 139 | loss.backward() 140 | self.optimizer.step() 141 | self.optimizer.zero_grad() 142 | 143 | #adapted 144 | pl_module.log("train/online_prec", prec, on_step=False, on_epoch=True, sync_dist=True) 145 | pl_module.log("train/online_recall", rec, on_step=False, on_epoch=True, sync_dist=True) 146 | pl_module.log("train/online_f1", 2*(rec*prec)/(rec+prec), on_step=False, on_epoch=True, sync_dist=True) 147 | pl_module.log("train/online_loss", loss, on_step=True, on_epoch=False) 148 | 149 | def on_validation_batch_end( 150 | self, 151 | trainer: Trainer, 152 | pl_module: LightningModule, 153 | outputs: Sequence, 154 | batch: Sequence, 155 | batch_idx: int, 156 | dataloader_idx: int, 157 | ) -> None: 158 | loss, prec, rec = self.shared_step(pl_module, batch) 159 | #adapted 160 | pl_module.log("valid/online_prec", prec, on_step=False, on_epoch=True, sync_dist=True) 161 | pl_module.log("valid/online_recall", rec, on_step=False, on_epoch=True, sync_dist=True) 162 | pl_module.log("valid/online_f1", 2*(rec*prec)/(rec+prec), on_step=False, on_epoch=True, sync_dist=True) 163 | pl_module.log("valid/online_loss", loss, on_step=True, on_epoch=False) 164 | 165 | def on_save_checkpoint(self, trainer: Trainer, pl_module: LightningModule, checkpoint: Dict[str, Any]) -> dict: 166 | return {"state_dict": self.online_evaluator.state_dict(), "optimizer_state": self.optimizer.state_dict()} 167 | 168 | def on_load_checkpoint(self, trainer: Trainer, pl_module: LightningModule, callback_state: Dict[str, Any]) -> None: 169 | self._recovered_callback_state = callback_state 170 | 171 | 172 | class SSLOnlineEvaluator_bottleneck(Callback): # pragma: no cover 173 | """Attaches a MLP for fine-tuning using the standard self-supervised protocol. 174 | Example:: 175 | # your datamodule must have 2 attributes 176 | dm = DataModule() 177 | dm.num_classes = ... # the num of classes in the datamodule 178 | dm.name = ... # name of the datamodule (e.g. ImageNet, STL10, CIFAR10) 179 | # your model must have 1 attribute 180 | model = Model() 181 | model.z_dim = ... # the representation dim 182 | online_eval = SSLOnlineEvaluator( 183 | z_dim=model.z_dim 184 | ) 185 | """ 186 | 187 | def __init__( 188 | self, 189 | in_channels: int = 4, 190 | out_channels: int = 1, 191 | optimizer = None, 192 | dataset = None, 193 | activation = 'Sigmoid' 194 | ): 195 | """ 196 | Args: 197 | z_dim: Representation dimension 198 | drop_p: Dropout probability 199 | hidden_dim: Hidden dimension for the fine-tune MLP 200 | """ 201 | super().__init__() 202 | 203 | self.optimizer = optimizer 204 | 205 | self.dataset = dataset 206 | 207 | self.in_channels = in_channels 208 | self.out_channels = out_channels 209 | 210 | self.criterion = smp.losses.DiceLoss(mode="binary",from_logits=False) 211 | self.activation = activation 212 | self._recovered_callback_state: Optional[Dict[str, Any]] = None 213 | 214 | def setup(self, trainer: Trainer, pl_module: LightningModule, stage: Optional[str] = None) -> None: 215 | if self.dataset is None: 216 | self.dataset = trainer.datamodule.name 217 | 218 | def on_pretrain_routine_start(self, trainer: Trainer, pl_module: LightningModule) -> None: 219 | # must move to device after setup, as during setup, pl_module is still on cpu 220 | self.online_evaluator = networks.Combined_Upsample_CNN_decoder(self.in_channels,self.out_channels,final_activation=self.activation).to(pl_module.device) 221 | 222 | # switch fo PL compatibility reasons 223 | accel = ( 224 | trainer.accelerator_connector 225 | if hasattr(trainer, "accelerator_connector") 226 | else trainer._accelerator_connector 227 | ) 228 | if accel.is_distributed: 229 | if accel._strategy_flag=='ddp': 230 | from torch.nn.parallel import DistributedDataParallel as DDP 231 | 232 | self.online_evaluator = DDP(self.online_evaluator, device_ids=[pl_module.device]) 233 | elif accel._strategy_flag=='dp': 234 | from torch.nn.parallel import DataParallel as DP 235 | 236 | self.online_evaluator = DP(self.online_evaluator, device_ids=[pl_module.device]) 237 | # elif accel._strategy_flag=='ddp_spawn': 238 | # from torch.nn.parallel import DDPSpawnStrategy as DPS 239 | # self.online_evaluator = DPS(self.online_evaluator, device_ids=[pl_module.device]) 240 | else: 241 | rank_zero_warn( 242 | "Does not support this type of distributed accelerator. The online evaluator will not sync." 243 | ) 244 | 245 | self.optimizer = torch.optim.Adam(self.online_evaluator.parameters(), lr=1e-4) 246 | 247 | if self._recovered_callback_state is not None: 248 | self.online_evaluator.load_state_dict(self._recovered_callback_state["state_dict"]) 249 | self.optimizer.load_state_dict(self._recovered_callback_state["optimizer_state"]) 250 | 251 | def to_device(self, batch: Sequence, device: Union[str, torch.device]) -> Tuple[Tensor, Tensor]: 252 | # adapted 253 | inputs, _, _, y = batch 254 | 255 | # last input is for online eval 256 | x = inputs#[-1] 257 | x = x.to(device) 258 | y = y.to(device) 259 | 260 | return x, y 261 | 262 | # adapted 263 | def shared_step( 264 | self, 265 | pl_module: LightningModule, 266 | batch: Sequence, 267 | ): 268 | with torch.no_grad(): 269 | with set_training(pl_module, False): 270 | x, y = self.to_device(batch, pl_module.device) 271 | representations = pl_module.encoder(x, torch.ones(x.shape).cuda()) 272 | 273 | # forward pass 274 | 275 | preds = self.online_evaluator(representations) # type: ignore[operator] 276 | 277 | loss = self.criterion(preds, y) 278 | 279 | prec, rec = precision_recall(preds.view(-1), y.view(-1)) 280 | 281 | return loss, prec, rec 282 | 283 | def on_train_batch_end( 284 | self, 285 | trainer: Trainer, 286 | pl_module: LightningModule, 287 | outputs: Sequence, 288 | batch: Sequence, 289 | batch_idx: int, 290 | dataloader_idx: int, 291 | ) -> None: 292 | loss, prec, rec = self.shared_step(pl_module, batch) 293 | 294 | # update finetune weights 295 | loss.backward() 296 | self.optimizer.step() 297 | self.optimizer.zero_grad() 298 | 299 | #adapted 300 | pl_module.log("train/bottleneck_online_prec", prec, on_step=False, on_epoch=True, sync_dist=True) 301 | pl_module.log("train/bottleneck_online_recall", rec, on_step=False, on_epoch=True, sync_dist=True) 302 | pl_module.log("train/bottleneck_online_f1", 2*(rec*prec)/(rec+prec), on_step=False, on_epoch=True, sync_dist=True) 303 | pl_module.log("train/bottleneck_online_loss", loss, on_step=True, on_epoch=False) 304 | 305 | def on_validation_batch_end( 306 | self, 307 | trainer: Trainer, 308 | pl_module: LightningModule, 309 | outputs: Sequence, 310 | batch: Sequence, 311 | batch_idx: int, 312 | dataloader_idx: int, 313 | ) -> None: 314 | loss, prec, rec = self.shared_step(pl_module, batch) 315 | #adapted 316 | pl_module.log("valid/bottleneck_online_prec", prec, on_step=False, on_epoch=True, sync_dist=True) 317 | pl_module.log("valid/bottleneck_online_recall", rec, on_step=False, on_epoch=True, sync_dist=True) 318 | pl_module.log("valid/bottleneck_online_f1", 2*(rec*prec)/(rec+prec), on_step=False, on_epoch=True, sync_dist=True) 319 | pl_module.log("valid/bottleneck_online_loss", loss, on_step=True, on_epoch=False) 320 | 321 | def on_save_checkpoint(self, trainer: Trainer, pl_module: LightningModule, checkpoint: Dict[str, Any]) -> dict: 322 | return {"state_dict": self.online_evaluator.state_dict(), "optimizer_state": self.optimizer.state_dict()} 323 | 324 | def on_load_checkpoint(self, trainer: Trainer, pl_module: LightningModule, callback_state: Dict[str, Any]) -> None: 325 | self._recovered_callback_state = callback_state 326 | 327 | @contextmanager 328 | def set_training(module: nn.Module, mode: bool): 329 | """Context manager to set training mode. 330 | When exit, recover the original training mode. 331 | Args: 332 | module: module to set training mode 333 | mode: whether to set training mode (True) or evaluation mode (False). 334 | """ 335 | original_mode = module.training 336 | 337 | try: 338 | module.train(mode) 339 | yield module 340 | finally: 341 | module.train(original_mode) -------------------------------------------------------------------------------- /src/callbacks/wandb_callbacks.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | from typing import List 4 | from torch.utils.data import DataLoader 5 | from tqdm import tqdm 6 | import matplotlib.pyplot as plt 7 | import seaborn as sns 8 | import torch 9 | import wandb 10 | from pytorch_lightning import Callback, Trainer 11 | from pytorch_lightning.loggers import LoggerCollection, WandbLogger 12 | from pytorch_lightning.utilities import rank_zero_only 13 | from sklearn import metrics 14 | from sklearn.metrics import f1_score, precision_score, recall_score 15 | import gc 16 | import numpy as np 17 | 18 | 19 | 20 | from typing import List 21 | 22 | import torch 23 | from pl_bolts.models.self_supervised.evaluator import SSLEvaluator 24 | from pytorch_lightning import Callback 25 | from torch import nn 26 | from tqdm import tqdm 27 | from tqdm import trange 28 | 29 | 30 | def make_figure(data): 31 | fig, ax = plt.subplots(1, 1) 32 | im = ax.imshow(data) 33 | fig.colorbar(im, ax=ax) 34 | plt.close() 35 | return fig 36 | 37 | def make_plot(data1, data2=None, label1='target', label2='prediction'): 38 | fig, ax = plt.subplots(1, 1) 39 | ax.plot(data1, label=label1, ls='', marker='o') 40 | if np.all(data2)!=None: 41 | ax.plot(data2, label=label2, ls='', marker='+') 42 | plt.legend() 43 | plt.close() 44 | return fig 45 | 46 | def make_scatter(preds, targets): 47 | fig, ax = plt.subplots(1, 1) 48 | ax.scatter(targets, preds, alpha=.5) 49 | plt.grid() 50 | plt.xlabel("targets") 51 | plt.ylabel("predictions") 52 | plt.close() 53 | return fig 54 | 55 | def get_wandb_logger(trainer: Trainer) -> WandbLogger: 56 | """Safely get Weights&Biases logger from Trainer.""" 57 | 58 | if isinstance(trainer.logger, WandbLogger): 59 | return trainer.logger 60 | 61 | if isinstance(trainer.logger, LoggerCollection): 62 | for logger in trainer.logger: 63 | if isinstance(logger, WandbLogger): 64 | return logger 65 | 66 | raise Exception( 67 | "You are using wandb related callback, but WandbLogger was not found for some reason..." 68 | ) 69 | 70 | 71 | 72 | class LogValPredictions_MAE(Callback): 73 | """Logs a validation batch and their predictions to wandb. 74 | Example adapted from: 75 | https://wandb.ai/wandb/wandb-lightning/reports/Image-Classification-using-PyTorch-Lightning--VmlldzoyODk1NzY 76 | """ 77 | 78 | def __init__(self, num_samples: int = 8, counts=True): 79 | super().__init__() 80 | self.num_samples = num_samples 81 | self.preds = [] 82 | self.inputs = [] 83 | #self.names = [] 84 | #self.masks =[] 85 | 86 | def on_sanity_check_start(self, trainer, pl_module): 87 | self.ready = False 88 | 89 | def on_sanity_check_end(self, trainer, pl_module): 90 | """Start executing this callback only after all validation sanity checks end.""" 91 | self.ready = True 92 | 93 | def on_validation_batch_end( 94 | self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx 95 | ): 96 | 97 | """Gather data from single batch.""" 98 | if self.ready: 99 | """Gather data from single batch.""" 100 | self.preds.append(outputs["preds"]) 101 | 102 | self.inputs.append(outputs["inputs"]) 103 | 104 | #self.names.append(outputs['input_names']) 105 | 106 | #self.masks.append(outputs['mask']) 107 | 108 | 109 | def on_validation_epoch_end(self, trainer, pl_module): 110 | if self.ready: 111 | logger = get_wandb_logger(trainer=trainer) 112 | experiment = logger.experiment 113 | preds = torch.cat(self.preds[:self.num_samples]).detach().cpu().numpy() 114 | inputs = torch.cat(self.inputs[:self.num_samples]).detach().cpu().numpy() 115 | #masks = torch.cat(self.masks[:self.num_samples]).detach().cpu().numpy() 116 | #masked_inputs = inputs*masks 117 | #recons = preds 118 | #recons[masks.astype(bool)] = inputs[masks.astype(bool)] 119 | 120 | names = np.concatenate(self.names[:self.num_samples]) 121 | 122 | input_imgs = [] 123 | 124 | 125 | for i in range(self.num_samples): 126 | if inputs[i].shape[0]>1: 127 | for jj in range(inputs[i].squeeze().shape[0]): 128 | input_imgs.append(wandb.Image(make_figure(inputs[i].squeeze()[jj]), caption=f'{names[i,jj]} target')) 129 | #input_imgs.append(wandb.Image(make_figure(masked_inputs[i].squeeze()[jj]), caption=f'{names[i,jj]} inputs')) 130 | input_imgs.append(wandb.Image(make_figure(preds[i].squeeze()[jj]), caption=f'{names[i,jj]} recon')) 131 | #input_imgs.append(wandb.Image(make_figure(recons[i].squeeze()[jj]), caption=f'{names[i,jj]} recons inpainted')) 132 | else: 133 | input_imgs.append(wandb.Image(make_figure(inputs[i].squeeze()), caption=f'{names[i]} inputs')) 134 | #input_imgs.append(wandb.Image(make_figure(masked_inputs[i].squeeze()), caption=f'{names[i]} inputs')) 135 | input_imgs.append(wandb.Image(make_figure(preds[i].squeeze()), caption=f'{names[i]} recon')) 136 | #input_imgs.append(wandb.Image(make_figure(recons[i].squeeze()), caption=f'{names[i]} recons inpainted')) 137 | 138 | 139 | experiment.log( 140 | { 141 | "valid/maps": input_imgs, 142 | } 143 | ) 144 | 145 | self.preds.clear() 146 | self.inputs.clear() 147 | self.names.clear() 148 | self.masks.clear() 149 | 150 | 151 | 152 | class LogTrainPredictions_MAE_Downstream(Callback): 153 | """Logs a validation batch and their predictions to wandb. 154 | Example adapted from: 155 | https://wandb.ai/wandb/wandb-lightning/reports/Image-Classification-using-PyTorch-Lightning--VmlldzoyODk1NzY 156 | """ 157 | 158 | def __init__(self, num_samples: int = 2, counts=True): 159 | super().__init__() 160 | self.num_samples = num_samples 161 | self.preds = [] 162 | self.inputs = [] 163 | self.names = [] 164 | 165 | def on_sanity_check_start(self, trainer, pl_module): 166 | self.ready = False 167 | 168 | def on_sanity_check_end(self, trainer, pl_module): 169 | """Start executing this callback only after all validation sanity checks end.""" 170 | self.ready = True 171 | 172 | def on_train_batch_end( 173 | self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx 174 | ): 175 | if self.ready: 176 | """Gather data from single batch.""" 177 | self.preds.append(outputs["preds"]) 178 | 179 | self.inputs.append(outputs["inputs"]) 180 | 181 | #self.names.append(outputs['input_names']) 182 | 183 | 184 | def on_train_epoch_end(self, trainer, pl_module): 185 | if self.ready: 186 | logger = get_wandb_logger(trainer=trainer) 187 | experiment = logger.experiment 188 | preds = torch.cat(self.preds[:self.num_samples]).detach().cpu().numpy() 189 | inputs = torch.cat(self.inputs[:self.num_samples]).detach().cpu().numpy() 190 | 191 | #names = np.concatenate(self.names[:self.num_samples]) 192 | 193 | input_imgs = [] 194 | 195 | 196 | for i in range(self.num_samples): 197 | if inputs[i].shape[0]>1: 198 | for jj in range(inputs[i].squeeze().shape[0]): 199 | input_imgs.append(wandb.Image(make_figure(inputs[i].squeeze()[jj]), caption=f'inputs')) 200 | input_imgs.append(wandb.Image(make_figure(preds[i].squeeze()[jj]>0.5), caption=f'recon>0.5')) 201 | input_imgs.append(wandb.Image(make_figure(preds[i].squeeze()[jj]), caption=f'recon')) 202 | else: 203 | input_imgs.append(wandb.Image(make_figure(inputs[i].squeeze()), caption=f'original')) 204 | input_imgs.append(wandb.Image(make_figure(preds[i].squeeze()>0.5), caption=f'recon>0.5')) 205 | input_imgs.append(wandb.Image(make_figure(preds[i].squeeze()), caption=f'recon')) 206 | 207 | 208 | experiment.log( 209 | { 210 | "train/maps": input_imgs, 211 | } 212 | ) 213 | 214 | self.preds.clear() 215 | self.inputs.clear() 216 | self.names.clear() 217 | 218 | 219 | class LogValPredictions_MAE_Downstream(Callback): 220 | """Logs a validation batch and their predictions to wandb. 221 | Example adapted from: 222 | https://wandb.ai/wandb/wandb-lightning/reports/Image-Classification-using-PyTorch-Lightning--VmlldzoyODk1NzY 223 | """ 224 | 225 | def __init__(self, num_samples: int = 8, counts=True): 226 | super().__init__() 227 | self.num_samples = num_samples 228 | self.preds = [] 229 | self.inputs = [] 230 | self.names = [] 231 | 232 | def on_sanity_check_start(self, trainer, pl_module): 233 | self.ready = False 234 | 235 | def on_sanity_check_end(self, trainer, pl_module): 236 | """Start executing this callback only after all validation sanity checks end.""" 237 | self.ready = True 238 | 239 | def on_validation_batch_end( 240 | self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx 241 | ): 242 | 243 | """Gather data from single batch.""" 244 | if self.ready: 245 | """Gather data from single batch.""" 246 | self.preds.append(outputs["preds"]) 247 | 248 | self.inputs.append(outputs["inputs"]) 249 | 250 | #self.names.append(outputs['input_names']) 251 | 252 | 253 | 254 | def on_validation_epoch_end(self, trainer, pl_module): 255 | if self.ready: 256 | logger = get_wandb_logger(trainer=trainer) 257 | experiment = logger.experiment 258 | preds = torch.cat(self.preds[:self.num_samples]).detach().cpu().numpy() 259 | inputs = torch.cat(self.inputs[:self.num_samples]).detach().cpu().numpy() 260 | 261 | #names = np.concatenate(self.names[:self.num_samples]) 262 | 263 | input_imgs = [] 264 | 265 | for i in range(self.num_samples): 266 | if inputs[i].shape[0]>1: 267 | for jj in range(inputs[i].squeeze().shape[0]): 268 | input_imgs.append(wandb.Image(make_figure(inputs[i].squeeze()[jj]), caption=f'inputs')) 269 | input_imgs.append(wandb.Image(make_figure(preds[i].squeeze()[jj]>0.5), caption=f'recon>0.5')) 270 | input_imgs.append(wandb.Image(make_figure(preds[i].squeeze()[jj]), caption=f'recon')) 271 | else: 272 | input_imgs.append(wandb.Image(make_figure(inputs[i].squeeze()), caption=f'original')) 273 | input_imgs.append(wandb.Image(make_figure(preds[i].squeeze()>0.5), caption=f'recon>0.5')) 274 | input_imgs.append(wandb.Image(make_figure(preds[i].squeeze()), caption=f'recon')) 275 | 276 | 277 | experiment.log( 278 | { 279 | "valid/maps": input_imgs, 280 | } 281 | ) 282 | 283 | self.preds.clear() 284 | self.inputs.clear() 285 | self.names.clear() 286 | 287 | 288 | 289 | class LogTrainPredictions_MAE(Callback): 290 | """Logs a validation batch and their predictions to wandb. 291 | Example adapted from: 292 | https://wandb.ai/wandb/wandb-lightning/reports/Image-Classification-using-PyTorch-Lightning--VmlldzoyODk1NzY 293 | """ 294 | 295 | def __init__(self, num_samples: int = 2, counts=True): 296 | super().__init__() 297 | self.num_samples = num_samples 298 | self.preds = [] 299 | self.inputs = [] 300 | self.names = [] 301 | self.masks = [] 302 | 303 | def on_sanity_check_start(self, trainer, pl_module): 304 | self.ready = False 305 | 306 | def on_sanity_check_end(self, trainer, pl_module): 307 | """Start executing this callback only after all validation sanity checks end.""" 308 | self.ready = True 309 | 310 | def on_train_batch_end( 311 | self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx 312 | ): 313 | if self.ready: 314 | """Gather data from single batch.""" 315 | self.preds.append(outputs["preds"]) 316 | 317 | self.inputs.append(outputs["inputs"]) 318 | 319 | self.names.append(outputs['input_names']) 320 | self.masks.append(outputs['mask']) 321 | 322 | 323 | def on_train_epoch_end(self, trainer, pl_module): 324 | if self.ready: 325 | logger = get_wandb_logger(trainer=trainer) 326 | experiment = logger.experiment 327 | preds = torch.cat(self.preds[:self.num_samples]).detach().cpu().numpy() 328 | inputs = torch.cat(self.inputs[:self.num_samples]).detach().cpu().numpy() 329 | masks = torch.cat(self.masks[:self.num_samples]).detach().cpu().numpy() 330 | masked_inputs = inputs*masks 331 | 332 | names = np.concatenate(self.names[:self.num_samples]) 333 | 334 | input_imgs = [] 335 | 336 | 337 | for i in range(self.num_samples): 338 | if inputs[i].shape[0]>1: 339 | for jj in range(inputs[i].squeeze().shape[0]): 340 | input_imgs.append(wandb.Image(make_figure(inputs[i].squeeze()[jj]), caption=f'{names[i,jj]} original')) 341 | input_imgs.append(wandb.Image(make_figure(masked_inputs[i].squeeze()[jj]), caption=f'{names[i,jj]} inputs')) 342 | input_imgs.append(wandb.Image(make_figure(preds[i].squeeze()[jj]), caption=f'{names[i,jj]} recon')) 343 | else: 344 | input_imgs.append(wandb.Image(make_figure(inputs[i].squeeze()), caption=f'{names[i]} original')) 345 | input_imgs.append(wandb.Image(make_figure(masked_inputs[i].squeeze()), caption=f'{names[i]} inputs')) 346 | input_imgs.append(wandb.Image(make_figure(preds[i].squeeze()), caption=f'{names[i]} recon')) 347 | 348 | 349 | experiment.log( 350 | { 351 | "train/maps": input_imgs, 352 | } 353 | ) 354 | 355 | self.preds.clear() 356 | self.inputs.clear() 357 | self.names.clear() 358 | self.masks.clear() 359 | 360 | 361 | 362 | class LogTrainPredictions(Callback): 363 | """Logs a validation batch and their predictions to wandb. 364 | Example adapted from: 365 | https://wandb.ai/wandb/wandb-lightning/reports/Image-Classification-using-PyTorch-Lightning--VmlldzoyODk1NzY 366 | """ 367 | 368 | def __init__(self, num_samples: int = 2, counts=True): 369 | super().__init__() 370 | self.num_samples = num_samples 371 | self.counts = counts 372 | self.ready = True 373 | self.preds = [] 374 | self.targets = [] 375 | self.inputs = [] 376 | self.other_inputs = [] 377 | self.other_names = [] 378 | self.hists = [] 379 | self.counts_true = [] 380 | self.counts_pred = [] 381 | self.names = [] 382 | 383 | def on_sanity_check_start(self, trainer, pl_module): 384 | self.ready = False 385 | 386 | def on_sanity_check_end(self, trainer, pl_module): 387 | """Start executing this callback only after all validation sanity checks end.""" 388 | self.ready = True 389 | 390 | def on_train_batch_end( 391 | self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx 392 | ): 393 | """Gather data from single batch.""" 394 | if self.ready: 395 | self.preds.append(outputs["preds"]) 396 | self.targets.append(outputs["targets"]) 397 | self.inputs.append(outputs["inputs"]) 398 | ### yeah.... don't ask. Feel free to clean this up. 399 | try: 400 | self.other_inputs.append(outputs["other_inputs"]) 401 | self.other_names.append(outputs["other_names"]) 402 | self.other=True 403 | except: 404 | self.other=False 405 | 406 | if self.counts: 407 | self.counts_true.append(outputs["true_count"]) 408 | self.counts_pred.append(outputs["pred_count"]) 409 | self.names.append(outputs['input_names']) 410 | 411 | 412 | def on_train_epoch_end(self, trainer, pl_module): 413 | if self.ready: 414 | logger = get_wandb_logger(trainer=trainer) 415 | experiment = logger.experiment 416 | preds = torch.cat(self.preds[:self.num_samples]).detach().cpu().numpy() 417 | targets = torch.cat(self.targets[:self.num_samples]).detach().cpu().numpy() 418 | inputs = torch.cat(self.inputs[:self.num_samples]).detach().cpu().numpy() 419 | 420 | try: 421 | count_t = torch.cat(self.counts_true).detach().cpu().numpy() 422 | count_p = torch.cat(self.counts_pred).detach().cpu().numpy() 423 | except: 424 | try: 425 | count_t = self.counts_true.detach().cpu().numpy() 426 | count_p = self.counts_pred.detach().cpu().numpy() 427 | except: 428 | try: 429 | count_t = np.concatenate(self.counts_true) 430 | count_p = np.concatenate(self.counts_pred) 431 | except: 432 | if self.counts: 433 | count_t = self.counts_true 434 | count_p = self.counts_pred 435 | else: 436 | pass 437 | 438 | 439 | names = np.concatenate(self.names[:self.num_samples]) 440 | if self.other: 441 | other_names = np.concatenate(self.other_names[:self.num_samples]) 442 | other_inputs = torch.cat(self.other_inputs[:self.num_samples]).detach().cpu().numpy() 443 | 444 | output_imgs = [] 445 | input_imgs = [] 446 | count_imgs = [] 447 | 448 | for i in range(self.num_samples): 449 | 450 | output_imgs.append(wandb.Image(make_figure(targets[i].squeeze()), caption=f'Target {i}')) 451 | if len(preds.shape)>2: 452 | output_imgs.append(wandb.Image(make_figure(preds[i].squeeze()>0.5), caption=f'Prediction {i}>0.5')) 453 | output_imgs.append(wandb.Image(make_figure(preds[i].squeeze()), caption=f'Prediction {i}')) 454 | 455 | for jj in range(inputs[i].squeeze().shape[0]): 456 | input_imgs.append(wandb.Image(make_figure(inputs[i].squeeze()[jj]), caption=f'{names[i,jj]}')) 457 | if self.other: 458 | input_imgs.append(wandb.Image(make_figure(other_inputs[i].squeeze()[jj]), caption=f'{other_names[i,jj]}')) 459 | input_imgs.append(wandb.Image(make_figure(targets[i].squeeze()), caption=f'Target {i}')) 460 | 461 | if self.counts: 462 | count_imgs.append(wandb.Image(make_scatter(count_p, count_t))) 463 | 464 | 465 | # log the images as wandb Image 466 | if self.counts: 467 | experiment.log( 468 | { 469 | "train/target-prediction maps": output_imgs, 470 | "train/input maps": input_imgs, 471 | "train/counts": count_imgs 472 | } 473 | ) 474 | else: 475 | experiment.log( 476 | { 477 | "train/target-prediction maps": output_imgs, 478 | "train/input maps": input_imgs 479 | } 480 | ) 481 | self.preds.clear() 482 | self.targets.clear() 483 | self.inputs.clear() 484 | self.other_inputs.clear() 485 | self.names.clear() 486 | self.other_names.clear() 487 | if self.counts: 488 | self.counts_pred.clear() 489 | self.counts_true.clear() 490 | 491 | 492 | class LogValPredictions(Callback): 493 | """Logs a validation batch and their predictions to wandb. 494 | Example adapted from: 495 | https://wandb.ai/wandb/wandb-lightning/reports/Image-Classification-using-PyTorch-Lightning--VmlldzoyODk1NzY 496 | """ 497 | 498 | def __init__(self, num_samples: int = 8, counts=True): 499 | super().__init__() 500 | self.num_samples = num_samples 501 | self.ready = True 502 | self.preds = [] 503 | self.targets = [] 504 | self.counts = counts 505 | if self.counts: 506 | self.counts_true = [] 507 | self.counts_pred = [] 508 | 509 | def on_sanity_check_start(self, trainer, pl_module): 510 | self.ready = False 511 | 512 | def on_sanity_check_end(self, trainer, pl_module): 513 | """Start executing this callback only after all validation sanity checks end.""" 514 | self.ready = True 515 | 516 | def on_validation_batch_end( 517 | self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx 518 | ): 519 | """Gather data from single batch.""" 520 | if self.ready: 521 | self.preds.append(outputs["preds"]) 522 | self.targets.append(outputs["targets"]) 523 | if self.counts: 524 | self.counts_true.append(outputs["true_count"]) 525 | self.counts_pred.append(outputs["pred_count"]) 526 | 527 | def on_validation_epoch_end(self, trainer, pl_module): 528 | if self.ready: 529 | logger = get_wandb_logger(trainer=trainer) 530 | experiment = logger.experiment 531 | preds = torch.cat(self.preds[:self.num_samples]).detach().cpu().numpy() 532 | targets = torch.cat(self.targets[:self.num_samples]).detach().cpu().numpy() 533 | 534 | try: 535 | count_t = torch.cat(self.counts_true).detach().cpu().numpy() 536 | count_p = torch.cat(self.counts_pred).detach().cpu().numpy() 537 | except: 538 | try: 539 | count_t = self.counts_true.detach().cpu().numpy() 540 | count_p = self.counts_pred.detach().cpu().numpy() 541 | except: 542 | try: 543 | count_t = np.concatenate(self.counts_true) 544 | count_p = np.concatenate(self.counts_pred) 545 | except: 546 | if self.counts: 547 | count_t = self.counts_true 548 | count_p = self.counts_pred 549 | else: 550 | pass 551 | 552 | 553 | imgs = [] 554 | 555 | for i in range(self.num_samples): 556 | 557 | imgs.append(wandb.Image(make_figure(targets[i].squeeze()), caption=f'Target {i}')) 558 | if len(preds.shape)>2: 559 | imgs.append(wandb.Image(make_figure(preds[i].squeeze()>0.5), caption=f'Prediction {i}>0.5')) 560 | imgs.append(wandb.Image(make_figure(preds[i].squeeze()), caption=f'Prediction {i}')) 561 | if self.counts: 562 | imgs.append(wandb.Image(make_scatter(count_p, count_t))) 563 | 564 | # log the images as wandb Image 565 | 566 | experiment.log( 567 | { 568 | "Validation/target-prediction maps": imgs 569 | } 570 | ) 571 | self.preds.clear() 572 | self.targets.clear() 573 | if self.counts: 574 | self.counts_pred.clear() 575 | self.counts_true.clear() -------------------------------------------------------------------------------- /src/datamodules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VMBoehm/SAR-landslide-detection-pretraining/177799a5f22a1fab38cf0da76b98b3d32843a765/src/datamodules/__init__.py -------------------------------------------------------------------------------- /src/datamodules/components/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VMBoehm/SAR-landslide-detection-pretraining/177799a5f22a1fab38cf0da76b98b3d32843a765/src/datamodules/components/__init__.py -------------------------------------------------------------------------------- /src/datamodules/components/chips.py: -------------------------------------------------------------------------------- 1 | #from osgeo import gdal, osr, ogr # Python bindings for GDAL 2 | import numpy as np 3 | import matplotlib.pyplot as plt 4 | import sys 5 | import os 6 | import pickle 7 | import json 8 | from rlxutils import subplots 9 | from .utils import pimshow 10 | 11 | class Chipset: 12 | 13 | def __init__(self, chipset_folder=None, data=None, metadata=None): 14 | 15 | self.folder = chipset_folder 16 | self.chip_fnames = os.listdir(chipset_folder) 17 | 18 | chip_ids_npz = [i.split(".")[0] for i in self.chip_fnames if i.endswith(".npz")] 19 | chip_ids_pkl = [i.split(".")[0] for i in self.chip_fnames if i.endswith("metadata.pkl")] 20 | # keep the chips with both metadata and data 21 | self.chip_ids = [i for i in chip_ids_npz if i in chip_ids_pkl] 22 | 23 | def get_chip(self, chip_id): 24 | if not chip_id in self.chip_ids: 25 | raise ValueError(f"{chip_id} does not exist") 26 | 27 | return Chip(self.folder, chip_id) 28 | 29 | def random_chip(self): 30 | chip_id = self.chip_ids[np.random.randint(len(self.chip_ids))] 31 | return Chip(self.folder, chip_id) 32 | 33 | def chips(self): 34 | for chip_id in self.chip_ids: 35 | yield Chip(self.folder, chip_id) 36 | 37 | class Chip: 38 | def __init__(self, chipset_folder, chip_id, data=None, metadata=None): 39 | self.chipset_folder = chipset_folder 40 | self.chip_id = chip_id 41 | 42 | assert (data is None and metadata is None) or (data is not None and metadata is not None), "'data' and 'metadata' must be both set or unset" 43 | 44 | if data is None: 45 | with np.load(f"{self.chipset_folder}/{self.chip_id}.npz") as my_file: 46 | self.data = my_file['arr_0'] 47 | my_file.close() 48 | with open(f"{self.chipset_folder}/{self.chip_id}.metadata.pkl", "rb") as f: 49 | self.metadata = pickle.load(f) 50 | f.close() 51 | else: 52 | self.data = data 53 | self.metadata = metadata 54 | 55 | def clone(self): 56 | return self.__class__(self.chipset_folder, self.chip_id, self.data.copy(), self.metadata.copy()) 57 | 58 | def get_varnames(self, exceptions=[]): 59 | r = self.get_timestep_varnames() + self.get_timepair_varnames() + self.get_static_varnames() 60 | r = [i for i in r if not i in exceptions] 61 | return r 62 | 63 | def get_time_varnames(self): 64 | return self.get_timestep_varnames() + self.get_timepair_varnames() 65 | 66 | def get_timestep_varnames(self): 67 | return list(np.unique([i.split("::")[1] for i in self.metadata['variables'] if i.startswith("TS::")])) 68 | 69 | def get_timesteps(self): 70 | return list(np.unique([i.split("::")[2] for i in self.metadata['variables'] if i.startswith("TS::")])) 71 | 72 | def get_timepair_varnames(self): 73 | return list(np.unique([i.split("::")[1] for i in self.metadata['variables'] if i.startswith("TP::")])) 74 | 75 | def get_timepairs(self): 76 | return list(np.unique([i.split("::")[2] for i in self.metadata['variables'] if i.startswith("TP::")])) 77 | 78 | def get_static_varnames(self): 79 | return list(np.unique([i.split("::")[1] for i in self.metadata['variables'] if i.startswith("ST::")])) 80 | 81 | def apply_with_other(self, other, func): 82 | """ 83 | applies a function element-wise to the data of two chips 84 | which must have exactly the same vars. 85 | 86 | the function must have signature func(x,y) with x,y and the return value 87 | must all be arrays of the same size 88 | """ 89 | thisvars = self.get_varnames() 90 | othervars = other.get_varnames() 91 | 92 | if self.get_varnames() != other.get_varnames(): 93 | raise ValueError("chips must have the same number of variables") 94 | 95 | r = self.clone() 96 | 97 | r.data = func(self.data, other.data) 98 | 99 | for name in self.metadata['variables']: 100 | self.metadata['variables'] 101 | 102 | return r 103 | 104 | def apply(self, var_name, func, args): 105 | """ 106 | applies a function element-wise to the data of two chips 107 | which must have exactly the same vars. 108 | 109 | the function must have signature func(x,y) with x,y and the return value 110 | must all be arrays of the same size 111 | """ 112 | try: 113 | assert(var_name in self.get_varnames()) 114 | except: 115 | raise ValueError(f'{var_name} not in chip') 116 | 117 | r = self.sel([var_name]) 118 | 119 | r.data, name_tag = func(np.nan_to_num(r.data), **args) 120 | 121 | if name_tag: 122 | r.metadata['variables']= [name.replace(var_name,var_name+'_'+name_tag) for name in r.metadata['variables']] 123 | 124 | return r 125 | 126 | def apply_across_time(self, func): 127 | """ 128 | applies a function across the time dimension rendering time variables into static 129 | returns a new chip 130 | """ 131 | 132 | tsvars = self.get_timestep_varnames() 133 | tpvars = self.get_timepair_varnames() 134 | selected_vars = tsvars + tpvars 135 | new_data = np.vstack([np.apply_over_axes(func, self.sel(v).data,0) for v in selected_vars]) 136 | new_metadata = self.metadata.copy() 137 | # all variables now become static 138 | new_metadata['variables'] = [f"ST::{i}__{func.__name__}" for i in tsvars] + [f"ST::{i}__{func.__name__}" for i in tpvars] 139 | new_chip_id = self.chip_id + f"_{np.random.randint(1000000):07d}" 140 | new_metadata['chip_id'] = new_chip_id 141 | return self.__class__(self.chipset_folder, new_chip_id, new_data, new_metadata) 142 | 143 | def diff_channels(self, tag1, tag2): 144 | var_names = self.get_varnames() 145 | 146 | new_metadata = self.metadata.copy() 147 | 148 | new_metadata['variables'] =[] 149 | new_data = [] 150 | new_vars = [] 151 | for varname in var_names: 152 | if tag1 in varname: 153 | v, idx = self.get_array_idxs(varnames=[varname.replace(tag1, tag2), varname]) 154 | new_data.append(np.expand_dims(np.squeeze(self.data[idx[0]]-self.data[idx[1]]),axis=0)) 155 | 156 | if tag1 in v[0]: 157 | v = v[0].replace(tag1,'diff') 158 | else: 159 | v = v[0].replace(tag2,'diff') 160 | 161 | new_metadata['variables'].append(v) 162 | elif tag2 in varname: 163 | pass 164 | else: 165 | v, idx = self.get_array_idxs(varnames=[varname]) 166 | new_data.append(np.expand_dims(np.squeeze(self.data[idx]),0)) 167 | new_metadata['variables'].append(v[0]) 168 | 169 | new_data = np.vstack(new_data) 170 | new_chip_id = self.chip_id + f"_{np.random.randint(1000000):07d}" 171 | new_metadata['chip_id'] = new_chip_id 172 | 173 | return self.__class__(self.chipset_folder, new_chip_id, new_data, new_metadata) 174 | 175 | def get_array_idxs(self, varnames=None, start_date=None, end_date=None): 176 | if varnames is None: 177 | varnames = self.get_varnames() 178 | elif not type(varnames)==list: 179 | varnames = [varnames] 180 | vspecs = self.metadata['variables'] 181 | selected_idxs = [] 182 | selected_vars = [] 183 | for i in range(len(vspecs)): 184 | vspec = vspecs[i] 185 | if vspec.startswith('TS::'): 186 | _,vname,vdate = vspec.split("::") 187 | if vname in varnames\ 188 | and (start_date is None or start_date<=vdate)\ 189 | and (end_date is None or end_date>=vdate): 190 | selected_idxs.append(i) 191 | selected_vars.append(vspec) 192 | elif vspec.startswith('TP::'): 193 | _,vname,vdate = vspec.split("::") 194 | vdate1, vdate2 = vdate.split("_") 195 | if vname in varnames\ 196 | and (start_date is None or (start_date<=vdate1 and start_date<=vdate2))\ 197 | and (end_date is None or (end_date>=vdate2 and end_date>=vdate2)): 198 | selected_idxs.append(i) 199 | selected_vars.append(vspec) 200 | elif vspec.startswith('ST::'): 201 | _, vname = vspec.split("::") 202 | if vname in varnames: 203 | selected_idxs.append(i) 204 | selected_vars.append(vspec) 205 | return selected_vars, selected_idxs 206 | 207 | def get_array(self, varnames=None, start_date=None, end_date=None): 208 | _, selected_idxs = self.get_array_idxs(varnames, start_date, end_date) 209 | return self.data[selected_idxs] 210 | 211 | def plot(self, overlay=None, log=False, **kwargs): 212 | if not 'n_cols' in kwargs: 213 | kwargs['n_cols'] = 5 214 | if not 'usizex' in kwargs: 215 | kwargs['usizex'] = 4 216 | if not 'usizey' in kwargs: 217 | kwargs['usizey'] = 3 218 | 219 | for ax,i in subplots(len(self.data), **kwargs): 220 | if log and np.nanmin(self.data[i])>=0: 221 | x = np.log10(self.data[i]+1e-4) 222 | else: 223 | x = self.data[i] 224 | pimshow(x) 225 | if np.sum(overlay): 226 | pimshow(np.squeeze(overlay), alpha=0.2) 227 | plt.colorbar() 228 | varname = self.metadata['variables'][i].split("::") 229 | if len(varname)==2: 230 | tit = varname[1] 231 | else: 232 | tit = f"{varname[1]}\n{varname[2]}" 233 | 234 | plt.title(tit) 235 | 236 | plt.tight_layout() 237 | 238 | def sel(self, varnames=None, start_date=None, end_date=None): 239 | selected_vars, selected_idxs = self.get_array_idxs(varnames, start_date, end_date) 240 | new_data = self.data[selected_idxs] 241 | new_metadata = self.metadata.copy() 242 | new_metadata['variables'] = selected_vars 243 | new_chip_id = self.chip_id + f"_{np.random.randint(1000000):07d}" 244 | new_metadata['chip_id'] = new_chip_id 245 | return self.__class__(self.chipset_folder, new_chip_id, new_data, new_metadata) 246 | 247 | def save_as_geotif(self, dest_folder): 248 | from osgeo import gdal, osr, ogr 249 | 250 | def getGeoTransform(extent, nlines, ncols): 251 | resx = (extent[2] - extent[0]) / ncols 252 | resy = (extent[3] - extent[1]) / nlines 253 | return [extent[0], resx, 0, extent[3] , 0, -resy] 254 | 255 | # Define the data extent (min. lon, min. lat, max. lon, max. lat) 256 | extent = list(self.metadata['bounds'].values()) # South America 257 | 258 | # Export the test array to GeoTIFF ================================================ 259 | 260 | # Get GDAL driver GeoTiff 261 | driver = gdal.GetDriverByName('GTiff') 262 | 263 | data = self.data 264 | 265 | # Get dimensions 266 | nlines = data.shape[1] 267 | ncols = data.shape[2] 268 | nbands = len(data) 269 | data_type = gdal.GDT_Float32 # gdal.GDT_Float32 270 | 271 | # Create a temp grid 272 | #options = ['COMPRESS=JPEG', 'JPEG_QUALITY=80', 'TILED=YES'] 273 | grid_data = driver.Create('grid_data', ncols, nlines, nbands, data_type)#, options) 274 | 275 | # Write data for each bands 276 | for i in range(len(data)): 277 | grid_data.GetRasterBand(i+1).WriteArray(self.data[i]) 278 | 279 | # Lat/Lon WSG84 Spatial Reference System 280 | import os 281 | import sys 282 | proj_lib = "/".join(sys.executable.split("/")[:-2]+['share', 'proj']) 283 | 284 | os.environ['PROJ_LIB']=proj_lib 285 | srs = osr.SpatialReference() 286 | #srs.ImportFromProj4('+proj=longlat +ellps=WGS84 +datum=WGS84 +no_defs') 287 | srs.ImportFromEPSG(int(self.metadata['crs'].split(':')[1])) 288 | 289 | # Setup projection and geo-transform 290 | grid_data.SetProjection(srs.ExportToWkt()) 291 | grid_data.SetGeoTransform(getGeoTransform(extent, nlines, ncols)) 292 | 293 | # Save the file 294 | file_name = f'{dest_folder}/{self.chip_id}.tif' 295 | print(f'saved {file_name} with {nbands} bands') 296 | driver.CreateCopy(file_name, grid_data, 0) 297 | 298 | # Close the file 299 | driver = None 300 | grid_data = None 301 | 302 | # Delete the temp grid 303 | import os 304 | os.remove('grid_data') 305 | #=========================== 306 | 307 | @staticmethod 308 | def concat_static(chip_list,sufixes=None): 309 | """ 310 | concats two chips containing only static variables 311 | """ 312 | 313 | if sufixes is None: 314 | sufixes = [""] * len(chip_list) 315 | 316 | assert len(chip_list)==len(sufixes), f"you have {len(chip_list)} chips but {len(sufixes)} sufixes" 317 | 318 | c = chip_list[0] 319 | p = sufixes[0] 320 | 321 | # all variables must be static 322 | assert len(c.get_time_varnames())==0, "chips can only contain static variables" 323 | 324 | r = c.clone() 325 | r.chip_id += "_concat" 326 | r.metadata['variables'] = [f"{i}{p}" for i in r.metadata['variables']] 327 | 328 | for i in range(1,len(chip_list)): 329 | c = chip_list[i] 330 | p = sufixes[i] 331 | 332 | if c.metadata['bounds'] != r.metadata['bounds']: 333 | raise ValueError("all chips must have the same bounds") 334 | 335 | if c.metadata['crs'] != r.metadata['crs']: 336 | raise ValueError("all chips must have the same crs") 337 | 338 | r.data = np.vstack([r.data, c.data]) 339 | r.metadata['variables'] += [f"{i}{p}" for i in c.metadata['variables']] 340 | 341 | if len(np.unique( r.metadata['variables'] )) != sum([len(i.metadata['variables'] ) for i in chip_list] ): 342 | raise ValueError("there were overlapping variable names in the chips. use 'sufixes'") 343 | 344 | return r -------------------------------------------------------------------------------- /src/datamodules/components/transforms.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy 3 | from scipy.spatial import KDTree as KDTree 4 | from src.utils.loss_kernels import gaussian_kernel 5 | from scipy.signal import convolve2d 6 | 7 | class KDE_smoothing(): 8 | 9 | def __init__(self, sigma, norm=False): 10 | 11 | self.sigma = sigma 12 | self.norm = norm 13 | self.kernel = gaussian_kernel(size=5*sigma,sigma=sigma) 14 | self.kernel = self.kernel/self.kernel.sum() 15 | 16 | def __call__(self,data): 17 | shape = data.shape 18 | data = convolve2d(np.squeeze(data),self.kernel, mode='same') 19 | if self.norm: 20 | data/=np.sum(data) 21 | return np.reshape(data, shape) 22 | 23 | 24 | class Ignore_multiples(): 25 | 26 | def __init__(self, threshold=0., new_val = 1.0): 27 | self.threshold = threshold 28 | self.new_val = new_val 29 | 30 | def __call__(self,data): 31 | data = (data>self.threshold).astype(np.float32)*self.new_val 32 | return data 33 | 34 | 35 | class Log_transform(): 36 | 37 | def __init__(self, stats_dict, offset=1e-4, log=np.log10, name_tag='log'): 38 | """ 39 | 40 | """ 41 | self.offset = offset 42 | self.log = log 43 | self.stats_dict = stats_dict 44 | self.name_tag = name_tag 45 | 46 | 47 | def __call__(self,data, channel): 48 | data = np.nan_to_num(data) 49 | if min(self.stats_dict[f'min_{channel}'],np.min(data))<=0.: 50 | data+=abs(self.stats_dict[f'min_{channel}']) 51 | data+=self.offset 52 | data = self.log(data) 53 | 54 | return data, self.name_tag 55 | 56 | #TODO: adapt 57 | class Standardize(): 58 | 59 | def __init__(self, stats_dict, name_tag='stded'): 60 | self.stats_dict = stats_dict 61 | self.name_tag = name_tag 62 | 63 | # VB uses stas_dict for all channels together (ig broadcasting is correct) 64 | def __call__(self,data, channel): 65 | 66 | if 'log' in channel: 67 | mean = self.stats_dict[f'log_mean_{channel}'] 68 | std = self.stats_dict[f'log_std_{channel}'] 69 | else: 70 | mean = self.stats_dict[f'mean_{channel}'] 71 | std = self.stats_dict[f'std_{channel}'] 72 | 73 | data = np.nan_to_num(data) 74 | data = (data-mean)/std 75 | 76 | 77 | return data, self.name_tag 78 | 79 | 80 | class Identity_transform(): 81 | 82 | def __init__(self): 83 | pass 84 | 85 | def __call__(self,data): 86 | return data 87 | 88 | 89 | 90 | -------------------------------------------------------------------------------- /src/datamodules/components/utils.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | 4 | def qplot(x): 5 | pctls = np.linspace(0,100,21)[1:-1] 6 | qtles = np.percentile(x, pctls) 7 | plt.plot(pctls, qtles) 8 | plt.grid(); plt.xlabel("percentiles"); plt.ylabel("quantiles") 9 | 10 | def pimshow(x, pmin=1, pmax=99, set_nans_to=0, alpha=None): 11 | xx = x.copy() 12 | xx[np.isnan(xx)]=set_nans_to 13 | vmin, vmax = np.percentile(xx, [pmin, pmax]) 14 | plt.imshow(x, vmin=vmin, vmax=vmax, alpha=alpha) -------------------------------------------------------------------------------- /src/datamodules/siamese_datamodule.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import torch 3 | from pytorch_lightning import LightningDataModule 4 | from torch.utils.data import ConcatDataset, DataLoader, Dataset 5 | from torchvision.transforms import transforms 6 | from src.datamodules.components import transforms 7 | import numpy as np 8 | import pickle 9 | import warnings 10 | from src.datamodules.components import chips 11 | 12 | 13 | event_dates = {} 14 | event_dates['hokkaido'] = {'event_start_date':'20180905', 'event_end_date':'20180907'} 15 | event_dates['indonesia'] = {'event_start_date':'20220224', 'event_end_date':'20220225'} 16 | event_dates['kaikoura'] = {'event_start_date':'20161113', 'event_end_date':'20161114'} 17 | 18 | def identity(x, dummy): 19 | return x 20 | 21 | class Siamese_Landslide_Dataloader(Dataset): 22 | """ 23 | Torch Dataset class for Siamese Change Detection Training 24 | 25 | """ 26 | def __init__(self, 27 | data_dir, 28 | dict_dir, 29 | split, 30 | channels, 31 | setting, 32 | num_time_steps, 33 | time_step_summary=None, 34 | input_transforms=None, 35 | trainsize=-1): 36 | 37 | """ 38 | attributes: 39 | data_dir: data directory 40 | dict_dir: data summary statistics directory 41 | split: one of 'train','valid', 'test' 42 | channels: list of strings, which channels to return 43 | setting: 'pretraining' or 'downstream' 44 | num_time_steps: int, number of time steps before and after the event 45 | balance_weights: weights for imbalanced dataset 46 | time_step_summary: numpy function name for combining time steps, None for not combining 47 | input_transforms: list of input transforms 48 | """ 49 | 50 | self.datasets = data_dir.keys() 51 | self.trainsize = trainsize 52 | self.split = split 53 | assert split in ['train', 'valid', 'test'] 54 | 55 | self.setting = setting 56 | assert setting in ['pretraining','downstream'] 57 | 58 | # decide which 59 | if self.split=='train': 60 | if self.setting=='pretraining': 61 | self.split='train_pre' 62 | elif self.setting=='downstream': 63 | self.split='train_down' 64 | 65 | # get file numbers for split 66 | self.dataset_dicts = {} 67 | self.length = 0 68 | for name in self.datasets: 69 | self.dataset_dicts[name] = {} 70 | self.dataset_dicts[name]['file_nums'] = pickle.load(open(dict_dir/'split_dict_balanced_siamese_{}.pkl'.format(name), 'rb'))[self.split] 71 | self.dataset_dicts[name]['length'] = len(self.dataset_dicts[name]['file_nums']) 72 | self.dataset_dicts[name]['start_date'], self.dataset_dicts[name]['end_date'] = pickle.load(open(dict_dir/'timestep_dict_siamese_{}.pkl'.format(name),'rb'))[str(num_time_steps)] 73 | self.length+=self.dataset_dicts[name]['length'] 74 | self.dataset_dicts[name]['start_date'], self.dataset_dicts[name]['end_date'] 75 | if self.split=='train_down': 76 | if len(self.datasets)==1: 77 | if self.trainsize!=-1: 78 | inds = np.random.choice(np.arange(self.length),self.trainsize, replace=False) 79 | self.dataset_dicts[name]['file_nums'] = self.dataset_dicts[name]['file_nums'][inds] 80 | self.dataset_dicts[name]['length'] = self.trainsize 81 | self.length = self.trainsize 82 | 83 | self.input_transforms = input_transforms 84 | self.channels = list(channels) 85 | 86 | 87 | if time_step_summary: 88 | self.summary = getattr(np,time_step_summary) 89 | else: 90 | self.summary = False 91 | 92 | self.chipsets = {} 93 | for name in self.datasets: 94 | self.chipsets[name] = chips.Chipset(data_dir[name]) 95 | 96 | print(self.length) 97 | 98 | def __len__(self): 99 | return self.length 100 | 101 | def __getitem__(self, index): 102 | 103 | count=0 104 | for name_ in self.datasets: 105 | bound = self.dataset_dicts[name_]['length']+count 106 | if (index=count): 107 | chip = self.chipsets[name_].get_chip(self.dataset_dicts[name_]['file_nums'][index-count]) 108 | name = name_ 109 | count+=bound 110 | 111 | # add reactivated landslides? 112 | if self.setting=='pretraining': 113 | label = int(np.sum(np.nan_to_num(chip.sel([f'landslides']).data))>0) 114 | else: 115 | label = np.nan_to_num(chip.sel([f'landslides']).data).astype(int) 116 | 117 | input_chip = chip.sel(self.channels) 118 | 119 | pre = input_chip.sel(input_chip.get_time_varnames(), 120 | start_date=self.dataset_dicts[name]['start_date'],end_date=event_dates[name]['event_start_date']) 121 | if self.summary: 122 | pre = pre.apply_across_time(self.summary) 123 | else: 124 | pre = pre.apply_across_time(identity) 125 | 126 | pos = input_chip.sel(input_chip.get_time_varnames(), 127 | start_date=event_dates[name]['event_end_date'], end_date=self.dataset_dicts[name]['end_date']) 128 | if self.summary: 129 | pos= pos.apply_across_time(self.summary) 130 | else: 131 | pos= pos.apply_across_time(identity) 132 | 133 | names = pre.get_varnames() 134 | 135 | if self.input_transforms[name]: 136 | pre = self.apply_transforms(pre,name).data.astype(np.float32) 137 | pos = self.apply_transforms(pos,name).data.astype(np.float32) 138 | 139 | return pre, pos, label, names, 1. 140 | 141 | 142 | def apply_transforms(self,input_chip,name): 143 | chans = [] 144 | for ch in input_chip.get_varnames(): 145 | chan = input_chip.sel([ch]) 146 | ch_base = ch.split('__')[0] 147 | if 'dem' not in ch_base: 148 | for trafo in list(self.input_transforms[name]): 149 | if isinstance(trafo, transforms.Log_transform): 150 | if ch_base in ['vh', 'vv','topophase.cor_0']: 151 | chan = chan.apply(ch,trafo,{'channel':ch_base}) 152 | ch+='_log' 153 | else: 154 | pass 155 | else: 156 | chan=chan.apply(ch,trafo,{'channel':ch_base}) 157 | chans.append(chan) 158 | input_chip = chan.concat_static(chans) 159 | return input_chip 160 | 161 | 162 | class Siamese_Landslide_Datamodule(LightningDataModule): 163 | 164 | def __init__( 165 | self, 166 | data_dir: str = "data/", 167 | dict_dir: str = "data_dicts/", 168 | batch_size: int = 16, 169 | num_workers: int = 1, 170 | pin_memory: bool = False, 171 | input_channels = ['vh', 'vv'], 172 | time_step_summary = 'mean', 173 | input_transforms: list = ['Log_transform', 'Standardize'], 174 | num_time_steps = 5, 175 | setting = 'pretraining', 176 | datasets = ['hokkaido','kaikoura'], 177 | trainsize = -1 178 | ): 179 | super().__init__() 180 | 181 | self.save_hyperparameters(logger=False) 182 | data_dirs = {} 183 | datasets = list(datasets) 184 | for name in datasets: 185 | data_dirs[name] = Path(data_dir)/"landslides-{}".format(name)/'chips_128x128' 186 | 187 | dict_dir = Path(dict_dir) 188 | 189 | self.input_channels = list(input_channels) 190 | self.num_channels = len(self.input_channels) 191 | ### see hokkaido_siamese_landslide_prep.ipynb 192 | self.stats_dicts = {} 193 | for name in datasets: 194 | self.stats_dicts[name] = pickle.load(open(dict_dir/'stats_dict_{}.pkl'.format(name), 'rb')) 195 | 196 | input_transforms = list(input_transforms) 197 | if input_transforms: 198 | input_transforms_dict = {} 199 | for name_ in datasets: 200 | input_transforms_dict[name_] = [getattr(transforms, name)(self.stats_dicts[name_]) for name in input_transforms] 201 | 202 | self.data_train: Optional[Dataset] = Siamese_Landslide_Dataloader(data_dir=data_dirs, dict_dir=dict_dir, split="train", 203 | channels=input_channels, setting=setting,num_time_steps=num_time_steps, 204 | time_step_summary=time_step_summary, input_transforms=input_transforms_dict, trainsize=trainsize) 205 | self.data_val: Optional[Dataset] = Siamese_Landslide_Dataloader(data_dir=data_dirs, dict_dir=dict_dir, split="valid", 206 | channels=input_channels, setting=setting,num_time_steps=num_time_steps, 207 | time_step_summary=time_step_summary, input_transforms=input_transforms_dict, trainsize=trainsize) 208 | self.data_test: Optional[Dataset] = Siamese_Landslide_Dataloader(data_dir=data_dirs, dict_dir=dict_dir, split="test", 209 | channels=input_channels, setting=setting,num_time_steps=num_time_steps, 210 | time_step_summary=time_step_summary, input_transforms=input_transforms_dict, trainsize=trainsize) 211 | 212 | @property 213 | def num_classes(self) -> int: 214 | return 1 215 | 216 | def train_dataloader(self): 217 | return DataLoader( 218 | dataset=self.data_train, 219 | batch_size=self.hparams.batch_size, 220 | num_workers=self.hparams.num_workers, 221 | pin_memory=self.hparams.pin_memory, 222 | shuffle=True, 223 | persistent_workers=True, 224 | prefetch_factor=8 225 | ) 226 | 227 | def val_dataloader(self): 228 | return DataLoader( 229 | dataset=self.data_val, 230 | batch_size=self.hparams.batch_size*8, 231 | num_workers=self.hparams.num_workers, 232 | pin_memory=self.hparams.pin_memory, 233 | shuffle=False, 234 | persistent_workers=True, 235 | prefetch_factor=8 236 | ) 237 | 238 | def test_dataloader(self): 239 | return DataLoader( 240 | dataset=self.data_test, 241 | batch_size=self.hparams.batch_size*8, 242 | num_workers=self.hparams.num_workers, 243 | pin_memory=self.hparams.pin_memory, 244 | shuffle=False, 245 | persistent_workers=True, 246 | prefetch_factor=8 247 | ) 248 | 249 | 250 | ###-------------------------------------------------------------------------------------------------------------------### 251 | -------------------------------------------------------------------------------- /src/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VMBoehm/SAR-landslide-detection-pretraining/177799a5f22a1fab38cf0da76b98b3d32843a765/src/models/__init__.py -------------------------------------------------------------------------------- /src/models/siamese_downstream_module.py: -------------------------------------------------------------------------------- 1 | """ 2 | Siamese Network trained to predict landlside induced change 3 | """ 4 | from pathlib import Path 5 | import torch 6 | from pytorch_lightning import LightningModule 7 | import segmentation_models_pytorch as smp 8 | import torchvision.models as models 9 | import numpy as np 10 | import math 11 | import torch.nn as nn 12 | from torchmetrics.functional import precision_recall 13 | from torchmetrics import AveragePrecision,Accuracy 14 | from src.models.siamese_module import Siamese_Type_1 15 | 16 | 17 | def compute_metrics(preds, targets, threshold=0.5): 18 | #print(preds) 19 | prec, rec = precision_recall(preds, targets, threshold=threshold) 20 | f1_score = 2*(prec*rec)/(prec+rec) 21 | 22 | # pr_curve = PrecisionRecallCurve(num_classes=5) 23 | # precision, recall, thresholds = pr_curve(preds, targets) 24 | 25 | average_precision = AveragePrecision() 26 | AP_score = average_precision(preds, targets) 27 | accuracy = Accuracy(threshold=threshold).cuda() 28 | acc = accuracy(preds, targets) 29 | return f1_score, AP_score, prec, rec, acc 30 | 31 | 32 | class Segmentation_Model(LightningModule): 33 | 34 | def __init__( 35 | self, 36 | input_size: list = [2,128,128], 37 | embedding_size: int = 32, 38 | pre_train_augmented = False, 39 | pretrain_path = None, 40 | pretrain_params = {}, 41 | encoder_depth: int = 3, 42 | decoder_channels = [128, 64, 32], 43 | base_lr = 1.5e-4, 44 | unet = False, 45 | loss = 'ce'): 46 | 47 | super().__init__() 48 | 49 | self.input_size = list(input_size) 50 | 51 | self.channels = input_size[0] 52 | self.image_size = input_size[1::] 53 | self.base_lr = base_lr 54 | self.pretrain_augmented = pre_train_augmented 55 | 56 | if unet: 57 | self.basemodel = smp.Unet(encoder_name='resnet34', encoder_depth=encoder_depth, decoder_channels=decoder_channels, encoder_weights=None, decoder_use_batchnorm=True, in_channels=self.input_size[0], classes=embedding_size, activation='identity', aux_params=None) 58 | 59 | 60 | else: 61 | self.basemodel = nn.Sequential(torch.nn.Conv2d(self.channels,self.channels*2, kernel_size=8,stride=1, bias=True, padding='same'), 62 | torch.nn.BatchNorm2d(self.channels*2), 63 | torch.nn.LeakyReLU(), 64 | torch.nn.Conv2d(self.channels*2,self.channels*4, kernel_size=8,stride=1, bias=True, padding='same'), 65 | torch.nn.BatchNorm2d(self.channels*4), 66 | torch.nn.LeakyReLU(), 67 | torch.nn.Conv2d(self.channels*4,embedding_size, kernel_size=8,stride=1, bias=True, padding='same'), 68 | torch.nn.BatchNorm2d(embedding_size), 69 | torch.nn.LeakyReLU() 70 | ) 71 | 72 | depth = embedding_size 73 | 74 | if self.pretrain_augmented: 75 | depth = embedding_size+pretrain_params['embedding_size']*2 76 | self.pretrained = Siamese_Type_1(**pretrain_params) 77 | checkpoint = torch.load(pretrain_path) 78 | self.pretrained.load_state_dict(checkpoint['state_dict']) 79 | self.pretrained.eval() 80 | 81 | self.cnn = nn.Sequential(torch.nn.Conv2d(depth, depth, kernel_size=8,stride=1, bias=True, padding='same'), 82 | torch.nn.BatchNorm2d(depth), 83 | torch.nn.LeakyReLU(), 84 | torch.nn.Conv2d(depth, 1, kernel_size=8,stride=1, bias=True, padding='same')) 85 | self.loss = loss 86 | if loss == 'ce': 87 | self.criterion = nn.CrossEntropyLoss(reduction='none') 88 | elif loss=='dice': 89 | self.criterion = smp.losses.DiceLoss(mode='binary') 90 | 91 | if torch.cuda.is_available(): 92 | self.net = self.basemodel.cuda() 93 | self.fc = self.cnn.cuda() 94 | if pre_train_augmented: 95 | self.pretrained = self.pretrained.cuda() 96 | for params in self.pretrained.parameters(): 97 | params.requires_grad = False 98 | 99 | 100 | def forward(self, pre, post): 101 | 102 | inputs = torch.cat((pre,post),1) 103 | emb = self.basemodel(inputs) 104 | 105 | if self.pretrain_augmented: 106 | with torch.no_grad(): 107 | pre_emb = self.pretrained.forward_once(pre) 108 | pre_emb = torch.cat((pre_emb, self.pretrained.forward_once(post)),1) 109 | 110 | emb = torch.cat((emb, pre_emb), 1) 111 | 112 | output = self.cnn(emb) 113 | if self.loss=='ce': 114 | output = output.view(output.size()[0], -1) 115 | 116 | return output 117 | 118 | def step(self, batch): 119 | 120 | pre, post, label, names, weight = batch 121 | 122 | names = np.asarray(names).T 123 | 124 | preds = self.forward(pre,post) 125 | label = label.view(preds.size()) 126 | 127 | loss = torch.mean(weight*self.criterion(preds.float(),label.float())) 128 | return loss, pre, post, preds, label, names 129 | 130 | 131 | def training_step(self, batch, batch_idx: int): 132 | loss, pre, post, preds, targets, names = self.step(batch) 133 | 134 | self.log("train/loss", loss, on_step=False, on_epoch=True, prog_bar=False) 135 | 136 | f1, auprc, prec, rec, acc = compute_metrics(torch.sigmoid(preds.detach()), targets.detach()) 137 | self.log("train/precision", prec, on_step=False, on_epoch=True, prog_bar=False) 138 | self.log("train/recall", rec, on_step=False, on_epoch=True, prog_bar=False) 139 | self.log("train/f1", f1, on_step=False, on_epoch=True, prog_bar=False) 140 | self.log("train/AP", auprc, on_step=False, on_epoch=True, prog_bar=False) 141 | self.log("train/acc", acc, on_step=False, on_epoch=True, prog_bar=False) 142 | 143 | return {"loss": loss, "inputs": torch.reshape(targets,(-1,1,128,128)),"preds": torch.reshape(preds,(-1,1,128,128))} 144 | 145 | def training_epoch_end(self, outputs): 146 | pass 147 | 148 | def validation_step(self, batch, batch_idx: int): 149 | loss, pre, post, preds, targets, names = self.step(batch) 150 | 151 | self.log("valid/loss", loss, on_step=False, on_epoch=True, prog_bar=False) 152 | 153 | f1, auprc, prec, rec, acc = compute_metrics(torch.sigmoid(preds.detach()), targets.detach()) 154 | self.log("valid/precision", prec, on_step=False, on_epoch=True, prog_bar=False) 155 | self.log("valid/recall", rec, on_step=False, on_epoch=True, prog_bar=False) 156 | self.log("valid/f1", f1, on_step=False, on_epoch=True, prog_bar=False) 157 | self.log("valid/AP", auprc, on_step=False, on_epoch=True, prog_bar=False) 158 | self.log("valid/acc", acc, on_step=False, on_epoch=True, prog_bar=False) 159 | 160 | 161 | return {"loss": loss, "inputs": torch.reshape(targets,(-1,1,128,128)),"preds": torch.reshape(preds,(-1,1,128,128))} 162 | 163 | 164 | def test_step(self, batch, batch_idx: int): 165 | loss, pre, post, preds, targets, names = self.step(batch) 166 | 167 | self.log("test/loss", loss, on_step=False, on_epoch=True, prog_bar=False) 168 | 169 | f1, auprc, prec, rec, acc = compute_metrics(torch.sigmoid(preds.detach()), targets.detach()) 170 | self.log("test/precision", prec, on_step=False, on_epoch=True, prog_bar=False) 171 | self.log("test/recall", rec, on_step=False, on_epoch=True, prog_bar=False) 172 | self.log("test/f1", f1, on_step=False, on_epoch=True, prog_bar=False) 173 | self.log("test/AP", auprc, on_step=False, on_epoch=True, prog_bar=False) 174 | self.log("test/acc", acc, on_step=False, on_epoch=True, prog_bar=False) 175 | 176 | return {"loss": loss} 177 | 178 | 179 | def test_epoch_end(self, outputs): 180 | pass 181 | 182 | 183 | def configure_optimizers(self): 184 | 185 | optimizer = torch.optim.AdamW(self.parameters(),lr=self.base_lr, capturable=True) 186 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', threshold=1e-4, patience=50, factor=0.9) 187 | 188 | return {'lr_scheduler':scheduler, 'optimizer':optimizer, 'monitor': 'train/loss'} 189 | -------------------------------------------------------------------------------- /src/models/siamese_module.py: -------------------------------------------------------------------------------- 1 | """ 2 | Siamese Network trained to predict landlside induced change 3 | """ 4 | from pathlib import Path 5 | import torch 6 | from pytorch_lightning import LightningModule 7 | import segmentation_models_pytorch as smp 8 | import torchvision.models as models 9 | import numpy as np 10 | import math 11 | import torch.nn as nn 12 | from torchmetrics.functional import precision_recall 13 | from torchmetrics import AveragePrecision,Accuracy 14 | 15 | 16 | def compute_metrics(preds, targets, threshold=0.5): 17 | #print(preds) 18 | prec, rec = precision_recall(preds, targets, threshold=threshold) 19 | f1_score = 2*(prec*rec)/(prec+rec) 20 | 21 | # pr_curve = PrecisionRecallCurve(num_classes=5) 22 | # precision, recall, thresholds = pr_curve(preds, targets) 23 | 24 | average_precision = AveragePrecision() 25 | AP_score = average_precision(preds, targets) 26 | accuracy = Accuracy(threshold=threshold).cuda() 27 | acc = accuracy(preds, targets) 28 | return f1_score, AP_score, prec, rec, acc 29 | 30 | 31 | class Siamese_Type_1(LightningModule): 32 | 33 | def __init__( 34 | self, 35 | input_size: list = [2,128,128], 36 | embedding_size: int = 32, 37 | decoder_depth: int = 3, 38 | encoder_depth: int = 3, 39 | base_lr = 1.5e-4, 40 | unet = False, 41 | cnn = True, 42 | decoder_channels = [128,64,32]): 43 | 44 | super().__init__() 45 | 46 | self.input_size = list(input_size) 47 | 48 | self.channels = input_size[0] 49 | self.image_size = input_size[1::] 50 | self.base_lr = base_lr 51 | 52 | #decoder_channels = [embedding_size for ii in range(decoder_depth)] 53 | 54 | dropout_rate = 0.8 55 | 56 | if unet: 57 | self.net = smp.Unet(encoder_name='resnet34', encoder_depth=encoder_depth, encoder_weights=None, decoder_use_batchnorm=True, in_channels=self.input_size[0], decoder_channels=list(decoder_channels)[:decoder_depth], classes=embedding_size, activation='identity', aux_params=None) 58 | 59 | 60 | else: 61 | self.net = nn.Sequential(torch.nn.Conv2d(self.channels,self.channels*4, kernel_size=8,stride=1, bias=True, padding='same'), 62 | torch.nn.BatchNorm2d(self.channels*4), 63 | torch.nn.LeakyReLU(), 64 | # torch.nn.Conv2d(self.channels*4,self.channels*4*4, kernel_size=8,stride=1, bias=True, padding='same'), 65 | # torch.nn.BatchNorm2d(self.channels*4*4), 66 | # torch.nn.LeakyReLU(), 67 | torch.nn.Conv2d(self.channels*4,embedding_size, kernel_size=8,stride=1, bias=True, padding='same'), 68 | torch.nn.BatchNorm2d(embedding_size), 69 | torch.nn.LeakyReLU() 70 | ) 71 | 72 | size = np.prod(self.image_size)*embedding_size*2 73 | if cnn: 74 | self.cnn = nn.Sequential(torch.nn.Conv2d(embedding_size*2, 1, kernel_size=8 ,stride=1, bias=True, padding='valid'), 75 | torch.nn.BatchNorm2d(1), 76 | torch.nn.LeakyReLU()) 77 | size = 1*128**2#torch.nn.LeakyReLU() 78 | 79 | self.fc = nn.Sequential(nn.Linear(size, 8), 80 | nn.BatchNorm1d(8), 81 | nn.LeakyReLU(inplace=True), 82 | nn.Dropout(dropout_rate), 83 | nn.Linear(8, 1)) 84 | 85 | if cnn: 86 | self.int_layer = True 87 | else: 88 | self.int_layer = False 89 | 90 | self.criterion = nn.CrossEntropyLoss() 91 | 92 | if torch.cuda.is_available(): 93 | self.net = self.net.cuda() 94 | #self.fc = self.fc.cuda() 95 | if cnn: 96 | self.cnn = self.cnn.cuda() 97 | 98 | 99 | def forward_once(self, x): 100 | output = self.net(x) 101 | 102 | return output 103 | 104 | def forward(self, input1, input2): 105 | # get two images' features 106 | output1 = self.forward_once(input1) 107 | output2 = self.forward_once(input2) 108 | 109 | # concatenate both images' features 110 | output = torch.cat((output1, output2), 1) 111 | 112 | #output = self.cnn(output) 113 | if self.int_layer: 114 | output = self.cnn(output) 115 | output = output.view(output.size()[0], -1) 116 | # pass the concatenation to the linear layers 117 | output = torch.mean(output, axis=-1)#self.fc(output) 118 | 119 | return output 120 | 121 | def step(self, batch): 122 | 123 | pre, post, label, names, weight = batch 124 | 125 | names = np.asarray(names).T 126 | 127 | preds = self.forward(pre,post) 128 | preds = torch.squeeze(preds) 129 | weight = torch.reshape(weight.float(),preds.shape) 130 | 131 | loss = self.criterion(preds.float(),label.float()) 132 | 133 | return loss, pre, post, preds, label, names 134 | 135 | 136 | def training_step(self, batch, batch_idx: int): 137 | loss, pre, post, preds, targets, names = self.step(batch) 138 | 139 | self.log("train/loss", loss, on_step=False, on_epoch=True, prog_bar=False) 140 | 141 | f1, auprc, prec, rec, acc = compute_metrics(torch.sigmoid(preds.detach()), targets.detach()) 142 | self.log("train/precision", prec, on_step=False, on_epoch=True, prog_bar=False) 143 | self.log("train/recall", rec, on_step=False, on_epoch=True, prog_bar=False) 144 | self.log("train/f1", f1, on_step=False, on_epoch=True, prog_bar=False) 145 | self.log("train/AP", auprc, on_step=False, on_epoch=True, prog_bar=False) 146 | self.log("train/acc", acc, on_step=False, on_epoch=True, prog_bar=False) 147 | 148 | return {"loss": loss} 149 | 150 | def training_epoch_end(self, outputs): 151 | pass 152 | 153 | def validation_step(self, batch, batch_idx: int): 154 | loss, pre, post, preds, targets, names = self.step(batch) 155 | 156 | self.log("valid/loss", loss, on_step=False, on_epoch=True, prog_bar=False) 157 | 158 | f1, auprc, prec, rec, acc = compute_metrics(torch.sigmoid(preds.detach()), targets.detach()) 159 | self.log("valid/precision", prec, on_step=False, on_epoch=True, prog_bar=False) 160 | self.log("valid/recall", rec, on_step=False, on_epoch=True, prog_bar=False) 161 | self.log("valid/f1", f1, on_step=False, on_epoch=True, prog_bar=False) 162 | self.log("valid/AP", auprc, on_step=False, on_epoch=True, prog_bar=False) 163 | self.log("valid/acc", acc, on_step=False, on_epoch=True, prog_bar=False) 164 | 165 | 166 | return {"loss": loss} 167 | 168 | 169 | def test_step(self, batch, batch_idx: int): 170 | loss, pre, post, preds, targets, names = self.step(batch) 171 | 172 | self.log("test/loss", loss, on_step=False, on_epoch=True, prog_bar=False) 173 | 174 | f1, auprc, prec, rec, acc = compute_metrics(torch.sigmoid(preds.detach()), targets.detach()) 175 | self.log("test/precision", prec, on_step=False, on_epoch=True, prog_bar=False) 176 | self.log("test/recall", rec, on_step=False, on_epoch=True, prog_bar=False) 177 | self.log("test/f1", f1, on_step=False, on_epoch=True, prog_bar=False) 178 | self.log("test/AP", auprc, on_step=False, on_epoch=True, prog_bar=False) 179 | self.log("test/acc", acc, on_step=False, on_epoch=True, prog_bar=False) 180 | 181 | return {"loss": loss} 182 | 183 | 184 | def test_epoch_end(self, outputs): 185 | pass 186 | 187 | 188 | def configure_optimizers(self): 189 | 190 | optimizer = torch.optim.AdamW(self.parameters(),lr=self.base_lr, capturable=True) 191 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', threshold=1e-4, patience=200, factor=0.9) 192 | 193 | return {'lr_scheduler':scheduler, 'optimizer':optimizer, 'monitor': 'valid/loss'} -------------------------------------------------------------------------------- /src/testing_pipeline.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import List 3 | 4 | import hydra 5 | from omegaconf import DictConfig 6 | from pytorch_lightning import LightningDataModule, LightningModule, Trainer, seed_everything 7 | from pytorch_lightning.loggers import LightningLoggerBase 8 | 9 | from src import utils 10 | 11 | log = utils.get_logger(__name__) 12 | 13 | 14 | def test(config: DictConfig) -> None: 15 | """Contains minimal example of the testing pipeline. Evaluates given checkpoint on a testset. 16 | 17 | Args: 18 | config (DictConfig): Configuration composed by Hydra. 19 | 20 | Returns: 21 | None 22 | """ 23 | 24 | # Set seed for random number generators in pytorch, numpy and python.random 25 | if config.get("seed"): 26 | seed_everything(config.seed, workers=True) 27 | 28 | # Convert relative ckpt path to absolute path if necessary 29 | if not os.path.isabs(config.ckpt_path): 30 | config.ckpt_path = os.path.join(hydra.utils.get_original_cwd(), config.ckpt_path) 31 | 32 | # Init lightning datamodule 33 | log.info(f"Instantiating datamodule <{config.datamodule._target_}>") 34 | datamodule: LightningDataModule = hydra.utils.instantiate(config.datamodule) 35 | 36 | # Init lightning model 37 | log.info(f"Instantiating model <{config.model._target_}>") 38 | model: LightningModule = hydra.utils.instantiate(config.model) 39 | 40 | # Init lightning loggers 41 | logger: List[LightningLoggerBase] = [] 42 | if "logger" in config: 43 | for _, lg_conf in config.logger.items(): 44 | if "_target_" in lg_conf: 45 | log.info(f"Instantiating logger <{lg_conf._target_}>") 46 | logger.append(hydra.utils.instantiate(lg_conf)) 47 | 48 | # Init lightning trainer 49 | log.info(f"Instantiating trainer <{config.trainer._target_}>") 50 | trainer: Trainer = hydra.utils.instantiate(config.trainer, logger=logger) 51 | 52 | # Log hyperparameters 53 | if trainer.logger: 54 | trainer.logger.log_hyperparams({"ckpt_path": config.ckpt_path}) 55 | 56 | log.info("Starting testing!") 57 | trainer.test(model=model, datamodule=datamodule, ckpt_path=config.ckpt_path) 58 | -------------------------------------------------------------------------------- /src/training_pipeline.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import List, Optional 3 | 4 | import hydra 5 | from omegaconf import DictConfig 6 | from pytorch_lightning import ( 7 | Callback, 8 | LightningDataModule, 9 | LightningModule, 10 | Trainer, 11 | seed_everything, 12 | ) 13 | from pytorch_lightning.loggers import LightningLoggerBase 14 | from pytorch_lightning.plugins import DDPPlugin 15 | from src import utils 16 | 17 | log = utils.get_logger(__name__) 18 | 19 | 20 | def train(config: DictConfig) -> Optional[float]: 21 | """Contains the training pipeline. Can additionally evaluate model on a testset, using best 22 | weights achieved during training. 23 | 24 | Args: 25 | config (DictConfig): Configuration composed by Hydra. 26 | 27 | Returns: 28 | Optional[float]: Metric score for hyperparameter optimization. 29 | """ 30 | 31 | # Set seed for random number generators in pytorch, numpy and python.random 32 | if config.get("seed"): 33 | seed_everything(config.seed, workers=True) 34 | 35 | # Convert relative ckpt path to absolute path if necessary 36 | ckpt_path = config.trainer.get("resume_from_checkpoint") 37 | if ckpt_path and not os.path.isabs(ckpt_path): 38 | config.trainer.resume_from_checkpoint = os.path.join( 39 | hydra.utils.get_original_cwd(), ckpt_path 40 | ) 41 | 42 | # Init lightning datamodule 43 | log.info(f"Instantiating datamodule <{config.datamodule._target_}>") 44 | datamodule: LightningDataModule = hydra.utils.instantiate(config.datamodule) 45 | 46 | # Init lightning model 47 | log.info(f"Instantiating model <{config.model._target_}>") 48 | model: LightningModule = hydra.utils.instantiate(config.model) 49 | 50 | # Init lightning callbacks 51 | callbacks: List[Callback] = [] 52 | if "callbacks" in config: 53 | for _, cb_conf in config.callbacks.items(): 54 | if "_target_" in cb_conf: 55 | log.info(f"Instantiating callback <{cb_conf._target_}>") 56 | callbacks.append(hydra.utils.instantiate(cb_conf)) 57 | 58 | # Init lightning loggers 59 | logger: List[LightningLoggerBase] = [] 60 | if "logger" in config: 61 | for _, lg_conf in config.logger.items(): 62 | if "_target_" in lg_conf: 63 | log.info(f"Instantiating logger <{lg_conf._target_}>") 64 | logger.append(hydra.utils.instantiate(lg_conf)) 65 | 66 | # Init lightning trainer 67 | log.info(f"Instantiating trainer <{config.trainer._target_}>") 68 | trainer: Trainer = hydra.utils.instantiate( 69 | config.trainer, callbacks=callbacks, logger=logger, _convert_="partial" 70 | ) 71 | 72 | # Send some parameters from config to all lightning loggers 73 | log.info("Logging hyperparameters!") 74 | utils.log_hyperparameters( 75 | config=config, 76 | model=model, 77 | datamodule=datamodule, 78 | trainer=trainer, 79 | callbacks=callbacks, 80 | logger=logger, 81 | ) 82 | 83 | 84 | # Train the model 85 | if config.get("train"): 86 | log.info("Starting training!") 87 | trainer.fit(model=model, datamodule=datamodule, ckpt_path=ckpt_path) 88 | 89 | # Get metric score for hyperparameter optimization 90 | optimized_metric = config.get("optimized_metric") 91 | if optimized_metric and optimized_metric not in trainer.callback_metrics: 92 | raise Exception( 93 | "Metric for hyperparameter optimization not found! " 94 | "Make sure the `optimized_metric` in `hparams_search` config is correct!" 95 | ) 96 | score = trainer.callback_metrics.get(optimized_metric) 97 | 98 | # Test the model 99 | if config.get("test"): 100 | ckpt_path = "best" 101 | if not config.get("train") or config.trainer.get("fast_dev_run"): 102 | ckpt_path = None 103 | log.info("Starting testing!") 104 | trainer.test(model=model, datamodule=datamodule, ckpt_path=ckpt_path) 105 | 106 | # Make sure everything closed properly 107 | log.info("Finalizing!") 108 | utils.finish( 109 | config=config, 110 | model=model, 111 | datamodule=datamodule, 112 | trainer=trainer, 113 | callbacks=callbacks, 114 | logger=logger, 115 | ) 116 | 117 | # Print path to best checkpoint 118 | if not config.trainer.get("fast_dev_run") and config.get("train"): 119 | log.info(f"Best model ckpt at {trainer.checkpoint_callback.best_model_path}") 120 | 121 | # Return metric score for hyperparameter optimization 122 | return score 123 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import dotenv 2 | import hydra 3 | from omegaconf import DictConfig 4 | 5 | # load environment variables from `.env` file if it exists 6 | # recursively searches for `.env` in all folders starting from work dir 7 | dotenv.load_dotenv(override=True) 8 | 9 | 10 | @hydra.main(config_path="configs/", config_name="test.yaml") 11 | def main(config: DictConfig): 12 | 13 | # Imports can be nested inside @hydra.main to optimize tab completion 14 | # https://github.com/facebookresearch/hydra/issues/934 15 | from src import utils 16 | from src.testing_pipeline import test 17 | 18 | # Applies optional utilities 19 | utils.extras(config) 20 | 21 | # Evaluate model 22 | return test(config) 23 | 24 | 25 | if __name__ == "__main__": 26 | main() 27 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VMBoehm/SAR-landslide-detection-pretraining/177799a5f22a1fab38cf0da76b98b3d32843a765/tests/__init__.py -------------------------------------------------------------------------------- /tests/helpers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VMBoehm/SAR-landslide-detection-pretraining/177799a5f22a1fab38cf0da76b98b3d32843a765/tests/helpers/__init__.py -------------------------------------------------------------------------------- /tests/helpers/module_available.py: -------------------------------------------------------------------------------- 1 | import platform 2 | from importlib.util import find_spec 3 | 4 | """ 5 | Adapted from: 6 | https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pytorch_lightning/utilities/imports.py 7 | """ 8 | 9 | 10 | def _module_available(module_path: str) -> bool: 11 | """Check if a path is available in your environment. 12 | 13 | >>> _module_available('os') 14 | True 15 | >>> _module_available('bla.bla') 16 | False 17 | """ 18 | try: 19 | return find_spec(module_path) is not None 20 | except ModuleNotFoundError: 21 | # Python 3.7+ 22 | return False 23 | 24 | 25 | _IS_WINDOWS = platform.system() == "Windows" 26 | _DEEPSPEED_AVAILABLE = not _IS_WINDOWS and _module_available("deepspeed") 27 | _FAIRSCALE_AVAILABLE = not _IS_WINDOWS and _module_available("fairscale.nn") 28 | _RPC_AVAILABLE = not _IS_WINDOWS and _module_available("torch.distributed.rpc") 29 | -------------------------------------------------------------------------------- /tests/helpers/run_command.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import pytest 4 | import sh 5 | 6 | 7 | def run_command(command: List[str]): 8 | """Default method for executing shell commands with pytest.""" 9 | msg = None 10 | try: 11 | sh.python(command) 12 | except sh.ErrorReturnCode as e: 13 | msg = e.stderr.decode() 14 | if msg: 15 | pytest.fail(msg=msg) 16 | -------------------------------------------------------------------------------- /tests/helpers/runif.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from typing import Optional 3 | 4 | import pytest 5 | import torch 6 | from packaging.version import Version 7 | from pkg_resources import get_distribution 8 | 9 | """ 10 | Adapted from: 11 | https://github.com/PyTorchLightning/pytorch-lightning/blob/master/tests/helpers/runif.py 12 | """ 13 | 14 | from tests.helpers.module_available import ( 15 | _DEEPSPEED_AVAILABLE, 16 | _FAIRSCALE_AVAILABLE, 17 | _IS_WINDOWS, 18 | _RPC_AVAILABLE, 19 | ) 20 | 21 | 22 | class RunIf: 23 | """RunIf wrapper for conditional skipping of tests. 24 | 25 | Fully compatible with `@pytest.mark`. 26 | 27 | Example: 28 | 29 | @RunIf(min_torch="1.8") 30 | @pytest.mark.parametrize("arg1", [1.0, 2.0]) 31 | def test_wrapper(arg1): 32 | assert arg1 > 0 33 | """ 34 | 35 | def __new__( 36 | self, 37 | min_gpus: int = 0, 38 | min_torch: Optional[str] = None, 39 | max_torch: Optional[str] = None, 40 | min_python: Optional[str] = None, 41 | skip_windows: bool = False, 42 | rpc: bool = False, 43 | fairscale: bool = False, 44 | deepspeed: bool = False, 45 | **kwargs, 46 | ): 47 | """ 48 | Args: 49 | min_gpus: min number of gpus required to run test 50 | min_torch: minimum pytorch version to run test 51 | max_torch: maximum pytorch version to run test 52 | min_python: minimum python version required to run test 53 | skip_windows: skip test for Windows platform 54 | rpc: requires Remote Procedure Call (RPC) 55 | fairscale: if `fairscale` module is required to run the test 56 | deepspeed: if `deepspeed` module is required to run the test 57 | kwargs: native pytest.mark.skipif keyword arguments 58 | """ 59 | conditions = [] 60 | reasons = [] 61 | 62 | if min_gpus: 63 | conditions.append(torch.cuda.device_count() < min_gpus) 64 | reasons.append(f"GPUs>={min_gpus}") 65 | 66 | if min_torch: 67 | torch_version = get_distribution("torch").version 68 | conditions.append(Version(torch_version) < Version(min_torch)) 69 | reasons.append(f"torch>={min_torch}") 70 | 71 | if max_torch: 72 | torch_version = get_distribution("torch").version 73 | conditions.append(Version(torch_version) >= Version(max_torch)) 74 | reasons.append(f"torch<{max_torch}") 75 | 76 | if min_python: 77 | py_version = ( 78 | f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}" 79 | ) 80 | conditions.append(Version(py_version) < Version(min_python)) 81 | reasons.append(f"python>={min_python}") 82 | 83 | if skip_windows: 84 | conditions.append(_IS_WINDOWS) 85 | reasons.append("does not run on Windows") 86 | 87 | if rpc: 88 | conditions.append(not _RPC_AVAILABLE) 89 | reasons.append("RPC") 90 | 91 | if fairscale: 92 | conditions.append(not _FAIRSCALE_AVAILABLE) 93 | reasons.append("Fairscale") 94 | 95 | if deepspeed: 96 | conditions.append(not _DEEPSPEED_AVAILABLE) 97 | reasons.append("Deepspeed") 98 | 99 | reasons = [rs for cond, rs in zip(conditions, reasons) if cond] 100 | return pytest.mark.skipif( 101 | condition=any(conditions), 102 | reason=f"Requires: [{' + '.join(reasons)}]", 103 | **kwargs, 104 | ) 105 | -------------------------------------------------------------------------------- /tests/shell/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VMBoehm/SAR-landslide-detection-pretraining/177799a5f22a1fab38cf0da76b98b3d32843a765/tests/shell/__init__.py -------------------------------------------------------------------------------- /tests/shell/test_basic_commands.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from tests.helpers.run_command import run_command 4 | from tests.helpers.runif import RunIf 5 | 6 | """ 7 | A couple of sanity checks to make sure the model doesn't crash with different running options. 8 | """ 9 | 10 | 11 | def test_fast_dev_run(): 12 | """Test running for 1 train, val and test batch.""" 13 | command = ["train.py", "++trainer.fast_dev_run=true"] 14 | run_command(command) 15 | 16 | 17 | @pytest.mark.slow 18 | def test_cpu(): 19 | """Test running 1 epoch on CPU.""" 20 | command = ["train.py", "++trainer.max_epochs=1", "++trainer.gpus=0"] 21 | run_command(command) 22 | 23 | 24 | # use RunIf to skip execution of some tests, e.g. when no gpus are available 25 | @RunIf(min_gpus=1) 26 | @pytest.mark.slow 27 | def test_gpu(): 28 | """Test running 1 epoch on GPU.""" 29 | command = [ 30 | "train.py", 31 | "++trainer.max_epochs=1", 32 | "++trainer.gpus=1", 33 | ] 34 | run_command(command) 35 | 36 | 37 | @RunIf(min_gpus=1) 38 | @pytest.mark.slow 39 | def test_mixed_precision(): 40 | """Test running 1 epoch with pytorch native automatic mixed precision (AMP).""" 41 | command = [ 42 | "train.py", 43 | "++trainer.max_epochs=1", 44 | "++trainer.gpus=1", 45 | "++trainer.precision=16", 46 | ] 47 | run_command(command) 48 | 49 | 50 | @pytest.mark.slow 51 | def test_double_validation_loop(): 52 | """Test running 1 epoch with validation loop twice per epoch.""" 53 | command = [ 54 | "train.py", 55 | "++trainer.max_epochs=1", 56 | "++trainer.val_check_interval=0.5", 57 | ] 58 | run_command(command) 59 | -------------------------------------------------------------------------------- /tests/shell/test_debug_configs.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from tests.helpers.run_command import run_command 4 | 5 | 6 | @pytest.mark.slow 7 | def test_debug_default(): 8 | command = ["train.py", "debug=default"] 9 | run_command(command) 10 | 11 | 12 | def test_debug_limit_batches(): 13 | command = ["train.py", "debug=limit_batches"] 14 | run_command(command) 15 | 16 | 17 | def test_debug_overfit(): 18 | command = ["train.py", "debug=overfit"] 19 | run_command(command) 20 | 21 | 22 | @pytest.mark.slow 23 | def test_debug_profiler(): 24 | command = ["train.py", "debug=profiler"] 25 | run_command(command) 26 | 27 | 28 | def test_debug_step(): 29 | command = ["train.py", "debug=step"] 30 | run_command(command) 31 | 32 | 33 | def test_debug_test_only(): 34 | command = ["train.py", "debug=test_only"] 35 | run_command(command) 36 | -------------------------------------------------------------------------------- /tests/shell/test_sweeps.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from tests.helpers.run_command import run_command 4 | 5 | """ 6 | A couple of tests executing hydra sweeps. 7 | 8 | Use the following command to skip slow tests: 9 | pytest -k "not slow" 10 | """ 11 | 12 | 13 | @pytest.mark.slow 14 | def test_experiments(): 15 | """Test running all available experiment configs for 1 epoch.""" 16 | command = ["train.py", "-m", "experiment=glob(*)", "++trainer.max_epochs=1"] 17 | run_command(command) 18 | 19 | 20 | @pytest.mark.slow 21 | def test_default_sweep(): 22 | """Test default Hydra sweeper.""" 23 | command = [ 24 | "train.py", 25 | "-m", 26 | "datamodule.batch_size=64,128", 27 | "model.lr=0.01,0.02", 28 | "trainer=default", 29 | "++trainer.fast_dev_run=true", 30 | ] 31 | run_command(command) 32 | 33 | 34 | @pytest.mark.slow 35 | def test_optuna_sweep(): 36 | """Test Optuna sweeper.""" 37 | command = [ 38 | "train.py", 39 | "-m", 40 | "hparams_search=mnist_optuna", 41 | "trainer=default", 42 | "++trainer.fast_dev_run=true", 43 | ] 44 | run_command(command) 45 | -------------------------------------------------------------------------------- /tests/unit/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VMBoehm/SAR-landslide-detection-pretraining/177799a5f22a1fab38cf0da76b98b3d32843a765/tests/unit/__init__.py -------------------------------------------------------------------------------- /tests/unit/test_mnist_datamodule.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import pytest 4 | import torch 5 | 6 | from src.datamodules.mnist_datamodule import MNISTDataModule 7 | 8 | 9 | @pytest.mark.parametrize("batch_size", [32, 128]) 10 | def test_mnist_datamodule(batch_size): 11 | datamodule = MNISTDataModule(batch_size=batch_size) 12 | datamodule.prepare_data() 13 | 14 | assert not datamodule.data_train and not datamodule.data_val and not datamodule.data_test 15 | 16 | assert os.path.exists(os.path.join("data", "MNIST")) 17 | assert os.path.exists(os.path.join("data", "MNIST", "raw")) 18 | 19 | datamodule.setup() 20 | 21 | assert datamodule.data_train and datamodule.data_val and datamodule.data_test 22 | assert ( 23 | len(datamodule.data_train) + len(datamodule.data_val) + len(datamodule.data_test) == 70_000 24 | ) 25 | 26 | assert datamodule.train_dataloader() 27 | assert datamodule.val_dataloader() 28 | assert datamodule.test_dataloader() 29 | 30 | batch = next(iter(datamodule.train_dataloader())) 31 | x, y = batch 32 | 33 | assert len(x) == batch_size 34 | assert len(y) == batch_size 35 | assert x.dtype == torch.float32 36 | assert y.dtype == torch.int64 37 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import dotenv 2 | import hydra 3 | from omegaconf import DictConfig 4 | 5 | # load environment variables from `.env` file if it exists 6 | # recursively searches for `.env` in all folders starting from work dir 7 | dotenv.load_dotenv(override=True) 8 | 9 | 10 | @hydra.main(config_path="configs/", config_name="train.yaml",version_base="1.1") 11 | def main(config: DictConfig): 12 | 13 | # Imports can be nested inside @hydra.main to optimize tab completion 14 | # https://github.com/facebookresearch/hydra/issues/934 15 | from src import utils 16 | from src.training_pipeline import train 17 | 18 | # Applies optional utilities 19 | utils.extras(config) 20 | 21 | # Train model 22 | return train(config) 23 | 24 | 25 | if __name__ == "__main__": 26 | main() 27 | --------------------------------------------------------------------------------