├── .gitignore
├── .idea
├── ContUrbanCD.iml
├── inspectionProfiles
│ └── profiles_settings.xml
├── misc.xml
├── modules.xml
├── vcs.xml
└── workspace.xml
├── README.md
├── configs
├── base.yaml
├── conturbancd_sn7.yaml
├── conturbancd_tscd.yaml
├── conturbancd_wusu.yaml
└── debug.yaml
├── eval.py
├── figures
└── overview.jpg
├── inference.py
├── metadata_conturbancd.json
├── model
├── model.py
├── modules.py
└── unet.py
├── requirements.txt
├── train.py
└── utils
├── datasets.py
├── evaluation.py
├── experiment_manager.py
├── helpers.py
└── parsers.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | *.egg-info/
24 | .installed.cfg
25 | *.egg
26 | MANIFEST
27 |
28 | # PyInstaller
29 | # Usually these files are written by a python script from a template
30 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
31 | *.manifest
32 | *.spec
33 |
34 | # Installer logs
35 | pip-log.txt
36 | pip-delete-this-directory.txt
37 |
38 | # Unit test / coverage reports
39 | htmlcov/
40 | .tox/
41 | .coverage
42 | .coverage.*
43 | .cache
44 | nosetests.xml
45 | coverage.xml
46 | *.cover
47 | .hypothesis/
48 | .pytest_cache/
49 |
50 | # Translations
51 | *.mo
52 | *.pot
53 |
54 | # Django stuff:
55 | *.log
56 | local_settings.py
57 | db.sqlite3
58 |
59 | # Flask stuff:
60 | instance/
61 | .webassets-cache
62 |
63 | # Scrapy stuff:
64 | .scrapy
65 |
66 | # Sphinx documentation
67 | docs/_build/
68 |
69 | # PyBuilder
70 | target/
71 |
72 | # Jupyter Notebook
73 | .ipynb_checkpoints
74 |
75 | # pyenv
76 | .python-version
77 |
78 | # celery beat schedule file
79 | celerybeat-schedule
80 |
81 | # SageMath parsed files
82 | *.sage.py
83 |
84 | # Environments
85 | .env
86 | .venv
87 | env/
88 | venv/
89 | ENV/
90 | env.bak/
91 | venv.bak/
92 |
93 | # Spyder project settings
94 | .spyderproject
95 | .spyproject
96 |
97 | # Rope project settings
98 | .ropeproject
99 |
100 | # mkdocs documentation
101 | /site
102 |
103 | # mypy
104 | .mypy_cache/
105 |
106 | wandb/
107 | output/
108 |
--------------------------------------------------------------------------------
/.idea/ContUrbanCD.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/workspace.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 |
84 |
85 |
86 |
87 |
88 |
89 |
90 |
91 | 1718108828280
92 |
93 |
94 | 1718108828280
95 |
96 |
97 |
98 |
99 |
100 |
101 |
102 |
103 |
104 |
105 |
106 |
107 |
108 |
109 |
110 |
111 |
112 |
113 |
114 |
115 |
116 |
117 |
118 |
119 |
120 |
121 |
122 |
123 |
124 |
125 |
126 |
127 |
128 |
129 |
130 |
131 |
132 |
133 |
134 |
135 |
136 |
137 |
138 |
139 |
140 |
141 |
142 |
143 |
144 |
145 | file://$PROJECT_DIR$/train.py
146 | 17
147 |
148 |
149 |
150 | file://$PROJECT_DIR$/eval.py
151 | 27
152 |
153 |
154 |
155 |
156 |
157 |
158 |
159 |
160 |
161 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Continuous Urban Change Detection from Satellite Image Time Series
2 |
3 | This repository contains the official code for the following paper:
4 |
5 | S. Hafner, H. Fang, H. Azizpour and Y. Ban, "Continuous Urban Change Detection from Satellite Image Time Series with Temporal Feature Refinement and Multi-Task Integration," *(accepted to TGRS)*.
6 |
7 | [](https://arxiv.org/abs/2406.17458)
8 |
9 |
10 |
11 | 
12 |
13 |
14 |
15 |
16 | # Setup
17 |
18 | This section shows you how to setup the dataset and the virtual environment.
19 |
20 | ## Dataset
21 |
22 | The SpaceNet 7 dataset is available from this [link](https://spacenet.ai/sn7-challenge/). More information can be found in the [SpaceNet 7 dataset paper](https://openaccess.thecvf.com/content/CVPR2021/html/Van_Etten_The_Multi-Temporal_Urban_Development_SpaceNet_Dataset_CVPR_2021_paper.html).
23 |
24 | We use a metadata file (`metadata_conturbancd.json`) that can be downloaded from [here](https://drive.google.com/file/d/1wzRZ9iS2lOu24OArkG5n5tQaNrRQ6As6/view?usp=drive_link) or from this repository. The metadata file should be placed in the root directory of the SpaceNet 7 dataset (see below).
25 |
26 | We also generated raster labels using the code in [this paper](https://doi.org/10.1109/IGARSS46834.2022.9883982). The labels can be obtained from [this link](https://drive.google.com/file/d/1ZHcZ0qfcymBJ4_hHcpDz6UpwvOiRZccB/view?usp=sharing) and should be placed in the respective study site folder.
27 |
28 | The dataset directory should look like this:
29 |
30 | ```
31 | $ SpaceNet 7 dataset directory
32 | spacenet7 # -d should point to this directory
33 | ├── metadata_conturbancd.json # Download this file and place it in the dataset directory
34 | └── train
35 | ├── L15-0331E-1257N_1327_3160_13
36 | ...
37 | └── L15-1848E-0793N_7394_5018_13
38 | ├── images_masked
39 | ├── labels
40 | └── labels_raster # Generated using the vector data in labels
41 | ```
42 |
43 |
44 |
45 | ## Environment
46 |
47 | 1. Clone this repository
48 |
49 | ```bash
50 | git clone https://github.com/SebastianHafner/ContUrbanCD.git
51 | cd ContUrbanCD
52 | ```
53 |
54 | 2. Setup a virtual environment
55 |
56 | We use conda to setup the environment:
57 | ```bash
58 | conda create -n conturbancd python=3.9.7
59 | conda activate conturbancd
60 | ```
61 |
62 | 3. Install the dependencies
63 |
64 | Install pytorch according to the [official guide](https://pytorch.org/get-started/locally/).
65 |
66 | Install other dependencies using the `requirements.txt` file. Note that not all libraries specified in the file are required to run this code.
67 | ```bash
68 | pip install -r requirements.txt
69 | ```
70 |
71 |
72 |
73 | # Running our code
74 |
75 | This section provides all the instructions to run our code. If you do not want to train your own models, we also provide [our model weights](https://drive.google.com/drive/folders/1GA4_GM4li-K8gpCltFjM0x0W46k7iol5?usp=sharing) (trained using this code base).
76 |
77 | All scripts (`train.py`, `evaluation.py`, `inference.py`) require three arguments:
78 | - The config file is specified using `-c`. This repo includes dataset-specific configs for the proposed method.
79 | - The output directory is specified using `-o`. We use this directory to store model weights and evaluation and inference outputs.
80 | - The dataset directory is specified using `-d`. This directory points to the root folder of the dataset.
81 |
82 |
83 | ## Training
84 |
85 | To train our network on SpaceNet 7, run the following script:
86 | ```bash
87 | python train.py -c conturbancd_sn7 -o output -d spacenet7
88 | ```
89 |
90 |
91 | ## Evaluation
92 |
93 | To calculate the evaluation metrics, run this script:
94 | ```bash
95 | python eval.py -c conturbancd_sn7 -o output -d spacenet7
96 | ```
97 |
98 | The script outputs a `.json` file containing accuracy values.
99 |
100 |
101 | ## Inference
102 |
103 | To produce multi-temporal building segmentation outputs from a satellite image time series, run the following script:
104 |
105 | ```bash
106 | python inference.py -c conturbancd_sn7 -o output -d spacenet7
107 | ```
108 |
109 | Note: You have the option to set the edge setting in the MRF module using `-e` (degenerate, adjacent, cyclic, and dense). Dense is used as default option.
110 |
111 | The resulting file (`.np`) is of dimension T x H x W with the first dimension denoting the temporal dimension.
112 |
113 | ```python
114 | seg_sits = np.load(output_file) # T x H x X
115 |
116 | # Building segmentations for second image in the time series
117 | seg_img_t2 = seg_sits[1]
118 |
119 | # Changes between first and last images of the time series
120 | ch_first_last = np.not_equal(seg_sits[0], seg_sits[-1])
121 | ```
122 |
123 | # Credits
124 |
125 | If you find this work useful, please consider citing:
126 |
127 |
128 |
129 | ```bibtex
130 | @article{hafner2024continuous,
131 | title={Continuous Urban Change Detection from Satellite Image Time Series with Temporal Feature Refinement and Multi-Task Integration},
132 | author={Hafner, Sebastian and Fang, Heng and Azizpour, Hossein and Ban, Yifang},
133 | journal={arXiv preprint arXiv:2406.17458},
134 | year={2024}
135 | }
136 | ```
137 |
--------------------------------------------------------------------------------
/configs/base.yaml:
--------------------------------------------------------------------------------
1 | SEED: 1
2 | DEBUG: False
3 | LOG_FREQ: 100
4 |
5 | TRAINER:
6 | LR: 1e-4
7 | BATCH_SIZE: 8
8 | EPOCHS: 100
9 | OPTIMIZER: 'adamw'
10 | PATIENCE: 10
11 | LAMBDA: 0
12 | LR_SCHEDULER: 'linear'
13 |
14 | MODEL:
15 | TYPE: 'mtunetformer'
16 | IN_CHANNELS: 3
17 | OUT_CHANNELS: 1
18 | LOSS_TYPE: 'PowerJaccardLoss'
19 | EDGE_TYPE: 'dense'
20 | TOPOLOGY: [64, 128, 256, 512,]
21 | DISABLE_OUTCONV: False
22 | TRANSFORMER: True
23 | TRANSFORMER_PARAMS:
24 | N_LAYERS: 2
25 | N_HEADS: 2
26 | D_MODEL: 192
27 | PATCH_SIZE: 8
28 | ACTIVATION: 'gelu'
29 | SPATIAL_ATTENTION_SIZE: 1
30 | ADJACENT_CHANGES: True
31 |
32 | DATALOADER:
33 | NUM_WORKER: 4
34 | SHUFFLE: True
35 | MODE: 'all' # 'first_last' or 'all'
36 | SENSOR: 'planetscope'
37 | INCLUDE_ALPHA: False
38 | TRAINING_MULTIPLIER: 100
39 | PAD_BORDERS: True
40 | TIMESERIES_LENGTH: 5
41 | INCLUDE_CHANGE_LABEL: False
42 | EVAL_TRAIN_THRESHOLD: 0.75
43 | I_SPLIT: 768
44 | J_SPLIT: 512
45 |
46 | AUGMENTATION:
47 | CROP_SIZE: 64
48 | IMAGE_OVERSAMPLING_TYPE: 'change' # [none, change, semantic]
49 | RANDOM_FLIP: True
50 | RANDOM_ROTATE: True
51 | COLOR_BLUR: True
52 | COLOR_JITTER: True
53 |
54 | DATASET:
55 | NAME: 'sn7'
56 | TRAIN_IDS: [
57 | 'L15-0331E-1257N_1327_3160_13',
58 | 'L15-0357E-1223N_1429_3296_13',
59 | 'L15-0358E-1220N_1433_3310_13',
60 | 'L15-0361E-1300N_1446_2989_13',
61 | 'L15-0434E-1218N_1736_3318_13',
62 | 'L15-0487E-1246N_1950_3207_13',
63 | 'L15-0506E-1204N_2027_3374_13',
64 | 'L15-0544E-1228N_2176_3279_13',
65 | 'L15-0577E-1243N_2309_3217_13',
66 | 'L15-0586E-1127N_2345_3680_13',
67 | 'L15-0595E-1278N_2383_3079_13',
68 | 'L15-0614E-0946N_2459_4406_13',
69 | 'L15-0683E-1006N_2732_4164_13',
70 | 'L15-0760E-0887N_3041_4643_13',
71 | 'L15-0924E-1108N_3699_3757_13',
72 | 'L15-0977E-1187N_3911_3441_13',
73 | 'L15-1025E-1366N_4102_2726_13',
74 | 'L15-1049E-1370N_4196_2710_13',
75 | 'L15-1172E-1306N_4688_2967_13',
76 | 'L15-1185E-0935N_4742_4450_13',
77 | 'L15-1203E-1203N_4815_3378_13',
78 | 'L15-1204E-1202N_4816_3380_13',
79 | 'L15-1204E-1204N_4819_3372_13',
80 | 'L15-1209E-1113N_4838_3737_13',
81 | 'L15-1210E-1025N_4840_4088_13',
82 | 'L15-1289E-1169N_5156_3514_13',
83 | 'L15-1296E-1198N_5184_3399_13',
84 | 'L15-1298E-1322N_5193_2903_13',
85 | 'L15-1335E-1166N_5342_3524_13',
86 | 'L15-1389E-1284N_5557_3054_13',
87 | 'L15-1479E-1101N_5916_3785_13',
88 | 'L15-1481E-1119N_5927_3715_13',
89 | 'L15-1538E-1163N_6154_3539_13',
90 | 'L15-1615E-1205N_6460_3370_13',
91 | 'L15-1617E-1207N_6468_3360_13',
92 | 'L15-1669E-1153N_6678_3579_13',
93 | 'L15-1669E-1160N_6678_3548_13',
94 | 'L15-1691E-1211N_6764_3347_13',
95 | 'L15-1703E-1219N_6813_3313_13',
96 | 'L15-1716E-1211N_6864_3345_13',
97 | 'L15-0368E-1245N_1474_3210_13',
98 | 'L15-0457E-1135N_1831_3648_13',
99 | 'L15-0571E-1075N_2287_3888_13',
100 | 'L15-1014E-1375N_4056_2688_13',
101 | 'L15-1138E-1216N_4553_3325_13',
102 | 'L15-1439E-1134N_5759_3655_13',
103 | 'L15-1669E-1160N_6679_3549_13',
104 | 'L15-1672E-1207N_6691_3363_13',
105 | 'L15-1709E-1112N_6838_3742_13',
106 | 'L15-1748E-1247N_6993_3202_13',
107 | 'L15-0387E-1276N_1549_3087_13',
108 | 'L15-0566E-1185N_2265_3451_13',
109 | 'L15-0632E-0892N_2528_4620_13',
110 | 'L15-1015E-1062N_4061_3941_13',
111 | 'L15-1200E-0847N_4802_4803_13',
112 | 'L15-1276E-1107N_5105_3761_13',
113 | 'L15-1438E-1134N_5753_3655_13',
114 | 'L15-1615E-1206N_6460_3366_13',
115 | 'L15-1690E-1211N_6763_3346_13',
116 | 'L15-1848E-0793N_7394_5018_13',
117 | ]
118 | TEST_IDS: [
119 | 'L15-0369E-1244N_1479_3214_13',
120 | 'L15-0391E-1219N_1567_3314_13',
121 | 'L15-0509E-1108N_2037_3758_13',
122 | 'L15-0571E-1302N_2284_2983_13',
123 | 'L15-0697E-0874N_2789_4694_13',
124 | 'L15-0744E-0927N_2979_4481_13',
125 | 'L15-1031E-1300N_4127_2991_13',
126 | 'L15-1129E-0819N_4517_4915_13',
127 | 'L15-1203E-1203N_4815_3379_13',
128 | 'L15-1213E-1238N_4852_3239_13',
129 | 'L15-1249E-1167N_4999_3521_13',
130 | 'L15-1281E-1035N_5125_4049_13',
131 | 'L15-1438E-1227N_5753_3282_13',
132 | 'L15-1546E-1154N_6186_3574_13',
133 | 'L15-1615E-1205N_6461_3368_13',
134 | 'L15-1630E-0988N_6522_4239_13',
135 | 'L15-1666E-1189N_6665_3433_13',
136 | 'L15-1670E-1159N_6681_3552_13',
137 | 'L15-1690E-1210N_6762_3348_13',
138 | 'L15-1749E-1266N_6997_3126_13',
139 | ]
--------------------------------------------------------------------------------
/configs/conturbancd_sn7.yaml:
--------------------------------------------------------------------------------
1 | _BASE_: "base.yaml"
2 | SEED: 2
3 |
4 | TRAINER:
5 | LR: 1e-4
6 | BATCH_SIZE: 16
7 |
8 | MODEL:
9 | EDGE_TYPE: 'dense'
10 | TYPE: 'mtunetformer'
11 |
12 | DATASET:
13 | NAME: 'sn7'
--------------------------------------------------------------------------------
/configs/conturbancd_tscd.yaml:
--------------------------------------------------------------------------------
1 | _BASE_: "base.yaml"
2 | SEED: 3
3 |
4 | TRAINER:
5 | LR: 5e-5
6 | BATCH_SIZE: 16
7 |
8 | MODEL:
9 | EDGE_TYPE: 'dense'
10 | TYPE: 'mtunetformer'
11 | IN_CHANNELS: 3
12 |
13 | DATASET:
14 | NAME: 'tscd'
15 |
16 | DATALOADER:
17 | TIMESERIES_LENGTH: 4
--------------------------------------------------------------------------------
/configs/conturbancd_wusu.yaml:
--------------------------------------------------------------------------------
1 | _BASE_: "base.yaml"
2 | SEED: 2
3 |
4 | TRAINER:
5 | LR: 1e-4
6 | BATCH_SIZE: 16
7 |
8 | MODEL:
9 | EDGE_TYPE: 'cyclic'
10 | TYPE: 'mtunetformer'
11 | IN_CHANNELS: 4
12 |
13 | DATASET:
14 | NAME: 'wusu'
15 |
16 | DATALOADER:
17 | TIMESERIES_LENGTH: 3
--------------------------------------------------------------------------------
/configs/debug.yaml:
--------------------------------------------------------------------------------
1 | _BASE_: "conturbancd_sn7.yaml"
2 | DEBUG: True
3 |
4 | TRAINER:
5 | BATCH_SIZE: 2
6 |
7 | DATALOADER:
8 | TRAINING_MULTIPLIER: 1
9 |
10 | AUGMENTATION:
11 | CROP_SIZE: 64
12 | IMAGE_OVERSAMPLING_TYPE: 'none'
13 |
--------------------------------------------------------------------------------
/eval.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from utils import datasets, parsers, experiment_manager, helpers, evaluation
4 | from utils.experiment_manager import CfgNode
5 | from model import model
6 |
7 | from pathlib import Path
8 |
9 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
10 |
11 |
12 | def assessment(cfg: CfgNode, edge_type: str = 'dense', run_type: str = 'test'):
13 | print(cfg.NAME)
14 | net = model.load_model(cfg, device)
15 | m = evaluation.run_quantitative_evaluation(net, cfg, device, run_type, enable_mti=True, mti_edge_setting=edge_type)
16 |
17 | data = {}
18 | for attr in ['seg_cont', 'seg_fl', 'ch_cont', 'ch_fl']:
19 | f1 = evaluation.f1_score(getattr(m, f'TP_{attr}'), getattr(m, f'FP_{attr}'), getattr(m, f'FN_{attr}'))
20 | iou = evaluation.iou(getattr(m, f'TP_{attr}'), getattr(m, f'FP_{attr}'), getattr(m, f'FN_{attr}'))
21 | data[attr] = {'f1': f1, 'iou': iou}
22 | eval_folder = Path(cfg.PATHS.OUTPUT) / 'evaluation'
23 | eval_folder.mkdir(exist_ok=True)
24 | helpers.write_json(eval_folder / f'{cfg.NAME}_{edge_type}.json', data)
25 |
26 |
27 | if __name__ == '__main__':
28 | args = parsers.inference_argument_parser().parse_known_args()[0]
29 | cfg = experiment_manager.setup_cfg(args)
30 | assessment(cfg, edge_type=args.edge_type)
31 |
32 |
--------------------------------------------------------------------------------
/figures/overview.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SebastianHafner/ContUrbanCD/4c3c6501b444880ebc8eb21bcceea6dcf738c46b/figures/overview.jpg
--------------------------------------------------------------------------------
/inference.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from utils import parsers, experiment_manager, helpers, datasets
3 | from utils.experiment_manager import CfgNode
4 | from pathlib import Path
5 | import numpy as np
6 | from model import model
7 |
8 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
9 | EPS = 1e-6
10 |
11 |
12 | def inference(cfg: CfgNode, edge_type: str = 'dense', run_type: str = 'test'):
13 | print(cfg.NAME)
14 | net = model.load_model(cfg, device)
15 | net.eval()
16 |
17 | tile_size = cfg.AUGMENTATION.CROP_SIZE
18 | edges = helpers.get_edges(cfg.DATALOADER.TIMESERIES_LENGTH, edge_type)
19 |
20 | pred_folder = Path(cfg.PATHS.OUTPUT) / 'inference' / cfg.NAME
21 | pred_folder.mkdir(exist_ok=True)
22 |
23 | for aoi_id in list(cfg.DATASET.TRAIN_IDS):
24 | print(aoi_id)
25 |
26 | ds = datasets.create_eval_dataset(cfg, run_type, site=aoi_id, tiling=tile_size)
27 | o_seg = np.empty((1, ds.T, 1, ds.m, ds.n), dtype=np.uint8)
28 |
29 | for index in range(len(ds)):
30 | item = ds.__getitem__(index)
31 | x, y_seg = item['x'].to(device).unsqueeze(0), item['y'].to(device).unsqueeze(0)
32 | i, j = item['i'], item['j']
33 | o_seg_tile = net.module.inference(x, edges)
34 | o_seg[:, :, :, i:i + tile_size, j:j + tile_size] = o_seg_tile
35 |
36 | np.save(pred_folder / f'{cfg.NAME}_{aoi_id}.npy', o_seg)
37 |
38 |
39 | if __name__ == '__main__':
40 | args = parsers.inference_argument_parser().parse_known_args()[0]
41 | cfg = experiment_manager.setup_cfg(args)
42 | inference(cfg, edge_type=args.edge_type)
43 |
--------------------------------------------------------------------------------
/model/model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch import Tensor
4 | import einops
5 |
6 | from typing import Tuple, Sequence
7 | from pathlib import Path
8 |
9 | from utils.experiment_manager import CfgNode
10 |
11 | from model import unet, modules
12 |
13 |
14 | class ContUrbanCDModel(nn.Module):
15 | def __init__(self, cfg: CfgNode):
16 | super(ContUrbanCDModel, self).__init__()
17 |
18 | # Attributes
19 | self.cfg = cfg
20 | self.c = cfg.MODEL.IN_CHANNELS
21 | self.d_out = cfg.MODEL.OUT_CHANNELS
22 | self.h = self.w = cfg.AUGMENTATION.CROP_SIZE
23 | self.t = cfg.DATALOADER.TIMESERIES_LENGTH
24 | self.topology = cfg.MODEL.TOPOLOGY
25 |
26 | # ConvNet layers
27 | self.inc = unet.InConv(self.c, self.topology[0], unet.DoubleConv)
28 | self.encoder = unet.Encoder(cfg)
29 |
30 | self.decoder_seg = unet.Decoder(cfg)
31 | self.outc_seg = unet.OutConv(self.topology[0], self.d_out)
32 |
33 | self.decoder_ch = unet.Decoder(cfg)
34 | self.outc_ch = unet.OutConv(self.topology[0], self.d_out)
35 |
36 | # Temporal feature refinement (TFR) modules
37 | tfr_modules = []
38 | transformer_dims = [self.topology[-1]] + list(self.topology[::-1])
39 | for i, d_model in enumerate(transformer_dims):
40 | tfr_module = modules.TFRModule(
41 | t=self.t,
42 | d_model=d_model,
43 | n_heads=cfg.MODEL.TRANSFORMER_PARAMS.N_HEADS,
44 | d_hid=self.topology[0] * 4,
45 | activation=cfg.MODEL.TRANSFORMER_PARAMS.ACTIVATION,
46 | n_layers=cfg.MODEL.TRANSFORMER_PARAMS.N_LAYERS
47 | )
48 | tfr_modules.append(tfr_module)
49 | self.tfr_modules = nn.ModuleList(tfr_modules)
50 |
51 | # Change feature (CF) module
52 | self.cf_module = modules.CFModule()
53 |
54 | # Multi-task integration (MTI) module
55 | self.mti_module = modules.MTIModule()
56 |
57 | def forward(self, x: Tensor, edges: Sequence[Tuple[int, int]]) -> Tuple[Tensor, Tensor]:
58 | B, T, _, H, W = x.size()
59 |
60 | # Feature extraction with Siamese ConvNet encoder
61 | x = einops.rearrange(x, 'b t c h w -> (b t) c h w')
62 | features = self.encoder(self.inc(x))
63 | features = [einops.rearrange(f_s, '(b t) f h w -> b t f h w', b=B) for f_s in features]
64 |
65 | # Temporal feature refinement with TFR modules
66 | for i, tfr_module in enumerate(self.tfr_modules):
67 | f = features[i] # Feature maps at scale s
68 | # Feature refinement with self-attention
69 | f_refined = tfr_module(f)
70 | features[i] = f_refined
71 |
72 | # Change feature maps
73 | features_ch = self.cf_module(features, edges)
74 | features_ch = [einops.rearrange(f, 'n b c h w -> (b n) c h w') for f in features_ch]
75 |
76 | # Building segmentation
77 | features_seg = [einops.rearrange(f, 'b t c h w -> (b t) c h w') for f in features]
78 | logits_seg = self.outc_seg(self.decoder_seg(features_seg))
79 | logits_seg = einops.rearrange(logits_seg, '(b t) c h w -> b t c h w', b=B)
80 |
81 | logits_ch = self.outc_ch(self.decoder_ch(features_ch))
82 | logits_ch = einops.rearrange(logits_ch, '(b n) c h w -> b n c h w', n=len(edges))
83 |
84 | return logits_ch, logits_seg
85 |
86 | def inference(self, x: Tensor, edges: Sequence[Tuple[int, int]]) -> Tensor:
87 | logits_ch, logits_seg = self.forward(x, edges)
88 | o_ch = torch.sigmoid(logits_ch).detach()
89 | o_seg = torch.sigmoid(logits_seg).detach()
90 | o_seg = self.mti_module(o_ch, o_seg, edges)
91 | return o_seg
92 |
93 |
94 | def init_model(cfg: CfgNode) -> nn.Module:
95 | net = ContUrbanCDModel(cfg)
96 | return torch.nn.DataParallel(net)
97 |
98 |
99 | def save_model(network: nn.Module, epoch: float, cfg: CfgNode):
100 | save_file = Path(cfg.PATHS.OUTPUT) / 'weights' / f'{cfg.NAME}.pt'
101 | save_file.parent.mkdir(exist_ok=True)
102 | checkpoint = {
103 | 'epoch': epoch,
104 | 'weights': network.state_dict(),
105 | }
106 | torch.save(checkpoint, save_file)
107 |
108 |
109 | def load_model(cfg: CfgNode, device: torch.device) -> nn.Module:
110 | net = init_model(cfg)
111 | net.to(device)
112 | net_file = Path(cfg.PATHS.OUTPUT) / 'weights' / f'{cfg.NAME}.pt'
113 | checkpoint = torch.load(net_file, map_location=device)
114 | net.load_state_dict(checkpoint['weights'])
115 | return net
116 |
117 |
118 | def power_jaccard_loss(input: Tensor, target: Tensor, disable_sigmoid: bool = False) -> Tensor:
119 | input_sigmoid = torch.sigmoid(input) if not disable_sigmoid else input
120 | eps = 1e-6
121 |
122 | iflat = input_sigmoid.flatten()
123 | tflat = target.flatten()
124 | intersection = (iflat * tflat).sum()
125 | denom = (iflat ** 2 + tflat ** 2).sum() - (iflat * tflat).sum() + eps
126 |
127 | return 1 - (intersection / denom)
128 |
129 |
--------------------------------------------------------------------------------
/model/modules.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch import Tensor
4 | import einops
5 |
6 | import numpy as np
7 |
8 | # https://pgmpy.org/models/markovnetwork.html
9 | from pgmpy.models import MarkovNetwork
10 | from pgmpy.factors.discrete import DiscreteFactor
11 | from pgmpy.inference import BeliefPropagation
12 |
13 | from typing import Tuple, Sequence, Callable
14 |
15 | from joblib import Parallel, delayed
16 |
17 | import os
18 | os.environ["PYTHONWARNINGS"] = "ignore"
19 |
20 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
21 |
22 |
23 | class TFRModule(nn.Module):
24 | def __init__(self, t: int, d_model: int, n_heads: int, d_hid: int, activation: str, n_layers: int):
25 | super().__init__()
26 | # Generate relative temporal encodings
27 | self.register_buffer('temporal_encodings', self.get_relative_encodings(t, d_model), persistent=False)
28 |
29 | # Define a transformer encoder layer
30 | encoder_layer = nn.TransformerEncoderLayer(
31 | d_model=d_model, nhead=n_heads,
32 | dim_feedforward=d_hid, batch_first=True,
33 | activation=activation
34 | )
35 | # Create module
36 | self.temporal_feature_refinement = nn.TransformerEncoder(encoder_layer, n_layers)
37 |
38 | def forward(self, features: Tensor) -> Tensor:
39 | B, T, D, H, W = features.size()
40 |
41 | # Reshape to tokens
42 | tokens = einops.rearrange(features, 'B T D H W -> (B H W) T D')
43 |
44 | # Adding relative temporal encodings
45 | tokens = tokens + self.temporal_encodings.repeat(B * H * W, 1, 1)
46 |
47 | # Feature refinement with self-attention
48 | features_hat = self.temporal_feature_refinement(tokens)
49 |
50 | # Reshape to original shape
51 | features_hat = einops.rearrange(features_hat, '(B H W) T D -> B T D H W', B=B, H=H)
52 |
53 | return features_hat
54 |
55 | @staticmethod
56 | def get_relative_encodings(sequence_length, d):
57 | result = torch.ones(sequence_length, d)
58 | for i in range(sequence_length):
59 | for j in range(d):
60 | result[i][j] = np.sin(i / (10000 ** (j / d))) if j % 2 == 0 else np.cos(i / (10000 ** ((j - 1) / d)))
61 | return result
62 |
63 |
64 | class CFModule(nn.Module):
65 | def __init__(self):
66 | super().__init__()
67 |
68 | @staticmethod
69 | def forward(features: Sequence[Tensor], edges: Sequence[Tuple[int, int]]) -> Sequence[Tensor]:
70 | # compute urban change detection features
71 | features_ch = []
72 | for feature in features:
73 | B, T, _, H, W = feature.size()
74 | feature_ch = []
75 | for t1, t2 in edges:
76 | feature_ch.append(feature[:, t2] - feature[:, t1])
77 | # n: number of combinations
78 | feature_ch = torch.stack(feature_ch)
79 | features_ch.append(feature_ch)
80 | return features_ch
81 |
82 |
83 | class MTIModule(nn.Module):
84 | def __init__(self):
85 | super().__init__()
86 |
87 | def forward(self, o_ch: Tensor, o_seg: Tensor, edges: Sequence[Tuple[int, int]]) -> Tensor:
88 |
89 | B, T, _, H, W = o_seg.size()
90 |
91 | # Get processing function by defining the Markov network
92 | process_pixel = self.markov_network(T, edges)
93 |
94 | # Reshape
95 | o_ch = einops.rearrange(o_ch, 'B N C H W -> (B H W) N C').cpu().numpy()
96 | o_seg = einops.rearrange(o_seg, 'B T C H W -> (B H W) T C').cpu().numpy()
97 |
98 | # Find optimal building time series using Markov network
99 | o_seg_mrf = Parallel(n_jobs=-1)(delayed(process_pixel)(p_seg, p_ch) for p_seg, p_ch, in zip(o_seg, o_ch))
100 |
101 | # Reshape
102 | o_seg_mrf = torch.Tensor(o_seg_mrf)
103 | o_seg_mrf = einops.rearrange(o_seg_mrf, '(B H W) (T C) -> B T C H W', H=H, W=W, T=T)
104 |
105 | return o_seg_mrf
106 |
107 | @staticmethod
108 | def markov_network(n: int, edges: Sequence[Tuple[int, int]]) -> Callable:
109 | # Define a function to process a single pixel
110 | def process_pixel(y_hat_seg: Sequence[float], y_hat_ch: Sequence[float]):
111 | model = MarkovNetwork()
112 |
113 | for t in range(len(y_hat_seg)):
114 | model.add_node(f'N{t}')
115 | # Cardinality: number of potential values (i.e., 2: 0/1)
116 | # Potential Values for node: P(urban=True), P(urban=False)
117 | urban_value = float(y_hat_seg[t])
118 | factor = DiscreteFactor([f'N{t}'], cardinality=[2], values=[1 - urban_value, urban_value])
119 | model.add_factors(factor)
120 |
121 | # add adjacent edges w/o potentials
122 | for t in range(n - 1):
123 | model.add_edge(f'N{t}', f'N{t + 1}')
124 |
125 | # add edges with potentials
126 | for i, (t1, t2) in enumerate(edges):
127 | model.add_edge(f'N{t1}', f'N{t2}')
128 | # [P(A=False, B=False), P(A=False, B=True), P(A=True, B=False), P(A=True, B=True)]
129 | change_value = float(y_hat_ch[i])
130 | edge_values = [1 - change_value, change_value, change_value, 1 - change_value]
131 |
132 | factor = DiscreteFactor([f'N{t1}', f'N{t2}'], cardinality=[2, 2], values=edge_values)
133 | model.add_factors(factor)
134 |
135 | # Create an instance of BeliefPropagation algorithm
136 | bp = BeliefPropagation(model)
137 |
138 | # Compute the most probable state of the MRF
139 | state = bp.map_query()
140 | states_list = [state[f'N{t}'] for t in range(n)]
141 | return states_list
142 |
143 | return process_pixel
144 |
--------------------------------------------------------------------------------
/model/unet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch import Tensor
4 | import torch.nn.functional as F
5 |
6 | import einops
7 | from collections import OrderedDict
8 | from utils.experiment_manager import CfgNode
9 |
10 |
11 | class UNet(nn.Module):
12 | def __init__(self, cfg: CfgNode):
13 | super(UNet, self).__init__()
14 | self.cfg = cfg
15 |
16 | n_channels = cfg.MODEL.IN_CHANNELS
17 | n_classes = cfg.MODEL.OUT_CHANNELS
18 | topology = cfg.MODEL.TOPOLOGY
19 |
20 | self.inc = InConv(n_channels, topology[0], DoubleConv)
21 | self.encoder = Encoder(cfg)
22 | self.decoder = Decoder(cfg)
23 | self.outc = OutConv(topology[0], n_classes)
24 | self.disable_outc = cfg.MODEL.DISABLE_OUTCONV
25 |
26 | def forward(self, x: Tensor) -> Tensor:
27 | B, T, _, H, W = x.size()
28 | x = einops.rearrange(x, 'b t c h w -> (b t) c h w')
29 | out = self.outc(self.decoder(self.encoder(self.inc(x))))
30 | out = einops.rearrange(out, '(b t) c h w -> b t c h w', b=B)
31 | return out
32 |
33 |
34 | class Encoder(nn.Module):
35 | def __init__(self, cfg: CfgNode):
36 | super(Encoder, self).__init__()
37 |
38 | self.cfg = cfg
39 | topology = cfg.MODEL.TOPOLOGY
40 |
41 | # Variable scale
42 | down_topo = topology
43 | down_dict = OrderedDict()
44 | n_layers = len(down_topo)
45 |
46 | # Downward layers
47 | for idx in range(n_layers):
48 | is_not_last_layer = idx != n_layers - 1
49 | in_dim = down_topo[idx]
50 | out_dim = down_topo[idx + 1] if is_not_last_layer else down_topo[idx] # last layer
51 | layer = Down(in_dim, out_dim, DoubleConv)
52 | down_dict[f'down{idx + 1}'] = layer
53 | self.down_seq = nn.ModuleDict(down_dict)
54 |
55 | def forward(self, x1: Tensor) -> list:
56 |
57 | inputs = [x1]
58 | # Downward U:
59 | for layer in self.down_seq.values():
60 | out = layer(inputs[-1])
61 | inputs.append(out)
62 |
63 | inputs.reverse()
64 | return inputs
65 |
66 |
67 | class Decoder(nn.Module):
68 | def __init__(self, cfg: CfgNode):
69 | super(Decoder, self).__init__()
70 | self.cfg = cfg
71 |
72 | topology = cfg.MODEL.TOPOLOGY
73 |
74 | # Variable scale
75 | n_layers = len(topology)
76 | up_topo = [topology[0]] # topography upwards
77 | up_dict = OrderedDict()
78 |
79 | for idx in range(n_layers):
80 | is_not_last_layer = idx != n_layers - 1
81 | out_dim = topology[idx + 1] if is_not_last_layer else topology[idx] # last layer
82 | up_topo.append(out_dim)
83 |
84 | # Upward layers
85 | for idx in reversed(range(n_layers)):
86 | is_not_last_layer = idx != 0
87 | x1_idx = idx
88 | x2_idx = idx - 1 if is_not_last_layer else idx
89 | in_dim = up_topo[x1_idx] * 2
90 | out_dim = up_topo[x2_idx]
91 | layer = Up(in_dim, out_dim, DoubleConv)
92 | up_dict[f'up{idx + 1}'] = layer
93 |
94 | self.up_seq = nn.ModuleDict(up_dict)
95 |
96 | def forward(self, features: list) -> Tensor:
97 |
98 | x1 = features.pop(0)
99 | for idx, layer in enumerate(self.up_seq.values()):
100 | x2 = features[idx]
101 | x1 = layer(x1, x2) # x1 for next up layer
102 |
103 | return x1
104 |
105 |
106 | class DoubleConv(nn.Module):
107 | '''(conv => BN => ReLU) * 2'''
108 |
109 | def __init__(self, in_ch, out_ch):
110 | super(DoubleConv, self).__init__()
111 | self.conv = nn.Sequential(
112 | nn.Conv2d(in_ch, out_ch, 3, padding=1),
113 | nn.BatchNorm2d(out_ch),
114 | nn.ReLU(inplace=True),
115 | nn.Conv2d(out_ch, out_ch, 3, padding=1),
116 | nn.BatchNorm2d(out_ch),
117 | nn.ReLU(inplace=True)
118 | )
119 |
120 | def forward(self, x):
121 | x = self.conv(x)
122 | return x
123 |
124 |
125 | class InConv(nn.Module):
126 | def __init__(self, in_ch, out_ch, conv_block):
127 | super(InConv, self).__init__()
128 | self.conv = conv_block(in_ch, out_ch)
129 |
130 | def forward(self, x):
131 | x = self.conv(x)
132 | return x
133 |
134 |
135 | class Down(nn.Module):
136 | def __init__(self, in_ch, out_ch, conv_block):
137 | super(Down, self).__init__()
138 |
139 | self.mpconv = nn.Sequential(
140 | nn.MaxPool2d(2),
141 | conv_block(in_ch, out_ch)
142 | )
143 |
144 | def forward(self, x):
145 | x = self.mpconv(x)
146 | return x
147 |
148 |
149 | class up_conv(nn.Module):
150 | def __init__(self, ch_in, ch_out):
151 | super(up_conv, self).__init__()
152 | self.up = nn.Sequential(
153 | nn.Upsample(scale_factor=2),
154 | nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1, bias=True),
155 | nn.BatchNorm2d(ch_out),
156 | nn.ReLU(inplace=True)
157 | )
158 |
159 | def forward(self, x):
160 | x = self.up(x)
161 | return x
162 |
163 |
164 | class Up(nn.Module):
165 | def __init__(self, in_ch, out_ch, conv_block):
166 | super(Up, self).__init__()
167 |
168 | self.up = nn.ConvTranspose2d(in_ch // 2, in_ch // 2, 2, stride=2)
169 | self.conv = conv_block(in_ch, out_ch)
170 |
171 | def forward(self, x1, x2):
172 | x1 = self.up(x1)
173 |
174 | # input is CHW
175 | diffY = x2.detach().size()[2] - x1.detach().size()[2]
176 | diffX = x2.detach().size()[3] - x1.detach().size()[3]
177 |
178 | x1 = F.pad(x1, (diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2))
179 |
180 | # for padding issues, see
181 | # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
182 | # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
183 |
184 | x = torch.cat([x2, x1], dim=1)
185 | x = self.conv(x)
186 | return x
187 |
188 |
189 | class OutConv(nn.Module):
190 | def __init__(self, in_ch, out_ch):
191 | super(OutConv, self).__init__()
192 | self.conv = nn.Conv2d(in_ch, out_ch, 1)
193 |
194 | def forward(self, x):
195 | x = self.conv(x)
196 | return x
197 |
198 |
199 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | affine==2.3.0
2 | argh==0.26.2
3 | attrs==21.2.0
4 | certifi==2021.10.8
5 | charset-normalizer==2.0.8
6 | click==8.0.3
7 | click-plugins==1.1.1
8 | cligj==0.7.2
9 | configparser==5.1.0
10 | cycler==0.11.0
11 | docker-pycreds==0.4.0
12 | einops==0.6.0
13 | filelock==3.11.0
14 | fonttools==4.28.4
15 | fsspec==2023.5.0
16 | fvcore==0.1.5.post20211023
17 | gin-config==0.5.0
18 | gitdb==4.0.9
19 | GitPython==3.1.24
20 | huggingface-hub==0.15.1
21 | idna==3.3
22 | imagecodecs==2022.8.8
23 | iopath==0.1.9
24 | joblib==1.3.1
25 | kiwisolver==1.3.2
26 | lightning-utilities==0.10.0
27 | matplotlib==3.5.1
28 | mkl-fft==1.3.1
29 | mkl-random @ file:///tmp/build/80754af9/mkl_random_1626186066731/work
30 | mkl-service==2.4.0
31 | networkx==3.1
32 | numpy @ file:///tmp/build/80754af9/numpy_and_numpy_base_1634095651905/work
33 | olefile @ file:///Users/ktietz/demo/mc3/conda-bld/olefile_1629805411829/work
34 | opencv-python==4.5.5.62
35 | opt-einsum==3.3.0
36 | packaging==21.3
37 | pandas==2.0.3
38 | pathtools==0.1.2
39 | patsy==0.5.3
40 | pgmpy==0.1.23
41 | Pillow==8.4.0
42 | portalocker==2.3.2
43 | promise==2.3
44 | protobuf==3.19.1
45 | psutil==5.8.0
46 | pyparsing==3.0.6
47 | python-dateutil==2.8.2
48 | pytz==2023.3
49 | PyYAML==6.0
50 | rasterio==1.2.10
51 | regex==2023.6.3
52 | requests==2.26.0
53 | scikit-learn==1.3.0
54 | scipy==1.7.3
55 | sentry-sdk==1.5.0
56 | shortuuid==1.0.8
57 | six @ file:///tmp/build/80754af9/six_1623709665295/work
58 | smmap==5.0.0
59 | snuggs==1.4.7
60 | statsmodels==0.14.0
61 | subprocess32==3.5.4
62 | tabulate==0.8.9
63 | termcolor==1.1.0
64 | threadpoolctl==3.2.0
65 | tifffile==2022.8.12
66 | timm==0.6.13
67 | tokenizers==0.13.3
68 | torch==1.10.0
69 | torchaudio==0.10.0
70 | torcheval==0.0.7
71 | torchmetrics==1.2.0
72 | torchvision==0.11.1
73 | tqdm==4.62.3
74 | transformers==4.29.2
75 | typing-extensions @ file:///tmp/build/80754af9/typing_extensions_1631814937681/work
76 | tzdata==2023.3
77 | urllib3==1.26.7
78 | wandb==0.12.7
79 | yacs==0.1.8
80 | yaspin==2.1.0
81 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 | import timeit
4 |
5 | import torch
6 | from torch import optim
7 | from torch.utils import data as torch_data
8 |
9 | import wandb
10 | import numpy as np
11 |
12 | from utils import datasets, evaluation, experiment_manager, parsers, helpers
13 | from model import model
14 | from utils.experiment_manager import CfgNode
15 |
16 |
17 | def run_training(cfg: CfgNode):
18 | net = model.init_model(cfg)
19 | net.to(device)
20 |
21 | optimizer = optim.AdamW(net.parameters(), lr=cfg.TRAINER.LR, weight_decay=0.01)
22 |
23 | def lambda_rule(e: int):
24 | lr_l = 1.0 - e / float(cfg.TRAINER.EPOCHS - 1)
25 | return lr_l
26 | scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
27 |
28 | criterion = model.power_jaccard_loss
29 |
30 | # reset the generators
31 | dataset = datasets.create_train_dataset(cfg, run_type='train')
32 | print(dataset)
33 |
34 | dataloader_kwargs = {
35 | 'batch_size': cfg.TRAINER.BATCH_SIZE,
36 | 'num_workers': 0 if cfg.DEBUG else cfg.DATALOADER.NUM_WORKER,
37 | 'shuffle': cfg.DATALOADER.SHUFFLE,
38 | 'drop_last': True,
39 | 'pin_memory': True,
40 | }
41 | dataloader = torch_data.DataLoader(dataset, **dataloader_kwargs)
42 |
43 | edges = helpers.get_edges(cfg.DATALOADER.TIMESERIES_LENGTH, cfg.MODEL.EDGE_TYPE)
44 |
45 | # unpacking cfg
46 | epochs = cfg.TRAINER.EPOCHS
47 | steps_per_epoch = len(dataloader)
48 |
49 | # tracking variables
50 | global_step = epoch_float = 0
51 |
52 | # early stopping
53 | best_f1_val = 0
54 | trigger_times = 0
55 | stop_training = False
56 |
57 | for epoch in range(1, epochs + 1):
58 | print(f'Starting epoch {epoch}/{epochs}.')
59 | wandb.log({'lr': scheduler.get_last_lr()[-1] if scheduler is not None else cfg.TRAINER.LR, 'epoch': epoch})
60 | start = timeit.default_timer()
61 | loss_seg_set, loss_ch_set, loss_set = [], [], []
62 |
63 | for i, batch in enumerate(dataloader):
64 |
65 | net.train()
66 | optimizer.zero_grad()
67 |
68 | x, y_seg = batch['x'].to(device), batch['y'].to(device)
69 | logits_ch, logits_seg = net(x, edges)
70 |
71 | y_ch = helpers.get_ch(y_seg, edges)
72 |
73 | loss_seg = criterion(logits_seg, y_seg)
74 | loss_ch = criterion(logits_ch, y_ch)
75 |
76 | loss = loss_seg + loss_ch
77 | loss.backward()
78 | optimizer.step()
79 |
80 | loss_seg_set.append(loss_seg.item())
81 | loss_ch_set.append(loss_ch.item())
82 | loss_set.append(loss.item())
83 |
84 | global_step += 1
85 | epoch_float = global_step / steps_per_epoch
86 |
87 | if global_step % cfg.LOG_FREQ == 0:
88 | print(f'Logging step {global_step} (epoch {epoch_float:.2f}).')
89 |
90 | # logging
91 | time = timeit.default_timer() - start
92 | wandb.log({
93 | 'loss_seg': np.mean(loss_seg_set),
94 | 'loss_ch': np.mean(loss_ch_set),
95 | 'loss': np.mean(loss_set),
96 | 'time': time,
97 | 'step': global_step,
98 | 'epoch': epoch_float,
99 | })
100 | start = timeit.default_timer()
101 | loss_seg_set, loss_ch_set, loss_set = [], [], []
102 | # end of batch
103 |
104 | assert (epoch == epoch_float)
105 | print(f'epoch float {epoch_float} (step {global_step}) - epoch {epoch}')
106 | if scheduler is not None:
107 | scheduler.step()
108 | # evaluation at the end of an epoch
109 | f1_val = evaluation.model_evaluation(net, cfg, device, 'val', epoch_float, global_step)
110 |
111 | if f1_val <= best_f1_val:
112 | trigger_times += 1
113 | if trigger_times > cfg.TRAINER.PATIENCE:
114 | stop_training = True
115 | else:
116 | best_f1_val = f1_val
117 | wandb.log({
118 | 'best val f1': best_f1_val,
119 | 'step': global_step,
120 | 'epoch': epoch_float,
121 | })
122 | print(f'saving network (F1 {f1_val:.3f})', flush=True)
123 | model.save_model(net, epoch, cfg)
124 | trigger_times = 0
125 |
126 | if stop_training:
127 | break
128 |
129 | net = model.load_model(cfg, device)
130 | _ = evaluation.model_evaluation(net, cfg, device, 'test', epoch_float, global_step)
131 |
132 |
133 | if __name__ == '__main__':
134 | args = parsers.training_argument_parser().parse_known_args()[0]
135 | cfg = experiment_manager.setup_cfg(args)
136 |
137 | # make training deterministic
138 | torch.manual_seed(cfg.SEED)
139 | np.random.seed(cfg.SEED)
140 | torch.backends.cudnn.deterministic = True
141 | torch.backends.cudnn.benchmark = False
142 |
143 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
144 |
145 | print('=== Runnning on device: p', device)
146 |
147 | wandb.init(
148 | name=cfg.NAME,
149 | config=cfg,
150 | project='ContUrbanCD',
151 | mode='online' if not cfg.DEBUG else 'disabled',
152 | )
153 |
154 | try:
155 | run_training(cfg)
156 | except KeyboardInterrupt:
157 | try:
158 | sys.exit(0)
159 | except SystemExit:
160 | os._exit(0)
161 |
--------------------------------------------------------------------------------
/utils/datasets.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from pathlib import Path
3 | from abc import abstractmethod
4 | import numpy as np
5 | import multiprocessing
6 | import tifffile
7 | from utils import experiment_manager, helpers
8 | from utils.experiment_manager import CfgNode
9 | import cv2
10 | import torch
11 | from torchvision import transforms
12 | import numpy as np
13 | from scipy.ndimage import gaussian_filter
14 |
15 |
16 | def create_train_dataset(cfg: CfgNode, run_type: str) -> torch.utils.data.Dataset:
17 | if cfg.DATASET.NAME == 'sn7':
18 | return TrainSpaceNet7Dataset(cfg, run_type=run_type)
19 | elif cfg.DATASET.NAME == 'wusu':
20 | return TrainWUSUDataset(cfg, run_type=run_type)
21 | else:
22 | raise Exception('Unknown train dataset!')
23 |
24 |
25 | def create_eval_dataset(cfg: CfgNode, run_type: str, site: str = None, tiling: int = None) -> torch.utils.data.Dataset:
26 | if cfg.DATASET.NAME == 'sn7':
27 | return EvalSpaceNet7Dataset(cfg, run_type, aoi_id=site, tiling=tiling)
28 | elif cfg.DATASET.NAME == 'wusu':
29 | if run_type == 'test':
30 | return EvalTestWUSUDataset(cfg, site=site)
31 | else:
32 | return EvalWUSUDataset(cfg, run_type=run_type, tiling=tiling)
33 | else:
34 | raise Exception('Unknown train dataset!')
35 |
36 |
37 | class AbstractSpaceNet7Dataset(torch.utils.data.Dataset):
38 |
39 | def __init__(self, cfg: CfgNode):
40 | super().__init__()
41 | self.cfg = cfg
42 | self.root_path = Path(cfg.PATHS.DATASET)
43 | self.name = cfg.DATASET.NAME
44 |
45 | self.include_alpha = cfg.DATALOADER.INCLUDE_ALPHA
46 | self.pad = cfg.DATALOADER.PAD_BORDERS
47 |
48 | @abstractmethod
49 | def __getitem__(self, index: int) -> dict:
50 | pass
51 |
52 | @abstractmethod
53 | def __len__(self) -> int:
54 | pass
55 |
56 | def load_planet_mosaic(self, aoi_id: str, dataset: str, year: int, month: int) -> np.ndarray:
57 | folder = self.root_path / dataset / aoi_id / 'images_masked'
58 | file = folder / f'global_monthly_{year}_{month:02d}_mosaic_{aoi_id}.tif'
59 | img = tifffile.imread(str(file))
60 | img = img / 255
61 | # 4th band (last oen) is alpha band
62 | if not self.include_alpha:
63 | img = img[:, :, :-1]
64 | m, n, _ = img.shape
65 | if self.pad and (m != 1024 or n != 1024):
66 | # https://www.geeksforgeeks.org/python-opencv-cv2-copymakeborder-method/
67 | img = cv2.copyMakeBorder(img, 0, 1024 - m, 0, 1024 - n, borderType=cv2.BORDER_REPLICATE)
68 | return img.astype(np.float32)
69 |
70 | def load_building_label(self, aoi_id: str, year: int, month: int) -> np.ndarray:
71 | folder = self.root_path / 'train' / aoi_id / 'labels_raster'
72 | file = folder / f'global_monthly_{year}_{month:02d}_mosaic_{aoi_id}_Buildings.tif'
73 | label = tifffile.imread(str(file))
74 | m, n = label.shape
75 | if self.pad and (m != 1024 or n != 1024):
76 | label = cv2.copyMakeBorder(label, 0, 1024 - m, 0, 1024 - n, borderType=cv2.BORDER_REPLICATE)
77 | label = label[:, :, None]
78 | label = label > 0
79 | return label.astype(np.float32)
80 |
81 | def load_change_label(self, aoi_id: str, year_t1: int, month_t1: int, year_t2: int, month_t2) -> np.ndarray:
82 | building_t1 = self.load_building_label(aoi_id, year_t1, month_t1)
83 | building_t2 = self.load_building_label(aoi_id, year_t2, month_t2)
84 | change = np.logical_and(building_t1 == 0, building_t2 == 1)
85 | return change.astype(np.float32)
86 |
87 | def load_mask(self, aoi_id: str, year: int, month: int) -> np.ndarray:
88 | folder = self.root_path / 'train' / aoi_id / 'labels_raster'
89 | file = folder / f'global_monthly_{year}_{month:02d}_mosaic_{aoi_id}_mask.tif'
90 | mask = tifffile.imread(str(file))
91 | return mask.astype(np.int8)
92 |
93 | def get_aoi_ids(self) -> list:
94 | return list(set([s['aoi_id'] for s in self.samples]))
95 |
96 | def __len__(self):
97 | return self.length
98 |
99 | def __str__(self):
100 | return f'Dataset with {self.length} samples.'
101 |
102 |
103 | class AbstractWUSUDataset(torch.utils.data.Dataset):
104 |
105 | def __init__(self, cfg: experiment_manager.CfgNode):
106 | super().__init__()
107 | self.cfg = cfg
108 | self.root_path = Path(cfg.PATHS.DATASET)
109 | self.timestamps = [15, 16, 18]
110 | self.T = 3
111 | self.name = cfg.DATASET.NAME
112 |
113 | @abstractmethod
114 | def __getitem__(self, index: int) -> dict:
115 | pass
116 |
117 | @abstractmethod
118 | def __len__(self) -> int:
119 | pass
120 |
121 | def load_gf2_img(self, site: str, dataset: str, index: int, year: int) -> np.ndarray:
122 | file = self.root_path / dataset / site / 'imgs' / f'{site}{year}_{index}.tif'
123 | img = tifffile.imread(str(file))
124 | img = img / 255
125 | return img.astype(np.float32)
126 |
127 | def load_lulc_label(self, site: str, dataset: str, index: int, year: int) -> np.ndarray:
128 | file = self.root_path / dataset / site / 'class' / f'{site}{year}_{index}.tif'
129 | lulc_label = tifffile.imread(str(file))
130 | lulc_label = lulc_label[:, :, None]
131 | return lulc_label
132 |
133 | def load_building_label(self, site: str, dataset: str, index: int, year: int) -> np.ndarray:
134 | lulc = self.load_lulc_label(site, dataset, index, year)
135 | buildings = np.logical_or(lulc == 2, lulc == 3)
136 | return buildings.astype(np.float32)
137 |
138 | def load_building_change_label(self, site: str, dataset: str, index: int, year_t1: int, year_t2: int) -> np.ndarray:
139 | buildings_t1 = self.load_building_label(site, dataset, index, year_t1)
140 | buildings_t2 = self.load_building_label(site, dataset, index, year_t2)
141 | change = np.logical_and(buildings_t1 == 0, buildings_t2 == 1)
142 | return change.astype(np.float32)
143 |
144 |
145 | class TrainWUSUDataset(AbstractWUSUDataset):
146 |
147 | def __init__(self, cfg: experiment_manager.CfgNode, run_type: str, no_augmentations: bool = False):
148 | super().__init__(cfg)
149 |
150 | # handling transformations of data
151 | self.no_augmentations = no_augmentations
152 | self.transform = compose_transformations(cfg, no_augmentations)
153 |
154 | self.dataset = 'test' if run_type == 'test' else 'train'
155 | if run_type == 'test':
156 | self.samples = helpers.load_json(self.root_path / f'samples_test.json')
157 | elif run_type == 'val' or run_type == 'train':
158 | self.samples = helpers.load_json(self.root_path / f'samples_train.json')
159 | self.samples = [s for s in self.samples if s['split'] == run_type]
160 | else:
161 | raise Exception('Unkown run type!')
162 |
163 | manager = multiprocessing.Manager()
164 | self.samples = manager.list(self.samples)
165 |
166 | self.length = len(self.samples)
167 |
168 | def __getitem__(self, index):
169 |
170 | sample = self.samples[index]
171 | site, index = sample['site'], sample['index']
172 |
173 | images = [self.load_gf2_img(site, self.dataset, index, year) for year in self.timestamps]
174 | labels = [self.load_building_label(site, self.dataset, index, year) for year in self.timestamps]
175 |
176 | images, labels = self.transform((np.stack(images), np.stack(labels)))
177 |
178 | item = {
179 | 'x': images,
180 | 'y': labels,
181 | 'site': site,
182 | 'index': index,
183 | }
184 | return item
185 |
186 | def __len__(self):
187 | return self.length
188 |
189 | def __str__(self):
190 | return f'Train {self.name} dataset with {self.length} samples.'
191 |
192 |
193 | class EvalWUSUDataset(AbstractWUSUDataset):
194 |
195 | def __init__(self, cfg: experiment_manager.CfgNode, run_type: str, tiling: int = None, add_padding: bool = False,
196 | index: int = None):
197 | super().__init__(cfg)
198 |
199 | self.tiling = tiling
200 | self.add_padding = add_padding
201 |
202 | # handling transformations of data
203 | self.transform = compose_transformations(cfg, no_augmentations=True)
204 |
205 | self.dataset = 'test' if run_type == 'test' else 'train'
206 | if run_type == 'test':
207 | samples = helpers.load_json(self.root_path / f'samples_test.json')
208 | elif run_type == 'val' or run_type == 'train':
209 | samples = helpers.load_json(self.root_path / f'samples_train.json')
210 | samples = [s for s in samples if s['split'] == run_type]
211 | else:
212 | raise Exception('Unkown run type!')
213 |
214 | if index is not None:
215 | samples = [s for s in samples if s['index'] == index]
216 |
217 | if tiling is None:
218 | self.tiling = 512
219 |
220 | self.samples = []
221 | for sample in samples:
222 | for i in range(0, 512, self.tiling):
223 | for j in range(0, 512, self.tiling):
224 | tile_sample = {
225 | 'site': sample['site'],
226 | 'index': sample['index'],
227 | 'split': sample['split'],
228 | 'i': i,
229 | 'j': j,
230 | }
231 | self.samples.append(tile_sample)
232 |
233 | manager = multiprocessing.Manager()
234 | self.samples = manager.list(self.samples)
235 |
236 | self.length = len(self.samples)
237 |
238 | def __getitem__(self, index):
239 |
240 | sample = self.samples[index]
241 | site, index, i, j = sample['site'], sample['index'], sample['i'], sample['j']
242 |
243 | images = [self.load_gf2_img(site, self.dataset, index, year) for year in self.timestamps]
244 | images = np.stack(images)
245 | if self.add_padding:
246 | images = np.pad(images, ((0, 0), (self.tiling, self.tiling), (self.tiling, self.tiling), (0, 0)),
247 | mode='reflect')
248 | i_min, j_min = i, j
249 | i_max, j_max = i + 3 * self.tiling, j + 3 * self.tiling
250 | images = images[:, i_min:i_max, j_min:j_max]
251 | else:
252 | images = images[:, i:i + self.tiling, j:j + self.tiling]
253 |
254 | labels = [self.load_building_label(site, self.dataset, index, year) for year in self.timestamps]
255 | labels = np.stack(labels)[:, i:i + self.tiling, j:j + self.tiling]
256 |
257 | images, labels = self.transform((images, labels))
258 |
259 | item = {
260 | 'x': images,
261 | 'y': labels,
262 | 'site': site,
263 | 'index': index,
264 | 'i': i,
265 | 'j': j,
266 | }
267 |
268 | return item
269 |
270 | def __len__(self):
271 | return self.length
272 |
273 | def __str__(self):
274 | return f'Eval {self.name} dataset with {self.length} samples.'
275 |
276 |
277 | class EvalTestWUSUDataset(AbstractWUSUDataset):
278 |
279 | def __init__(self, cfg: experiment_manager.CfgNode, site: str = None):
280 | super().__init__(cfg)
281 |
282 | metadata_file = self.root_path / 'metadata_test.json'
283 | self.metadata = helpers.load_json(metadata_file)
284 |
285 | self.sites = ['JA', 'HS'] if site is None else [site]
286 |
287 | self.transform = compose_transformations(cfg, no_augmentations=True)
288 | self.crop_size = cfg.AUGMENTATION.CROP_SIZE
289 |
290 | self.dataset = 'test'
291 | self.samples = []
292 | for site in self.sites:
293 | tiles = self.metadata[site]['tiles']
294 | tile_size = int(self.metadata[site]['tile_size'] - self.metadata[site]['overlap'])
295 | for tile in tiles:
296 | i_tile, j_tile = tile['i_tile'], tile['j_tile']
297 | m_max, n_max = tile_size, tile_size
298 | if tile['edge_tile']:
299 | m_edge_tile, n_edge_tile = self.metadata[site]['m_edge_tile'], self.metadata[site]['n_edge_tile']
300 | m_edge_tile = m_edge_tile - m_edge_tile % self.crop_size
301 | n_edge_tile = n_edge_tile - n_edge_tile % self.crop_size
302 | if tile['row_end']:
303 | m_max = m_edge_tile
304 | if tile['col_end']:
305 | n_max = n_edge_tile
306 |
307 | for i in range(0, m_max, self.crop_size):
308 | for j in range(0, n_max, self.crop_size):
309 | tile_sample = {
310 | 'site': site,
311 | 'index': tile['index'],
312 | 'split': 'test',
313 | 'i_crop': i,
314 | 'j_crop': j,
315 | 'i_tile': i_tile,
316 | 'j_tile': j_tile,
317 | }
318 | self.samples.append(tile_sample)
319 |
320 | manager = multiprocessing.Manager()
321 | self.samples = manager.list(self.samples)
322 | self.metadata = manager.dict(self.metadata)
323 | self.length = len(self.samples)
324 |
325 | def __getitem__(self, sample_index):
326 |
327 | sample = self.samples[sample_index]
328 | site, index = sample['site'], sample['index']
329 |
330 | images = np.stack([self.load_gf2_img(site, self.dataset, index, year) for year in self.timestamps])
331 | i_crop, j_crop = sample['i_crop'], sample['j_crop']
332 | images = images[:, i_crop:i_crop + self.crop_size, j_crop:j_crop + self.crop_size]
333 |
334 | labels = np.stack([self.load_building_label(site, self.dataset, index, year) for year in self.timestamps])
335 | labels = labels[:, i_crop:i_crop + self.crop_size, j_crop:j_crop + self.crop_size]
336 |
337 | images, labels = self.transform((images, labels))
338 |
339 | i_tile, j_tile = sample['i_tile'], sample['j_tile']
340 | tile_size = int(self.metadata[site]['tile_size'] - self.metadata[site]['overlap'])
341 | i_img, j_img = i_tile * tile_size + i_crop, j_tile * tile_size + j_crop
342 |
343 | item = {
344 | 'x': images,
345 | 'y': labels,
346 | 'site': site,
347 | 'index': index,
348 | 'i_img': i_img,
349 | 'j_img': j_img,
350 | }
351 |
352 | return item
353 |
354 | def get_img_dims(self, site: str) -> tuple:
355 | tile_size = int(self.metadata[site]['tile_size'] - self.metadata[site]['overlap'])
356 | m_tile, n_tile = self.metadata[site]['m_tile'], self.metadata[site]['n_tile']
357 |
358 | m_edge_tile, n_edge_tile = self.metadata[site]['m_edge_tile'], self.metadata[site]['n_edge_tile']
359 | m_edge_tile = m_edge_tile - m_edge_tile % self.crop_size
360 | n_edge_tile = n_edge_tile - n_edge_tile % self.crop_size
361 |
362 | m_img = (m_tile - 1) * tile_size + m_edge_tile
363 | n_img = (n_tile - 1) * tile_size + n_edge_tile
364 | return m_img, n_img
365 |
366 | def __len__(self):
367 | return self.length
368 |
369 | def __str__(self):
370 | return f'Eval {self.name} dataset with {self.length} samples.'
371 |
372 |
373 | class TrainSpaceNet7Dataset(AbstractSpaceNet7Dataset):
374 |
375 | def __init__(self, cfg: experiment_manager.CfgNode, run_type: str, no_augmentations: bool = False,
376 | disable_multiplier: bool = False):
377 | super().__init__(cfg)
378 |
379 | self.T = cfg.DATALOADER.TIMESERIES_LENGTH
380 | self.include_change_label = cfg.DATALOADER.INCLUDE_CHANGE_LABEL
381 |
382 | # handling transformations of data
383 | self.no_augmentations = no_augmentations
384 | self.transform = compose_transformations(cfg, no_augmentations)
385 |
386 | self.metadata = helpers.load_json(self.root_path / f'metadata_siamesessl.json')
387 | self.aoi_ids = list(cfg.DATASET.TRAIN_IDS)
388 | assert (len(self.aoi_ids) == 60)
389 |
390 | # split
391 | self.run_type = run_type
392 | self.i_split = cfg.DATALOADER.I_SPLIT
393 | self.j_split = cfg.DATALOADER.J_SPLIT
394 |
395 | if not disable_multiplier:
396 | self.aoi_ids = self.aoi_ids * cfg.DATALOADER.TRAINING_MULTIPLIER
397 |
398 | manager = multiprocessing.Manager()
399 | self.aoi_ids = manager.list(self.aoi_ids)
400 | self.metadata = manager.dict(self.metadata)
401 |
402 | self.length = len(self.aoi_ids)
403 |
404 | def __getitem__(self, index):
405 |
406 | aoi_id = self.aoi_ids[index]
407 |
408 | timestamps = [ts for ts in self.metadata[aoi_id] if not ts['mask']]
409 |
410 | t_values = sorted(np.random.randint(0, len(timestamps), size=self.T))
411 | timestamps = sorted([timestamps[t] for t in t_values], key=lambda ts: int(ts['year']) * 12 + int(ts['month']))
412 |
413 | images = [self.load_planet_mosaic(aoi_id, ts['dataset'], ts['year'], ts['month']) for ts in timestamps]
414 | labels = [self.load_building_label(aoi_id, ts['year'], ts['month']) for ts in timestamps]
415 | images = [self.apply_split(img) for img in images]
416 | labels = [self.apply_split(label) for label in labels]
417 |
418 | images, labels = self.transform((np.stack(images), np.stack(labels)))
419 |
420 | item = {
421 | 'x': images,
422 | 'y': labels,
423 | 'aoi_id': aoi_id,
424 | 'dates': [(int(ts['year']), int(ts['month'])) for ts in timestamps],
425 | }
426 |
427 | if self.include_change_label:
428 | labels_ch = []
429 | for t in range(len(timestamps) - 1):
430 | labels_ch.append(torch.ne(labels[t + 1], labels[t]))
431 | labels_ch.append(torch.ne(labels[-1], labels[0]))
432 | item['y_ch'] = torch.stack(labels_ch)
433 |
434 | return item
435 |
436 | def apply_split(self, img: np.ndarray):
437 | if self.run_type == 'train':
438 | return img[:self.i_split]
439 | elif self.run_type == 'val':
440 | return img[self.i_split:, :self.j_split]
441 | elif self.run_type == 'test':
442 | return img[self.i_split:, self.j_split:]
443 | else:
444 | raise Exception('Unkown split!')
445 |
446 | def __len__(self):
447 | return self.length
448 |
449 | def __str__(self):
450 | return f'Train {self.name} dataset with {self.length} samples.'
451 |
452 |
453 | class EvalSpaceNet7Dataset(AbstractSpaceNet7Dataset):
454 |
455 | def __init__(self, cfg: experiment_manager.CfgNode, run_type: str, tiling: int = None, aoi_id: str = None,
456 | add_padding: bool = False):
457 | super().__init__(cfg)
458 |
459 | self.T = cfg.DATALOADER.TIMESERIES_LENGTH
460 | self.include_change_label = cfg.DATALOADER.INCLUDE_CHANGE_LABEL
461 | self.tiling = tiling if tiling is not None else 1024
462 | self.eval_train_threshold = cfg.DATALOADER.EVAL_TRAIN_THRESHOLD
463 | self.add_padding = add_padding
464 |
465 | # handling transformations of data
466 | self.transform = compose_transformations(cfg, no_augmentations=True)
467 |
468 | self.metadata = helpers.load_json(self.root_path / f'metadata_conturbancd.json')
469 |
470 | if aoi_id is None:
471 | self.aoi_ids = list(cfg.DATASET.TRAIN_IDS)
472 | assert (len(self.aoi_ids) == 60)
473 | else:
474 | self.aoi_ids = [aoi_id]
475 |
476 | # split
477 | self.run_type = run_type
478 | self.i_split = cfg.DATALOADER.I_SPLIT
479 | self.j_split = cfg.DATALOADER.J_SPLIT
480 |
481 | self.min_m, self.max_m = 0, 1024
482 | self.min_n, self.max_n = 0, 1024
483 | if run_type == 'train':
484 | self.max_m = self.i_split
485 | self.m = self.max_m
486 | else:
487 | assert (run_type == 'val' or run_type == 'test')
488 | self.min_m = self.i_split
489 | if run_type == 'val':
490 | self.max_n = self.j_split
491 | if run_type == 'test':
492 | self.min_n = self.j_split
493 |
494 | self.samples = []
495 | for aoi_id in self.aoi_ids:
496 | for i in range(self.min_m, self.max_m, self.tiling):
497 | for j in range(self.min_n, self.max_n, self.tiling):
498 | self.samples.append((aoi_id, (i, j)))
499 |
500 | self.m, self.n = self.max_m - self.min_m, self.max_n - self.min_n
501 |
502 | manager = multiprocessing.Manager()
503 | self.aoi_ids = manager.list(self.aoi_ids)
504 | self.metadata = manager.dict(self.metadata)
505 |
506 | self.length = len(self.samples)
507 |
508 | def __getitem__(self, index):
509 |
510 | aoi_id, (i, j) = self.samples[index]
511 |
512 | timestamps = [ts for ts in self.metadata[aoi_id] if not ts['mask']]
513 | t_values = list(np.linspace(0, len(timestamps), self.T, endpoint=False, dtype=int))
514 | timestamps = sorted([timestamps[t] for t in t_values], key=lambda ts: int(ts['year']) * 12 + int(ts['month']))
515 |
516 | images = [self.load_planet_mosaic(ts['aoi_id'], ts['dataset'], ts['year'], ts['month']) for ts in timestamps]
517 | images = np.stack(images)
518 | if self.add_padding:
519 | # images = np.pad(images, ((0, 0), (self.tiling, self.tiling), (self.tiling, self.tiling), (0, 0)),
520 | # mode='constant', constant_values=0)
521 | images = np.pad(images, ((0, 0), (self.tiling, self.tiling), (self.tiling, self.tiling), (0, 0)),
522 | mode='reflect')
523 | i_min, j_min = i, j
524 | i_max, j_max = i + 3 * self.tiling, j + 3 * self.tiling
525 | images = images[:, i_min:i_max, j_min:j_max]
526 | else:
527 | images = images[:, i:i + self.tiling, j:j + self.tiling]
528 |
529 | labels = [self.load_building_label(aoi_id, ts['year'], ts['month']) for ts in timestamps]
530 | labels = np.stack(labels)[:, i:i + self.tiling, j:j + self.tiling]
531 |
532 | images, labels = self.transform((images, labels))
533 |
534 | item = {
535 | 'x': images,
536 | 'y': labels,
537 | 'aoi_id': aoi_id,
538 | 'i': i - self.min_m,
539 | 'j': j - self.min_n,
540 | 'dates': [(int(ts['year']), int(ts['month'])) for ts in timestamps],
541 | }
542 |
543 | return item
544 |
545 | def __len__(self):
546 | return self.length
547 |
548 | def __str__(self):
549 | return f'Eval {self.name} dataset with {self.length} samples.'
550 |
551 |
552 | def compose_transformations(cfg, no_augmentations: bool):
553 | if no_augmentations:
554 | return transforms.Compose([Numpy2Torch()])
555 |
556 | transformations = []
557 |
558 | # cropping
559 | if cfg.AUGMENTATION.IMAGE_OVERSAMPLING_TYPE == 'none':
560 | transformations.append(UniformCrop(cfg.AUGMENTATION.CROP_SIZE))
561 | elif cfg.AUGMENTATION.IMAGE_OVERSAMPLING_TYPE == 'change':
562 | transformations.append(ImportanceRandomCrop(cfg.AUGMENTATION.CROP_SIZE, 'change'))
563 | elif cfg.AUGMENTATION.IMAGE_OVERSAMPLING_TYPE == 'semantic':
564 | transformations.append(ImportanceRandomCrop(cfg.AUGMENTATION.CROP_SIZE, 'semantic'))
565 | else:
566 | raise Exception('Unkown oversampling type!')
567 |
568 | if cfg.AUGMENTATION.RANDOM_FLIP:
569 | transformations.append(RandomFlip())
570 |
571 | if cfg.AUGMENTATION.RANDOM_ROTATE:
572 | transformations.append(RandomRotate())
573 |
574 | if cfg.AUGMENTATION.COLOR_BLUR:
575 | transformations.append(RandomColorBlur())
576 |
577 | transformations.append(Numpy2Torch())
578 |
579 | if cfg.AUGMENTATION.COLOR_JITTER:
580 | transformations.append(RandomColorJitter(n_bands=cfg.MODEL.IN_CHANNELS))
581 |
582 | return transforms.Compose(transformations)
583 |
584 |
585 | class Numpy2Torch(object):
586 | def __call__(self, args):
587 | images, labels = args
588 | images_tensor = torch.Tensor(images).permute(0, 3, 1, 2)
589 | labels_tensor = torch.Tensor(labels).permute(0, 3, 1, 2)
590 | return images_tensor, labels_tensor
591 |
592 |
593 | class RandomFlip(object):
594 | def __call__(self, args):
595 | images, labels = args
596 | horizontal_flip = np.random.choice([True, False])
597 | vertical_flip = np.random.choice([True, False])
598 |
599 | if horizontal_flip:
600 | images = np.flip(images, axis=2)
601 | labels = np.flip(labels, axis=2)
602 |
603 | if vertical_flip:
604 | images = np.flip(images, axis=1)
605 | labels = np.flip(labels, axis=1)
606 |
607 | images = images.copy()
608 | labels = labels.copy()
609 |
610 | return images, labels
611 |
612 |
613 | class RandomRotate(object):
614 | def __call__(self, args):
615 | images, labels = args
616 | k = np.random.randint(1, 4) # number of 90 degree rotations
617 | images = np.rot90(images, k, axes=(1, 2)).copy()
618 | labels = np.rot90(labels, k, axes=(1, 2)).copy()
619 | return images, labels
620 |
621 |
622 | class RandomColorBlur(object):
623 | def __call__(self, args):
624 | images, labels = args
625 | for t in range(images.shape[0]):
626 | blurred_image = gaussian_filter(images[t], sigma=np.random.rand() / 2)
627 | images[t] = blurred_image
628 | return images, labels
629 |
630 |
631 | class RandomColorJitter(object):
632 | def __init__(self, brightness: float = 0.3, contrast: float = 0.3, saturation: float = 0.3, hue: float = 0.3,
633 | n_bands: int = 3):
634 | self.cj = transforms.ColorJitter(brightness=brightness, contrast=contrast, saturation=saturation, hue=hue)
635 | self.n_bands = n_bands
636 |
637 | def __call__(self, args):
638 | images, labels = args
639 | for t in range(images.shape[0]):
640 | if self.n_bands <= 3:
641 | images[t] = self.cj(images[t])
642 | else:
643 | # Separate the bands
644 | bands = [images[t, i] for i in range(self.n_bands)]
645 |
646 | # Apply ColorJitter to each band
647 | jittered_bands = [self.cj(band.unsqueeze(0)) for band in bands]
648 |
649 | # Stack the bands back together
650 | jittered_image = torch.cat(jittered_bands, dim=0)
651 | images[t] = jittered_image
652 | return images, labels
653 |
654 |
655 | # Performs uniform cropping on images
656 | class UniformCrop(object):
657 | def __init__(self, crop_size: int):
658 | self.crop_size = crop_size
659 |
660 | def random_crop(self, args):
661 | images, labels = args
662 | _, height, width, _ = labels.shape
663 | crop_limit_x = width - self.crop_size
664 | crop_limit_y = height - self.crop_size
665 | x = np.random.randint(0, crop_limit_x)
666 | y = np.random.randint(0, crop_limit_y)
667 |
668 | images_crop = images[:, y:y + self.crop_size, x:x + self.crop_size]
669 | labels_crop = labels[:, y:y + self.crop_size, x:x + self.crop_size]
670 | return images_crop, labels_crop
671 |
672 | def __call__(self, args):
673 | images, labels = self.random_crop(args)
674 | return images, labels
675 |
676 |
677 | class ImportanceRandomCrop(UniformCrop):
678 | def __init__(self, crop_size: int, oversampling_type: str):
679 | super().__init__(crop_size)
680 | self.oversampling_type = oversampling_type
681 |
682 | def __call__(self, args):
683 |
684 | sample_size = 20
685 | balancing_factor = 5
686 |
687 | random_crops = [self.random_crop(args) for _ in range(sample_size)]
688 |
689 | if self.oversampling_type == 'change':
690 | crop_weights = np.array(
691 | [np.not_equal(crop_label[-1], crop_label[0]).sum() for _, crop_label in random_crops]
692 | ) + balancing_factor
693 | elif self.oversampling_type == 'semantic':
694 | crop_weights = np.array([crop_label.sum() for _, crop_label in random_crops]) + balancing_factor
695 | else:
696 | raise Exception('Unkown oversampling type!')
697 |
698 | crop_weights = crop_weights / crop_weights.sum()
699 |
700 | sample_idx = np.random.choice(sample_size, p=crop_weights)
701 | img, label = random_crops[sample_idx]
702 |
703 | return img, label
704 |
--------------------------------------------------------------------------------
/utils/evaluation.py:
--------------------------------------------------------------------------------
1 | import abc
2 |
3 | import torch
4 | from torch.utils import data as torch_data
5 | from torch import Tensor
6 |
7 | from utils import datasets, helpers
8 | import wandb
9 |
10 | EPS = 10e-05
11 |
12 |
13 | class AbstractMeasurer(abc.ABC):
14 | def __init__(self, threshold: float = 0.5, name: str = None):
15 |
16 | self.threshold = threshold
17 | self.name = name
18 |
19 | # urban mapping
20 | self.TP_seg_cont = self.TN_seg_cont = self.FP_seg_cont = self.FN_seg_cont = 0
21 | self.TP_seg_fl = self.TN_seg_fl = self.FP_seg_fl = self.FN_seg_fl = 0
22 |
23 | # urban change | cch -> continuous change | flch -> first to last change
24 | self.TP_ch_cont = self.TN_ch_cont = self.FP_ch_cont = self.FN_ch_cont = 0
25 | self.TP_ch_fl = self.TN_ch_fl = self.FP_ch_fl = self.FN_ch_fl = 0
26 |
27 | def add_sample(self, *args, **kwargs):
28 | raise NotImplementedError("add_sample method must be implemented in the subclass.")
29 |
30 | def _update_metrics(self, y: Tensor, y_hat: Tensor, attr_name: str, mask: Tensor = None):
31 | y = y.bool()
32 | y_hat = y_hat > self.threshold
33 |
34 | tp_attr = f'TP_{attr_name}'
35 | tn_attr = f'TN_{attr_name}'
36 | fp_attr = f'FP_{attr_name}'
37 | fn_attr = f'FN_{attr_name}'
38 |
39 | tp = (y & y_hat).float()
40 | tn = (~y & ~y_hat).float()
41 | fp = (y_hat & ~y).float()
42 | fn = (~y_hat & y).float()
43 |
44 | if mask is not None:
45 | tp[mask] = float('nan')
46 | tn[mask] = float('nan')
47 | fp[mask] = float('nan')
48 | fn[mask] = float('nan')
49 |
50 | setattr(self, tp_attr, getattr(self, tp_attr) + torch.nansum(tp).float().item())
51 | setattr(self, tn_attr, getattr(self, tn_attr) + torch.nansum(tn).float().item())
52 | setattr(self, fp_attr, getattr(self, fp_attr) + torch.nansum(fp).float().item())
53 | setattr(self, fn_attr, getattr(self, fn_attr) + torch.nansum(fn).float().item())
54 |
55 |
56 | class MultiTaskMeasurer(AbstractMeasurer):
57 | def __init__(self, threshold: float = 0.5, name: str = None):
58 | super().__init__(threshold, name)
59 |
60 | def add_sample(self, y_seg: Tensor, y_hat_seg: Tensor, y_ch: Tensor, y_hat_ch: Tensor, mask: Tensor = None):
61 |
62 | # urban mapping
63 | if y_seg is not None:
64 | self._update_metrics(y_seg, y_hat_seg, 'seg_cont', mask)
65 | self._update_metrics(y_seg[:, [0, -1]], y_hat_seg[:, [0, -1]], 'seg_fl', mask)
66 |
67 | # urban change
68 | if y_hat_ch.size(1) > 1:
69 | self._update_metrics(y_ch[:, :-1], y_hat_ch[:, :-1], 'ch_cont', mask)
70 | self._update_metrics(y_ch[:, -1], y_hat_ch[:, -1], 'ch_fl', mask)
71 |
72 |
73 | def run_quantitative_evaluation(net, cfg, device, run_type: str, enable_mti: bool = False,
74 | mti_edge_setting: str = 'dense') -> MultiTaskMeasurer:
75 | tile_size = cfg.AUGMENTATION.CROP_SIZE
76 | ds = datasets.create_eval_dataset(cfg, run_type, tiling=tile_size)
77 |
78 | net.to(device)
79 | net.eval()
80 |
81 | m = MultiTaskMeasurer()
82 | edges_cyclic = helpers.get_edges(cfg.DATALOADER.TIMESERIES_LENGTH, 'cyclic')
83 | edges_mti = helpers.get_edges(cfg.DATALOADER.TIMESERIES_LENGTH, mti_edge_setting)
84 |
85 | batch_size = 1 if enable_mti else cfg.TRAINER.BATCH_SIZE
86 | dataloader = torch_data.DataLoader(ds, batch_size=batch_size, num_workers=0, shuffle=False, drop_last=False)
87 |
88 | for step, item in enumerate(dataloader):
89 | x, y_seg = item['x'].to(device), item['y']
90 | y_ch = helpers.get_ch(y_seg, edges_cyclic)
91 | with torch.no_grad():
92 | if enable_mti:
93 | o_seg = net.module.inference(x, edges_mti)
94 | o_ch = helpers.get_ch(o_seg, edges_cyclic)
95 | else:
96 | logits_ch, logits_seg = net(x, edges_cyclic)
97 | o_ch, o_seg = torch.sigmoid(logits_ch).detach(), torch.sigmoid(logits_seg).detach()
98 | m.add_sample(y_seg.cpu(), o_seg.cpu(), y_ch.cpu(), o_ch.cpu())
99 |
100 | return m
101 |
102 |
103 | def model_evaluation(net, cfg, device, run_type: str, epoch: float, step: int) -> float:
104 | m = run_quantitative_evaluation(net, cfg, device, run_type)
105 |
106 | f1_seg_cont = f1_score(m.TP_seg_cont, m.FP_seg_cont, m.FN_seg_cont)
107 | f1_seg_fl = f1_score(m.TP_seg_fl, m.FP_seg_fl, m.FN_seg_fl)
108 | f1_ch_cont = f1_score(m.TP_ch_cont, m.FP_ch_cont, m.FN_ch_cont)
109 | f1_ch_fl = f1_score(m.TP_ch_fl, m.FP_ch_fl, m.FN_ch_fl)
110 | f1 = (f1_seg_cont + f1_seg_fl + f1_ch_cont + f1_ch_fl) / 4
111 |
112 | wandb.log({
113 | f'{run_type} f1': f1,
114 | f'{run_type} f1 seg cont': f1_seg_cont,
115 | f'{run_type} f1 seg fl': f1_seg_fl,
116 | f'{run_type} f1 ch cont': f1_ch_cont,
117 | f'{run_type} f1 ch fl': f1_ch_fl,
118 | 'step': step, 'epoch': epoch,
119 | })
120 |
121 | return f1
122 |
123 |
124 | def precision(tp: int, fp: int) -> float:
125 | return tp / (tp + fp + EPS)
126 |
127 |
128 | def recall(tp: int, fn: int) -> float:
129 | return tp / (tp + fn + EPS)
130 |
131 |
132 | def f1_score(tp: int, fp: int, fn: int) -> float:
133 | p = precision(tp, fp)
134 | r = recall(tp, fn)
135 | return (2 * p * r) / (p + r + EPS)
136 |
137 |
138 | def iou(tp: int, fp: int, fn: int) -> float:
139 | return tp / (tp + fp + fn + EPS)
140 |
--------------------------------------------------------------------------------
/utils/experiment_manager.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import logging
3 | from argparse import ArgumentParser
4 | from tabulate import tabulate
5 | from collections import OrderedDict
6 | import yaml
7 | from fvcore.common.config import CfgNode as _CfgNode
8 | from pathlib import Path
9 |
10 |
11 | class CfgNode(_CfgNode):
12 | """
13 | The same as `fvcore.common.config.CfgNode`, but different in:
14 |
15 | 1. Use unsafe yaml loading by default.
16 | Note that this may lead to arbitrary code execution: you must not
17 | load a config file from untrusted sources before manually inspecting
18 | the content of the file.
19 | 2. Support config versioning.
20 | When attempting to merge an old config, it will convert the old config automatically.
21 |
22 | """
23 |
24 | def __init__(self, init_dict=None, key_list=None, new_allowed=False):
25 | # Always allow merging new configs
26 | self.__dict__[CfgNode.NEW_ALLOWED] = True
27 | super(CfgNode, self).__init__(init_dict, key_list, True)
28 |
29 | # Note that the default value of allow_unsafe is changed to True
30 | def merge_from_file(self, cfg_filename: str, allow_unsafe: bool = True) -> None:
31 | loaded_cfg = _CfgNode.load_yaml_with_base(cfg_filename, allow_unsafe=allow_unsafe)
32 | loaded_cfg = type(self)(loaded_cfg)
33 |
34 | # defaults.py needs to import CfgNode
35 | self.merge_from_other_cfg(loaded_cfg)
36 |
37 |
38 | def new_config():
39 | '''
40 | Creates a new config based on the default config file
41 | :return:
42 | '''
43 |
44 | C = CfgNode()
45 |
46 | C.CONFIG_DIR = 'config/'
47 |
48 | C.PATHS = CfgNode()
49 | C.TRAINER = CfgNode()
50 | C.MODEL = CfgNode()
51 | C.DATALOADER = CfgNode()
52 | C.AUGMENTATIONS = CfgNode()
53 | C.CONSISTENCY_TRAINER = CfgNode()
54 | C.DATASETS = CfgNode()
55 |
56 | return C.clone()
57 |
58 |
59 | def setup_cfg(args):
60 | cfg = new_config()
61 | cfg.merge_from_file(f'configs/{args.config_file}.yaml')
62 | cfg.merge_from_list(args.opts)
63 | cfg.NAME = args.config_file
64 | cfg.PATHS.ROOT = str(Path.cwd())
65 | assert (Path(args.output_dir).exists())
66 | cfg.PATHS.OUTPUT = args.output_dir
67 | assert (Path(args.dataset_dir).exists())
68 | cfg.PATHS.DATASET = args.dataset_dir
69 | return cfg
70 |
71 |
72 | def setup_cfg_manual(config_name: str, root_dir: Path, output_dir: Path, dataset_dir: Path):
73 | cfg = new_config()
74 | cfg.merge_from_file(root_dir / f'configs/{config_name}.yaml')
75 | cfg.NAME = config_name
76 | cfg.PATHS.ROOT = str(Path.cwd())
77 | assert output_dir.exists()
78 | cfg.PATHS.OUTPUT = str(output_dir)
79 | assert dataset_dir.exists()
80 | cfg.PATHS.DATASET = str(dataset_dir)
81 | return cfg
82 |
83 |
84 | # loading cfg
85 | def load_cfg(config_name: str):
86 | cfg = new_config()
87 | cfg_file = Path.cwd() / 'configs' / f'{config_name}.yaml'
88 | cfg.merge_from_file(str(cfg_file))
89 | cfg.NAME = config_name
90 | return cfg
91 |
92 |
93 |
94 |
--------------------------------------------------------------------------------
/utils/helpers.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import Tensor
3 | from itertools import combinations
4 | from pathlib import Path
5 | from utils.experiment_manager import CfgNode
6 | import json
7 | from typing import Sequence, Tuple
8 |
9 |
10 | def get_aoi_ids(cfg: CfgNode) -> Sequence[str]:
11 | aoi_ids = list(cfg.DATASET.TRAIN_IDS)
12 | return aoi_ids
13 |
14 |
15 | def get_edges(n: int, edge_type: str) -> Sequence[Tuple[int, int]]:
16 | # edges are the timestamp combinations
17 | if edge_type == 'adjacent':
18 | edges = [(t1, t1 + 1) for t1 in range(n - 1)]
19 | elif edge_type == 'cyclic':
20 | edges = [(t1, t1 + 1) for t1 in range(n - 1)]
21 | edges.append((0, n - 1))
22 | elif edge_type == 'dense':
23 | edges = list(combinations(range(n), 2))
24 | elif edge_type == 'firstlast':
25 | edges = [(0, n - 1)]
26 | else:
27 | raise Exception('Unkown edge type!')
28 | return edges
29 |
30 |
31 | def get_ch(seg: Tensor, edges: Sequence[Tuple[int, int]]) -> Tensor:
32 | ch = [torch.ne(seg[:, t1], seg[:, t2]) for t1, t2 in edges]
33 | ch = torch.stack(ch).transpose(0, 1)
34 | return ch
35 |
36 |
37 | def load_json(file: Path) -> dict:
38 | with open(str(file)) as f:
39 | d = json.load(f)
40 | return d
41 |
42 |
43 | def write_json(file: Path, data: dict) -> None:
44 | with open(str(file), 'w', encoding='utf-8') as f:
45 | json.dump(data, f, ensure_ascii=False, indent=4)
46 |
--------------------------------------------------------------------------------
/utils/parsers.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 |
4 | def training_argument_parser():
5 | # https://docs.python.org/3/library/argparse.html#the-add-argument-methodsq
6 | parser = argparse.ArgumentParser(description="Experiment Args")
7 | parser.add_argument('-c', "--config-file", dest='config_file', required=True, help="path to config file")
8 | parser.add_argument('-o', "--output-dir", dest='output_dir', required=True, help="path to output directory")
9 | parser.add_argument('-d', "--dataset-dir", dest='dataset_dir', default="", required=True,
10 | help="path to output directory")
11 | parser.add_argument(
12 | "opts",
13 | help="Modify config options using the command-line",
14 | default=None,
15 | nargs=argparse.REMAINDER,
16 | )
17 | return parser
18 |
19 |
20 | def inference_argument_parser():
21 | # https://docs.python.org/3/library/argparse.html#the-add-argument-method
22 | parser = argparse.ArgumentParser(description="Experiment Args")
23 | parser.add_argument('-c', "--config-file", dest='config_file', required=True, help="path to config file")
24 | parser.add_argument('-e', "--edge-type", dest='edge_type', default='dense', help="mrf edge type")
25 | parser.add_argument('-o', "--output-dir", dest='output_dir', required=True, help="path to output directory")
26 | parser.add_argument('-d', "--dataset-dir", dest='dataset_dir', default="", required=True,
27 | help="path to output directory")
28 |
29 | parser.add_argument(
30 | "opts",
31 | help="Modify config options using the command-line",
32 | default=None,
33 | nargs=argparse.REMAINDER,
34 | )
35 | return parser
36 |
37 |
38 | def str2bool(v):
39 | if isinstance(v, bool):
40 | return v
41 | if v.lower() in ('yes', 'true', 't', 'y', '1'):
42 | return True
43 | elif v.lower() in ('no', 'false', 'f', 'n', '0'):
44 | return False
45 | else:
46 | raise argparse.ArgumentTypeError('Boolean value expected.')
47 |
--------------------------------------------------------------------------------