├── dev_floodsml_Dockerfile
├── FloodsML
├── outpit_file.tif.aux.xml
├── install.txt
├── data_generators
│ └── data_generator.py
├── merge_tifs.py
├── README.md
├── utils
│ ├── metrics.py
│ ├── lr_scheduler.py
│ ├── datagen_utils.py
│ └── training_utils.py
├── .gitignore
├── models
│ ├── unet.py
│ └── backbone
│ │ └── xception.py
├── check_dimensions.py
├── flood_dataset.py
├── check_sentinel_values.py
├── train_unet.py
├── inference.py
├── sweep_review.ipynb
├── mg_test.ipynb
├── trainers
│ └── trainer.py
├── 0_analyze_labels.py
├── 0_create_splits.py
├── optimize_parameters.py
├── optimize_parameters_extended.py
└── train_inference_full.py
├── readme.md
├── postprocessing.py
├── stac_generator.py
├── .gitignore
├── HAND
├── compute_flood_map.py
├── mosaic_dem_tools.py
└── compute_hand.py
└── preprocessing.py
/dev_floodsml_Dockerfile:
--------------------------------------------------------------------------------
1 | FROM ghcr.io/osgeo/gdal:ubuntu-full-latest
2 |
3 | # Install latest pip
4 | RUN apt-get update && apt-get install -y python3-pip curl && apt-get clean \
5 | && apt-get remove -y python3-pip \
6 | && curl -sS https://bootstrap.pypa.io/get-pip.py -o get-pip.py \
7 | && python3 get-pip.py --break-system-packages \
8 | && rm get-pip.py
9 |
10 | # Copy requirements
11 | COPY requirements.txt .
12 |
13 | # Install Python deps
14 | RUN pip3 install --break-system-packages --no-cache-dir -r requirements.txt --ignore-installed jsonschema
15 |
16 | # Set workdir
17 | WORKDIR /app
18 |
19 | # Default command
20 | CMD ["/bin/bash"]
--------------------------------------------------------------------------------
/FloodsML/outpit_file.tif.aux.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 | -0.25
6 | 1.25
7 | 2
8 | 0
9 | 0
10 | 3181282|65954
11 |
12 |
13 |
14 | 0
15 | 1
16 | 0.020310812025981
17 | 0.14106127371049
18 | 100
19 |
20 |
21 |
22 |
--------------------------------------------------------------------------------
/FloodsML/install.txt:
--------------------------------------------------------------------------------
1 | https://visualstudio.microsoft.com/visual-cpp-build-tools/
2 |
3 | mamba create -n poplave python=3.10 -y && mamba activate poplave
4 | mamba install pytorch torchvision torchaudio torchmetrics tqdm wandb jupyterlab pandas numpy matplotlib seaborn gdal pytorch-cuda=12.4 -c pytorch -c nvidia -y
5 | python -c "import torch; print(torch.cuda.is_available())"
6 |
7 | #transformers
8 | mamba create -n poplave_transformers python=3.12 -y
9 | mamba activate poplave_transformers
10 | mamba install conda-forge::transformers gdal
11 | pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu126 torchmetrics
12 |
13 | #uv
14 | uv venv poplave_transformers --python=3.12 && poplave_transformers\Scripts\uv pip install transformers gdal torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu126 torchmetrics
15 |
16 | uv venv poplave_transformers --python=3.12
17 | poplave_transformers\Scripts\activate
18 | uv pip install transformers
19 | uv pip install gdal rasterio
20 |
21 | cd Documents\poplave && activate poplave_transformers
--------------------------------------------------------------------------------
/FloodsML/data_generators/data_generator.py:
--------------------------------------------------------------------------------
1 | #from data_generators.datasets import cityscapes, coco, combine_dbs, pascal, sbd, deepfashion
2 | from torch.utils.data import DataLoader
3 | from data_generators.deepfashion import DeepFashionSegmentation
4 |
5 |
6 | def initialize_data_loader(config):
7 |
8 |
9 | if config['dataset']['dataset_name'] == 'deepfashion':
10 | train_set = DeepFashionSegmentation(config, split='train')
11 | val_set = DeepFashionSegmentation(config, split='val')
12 | test_set = DeepFashionSegmentation(config, split='test')
13 |
14 | else:
15 | raise Exception('dataset not implemented yet!')
16 |
17 | num_classes = train_set.num_classes
18 | train_loader = DataLoader(train_set, batch_size=config['training']['batch_size'], shuffle=True, num_workers=config['training']['workers'], pin_memory=True)
19 | val_loader = DataLoader(val_set, batch_size=config['training']['batch_size'], shuffle=False, num_workers=config['training']['workers'], pin_memory=True)
20 | test_loader = DataLoader(test_set, batch_size=config['training']['batch_size'], shuffle=False, num_workers=config['training']['workers'], pin_memory=True)
21 |
22 | return train_loader, val_loader, test_loader, num_classes
23 |
24 |
--------------------------------------------------------------------------------
/FloodsML/merge_tifs.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | import os
3 | import argparse
4 | from osgeo import gdal
5 | gdal.UseExceptions()
6 |
7 | def merge_tifs(input_folder, output_tif):
8 | # collect all .tif files from the input folder
9 | tif_files = [
10 | os.path.join(input_folder, f)
11 | for f in os.listdir(input_folder)
12 | if f.lower().endswith(".tif")
13 | ]
14 |
15 | if not tif_files:
16 | print(f"⚠️ No .tif files found in folder: {input_folder}")
17 | return
18 |
19 | # merge into a single GeoTIFF
20 | gdal.Warp(output_tif, tif_files, format="GTiff")
21 | print(f"✅ Merged raster created: {output_tif}")
22 |
23 | def main():
24 | parser = argparse.ArgumentParser(
25 | description="Merge all GeoTIFF files from a folder into a single raster."
26 | )
27 | parser.add_argument(
28 | "-i", "--input",
29 | required=True,
30 | help="Input folder containing GeoTIFF files"
31 | )
32 | parser.add_argument(
33 | "-o", "--output",
34 | required=True,
35 | help="Output GeoTIFF file (e.g. merged.tif)"
36 | )
37 | args = parser.parse_args()
38 |
39 | merge_tifs(args.input, args.output)
40 |
41 | if __name__ == "__main__":
42 | main()
43 |
--------------------------------------------------------------------------------
/FloodsML/README.md:
--------------------------------------------------------------------------------
1 | # Flood area mapping using deep learning
2 |
3 | ## Virtual environment
4 | First you need to create a virtual environment.
5 |
6 | Using Conda you can type:
7 |
8 | ```
9 | mamba create -n floods python=3.10 -y
10 | mamba activate floods
11 | mamba install pytorch torchvision torchaudio torchmetrics tqdm wandb psutils jupyterlab pandas numpy matplotlib seaborn gdal pytorch-cuda=12.4 -c pytorch -c nvidia -y
12 | ```
13 |
14 | Check if cuda is available
15 | ```
16 | python -c "import torch; print(torch.cuda.is_available())"
17 | ```
18 |
19 | ## Files
20 |
21 | Files starting with 0_* were used to analyse the dataset and provide splits that had flood pixels. Files starting with check_* were used to ensure that the expected values were available and ensure correctness of the dataset.
22 |
23 | ## Training
24 |
25 | Files starting with optimize_* were used to determine optimal training parameters for the selected model and dataset. While the train_unet.py file was used to train a single model. With the sweep_review.ipynb providing insights into the top performing models.
26 |
27 | ## Inference
28 |
29 | Model inference can be run using the inference.py script by providing input, output folder and model path. Additionally flag --device can be added to manually select the desired device.
30 | ```
31 | python inference.py /path/to/input_tiffs /path/to/output_predictions /path/to/your/best_model.pth
32 | ```
33 |
--------------------------------------------------------------------------------
/FloodsML/utils/metrics.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | class Evaluator(object):
5 | def __init__(self, num_class):
6 | self.num_class = num_class
7 | self.confusion_matrix = np.zeros((self.num_class,)*2)
8 |
9 | def Pixel_Accuracy(self):
10 | Acc = np.diag(self.confusion_matrix).sum() / self.confusion_matrix.sum()
11 | return Acc
12 |
13 | def Pixel_Accuracy_Class(self):
14 | Acc = np.diag(self.confusion_matrix) / self.confusion_matrix.sum(axis=1)
15 | Acc = np.nanmean(Acc)
16 | return Acc
17 |
18 | def Mean_Intersection_over_Union(self):
19 | MIoU = np.diag(self.confusion_matrix) / (
20 | np.sum(self.confusion_matrix, axis=1) + np.sum(self.confusion_matrix, axis=0) -
21 | np.diag(self.confusion_matrix))
22 | MIoU = np.nanmean(MIoU)
23 | return MIoU
24 |
25 | def Frequency_Weighted_Intersection_over_Union(self):
26 | freq = np.sum(self.confusion_matrix, axis=1) / np.sum(self.confusion_matrix)
27 | iu = np.diag(self.confusion_matrix) / (
28 | np.sum(self.confusion_matrix, axis=1) + np.sum(self.confusion_matrix, axis=0) -
29 | np.diag(self.confusion_matrix))
30 |
31 | FWIoU = (freq[freq > 0] * iu[freq > 0]).sum()
32 | return FWIoU
33 |
34 | def _generate_matrix(self, gt_image, pre_image):
35 | mask = (gt_image >= 0) & (gt_image < self.num_class)
36 | label = self.num_class * gt_image[mask].astype('int') + pre_image[mask]
37 | count = np.bincount(label, minlength=self.num_class**2)
38 | confusion_matrix = count.reshape(self.num_class, self.num_class)
39 | return confusion_matrix
40 |
41 | def add_batch(self, gt_image, pre_image):
42 | assert gt_image.shape == pre_image.shape
43 | self.confusion_matrix += self._generate_matrix(gt_image, pre_image)
44 |
45 | def reset(self):
46 | self.confusion_matrix = np.zeros((self.num_class,) * 2)
47 |
--------------------------------------------------------------------------------
/FloodsML/utils/lr_scheduler.py:
--------------------------------------------------------------------------------
1 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
2 | ## Created by: Hang Zhang
3 | ## ECE Department, Rutgers University
4 | ## Email: zhang.hang@rutgers.edu
5 | ## Copyright (c) 2017
6 | ##
7 | ## This source code is licensed under the MIT-style license found in the
8 | ## LICENSE file in the root directory of this source tree
9 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
10 |
11 | import math
12 |
13 | class LR_Scheduler(object):
14 | """Learning Rate Scheduler
15 |
16 | Step mode: ``lr = baselr * 0.1 ^ {floor(epoch-1 / lr_step)}``
17 |
18 | Cosine mode: ``lr = baselr * 0.5 * (1 + cos(iter/maxiter))``
19 |
20 | Poly mode: ``lr = baselr * (1 - iter/maxiter) ^ 0.9``
21 |
22 | Args:
23 | args:
24 | :attr:`args.lr_scheduler` lr scheduler mode (`cos`, `poly`),
25 | :attr:`args.lr` base learning rate, :attr:`args.epochs` number of epochs,
26 | :attr:`args.lr_step`
27 |
28 | iters_per_epoch: number of iterations per epoch
29 | """
30 | def __init__(self, mode, base_lr, num_epochs, iters_per_epoch=0,
31 | lr_step=0, warmup_epochs=0):
32 | self.mode = mode
33 | print('Using {} LR Scheduler!'.format(self.mode))
34 | self.lr = base_lr
35 | if mode == 'step':
36 | assert lr_step
37 | self.lr_step = lr_step
38 | self.iters_per_epoch = iters_per_epoch
39 | self.N = num_epochs * iters_per_epoch
40 | self.epoch = -1
41 | self.warmup_iters = warmup_epochs * iters_per_epoch
42 |
43 | def __call__(self, optimizer, i, epoch, best_pred):
44 | T = epoch * self.iters_per_epoch + i
45 | if self.mode == 'cos':
46 | lr = 0.5 * self.lr * (1 + math.cos(1.0 * T / self.N * math.pi))
47 | elif self.mode == 'poly':
48 | lr = self.lr * pow((1 - 1.0 * T / self.N), 0.9)
49 | elif self.mode == 'step':
50 | lr = self.lr * (0.1 ** (epoch // self.lr_step))
51 | else:
52 | raise NotImplemented
53 | # warm up lr schedule
54 | if self.warmup_iters > 0 and T < self.warmup_iters:
55 | lr = lr * 1.0 * T / self.warmup_iters
56 | if epoch > self.epoch:
57 | print('\n=>Epoches %i, learning rate = %.4f, \
58 | previous best = %.4f' % (epoch, lr, best_pred))
59 | self.epoch = epoch
60 | assert lr >= 0
61 | self._adjust_learning_rate(optimizer, lr)
62 |
63 | def _adjust_learning_rate(self, optimizer, lr):
64 | if len(optimizer.param_groups) == 1:
65 | optimizer.param_groups[0]['lr'] = lr
66 | else:
67 | # enlarge the lr at the head
68 | optimizer.param_groups[0]['lr'] = lr
69 | for i in range(1, len(optimizer.param_groups)):
70 | optimizer.param_groups[i]['lr'] = lr * 10
71 |
--------------------------------------------------------------------------------
/readme.md:
--------------------------------------------------------------------------------
1 | # FloodsML – Flood Detection with Deep Learning
2 |
3 | FloodsML is a deep learning–based workflow for flood extent detection using Sentinel-1 SAR imagery and DEM data. The pipeline is packaged in Docker for reproducibility and portability, and it integrates seamlessly with STAC-compliant datasets.
4 |
5 | ---
6 |
7 | ## 🚀 Setup
8 |
9 | ### 1. Build Docker image
10 | From the project root directory, build the Docker image:
11 |
12 | ```bash
13 | docker build -t dev_floodsml -f dev_floodsml_Dockerfile .
14 | ```
15 |
16 | ### 2. Run container with mounted input data
17 | Prepare an input folder on your local machine (e.g., `C:\Users\\Desktop\Dev FloodsML\input_src`) and mount it into the container:
18 |
19 | ```bash
20 | docker run -it --rm --name floods_demo_container ^
21 | -v "C:\Users\\Desktop\Dev FloodsML\input_src:/app/input_src" ^
22 | dev_floodsml
23 | ```
24 |
25 | ---
26 |
27 | ## 📥 Input Data
28 |
29 | The mounted folder **must contain**:
30 | - **Sentinel-1 before-flood acquisitions** (VV and VH)
31 | - **Sentinel-1 after-flood acquisitions** (VV and VH)
32 | - **DEM file(s)** (one or more, named `copdem.tif`)
33 |
34 | > ⚠️ All input data must follow the **STAC item structure** so that metadata, georeferencing, and provenance are preserved.
35 |
36 | ---
37 |
38 | ## ▶️ Running the model
39 |
40 | Inside the container, run:
41 |
42 | ```bash
43 | python /app/env/main.py \
44 | --aoi_wkt "POLYGON((...))" \
45 | [--hand] \
46 | [--buffer 5000] \
47 | [--treshold 5]
48 | ```
49 |
50 | ### Parameters
51 | - `--aoi_wkt` (required): AOI polygon in WGS84 (EPSG:4326). Must be of type **POLYGON**.
52 | - `--hand` (optional): Apply HAND-based filtering in post-processing.
53 | - `--buffer` (optional, default: `5000`): Buffer distance for HAND calculation.
54 | - `--treshold` (optional, default: `5`): Threshold parameter for HAND filtering.
55 |
56 | ---
57 |
58 | ## 📤 Outputs
59 |
60 | Results are published to:
61 |
62 | ```
63 | /app/tmp/results
64 | ```
65 |
66 | If you want to make the results visible on the host system, map the results directory when starting the container, e.g.:
67 |
68 | ```bash
69 | docker run -it --rm --name floods_demo_container ^
70 | -v "C:\Users\\Desktop\Dev FloodsML\input_src:/app/input_src" ^
71 | -v "C:\Users\\Desktop\Dev FloodsML\results:/app/tmp/results" ^
72 | dev_floodsml
73 | ```
74 |
75 | The outputs are:
76 | - **Cloud-Optimized GeoTIFF (COG) masks of flooded areas**
77 | - Wrapped as **STAC items** for interoperability
78 |
79 | The masks contain three classes:
80 | - `0` – non-water
81 | - `1` – permanent hydrography
82 | - `2` – flooded areas
83 |
84 | This structure ensures compatibility with catalog-based workflows and easy integration with other disaster mapping products.
85 |
86 | ---
87 |
88 | ## 🔄 Workflow Overview
89 |
90 | **STAC input items → Preprocessing & Inference → STAC output items (COG masks)**
91 |
92 | The workflow ensures that both input and output remain fully standardized, enabling transparent use in broader EO platforms and services.
93 |
94 | ---
95 |
96 | 📌 Ready-to-use, portable, and interoperable — FloodsML can be deployed wherever flood mapping support is required.
97 | "# floodsdl-mcube"
98 |
--------------------------------------------------------------------------------
/FloodsML/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/#use-with-ide
110 | .pdm.toml
111 |
112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113 | __pypackages__/
114 |
115 | # Celery stuff
116 | celerybeat-schedule
117 | celerybeat.pid
118 |
119 | # SageMath parsed files
120 | *.sage.py
121 |
122 | # Environments
123 | .env
124 | .venv
125 | env/
126 | venv/
127 | ENV/
128 | env.bak/
129 | venv.bak/
130 |
131 | # Spyder project settings
132 | .spyderproject
133 | .spyproject
134 |
135 | # Rope project settings
136 | .ropeproject
137 |
138 | # mkdocs documentation
139 | /site
140 |
141 | # mypy
142 | .mypy_cache/
143 | .dmypy.json
144 | dmypy.json
145 |
146 | # Pyre type checker
147 | .pyre/
148 |
149 | # pytype static type analyzer
150 | .pytype/
151 |
152 | # Cython debug symbols
153 | cython_debug/
154 |
155 | # PyCharm
156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158 | # and can be added to the global gitignore or merged into this file. For a more nuclear
159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160 | #.idea/
161 | *.wandb
162 | *.json
163 | *.yaml
164 | *.pth
165 | *.tif
166 | /wandb
167 | *.csv
168 | /data
169 | /.vs
170 | /predictions
171 | /predictionsDice
172 |
--------------------------------------------------------------------------------
/postprocessing.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import rasterio
4 | from rasterio.warp import reproject, Resampling
5 | from scipy.ndimage import median_filter
6 | from skimage.morphology import (remove_small_objects, remove_small_holes, disk, opening, closing)
7 | import numpy as np
8 |
9 | def process_raster(in_path, out_path=None,
10 | min_obj_size=4, min_hole_size=4, min_region_size=20):
11 | """
12 | Load a raster, clean it, and save the mask as GeoTIFF.
13 | Returns: path to the output raster.
14 | """
15 | # Read input raster
16 | with rasterio.open(in_path) as src:
17 | arr = src.read(1, out_dtype='uint8')
18 | profile = src.profile
19 |
20 | # Convert to binary mask (0/1)
21 | arr = (arr > 0).astype(np.uint8)
22 |
23 | # Apply cleaning steps
24 | arr = remove_small_objects(arr.astype(bool), min_size=min_obj_size)
25 | arr = remove_small_holes(arr, area_threshold=min_hole_size)
26 | arr = median_filter(arr.astype(np.uint8), size=2)
27 | arr = closing(opening(arr, disk(1)), disk(1))
28 | arr = remove_small_objects(arr.astype(bool), min_size=min_region_size)
29 |
30 | arr = arr.astype(np.uint8)
31 |
32 | # Set default output path if not provided
33 | if out_path is None:
34 | out_path = os.path.splitext(in_path)[0] + "_clean.tif"
35 |
36 | # Save cleaned raster
37 | profile.update(dtype=rasterio.uint8, count=1, compress='lzw')
38 | with rasterio.open(out_path, 'w', **profile) as dst:
39 | dst.write(arr, 1)
40 |
41 | # Notify user
42 | print(f"✅ Cleaned raster successfully exported: {out_path}")
43 |
44 | return out_path
45 |
46 |
47 |
48 | def classify_flood(before_path, after_path, output_path):
49 | """
50 | Create a 0/1/2 classification layer (aligned to AFTER raster grid):
51 | 0 = no change / no water
52 | 1 = water present before event
53 | 2 = new flooded area (after - before > 0)
54 |
55 | :param before_path: path to 'before' mask (.tif)
56 | :param after_path: path to 'after' mask (.tif)
57 | :param output_path: path to save output raster
58 | :return: output_path
59 | """
60 |
61 | with rasterio.open(after_path) as src_after, rasterio.open(before_path) as src_before:
62 | # Read AFTER array (this is the reference grid)
63 | arr_after = src_after.read(1).astype(int)
64 |
65 | # privzeto uporabimo original BEFORE array
66 | arr_before_aligned = src_before.read(1).astype(int)
67 |
68 | # če se dimenzije ne ujemajo -> reproject
69 | if (src_before.width != src_after.width) or (src_before.height != src_after.height):
70 | print("⚠️ BEFORE and AFTER rasters differ in size. Reprojecting BEFORE → AFTER grid.")
71 |
72 | arr_before_aligned = np.empty_like(arr_after, dtype=np.int32)
73 |
74 | reproject(
75 | source=rasterio.band(src_before, 1),
76 | destination=arr_before_aligned,
77 | src_transform=src_before.transform,
78 | src_crs=src_before.crs,
79 | dst_transform=src_after.transform,
80 | dst_crs=src_after.crs,
81 | resampling=Resampling.nearest
82 | )
83 |
84 | # Initialize classification with 0 (no change)
85 | classification = np.zeros_like(arr_after, dtype=np.uint8)
86 |
87 | # Areas that were already water before
88 | classification[arr_before_aligned > 0] = 1
89 |
90 | # New flooded areas (after but not before)
91 | classification[(arr_before_aligned == 0) & (arr_after > 0)] = 2
92 |
93 | # Prepare metadata (use AFTER as reference)
94 | out_meta = src_after.meta.copy()
95 | out_meta.update(dtype=rasterio.uint8, count=1)
96 |
97 | # Write output raster
98 | with rasterio.open(output_path, "w", **out_meta) as dst:
99 | dst.write(classification, 1)
100 |
101 | print(f"✅ Flood classification saved to {output_path}")
102 | return output_path
--------------------------------------------------------------------------------
/stac_generator.py:
--------------------------------------------------------------------------------
1 | import json
2 | from datetime import datetime
3 | import rasterio
4 | import os
5 |
6 | def build_stac_item(
7 | tif_path: str,
8 | out_path: str,
9 | title: str,
10 | description: str,
11 | ):
12 | """
13 | Build a STAC Item JSON for a flood mask raster.
14 |
15 | Args:
16 | tif_path (str): Path to input GeoTIFF (COG).
17 | out_path (str): Path to save STAC JSON.
18 | title (str): Human-readable title.
19 | description (str): Description of the dataset.
20 | """
21 |
22 | with rasterio.open(tif_path) as src:
23 | # bounding box + geometry
24 | bounds = src.bounds
25 | bbox = [bounds.left, bounds.bottom, bounds.right, bounds.top]
26 | geometry = {
27 | "type": "Polygon",
28 | "coordinates": [[
29 | [bounds.left, bounds.bottom],
30 | [bounds.right, bounds.bottom],
31 | [bounds.right, bounds.top],
32 | [bounds.left, bounds.top],
33 | [bounds.left, bounds.bottom]
34 | ]]
35 | }
36 |
37 | # spatial resolution (assume square pixels)
38 | res_x, res_y = src.res
39 | spatial_resolution = round(abs(res_x), 2)
40 |
41 | # CRS
42 | proj_epsg = None
43 | if src.crs and src.crs.to_epsg():
44 | proj_epsg = src.crs.to_epsg()
45 |
46 | # timestamps
47 | created = datetime.utcnow().isoformat() + "Z"
48 | base = os.path.splitext(os.path.basename(tif_path))[0]
49 | item_id = f"{base}_{datetime.now().strftime('%Y%m%d%H%M%S')}"
50 |
51 | # file size in MB
52 | file_size_mb = round(os.path.getsize(tif_path) / (1024 * 1024), 6)
53 |
54 | # classification legend
55 | classes = [
56 | {
57 | "name": "Permanent Water",
58 | "title": "Permanent Water",
59 | "value": 1,
60 | "color_hint": "#1f78b499",
61 | "description": "Areas consistently covered with water."
62 | },
63 | {
64 | "name": "Flooded Areas",
65 | "title": "Flooded Areas",
66 | "value": 2,
67 | "color_hint": "#e31a1c99",
68 | "description": "Detected flooded areas."
69 | }
70 | ]
71 |
72 | stac_item = {
73 | "type": "Feature",
74 | "stac_version": "1.0.0",
75 | "id": item_id,
76 | "geometry": geometry,
77 | "bbox": bbox,
78 | "properties": {
79 | "title": title,
80 | "description": description,
81 | "created": created,
82 | "datetime": created, # same as created
83 | "renders": {
84 | "overview": {
85 | "title": title,
86 | "assets": ["flood"],
87 | "resampling": "nearest",
88 | "colormap": {
89 | "1": "#1f78b499",
90 | "2": "#e31a1c99"
91 | },
92 | "nodata": 0
93 | }
94 | }
95 | },
96 | "assets": {
97 | "flood": {
98 | "href": tif_path,
99 | "type": "image/tiff; application=geotiff",
100 | "title": f"{title} raster",
101 | "raster:bands": [
102 | {
103 | "data_type": "uint8",
104 | "spatial_resolution": spatial_resolution,
105 | "scale": 1,
106 | "offset": 0
107 | }
108 | ],
109 | "classification:classes": classes,
110 | "roles": ["data", "visual"],
111 | "file:size": file_size_mb
112 | }
113 | }
114 | }
115 |
116 | # add proj:epsg if available
117 | if proj_epsg:
118 | stac_item["properties"]["proj:epsg"] = proj_epsg
119 |
120 | with open(out_path, "w", encoding="utf-8") as f:
121 | json.dump(stac_item, f, indent=2)
122 |
123 | print(f"✅ STAC item saved to {out_path}")
124 |
--------------------------------------------------------------------------------
/FloodsML/models/unet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | class DoubleConv(nn.Module):
6 | def __init__(self, in_channels, out_channels, mid_channels=None, dropout_rate=0.0):
7 | super().__init__()
8 | if not mid_channels:
9 | mid_channels = out_channels
10 | self.double_conv = nn.Sequential(
11 | nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
12 | nn.BatchNorm2d(mid_channels),
13 | nn.ReLU(inplace=True),
14 | nn.Dropout2d(p=dropout_rate),
15 | nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
16 | nn.BatchNorm2d(out_channels),
17 | nn.ReLU(inplace=True),
18 | nn.Dropout2d(p=dropout_rate)
19 | )
20 |
21 | def forward(self, x):
22 | return self.double_conv(x)
23 |
24 | class Down(nn.Module):
25 | def __init__(self, in_channels, out_channels, dropout_rate=0.0):
26 | super().__init__()
27 | self.maxpool_conv = nn.Sequential(
28 | nn.MaxPool2d(2),
29 | DoubleConv(in_channels, out_channels, dropout_rate=dropout_rate)
30 | )
31 |
32 | def forward(self, x):
33 | return self.maxpool_conv(x)
34 |
35 | class Up(nn.Module):
36 | def __init__(self, in_channels, out_channels, bilinear=True, dropout_rate=0.0):
37 | super().__init__()
38 |
39 | if bilinear:
40 | self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
41 | self.conv = DoubleConv(in_channels, out_channels, in_channels // 2, dropout_rate=dropout_rate)
42 | else:
43 | self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
44 | self.conv = DoubleConv(in_channels, out_channels, dropout_rate=dropout_rate)
45 |
46 | def forward(self, x1, x2):
47 | x1 = self.up(x1)
48 | # Handle input sizes that aren't perfectly divisible by 2
49 | diff_y = x2.size()[2] - x1.size()[2]
50 | diff_x = x2.size()[3] - x1.size()[3]
51 |
52 | x1 = F.pad(x1, [diff_x // 2, diff_x - diff_x // 2,
53 | diff_y // 2, diff_y - diff_y // 2])
54 | x = torch.cat([x2, x1], dim=1)
55 | return self.conv(x)
56 |
57 | class OutConv(nn.Module):
58 | def __init__(self, in_channels, out_channels):
59 | super(OutConv, self).__init__()
60 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
61 |
62 | def forward(self, x):
63 | return self.conv(x)
64 |
65 | class UNet(nn.Module):
66 | def __init__(self, n_channels=3, n_classes=1, bilinear=True, dropout_rate=0.0):
67 | """
68 | Args:
69 | n_channels (int): Number of input channels (3 for RGB)
70 | n_classes (int): Number of output classes (1 for binary segmentation)
71 | bilinear (bool): Use bilinear upsampling instead of transposed convolutions
72 | dropout_rate (float): Dropout rate for regularization (0.0 to 1.0)
73 | """
74 | super(UNet, self).__init__()
75 | self.n_channels = n_channels
76 | self.n_classes = n_classes
77 | self.bilinear = bilinear
78 | self.dropout_rate = dropout_rate
79 |
80 | factor = 2 if bilinear else 1
81 |
82 | self.inc = DoubleConv(n_channels, 64, dropout_rate=dropout_rate)
83 | self.down1 = Down(64, 128, dropout_rate=dropout_rate)
84 | self.down2 = Down(128, 256, dropout_rate=dropout_rate)
85 | self.down3 = Down(256, 512, dropout_rate=dropout_rate)
86 | self.down4 = Down(512, 1024 // factor, dropout_rate=dropout_rate)
87 |
88 | self.up1 = Up(1024, 512 // factor, bilinear, dropout_rate=dropout_rate)
89 | self.up2 = Up(512, 256 // factor, bilinear, dropout_rate=dropout_rate)
90 | self.up3 = Up(256, 128 // factor, bilinear, dropout_rate=dropout_rate)
91 | self.up4 = Up(128, 64, bilinear, dropout_rate=dropout_rate)
92 |
93 | self.outc = OutConv(64, n_classes)
94 |
95 | def forward(self, x):
96 | x1 = self.inc(x)
97 | x2 = self.down1(x1)
98 | x3 = self.down2(x2)
99 | x4 = self.down3(x3)
100 | x5 = self.down4(x4)
101 |
102 | x = self.up1(x5, x4)
103 | x = self.up2(x, x3)
104 | x = self.up3(x, x2)
105 | x = self.up4(x, x1)
106 |
107 | logits = self.outc(x)
108 | return logits
--------------------------------------------------------------------------------
/FloodsML/check_dimensions.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from pathlib import Path
3 | from osgeo import gdal
4 | from tqdm import tqdm
5 | from collections import defaultdict
6 |
7 | def check_dimensions(data_dir='data'):
8 | """Check if all files in the dataset are 512x512 pixels and their value ranges"""
9 | data_dir = Path(data_dir)
10 | labels_dir = data_dir / '512_labels'
11 |
12 | if not labels_dir.exists():
13 | print(f"Error: Labels directory not found at {labels_dir}")
14 | return
15 |
16 | # Statistics storage
17 | stats = {
18 | 'total_files': 0,
19 | 'correct_size': 0,
20 | 'incorrect_size': 0,
21 | 'errors': [],
22 | 'incorrect_files': defaultdict(list), # Store files with incorrect dimensions
23 | 'value_ranges': {
24 | 'min': float('inf'),
25 | 'max': float('-inf'),
26 | 'unique_values': set()
27 | }
28 | }
29 |
30 | # Count total EMSR folders
31 | emsr_folders = [f for f in labels_dir.iterdir() if f.is_dir()]
32 |
33 | # Iterate through all EMSR folders
34 | for emsr_folder in tqdm(emsr_folders, desc="Processing EMSR folders"):
35 | # Get all FloodMask files
36 | flood_files = list(emsr_folder.glob('*_FloodMask.tif'))
37 | if not flood_files:
38 | stats['errors'].append(f"No FloodMask files found in {emsr_folder.name}")
39 | continue
40 |
41 | for flood_file in flood_files:
42 | stats['total_files'] += 1
43 |
44 | # Read label
45 | try:
46 | ds = gdal.Open(str(flood_file))
47 | if ds is None:
48 | stats['errors'].append(f"Could not open {flood_file}")
49 | continue
50 |
51 | # Get dimensions
52 | width = ds.RasterXSize
53 | height = ds.RasterYSize
54 |
55 | # Read data and check values
56 | band = ds.GetRasterBand(1)
57 | data = band.ReadAsArray()
58 |
59 | # Update min and max values
60 | stats['value_ranges']['min'] = min(stats['value_ranges']['min'], np.min(data))
61 | stats['value_ranges']['max'] = max(stats['value_ranges']['max'], np.max(data))
62 | stats['value_ranges']['unique_values'].update(np.unique(data))
63 |
64 | ds = None
65 |
66 | # Check dimensions
67 | if width == 512 and height == 512:
68 | stats['correct_size'] += 1
69 | else:
70 | stats['incorrect_size'] += 1
71 | stats['incorrect_files'][emsr_folder.name].append({
72 | 'file': flood_file.name,
73 | 'width': width,
74 | 'height': height
75 | })
76 |
77 | except Exception as e:
78 | stats['errors'].append(f"Error reading {flood_file}: {str(e)}")
79 | continue
80 |
81 | # Print summary
82 | print("\n=== File Dimension Check Summary ===")
83 | print(f"Total files checked: {stats['total_files']}")
84 | print(f"Files with correct dimensions (512x512): {stats['correct_size']}")
85 | print(f"Files with incorrect dimensions: {stats['incorrect_size']}")
86 |
87 | print("\n=== Value Range Summary ===")
88 | print(f"Minimum value: {stats['value_ranges']['min']}")
89 | print(f"Maximum value: {stats['value_ranges']['max']}")
90 | print(f"Unique values found: {sorted(list(stats['value_ranges']['unique_values']))}")
91 |
92 | if stats['incorrect_size'] > 0:
93 | print("\n=== Files with Incorrect Dimensions ===")
94 | for emsr, files in stats['incorrect_files'].items():
95 | print(f"\nEMSR: {emsr}")
96 | for file_info in files:
97 | print(f" - {file_info['file']}: {file_info['width']}x{file_info['height']}")
98 |
99 | if stats['errors']:
100 | print("\n=== Errors and Warnings ===")
101 | for error in stats['errors']:
102 | print(f"- {error}")
103 |
104 | if stats['total_files'] == 0:
105 | print("\nNo files found! Please check:")
106 | print("1. The directory structure is correct")
107 | print("2. The files are valid GeoTIFFs")
108 | return
109 |
110 | if __name__ == "__main__":
111 | check_dimensions()
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Created by https://www.toptal.com/developers/gitignore/api/python,git
2 | # Edit at https://www.toptal.com/developers/gitignore?templates=python,git
3 |
4 | ### Git ###
5 | # Created by git for backups. To disable backups in Git:
6 | # $ git config --global mergetool.keepBackup false
7 | *.orig
8 |
9 | # Created by git when using merge tools for conflicts
10 | *.BACKUP.*
11 | *.BASE.*
12 | *.LOCAL.*
13 | *.REMOTE.*
14 | *_BACKUP_*.txt
15 | *_BASE_*.txt
16 | *_LOCAL_*.txt
17 | *_REMOTE_*.txt
18 |
19 | ### Python ###
20 | # Byte-compiled / optimized / DLL files
21 | __pycache__/
22 | *.py[cod]
23 | *$py.class
24 |
25 | # C extensions
26 | *.so
27 |
28 | # Distribution / packaging
29 | .Python
30 | build/
31 | develop-eggs/
32 | dist/
33 | downloads/
34 | eggs/
35 | .eggs/
36 | lib/
37 | lib64/
38 | parts/
39 | sdist/
40 | var/
41 | wheels/
42 | share/python-wheels/
43 | *.egg-info/
44 | .installed.cfg
45 | *.egg
46 | MANIFEST
47 |
48 | # PyInstaller
49 | # Usually these files are written by a python script from a template
50 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
51 | *.manifest
52 | *.spec
53 |
54 | # Installer logs
55 | pip-log.txt
56 | pip-delete-this-directory.txt
57 |
58 | # Unit test / coverage reports
59 | htmlcov/
60 | .tox/
61 | .nox/
62 | .coverage
63 | .coverage.*
64 | .cache
65 | nosetests.xml
66 | coverage.xml
67 | *.cover
68 | *.py,cover
69 | .hypothesis/
70 | .pytest_cache/
71 | cover/
72 |
73 | # Translations
74 | *.mo
75 | *.pot
76 |
77 | # Django stuff:
78 | *.log
79 | local_settings.py
80 | db.sqlite3
81 | db.sqlite3-journal
82 |
83 | # Flask stuff:
84 | instance/
85 | .webassets-cache
86 |
87 | # Scrapy stuff:
88 | .scrapy
89 |
90 | # Sphinx documentation
91 | docs/_build/
92 |
93 | # PyBuilder
94 | .pybuilder/
95 | target/
96 |
97 | # Jupyter Notebook
98 | .ipynb_checkpoints
99 |
100 | # IPython
101 | profile_default/
102 | ipython_config.py
103 |
104 | # pyenv
105 | # For a library or package, you might want to ignore these files since the code is
106 | # intended to run in multiple environments; otherwise, check them in:
107 | # .python-version
108 |
109 | # pipenv
110 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
111 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
112 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
113 | # install all needed dependencies.
114 | #Pipfile.lock
115 |
116 | # poetry
117 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
118 | # This is especially recommended for binary packages to ensure reproducibility, and is more
119 | # commonly ignored for libraries.
120 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
121 | #poetry.lock
122 |
123 | # pdm
124 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
125 | #pdm.lock
126 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
127 | # in version control.
128 | # https://pdm.fming.dev/#use-with-ide
129 | .pdm.toml
130 |
131 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
132 | __pypackages__/
133 |
134 | # Celery stuff
135 | celerybeat-schedule
136 | celerybeat.pid
137 |
138 | # SageMath parsed files
139 | *.sage.py
140 |
141 | # Environments
142 | .env
143 | .venv
144 | env/
145 | venv/
146 | ENV/
147 | env.bak/
148 | venv.bak/
149 |
150 | # Spyder project settings
151 | .spyderproject
152 | .spyproject
153 |
154 | # Rope project settings
155 | .ropeproject
156 |
157 | # mkdocs documentation
158 | /site
159 |
160 | # mypy
161 | .mypy_cache/
162 | .dmypy.json
163 | dmypy.json
164 |
165 | # Pyre type checker
166 | .pyre/
167 |
168 | # pytype static type analyzer
169 | .pytype/
170 |
171 | # Cython debug symbols
172 | cython_debug/
173 |
174 | # PyCharm
175 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
176 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
177 | # and can be added to the global gitignore or merged into this file. For a more nuclear
178 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
179 | #.idea/
180 |
181 | ### Python Patch ###
182 | # Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration
183 | poetry.toml
184 |
185 | # ruff
186 | .ruff_cache/
187 |
188 | # LSP config files
189 | pyrightconfig.json
190 |
191 | # End of https://www.toptal.com/developers/gitignore/api/python,git
--------------------------------------------------------------------------------
/HAND/compute_flood_map.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | import numpy as np
4 | import rasterio
5 | from skimage.morphology import remove_small_objects, remove_small_holes, opening, closing, disk
6 | from scipy.ndimage import median_filter
7 | from rasterio.features import shapes
8 | import geopandas as gpd
9 | from shapely.geometry import shape
10 |
11 | def clean_binary_raster(in_fn, band=1, min_obj_size=4, min_hole_size=4, min_region_size=20, selem_radius=1, median_size=3):
12 | with rasterio.open(in_fn) as src:
13 | mask = src.read(band, out_dtype="uint8")
14 | profile = src.profile.copy()
15 | mask = (mask > 0)
16 | mask = remove_small_objects(mask, min_size=min_obj_size)
17 | mask = remove_small_holes(mask, area_threshold=min_hole_size)
18 | mask = median_filter(mask.astype(np.uint8), size=median_size) > 0
19 | if selem_radius > 0:
20 | selem = disk(selem_radius)
21 | mask = opening(mask, selem)
22 | mask = closing(mask, selem)
23 | mask = remove_small_objects(mask, min_size=min_region_size)
24 | mask_u8 = mask.astype(np.uint8)
25 | profile.update(dtype="uint8", count=1, nodata=None)
26 | return mask_u8, profile
27 |
28 | def compute_flood_map(s1_before_fn, s1_after_fn, hand_fn, out_flood_map_fn, out_vector_fn, valley_thresh=5.0):
29 | # Filter S1 before
30 | s1_before_cleaned, s1_before_profile = clean_binary_raster(s1_before_fn)
31 | # Filter S1 after
32 | s1_after_cleaned, s1_after_profile = clean_binary_raster(s1_after_fn)
33 | # Compute valley mask if HAND is provided
34 | valley_mask = None
35 | if hand_fn is not None and os.path.exists(hand_fn):
36 | with rasterio.open(hand_fn) as src:
37 | hand = src.read(1)
38 | h0 = valley_thresh # meters
39 | hand_filtered = median_filter(hand, size=3)
40 | valley_mask = hand_filtered <= h0
41 | valley_mask = remove_small_holes(valley_mask, area_threshold=64)
42 | valley_mask = remove_small_objects(valley_mask, min_size=64)
43 | # Compute flood map
44 | s1_diff = s1_after_cleaned - s1_before_cleaned
45 | meta = s1_before_profile.copy()
46 | meta.update(dtype="uint8", count=1, nodata=None)
47 | flood_map = np.zeros_like(s1_diff, dtype=np.uint8)
48 | # Permanent water (before) in valley or everywhere
49 | if valley_mask is not None:
50 | flood_map[(s1_before_cleaned == 1) & (valley_mask)] = 1
51 | flood_map[(s1_diff == 1) & (valley_mask)] = 2
52 | else:
53 | flood_map[(s1_before_cleaned == 1)] = 1
54 | flood_map[(s1_diff == 1)] = 2
55 | with rasterio.open(out_flood_map_fn, 'w', **meta) as dst:
56 | dst.write(flood_map, 1)
57 | # Vectorize flood extent (only value 2)
58 | with rasterio.open(out_flood_map_fn) as src:
59 | flood_data = src.read(1)
60 | transform = src.transform
61 | crs = src.crs
62 | flood_mask = flood_data == 2
63 | shapes_gen = shapes(flood_data, mask=flood_mask, transform=transform)
64 | geoms = []
65 | flooded = []
66 | for geom, value in shapes_gen:
67 | if value == 2:
68 | geoms.append(shape(geom))
69 | flooded.append(1)
70 | gdf = gpd.GeoDataFrame({'Flooded': flooded}, geometry=geoms, crs=crs)
71 | num_polygons = len(gdf)
72 | total_area = gdf.geometry.area.sum()
73 | print(f"Number of polygons: {num_polygons}")
74 | gdf.to_file(out_vector_fn, driver='GPKG')
75 |
76 | if __name__ == "__main__":
77 | import argparse
78 | import time
79 | parser = argparse.ArgumentParser(description="Compute flood map and vectorize extent.")
80 | parser.add_argument('--s1_before', type=str, required=True, help='Path to S1 before flood raster')
81 | parser.add_argument('--s1_after', type=str, required=True, help='Path to S1 after flood raster')
82 | parser.add_argument('--hand', type=str, required=False, default=None, help='Path to HAND raster (optional)')
83 | parser.add_argument('--out_flood_map', type=str, required=True, help='Output flood map raster')
84 | parser.add_argument('--out_vector', type=str, required=True, help='Output vector file (GeoPackage)')
85 | parser.add_argument('--valley_thresh', type=float, default=5.0, help='Valley threshold for HAND (default: 5)')
86 | args = parser.parse_args()
87 |
88 | # Start processing
89 | print('Starting flood map computation...')
90 | t0 = time.time()
91 | compute_flood_map(
92 | args.s1_before,
93 | args.s1_after,
94 | args.hand,
95 | args.out_flood_map,
96 | args.out_vector,
97 | valley_thresh=args.valley_thresh
98 | )
99 | print('Flood map computation done.')
100 | print(f"Output flood map: {args.out_flood_map}")
101 | print(f"Output vector file: {args.out_vector}")
102 | t1 = time.time()
103 | print(f"Total elapsed: {t1-t0:.2f} s")
104 | print('Done.')
--------------------------------------------------------------------------------
/FloodsML/utils/datagen_utils.py:
--------------------------------------------------------------------------------
1 | import matplotlib.pyplot as plt
2 | import numpy as np
3 | import torch
4 |
5 | def decode_seg_map_sequence(label_masks, dataset='pascal'):
6 | rgb_masks = []
7 | for label_mask in label_masks:
8 | rgb_mask = decode_segmap(label_mask, dataset)
9 | rgb_masks.append(rgb_mask)
10 | rgb_masks = torch.from_numpy(np.array(rgb_masks).transpose([0, 3, 1, 2]))
11 | return rgb_masks
12 |
13 |
14 | def decode_segmap(label_mask, dataset, plot=False):
15 | """Decode segmentation class labels into a color image
16 | Args:
17 | label_mask (np.ndarray): an (M,N) array of integer values denoting
18 | the class label at each spatial location.
19 | plot (bool, optional): whether to show the resulting color image
20 | in a figure.
21 | Returns:
22 | (np.ndarray, optional): the resulting decoded color image.
23 | """
24 | if dataset == 'pascal' or dataset == 'coco':
25 | n_classes = 21
26 | label_colours = get_pascal_labels()
27 | elif dataset == 'cityscapes':
28 | n_classes = 19
29 | label_colours = get_cityscapes_labels()
30 |
31 | elif dataset == 'deepfashion':
32 | n_classes = 13
33 | label_colours = get_deepfashion_labels()
34 |
35 | elif dataset == 'braintumor':
36 | n_classes = 3
37 | label_colours = get_braintumor_labels()
38 |
39 |
40 | else:
41 | raise NotImplementedError
42 |
43 | r = label_mask.copy()
44 | g = label_mask.copy()
45 | b = label_mask.copy()
46 | for ll in range(0, n_classes):
47 | r[label_mask == ll] = label_colours[ll, 0]
48 | g[label_mask == ll] = label_colours[ll, 1]
49 | b[label_mask == ll] = label_colours[ll, 2]
50 | rgb = np.zeros((label_mask.shape[0], label_mask.shape[1], 3))
51 | rgb[:, :, 0] = r / 255.0
52 | rgb[:, :, 1] = g / 255.0
53 | rgb[:, :, 2] = b / 255.0
54 | if plot:
55 | plt.imshow(rgb)
56 | plt.show()
57 | else:
58 | return rgb
59 |
60 |
61 | def encode_segmap(mask):
62 | """Encode segmentation label images as pascal classes
63 | Args:
64 | mask (np.ndarray): raw segmentation label image of dimension
65 | (M, N, 3), in which the Pascal classes are encoded as colours.
66 | Returns:
67 | (np.ndarray): class map with dimensions (M,N), where the value at
68 | a given location is the integer denoting the class index.
69 | """
70 | mask = mask.astype(int)
71 | label_mask = np.zeros((mask.shape[0], mask.shape[1]), dtype=np.int16)
72 | for ii, label in enumerate(get_pascal_labels()):
73 | label_mask[np.where(np.all(mask == label, axis=-1))[:2]] = ii
74 | label_mask = label_mask.astype(int)
75 | return label_mask
76 |
77 |
78 | def get_cityscapes_labels():
79 | return np.array([
80 | [128, 64, 128],
81 | [244, 35, 232],
82 | [70, 70, 70],
83 | [102, 102, 156],
84 | [190, 153, 153],
85 | [153, 153, 153],
86 | [250, 170, 30],
87 | [220, 220, 0],
88 | [107, 142, 35],
89 | [152, 251, 152],
90 | [0, 130, 180],
91 | [220, 20, 60],
92 | [255, 0, 0],
93 | [0, 0, 142],
94 | [0, 0, 70],
95 | [0, 60, 100],
96 | [0, 80, 100],
97 | [0, 0, 230],
98 | [119, 11, 32]])
99 |
100 |
101 | def get_pascal_labels():
102 | """Load the mapping that associates pascal classes with label colors
103 | Returns:
104 | np.ndarray with dimensions (21, 3)
105 | """
106 | return np.asarray([[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0],
107 | [0, 0, 128], [128, 0, 128], [0, 128, 128], [128, 128, 128],
108 | [64, 0, 0], [192, 0, 0], [64, 128, 0], [192, 128, 0],
109 | [64, 0, 128], [192, 0, 128], [64, 128, 128], [192, 128, 128],
110 | [0, 64, 0], [128, 64, 0], [0, 192, 0], [128, 192, 0],
111 | [0, 64, 128]])
112 |
113 |
114 | def get_deepfashion_labels():
115 | return np.array([
116 | [128, 64, 128],
117 | [244, 35, 232],
118 | [70, 70, 70],
119 | [102, 102, 156],
120 | [190, 153, 153],
121 | [153, 153, 153],
122 | [250, 170, 30],
123 | [220, 220, 0],
124 | [107, 142, 35],
125 | [152, 251, 152],
126 | [0, 130, 180],
127 | [220, 20, 60],
128 | [255, 0, 0],
129 | [0, 0, 142],
130 | [0, 0, 70],
131 | [0, 60, 100],
132 | [0, 80, 100],
133 | [0, 0, 230],
134 | [119, 11, 32]])
135 |
136 | def get_braintumor_labels():
137 | return np.array([
138 | [128, 64, 128],
139 | [244, 35, 232],
140 | [70, 70, 70]])
141 |
142 |
143 |
144 | def denormalize_image(image):
145 | mean=(0.485, 0.456, 0.406)
146 | std=(0.229, 0.224, 0.225)
147 |
148 | image *= std
149 | image += mean
150 |
151 | return image
--------------------------------------------------------------------------------
/FloodsML/flood_dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | from torch.utils.data import Dataset
4 | from osgeo import gdal
5 | import torch
6 | from torchvision import transforms
7 |
8 | class FloodSegmentationDataset(Dataset):
9 | def __init__(self, map_file, base_path, split='train', transform=None):
10 | """
11 | Args:
12 | map_file (string): Path to the map file with VH|VV|DEM|label pairs
13 | base_path (string): Base path to the dataset
14 | split (string): 'train' or 'val' split
15 | transform (callable, optional): Optional transform to be applied
16 | """
17 | self.base_path = base_path
18 | self.split = split
19 | self.transform = transform
20 | self.target_size = (512, 512)
21 |
22 | # VH channel statistics
23 | self.vh_mean = -11.2832
24 | self.vh_std = 5.6656
25 |
26 | # VV channel statistics
27 | self.vv_mean = -6.9761
28 | self.vv_std = 4.3096
29 |
30 | # Read the map file and verify dimensions
31 | self.image_label_pairs = []
32 | with open(map_file, 'r') as f:
33 | for line in f.readlines():
34 | if not line.strip():
35 | continue
36 |
37 | vh_path, vv_path, dem_path, label_path = line.strip().split('|')
38 |
39 | # Verify dimensions of VH image
40 | vh_ds = gdal.Open(os.path.join(self.base_path, vh_path))
41 | if vh_ds is None:
42 | continue
43 |
44 | if vh_ds.RasterXSize != 512 or vh_ds.RasterYSize != 512:
45 | vh_ds = None
46 | continue
47 |
48 | # Verify dimensions of VV image
49 | vv_ds = gdal.Open(os.path.join(self.base_path, vv_path))
50 | if vv_ds is None:
51 | vh_ds = None
52 | continue
53 |
54 | if vv_ds.RasterXSize != 512 or vv_ds.RasterYSize != 512:
55 | vh_ds = None
56 | vv_ds = None
57 | continue
58 |
59 | # Verify dimensions of label
60 | label_ds = gdal.Open(os.path.join(self.base_path, label_path))
61 | if label_ds is None:
62 | vh_ds = None
63 | vv_ds = None
64 | continue
65 |
66 | if label_ds.RasterXSize != 512 or label_ds.RasterYSize != 512:
67 | vh_ds = None
68 | vv_ds = None
69 | label_ds = None
70 | continue
71 |
72 | # If we get here, all images are 512x512
73 | self.image_label_pairs.append([vh_path, vv_path, dem_path, label_path])
74 |
75 | # Close datasets
76 | vh_ds = None
77 | vv_ds = None
78 | label_ds = None
79 |
80 | print(f"Found {len(self.image_label_pairs)} image pairs for {split} split")
81 |
82 | def __len__(self):
83 | return len(self.image_label_pairs)
84 |
85 | def __getitem__(self, idx):
86 | vh_path, vv_path, dem_path, label_path = self.image_label_pairs[idx]
87 |
88 | # Load VH and VV images using GDAL
89 | vh_ds = gdal.Open(os.path.join(self.base_path, vh_path))
90 | vv_ds = gdal.Open(os.path.join(self.base_path, vv_path))
91 |
92 | # Read arrays (we know they're 512x512)
93 | vh = vh_ds.ReadAsArray()
94 | vv = vv_ds.ReadAsArray()
95 |
96 | # Normalize VH and VV channels
97 | vh = (vh - self.vh_mean) / self.vh_std
98 | vv = (vv - self.vv_mean) / self.vv_std
99 |
100 | # Stack VH and VV channels
101 | img = np.stack([vh, vv], axis=0)
102 |
103 | # Load label
104 | label_ds = gdal.Open(os.path.join(self.base_path, label_path))
105 | label = label_ds.ReadAsArray()
106 | label = label.astype(np.float32)
107 |
108 | # Map values 1 and 2 to 1 to create binary mask
109 | label = np.where((label == 1) | (label == 2), 1, 0)
110 |
111 | # Apply transforms if provided
112 | if self.transform:
113 | transformed = self.transform(image=img, mask=label)
114 | img = transformed['image']
115 | label = transformed['mask']
116 |
117 | # Convert to torch tensors
118 | img = torch.from_numpy(img).float()
119 | label = torch.from_numpy(label).float().unsqueeze(0) # Add channel dimension
120 |
121 | # Close GDAL datasets
122 | vh_ds = None
123 | vv_ds = None
124 | label_ds = None
125 |
126 | return {
127 | 'image': img,
128 | 'label': label,
129 | 'image_path': vh_path,
130 | 'label_path': label_path,
131 | #'dem_path': dem_path # Store DEM path for future use
132 | }
--------------------------------------------------------------------------------
/FloodsML/check_sentinel_values.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from pathlib import Path
3 | from osgeo import gdal
4 | from tqdm import tqdm
5 | from collections import defaultdict
6 |
7 | def check_sentinel_values(data_dir='data'):
8 | """Check value ranges in Sentinel-1 data (VH and VV channels)"""
9 | data_dir = Path(data_dir)
10 |
11 | if not data_dir.exists():
12 | print(f"Error: Data directory not found at {data_dir}")
13 | return
14 |
15 | # Statistics storage
16 | stats = {
17 | 'total_files': 0,
18 | 'errors': [],
19 | 'vh_stats': {
20 | 'min': float('inf'),
21 | 'max': float('-inf'),
22 | 'mean': 0,
23 | 'std': 0
24 | },
25 | 'vv_stats': {
26 | 'min': float('inf'),
27 | 'max': float('-inf'),
28 | 'mean': 0,
29 | 'std': 0
30 | }
31 | }
32 |
33 | # Read both train and val map files
34 | map_files = [data_dir / 'train.txt', data_dir / 'val.txt']
35 |
36 | for map_file in map_files:
37 | if not map_file.exists():
38 | print(f"Warning: Map file not found at {map_file}")
39 | continue
40 |
41 | # Read the map file
42 | try:
43 | with open(map_file, 'r') as f:
44 | image_pairs = [line.strip().split('|') for line in f.readlines() if line.strip()]
45 | except Exception as e:
46 | print(f"Error reading map file {map_file}: {str(e)}")
47 | continue
48 |
49 | # Process each pair
50 | for vh_path, vv_path, dem_path, label_path in tqdm(image_pairs, desc=f"Processing {map_file.name}"):
51 | stats['total_files'] += 1
52 |
53 | # Process VH channel
54 | try:
55 | vh_ds = gdal.Open(str(data_dir / vh_path))
56 | if vh_ds is None:
57 | stats['errors'].append(f"Could not open VH file: {vh_path}")
58 | continue
59 |
60 | vh_data = vh_ds.ReadAsArray()
61 | stats['vh_stats']['min'] = min(stats['vh_stats']['min'], np.min(vh_data))
62 | stats['vh_stats']['max'] = max(stats['vh_stats']['max'], np.max(vh_data))
63 | stats['vh_stats']['mean'] += np.mean(vh_data)
64 | stats['vh_stats']['std'] += np.std(vh_data)
65 | vh_ds = None
66 |
67 | except Exception as e:
68 | stats['errors'].append(f"Error processing VH file {vh_path}: {str(e)}")
69 |
70 | # Process VV channel
71 | try:
72 | vv_ds = gdal.Open(str(data_dir / vv_path))
73 | if vv_ds is None:
74 | stats['errors'].append(f"Could not open VV file: {vv_path}")
75 | continue
76 |
77 | vv_data = vv_ds.ReadAsArray()
78 | stats['vv_stats']['min'] = min(stats['vv_stats']['min'], np.min(vv_data))
79 | stats['vv_stats']['max'] = max(stats['vv_stats']['max'], np.max(vv_data))
80 | stats['vv_stats']['mean'] += np.mean(vv_data)
81 | stats['vv_stats']['std'] += np.std(vv_data)
82 | vv_ds = None
83 |
84 | except Exception as e:
85 | stats['errors'].append(f"Error processing VV file {vv_path}: {str(e)}")
86 |
87 | # Calculate final statistics
88 | if stats['total_files'] > 0:
89 | stats['vh_stats']['mean'] /= stats['total_files']
90 | stats['vh_stats']['std'] /= stats['total_files']
91 | stats['vv_stats']['mean'] /= stats['total_files']
92 | stats['vv_stats']['std'] /= stats['total_files']
93 |
94 | # Print summary
95 | print("\n=== Sentinel-1 Data Statistics ===")
96 | print(f"Total files processed: {stats['total_files']}")
97 |
98 | print("\nVH Channel Statistics:")
99 | print(f" Minimum value: {stats['vh_stats']['min']:.4f}")
100 | print(f" Maximum value: {stats['vh_stats']['max']:.4f}")
101 | print(f" Mean value: {stats['vh_stats']['mean']:.4f}")
102 | print(f" Standard deviation: {stats['vh_stats']['std']:.4f}")
103 |
104 | print("\nVV Channel Statistics:")
105 | print(f" Minimum value: {stats['vv_stats']['min']:.4f}")
106 | print(f" Maximum value: {stats['vv_stats']['max']:.4f}")
107 | print(f" Mean value: {stats['vv_stats']['mean']:.4f}")
108 | print(f" Standard deviation: {stats['vv_stats']['std']:.4f}")
109 |
110 | if stats['errors']:
111 | print("\n=== Errors and Warnings ===")
112 | for error in stats['errors']:
113 | print(f"- {error}")
114 |
115 | if stats['total_files'] == 0:
116 | print("\nNo files found! Please check:")
117 | print("1. The data directory exists")
118 | print("2. The train.txt and val.txt files exist and are readable")
119 | print("3. The paths in the map files are correct")
120 | print("4. The files are valid GeoTIFFs")
121 |
122 | if __name__ == "__main__":
123 | check_sentinel_values()
--------------------------------------------------------------------------------
/HAND/mosaic_dem_tools.py:
--------------------------------------------------------------------------------
1 | import os
2 | import glob
3 | import argparse
4 | import rasterio
5 | from rasterio.merge import merge
6 | from rasterio.warp import transform_bounds
7 | from rasterio.windows import from_bounds
8 | import numpy as np
9 | import time
10 |
11 | def find_dem_files(folder, pattern='*.tif'):
12 | """Recursively find all DEM files in folder and subfolders."""
13 | dem_files = []
14 | for root, _, files in os.walk(folder):
15 | for file in files:
16 | # Add if filename is copdem.tif
17 | if file.lower() == 'copdem.tif':
18 | dem_files.append(os.path.join(root, file))
19 | return dem_files
20 |
21 | def mosaic_dems(dem_files, out_fn):
22 | """Mosaic a list of DEM files and save to out_fn."""
23 | src_files_to_mosaic = [rasterio.open(fp) for fp in dem_files]
24 | mosaic, out_trans = merge(src_files_to_mosaic)
25 | out_meta = src_files_to_mosaic[0].meta.copy()
26 | out_meta.update({
27 | "driver": "GTiff",
28 | "height": mosaic.shape[1],
29 | "width": mosaic.shape[2],
30 | "transform": out_trans,
31 | "count": 1,
32 | "dtype": "float32",
33 | "nodata": -9999.0,
34 | "compress": "deflate",
35 | "tiled": True,
36 | "blockxsize": 256,
37 | "blockysize": 256,
38 | })
39 | with rasterio.open(out_fn, "w", **out_meta) as dest:
40 | dest.write(mosaic[0].astype(np.float32), 1)
41 | for src in src_files_to_mosaic:
42 | src.close()
43 |
44 | def crop_to_reference_with_buffer(in_fn, ref_fn, out_fn, buffer_m=0):
45 | """Crop raster to reference extent with buffer (meters)."""
46 | with rasterio.open(ref_fn) as ref:
47 | ref_bounds = ref.bounds
48 | ref_crs = ref.crs
49 | ref_transform = ref.transform
50 | ref_height = ref.height
51 | ref_width = ref.width
52 | ref_dtype = "float32"
53 | ref_nodata = -9999.0
54 | ref_res = ref.res[0] # Assume square pixels
55 | # Buffer in pixels
56 | minx = ref_bounds.left - buffer_m
57 | miny = ref_bounds.bottom - buffer_m
58 | maxx = ref_bounds.right + buffer_m
59 | maxy = ref_bounds.top + buffer_m
60 | buffered_bounds = (minx, miny, maxx, maxy)
61 |
62 | with rasterio.open(in_fn) as src:
63 | # Prepare destination array and profile
64 | dst_array = np.full((ref_height, ref_width), ref_nodata, dtype=np.float32)
65 | out_profile = src.profile.copy()
66 | out_profile.update({
67 | "driver": "GTiff",
68 | "height": ref_height,
69 | "width": ref_width,
70 | "transform": ref_transform,
71 | "crs": ref_crs,
72 | "count": 1,
73 | "dtype": "float32",
74 | "nodata": -9999.0,
75 | "compress": "deflate",
76 | "tiled": True,
77 | "blockxsize": 256,
78 | "blockysize": 256,
79 | })
80 | from rasterio.warp import reproject, Resampling
81 | reproject(
82 | source=rasterio.band(src, 1),
83 | destination=dst_array,
84 | src_transform=src.transform,
85 | src_crs=src.crs,
86 | src_nodata=src.nodata,
87 | dst_transform=ref_transform,
88 | dst_crs=ref_crs,
89 | dst_nodata=ref_nodata,
90 | resampling=Resampling.bilinear,
91 | )
92 | with rasterio.open(out_fn, "w", **out_profile) as dest:
93 | dest.write(dst_array, 1)
94 |
95 | if __name__ == "__main__":
96 | parser = argparse.ArgumentParser(description="Mosaic DEMs and crop to reference raster with buffer.")
97 |
98 | parser.add_argument("--dem_folder", type=str, required=True, help="Folder containing DEM files (searches recursively)")
99 | parser.add_argument("--reference", type=str, required=True, help="Reference raster to crop to")
100 | parser.add_argument("--out_cropped", type=str, required=True, help="Output path for cropped DEM GeoTIFF")
101 | parser.add_argument("--buffer", type=float, default=5000, help="Buffer size in meters (default: 5000)")
102 |
103 | args = parser.parse_args()
104 |
105 | print("Finding DEM files...")
106 |
107 | dem_files = find_dem_files(args.dem_folder)
108 | if not dem_files:
109 | print(f"No DEM files found in {args.dem_folder}")
110 | exit(1)
111 | # Use a temporary file for the mosaic
112 | print(f"Found {len(dem_files)} DEM files. Creating mosaic...")
113 | import tempfile
114 | with tempfile.NamedTemporaryFile(suffix=".tif", delete=False) as tmp:
115 | tmp_mosaic = tmp.name
116 | t0 = time.time()
117 | mosaic_dems(dem_files, tmp_mosaic)
118 | t1 = time.time()
119 | crop_to_reference_with_buffer(tmp_mosaic, args.reference, args.out_cropped, buffer_m=args.buffer)
120 | t2 = time.time()
121 | # Delete intermediate mosaic file
122 | if os.path.exists(tmp_mosaic):
123 | os.remove(tmp_mosaic)
124 | print(f"Deleted intermediate mosaic file")
125 | print(f"Mosaic step elapsed: {t1-t0:.2f} s")
126 | print(f"Crop/reproject step elapsed: {t2-t1:.2f} s")
127 | print(f"Total elapsed: {t2-t0:.2f} s")
128 |
129 | print("Done.")
--------------------------------------------------------------------------------
/HAND/compute_hand.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | from whitebox.whitebox_tools import WhiteboxTools
4 | import numpy as np
5 | import rasterio
6 | import time
7 |
8 | def compute_threshold_map(dem_filled_fn, out_thr_fn, s_lo_pct=10, s_hi_pct=90, thr_hilly=500, thr_flat=10000):
9 | """
10 | Compute per-pixel threshold map from DEM slope.
11 | Args:
12 | dem_filled_fn: Path to filled DEM (GeoTIFF)
13 | out_thr_fn: Output threshold map path (GeoTIFF)
14 | s_lo_pct: Lower percentile for slope (default 10)
15 | s_hi_pct: Upper percentile for slope (default 90)
16 | thr_hilly: Min threshold in steep terrain (default 500)
17 | thr_flat: Max threshold in flat terrain (default 10000)
18 | """
19 | with rasterio.open(dem_filled_fn) as ds:
20 | dem = ds.read(1).astype("float64")
21 | dem_nodata = ds.nodata
22 | transform = ds.transform
23 | res_x = abs(transform.a)
24 | res_y = abs(transform.e)
25 | out_meta = ds.meta.copy()
26 | dem_mask = (dem == dem_nodata) if dem_nodata is not None else np.isnan(dem)
27 | dem = np.where(dem_mask, np.nan, dem)
28 | gy, gx = np.gradient(dem, res_y, res_x)
29 | slope_rad = np.arctan(np.hypot(gx, gy))
30 | slope_deg = np.degrees(slope_rad)
31 | slope_deg[dem_mask] = np.nan
32 | finite = np.isfinite(slope_deg)
33 | if not np.any(finite):
34 | raise ValueError("No valid slope pixels.")
35 | s_lo, s_hi = np.percentile(slope_deg[finite], (s_lo_pct, s_hi_pct))
36 | if s_hi <= s_lo:
37 | thr_map = np.full_like(slope_deg, (thr_hilly + thr_flat) / 2.0, dtype="float64")
38 | else:
39 | thr_map = np.interp(slope_deg, [s_lo, s_hi], [thr_flat, thr_hilly])
40 | thr_map = np.clip(thr_map, thr_hilly, thr_flat)
41 | thr_map[~finite] = np.nan
42 | thr_meta = {**out_meta, "count": 1, "dtype": "float32", "nodata": -9999.0}
43 | thr_write = thr_map.astype("float32")
44 | thr_write[~np.isfinite(thr_write)] = thr_meta["nodata"]
45 | with rasterio.open(out_thr_fn, "w", **thr_meta) as dst:
46 | dst.write(thr_write, 1)
47 |
48 |
49 | def compute_thresholded_streams(fac_fn, thr_map_fn, out_streams_fn):
50 | """
51 | Compute per-pixel thresholded streams from flow accumulation and threshold map.
52 | Args:
53 | fac_fn: Path to flow accumulation raster (GeoTIFF)
54 | thr_map_fn: Path to threshold map raster (GeoTIFF)
55 | out_streams_fn: Output streams raster path (GeoTIFF)
56 | """
57 | import numpy as np
58 | import rasterio
59 | with rasterio.open(fac_fn) as fac_ds, rasterio.open(thr_map_fn) as thr_ds:
60 | fac = fac_ds.read(1).astype("float64")
61 | thr_map = thr_ds.read(1).astype("float64")
62 | if fac_ds.nodata is not None:
63 | fac = np.where(fac == fac_ds.nodata, np.nan, fac)
64 | mask = ~np.isfinite(fac) | ~np.isfinite(thr_map)
65 | streams = np.where(~mask & (fac >= thr_map), 1, 0).astype("uint8")
66 | fac_meta = fac_ds.meta.copy()
67 | fac_meta.update(driver="GTiff", count=1, dtype="uint8", nodata=0)
68 | with rasterio.open(out_streams_fn, "w", **fac_meta) as dst:
69 | dst.write(streams, 1)
70 |
71 |
72 | def compute_hand_and_streams(dem_fn, out_hand_fn, out_streams_fn):
73 | # Temporary files
74 | tmp = []
75 | wbt = WhiteboxTools()
76 | wbt.verbose = False
77 | wbt.work_dir = os.path.dirname(out_hand_fn)
78 | # Set no-data values
79 | print("Setting no-data values in DEM")
80 | dem_nd_fn = os.path.join(wbt.work_dir, 'dem_mosaic.tif')
81 | wbt.set_nodata_value(
82 | i=dem_fn,
83 | output=dem_nd_fn,
84 | back_value=-9999
85 | )
86 | tmp.append(dem_nd_fn)
87 | # Fill depressions
88 | print("Filling depressions in DEM")
89 | dem_filled = os.path.join(wbt.work_dir, "dem_filled.tif")
90 | wbt.fill_depressions(dem_nd_fn, dem_filled)
91 | tmp.append(dem_filled)
92 | # Flow direction (D8)
93 | print("Computing D8 flow directions")
94 | fdr_fn = os.path.join(wbt.work_dir, "fdr.tif")
95 | wbt.d8_pointer(dem_filled, fdr_fn)
96 | tmp.append(fdr_fn)
97 | # Flow accumulation (D8)
98 | print("Computing D8 flow accumulation")
99 | fac_fn = os.path.join(wbt.work_dir, "fac.tif")
100 | wbt.d8_flow_accumulation(dem_filled, fac_fn, out_type="cells")
101 | tmp.append(fac_fn)
102 | # Compute threshold map
103 | print("Computing per-pixel threshold map")
104 | thr_map_fn = os.path.join(wbt.work_dir, "thr_map.tif")
105 | compute_threshold_map(dem_filled, thr_map_fn)
106 | tmp.append(thr_map_fn)
107 | # Compute thresholded streams
108 | print("Computing per-pixel thresholded streams")
109 | out_streams_fn = os.path.join(wbt.work_dir, "streams.tif")
110 | compute_thresholded_streams(fac_fn, thr_map_fn, out_streams_fn)
111 | # Compute HAND (Height Above Nearest Drainage)
112 | print("Computing HAND")
113 | out_hand_fn = os.path.join(wbt.work_dir, "hand.tif")
114 | wbt.elevation_above_stream(dem_filled, out_streams_fn, out_hand_fn)
115 | # Finishing up
116 | print("Saving final outputs")
117 | print(f"HAND saved to {out_hand_fn}")
118 | print(f"Streams saved to {out_streams_fn}")
119 |
120 | # Clean up intermediate files
121 | print('Cleaning up intermediate files...')
122 | for f in tmp:
123 | if os.path.exists(f):
124 | os.remove(f)
125 |
126 |
127 | if __name__ == "__main__":
128 | parser = argparse.ArgumentParser(description="Compute HAND and extract streams from DEM using WhiteboxTools.")
129 | parser.add_argument("--dem", type=str, required=True, help="Input DEM file (GeoTIFF)")
130 | parser.add_argument("--out_hand", type=str, required=True, help="Output HAND file (GeoTIFF)")
131 | parser.add_argument("--out_streams", type=str, required=True, help="Output streams file (GeoTIFF)")
132 | args = parser.parse_args()
133 |
134 | # Start processing
135 | print("HAND and streams computation...")
136 | t0 = time.time()
137 | compute_hand_and_streams(args.dem, args.out_hand, args.out_streams)
138 | t1 = time.time()
139 | print(f"Total elapsed: {t1-t0:.2f} s")
140 | print('Done.')
141 |
--------------------------------------------------------------------------------
/preprocessing.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import geopandas as gpd
4 | import rasterio
5 | from shapely import wkt
6 | from shapely.geometry import shape, box
7 | from rasterio.mask import mask
8 | from rasterio.warp import reproject, Resampling
9 | from rasterio.features import geometry_mask
10 | from rasterio import Affine
11 |
12 |
13 | # Step 1: Create AOI layer in raster CRS
14 | def create_aoi_layer(s1_path, aoi_wkt):
15 | """
16 | Create an AOI GeoDataFrame aligned to the CRS of the raster.
17 |
18 | :param s1_path: path to Sentinel-1 raster
19 | :param aoi_wkt: AOI polygon in WKT format (WGS84)
20 | :return: AOI as GeoDataFrame in raster CRS
21 | """
22 | with rasterio.open(s1_path) as src:
23 | raster_crs = src.crs
24 |
25 | geom = wkt.loads(aoi_wkt)
26 | gdf_aoi = gpd.GeoDataFrame(geometry=[geom], crs="EPSG:4326")
27 |
28 | # Reproject AOI to raster CRS if necessary
29 | if gdf_aoi.crs != raster_crs:
30 | gdf_aoi = gdf_aoi.to_crs(raster_crs)
31 |
32 | return gdf_aoi
33 |
34 |
35 | # Step 2: Crop Sentinel-1 image (VV and VH)
36 | def crop_s1(s1_path, gdf_aoi):
37 | """
38 | Crop a Sentinel-1 raster to the AOI extent.
39 |
40 | :param s1_path: path to Sentinel-1 raster
41 | :param gdf_aoi: AOI as GeoDataFrame in raster CRS
42 | :return: cropped image (numpy array) and metadata (dict)
43 | """
44 | base, ext = os.path.splitext(s1_path)
45 | output_path = f"{base}_cropped{ext}"
46 |
47 | with rasterio.open(s1_path) as src:
48 | # Reproject AOI if CRS does not match
49 | aoi_mask = gdf_aoi
50 | if src.crs != gdf_aoi.crs:
51 | aoi_mask = gdf_aoi.to_crs(src.crs)
52 |
53 | # Apply mask
54 | out_image, out_transform = mask(src, aoi_mask.geometry, crop=True)
55 |
56 | # Update metadata for cropped raster
57 | out_meta = src.meta.copy()
58 | out_meta.update({
59 | "driver": "GTiff",
60 | "height": out_image.shape[1],
61 | "width": out_image.shape[2],
62 | "transform": out_transform
63 | })
64 |
65 | # Cropped raster is kept in memory, not written to disk
66 | return out_image, out_meta
67 |
68 |
69 | # Step 3: Crop DEM raster and align to Sentinel-1 grid
70 | def crop_dem(dem_path, gdf_aoi, s1_meta):
71 | """
72 | Crop DEM raster to AOI and reproject to match Sentinel-1 grid.
73 |
74 | :param dem_path: path to DEM raster
75 | :param gdf_aoi: AOI as GeoDataFrame
76 | :param s1_meta: Sentinel-1 metadata dict (CRS, transform, shape)
77 | :return: cropped DEM image (numpy array) and metadata (dict)
78 | """
79 | s1_crs = s1_meta["crs"]
80 | s1_transform = s1_meta["transform"]
81 | s1_shape = (s1_meta["height"], s1_meta["width"])
82 |
83 | # Reproject DEM onto Sentinel-1 grid
84 | with rasterio.open(dem_path) as dem_src:
85 | out_array = np.empty(s1_shape, dtype=dem_src.dtypes[0])
86 | reproject(
87 | source=rasterio.band(dem_src, 1),
88 | destination=out_array,
89 | src_transform=dem_src.transform,
90 | src_crs=dem_src.crs,
91 | dst_transform=s1_transform,
92 | dst_crs=s1_crs,
93 | resampling=Resampling.bilinear,
94 | )
95 |
96 | # Clip DEM to AOI mask
97 | aoi_mask = gdf_aoi
98 | if gdf_aoi.crs != s1_crs:
99 | aoi_mask = gdf_aoi.to_crs(s1_crs)
100 |
101 | mask_geom = geometry_mask(
102 | geometries=aoi_mask.geometry,
103 | transform=s1_transform,
104 | invert=True,
105 | out_shape=s1_shape
106 | )
107 |
108 | out_image = np.where(mask_geom, out_array, np.nan).astype(out_array.dtype)
109 |
110 | # Build output metadata
111 | out_meta = {
112 | "driver": "GTiff",
113 | "height": out_image.shape[0],
114 | "width": out_image.shape[1],
115 | "count": 1,
116 | "dtype": str(out_array.dtype),
117 | "crs": s1_crs,
118 | "transform": s1_transform
119 | }
120 |
121 | # Save cropped DEM to disk
122 | base, ext = os.path.splitext(dem_path)
123 | output_path = f"{base}_cropped{ext}"
124 | with rasterio.open(output_path, "w", **out_meta) as dest:
125 | dest.write(out_image, 1)
126 |
127 | return out_image, out_meta
128 |
129 |
130 | # Step 4: Generate tiles from raster held in memory
131 | def tile_from_memory(image, meta, tile_size, suffix, working_dir):
132 | """
133 | Split raster (from memory) into tiles and save each to disk.
134 |
135 | :param image: numpy array (raster data)
136 | :param meta: metadata dict for raster
137 | :param tile_size: tile size in pixels
138 | :param suffix: suffix for naming tiles
139 | :param working_dir: directory where tiles will be saved
140 | :return: dict of {tile_name: (tile_data, tile_meta)}
141 | """
142 | os.makedirs(working_dir, exist_ok=True)
143 |
144 | # Ensure 3D shape (bands, rows, cols)
145 | if image.ndim == 2:
146 | image = image[np.newaxis, :, :]
147 |
148 | height, width = image.shape[1], image.shape[2]
149 | tiles = {}
150 | tile_id = 1
151 |
152 | for row in range(0, height, tile_size):
153 | for col in range(0, width, tile_size):
154 | # Handle edge tiles (no overflow)
155 | row_off = min(row, height - tile_size) if row + tile_size > height else row
156 | col_off = min(col, width - tile_size) if col + tile_size > width else col
157 |
158 | # Extract subarray
159 | window_data = image[:, row_off:row_off+tile_size, col_off:col_off+tile_size]
160 |
161 | # (Optional) skip empty tiles
162 | # if np.all(np.isnan(window_data)):
163 | # continue
164 |
165 | # Build metadata for tile
166 | transform = meta["transform"] * Affine.translation(col_off, row_off)
167 | tile_meta = meta.copy()
168 | tile_meta.update({
169 | "height": window_data.shape[1],
170 | "width": window_data.shape[2],
171 | "transform": transform
172 | })
173 |
174 | # Define tile name and path
175 | tile_name = f"tile_{tile_id:03d}_0_{suffix}"
176 | out_path = os.path.join(working_dir, f"{tile_name}.tif")
177 |
178 | # Save tile to disk
179 | with rasterio.open(out_path, "w", **tile_meta) as dst:
180 | dst.write(window_data)
181 |
182 | tiles[tile_name] = (window_data, tile_meta)
183 | tile_id += 1
184 |
185 | return tiles
186 |
--------------------------------------------------------------------------------
/FloodsML/train_unet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.optim as optim
4 | from torch.utils.data import DataLoader
5 | from torchmetrics import JaccardIndex
6 | from pathlib import Path
7 | from tqdm import tqdm
8 | import wandb
9 | import random
10 | from osgeo import gdal
11 | from utils.training_utils import (
12 | MaskedDiceLoss, MaskedBCELoss, setup_wandb_run, log_summary_metrics, log_metrics,
13 | save_checkpoint, create_save_directory,
14 | train_epoch, validate_epoch
15 | )
16 | import time
17 | import os
18 |
19 | # Set GDAL to use exceptions
20 | gdal.UseExceptions()
21 |
22 |
23 |
24 | def train_model(model, train_loader, val_loader, config):
25 | """Training loop for flood segmentation model"""
26 | start_time = time.time()
27 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
28 | model = model.to(device)
29 |
30 | # Initialize loss functions and metrics
31 | criterion_bce = MaskedBCELoss()
32 | criterion_dice = MaskedDiceLoss()
33 | jaccard = JaccardIndex(task='binary', num_classes=2).to(device)
34 |
35 | # Initialize optimizer
36 | optimizer = optim.AdamW(
37 | model.parameters(),
38 | lr=config['learning_rate'],
39 | weight_decay=config['weight_decay']
40 | )
41 |
42 | # Create warmup scheduler
43 | warmup_scheduler = optim.lr_scheduler.LinearLR(
44 | optimizer,
45 | start_factor=0.1,
46 | end_factor=1.0,
47 | total_iters=config['warmup_epochs']
48 | )
49 |
50 | # Create plateau scheduler
51 | plateau_scheduler = optim.lr_scheduler.ReduceLROnPlateau(
52 | optimizer,
53 | mode='min',
54 | factor=config.get('plateau_factor', 0.5),
55 | patience=config.get('plateau_patience', 5),
56 | verbose=True
57 | )
58 |
59 | # Setup directories and wandb
60 | save_dir = create_save_directory(config['save_dir'])
61 | if config.get('use_wandb', False) and not wandb.run: # Only initialize if not already initialized
62 | setup_wandb_run(config, "UNet")
63 |
64 | best_val_loss = float('inf')
65 | early_stopping_counter = 0
66 | final_train_iou = 0.0
67 | final_val_iou = 0.0
68 |
69 | for epoch in tqdm(range(config['epochs'])):
70 | # Training phase
71 | train_metrics = train_epoch(
72 | model=model,
73 | loader=train_loader,
74 | optimizer=optimizer,
75 | criterion_bce=criterion_bce,
76 | criterion_dice=criterion_dice,
77 | jaccard=jaccard,
78 | config=config,
79 | device=device
80 | )
81 |
82 | # Validation phase
83 | val_metrics = validate_epoch(
84 | model=model,
85 | loader=val_loader,
86 | criterion_bce=criterion_bce,
87 | criterion_dice=criterion_dice,
88 | jaccard=jaccard,
89 | config=config,
90 | device=device
91 | )
92 |
93 | # Update learning rate schedulers
94 | if epoch < config['warmup_epochs']:
95 | warmup_scheduler.step()
96 | else:
97 | plateau_scheduler.step(val_metrics['val_loss'])
98 |
99 | # Log metrics
100 | if config.get('use_wandb', False):
101 | log_metrics({**train_metrics, **val_metrics}, epoch)
102 |
103 | # Early stopping check
104 | if val_metrics['val_loss'] < best_val_loss:
105 | best_val_loss = val_metrics['val_loss']
106 | early_stopping_counter = 0
107 | save_checkpoint(model, optimizer, epoch, val_metrics, save_dir)
108 | else:
109 | early_stopping_counter += 1
110 | if early_stopping_counter >= config.get('early_stopping_patience', 10):
111 | print('Early stopping triggered')
112 | break
113 |
114 | if (epoch + 1) % config['checkpoint_freq'] == 0:
115 | save_checkpoint(model, optimizer, epoch, val_metrics, save_dir)
116 |
117 | # Update final metrics
118 | final_train_iou = train_metrics['train_iou']
119 | final_val_iou = val_metrics['val_iou']
120 |
121 | # Log final summary metrics
122 | if config.get('use_wandb', False):
123 | training_duration = time.time() - start_time
124 | summary_metrics = {
125 | 'train_iou': final_train_iou,
126 | 'val_iou': final_val_iou,
127 | 'best_val_loss': best_val_loss,
128 | 'training_duration': training_duration,
129 | 'total_epochs': epoch + 1 # Use actual number of epochs completed
130 | }
131 | log_summary_metrics(summary_metrics, config)
132 |
133 | return best_val_loss
134 |
135 | # Example usage:
136 | if __name__ == "__main__":
137 | from flood_dataset import FloodSegmentationDataset
138 | from models.unet import UNet
139 |
140 | # Create data directory if it doesn't exist
141 | data_dir = Path('data')
142 | data_dir.mkdir(parents=True, exist_ok=True)
143 |
144 | # Training configuration
145 | config = {
146 | 'learning_rate': 1e-4,
147 | 'weight_decay': 1e-4,
148 | 'epochs': 10,
149 | 'checkpoint_freq': 5,
150 | 'save_dir': 'checkpoints/unet',
151 | 'use_wandb': True,
152 | 'early_stopping_patience': 100,
153 | 'warmup_epochs': 3, # Number of epochs for warmup
154 | 'model': {
155 | 'n_channels': 2, # VH and VV channels
156 | 'n_classes': 1,
157 | 'bilinear': True
158 | },
159 | 'batch_size': 32,
160 | 'bce_weight': 0.5,
161 | 'dice_weight': 0.5
162 | }
163 |
164 | # Create datasets and dataloaders
165 | train_dataset = FloodSegmentationDataset(
166 | map_file=str(data_dir / 'train.txt'),
167 | base_path=str(data_dir),
168 | split='train'
169 | )
170 |
171 | val_dataset = FloodSegmentationDataset(
172 | map_file=str(data_dir / 'val.txt'),
173 | base_path=str(data_dir),
174 | split='val'
175 | )
176 |
177 | # Calculate optimal number of workers (4 * num_GPU or number of CPU cores)
178 | num_workers = min(8 * torch.cuda.device_count(), os.cpu_count() or 4)
179 |
180 | train_loader = DataLoader(
181 | train_dataset,
182 | batch_size=config['batch_size'],
183 | shuffle=True,
184 | num_workers=num_workers,
185 | pin_memory=True,
186 | persistent_workers=True,
187 | prefetch_factor=2
188 | )
189 |
190 | val_loader = DataLoader(
191 | val_dataset,
192 | batch_size=config['batch_size'],
193 | shuffle=False,
194 | num_workers=num_workers,
195 | pin_memory=True,
196 | persistent_workers=True,
197 | prefetch_factor=2
198 | )
199 |
200 | # Initialize model
201 | model = UNet(**config['model'])
202 |
203 | # Train the model
204 | train_model(model, train_loader, val_loader, config)
--------------------------------------------------------------------------------
/FloodsML/inference.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from pathlib import Path
3 | import sys
4 | import numpy as np
5 | from osgeo import gdal
6 | import torch
7 | from models.unet import UNet
8 |
9 | def load_tiff(filepath: Path) -> tuple[np.ndarray | None, gdal.Dataset | None]:
10 | """Opens TIFF, returns NumPy array and GDAL dataset object (for geo-info)."""
11 | ds = gdal.Open(str(filepath))
12 | if ds is None:
13 | print(f"Error: Could not open TIFF file: {filepath}", file=sys.stderr)
14 | return None, None
15 | data = ds.ReadAsArray()
16 | return data, ds
17 |
18 | def save_tiff(data: np.ndarray, output_path: Path, ref_ds: gdal.Dataset):
19 | """Saves a NumPy array as a GeoTIFF using reference geo-information."""
20 | driver = gdal.GetDriverByName('GTiff')
21 | if len(data.shape) == 2:
22 | rows, cols = data.shape
23 | num_bands = 1
24 | elif len(data.shape) == 3:
25 | num_bands, rows, cols = data.shape
26 | else:
27 | print(f"Error: Unexpected data shape for saving: {data.shape}", file=sys.stderr)
28 | return
29 |
30 | output_datatype = gdal.GDT_Byte
31 |
32 | out_ds = driver.Create(str(output_path), cols, rows, num_bands, output_datatype)
33 | if out_ds is None:
34 | print(f"Error: Could not create output file: {output_path}", file=sys.stderr)
35 | return
36 |
37 | out_ds.SetGeoTransform(ref_ds.GetGeoTransform())
38 | out_ds.SetProjection(ref_ds.GetProjection())
39 |
40 | if num_bands == 1:
41 | out_band = out_ds.GetRasterBand(1)
42 | out_band.WriteArray(data)
43 | out_band.FlushCache()
44 | else:
45 | for i in range(num_bands):
46 | out_band = out_ds.GetRasterBand(i + 1)
47 | out_band.WriteArray(data[i, :, :])
48 | out_band.FlushCache()
49 |
50 | out_ds = None
51 |
52 |
53 | def main():
54 | parser = argparse.ArgumentParser(
55 | description="Load VV/VH/DEM TIFFs, run inference, save prediction masks."
56 | )
57 | parser.add_argument("input_dir", type=str, help="Path to the input directory containing TIFF files.")
58 | parser.add_argument("output_dir", type=str, help="Path to the output directory for prediction masks.")
59 | parser.add_argument("model_path", type=str, help="Path to the pre-trained model checkpoint (.pth file).")
60 | parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", help="Device to use ('cuda' or 'cpu').")
61 | parser.add_argument("--n_channels", type=int, default=2, help="Number of input channels for the model (e.g., 3 for VV+VH+DEM).")
62 | parser.add_argument("--n_classes", type=int, default=1, help="Number of output classes (e.g., 1 for binary segmentation).")
63 | # Add bilinear argument matching the example config
64 | parser.add_argument("--bilinear", action='store_true', default=True, help="Use bilinear upsampling in UNet.") # Default True as per example
65 | parser.add_argument("--no-bilinear", action='store_false', dest='bilinear', help="Use transposed convolutions instead of bilinear upsampling.")
66 |
67 |
68 | args = parser.parse_args()
69 |
70 | input_path = Path(args.input_dir).resolve()
71 | output_path = Path(args.output_dir).resolve()
72 | model_path = Path(args.model_path).resolve()
73 | device = torch.device(args.device)
74 |
75 |
76 | if not input_path.is_dir():
77 | print(f"Error: Input directory not found: {input_path}", file=sys.stderr)
78 | sys.exit(1)
79 | if not model_path.is_file():
80 | print(f"Error: Model file not found: {model_path}", file=sys.stderr)
81 | sys.exit(1)
82 | output_path.mkdir(parents=True, exist_ok=True) # Create output directory
83 |
84 | print(f"Using device: {device}")
85 | print(f"Loading model from: {model_path}")
86 |
87 |
88 | try:
89 | model = UNet(n_channels=args.n_channels, n_classes=args.n_classes, bilinear=args.bilinear)
90 | checkpoint = torch.load(model_path, map_location=device)
91 |
92 | model_state_dict = checkpoint.get('model_state_dict', checkpoint)
93 | model.load_state_dict(model_state_dict)
94 | model.to(device)
95 | model.eval()
96 | print(f"Model loaded successfully ({args.n_channels} inputs, {args.n_classes} outputs, bilinear={args.bilinear}).")
97 | except Exception as e:
98 | print(f"Error loading model: {e}", file=sys.stderr)
99 | # traceback.print_exc() # Uncomment for detailed loading errors
100 | sys.exit(1)
101 |
102 | vh_mean = -11.2832
103 | vh_std = 5.6656
104 | vv_mean = -6.9761
105 | vv_std = 4.3096
106 |
107 |
108 |
109 | print(f"Scanning for '*vv.tif' files in: {input_path}")
110 | vv_files = sorted(list(input_path.glob('*vv.tif')))
111 |
112 | if not vv_files:
113 | print("No '*vv.tif' files found.")
114 | sys.exit(0)
115 |
116 | processed_count = 0
117 | skipped_count = 0
118 | for vv_filepath in vv_files:
119 | base_name = vv_filepath.name[:-len('01_Sentinel1_vv.tif')]
120 | vh_filename = base_name + '01_Sentinel1_vh.tif'
121 | dem_filename = base_name + '02_DEM.tif'
122 |
123 | vh_filepath = input_path / vh_filename
124 | dem_filepath = input_path / dem_filename
125 |
126 | if vh_filepath.is_file() and dem_filepath.is_file():
127 | print(f"\nProcessing group based on: {vv_filepath.name}")
128 |
129 | vv_data, vv_ds = load_tiff(vv_filepath)
130 | vh_data, vh_ds = load_tiff(vh_filepath)
131 | dem_data, dem_ds = load_tiff(dem_filepath)
132 |
133 | # --- Data Integrity Check ---
134 | if (vv_data is not None and vh_data is not None and dem_data is not None and
135 | vv_data.shape == vh_data.shape and vv_data.shape == dem_data.shape):
136 |
137 | print(f" Loaded VV: {vv_data.shape}, VH: {vh_data.shape}, DEM: {dem_data.shape}")
138 |
139 | # --- Preprocessing ---
140 | vv_norm = (vv_data.astype(np.float32) - vv_mean) / vv_std
141 | vh_norm = (vh_data.astype(np.float32) - vh_mean) / vh_std
142 | dem_processed = dem_data.astype(np.float32)
143 |
144 | input_image = np.stack([vv_norm, vh_norm], axis=0) # dem_processed
145 | input_tensor = torch.from_numpy(input_image).unsqueeze(0).to(device)
146 |
147 | # --- Inference ---
148 | with torch.no_grad(): # Disable gradient calculation
149 | output_tensor = model(input_tensor)
150 |
151 | pred_prob = torch.sigmoid(output_tensor)
152 | pred_binary = (pred_prob > 0.5).squeeze().cpu().numpy().astype(np.uint8)
153 |
154 | # --- Save Output ---
155 | output_filename = output_path / (base_name + 'pred.tif')
156 | print(f" Saving prediction to: {output_filename}")
157 | # Use geo-info from one of the inputs (e.g., VV)
158 | save_tiff(pred_binary, output_filename, vv_ds)
159 | processed_count += 1
160 |
161 | else:
162 | print(f" Skipping group: Loading error or shape mismatch.")
163 | print(f" VV: {'OK' if vv_data is not None else 'Fail'}, VH: {'OK' if vh_data is not None else 'Fail'}, DEM: {'OK' if dem_data is not None else 'Fail'}")
164 | if vv_data is not None and vh_data is not None and dem_data is not None:
165 | print(f" Shapes: VV={vv_data.shape}, VH={vh_data.shape}, DEM={dem_data.shape}")
166 | skipped_count += 1
167 |
168 | if vv_ds: vv_ds = None
169 | if vh_ds: vh_ds = None
170 | if dem_ds: dem_ds = None
171 |
172 | else:
173 | # Warning if counterparts are missing
174 | print(f"\nSkipping {vv_filepath.name}: Missing counterparts.")
175 | if not vh_filepath.is_file(): print(f" Missing: {vh_filename}")
176 | if not dem_filepath.is_file(): print(f" Missing: {dem_filename}")
177 | skipped_count += 1
178 |
179 | print(f"\nFinished. Processed {processed_count} groups, skipped {skipped_count} groups.")
180 |
181 | if __name__ == "__main__":
182 | main()
--------------------------------------------------------------------------------
/FloodsML/sweep_review.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "import wandb\n",
10 | "import pandas as pd\n",
11 | "import seaborn as sns\n",
12 | "import matplotlib.pyplot as plt"
13 | ]
14 | },
15 | {
16 | "cell_type": "code",
17 | "execution_count": 2,
18 | "metadata": {},
19 | "outputs": [
20 | {
21 | "name": "stdout",
22 | "output_type": "stream",
23 | "text": [
24 | "🏆 Top 5 Runs by Validation IoU:\n",
25 | "\n",
26 | " run_id val_iou best_val_loss learning_rate weight_decay \\\n",
27 | "19 d2l878my 0.450551 0.546182 0.000016 0.000182 \n",
28 | "8 sfek7m88 0.450266 0.540528 0.000018 0.000717 \n",
29 | "1 ij6bpizs 0.449752 0.570179 0.000029 0.000011 \n",
30 | "0 kfoyx1x6 0.449238 0.578700 0.000013 0.000080 \n",
31 | "17 vbzyhodf 0.445841 0.511774 0.000146 0.000010 \n",
32 | "\n",
33 | " warmup_epochs plateau_factor plateau_patience batch_size num_workers \\\n",
34 | "19 4 0.2 7 16 8 \n",
35 | "8 3 0.7 5 16 4 \n",
36 | "1 4 0.5 5 32 16 \n",
37 | "0 5 0.7 5 16 16 \n",
38 | "17 3 0.5 3 32 16 \n",
39 | "\n",
40 | " prefetch_factor \n",
41 | "19 1 \n",
42 | "8 1 \n",
43 | "1 1 \n",
44 | "0 2 \n",
45 | "17 1 \n"
46 | ]
47 | }
48 | ],
49 | "source": [
50 | "# Set your sweep details\n",
51 | "ENTITY = \"mracic\"\n",
52 | "PROJECT = \"flood-segmentation\"\n",
53 | "SWEEP_ID = \"sk0uzqz0\"\n",
54 | "\n",
55 | "# Initialize API\n",
56 | "api = wandb.Api()\n",
57 | "\n",
58 | "# Load sweep\n",
59 | "sweep = api.sweep(f\"{ENTITY}/{PROJECT}/{SWEEP_ID}\")\n",
60 | "runs = sweep.runs\n",
61 | "\n",
62 | "# Convert to DataFrame\n",
63 | "records = []\n",
64 | "for run in runs:\n",
65 | " if run.state != \"finished\":\n",
66 | " continue\n",
67 | " summary = run.summary\n",
68 | " config = run.config\n",
69 | " records.append({\n",
70 | " \"run_id\": run.id,\n",
71 | " \"val_iou\": summary.get(\"val_iou\"),\n",
72 | " \"best_val_loss\": summary.get(\"best_val_loss\"),\n",
73 | " \"learning_rate\": config.get(\"learning_rate\"),\n",
74 | " \"weight_decay\": config.get(\"weight_decay\"),\n",
75 | " \"warmup_epochs\": config.get(\"warmup_epochs\"),\n",
76 | " \"plateau_factor\": config.get(\"plateau_factor\"),\n",
77 | " \"plateau_patience\": config.get(\"plateau_patience\"),\n",
78 | " \"batch_size\": config.get(\"batch_size\"),\n",
79 | " \"num_workers\": config.get(\"num_workers\"),\n",
80 | " \"prefetch_factor\": config.get(\"prefetch_factor\"),\n",
81 | " })\n",
82 | "\n",
83 | "df = pd.DataFrame(records)\n",
84 | "\n",
85 | "# Drop rows with missing values\n",
86 | "df = df.dropna(subset=[\"val_iou\"])\n",
87 | "\n",
88 | "# Sort by best val_iou\n",
89 | "df_sorted = df.sort_values(by=\"val_iou\", ascending=False)\n",
90 | "\n",
91 | "# Show top 5 runs\n",
92 | "print(\"🏆 Top 5 Runs by Validation IoU:\\n\")\n",
93 | "print(df_sorted.head())\n",
94 | "\n",
95 | "# Plot\n"
96 | ]
97 | },
98 | {
99 | "cell_type": "code",
100 | "execution_count": null,
101 | "metadata": {},
102 | "outputs": [],
103 | "source": [
104 | "SWEEP_ID = 'jlcywgkr'\n",
105 | "sweep = api.sweep(f\"{ENTITY}/{PROJECT}/{SWEEP_ID}\")"
106 | ]
107 | },
108 | {
109 | "cell_type": "code",
110 | "execution_count": null,
111 | "metadata": {},
112 | "outputs": [
113 | {
114 | "name": "stdout",
115 | "output_type": "stream",
116 | "text": [
117 | "🏆 Top 5 Runs by Validation IoU:\n",
118 | "\n",
119 | " run_id val_iou best_val_loss learning_rate weight_decay warmup_epochs plateau_factor plateau_patience batch_size num_workers prefetch_factor dropout_rate bce_weight dice_weight\n",
120 | "otpnzi48 0.481520 0.508222 0.000008 0.000236 5 0.2 9 8 6 2 0.0 0.5 0.5\n",
121 | "az3w121w 0.469485 0.161217 0.000010 0.000348 5 0.3 9 8 6 1 0.0 1.0 0.0\n",
122 | "kl9io4im 0.469250 0.173235 0.000009 0.000191 5 0.3 9 8 6 2 0.0 1.0 0.0\n",
123 | "ivryz770 0.466911 0.400988 0.000011 0.000322 5 0.2 9 16 6 2 0.0 0.7 0.3\n",
124 | "3t8pofh1 0.459811 0.656754 0.000020 0.000234 4 0.2 9 24 8 1 0.0 0.3 0.7\n"
125 | ]
126 | }
127 | ],
128 | "source": [
129 | "runs = sweep.runs\n",
130 | "\n",
131 | "# Extract metrics and config for each run\n",
132 | "records = []\n",
133 | "for run in runs:\n",
134 | " if run.state != \"finished\":\n",
135 | " continue\n",
136 | " try:\n",
137 | " val_iou = run.summary.get(\"val_iou\")\n",
138 | " best_val_loss = run.summary.get(\"best_val_loss\")\n",
139 | " config = run.config\n",
140 | " \n",
141 | " records.append({\n",
142 | " \"run_id\": run.id,\n",
143 | " \"val_iou\": val_iou,\n",
144 | " \"best_val_loss\": best_val_loss,\n",
145 | " \"learning_rate\": config.get(\"learning_rate\"),\n",
146 | " \"weight_decay\": config.get(\"weight_decay\"),\n",
147 | " \"warmup_epochs\": config.get(\"warmup_epochs\"),\n",
148 | " \"plateau_factor\": config.get(\"plateau_factor\"),\n",
149 | " \"plateau_patience\": config.get(\"plateau_patience\"),\n",
150 | " \"batch_size\": config.get(\"batch_size\"),\n",
151 | " \"num_workers\": config.get(\"num_workers\"),\n",
152 | " \"prefetch_factor\": config.get(\"prefetch_factor\"),\n",
153 | " \"dropout_rate\": config.get(\"dropout_rate\"),\n",
154 | " \"bce_weight\": config.get(\"bce_weight\"),\n",
155 | " \"dice_weight\": 1.0 - config.get(\"bce_weight\") if config.get(\"bce_weight\") is not None else None,\n",
156 | " })\n",
157 | " except Exception as e:\n",
158 | " print(f\"Skipping run {run.id} due to error: {e}\")\n",
159 | "\n",
160 | "# Create DataFrame\n",
161 | "df = pd.DataFrame(records)\n",
162 | "\n",
163 | "# Sort by val_iou (descending)\n",
164 | "df = df.sort_values(by=\"val_iou\", ascending=False)\n",
165 | "\n",
166 | "# Select top 5\n",
167 | "top_5 = df.head(5)\n",
168 | "\n",
169 | "# Display\n",
170 | "print(\"🏆 Top 5 Runs by Validation IoU:\\n\")\n",
171 | "print(top_5.to_string(index=False))"
172 | ]
173 | }
174 | ],
175 | "metadata": {
176 | "kernelspec": {
177 | "display_name": "poplave",
178 | "language": "python",
179 | "name": "python3"
180 | },
181 | "language_info": {
182 | "codemirror_mode": {
183 | "name": "ipython",
184 | "version": 3
185 | },
186 | "file_extension": ".py",
187 | "mimetype": "text/x-python",
188 | "name": "python",
189 | "nbconvert_exporter": "python",
190 | "pygments_lexer": "ipython3",
191 | "version": "3.10.16"
192 | }
193 | },
194 | "nbformat": 4,
195 | "nbformat_minor": 2
196 | }
197 |
--------------------------------------------------------------------------------
/FloodsML/mg_test.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 2,
6 | "id": "56dba08f-bad7-44bf-872f-2dc87d008a0f",
7 | "metadata": {},
8 | "outputs": [
9 | {
10 | "name": "stdout",
11 | "output_type": "stream",
12 | "text": [
13 | "Hello, World!\n"
14 | ]
15 | }
16 | ],
17 | "source": [
18 | "print(\"Hello, World!\")"
19 | ]
20 | },
21 | {
22 | "cell_type": "code",
23 | "execution_count": 10,
24 | "id": "df6c268f-46dd-498d-a4e6-8d0aadc12495",
25 | "metadata": {},
26 | "outputs": [
27 | {
28 | "name": "stdout",
29 | "output_type": "stream",
30 | "text": [
31 | "Using device: cpu\n",
32 | "Loading model from: /app/models/best_model.pth\n",
33 | "Model loaded successfully (2 inputs, 1 outputs, bilinear=True).\n",
34 | "Scanning for '*vv.tif' files in: /app/sample_data/512_images_EMSR708_AOI01_DEL_PRODUCT\n",
35 | "\n",
36 | "Processing group based on: tile_001_0_01_Sentinel1_vv.tif\n",
37 | "/usr/lib/python3/dist-packages/osgeo/gdal.py:606: FutureWarning: Neither gdal.UseExceptions() nor gdal.DontUseExceptions() has been explicitly called. In GDAL 4.0, exceptions will be enabled by default.\n",
38 | " warnings.warn(\n",
39 | " Loaded VV: (512, 512), VH: (512, 512), DEM: (512, 512)\n",
40 | " Saving prediction to: /app/sample_data/512_output_labels/tile_001_0_pred.tif\n",
41 | "\n",
42 | "Processing group based on: tile_002_0_01_Sentinel1_vv.tif\n",
43 | " Loaded VV: (512, 512), VH: (512, 512), DEM: (512, 512)\n",
44 | " Saving prediction to: /app/sample_data/512_output_labels/tile_002_0_pred.tif\n",
45 | "\n",
46 | "Processing group based on: tile_003_0_01_Sentinel1_vv.tif\n",
47 | " Loaded VV: (512, 512), VH: (512, 512), DEM: (512, 512)\n",
48 | " Saving prediction to: /app/sample_data/512_output_labels/tile_003_0_pred.tif\n",
49 | "\n",
50 | "Processing group based on: tile_004_0_01_Sentinel1_vv.tif\n",
51 | " Loaded VV: (512, 512), VH: (512, 512), DEM: (512, 512)\n",
52 | " Saving prediction to: /app/sample_data/512_output_labels/tile_004_0_pred.tif\n",
53 | "\n",
54 | "Processing group based on: tile_005_0_01_Sentinel1_vv.tif\n",
55 | " Loaded VV: (512, 512), VH: (512, 512), DEM: (512, 512)\n",
56 | " Saving prediction to: /app/sample_data/512_output_labels/tile_005_0_pred.tif\n",
57 | "\n",
58 | "Processing group based on: tile_006_0_01_Sentinel1_vv.tif\n",
59 | " Loaded VV: (512, 512), VH: (512, 512), DEM: (512, 512)\n",
60 | " Saving prediction to: /app/sample_data/512_output_labels/tile_006_0_pred.tif\n",
61 | "\n",
62 | "Processing group based on: tile_007_0_01_Sentinel1_vv.tif\n",
63 | " Loaded VV: (512, 512), VH: (512, 512), DEM: (512, 512)\n",
64 | " Saving prediction to: /app/sample_data/512_output_labels/tile_007_0_pred.tif\n",
65 | "\n",
66 | "Processing group based on: tile_008_0_01_Sentinel1_vv.tif\n",
67 | " Loaded VV: (512, 512), VH: (512, 512), DEM: (512, 512)\n",
68 | " Saving prediction to: /app/sample_data/512_output_labels/tile_008_0_pred.tif\n",
69 | "\n",
70 | "Processing group based on: tile_009_0_01_Sentinel1_vv.tif\n",
71 | " Loaded VV: (512, 512), VH: (512, 512), DEM: (512, 512)\n",
72 | " Saving prediction to: /app/sample_data/512_output_labels/tile_009_0_pred.tif\n",
73 | "\n",
74 | "Processing group based on: tile_010_0_01_Sentinel1_vv.tif\n",
75 | " Loaded VV: (512, 512), VH: (512, 512), DEM: (512, 512)\n",
76 | " Saving prediction to: /app/sample_data/512_output_labels/tile_010_0_pred.tif\n",
77 | "\n",
78 | "Processing group based on: tile_011_0_01_Sentinel1_vv.tif\n",
79 | " Loaded VV: (512, 512), VH: (512, 512), DEM: (512, 512)\n",
80 | " Saving prediction to: /app/sample_data/512_output_labels/tile_011_0_pred.tif\n",
81 | "\n",
82 | "Processing group based on: tile_012_0_01_Sentinel1_vv.tif\n",
83 | " Loaded VV: (512, 512), VH: (512, 512), DEM: (512, 512)\n",
84 | " Saving prediction to: /app/sample_data/512_output_labels/tile_012_0_pred.tif\n",
85 | "\n",
86 | "Processing group based on: tile_013_0_01_Sentinel1_vv.tif\n",
87 | " Loaded VV: (512, 512), VH: (512, 512), DEM: (512, 512)\n",
88 | " Saving prediction to: /app/sample_data/512_output_labels/tile_013_0_pred.tif\n",
89 | "\n",
90 | "Processing group based on: tile_014_0_01_Sentinel1_vv.tif\n",
91 | " Loaded VV: (512, 512), VH: (512, 512), DEM: (512, 512)\n",
92 | " Saving prediction to: /app/sample_data/512_output_labels/tile_014_0_pred.tif\n",
93 | "\n",
94 | "Processing group based on: tile_015_0_01_Sentinel1_vv.tif\n",
95 | " Loaded VV: (512, 512), VH: (512, 512), DEM: (512, 512)\n",
96 | " Saving prediction to: /app/sample_data/512_output_labels/tile_015_0_pred.tif\n",
97 | "\n",
98 | "Processing group based on: tile_016_0_01_Sentinel1_vv.tif\n",
99 | " Loaded VV: (512, 512), VH: (512, 512), DEM: (512, 512)\n",
100 | " Saving prediction to: /app/sample_data/512_output_labels/tile_016_0_pred.tif\n",
101 | "\n",
102 | "Finished. Processed 16 groups, skipped 0 groups.\n"
103 | ]
104 | }
105 | ],
106 | "source": []
107 | },
108 | {
109 | "cell_type": "code",
110 | "execution_count": 6,
111 | "id": "675edde6-6e87-47c5-8694-bf814195b452",
112 | "metadata": {},
113 | "outputs": [
114 | {
115 | "name": "stdout",
116 | "output_type": "stream",
117 | "text": [
118 | "GDAL version (library): 3120000\n",
119 | "GDAL version (Python binding): 3.12.0dev-adbc_bigquery\n"
120 | ]
121 | }
122 | ],
123 | "source": [
124 | "from osgeo import gdal\n",
125 | "print(\"GDAL version (library):\", gdal.VersionInfo())\n",
126 | "print(\"GDAL version (Python binding):\", gdal.__version__)\n"
127 | ]
128 | },
129 | {
130 | "cell_type": "code",
131 | "execution_count": 11,
132 | "id": "b79493fd-9edc-4ae4-9d6b-6e07824dd04c",
133 | "metadata": {},
134 | "outputs": [],
135 | "source": [
136 | "#!pip install \"numpy<2\" --force-reinstall --break-system-packages\n"
137 | ]
138 | },
139 | {
140 | "cell_type": "code",
141 | "execution_count": null,
142 | "id": "92937c78-6418-4190-8057-7bd538a3e63e",
143 | "metadata": {},
144 | "outputs": [],
145 | "source": []
146 | },
147 | {
148 | "cell_type": "code",
149 | "execution_count": null,
150 | "id": "37fedb26-2a7d-4cd5-92b3-3f1195b02d55",
151 | "metadata": {},
152 | "outputs": [],
153 | "source": []
154 | },
155 | {
156 | "cell_type": "code",
157 | "execution_count": null,
158 | "id": "7f3e74cc-295c-4136-bc19-3597e67ab4ba",
159 | "metadata": {},
160 | "outputs": [],
161 | "source": []
162 | },
163 | {
164 | "cell_type": "code",
165 | "execution_count": null,
166 | "id": "457f7a30-1778-41ba-86db-e8ca9a78f7c4",
167 | "metadata": {},
168 | "outputs": [],
169 | "source": []
170 | },
171 | {
172 | "cell_type": "code",
173 | "execution_count": 14,
174 | "id": "e575d7cd-10e0-4be8-acd2-27f747c762af",
175 | "metadata": {},
176 | "outputs": [
177 | {
178 | "name": "stdout",
179 | "output_type": "stream",
180 | "text": [
181 | "Narejen enoten raster: outpit_file.tif\n"
182 | ]
183 | }
184 | ],
185 | "source": [
186 | "!python inference.py \"sample_data/512_images_EMSR708_AOI01_DEL_PRODUCT\" \"sample_data/512_output_labels\" \"models/best_model.pth\""
187 | ]
188 | },
189 | {
190 | "cell_type": "code",
191 | "execution_count": 17,
192 | "id": "0dde0937-966f-4bda-8576-a19d5dacf298",
193 | "metadata": {},
194 | "outputs": [
195 | {
196 | "name": "stdout",
197 | "output_type": "stream",
198 | "text": [
199 | "✅ Merged raster created: sample_data/output.tif\n"
200 | ]
201 | }
202 | ],
203 | "source": [
204 | "!python merge_tifs.py --input \"sample_data/512_images_EMSR708_AOI01_DEL_PRODUCT\" --output \"sample_data/output.tif\""
205 | ]
206 | },
207 | {
208 | "cell_type": "code",
209 | "execution_count": null,
210 | "id": "71bff668-398d-4158-a2a6-8d3e74b9a0f7",
211 | "metadata": {},
212 | "outputs": [],
213 | "source": []
214 | }
215 | ],
216 | "metadata": {
217 | "kernelspec": {
218 | "display_name": "Python 3 (ipykernel)",
219 | "language": "python",
220 | "name": "python3"
221 | },
222 | "language_info": {
223 | "codemirror_mode": {
224 | "name": "ipython",
225 | "version": 3
226 | },
227 | "file_extension": ".py",
228 | "mimetype": "text/x-python",
229 | "name": "python",
230 | "nbconvert_exporter": "python",
231 | "pygments_lexer": "ipython3",
232 | "version": "3.12.3"
233 | }
234 | },
235 | "nbformat": 4,
236 | "nbformat_minor": 5
237 | }
238 |
--------------------------------------------------------------------------------
/FloodsML/trainers/trainer.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import numpy as np
4 | from tqdm import tqdm
5 |
6 | from data_generators.data_generator import initialize_data_loader
7 | from models.sync_batchnorm.replicate import patch_replication_callback
8 | from models.deeplab import DeepLab
9 | from losses.loss import SegmentationLosses
10 | from utils.calculate_weights import calculate_weigths_labels
11 | from utils.lr_scheduler import LR_Scheduler
12 | from utils.saver import Saver
13 | from utils.summaries import TensorboardSummary
14 | from utils.metrics import Evaluator
15 | import torch
16 | import yaml
17 |
18 |
19 | class Trainer(object):
20 | def __init__(self, config):
21 |
22 | self.config = config
23 | self.best_pred = 0.0
24 |
25 | # Define Saver
26 | self.saver = Saver(config)
27 | self.saver.save_experiment_config()
28 | # Define Tensorboard Summary
29 | self.summary = TensorboardSummary(self.config['training']['tensorboard']['log_dir'])
30 | self.writer = self.summary.create_summary()
31 |
32 | self.train_loader, self.val_loader, self.test_loader, self.nclass = initialize_data_loader(config)
33 |
34 | # Define network
35 | model = DeepLab(num_classes=self.nclass,
36 | backbone=self.config['network']['backbone'],
37 | output_stride=self.config['image']['out_stride'],
38 | sync_bn=self.config['network']['sync_bn'],
39 | freeze_bn=self.config['network']['freeze_bn'])
40 |
41 | train_params = [{'params': model.get_1x_lr_params(), 'lr': self.config['training']['lr']},
42 | {'params': model.get_10x_lr_params(), 'lr': self.config['training']['lr'] * 10}]
43 |
44 | # Define Optimizer
45 | optimizer = torch.optim.SGD(train_params, momentum=self.config['training']['momentum'],
46 | weight_decay=self.config['training']['weight_decay'], nesterov=self.config['training']['nesterov'])
47 |
48 | # Define Criterion
49 | # whether to use class balanced weights
50 | if self.config['training']['use_balanced_weights']:
51 | classes_weights_path = os.path.join(self.config['dataset']['base_path'], self.config['dataset']['dataset_name'] + '_classes_weights.npy')
52 | if os.path.isfile(classes_weights_path):
53 | weight = np.load(classes_weights_path)
54 | else:
55 | weight = calculate_weigths_labels(self.config, self.config['dataset']['dataset_name'], self.train_loader, self.nclass)
56 | weight = torch.from_numpy(weight.astype(np.float32))
57 | else:
58 | weight = None
59 |
60 | self.criterion = SegmentationLosses(weight=weight, cuda=self.config['network']['use_cuda']).build_loss(mode=self.config['training']['loss_type'])
61 | self.model, self.optimizer = model, optimizer
62 |
63 | # Define Evaluator
64 | self.evaluator = Evaluator(self.nclass)
65 | # Define lr scheduler
66 | self.scheduler = LR_Scheduler(self.config['training']['lr_scheduler'], self.config['training']['lr'],
67 | self.config['training']['epochs'], len(self.train_loader))
68 |
69 |
70 | # Using cuda
71 | if self.config['network']['use_cuda']:
72 | self.model = torch.nn.DataParallel(self.model)
73 | patch_replication_callback(self.model)
74 | self.model = self.model.cuda()
75 |
76 | # Resuming checkpoint
77 |
78 | if self.config['training']['weights_initialization']['use_pretrained_weights']:
79 | if not os.path.isfile(self.config['training']['weights_initialization']['restore_from']):
80 | raise RuntimeError("=> no checkpoint found at '{}'" .format(self.config['training']['weights_initialization']['restore_from']))
81 |
82 | if self.config['network']['use_cuda']:
83 | checkpoint = torch.load(self.config['training']['weights_initialization']['restore_from'])
84 | else:
85 | checkpoint = torch.load(self.config['training']['weights_initialization']['restore_from'], map_location={'cuda:0': 'cpu'})
86 |
87 | self.config['training']['start_epoch'] = checkpoint['epoch']
88 |
89 | if self.config['network']['use_cuda']:
90 | self.model.load_state_dict(checkpoint['state_dict'])
91 | else:
92 | self.model.load_state_dict(checkpoint['state_dict'])
93 |
94 | # if not self.config['ft']:
95 | self.optimizer.load_state_dict(checkpoint['optimizer'])
96 | self.best_pred = checkpoint['best_pred']
97 | print("=> loaded checkpoint '{}' (epoch {})"
98 | .format(self.config['training']['weights_initialization']['restore_from'], checkpoint['epoch']))
99 |
100 |
101 | def training(self, epoch):
102 | train_loss = 0.0
103 | self.model.train()
104 | tbar = tqdm(self.train_loader)
105 | num_img_tr = len(self.train_loader)
106 | for i, sample in enumerate(tbar):
107 | image, target = sample['image'], sample['label']
108 | if self.config['network']['use_cuda']:
109 | image, target = image.cuda(), target.cuda()
110 | self.scheduler(self.optimizer, i, epoch, self.best_pred)
111 | self.optimizer.zero_grad()
112 | output = self.model(image)
113 | loss = self.criterion(output, target)
114 | loss.backward()
115 | self.optimizer.step()
116 | train_loss += loss.item()
117 | tbar.set_description('Train loss: %.3f' % (train_loss / (i + 1)))
118 | self.writer.add_scalar('train/total_loss_iter', loss.item(), i + num_img_tr * epoch)
119 |
120 | # Show 10 * 3 inference results each epoch
121 | if i % (num_img_tr // 10) == 0:
122 | global_step = i + num_img_tr * epoch
123 | self.summary.visualize_image(self.writer, self.config['dataset']['dataset_name'], image, target, output, global_step)
124 |
125 | self.writer.add_scalar('train/total_loss_epoch', train_loss, epoch)
126 | print('[Epoch: %d, numImages: %5d]' % (epoch, i * self.config['training']['batch_size'] + image.data.shape[0]))
127 | print('Loss: %.3f' % train_loss)
128 |
129 | #save last checkpoint
130 | self.saver.save_checkpoint({
131 | 'epoch': epoch + 1,
132 | # 'state_dict': self.model.module.state_dict(),
133 | 'state_dict': self.model.state_dict(),
134 | 'optimizer': self.optimizer.state_dict(),
135 | 'best_pred': self.best_pred,
136 | }, is_best = False, filename='checkpoint_last.pth.tar')
137 |
138 | #if training on a subset reshuffle the data
139 | if self.config['training']['train_on_subset']['enabled']:
140 | self.train_loader.dataset.shuffle_dataset()
141 |
142 |
143 | def validation(self, epoch):
144 | self.model.eval()
145 | self.evaluator.reset()
146 | tbar = tqdm(self.val_loader, desc='\r')
147 | test_loss = 0.0
148 | for i, sample in enumerate(tbar):
149 | image, target = sample['image'], sample['label']
150 | if self.config['network']['use_cuda']:
151 | image, target = image.cuda(), target.cuda()
152 | with torch.no_grad():
153 | output = self.model(image)
154 | loss = self.criterion(output, target)
155 | test_loss += loss.item()
156 | tbar.set_description('Val loss: %.3f' % (test_loss / (i + 1)))
157 | pred = output.data.cpu().numpy()
158 | target = target.cpu().numpy()
159 | pred = np.argmax(pred, axis=1)
160 | # Add batch sample into evaluator
161 | self.evaluator.add_batch(target, pred)
162 |
163 | # Fast test during the training
164 | Acc = self.evaluator.Pixel_Accuracy()
165 | Acc_class = self.evaluator.Pixel_Accuracy_Class()
166 | mIoU = self.evaluator.Mean_Intersection_over_Union()
167 | FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union()
168 | self.writer.add_scalar('val/total_loss_epoch', test_loss, epoch)
169 | self.writer.add_scalar('val/mIoU', mIoU, epoch)
170 | self.writer.add_scalar('val/Acc', Acc, epoch)
171 | self.writer.add_scalar('val/Acc_class', Acc_class, epoch)
172 | self.writer.add_scalar('val/fwIoU', FWIoU, epoch)
173 | print('Validation:')
174 | print('[Epoch: %d, numImages: %5d]' % (epoch, i * self.config['training']['batch_size'] + image.data.shape[0]))
175 | print("Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format(Acc, Acc_class, mIoU, FWIoU))
176 | print('Loss: %.3f' % test_loss)
177 |
178 | new_pred = mIoU
179 | if new_pred > self.best_pred:
180 | self.best_pred = new_pred
181 | self.saver.save_checkpoint({
182 | 'epoch': epoch + 1,
183 | # 'state_dict': self.model.module.state_dict(),
184 | 'state_dict': self.model.state_dict(),
185 | 'optimizer': self.optimizer.state_dict(),
186 | 'best_pred': self.best_pred,
187 | }, is_best = True, filename='checkpoint_best.pth.tar')
--------------------------------------------------------------------------------
/FloodsML/0_analyze_labels.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from pathlib import Path
3 | from osgeo import gdal
4 | import matplotlib.pyplot as plt
5 | from tqdm import tqdm
6 | import pandas as pd
7 | from collections import defaultdict
8 |
9 | def analyze_labels(data_dir='data'):
10 | """Analyze the distribution of labels in the dataset"""
11 | data_dir = Path(data_dir)
12 | labels_dir = data_dir / '512_labels'
13 |
14 | if not labels_dir.exists():
15 | print(f"Error: Labels directory not found at {labels_dir}")
16 | return
17 |
18 | # Statistics storage
19 | stats = {
20 | 'total_tiles': 0,
21 | 'flooded_tiles': 0,
22 | 'non_flooded_tiles': 0,
23 | 'flood_ratio': 0,
24 | 'tiles_per_emsr': defaultdict(int),
25 | 'flooded_per_emsr': defaultdict(int),
26 | 'errors': [],
27 | 'unique_values': set(), # Track unique values across all files
28 | 'value_counts': defaultdict(int), # Track count of each value
29 | 'flood_ratios': [], # Store flood ratios for histogram
30 | 'total_pixels': 0, # Track total pixels across all tiles
31 | 'threshold_stats': defaultdict(lambda: {'tiles': 0, 'pixels': defaultdict(int)}) # Stats for different thresholds
32 | }
33 |
34 | # Define thresholds to analyze
35 | thresholds = [0.01, 0.02, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3]
36 |
37 | # Count total EMSR folders
38 | emsr_folders = [f for f in labels_dir.iterdir() if f.is_dir()]
39 |
40 | # Iterate through all EMSR folders
41 | for emsr_folder in tqdm(emsr_folders, desc="Processing EMSR folders"):
42 | # Get all FloodMask files
43 | flood_files = list(emsr_folder.glob('*_FloodMask.tif'))
44 | if not flood_files:
45 | stats['errors'].append(f"No FloodMask files found in {emsr_folder.name}")
46 | continue
47 |
48 | for flood_file in flood_files:
49 | # Read label
50 | try:
51 | ds = gdal.Open(str(flood_file))
52 | label = ds.ReadAsArray()
53 | ds = None
54 |
55 | # Skip tiles that only contain 0 values
56 | if np.all(label == 0):
57 | continue
58 |
59 | # Track unique values and their counts
60 | unique_vals, counts = np.unique(label, return_counts=True)
61 | stats['unique_values'].update(unique_vals)
62 | for val, count in zip(unique_vals, counts):
63 | stats['value_counts'][int(val)] += count
64 |
65 | # Update total pixels
66 | stats['total_pixels'] += label.size
67 |
68 | # Calculate statistics
69 | total_pixels = label.size
70 | # Consider both values 1 and 2 as flooded
71 | flooded_pixels = np.sum((label == 1) | (label == 2))
72 | flood_ratio = flooded_pixels / total_pixels
73 |
74 | # Only count tiles that have at least some flooded pixels
75 | if flooded_pixels > 0:
76 | # Update statistics
77 | stats['total_tiles'] += 1
78 | stats['tiles_per_emsr'][emsr_folder.name] += 1
79 | stats['flood_ratios'].append(flood_ratio) # Store flood ratio for histogram
80 |
81 | # Update threshold statistics
82 | for threshold in thresholds:
83 | if flood_ratio > threshold:
84 | stats['threshold_stats'][threshold]['tiles'] += 1
85 | for val, count in zip(unique_vals, counts):
86 | stats['threshold_stats'][threshold]['pixels'][int(val)] += count
87 |
88 | if flood_ratio > 0.02: # Consider tile flooded if more than 2% is flooded
89 | stats['flooded_tiles'] += 1
90 | stats['flooded_per_emsr'][emsr_folder.name] += 1
91 | else:
92 | stats['non_flooded_tiles'] += 1
93 |
94 | except Exception as e:
95 | stats['errors'].append(f"Error reading {flood_file}: {str(e)}")
96 | continue
97 |
98 | # Calculate overall flood ratio
99 | if stats['total_tiles'] > 0:
100 | stats['flood_ratio'] = stats['flooded_tiles'] / stats['total_tiles']
101 |
102 | # Print summary
103 | print("\n=== Label Distribution Summary ===")
104 | print(f"Total tiles with flood pixels: {stats['total_tiles']}")
105 | print(f"Flooded tiles (>2% flood): {stats['flooded_tiles']}")
106 | print(f"Non-flooded tiles (<2% flood): {stats['non_flooded_tiles']}")
107 | print(f"Overall flood ratio: {stats['flood_ratio']:.2%}")
108 | print("\nValue Distribution:")
109 | for val in sorted(stats['value_counts'].keys()):
110 | count = stats['value_counts'][val]
111 | percentage = (count / stats['total_pixels']) * 100
112 | print(f"Value {val}: {count:,} pixels ({percentage:.2f}% of total)")
113 |
114 | print("\nDistribution after discarding tiles with less than X% flood pixels:")
115 | for threshold in thresholds:
116 | threshold_pixels = sum(stats['threshold_stats'][threshold]['pixels'].values())
117 | if threshold_pixels > 0:
118 | flood_pixels = (stats['threshold_stats'][threshold]['pixels'][1] +
119 | stats['threshold_stats'][threshold]['pixels'][2])
120 | print(f"\nThreshold {threshold*100:.1f}%:")
121 | print(f"Remaining tiles: {stats['threshold_stats'][threshold]['tiles']} "
122 | f"({stats['threshold_stats'][threshold]['tiles']/stats['total_tiles']*100:.1f}% of total)")
123 | print(f"Flood pixel ratio: {flood_pixels/threshold_pixels*100:.2f}%")
124 |
125 | if stats['errors']:
126 | print("\n=== Errors and Warnings ===")
127 | for error in stats['errors']:
128 | print(f"- {error}")
129 |
130 | if stats['total_tiles'] == 0:
131 | print("\nNo valid tiles found! Please check:")
132 | print("1. The directory structure is correct")
133 | print("2. The FloodMask files are valid GeoTIFFs")
134 | return
135 |
136 | # Create EMSR distribution DataFrame
137 | emsr_stats = []
138 | for emsr in stats['tiles_per_emsr'].keys():
139 | emsr_stats.append({
140 | 'EMSR': emsr,
141 | 'Total Tiles': stats['tiles_per_emsr'][emsr],
142 | 'Flooded Tiles': stats['flooded_per_emsr'][emsr],
143 | 'Flood Ratio': stats['flooded_per_emsr'][emsr] / stats['tiles_per_emsr'][emsr] if stats['tiles_per_emsr'][emsr] > 0 else 0
144 | })
145 |
146 | if emsr_stats: # Only create DataFrame if we have data
147 | df = pd.DataFrame(emsr_stats)
148 | df = df.sort_values('Total Tiles', ascending=False)
149 |
150 | # Save to CSV
151 | output_file = data_dir / 'label_distribution.csv'
152 | df.to_csv(output_file, index=False)
153 | print(f"\nDetailed statistics saved to {output_file}")
154 |
155 | # Create visualizations
156 | plt.figure(figsize=(15, 12))
157 |
158 | # Plot 1: Overall distribution
159 | plt.subplot(2, 2, 1)
160 | plt.pie([stats['flooded_tiles'], stats['non_flooded_tiles']],
161 | labels=[f'Flooded (>2%)\n{stats["flooded_tiles"]} tiles',
162 | f'Non-flooded (<2%)\n{stats["non_flooded_tiles"]} tiles'],
163 | autopct='%.1f%%')
164 | plt.title('Overall Label Distribution')
165 |
166 | # Plot 2: Histogram of flood ratios
167 | plt.subplot(2, 2, 2)
168 | plt.hist(stats['flood_ratios'], bins=20, edgecolor='black')
169 | plt.title('Distribution of Flood Ratios')
170 | plt.xlabel('Flood Ratio')
171 | plt.ylabel('Number of Tiles')
172 | plt.axvline(x=0.02, color='r', linestyle='--', label='2% Threshold')
173 | plt.legend()
174 |
175 | # Plot 3: Threshold impact on tile count
176 | plt.subplot(2, 2, 3)
177 | threshold_tiles = [stats['threshold_stats'][t]['tiles'] for t in thresholds]
178 | plt.plot(np.array(thresholds) * 100, threshold_tiles, 'b-o')
179 | plt.title('Number of Tiles vs Threshold')
180 | plt.xlabel('Threshold (%)')
181 | plt.ylabel('Number of Tiles')
182 | plt.grid(True)
183 |
184 | # Plot 4: Threshold impact on flood pixel ratio
185 | plt.subplot(2, 2, 4)
186 | flood_ratios = []
187 | for t in thresholds:
188 | total_pixels = sum(stats['threshold_stats'][t]['pixels'].values())
189 | if total_pixels > 0:
190 | flood_pixels = (stats['threshold_stats'][t]['pixels'][1] +
191 | stats['threshold_stats'][t]['pixels'][2])
192 | flood_ratios.append(flood_pixels / total_pixels * 100)
193 | else:
194 | flood_ratios.append(0)
195 |
196 | plt.plot(np.array(thresholds) * 100, flood_ratios, 'g-o')
197 | plt.title('Flood Pixel Ratio vs Threshold')
198 | plt.xlabel('Threshold (%)')
199 | plt.ylabel('Flood Pixel Ratio (%)')
200 | plt.grid(True)
201 |
202 | plt.tight_layout()
203 | plt.savefig(data_dir / 'label_distribution.png')
204 | print(f"Visualization saved to {data_dir / 'label_distribution.png'}")
205 |
206 | if __name__ == "__main__":
207 | analyze_labels()
--------------------------------------------------------------------------------
/FloodsML/0_create_splits.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from pathlib import Path
3 | from osgeo import gdal
4 | import random
5 | from tqdm import tqdm
6 | import sys
7 |
8 | def create_filtered_splits(data_dir='data', train_ratio=0.8, flood_threshold=0.01):
9 | """
10 | Create train and validation splits with filtering criteria:
11 | - Exclude tiles with only 0 values
12 | - Exclude tiles with less than 1% of flood pixels (values 1 and 2)
13 |
14 | Args:
15 | data_dir (str): Path to data directory
16 | train_ratio (float): Ratio of training data (default: 0.8)
17 | flood_threshold (float): Minimum ratio of flood pixels (default: 0.01)
18 | """
19 | data_dir = Path(data_dir)
20 | images_dir = data_dir / '512_images'
21 | labels_dir = data_dir / '512_labels'
22 |
23 | if not images_dir.exists():
24 | print(f"Error: Images directory not found at {images_dir}", flush=True)
25 | return
26 | if not labels_dir.exists():
27 | print(f"Error: Labels directory not found at {labels_dir}", flush=True)
28 | return
29 |
30 | # Statistics for reporting
31 | stats = {
32 | 'total_pairs': 0,
33 | 'filtered_pairs': 0,
34 | 'skipped_zero_only': 0,
35 | 'skipped_low_flood': 0,
36 | 'pairs_per_emsr': {}
37 | }
38 |
39 | # Store valid pairs
40 | pairs = []
41 |
42 | # Get all EMSR folders first
43 | emsr_folders = list(images_dir.iterdir())
44 | emsr_folders = [f for f in emsr_folders if f.is_dir()]
45 |
46 | # Create progress bar for EMSR folders
47 | pbar = tqdm(emsr_folders, desc="Processing EMSR folders", file=sys.stdout)
48 |
49 | # Iterate through all EMSR folders
50 | for emsr_folder in pbar:
51 | # Update progress bar description with current folder
52 | pbar.set_description(f"Processing {emsr_folder.name}")
53 |
54 | # Get corresponding label folder
55 | label_folder = labels_dir / emsr_folder.name
56 | if not label_folder.exists():
57 | print(f"Warning: Label folder missing for {emsr_folder.name}", flush=True)
58 | continue
59 |
60 | # Get all unique base names (without extensions and suffixes)
61 | base_names = set()
62 | for vh_file in emsr_folder.glob('*_Sentinel1_vh.tif'):
63 | # Extract base name (e.g., 'tile_001_1_01' from 'tile_001_1_01_Sentinel1_vh.tif')
64 | base_name = vh_file.stem.replace('_Sentinel1_vh', '')
65 | base_names.add(base_name)
66 |
67 | if not base_names:
68 | print(f"Warning: No valid image files found in {emsr_folder.name}", flush=True)
69 | continue
70 |
71 | emsr_pairs = 0
72 | for base_name in base_names:
73 | # Extract tile number and other components
74 | # e.g., from 'tile_001_1_01' get '001', '1', '01'
75 | parts = base_name.split('_')
76 | if len(parts) < 4:
77 | print(f"Warning: Invalid base name format: {base_name}", flush=True)
78 | continue
79 |
80 | tile_num = parts[1]
81 | suffix_num = parts[2] # e.g., '1' from 'tile_001_1_01'
82 | version_num = parts[3] # e.g., '01' from 'tile_001_1_01'
83 | if emsr_folder == 'EMSR783_AOI01_DEL_MONIT01':
84 | print("Working on EMSR783_AOI01_DEL_MONIT01", flush=True)
85 | # Construct paths for all required files
86 | vh_file = emsr_folder / f"{base_name}_Sentinel1_vh.tif"
87 | vv_file = emsr_folder / f"{base_name}_Sentinel1_vv.tif"
88 | # DEM files use _02 instead of _01
89 | dem_file = emsr_folder / f"tile_{tile_num}_{suffix_num}_02_DEM.tif"
90 |
91 | # Find corresponding label file with FloodMask pattern
92 | label_files = list(label_folder.glob(f"tile_{tile_num}_{suffix_num}_*_FloodMask.tif"))
93 | if not label_files:
94 | print(f"Warning: No FloodMask file found for {base_name}", flush=True)
95 | continue
96 |
97 | label_file = label_files[0] # Take the first matching label file
98 |
99 | # Check if all required files exist
100 | if not all(f.exists() for f in [vh_file, vv_file, dem_file, label_file]):
101 | missing_files = [f.name for f in [vh_file, vv_file, dem_file, label_file] if not f.exists()]
102 | print(f"Warning: Missing files for {base_name}: {', '.join(missing_files)}", flush=True)
103 | continue
104 |
105 | # Read and analyze label
106 | try:
107 | ds = gdal.Open(str(label_file))
108 | label = ds.ReadAsArray()
109 | ds = None
110 |
111 | # Check if tile has only zeros
112 | if np.all(label == 0):
113 | stats['skipped_zero_only'] += 1
114 | continue
115 |
116 | # Calculate flood ratio (values 1 and 2)
117 | total_pixels = label.size
118 | flood_pixels = np.sum((label == 1) | (label == 2))
119 | flood_ratio = flood_pixels / total_pixels
120 |
121 | # Check label and vv band for valid flood ratio
122 | try:
123 | # Read vv_file
124 | ds_vv = gdal.Open(str(vv_file))
125 | vv_band = ds_vv.ReadAsArray()
126 | ds_vv = None
127 |
128 | # Check if vv_band dimensions are smaller than label
129 | if vv_band.shape != label.shape:
130 | # Calculate padding
131 | pad_height = label.shape[0] - vv_band.shape[0]
132 | pad_width = label.shape[1] - vv_band.shape[1]
133 |
134 | # Ensure padding values are non-negative
135 | if pad_height < 0 or pad_width < 0:
136 | #apply padding to label
137 | label = np.pad(
138 | label,
139 | ((0, -pad_height), (0, -pad_width)),
140 | mode='constant',
141 | constant_values=0
142 | )
143 | print(f"Error: vv_file dimensions are larger than label {label.shape}", flush=True)
144 | else:
145 | # Apply padding to vv_band
146 | vv_band = np.pad(
147 | vv_band,
148 | ((0, pad_height), (0, pad_width)),
149 | mode='constant',
150 | constant_values=0
151 | )
152 |
153 | # Mask the label using vv_file where vv_band != 0
154 | masked_label = np.where(vv_band != 0, label, 0)
155 |
156 | # Check if masked label has only zeros
157 | if np.all(masked_label == 0):
158 | stats['skipped_zero_only'] += 1
159 | continue
160 |
161 | # Calculate flood ratio for the remaining pixels
162 | # total_pixels = np.count_nonzero(vv_band != 0)
163 | flood_pixels = np.sum((masked_label == 1) | (masked_label == 2))
164 | flood_ratio = flood_pixels / total_pixels
165 |
166 | # Skip if flood ratio is below threshold
167 | if flood_ratio < flood_threshold:
168 | stats['skipped_low_flood'] += 1
169 | continue
170 | except Exception as e:
171 | print(f"Error processing {emsr_folder}/{base_name} with vv_file: {str(e)}", flush=True)
172 | continue
173 | if emsr_folder == 'EMSR783_AOI01_DEL_MONIT01':
174 | print(f"EMSR783_AOI01_DEL_MONIT01 flood ratio: {flood_ratio}", flush=True)
175 | print(flood_pixels, total_pixels, flood_ratio, flush=True)
176 |
177 | # Store relative paths
178 | rel_vh_path = vh_file.relative_to(data_dir)
179 | rel_vv_path = vv_file.relative_to(data_dir)
180 | rel_dem_path = dem_file.relative_to(data_dir)
181 | rel_label_path = label_file.relative_to(data_dir)
182 |
183 | # Store as a quadruplet: vh|vv|dem|label
184 | pairs.append(f"{rel_vh_path}|{rel_vv_path}|{rel_dem_path}|{rel_label_path}")
185 | emsr_pairs += 1
186 |
187 | except Exception as e:
188 | print(f"Error reading {label_file}: {str(e)}", flush=True)
189 | continue
190 | if emsr_folder == 'EMSR783_AOI01_DEL_MONIT01':
191 | print(f"EMSR783_AOI01_DEL_MONIT01 pairs: {emsr_pairs}", flush=True)
192 | stats['pairs_per_emsr'][emsr_folder.name] = emsr_pairs
193 | stats['total_pairs'] += len(base_names)
194 |
195 | if not pairs:
196 | print("Error: No valid pairs found!", flush=True)
197 | return
198 |
199 | # Shuffle the pairs
200 | random.seed(42) # for reproducibility
201 | random.shuffle(pairs)
202 |
203 | # Split into train/val
204 | train_size = int(train_ratio * len(pairs))
205 | train_pairs = pairs[:train_size]
206 | val_pairs = pairs[train_size:]
207 |
208 | # Save splits
209 | train_file = data_dir / 'train.txt'
210 | val_file = data_dir / 'val.txt'
211 |
212 | with open(train_file, 'w') as f:
213 | f.write('\n'.join(train_pairs))
214 | with open(val_file, 'w') as f:
215 | f.write('\n'.join(val_pairs))
216 |
217 | # Print summary
218 | print("\n=== Split Creation Summary ===", flush=True)
219 | print(f"Total image pairs found: {stats['total_pairs']}", flush=True)
220 | print(f"Pairs skipped (zero-only): {stats['skipped_zero_only']}", flush=True)
221 | print(f"Pairs skipped (low flood): {stats['skipped_low_flood']}", flush=True)
222 | print(f"Valid pairs kept: {len(pairs)}", flush=True)
223 | print(f"Training pairs: {len(train_pairs)}", flush=True)
224 | print(f"Validation pairs: {len(val_pairs)}", flush=True)
225 |
226 | print("\nPairs per EMSR:")
227 | for emsr, count in sorted(stats['pairs_per_emsr'].items(), key=lambda x: x[1], reverse=True):
228 | print(f"{emsr}: {count} pairs", flush=True)
229 |
230 | if __name__ == "__main__":
231 | create_filtered_splits()
--------------------------------------------------------------------------------
/FloodsML/utils/training_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import wandb
4 | from pathlib import Path
5 | from torch.optim import Optimizer
6 | from torch.utils.data import DataLoader
7 | from torchmetrics import JaccardIndex
8 | from typing import Dict, Any
9 | from osgeo import gdal
10 |
11 | gdal.UseExceptions()
12 |
13 | class DiceLoss(nn.Module):
14 | def __init__(self, smooth=1.0):
15 | super(DiceLoss, self).__init__()
16 | self.smooth = smooth
17 |
18 | def forward(self, predictions, targets):
19 | predictions = torch.sigmoid(predictions)
20 | predictions = predictions.view(-1)
21 | targets = targets.view(-1).float()
22 |
23 | intersection = (predictions * targets).sum()
24 | dice = (2. * intersection + self.smooth) / (predictions.sum() + targets.sum() + self.smooth)
25 |
26 | return 1 - dice
27 |
28 | class MaskedDiceLoss(DiceLoss):
29 | def forward(self, predictions, targets, images):
30 | predictions = torch.sigmoid(predictions)
31 |
32 | # Create mask for pixels where either VH or VV is 0
33 | # images shape: [batch_size, 2, height, width]
34 | mask = (images[:, 0] != 0) & (images[:, 1] != 0) # [batch_size, height, width]
35 | mask = mask.unsqueeze(1) # Add channel dimension
36 |
37 | # Apply mask to predictions and targets
38 | predictions = predictions * mask
39 | targets = targets * mask
40 |
41 | # Use parent class's dice calculation
42 | return super().forward(predictions, targets)
43 |
44 | class MaskedBCELoss(nn.BCEWithLogitsLoss):
45 | def forward(self, predictions, targets, images):
46 | # Create mask for pixels where either VH or VV is 0
47 | # images shape: [batch_size, 2, height, width]
48 | mask = (images[:, 0] != 0) & (images[:, 1] != 0) # [batch_size, height, width]
49 | mask = mask.unsqueeze(1) # Add channel dimension
50 |
51 | # Apply mask to predictions and targets
52 | predictions = predictions * mask
53 | targets = targets * mask
54 |
55 | # Use parent class's BCE calculation
56 | return super().forward(predictions, targets)
57 |
58 | def log_summary_metrics(metrics, config):
59 | """Log summary metrics and create visualizations in wandb"""
60 | if not config.get('use_wandb', False):
61 | return
62 |
63 | # Create a dictionary with default values for missing metrics
64 | summary_dict = {
65 | "final_train_iou": metrics.get('train_iou', 0.0),
66 | "val_iou": metrics.get('val_iou', 0.0),
67 | "best_val_loss": metrics.get('best_val_loss', float('inf')),
68 | "training_duration": metrics.get('training_duration', 0.0),
69 | "total_epochs": metrics.get('total_epochs', 0)
70 | }
71 |
72 | # Only log if wandb is initialized
73 | if wandb.run:
74 | wandb.log(summary_dict)
75 |
76 | def setup_wandb_run(config, model_type):
77 | """Setup wandb run with proper grouping and tagging"""
78 | # Create meaningful group name based on model type
79 | if model_type.lower() == "unet":
80 | group_name = f"unet_{config['model']['n_channels']}ch_{config['batch_size']}bs"
81 | elif model_type.lower() == "transformersegmentation":
82 | group_name = f"transformer_{config['batch_size']}bs"
83 | else: # DeepLab
84 | group_name = f"deeplab_{config['backbone']}_{config['batch_size']}bs"
85 |
86 | # Common tags
87 | tags = [
88 | f"lr_{config['learning_rate']}",
89 | f"bs_{config['batch_size']}",
90 | model_type.lower(),
91 | ]
92 |
93 | # Model-specific tags
94 | if model_type.lower() == "unet":
95 | tags.extend([
96 | "bilinear" if config['model']['bilinear'] else "transpose",
97 | f"channels_{config['model']['n_channels']}"
98 | ])
99 | elif model_type.lower() == "transformersegmentation":
100 | tags.extend([
101 | "segformer",
102 | f"image_size_{config['model']['image_size']}"
103 | ])
104 | else: # DeepLab
105 | tags.extend([
106 | f"backbone_{config['backbone']}",
107 | f"output_stride_{config['output_stride']}"
108 | ])
109 |
110 | wandb.init(
111 | project="flood-segmentation",
112 | config=config,
113 | group=group_name,
114 | tags=tags,
115 | name=f"{model_type}_lr{config['learning_rate']}_bs{config['batch_size']}"
116 | )
117 |
118 | # Define custom metrics
119 | wandb.define_metric("train/loss", summary="min")
120 | wandb.define_metric("train/iou", summary="max")
121 | wandb.define_metric("val/loss", summary="min")
122 | wandb.define_metric("val/iou", summary="max")
123 | wandb.define_metric("learning_rate", summary="min")
124 |
125 | def log_metrics(metrics, epoch):
126 | """
127 | Unified metric logging for both models
128 | """
129 | if not wandb.run:
130 | return
131 |
132 | wandb.log({
133 | # Loss components
134 | "train/total_loss": metrics['total_loss'], # Combined weighted loss
135 | "train/bce_loss": metrics.get('bce_loss', 0),
136 | "train/dice_loss": metrics.get('dice_loss', 0),
137 |
138 | # Training metrics
139 | "train/iou": metrics['train_iou'],
140 | "train/dice_score": 1 - metrics.get('train_dice_loss', 0),
141 |
142 | # Validation metrics
143 | "val/total_loss": metrics['val_loss'], # Combined weighted loss for validation
144 | "val/iou": metrics['val_iou'],
145 | "val/dice_score": 1 - metrics.get('val_dice_loss', 0),
146 |
147 | # Learning rate
148 | "train/learning_rate": metrics['learning_rate'],
149 |
150 | "epoch": epoch,
151 | })
152 |
153 | def save_checkpoint(model, optimizer, epoch, metrics, save_dir):
154 | """
155 | Unified checkpoint saving for both models
156 |
157 | Args:
158 | model: PyTorch model
159 | optimizer: PyTorch optimizer
160 | epoch: Current epoch number
161 | metrics: Dictionary of metrics
162 | save_dir: Directory to save the checkpoint
163 | """
164 | save_dir = Path(save_dir)
165 | save_dir.mkdir(parents=True, exist_ok=True)
166 |
167 | # Save best model
168 | best_model_path = save_dir / 'best_model.pth'
169 | torch.save({
170 | 'epoch': epoch,
171 | 'model_state_dict': model.state_dict(),
172 | 'optimizer_state_dict': optimizer.state_dict(),
173 | 'metrics': metrics,
174 | 'best_val_loss': metrics.get('val_loss', float('inf')),
175 | 'val_iou': metrics.get('val_iou', 0.0)
176 | }, best_model_path)
177 |
178 | # Save periodic checkpoint
179 | if (epoch + 1) % 5 == 0: # Save every 5 epochs
180 | periodic_path = save_dir / f'checkpoint_epoch_{epoch+1}.pth'
181 | torch.save({
182 | 'epoch': epoch,
183 | 'model_state_dict': model.state_dict(),
184 | 'optimizer_state_dict': optimizer.state_dict(),
185 | 'metrics': metrics
186 | }, periodic_path)
187 |
188 | def create_save_directory(save_dir):
189 | """
190 | Create directory for saving model checkpoints
191 | """
192 | save_dir = Path(save_dir)
193 | save_dir.mkdir(parents=True, exist_ok=True)
194 | return save_dir
195 |
196 | def train_epoch(
197 | model: nn.Module,
198 | loader: DataLoader,
199 | optimizer: Optimizer,
200 | criterion_bce: nn.Module,
201 | criterion_dice: nn.Module,
202 | jaccard: JaccardIndex,
203 | config: Dict[str, Any],
204 | device: torch.device
205 | ) -> Dict[str, float]:
206 | """
207 | Single training epoch for both UNet and DeepLab models
208 | """
209 | model.train()
210 | metrics = {
211 | 'total_loss': 0,
212 | 'bce_loss': 0,
213 | 'dice_loss': 0,
214 | 'train_iou': 0,
215 | 'learning_rate': optimizer.param_groups[0]['lr'] # Add learning rate to metrics
216 | }
217 |
218 | for batch in loader:
219 | images = batch['image'].to(device)
220 | masks = batch['label'].to(device)
221 |
222 | outputs = model(images)
223 |
224 | # Calculate losses
225 | loss_bce = criterion_bce(outputs, masks.float(), images) # Pass images for masking
226 | loss_dice = criterion_dice(outputs, masks, images) # Pass images for masking
227 |
228 | # Get loss weights from config, default to 0.5/0.5 if not specified
229 | bce_weight = config.get('bce_weight', 0.5)
230 | dice_weight = config.get('dice_weight', 0.5)
231 | loss = bce_weight * loss_bce + dice_weight * loss_dice
232 |
233 | # Calculate metrics
234 | preds = (torch.sigmoid(outputs) > 0.5).float()
235 | batch_iou = jaccard(preds, masks)
236 |
237 | # Update metrics
238 | metrics['total_loss'] += loss.item()
239 | metrics['bce_loss'] += loss_bce.item()
240 | metrics['dice_loss'] += loss_dice.item()
241 | metrics['train_iou'] += batch_iou.item()
242 |
243 | # Backward pass
244 | optimizer.zero_grad()
245 | loss.backward()
246 | optimizer.step()
247 |
248 | # Average metrics
249 | for key in metrics:
250 | if key != 'learning_rate': # Don't average the learning rate
251 | metrics[key] /= len(loader)
252 |
253 | return metrics
254 |
255 | def validate_epoch(
256 | model: nn.Module,
257 | loader: DataLoader,
258 | criterion_bce: nn.Module,
259 | criterion_dice: nn.Module,
260 | jaccard: JaccardIndex,
261 | config: Dict[str, Any],
262 | device: torch.device
263 | ) -> Dict[str, float]:
264 | """
265 | Single validation epoch for both UNet and DeepLab models
266 | """
267 | model.eval()
268 | metrics = {
269 | 'val_loss': 0,
270 | 'val_iou': 0,
271 | 'val_dice_loss': 0,
272 | 'learning_rate': config.get('learning_rate', 0.0) # Add learning rate to metrics
273 | }
274 |
275 | with torch.no_grad():
276 | for batch in loader:
277 | images = batch['image'].to(device)
278 | masks = batch['label'].to(device)
279 |
280 | outputs = model(images)
281 |
282 | # Calculate losses
283 | loss_bce = criterion_bce(outputs, masks.float(), images) # Pass images for masking
284 | loss_dice = criterion_dice(outputs, masks, images) # Pass images for masking
285 |
286 | # Get loss weights from config, default to 0.5/0.5 if not specified
287 | bce_weight = config.get('bce_weight', 0.5)
288 | dice_weight = config.get('dice_weight', 0.5)
289 | loss = bce_weight * loss_bce + dice_weight * loss_dice
290 |
291 | # Calculate metrics
292 | preds = (torch.sigmoid(outputs) > 0.5).float()
293 | batch_iou = jaccard(preds, masks)
294 |
295 | metrics['val_loss'] += loss.item()
296 | metrics['val_iou'] += batch_iou.item()
297 | metrics['val_dice_loss'] += loss_dice.item()
298 |
299 | # Average metrics
300 | for key in metrics:
301 | if key != 'learning_rate': # Don't average the learning rate
302 | metrics[key] /= len(loader)
303 |
304 | return metrics
--------------------------------------------------------------------------------
/FloodsML/optimize_parameters.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.optim as optim
4 | from torch.utils.data import DataLoader
5 | from pathlib import Path
6 | import wandb
7 | import os
8 | import psutil
9 | import time
10 | from train_unet import train_model
11 | from flood_dataset import FloodSegmentationDataset
12 | from models.unet import UNet
13 |
14 | # Define project structure
15 | PROJECT_BASE = 'flood-segmentation'
16 | OPTIMIZATION_PHASE = 'data-loading-opt' # Change this for different optimization phases
17 | PROJECT_NAME = f"{PROJECT_BASE}/{OPTIMIZATION_PHASE}"
18 |
19 | def get_gpu_memory_usage():
20 | """Get GPU memory usage if available"""
21 | if torch.cuda.is_available():
22 | return torch.cuda.memory_allocated() / 1024**3 # Convert to GB
23 | return 0
24 |
25 | def create_sweep_config():
26 | """Create sweep configuration for parameter optimization"""
27 | sweep_config = {
28 | 'method': 'bayes', # Bayesian optimization
29 | 'metric': {
30 | 'name': 'val_iou',
31 | 'goal': 'maximize'
32 | },
33 | 'name': OPTIMIZATION_PHASE, # Name of the sweep
34 | 'project': PROJECT_BASE, # Project name
35 | 'parameters': {
36 | # Learning rate parameters
37 | 'learning_rate': {
38 | 'distribution': 'log_uniform_values',
39 | 'min': 1e-5,
40 | 'max': 1e-3,
41 | },
42 | 'weight_decay': {
43 | 'distribution': 'log_uniform_values',
44 | 'min': 1e-5,
45 | 'max': 1e-3,
46 | },
47 | 'warmup_epochs': {
48 | 'values': [1, 2, 3, 4, 5]
49 | },
50 | 'plateau_factor': {
51 | 'values': [0.1, 0.2, 0.5, 0.7]
52 | },
53 | 'plateau_patience': {
54 | 'values': [3, 5, 7, 10]
55 | },
56 |
57 | # Data loading parameters
58 | 'batch_size': {
59 | 'values': [16, 32, 40]
60 | },
61 | 'num_workers': {
62 | 'values': [2, 4, 8, 16]
63 | },
64 | 'prefetch_factor': {
65 | 'values': [1, 2, 4]
66 | }
67 | }
68 | }
69 | return sweep_config
70 |
71 | def train_sweep():
72 | """Training function for sweep"""
73 | # Initialize wandb with project structure first
74 | if not wandb.run: # Only initialize if not already initialized
75 | wandb.init(project=PROJECT_BASE, group=OPTIMIZATION_PHASE)
76 |
77 | try:
78 | # Get hyperparameters from sweep
79 | config = {
80 | 'learning_rate': wandb.config.learning_rate,
81 | 'weight_decay': wandb.config.weight_decay,
82 | 'epochs': 60,
83 | 'checkpoint_freq': 5,
84 | 'save_dir': f'checkpoints/{OPTIMIZATION_PHASE}/{wandb.run.id}',
85 | 'use_wandb': True,
86 | 'early_stopping_patience': 10,
87 | 'warmup_epochs': wandb.config.warmup_epochs,
88 | 'model': {
89 | 'n_channels': 2,
90 | 'n_classes': 1,
91 | 'bilinear': True # Fixed to True for now
92 | },
93 | 'batch_size': wandb.config.batch_size,
94 | 'bce_weight': 0.5,
95 | 'dice_weight': 0.5,
96 | 'plateau_factor': wandb.config.plateau_factor,
97 | 'plateau_patience': wandb.config.plateau_patience,
98 | 'num_workers': wandb.config.num_workers,
99 | 'prefetch_factor': wandb.config.prefetch_factor,
100 | 'optimization_phase': OPTIMIZATION_PHASE # Track which phase this run belongs to
101 | }
102 |
103 | # Log system information immediately after wandb initialization
104 | if wandb.run:
105 | wandb.log({
106 | 'system/cpu_count': os.cpu_count(),
107 | 'system/gpu_count': torch.cuda.device_count() if torch.cuda.is_available() else 0,
108 | 'system/gpu_name': torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'None',
109 | 'system/memory_total': psutil.virtual_memory().total / 1024**3,
110 | 'system/memory_available': psutil.virtual_memory().available / 1024**3,
111 | 'system/gpu_memory_reserved': torch.cuda.memory_reserved() / 1024**3,
112 | 'system/max_gpu_memory_allocated': torch.cuda.max_memory_allocated() / 1024**3,
113 | 'config': config # Log the full config
114 | })
115 |
116 | # Create data directory if it doesn't exist
117 | data_dir = Path('data')
118 | data_dir.mkdir(parents=True, exist_ok=True)
119 |
120 | # Create datasets
121 | train_dataset = FloodSegmentationDataset(
122 | map_file=str(data_dir / 'train.txt'),
123 | base_path=str(data_dir),
124 | split='train'
125 | )
126 |
127 | val_dataset = FloodSegmentationDataset(
128 | map_file=str(data_dir / 'val.txt'),
129 | base_path=str(data_dir),
130 | split='val'
131 | )
132 |
133 | # Create dataloaders with sweep parameters
134 | train_loader = DataLoader(
135 | train_dataset,
136 | batch_size=config['batch_size'],
137 | shuffle=True,
138 | num_workers=config['num_workers'],
139 | pin_memory=True,
140 | persistent_workers=True,
141 | prefetch_factor=config['prefetch_factor']
142 | )
143 |
144 | val_loader = DataLoader(
145 | val_dataset,
146 | batch_size=config['batch_size'],
147 | shuffle=False,
148 | num_workers=config['num_workers'],
149 | pin_memory=True,
150 | persistent_workers=True,
151 | prefetch_factor=config['prefetch_factor']
152 | )
153 |
154 | # Initialize model
155 | model = UNet(**config['model'])
156 |
157 | # Train the model
158 | best_val_loss = train_model(model, train_loader, val_loader, config)
159 |
160 | # Log final system metrics
161 | if wandb.run:
162 | wandb.log({
163 | 'system/final_gpu_memory': get_gpu_memory_usage(),
164 | 'system/final_cpu_percent': psutil.cpu_percent(),
165 | 'system/final_memory_percent': psutil.virtual_memory().percent
166 | })
167 |
168 | # Set the validation IoU as a summary metric for the sweep
169 | wandb.run.summary.update({
170 | 'val_iou': wandb.run.summary.get('val_iou', 0.0),
171 | 'best_val_loss': best_val_loss
172 | })
173 |
174 | except torch.cuda.OutOfMemoryError as e:
175 | # Log OOM error specifically
176 | if wandb.run:
177 | wandb.log({
178 | 'error': str(e),
179 | 'error_type': 'OutOfMemoryError',
180 | 'failed_config': wandb.config,
181 | 'system/final_gpu_memory': get_gpu_memory_usage(),
182 | 'system/final_cpu_percent': psutil.cpu_percent(),
183 | 'system/final_memory_percent': psutil.virtual_memory().percent
184 | })
185 | raise e
186 | except Exception as e:
187 | # Log other errors
188 | if wandb.run:
189 | wandb.log({
190 | 'error': str(e),
191 | 'error_type': type(e).__name__,
192 | 'failed_config': wandb.config,
193 | 'system/final_gpu_memory': get_gpu_memory_usage(),
194 | 'system/final_cpu_percent': psutil.cpu_percent(),
195 | 'system/final_memory_percent': psutil.virtual_memory().percent
196 | })
197 | raise e
198 | finally:
199 | # Ensure wandb is properly closed
200 | if wandb.run:
201 | wandb.finish()
202 |
203 | def run_sweep_with_retries(sweep_id, max_retries=2, total_trials=20):
204 | """Run sweep with retry mechanism for failed trials"""
205 | api = wandb.Api()
206 |
207 | successful_trials = 0
208 | failed_trials = 0
209 |
210 | while successful_trials < total_trials:
211 | try:
212 | # Run a single trial
213 | wandb.agent(sweep_id, function=train_sweep, count=1)
214 | successful_trials += 1
215 | print(f"Successfully completed trial {successful_trials}/{total_trials}")
216 |
217 | except Exception as e:
218 | failed_trials += 1
219 | print(f"Trial {successful_trials + failed_trials} failed with error: {str(e)}")
220 |
221 | if failed_trials >= max_retries:
222 | print(f"Too many consecutive failures ({failed_trials}). Stopping sweep.")
223 | break
224 |
225 | print(f"Retrying... (Attempt {failed_trials}/{max_retries})")
226 | continue
227 |
228 | print(f"\nSweep completed with {successful_trials} successful trials and {failed_trials} failed trials")
229 | return successful_trials, failed_trials
230 |
231 | def print_optimal_parameters(sweep_id):
232 | """Print optimal parameters from sweep"""
233 | api = wandb.Api()
234 |
235 | # Add a delay to allow for synchronization
236 | time.sleep(5)
237 |
238 | try:
239 | # Get sweep
240 | sweep = api.sweep(f"{PROJECT_NAME}/{sweep_id}")
241 |
242 | # Get best run
243 | best_run = sweep.best_run()
244 |
245 | print("\nOptimal Parameters Found:")
246 | print("-" * 50)
247 | print(f"Best Validation IoU: {best_run.summary['val_iou']:.4f}")
248 | print("\nLearning Rate Parameters:")
249 | print(f"Learning Rate: {wandb.config.learning_rate:.2e}")
250 | print(f"Weight Decay: {wandb.config.weight_decay:.2e}")
251 | print(f"Warmup Epochs: {wandb.config.warmup_epochs}")
252 | print(f"Plateau Factor: {wandb.config.plateau_factor}")
253 | print(f"Plateau Patience: {wandb.config.plateau_patience}")
254 | print("\nData Loading Parameters:")
255 | print(f"Batch Size: {wandb.config.batch_size}")
256 | print(f"Number of Workers: {wandb.config.num_workers}")
257 | print(f"Prefetch Factor: {wandb.config.prefetch_factor}")
258 | print("-" * 50)
259 | except Exception as e:
260 | print("\nCould not retrieve optimal parameters yet.")
261 | print("This is normal if the sweep is still running or results haven't synced.")
262 | print("You can view the results at:")
263 | print(f"https://wandb.ai/mracic/flood-segmentation/sweeps/{sweep_id}")
264 | print(f"\nError details: {str(e)}")
265 |
266 | if __name__ == "__main__":
267 | # Create sweep configuration
268 | sweep_config = create_sweep_config()
269 |
270 | # Initialize sweep
271 | sweep_id = wandb.sweep(sweep_config)
272 |
273 | print(f"\nSweep created with ID: {sweep_id}")
274 | print(f"View sweep at: https://wandb.ai/mracic/flood-segmentation/sweeps/{sweep_id}")
275 |
276 | # Run sweep with retry mechanism
277 | successful_trials, failed_trials = run_sweep_with_retries(sweep_id)
278 |
279 | if successful_trials > 0:
280 | print(f"\nSweep completed with {successful_trials} successful trials and {failed_trials} failed trials")
281 | print("Attempting to retrieve optimal parameters...")
282 | print_optimal_parameters(sweep_id)
283 | else:
284 | print("\nNo successful trials completed. Cannot determine optimal parameters.")
--------------------------------------------------------------------------------
/FloodsML/models/backbone/xception.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | import torch.utils.model_zoo as model_zoo
6 | from models.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d
7 |
8 | def fixed_padding(inputs, kernel_size, dilation):
9 | kernel_size_effective = kernel_size + (kernel_size - 1) * (dilation - 1)
10 | pad_total = kernel_size_effective - 1
11 | pad_beg = pad_total // 2
12 | pad_end = pad_total - pad_beg
13 | padded_inputs = F.pad(inputs, (pad_beg, pad_end, pad_beg, pad_end))
14 | return padded_inputs
15 |
16 |
17 | class SeparableConv2d(nn.Module):
18 | def __init__(self, inplanes, planes, kernel_size=3, stride=1, dilation=1, bias=False, BatchNorm=None):
19 | super(SeparableConv2d, self).__init__()
20 |
21 | self.conv1 = nn.Conv2d(inplanes, inplanes, kernel_size, stride, 0, dilation,
22 | groups=inplanes, bias=bias)
23 | self.bn = BatchNorm(inplanes)
24 | self.pointwise = nn.Conv2d(inplanes, planes, 1, 1, 0, 1, 1, bias=bias)
25 |
26 | def forward(self, x):
27 | x = fixed_padding(x, self.conv1.kernel_size[0], dilation=self.conv1.dilation[0])
28 | x = self.conv1(x)
29 | x = self.bn(x)
30 | x = self.pointwise(x)
31 | return x
32 |
33 |
34 | class Block(nn.Module):
35 | def __init__(self, inplanes, planes, reps, stride=1, dilation=1, BatchNorm=None,
36 | start_with_relu=True, grow_first=True, is_last=False):
37 | super(Block, self).__init__()
38 |
39 | if planes != inplanes or stride != 1:
40 | self.skip = nn.Conv2d(inplanes, planes, 1, stride=stride, bias=False)
41 | self.skipbn = BatchNorm(planes)
42 | else:
43 | self.skip = None
44 |
45 | self.relu = nn.ReLU(inplace=True)
46 | rep = []
47 |
48 | filters = inplanes
49 | if grow_first:
50 | rep.append(self.relu)
51 | rep.append(SeparableConv2d(inplanes, planes, 3, 1, dilation, BatchNorm=BatchNorm))
52 | rep.append(BatchNorm(planes))
53 | filters = planes
54 |
55 | for i in range(reps - 1):
56 | rep.append(self.relu)
57 | rep.append(SeparableConv2d(filters, filters, 3, 1, dilation, BatchNorm=BatchNorm))
58 | rep.append(BatchNorm(filters))
59 |
60 | if not grow_first:
61 | rep.append(self.relu)
62 | rep.append(SeparableConv2d(inplanes, planes, 3, 1, dilation, BatchNorm=BatchNorm))
63 | rep.append(BatchNorm(planes))
64 |
65 | if stride != 1:
66 | rep.append(self.relu)
67 | rep.append(SeparableConv2d(planes, planes, 3, 2, BatchNorm=BatchNorm))
68 | rep.append(BatchNorm(planes))
69 |
70 | if stride == 1 and is_last:
71 | rep.append(self.relu)
72 | rep.append(SeparableConv2d(planes, planes, 3, 1, BatchNorm=BatchNorm))
73 | rep.append(BatchNorm(planes))
74 |
75 | if not start_with_relu:
76 | rep = rep[1:]
77 |
78 | self.rep = nn.Sequential(*rep)
79 |
80 | def forward(self, inp):
81 | x = self.rep(inp)
82 |
83 | if self.skip is not None:
84 | skip = self.skip(inp)
85 | skip = self.skipbn(skip)
86 | else:
87 | skip = inp
88 |
89 | x = x + skip
90 |
91 | return x
92 |
93 |
94 | class AlignedXception(nn.Module):
95 | """
96 | Modified Alighed Xception
97 | """
98 | def __init__(self, output_stride, BatchNorm,
99 | pretrained=True):
100 | super(AlignedXception, self).__init__()
101 |
102 | if output_stride == 16:
103 | entry_block3_stride = 2
104 | middle_block_dilation = 1
105 | exit_block_dilations = (1, 2)
106 | elif output_stride == 8:
107 | entry_block3_stride = 1
108 | middle_block_dilation = 2
109 | exit_block_dilations = (2, 4)
110 | else:
111 | raise NotImplementedError
112 |
113 |
114 | # Entry flow
115 | self.conv1 = nn.Conv2d(3, 32, 3, stride=2, padding=1, bias=False)
116 | self.bn1 = BatchNorm(32)
117 | self.relu = nn.ReLU(inplace=True)
118 |
119 | self.conv2 = nn.Conv2d(32, 64, 3, stride=1, padding=1, bias=False)
120 | self.bn2 = BatchNorm(64)
121 |
122 | self.block1 = Block(64, 128, reps=2, stride=2, BatchNorm=BatchNorm, start_with_relu=False)
123 | self.block2 = Block(128, 256, reps=2, stride=2, BatchNorm=BatchNorm, start_with_relu=False,
124 | grow_first=True)
125 | self.block3 = Block(256, 728, reps=2, stride=entry_block3_stride, BatchNorm=BatchNorm,
126 | start_with_relu=True, grow_first=True, is_last=True)
127 |
128 | # Middle flow
129 | self.block4 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation,
130 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True)
131 | self.block5 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation,
132 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True)
133 | self.block6 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation,
134 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True)
135 | self.block7 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation,
136 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True)
137 | self.block8 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation,
138 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True)
139 | self.block9 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation,
140 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True)
141 | self.block10 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation,
142 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True)
143 | self.block11 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation,
144 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True)
145 | self.block12 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation,
146 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True)
147 | self.block13 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation,
148 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True)
149 | self.block14 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation,
150 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True)
151 | self.block15 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation,
152 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True)
153 | self.block16 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation,
154 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True)
155 | self.block17 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation,
156 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True)
157 | self.block18 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation,
158 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True)
159 | self.block19 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation,
160 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True)
161 |
162 | # Exit flow
163 | self.block20 = Block(728, 1024, reps=2, stride=1, dilation=exit_block_dilations[0],
164 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=False, is_last=True)
165 |
166 | self.conv3 = SeparableConv2d(1024, 1536, 3, stride=1, dilation=exit_block_dilations[1], BatchNorm=BatchNorm)
167 | self.bn3 = BatchNorm(1536)
168 |
169 | self.conv4 = SeparableConv2d(1536, 1536, 3, stride=1, dilation=exit_block_dilations[1], BatchNorm=BatchNorm)
170 | self.bn4 = BatchNorm(1536)
171 |
172 | self.conv5 = SeparableConv2d(1536, 2048, 3, stride=1, dilation=exit_block_dilations[1], BatchNorm=BatchNorm)
173 | self.bn5 = BatchNorm(2048)
174 |
175 | # Init weights
176 | self._init_weight()
177 |
178 | # Load pretrained model
179 | if pretrained:
180 | self._load_pretrained_model()
181 |
182 | def forward(self, x):
183 | # Entry flow
184 | x = self.conv1(x)
185 | x = self.bn1(x)
186 | x = self.relu(x)
187 |
188 | x = self.conv2(x)
189 | x = self.bn2(x)
190 | x = self.relu(x)
191 |
192 | x = self.block1(x)
193 | # add relu here
194 | x = self.relu(x)
195 | low_level_feat = x
196 | x = self.block2(x)
197 | x = self.block3(x)
198 |
199 | # Middle flow
200 | x = self.block4(x)
201 | x = self.block5(x)
202 | x = self.block6(x)
203 | x = self.block7(x)
204 | x = self.block8(x)
205 | x = self.block9(x)
206 | x = self.block10(x)
207 | x = self.block11(x)
208 | x = self.block12(x)
209 | x = self.block13(x)
210 | x = self.block14(x)
211 | x = self.block15(x)
212 | x = self.block16(x)
213 | x = self.block17(x)
214 | x = self.block18(x)
215 | x = self.block19(x)
216 |
217 | # Exit flow
218 | x = self.block20(x)
219 | x = self.relu(x)
220 | x = self.conv3(x)
221 | x = self.bn3(x)
222 | x = self.relu(x)
223 |
224 | x = self.conv4(x)
225 | x = self.bn4(x)
226 | x = self.relu(x)
227 |
228 | x = self.conv5(x)
229 | x = self.bn5(x)
230 | x = self.relu(x)
231 |
232 | return x, low_level_feat
233 |
234 | def _init_weight(self):
235 | for m in self.modules():
236 | if isinstance(m, nn.Conv2d):
237 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
238 | m.weight.data.normal_(0, math.sqrt(2. / n))
239 | elif isinstance(m, SynchronizedBatchNorm2d):
240 | m.weight.data.fill_(1)
241 | m.bias.data.zero_()
242 | elif isinstance(m, nn.BatchNorm2d):
243 | m.weight.data.fill_(1)
244 | m.bias.data.zero_()
245 |
246 |
247 | def _load_pretrained_model(self):
248 | pretrain_dict = model_zoo.load_url('http://data.lip6.fr/cadene/pretrainedmodels/xception-b5690688.pth')
249 | model_dict = {}
250 | state_dict = self.state_dict()
251 |
252 | for k, v in pretrain_dict.items():
253 | if k in model_dict:
254 | if 'pointwise' in k:
255 | v = v.unsqueeze(-1).unsqueeze(-1)
256 | if k.startswith('block11'):
257 | model_dict[k] = v
258 | model_dict[k.replace('block11', 'block12')] = v
259 | model_dict[k.replace('block11', 'block13')] = v
260 | model_dict[k.replace('block11', 'block14')] = v
261 | model_dict[k.replace('block11', 'block15')] = v
262 | model_dict[k.replace('block11', 'block16')] = v
263 | model_dict[k.replace('block11', 'block17')] = v
264 | model_dict[k.replace('block11', 'block18')] = v
265 | model_dict[k.replace('block11', 'block19')] = v
266 | elif k.startswith('block12'):
267 | model_dict[k.replace('block12', 'block20')] = v
268 | elif k.startswith('bn3'):
269 | model_dict[k] = v
270 | model_dict[k.replace('bn3', 'bn4')] = v
271 | elif k.startswith('conv4'):
272 | model_dict[k.replace('conv4', 'conv5')] = v
273 | elif k.startswith('bn4'):
274 | model_dict[k.replace('bn4', 'bn5')] = v
275 | else:
276 | model_dict[k] = v
277 | state_dict.update(model_dict)
278 | self.load_state_dict(state_dict)
279 |
280 |
281 |
282 | if __name__ == "__main__":
283 | import torch
284 | model = AlignedXception(BatchNorm=nn.BatchNorm2d, pretrained=True, output_stride=16)
285 | input = torch.rand(1, 3, 512, 512)
286 | output, low_level_feat = model(input)
287 | print(output.size())
288 | print(low_level_feat.size())
--------------------------------------------------------------------------------
/FloodsML/optimize_parameters_extended.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.optim as optim
4 | from torch.utils.data import DataLoader
5 | from pathlib import Path
6 | import wandb
7 | import os
8 | import psutil
9 | import time
10 | from train_unet import train_model
11 | from flood_dataset import FloodSegmentationDataset
12 | from models.unet import UNet
13 | import pandas as pd
14 |
15 | # Define project structure
16 | PROJECT_BASE = 'flood-segmentation'
17 | OPTIMIZATION_PHASE = 'extended-optimization' # New phase for extended optimization
18 | PROJECT_NAME = f"{PROJECT_BASE}/{OPTIMIZATION_PHASE}"
19 |
20 | def get_gpu_memory_usage():
21 | """Get GPU memory usage if available"""
22 | if torch.cuda.is_available():
23 | return torch.cuda.memory_allocated() / 1024**3 # Convert to GB
24 | return 0
25 |
26 | def get_best_parameters_from_previous_sweep(sweep_name):
27 | """Retrieve best parameters from the previous optimization phase"""
28 | api = wandb.Api()
29 |
30 | # Get all sweeps in the project
31 | project = api.project(PROJECT_BASE)
32 | sweeps = project.sweeps()
33 |
34 | # Find the sweep with name 'data-loading-opt'
35 | previous_sweep = None
36 | for sweep in sweeps:
37 | if sweep.name == sweep_name:
38 | previous_sweep = sweep
39 | break
40 |
41 | if previous_sweep is None:
42 | raise ValueError(f"Could not find sweep with name '{sweep_name}'. Please make sure the previous optimization phase has completed.")
43 |
44 | best_run = previous_sweep.best_run()
45 |
46 | # Get the initial learning rate from the run history
47 | history = best_run.scan_history()
48 | #history_df = pd.DataFrame(history)
49 | initial_lr = best_run.config['learning_rate'] # history_df['learning_rate'].iloc[0]
50 |
51 | return {
52 | 'learning_rate': initial_lr, # Use initial learning rate instead of final config value
53 | 'weight_decay': best_run.config['weight_decay'],
54 | 'warmup_epochs': best_run.config['warmup_epochs'],
55 | 'plateau_factor': best_run.config['plateau_factor'],
56 | 'plateau_patience': best_run.config['plateau_patience'],
57 | 'batch_size': best_run.config['batch_size'],
58 | 'num_workers': best_run.config['num_workers'],
59 | 'prefetch_factor': best_run.config['prefetch_factor']
60 | }
61 |
62 | def create_sweep_config():
63 | """Create sweep configuration for extended parameter optimization"""
64 | # Get best parameters from previous sweep
65 | best_params = get_best_parameters_from_previous_sweep('data-loading-opt')
66 |
67 | sweep_config = {
68 | 'method': 'bayes', # Bayesian optimization
69 | 'metric': {
70 | 'name': 'val_iou',
71 | 'goal': 'maximize'
72 | },
73 | 'name': OPTIMIZATION_PHASE,
74 | 'project': PROJECT_BASE,
75 | 'parameters': {
76 | # Learning rate parameters - finer range around best values
77 | 'learning_rate': {
78 | 'distribution': 'log_uniform_values',
79 | 'min': best_params['learning_rate'] * 0.5,
80 | 'max': best_params['learning_rate'] * 2.0,
81 | },
82 | 'weight_decay': {
83 | 'distribution': 'log_uniform_values',
84 | 'min': best_params['weight_decay'] * 0.5,
85 | 'max': best_params['weight_decay'] * 2.0,
86 | },
87 | 'warmup_epochs': {
88 | 'values': [max(1, best_params['warmup_epochs'] - 1),
89 | best_params['warmup_epochs'],
90 | min(10, best_params['warmup_epochs'] + 1)]
91 | },
92 | 'plateau_factor': {
93 | 'values': [max(0.1, best_params['plateau_factor'] - 0.1),
94 | best_params['plateau_factor'],
95 | min(0.8, best_params['plateau_factor'] + 0.1)]
96 | },
97 | 'plateau_patience': {
98 | 'values': [max(3, best_params['plateau_patience'] - 2),
99 | best_params['plateau_patience'],
100 | min(15, best_params['plateau_patience'] + 2)]
101 | },
102 |
103 | # Data loading parameters - refined ranges
104 | 'batch_size': {
105 | 'values': [max(8, best_params['batch_size'] - 8),
106 | best_params['batch_size'],
107 | min(50, best_params['batch_size'] + 8)]
108 | },
109 | 'num_workers': {
110 | 'values': [max(1, best_params['num_workers'] - 2),
111 | best_params['num_workers'],
112 | min(32, best_params['num_workers'] + 2)]
113 | },
114 | 'prefetch_factor': {
115 | 'values': [max(1, best_params['prefetch_factor'] - 1),
116 | best_params['prefetch_factor'],
117 | min(8, best_params['prefetch_factor'] + 1)]
118 | },
119 |
120 | # Additional parameters for extended optimization
121 | 'dropout_rate': {
122 | 'values': [0.0, 0.1, 0.2]
123 | },
124 | 'bce_weight': {
125 | 'values': [0, 0.3, 0.5, 0.7, 1] # Dice weight will be 1 - bce_weight
126 | }
127 | }
128 | }
129 | return sweep_config
130 |
131 | def train_sweep():
132 | """Training function for extended sweep"""
133 | wandb.init(project=PROJECT_BASE, group=OPTIMIZATION_PHASE)
134 |
135 | try:
136 | # Get hyperparameters from sweep
137 | config = {
138 | 'learning_rate': wandb.config.learning_rate,
139 | 'weight_decay': wandb.config.weight_decay,
140 | 'epochs': 50, # Increased for better convergence
141 | 'checkpoint_freq': 5,
142 | 'save_dir': f'checkpoints/{OPTIMIZATION_PHASE}/{wandb.run.id}',
143 | 'use_wandb': True,
144 | 'early_stopping_patience': 15, # Increased patience
145 | 'warmup_epochs': wandb.config.warmup_epochs,
146 | 'model': {
147 | 'n_channels': 2,
148 | 'n_classes': 1,
149 | 'bilinear': True,
150 | 'dropout_rate': wandb.config.dropout_rate # Added dropout_rate to model config
151 | },
152 | 'batch_size': wandb.config.batch_size,
153 | 'bce_weight': wandb.config.bce_weight,
154 | 'dice_weight': 1.0 - wandb.config.bce_weight, # Ensure weights sum to 1.0
155 | 'plateau_factor': wandb.config.plateau_factor,
156 | 'plateau_patience': wandb.config.plateau_patience,
157 | 'num_workers': wandb.config.num_workers,
158 | 'prefetch_factor': wandb.config.prefetch_factor,
159 | 'optimization_phase': OPTIMIZATION_PHASE
160 | }
161 |
162 | # Log system information
163 | wandb.log({
164 | 'system/cpu_count': os.cpu_count(),
165 | 'system/gpu_count': torch.cuda.device_count() if torch.cuda.is_available() else 0,
166 | 'system/gpu_name': torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'None',
167 | 'system/memory_total': psutil.virtual_memory().total / 1024**3,
168 | 'system/memory_available': psutil.virtual_memory().available / 1024**3,
169 | 'config': dict(config) # Convert to regular dict for JSON serialization
170 | })
171 |
172 | # Create data directory if it doesn't exist
173 | data_dir = Path('data')
174 | data_dir.mkdir(parents=True, exist_ok=True)
175 |
176 | # Create datasets
177 | train_dataset = FloodSegmentationDataset(
178 | map_file=str(data_dir / 'train.txt'),
179 | base_path=str(data_dir),
180 | split='train'
181 | )
182 |
183 | val_dataset = FloodSegmentationDataset(
184 | map_file=str(data_dir / 'val.txt'),
185 | base_path=str(data_dir),
186 | split='val'
187 | )
188 |
189 | # Create dataloaders with sweep parameters
190 | train_loader = DataLoader(
191 | train_dataset,
192 | batch_size=config['batch_size'],
193 | shuffle=True,
194 | num_workers=config['num_workers'],
195 | pin_memory=True,
196 | persistent_workers=True,
197 | prefetch_factor=config['prefetch_factor']
198 | )
199 |
200 | val_loader = DataLoader(
201 | val_dataset,
202 | batch_size=config['batch_size'],
203 | shuffle=False,
204 | num_workers=config['num_workers'],
205 | pin_memory=True,
206 | persistent_workers=True,
207 | prefetch_factor=config['prefetch_factor']
208 | )
209 |
210 | # Initialize model
211 | model = UNet(**config['model'])
212 |
213 | # Train the model
214 | train_model(model, train_loader, val_loader, config)
215 |
216 | # Log final system metrics
217 | wandb.log({
218 | 'system/final_gpu_memory': get_gpu_memory_usage(),
219 | 'system/final_cpu_percent': psutil.cpu_percent(),
220 | 'system/final_memory_percent': psutil.virtual_memory().percent
221 | })
222 |
223 | except Exception as e:
224 | wandb.log({
225 | 'error': str(e),
226 | 'error_type': type(e).__name__,
227 | 'failed_config': dict(wandb.config), # Convert to regular dict for JSON serialization
228 | 'system/final_gpu_memory': get_gpu_memory_usage(),
229 | 'system/final_cpu_percent': psutil.cpu_percent(),
230 | 'system/final_memory_percent': psutil.virtual_memory().percent
231 | })
232 | raise e
233 | finally:
234 | if wandb.run:
235 | wandb.finish()
236 |
237 | def run_sweep_with_retries(sweep_id, max_retries=3, total_trials=30):
238 | """Run sweep with retry mechanism for failed trials"""
239 | api = wandb.Api()
240 |
241 | successful_trials = 0
242 | failed_trials = 0
243 |
244 | while successful_trials < total_trials:
245 | try:
246 | wandb.agent(sweep_id, function=train_sweep, count=1)
247 | successful_trials += 1
248 | print(f"Successfully completed trial {successful_trials}/{total_trials}")
249 |
250 | except Exception as e:
251 | failed_trials += 1
252 | print(f"Trial {successful_trials + failed_trials} failed with error: {str(e)}")
253 |
254 | if failed_trials >= max_retries:
255 | print(f"Too many consecutive failures ({failed_trials}). Stopping sweep.")
256 | break
257 |
258 | print(f"Retrying... (Attempt {failed_trials}/{max_retries})")
259 | continue
260 | sweep = api.sweep(f"{PROJECT_NAME}/{sweep_id}")
261 | sweep.state = "FINISHED" # or "FINISHED" (both stop the sweep)
262 | sweep.update()
263 | print(f"\nSweep completed with {successful_trials} successful trials and {failed_trials} failed trials")
264 | return successful_trials, failed_trials
265 |
266 | def print_optimal_parameters(sweep_id):
267 | """Print optimal parameters from sweep"""
268 | api = wandb.Api()
269 |
270 | time.sleep(5) # Allow for synchronization
271 |
272 | try:
273 | sweep = api.sweep(f"{PROJECT_NAME}/{sweep_id}")
274 | best_run = sweep.best_run()
275 |
276 | print("\nOptimal Parameters Found:")
277 | print("-" * 50)
278 | print(f"Best Validation IoU: {best_run.summary['val_iou']:.4f}")
279 | print("\nLearning Rate Parameters:")
280 | print(f"Learning Rate: {wandb.config.learning_rate:.2e}")
281 | print(f"Weight Decay: {wandb.config.weight_decay:.2e}")
282 | print(f"Warmup Epochs: {wandb.config.warmup_epochs}")
283 | print(f"Plateau Factor: {wandb.config.plateau_factor}")
284 | print(f"Plateau Patience: {wandb.config.plateau_patience}")
285 | print("\nData Loading Parameters:")
286 | print(f"Batch Size: {wandb.config.batch_size}")
287 | print(f"Number of Workers: {wandb.config.num_workers}")
288 | print(f"Prefetch Factor: {wandb.config.prefetch_factor}")
289 | print("\nAdditional Parameters:")
290 | print(f"Dropout Rate: {wandb.config.dropout_rate}")
291 | print(f"BCE Weight: {wandb.config.bce_weight}")
292 | print(f"Dice Weight: {wandb.config.dice_weight}")
293 | print("-" * 50)
294 | except Exception as e:
295 | print("\nCould not retrieve optimal parameters yet.")
296 | print("This is normal if the sweep is still running or results haven't synced.")
297 | print("You can view the results at:")
298 | print(f"https://wandb.ai/mracic/flood-segmentation/sweeps/{sweep_id}")
299 | print(f"\nError details: {str(e)}")
300 |
301 | if __name__ == "__main__":
302 | # Create sweep configuration
303 | sweep_config = create_sweep_config()
304 |
305 | # Initialize sweep
306 | sweep_id = wandb.sweep(sweep_config)
307 |
308 | print(f"\nExtended sweep created with ID: {sweep_id}")
309 | print(f"View sweep at: https://wandb.ai/mracic/flood-segmentation/sweeps/{sweep_id}")
310 |
311 | # Run sweep with retry mechanism
312 | successful_trials, failed_trials = run_sweep_with_retries(sweep_id)
313 |
314 | if successful_trials > 0:
315 | print(f"\nExtended sweep completed with {successful_trials} successful trials and {failed_trials} failed trials")
316 | print("Attempting to retrieve optimal parameters...")
317 | print_optimal_parameters(sweep_id)
318 | else:
319 | print("\nNo successful trials completed. Cannot determine optimal parameters.")
--------------------------------------------------------------------------------
/FloodsML/train_inference_full.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.optim as optim
4 | from torch.utils.data import DataLoader
5 | from pathlib import Path
6 | import wandb
7 | import os
8 | import psutil
9 | import time
10 | import numpy as np
11 | from tqdm import tqdm
12 | from torchmetrics import JaccardIndex
13 | import matplotlib.pyplot as plt
14 | from models.unet import UNet
15 | from flood_dataset import FloodSegmentationDataset
16 |
17 | # Define project structure
18 | PROJECT_BASE = 'flood-segmentation'
19 | EXPERIMENT_NAME = 'single-optimized-run'
20 | PROJECT_NAME = f"{PROJECT_BASE}/{EXPERIMENT_NAME}"
21 |
22 | def get_gpu_memory_usage():
23 | """Get GPU memory usage if available"""
24 | if torch.cuda.is_available():
25 | return torch.cuda.memory_allocated() / 1024**3 # Convert to GB
26 | return 0
27 |
28 | def train_model(model, train_loader, val_loader, config, device):
29 | """Train the model with the given configuration"""
30 | model = model.to(device)
31 |
32 | # Setup loss function
33 | bce_weight = config['bce_weight']
34 | dice_weight = config['dice_weight']
35 |
36 | def combined_loss(pred, target):
37 | # Binary Cross Entropy Loss
38 | bce_loss = nn.BCEWithLogitsLoss()(pred, target)
39 |
40 | # Dice Loss
41 | pred_sigmoid = torch.sigmoid(pred)
42 | smooth = 1.0
43 | intersection = (pred_sigmoid * target).sum()
44 | dice_loss = 1 - (2.0 * intersection + smooth) / (pred_sigmoid.sum() + target.sum() + smooth)
45 |
46 | # Combine losses
47 | return bce_weight * bce_loss + dice_weight * dice_loss
48 |
49 | # Setup optimizer
50 | optimizer = optim.AdamW(
51 | model.parameters(),
52 | lr=config['learning_rate'],
53 | weight_decay=config['weight_decay']
54 | )
55 |
56 | # Setup learning rate scheduler with warmup
57 | def lr_lambda(epoch):
58 | if epoch < config['warmup_epochs']:
59 | return epoch / config['warmup_epochs']
60 | return 1.0
61 |
62 | warmup_scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda)
63 | plateau_scheduler = optim.lr_scheduler.ReduceLROnPlateau(
64 | optimizer,
65 | mode='min',
66 | factor=config['plateau_factor'],
67 | patience=config['plateau_patience'],
68 | verbose=True
69 | )
70 |
71 | # Metrics
72 | jaccard = JaccardIndex(task='binary').to(device)
73 |
74 | # Setup checkpointing
75 | save_dir = Path(config['save_dir'])
76 | save_dir.mkdir(parents=True, exist_ok=True)
77 |
78 | best_val_iou = 0.0
79 | best_epoch = 0
80 | no_improvement_count = 0
81 |
82 | # Initialize WandB if enabled
83 | if config.get('use_wandb', False):
84 | wandb.init(project=PROJECT_BASE, name=EXPERIMENT_NAME)
85 | wandb.config.update(config)
86 |
87 | for epoch in range(config['epochs']):
88 | model.train()
89 | train_loss = 0.0
90 | train_iou = 0.0
91 |
92 | # Training loop
93 | pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{config['epochs']}")
94 | for batch in pbar:
95 | images = batch['image'].to(device)
96 | masks = batch['label'].to(device)
97 |
98 | optimizer.zero_grad()
99 | outputs = model(images)
100 | loss = combined_loss(outputs, masks)
101 |
102 | loss.backward()
103 | optimizer.step()
104 |
105 | # Calculate IoU
106 | preds = (torch.sigmoid(outputs) > 0.5).float()
107 | batch_iou = jaccard(preds, masks)
108 |
109 | train_loss += loss.item()
110 | train_iou += batch_iou.item()
111 |
112 | pbar.set_postfix({'loss': loss.item(), 'iou': batch_iou.item()})
113 |
114 | # Update warmup scheduler
115 | if epoch < config['warmup_epochs']:
116 | warmup_scheduler.step()
117 |
118 | # Calculate average metrics
119 | avg_train_loss = train_loss / len(train_loader)
120 | avg_train_iou = train_iou / len(train_loader)
121 |
122 | # Validation loop
123 | model.eval()
124 | val_loss = 0.0
125 | val_iou = 0.0
126 |
127 | with torch.no_grad():
128 | for batch in tqdm(val_loader, desc="Validation"):
129 | images = batch['image'].to(device)
130 | masks = batch['label'].to(device)
131 |
132 | outputs = model(images)
133 | loss = combined_loss(outputs, masks)
134 |
135 | # Calculate IoU
136 | preds = (torch.sigmoid(outputs) > 0.5).float()
137 | batch_iou = jaccard(preds, masks)
138 |
139 | val_loss += loss.item()
140 | val_iou += batch_iou.item()
141 |
142 | # Calculate average validation metrics
143 | avg_val_loss = val_loss / len(val_loader)
144 | avg_val_iou = val_iou / len(val_loader)
145 |
146 | # Update plateau scheduler
147 | plateau_scheduler.step(avg_val_loss)
148 |
149 | # Log metrics
150 | print(f"Epoch {epoch+1}/{config['epochs']} - Train Loss: {avg_train_loss:.4f}, Train IoU: {avg_train_iou:.4f}, Val Loss: {avg_val_loss:.4f}, Val IoU: {avg_val_iou:.4f}")
151 |
152 | if config.get('use_wandb', False):
153 | wandb.log({
154 | 'epoch': epoch + 1,
155 | 'train_loss': avg_train_loss,
156 | 'train_iou': avg_train_iou,
157 | 'val_loss': avg_val_loss,
158 | 'val_iou': avg_val_iou,
159 | 'learning_rate': optimizer.param_groups[0]['lr'],
160 | 'gpu_memory': get_gpu_memory_usage(),
161 | 'cpu_percent': psutil.cpu_percent(),
162 | 'memory_percent': psutil.virtual_memory().percent
163 | })
164 |
165 | # Save checkpoint based on validation IoU
166 | if avg_val_iou > best_val_iou:
167 | best_val_iou = avg_val_iou
168 | best_epoch = epoch + 1
169 | no_improvement_count = 0
170 |
171 | # Save best model
172 | torch.save({
173 | 'epoch': epoch + 1,
174 | 'model_state_dict': model.state_dict(),
175 | 'optimizer_state_dict': optimizer.state_dict(),
176 | 'train_loss': avg_train_loss,
177 | 'val_loss': avg_val_loss,
178 | 'train_iou': avg_train_iou,
179 | 'val_iou': avg_val_iou,
180 | 'config': config
181 | }, save_dir / 'best_model.pth')
182 |
183 | print(f"Saved best model with validation IoU: {best_val_iou:.4f}")
184 | else:
185 | no_improvement_count += 1
186 |
187 | # Save checkpoint periodically
188 | if (epoch + 1) % config['checkpoint_freq'] == 0:
189 | torch.save({
190 | 'epoch': epoch + 1,
191 | 'model_state_dict': model.state_dict(),
192 | 'optimizer_state_dict': optimizer.state_dict(),
193 | 'train_loss': avg_train_loss,
194 | 'val_loss': avg_val_loss,
195 | 'train_iou': avg_train_iou,
196 | 'val_iou': avg_val_iou,
197 | 'config': config
198 | }, save_dir / f'checkpoint_epoch_{epoch+1}.pth')
199 |
200 | # Early stopping
201 | if no_improvement_count >= config['early_stopping_patience']:
202 | print(f"Early stopping triggered. No improvement for {no_improvement_count} epochs.")
203 | break
204 |
205 | print(f"Training completed. Best validation IoU: {best_val_iou:.4f} at epoch {best_epoch}")
206 | return best_val_iou, best_epoch
207 |
208 | def load_model(checkpoint_path, model, device):
209 | """Load trained model from checkpoint"""
210 | checkpoint = torch.load(checkpoint_path, map_location=device)
211 | model.load_state_dict(checkpoint['model_state_dict'])
212 | return model, checkpoint['epoch'], checkpoint['val_iou']
213 |
214 | def predict(model, dataloader, device, save_dir=None):
215 | """
216 | Run inference on the dataloader
217 |
218 | Args:
219 | model: Trained PyTorch model
220 | dataloader: DataLoader containing test/validation data
221 | device: torch device
222 | save_dir: Optional directory to save prediction visualizations
223 |
224 | Returns:
225 | metrics: dict containing evaluation metrics
226 | """
227 | model.eval()
228 | jaccard_metric = JaccardIndex(task='binary').to(device)
229 | total_iou_metric = 0.0
230 | num_batches = 0
231 |
232 | if save_dir:
233 | save_dir = Path(save_dir)
234 | save_dir.mkdir(parents=True, exist_ok=True)
235 |
236 | with torch.no_grad():
237 | for idx, batch in enumerate(tqdm(dataloader, desc='Predicting')):
238 | images = batch['image'].to(device) # Shape: (N, C, H, W)
239 | masks = batch['label'].to(device) # Shape: (N, 1, H, W)
240 | image_paths = batch['image_path']
241 |
242 | # Create a mask for pixels where *input* image is zero
243 | zero_mask = (images == 0).all(dim=1, keepdim=True) # Shape: (N, 1, H, W)
244 |
245 | # Forward pass
246 | outputs = model(images)
247 | preds_sigmoid = torch.sigmoid(outputs)
248 | preds_binary = (preds_sigmoid > 0.5).float()
249 |
250 | valid_pixels_mask = ~zero_mask.expand_as(preds_binary) # Shape: (N, 1, H, W), bool
251 |
252 | # Filter predictions and masks based on valid pixels *before* calculating batch IoU
253 | # Only calculate IoU on batches where there are *any* valid pixels
254 | if valid_pixels_mask.any():
255 | preds_masked = preds_binary[valid_pixels_mask]
256 | masks_masked = masks[valid_pixels_mask]
257 |
258 | # Calculate torchmetrics IoU on the filtered pixels
259 | # Handle cases where masks_masked might be all zeros after filtering
260 | if masks_masked.sum() > 0 or preds_masked.sum() > 0: # Avoid calculating IoU if both are empty
261 | batch_iou_val = jaccard_metric(preds_masked, masks_masked)
262 | total_iou_metric += batch_iou_val.item()
263 | num_batches += 1 # Only count batches where IoU was meaningfully calculated
264 | elif masks_masked.sum() == 0 and preds_masked.sum() == 0:
265 | # If both ground truth and prediction are empty in the valid region, IoU is 1
266 | total_iou_metric += 1.0
267 | num_batches += 1
268 | # Else: one is empty, the other isn't -> IoU is 0, torchmetrics handles this, but we add 0 explicitly if needed
269 | # Or simply rely on torchmetrics default behaviour for these cases.
270 |
271 | # --- Calculate Per-Sample IoU for Filenames & Save Images ---
272 | if save_dir:
273 | # Iterate through samples in the *current* batch
274 | for i in range(images.shape[0]):
275 | # Get individual sample tensors
276 | single_pred = preds_binary[i] # Shape: (1, H, W)
277 | single_mask = masks[i] # Shape: (1, H, W)
278 | single_zero_mask = zero_mask[i] # Shape: (1, H, W)
279 |
280 | # Apply the zero mask to the single sample
281 | single_valid_pixels = ~single_zero_mask.expand_as(single_pred) # Shape: (1, H, W), bool
282 |
283 | sample_iou = 0.0 # Default IoU if no valid pixels or calculation fails
284 | if single_valid_pixels.any():
285 | pred_masked_single = single_pred[single_valid_pixels] # Flattened valid pixels
286 | mask_masked_single = single_mask[single_valid_pixels] # Flattened valid pixels
287 |
288 | # Calculate IoU for this single sample (ignoring zero pixels)
289 | jaccard_metric.reset() # Reset internal state for single sample calc
290 | try:
291 | # Handle case where both are empty after masking -> IoU=1
292 | if mask_masked_single.sum() == 0 and pred_masked_single.sum() == 0:
293 | sample_iou = 1.0
294 | else:
295 | sample_iou = jaccard_metric(pred_masked_single, mask_masked_single).item()
296 | except Exception as e:
297 | print(f"Warning: Could not compute IoU for sample {i} in batch {idx}. Error: {e}")
298 | sample_iou = 0.0 # Assign 0 IoU if calculation fails
299 |
300 | # --- Plotting ---
301 | fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(15, 5))
302 |
303 | # Original image (assuming 2 channels, need to handle display)
304 | img_display = images[i].cpu().numpy().transpose(1, 2, 0) # H, W, C=2
305 | # Create a 3-channel image for display (e.g., use one channel for R/G, zero for B)
306 | if img_display.shape[2] == 2:
307 | display_rgb = np.zeros((img_display.shape[0], img_display.shape[1], 3), dtype=img_display.dtype)
308 | display_rgb[:, :, 0] = img_display[:, :, 0] # VV channel to Red
309 | display_rgb[:, :, 1] = img_display[:, :, 1] # VH channel to Green
310 | max_val = display_rgb.max()
311 | if max_val > 0:
312 | display_rgb = display_rgb / max_val
313 | ax1.imshow(display_rgb)
314 | else: # Handle other channel numbers if necessary
315 | ax1.imshow(np.squeeze(img_display), cmap='gray') # Fallback for single channel
316 |
317 | ax1.set_title('Original Image')
318 | ax1.axis('off')
319 |
320 | # Ground truth
321 | ax2.imshow(single_mask.cpu().numpy().squeeze(), cmap='gray')
322 | ax2.set_title('Ground Truth')
323 | ax2.axis('off')
324 |
325 | # Prediction
326 | ax3.imshow(single_pred.cpu().numpy().squeeze(), cmap='gray')
327 | ax3.set_title(f'Prediction (IoU: {sample_iou:.3f})')
328 | ax3.axis('off')
329 |
330 | # Generate filename
331 | parts = image_paths[i].replace('\\', '/').split("/")
332 | try:
333 | folder = parts[-2] # e.g., "EMSR756_AOI03_DEL_PRODUCT"
334 | filename_with_ext = parts[-1] # e.g., "tile_006_0_01_Sentinel1_vh.tif"
335 | filename = Path(filename_with_ext).stem # remove extension robustly
336 | new_name = f"{folder}_{filename}"
337 | except IndexError:
338 | new_name = f"batch_{idx}_sample_{i}" # Fallback name
339 |
340 | # Save the figure with IoU calculated *ignoring zero pixels*
341 | plt.savefig(save_dir / f'{str(int(sample_iou*100)).zfill(3)}_prediction_{new_name}.png')
342 | plt.close(fig) # Close the figure explicitly
343 |
344 | # Calculate final average IoU
345 | mean_iou = total_iou_metric / num_batches if num_batches > 0 else 0.0
346 | metrics = {
347 | 'mean_iou': mean_iou
348 | }
349 |
350 | return metrics
351 |
352 | def main():
353 | # Set the most promising parameters from top runs
354 | config = {
355 | 'learning_rate': 0.000008,
356 | 'weight_decay': 0.000236,
357 | 'warmup_epochs': 5,
358 | 'plateau_factor': 0.2,
359 | 'plateau_patience': 9,
360 | 'batch_size': 16,
361 | 'num_workers': 6,
362 | 'prefetch_factor': 2,
363 | 'dropout_rate': 0.0,
364 | 'bce_weight': 0.2,
365 | 'dice_weight': 0.8,
366 | 'epochs': 200,
367 | 'checkpoint_freq': 5,
368 | 'save_dir': f'checkpointsDice/{EXPERIMENT_NAME}',
369 | 'use_wandb': True,
370 | 'early_stopping_patience': 15,
371 | 'model': {
372 | 'n_channels': 2,
373 | 'n_classes': 1,
374 | 'bilinear': True,
375 | 'dropout_rate': 0.0 # Use dropout_rate from config
376 | }
377 | }
378 |
379 | # Initialize WandB
380 | if config['use_wandb']:
381 | wandb.init(project=PROJECT_BASE, name=EXPERIMENT_NAME)
382 | wandb.config.update(config)
383 |
384 | # Log system information
385 | if config['use_wandb']:
386 | wandb.log({
387 | 'system/cpu_count': os.cpu_count(),
388 | 'system/gpu_count': torch.cuda.device_count() if torch.cuda.is_available() else 0,
389 | 'system/gpu_name': torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'None',
390 | 'system/memory_total': psutil.virtual_memory().total / 1024**3,
391 | 'system/memory_available': psutil.virtual_memory().available / 1024**3
392 | })
393 |
394 | # Create data directory if it doesn't exist
395 | data_dir = Path('data')
396 | data_dir.mkdir(parents=True, exist_ok=True)
397 |
398 | # Create datasets
399 | train_dataset = FloodSegmentationDataset(
400 | map_file=str(data_dir / 'full_list.txt'),
401 | base_path=str(data_dir),
402 | split='train'
403 | )
404 |
405 | val_dataset = FloodSegmentationDataset(
406 | map_file=str(data_dir / 'full_list.txt'),
407 | base_path=str(data_dir),
408 | split='val'
409 | )
410 |
411 | # Create dataloaders
412 | train_loader = DataLoader(
413 | train_dataset,
414 | batch_size=config['batch_size'],
415 | shuffle=True,
416 | num_workers=config['num_workers'],
417 | pin_memory=True,
418 | persistent_workers=True,
419 | prefetch_factor=config['prefetch_factor']
420 | )
421 |
422 | val_loader = DataLoader(
423 | val_dataset,
424 | batch_size=config['batch_size'],
425 | shuffle=False,
426 | num_workers=config['num_workers'],
427 | pin_memory=True,
428 | persistent_workers=True,
429 | prefetch_factor=config['prefetch_factor']
430 | )
431 |
432 | # Initialize model
433 | model = UNet(**config['model'])
434 |
435 | # Set device
436 | device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')
437 | print(f"Using device: {device}")
438 |
439 | # Train the model
440 | best_val_iou, best_epoch = train_model(model, train_loader, val_loader, config, device)
441 | print(f"\nTraining completed. Best validation IoU: {best_val_iou:.4f} at epoch {best_epoch}")
442 |
443 | # Load the best model for inference
444 | checkpoint_path = Path(config['save_dir']) / 'best_model.pth'
445 | if checkpoint_path.exists():
446 | print(f"\nRunning inference with best model from {checkpoint_path}")
447 | model, epoch, val_iou = load_model(checkpoint_path, model, device)
448 | model = model.to(device)
449 | print(f"Loaded model from epoch {epoch} with validation IoU: {val_iou:.4f}")
450 |
451 | # Create test dataset and dataloader for inference
452 | test_dataset = FloodSegmentationDataset(
453 | map_file=str(data_dir / 'full_list.txt'), # Using validation set for testing
454 | base_path=str(data_dir),
455 | split='test'
456 | )
457 |
458 | test_loader = DataLoader(
459 | test_dataset,
460 | batch_size=config['batch_size'],
461 | shuffle=False,
462 | num_workers=config['num_workers'],
463 | pin_memory=True
464 | )
465 |
466 | # Run inference
467 | save_dir = f'predictionsDice/{EXPERIMENT_NAME}'
468 | metrics = predict(model, test_loader, device, save_dir)
469 |
470 | print(f"Inference completed. Mean IoU: {metrics['mean_iou']:.4f}")
471 |
472 | # Log final results to WandB
473 | if config['use_wandb']:
474 | wandb.log({
475 | 'final_val_iou': best_val_iou,
476 | 'test_iou': metrics['mean_iou'],
477 | 'best_epoch': best_epoch,
478 | 'system/final_gpu_memory': get_gpu_memory_usage(),
479 | 'system/final_cpu_percent': psutil.cpu_percent(),
480 | 'system/final_memory_percent': psutil.virtual_memory().percent
481 | })
482 |
483 | # Log sample prediction images to WandB
484 | #if Path(save_dir).exists():
485 | # for image_path in Path(save_dir).glob('prediction_*.png'):
486 | # wandb.log({f"predictions/{image_path.name}": wandb.Image(str(image_path))})
487 | else:
488 | print(f"Warning: Could not find checkpoint at {checkpoint_path}")
489 |
490 | # Finish WandB run
491 | if config['use_wandb'] and wandb.run:
492 | wandb.finish()
493 |
494 | if __name__ == "__main__":
495 | main()
--------------------------------------------------------------------------------