├── .gitignore ├── .idea ├── ContUrbanCD.iml ├── inspectionProfiles │ └── profiles_settings.xml ├── misc.xml ├── modules.xml ├── vcs.xml └── workspace.xml ├── README.md ├── configs ├── base.yaml ├── conturbancd_sn7.yaml ├── conturbancd_tscd.yaml ├── conturbancd_wusu.yaml └── debug.yaml ├── eval.py ├── figures └── overview.jpg ├── inference.py ├── metadata_conturbancd.json ├── model ├── model.py ├── modules.py └── unet.py ├── requirements.txt ├── train.py └── utils ├── datasets.py ├── evaluation.py ├── experiment_manager.py ├── helpers.py └── parsers.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | 106 | wandb/ 107 | output/ 108 | -------------------------------------------------------------------------------- /.idea/ContUrbanCD.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 6 | 7 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /.idea/workspace.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 15 | 16 | 18 | 19 | 20 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 52 | 53 | 54 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 1718108828280 92 | 98 | 99 | 100 | 101 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 | 134 | 135 | 136 | 137 | 138 | 139 | 140 | 141 | 142 | 143 | 144 | 145 | file://$PROJECT_DIR$/train.py 146 | 17 147 | 149 | 150 | file://$PROJECT_DIR$/eval.py 151 | 27 152 | 154 | 155 | 156 | 157 | 158 | 159 | 160 | 161 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Continuous Urban Change Detection from Satellite Image Time Series 2 | 3 | This repository contains the official code for the following paper: 4 | 5 | S. Hafner, H. Fang, H. Azizpour and Y. Ban, "Continuous Urban Change Detection from Satellite Image Time Series with Temporal Feature Refinement and Multi-Task Integration," *(accepted to TGRS)*. 6 | 7 | [![arXiv](https://img.shields.io/badge/arXiv-2406.17458-b31b1b.svg)](https://arxiv.org/abs/2406.17458) 8 | 9 | 10 | 11 | ![overview](figures/overview.jpg) 12 | 13 | 14 | 15 | 16 | # Setup 17 | 18 | This section shows you how to setup the dataset and the virtual environment. 19 | 20 | ## Dataset 21 | 22 | The SpaceNet 7 dataset is available from this [link](https://spacenet.ai/sn7-challenge/). More information can be found in the [SpaceNet 7 dataset paper](https://openaccess.thecvf.com/content/CVPR2021/html/Van_Etten_The_Multi-Temporal_Urban_Development_SpaceNet_Dataset_CVPR_2021_paper.html). 23 | 24 | We use a metadata file (`metadata_conturbancd.json`) that can be downloaded from [here](https://drive.google.com/file/d/1wzRZ9iS2lOu24OArkG5n5tQaNrRQ6As6/view?usp=drive_link) or from this repository. The metadata file should be placed in the root directory of the SpaceNet 7 dataset (see below). 25 | 26 | We also generated raster labels using the code in [this paper](https://doi.org/10.1109/IGARSS46834.2022.9883982). The labels can be obtained from [this link](https://drive.google.com/file/d/1ZHcZ0qfcymBJ4_hHcpDz6UpwvOiRZccB/view?usp=sharing) and should be placed in the respective study site folder. 27 | 28 | The dataset directory should look like this: 29 | 30 | ``` 31 | $ SpaceNet 7 dataset directory 32 | spacenet7 # -d should point to this directory 33 | ├── metadata_conturbancd.json # Download this file and place it in the dataset directory 34 | └── train 35 | ├── L15-0331E-1257N_1327_3160_13 36 | ... 37 | └── L15-1848E-0793N_7394_5018_13 38 | ├── images_masked 39 | ├── labels 40 | └── labels_raster # Generated using the vector data in labels 41 | ``` 42 | 43 | 44 | 45 | ## Environment 46 | 47 | 1. Clone this repository 48 | 49 | ```bash 50 | git clone https://github.com/SebastianHafner/ContUrbanCD.git 51 | cd ContUrbanCD 52 | ``` 53 | 54 | 2. Setup a virtual environment 55 | 56 | We use conda to setup the environment: 57 | ```bash 58 | conda create -n conturbancd python=3.9.7 59 | conda activate conturbancd 60 | ``` 61 | 62 | 3. Install the dependencies 63 | 64 | Install pytorch according to the [official guide](https://pytorch.org/get-started/locally/). 65 | 66 | Install other dependencies using the `requirements.txt` file. Note that not all libraries specified in the file are required to run this code. 67 | ```bash 68 | pip install -r requirements.txt 69 | ``` 70 | 71 | 72 | 73 | # Running our code 74 | 75 | This section provides all the instructions to run our code. If you do not want to train your own models, we also provide [our model weights](https://drive.google.com/drive/folders/1GA4_GM4li-K8gpCltFjM0x0W46k7iol5?usp=sharing) (trained using this code base). 76 | 77 | All scripts (`train.py`, `evaluation.py`, `inference.py`) require three arguments: 78 | - The config file is specified using `-c`. This repo includes dataset-specific configs for the proposed method. 79 | - The output directory is specified using `-o`. We use this directory to store model weights and evaluation and inference outputs. 80 | - The dataset directory is specified using `-d`. This directory points to the root folder of the dataset. 81 | 82 | 83 | ## Training 84 | 85 | To train our network on SpaceNet 7, run the following script: 86 | ```bash 87 | python train.py -c conturbancd_sn7 -o output -d spacenet7 88 | ``` 89 | 90 | 91 | ## Evaluation 92 | 93 | To calculate the evaluation metrics, run this script: 94 | ```bash 95 | python eval.py -c conturbancd_sn7 -o output -d spacenet7 96 | ``` 97 | 98 | The script outputs a `.json` file containing accuracy values. 99 | 100 | 101 | ## Inference 102 | 103 | To produce multi-temporal building segmentation outputs from a satellite image time series, run the following script: 104 | 105 | ```bash 106 | python inference.py -c conturbancd_sn7 -o output -d spacenet7 107 | ``` 108 | 109 | Note: You have the option to set the edge setting in the MRF module using `-e` (degenerate, adjacent, cyclic, and dense). Dense is used as default option. 110 | 111 | The resulting file (`.np`) is of dimension T x H x W with the first dimension denoting the temporal dimension. 112 | 113 | ```python 114 | seg_sits = np.load(output_file) # T x H x X 115 | 116 | # Building segmentations for second image in the time series 117 | seg_img_t2 = seg_sits[1] 118 | 119 | # Changes between first and last images of the time series 120 | ch_first_last = np.not_equal(seg_sits[0], seg_sits[-1]) 121 | ``` 122 | 123 | # Credits 124 | 125 | If you find this work useful, please consider citing: 126 | 127 | 128 | 129 | ```bibtex 130 | @article{hafner2024continuous, 131 | title={Continuous Urban Change Detection from Satellite Image Time Series with Temporal Feature Refinement and Multi-Task Integration}, 132 | author={Hafner, Sebastian and Fang, Heng and Azizpour, Hossein and Ban, Yifang}, 133 | journal={arXiv preprint arXiv:2406.17458}, 134 | year={2024} 135 | } 136 | ``` 137 | -------------------------------------------------------------------------------- /configs/base.yaml: -------------------------------------------------------------------------------- 1 | SEED: 1 2 | DEBUG: False 3 | LOG_FREQ: 100 4 | 5 | TRAINER: 6 | LR: 1e-4 7 | BATCH_SIZE: 8 8 | EPOCHS: 100 9 | OPTIMIZER: 'adamw' 10 | PATIENCE: 10 11 | LAMBDA: 0 12 | LR_SCHEDULER: 'linear' 13 | 14 | MODEL: 15 | TYPE: 'mtunetformer' 16 | IN_CHANNELS: 3 17 | OUT_CHANNELS: 1 18 | LOSS_TYPE: 'PowerJaccardLoss' 19 | EDGE_TYPE: 'dense' 20 | TOPOLOGY: [64, 128, 256, 512,] 21 | DISABLE_OUTCONV: False 22 | TRANSFORMER: True 23 | TRANSFORMER_PARAMS: 24 | N_LAYERS: 2 25 | N_HEADS: 2 26 | D_MODEL: 192 27 | PATCH_SIZE: 8 28 | ACTIVATION: 'gelu' 29 | SPATIAL_ATTENTION_SIZE: 1 30 | ADJACENT_CHANGES: True 31 | 32 | DATALOADER: 33 | NUM_WORKER: 4 34 | SHUFFLE: True 35 | MODE: 'all' # 'first_last' or 'all' 36 | SENSOR: 'planetscope' 37 | INCLUDE_ALPHA: False 38 | TRAINING_MULTIPLIER: 100 39 | PAD_BORDERS: True 40 | TIMESERIES_LENGTH: 5 41 | INCLUDE_CHANGE_LABEL: False 42 | EVAL_TRAIN_THRESHOLD: 0.75 43 | I_SPLIT: 768 44 | J_SPLIT: 512 45 | 46 | AUGMENTATION: 47 | CROP_SIZE: 64 48 | IMAGE_OVERSAMPLING_TYPE: 'change' # [none, change, semantic] 49 | RANDOM_FLIP: True 50 | RANDOM_ROTATE: True 51 | COLOR_BLUR: True 52 | COLOR_JITTER: True 53 | 54 | DATASET: 55 | NAME: 'sn7' 56 | TRAIN_IDS: [ 57 | 'L15-0331E-1257N_1327_3160_13', 58 | 'L15-0357E-1223N_1429_3296_13', 59 | 'L15-0358E-1220N_1433_3310_13', 60 | 'L15-0361E-1300N_1446_2989_13', 61 | 'L15-0434E-1218N_1736_3318_13', 62 | 'L15-0487E-1246N_1950_3207_13', 63 | 'L15-0506E-1204N_2027_3374_13', 64 | 'L15-0544E-1228N_2176_3279_13', 65 | 'L15-0577E-1243N_2309_3217_13', 66 | 'L15-0586E-1127N_2345_3680_13', 67 | 'L15-0595E-1278N_2383_3079_13', 68 | 'L15-0614E-0946N_2459_4406_13', 69 | 'L15-0683E-1006N_2732_4164_13', 70 | 'L15-0760E-0887N_3041_4643_13', 71 | 'L15-0924E-1108N_3699_3757_13', 72 | 'L15-0977E-1187N_3911_3441_13', 73 | 'L15-1025E-1366N_4102_2726_13', 74 | 'L15-1049E-1370N_4196_2710_13', 75 | 'L15-1172E-1306N_4688_2967_13', 76 | 'L15-1185E-0935N_4742_4450_13', 77 | 'L15-1203E-1203N_4815_3378_13', 78 | 'L15-1204E-1202N_4816_3380_13', 79 | 'L15-1204E-1204N_4819_3372_13', 80 | 'L15-1209E-1113N_4838_3737_13', 81 | 'L15-1210E-1025N_4840_4088_13', 82 | 'L15-1289E-1169N_5156_3514_13', 83 | 'L15-1296E-1198N_5184_3399_13', 84 | 'L15-1298E-1322N_5193_2903_13', 85 | 'L15-1335E-1166N_5342_3524_13', 86 | 'L15-1389E-1284N_5557_3054_13', 87 | 'L15-1479E-1101N_5916_3785_13', 88 | 'L15-1481E-1119N_5927_3715_13', 89 | 'L15-1538E-1163N_6154_3539_13', 90 | 'L15-1615E-1205N_6460_3370_13', 91 | 'L15-1617E-1207N_6468_3360_13', 92 | 'L15-1669E-1153N_6678_3579_13', 93 | 'L15-1669E-1160N_6678_3548_13', 94 | 'L15-1691E-1211N_6764_3347_13', 95 | 'L15-1703E-1219N_6813_3313_13', 96 | 'L15-1716E-1211N_6864_3345_13', 97 | 'L15-0368E-1245N_1474_3210_13', 98 | 'L15-0457E-1135N_1831_3648_13', 99 | 'L15-0571E-1075N_2287_3888_13', 100 | 'L15-1014E-1375N_4056_2688_13', 101 | 'L15-1138E-1216N_4553_3325_13', 102 | 'L15-1439E-1134N_5759_3655_13', 103 | 'L15-1669E-1160N_6679_3549_13', 104 | 'L15-1672E-1207N_6691_3363_13', 105 | 'L15-1709E-1112N_6838_3742_13', 106 | 'L15-1748E-1247N_6993_3202_13', 107 | 'L15-0387E-1276N_1549_3087_13', 108 | 'L15-0566E-1185N_2265_3451_13', 109 | 'L15-0632E-0892N_2528_4620_13', 110 | 'L15-1015E-1062N_4061_3941_13', 111 | 'L15-1200E-0847N_4802_4803_13', 112 | 'L15-1276E-1107N_5105_3761_13', 113 | 'L15-1438E-1134N_5753_3655_13', 114 | 'L15-1615E-1206N_6460_3366_13', 115 | 'L15-1690E-1211N_6763_3346_13', 116 | 'L15-1848E-0793N_7394_5018_13', 117 | ] 118 | TEST_IDS: [ 119 | 'L15-0369E-1244N_1479_3214_13', 120 | 'L15-0391E-1219N_1567_3314_13', 121 | 'L15-0509E-1108N_2037_3758_13', 122 | 'L15-0571E-1302N_2284_2983_13', 123 | 'L15-0697E-0874N_2789_4694_13', 124 | 'L15-0744E-0927N_2979_4481_13', 125 | 'L15-1031E-1300N_4127_2991_13', 126 | 'L15-1129E-0819N_4517_4915_13', 127 | 'L15-1203E-1203N_4815_3379_13', 128 | 'L15-1213E-1238N_4852_3239_13', 129 | 'L15-1249E-1167N_4999_3521_13', 130 | 'L15-1281E-1035N_5125_4049_13', 131 | 'L15-1438E-1227N_5753_3282_13', 132 | 'L15-1546E-1154N_6186_3574_13', 133 | 'L15-1615E-1205N_6461_3368_13', 134 | 'L15-1630E-0988N_6522_4239_13', 135 | 'L15-1666E-1189N_6665_3433_13', 136 | 'L15-1670E-1159N_6681_3552_13', 137 | 'L15-1690E-1210N_6762_3348_13', 138 | 'L15-1749E-1266N_6997_3126_13', 139 | ] -------------------------------------------------------------------------------- /configs/conturbancd_sn7.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "base.yaml" 2 | SEED: 2 3 | 4 | TRAINER: 5 | LR: 1e-4 6 | BATCH_SIZE: 16 7 | 8 | MODEL: 9 | EDGE_TYPE: 'dense' 10 | TYPE: 'mtunetformer' 11 | 12 | DATASET: 13 | NAME: 'sn7' -------------------------------------------------------------------------------- /configs/conturbancd_tscd.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "base.yaml" 2 | SEED: 3 3 | 4 | TRAINER: 5 | LR: 5e-5 6 | BATCH_SIZE: 16 7 | 8 | MODEL: 9 | EDGE_TYPE: 'dense' 10 | TYPE: 'mtunetformer' 11 | IN_CHANNELS: 3 12 | 13 | DATASET: 14 | NAME: 'tscd' 15 | 16 | DATALOADER: 17 | TIMESERIES_LENGTH: 4 -------------------------------------------------------------------------------- /configs/conturbancd_wusu.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "base.yaml" 2 | SEED: 2 3 | 4 | TRAINER: 5 | LR: 1e-4 6 | BATCH_SIZE: 16 7 | 8 | MODEL: 9 | EDGE_TYPE: 'cyclic' 10 | TYPE: 'mtunetformer' 11 | IN_CHANNELS: 4 12 | 13 | DATASET: 14 | NAME: 'wusu' 15 | 16 | DATALOADER: 17 | TIMESERIES_LENGTH: 3 -------------------------------------------------------------------------------- /configs/debug.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "conturbancd_sn7.yaml" 2 | DEBUG: True 3 | 4 | TRAINER: 5 | BATCH_SIZE: 2 6 | 7 | DATALOADER: 8 | TRAINING_MULTIPLIER: 1 9 | 10 | AUGMENTATION: 11 | CROP_SIZE: 64 12 | IMAGE_OVERSAMPLING_TYPE: 'none' 13 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from utils import datasets, parsers, experiment_manager, helpers, evaluation 4 | from utils.experiment_manager import CfgNode 5 | from model import model 6 | 7 | from pathlib import Path 8 | 9 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 10 | 11 | 12 | def assessment(cfg: CfgNode, edge_type: str = 'dense', run_type: str = 'test'): 13 | print(cfg.NAME) 14 | net = model.load_model(cfg, device) 15 | m = evaluation.run_quantitative_evaluation(net, cfg, device, run_type, enable_mti=True, mti_edge_setting=edge_type) 16 | 17 | data = {} 18 | for attr in ['seg_cont', 'seg_fl', 'ch_cont', 'ch_fl']: 19 | f1 = evaluation.f1_score(getattr(m, f'TP_{attr}'), getattr(m, f'FP_{attr}'), getattr(m, f'FN_{attr}')) 20 | iou = evaluation.iou(getattr(m, f'TP_{attr}'), getattr(m, f'FP_{attr}'), getattr(m, f'FN_{attr}')) 21 | data[attr] = {'f1': f1, 'iou': iou} 22 | eval_folder = Path(cfg.PATHS.OUTPUT) / 'evaluation' 23 | eval_folder.mkdir(exist_ok=True) 24 | helpers.write_json(eval_folder / f'{cfg.NAME}_{edge_type}.json', data) 25 | 26 | 27 | if __name__ == '__main__': 28 | args = parsers.inference_argument_parser().parse_known_args()[0] 29 | cfg = experiment_manager.setup_cfg(args) 30 | assessment(cfg, edge_type=args.edge_type) 31 | 32 | -------------------------------------------------------------------------------- /figures/overview.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SebastianHafner/ContUrbanCD/4c3c6501b444880ebc8eb21bcceea6dcf738c46b/figures/overview.jpg -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from utils import parsers, experiment_manager, helpers, datasets 3 | from utils.experiment_manager import CfgNode 4 | from pathlib import Path 5 | import numpy as np 6 | from model import model 7 | 8 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 9 | EPS = 1e-6 10 | 11 | 12 | def inference(cfg: CfgNode, edge_type: str = 'dense', run_type: str = 'test'): 13 | print(cfg.NAME) 14 | net = model.load_model(cfg, device) 15 | net.eval() 16 | 17 | tile_size = cfg.AUGMENTATION.CROP_SIZE 18 | edges = helpers.get_edges(cfg.DATALOADER.TIMESERIES_LENGTH, edge_type) 19 | 20 | pred_folder = Path(cfg.PATHS.OUTPUT) / 'inference' / cfg.NAME 21 | pred_folder.mkdir(exist_ok=True) 22 | 23 | for aoi_id in list(cfg.DATASET.TRAIN_IDS): 24 | print(aoi_id) 25 | 26 | ds = datasets.create_eval_dataset(cfg, run_type, site=aoi_id, tiling=tile_size) 27 | o_seg = np.empty((1, ds.T, 1, ds.m, ds.n), dtype=np.uint8) 28 | 29 | for index in range(len(ds)): 30 | item = ds.__getitem__(index) 31 | x, y_seg = item['x'].to(device).unsqueeze(0), item['y'].to(device).unsqueeze(0) 32 | i, j = item['i'], item['j'] 33 | o_seg_tile = net.module.inference(x, edges) 34 | o_seg[:, :, :, i:i + tile_size, j:j + tile_size] = o_seg_tile 35 | 36 | np.save(pred_folder / f'{cfg.NAME}_{aoi_id}.npy', o_seg) 37 | 38 | 39 | if __name__ == '__main__': 40 | args = parsers.inference_argument_parser().parse_known_args()[0] 41 | cfg = experiment_manager.setup_cfg(args) 42 | inference(cfg, edge_type=args.edge_type) 43 | -------------------------------------------------------------------------------- /model/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch import Tensor 4 | import einops 5 | 6 | from typing import Tuple, Sequence 7 | from pathlib import Path 8 | 9 | from utils.experiment_manager import CfgNode 10 | 11 | from model import unet, modules 12 | 13 | 14 | class ContUrbanCDModel(nn.Module): 15 | def __init__(self, cfg: CfgNode): 16 | super(ContUrbanCDModel, self).__init__() 17 | 18 | # Attributes 19 | self.cfg = cfg 20 | self.c = cfg.MODEL.IN_CHANNELS 21 | self.d_out = cfg.MODEL.OUT_CHANNELS 22 | self.h = self.w = cfg.AUGMENTATION.CROP_SIZE 23 | self.t = cfg.DATALOADER.TIMESERIES_LENGTH 24 | self.topology = cfg.MODEL.TOPOLOGY 25 | 26 | # ConvNet layers 27 | self.inc = unet.InConv(self.c, self.topology[0], unet.DoubleConv) 28 | self.encoder = unet.Encoder(cfg) 29 | 30 | self.decoder_seg = unet.Decoder(cfg) 31 | self.outc_seg = unet.OutConv(self.topology[0], self.d_out) 32 | 33 | self.decoder_ch = unet.Decoder(cfg) 34 | self.outc_ch = unet.OutConv(self.topology[0], self.d_out) 35 | 36 | # Temporal feature refinement (TFR) modules 37 | tfr_modules = [] 38 | transformer_dims = [self.topology[-1]] + list(self.topology[::-1]) 39 | for i, d_model in enumerate(transformer_dims): 40 | tfr_module = modules.TFRModule( 41 | t=self.t, 42 | d_model=d_model, 43 | n_heads=cfg.MODEL.TRANSFORMER_PARAMS.N_HEADS, 44 | d_hid=self.topology[0] * 4, 45 | activation=cfg.MODEL.TRANSFORMER_PARAMS.ACTIVATION, 46 | n_layers=cfg.MODEL.TRANSFORMER_PARAMS.N_LAYERS 47 | ) 48 | tfr_modules.append(tfr_module) 49 | self.tfr_modules = nn.ModuleList(tfr_modules) 50 | 51 | # Change feature (CF) module 52 | self.cf_module = modules.CFModule() 53 | 54 | # Multi-task integration (MTI) module 55 | self.mti_module = modules.MTIModule() 56 | 57 | def forward(self, x: Tensor, edges: Sequence[Tuple[int, int]]) -> Tuple[Tensor, Tensor]: 58 | B, T, _, H, W = x.size() 59 | 60 | # Feature extraction with Siamese ConvNet encoder 61 | x = einops.rearrange(x, 'b t c h w -> (b t) c h w') 62 | features = self.encoder(self.inc(x)) 63 | features = [einops.rearrange(f_s, '(b t) f h w -> b t f h w', b=B) for f_s in features] 64 | 65 | # Temporal feature refinement with TFR modules 66 | for i, tfr_module in enumerate(self.tfr_modules): 67 | f = features[i] # Feature maps at scale s 68 | # Feature refinement with self-attention 69 | f_refined = tfr_module(f) 70 | features[i] = f_refined 71 | 72 | # Change feature maps 73 | features_ch = self.cf_module(features, edges) 74 | features_ch = [einops.rearrange(f, 'n b c h w -> (b n) c h w') for f in features_ch] 75 | 76 | # Building segmentation 77 | features_seg = [einops.rearrange(f, 'b t c h w -> (b t) c h w') for f in features] 78 | logits_seg = self.outc_seg(self.decoder_seg(features_seg)) 79 | logits_seg = einops.rearrange(logits_seg, '(b t) c h w -> b t c h w', b=B) 80 | 81 | logits_ch = self.outc_ch(self.decoder_ch(features_ch)) 82 | logits_ch = einops.rearrange(logits_ch, '(b n) c h w -> b n c h w', n=len(edges)) 83 | 84 | return logits_ch, logits_seg 85 | 86 | def inference(self, x: Tensor, edges: Sequence[Tuple[int, int]]) -> Tensor: 87 | logits_ch, logits_seg = self.forward(x, edges) 88 | o_ch = torch.sigmoid(logits_ch).detach() 89 | o_seg = torch.sigmoid(logits_seg).detach() 90 | o_seg = self.mti_module(o_ch, o_seg, edges) 91 | return o_seg 92 | 93 | 94 | def init_model(cfg: CfgNode) -> nn.Module: 95 | net = ContUrbanCDModel(cfg) 96 | return torch.nn.DataParallel(net) 97 | 98 | 99 | def save_model(network: nn.Module, epoch: float, cfg: CfgNode): 100 | save_file = Path(cfg.PATHS.OUTPUT) / 'weights' / f'{cfg.NAME}.pt' 101 | save_file.parent.mkdir(exist_ok=True) 102 | checkpoint = { 103 | 'epoch': epoch, 104 | 'weights': network.state_dict(), 105 | } 106 | torch.save(checkpoint, save_file) 107 | 108 | 109 | def load_model(cfg: CfgNode, device: torch.device) -> nn.Module: 110 | net = init_model(cfg) 111 | net.to(device) 112 | net_file = Path(cfg.PATHS.OUTPUT) / 'weights' / f'{cfg.NAME}.pt' 113 | checkpoint = torch.load(net_file, map_location=device) 114 | net.load_state_dict(checkpoint['weights']) 115 | return net 116 | 117 | 118 | def power_jaccard_loss(input: Tensor, target: Tensor, disable_sigmoid: bool = False) -> Tensor: 119 | input_sigmoid = torch.sigmoid(input) if not disable_sigmoid else input 120 | eps = 1e-6 121 | 122 | iflat = input_sigmoid.flatten() 123 | tflat = target.flatten() 124 | intersection = (iflat * tflat).sum() 125 | denom = (iflat ** 2 + tflat ** 2).sum() - (iflat * tflat).sum() + eps 126 | 127 | return 1 - (intersection / denom) 128 | 129 | -------------------------------------------------------------------------------- /model/modules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch import Tensor 4 | import einops 5 | 6 | import numpy as np 7 | 8 | # https://pgmpy.org/models/markovnetwork.html 9 | from pgmpy.models import MarkovNetwork 10 | from pgmpy.factors.discrete import DiscreteFactor 11 | from pgmpy.inference import BeliefPropagation 12 | 13 | from typing import Tuple, Sequence, Callable 14 | 15 | from joblib import Parallel, delayed 16 | 17 | import os 18 | os.environ["PYTHONWARNINGS"] = "ignore" 19 | 20 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 21 | 22 | 23 | class TFRModule(nn.Module): 24 | def __init__(self, t: int, d_model: int, n_heads: int, d_hid: int, activation: str, n_layers: int): 25 | super().__init__() 26 | # Generate relative temporal encodings 27 | self.register_buffer('temporal_encodings', self.get_relative_encodings(t, d_model), persistent=False) 28 | 29 | # Define a transformer encoder layer 30 | encoder_layer = nn.TransformerEncoderLayer( 31 | d_model=d_model, nhead=n_heads, 32 | dim_feedforward=d_hid, batch_first=True, 33 | activation=activation 34 | ) 35 | # Create module 36 | self.temporal_feature_refinement = nn.TransformerEncoder(encoder_layer, n_layers) 37 | 38 | def forward(self, features: Tensor) -> Tensor: 39 | B, T, D, H, W = features.size() 40 | 41 | # Reshape to tokens 42 | tokens = einops.rearrange(features, 'B T D H W -> (B H W) T D') 43 | 44 | # Adding relative temporal encodings 45 | tokens = tokens + self.temporal_encodings.repeat(B * H * W, 1, 1) 46 | 47 | # Feature refinement with self-attention 48 | features_hat = self.temporal_feature_refinement(tokens) 49 | 50 | # Reshape to original shape 51 | features_hat = einops.rearrange(features_hat, '(B H W) T D -> B T D H W', B=B, H=H) 52 | 53 | return features_hat 54 | 55 | @staticmethod 56 | def get_relative_encodings(sequence_length, d): 57 | result = torch.ones(sequence_length, d) 58 | for i in range(sequence_length): 59 | for j in range(d): 60 | result[i][j] = np.sin(i / (10000 ** (j / d))) if j % 2 == 0 else np.cos(i / (10000 ** ((j - 1) / d))) 61 | return result 62 | 63 | 64 | class CFModule(nn.Module): 65 | def __init__(self): 66 | super().__init__() 67 | 68 | @staticmethod 69 | def forward(features: Sequence[Tensor], edges: Sequence[Tuple[int, int]]) -> Sequence[Tensor]: 70 | # compute urban change detection features 71 | features_ch = [] 72 | for feature in features: 73 | B, T, _, H, W = feature.size() 74 | feature_ch = [] 75 | for t1, t2 in edges: 76 | feature_ch.append(feature[:, t2] - feature[:, t1]) 77 | # n: number of combinations 78 | feature_ch = torch.stack(feature_ch) 79 | features_ch.append(feature_ch) 80 | return features_ch 81 | 82 | 83 | class MTIModule(nn.Module): 84 | def __init__(self): 85 | super().__init__() 86 | 87 | def forward(self, o_ch: Tensor, o_seg: Tensor, edges: Sequence[Tuple[int, int]]) -> Tensor: 88 | 89 | B, T, _, H, W = o_seg.size() 90 | 91 | # Get processing function by defining the Markov network 92 | process_pixel = self.markov_network(T, edges) 93 | 94 | # Reshape 95 | o_ch = einops.rearrange(o_ch, 'B N C H W -> (B H W) N C').cpu().numpy() 96 | o_seg = einops.rearrange(o_seg, 'B T C H W -> (B H W) T C').cpu().numpy() 97 | 98 | # Find optimal building time series using Markov network 99 | o_seg_mrf = Parallel(n_jobs=-1)(delayed(process_pixel)(p_seg, p_ch) for p_seg, p_ch, in zip(o_seg, o_ch)) 100 | 101 | # Reshape 102 | o_seg_mrf = torch.Tensor(o_seg_mrf) 103 | o_seg_mrf = einops.rearrange(o_seg_mrf, '(B H W) (T C) -> B T C H W', H=H, W=W, T=T) 104 | 105 | return o_seg_mrf 106 | 107 | @staticmethod 108 | def markov_network(n: int, edges: Sequence[Tuple[int, int]]) -> Callable: 109 | # Define a function to process a single pixel 110 | def process_pixel(y_hat_seg: Sequence[float], y_hat_ch: Sequence[float]): 111 | model = MarkovNetwork() 112 | 113 | for t in range(len(y_hat_seg)): 114 | model.add_node(f'N{t}') 115 | # Cardinality: number of potential values (i.e., 2: 0/1) 116 | # Potential Values for node: P(urban=True), P(urban=False) 117 | urban_value = float(y_hat_seg[t]) 118 | factor = DiscreteFactor([f'N{t}'], cardinality=[2], values=[1 - urban_value, urban_value]) 119 | model.add_factors(factor) 120 | 121 | # add adjacent edges w/o potentials 122 | for t in range(n - 1): 123 | model.add_edge(f'N{t}', f'N{t + 1}') 124 | 125 | # add edges with potentials 126 | for i, (t1, t2) in enumerate(edges): 127 | model.add_edge(f'N{t1}', f'N{t2}') 128 | # [P(A=False, B=False), P(A=False, B=True), P(A=True, B=False), P(A=True, B=True)] 129 | change_value = float(y_hat_ch[i]) 130 | edge_values = [1 - change_value, change_value, change_value, 1 - change_value] 131 | 132 | factor = DiscreteFactor([f'N{t1}', f'N{t2}'], cardinality=[2, 2], values=edge_values) 133 | model.add_factors(factor) 134 | 135 | # Create an instance of BeliefPropagation algorithm 136 | bp = BeliefPropagation(model) 137 | 138 | # Compute the most probable state of the MRF 139 | state = bp.map_query() 140 | states_list = [state[f'N{t}'] for t in range(n)] 141 | return states_list 142 | 143 | return process_pixel 144 | -------------------------------------------------------------------------------- /model/unet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch import Tensor 4 | import torch.nn.functional as F 5 | 6 | import einops 7 | from collections import OrderedDict 8 | from utils.experiment_manager import CfgNode 9 | 10 | 11 | class UNet(nn.Module): 12 | def __init__(self, cfg: CfgNode): 13 | super(UNet, self).__init__() 14 | self.cfg = cfg 15 | 16 | n_channels = cfg.MODEL.IN_CHANNELS 17 | n_classes = cfg.MODEL.OUT_CHANNELS 18 | topology = cfg.MODEL.TOPOLOGY 19 | 20 | self.inc = InConv(n_channels, topology[0], DoubleConv) 21 | self.encoder = Encoder(cfg) 22 | self.decoder = Decoder(cfg) 23 | self.outc = OutConv(topology[0], n_classes) 24 | self.disable_outc = cfg.MODEL.DISABLE_OUTCONV 25 | 26 | def forward(self, x: Tensor) -> Tensor: 27 | B, T, _, H, W = x.size() 28 | x = einops.rearrange(x, 'b t c h w -> (b t) c h w') 29 | out = self.outc(self.decoder(self.encoder(self.inc(x)))) 30 | out = einops.rearrange(out, '(b t) c h w -> b t c h w', b=B) 31 | return out 32 | 33 | 34 | class Encoder(nn.Module): 35 | def __init__(self, cfg: CfgNode): 36 | super(Encoder, self).__init__() 37 | 38 | self.cfg = cfg 39 | topology = cfg.MODEL.TOPOLOGY 40 | 41 | # Variable scale 42 | down_topo = topology 43 | down_dict = OrderedDict() 44 | n_layers = len(down_topo) 45 | 46 | # Downward layers 47 | for idx in range(n_layers): 48 | is_not_last_layer = idx != n_layers - 1 49 | in_dim = down_topo[idx] 50 | out_dim = down_topo[idx + 1] if is_not_last_layer else down_topo[idx] # last layer 51 | layer = Down(in_dim, out_dim, DoubleConv) 52 | down_dict[f'down{idx + 1}'] = layer 53 | self.down_seq = nn.ModuleDict(down_dict) 54 | 55 | def forward(self, x1: Tensor) -> list: 56 | 57 | inputs = [x1] 58 | # Downward U: 59 | for layer in self.down_seq.values(): 60 | out = layer(inputs[-1]) 61 | inputs.append(out) 62 | 63 | inputs.reverse() 64 | return inputs 65 | 66 | 67 | class Decoder(nn.Module): 68 | def __init__(self, cfg: CfgNode): 69 | super(Decoder, self).__init__() 70 | self.cfg = cfg 71 | 72 | topology = cfg.MODEL.TOPOLOGY 73 | 74 | # Variable scale 75 | n_layers = len(topology) 76 | up_topo = [topology[0]] # topography upwards 77 | up_dict = OrderedDict() 78 | 79 | for idx in range(n_layers): 80 | is_not_last_layer = idx != n_layers - 1 81 | out_dim = topology[idx + 1] if is_not_last_layer else topology[idx] # last layer 82 | up_topo.append(out_dim) 83 | 84 | # Upward layers 85 | for idx in reversed(range(n_layers)): 86 | is_not_last_layer = idx != 0 87 | x1_idx = idx 88 | x2_idx = idx - 1 if is_not_last_layer else idx 89 | in_dim = up_topo[x1_idx] * 2 90 | out_dim = up_topo[x2_idx] 91 | layer = Up(in_dim, out_dim, DoubleConv) 92 | up_dict[f'up{idx + 1}'] = layer 93 | 94 | self.up_seq = nn.ModuleDict(up_dict) 95 | 96 | def forward(self, features: list) -> Tensor: 97 | 98 | x1 = features.pop(0) 99 | for idx, layer in enumerate(self.up_seq.values()): 100 | x2 = features[idx] 101 | x1 = layer(x1, x2) # x1 for next up layer 102 | 103 | return x1 104 | 105 | 106 | class DoubleConv(nn.Module): 107 | '''(conv => BN => ReLU) * 2''' 108 | 109 | def __init__(self, in_ch, out_ch): 110 | super(DoubleConv, self).__init__() 111 | self.conv = nn.Sequential( 112 | nn.Conv2d(in_ch, out_ch, 3, padding=1), 113 | nn.BatchNorm2d(out_ch), 114 | nn.ReLU(inplace=True), 115 | nn.Conv2d(out_ch, out_ch, 3, padding=1), 116 | nn.BatchNorm2d(out_ch), 117 | nn.ReLU(inplace=True) 118 | ) 119 | 120 | def forward(self, x): 121 | x = self.conv(x) 122 | return x 123 | 124 | 125 | class InConv(nn.Module): 126 | def __init__(self, in_ch, out_ch, conv_block): 127 | super(InConv, self).__init__() 128 | self.conv = conv_block(in_ch, out_ch) 129 | 130 | def forward(self, x): 131 | x = self.conv(x) 132 | return x 133 | 134 | 135 | class Down(nn.Module): 136 | def __init__(self, in_ch, out_ch, conv_block): 137 | super(Down, self).__init__() 138 | 139 | self.mpconv = nn.Sequential( 140 | nn.MaxPool2d(2), 141 | conv_block(in_ch, out_ch) 142 | ) 143 | 144 | def forward(self, x): 145 | x = self.mpconv(x) 146 | return x 147 | 148 | 149 | class up_conv(nn.Module): 150 | def __init__(self, ch_in, ch_out): 151 | super(up_conv, self).__init__() 152 | self.up = nn.Sequential( 153 | nn.Upsample(scale_factor=2), 154 | nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1, bias=True), 155 | nn.BatchNorm2d(ch_out), 156 | nn.ReLU(inplace=True) 157 | ) 158 | 159 | def forward(self, x): 160 | x = self.up(x) 161 | return x 162 | 163 | 164 | class Up(nn.Module): 165 | def __init__(self, in_ch, out_ch, conv_block): 166 | super(Up, self).__init__() 167 | 168 | self.up = nn.ConvTranspose2d(in_ch // 2, in_ch // 2, 2, stride=2) 169 | self.conv = conv_block(in_ch, out_ch) 170 | 171 | def forward(self, x1, x2): 172 | x1 = self.up(x1) 173 | 174 | # input is CHW 175 | diffY = x2.detach().size()[2] - x1.detach().size()[2] 176 | diffX = x2.detach().size()[3] - x1.detach().size()[3] 177 | 178 | x1 = F.pad(x1, (diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2)) 179 | 180 | # for padding issues, see 181 | # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a 182 | # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd 183 | 184 | x = torch.cat([x2, x1], dim=1) 185 | x = self.conv(x) 186 | return x 187 | 188 | 189 | class OutConv(nn.Module): 190 | def __init__(self, in_ch, out_ch): 191 | super(OutConv, self).__init__() 192 | self.conv = nn.Conv2d(in_ch, out_ch, 1) 193 | 194 | def forward(self, x): 195 | x = self.conv(x) 196 | return x 197 | 198 | 199 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | affine==2.3.0 2 | argh==0.26.2 3 | attrs==21.2.0 4 | certifi==2021.10.8 5 | charset-normalizer==2.0.8 6 | click==8.0.3 7 | click-plugins==1.1.1 8 | cligj==0.7.2 9 | configparser==5.1.0 10 | cycler==0.11.0 11 | docker-pycreds==0.4.0 12 | einops==0.6.0 13 | filelock==3.11.0 14 | fonttools==4.28.4 15 | fsspec==2023.5.0 16 | fvcore==0.1.5.post20211023 17 | gin-config==0.5.0 18 | gitdb==4.0.9 19 | GitPython==3.1.24 20 | huggingface-hub==0.15.1 21 | idna==3.3 22 | imagecodecs==2022.8.8 23 | iopath==0.1.9 24 | joblib==1.3.1 25 | kiwisolver==1.3.2 26 | lightning-utilities==0.10.0 27 | matplotlib==3.5.1 28 | mkl-fft==1.3.1 29 | mkl-random @ file:///tmp/build/80754af9/mkl_random_1626186066731/work 30 | mkl-service==2.4.0 31 | networkx==3.1 32 | numpy @ file:///tmp/build/80754af9/numpy_and_numpy_base_1634095651905/work 33 | olefile @ file:///Users/ktietz/demo/mc3/conda-bld/olefile_1629805411829/work 34 | opencv-python==4.5.5.62 35 | opt-einsum==3.3.0 36 | packaging==21.3 37 | pandas==2.0.3 38 | pathtools==0.1.2 39 | patsy==0.5.3 40 | pgmpy==0.1.23 41 | Pillow==8.4.0 42 | portalocker==2.3.2 43 | promise==2.3 44 | protobuf==3.19.1 45 | psutil==5.8.0 46 | pyparsing==3.0.6 47 | python-dateutil==2.8.2 48 | pytz==2023.3 49 | PyYAML==6.0 50 | rasterio==1.2.10 51 | regex==2023.6.3 52 | requests==2.26.0 53 | scikit-learn==1.3.0 54 | scipy==1.7.3 55 | sentry-sdk==1.5.0 56 | shortuuid==1.0.8 57 | six @ file:///tmp/build/80754af9/six_1623709665295/work 58 | smmap==5.0.0 59 | snuggs==1.4.7 60 | statsmodels==0.14.0 61 | subprocess32==3.5.4 62 | tabulate==0.8.9 63 | termcolor==1.1.0 64 | threadpoolctl==3.2.0 65 | tifffile==2022.8.12 66 | timm==0.6.13 67 | tokenizers==0.13.3 68 | torch==1.10.0 69 | torchaudio==0.10.0 70 | torcheval==0.0.7 71 | torchmetrics==1.2.0 72 | torchvision==0.11.1 73 | tqdm==4.62.3 74 | transformers==4.29.2 75 | typing-extensions @ file:///tmp/build/80754af9/typing_extensions_1631814937681/work 76 | tzdata==2023.3 77 | urllib3==1.26.7 78 | wandb==0.12.7 79 | yacs==0.1.8 80 | yaspin==2.1.0 81 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import timeit 4 | 5 | import torch 6 | from torch import optim 7 | from torch.utils import data as torch_data 8 | 9 | import wandb 10 | import numpy as np 11 | 12 | from utils import datasets, evaluation, experiment_manager, parsers, helpers 13 | from model import model 14 | from utils.experiment_manager import CfgNode 15 | 16 | 17 | def run_training(cfg: CfgNode): 18 | net = model.init_model(cfg) 19 | net.to(device) 20 | 21 | optimizer = optim.AdamW(net.parameters(), lr=cfg.TRAINER.LR, weight_decay=0.01) 22 | 23 | def lambda_rule(e: int): 24 | lr_l = 1.0 - e / float(cfg.TRAINER.EPOCHS - 1) 25 | return lr_l 26 | scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule) 27 | 28 | criterion = model.power_jaccard_loss 29 | 30 | # reset the generators 31 | dataset = datasets.create_train_dataset(cfg, run_type='train') 32 | print(dataset) 33 | 34 | dataloader_kwargs = { 35 | 'batch_size': cfg.TRAINER.BATCH_SIZE, 36 | 'num_workers': 0 if cfg.DEBUG else cfg.DATALOADER.NUM_WORKER, 37 | 'shuffle': cfg.DATALOADER.SHUFFLE, 38 | 'drop_last': True, 39 | 'pin_memory': True, 40 | } 41 | dataloader = torch_data.DataLoader(dataset, **dataloader_kwargs) 42 | 43 | edges = helpers.get_edges(cfg.DATALOADER.TIMESERIES_LENGTH, cfg.MODEL.EDGE_TYPE) 44 | 45 | # unpacking cfg 46 | epochs = cfg.TRAINER.EPOCHS 47 | steps_per_epoch = len(dataloader) 48 | 49 | # tracking variables 50 | global_step = epoch_float = 0 51 | 52 | # early stopping 53 | best_f1_val = 0 54 | trigger_times = 0 55 | stop_training = False 56 | 57 | for epoch in range(1, epochs + 1): 58 | print(f'Starting epoch {epoch}/{epochs}.') 59 | wandb.log({'lr': scheduler.get_last_lr()[-1] if scheduler is not None else cfg.TRAINER.LR, 'epoch': epoch}) 60 | start = timeit.default_timer() 61 | loss_seg_set, loss_ch_set, loss_set = [], [], [] 62 | 63 | for i, batch in enumerate(dataloader): 64 | 65 | net.train() 66 | optimizer.zero_grad() 67 | 68 | x, y_seg = batch['x'].to(device), batch['y'].to(device) 69 | logits_ch, logits_seg = net(x, edges) 70 | 71 | y_ch = helpers.get_ch(y_seg, edges) 72 | 73 | loss_seg = criterion(logits_seg, y_seg) 74 | loss_ch = criterion(logits_ch, y_ch) 75 | 76 | loss = loss_seg + loss_ch 77 | loss.backward() 78 | optimizer.step() 79 | 80 | loss_seg_set.append(loss_seg.item()) 81 | loss_ch_set.append(loss_ch.item()) 82 | loss_set.append(loss.item()) 83 | 84 | global_step += 1 85 | epoch_float = global_step / steps_per_epoch 86 | 87 | if global_step % cfg.LOG_FREQ == 0: 88 | print(f'Logging step {global_step} (epoch {epoch_float:.2f}).') 89 | 90 | # logging 91 | time = timeit.default_timer() - start 92 | wandb.log({ 93 | 'loss_seg': np.mean(loss_seg_set), 94 | 'loss_ch': np.mean(loss_ch_set), 95 | 'loss': np.mean(loss_set), 96 | 'time': time, 97 | 'step': global_step, 98 | 'epoch': epoch_float, 99 | }) 100 | start = timeit.default_timer() 101 | loss_seg_set, loss_ch_set, loss_set = [], [], [] 102 | # end of batch 103 | 104 | assert (epoch == epoch_float) 105 | print(f'epoch float {epoch_float} (step {global_step}) - epoch {epoch}') 106 | if scheduler is not None: 107 | scheduler.step() 108 | # evaluation at the end of an epoch 109 | f1_val = evaluation.model_evaluation(net, cfg, device, 'val', epoch_float, global_step) 110 | 111 | if f1_val <= best_f1_val: 112 | trigger_times += 1 113 | if trigger_times > cfg.TRAINER.PATIENCE: 114 | stop_training = True 115 | else: 116 | best_f1_val = f1_val 117 | wandb.log({ 118 | 'best val f1': best_f1_val, 119 | 'step': global_step, 120 | 'epoch': epoch_float, 121 | }) 122 | print(f'saving network (F1 {f1_val:.3f})', flush=True) 123 | model.save_model(net, epoch, cfg) 124 | trigger_times = 0 125 | 126 | if stop_training: 127 | break 128 | 129 | net = model.load_model(cfg, device) 130 | _ = evaluation.model_evaluation(net, cfg, device, 'test', epoch_float, global_step) 131 | 132 | 133 | if __name__ == '__main__': 134 | args = parsers.training_argument_parser().parse_known_args()[0] 135 | cfg = experiment_manager.setup_cfg(args) 136 | 137 | # make training deterministic 138 | torch.manual_seed(cfg.SEED) 139 | np.random.seed(cfg.SEED) 140 | torch.backends.cudnn.deterministic = True 141 | torch.backends.cudnn.benchmark = False 142 | 143 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 144 | 145 | print('=== Runnning on device: p', device) 146 | 147 | wandb.init( 148 | name=cfg.NAME, 149 | config=cfg, 150 | project='ContUrbanCD', 151 | mode='online' if not cfg.DEBUG else 'disabled', 152 | ) 153 | 154 | try: 155 | run_training(cfg) 156 | except KeyboardInterrupt: 157 | try: 158 | sys.exit(0) 159 | except SystemExit: 160 | os._exit(0) 161 | -------------------------------------------------------------------------------- /utils/datasets.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from pathlib import Path 3 | from abc import abstractmethod 4 | import numpy as np 5 | import multiprocessing 6 | import tifffile 7 | from utils import experiment_manager, helpers 8 | from utils.experiment_manager import CfgNode 9 | import cv2 10 | import torch 11 | from torchvision import transforms 12 | import numpy as np 13 | from scipy.ndimage import gaussian_filter 14 | 15 | 16 | def create_train_dataset(cfg: CfgNode, run_type: str) -> torch.utils.data.Dataset: 17 | if cfg.DATASET.NAME == 'sn7': 18 | return TrainSpaceNet7Dataset(cfg, run_type=run_type) 19 | elif cfg.DATASET.NAME == 'wusu': 20 | return TrainWUSUDataset(cfg, run_type=run_type) 21 | else: 22 | raise Exception('Unknown train dataset!') 23 | 24 | 25 | def create_eval_dataset(cfg: CfgNode, run_type: str, site: str = None, tiling: int = None) -> torch.utils.data.Dataset: 26 | if cfg.DATASET.NAME == 'sn7': 27 | return EvalSpaceNet7Dataset(cfg, run_type, aoi_id=site, tiling=tiling) 28 | elif cfg.DATASET.NAME == 'wusu': 29 | if run_type == 'test': 30 | return EvalTestWUSUDataset(cfg, site=site) 31 | else: 32 | return EvalWUSUDataset(cfg, run_type=run_type, tiling=tiling) 33 | else: 34 | raise Exception('Unknown train dataset!') 35 | 36 | 37 | class AbstractSpaceNet7Dataset(torch.utils.data.Dataset): 38 | 39 | def __init__(self, cfg: CfgNode): 40 | super().__init__() 41 | self.cfg = cfg 42 | self.root_path = Path(cfg.PATHS.DATASET) 43 | self.name = cfg.DATASET.NAME 44 | 45 | self.include_alpha = cfg.DATALOADER.INCLUDE_ALPHA 46 | self.pad = cfg.DATALOADER.PAD_BORDERS 47 | 48 | @abstractmethod 49 | def __getitem__(self, index: int) -> dict: 50 | pass 51 | 52 | @abstractmethod 53 | def __len__(self) -> int: 54 | pass 55 | 56 | def load_planet_mosaic(self, aoi_id: str, dataset: str, year: int, month: int) -> np.ndarray: 57 | folder = self.root_path / dataset / aoi_id / 'images_masked' 58 | file = folder / f'global_monthly_{year}_{month:02d}_mosaic_{aoi_id}.tif' 59 | img = tifffile.imread(str(file)) 60 | img = img / 255 61 | # 4th band (last oen) is alpha band 62 | if not self.include_alpha: 63 | img = img[:, :, :-1] 64 | m, n, _ = img.shape 65 | if self.pad and (m != 1024 or n != 1024): 66 | # https://www.geeksforgeeks.org/python-opencv-cv2-copymakeborder-method/ 67 | img = cv2.copyMakeBorder(img, 0, 1024 - m, 0, 1024 - n, borderType=cv2.BORDER_REPLICATE) 68 | return img.astype(np.float32) 69 | 70 | def load_building_label(self, aoi_id: str, year: int, month: int) -> np.ndarray: 71 | folder = self.root_path / 'train' / aoi_id / 'labels_raster' 72 | file = folder / f'global_monthly_{year}_{month:02d}_mosaic_{aoi_id}_Buildings.tif' 73 | label = tifffile.imread(str(file)) 74 | m, n = label.shape 75 | if self.pad and (m != 1024 or n != 1024): 76 | label = cv2.copyMakeBorder(label, 0, 1024 - m, 0, 1024 - n, borderType=cv2.BORDER_REPLICATE) 77 | label = label[:, :, None] 78 | label = label > 0 79 | return label.astype(np.float32) 80 | 81 | def load_change_label(self, aoi_id: str, year_t1: int, month_t1: int, year_t2: int, month_t2) -> np.ndarray: 82 | building_t1 = self.load_building_label(aoi_id, year_t1, month_t1) 83 | building_t2 = self.load_building_label(aoi_id, year_t2, month_t2) 84 | change = np.logical_and(building_t1 == 0, building_t2 == 1) 85 | return change.astype(np.float32) 86 | 87 | def load_mask(self, aoi_id: str, year: int, month: int) -> np.ndarray: 88 | folder = self.root_path / 'train' / aoi_id / 'labels_raster' 89 | file = folder / f'global_monthly_{year}_{month:02d}_mosaic_{aoi_id}_mask.tif' 90 | mask = tifffile.imread(str(file)) 91 | return mask.astype(np.int8) 92 | 93 | def get_aoi_ids(self) -> list: 94 | return list(set([s['aoi_id'] for s in self.samples])) 95 | 96 | def __len__(self): 97 | return self.length 98 | 99 | def __str__(self): 100 | return f'Dataset with {self.length} samples.' 101 | 102 | 103 | class AbstractWUSUDataset(torch.utils.data.Dataset): 104 | 105 | def __init__(self, cfg: experiment_manager.CfgNode): 106 | super().__init__() 107 | self.cfg = cfg 108 | self.root_path = Path(cfg.PATHS.DATASET) 109 | self.timestamps = [15, 16, 18] 110 | self.T = 3 111 | self.name = cfg.DATASET.NAME 112 | 113 | @abstractmethod 114 | def __getitem__(self, index: int) -> dict: 115 | pass 116 | 117 | @abstractmethod 118 | def __len__(self) -> int: 119 | pass 120 | 121 | def load_gf2_img(self, site: str, dataset: str, index: int, year: int) -> np.ndarray: 122 | file = self.root_path / dataset / site / 'imgs' / f'{site}{year}_{index}.tif' 123 | img = tifffile.imread(str(file)) 124 | img = img / 255 125 | return img.astype(np.float32) 126 | 127 | def load_lulc_label(self, site: str, dataset: str, index: int, year: int) -> np.ndarray: 128 | file = self.root_path / dataset / site / 'class' / f'{site}{year}_{index}.tif' 129 | lulc_label = tifffile.imread(str(file)) 130 | lulc_label = lulc_label[:, :, None] 131 | return lulc_label 132 | 133 | def load_building_label(self, site: str, dataset: str, index: int, year: int) -> np.ndarray: 134 | lulc = self.load_lulc_label(site, dataset, index, year) 135 | buildings = np.logical_or(lulc == 2, lulc == 3) 136 | return buildings.astype(np.float32) 137 | 138 | def load_building_change_label(self, site: str, dataset: str, index: int, year_t1: int, year_t2: int) -> np.ndarray: 139 | buildings_t1 = self.load_building_label(site, dataset, index, year_t1) 140 | buildings_t2 = self.load_building_label(site, dataset, index, year_t2) 141 | change = np.logical_and(buildings_t1 == 0, buildings_t2 == 1) 142 | return change.astype(np.float32) 143 | 144 | 145 | class TrainWUSUDataset(AbstractWUSUDataset): 146 | 147 | def __init__(self, cfg: experiment_manager.CfgNode, run_type: str, no_augmentations: bool = False): 148 | super().__init__(cfg) 149 | 150 | # handling transformations of data 151 | self.no_augmentations = no_augmentations 152 | self.transform = compose_transformations(cfg, no_augmentations) 153 | 154 | self.dataset = 'test' if run_type == 'test' else 'train' 155 | if run_type == 'test': 156 | self.samples = helpers.load_json(self.root_path / f'samples_test.json') 157 | elif run_type == 'val' or run_type == 'train': 158 | self.samples = helpers.load_json(self.root_path / f'samples_train.json') 159 | self.samples = [s for s in self.samples if s['split'] == run_type] 160 | else: 161 | raise Exception('Unkown run type!') 162 | 163 | manager = multiprocessing.Manager() 164 | self.samples = manager.list(self.samples) 165 | 166 | self.length = len(self.samples) 167 | 168 | def __getitem__(self, index): 169 | 170 | sample = self.samples[index] 171 | site, index = sample['site'], sample['index'] 172 | 173 | images = [self.load_gf2_img(site, self.dataset, index, year) for year in self.timestamps] 174 | labels = [self.load_building_label(site, self.dataset, index, year) for year in self.timestamps] 175 | 176 | images, labels = self.transform((np.stack(images), np.stack(labels))) 177 | 178 | item = { 179 | 'x': images, 180 | 'y': labels, 181 | 'site': site, 182 | 'index': index, 183 | } 184 | return item 185 | 186 | def __len__(self): 187 | return self.length 188 | 189 | def __str__(self): 190 | return f'Train {self.name} dataset with {self.length} samples.' 191 | 192 | 193 | class EvalWUSUDataset(AbstractWUSUDataset): 194 | 195 | def __init__(self, cfg: experiment_manager.CfgNode, run_type: str, tiling: int = None, add_padding: bool = False, 196 | index: int = None): 197 | super().__init__(cfg) 198 | 199 | self.tiling = tiling 200 | self.add_padding = add_padding 201 | 202 | # handling transformations of data 203 | self.transform = compose_transformations(cfg, no_augmentations=True) 204 | 205 | self.dataset = 'test' if run_type == 'test' else 'train' 206 | if run_type == 'test': 207 | samples = helpers.load_json(self.root_path / f'samples_test.json') 208 | elif run_type == 'val' or run_type == 'train': 209 | samples = helpers.load_json(self.root_path / f'samples_train.json') 210 | samples = [s for s in samples if s['split'] == run_type] 211 | else: 212 | raise Exception('Unkown run type!') 213 | 214 | if index is not None: 215 | samples = [s for s in samples if s['index'] == index] 216 | 217 | if tiling is None: 218 | self.tiling = 512 219 | 220 | self.samples = [] 221 | for sample in samples: 222 | for i in range(0, 512, self.tiling): 223 | for j in range(0, 512, self.tiling): 224 | tile_sample = { 225 | 'site': sample['site'], 226 | 'index': sample['index'], 227 | 'split': sample['split'], 228 | 'i': i, 229 | 'j': j, 230 | } 231 | self.samples.append(tile_sample) 232 | 233 | manager = multiprocessing.Manager() 234 | self.samples = manager.list(self.samples) 235 | 236 | self.length = len(self.samples) 237 | 238 | def __getitem__(self, index): 239 | 240 | sample = self.samples[index] 241 | site, index, i, j = sample['site'], sample['index'], sample['i'], sample['j'] 242 | 243 | images = [self.load_gf2_img(site, self.dataset, index, year) for year in self.timestamps] 244 | images = np.stack(images) 245 | if self.add_padding: 246 | images = np.pad(images, ((0, 0), (self.tiling, self.tiling), (self.tiling, self.tiling), (0, 0)), 247 | mode='reflect') 248 | i_min, j_min = i, j 249 | i_max, j_max = i + 3 * self.tiling, j + 3 * self.tiling 250 | images = images[:, i_min:i_max, j_min:j_max] 251 | else: 252 | images = images[:, i:i + self.tiling, j:j + self.tiling] 253 | 254 | labels = [self.load_building_label(site, self.dataset, index, year) for year in self.timestamps] 255 | labels = np.stack(labels)[:, i:i + self.tiling, j:j + self.tiling] 256 | 257 | images, labels = self.transform((images, labels)) 258 | 259 | item = { 260 | 'x': images, 261 | 'y': labels, 262 | 'site': site, 263 | 'index': index, 264 | 'i': i, 265 | 'j': j, 266 | } 267 | 268 | return item 269 | 270 | def __len__(self): 271 | return self.length 272 | 273 | def __str__(self): 274 | return f'Eval {self.name} dataset with {self.length} samples.' 275 | 276 | 277 | class EvalTestWUSUDataset(AbstractWUSUDataset): 278 | 279 | def __init__(self, cfg: experiment_manager.CfgNode, site: str = None): 280 | super().__init__(cfg) 281 | 282 | metadata_file = self.root_path / 'metadata_test.json' 283 | self.metadata = helpers.load_json(metadata_file) 284 | 285 | self.sites = ['JA', 'HS'] if site is None else [site] 286 | 287 | self.transform = compose_transformations(cfg, no_augmentations=True) 288 | self.crop_size = cfg.AUGMENTATION.CROP_SIZE 289 | 290 | self.dataset = 'test' 291 | self.samples = [] 292 | for site in self.sites: 293 | tiles = self.metadata[site]['tiles'] 294 | tile_size = int(self.metadata[site]['tile_size'] - self.metadata[site]['overlap']) 295 | for tile in tiles: 296 | i_tile, j_tile = tile['i_tile'], tile['j_tile'] 297 | m_max, n_max = tile_size, tile_size 298 | if tile['edge_tile']: 299 | m_edge_tile, n_edge_tile = self.metadata[site]['m_edge_tile'], self.metadata[site]['n_edge_tile'] 300 | m_edge_tile = m_edge_tile - m_edge_tile % self.crop_size 301 | n_edge_tile = n_edge_tile - n_edge_tile % self.crop_size 302 | if tile['row_end']: 303 | m_max = m_edge_tile 304 | if tile['col_end']: 305 | n_max = n_edge_tile 306 | 307 | for i in range(0, m_max, self.crop_size): 308 | for j in range(0, n_max, self.crop_size): 309 | tile_sample = { 310 | 'site': site, 311 | 'index': tile['index'], 312 | 'split': 'test', 313 | 'i_crop': i, 314 | 'j_crop': j, 315 | 'i_tile': i_tile, 316 | 'j_tile': j_tile, 317 | } 318 | self.samples.append(tile_sample) 319 | 320 | manager = multiprocessing.Manager() 321 | self.samples = manager.list(self.samples) 322 | self.metadata = manager.dict(self.metadata) 323 | self.length = len(self.samples) 324 | 325 | def __getitem__(self, sample_index): 326 | 327 | sample = self.samples[sample_index] 328 | site, index = sample['site'], sample['index'] 329 | 330 | images = np.stack([self.load_gf2_img(site, self.dataset, index, year) for year in self.timestamps]) 331 | i_crop, j_crop = sample['i_crop'], sample['j_crop'] 332 | images = images[:, i_crop:i_crop + self.crop_size, j_crop:j_crop + self.crop_size] 333 | 334 | labels = np.stack([self.load_building_label(site, self.dataset, index, year) for year in self.timestamps]) 335 | labels = labels[:, i_crop:i_crop + self.crop_size, j_crop:j_crop + self.crop_size] 336 | 337 | images, labels = self.transform((images, labels)) 338 | 339 | i_tile, j_tile = sample['i_tile'], sample['j_tile'] 340 | tile_size = int(self.metadata[site]['tile_size'] - self.metadata[site]['overlap']) 341 | i_img, j_img = i_tile * tile_size + i_crop, j_tile * tile_size + j_crop 342 | 343 | item = { 344 | 'x': images, 345 | 'y': labels, 346 | 'site': site, 347 | 'index': index, 348 | 'i_img': i_img, 349 | 'j_img': j_img, 350 | } 351 | 352 | return item 353 | 354 | def get_img_dims(self, site: str) -> tuple: 355 | tile_size = int(self.metadata[site]['tile_size'] - self.metadata[site]['overlap']) 356 | m_tile, n_tile = self.metadata[site]['m_tile'], self.metadata[site]['n_tile'] 357 | 358 | m_edge_tile, n_edge_tile = self.metadata[site]['m_edge_tile'], self.metadata[site]['n_edge_tile'] 359 | m_edge_tile = m_edge_tile - m_edge_tile % self.crop_size 360 | n_edge_tile = n_edge_tile - n_edge_tile % self.crop_size 361 | 362 | m_img = (m_tile - 1) * tile_size + m_edge_tile 363 | n_img = (n_tile - 1) * tile_size + n_edge_tile 364 | return m_img, n_img 365 | 366 | def __len__(self): 367 | return self.length 368 | 369 | def __str__(self): 370 | return f'Eval {self.name} dataset with {self.length} samples.' 371 | 372 | 373 | class TrainSpaceNet7Dataset(AbstractSpaceNet7Dataset): 374 | 375 | def __init__(self, cfg: experiment_manager.CfgNode, run_type: str, no_augmentations: bool = False, 376 | disable_multiplier: bool = False): 377 | super().__init__(cfg) 378 | 379 | self.T = cfg.DATALOADER.TIMESERIES_LENGTH 380 | self.include_change_label = cfg.DATALOADER.INCLUDE_CHANGE_LABEL 381 | 382 | # handling transformations of data 383 | self.no_augmentations = no_augmentations 384 | self.transform = compose_transformations(cfg, no_augmentations) 385 | 386 | self.metadata = helpers.load_json(self.root_path / f'metadata_siamesessl.json') 387 | self.aoi_ids = list(cfg.DATASET.TRAIN_IDS) 388 | assert (len(self.aoi_ids) == 60) 389 | 390 | # split 391 | self.run_type = run_type 392 | self.i_split = cfg.DATALOADER.I_SPLIT 393 | self.j_split = cfg.DATALOADER.J_SPLIT 394 | 395 | if not disable_multiplier: 396 | self.aoi_ids = self.aoi_ids * cfg.DATALOADER.TRAINING_MULTIPLIER 397 | 398 | manager = multiprocessing.Manager() 399 | self.aoi_ids = manager.list(self.aoi_ids) 400 | self.metadata = manager.dict(self.metadata) 401 | 402 | self.length = len(self.aoi_ids) 403 | 404 | def __getitem__(self, index): 405 | 406 | aoi_id = self.aoi_ids[index] 407 | 408 | timestamps = [ts for ts in self.metadata[aoi_id] if not ts['mask']] 409 | 410 | t_values = sorted(np.random.randint(0, len(timestamps), size=self.T)) 411 | timestamps = sorted([timestamps[t] for t in t_values], key=lambda ts: int(ts['year']) * 12 + int(ts['month'])) 412 | 413 | images = [self.load_planet_mosaic(aoi_id, ts['dataset'], ts['year'], ts['month']) for ts in timestamps] 414 | labels = [self.load_building_label(aoi_id, ts['year'], ts['month']) for ts in timestamps] 415 | images = [self.apply_split(img) for img in images] 416 | labels = [self.apply_split(label) for label in labels] 417 | 418 | images, labels = self.transform((np.stack(images), np.stack(labels))) 419 | 420 | item = { 421 | 'x': images, 422 | 'y': labels, 423 | 'aoi_id': aoi_id, 424 | 'dates': [(int(ts['year']), int(ts['month'])) for ts in timestamps], 425 | } 426 | 427 | if self.include_change_label: 428 | labels_ch = [] 429 | for t in range(len(timestamps) - 1): 430 | labels_ch.append(torch.ne(labels[t + 1], labels[t])) 431 | labels_ch.append(torch.ne(labels[-1], labels[0])) 432 | item['y_ch'] = torch.stack(labels_ch) 433 | 434 | return item 435 | 436 | def apply_split(self, img: np.ndarray): 437 | if self.run_type == 'train': 438 | return img[:self.i_split] 439 | elif self.run_type == 'val': 440 | return img[self.i_split:, :self.j_split] 441 | elif self.run_type == 'test': 442 | return img[self.i_split:, self.j_split:] 443 | else: 444 | raise Exception('Unkown split!') 445 | 446 | def __len__(self): 447 | return self.length 448 | 449 | def __str__(self): 450 | return f'Train {self.name} dataset with {self.length} samples.' 451 | 452 | 453 | class EvalSpaceNet7Dataset(AbstractSpaceNet7Dataset): 454 | 455 | def __init__(self, cfg: experiment_manager.CfgNode, run_type: str, tiling: int = None, aoi_id: str = None, 456 | add_padding: bool = False): 457 | super().__init__(cfg) 458 | 459 | self.T = cfg.DATALOADER.TIMESERIES_LENGTH 460 | self.include_change_label = cfg.DATALOADER.INCLUDE_CHANGE_LABEL 461 | self.tiling = tiling if tiling is not None else 1024 462 | self.eval_train_threshold = cfg.DATALOADER.EVAL_TRAIN_THRESHOLD 463 | self.add_padding = add_padding 464 | 465 | # handling transformations of data 466 | self.transform = compose_transformations(cfg, no_augmentations=True) 467 | 468 | self.metadata = helpers.load_json(self.root_path / f'metadata_conturbancd.json') 469 | 470 | if aoi_id is None: 471 | self.aoi_ids = list(cfg.DATASET.TRAIN_IDS) 472 | assert (len(self.aoi_ids) == 60) 473 | else: 474 | self.aoi_ids = [aoi_id] 475 | 476 | # split 477 | self.run_type = run_type 478 | self.i_split = cfg.DATALOADER.I_SPLIT 479 | self.j_split = cfg.DATALOADER.J_SPLIT 480 | 481 | self.min_m, self.max_m = 0, 1024 482 | self.min_n, self.max_n = 0, 1024 483 | if run_type == 'train': 484 | self.max_m = self.i_split 485 | self.m = self.max_m 486 | else: 487 | assert (run_type == 'val' or run_type == 'test') 488 | self.min_m = self.i_split 489 | if run_type == 'val': 490 | self.max_n = self.j_split 491 | if run_type == 'test': 492 | self.min_n = self.j_split 493 | 494 | self.samples = [] 495 | for aoi_id in self.aoi_ids: 496 | for i in range(self.min_m, self.max_m, self.tiling): 497 | for j in range(self.min_n, self.max_n, self.tiling): 498 | self.samples.append((aoi_id, (i, j))) 499 | 500 | self.m, self.n = self.max_m - self.min_m, self.max_n - self.min_n 501 | 502 | manager = multiprocessing.Manager() 503 | self.aoi_ids = manager.list(self.aoi_ids) 504 | self.metadata = manager.dict(self.metadata) 505 | 506 | self.length = len(self.samples) 507 | 508 | def __getitem__(self, index): 509 | 510 | aoi_id, (i, j) = self.samples[index] 511 | 512 | timestamps = [ts for ts in self.metadata[aoi_id] if not ts['mask']] 513 | t_values = list(np.linspace(0, len(timestamps), self.T, endpoint=False, dtype=int)) 514 | timestamps = sorted([timestamps[t] for t in t_values], key=lambda ts: int(ts['year']) * 12 + int(ts['month'])) 515 | 516 | images = [self.load_planet_mosaic(ts['aoi_id'], ts['dataset'], ts['year'], ts['month']) for ts in timestamps] 517 | images = np.stack(images) 518 | if self.add_padding: 519 | # images = np.pad(images, ((0, 0), (self.tiling, self.tiling), (self.tiling, self.tiling), (0, 0)), 520 | # mode='constant', constant_values=0) 521 | images = np.pad(images, ((0, 0), (self.tiling, self.tiling), (self.tiling, self.tiling), (0, 0)), 522 | mode='reflect') 523 | i_min, j_min = i, j 524 | i_max, j_max = i + 3 * self.tiling, j + 3 * self.tiling 525 | images = images[:, i_min:i_max, j_min:j_max] 526 | else: 527 | images = images[:, i:i + self.tiling, j:j + self.tiling] 528 | 529 | labels = [self.load_building_label(aoi_id, ts['year'], ts['month']) for ts in timestamps] 530 | labels = np.stack(labels)[:, i:i + self.tiling, j:j + self.tiling] 531 | 532 | images, labels = self.transform((images, labels)) 533 | 534 | item = { 535 | 'x': images, 536 | 'y': labels, 537 | 'aoi_id': aoi_id, 538 | 'i': i - self.min_m, 539 | 'j': j - self.min_n, 540 | 'dates': [(int(ts['year']), int(ts['month'])) for ts in timestamps], 541 | } 542 | 543 | return item 544 | 545 | def __len__(self): 546 | return self.length 547 | 548 | def __str__(self): 549 | return f'Eval {self.name} dataset with {self.length} samples.' 550 | 551 | 552 | def compose_transformations(cfg, no_augmentations: bool): 553 | if no_augmentations: 554 | return transforms.Compose([Numpy2Torch()]) 555 | 556 | transformations = [] 557 | 558 | # cropping 559 | if cfg.AUGMENTATION.IMAGE_OVERSAMPLING_TYPE == 'none': 560 | transformations.append(UniformCrop(cfg.AUGMENTATION.CROP_SIZE)) 561 | elif cfg.AUGMENTATION.IMAGE_OVERSAMPLING_TYPE == 'change': 562 | transformations.append(ImportanceRandomCrop(cfg.AUGMENTATION.CROP_SIZE, 'change')) 563 | elif cfg.AUGMENTATION.IMAGE_OVERSAMPLING_TYPE == 'semantic': 564 | transformations.append(ImportanceRandomCrop(cfg.AUGMENTATION.CROP_SIZE, 'semantic')) 565 | else: 566 | raise Exception('Unkown oversampling type!') 567 | 568 | if cfg.AUGMENTATION.RANDOM_FLIP: 569 | transformations.append(RandomFlip()) 570 | 571 | if cfg.AUGMENTATION.RANDOM_ROTATE: 572 | transformations.append(RandomRotate()) 573 | 574 | if cfg.AUGMENTATION.COLOR_BLUR: 575 | transformations.append(RandomColorBlur()) 576 | 577 | transformations.append(Numpy2Torch()) 578 | 579 | if cfg.AUGMENTATION.COLOR_JITTER: 580 | transformations.append(RandomColorJitter(n_bands=cfg.MODEL.IN_CHANNELS)) 581 | 582 | return transforms.Compose(transformations) 583 | 584 | 585 | class Numpy2Torch(object): 586 | def __call__(self, args): 587 | images, labels = args 588 | images_tensor = torch.Tensor(images).permute(0, 3, 1, 2) 589 | labels_tensor = torch.Tensor(labels).permute(0, 3, 1, 2) 590 | return images_tensor, labels_tensor 591 | 592 | 593 | class RandomFlip(object): 594 | def __call__(self, args): 595 | images, labels = args 596 | horizontal_flip = np.random.choice([True, False]) 597 | vertical_flip = np.random.choice([True, False]) 598 | 599 | if horizontal_flip: 600 | images = np.flip(images, axis=2) 601 | labels = np.flip(labels, axis=2) 602 | 603 | if vertical_flip: 604 | images = np.flip(images, axis=1) 605 | labels = np.flip(labels, axis=1) 606 | 607 | images = images.copy() 608 | labels = labels.copy() 609 | 610 | return images, labels 611 | 612 | 613 | class RandomRotate(object): 614 | def __call__(self, args): 615 | images, labels = args 616 | k = np.random.randint(1, 4) # number of 90 degree rotations 617 | images = np.rot90(images, k, axes=(1, 2)).copy() 618 | labels = np.rot90(labels, k, axes=(1, 2)).copy() 619 | return images, labels 620 | 621 | 622 | class RandomColorBlur(object): 623 | def __call__(self, args): 624 | images, labels = args 625 | for t in range(images.shape[0]): 626 | blurred_image = gaussian_filter(images[t], sigma=np.random.rand() / 2) 627 | images[t] = blurred_image 628 | return images, labels 629 | 630 | 631 | class RandomColorJitter(object): 632 | def __init__(self, brightness: float = 0.3, contrast: float = 0.3, saturation: float = 0.3, hue: float = 0.3, 633 | n_bands: int = 3): 634 | self.cj = transforms.ColorJitter(brightness=brightness, contrast=contrast, saturation=saturation, hue=hue) 635 | self.n_bands = n_bands 636 | 637 | def __call__(self, args): 638 | images, labels = args 639 | for t in range(images.shape[0]): 640 | if self.n_bands <= 3: 641 | images[t] = self.cj(images[t]) 642 | else: 643 | # Separate the bands 644 | bands = [images[t, i] for i in range(self.n_bands)] 645 | 646 | # Apply ColorJitter to each band 647 | jittered_bands = [self.cj(band.unsqueeze(0)) for band in bands] 648 | 649 | # Stack the bands back together 650 | jittered_image = torch.cat(jittered_bands, dim=0) 651 | images[t] = jittered_image 652 | return images, labels 653 | 654 | 655 | # Performs uniform cropping on images 656 | class UniformCrop(object): 657 | def __init__(self, crop_size: int): 658 | self.crop_size = crop_size 659 | 660 | def random_crop(self, args): 661 | images, labels = args 662 | _, height, width, _ = labels.shape 663 | crop_limit_x = width - self.crop_size 664 | crop_limit_y = height - self.crop_size 665 | x = np.random.randint(0, crop_limit_x) 666 | y = np.random.randint(0, crop_limit_y) 667 | 668 | images_crop = images[:, y:y + self.crop_size, x:x + self.crop_size] 669 | labels_crop = labels[:, y:y + self.crop_size, x:x + self.crop_size] 670 | return images_crop, labels_crop 671 | 672 | def __call__(self, args): 673 | images, labels = self.random_crop(args) 674 | return images, labels 675 | 676 | 677 | class ImportanceRandomCrop(UniformCrop): 678 | def __init__(self, crop_size: int, oversampling_type: str): 679 | super().__init__(crop_size) 680 | self.oversampling_type = oversampling_type 681 | 682 | def __call__(self, args): 683 | 684 | sample_size = 20 685 | balancing_factor = 5 686 | 687 | random_crops = [self.random_crop(args) for _ in range(sample_size)] 688 | 689 | if self.oversampling_type == 'change': 690 | crop_weights = np.array( 691 | [np.not_equal(crop_label[-1], crop_label[0]).sum() for _, crop_label in random_crops] 692 | ) + balancing_factor 693 | elif self.oversampling_type == 'semantic': 694 | crop_weights = np.array([crop_label.sum() for _, crop_label in random_crops]) + balancing_factor 695 | else: 696 | raise Exception('Unkown oversampling type!') 697 | 698 | crop_weights = crop_weights / crop_weights.sum() 699 | 700 | sample_idx = np.random.choice(sample_size, p=crop_weights) 701 | img, label = random_crops[sample_idx] 702 | 703 | return img, label 704 | -------------------------------------------------------------------------------- /utils/evaluation.py: -------------------------------------------------------------------------------- 1 | import abc 2 | 3 | import torch 4 | from torch.utils import data as torch_data 5 | from torch import Tensor 6 | 7 | from utils import datasets, helpers 8 | import wandb 9 | 10 | EPS = 10e-05 11 | 12 | 13 | class AbstractMeasurer(abc.ABC): 14 | def __init__(self, threshold: float = 0.5, name: str = None): 15 | 16 | self.threshold = threshold 17 | self.name = name 18 | 19 | # urban mapping 20 | self.TP_seg_cont = self.TN_seg_cont = self.FP_seg_cont = self.FN_seg_cont = 0 21 | self.TP_seg_fl = self.TN_seg_fl = self.FP_seg_fl = self.FN_seg_fl = 0 22 | 23 | # urban change | cch -> continuous change | flch -> first to last change 24 | self.TP_ch_cont = self.TN_ch_cont = self.FP_ch_cont = self.FN_ch_cont = 0 25 | self.TP_ch_fl = self.TN_ch_fl = self.FP_ch_fl = self.FN_ch_fl = 0 26 | 27 | def add_sample(self, *args, **kwargs): 28 | raise NotImplementedError("add_sample method must be implemented in the subclass.") 29 | 30 | def _update_metrics(self, y: Tensor, y_hat: Tensor, attr_name: str, mask: Tensor = None): 31 | y = y.bool() 32 | y_hat = y_hat > self.threshold 33 | 34 | tp_attr = f'TP_{attr_name}' 35 | tn_attr = f'TN_{attr_name}' 36 | fp_attr = f'FP_{attr_name}' 37 | fn_attr = f'FN_{attr_name}' 38 | 39 | tp = (y & y_hat).float() 40 | tn = (~y & ~y_hat).float() 41 | fp = (y_hat & ~y).float() 42 | fn = (~y_hat & y).float() 43 | 44 | if mask is not None: 45 | tp[mask] = float('nan') 46 | tn[mask] = float('nan') 47 | fp[mask] = float('nan') 48 | fn[mask] = float('nan') 49 | 50 | setattr(self, tp_attr, getattr(self, tp_attr) + torch.nansum(tp).float().item()) 51 | setattr(self, tn_attr, getattr(self, tn_attr) + torch.nansum(tn).float().item()) 52 | setattr(self, fp_attr, getattr(self, fp_attr) + torch.nansum(fp).float().item()) 53 | setattr(self, fn_attr, getattr(self, fn_attr) + torch.nansum(fn).float().item()) 54 | 55 | 56 | class MultiTaskMeasurer(AbstractMeasurer): 57 | def __init__(self, threshold: float = 0.5, name: str = None): 58 | super().__init__(threshold, name) 59 | 60 | def add_sample(self, y_seg: Tensor, y_hat_seg: Tensor, y_ch: Tensor, y_hat_ch: Tensor, mask: Tensor = None): 61 | 62 | # urban mapping 63 | if y_seg is not None: 64 | self._update_metrics(y_seg, y_hat_seg, 'seg_cont', mask) 65 | self._update_metrics(y_seg[:, [0, -1]], y_hat_seg[:, [0, -1]], 'seg_fl', mask) 66 | 67 | # urban change 68 | if y_hat_ch.size(1) > 1: 69 | self._update_metrics(y_ch[:, :-1], y_hat_ch[:, :-1], 'ch_cont', mask) 70 | self._update_metrics(y_ch[:, -1], y_hat_ch[:, -1], 'ch_fl', mask) 71 | 72 | 73 | def run_quantitative_evaluation(net, cfg, device, run_type: str, enable_mti: bool = False, 74 | mti_edge_setting: str = 'dense') -> MultiTaskMeasurer: 75 | tile_size = cfg.AUGMENTATION.CROP_SIZE 76 | ds = datasets.create_eval_dataset(cfg, run_type, tiling=tile_size) 77 | 78 | net.to(device) 79 | net.eval() 80 | 81 | m = MultiTaskMeasurer() 82 | edges_cyclic = helpers.get_edges(cfg.DATALOADER.TIMESERIES_LENGTH, 'cyclic') 83 | edges_mti = helpers.get_edges(cfg.DATALOADER.TIMESERIES_LENGTH, mti_edge_setting) 84 | 85 | batch_size = 1 if enable_mti else cfg.TRAINER.BATCH_SIZE 86 | dataloader = torch_data.DataLoader(ds, batch_size=batch_size, num_workers=0, shuffle=False, drop_last=False) 87 | 88 | for step, item in enumerate(dataloader): 89 | x, y_seg = item['x'].to(device), item['y'] 90 | y_ch = helpers.get_ch(y_seg, edges_cyclic) 91 | with torch.no_grad(): 92 | if enable_mti: 93 | o_seg = net.module.inference(x, edges_mti) 94 | o_ch = helpers.get_ch(o_seg, edges_cyclic) 95 | else: 96 | logits_ch, logits_seg = net(x, edges_cyclic) 97 | o_ch, o_seg = torch.sigmoid(logits_ch).detach(), torch.sigmoid(logits_seg).detach() 98 | m.add_sample(y_seg.cpu(), o_seg.cpu(), y_ch.cpu(), o_ch.cpu()) 99 | 100 | return m 101 | 102 | 103 | def model_evaluation(net, cfg, device, run_type: str, epoch: float, step: int) -> float: 104 | m = run_quantitative_evaluation(net, cfg, device, run_type) 105 | 106 | f1_seg_cont = f1_score(m.TP_seg_cont, m.FP_seg_cont, m.FN_seg_cont) 107 | f1_seg_fl = f1_score(m.TP_seg_fl, m.FP_seg_fl, m.FN_seg_fl) 108 | f1_ch_cont = f1_score(m.TP_ch_cont, m.FP_ch_cont, m.FN_ch_cont) 109 | f1_ch_fl = f1_score(m.TP_ch_fl, m.FP_ch_fl, m.FN_ch_fl) 110 | f1 = (f1_seg_cont + f1_seg_fl + f1_ch_cont + f1_ch_fl) / 4 111 | 112 | wandb.log({ 113 | f'{run_type} f1': f1, 114 | f'{run_type} f1 seg cont': f1_seg_cont, 115 | f'{run_type} f1 seg fl': f1_seg_fl, 116 | f'{run_type} f1 ch cont': f1_ch_cont, 117 | f'{run_type} f1 ch fl': f1_ch_fl, 118 | 'step': step, 'epoch': epoch, 119 | }) 120 | 121 | return f1 122 | 123 | 124 | def precision(tp: int, fp: int) -> float: 125 | return tp / (tp + fp + EPS) 126 | 127 | 128 | def recall(tp: int, fn: int) -> float: 129 | return tp / (tp + fn + EPS) 130 | 131 | 132 | def f1_score(tp: int, fp: int, fn: int) -> float: 133 | p = precision(tp, fp) 134 | r = recall(tp, fn) 135 | return (2 * p * r) / (p + r + EPS) 136 | 137 | 138 | def iou(tp: int, fp: int, fn: int) -> float: 139 | return tp / (tp + fp + fn + EPS) 140 | -------------------------------------------------------------------------------- /utils/experiment_manager.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | from argparse import ArgumentParser 4 | from tabulate import tabulate 5 | from collections import OrderedDict 6 | import yaml 7 | from fvcore.common.config import CfgNode as _CfgNode 8 | from pathlib import Path 9 | 10 | 11 | class CfgNode(_CfgNode): 12 | """ 13 | The same as `fvcore.common.config.CfgNode`, but different in: 14 | 15 | 1. Use unsafe yaml loading by default. 16 | Note that this may lead to arbitrary code execution: you must not 17 | load a config file from untrusted sources before manually inspecting 18 | the content of the file. 19 | 2. Support config versioning. 20 | When attempting to merge an old config, it will convert the old config automatically. 21 | 22 | """ 23 | 24 | def __init__(self, init_dict=None, key_list=None, new_allowed=False): 25 | # Always allow merging new configs 26 | self.__dict__[CfgNode.NEW_ALLOWED] = True 27 | super(CfgNode, self).__init__(init_dict, key_list, True) 28 | 29 | # Note that the default value of allow_unsafe is changed to True 30 | def merge_from_file(self, cfg_filename: str, allow_unsafe: bool = True) -> None: 31 | loaded_cfg = _CfgNode.load_yaml_with_base(cfg_filename, allow_unsafe=allow_unsafe) 32 | loaded_cfg = type(self)(loaded_cfg) 33 | 34 | # defaults.py needs to import CfgNode 35 | self.merge_from_other_cfg(loaded_cfg) 36 | 37 | 38 | def new_config(): 39 | ''' 40 | Creates a new config based on the default config file 41 | :return: 42 | ''' 43 | 44 | C = CfgNode() 45 | 46 | C.CONFIG_DIR = 'config/' 47 | 48 | C.PATHS = CfgNode() 49 | C.TRAINER = CfgNode() 50 | C.MODEL = CfgNode() 51 | C.DATALOADER = CfgNode() 52 | C.AUGMENTATIONS = CfgNode() 53 | C.CONSISTENCY_TRAINER = CfgNode() 54 | C.DATASETS = CfgNode() 55 | 56 | return C.clone() 57 | 58 | 59 | def setup_cfg(args): 60 | cfg = new_config() 61 | cfg.merge_from_file(f'configs/{args.config_file}.yaml') 62 | cfg.merge_from_list(args.opts) 63 | cfg.NAME = args.config_file 64 | cfg.PATHS.ROOT = str(Path.cwd()) 65 | assert (Path(args.output_dir).exists()) 66 | cfg.PATHS.OUTPUT = args.output_dir 67 | assert (Path(args.dataset_dir).exists()) 68 | cfg.PATHS.DATASET = args.dataset_dir 69 | return cfg 70 | 71 | 72 | def setup_cfg_manual(config_name: str, root_dir: Path, output_dir: Path, dataset_dir: Path): 73 | cfg = new_config() 74 | cfg.merge_from_file(root_dir / f'configs/{config_name}.yaml') 75 | cfg.NAME = config_name 76 | cfg.PATHS.ROOT = str(Path.cwd()) 77 | assert output_dir.exists() 78 | cfg.PATHS.OUTPUT = str(output_dir) 79 | assert dataset_dir.exists() 80 | cfg.PATHS.DATASET = str(dataset_dir) 81 | return cfg 82 | 83 | 84 | # loading cfg 85 | def load_cfg(config_name: str): 86 | cfg = new_config() 87 | cfg_file = Path.cwd() / 'configs' / f'{config_name}.yaml' 88 | cfg.merge_from_file(str(cfg_file)) 89 | cfg.NAME = config_name 90 | return cfg 91 | 92 | 93 | 94 | -------------------------------------------------------------------------------- /utils/helpers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | from itertools import combinations 4 | from pathlib import Path 5 | from utils.experiment_manager import CfgNode 6 | import json 7 | from typing import Sequence, Tuple 8 | 9 | 10 | def get_aoi_ids(cfg: CfgNode) -> Sequence[str]: 11 | aoi_ids = list(cfg.DATASET.TRAIN_IDS) 12 | return aoi_ids 13 | 14 | 15 | def get_edges(n: int, edge_type: str) -> Sequence[Tuple[int, int]]: 16 | # edges are the timestamp combinations 17 | if edge_type == 'adjacent': 18 | edges = [(t1, t1 + 1) for t1 in range(n - 1)] 19 | elif edge_type == 'cyclic': 20 | edges = [(t1, t1 + 1) for t1 in range(n - 1)] 21 | edges.append((0, n - 1)) 22 | elif edge_type == 'dense': 23 | edges = list(combinations(range(n), 2)) 24 | elif edge_type == 'firstlast': 25 | edges = [(0, n - 1)] 26 | else: 27 | raise Exception('Unkown edge type!') 28 | return edges 29 | 30 | 31 | def get_ch(seg: Tensor, edges: Sequence[Tuple[int, int]]) -> Tensor: 32 | ch = [torch.ne(seg[:, t1], seg[:, t2]) for t1, t2 in edges] 33 | ch = torch.stack(ch).transpose(0, 1) 34 | return ch 35 | 36 | 37 | def load_json(file: Path) -> dict: 38 | with open(str(file)) as f: 39 | d = json.load(f) 40 | return d 41 | 42 | 43 | def write_json(file: Path, data: dict) -> None: 44 | with open(str(file), 'w', encoding='utf-8') as f: 45 | json.dump(data, f, ensure_ascii=False, indent=4) 46 | -------------------------------------------------------------------------------- /utils/parsers.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | 4 | def training_argument_parser(): 5 | # https://docs.python.org/3/library/argparse.html#the-add-argument-methodsq 6 | parser = argparse.ArgumentParser(description="Experiment Args") 7 | parser.add_argument('-c', "--config-file", dest='config_file', required=True, help="path to config file") 8 | parser.add_argument('-o', "--output-dir", dest='output_dir', required=True, help="path to output directory") 9 | parser.add_argument('-d', "--dataset-dir", dest='dataset_dir', default="", required=True, 10 | help="path to output directory") 11 | parser.add_argument( 12 | "opts", 13 | help="Modify config options using the command-line", 14 | default=None, 15 | nargs=argparse.REMAINDER, 16 | ) 17 | return parser 18 | 19 | 20 | def inference_argument_parser(): 21 | # https://docs.python.org/3/library/argparse.html#the-add-argument-method 22 | parser = argparse.ArgumentParser(description="Experiment Args") 23 | parser.add_argument('-c', "--config-file", dest='config_file', required=True, help="path to config file") 24 | parser.add_argument('-e', "--edge-type", dest='edge_type', default='dense', help="mrf edge type") 25 | parser.add_argument('-o', "--output-dir", dest='output_dir', required=True, help="path to output directory") 26 | parser.add_argument('-d', "--dataset-dir", dest='dataset_dir', default="", required=True, 27 | help="path to output directory") 28 | 29 | parser.add_argument( 30 | "opts", 31 | help="Modify config options using the command-line", 32 | default=None, 33 | nargs=argparse.REMAINDER, 34 | ) 35 | return parser 36 | 37 | 38 | def str2bool(v): 39 | if isinstance(v, bool): 40 | return v 41 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 42 | return True 43 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 44 | return False 45 | else: 46 | raise argparse.ArgumentTypeError('Boolean value expected.') 47 | --------------------------------------------------------------------------------