├── dev_floodsml_Dockerfile ├── FloodsML ├── outpit_file.tif.aux.xml ├── install.txt ├── data_generators │ └── data_generator.py ├── merge_tifs.py ├── README.md ├── utils │ ├── metrics.py │ ├── lr_scheduler.py │ ├── datagen_utils.py │ └── training_utils.py ├── .gitignore ├── models │ ├── unet.py │ └── backbone │ │ └── xception.py ├── check_dimensions.py ├── flood_dataset.py ├── check_sentinel_values.py ├── train_unet.py ├── inference.py ├── sweep_review.ipynb ├── mg_test.ipynb ├── trainers │ └── trainer.py ├── 0_analyze_labels.py ├── 0_create_splits.py ├── optimize_parameters.py ├── optimize_parameters_extended.py └── train_inference_full.py ├── readme.md ├── postprocessing.py ├── stac_generator.py ├── .gitignore ├── HAND ├── compute_flood_map.py ├── mosaic_dem_tools.py └── compute_hand.py └── preprocessing.py /dev_floodsml_Dockerfile: -------------------------------------------------------------------------------- 1 | FROM ghcr.io/osgeo/gdal:ubuntu-full-latest 2 | 3 | # Install latest pip 4 | RUN apt-get update && apt-get install -y python3-pip curl && apt-get clean \ 5 | && apt-get remove -y python3-pip \ 6 | && curl -sS https://bootstrap.pypa.io/get-pip.py -o get-pip.py \ 7 | && python3 get-pip.py --break-system-packages \ 8 | && rm get-pip.py 9 | 10 | # Copy requirements 11 | COPY requirements.txt . 12 | 13 | # Install Python deps 14 | RUN pip3 install --break-system-packages --no-cache-dir -r requirements.txt --ignore-installed jsonschema 15 | 16 | # Set workdir 17 | WORKDIR /app 18 | 19 | # Default command 20 | CMD ["/bin/bash"] -------------------------------------------------------------------------------- /FloodsML/outpit_file.tif.aux.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | -0.25 6 | 1.25 7 | 2 8 | 0 9 | 0 10 | 3181282|65954 11 | 12 | 13 | 14 | 0 15 | 1 16 | 0.020310812025981 17 | 0.14106127371049 18 | 100 19 | 20 | 21 | 22 | -------------------------------------------------------------------------------- /FloodsML/install.txt: -------------------------------------------------------------------------------- 1 | https://visualstudio.microsoft.com/visual-cpp-build-tools/ 2 | 3 | mamba create -n poplave python=3.10 -y && mamba activate poplave 4 | mamba install pytorch torchvision torchaudio torchmetrics tqdm wandb jupyterlab pandas numpy matplotlib seaborn gdal pytorch-cuda=12.4 -c pytorch -c nvidia -y 5 | python -c "import torch; print(torch.cuda.is_available())" 6 | 7 | #transformers 8 | mamba create -n poplave_transformers python=3.12 -y 9 | mamba activate poplave_transformers 10 | mamba install conda-forge::transformers gdal 11 | pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu126 torchmetrics 12 | 13 | #uv 14 | uv venv poplave_transformers --python=3.12 && poplave_transformers\Scripts\uv pip install transformers gdal torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu126 torchmetrics 15 | 16 | uv venv poplave_transformers --python=3.12 17 | poplave_transformers\Scripts\activate 18 | uv pip install transformers 19 | uv pip install gdal rasterio 20 | 21 | cd Documents\poplave && activate poplave_transformers -------------------------------------------------------------------------------- /FloodsML/data_generators/data_generator.py: -------------------------------------------------------------------------------- 1 | #from data_generators.datasets import cityscapes, coco, combine_dbs, pascal, sbd, deepfashion 2 | from torch.utils.data import DataLoader 3 | from data_generators.deepfashion import DeepFashionSegmentation 4 | 5 | 6 | def initialize_data_loader(config): 7 | 8 | 9 | if config['dataset']['dataset_name'] == 'deepfashion': 10 | train_set = DeepFashionSegmentation(config, split='train') 11 | val_set = DeepFashionSegmentation(config, split='val') 12 | test_set = DeepFashionSegmentation(config, split='test') 13 | 14 | else: 15 | raise Exception('dataset not implemented yet!') 16 | 17 | num_classes = train_set.num_classes 18 | train_loader = DataLoader(train_set, batch_size=config['training']['batch_size'], shuffle=True, num_workers=config['training']['workers'], pin_memory=True) 19 | val_loader = DataLoader(val_set, batch_size=config['training']['batch_size'], shuffle=False, num_workers=config['training']['workers'], pin_memory=True) 20 | test_loader = DataLoader(test_set, batch_size=config['training']['batch_size'], shuffle=False, num_workers=config['training']['workers'], pin_memory=True) 21 | 22 | return train_loader, val_loader, test_loader, num_classes 23 | 24 | -------------------------------------------------------------------------------- /FloodsML/merge_tifs.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import os 3 | import argparse 4 | from osgeo import gdal 5 | gdal.UseExceptions() 6 | 7 | def merge_tifs(input_folder, output_tif): 8 | # collect all .tif files from the input folder 9 | tif_files = [ 10 | os.path.join(input_folder, f) 11 | for f in os.listdir(input_folder) 12 | if f.lower().endswith(".tif") 13 | ] 14 | 15 | if not tif_files: 16 | print(f"⚠️ No .tif files found in folder: {input_folder}") 17 | return 18 | 19 | # merge into a single GeoTIFF 20 | gdal.Warp(output_tif, tif_files, format="GTiff") 21 | print(f"✅ Merged raster created: {output_tif}") 22 | 23 | def main(): 24 | parser = argparse.ArgumentParser( 25 | description="Merge all GeoTIFF files from a folder into a single raster." 26 | ) 27 | parser.add_argument( 28 | "-i", "--input", 29 | required=True, 30 | help="Input folder containing GeoTIFF files" 31 | ) 32 | parser.add_argument( 33 | "-o", "--output", 34 | required=True, 35 | help="Output GeoTIFF file (e.g. merged.tif)" 36 | ) 37 | args = parser.parse_args() 38 | 39 | merge_tifs(args.input, args.output) 40 | 41 | if __name__ == "__main__": 42 | main() 43 | -------------------------------------------------------------------------------- /FloodsML/README.md: -------------------------------------------------------------------------------- 1 | # Flood area mapping using deep learning 2 | 3 | ## Virtual environment 4 | First you need to create a virtual environment. 5 | 6 | Using Conda you can type: 7 | 8 | ``` 9 | mamba create -n floods python=3.10 -y 10 | mamba activate floods 11 | mamba install pytorch torchvision torchaudio torchmetrics tqdm wandb psutils jupyterlab pandas numpy matplotlib seaborn gdal pytorch-cuda=12.4 -c pytorch -c nvidia -y 12 | ``` 13 | 14 | Check if cuda is available 15 | ``` 16 | python -c "import torch; print(torch.cuda.is_available())" 17 | ``` 18 | 19 | ## Files 20 | 21 | Files starting with 0_* were used to analyse the dataset and provide splits that had flood pixels. Files starting with check_* were used to ensure that the expected values were available and ensure correctness of the dataset. 22 | 23 | ## Training 24 | 25 | Files starting with optimize_* were used to determine optimal training parameters for the selected model and dataset. While the train_unet.py file was used to train a single model. With the sweep_review.ipynb providing insights into the top performing models. 26 | 27 | ## Inference 28 | 29 | Model inference can be run using the inference.py script by providing input, output folder and model path. Additionally flag --device can be added to manually select the desired device. 30 | ``` 31 | python inference.py /path/to/input_tiffs /path/to/output_predictions /path/to/your/best_model.pth 32 | ``` 33 | -------------------------------------------------------------------------------- /FloodsML/utils/metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class Evaluator(object): 5 | def __init__(self, num_class): 6 | self.num_class = num_class 7 | self.confusion_matrix = np.zeros((self.num_class,)*2) 8 | 9 | def Pixel_Accuracy(self): 10 | Acc = np.diag(self.confusion_matrix).sum() / self.confusion_matrix.sum() 11 | return Acc 12 | 13 | def Pixel_Accuracy_Class(self): 14 | Acc = np.diag(self.confusion_matrix) / self.confusion_matrix.sum(axis=1) 15 | Acc = np.nanmean(Acc) 16 | return Acc 17 | 18 | def Mean_Intersection_over_Union(self): 19 | MIoU = np.diag(self.confusion_matrix) / ( 20 | np.sum(self.confusion_matrix, axis=1) + np.sum(self.confusion_matrix, axis=0) - 21 | np.diag(self.confusion_matrix)) 22 | MIoU = np.nanmean(MIoU) 23 | return MIoU 24 | 25 | def Frequency_Weighted_Intersection_over_Union(self): 26 | freq = np.sum(self.confusion_matrix, axis=1) / np.sum(self.confusion_matrix) 27 | iu = np.diag(self.confusion_matrix) / ( 28 | np.sum(self.confusion_matrix, axis=1) + np.sum(self.confusion_matrix, axis=0) - 29 | np.diag(self.confusion_matrix)) 30 | 31 | FWIoU = (freq[freq > 0] * iu[freq > 0]).sum() 32 | return FWIoU 33 | 34 | def _generate_matrix(self, gt_image, pre_image): 35 | mask = (gt_image >= 0) & (gt_image < self.num_class) 36 | label = self.num_class * gt_image[mask].astype('int') + pre_image[mask] 37 | count = np.bincount(label, minlength=self.num_class**2) 38 | confusion_matrix = count.reshape(self.num_class, self.num_class) 39 | return confusion_matrix 40 | 41 | def add_batch(self, gt_image, pre_image): 42 | assert gt_image.shape == pre_image.shape 43 | self.confusion_matrix += self._generate_matrix(gt_image, pre_image) 44 | 45 | def reset(self): 46 | self.confusion_matrix = np.zeros((self.num_class,) * 2) 47 | -------------------------------------------------------------------------------- /FloodsML/utils/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 | ## Created by: Hang Zhang 3 | ## ECE Department, Rutgers University 4 | ## Email: zhang.hang@rutgers.edu 5 | ## Copyright (c) 2017 6 | ## 7 | ## This source code is licensed under the MIT-style license found in the 8 | ## LICENSE file in the root directory of this source tree 9 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 10 | 11 | import math 12 | 13 | class LR_Scheduler(object): 14 | """Learning Rate Scheduler 15 | 16 | Step mode: ``lr = baselr * 0.1 ^ {floor(epoch-1 / lr_step)}`` 17 | 18 | Cosine mode: ``lr = baselr * 0.5 * (1 + cos(iter/maxiter))`` 19 | 20 | Poly mode: ``lr = baselr * (1 - iter/maxiter) ^ 0.9`` 21 | 22 | Args: 23 | args: 24 | :attr:`args.lr_scheduler` lr scheduler mode (`cos`, `poly`), 25 | :attr:`args.lr` base learning rate, :attr:`args.epochs` number of epochs, 26 | :attr:`args.lr_step` 27 | 28 | iters_per_epoch: number of iterations per epoch 29 | """ 30 | def __init__(self, mode, base_lr, num_epochs, iters_per_epoch=0, 31 | lr_step=0, warmup_epochs=0): 32 | self.mode = mode 33 | print('Using {} LR Scheduler!'.format(self.mode)) 34 | self.lr = base_lr 35 | if mode == 'step': 36 | assert lr_step 37 | self.lr_step = lr_step 38 | self.iters_per_epoch = iters_per_epoch 39 | self.N = num_epochs * iters_per_epoch 40 | self.epoch = -1 41 | self.warmup_iters = warmup_epochs * iters_per_epoch 42 | 43 | def __call__(self, optimizer, i, epoch, best_pred): 44 | T = epoch * self.iters_per_epoch + i 45 | if self.mode == 'cos': 46 | lr = 0.5 * self.lr * (1 + math.cos(1.0 * T / self.N * math.pi)) 47 | elif self.mode == 'poly': 48 | lr = self.lr * pow((1 - 1.0 * T / self.N), 0.9) 49 | elif self.mode == 'step': 50 | lr = self.lr * (0.1 ** (epoch // self.lr_step)) 51 | else: 52 | raise NotImplemented 53 | # warm up lr schedule 54 | if self.warmup_iters > 0 and T < self.warmup_iters: 55 | lr = lr * 1.0 * T / self.warmup_iters 56 | if epoch > self.epoch: 57 | print('\n=>Epoches %i, learning rate = %.4f, \ 58 | previous best = %.4f' % (epoch, lr, best_pred)) 59 | self.epoch = epoch 60 | assert lr >= 0 61 | self._adjust_learning_rate(optimizer, lr) 62 | 63 | def _adjust_learning_rate(self, optimizer, lr): 64 | if len(optimizer.param_groups) == 1: 65 | optimizer.param_groups[0]['lr'] = lr 66 | else: 67 | # enlarge the lr at the head 68 | optimizer.param_groups[0]['lr'] = lr 69 | for i in range(1, len(optimizer.param_groups)): 70 | optimizer.param_groups[i]['lr'] = lr * 10 71 | -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # FloodsML – Flood Detection with Deep Learning 2 | 3 | FloodsML is a deep learning–based workflow for flood extent detection using Sentinel-1 SAR imagery and DEM data. The pipeline is packaged in Docker for reproducibility and portability, and it integrates seamlessly with STAC-compliant datasets. 4 | 5 | --- 6 | 7 | ## 🚀 Setup 8 | 9 | ### 1. Build Docker image 10 | From the project root directory, build the Docker image: 11 | 12 | ```bash 13 | docker build -t dev_floodsml -f dev_floodsml_Dockerfile . 14 | ``` 15 | 16 | ### 2. Run container with mounted input data 17 | Prepare an input folder on your local machine (e.g., `C:\Users\\Desktop\Dev FloodsML\input_src`) and mount it into the container: 18 | 19 | ```bash 20 | docker run -it --rm --name floods_demo_container ^ 21 | -v "C:\Users\\Desktop\Dev FloodsML\input_src:/app/input_src" ^ 22 | dev_floodsml 23 | ``` 24 | 25 | --- 26 | 27 | ## 📥 Input Data 28 | 29 | The mounted folder **must contain**: 30 | - **Sentinel-1 before-flood acquisitions** (VV and VH) 31 | - **Sentinel-1 after-flood acquisitions** (VV and VH) 32 | - **DEM file(s)** (one or more, named `copdem.tif`) 33 | 34 | > ⚠️ All input data must follow the **STAC item structure** so that metadata, georeferencing, and provenance are preserved. 35 | 36 | --- 37 | 38 | ## ▶️ Running the model 39 | 40 | Inside the container, run: 41 | 42 | ```bash 43 | python /app/env/main.py \ 44 | --aoi_wkt "POLYGON((...))" \ 45 | [--hand] \ 46 | [--buffer 5000] \ 47 | [--treshold 5] 48 | ``` 49 | 50 | ### Parameters 51 | - `--aoi_wkt` (required): AOI polygon in WGS84 (EPSG:4326). Must be of type **POLYGON**. 52 | - `--hand` (optional): Apply HAND-based filtering in post-processing. 53 | - `--buffer` (optional, default: `5000`): Buffer distance for HAND calculation. 54 | - `--treshold` (optional, default: `5`): Threshold parameter for HAND filtering. 55 | 56 | --- 57 | 58 | ## 📤 Outputs 59 | 60 | Results are published to: 61 | 62 | ``` 63 | /app/tmp/results 64 | ``` 65 | 66 | If you want to make the results visible on the host system, map the results directory when starting the container, e.g.: 67 | 68 | ```bash 69 | docker run -it --rm --name floods_demo_container ^ 70 | -v "C:\Users\\Desktop\Dev FloodsML\input_src:/app/input_src" ^ 71 | -v "C:\Users\\Desktop\Dev FloodsML\results:/app/tmp/results" ^ 72 | dev_floodsml 73 | ``` 74 | 75 | The outputs are: 76 | - **Cloud-Optimized GeoTIFF (COG) masks of flooded areas** 77 | - Wrapped as **STAC items** for interoperability 78 | 79 | The masks contain three classes: 80 | - `0` – non-water 81 | - `1` – permanent hydrography 82 | - `2` – flooded areas 83 | 84 | This structure ensures compatibility with catalog-based workflows and easy integration with other disaster mapping products. 85 | 86 | --- 87 | 88 | ## 🔄 Workflow Overview 89 | 90 | **STAC input items → Preprocessing & Inference → STAC output items (COG masks)** 91 | 92 | The workflow ensures that both input and output remain fully standardized, enabling transparent use in broader EO platforms and services. 93 | 94 | --- 95 | 96 | 📌 Ready-to-use, portable, and interoperable — FloodsML can be deployed wherever flood mapping support is required. 97 | "# floodsdl-mcube" 98 | -------------------------------------------------------------------------------- /FloodsML/.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | *.wandb 162 | *.json 163 | *.yaml 164 | *.pth 165 | *.tif 166 | /wandb 167 | *.csv 168 | /data 169 | /.vs 170 | /predictions 171 | /predictionsDice 172 | -------------------------------------------------------------------------------- /postprocessing.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import rasterio 4 | from rasterio.warp import reproject, Resampling 5 | from scipy.ndimage import median_filter 6 | from skimage.morphology import (remove_small_objects, remove_small_holes, disk, opening, closing) 7 | import numpy as np 8 | 9 | def process_raster(in_path, out_path=None, 10 | min_obj_size=4, min_hole_size=4, min_region_size=20): 11 | """ 12 | Load a raster, clean it, and save the mask as GeoTIFF. 13 | Returns: path to the output raster. 14 | """ 15 | # Read input raster 16 | with rasterio.open(in_path) as src: 17 | arr = src.read(1, out_dtype='uint8') 18 | profile = src.profile 19 | 20 | # Convert to binary mask (0/1) 21 | arr = (arr > 0).astype(np.uint8) 22 | 23 | # Apply cleaning steps 24 | arr = remove_small_objects(arr.astype(bool), min_size=min_obj_size) 25 | arr = remove_small_holes(arr, area_threshold=min_hole_size) 26 | arr = median_filter(arr.astype(np.uint8), size=2) 27 | arr = closing(opening(arr, disk(1)), disk(1)) 28 | arr = remove_small_objects(arr.astype(bool), min_size=min_region_size) 29 | 30 | arr = arr.astype(np.uint8) 31 | 32 | # Set default output path if not provided 33 | if out_path is None: 34 | out_path = os.path.splitext(in_path)[0] + "_clean.tif" 35 | 36 | # Save cleaned raster 37 | profile.update(dtype=rasterio.uint8, count=1, compress='lzw') 38 | with rasterio.open(out_path, 'w', **profile) as dst: 39 | dst.write(arr, 1) 40 | 41 | # Notify user 42 | print(f"✅ Cleaned raster successfully exported: {out_path}") 43 | 44 | return out_path 45 | 46 | 47 | 48 | def classify_flood(before_path, after_path, output_path): 49 | """ 50 | Create a 0/1/2 classification layer (aligned to AFTER raster grid): 51 | 0 = no change / no water 52 | 1 = water present before event 53 | 2 = new flooded area (after - before > 0) 54 | 55 | :param before_path: path to 'before' mask (.tif) 56 | :param after_path: path to 'after' mask (.tif) 57 | :param output_path: path to save output raster 58 | :return: output_path 59 | """ 60 | 61 | with rasterio.open(after_path) as src_after, rasterio.open(before_path) as src_before: 62 | # Read AFTER array (this is the reference grid) 63 | arr_after = src_after.read(1).astype(int) 64 | 65 | # privzeto uporabimo original BEFORE array 66 | arr_before_aligned = src_before.read(1).astype(int) 67 | 68 | # če se dimenzije ne ujemajo -> reproject 69 | if (src_before.width != src_after.width) or (src_before.height != src_after.height): 70 | print("⚠️ BEFORE and AFTER rasters differ in size. Reprojecting BEFORE → AFTER grid.") 71 | 72 | arr_before_aligned = np.empty_like(arr_after, dtype=np.int32) 73 | 74 | reproject( 75 | source=rasterio.band(src_before, 1), 76 | destination=arr_before_aligned, 77 | src_transform=src_before.transform, 78 | src_crs=src_before.crs, 79 | dst_transform=src_after.transform, 80 | dst_crs=src_after.crs, 81 | resampling=Resampling.nearest 82 | ) 83 | 84 | # Initialize classification with 0 (no change) 85 | classification = np.zeros_like(arr_after, dtype=np.uint8) 86 | 87 | # Areas that were already water before 88 | classification[arr_before_aligned > 0] = 1 89 | 90 | # New flooded areas (after but not before) 91 | classification[(arr_before_aligned == 0) & (arr_after > 0)] = 2 92 | 93 | # Prepare metadata (use AFTER as reference) 94 | out_meta = src_after.meta.copy() 95 | out_meta.update(dtype=rasterio.uint8, count=1) 96 | 97 | # Write output raster 98 | with rasterio.open(output_path, "w", **out_meta) as dst: 99 | dst.write(classification, 1) 100 | 101 | print(f"✅ Flood classification saved to {output_path}") 102 | return output_path -------------------------------------------------------------------------------- /stac_generator.py: -------------------------------------------------------------------------------- 1 | import json 2 | from datetime import datetime 3 | import rasterio 4 | import os 5 | 6 | def build_stac_item( 7 | tif_path: str, 8 | out_path: str, 9 | title: str, 10 | description: str, 11 | ): 12 | """ 13 | Build a STAC Item JSON for a flood mask raster. 14 | 15 | Args: 16 | tif_path (str): Path to input GeoTIFF (COG). 17 | out_path (str): Path to save STAC JSON. 18 | title (str): Human-readable title. 19 | description (str): Description of the dataset. 20 | """ 21 | 22 | with rasterio.open(tif_path) as src: 23 | # bounding box + geometry 24 | bounds = src.bounds 25 | bbox = [bounds.left, bounds.bottom, bounds.right, bounds.top] 26 | geometry = { 27 | "type": "Polygon", 28 | "coordinates": [[ 29 | [bounds.left, bounds.bottom], 30 | [bounds.right, bounds.bottom], 31 | [bounds.right, bounds.top], 32 | [bounds.left, bounds.top], 33 | [bounds.left, bounds.bottom] 34 | ]] 35 | } 36 | 37 | # spatial resolution (assume square pixels) 38 | res_x, res_y = src.res 39 | spatial_resolution = round(abs(res_x), 2) 40 | 41 | # CRS 42 | proj_epsg = None 43 | if src.crs and src.crs.to_epsg(): 44 | proj_epsg = src.crs.to_epsg() 45 | 46 | # timestamps 47 | created = datetime.utcnow().isoformat() + "Z" 48 | base = os.path.splitext(os.path.basename(tif_path))[0] 49 | item_id = f"{base}_{datetime.now().strftime('%Y%m%d%H%M%S')}" 50 | 51 | # file size in MB 52 | file_size_mb = round(os.path.getsize(tif_path) / (1024 * 1024), 6) 53 | 54 | # classification legend 55 | classes = [ 56 | { 57 | "name": "Permanent Water", 58 | "title": "Permanent Water", 59 | "value": 1, 60 | "color_hint": "#1f78b499", 61 | "description": "Areas consistently covered with water." 62 | }, 63 | { 64 | "name": "Flooded Areas", 65 | "title": "Flooded Areas", 66 | "value": 2, 67 | "color_hint": "#e31a1c99", 68 | "description": "Detected flooded areas." 69 | } 70 | ] 71 | 72 | stac_item = { 73 | "type": "Feature", 74 | "stac_version": "1.0.0", 75 | "id": item_id, 76 | "geometry": geometry, 77 | "bbox": bbox, 78 | "properties": { 79 | "title": title, 80 | "description": description, 81 | "created": created, 82 | "datetime": created, # same as created 83 | "renders": { 84 | "overview": { 85 | "title": title, 86 | "assets": ["flood"], 87 | "resampling": "nearest", 88 | "colormap": { 89 | "1": "#1f78b499", 90 | "2": "#e31a1c99" 91 | }, 92 | "nodata": 0 93 | } 94 | } 95 | }, 96 | "assets": { 97 | "flood": { 98 | "href": tif_path, 99 | "type": "image/tiff; application=geotiff", 100 | "title": f"{title} raster", 101 | "raster:bands": [ 102 | { 103 | "data_type": "uint8", 104 | "spatial_resolution": spatial_resolution, 105 | "scale": 1, 106 | "offset": 0 107 | } 108 | ], 109 | "classification:classes": classes, 110 | "roles": ["data", "visual"], 111 | "file:size": file_size_mb 112 | } 113 | } 114 | } 115 | 116 | # add proj:epsg if available 117 | if proj_epsg: 118 | stac_item["properties"]["proj:epsg"] = proj_epsg 119 | 120 | with open(out_path, "w", encoding="utf-8") as f: 121 | json.dump(stac_item, f, indent=2) 122 | 123 | print(f"✅ STAC item saved to {out_path}") 124 | -------------------------------------------------------------------------------- /FloodsML/models/unet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class DoubleConv(nn.Module): 6 | def __init__(self, in_channels, out_channels, mid_channels=None, dropout_rate=0.0): 7 | super().__init__() 8 | if not mid_channels: 9 | mid_channels = out_channels 10 | self.double_conv = nn.Sequential( 11 | nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False), 12 | nn.BatchNorm2d(mid_channels), 13 | nn.ReLU(inplace=True), 14 | nn.Dropout2d(p=dropout_rate), 15 | nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False), 16 | nn.BatchNorm2d(out_channels), 17 | nn.ReLU(inplace=True), 18 | nn.Dropout2d(p=dropout_rate) 19 | ) 20 | 21 | def forward(self, x): 22 | return self.double_conv(x) 23 | 24 | class Down(nn.Module): 25 | def __init__(self, in_channels, out_channels, dropout_rate=0.0): 26 | super().__init__() 27 | self.maxpool_conv = nn.Sequential( 28 | nn.MaxPool2d(2), 29 | DoubleConv(in_channels, out_channels, dropout_rate=dropout_rate) 30 | ) 31 | 32 | def forward(self, x): 33 | return self.maxpool_conv(x) 34 | 35 | class Up(nn.Module): 36 | def __init__(self, in_channels, out_channels, bilinear=True, dropout_rate=0.0): 37 | super().__init__() 38 | 39 | if bilinear: 40 | self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 41 | self.conv = DoubleConv(in_channels, out_channels, in_channels // 2, dropout_rate=dropout_rate) 42 | else: 43 | self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2) 44 | self.conv = DoubleConv(in_channels, out_channels, dropout_rate=dropout_rate) 45 | 46 | def forward(self, x1, x2): 47 | x1 = self.up(x1) 48 | # Handle input sizes that aren't perfectly divisible by 2 49 | diff_y = x2.size()[2] - x1.size()[2] 50 | diff_x = x2.size()[3] - x1.size()[3] 51 | 52 | x1 = F.pad(x1, [diff_x // 2, diff_x - diff_x // 2, 53 | diff_y // 2, diff_y - diff_y // 2]) 54 | x = torch.cat([x2, x1], dim=1) 55 | return self.conv(x) 56 | 57 | class OutConv(nn.Module): 58 | def __init__(self, in_channels, out_channels): 59 | super(OutConv, self).__init__() 60 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1) 61 | 62 | def forward(self, x): 63 | return self.conv(x) 64 | 65 | class UNet(nn.Module): 66 | def __init__(self, n_channels=3, n_classes=1, bilinear=True, dropout_rate=0.0): 67 | """ 68 | Args: 69 | n_channels (int): Number of input channels (3 for RGB) 70 | n_classes (int): Number of output classes (1 for binary segmentation) 71 | bilinear (bool): Use bilinear upsampling instead of transposed convolutions 72 | dropout_rate (float): Dropout rate for regularization (0.0 to 1.0) 73 | """ 74 | super(UNet, self).__init__() 75 | self.n_channels = n_channels 76 | self.n_classes = n_classes 77 | self.bilinear = bilinear 78 | self.dropout_rate = dropout_rate 79 | 80 | factor = 2 if bilinear else 1 81 | 82 | self.inc = DoubleConv(n_channels, 64, dropout_rate=dropout_rate) 83 | self.down1 = Down(64, 128, dropout_rate=dropout_rate) 84 | self.down2 = Down(128, 256, dropout_rate=dropout_rate) 85 | self.down3 = Down(256, 512, dropout_rate=dropout_rate) 86 | self.down4 = Down(512, 1024 // factor, dropout_rate=dropout_rate) 87 | 88 | self.up1 = Up(1024, 512 // factor, bilinear, dropout_rate=dropout_rate) 89 | self.up2 = Up(512, 256 // factor, bilinear, dropout_rate=dropout_rate) 90 | self.up3 = Up(256, 128 // factor, bilinear, dropout_rate=dropout_rate) 91 | self.up4 = Up(128, 64, bilinear, dropout_rate=dropout_rate) 92 | 93 | self.outc = OutConv(64, n_classes) 94 | 95 | def forward(self, x): 96 | x1 = self.inc(x) 97 | x2 = self.down1(x1) 98 | x3 = self.down2(x2) 99 | x4 = self.down3(x3) 100 | x5 = self.down4(x4) 101 | 102 | x = self.up1(x5, x4) 103 | x = self.up2(x, x3) 104 | x = self.up3(x, x2) 105 | x = self.up4(x, x1) 106 | 107 | logits = self.outc(x) 108 | return logits -------------------------------------------------------------------------------- /FloodsML/check_dimensions.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from pathlib import Path 3 | from osgeo import gdal 4 | from tqdm import tqdm 5 | from collections import defaultdict 6 | 7 | def check_dimensions(data_dir='data'): 8 | """Check if all files in the dataset are 512x512 pixels and their value ranges""" 9 | data_dir = Path(data_dir) 10 | labels_dir = data_dir / '512_labels' 11 | 12 | if not labels_dir.exists(): 13 | print(f"Error: Labels directory not found at {labels_dir}") 14 | return 15 | 16 | # Statistics storage 17 | stats = { 18 | 'total_files': 0, 19 | 'correct_size': 0, 20 | 'incorrect_size': 0, 21 | 'errors': [], 22 | 'incorrect_files': defaultdict(list), # Store files with incorrect dimensions 23 | 'value_ranges': { 24 | 'min': float('inf'), 25 | 'max': float('-inf'), 26 | 'unique_values': set() 27 | } 28 | } 29 | 30 | # Count total EMSR folders 31 | emsr_folders = [f for f in labels_dir.iterdir() if f.is_dir()] 32 | 33 | # Iterate through all EMSR folders 34 | for emsr_folder in tqdm(emsr_folders, desc="Processing EMSR folders"): 35 | # Get all FloodMask files 36 | flood_files = list(emsr_folder.glob('*_FloodMask.tif')) 37 | if not flood_files: 38 | stats['errors'].append(f"No FloodMask files found in {emsr_folder.name}") 39 | continue 40 | 41 | for flood_file in flood_files: 42 | stats['total_files'] += 1 43 | 44 | # Read label 45 | try: 46 | ds = gdal.Open(str(flood_file)) 47 | if ds is None: 48 | stats['errors'].append(f"Could not open {flood_file}") 49 | continue 50 | 51 | # Get dimensions 52 | width = ds.RasterXSize 53 | height = ds.RasterYSize 54 | 55 | # Read data and check values 56 | band = ds.GetRasterBand(1) 57 | data = band.ReadAsArray() 58 | 59 | # Update min and max values 60 | stats['value_ranges']['min'] = min(stats['value_ranges']['min'], np.min(data)) 61 | stats['value_ranges']['max'] = max(stats['value_ranges']['max'], np.max(data)) 62 | stats['value_ranges']['unique_values'].update(np.unique(data)) 63 | 64 | ds = None 65 | 66 | # Check dimensions 67 | if width == 512 and height == 512: 68 | stats['correct_size'] += 1 69 | else: 70 | stats['incorrect_size'] += 1 71 | stats['incorrect_files'][emsr_folder.name].append({ 72 | 'file': flood_file.name, 73 | 'width': width, 74 | 'height': height 75 | }) 76 | 77 | except Exception as e: 78 | stats['errors'].append(f"Error reading {flood_file}: {str(e)}") 79 | continue 80 | 81 | # Print summary 82 | print("\n=== File Dimension Check Summary ===") 83 | print(f"Total files checked: {stats['total_files']}") 84 | print(f"Files with correct dimensions (512x512): {stats['correct_size']}") 85 | print(f"Files with incorrect dimensions: {stats['incorrect_size']}") 86 | 87 | print("\n=== Value Range Summary ===") 88 | print(f"Minimum value: {stats['value_ranges']['min']}") 89 | print(f"Maximum value: {stats['value_ranges']['max']}") 90 | print(f"Unique values found: {sorted(list(stats['value_ranges']['unique_values']))}") 91 | 92 | if stats['incorrect_size'] > 0: 93 | print("\n=== Files with Incorrect Dimensions ===") 94 | for emsr, files in stats['incorrect_files'].items(): 95 | print(f"\nEMSR: {emsr}") 96 | for file_info in files: 97 | print(f" - {file_info['file']}: {file_info['width']}x{file_info['height']}") 98 | 99 | if stats['errors']: 100 | print("\n=== Errors and Warnings ===") 101 | for error in stats['errors']: 102 | print(f"- {error}") 103 | 104 | if stats['total_files'] == 0: 105 | print("\nNo files found! Please check:") 106 | print("1. The directory structure is correct") 107 | print("2. The files are valid GeoTIFFs") 108 | return 109 | 110 | if __name__ == "__main__": 111 | check_dimensions() -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Created by https://www.toptal.com/developers/gitignore/api/python,git 2 | # Edit at https://www.toptal.com/developers/gitignore?templates=python,git 3 | 4 | ### Git ### 5 | # Created by git for backups. To disable backups in Git: 6 | # $ git config --global mergetool.keepBackup false 7 | *.orig 8 | 9 | # Created by git when using merge tools for conflicts 10 | *.BACKUP.* 11 | *.BASE.* 12 | *.LOCAL.* 13 | *.REMOTE.* 14 | *_BACKUP_*.txt 15 | *_BASE_*.txt 16 | *_LOCAL_*.txt 17 | *_REMOTE_*.txt 18 | 19 | ### Python ### 20 | # Byte-compiled / optimized / DLL files 21 | __pycache__/ 22 | *.py[cod] 23 | *$py.class 24 | 25 | # C extensions 26 | *.so 27 | 28 | # Distribution / packaging 29 | .Python 30 | build/ 31 | develop-eggs/ 32 | dist/ 33 | downloads/ 34 | eggs/ 35 | .eggs/ 36 | lib/ 37 | lib64/ 38 | parts/ 39 | sdist/ 40 | var/ 41 | wheels/ 42 | share/python-wheels/ 43 | *.egg-info/ 44 | .installed.cfg 45 | *.egg 46 | MANIFEST 47 | 48 | # PyInstaller 49 | # Usually these files are written by a python script from a template 50 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 51 | *.manifest 52 | *.spec 53 | 54 | # Installer logs 55 | pip-log.txt 56 | pip-delete-this-directory.txt 57 | 58 | # Unit test / coverage reports 59 | htmlcov/ 60 | .tox/ 61 | .nox/ 62 | .coverage 63 | .coverage.* 64 | .cache 65 | nosetests.xml 66 | coverage.xml 67 | *.cover 68 | *.py,cover 69 | .hypothesis/ 70 | .pytest_cache/ 71 | cover/ 72 | 73 | # Translations 74 | *.mo 75 | *.pot 76 | 77 | # Django stuff: 78 | *.log 79 | local_settings.py 80 | db.sqlite3 81 | db.sqlite3-journal 82 | 83 | # Flask stuff: 84 | instance/ 85 | .webassets-cache 86 | 87 | # Scrapy stuff: 88 | .scrapy 89 | 90 | # Sphinx documentation 91 | docs/_build/ 92 | 93 | # PyBuilder 94 | .pybuilder/ 95 | target/ 96 | 97 | # Jupyter Notebook 98 | .ipynb_checkpoints 99 | 100 | # IPython 101 | profile_default/ 102 | ipython_config.py 103 | 104 | # pyenv 105 | # For a library or package, you might want to ignore these files since the code is 106 | # intended to run in multiple environments; otherwise, check them in: 107 | # .python-version 108 | 109 | # pipenv 110 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 111 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 112 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 113 | # install all needed dependencies. 114 | #Pipfile.lock 115 | 116 | # poetry 117 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 118 | # This is especially recommended for binary packages to ensure reproducibility, and is more 119 | # commonly ignored for libraries. 120 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 121 | #poetry.lock 122 | 123 | # pdm 124 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 125 | #pdm.lock 126 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 127 | # in version control. 128 | # https://pdm.fming.dev/#use-with-ide 129 | .pdm.toml 130 | 131 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 132 | __pypackages__/ 133 | 134 | # Celery stuff 135 | celerybeat-schedule 136 | celerybeat.pid 137 | 138 | # SageMath parsed files 139 | *.sage.py 140 | 141 | # Environments 142 | .env 143 | .venv 144 | env/ 145 | venv/ 146 | ENV/ 147 | env.bak/ 148 | venv.bak/ 149 | 150 | # Spyder project settings 151 | .spyderproject 152 | .spyproject 153 | 154 | # Rope project settings 155 | .ropeproject 156 | 157 | # mkdocs documentation 158 | /site 159 | 160 | # mypy 161 | .mypy_cache/ 162 | .dmypy.json 163 | dmypy.json 164 | 165 | # Pyre type checker 166 | .pyre/ 167 | 168 | # pytype static type analyzer 169 | .pytype/ 170 | 171 | # Cython debug symbols 172 | cython_debug/ 173 | 174 | # PyCharm 175 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 176 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 177 | # and can be added to the global gitignore or merged into this file. For a more nuclear 178 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 179 | #.idea/ 180 | 181 | ### Python Patch ### 182 | # Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration 183 | poetry.toml 184 | 185 | # ruff 186 | .ruff_cache/ 187 | 188 | # LSP config files 189 | pyrightconfig.json 190 | 191 | # End of https://www.toptal.com/developers/gitignore/api/python,git -------------------------------------------------------------------------------- /HAND/compute_flood_map.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import numpy as np 4 | import rasterio 5 | from skimage.morphology import remove_small_objects, remove_small_holes, opening, closing, disk 6 | from scipy.ndimage import median_filter 7 | from rasterio.features import shapes 8 | import geopandas as gpd 9 | from shapely.geometry import shape 10 | 11 | def clean_binary_raster(in_fn, band=1, min_obj_size=4, min_hole_size=4, min_region_size=20, selem_radius=1, median_size=3): 12 | with rasterio.open(in_fn) as src: 13 | mask = src.read(band, out_dtype="uint8") 14 | profile = src.profile.copy() 15 | mask = (mask > 0) 16 | mask = remove_small_objects(mask, min_size=min_obj_size) 17 | mask = remove_small_holes(mask, area_threshold=min_hole_size) 18 | mask = median_filter(mask.astype(np.uint8), size=median_size) > 0 19 | if selem_radius > 0: 20 | selem = disk(selem_radius) 21 | mask = opening(mask, selem) 22 | mask = closing(mask, selem) 23 | mask = remove_small_objects(mask, min_size=min_region_size) 24 | mask_u8 = mask.astype(np.uint8) 25 | profile.update(dtype="uint8", count=1, nodata=None) 26 | return mask_u8, profile 27 | 28 | def compute_flood_map(s1_before_fn, s1_after_fn, hand_fn, out_flood_map_fn, out_vector_fn, valley_thresh=5.0): 29 | # Filter S1 before 30 | s1_before_cleaned, s1_before_profile = clean_binary_raster(s1_before_fn) 31 | # Filter S1 after 32 | s1_after_cleaned, s1_after_profile = clean_binary_raster(s1_after_fn) 33 | # Compute valley mask if HAND is provided 34 | valley_mask = None 35 | if hand_fn is not None and os.path.exists(hand_fn): 36 | with rasterio.open(hand_fn) as src: 37 | hand = src.read(1) 38 | h0 = valley_thresh # meters 39 | hand_filtered = median_filter(hand, size=3) 40 | valley_mask = hand_filtered <= h0 41 | valley_mask = remove_small_holes(valley_mask, area_threshold=64) 42 | valley_mask = remove_small_objects(valley_mask, min_size=64) 43 | # Compute flood map 44 | s1_diff = s1_after_cleaned - s1_before_cleaned 45 | meta = s1_before_profile.copy() 46 | meta.update(dtype="uint8", count=1, nodata=None) 47 | flood_map = np.zeros_like(s1_diff, dtype=np.uint8) 48 | # Permanent water (before) in valley or everywhere 49 | if valley_mask is not None: 50 | flood_map[(s1_before_cleaned == 1) & (valley_mask)] = 1 51 | flood_map[(s1_diff == 1) & (valley_mask)] = 2 52 | else: 53 | flood_map[(s1_before_cleaned == 1)] = 1 54 | flood_map[(s1_diff == 1)] = 2 55 | with rasterio.open(out_flood_map_fn, 'w', **meta) as dst: 56 | dst.write(flood_map, 1) 57 | # Vectorize flood extent (only value 2) 58 | with rasterio.open(out_flood_map_fn) as src: 59 | flood_data = src.read(1) 60 | transform = src.transform 61 | crs = src.crs 62 | flood_mask = flood_data == 2 63 | shapes_gen = shapes(flood_data, mask=flood_mask, transform=transform) 64 | geoms = [] 65 | flooded = [] 66 | for geom, value in shapes_gen: 67 | if value == 2: 68 | geoms.append(shape(geom)) 69 | flooded.append(1) 70 | gdf = gpd.GeoDataFrame({'Flooded': flooded}, geometry=geoms, crs=crs) 71 | num_polygons = len(gdf) 72 | total_area = gdf.geometry.area.sum() 73 | print(f"Number of polygons: {num_polygons}") 74 | gdf.to_file(out_vector_fn, driver='GPKG') 75 | 76 | if __name__ == "__main__": 77 | import argparse 78 | import time 79 | parser = argparse.ArgumentParser(description="Compute flood map and vectorize extent.") 80 | parser.add_argument('--s1_before', type=str, required=True, help='Path to S1 before flood raster') 81 | parser.add_argument('--s1_after', type=str, required=True, help='Path to S1 after flood raster') 82 | parser.add_argument('--hand', type=str, required=False, default=None, help='Path to HAND raster (optional)') 83 | parser.add_argument('--out_flood_map', type=str, required=True, help='Output flood map raster') 84 | parser.add_argument('--out_vector', type=str, required=True, help='Output vector file (GeoPackage)') 85 | parser.add_argument('--valley_thresh', type=float, default=5.0, help='Valley threshold for HAND (default: 5)') 86 | args = parser.parse_args() 87 | 88 | # Start processing 89 | print('Starting flood map computation...') 90 | t0 = time.time() 91 | compute_flood_map( 92 | args.s1_before, 93 | args.s1_after, 94 | args.hand, 95 | args.out_flood_map, 96 | args.out_vector, 97 | valley_thresh=args.valley_thresh 98 | ) 99 | print('Flood map computation done.') 100 | print(f"Output flood map: {args.out_flood_map}") 101 | print(f"Output vector file: {args.out_vector}") 102 | t1 = time.time() 103 | print(f"Total elapsed: {t1-t0:.2f} s") 104 | print('Done.') -------------------------------------------------------------------------------- /FloodsML/utils/datagen_utils.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | import torch 4 | 5 | def decode_seg_map_sequence(label_masks, dataset='pascal'): 6 | rgb_masks = [] 7 | for label_mask in label_masks: 8 | rgb_mask = decode_segmap(label_mask, dataset) 9 | rgb_masks.append(rgb_mask) 10 | rgb_masks = torch.from_numpy(np.array(rgb_masks).transpose([0, 3, 1, 2])) 11 | return rgb_masks 12 | 13 | 14 | def decode_segmap(label_mask, dataset, plot=False): 15 | """Decode segmentation class labels into a color image 16 | Args: 17 | label_mask (np.ndarray): an (M,N) array of integer values denoting 18 | the class label at each spatial location. 19 | plot (bool, optional): whether to show the resulting color image 20 | in a figure. 21 | Returns: 22 | (np.ndarray, optional): the resulting decoded color image. 23 | """ 24 | if dataset == 'pascal' or dataset == 'coco': 25 | n_classes = 21 26 | label_colours = get_pascal_labels() 27 | elif dataset == 'cityscapes': 28 | n_classes = 19 29 | label_colours = get_cityscapes_labels() 30 | 31 | elif dataset == 'deepfashion': 32 | n_classes = 13 33 | label_colours = get_deepfashion_labels() 34 | 35 | elif dataset == 'braintumor': 36 | n_classes = 3 37 | label_colours = get_braintumor_labels() 38 | 39 | 40 | else: 41 | raise NotImplementedError 42 | 43 | r = label_mask.copy() 44 | g = label_mask.copy() 45 | b = label_mask.copy() 46 | for ll in range(0, n_classes): 47 | r[label_mask == ll] = label_colours[ll, 0] 48 | g[label_mask == ll] = label_colours[ll, 1] 49 | b[label_mask == ll] = label_colours[ll, 2] 50 | rgb = np.zeros((label_mask.shape[0], label_mask.shape[1], 3)) 51 | rgb[:, :, 0] = r / 255.0 52 | rgb[:, :, 1] = g / 255.0 53 | rgb[:, :, 2] = b / 255.0 54 | if plot: 55 | plt.imshow(rgb) 56 | plt.show() 57 | else: 58 | return rgb 59 | 60 | 61 | def encode_segmap(mask): 62 | """Encode segmentation label images as pascal classes 63 | Args: 64 | mask (np.ndarray): raw segmentation label image of dimension 65 | (M, N, 3), in which the Pascal classes are encoded as colours. 66 | Returns: 67 | (np.ndarray): class map with dimensions (M,N), where the value at 68 | a given location is the integer denoting the class index. 69 | """ 70 | mask = mask.astype(int) 71 | label_mask = np.zeros((mask.shape[0], mask.shape[1]), dtype=np.int16) 72 | for ii, label in enumerate(get_pascal_labels()): 73 | label_mask[np.where(np.all(mask == label, axis=-1))[:2]] = ii 74 | label_mask = label_mask.astype(int) 75 | return label_mask 76 | 77 | 78 | def get_cityscapes_labels(): 79 | return np.array([ 80 | [128, 64, 128], 81 | [244, 35, 232], 82 | [70, 70, 70], 83 | [102, 102, 156], 84 | [190, 153, 153], 85 | [153, 153, 153], 86 | [250, 170, 30], 87 | [220, 220, 0], 88 | [107, 142, 35], 89 | [152, 251, 152], 90 | [0, 130, 180], 91 | [220, 20, 60], 92 | [255, 0, 0], 93 | [0, 0, 142], 94 | [0, 0, 70], 95 | [0, 60, 100], 96 | [0, 80, 100], 97 | [0, 0, 230], 98 | [119, 11, 32]]) 99 | 100 | 101 | def get_pascal_labels(): 102 | """Load the mapping that associates pascal classes with label colors 103 | Returns: 104 | np.ndarray with dimensions (21, 3) 105 | """ 106 | return np.asarray([[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0], 107 | [0, 0, 128], [128, 0, 128], [0, 128, 128], [128, 128, 128], 108 | [64, 0, 0], [192, 0, 0], [64, 128, 0], [192, 128, 0], 109 | [64, 0, 128], [192, 0, 128], [64, 128, 128], [192, 128, 128], 110 | [0, 64, 0], [128, 64, 0], [0, 192, 0], [128, 192, 0], 111 | [0, 64, 128]]) 112 | 113 | 114 | def get_deepfashion_labels(): 115 | return np.array([ 116 | [128, 64, 128], 117 | [244, 35, 232], 118 | [70, 70, 70], 119 | [102, 102, 156], 120 | [190, 153, 153], 121 | [153, 153, 153], 122 | [250, 170, 30], 123 | [220, 220, 0], 124 | [107, 142, 35], 125 | [152, 251, 152], 126 | [0, 130, 180], 127 | [220, 20, 60], 128 | [255, 0, 0], 129 | [0, 0, 142], 130 | [0, 0, 70], 131 | [0, 60, 100], 132 | [0, 80, 100], 133 | [0, 0, 230], 134 | [119, 11, 32]]) 135 | 136 | def get_braintumor_labels(): 137 | return np.array([ 138 | [128, 64, 128], 139 | [244, 35, 232], 140 | [70, 70, 70]]) 141 | 142 | 143 | 144 | def denormalize_image(image): 145 | mean=(0.485, 0.456, 0.406) 146 | std=(0.229, 0.224, 0.225) 147 | 148 | image *= std 149 | image += mean 150 | 151 | return image -------------------------------------------------------------------------------- /FloodsML/flood_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from torch.utils.data import Dataset 4 | from osgeo import gdal 5 | import torch 6 | from torchvision import transforms 7 | 8 | class FloodSegmentationDataset(Dataset): 9 | def __init__(self, map_file, base_path, split='train', transform=None): 10 | """ 11 | Args: 12 | map_file (string): Path to the map file with VH|VV|DEM|label pairs 13 | base_path (string): Base path to the dataset 14 | split (string): 'train' or 'val' split 15 | transform (callable, optional): Optional transform to be applied 16 | """ 17 | self.base_path = base_path 18 | self.split = split 19 | self.transform = transform 20 | self.target_size = (512, 512) 21 | 22 | # VH channel statistics 23 | self.vh_mean = -11.2832 24 | self.vh_std = 5.6656 25 | 26 | # VV channel statistics 27 | self.vv_mean = -6.9761 28 | self.vv_std = 4.3096 29 | 30 | # Read the map file and verify dimensions 31 | self.image_label_pairs = [] 32 | with open(map_file, 'r') as f: 33 | for line in f.readlines(): 34 | if not line.strip(): 35 | continue 36 | 37 | vh_path, vv_path, dem_path, label_path = line.strip().split('|') 38 | 39 | # Verify dimensions of VH image 40 | vh_ds = gdal.Open(os.path.join(self.base_path, vh_path)) 41 | if vh_ds is None: 42 | continue 43 | 44 | if vh_ds.RasterXSize != 512 or vh_ds.RasterYSize != 512: 45 | vh_ds = None 46 | continue 47 | 48 | # Verify dimensions of VV image 49 | vv_ds = gdal.Open(os.path.join(self.base_path, vv_path)) 50 | if vv_ds is None: 51 | vh_ds = None 52 | continue 53 | 54 | if vv_ds.RasterXSize != 512 or vv_ds.RasterYSize != 512: 55 | vh_ds = None 56 | vv_ds = None 57 | continue 58 | 59 | # Verify dimensions of label 60 | label_ds = gdal.Open(os.path.join(self.base_path, label_path)) 61 | if label_ds is None: 62 | vh_ds = None 63 | vv_ds = None 64 | continue 65 | 66 | if label_ds.RasterXSize != 512 or label_ds.RasterYSize != 512: 67 | vh_ds = None 68 | vv_ds = None 69 | label_ds = None 70 | continue 71 | 72 | # If we get here, all images are 512x512 73 | self.image_label_pairs.append([vh_path, vv_path, dem_path, label_path]) 74 | 75 | # Close datasets 76 | vh_ds = None 77 | vv_ds = None 78 | label_ds = None 79 | 80 | print(f"Found {len(self.image_label_pairs)} image pairs for {split} split") 81 | 82 | def __len__(self): 83 | return len(self.image_label_pairs) 84 | 85 | def __getitem__(self, idx): 86 | vh_path, vv_path, dem_path, label_path = self.image_label_pairs[idx] 87 | 88 | # Load VH and VV images using GDAL 89 | vh_ds = gdal.Open(os.path.join(self.base_path, vh_path)) 90 | vv_ds = gdal.Open(os.path.join(self.base_path, vv_path)) 91 | 92 | # Read arrays (we know they're 512x512) 93 | vh = vh_ds.ReadAsArray() 94 | vv = vv_ds.ReadAsArray() 95 | 96 | # Normalize VH and VV channels 97 | vh = (vh - self.vh_mean) / self.vh_std 98 | vv = (vv - self.vv_mean) / self.vv_std 99 | 100 | # Stack VH and VV channels 101 | img = np.stack([vh, vv], axis=0) 102 | 103 | # Load label 104 | label_ds = gdal.Open(os.path.join(self.base_path, label_path)) 105 | label = label_ds.ReadAsArray() 106 | label = label.astype(np.float32) 107 | 108 | # Map values 1 and 2 to 1 to create binary mask 109 | label = np.where((label == 1) | (label == 2), 1, 0) 110 | 111 | # Apply transforms if provided 112 | if self.transform: 113 | transformed = self.transform(image=img, mask=label) 114 | img = transformed['image'] 115 | label = transformed['mask'] 116 | 117 | # Convert to torch tensors 118 | img = torch.from_numpy(img).float() 119 | label = torch.from_numpy(label).float().unsqueeze(0) # Add channel dimension 120 | 121 | # Close GDAL datasets 122 | vh_ds = None 123 | vv_ds = None 124 | label_ds = None 125 | 126 | return { 127 | 'image': img, 128 | 'label': label, 129 | 'image_path': vh_path, 130 | 'label_path': label_path, 131 | #'dem_path': dem_path # Store DEM path for future use 132 | } -------------------------------------------------------------------------------- /FloodsML/check_sentinel_values.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from pathlib import Path 3 | from osgeo import gdal 4 | from tqdm import tqdm 5 | from collections import defaultdict 6 | 7 | def check_sentinel_values(data_dir='data'): 8 | """Check value ranges in Sentinel-1 data (VH and VV channels)""" 9 | data_dir = Path(data_dir) 10 | 11 | if not data_dir.exists(): 12 | print(f"Error: Data directory not found at {data_dir}") 13 | return 14 | 15 | # Statistics storage 16 | stats = { 17 | 'total_files': 0, 18 | 'errors': [], 19 | 'vh_stats': { 20 | 'min': float('inf'), 21 | 'max': float('-inf'), 22 | 'mean': 0, 23 | 'std': 0 24 | }, 25 | 'vv_stats': { 26 | 'min': float('inf'), 27 | 'max': float('-inf'), 28 | 'mean': 0, 29 | 'std': 0 30 | } 31 | } 32 | 33 | # Read both train and val map files 34 | map_files = [data_dir / 'train.txt', data_dir / 'val.txt'] 35 | 36 | for map_file in map_files: 37 | if not map_file.exists(): 38 | print(f"Warning: Map file not found at {map_file}") 39 | continue 40 | 41 | # Read the map file 42 | try: 43 | with open(map_file, 'r') as f: 44 | image_pairs = [line.strip().split('|') for line in f.readlines() if line.strip()] 45 | except Exception as e: 46 | print(f"Error reading map file {map_file}: {str(e)}") 47 | continue 48 | 49 | # Process each pair 50 | for vh_path, vv_path, dem_path, label_path in tqdm(image_pairs, desc=f"Processing {map_file.name}"): 51 | stats['total_files'] += 1 52 | 53 | # Process VH channel 54 | try: 55 | vh_ds = gdal.Open(str(data_dir / vh_path)) 56 | if vh_ds is None: 57 | stats['errors'].append(f"Could not open VH file: {vh_path}") 58 | continue 59 | 60 | vh_data = vh_ds.ReadAsArray() 61 | stats['vh_stats']['min'] = min(stats['vh_stats']['min'], np.min(vh_data)) 62 | stats['vh_stats']['max'] = max(stats['vh_stats']['max'], np.max(vh_data)) 63 | stats['vh_stats']['mean'] += np.mean(vh_data) 64 | stats['vh_stats']['std'] += np.std(vh_data) 65 | vh_ds = None 66 | 67 | except Exception as e: 68 | stats['errors'].append(f"Error processing VH file {vh_path}: {str(e)}") 69 | 70 | # Process VV channel 71 | try: 72 | vv_ds = gdal.Open(str(data_dir / vv_path)) 73 | if vv_ds is None: 74 | stats['errors'].append(f"Could not open VV file: {vv_path}") 75 | continue 76 | 77 | vv_data = vv_ds.ReadAsArray() 78 | stats['vv_stats']['min'] = min(stats['vv_stats']['min'], np.min(vv_data)) 79 | stats['vv_stats']['max'] = max(stats['vv_stats']['max'], np.max(vv_data)) 80 | stats['vv_stats']['mean'] += np.mean(vv_data) 81 | stats['vv_stats']['std'] += np.std(vv_data) 82 | vv_ds = None 83 | 84 | except Exception as e: 85 | stats['errors'].append(f"Error processing VV file {vv_path}: {str(e)}") 86 | 87 | # Calculate final statistics 88 | if stats['total_files'] > 0: 89 | stats['vh_stats']['mean'] /= stats['total_files'] 90 | stats['vh_stats']['std'] /= stats['total_files'] 91 | stats['vv_stats']['mean'] /= stats['total_files'] 92 | stats['vv_stats']['std'] /= stats['total_files'] 93 | 94 | # Print summary 95 | print("\n=== Sentinel-1 Data Statistics ===") 96 | print(f"Total files processed: {stats['total_files']}") 97 | 98 | print("\nVH Channel Statistics:") 99 | print(f" Minimum value: {stats['vh_stats']['min']:.4f}") 100 | print(f" Maximum value: {stats['vh_stats']['max']:.4f}") 101 | print(f" Mean value: {stats['vh_stats']['mean']:.4f}") 102 | print(f" Standard deviation: {stats['vh_stats']['std']:.4f}") 103 | 104 | print("\nVV Channel Statistics:") 105 | print(f" Minimum value: {stats['vv_stats']['min']:.4f}") 106 | print(f" Maximum value: {stats['vv_stats']['max']:.4f}") 107 | print(f" Mean value: {stats['vv_stats']['mean']:.4f}") 108 | print(f" Standard deviation: {stats['vv_stats']['std']:.4f}") 109 | 110 | if stats['errors']: 111 | print("\n=== Errors and Warnings ===") 112 | for error in stats['errors']: 113 | print(f"- {error}") 114 | 115 | if stats['total_files'] == 0: 116 | print("\nNo files found! Please check:") 117 | print("1. The data directory exists") 118 | print("2. The train.txt and val.txt files exist and are readable") 119 | print("3. The paths in the map files are correct") 120 | print("4. The files are valid GeoTIFFs") 121 | 122 | if __name__ == "__main__": 123 | check_sentinel_values() -------------------------------------------------------------------------------- /HAND/mosaic_dem_tools.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import argparse 4 | import rasterio 5 | from rasterio.merge import merge 6 | from rasterio.warp import transform_bounds 7 | from rasterio.windows import from_bounds 8 | import numpy as np 9 | import time 10 | 11 | def find_dem_files(folder, pattern='*.tif'): 12 | """Recursively find all DEM files in folder and subfolders.""" 13 | dem_files = [] 14 | for root, _, files in os.walk(folder): 15 | for file in files: 16 | # Add if filename is copdem.tif 17 | if file.lower() == 'copdem.tif': 18 | dem_files.append(os.path.join(root, file)) 19 | return dem_files 20 | 21 | def mosaic_dems(dem_files, out_fn): 22 | """Mosaic a list of DEM files and save to out_fn.""" 23 | src_files_to_mosaic = [rasterio.open(fp) for fp in dem_files] 24 | mosaic, out_trans = merge(src_files_to_mosaic) 25 | out_meta = src_files_to_mosaic[0].meta.copy() 26 | out_meta.update({ 27 | "driver": "GTiff", 28 | "height": mosaic.shape[1], 29 | "width": mosaic.shape[2], 30 | "transform": out_trans, 31 | "count": 1, 32 | "dtype": "float32", 33 | "nodata": -9999.0, 34 | "compress": "deflate", 35 | "tiled": True, 36 | "blockxsize": 256, 37 | "blockysize": 256, 38 | }) 39 | with rasterio.open(out_fn, "w", **out_meta) as dest: 40 | dest.write(mosaic[0].astype(np.float32), 1) 41 | for src in src_files_to_mosaic: 42 | src.close() 43 | 44 | def crop_to_reference_with_buffer(in_fn, ref_fn, out_fn, buffer_m=0): 45 | """Crop raster to reference extent with buffer (meters).""" 46 | with rasterio.open(ref_fn) as ref: 47 | ref_bounds = ref.bounds 48 | ref_crs = ref.crs 49 | ref_transform = ref.transform 50 | ref_height = ref.height 51 | ref_width = ref.width 52 | ref_dtype = "float32" 53 | ref_nodata = -9999.0 54 | ref_res = ref.res[0] # Assume square pixels 55 | # Buffer in pixels 56 | minx = ref_bounds.left - buffer_m 57 | miny = ref_bounds.bottom - buffer_m 58 | maxx = ref_bounds.right + buffer_m 59 | maxy = ref_bounds.top + buffer_m 60 | buffered_bounds = (minx, miny, maxx, maxy) 61 | 62 | with rasterio.open(in_fn) as src: 63 | # Prepare destination array and profile 64 | dst_array = np.full((ref_height, ref_width), ref_nodata, dtype=np.float32) 65 | out_profile = src.profile.copy() 66 | out_profile.update({ 67 | "driver": "GTiff", 68 | "height": ref_height, 69 | "width": ref_width, 70 | "transform": ref_transform, 71 | "crs": ref_crs, 72 | "count": 1, 73 | "dtype": "float32", 74 | "nodata": -9999.0, 75 | "compress": "deflate", 76 | "tiled": True, 77 | "blockxsize": 256, 78 | "blockysize": 256, 79 | }) 80 | from rasterio.warp import reproject, Resampling 81 | reproject( 82 | source=rasterio.band(src, 1), 83 | destination=dst_array, 84 | src_transform=src.transform, 85 | src_crs=src.crs, 86 | src_nodata=src.nodata, 87 | dst_transform=ref_transform, 88 | dst_crs=ref_crs, 89 | dst_nodata=ref_nodata, 90 | resampling=Resampling.bilinear, 91 | ) 92 | with rasterio.open(out_fn, "w", **out_profile) as dest: 93 | dest.write(dst_array, 1) 94 | 95 | if __name__ == "__main__": 96 | parser = argparse.ArgumentParser(description="Mosaic DEMs and crop to reference raster with buffer.") 97 | 98 | parser.add_argument("--dem_folder", type=str, required=True, help="Folder containing DEM files (searches recursively)") 99 | parser.add_argument("--reference", type=str, required=True, help="Reference raster to crop to") 100 | parser.add_argument("--out_cropped", type=str, required=True, help="Output path for cropped DEM GeoTIFF") 101 | parser.add_argument("--buffer", type=float, default=5000, help="Buffer size in meters (default: 5000)") 102 | 103 | args = parser.parse_args() 104 | 105 | print("Finding DEM files...") 106 | 107 | dem_files = find_dem_files(args.dem_folder) 108 | if not dem_files: 109 | print(f"No DEM files found in {args.dem_folder}") 110 | exit(1) 111 | # Use a temporary file for the mosaic 112 | print(f"Found {len(dem_files)} DEM files. Creating mosaic...") 113 | import tempfile 114 | with tempfile.NamedTemporaryFile(suffix=".tif", delete=False) as tmp: 115 | tmp_mosaic = tmp.name 116 | t0 = time.time() 117 | mosaic_dems(dem_files, tmp_mosaic) 118 | t1 = time.time() 119 | crop_to_reference_with_buffer(tmp_mosaic, args.reference, args.out_cropped, buffer_m=args.buffer) 120 | t2 = time.time() 121 | # Delete intermediate mosaic file 122 | if os.path.exists(tmp_mosaic): 123 | os.remove(tmp_mosaic) 124 | print(f"Deleted intermediate mosaic file") 125 | print(f"Mosaic step elapsed: {t1-t0:.2f} s") 126 | print(f"Crop/reproject step elapsed: {t2-t1:.2f} s") 127 | print(f"Total elapsed: {t2-t0:.2f} s") 128 | 129 | print("Done.") -------------------------------------------------------------------------------- /HAND/compute_hand.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | from whitebox.whitebox_tools import WhiteboxTools 4 | import numpy as np 5 | import rasterio 6 | import time 7 | 8 | def compute_threshold_map(dem_filled_fn, out_thr_fn, s_lo_pct=10, s_hi_pct=90, thr_hilly=500, thr_flat=10000): 9 | """ 10 | Compute per-pixel threshold map from DEM slope. 11 | Args: 12 | dem_filled_fn: Path to filled DEM (GeoTIFF) 13 | out_thr_fn: Output threshold map path (GeoTIFF) 14 | s_lo_pct: Lower percentile for slope (default 10) 15 | s_hi_pct: Upper percentile for slope (default 90) 16 | thr_hilly: Min threshold in steep terrain (default 500) 17 | thr_flat: Max threshold in flat terrain (default 10000) 18 | """ 19 | with rasterio.open(dem_filled_fn) as ds: 20 | dem = ds.read(1).astype("float64") 21 | dem_nodata = ds.nodata 22 | transform = ds.transform 23 | res_x = abs(transform.a) 24 | res_y = abs(transform.e) 25 | out_meta = ds.meta.copy() 26 | dem_mask = (dem == dem_nodata) if dem_nodata is not None else np.isnan(dem) 27 | dem = np.where(dem_mask, np.nan, dem) 28 | gy, gx = np.gradient(dem, res_y, res_x) 29 | slope_rad = np.arctan(np.hypot(gx, gy)) 30 | slope_deg = np.degrees(slope_rad) 31 | slope_deg[dem_mask] = np.nan 32 | finite = np.isfinite(slope_deg) 33 | if not np.any(finite): 34 | raise ValueError("No valid slope pixels.") 35 | s_lo, s_hi = np.percentile(slope_deg[finite], (s_lo_pct, s_hi_pct)) 36 | if s_hi <= s_lo: 37 | thr_map = np.full_like(slope_deg, (thr_hilly + thr_flat) / 2.0, dtype="float64") 38 | else: 39 | thr_map = np.interp(slope_deg, [s_lo, s_hi], [thr_flat, thr_hilly]) 40 | thr_map = np.clip(thr_map, thr_hilly, thr_flat) 41 | thr_map[~finite] = np.nan 42 | thr_meta = {**out_meta, "count": 1, "dtype": "float32", "nodata": -9999.0} 43 | thr_write = thr_map.astype("float32") 44 | thr_write[~np.isfinite(thr_write)] = thr_meta["nodata"] 45 | with rasterio.open(out_thr_fn, "w", **thr_meta) as dst: 46 | dst.write(thr_write, 1) 47 | 48 | 49 | def compute_thresholded_streams(fac_fn, thr_map_fn, out_streams_fn): 50 | """ 51 | Compute per-pixel thresholded streams from flow accumulation and threshold map. 52 | Args: 53 | fac_fn: Path to flow accumulation raster (GeoTIFF) 54 | thr_map_fn: Path to threshold map raster (GeoTIFF) 55 | out_streams_fn: Output streams raster path (GeoTIFF) 56 | """ 57 | import numpy as np 58 | import rasterio 59 | with rasterio.open(fac_fn) as fac_ds, rasterio.open(thr_map_fn) as thr_ds: 60 | fac = fac_ds.read(1).astype("float64") 61 | thr_map = thr_ds.read(1).astype("float64") 62 | if fac_ds.nodata is not None: 63 | fac = np.where(fac == fac_ds.nodata, np.nan, fac) 64 | mask = ~np.isfinite(fac) | ~np.isfinite(thr_map) 65 | streams = np.where(~mask & (fac >= thr_map), 1, 0).astype("uint8") 66 | fac_meta = fac_ds.meta.copy() 67 | fac_meta.update(driver="GTiff", count=1, dtype="uint8", nodata=0) 68 | with rasterio.open(out_streams_fn, "w", **fac_meta) as dst: 69 | dst.write(streams, 1) 70 | 71 | 72 | def compute_hand_and_streams(dem_fn, out_hand_fn, out_streams_fn): 73 | # Temporary files 74 | tmp = [] 75 | wbt = WhiteboxTools() 76 | wbt.verbose = False 77 | wbt.work_dir = os.path.dirname(out_hand_fn) 78 | # Set no-data values 79 | print("Setting no-data values in DEM") 80 | dem_nd_fn = os.path.join(wbt.work_dir, 'dem_mosaic.tif') 81 | wbt.set_nodata_value( 82 | i=dem_fn, 83 | output=dem_nd_fn, 84 | back_value=-9999 85 | ) 86 | tmp.append(dem_nd_fn) 87 | # Fill depressions 88 | print("Filling depressions in DEM") 89 | dem_filled = os.path.join(wbt.work_dir, "dem_filled.tif") 90 | wbt.fill_depressions(dem_nd_fn, dem_filled) 91 | tmp.append(dem_filled) 92 | # Flow direction (D8) 93 | print("Computing D8 flow directions") 94 | fdr_fn = os.path.join(wbt.work_dir, "fdr.tif") 95 | wbt.d8_pointer(dem_filled, fdr_fn) 96 | tmp.append(fdr_fn) 97 | # Flow accumulation (D8) 98 | print("Computing D8 flow accumulation") 99 | fac_fn = os.path.join(wbt.work_dir, "fac.tif") 100 | wbt.d8_flow_accumulation(dem_filled, fac_fn, out_type="cells") 101 | tmp.append(fac_fn) 102 | # Compute threshold map 103 | print("Computing per-pixel threshold map") 104 | thr_map_fn = os.path.join(wbt.work_dir, "thr_map.tif") 105 | compute_threshold_map(dem_filled, thr_map_fn) 106 | tmp.append(thr_map_fn) 107 | # Compute thresholded streams 108 | print("Computing per-pixel thresholded streams") 109 | out_streams_fn = os.path.join(wbt.work_dir, "streams.tif") 110 | compute_thresholded_streams(fac_fn, thr_map_fn, out_streams_fn) 111 | # Compute HAND (Height Above Nearest Drainage) 112 | print("Computing HAND") 113 | out_hand_fn = os.path.join(wbt.work_dir, "hand.tif") 114 | wbt.elevation_above_stream(dem_filled, out_streams_fn, out_hand_fn) 115 | # Finishing up 116 | print("Saving final outputs") 117 | print(f"HAND saved to {out_hand_fn}") 118 | print(f"Streams saved to {out_streams_fn}") 119 | 120 | # Clean up intermediate files 121 | print('Cleaning up intermediate files...') 122 | for f in tmp: 123 | if os.path.exists(f): 124 | os.remove(f) 125 | 126 | 127 | if __name__ == "__main__": 128 | parser = argparse.ArgumentParser(description="Compute HAND and extract streams from DEM using WhiteboxTools.") 129 | parser.add_argument("--dem", type=str, required=True, help="Input DEM file (GeoTIFF)") 130 | parser.add_argument("--out_hand", type=str, required=True, help="Output HAND file (GeoTIFF)") 131 | parser.add_argument("--out_streams", type=str, required=True, help="Output streams file (GeoTIFF)") 132 | args = parser.parse_args() 133 | 134 | # Start processing 135 | print("HAND and streams computation...") 136 | t0 = time.time() 137 | compute_hand_and_streams(args.dem, args.out_hand, args.out_streams) 138 | t1 = time.time() 139 | print(f"Total elapsed: {t1-t0:.2f} s") 140 | print('Done.') 141 | -------------------------------------------------------------------------------- /preprocessing.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import geopandas as gpd 4 | import rasterio 5 | from shapely import wkt 6 | from shapely.geometry import shape, box 7 | from rasterio.mask import mask 8 | from rasterio.warp import reproject, Resampling 9 | from rasterio.features import geometry_mask 10 | from rasterio import Affine 11 | 12 | 13 | # Step 1: Create AOI layer in raster CRS 14 | def create_aoi_layer(s1_path, aoi_wkt): 15 | """ 16 | Create an AOI GeoDataFrame aligned to the CRS of the raster. 17 | 18 | :param s1_path: path to Sentinel-1 raster 19 | :param aoi_wkt: AOI polygon in WKT format (WGS84) 20 | :return: AOI as GeoDataFrame in raster CRS 21 | """ 22 | with rasterio.open(s1_path) as src: 23 | raster_crs = src.crs 24 | 25 | geom = wkt.loads(aoi_wkt) 26 | gdf_aoi = gpd.GeoDataFrame(geometry=[geom], crs="EPSG:4326") 27 | 28 | # Reproject AOI to raster CRS if necessary 29 | if gdf_aoi.crs != raster_crs: 30 | gdf_aoi = gdf_aoi.to_crs(raster_crs) 31 | 32 | return gdf_aoi 33 | 34 | 35 | # Step 2: Crop Sentinel-1 image (VV and VH) 36 | def crop_s1(s1_path, gdf_aoi): 37 | """ 38 | Crop a Sentinel-1 raster to the AOI extent. 39 | 40 | :param s1_path: path to Sentinel-1 raster 41 | :param gdf_aoi: AOI as GeoDataFrame in raster CRS 42 | :return: cropped image (numpy array) and metadata (dict) 43 | """ 44 | base, ext = os.path.splitext(s1_path) 45 | output_path = f"{base}_cropped{ext}" 46 | 47 | with rasterio.open(s1_path) as src: 48 | # Reproject AOI if CRS does not match 49 | aoi_mask = gdf_aoi 50 | if src.crs != gdf_aoi.crs: 51 | aoi_mask = gdf_aoi.to_crs(src.crs) 52 | 53 | # Apply mask 54 | out_image, out_transform = mask(src, aoi_mask.geometry, crop=True) 55 | 56 | # Update metadata for cropped raster 57 | out_meta = src.meta.copy() 58 | out_meta.update({ 59 | "driver": "GTiff", 60 | "height": out_image.shape[1], 61 | "width": out_image.shape[2], 62 | "transform": out_transform 63 | }) 64 | 65 | # Cropped raster is kept in memory, not written to disk 66 | return out_image, out_meta 67 | 68 | 69 | # Step 3: Crop DEM raster and align to Sentinel-1 grid 70 | def crop_dem(dem_path, gdf_aoi, s1_meta): 71 | """ 72 | Crop DEM raster to AOI and reproject to match Sentinel-1 grid. 73 | 74 | :param dem_path: path to DEM raster 75 | :param gdf_aoi: AOI as GeoDataFrame 76 | :param s1_meta: Sentinel-1 metadata dict (CRS, transform, shape) 77 | :return: cropped DEM image (numpy array) and metadata (dict) 78 | """ 79 | s1_crs = s1_meta["crs"] 80 | s1_transform = s1_meta["transform"] 81 | s1_shape = (s1_meta["height"], s1_meta["width"]) 82 | 83 | # Reproject DEM onto Sentinel-1 grid 84 | with rasterio.open(dem_path) as dem_src: 85 | out_array = np.empty(s1_shape, dtype=dem_src.dtypes[0]) 86 | reproject( 87 | source=rasterio.band(dem_src, 1), 88 | destination=out_array, 89 | src_transform=dem_src.transform, 90 | src_crs=dem_src.crs, 91 | dst_transform=s1_transform, 92 | dst_crs=s1_crs, 93 | resampling=Resampling.bilinear, 94 | ) 95 | 96 | # Clip DEM to AOI mask 97 | aoi_mask = gdf_aoi 98 | if gdf_aoi.crs != s1_crs: 99 | aoi_mask = gdf_aoi.to_crs(s1_crs) 100 | 101 | mask_geom = geometry_mask( 102 | geometries=aoi_mask.geometry, 103 | transform=s1_transform, 104 | invert=True, 105 | out_shape=s1_shape 106 | ) 107 | 108 | out_image = np.where(mask_geom, out_array, np.nan).astype(out_array.dtype) 109 | 110 | # Build output metadata 111 | out_meta = { 112 | "driver": "GTiff", 113 | "height": out_image.shape[0], 114 | "width": out_image.shape[1], 115 | "count": 1, 116 | "dtype": str(out_array.dtype), 117 | "crs": s1_crs, 118 | "transform": s1_transform 119 | } 120 | 121 | # Save cropped DEM to disk 122 | base, ext = os.path.splitext(dem_path) 123 | output_path = f"{base}_cropped{ext}" 124 | with rasterio.open(output_path, "w", **out_meta) as dest: 125 | dest.write(out_image, 1) 126 | 127 | return out_image, out_meta 128 | 129 | 130 | # Step 4: Generate tiles from raster held in memory 131 | def tile_from_memory(image, meta, tile_size, suffix, working_dir): 132 | """ 133 | Split raster (from memory) into tiles and save each to disk. 134 | 135 | :param image: numpy array (raster data) 136 | :param meta: metadata dict for raster 137 | :param tile_size: tile size in pixels 138 | :param suffix: suffix for naming tiles 139 | :param working_dir: directory where tiles will be saved 140 | :return: dict of {tile_name: (tile_data, tile_meta)} 141 | """ 142 | os.makedirs(working_dir, exist_ok=True) 143 | 144 | # Ensure 3D shape (bands, rows, cols) 145 | if image.ndim == 2: 146 | image = image[np.newaxis, :, :] 147 | 148 | height, width = image.shape[1], image.shape[2] 149 | tiles = {} 150 | tile_id = 1 151 | 152 | for row in range(0, height, tile_size): 153 | for col in range(0, width, tile_size): 154 | # Handle edge tiles (no overflow) 155 | row_off = min(row, height - tile_size) if row + tile_size > height else row 156 | col_off = min(col, width - tile_size) if col + tile_size > width else col 157 | 158 | # Extract subarray 159 | window_data = image[:, row_off:row_off+tile_size, col_off:col_off+tile_size] 160 | 161 | # (Optional) skip empty tiles 162 | # if np.all(np.isnan(window_data)): 163 | # continue 164 | 165 | # Build metadata for tile 166 | transform = meta["transform"] * Affine.translation(col_off, row_off) 167 | tile_meta = meta.copy() 168 | tile_meta.update({ 169 | "height": window_data.shape[1], 170 | "width": window_data.shape[2], 171 | "transform": transform 172 | }) 173 | 174 | # Define tile name and path 175 | tile_name = f"tile_{tile_id:03d}_0_{suffix}" 176 | out_path = os.path.join(working_dir, f"{tile_name}.tif") 177 | 178 | # Save tile to disk 179 | with rasterio.open(out_path, "w", **tile_meta) as dst: 180 | dst.write(window_data) 181 | 182 | tiles[tile_name] = (window_data, tile_meta) 183 | tile_id += 1 184 | 185 | return tiles 186 | -------------------------------------------------------------------------------- /FloodsML/train_unet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.optim as optim 4 | from torch.utils.data import DataLoader 5 | from torchmetrics import JaccardIndex 6 | from pathlib import Path 7 | from tqdm import tqdm 8 | import wandb 9 | import random 10 | from osgeo import gdal 11 | from utils.training_utils import ( 12 | MaskedDiceLoss, MaskedBCELoss, setup_wandb_run, log_summary_metrics, log_metrics, 13 | save_checkpoint, create_save_directory, 14 | train_epoch, validate_epoch 15 | ) 16 | import time 17 | import os 18 | 19 | # Set GDAL to use exceptions 20 | gdal.UseExceptions() 21 | 22 | 23 | 24 | def train_model(model, train_loader, val_loader, config): 25 | """Training loop for flood segmentation model""" 26 | start_time = time.time() 27 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 28 | model = model.to(device) 29 | 30 | # Initialize loss functions and metrics 31 | criterion_bce = MaskedBCELoss() 32 | criterion_dice = MaskedDiceLoss() 33 | jaccard = JaccardIndex(task='binary', num_classes=2).to(device) 34 | 35 | # Initialize optimizer 36 | optimizer = optim.AdamW( 37 | model.parameters(), 38 | lr=config['learning_rate'], 39 | weight_decay=config['weight_decay'] 40 | ) 41 | 42 | # Create warmup scheduler 43 | warmup_scheduler = optim.lr_scheduler.LinearLR( 44 | optimizer, 45 | start_factor=0.1, 46 | end_factor=1.0, 47 | total_iters=config['warmup_epochs'] 48 | ) 49 | 50 | # Create plateau scheduler 51 | plateau_scheduler = optim.lr_scheduler.ReduceLROnPlateau( 52 | optimizer, 53 | mode='min', 54 | factor=config.get('plateau_factor', 0.5), 55 | patience=config.get('plateau_patience', 5), 56 | verbose=True 57 | ) 58 | 59 | # Setup directories and wandb 60 | save_dir = create_save_directory(config['save_dir']) 61 | if config.get('use_wandb', False) and not wandb.run: # Only initialize if not already initialized 62 | setup_wandb_run(config, "UNet") 63 | 64 | best_val_loss = float('inf') 65 | early_stopping_counter = 0 66 | final_train_iou = 0.0 67 | final_val_iou = 0.0 68 | 69 | for epoch in tqdm(range(config['epochs'])): 70 | # Training phase 71 | train_metrics = train_epoch( 72 | model=model, 73 | loader=train_loader, 74 | optimizer=optimizer, 75 | criterion_bce=criterion_bce, 76 | criterion_dice=criterion_dice, 77 | jaccard=jaccard, 78 | config=config, 79 | device=device 80 | ) 81 | 82 | # Validation phase 83 | val_metrics = validate_epoch( 84 | model=model, 85 | loader=val_loader, 86 | criterion_bce=criterion_bce, 87 | criterion_dice=criterion_dice, 88 | jaccard=jaccard, 89 | config=config, 90 | device=device 91 | ) 92 | 93 | # Update learning rate schedulers 94 | if epoch < config['warmup_epochs']: 95 | warmup_scheduler.step() 96 | else: 97 | plateau_scheduler.step(val_metrics['val_loss']) 98 | 99 | # Log metrics 100 | if config.get('use_wandb', False): 101 | log_metrics({**train_metrics, **val_metrics}, epoch) 102 | 103 | # Early stopping check 104 | if val_metrics['val_loss'] < best_val_loss: 105 | best_val_loss = val_metrics['val_loss'] 106 | early_stopping_counter = 0 107 | save_checkpoint(model, optimizer, epoch, val_metrics, save_dir) 108 | else: 109 | early_stopping_counter += 1 110 | if early_stopping_counter >= config.get('early_stopping_patience', 10): 111 | print('Early stopping triggered') 112 | break 113 | 114 | if (epoch + 1) % config['checkpoint_freq'] == 0: 115 | save_checkpoint(model, optimizer, epoch, val_metrics, save_dir) 116 | 117 | # Update final metrics 118 | final_train_iou = train_metrics['train_iou'] 119 | final_val_iou = val_metrics['val_iou'] 120 | 121 | # Log final summary metrics 122 | if config.get('use_wandb', False): 123 | training_duration = time.time() - start_time 124 | summary_metrics = { 125 | 'train_iou': final_train_iou, 126 | 'val_iou': final_val_iou, 127 | 'best_val_loss': best_val_loss, 128 | 'training_duration': training_duration, 129 | 'total_epochs': epoch + 1 # Use actual number of epochs completed 130 | } 131 | log_summary_metrics(summary_metrics, config) 132 | 133 | return best_val_loss 134 | 135 | # Example usage: 136 | if __name__ == "__main__": 137 | from flood_dataset import FloodSegmentationDataset 138 | from models.unet import UNet 139 | 140 | # Create data directory if it doesn't exist 141 | data_dir = Path('data') 142 | data_dir.mkdir(parents=True, exist_ok=True) 143 | 144 | # Training configuration 145 | config = { 146 | 'learning_rate': 1e-4, 147 | 'weight_decay': 1e-4, 148 | 'epochs': 10, 149 | 'checkpoint_freq': 5, 150 | 'save_dir': 'checkpoints/unet', 151 | 'use_wandb': True, 152 | 'early_stopping_patience': 100, 153 | 'warmup_epochs': 3, # Number of epochs for warmup 154 | 'model': { 155 | 'n_channels': 2, # VH and VV channels 156 | 'n_classes': 1, 157 | 'bilinear': True 158 | }, 159 | 'batch_size': 32, 160 | 'bce_weight': 0.5, 161 | 'dice_weight': 0.5 162 | } 163 | 164 | # Create datasets and dataloaders 165 | train_dataset = FloodSegmentationDataset( 166 | map_file=str(data_dir / 'train.txt'), 167 | base_path=str(data_dir), 168 | split='train' 169 | ) 170 | 171 | val_dataset = FloodSegmentationDataset( 172 | map_file=str(data_dir / 'val.txt'), 173 | base_path=str(data_dir), 174 | split='val' 175 | ) 176 | 177 | # Calculate optimal number of workers (4 * num_GPU or number of CPU cores) 178 | num_workers = min(8 * torch.cuda.device_count(), os.cpu_count() or 4) 179 | 180 | train_loader = DataLoader( 181 | train_dataset, 182 | batch_size=config['batch_size'], 183 | shuffle=True, 184 | num_workers=num_workers, 185 | pin_memory=True, 186 | persistent_workers=True, 187 | prefetch_factor=2 188 | ) 189 | 190 | val_loader = DataLoader( 191 | val_dataset, 192 | batch_size=config['batch_size'], 193 | shuffle=False, 194 | num_workers=num_workers, 195 | pin_memory=True, 196 | persistent_workers=True, 197 | prefetch_factor=2 198 | ) 199 | 200 | # Initialize model 201 | model = UNet(**config['model']) 202 | 203 | # Train the model 204 | train_model(model, train_loader, val_loader, config) -------------------------------------------------------------------------------- /FloodsML/inference.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from pathlib import Path 3 | import sys 4 | import numpy as np 5 | from osgeo import gdal 6 | import torch 7 | from models.unet import UNet 8 | 9 | def load_tiff(filepath: Path) -> tuple[np.ndarray | None, gdal.Dataset | None]: 10 | """Opens TIFF, returns NumPy array and GDAL dataset object (for geo-info).""" 11 | ds = gdal.Open(str(filepath)) 12 | if ds is None: 13 | print(f"Error: Could not open TIFF file: {filepath}", file=sys.stderr) 14 | return None, None 15 | data = ds.ReadAsArray() 16 | return data, ds 17 | 18 | def save_tiff(data: np.ndarray, output_path: Path, ref_ds: gdal.Dataset): 19 | """Saves a NumPy array as a GeoTIFF using reference geo-information.""" 20 | driver = gdal.GetDriverByName('GTiff') 21 | if len(data.shape) == 2: 22 | rows, cols = data.shape 23 | num_bands = 1 24 | elif len(data.shape) == 3: 25 | num_bands, rows, cols = data.shape 26 | else: 27 | print(f"Error: Unexpected data shape for saving: {data.shape}", file=sys.stderr) 28 | return 29 | 30 | output_datatype = gdal.GDT_Byte 31 | 32 | out_ds = driver.Create(str(output_path), cols, rows, num_bands, output_datatype) 33 | if out_ds is None: 34 | print(f"Error: Could not create output file: {output_path}", file=sys.stderr) 35 | return 36 | 37 | out_ds.SetGeoTransform(ref_ds.GetGeoTransform()) 38 | out_ds.SetProjection(ref_ds.GetProjection()) 39 | 40 | if num_bands == 1: 41 | out_band = out_ds.GetRasterBand(1) 42 | out_band.WriteArray(data) 43 | out_band.FlushCache() 44 | else: 45 | for i in range(num_bands): 46 | out_band = out_ds.GetRasterBand(i + 1) 47 | out_band.WriteArray(data[i, :, :]) 48 | out_band.FlushCache() 49 | 50 | out_ds = None 51 | 52 | 53 | def main(): 54 | parser = argparse.ArgumentParser( 55 | description="Load VV/VH/DEM TIFFs, run inference, save prediction masks." 56 | ) 57 | parser.add_argument("input_dir", type=str, help="Path to the input directory containing TIFF files.") 58 | parser.add_argument("output_dir", type=str, help="Path to the output directory for prediction masks.") 59 | parser.add_argument("model_path", type=str, help="Path to the pre-trained model checkpoint (.pth file).") 60 | parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", help="Device to use ('cuda' or 'cpu').") 61 | parser.add_argument("--n_channels", type=int, default=2, help="Number of input channels for the model (e.g., 3 for VV+VH+DEM).") 62 | parser.add_argument("--n_classes", type=int, default=1, help="Number of output classes (e.g., 1 for binary segmentation).") 63 | # Add bilinear argument matching the example config 64 | parser.add_argument("--bilinear", action='store_true', default=True, help="Use bilinear upsampling in UNet.") # Default True as per example 65 | parser.add_argument("--no-bilinear", action='store_false', dest='bilinear', help="Use transposed convolutions instead of bilinear upsampling.") 66 | 67 | 68 | args = parser.parse_args() 69 | 70 | input_path = Path(args.input_dir).resolve() 71 | output_path = Path(args.output_dir).resolve() 72 | model_path = Path(args.model_path).resolve() 73 | device = torch.device(args.device) 74 | 75 | 76 | if not input_path.is_dir(): 77 | print(f"Error: Input directory not found: {input_path}", file=sys.stderr) 78 | sys.exit(1) 79 | if not model_path.is_file(): 80 | print(f"Error: Model file not found: {model_path}", file=sys.stderr) 81 | sys.exit(1) 82 | output_path.mkdir(parents=True, exist_ok=True) # Create output directory 83 | 84 | print(f"Using device: {device}") 85 | print(f"Loading model from: {model_path}") 86 | 87 | 88 | try: 89 | model = UNet(n_channels=args.n_channels, n_classes=args.n_classes, bilinear=args.bilinear) 90 | checkpoint = torch.load(model_path, map_location=device) 91 | 92 | model_state_dict = checkpoint.get('model_state_dict', checkpoint) 93 | model.load_state_dict(model_state_dict) 94 | model.to(device) 95 | model.eval() 96 | print(f"Model loaded successfully ({args.n_channels} inputs, {args.n_classes} outputs, bilinear={args.bilinear}).") 97 | except Exception as e: 98 | print(f"Error loading model: {e}", file=sys.stderr) 99 | # traceback.print_exc() # Uncomment for detailed loading errors 100 | sys.exit(1) 101 | 102 | vh_mean = -11.2832 103 | vh_std = 5.6656 104 | vv_mean = -6.9761 105 | vv_std = 4.3096 106 | 107 | 108 | 109 | print(f"Scanning for '*vv.tif' files in: {input_path}") 110 | vv_files = sorted(list(input_path.glob('*vv.tif'))) 111 | 112 | if not vv_files: 113 | print("No '*vv.tif' files found.") 114 | sys.exit(0) 115 | 116 | processed_count = 0 117 | skipped_count = 0 118 | for vv_filepath in vv_files: 119 | base_name = vv_filepath.name[:-len('01_Sentinel1_vv.tif')] 120 | vh_filename = base_name + '01_Sentinel1_vh.tif' 121 | dem_filename = base_name + '02_DEM.tif' 122 | 123 | vh_filepath = input_path / vh_filename 124 | dem_filepath = input_path / dem_filename 125 | 126 | if vh_filepath.is_file() and dem_filepath.is_file(): 127 | print(f"\nProcessing group based on: {vv_filepath.name}") 128 | 129 | vv_data, vv_ds = load_tiff(vv_filepath) 130 | vh_data, vh_ds = load_tiff(vh_filepath) 131 | dem_data, dem_ds = load_tiff(dem_filepath) 132 | 133 | # --- Data Integrity Check --- 134 | if (vv_data is not None and vh_data is not None and dem_data is not None and 135 | vv_data.shape == vh_data.shape and vv_data.shape == dem_data.shape): 136 | 137 | print(f" Loaded VV: {vv_data.shape}, VH: {vh_data.shape}, DEM: {dem_data.shape}") 138 | 139 | # --- Preprocessing --- 140 | vv_norm = (vv_data.astype(np.float32) - vv_mean) / vv_std 141 | vh_norm = (vh_data.astype(np.float32) - vh_mean) / vh_std 142 | dem_processed = dem_data.astype(np.float32) 143 | 144 | input_image = np.stack([vv_norm, vh_norm], axis=0) # dem_processed 145 | input_tensor = torch.from_numpy(input_image).unsqueeze(0).to(device) 146 | 147 | # --- Inference --- 148 | with torch.no_grad(): # Disable gradient calculation 149 | output_tensor = model(input_tensor) 150 | 151 | pred_prob = torch.sigmoid(output_tensor) 152 | pred_binary = (pred_prob > 0.5).squeeze().cpu().numpy().astype(np.uint8) 153 | 154 | # --- Save Output --- 155 | output_filename = output_path / (base_name + 'pred.tif') 156 | print(f" Saving prediction to: {output_filename}") 157 | # Use geo-info from one of the inputs (e.g., VV) 158 | save_tiff(pred_binary, output_filename, vv_ds) 159 | processed_count += 1 160 | 161 | else: 162 | print(f" Skipping group: Loading error or shape mismatch.") 163 | print(f" VV: {'OK' if vv_data is not None else 'Fail'}, VH: {'OK' if vh_data is not None else 'Fail'}, DEM: {'OK' if dem_data is not None else 'Fail'}") 164 | if vv_data is not None and vh_data is not None and dem_data is not None: 165 | print(f" Shapes: VV={vv_data.shape}, VH={vh_data.shape}, DEM={dem_data.shape}") 166 | skipped_count += 1 167 | 168 | if vv_ds: vv_ds = None 169 | if vh_ds: vh_ds = None 170 | if dem_ds: dem_ds = None 171 | 172 | else: 173 | # Warning if counterparts are missing 174 | print(f"\nSkipping {vv_filepath.name}: Missing counterparts.") 175 | if not vh_filepath.is_file(): print(f" Missing: {vh_filename}") 176 | if not dem_filepath.is_file(): print(f" Missing: {dem_filename}") 177 | skipped_count += 1 178 | 179 | print(f"\nFinished. Processed {processed_count} groups, skipped {skipped_count} groups.") 180 | 181 | if __name__ == "__main__": 182 | main() -------------------------------------------------------------------------------- /FloodsML/sweep_review.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import wandb\n", 10 | "import pandas as pd\n", 11 | "import seaborn as sns\n", 12 | "import matplotlib.pyplot as plt" 13 | ] 14 | }, 15 | { 16 | "cell_type": "code", 17 | "execution_count": 2, 18 | "metadata": {}, 19 | "outputs": [ 20 | { 21 | "name": "stdout", 22 | "output_type": "stream", 23 | "text": [ 24 | "🏆 Top 5 Runs by Validation IoU:\n", 25 | "\n", 26 | " run_id val_iou best_val_loss learning_rate weight_decay \\\n", 27 | "19 d2l878my 0.450551 0.546182 0.000016 0.000182 \n", 28 | "8 sfek7m88 0.450266 0.540528 0.000018 0.000717 \n", 29 | "1 ij6bpizs 0.449752 0.570179 0.000029 0.000011 \n", 30 | "0 kfoyx1x6 0.449238 0.578700 0.000013 0.000080 \n", 31 | "17 vbzyhodf 0.445841 0.511774 0.000146 0.000010 \n", 32 | "\n", 33 | " warmup_epochs plateau_factor plateau_patience batch_size num_workers \\\n", 34 | "19 4 0.2 7 16 8 \n", 35 | "8 3 0.7 5 16 4 \n", 36 | "1 4 0.5 5 32 16 \n", 37 | "0 5 0.7 5 16 16 \n", 38 | "17 3 0.5 3 32 16 \n", 39 | "\n", 40 | " prefetch_factor \n", 41 | "19 1 \n", 42 | "8 1 \n", 43 | "1 1 \n", 44 | "0 2 \n", 45 | "17 1 \n" 46 | ] 47 | } 48 | ], 49 | "source": [ 50 | "# Set your sweep details\n", 51 | "ENTITY = \"mracic\"\n", 52 | "PROJECT = \"flood-segmentation\"\n", 53 | "SWEEP_ID = \"sk0uzqz0\"\n", 54 | "\n", 55 | "# Initialize API\n", 56 | "api = wandb.Api()\n", 57 | "\n", 58 | "# Load sweep\n", 59 | "sweep = api.sweep(f\"{ENTITY}/{PROJECT}/{SWEEP_ID}\")\n", 60 | "runs = sweep.runs\n", 61 | "\n", 62 | "# Convert to DataFrame\n", 63 | "records = []\n", 64 | "for run in runs:\n", 65 | " if run.state != \"finished\":\n", 66 | " continue\n", 67 | " summary = run.summary\n", 68 | " config = run.config\n", 69 | " records.append({\n", 70 | " \"run_id\": run.id,\n", 71 | " \"val_iou\": summary.get(\"val_iou\"),\n", 72 | " \"best_val_loss\": summary.get(\"best_val_loss\"),\n", 73 | " \"learning_rate\": config.get(\"learning_rate\"),\n", 74 | " \"weight_decay\": config.get(\"weight_decay\"),\n", 75 | " \"warmup_epochs\": config.get(\"warmup_epochs\"),\n", 76 | " \"plateau_factor\": config.get(\"plateau_factor\"),\n", 77 | " \"plateau_patience\": config.get(\"plateau_patience\"),\n", 78 | " \"batch_size\": config.get(\"batch_size\"),\n", 79 | " \"num_workers\": config.get(\"num_workers\"),\n", 80 | " \"prefetch_factor\": config.get(\"prefetch_factor\"),\n", 81 | " })\n", 82 | "\n", 83 | "df = pd.DataFrame(records)\n", 84 | "\n", 85 | "# Drop rows with missing values\n", 86 | "df = df.dropna(subset=[\"val_iou\"])\n", 87 | "\n", 88 | "# Sort by best val_iou\n", 89 | "df_sorted = df.sort_values(by=\"val_iou\", ascending=False)\n", 90 | "\n", 91 | "# Show top 5 runs\n", 92 | "print(\"🏆 Top 5 Runs by Validation IoU:\\n\")\n", 93 | "print(df_sorted.head())\n", 94 | "\n", 95 | "# Plot\n" 96 | ] 97 | }, 98 | { 99 | "cell_type": "code", 100 | "execution_count": null, 101 | "metadata": {}, 102 | "outputs": [], 103 | "source": [ 104 | "SWEEP_ID = 'jlcywgkr'\n", 105 | "sweep = api.sweep(f\"{ENTITY}/{PROJECT}/{SWEEP_ID}\")" 106 | ] 107 | }, 108 | { 109 | "cell_type": "code", 110 | "execution_count": null, 111 | "metadata": {}, 112 | "outputs": [ 113 | { 114 | "name": "stdout", 115 | "output_type": "stream", 116 | "text": [ 117 | "🏆 Top 5 Runs by Validation IoU:\n", 118 | "\n", 119 | " run_id val_iou best_val_loss learning_rate weight_decay warmup_epochs plateau_factor plateau_patience batch_size num_workers prefetch_factor dropout_rate bce_weight dice_weight\n", 120 | "otpnzi48 0.481520 0.508222 0.000008 0.000236 5 0.2 9 8 6 2 0.0 0.5 0.5\n", 121 | "az3w121w 0.469485 0.161217 0.000010 0.000348 5 0.3 9 8 6 1 0.0 1.0 0.0\n", 122 | "kl9io4im 0.469250 0.173235 0.000009 0.000191 5 0.3 9 8 6 2 0.0 1.0 0.0\n", 123 | "ivryz770 0.466911 0.400988 0.000011 0.000322 5 0.2 9 16 6 2 0.0 0.7 0.3\n", 124 | "3t8pofh1 0.459811 0.656754 0.000020 0.000234 4 0.2 9 24 8 1 0.0 0.3 0.7\n" 125 | ] 126 | } 127 | ], 128 | "source": [ 129 | "runs = sweep.runs\n", 130 | "\n", 131 | "# Extract metrics and config for each run\n", 132 | "records = []\n", 133 | "for run in runs:\n", 134 | " if run.state != \"finished\":\n", 135 | " continue\n", 136 | " try:\n", 137 | " val_iou = run.summary.get(\"val_iou\")\n", 138 | " best_val_loss = run.summary.get(\"best_val_loss\")\n", 139 | " config = run.config\n", 140 | " \n", 141 | " records.append({\n", 142 | " \"run_id\": run.id,\n", 143 | " \"val_iou\": val_iou,\n", 144 | " \"best_val_loss\": best_val_loss,\n", 145 | " \"learning_rate\": config.get(\"learning_rate\"),\n", 146 | " \"weight_decay\": config.get(\"weight_decay\"),\n", 147 | " \"warmup_epochs\": config.get(\"warmup_epochs\"),\n", 148 | " \"plateau_factor\": config.get(\"plateau_factor\"),\n", 149 | " \"plateau_patience\": config.get(\"plateau_patience\"),\n", 150 | " \"batch_size\": config.get(\"batch_size\"),\n", 151 | " \"num_workers\": config.get(\"num_workers\"),\n", 152 | " \"prefetch_factor\": config.get(\"prefetch_factor\"),\n", 153 | " \"dropout_rate\": config.get(\"dropout_rate\"),\n", 154 | " \"bce_weight\": config.get(\"bce_weight\"),\n", 155 | " \"dice_weight\": 1.0 - config.get(\"bce_weight\") if config.get(\"bce_weight\") is not None else None,\n", 156 | " })\n", 157 | " except Exception as e:\n", 158 | " print(f\"Skipping run {run.id} due to error: {e}\")\n", 159 | "\n", 160 | "# Create DataFrame\n", 161 | "df = pd.DataFrame(records)\n", 162 | "\n", 163 | "# Sort by val_iou (descending)\n", 164 | "df = df.sort_values(by=\"val_iou\", ascending=False)\n", 165 | "\n", 166 | "# Select top 5\n", 167 | "top_5 = df.head(5)\n", 168 | "\n", 169 | "# Display\n", 170 | "print(\"🏆 Top 5 Runs by Validation IoU:\\n\")\n", 171 | "print(top_5.to_string(index=False))" 172 | ] 173 | } 174 | ], 175 | "metadata": { 176 | "kernelspec": { 177 | "display_name": "poplave", 178 | "language": "python", 179 | "name": "python3" 180 | }, 181 | "language_info": { 182 | "codemirror_mode": { 183 | "name": "ipython", 184 | "version": 3 185 | }, 186 | "file_extension": ".py", 187 | "mimetype": "text/x-python", 188 | "name": "python", 189 | "nbconvert_exporter": "python", 190 | "pygments_lexer": "ipython3", 191 | "version": "3.10.16" 192 | } 193 | }, 194 | "nbformat": 4, 195 | "nbformat_minor": 2 196 | } 197 | -------------------------------------------------------------------------------- /FloodsML/mg_test.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 2, 6 | "id": "56dba08f-bad7-44bf-872f-2dc87d008a0f", 7 | "metadata": {}, 8 | "outputs": [ 9 | { 10 | "name": "stdout", 11 | "output_type": "stream", 12 | "text": [ 13 | "Hello, World!\n" 14 | ] 15 | } 16 | ], 17 | "source": [ 18 | "print(\"Hello, World!\")" 19 | ] 20 | }, 21 | { 22 | "cell_type": "code", 23 | "execution_count": 10, 24 | "id": "df6c268f-46dd-498d-a4e6-8d0aadc12495", 25 | "metadata": {}, 26 | "outputs": [ 27 | { 28 | "name": "stdout", 29 | "output_type": "stream", 30 | "text": [ 31 | "Using device: cpu\n", 32 | "Loading model from: /app/models/best_model.pth\n", 33 | "Model loaded successfully (2 inputs, 1 outputs, bilinear=True).\n", 34 | "Scanning for '*vv.tif' files in: /app/sample_data/512_images_EMSR708_AOI01_DEL_PRODUCT\n", 35 | "\n", 36 | "Processing group based on: tile_001_0_01_Sentinel1_vv.tif\n", 37 | "/usr/lib/python3/dist-packages/osgeo/gdal.py:606: FutureWarning: Neither gdal.UseExceptions() nor gdal.DontUseExceptions() has been explicitly called. In GDAL 4.0, exceptions will be enabled by default.\n", 38 | " warnings.warn(\n", 39 | " Loaded VV: (512, 512), VH: (512, 512), DEM: (512, 512)\n", 40 | " Saving prediction to: /app/sample_data/512_output_labels/tile_001_0_pred.tif\n", 41 | "\n", 42 | "Processing group based on: tile_002_0_01_Sentinel1_vv.tif\n", 43 | " Loaded VV: (512, 512), VH: (512, 512), DEM: (512, 512)\n", 44 | " Saving prediction to: /app/sample_data/512_output_labels/tile_002_0_pred.tif\n", 45 | "\n", 46 | "Processing group based on: tile_003_0_01_Sentinel1_vv.tif\n", 47 | " Loaded VV: (512, 512), VH: (512, 512), DEM: (512, 512)\n", 48 | " Saving prediction to: /app/sample_data/512_output_labels/tile_003_0_pred.tif\n", 49 | "\n", 50 | "Processing group based on: tile_004_0_01_Sentinel1_vv.tif\n", 51 | " Loaded VV: (512, 512), VH: (512, 512), DEM: (512, 512)\n", 52 | " Saving prediction to: /app/sample_data/512_output_labels/tile_004_0_pred.tif\n", 53 | "\n", 54 | "Processing group based on: tile_005_0_01_Sentinel1_vv.tif\n", 55 | " Loaded VV: (512, 512), VH: (512, 512), DEM: (512, 512)\n", 56 | " Saving prediction to: /app/sample_data/512_output_labels/tile_005_0_pred.tif\n", 57 | "\n", 58 | "Processing group based on: tile_006_0_01_Sentinel1_vv.tif\n", 59 | " Loaded VV: (512, 512), VH: (512, 512), DEM: (512, 512)\n", 60 | " Saving prediction to: /app/sample_data/512_output_labels/tile_006_0_pred.tif\n", 61 | "\n", 62 | "Processing group based on: tile_007_0_01_Sentinel1_vv.tif\n", 63 | " Loaded VV: (512, 512), VH: (512, 512), DEM: (512, 512)\n", 64 | " Saving prediction to: /app/sample_data/512_output_labels/tile_007_0_pred.tif\n", 65 | "\n", 66 | "Processing group based on: tile_008_0_01_Sentinel1_vv.tif\n", 67 | " Loaded VV: (512, 512), VH: (512, 512), DEM: (512, 512)\n", 68 | " Saving prediction to: /app/sample_data/512_output_labels/tile_008_0_pred.tif\n", 69 | "\n", 70 | "Processing group based on: tile_009_0_01_Sentinel1_vv.tif\n", 71 | " Loaded VV: (512, 512), VH: (512, 512), DEM: (512, 512)\n", 72 | " Saving prediction to: /app/sample_data/512_output_labels/tile_009_0_pred.tif\n", 73 | "\n", 74 | "Processing group based on: tile_010_0_01_Sentinel1_vv.tif\n", 75 | " Loaded VV: (512, 512), VH: (512, 512), DEM: (512, 512)\n", 76 | " Saving prediction to: /app/sample_data/512_output_labels/tile_010_0_pred.tif\n", 77 | "\n", 78 | "Processing group based on: tile_011_0_01_Sentinel1_vv.tif\n", 79 | " Loaded VV: (512, 512), VH: (512, 512), DEM: (512, 512)\n", 80 | " Saving prediction to: /app/sample_data/512_output_labels/tile_011_0_pred.tif\n", 81 | "\n", 82 | "Processing group based on: tile_012_0_01_Sentinel1_vv.tif\n", 83 | " Loaded VV: (512, 512), VH: (512, 512), DEM: (512, 512)\n", 84 | " Saving prediction to: /app/sample_data/512_output_labels/tile_012_0_pred.tif\n", 85 | "\n", 86 | "Processing group based on: tile_013_0_01_Sentinel1_vv.tif\n", 87 | " Loaded VV: (512, 512), VH: (512, 512), DEM: (512, 512)\n", 88 | " Saving prediction to: /app/sample_data/512_output_labels/tile_013_0_pred.tif\n", 89 | "\n", 90 | "Processing group based on: tile_014_0_01_Sentinel1_vv.tif\n", 91 | " Loaded VV: (512, 512), VH: (512, 512), DEM: (512, 512)\n", 92 | " Saving prediction to: /app/sample_data/512_output_labels/tile_014_0_pred.tif\n", 93 | "\n", 94 | "Processing group based on: tile_015_0_01_Sentinel1_vv.tif\n", 95 | " Loaded VV: (512, 512), VH: (512, 512), DEM: (512, 512)\n", 96 | " Saving prediction to: /app/sample_data/512_output_labels/tile_015_0_pred.tif\n", 97 | "\n", 98 | "Processing group based on: tile_016_0_01_Sentinel1_vv.tif\n", 99 | " Loaded VV: (512, 512), VH: (512, 512), DEM: (512, 512)\n", 100 | " Saving prediction to: /app/sample_data/512_output_labels/tile_016_0_pred.tif\n", 101 | "\n", 102 | "Finished. Processed 16 groups, skipped 0 groups.\n" 103 | ] 104 | } 105 | ], 106 | "source": [] 107 | }, 108 | { 109 | "cell_type": "code", 110 | "execution_count": 6, 111 | "id": "675edde6-6e87-47c5-8694-bf814195b452", 112 | "metadata": {}, 113 | "outputs": [ 114 | { 115 | "name": "stdout", 116 | "output_type": "stream", 117 | "text": [ 118 | "GDAL version (library): 3120000\n", 119 | "GDAL version (Python binding): 3.12.0dev-adbc_bigquery\n" 120 | ] 121 | } 122 | ], 123 | "source": [ 124 | "from osgeo import gdal\n", 125 | "print(\"GDAL version (library):\", gdal.VersionInfo())\n", 126 | "print(\"GDAL version (Python binding):\", gdal.__version__)\n" 127 | ] 128 | }, 129 | { 130 | "cell_type": "code", 131 | "execution_count": 11, 132 | "id": "b79493fd-9edc-4ae4-9d6b-6e07824dd04c", 133 | "metadata": {}, 134 | "outputs": [], 135 | "source": [ 136 | "#!pip install \"numpy<2\" --force-reinstall --break-system-packages\n" 137 | ] 138 | }, 139 | { 140 | "cell_type": "code", 141 | "execution_count": null, 142 | "id": "92937c78-6418-4190-8057-7bd538a3e63e", 143 | "metadata": {}, 144 | "outputs": [], 145 | "source": [] 146 | }, 147 | { 148 | "cell_type": "code", 149 | "execution_count": null, 150 | "id": "37fedb26-2a7d-4cd5-92b3-3f1195b02d55", 151 | "metadata": {}, 152 | "outputs": [], 153 | "source": [] 154 | }, 155 | { 156 | "cell_type": "code", 157 | "execution_count": null, 158 | "id": "7f3e74cc-295c-4136-bc19-3597e67ab4ba", 159 | "metadata": {}, 160 | "outputs": [], 161 | "source": [] 162 | }, 163 | { 164 | "cell_type": "code", 165 | "execution_count": null, 166 | "id": "457f7a30-1778-41ba-86db-e8ca9a78f7c4", 167 | "metadata": {}, 168 | "outputs": [], 169 | "source": [] 170 | }, 171 | { 172 | "cell_type": "code", 173 | "execution_count": 14, 174 | "id": "e575d7cd-10e0-4be8-acd2-27f747c762af", 175 | "metadata": {}, 176 | "outputs": [ 177 | { 178 | "name": "stdout", 179 | "output_type": "stream", 180 | "text": [ 181 | "Narejen enoten raster: outpit_file.tif\n" 182 | ] 183 | } 184 | ], 185 | "source": [ 186 | "!python inference.py \"sample_data/512_images_EMSR708_AOI01_DEL_PRODUCT\" \"sample_data/512_output_labels\" \"models/best_model.pth\"" 187 | ] 188 | }, 189 | { 190 | "cell_type": "code", 191 | "execution_count": 17, 192 | "id": "0dde0937-966f-4bda-8576-a19d5dacf298", 193 | "metadata": {}, 194 | "outputs": [ 195 | { 196 | "name": "stdout", 197 | "output_type": "stream", 198 | "text": [ 199 | "✅ Merged raster created: sample_data/output.tif\n" 200 | ] 201 | } 202 | ], 203 | "source": [ 204 | "!python merge_tifs.py --input \"sample_data/512_images_EMSR708_AOI01_DEL_PRODUCT\" --output \"sample_data/output.tif\"" 205 | ] 206 | }, 207 | { 208 | "cell_type": "code", 209 | "execution_count": null, 210 | "id": "71bff668-398d-4158-a2a6-8d3e74b9a0f7", 211 | "metadata": {}, 212 | "outputs": [], 213 | "source": [] 214 | } 215 | ], 216 | "metadata": { 217 | "kernelspec": { 218 | "display_name": "Python 3 (ipykernel)", 219 | "language": "python", 220 | "name": "python3" 221 | }, 222 | "language_info": { 223 | "codemirror_mode": { 224 | "name": "ipython", 225 | "version": 3 226 | }, 227 | "file_extension": ".py", 228 | "mimetype": "text/x-python", 229 | "name": "python", 230 | "nbconvert_exporter": "python", 231 | "pygments_lexer": "ipython3", 232 | "version": "3.12.3" 233 | } 234 | }, 235 | "nbformat": 4, 236 | "nbformat_minor": 5 237 | } 238 | -------------------------------------------------------------------------------- /FloodsML/trainers/trainer.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import numpy as np 4 | from tqdm import tqdm 5 | 6 | from data_generators.data_generator import initialize_data_loader 7 | from models.sync_batchnorm.replicate import patch_replication_callback 8 | from models.deeplab import DeepLab 9 | from losses.loss import SegmentationLosses 10 | from utils.calculate_weights import calculate_weigths_labels 11 | from utils.lr_scheduler import LR_Scheduler 12 | from utils.saver import Saver 13 | from utils.summaries import TensorboardSummary 14 | from utils.metrics import Evaluator 15 | import torch 16 | import yaml 17 | 18 | 19 | class Trainer(object): 20 | def __init__(self, config): 21 | 22 | self.config = config 23 | self.best_pred = 0.0 24 | 25 | # Define Saver 26 | self.saver = Saver(config) 27 | self.saver.save_experiment_config() 28 | # Define Tensorboard Summary 29 | self.summary = TensorboardSummary(self.config['training']['tensorboard']['log_dir']) 30 | self.writer = self.summary.create_summary() 31 | 32 | self.train_loader, self.val_loader, self.test_loader, self.nclass = initialize_data_loader(config) 33 | 34 | # Define network 35 | model = DeepLab(num_classes=self.nclass, 36 | backbone=self.config['network']['backbone'], 37 | output_stride=self.config['image']['out_stride'], 38 | sync_bn=self.config['network']['sync_bn'], 39 | freeze_bn=self.config['network']['freeze_bn']) 40 | 41 | train_params = [{'params': model.get_1x_lr_params(), 'lr': self.config['training']['lr']}, 42 | {'params': model.get_10x_lr_params(), 'lr': self.config['training']['lr'] * 10}] 43 | 44 | # Define Optimizer 45 | optimizer = torch.optim.SGD(train_params, momentum=self.config['training']['momentum'], 46 | weight_decay=self.config['training']['weight_decay'], nesterov=self.config['training']['nesterov']) 47 | 48 | # Define Criterion 49 | # whether to use class balanced weights 50 | if self.config['training']['use_balanced_weights']: 51 | classes_weights_path = os.path.join(self.config['dataset']['base_path'], self.config['dataset']['dataset_name'] + '_classes_weights.npy') 52 | if os.path.isfile(classes_weights_path): 53 | weight = np.load(classes_weights_path) 54 | else: 55 | weight = calculate_weigths_labels(self.config, self.config['dataset']['dataset_name'], self.train_loader, self.nclass) 56 | weight = torch.from_numpy(weight.astype(np.float32)) 57 | else: 58 | weight = None 59 | 60 | self.criterion = SegmentationLosses(weight=weight, cuda=self.config['network']['use_cuda']).build_loss(mode=self.config['training']['loss_type']) 61 | self.model, self.optimizer = model, optimizer 62 | 63 | # Define Evaluator 64 | self.evaluator = Evaluator(self.nclass) 65 | # Define lr scheduler 66 | self.scheduler = LR_Scheduler(self.config['training']['lr_scheduler'], self.config['training']['lr'], 67 | self.config['training']['epochs'], len(self.train_loader)) 68 | 69 | 70 | # Using cuda 71 | if self.config['network']['use_cuda']: 72 | self.model = torch.nn.DataParallel(self.model) 73 | patch_replication_callback(self.model) 74 | self.model = self.model.cuda() 75 | 76 | # Resuming checkpoint 77 | 78 | if self.config['training']['weights_initialization']['use_pretrained_weights']: 79 | if not os.path.isfile(self.config['training']['weights_initialization']['restore_from']): 80 | raise RuntimeError("=> no checkpoint found at '{}'" .format(self.config['training']['weights_initialization']['restore_from'])) 81 | 82 | if self.config['network']['use_cuda']: 83 | checkpoint = torch.load(self.config['training']['weights_initialization']['restore_from']) 84 | else: 85 | checkpoint = torch.load(self.config['training']['weights_initialization']['restore_from'], map_location={'cuda:0': 'cpu'}) 86 | 87 | self.config['training']['start_epoch'] = checkpoint['epoch'] 88 | 89 | if self.config['network']['use_cuda']: 90 | self.model.load_state_dict(checkpoint['state_dict']) 91 | else: 92 | self.model.load_state_dict(checkpoint['state_dict']) 93 | 94 | # if not self.config['ft']: 95 | self.optimizer.load_state_dict(checkpoint['optimizer']) 96 | self.best_pred = checkpoint['best_pred'] 97 | print("=> loaded checkpoint '{}' (epoch {})" 98 | .format(self.config['training']['weights_initialization']['restore_from'], checkpoint['epoch'])) 99 | 100 | 101 | def training(self, epoch): 102 | train_loss = 0.0 103 | self.model.train() 104 | tbar = tqdm(self.train_loader) 105 | num_img_tr = len(self.train_loader) 106 | for i, sample in enumerate(tbar): 107 | image, target = sample['image'], sample['label'] 108 | if self.config['network']['use_cuda']: 109 | image, target = image.cuda(), target.cuda() 110 | self.scheduler(self.optimizer, i, epoch, self.best_pred) 111 | self.optimizer.zero_grad() 112 | output = self.model(image) 113 | loss = self.criterion(output, target) 114 | loss.backward() 115 | self.optimizer.step() 116 | train_loss += loss.item() 117 | tbar.set_description('Train loss: %.3f' % (train_loss / (i + 1))) 118 | self.writer.add_scalar('train/total_loss_iter', loss.item(), i + num_img_tr * epoch) 119 | 120 | # Show 10 * 3 inference results each epoch 121 | if i % (num_img_tr // 10) == 0: 122 | global_step = i + num_img_tr * epoch 123 | self.summary.visualize_image(self.writer, self.config['dataset']['dataset_name'], image, target, output, global_step) 124 | 125 | self.writer.add_scalar('train/total_loss_epoch', train_loss, epoch) 126 | print('[Epoch: %d, numImages: %5d]' % (epoch, i * self.config['training']['batch_size'] + image.data.shape[0])) 127 | print('Loss: %.3f' % train_loss) 128 | 129 | #save last checkpoint 130 | self.saver.save_checkpoint({ 131 | 'epoch': epoch + 1, 132 | # 'state_dict': self.model.module.state_dict(), 133 | 'state_dict': self.model.state_dict(), 134 | 'optimizer': self.optimizer.state_dict(), 135 | 'best_pred': self.best_pred, 136 | }, is_best = False, filename='checkpoint_last.pth.tar') 137 | 138 | #if training on a subset reshuffle the data 139 | if self.config['training']['train_on_subset']['enabled']: 140 | self.train_loader.dataset.shuffle_dataset() 141 | 142 | 143 | def validation(self, epoch): 144 | self.model.eval() 145 | self.evaluator.reset() 146 | tbar = tqdm(self.val_loader, desc='\r') 147 | test_loss = 0.0 148 | for i, sample in enumerate(tbar): 149 | image, target = sample['image'], sample['label'] 150 | if self.config['network']['use_cuda']: 151 | image, target = image.cuda(), target.cuda() 152 | with torch.no_grad(): 153 | output = self.model(image) 154 | loss = self.criterion(output, target) 155 | test_loss += loss.item() 156 | tbar.set_description('Val loss: %.3f' % (test_loss / (i + 1))) 157 | pred = output.data.cpu().numpy() 158 | target = target.cpu().numpy() 159 | pred = np.argmax(pred, axis=1) 160 | # Add batch sample into evaluator 161 | self.evaluator.add_batch(target, pred) 162 | 163 | # Fast test during the training 164 | Acc = self.evaluator.Pixel_Accuracy() 165 | Acc_class = self.evaluator.Pixel_Accuracy_Class() 166 | mIoU = self.evaluator.Mean_Intersection_over_Union() 167 | FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union() 168 | self.writer.add_scalar('val/total_loss_epoch', test_loss, epoch) 169 | self.writer.add_scalar('val/mIoU', mIoU, epoch) 170 | self.writer.add_scalar('val/Acc', Acc, epoch) 171 | self.writer.add_scalar('val/Acc_class', Acc_class, epoch) 172 | self.writer.add_scalar('val/fwIoU', FWIoU, epoch) 173 | print('Validation:') 174 | print('[Epoch: %d, numImages: %5d]' % (epoch, i * self.config['training']['batch_size'] + image.data.shape[0])) 175 | print("Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format(Acc, Acc_class, mIoU, FWIoU)) 176 | print('Loss: %.3f' % test_loss) 177 | 178 | new_pred = mIoU 179 | if new_pred > self.best_pred: 180 | self.best_pred = new_pred 181 | self.saver.save_checkpoint({ 182 | 'epoch': epoch + 1, 183 | # 'state_dict': self.model.module.state_dict(), 184 | 'state_dict': self.model.state_dict(), 185 | 'optimizer': self.optimizer.state_dict(), 186 | 'best_pred': self.best_pred, 187 | }, is_best = True, filename='checkpoint_best.pth.tar') -------------------------------------------------------------------------------- /FloodsML/0_analyze_labels.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from pathlib import Path 3 | from osgeo import gdal 4 | import matplotlib.pyplot as plt 5 | from tqdm import tqdm 6 | import pandas as pd 7 | from collections import defaultdict 8 | 9 | def analyze_labels(data_dir='data'): 10 | """Analyze the distribution of labels in the dataset""" 11 | data_dir = Path(data_dir) 12 | labels_dir = data_dir / '512_labels' 13 | 14 | if not labels_dir.exists(): 15 | print(f"Error: Labels directory not found at {labels_dir}") 16 | return 17 | 18 | # Statistics storage 19 | stats = { 20 | 'total_tiles': 0, 21 | 'flooded_tiles': 0, 22 | 'non_flooded_tiles': 0, 23 | 'flood_ratio': 0, 24 | 'tiles_per_emsr': defaultdict(int), 25 | 'flooded_per_emsr': defaultdict(int), 26 | 'errors': [], 27 | 'unique_values': set(), # Track unique values across all files 28 | 'value_counts': defaultdict(int), # Track count of each value 29 | 'flood_ratios': [], # Store flood ratios for histogram 30 | 'total_pixels': 0, # Track total pixels across all tiles 31 | 'threshold_stats': defaultdict(lambda: {'tiles': 0, 'pixels': defaultdict(int)}) # Stats for different thresholds 32 | } 33 | 34 | # Define thresholds to analyze 35 | thresholds = [0.01, 0.02, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3] 36 | 37 | # Count total EMSR folders 38 | emsr_folders = [f for f in labels_dir.iterdir() if f.is_dir()] 39 | 40 | # Iterate through all EMSR folders 41 | for emsr_folder in tqdm(emsr_folders, desc="Processing EMSR folders"): 42 | # Get all FloodMask files 43 | flood_files = list(emsr_folder.glob('*_FloodMask.tif')) 44 | if not flood_files: 45 | stats['errors'].append(f"No FloodMask files found in {emsr_folder.name}") 46 | continue 47 | 48 | for flood_file in flood_files: 49 | # Read label 50 | try: 51 | ds = gdal.Open(str(flood_file)) 52 | label = ds.ReadAsArray() 53 | ds = None 54 | 55 | # Skip tiles that only contain 0 values 56 | if np.all(label == 0): 57 | continue 58 | 59 | # Track unique values and their counts 60 | unique_vals, counts = np.unique(label, return_counts=True) 61 | stats['unique_values'].update(unique_vals) 62 | for val, count in zip(unique_vals, counts): 63 | stats['value_counts'][int(val)] += count 64 | 65 | # Update total pixels 66 | stats['total_pixels'] += label.size 67 | 68 | # Calculate statistics 69 | total_pixels = label.size 70 | # Consider both values 1 and 2 as flooded 71 | flooded_pixels = np.sum((label == 1) | (label == 2)) 72 | flood_ratio = flooded_pixels / total_pixels 73 | 74 | # Only count tiles that have at least some flooded pixels 75 | if flooded_pixels > 0: 76 | # Update statistics 77 | stats['total_tiles'] += 1 78 | stats['tiles_per_emsr'][emsr_folder.name] += 1 79 | stats['flood_ratios'].append(flood_ratio) # Store flood ratio for histogram 80 | 81 | # Update threshold statistics 82 | for threshold in thresholds: 83 | if flood_ratio > threshold: 84 | stats['threshold_stats'][threshold]['tiles'] += 1 85 | for val, count in zip(unique_vals, counts): 86 | stats['threshold_stats'][threshold]['pixels'][int(val)] += count 87 | 88 | if flood_ratio > 0.02: # Consider tile flooded if more than 2% is flooded 89 | stats['flooded_tiles'] += 1 90 | stats['flooded_per_emsr'][emsr_folder.name] += 1 91 | else: 92 | stats['non_flooded_tiles'] += 1 93 | 94 | except Exception as e: 95 | stats['errors'].append(f"Error reading {flood_file}: {str(e)}") 96 | continue 97 | 98 | # Calculate overall flood ratio 99 | if stats['total_tiles'] > 0: 100 | stats['flood_ratio'] = stats['flooded_tiles'] / stats['total_tiles'] 101 | 102 | # Print summary 103 | print("\n=== Label Distribution Summary ===") 104 | print(f"Total tiles with flood pixels: {stats['total_tiles']}") 105 | print(f"Flooded tiles (>2% flood): {stats['flooded_tiles']}") 106 | print(f"Non-flooded tiles (<2% flood): {stats['non_flooded_tiles']}") 107 | print(f"Overall flood ratio: {stats['flood_ratio']:.2%}") 108 | print("\nValue Distribution:") 109 | for val in sorted(stats['value_counts'].keys()): 110 | count = stats['value_counts'][val] 111 | percentage = (count / stats['total_pixels']) * 100 112 | print(f"Value {val}: {count:,} pixels ({percentage:.2f}% of total)") 113 | 114 | print("\nDistribution after discarding tiles with less than X% flood pixels:") 115 | for threshold in thresholds: 116 | threshold_pixels = sum(stats['threshold_stats'][threshold]['pixels'].values()) 117 | if threshold_pixels > 0: 118 | flood_pixels = (stats['threshold_stats'][threshold]['pixels'][1] + 119 | stats['threshold_stats'][threshold]['pixels'][2]) 120 | print(f"\nThreshold {threshold*100:.1f}%:") 121 | print(f"Remaining tiles: {stats['threshold_stats'][threshold]['tiles']} " 122 | f"({stats['threshold_stats'][threshold]['tiles']/stats['total_tiles']*100:.1f}% of total)") 123 | print(f"Flood pixel ratio: {flood_pixels/threshold_pixels*100:.2f}%") 124 | 125 | if stats['errors']: 126 | print("\n=== Errors and Warnings ===") 127 | for error in stats['errors']: 128 | print(f"- {error}") 129 | 130 | if stats['total_tiles'] == 0: 131 | print("\nNo valid tiles found! Please check:") 132 | print("1. The directory structure is correct") 133 | print("2. The FloodMask files are valid GeoTIFFs") 134 | return 135 | 136 | # Create EMSR distribution DataFrame 137 | emsr_stats = [] 138 | for emsr in stats['tiles_per_emsr'].keys(): 139 | emsr_stats.append({ 140 | 'EMSR': emsr, 141 | 'Total Tiles': stats['tiles_per_emsr'][emsr], 142 | 'Flooded Tiles': stats['flooded_per_emsr'][emsr], 143 | 'Flood Ratio': stats['flooded_per_emsr'][emsr] / stats['tiles_per_emsr'][emsr] if stats['tiles_per_emsr'][emsr] > 0 else 0 144 | }) 145 | 146 | if emsr_stats: # Only create DataFrame if we have data 147 | df = pd.DataFrame(emsr_stats) 148 | df = df.sort_values('Total Tiles', ascending=False) 149 | 150 | # Save to CSV 151 | output_file = data_dir / 'label_distribution.csv' 152 | df.to_csv(output_file, index=False) 153 | print(f"\nDetailed statistics saved to {output_file}") 154 | 155 | # Create visualizations 156 | plt.figure(figsize=(15, 12)) 157 | 158 | # Plot 1: Overall distribution 159 | plt.subplot(2, 2, 1) 160 | plt.pie([stats['flooded_tiles'], stats['non_flooded_tiles']], 161 | labels=[f'Flooded (>2%)\n{stats["flooded_tiles"]} tiles', 162 | f'Non-flooded (<2%)\n{stats["non_flooded_tiles"]} tiles'], 163 | autopct='%.1f%%') 164 | plt.title('Overall Label Distribution') 165 | 166 | # Plot 2: Histogram of flood ratios 167 | plt.subplot(2, 2, 2) 168 | plt.hist(stats['flood_ratios'], bins=20, edgecolor='black') 169 | plt.title('Distribution of Flood Ratios') 170 | plt.xlabel('Flood Ratio') 171 | plt.ylabel('Number of Tiles') 172 | plt.axvline(x=0.02, color='r', linestyle='--', label='2% Threshold') 173 | plt.legend() 174 | 175 | # Plot 3: Threshold impact on tile count 176 | plt.subplot(2, 2, 3) 177 | threshold_tiles = [stats['threshold_stats'][t]['tiles'] for t in thresholds] 178 | plt.plot(np.array(thresholds) * 100, threshold_tiles, 'b-o') 179 | plt.title('Number of Tiles vs Threshold') 180 | plt.xlabel('Threshold (%)') 181 | plt.ylabel('Number of Tiles') 182 | plt.grid(True) 183 | 184 | # Plot 4: Threshold impact on flood pixel ratio 185 | plt.subplot(2, 2, 4) 186 | flood_ratios = [] 187 | for t in thresholds: 188 | total_pixels = sum(stats['threshold_stats'][t]['pixels'].values()) 189 | if total_pixels > 0: 190 | flood_pixels = (stats['threshold_stats'][t]['pixels'][1] + 191 | stats['threshold_stats'][t]['pixels'][2]) 192 | flood_ratios.append(flood_pixels / total_pixels * 100) 193 | else: 194 | flood_ratios.append(0) 195 | 196 | plt.plot(np.array(thresholds) * 100, flood_ratios, 'g-o') 197 | plt.title('Flood Pixel Ratio vs Threshold') 198 | plt.xlabel('Threshold (%)') 199 | plt.ylabel('Flood Pixel Ratio (%)') 200 | plt.grid(True) 201 | 202 | plt.tight_layout() 203 | plt.savefig(data_dir / 'label_distribution.png') 204 | print(f"Visualization saved to {data_dir / 'label_distribution.png'}") 205 | 206 | if __name__ == "__main__": 207 | analyze_labels() -------------------------------------------------------------------------------- /FloodsML/0_create_splits.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from pathlib import Path 3 | from osgeo import gdal 4 | import random 5 | from tqdm import tqdm 6 | import sys 7 | 8 | def create_filtered_splits(data_dir='data', train_ratio=0.8, flood_threshold=0.01): 9 | """ 10 | Create train and validation splits with filtering criteria: 11 | - Exclude tiles with only 0 values 12 | - Exclude tiles with less than 1% of flood pixels (values 1 and 2) 13 | 14 | Args: 15 | data_dir (str): Path to data directory 16 | train_ratio (float): Ratio of training data (default: 0.8) 17 | flood_threshold (float): Minimum ratio of flood pixels (default: 0.01) 18 | """ 19 | data_dir = Path(data_dir) 20 | images_dir = data_dir / '512_images' 21 | labels_dir = data_dir / '512_labels' 22 | 23 | if not images_dir.exists(): 24 | print(f"Error: Images directory not found at {images_dir}", flush=True) 25 | return 26 | if not labels_dir.exists(): 27 | print(f"Error: Labels directory not found at {labels_dir}", flush=True) 28 | return 29 | 30 | # Statistics for reporting 31 | stats = { 32 | 'total_pairs': 0, 33 | 'filtered_pairs': 0, 34 | 'skipped_zero_only': 0, 35 | 'skipped_low_flood': 0, 36 | 'pairs_per_emsr': {} 37 | } 38 | 39 | # Store valid pairs 40 | pairs = [] 41 | 42 | # Get all EMSR folders first 43 | emsr_folders = list(images_dir.iterdir()) 44 | emsr_folders = [f for f in emsr_folders if f.is_dir()] 45 | 46 | # Create progress bar for EMSR folders 47 | pbar = tqdm(emsr_folders, desc="Processing EMSR folders", file=sys.stdout) 48 | 49 | # Iterate through all EMSR folders 50 | for emsr_folder in pbar: 51 | # Update progress bar description with current folder 52 | pbar.set_description(f"Processing {emsr_folder.name}") 53 | 54 | # Get corresponding label folder 55 | label_folder = labels_dir / emsr_folder.name 56 | if not label_folder.exists(): 57 | print(f"Warning: Label folder missing for {emsr_folder.name}", flush=True) 58 | continue 59 | 60 | # Get all unique base names (without extensions and suffixes) 61 | base_names = set() 62 | for vh_file in emsr_folder.glob('*_Sentinel1_vh.tif'): 63 | # Extract base name (e.g., 'tile_001_1_01' from 'tile_001_1_01_Sentinel1_vh.tif') 64 | base_name = vh_file.stem.replace('_Sentinel1_vh', '') 65 | base_names.add(base_name) 66 | 67 | if not base_names: 68 | print(f"Warning: No valid image files found in {emsr_folder.name}", flush=True) 69 | continue 70 | 71 | emsr_pairs = 0 72 | for base_name in base_names: 73 | # Extract tile number and other components 74 | # e.g., from 'tile_001_1_01' get '001', '1', '01' 75 | parts = base_name.split('_') 76 | if len(parts) < 4: 77 | print(f"Warning: Invalid base name format: {base_name}", flush=True) 78 | continue 79 | 80 | tile_num = parts[1] 81 | suffix_num = parts[2] # e.g., '1' from 'tile_001_1_01' 82 | version_num = parts[3] # e.g., '01' from 'tile_001_1_01' 83 | if emsr_folder == 'EMSR783_AOI01_DEL_MONIT01': 84 | print("Working on EMSR783_AOI01_DEL_MONIT01", flush=True) 85 | # Construct paths for all required files 86 | vh_file = emsr_folder / f"{base_name}_Sentinel1_vh.tif" 87 | vv_file = emsr_folder / f"{base_name}_Sentinel1_vv.tif" 88 | # DEM files use _02 instead of _01 89 | dem_file = emsr_folder / f"tile_{tile_num}_{suffix_num}_02_DEM.tif" 90 | 91 | # Find corresponding label file with FloodMask pattern 92 | label_files = list(label_folder.glob(f"tile_{tile_num}_{suffix_num}_*_FloodMask.tif")) 93 | if not label_files: 94 | print(f"Warning: No FloodMask file found for {base_name}", flush=True) 95 | continue 96 | 97 | label_file = label_files[0] # Take the first matching label file 98 | 99 | # Check if all required files exist 100 | if not all(f.exists() for f in [vh_file, vv_file, dem_file, label_file]): 101 | missing_files = [f.name for f in [vh_file, vv_file, dem_file, label_file] if not f.exists()] 102 | print(f"Warning: Missing files for {base_name}: {', '.join(missing_files)}", flush=True) 103 | continue 104 | 105 | # Read and analyze label 106 | try: 107 | ds = gdal.Open(str(label_file)) 108 | label = ds.ReadAsArray() 109 | ds = None 110 | 111 | # Check if tile has only zeros 112 | if np.all(label == 0): 113 | stats['skipped_zero_only'] += 1 114 | continue 115 | 116 | # Calculate flood ratio (values 1 and 2) 117 | total_pixels = label.size 118 | flood_pixels = np.sum((label == 1) | (label == 2)) 119 | flood_ratio = flood_pixels / total_pixels 120 | 121 | # Check label and vv band for valid flood ratio 122 | try: 123 | # Read vv_file 124 | ds_vv = gdal.Open(str(vv_file)) 125 | vv_band = ds_vv.ReadAsArray() 126 | ds_vv = None 127 | 128 | # Check if vv_band dimensions are smaller than label 129 | if vv_band.shape != label.shape: 130 | # Calculate padding 131 | pad_height = label.shape[0] - vv_band.shape[0] 132 | pad_width = label.shape[1] - vv_band.shape[1] 133 | 134 | # Ensure padding values are non-negative 135 | if pad_height < 0 or pad_width < 0: 136 | #apply padding to label 137 | label = np.pad( 138 | label, 139 | ((0, -pad_height), (0, -pad_width)), 140 | mode='constant', 141 | constant_values=0 142 | ) 143 | print(f"Error: vv_file dimensions are larger than label {label.shape}", flush=True) 144 | else: 145 | # Apply padding to vv_band 146 | vv_band = np.pad( 147 | vv_band, 148 | ((0, pad_height), (0, pad_width)), 149 | mode='constant', 150 | constant_values=0 151 | ) 152 | 153 | # Mask the label using vv_file where vv_band != 0 154 | masked_label = np.where(vv_band != 0, label, 0) 155 | 156 | # Check if masked label has only zeros 157 | if np.all(masked_label == 0): 158 | stats['skipped_zero_only'] += 1 159 | continue 160 | 161 | # Calculate flood ratio for the remaining pixels 162 | # total_pixels = np.count_nonzero(vv_band != 0) 163 | flood_pixels = np.sum((masked_label == 1) | (masked_label == 2)) 164 | flood_ratio = flood_pixels / total_pixels 165 | 166 | # Skip if flood ratio is below threshold 167 | if flood_ratio < flood_threshold: 168 | stats['skipped_low_flood'] += 1 169 | continue 170 | except Exception as e: 171 | print(f"Error processing {emsr_folder}/{base_name} with vv_file: {str(e)}", flush=True) 172 | continue 173 | if emsr_folder == 'EMSR783_AOI01_DEL_MONIT01': 174 | print(f"EMSR783_AOI01_DEL_MONIT01 flood ratio: {flood_ratio}", flush=True) 175 | print(flood_pixels, total_pixels, flood_ratio, flush=True) 176 | 177 | # Store relative paths 178 | rel_vh_path = vh_file.relative_to(data_dir) 179 | rel_vv_path = vv_file.relative_to(data_dir) 180 | rel_dem_path = dem_file.relative_to(data_dir) 181 | rel_label_path = label_file.relative_to(data_dir) 182 | 183 | # Store as a quadruplet: vh|vv|dem|label 184 | pairs.append(f"{rel_vh_path}|{rel_vv_path}|{rel_dem_path}|{rel_label_path}") 185 | emsr_pairs += 1 186 | 187 | except Exception as e: 188 | print(f"Error reading {label_file}: {str(e)}", flush=True) 189 | continue 190 | if emsr_folder == 'EMSR783_AOI01_DEL_MONIT01': 191 | print(f"EMSR783_AOI01_DEL_MONIT01 pairs: {emsr_pairs}", flush=True) 192 | stats['pairs_per_emsr'][emsr_folder.name] = emsr_pairs 193 | stats['total_pairs'] += len(base_names) 194 | 195 | if not pairs: 196 | print("Error: No valid pairs found!", flush=True) 197 | return 198 | 199 | # Shuffle the pairs 200 | random.seed(42) # for reproducibility 201 | random.shuffle(pairs) 202 | 203 | # Split into train/val 204 | train_size = int(train_ratio * len(pairs)) 205 | train_pairs = pairs[:train_size] 206 | val_pairs = pairs[train_size:] 207 | 208 | # Save splits 209 | train_file = data_dir / 'train.txt' 210 | val_file = data_dir / 'val.txt' 211 | 212 | with open(train_file, 'w') as f: 213 | f.write('\n'.join(train_pairs)) 214 | with open(val_file, 'w') as f: 215 | f.write('\n'.join(val_pairs)) 216 | 217 | # Print summary 218 | print("\n=== Split Creation Summary ===", flush=True) 219 | print(f"Total image pairs found: {stats['total_pairs']}", flush=True) 220 | print(f"Pairs skipped (zero-only): {stats['skipped_zero_only']}", flush=True) 221 | print(f"Pairs skipped (low flood): {stats['skipped_low_flood']}", flush=True) 222 | print(f"Valid pairs kept: {len(pairs)}", flush=True) 223 | print(f"Training pairs: {len(train_pairs)}", flush=True) 224 | print(f"Validation pairs: {len(val_pairs)}", flush=True) 225 | 226 | print("\nPairs per EMSR:") 227 | for emsr, count in sorted(stats['pairs_per_emsr'].items(), key=lambda x: x[1], reverse=True): 228 | print(f"{emsr}: {count} pairs", flush=True) 229 | 230 | if __name__ == "__main__": 231 | create_filtered_splits() -------------------------------------------------------------------------------- /FloodsML/utils/training_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import wandb 4 | from pathlib import Path 5 | from torch.optim import Optimizer 6 | from torch.utils.data import DataLoader 7 | from torchmetrics import JaccardIndex 8 | from typing import Dict, Any 9 | from osgeo import gdal 10 | 11 | gdal.UseExceptions() 12 | 13 | class DiceLoss(nn.Module): 14 | def __init__(self, smooth=1.0): 15 | super(DiceLoss, self).__init__() 16 | self.smooth = smooth 17 | 18 | def forward(self, predictions, targets): 19 | predictions = torch.sigmoid(predictions) 20 | predictions = predictions.view(-1) 21 | targets = targets.view(-1).float() 22 | 23 | intersection = (predictions * targets).sum() 24 | dice = (2. * intersection + self.smooth) / (predictions.sum() + targets.sum() + self.smooth) 25 | 26 | return 1 - dice 27 | 28 | class MaskedDiceLoss(DiceLoss): 29 | def forward(self, predictions, targets, images): 30 | predictions = torch.sigmoid(predictions) 31 | 32 | # Create mask for pixels where either VH or VV is 0 33 | # images shape: [batch_size, 2, height, width] 34 | mask = (images[:, 0] != 0) & (images[:, 1] != 0) # [batch_size, height, width] 35 | mask = mask.unsqueeze(1) # Add channel dimension 36 | 37 | # Apply mask to predictions and targets 38 | predictions = predictions * mask 39 | targets = targets * mask 40 | 41 | # Use parent class's dice calculation 42 | return super().forward(predictions, targets) 43 | 44 | class MaskedBCELoss(nn.BCEWithLogitsLoss): 45 | def forward(self, predictions, targets, images): 46 | # Create mask for pixels where either VH or VV is 0 47 | # images shape: [batch_size, 2, height, width] 48 | mask = (images[:, 0] != 0) & (images[:, 1] != 0) # [batch_size, height, width] 49 | mask = mask.unsqueeze(1) # Add channel dimension 50 | 51 | # Apply mask to predictions and targets 52 | predictions = predictions * mask 53 | targets = targets * mask 54 | 55 | # Use parent class's BCE calculation 56 | return super().forward(predictions, targets) 57 | 58 | def log_summary_metrics(metrics, config): 59 | """Log summary metrics and create visualizations in wandb""" 60 | if not config.get('use_wandb', False): 61 | return 62 | 63 | # Create a dictionary with default values for missing metrics 64 | summary_dict = { 65 | "final_train_iou": metrics.get('train_iou', 0.0), 66 | "val_iou": metrics.get('val_iou', 0.0), 67 | "best_val_loss": metrics.get('best_val_loss', float('inf')), 68 | "training_duration": metrics.get('training_duration', 0.0), 69 | "total_epochs": metrics.get('total_epochs', 0) 70 | } 71 | 72 | # Only log if wandb is initialized 73 | if wandb.run: 74 | wandb.log(summary_dict) 75 | 76 | def setup_wandb_run(config, model_type): 77 | """Setup wandb run with proper grouping and tagging""" 78 | # Create meaningful group name based on model type 79 | if model_type.lower() == "unet": 80 | group_name = f"unet_{config['model']['n_channels']}ch_{config['batch_size']}bs" 81 | elif model_type.lower() == "transformersegmentation": 82 | group_name = f"transformer_{config['batch_size']}bs" 83 | else: # DeepLab 84 | group_name = f"deeplab_{config['backbone']}_{config['batch_size']}bs" 85 | 86 | # Common tags 87 | tags = [ 88 | f"lr_{config['learning_rate']}", 89 | f"bs_{config['batch_size']}", 90 | model_type.lower(), 91 | ] 92 | 93 | # Model-specific tags 94 | if model_type.lower() == "unet": 95 | tags.extend([ 96 | "bilinear" if config['model']['bilinear'] else "transpose", 97 | f"channels_{config['model']['n_channels']}" 98 | ]) 99 | elif model_type.lower() == "transformersegmentation": 100 | tags.extend([ 101 | "segformer", 102 | f"image_size_{config['model']['image_size']}" 103 | ]) 104 | else: # DeepLab 105 | tags.extend([ 106 | f"backbone_{config['backbone']}", 107 | f"output_stride_{config['output_stride']}" 108 | ]) 109 | 110 | wandb.init( 111 | project="flood-segmentation", 112 | config=config, 113 | group=group_name, 114 | tags=tags, 115 | name=f"{model_type}_lr{config['learning_rate']}_bs{config['batch_size']}" 116 | ) 117 | 118 | # Define custom metrics 119 | wandb.define_metric("train/loss", summary="min") 120 | wandb.define_metric("train/iou", summary="max") 121 | wandb.define_metric("val/loss", summary="min") 122 | wandb.define_metric("val/iou", summary="max") 123 | wandb.define_metric("learning_rate", summary="min") 124 | 125 | def log_metrics(metrics, epoch): 126 | """ 127 | Unified metric logging for both models 128 | """ 129 | if not wandb.run: 130 | return 131 | 132 | wandb.log({ 133 | # Loss components 134 | "train/total_loss": metrics['total_loss'], # Combined weighted loss 135 | "train/bce_loss": metrics.get('bce_loss', 0), 136 | "train/dice_loss": metrics.get('dice_loss', 0), 137 | 138 | # Training metrics 139 | "train/iou": metrics['train_iou'], 140 | "train/dice_score": 1 - metrics.get('train_dice_loss', 0), 141 | 142 | # Validation metrics 143 | "val/total_loss": metrics['val_loss'], # Combined weighted loss for validation 144 | "val/iou": metrics['val_iou'], 145 | "val/dice_score": 1 - metrics.get('val_dice_loss', 0), 146 | 147 | # Learning rate 148 | "train/learning_rate": metrics['learning_rate'], 149 | 150 | "epoch": epoch, 151 | }) 152 | 153 | def save_checkpoint(model, optimizer, epoch, metrics, save_dir): 154 | """ 155 | Unified checkpoint saving for both models 156 | 157 | Args: 158 | model: PyTorch model 159 | optimizer: PyTorch optimizer 160 | epoch: Current epoch number 161 | metrics: Dictionary of metrics 162 | save_dir: Directory to save the checkpoint 163 | """ 164 | save_dir = Path(save_dir) 165 | save_dir.mkdir(parents=True, exist_ok=True) 166 | 167 | # Save best model 168 | best_model_path = save_dir / 'best_model.pth' 169 | torch.save({ 170 | 'epoch': epoch, 171 | 'model_state_dict': model.state_dict(), 172 | 'optimizer_state_dict': optimizer.state_dict(), 173 | 'metrics': metrics, 174 | 'best_val_loss': metrics.get('val_loss', float('inf')), 175 | 'val_iou': metrics.get('val_iou', 0.0) 176 | }, best_model_path) 177 | 178 | # Save periodic checkpoint 179 | if (epoch + 1) % 5 == 0: # Save every 5 epochs 180 | periodic_path = save_dir / f'checkpoint_epoch_{epoch+1}.pth' 181 | torch.save({ 182 | 'epoch': epoch, 183 | 'model_state_dict': model.state_dict(), 184 | 'optimizer_state_dict': optimizer.state_dict(), 185 | 'metrics': metrics 186 | }, periodic_path) 187 | 188 | def create_save_directory(save_dir): 189 | """ 190 | Create directory for saving model checkpoints 191 | """ 192 | save_dir = Path(save_dir) 193 | save_dir.mkdir(parents=True, exist_ok=True) 194 | return save_dir 195 | 196 | def train_epoch( 197 | model: nn.Module, 198 | loader: DataLoader, 199 | optimizer: Optimizer, 200 | criterion_bce: nn.Module, 201 | criterion_dice: nn.Module, 202 | jaccard: JaccardIndex, 203 | config: Dict[str, Any], 204 | device: torch.device 205 | ) -> Dict[str, float]: 206 | """ 207 | Single training epoch for both UNet and DeepLab models 208 | """ 209 | model.train() 210 | metrics = { 211 | 'total_loss': 0, 212 | 'bce_loss': 0, 213 | 'dice_loss': 0, 214 | 'train_iou': 0, 215 | 'learning_rate': optimizer.param_groups[0]['lr'] # Add learning rate to metrics 216 | } 217 | 218 | for batch in loader: 219 | images = batch['image'].to(device) 220 | masks = batch['label'].to(device) 221 | 222 | outputs = model(images) 223 | 224 | # Calculate losses 225 | loss_bce = criterion_bce(outputs, masks.float(), images) # Pass images for masking 226 | loss_dice = criterion_dice(outputs, masks, images) # Pass images for masking 227 | 228 | # Get loss weights from config, default to 0.5/0.5 if not specified 229 | bce_weight = config.get('bce_weight', 0.5) 230 | dice_weight = config.get('dice_weight', 0.5) 231 | loss = bce_weight * loss_bce + dice_weight * loss_dice 232 | 233 | # Calculate metrics 234 | preds = (torch.sigmoid(outputs) > 0.5).float() 235 | batch_iou = jaccard(preds, masks) 236 | 237 | # Update metrics 238 | metrics['total_loss'] += loss.item() 239 | metrics['bce_loss'] += loss_bce.item() 240 | metrics['dice_loss'] += loss_dice.item() 241 | metrics['train_iou'] += batch_iou.item() 242 | 243 | # Backward pass 244 | optimizer.zero_grad() 245 | loss.backward() 246 | optimizer.step() 247 | 248 | # Average metrics 249 | for key in metrics: 250 | if key != 'learning_rate': # Don't average the learning rate 251 | metrics[key] /= len(loader) 252 | 253 | return metrics 254 | 255 | def validate_epoch( 256 | model: nn.Module, 257 | loader: DataLoader, 258 | criterion_bce: nn.Module, 259 | criterion_dice: nn.Module, 260 | jaccard: JaccardIndex, 261 | config: Dict[str, Any], 262 | device: torch.device 263 | ) -> Dict[str, float]: 264 | """ 265 | Single validation epoch for both UNet and DeepLab models 266 | """ 267 | model.eval() 268 | metrics = { 269 | 'val_loss': 0, 270 | 'val_iou': 0, 271 | 'val_dice_loss': 0, 272 | 'learning_rate': config.get('learning_rate', 0.0) # Add learning rate to metrics 273 | } 274 | 275 | with torch.no_grad(): 276 | for batch in loader: 277 | images = batch['image'].to(device) 278 | masks = batch['label'].to(device) 279 | 280 | outputs = model(images) 281 | 282 | # Calculate losses 283 | loss_bce = criterion_bce(outputs, masks.float(), images) # Pass images for masking 284 | loss_dice = criterion_dice(outputs, masks, images) # Pass images for masking 285 | 286 | # Get loss weights from config, default to 0.5/0.5 if not specified 287 | bce_weight = config.get('bce_weight', 0.5) 288 | dice_weight = config.get('dice_weight', 0.5) 289 | loss = bce_weight * loss_bce + dice_weight * loss_dice 290 | 291 | # Calculate metrics 292 | preds = (torch.sigmoid(outputs) > 0.5).float() 293 | batch_iou = jaccard(preds, masks) 294 | 295 | metrics['val_loss'] += loss.item() 296 | metrics['val_iou'] += batch_iou.item() 297 | metrics['val_dice_loss'] += loss_dice.item() 298 | 299 | # Average metrics 300 | for key in metrics: 301 | if key != 'learning_rate': # Don't average the learning rate 302 | metrics[key] /= len(loader) 303 | 304 | return metrics -------------------------------------------------------------------------------- /FloodsML/optimize_parameters.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.optim as optim 4 | from torch.utils.data import DataLoader 5 | from pathlib import Path 6 | import wandb 7 | import os 8 | import psutil 9 | import time 10 | from train_unet import train_model 11 | from flood_dataset import FloodSegmentationDataset 12 | from models.unet import UNet 13 | 14 | # Define project structure 15 | PROJECT_BASE = 'flood-segmentation' 16 | OPTIMIZATION_PHASE = 'data-loading-opt' # Change this for different optimization phases 17 | PROJECT_NAME = f"{PROJECT_BASE}/{OPTIMIZATION_PHASE}" 18 | 19 | def get_gpu_memory_usage(): 20 | """Get GPU memory usage if available""" 21 | if torch.cuda.is_available(): 22 | return torch.cuda.memory_allocated() / 1024**3 # Convert to GB 23 | return 0 24 | 25 | def create_sweep_config(): 26 | """Create sweep configuration for parameter optimization""" 27 | sweep_config = { 28 | 'method': 'bayes', # Bayesian optimization 29 | 'metric': { 30 | 'name': 'val_iou', 31 | 'goal': 'maximize' 32 | }, 33 | 'name': OPTIMIZATION_PHASE, # Name of the sweep 34 | 'project': PROJECT_BASE, # Project name 35 | 'parameters': { 36 | # Learning rate parameters 37 | 'learning_rate': { 38 | 'distribution': 'log_uniform_values', 39 | 'min': 1e-5, 40 | 'max': 1e-3, 41 | }, 42 | 'weight_decay': { 43 | 'distribution': 'log_uniform_values', 44 | 'min': 1e-5, 45 | 'max': 1e-3, 46 | }, 47 | 'warmup_epochs': { 48 | 'values': [1, 2, 3, 4, 5] 49 | }, 50 | 'plateau_factor': { 51 | 'values': [0.1, 0.2, 0.5, 0.7] 52 | }, 53 | 'plateau_patience': { 54 | 'values': [3, 5, 7, 10] 55 | }, 56 | 57 | # Data loading parameters 58 | 'batch_size': { 59 | 'values': [16, 32, 40] 60 | }, 61 | 'num_workers': { 62 | 'values': [2, 4, 8, 16] 63 | }, 64 | 'prefetch_factor': { 65 | 'values': [1, 2, 4] 66 | } 67 | } 68 | } 69 | return sweep_config 70 | 71 | def train_sweep(): 72 | """Training function for sweep""" 73 | # Initialize wandb with project structure first 74 | if not wandb.run: # Only initialize if not already initialized 75 | wandb.init(project=PROJECT_BASE, group=OPTIMIZATION_PHASE) 76 | 77 | try: 78 | # Get hyperparameters from sweep 79 | config = { 80 | 'learning_rate': wandb.config.learning_rate, 81 | 'weight_decay': wandb.config.weight_decay, 82 | 'epochs': 60, 83 | 'checkpoint_freq': 5, 84 | 'save_dir': f'checkpoints/{OPTIMIZATION_PHASE}/{wandb.run.id}', 85 | 'use_wandb': True, 86 | 'early_stopping_patience': 10, 87 | 'warmup_epochs': wandb.config.warmup_epochs, 88 | 'model': { 89 | 'n_channels': 2, 90 | 'n_classes': 1, 91 | 'bilinear': True # Fixed to True for now 92 | }, 93 | 'batch_size': wandb.config.batch_size, 94 | 'bce_weight': 0.5, 95 | 'dice_weight': 0.5, 96 | 'plateau_factor': wandb.config.plateau_factor, 97 | 'plateau_patience': wandb.config.plateau_patience, 98 | 'num_workers': wandb.config.num_workers, 99 | 'prefetch_factor': wandb.config.prefetch_factor, 100 | 'optimization_phase': OPTIMIZATION_PHASE # Track which phase this run belongs to 101 | } 102 | 103 | # Log system information immediately after wandb initialization 104 | if wandb.run: 105 | wandb.log({ 106 | 'system/cpu_count': os.cpu_count(), 107 | 'system/gpu_count': torch.cuda.device_count() if torch.cuda.is_available() else 0, 108 | 'system/gpu_name': torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'None', 109 | 'system/memory_total': psutil.virtual_memory().total / 1024**3, 110 | 'system/memory_available': psutil.virtual_memory().available / 1024**3, 111 | 'system/gpu_memory_reserved': torch.cuda.memory_reserved() / 1024**3, 112 | 'system/max_gpu_memory_allocated': torch.cuda.max_memory_allocated() / 1024**3, 113 | 'config': config # Log the full config 114 | }) 115 | 116 | # Create data directory if it doesn't exist 117 | data_dir = Path('data') 118 | data_dir.mkdir(parents=True, exist_ok=True) 119 | 120 | # Create datasets 121 | train_dataset = FloodSegmentationDataset( 122 | map_file=str(data_dir / 'train.txt'), 123 | base_path=str(data_dir), 124 | split='train' 125 | ) 126 | 127 | val_dataset = FloodSegmentationDataset( 128 | map_file=str(data_dir / 'val.txt'), 129 | base_path=str(data_dir), 130 | split='val' 131 | ) 132 | 133 | # Create dataloaders with sweep parameters 134 | train_loader = DataLoader( 135 | train_dataset, 136 | batch_size=config['batch_size'], 137 | shuffle=True, 138 | num_workers=config['num_workers'], 139 | pin_memory=True, 140 | persistent_workers=True, 141 | prefetch_factor=config['prefetch_factor'] 142 | ) 143 | 144 | val_loader = DataLoader( 145 | val_dataset, 146 | batch_size=config['batch_size'], 147 | shuffle=False, 148 | num_workers=config['num_workers'], 149 | pin_memory=True, 150 | persistent_workers=True, 151 | prefetch_factor=config['prefetch_factor'] 152 | ) 153 | 154 | # Initialize model 155 | model = UNet(**config['model']) 156 | 157 | # Train the model 158 | best_val_loss = train_model(model, train_loader, val_loader, config) 159 | 160 | # Log final system metrics 161 | if wandb.run: 162 | wandb.log({ 163 | 'system/final_gpu_memory': get_gpu_memory_usage(), 164 | 'system/final_cpu_percent': psutil.cpu_percent(), 165 | 'system/final_memory_percent': psutil.virtual_memory().percent 166 | }) 167 | 168 | # Set the validation IoU as a summary metric for the sweep 169 | wandb.run.summary.update({ 170 | 'val_iou': wandb.run.summary.get('val_iou', 0.0), 171 | 'best_val_loss': best_val_loss 172 | }) 173 | 174 | except torch.cuda.OutOfMemoryError as e: 175 | # Log OOM error specifically 176 | if wandb.run: 177 | wandb.log({ 178 | 'error': str(e), 179 | 'error_type': 'OutOfMemoryError', 180 | 'failed_config': wandb.config, 181 | 'system/final_gpu_memory': get_gpu_memory_usage(), 182 | 'system/final_cpu_percent': psutil.cpu_percent(), 183 | 'system/final_memory_percent': psutil.virtual_memory().percent 184 | }) 185 | raise e 186 | except Exception as e: 187 | # Log other errors 188 | if wandb.run: 189 | wandb.log({ 190 | 'error': str(e), 191 | 'error_type': type(e).__name__, 192 | 'failed_config': wandb.config, 193 | 'system/final_gpu_memory': get_gpu_memory_usage(), 194 | 'system/final_cpu_percent': psutil.cpu_percent(), 195 | 'system/final_memory_percent': psutil.virtual_memory().percent 196 | }) 197 | raise e 198 | finally: 199 | # Ensure wandb is properly closed 200 | if wandb.run: 201 | wandb.finish() 202 | 203 | def run_sweep_with_retries(sweep_id, max_retries=2, total_trials=20): 204 | """Run sweep with retry mechanism for failed trials""" 205 | api = wandb.Api() 206 | 207 | successful_trials = 0 208 | failed_trials = 0 209 | 210 | while successful_trials < total_trials: 211 | try: 212 | # Run a single trial 213 | wandb.agent(sweep_id, function=train_sweep, count=1) 214 | successful_trials += 1 215 | print(f"Successfully completed trial {successful_trials}/{total_trials}") 216 | 217 | except Exception as e: 218 | failed_trials += 1 219 | print(f"Trial {successful_trials + failed_trials} failed with error: {str(e)}") 220 | 221 | if failed_trials >= max_retries: 222 | print(f"Too many consecutive failures ({failed_trials}). Stopping sweep.") 223 | break 224 | 225 | print(f"Retrying... (Attempt {failed_trials}/{max_retries})") 226 | continue 227 | 228 | print(f"\nSweep completed with {successful_trials} successful trials and {failed_trials} failed trials") 229 | return successful_trials, failed_trials 230 | 231 | def print_optimal_parameters(sweep_id): 232 | """Print optimal parameters from sweep""" 233 | api = wandb.Api() 234 | 235 | # Add a delay to allow for synchronization 236 | time.sleep(5) 237 | 238 | try: 239 | # Get sweep 240 | sweep = api.sweep(f"{PROJECT_NAME}/{sweep_id}") 241 | 242 | # Get best run 243 | best_run = sweep.best_run() 244 | 245 | print("\nOptimal Parameters Found:") 246 | print("-" * 50) 247 | print(f"Best Validation IoU: {best_run.summary['val_iou']:.4f}") 248 | print("\nLearning Rate Parameters:") 249 | print(f"Learning Rate: {wandb.config.learning_rate:.2e}") 250 | print(f"Weight Decay: {wandb.config.weight_decay:.2e}") 251 | print(f"Warmup Epochs: {wandb.config.warmup_epochs}") 252 | print(f"Plateau Factor: {wandb.config.plateau_factor}") 253 | print(f"Plateau Patience: {wandb.config.plateau_patience}") 254 | print("\nData Loading Parameters:") 255 | print(f"Batch Size: {wandb.config.batch_size}") 256 | print(f"Number of Workers: {wandb.config.num_workers}") 257 | print(f"Prefetch Factor: {wandb.config.prefetch_factor}") 258 | print("-" * 50) 259 | except Exception as e: 260 | print("\nCould not retrieve optimal parameters yet.") 261 | print("This is normal if the sweep is still running or results haven't synced.") 262 | print("You can view the results at:") 263 | print(f"https://wandb.ai/mracic/flood-segmentation/sweeps/{sweep_id}") 264 | print(f"\nError details: {str(e)}") 265 | 266 | if __name__ == "__main__": 267 | # Create sweep configuration 268 | sweep_config = create_sweep_config() 269 | 270 | # Initialize sweep 271 | sweep_id = wandb.sweep(sweep_config) 272 | 273 | print(f"\nSweep created with ID: {sweep_id}") 274 | print(f"View sweep at: https://wandb.ai/mracic/flood-segmentation/sweeps/{sweep_id}") 275 | 276 | # Run sweep with retry mechanism 277 | successful_trials, failed_trials = run_sweep_with_retries(sweep_id) 278 | 279 | if successful_trials > 0: 280 | print(f"\nSweep completed with {successful_trials} successful trials and {failed_trials} failed trials") 281 | print("Attempting to retrieve optimal parameters...") 282 | print_optimal_parameters(sweep_id) 283 | else: 284 | print("\nNo successful trials completed. Cannot determine optimal parameters.") -------------------------------------------------------------------------------- /FloodsML/models/backbone/xception.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torch.utils.model_zoo as model_zoo 6 | from models.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d 7 | 8 | def fixed_padding(inputs, kernel_size, dilation): 9 | kernel_size_effective = kernel_size + (kernel_size - 1) * (dilation - 1) 10 | pad_total = kernel_size_effective - 1 11 | pad_beg = pad_total // 2 12 | pad_end = pad_total - pad_beg 13 | padded_inputs = F.pad(inputs, (pad_beg, pad_end, pad_beg, pad_end)) 14 | return padded_inputs 15 | 16 | 17 | class SeparableConv2d(nn.Module): 18 | def __init__(self, inplanes, planes, kernel_size=3, stride=1, dilation=1, bias=False, BatchNorm=None): 19 | super(SeparableConv2d, self).__init__() 20 | 21 | self.conv1 = nn.Conv2d(inplanes, inplanes, kernel_size, stride, 0, dilation, 22 | groups=inplanes, bias=bias) 23 | self.bn = BatchNorm(inplanes) 24 | self.pointwise = nn.Conv2d(inplanes, planes, 1, 1, 0, 1, 1, bias=bias) 25 | 26 | def forward(self, x): 27 | x = fixed_padding(x, self.conv1.kernel_size[0], dilation=self.conv1.dilation[0]) 28 | x = self.conv1(x) 29 | x = self.bn(x) 30 | x = self.pointwise(x) 31 | return x 32 | 33 | 34 | class Block(nn.Module): 35 | def __init__(self, inplanes, planes, reps, stride=1, dilation=1, BatchNorm=None, 36 | start_with_relu=True, grow_first=True, is_last=False): 37 | super(Block, self).__init__() 38 | 39 | if planes != inplanes or stride != 1: 40 | self.skip = nn.Conv2d(inplanes, planes, 1, stride=stride, bias=False) 41 | self.skipbn = BatchNorm(planes) 42 | else: 43 | self.skip = None 44 | 45 | self.relu = nn.ReLU(inplace=True) 46 | rep = [] 47 | 48 | filters = inplanes 49 | if grow_first: 50 | rep.append(self.relu) 51 | rep.append(SeparableConv2d(inplanes, planes, 3, 1, dilation, BatchNorm=BatchNorm)) 52 | rep.append(BatchNorm(planes)) 53 | filters = planes 54 | 55 | for i in range(reps - 1): 56 | rep.append(self.relu) 57 | rep.append(SeparableConv2d(filters, filters, 3, 1, dilation, BatchNorm=BatchNorm)) 58 | rep.append(BatchNorm(filters)) 59 | 60 | if not grow_first: 61 | rep.append(self.relu) 62 | rep.append(SeparableConv2d(inplanes, planes, 3, 1, dilation, BatchNorm=BatchNorm)) 63 | rep.append(BatchNorm(planes)) 64 | 65 | if stride != 1: 66 | rep.append(self.relu) 67 | rep.append(SeparableConv2d(planes, planes, 3, 2, BatchNorm=BatchNorm)) 68 | rep.append(BatchNorm(planes)) 69 | 70 | if stride == 1 and is_last: 71 | rep.append(self.relu) 72 | rep.append(SeparableConv2d(planes, planes, 3, 1, BatchNorm=BatchNorm)) 73 | rep.append(BatchNorm(planes)) 74 | 75 | if not start_with_relu: 76 | rep = rep[1:] 77 | 78 | self.rep = nn.Sequential(*rep) 79 | 80 | def forward(self, inp): 81 | x = self.rep(inp) 82 | 83 | if self.skip is not None: 84 | skip = self.skip(inp) 85 | skip = self.skipbn(skip) 86 | else: 87 | skip = inp 88 | 89 | x = x + skip 90 | 91 | return x 92 | 93 | 94 | class AlignedXception(nn.Module): 95 | """ 96 | Modified Alighed Xception 97 | """ 98 | def __init__(self, output_stride, BatchNorm, 99 | pretrained=True): 100 | super(AlignedXception, self).__init__() 101 | 102 | if output_stride == 16: 103 | entry_block3_stride = 2 104 | middle_block_dilation = 1 105 | exit_block_dilations = (1, 2) 106 | elif output_stride == 8: 107 | entry_block3_stride = 1 108 | middle_block_dilation = 2 109 | exit_block_dilations = (2, 4) 110 | else: 111 | raise NotImplementedError 112 | 113 | 114 | # Entry flow 115 | self.conv1 = nn.Conv2d(3, 32, 3, stride=2, padding=1, bias=False) 116 | self.bn1 = BatchNorm(32) 117 | self.relu = nn.ReLU(inplace=True) 118 | 119 | self.conv2 = nn.Conv2d(32, 64, 3, stride=1, padding=1, bias=False) 120 | self.bn2 = BatchNorm(64) 121 | 122 | self.block1 = Block(64, 128, reps=2, stride=2, BatchNorm=BatchNorm, start_with_relu=False) 123 | self.block2 = Block(128, 256, reps=2, stride=2, BatchNorm=BatchNorm, start_with_relu=False, 124 | grow_first=True) 125 | self.block3 = Block(256, 728, reps=2, stride=entry_block3_stride, BatchNorm=BatchNorm, 126 | start_with_relu=True, grow_first=True, is_last=True) 127 | 128 | # Middle flow 129 | self.block4 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 130 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 131 | self.block5 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 132 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 133 | self.block6 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 134 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 135 | self.block7 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 136 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 137 | self.block8 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 138 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 139 | self.block9 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 140 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 141 | self.block10 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 142 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 143 | self.block11 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 144 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 145 | self.block12 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 146 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 147 | self.block13 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 148 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 149 | self.block14 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 150 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 151 | self.block15 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 152 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 153 | self.block16 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 154 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 155 | self.block17 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 156 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 157 | self.block18 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 158 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 159 | self.block19 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 160 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 161 | 162 | # Exit flow 163 | self.block20 = Block(728, 1024, reps=2, stride=1, dilation=exit_block_dilations[0], 164 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=False, is_last=True) 165 | 166 | self.conv3 = SeparableConv2d(1024, 1536, 3, stride=1, dilation=exit_block_dilations[1], BatchNorm=BatchNorm) 167 | self.bn3 = BatchNorm(1536) 168 | 169 | self.conv4 = SeparableConv2d(1536, 1536, 3, stride=1, dilation=exit_block_dilations[1], BatchNorm=BatchNorm) 170 | self.bn4 = BatchNorm(1536) 171 | 172 | self.conv5 = SeparableConv2d(1536, 2048, 3, stride=1, dilation=exit_block_dilations[1], BatchNorm=BatchNorm) 173 | self.bn5 = BatchNorm(2048) 174 | 175 | # Init weights 176 | self._init_weight() 177 | 178 | # Load pretrained model 179 | if pretrained: 180 | self._load_pretrained_model() 181 | 182 | def forward(self, x): 183 | # Entry flow 184 | x = self.conv1(x) 185 | x = self.bn1(x) 186 | x = self.relu(x) 187 | 188 | x = self.conv2(x) 189 | x = self.bn2(x) 190 | x = self.relu(x) 191 | 192 | x = self.block1(x) 193 | # add relu here 194 | x = self.relu(x) 195 | low_level_feat = x 196 | x = self.block2(x) 197 | x = self.block3(x) 198 | 199 | # Middle flow 200 | x = self.block4(x) 201 | x = self.block5(x) 202 | x = self.block6(x) 203 | x = self.block7(x) 204 | x = self.block8(x) 205 | x = self.block9(x) 206 | x = self.block10(x) 207 | x = self.block11(x) 208 | x = self.block12(x) 209 | x = self.block13(x) 210 | x = self.block14(x) 211 | x = self.block15(x) 212 | x = self.block16(x) 213 | x = self.block17(x) 214 | x = self.block18(x) 215 | x = self.block19(x) 216 | 217 | # Exit flow 218 | x = self.block20(x) 219 | x = self.relu(x) 220 | x = self.conv3(x) 221 | x = self.bn3(x) 222 | x = self.relu(x) 223 | 224 | x = self.conv4(x) 225 | x = self.bn4(x) 226 | x = self.relu(x) 227 | 228 | x = self.conv5(x) 229 | x = self.bn5(x) 230 | x = self.relu(x) 231 | 232 | return x, low_level_feat 233 | 234 | def _init_weight(self): 235 | for m in self.modules(): 236 | if isinstance(m, nn.Conv2d): 237 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 238 | m.weight.data.normal_(0, math.sqrt(2. / n)) 239 | elif isinstance(m, SynchronizedBatchNorm2d): 240 | m.weight.data.fill_(1) 241 | m.bias.data.zero_() 242 | elif isinstance(m, nn.BatchNorm2d): 243 | m.weight.data.fill_(1) 244 | m.bias.data.zero_() 245 | 246 | 247 | def _load_pretrained_model(self): 248 | pretrain_dict = model_zoo.load_url('http://data.lip6.fr/cadene/pretrainedmodels/xception-b5690688.pth') 249 | model_dict = {} 250 | state_dict = self.state_dict() 251 | 252 | for k, v in pretrain_dict.items(): 253 | if k in model_dict: 254 | if 'pointwise' in k: 255 | v = v.unsqueeze(-1).unsqueeze(-1) 256 | if k.startswith('block11'): 257 | model_dict[k] = v 258 | model_dict[k.replace('block11', 'block12')] = v 259 | model_dict[k.replace('block11', 'block13')] = v 260 | model_dict[k.replace('block11', 'block14')] = v 261 | model_dict[k.replace('block11', 'block15')] = v 262 | model_dict[k.replace('block11', 'block16')] = v 263 | model_dict[k.replace('block11', 'block17')] = v 264 | model_dict[k.replace('block11', 'block18')] = v 265 | model_dict[k.replace('block11', 'block19')] = v 266 | elif k.startswith('block12'): 267 | model_dict[k.replace('block12', 'block20')] = v 268 | elif k.startswith('bn3'): 269 | model_dict[k] = v 270 | model_dict[k.replace('bn3', 'bn4')] = v 271 | elif k.startswith('conv4'): 272 | model_dict[k.replace('conv4', 'conv5')] = v 273 | elif k.startswith('bn4'): 274 | model_dict[k.replace('bn4', 'bn5')] = v 275 | else: 276 | model_dict[k] = v 277 | state_dict.update(model_dict) 278 | self.load_state_dict(state_dict) 279 | 280 | 281 | 282 | if __name__ == "__main__": 283 | import torch 284 | model = AlignedXception(BatchNorm=nn.BatchNorm2d, pretrained=True, output_stride=16) 285 | input = torch.rand(1, 3, 512, 512) 286 | output, low_level_feat = model(input) 287 | print(output.size()) 288 | print(low_level_feat.size()) -------------------------------------------------------------------------------- /FloodsML/optimize_parameters_extended.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.optim as optim 4 | from torch.utils.data import DataLoader 5 | from pathlib import Path 6 | import wandb 7 | import os 8 | import psutil 9 | import time 10 | from train_unet import train_model 11 | from flood_dataset import FloodSegmentationDataset 12 | from models.unet import UNet 13 | import pandas as pd 14 | 15 | # Define project structure 16 | PROJECT_BASE = 'flood-segmentation' 17 | OPTIMIZATION_PHASE = 'extended-optimization' # New phase for extended optimization 18 | PROJECT_NAME = f"{PROJECT_BASE}/{OPTIMIZATION_PHASE}" 19 | 20 | def get_gpu_memory_usage(): 21 | """Get GPU memory usage if available""" 22 | if torch.cuda.is_available(): 23 | return torch.cuda.memory_allocated() / 1024**3 # Convert to GB 24 | return 0 25 | 26 | def get_best_parameters_from_previous_sweep(sweep_name): 27 | """Retrieve best parameters from the previous optimization phase""" 28 | api = wandb.Api() 29 | 30 | # Get all sweeps in the project 31 | project = api.project(PROJECT_BASE) 32 | sweeps = project.sweeps() 33 | 34 | # Find the sweep with name 'data-loading-opt' 35 | previous_sweep = None 36 | for sweep in sweeps: 37 | if sweep.name == sweep_name: 38 | previous_sweep = sweep 39 | break 40 | 41 | if previous_sweep is None: 42 | raise ValueError(f"Could not find sweep with name '{sweep_name}'. Please make sure the previous optimization phase has completed.") 43 | 44 | best_run = previous_sweep.best_run() 45 | 46 | # Get the initial learning rate from the run history 47 | history = best_run.scan_history() 48 | #history_df = pd.DataFrame(history) 49 | initial_lr = best_run.config['learning_rate'] # history_df['learning_rate'].iloc[0] 50 | 51 | return { 52 | 'learning_rate': initial_lr, # Use initial learning rate instead of final config value 53 | 'weight_decay': best_run.config['weight_decay'], 54 | 'warmup_epochs': best_run.config['warmup_epochs'], 55 | 'plateau_factor': best_run.config['plateau_factor'], 56 | 'plateau_patience': best_run.config['plateau_patience'], 57 | 'batch_size': best_run.config['batch_size'], 58 | 'num_workers': best_run.config['num_workers'], 59 | 'prefetch_factor': best_run.config['prefetch_factor'] 60 | } 61 | 62 | def create_sweep_config(): 63 | """Create sweep configuration for extended parameter optimization""" 64 | # Get best parameters from previous sweep 65 | best_params = get_best_parameters_from_previous_sweep('data-loading-opt') 66 | 67 | sweep_config = { 68 | 'method': 'bayes', # Bayesian optimization 69 | 'metric': { 70 | 'name': 'val_iou', 71 | 'goal': 'maximize' 72 | }, 73 | 'name': OPTIMIZATION_PHASE, 74 | 'project': PROJECT_BASE, 75 | 'parameters': { 76 | # Learning rate parameters - finer range around best values 77 | 'learning_rate': { 78 | 'distribution': 'log_uniform_values', 79 | 'min': best_params['learning_rate'] * 0.5, 80 | 'max': best_params['learning_rate'] * 2.0, 81 | }, 82 | 'weight_decay': { 83 | 'distribution': 'log_uniform_values', 84 | 'min': best_params['weight_decay'] * 0.5, 85 | 'max': best_params['weight_decay'] * 2.0, 86 | }, 87 | 'warmup_epochs': { 88 | 'values': [max(1, best_params['warmup_epochs'] - 1), 89 | best_params['warmup_epochs'], 90 | min(10, best_params['warmup_epochs'] + 1)] 91 | }, 92 | 'plateau_factor': { 93 | 'values': [max(0.1, best_params['plateau_factor'] - 0.1), 94 | best_params['plateau_factor'], 95 | min(0.8, best_params['plateau_factor'] + 0.1)] 96 | }, 97 | 'plateau_patience': { 98 | 'values': [max(3, best_params['plateau_patience'] - 2), 99 | best_params['plateau_patience'], 100 | min(15, best_params['plateau_patience'] + 2)] 101 | }, 102 | 103 | # Data loading parameters - refined ranges 104 | 'batch_size': { 105 | 'values': [max(8, best_params['batch_size'] - 8), 106 | best_params['batch_size'], 107 | min(50, best_params['batch_size'] + 8)] 108 | }, 109 | 'num_workers': { 110 | 'values': [max(1, best_params['num_workers'] - 2), 111 | best_params['num_workers'], 112 | min(32, best_params['num_workers'] + 2)] 113 | }, 114 | 'prefetch_factor': { 115 | 'values': [max(1, best_params['prefetch_factor'] - 1), 116 | best_params['prefetch_factor'], 117 | min(8, best_params['prefetch_factor'] + 1)] 118 | }, 119 | 120 | # Additional parameters for extended optimization 121 | 'dropout_rate': { 122 | 'values': [0.0, 0.1, 0.2] 123 | }, 124 | 'bce_weight': { 125 | 'values': [0, 0.3, 0.5, 0.7, 1] # Dice weight will be 1 - bce_weight 126 | } 127 | } 128 | } 129 | return sweep_config 130 | 131 | def train_sweep(): 132 | """Training function for extended sweep""" 133 | wandb.init(project=PROJECT_BASE, group=OPTIMIZATION_PHASE) 134 | 135 | try: 136 | # Get hyperparameters from sweep 137 | config = { 138 | 'learning_rate': wandb.config.learning_rate, 139 | 'weight_decay': wandb.config.weight_decay, 140 | 'epochs': 50, # Increased for better convergence 141 | 'checkpoint_freq': 5, 142 | 'save_dir': f'checkpoints/{OPTIMIZATION_PHASE}/{wandb.run.id}', 143 | 'use_wandb': True, 144 | 'early_stopping_patience': 15, # Increased patience 145 | 'warmup_epochs': wandb.config.warmup_epochs, 146 | 'model': { 147 | 'n_channels': 2, 148 | 'n_classes': 1, 149 | 'bilinear': True, 150 | 'dropout_rate': wandb.config.dropout_rate # Added dropout_rate to model config 151 | }, 152 | 'batch_size': wandb.config.batch_size, 153 | 'bce_weight': wandb.config.bce_weight, 154 | 'dice_weight': 1.0 - wandb.config.bce_weight, # Ensure weights sum to 1.0 155 | 'plateau_factor': wandb.config.plateau_factor, 156 | 'plateau_patience': wandb.config.plateau_patience, 157 | 'num_workers': wandb.config.num_workers, 158 | 'prefetch_factor': wandb.config.prefetch_factor, 159 | 'optimization_phase': OPTIMIZATION_PHASE 160 | } 161 | 162 | # Log system information 163 | wandb.log({ 164 | 'system/cpu_count': os.cpu_count(), 165 | 'system/gpu_count': torch.cuda.device_count() if torch.cuda.is_available() else 0, 166 | 'system/gpu_name': torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'None', 167 | 'system/memory_total': psutil.virtual_memory().total / 1024**3, 168 | 'system/memory_available': psutil.virtual_memory().available / 1024**3, 169 | 'config': dict(config) # Convert to regular dict for JSON serialization 170 | }) 171 | 172 | # Create data directory if it doesn't exist 173 | data_dir = Path('data') 174 | data_dir.mkdir(parents=True, exist_ok=True) 175 | 176 | # Create datasets 177 | train_dataset = FloodSegmentationDataset( 178 | map_file=str(data_dir / 'train.txt'), 179 | base_path=str(data_dir), 180 | split='train' 181 | ) 182 | 183 | val_dataset = FloodSegmentationDataset( 184 | map_file=str(data_dir / 'val.txt'), 185 | base_path=str(data_dir), 186 | split='val' 187 | ) 188 | 189 | # Create dataloaders with sweep parameters 190 | train_loader = DataLoader( 191 | train_dataset, 192 | batch_size=config['batch_size'], 193 | shuffle=True, 194 | num_workers=config['num_workers'], 195 | pin_memory=True, 196 | persistent_workers=True, 197 | prefetch_factor=config['prefetch_factor'] 198 | ) 199 | 200 | val_loader = DataLoader( 201 | val_dataset, 202 | batch_size=config['batch_size'], 203 | shuffle=False, 204 | num_workers=config['num_workers'], 205 | pin_memory=True, 206 | persistent_workers=True, 207 | prefetch_factor=config['prefetch_factor'] 208 | ) 209 | 210 | # Initialize model 211 | model = UNet(**config['model']) 212 | 213 | # Train the model 214 | train_model(model, train_loader, val_loader, config) 215 | 216 | # Log final system metrics 217 | wandb.log({ 218 | 'system/final_gpu_memory': get_gpu_memory_usage(), 219 | 'system/final_cpu_percent': psutil.cpu_percent(), 220 | 'system/final_memory_percent': psutil.virtual_memory().percent 221 | }) 222 | 223 | except Exception as e: 224 | wandb.log({ 225 | 'error': str(e), 226 | 'error_type': type(e).__name__, 227 | 'failed_config': dict(wandb.config), # Convert to regular dict for JSON serialization 228 | 'system/final_gpu_memory': get_gpu_memory_usage(), 229 | 'system/final_cpu_percent': psutil.cpu_percent(), 230 | 'system/final_memory_percent': psutil.virtual_memory().percent 231 | }) 232 | raise e 233 | finally: 234 | if wandb.run: 235 | wandb.finish() 236 | 237 | def run_sweep_with_retries(sweep_id, max_retries=3, total_trials=30): 238 | """Run sweep with retry mechanism for failed trials""" 239 | api = wandb.Api() 240 | 241 | successful_trials = 0 242 | failed_trials = 0 243 | 244 | while successful_trials < total_trials: 245 | try: 246 | wandb.agent(sweep_id, function=train_sweep, count=1) 247 | successful_trials += 1 248 | print(f"Successfully completed trial {successful_trials}/{total_trials}") 249 | 250 | except Exception as e: 251 | failed_trials += 1 252 | print(f"Trial {successful_trials + failed_trials} failed with error: {str(e)}") 253 | 254 | if failed_trials >= max_retries: 255 | print(f"Too many consecutive failures ({failed_trials}). Stopping sweep.") 256 | break 257 | 258 | print(f"Retrying... (Attempt {failed_trials}/{max_retries})") 259 | continue 260 | sweep = api.sweep(f"{PROJECT_NAME}/{sweep_id}") 261 | sweep.state = "FINISHED" # or "FINISHED" (both stop the sweep) 262 | sweep.update() 263 | print(f"\nSweep completed with {successful_trials} successful trials and {failed_trials} failed trials") 264 | return successful_trials, failed_trials 265 | 266 | def print_optimal_parameters(sweep_id): 267 | """Print optimal parameters from sweep""" 268 | api = wandb.Api() 269 | 270 | time.sleep(5) # Allow for synchronization 271 | 272 | try: 273 | sweep = api.sweep(f"{PROJECT_NAME}/{sweep_id}") 274 | best_run = sweep.best_run() 275 | 276 | print("\nOptimal Parameters Found:") 277 | print("-" * 50) 278 | print(f"Best Validation IoU: {best_run.summary['val_iou']:.4f}") 279 | print("\nLearning Rate Parameters:") 280 | print(f"Learning Rate: {wandb.config.learning_rate:.2e}") 281 | print(f"Weight Decay: {wandb.config.weight_decay:.2e}") 282 | print(f"Warmup Epochs: {wandb.config.warmup_epochs}") 283 | print(f"Plateau Factor: {wandb.config.plateau_factor}") 284 | print(f"Plateau Patience: {wandb.config.plateau_patience}") 285 | print("\nData Loading Parameters:") 286 | print(f"Batch Size: {wandb.config.batch_size}") 287 | print(f"Number of Workers: {wandb.config.num_workers}") 288 | print(f"Prefetch Factor: {wandb.config.prefetch_factor}") 289 | print("\nAdditional Parameters:") 290 | print(f"Dropout Rate: {wandb.config.dropout_rate}") 291 | print(f"BCE Weight: {wandb.config.bce_weight}") 292 | print(f"Dice Weight: {wandb.config.dice_weight}") 293 | print("-" * 50) 294 | except Exception as e: 295 | print("\nCould not retrieve optimal parameters yet.") 296 | print("This is normal if the sweep is still running or results haven't synced.") 297 | print("You can view the results at:") 298 | print(f"https://wandb.ai/mracic/flood-segmentation/sweeps/{sweep_id}") 299 | print(f"\nError details: {str(e)}") 300 | 301 | if __name__ == "__main__": 302 | # Create sweep configuration 303 | sweep_config = create_sweep_config() 304 | 305 | # Initialize sweep 306 | sweep_id = wandb.sweep(sweep_config) 307 | 308 | print(f"\nExtended sweep created with ID: {sweep_id}") 309 | print(f"View sweep at: https://wandb.ai/mracic/flood-segmentation/sweeps/{sweep_id}") 310 | 311 | # Run sweep with retry mechanism 312 | successful_trials, failed_trials = run_sweep_with_retries(sweep_id) 313 | 314 | if successful_trials > 0: 315 | print(f"\nExtended sweep completed with {successful_trials} successful trials and {failed_trials} failed trials") 316 | print("Attempting to retrieve optimal parameters...") 317 | print_optimal_parameters(sweep_id) 318 | else: 319 | print("\nNo successful trials completed. Cannot determine optimal parameters.") -------------------------------------------------------------------------------- /FloodsML/train_inference_full.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.optim as optim 4 | from torch.utils.data import DataLoader 5 | from pathlib import Path 6 | import wandb 7 | import os 8 | import psutil 9 | import time 10 | import numpy as np 11 | from tqdm import tqdm 12 | from torchmetrics import JaccardIndex 13 | import matplotlib.pyplot as plt 14 | from models.unet import UNet 15 | from flood_dataset import FloodSegmentationDataset 16 | 17 | # Define project structure 18 | PROJECT_BASE = 'flood-segmentation' 19 | EXPERIMENT_NAME = 'single-optimized-run' 20 | PROJECT_NAME = f"{PROJECT_BASE}/{EXPERIMENT_NAME}" 21 | 22 | def get_gpu_memory_usage(): 23 | """Get GPU memory usage if available""" 24 | if torch.cuda.is_available(): 25 | return torch.cuda.memory_allocated() / 1024**3 # Convert to GB 26 | return 0 27 | 28 | def train_model(model, train_loader, val_loader, config, device): 29 | """Train the model with the given configuration""" 30 | model = model.to(device) 31 | 32 | # Setup loss function 33 | bce_weight = config['bce_weight'] 34 | dice_weight = config['dice_weight'] 35 | 36 | def combined_loss(pred, target): 37 | # Binary Cross Entropy Loss 38 | bce_loss = nn.BCEWithLogitsLoss()(pred, target) 39 | 40 | # Dice Loss 41 | pred_sigmoid = torch.sigmoid(pred) 42 | smooth = 1.0 43 | intersection = (pred_sigmoid * target).sum() 44 | dice_loss = 1 - (2.0 * intersection + smooth) / (pred_sigmoid.sum() + target.sum() + smooth) 45 | 46 | # Combine losses 47 | return bce_weight * bce_loss + dice_weight * dice_loss 48 | 49 | # Setup optimizer 50 | optimizer = optim.AdamW( 51 | model.parameters(), 52 | lr=config['learning_rate'], 53 | weight_decay=config['weight_decay'] 54 | ) 55 | 56 | # Setup learning rate scheduler with warmup 57 | def lr_lambda(epoch): 58 | if epoch < config['warmup_epochs']: 59 | return epoch / config['warmup_epochs'] 60 | return 1.0 61 | 62 | warmup_scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda) 63 | plateau_scheduler = optim.lr_scheduler.ReduceLROnPlateau( 64 | optimizer, 65 | mode='min', 66 | factor=config['plateau_factor'], 67 | patience=config['plateau_patience'], 68 | verbose=True 69 | ) 70 | 71 | # Metrics 72 | jaccard = JaccardIndex(task='binary').to(device) 73 | 74 | # Setup checkpointing 75 | save_dir = Path(config['save_dir']) 76 | save_dir.mkdir(parents=True, exist_ok=True) 77 | 78 | best_val_iou = 0.0 79 | best_epoch = 0 80 | no_improvement_count = 0 81 | 82 | # Initialize WandB if enabled 83 | if config.get('use_wandb', False): 84 | wandb.init(project=PROJECT_BASE, name=EXPERIMENT_NAME) 85 | wandb.config.update(config) 86 | 87 | for epoch in range(config['epochs']): 88 | model.train() 89 | train_loss = 0.0 90 | train_iou = 0.0 91 | 92 | # Training loop 93 | pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{config['epochs']}") 94 | for batch in pbar: 95 | images = batch['image'].to(device) 96 | masks = batch['label'].to(device) 97 | 98 | optimizer.zero_grad() 99 | outputs = model(images) 100 | loss = combined_loss(outputs, masks) 101 | 102 | loss.backward() 103 | optimizer.step() 104 | 105 | # Calculate IoU 106 | preds = (torch.sigmoid(outputs) > 0.5).float() 107 | batch_iou = jaccard(preds, masks) 108 | 109 | train_loss += loss.item() 110 | train_iou += batch_iou.item() 111 | 112 | pbar.set_postfix({'loss': loss.item(), 'iou': batch_iou.item()}) 113 | 114 | # Update warmup scheduler 115 | if epoch < config['warmup_epochs']: 116 | warmup_scheduler.step() 117 | 118 | # Calculate average metrics 119 | avg_train_loss = train_loss / len(train_loader) 120 | avg_train_iou = train_iou / len(train_loader) 121 | 122 | # Validation loop 123 | model.eval() 124 | val_loss = 0.0 125 | val_iou = 0.0 126 | 127 | with torch.no_grad(): 128 | for batch in tqdm(val_loader, desc="Validation"): 129 | images = batch['image'].to(device) 130 | masks = batch['label'].to(device) 131 | 132 | outputs = model(images) 133 | loss = combined_loss(outputs, masks) 134 | 135 | # Calculate IoU 136 | preds = (torch.sigmoid(outputs) > 0.5).float() 137 | batch_iou = jaccard(preds, masks) 138 | 139 | val_loss += loss.item() 140 | val_iou += batch_iou.item() 141 | 142 | # Calculate average validation metrics 143 | avg_val_loss = val_loss / len(val_loader) 144 | avg_val_iou = val_iou / len(val_loader) 145 | 146 | # Update plateau scheduler 147 | plateau_scheduler.step(avg_val_loss) 148 | 149 | # Log metrics 150 | print(f"Epoch {epoch+1}/{config['epochs']} - Train Loss: {avg_train_loss:.4f}, Train IoU: {avg_train_iou:.4f}, Val Loss: {avg_val_loss:.4f}, Val IoU: {avg_val_iou:.4f}") 151 | 152 | if config.get('use_wandb', False): 153 | wandb.log({ 154 | 'epoch': epoch + 1, 155 | 'train_loss': avg_train_loss, 156 | 'train_iou': avg_train_iou, 157 | 'val_loss': avg_val_loss, 158 | 'val_iou': avg_val_iou, 159 | 'learning_rate': optimizer.param_groups[0]['lr'], 160 | 'gpu_memory': get_gpu_memory_usage(), 161 | 'cpu_percent': psutil.cpu_percent(), 162 | 'memory_percent': psutil.virtual_memory().percent 163 | }) 164 | 165 | # Save checkpoint based on validation IoU 166 | if avg_val_iou > best_val_iou: 167 | best_val_iou = avg_val_iou 168 | best_epoch = epoch + 1 169 | no_improvement_count = 0 170 | 171 | # Save best model 172 | torch.save({ 173 | 'epoch': epoch + 1, 174 | 'model_state_dict': model.state_dict(), 175 | 'optimizer_state_dict': optimizer.state_dict(), 176 | 'train_loss': avg_train_loss, 177 | 'val_loss': avg_val_loss, 178 | 'train_iou': avg_train_iou, 179 | 'val_iou': avg_val_iou, 180 | 'config': config 181 | }, save_dir / 'best_model.pth') 182 | 183 | print(f"Saved best model with validation IoU: {best_val_iou:.4f}") 184 | else: 185 | no_improvement_count += 1 186 | 187 | # Save checkpoint periodically 188 | if (epoch + 1) % config['checkpoint_freq'] == 0: 189 | torch.save({ 190 | 'epoch': epoch + 1, 191 | 'model_state_dict': model.state_dict(), 192 | 'optimizer_state_dict': optimizer.state_dict(), 193 | 'train_loss': avg_train_loss, 194 | 'val_loss': avg_val_loss, 195 | 'train_iou': avg_train_iou, 196 | 'val_iou': avg_val_iou, 197 | 'config': config 198 | }, save_dir / f'checkpoint_epoch_{epoch+1}.pth') 199 | 200 | # Early stopping 201 | if no_improvement_count >= config['early_stopping_patience']: 202 | print(f"Early stopping triggered. No improvement for {no_improvement_count} epochs.") 203 | break 204 | 205 | print(f"Training completed. Best validation IoU: {best_val_iou:.4f} at epoch {best_epoch}") 206 | return best_val_iou, best_epoch 207 | 208 | def load_model(checkpoint_path, model, device): 209 | """Load trained model from checkpoint""" 210 | checkpoint = torch.load(checkpoint_path, map_location=device) 211 | model.load_state_dict(checkpoint['model_state_dict']) 212 | return model, checkpoint['epoch'], checkpoint['val_iou'] 213 | 214 | def predict(model, dataloader, device, save_dir=None): 215 | """ 216 | Run inference on the dataloader 217 | 218 | Args: 219 | model: Trained PyTorch model 220 | dataloader: DataLoader containing test/validation data 221 | device: torch device 222 | save_dir: Optional directory to save prediction visualizations 223 | 224 | Returns: 225 | metrics: dict containing evaluation metrics 226 | """ 227 | model.eval() 228 | jaccard_metric = JaccardIndex(task='binary').to(device) 229 | total_iou_metric = 0.0 230 | num_batches = 0 231 | 232 | if save_dir: 233 | save_dir = Path(save_dir) 234 | save_dir.mkdir(parents=True, exist_ok=True) 235 | 236 | with torch.no_grad(): 237 | for idx, batch in enumerate(tqdm(dataloader, desc='Predicting')): 238 | images = batch['image'].to(device) # Shape: (N, C, H, W) 239 | masks = batch['label'].to(device) # Shape: (N, 1, H, W) 240 | image_paths = batch['image_path'] 241 | 242 | # Create a mask for pixels where *input* image is zero 243 | zero_mask = (images == 0).all(dim=1, keepdim=True) # Shape: (N, 1, H, W) 244 | 245 | # Forward pass 246 | outputs = model(images) 247 | preds_sigmoid = torch.sigmoid(outputs) 248 | preds_binary = (preds_sigmoid > 0.5).float() 249 | 250 | valid_pixels_mask = ~zero_mask.expand_as(preds_binary) # Shape: (N, 1, H, W), bool 251 | 252 | # Filter predictions and masks based on valid pixels *before* calculating batch IoU 253 | # Only calculate IoU on batches where there are *any* valid pixels 254 | if valid_pixels_mask.any(): 255 | preds_masked = preds_binary[valid_pixels_mask] 256 | masks_masked = masks[valid_pixels_mask] 257 | 258 | # Calculate torchmetrics IoU on the filtered pixels 259 | # Handle cases where masks_masked might be all zeros after filtering 260 | if masks_masked.sum() > 0 or preds_masked.sum() > 0: # Avoid calculating IoU if both are empty 261 | batch_iou_val = jaccard_metric(preds_masked, masks_masked) 262 | total_iou_metric += batch_iou_val.item() 263 | num_batches += 1 # Only count batches where IoU was meaningfully calculated 264 | elif masks_masked.sum() == 0 and preds_masked.sum() == 0: 265 | # If both ground truth and prediction are empty in the valid region, IoU is 1 266 | total_iou_metric += 1.0 267 | num_batches += 1 268 | # Else: one is empty, the other isn't -> IoU is 0, torchmetrics handles this, but we add 0 explicitly if needed 269 | # Or simply rely on torchmetrics default behaviour for these cases. 270 | 271 | # --- Calculate Per-Sample IoU for Filenames & Save Images --- 272 | if save_dir: 273 | # Iterate through samples in the *current* batch 274 | for i in range(images.shape[0]): 275 | # Get individual sample tensors 276 | single_pred = preds_binary[i] # Shape: (1, H, W) 277 | single_mask = masks[i] # Shape: (1, H, W) 278 | single_zero_mask = zero_mask[i] # Shape: (1, H, W) 279 | 280 | # Apply the zero mask to the single sample 281 | single_valid_pixels = ~single_zero_mask.expand_as(single_pred) # Shape: (1, H, W), bool 282 | 283 | sample_iou = 0.0 # Default IoU if no valid pixels or calculation fails 284 | if single_valid_pixels.any(): 285 | pred_masked_single = single_pred[single_valid_pixels] # Flattened valid pixels 286 | mask_masked_single = single_mask[single_valid_pixels] # Flattened valid pixels 287 | 288 | # Calculate IoU for this single sample (ignoring zero pixels) 289 | jaccard_metric.reset() # Reset internal state for single sample calc 290 | try: 291 | # Handle case where both are empty after masking -> IoU=1 292 | if mask_masked_single.sum() == 0 and pred_masked_single.sum() == 0: 293 | sample_iou = 1.0 294 | else: 295 | sample_iou = jaccard_metric(pred_masked_single, mask_masked_single).item() 296 | except Exception as e: 297 | print(f"Warning: Could not compute IoU for sample {i} in batch {idx}. Error: {e}") 298 | sample_iou = 0.0 # Assign 0 IoU if calculation fails 299 | 300 | # --- Plotting --- 301 | fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(15, 5)) 302 | 303 | # Original image (assuming 2 channels, need to handle display) 304 | img_display = images[i].cpu().numpy().transpose(1, 2, 0) # H, W, C=2 305 | # Create a 3-channel image for display (e.g., use one channel for R/G, zero for B) 306 | if img_display.shape[2] == 2: 307 | display_rgb = np.zeros((img_display.shape[0], img_display.shape[1], 3), dtype=img_display.dtype) 308 | display_rgb[:, :, 0] = img_display[:, :, 0] # VV channel to Red 309 | display_rgb[:, :, 1] = img_display[:, :, 1] # VH channel to Green 310 | max_val = display_rgb.max() 311 | if max_val > 0: 312 | display_rgb = display_rgb / max_val 313 | ax1.imshow(display_rgb) 314 | else: # Handle other channel numbers if necessary 315 | ax1.imshow(np.squeeze(img_display), cmap='gray') # Fallback for single channel 316 | 317 | ax1.set_title('Original Image') 318 | ax1.axis('off') 319 | 320 | # Ground truth 321 | ax2.imshow(single_mask.cpu().numpy().squeeze(), cmap='gray') 322 | ax2.set_title('Ground Truth') 323 | ax2.axis('off') 324 | 325 | # Prediction 326 | ax3.imshow(single_pred.cpu().numpy().squeeze(), cmap='gray') 327 | ax3.set_title(f'Prediction (IoU: {sample_iou:.3f})') 328 | ax3.axis('off') 329 | 330 | # Generate filename 331 | parts = image_paths[i].replace('\\', '/').split("/") 332 | try: 333 | folder = parts[-2] # e.g., "EMSR756_AOI03_DEL_PRODUCT" 334 | filename_with_ext = parts[-1] # e.g., "tile_006_0_01_Sentinel1_vh.tif" 335 | filename = Path(filename_with_ext).stem # remove extension robustly 336 | new_name = f"{folder}_{filename}" 337 | except IndexError: 338 | new_name = f"batch_{idx}_sample_{i}" # Fallback name 339 | 340 | # Save the figure with IoU calculated *ignoring zero pixels* 341 | plt.savefig(save_dir / f'{str(int(sample_iou*100)).zfill(3)}_prediction_{new_name}.png') 342 | plt.close(fig) # Close the figure explicitly 343 | 344 | # Calculate final average IoU 345 | mean_iou = total_iou_metric / num_batches if num_batches > 0 else 0.0 346 | metrics = { 347 | 'mean_iou': mean_iou 348 | } 349 | 350 | return metrics 351 | 352 | def main(): 353 | # Set the most promising parameters from top runs 354 | config = { 355 | 'learning_rate': 0.000008, 356 | 'weight_decay': 0.000236, 357 | 'warmup_epochs': 5, 358 | 'plateau_factor': 0.2, 359 | 'plateau_patience': 9, 360 | 'batch_size': 16, 361 | 'num_workers': 6, 362 | 'prefetch_factor': 2, 363 | 'dropout_rate': 0.0, 364 | 'bce_weight': 0.2, 365 | 'dice_weight': 0.8, 366 | 'epochs': 200, 367 | 'checkpoint_freq': 5, 368 | 'save_dir': f'checkpointsDice/{EXPERIMENT_NAME}', 369 | 'use_wandb': True, 370 | 'early_stopping_patience': 15, 371 | 'model': { 372 | 'n_channels': 2, 373 | 'n_classes': 1, 374 | 'bilinear': True, 375 | 'dropout_rate': 0.0 # Use dropout_rate from config 376 | } 377 | } 378 | 379 | # Initialize WandB 380 | if config['use_wandb']: 381 | wandb.init(project=PROJECT_BASE, name=EXPERIMENT_NAME) 382 | wandb.config.update(config) 383 | 384 | # Log system information 385 | if config['use_wandb']: 386 | wandb.log({ 387 | 'system/cpu_count': os.cpu_count(), 388 | 'system/gpu_count': torch.cuda.device_count() if torch.cuda.is_available() else 0, 389 | 'system/gpu_name': torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'None', 390 | 'system/memory_total': psutil.virtual_memory().total / 1024**3, 391 | 'system/memory_available': psutil.virtual_memory().available / 1024**3 392 | }) 393 | 394 | # Create data directory if it doesn't exist 395 | data_dir = Path('data') 396 | data_dir.mkdir(parents=True, exist_ok=True) 397 | 398 | # Create datasets 399 | train_dataset = FloodSegmentationDataset( 400 | map_file=str(data_dir / 'full_list.txt'), 401 | base_path=str(data_dir), 402 | split='train' 403 | ) 404 | 405 | val_dataset = FloodSegmentationDataset( 406 | map_file=str(data_dir / 'full_list.txt'), 407 | base_path=str(data_dir), 408 | split='val' 409 | ) 410 | 411 | # Create dataloaders 412 | train_loader = DataLoader( 413 | train_dataset, 414 | batch_size=config['batch_size'], 415 | shuffle=True, 416 | num_workers=config['num_workers'], 417 | pin_memory=True, 418 | persistent_workers=True, 419 | prefetch_factor=config['prefetch_factor'] 420 | ) 421 | 422 | val_loader = DataLoader( 423 | val_dataset, 424 | batch_size=config['batch_size'], 425 | shuffle=False, 426 | num_workers=config['num_workers'], 427 | pin_memory=True, 428 | persistent_workers=True, 429 | prefetch_factor=config['prefetch_factor'] 430 | ) 431 | 432 | # Initialize model 433 | model = UNet(**config['model']) 434 | 435 | # Set device 436 | device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu') 437 | print(f"Using device: {device}") 438 | 439 | # Train the model 440 | best_val_iou, best_epoch = train_model(model, train_loader, val_loader, config, device) 441 | print(f"\nTraining completed. Best validation IoU: {best_val_iou:.4f} at epoch {best_epoch}") 442 | 443 | # Load the best model for inference 444 | checkpoint_path = Path(config['save_dir']) / 'best_model.pth' 445 | if checkpoint_path.exists(): 446 | print(f"\nRunning inference with best model from {checkpoint_path}") 447 | model, epoch, val_iou = load_model(checkpoint_path, model, device) 448 | model = model.to(device) 449 | print(f"Loaded model from epoch {epoch} with validation IoU: {val_iou:.4f}") 450 | 451 | # Create test dataset and dataloader for inference 452 | test_dataset = FloodSegmentationDataset( 453 | map_file=str(data_dir / 'full_list.txt'), # Using validation set for testing 454 | base_path=str(data_dir), 455 | split='test' 456 | ) 457 | 458 | test_loader = DataLoader( 459 | test_dataset, 460 | batch_size=config['batch_size'], 461 | shuffle=False, 462 | num_workers=config['num_workers'], 463 | pin_memory=True 464 | ) 465 | 466 | # Run inference 467 | save_dir = f'predictionsDice/{EXPERIMENT_NAME}' 468 | metrics = predict(model, test_loader, device, save_dir) 469 | 470 | print(f"Inference completed. Mean IoU: {metrics['mean_iou']:.4f}") 471 | 472 | # Log final results to WandB 473 | if config['use_wandb']: 474 | wandb.log({ 475 | 'final_val_iou': best_val_iou, 476 | 'test_iou': metrics['mean_iou'], 477 | 'best_epoch': best_epoch, 478 | 'system/final_gpu_memory': get_gpu_memory_usage(), 479 | 'system/final_cpu_percent': psutil.cpu_percent(), 480 | 'system/final_memory_percent': psutil.virtual_memory().percent 481 | }) 482 | 483 | # Log sample prediction images to WandB 484 | #if Path(save_dir).exists(): 485 | # for image_path in Path(save_dir).glob('prediction_*.png'): 486 | # wandb.log({f"predictions/{image_path.name}": wandb.Image(str(image_path))}) 487 | else: 488 | print(f"Warning: Could not find checkpoint at {checkpoint_path}") 489 | 490 | # Finish WandB run 491 | if config['use_wandb'] and wandb.run: 492 | wandb.finish() 493 | 494 | if __name__ == "__main__": 495 | main() --------------------------------------------------------------------------------