├── .dockerignore ├── .gitignore ├── Dockerfile ├── Dockerfile.tensorboard ├── LICENSE ├── README.md ├── config └── so_match │ └── experiments │ ├── backbone_asl1_wml.json │ ├── goodness_hm.json │ ├── goodness_opt_vgg.json │ └── goodness_sar_vgg.json ├── create_goodness_dataset.py ├── datasets ├── __init__.py ├── csv_ua_dataset.py ├── sen12_dataset.py └── urban_atlas_dataset.py ├── environment.yml ├── logger └── tensorboard_logger.py ├── losses ├── __init__.py └── functional.py ├── metrics ├── categorical_accuracy_one_hot.py └── custom_metric.py ├── models ├── corr_feature_net.py ├── goodness_net.py └── outlier_reduction_net.py ├── optimizers └── __init__.py ├── samplers ├── __init__.py └── round_robin_batch_sampler.py ├── schedulers ├── __init__.py ├── functional.py └── smith_1cycle_lr.py ├── test.py ├── tools ├── __init__.py ├── compare_all.py ├── create_file_list.py ├── dfc_sen12ms_dataset.py ├── find_feature_points.py ├── lr_loss_plot.py ├── plot_results.py ├── process_full_scene_results.py ├── process_psnet_results.py └── urban_atlas_helpers.py ├── train.py ├── trainer ├── default_trainer.py └── wur_hypercol_trainer.py └── utils ├── __init__.py ├── augmentation.py ├── basic_cache.py ├── experiment.py ├── factory.py ├── geo_tools.py ├── helpers.py └── modules.py /.dockerignore: -------------------------------------------------------------------------------- 1 | results/ 2 | results_semisupervised/ 3 | data/ 4 | *.tar 5 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | 106 | # OSX Files 107 | .DS_Store 108 | 109 | # Built docs 110 | /docs/build 111 | 112 | data/* 113 | results/ 114 | notebooks_isprs/* -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM pytorch/pytorch:1.1.0-cuda10.0-cudnn7.5-runtime 2 | LABEL maintainer="Lloyd Hughes " 3 | 4 | ############################################################################## 5 | # Upgrade conda, pip and apt-get 6 | ############################################################################## 7 | #RUN conda update conda -y --quiet 8 | #RUN conda install -y pip 9 | #RUN pip install --upgrade pip 10 | RUN apt-get update 11 | 12 | RUN apt-get install -y libfftw3-dev libsm6 libxext6 libxrender-dev 13 | 14 | # RUN apt-get install -y libgl1-mesa-glx 15 | 16 | ENV LC_ALL=C.UTF-8 17 | ENV LANG=C.UTF-8 18 | 19 | #COPY environment.yml /root/environment.yml 20 | #RUN conda update -y -n base -c defaults conda 21 | # RUN conda create -n custom python=3.7 numpy scipy scikit-learn 22 | # RUN conda env update -f /root/environment.yml 23 | 24 | #RUN echo "source activate custom" > ~/.bashrc 25 | #ENV PATH /opt/conda/envs/env/bin:$PATH 26 | 27 | # RUN /bin/bash -c "source activate custom && conda install -y pytorch-nightly -c pytorch" 28 | # RUN /bin/bash -c "source activate custom && conda install -y pytorch=1.1 cuda90 torchvision -c pytorch" 29 | RUN conda install -y -c menpo opencv3 30 | RUN conda install -y dask=0.19.3 scikit-learn 31 | RUN conda install -y -c conda-forge geopandas=0.5.0 libspatialite=4.3.0a libspatialindex=1.9.0 32 | RUN pip install pytorch-ignite==0.2.0 rasterio==1.0.8 tensorboardX torchsummary dotmap==1.3.4 pandas==0.23.4 pyproj==2.1.0 geojson utm==0.4.2 fiona==1.8.0 33 | RUN pip install six numpy scipy Pillow matplotlib scikit-image opencv-python imageio 34 | RUN pip install imgaug visdom 35 | 36 | WORKDIR /src 37 | #RUN pip uninstall -y apex || : 38 | RUN git clone https://github.com/NVIDIA/apex.git 39 | WORKDIR /src/apex 40 | RUN python setup.py install 41 | #RUN /bin/bash -c "source activate custom && python setup.py install" 42 | 43 | 44 | COPY . /src 45 | WORKDIR /src 46 | 47 | ENTRYPOINT ["python"] 48 | 49 | # FROM pytorch/pytorch:1.1.0-cuda10.0-cudnn7.5-runtime 50 | # LABEL maintainer="Lloyd Hughes " 51 | 52 | # ############################################################################## 53 | # # Upgrade conda, pip and apt-get 54 | # ############################################################################## 55 | # RUN conda update conda -y --quiet 56 | # RUN pip install --upgrade pip 57 | # RUN apt-get update 58 | 59 | # RUN apt-get install -y libfftw3-dev 60 | 61 | # ENV LC_ALL=C.UTF-8 62 | # ENV LANG=C.UTF-8 63 | 64 | # COPY environment.yml /root/environment.yml 65 | # # RUN conda update -y -n base -c defaults conda 66 | # RUN conda env create -f /root/environment.yml -n custom 67 | # # RUN conda env update -f /root/environment.yml 68 | 69 | # RUN echo "source activate custom" > ~/.bashrc 70 | # ENV PATH /opt/conda/envs/env/bin:$PATH 71 | 72 | # # RUN /bin/bash -c "source activate custom && conda install -y pytorch-nightly -c pytorch" 73 | # RUN /bin/bash -c "source activate custom && conda install -y pytorch=1.1 cuda90 torchvision -c pytorch" 74 | 75 | # WORKDIR /src 76 | # RUN pip uninstall -y apex || : 77 | # RUN git clone https://github.com/NVIDIA/apex.git 78 | # WORKDIR /src/apex 79 | # RUN /bin/bash -c "source activate custom && python setup.py install" 80 | 81 | # COPY . /src 82 | # WORKDIR /src 83 | 84 | # ENTRYPOINT ["/opt/conda/envs/custom/bin/python"] 85 | -------------------------------------------------------------------------------- /Dockerfile.tensorboard: -------------------------------------------------------------------------------- 1 | FROM tensorflow/tensorflow 2 | RUN mkdir /runs 3 | WORKDIR "/runs" 4 | CMD ["tensorboard", "--logdir=/runs"] -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Lloyd Hughes 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SOMatch 2 | A Framework for Deep Learning-based Sparse SAR-Optical Image Matching [[Paper](https://www.sciencedirect.com/science/article/pii/S0924271620302598?via%3Dihub)] 3 | 4 | ## Building the Docker Image 5 | `docker build -t somatch:latest .` 6 | 7 | ## Datasets and Dataloaders 8 | 9 | ## Training the Matching Network 10 | `docker run -it --rm --runtime=nvidia -v :/src/data/ -v :/src/results -e CUDA_VISIBLE_DEVICES=0 --ipc=host somatch:latest train.py --config config/so_match/experiments/backbone_asl1_wml.json` 11 | 12 | ## Creating the Goodness Network and ORN dataset 13 | 14 | ## Training the Goodness Networks 15 | 16 | ## Training the Outlier Reduction Network (ORN) 17 | 18 | ## Using the Trained Networks 19 | 20 | # Using this work: 21 | If you make use of this code, or generate any results using this code, please cite the corresponding paper: 22 | 23 | > Hughes, L. H., Marcos, D., Lobry, S., Tuia, D., & Schmitt, M. (2020). A deep learning framework for matching of SAR and optical imagery. ISPRS Journal of Photogrammetry and Remote Sensing, 169, 166-179. 24 | 25 | -------------------------------------------------------------------------------- /config/so_match/experiments/backbone_asl1_wml.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "backbone_asl1_wml", 3 | "desc": "Hypercol matching network", 4 | "result_dir": "results/", 5 | "device": "cuda", 6 | "device_ids": [0], 7 | "epochs": 50, 8 | "nsave": 8, 9 | "resume_from": -1, 10 | "overwrite": true, 11 | "save_freq": 1, 12 | "log_freq": 100, 13 | "verbosity": 0, 14 | "seed": 999, 15 | "logger": { 16 | "TensorboardLogger": { 17 | "log_every": 1, 18 | "log_params": false, 19 | "log_grads": false, 20 | "log_images": true 21 | } 22 | }, 23 | "trainer": "WURHypercolTrainer", 24 | "trainer_config": { 25 | "search_domain": "A", 26 | "fp16": false, 27 | "loss_weights": { 28 | "match": 1, 29 | "spatial_softmax": 1e-4, 30 | "heatmap_l1": 1e-7 31 | } 32 | }, 33 | "loss": { 34 | "Lmatch": ["mse_loss_weighted"] 35 | }, 36 | "datasets": { 37 | "train": { 38 | "type": "UrbanAtlasDatasetSiameseTriplet", 39 | "base_dir": "data", 40 | "cities": [ 41 | "StaraZagora", 42 | "Marseille", 43 | "Faro", 44 | "Bristol", 45 | "Rzeszow", 46 | "Heraklion", 47 | "Wirral", 48 | "Lisbon", 49 | "LeHavre", 50 | "Athen", 51 | "Aveiro", 52 | "Braga" 53 | ], 54 | "batch_size": 16, 55 | "cache_dir": "cache/train", 56 | "cache_size": 50000, 57 | "shuffle": true, 58 | "augment": true, 59 | "workers": 4, 60 | "normalize": {}, 61 | "crop_a": 256, 62 | "crop_b": 128, 63 | "noise": false, 64 | "perc_supervised": 100, 65 | "single_domain": false, 66 | "stretch_contrast": true 67 | }, 68 | "validation": { 69 | "type": "UrbanAtlasDatasetSiameseTriplet", 70 | "base_dir": "data", 71 | "cities": ["Exeter", "Lincoln", "Kalisz"], 72 | "batch_size": 16, 73 | "cache_dir": "cache/validation", 74 | "cache_size": 10000, 75 | "shuffle": false, 76 | "augment": false, 77 | "workers": 4, 78 | "normalize": {}, 79 | "crop_a": 256, 80 | "crop_b": 128, 81 | "noise": false, 82 | "perc_supervised": 100, 83 | "single_domain": false, 84 | "stretch_contrast": true 85 | }, 86 | "test": { 87 | "type": "UrbanAtlasDatasetSiameseTriplet", 88 | "base_dir": "data", 89 | "cities": [ 90 | "Varna", 91 | "Sofia", 92 | "Valetta", 93 | "Kattowitz", 94 | "Portsmouth", 95 | "London", 96 | "Leeds", 97 | "Volos" 98 | ], 99 | "batch_size": 32, 100 | "cache_dir": "cache/test", 101 | "cache_size": 10000, 102 | "shuffle": false, 103 | "augment": false, 104 | "workers": 4, 105 | "normalize": {}, 106 | "crop_a": 256, 107 | "crop_b": 128, 108 | "noise": false, 109 | "perc_supervised": 100, 110 | "stretch_contrast": true 111 | } 112 | }, 113 | "optimizer": { 114 | "Fts": { 115 | "models": ["FtsA", "FtsB"], 116 | "Adam": { 117 | "lr": 1e-4, 118 | "betas": [0.9, 0.999], 119 | "weight_decay": 0 120 | } 121 | } 122 | }, 123 | "_scheduler": { 124 | "AE_A": { 125 | "smith_1cycle": { 126 | "scheme": "batch", 127 | "max_lr": 2e-4, 128 | "min_lr": 2e-5, 129 | "anneal_div": 1, 130 | "total_iter": 20000 131 | } 132 | }, 133 | "AE_B": { 134 | "smith_1cycle": { 135 | "scheme": "batch", 136 | "max_lr": 2e-4, 137 | "min_lr": 2e-5, 138 | "anneal_div": 1, 139 | "total_iter": 20000 140 | } 141 | } 142 | }, 143 | "_monitor": { 144 | "score": "loss", 145 | "scale": -1, 146 | "early_stopping": true, 147 | "patience": 30, 148 | "save_score": "loss", 149 | "save_scale": -1 150 | }, 151 | "model": { 152 | "FtsA": { 153 | "CorrelationFeatureNet": { 154 | "column_depth": 256, 155 | "normalize": true, 156 | "no_relu": true, 157 | "attention": true, 158 | "return_attn": false, 159 | "attn_act": "tanh" 160 | } 161 | }, 162 | "FtsB": { 163 | "CorrelationFeatureNet": { 164 | "column_depth": 256, 165 | "normalize": true, 166 | "no_relu": true, 167 | "attention": true, 168 | "return_attn": false, 169 | "attn_act": "tanh" 170 | } 171 | } 172 | } 173 | } 174 | -------------------------------------------------------------------------------- /config/so_match/experiments/goodness_hm.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "goodness_hm", 3 | "desc": "Hypercol matching network", 4 | "result_dir": "results/", 5 | "device": "cuda", 6 | "device_ids": [0], 7 | "epochs": 500, 8 | "nsave": 8, 9 | "resume_from": -1, 10 | "overwrite": true, 11 | "save_freq": 1, 12 | "log_freq": 100, 13 | "verbosity": 0, 14 | "seed": 123, 15 | "logger": { 16 | "TensorboardLogger": { 17 | "log_every": 1, 18 | "log_params": false, 19 | "log_grads": false, 20 | "log_images": true 21 | } 22 | }, 23 | "trainer": "DefaultTrainer", 24 | "trainer_config": { 25 | }, 26 | "loss": { 27 | "Loss": ["binary_cross_entropy_with_logits"] 28 | }, 29 | "datasets": { 30 | "train": { 31 | "type": "CSVUADataset", 32 | "base_dir": "data", 33 | "csv": "checkpoint_50_extracted_dset_train.csv", 34 | "domain": "hm", 35 | "balance": true, 36 | "thresh_loss": [2, 1.2], 37 | "thresh_l2": [1, 2.5], 38 | "batch_size": 64, 39 | "shuffle": true, 40 | "augment": true, 41 | "workers": 4 42 | }, 43 | "validation": { 44 | "type": "CSVUADataset", 45 | "base_dir": "data", 46 | "csv": "checkpoint_50_extracted_dset_validation.csv", 47 | "domain": "hm", 48 | "balance": true, 49 | "thresh_loss": [2, 1.2], 50 | "thresh_l2": [1, 2.5], 51 | "batch_size": 64, 52 | "shuffle": false, 53 | "augment": false, 54 | "workers": 1 55 | }, 56 | "test": { 57 | "type": "CSVUADataset", 58 | "base_dir": "data", 59 | "csv": "checkpoint_50_extracted_dset_test.csv", 60 | "domain": "hm", 61 | "balance": true, 62 | "thresh_loss": [2, 1.2], 63 | "thresh_l2": [1, 2.5], 64 | "batch_size": 64, 65 | "shuffle": false, 66 | "augment": false, 67 | "workers": 1 68 | } 69 | }, 70 | "optimizer": { 71 | "Optim": { 72 | "models": ["Model"], 73 | "Adam": { 74 | "lr": 0.0001, 75 | "betas": [0.9, 0.999], 76 | "weight_decay": 1e-6 77 | } 78 | } 79 | }, 80 | "_scheduler": { 81 | "Optim": { 82 | "smith_1cycle": { 83 | "scheme": "batch", 84 | "max_lr": 1e-4, 85 | "min_lr": 5e-6, 86 | "anneal_div": 1, 87 | "total_iter": 20000 88 | } 89 | } 90 | }, 91 | "_monitor": { 92 | "score": "loss", 93 | "scale": -1, 94 | "early_stopping": true, 95 | "patience": 30, 96 | "save_score": "loss", 97 | "save_scale": -1 98 | }, 99 | "model": { 100 | "Model": { 101 | "ORN": { 102 | "classes": 1, 103 | "padding": false 104 | } 105 | } 106 | } 107 | } 108 | -------------------------------------------------------------------------------- /config/so_match/experiments/goodness_opt_vgg.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "goodness_opt_vgg", 3 | "desc": "Hypercol matching network", 4 | "result_dir": "results/", 5 | "device": "cuda", 6 | "device_ids": [0], 7 | "epochs": 1000, 8 | "nsave": 8, 9 | "resume_from": -1, 10 | "overwrite": true, 11 | "save_freq": 1, 12 | "log_freq": 100, 13 | "verbosity": 0, 14 | "seed": 123, 15 | "logger": { 16 | "TensorboardLogger": { 17 | "log_every": 1, 18 | "log_params": false, 19 | "log_grads": false, 20 | "log_images": true 21 | } 22 | }, 23 | "trainer": "DefaultTrainer", 24 | "trainer_config": { 25 | }, 26 | "loss": { 27 | "Loss": ["binary_cross_entropy_with_logits"] 28 | }, 29 | "datasets": { 30 | "train": { 31 | "type": "CSVUADataset", 32 | "base_dir": "data", 33 | "csv": "checkpoint_50_extracted_dset_train.csv", 34 | "domain": "opt_crop", 35 | "balance": true, 36 | "thresh_loss": [2, 1.2], 37 | "thresh_l2": [1, 2.5], 38 | "batch_size": 16, 39 | "shuffle": true, 40 | "augment": true, 41 | "workers": 4 42 | }, 43 | "validation": { 44 | "type": "CSVUADataset", 45 | "base_dir": "data", 46 | "csv": "checkpoint_50_extracted_dset_validation.csv", 47 | "domain": "opt_crop", 48 | "balance": true, 49 | "thresh_loss": [2, 1.2], 50 | "thresh_l2": [1, 2.5], 51 | "batch_size": 16, 52 | "shuffle": false, 53 | "augment": false, 54 | "workers": 1 55 | }, 56 | "test": { 57 | "type": "CSVUADataset", 58 | "base_dir": "data", 59 | "csv": "checkpoint_50_extracted_dset_test.csv", 60 | "domain": "opt_crop", 61 | "balance": true, 62 | "thresh_loss": [2, 1.2], 63 | "thresh_l2": [1, 2.5], 64 | "batch_size": 64, 65 | "shuffle": false, 66 | "augment": false, 67 | "workers": 1 68 | } 69 | }, 70 | "optimizer": { 71 | "Optim": { 72 | "models": ["Model"], 73 | "Adam": { 74 | "lr": 9e-4, 75 | "betas": [0.9, 0.999], 76 | "weight_decay": 0 77 | } 78 | } 79 | }, 80 | "_scheduler": { 81 | "Optim": { 82 | "smith_1cycle": { 83 | "scheme": "batch", 84 | "max_lr": 1e-4, 85 | "min_lr": 5e-6, 86 | "anneal_div": 1, 87 | "total_iter": 20000 88 | } 89 | } 90 | }, 91 | "_monitor": { 92 | "score": "loss", 93 | "scale": -1, 94 | "early_stopping": true, 95 | "patience": 30, 96 | "save_score": "loss", 97 | "save_scale": -1 98 | }, 99 | "model": { 100 | "Model": { 101 | "VGGBasedGoodnessNet": { 102 | "classes": 1, 103 | "padding": false, 104 | "pooling": "avg" 105 | } 106 | } 107 | } 108 | } 109 | -------------------------------------------------------------------------------- /config/so_match/experiments/goodness_sar_vgg.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "goodness_sar_vgg", 3 | "desc": "Hypercol matching network", 4 | "result_dir": "results/", 5 | "device": "cuda", 6 | "device_ids": [0], 7 | "epochs": 1000, 8 | "nsave": 8, 9 | "resume_from": -1, 10 | "overwrite": true, 11 | "save_freq": 1, 12 | "log_freq": 100, 13 | "verbosity": 0, 14 | "seed": 123, 15 | "logger": { 16 | "TensorboardLogger": { 17 | "log_every": 1, 18 | "log_params": false, 19 | "log_grads": false, 20 | "log_images": true 21 | } 22 | }, 23 | "trainer": "DefaultTrainer", 24 | "trainer_config": { 25 | }, 26 | "loss": { 27 | "Loss": ["binary_cross_entropy_with_logits"] 28 | }, 29 | "datasets": { 30 | "train": { 31 | "type": "CSVUADataset", 32 | "base_dir": "data", 33 | "csv": "checkpoint_50_extracted_dset_train.csv", 34 | "domain": "sar_crop", 35 | "balance": true, 36 | "thresh_loss": [2, 1.2], 37 | "thresh_l2": [1, 2.5], 38 | "batch_size": 16, 39 | "shuffle": true, 40 | "augment": true, 41 | "workers": 4 42 | }, 43 | "validation": { 44 | "type": "CSVUADataset", 45 | "base_dir": "data", 46 | "csv": "checkpoint_50_extracted_dset_validation.csv", 47 | "domain": "sar_crop", 48 | "balance": true, 49 | "thresh_loss": [2, 1.2], 50 | "thresh_l2": [1, 2.5], 51 | "batch_size": 16, 52 | "shuffle": false, 53 | "augment": false, 54 | "workers": 1 55 | }, 56 | "test": { 57 | "type": "CSVUADataset", 58 | "base_dir": "data", 59 | "csv": "checkpoint_50_extracted_dset_test.csv", 60 | "domain": "sar_crop", 61 | "balance": true, 62 | "thresh_loss": [2, 1.2], 63 | "thresh_l2": [1, 2.5], 64 | "batch_size": 64, 65 | "shuffle": false, 66 | "augment": false, 67 | "workers": 1 68 | } 69 | }, 70 | "optimizer": { 71 | "Optim": { 72 | "models": ["Model"], 73 | "Adam": { 74 | "lr": 1e-4, 75 | "betas": [0.9, 0.999], 76 | "weight_decay": 0 77 | } 78 | } 79 | }, 80 | "_scheduler": { 81 | "Optim": { 82 | "smith_1cycle": { 83 | "scheme": "batch", 84 | "max_lr": 1e-4, 85 | "min_lr": 5e-6, 86 | "anneal_div": 1, 87 | "total_iter": 20000 88 | } 89 | } 90 | }, 91 | "_monitor": { 92 | "score": "loss", 93 | "scale": -1, 94 | "early_stopping": true, 95 | "patience": 30, 96 | "save_score": "loss", 97 | "save_scale": -1 98 | }, 99 | "model": { 100 | "Model": { 101 | "VGGBasedGoodnessNet": { 102 | "classes": 1, 103 | "padding": false, 104 | "pooling": "avg" 105 | } 106 | } 107 | } 108 | } 109 | -------------------------------------------------------------------------------- /create_goodness_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | from argparse import ArgumentParser 4 | 5 | import pandas as pd 6 | 7 | import torch 8 | from torch.utils.data import DataLoader 9 | 10 | from ignite.engine.engine import Engine, State, Events 11 | from ignite.handlers import ModelCheckpoint, EarlyStopping 12 | from ignite._utils import convert_tensor 13 | 14 | from utils import Experiment 15 | from utils.factory import * 16 | from utils.helpers import static_vars 17 | 18 | import matplotlib.pyplot as plt 19 | import numpy as np 20 | import scipy.io as sio 21 | 22 | from tqdm import tqdm 23 | 24 | logging.basicConfig(level=logging.INFO, format='') 25 | logger = logging.getLogger() 26 | 27 | def save_image(tensor, fname, cmap=plt.cm.jet): 28 | data = tensor.to("cpu").numpy().squeeze(0).squeeze(0) 29 | plt.imsave(fname, data, cmap=cmap) 30 | 31 | def save_numpy(tensor, fname): 32 | data = tensor.to("cpu").numpy() 33 | np.save(fname, data) 34 | 35 | def main(config, dataset="test"): 36 | assert validate_config(config), "ERROR: Config file is invalid. Please see log for details." 37 | 38 | logger.info("INFO: {}".format(config.toDict())) 39 | 40 | if config.device == "cpu" and torch.cuda.is_available(): 41 | logger.warning("WARNING: Not using the GPU") 42 | 43 | if "cuda" in config.device: 44 | config.device = "cuda" 45 | 46 | assert dataset in config.datasets, "ERROR: Not test dataset is specified in the config. Don't know how to proceed." 47 | 48 | logger.info("INFO: Creating datasets and dataloaders...") 49 | 50 | config.datasets[dataset].update({'shuffle': False, 'augment': False, 'workers': 1}) 51 | config.datasets[dataset].update({'batch_size': 1, "named": True, "return_all": True}) 52 | 53 | # Create the training dataset 54 | dset_test = create_dataset(config.datasets[dataset]) 55 | 56 | loader_test = get_data_loader(dset_test, config.datasets[dataset]) 57 | 58 | logger.info("INFO: Running inference on {} samples".format(len(dset_test))) 59 | 60 | cp_paths = None 61 | last_epoch = 0 62 | checkpoint_dir = config.result_dir 63 | if 'checkpoint' in config: 64 | checkpoint_dir = config.checkpoint_dir if 'checkpoint_dir' in config else config.result_path 65 | cp_paths, last_epoch = config.get_checkpoints(path=checkpoint_dir, tag=config.checkpoint) 66 | print(f"Found checkpoint {cp_paths} for epoch {last_epoch}") 67 | 68 | if "DetA" in cp_paths: 69 | del cp_paths["DetA"] 70 | del cp_paths["DetB"] 71 | 72 | if "Match" in cp_paths: 73 | del cp_paths["Match"] 74 | 75 | models = {} 76 | for name, model in config.model.items(): 77 | if name in ["DetA", "DetB", "Match"]: 78 | continue 79 | 80 | logger.info("INFO: Building the {} model".format(name)) 81 | models[name] = build_model(model) 82 | 83 | # Load the checkpoint 84 | if name in cp_paths: 85 | models[name].load_state_dict( torch.load( cp_paths[name] ) ) 86 | logger.info("INFO: Loaded model {} checkpoint {}".format(name, cp_paths[name])) 87 | 88 | models[name].to(config.device) 89 | print(models[name]) 90 | 91 | if 'debug' in config and config.debug is True: 92 | print("*********** {} ************".format(name)) 93 | for name, param in models[name].named_parameters(): 94 | if param.requires_grad: 95 | print(name, param.data) 96 | 97 | losses = {} 98 | for name, fcns in config.loss.items(): 99 | losses[name] = [] 100 | for l in fcns: 101 | losses[name].append( get_loss(l) ) 102 | assert losses[name][-1], "Loss function {} for {} could not be found, please check your config".format(l, name) 103 | 104 | exp_logger = None 105 | if 'logger' in config: 106 | logger.info("INFO: Initialising the experiment logger") 107 | exp_logger = get_experiment_logger(config.result_path, config.logger) 108 | 109 | logger.info("INFO: Creating training manager and configuring callbacks") 110 | trainer = get_trainer(models, None, losses, None, config) 111 | 112 | evaluator_engine = Engine(trainer.evaluate) 113 | 114 | trainer.attach("test_loader", loader_test) 115 | trainer.attach("evaluation_engine", evaluator_engine) 116 | 117 | logger.info("INFO: Starting inference...") 118 | 119 | results = [] 120 | 121 | save_path = os.path.join(config.checkpoint_dir, f"extracted_{last_epoch}", dataset) 122 | os.makedirs(save_path, exist_ok=True) 123 | 124 | with torch.no_grad(): 125 | for i, (xs, ys, names) in enumerate(tqdm(loader_test)): 126 | batch = (xs, ys) 127 | try: 128 | entity = { 129 | "wkt": names["WKT"][0], 130 | "city": names["city"][0], 131 | "shift_x": names["p_match"][0].to("cpu").numpy()[0], 132 | "shift_y": names["p_match"][1].to("cpu").numpy()[0] 133 | } 134 | 135 | filename = "{}_{}".format(names["city"], names["WKT"]) 136 | except: 137 | entity = { 138 | "season": names["season"][0], 139 | "scene": names["scene"][0], 140 | "patch": names["patch"][0], 141 | "shift_x": names["p_match"][0].to("cpu").numpy()[0], 142 | "shift_y": names["p_match"][1].to("cpu").numpy()[0] 143 | } 144 | 145 | filename = "{}_{}_{}".format(names["season"], names["scene"], names["patch"]) 146 | 147 | imgs, hms, y, fts, dets = trainer.infer_batch(batch) 148 | (search_img, template_img, template_hard, search_hard) = imgs 149 | (heatmap_neg, heatmap_neg_raw, heatmap_hneg, heatmap_hneg_raw) = hms 150 | (y_a, y_b, y_bhn) = fts 151 | 152 | d_l2 = trainer.l2_shift_loss(heatmap_hneg, y[0], device="cuda") 153 | 154 | d_target = trainer.weighted_binary_cross_entropy(heatmap_hneg.detach(), y[0], device="cuda", reduction="none") 155 | d_target = torch.mean(d_target, dim=[1,2,3]) 156 | d_target = -torch.log(d_target) 157 | 158 | # Save heatmaps 159 | save_numpy( heatmap_hneg_raw, os.path.join(save_path, f"{filename}_hm.npy") ) 160 | save_numpy( search_img, os.path.join(save_path, f"{filename}_sar.npy") ) 161 | save_numpy( search_hard, os.path.join(save_path, f"{filename}_sar_crop.npy") ) 162 | save_numpy( template_img, os.path.join(save_path, f"{filename}_opt.npy") ) 163 | save_numpy( template_hard, os.path.join(save_path, f"{filename}_opt_crop.npy") ) 164 | save_numpy( y[0], os.path.join(save_path, f"{filename}_gt.npy") ) 165 | 166 | entity.update({ 167 | "l2": d_l2.to("cpu").numpy()[0], 168 | "nlog_match_loss": d_target.to("cpu").numpy()[0] 169 | }) 170 | 171 | results.append(entity) 172 | 173 | if i % 1000 == 0: 174 | df = pd.DataFrame.from_dict(results) 175 | df.to_csv(os.path.join(config.checkpoint_dir, "checkpoint_{}_extracted_dset_{}.csv".format(last_epoch, dataset)) , index=None) 176 | 177 | df = pd.DataFrame.from_dict(results) 178 | df.to_csv(os.path.join(config.checkpoint_dir, "checkpoint_{}_extracted_dset_{}.csv".format(last_epoch, dataset)) , index=None) 179 | config.save() 180 | 181 | if __name__ == "__main__": 182 | parser = ArgumentParser() 183 | parser.add_argument('-c', '--config', default=None, type=str, required=True, help='config file path (default: None)') 184 | parser.add_argument('--checkpoint', default=None, type=str, help='Checkpoint tag to reload') 185 | parser.add_argument('--checkpoint_dir', default=None, type=str, help='Checkpoint directory to reload') 186 | parser.add_argument('--dataset', default="test", type=str, help="Which dataset to test on") 187 | args = parser.parse_args() 188 | 189 | OVERLOADABLE = ['checkpoint', 'epochs', 'checkpoint_dir', 'resume_from'] 190 | 191 | overloaded = {} 192 | for k, v in vars(args).items(): 193 | if (k in OVERLOADABLE) and (v is not None): 194 | overloaded[k] = v 195 | 196 | config = Experiment.load_from_path(args.config, overloaded) 197 | 198 | print(config.checkpoint) 199 | 200 | assert config, "Config could not be loaded." 201 | 202 | main(config, args.dataset) 203 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/system123/SOMatch/6f10cf28f506998a5e430ccd3faab3076fe350d5/datasets/__init__.py -------------------------------------------------------------------------------- /datasets/csv_ua_dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torchvision import transforms 4 | from torch.utils.data import Dataset 5 | from functools import partial 6 | from skimage.io import imread 7 | from glob import glob 8 | from skimage import exposure, img_as_float, util 9 | from utils.augmentation import Augmentation, cropCenter, toGrayscale, cropCorner, cutout 10 | 11 | import numpy as np 12 | import pandas as pd 13 | import os 14 | 15 | AUG_PROBS = { 16 | "fliplr": 0.5, 17 | "flipud": 0.5, 18 | "scale": 0.1, 19 | "scale_px": (1.1, 1.1), 20 | "translate": 0, 21 | "translate_perc": (0.0, 0.0), 22 | "rotate": 0, 23 | "rotate_angle": (-5, 5), 24 | "contrast": 0.7, 25 | "dropout": 0.8 26 | } 27 | 28 | class CSVUADataset(Dataset): 29 | def __init__(self, config): 30 | super() 31 | 32 | self.domain = config.domain if isinstance(config.domain, str) else "opt_crop" 33 | self.balance = config.balance if isinstance(config.balance, bool) else False 34 | self.thresh_loss = config.thresh_loss if 'thresh_loss' in config else [0, 12] 35 | self.thresh_l2 = config.thresh_l2 if 'thresh_l2' in config else [1, 2.5] 36 | self.named = config.named if isinstance(config.named, bool) else False 37 | self.normed = config.normed if isinstance(config.normed, bool) else True 38 | 39 | self.base_dir = config.base_dir 40 | self.df = pd.read_csv(os.path.join(self.base_dir, config.csv)) 41 | 42 | dataset_name = os.path.splitext(os.path.basename(config.csv))[0].rsplit("_", 1)[1] 43 | self.img_dir = os.path.join(self.base_dir, dataset_name) 44 | 45 | func = [] 46 | 47 | if config.augment: 48 | # If it is true like then just use the default augmentation parameters - this keeps things backwards compatible 49 | if config.augment is True or len(config.augment) == 0: 50 | config.augment = AUG_PROBS.copy() 51 | 52 | self.augmentor = Augmentation(probs=config.augment) 53 | else: 54 | self.augmentor = None 55 | 56 | func.append(transforms.ToTensor()) 57 | self.transforms = transforms.Compose(func) 58 | 59 | self._label_and_prune(self.thresh_l2[0], self.thresh_loss[0], self.thresh_l2[1], self.thresh_loss[1]) 60 | 61 | def _label_and_prune(self, l2_pos=1, loss_pos=2.2, l2_neg=2.5, loss_neg=1.2): 62 | self.df["label"] = np.nan 63 | # Label positive samples 64 | self.df.loc[(self.df.l2 <= l2_pos) & (self.df.nlog_match_loss >= loss_pos), "label"] = 1 65 | self.df.loc[(self.df.l2 >= l2_neg) & (self.df.nlog_match_loss <= loss_neg), "label"] = 0 66 | 67 | # Remove all unlabeled points 68 | self.df.dropna(axis=0, inplace=True) 69 | 70 | if self.balance: 71 | limit = min( sum(self.df["label"] == 0), sum(self.df["label"] == 1) ) 72 | limited_df = self.df.groupby("label").apply( lambda x: x.sample(n=limit) ) 73 | limited_df.reset_index(drop=True, inplace=True) 74 | self.df = limited_df.sample(frac=1).reset_index(drop=True) 75 | 76 | def _get_filepath(self, row, img="sar"): 77 | return f"{self.img_dir}/['{row.city}']_['{row.wkt}']_{img}.npy" 78 | 79 | def _load_image(self, row, domain=None): 80 | data = np.load(self._get_filepath(row, img=domain))[0,] 81 | # Put in HxWxC format so data augmentation works 82 | return np.ascontiguousarray(data.transpose((1,2,0))) 83 | 84 | def normalize(self, img): 85 | return (img - img.min())/(img.ptp() + 1e-6) 86 | 87 | def _get_raw_triplet(self, row, crop=False): 88 | suffix = "_crop" if crop else "" 89 | opt = (self.transforms(self._load_image(row, f"opt{suffix}")).numpy().transpose((1,2,0))*255).astype(np.uint8) 90 | sar = (self.normalize(self.transforms(self._load_image(row, f"sar{suffix}")).numpy().transpose((1,2,0)))*255).astype(np.uint8) 91 | y = np.ones_like(sar) * row.label 92 | return sar, opt, y, {"sar": f"{row.city}_{row.name}_sar.png", "opt": f"{row.city}_{row.name}_opt.png", "label": row.label} 93 | 94 | def __len__(self): 95 | return len(self.df) 96 | 97 | def __getitem__(self, index): 98 | row = self.df.iloc[index] 99 | x = self._load_image(row, self.domain) 100 | 101 | name = {"WKT": row.wkt, "city": row.city} 102 | 103 | if self.augmentor: 104 | self.augmentor.refresh_random_state() 105 | x = self.augmentor(x) 106 | 107 | if "sar" in self.domain and self.normed: 108 | x = self.normalize(x) 109 | 110 | if "hm" in self.domain and self.normed: 111 | x = self.normalize(x) 112 | 113 | x = self.transforms(x.copy()).float() 114 | 115 | y = np.array([row.label]) 116 | 117 | if self.named: 118 | return x, y, name 119 | else: 120 | return x, y 121 | 122 | -------------------------------------------------------------------------------- /datasets/sen12_dataset.py: -------------------------------------------------------------------------------- 1 | from torchvision import transforms 2 | from torch.utils.data import Dataset 3 | import torch 4 | import operator 5 | 6 | import torch.nn.functional as F 7 | 8 | from skimage.io import imread 9 | from skimage import exposure, img_as_float 10 | from utils.augmentation import Augmentation, cropCenter, toGrayscale 11 | from utils.visualisation_helpers import plot_side_by_side 12 | from utils.helpers import load_file_list 13 | 14 | from torch.utils.data import SubsetRandomSampler 15 | from samplers.round_robin_batch_sampler import RoundRobinBatchSampler 16 | from itertools import chain, repeat 17 | from utils.basic_cache import BasicCache 18 | from tools.dfc_sen12ms_dataset import DFCSEN12MSDataset, S1Bands, S2Bands 19 | 20 | 21 | import numpy as np 22 | import pandas as pd 23 | import os 24 | 25 | AUG_PROBS = { 26 | "fliplr": 0.4, 27 | "flipud": 0, 28 | "scale": 0, 29 | "scale_px": (1.0, 1.0), 30 | "translate": 0, 31 | "translate_perc": (0.0, 0.0), 32 | "rotate": 0, 33 | "rotate_angle": (-5, 5) 34 | } 35 | 36 | def null_norm(x): 37 | return x 38 | 39 | class SEN12Dataset(Dataset): 40 | def __init__(self, config): 41 | super(SEN12Dataset, self).__init__() 42 | 43 | self.crop_size = config.crop if isinstance(config.crop, (int, float)) else None 44 | self.named = config.named if isinstance(config.named, bool) else False 45 | self.hist_norm = config.hist_norm if isinstance(config.hist_norm, bool) else True 46 | 47 | func = [] 48 | 49 | if config.augment: 50 | # If it is true like then just use the default augmentation parameters - this keeps things backwards compatible 51 | if config.augment is True or len(config.augment) == 0: 52 | config.augment = AUG_PROBS.copy() 53 | 54 | self.augmentor = Augmentation(probs=config.augment) 55 | func.append(self.augmentor) 56 | else: 57 | self.augmentor = None 58 | 59 | func.append(transforms.Lambda(lambda img: self._preprocess(img, self.crop_size))) 60 | func.append(transforms.ToTensor()) 61 | 62 | self.transforms = transforms.Compose(func) 63 | 64 | if "sar" in config.normalize: 65 | self.sar_norm = transforms.Normalize(mean=[config.normalize.sar[0]], std=[config.normalize.sar[1]]) 66 | else: 67 | self.sar_norm = null_norm 68 | 69 | if "opt" in config.normalize: 70 | self.opt_norm = transforms.Normalize(mean=[config.normalize.opt[0]], std=[config.normalize.opt[1]]) 71 | else: 72 | self.opt_norm = null_norm 73 | 74 | self.cache_dir = config.cache_dir if isinstance(config.cache_dir, str) else None 75 | self.cache_size = config.cache_size if isinstance(config.cache_size, (int, float)) else 0 76 | 77 | if self.cache_dir is not None: 78 | self.cache = BasicCache(self.cache_dir, size=self.cache_size, scheme="fill", clear=False, overwrite=False) 79 | else: 80 | self.cache = None 81 | 82 | self.sar = load_file_list(config.base_dir, config.data_path_supervised[0]) 83 | self.opt = load_file_list(config.base_dir, config.data_path_supervised[1]) 84 | self.labels = np.loadtxt(config.data_path_labels) 85 | 86 | self.limit_supervised = config.limit_supervised if isinstance(config.limit_supervised, int) else -1 87 | if self.limit_supervised > 0 and self.limit_supervised < len(self.sar[0]): 88 | idxs = range(self.limit_supervised) 89 | self.sar[0] = [self.sar[0][i] for i in idxs] 90 | self.opt[0] = [self.opt[0][i] for i in idxs] 91 | self.labels = self.labels[idxs] 92 | 93 | self.noise = config.noise if isinstance(config.noise, bool) else False 94 | 95 | self._get_scenes(seasons=["winter"]) 96 | 97 | def _get_scenes(self, seasons=["winter"]): 98 | scenes = [] 99 | for s, o, l in zip(self.sar, self.opt, self.labels): 100 | if l == 0: 101 | continue 102 | 103 | scenes.append({ 104 | "sar_path": s, 105 | "opt_path": o, 106 | "scene": os.path.splitext(os.path.basename(s))[0].rsplit("_", 1)[0] 107 | }) 108 | 109 | self.df = pd.DataFrame.from_dict(scenes) 110 | self.df = self.df.sort_values("scene").reset_index() 111 | 112 | def _preprocess(self, x, crop=None, stack=False): 113 | x = toGrayscale(x) 114 | 115 | if crop: 116 | x = cropCenter(x, (crop, crop)) 117 | 118 | return(x) 119 | 120 | def __len__(self): 121 | # For every patch there are actually 128*128 122 | return np.sum(df.groupby("scene").sar.nunique().values**2) 123 | 124 | # TODO: Add hard negative mining as a third dataset option. 125 | def _load_and_label(self, index): 126 | 127 | img_sar = img_as_float(imread(self.sar[0][index], as_gray=True, plugin="pil")) 128 | img_opt = img_as_float(imread(self.opt[0][index], as_gray=True, plugin="pil")) 129 | 130 | # Rescale the image to be between 0 and 1 - otherwise normalisation won't work later 131 | if self.hist_norm: 132 | img_sar = exposure.rescale_intensity(img_sar, out_range=(0, 1), in_range='dtype') 133 | img_opt = exposure.rescale_intensity(img_opt, out_range=(0, 1), in_range='dtype') 134 | 135 | if len(img_sar.shape) < 3: 136 | img_sar = np.expand_dims(img_sar, axis=2) 137 | 138 | if len(img_opt.shape) < 3: 139 | img_opt = np.expand_dims(img_opt, axis=2) 140 | 141 | name_sar = os.path.basename(self.sar[0][index]) 142 | name_opt = os.path.basename(self.opt[0][index]) 143 | 144 | y = self.labels[index].astype(np.float) 145 | 146 | return img_sar, img_opt, y, {"name_a": name_sar, "name_b": name_opt} 147 | 148 | def __getitem__(self, index): 149 | # Fix the random state so we get the same transformations 150 | if self.augmentor: 151 | self.augmentor.refresh_random_state() 152 | 153 | img_sar, img_opt, y, names = self._load_and_label(index) 154 | 155 | img_sar = self.sar_norm( self.transforms(img_sar).float() ) 156 | img_opt = self.opt_norm( self.transforms(img_opt).float() ) 157 | 158 | if self.noise: 159 | img_sar = img_sar + 0.01*torch.randn_like(img_sar) + img_sar.mean() 160 | img_opt = img_opt + 0.01*torch.randn_like(img_opt) + img_opt.mean() 161 | 162 | if self.named: 163 | return (img_sar, img_opt), y, names 164 | else: 165 | return (img_sar, img_opt), y 166 | 167 | # def get_batch_sampler(self, batch_size): 168 | # super_idxs = range(len(self.opt[0]), len(self.opt[0]) + len(self.opt[1])) 169 | # unsuper_idxs = range(0, len(self.opt[0])) 170 | 171 | # super_idxs = list(chain.from_iterable( repeat( tuple(super_idxs), self.n) )) 172 | 173 | # self.unsupervised_sampler = SubsetRandomSampler( unsuper_idxs ) 174 | # self.supervised_sampler = SubsetRandomSampler( super_idxs ) 175 | # self.batch_sampler = RoundRobinBatchSampler([self.unsupervised_sampler, self.supervised_sampler], batch_size=batch_size) 176 | 177 | # return self.batch_sampler 178 | 179 | def cropCenterT(img, bounding, shift=(0,0,0,0)): 180 | imshape = [x+y*2 for x,y in zip(img.shape, shift)] 181 | bounding = list(bounding) 182 | start = tuple(map(lambda a, da: a//2-da//2, imshape, bounding)) 183 | end = tuple(map(operator.add, start, bounding)) 184 | slices = tuple(map(slice, start, end)) 185 | return img[slices] 186 | 187 | def get_gt_heatmap(shift_x=0, shift_y=0, w=64, h=64, device="cpu", full_size=False): 188 | x = int(w//2 + shift_x) 189 | y = int(h//2 + shift_y) 190 | hm = np.zeros((h, w)) 191 | hm[y, x] = 1 192 | gt_hm = torch.Tensor(hm[np.newaxis, np.newaxis, :, :]).to(device) 193 | gt_hm = cropCenterT(gt_hm, (1,1,h,w)) 194 | if not full_size: 195 | gt_hm = F.max_pool2d(gt_hm, 3, stride=2) 196 | gt_hm = F.interpolate(gt_hm, size=(h//2+1, w//2+1), align_corners=False, mode='bilinear') 197 | gt_hm.div_(gt_hm.max()) 198 | return gt_hm.numpy()[0,] 199 | 200 | class SEN12DatasetHeatmap(Dataset): 201 | SPLITS = { 202 | "train": {"summer": (0, 0.5), "spring": (0, 1), "autumn": (0, 1)}, 203 | "val": {"summer": (0.5, 1)}, 204 | "test": {"winter": (0, 1)} 205 | } 206 | 207 | 208 | def __init__(self, config): 209 | super().__init__() 210 | 211 | self.crop_size = config.crop if isinstance(config.crop, (int, float)) else None 212 | self.crop_size_a = config.crop_a if isinstance(config.crop_a, (int, float)) else None 213 | self.crop_size_b = config.crop_b if isinstance(config.crop_b, (int, float)) else None 214 | self.named = config.named if isinstance(config.named, bool) else False 215 | self.return_all = config.return_all if isinstance(config.return_all, bool) else False 216 | self.stretch_contrast = config.stretch_contrast if isinstance(config.stretch_contrast, bool) else False 217 | self.full_size = config.full_size if isinstance(config.full_size, bool) else False 218 | 219 | self.cache_dir = config.cache_dir if isinstance(config.cache_dir, str) else None 220 | self.cache_size = config.cache_size if isinstance(config.cache_size, (int, float)) else 0 221 | 222 | if self.cache_dir is not None: 223 | self.cache = BasicCache(self.cache_dir, size=self.cache_size, scheme="fill", clear=False, overwrite=False) 224 | else: 225 | self.cache = None 226 | 227 | func = [] 228 | if config.augment: 229 | # If it is true like then just use the default augmentation parameters - this keeps things backwards compatible 230 | if config.augment is True or len(config.augment) == 0: 231 | config.augment = AUG_PROBS.copy() 232 | 233 | self.augmentor = Augmentation(probs=config.augment) 234 | else: 235 | self.augmentor = None 236 | 237 | func.append(transforms.ToTensor()) 238 | self.transforms = transforms.Compose(func) 239 | 240 | self.base_dir = config.base_dir 241 | # Only read the corresponding patches 242 | self.files = np.unique(load_file_list(None, os.path.join(config.base_dir, config.filelist))) 243 | self.dataset = self._make_dataset() 244 | 245 | def _make_dataset(self): 246 | dataset = [] 247 | 248 | for f in self.files: 249 | f = os.path.basename(f) 250 | sar = os.path.join(self.base_dir, "sar", f) 251 | opt = os.path.join(self.base_dir, "opt", f) 252 | _, season, scene, patch = os.path.splitext(f)[0].split("_") 253 | 254 | dataset.append( (sar, opt, season, scene, patch) ) 255 | 256 | return dataset 257 | 258 | 259 | # self.sen12 = DFCSEN12MSDataset(config.base_dir) 260 | # self.split = config.split if isinstance(config.split, str) else "train" 261 | # self.dataset = self._get_split(self.split) 262 | 263 | # def _get_split(self, split): 264 | # dataset = [] 265 | 266 | # for season, ratio in self.SPLITS[split].items(): 267 | # scene_ids = self.sen12.get_scene_ids(season) 268 | # start, end = int(ratio[0]*len(scene_ids)), int(ratio[1]*len(scene_ids)) 269 | 270 | # scene_ids = scene_ids[start:end] 271 | # for scene_id in scene_ids: 272 | # patch_ids = get_patch_ids(season, scene_id) 273 | # items = [(season, scene_id, patch_id) for patch_id in patch_ids] 274 | # dataset.extend(items) 275 | 276 | # return dataset 277 | 278 | def get_gt_heatmap(self, shift_x=0, shift_y=0, w=64, h=64, sigma=1): 279 | x = int(w//2 + shift_x) 280 | y = int(h//2 + shift_y) 281 | 282 | hm = np.zeros((h, w)) 283 | hm[y, x] = 1 284 | 285 | return hm[np.newaxis, :, :] 286 | 287 | def __len__(self): 288 | return len(self.dataset) 289 | 290 | def _cache_key(self, index): 291 | _, _, season, scene, patch = self.dataset[index] 292 | return f"{season}{scene}{patch}" 293 | 294 | def _try_cache(self, index): 295 | if self.cache is not None: 296 | # Try get data for the point 297 | data = self.cache[self._cache_key(index)] 298 | 299 | if data is not None: 300 | # Hack as we have a dict in a 0-d numpy array 301 | data["INFO"] = data["INFO"].item() 302 | 303 | return data 304 | 305 | def _load_and_label(self, index): 306 | data = self._try_cache(index) 307 | 308 | if not data: 309 | data = {"SAR": None, "OPT": None, "INFO": None} 310 | 311 | sar, opt, season, scene, patch = self.dataset[index] 312 | # data["SAR"], data["OPT"], _ = self.sen12.get_s1_s2_pair(season, scene, patch, s1_bands=S1Bands.VV, s2_bands=S2Bands.RGB) 313 | 314 | data["SAR"] = img_as_float(imread(sar, as_gray=True, plugin="pil")) 315 | data["OPT"] = img_as_float(imread(opt, as_gray=True, plugin="pil")) 316 | data["INFO"] = {"season": season, "scene": scene, "patch": patch} 317 | 318 | if self.cache is not None: 319 | self.cache[self._cache_key(index)] = data 320 | 321 | return data["SAR"], data["OPT"], data["INFO"] 322 | 323 | def __getitem__(self, index): 324 | if self.augmentor: 325 | self.augmentor.refresh_random_state() 326 | 327 | img_sar, img_opt, img_info = self._load_and_label(index) 328 | 329 | if self.augmentor is not None: 330 | img_sar = self.augmentor(img_sar) 331 | img_opt = self.augmentor(img_opt) 332 | 333 | assert self.crop_size_a <= img_sar.shape[1], "The input image is too small to crop" 334 | assert self.crop_size_b <= img_opt.shape[1], "The input image is too small to crop" 335 | 336 | if self.full_size: 337 | fa_sz = self.crop_size_a 338 | fb_sz = self.crop_size_b 339 | else: 340 | fa_sz = (self.crop_size_a - 6)//2 - 1 341 | fb_sz = (self.crop_size_b - 6)//2 - 1 342 | 343 | hm_size = np.abs(fa_sz - fb_sz) + 1 344 | 345 | # We already in the center, so we can only shift by half of the radius (thus / 4) 346 | max_shift = min(fa_sz//4, fb_sz//4) 347 | shift_x = (2*np.random.randint(2) - 1)*(np.random.randint(max_shift) + 1) 348 | shift_y = (2*np.random.randint(2) - 1)*(np.random.randint(max_shift) + 1) 349 | 350 | if self.crop_size_a > self.crop_size_b: 351 | if img_sar.shape[1] - self.crop_size_a > 0: 352 | # Also ensure we don't shift the keypoint out of the search region 353 | max_shift = min((fa_sz - fb_sz)//4, max_shift) 354 | max_shift_x = min((fa_sz - fb_sz)//4 - shift_x//2, max_shift) 355 | max_shift_y = min((fa_sz - fb_sz)//4 - shift_y//2, max_shift) 356 | shift_x_s = (2*np.random.randint(2) - 1)*(np.random.randint(max_shift)) 357 | shift_y_s = (2*np.random.randint(2) - 1)*(np.random.randint(max_shift)) 358 | else: 359 | shift_x_s = 0 360 | shift_y_s = 0 361 | 362 | search_img = np.ascontiguousarray(cropCenter(img_sar, (self.crop_size_a, self.crop_size_a), (shift_x_s, shift_y_s))) 363 | template_img = np.ascontiguousarray(cropCenter(img_opt, (self.crop_size_a, self.crop_size_a), (shift_x_s, shift_y_s))) 364 | search_hard = np.ascontiguousarray(cropCenter(img_sar, (self.crop_size_b, self.crop_size_b), (shift_x, shift_y))) 365 | template_hard = np.ascontiguousarray(cropCenter(img_opt, (self.crop_size_b, self.crop_size_b), (shift_x, shift_y))) 366 | 367 | if self.stretch_contrast: 368 | search_img = (search_img - search_img.min())/(search_img.ptp()) 369 | 370 | search_img = self.transforms(search_img).float() 371 | template_img = self.transforms(template_img).float() 372 | search_hard = self.transforms(search_hard).float() 373 | template_hard = self.transforms(template_hard).float() 374 | else: 375 | if img_opt.shape[1] - self.crop_size_b > 0: 376 | # Also ensure we don't shift the keypoint out of the search region 377 | max_shift_x = min((fb_sz - fa_sz)//4 - shift_x//2, max_shift) 378 | max_shift_y = min((fb_sz - fa_sz)//4 - shift_y//2, max_shift) 379 | shift_x_s = (2*np.random.randint(2) - 1)*(np.random.randint(max_shift_x)) 380 | shift_y_s = (2*np.random.randint(2) - 1)*(np.random.randint(max_shift_y)) 381 | else: 382 | shift_x_s = 0 383 | shift_y_s = 0 384 | 385 | search_img = cropCenter(img_opt, (self.crop_size_b, self.crop_size_b), (shift_x_s, shift_y_s)) 386 | template_img = cropCenter(img_sar, (self.crop_size_b, self.crop_size_b), (shift_x_s, shift_y_s)) 387 | search_hard = cropCenter(img_opt, (self.crop_size_a, self.crop_size_a), (shift_x, shift_y)) 388 | template_hard = cropCenter(img_sar, (self.crop_size_a, self.crop_size_a), (shift_x, shift_y)) 389 | 390 | if self.stretch_contrast: 391 | template_img = (template_img - template_img.min())/(template_img.ptp()) 392 | 393 | search_img = self.transforms(search_img).float() 394 | template_img = self.transforms(template_img).float() 395 | search_hard = self.transforms(search_hard).float() 396 | template_hard = self.transforms(template_hard).float() 397 | 398 | # print(f"a: {shift_x} b: {shift_y} hm: {shift_x_s} ca: {shift_y_s}") 399 | # This is dependant on the Model!!!!!!!!!! We should move this there 400 | # print("WARNING THIS IS DEPENDANT ON THE MODEL") 401 | shift_x = shift_x - shift_x_s 402 | shift_y = shift_y - shift_y_s 403 | 404 | if self.full_size: 405 | scale = 1 406 | else: 407 | scale = ((1 - 6/search_img.shape[1])*(3/2)) 408 | 409 | # y_p = self.get_gt_heatmap(w=hm_size, h=hm_size, sigma=None) 410 | y_hn = self.get_gt_heatmap(shift_x=shift_x//scale, shift_y=shift_y//scale, w=hm_size, h=hm_size, sigma=None) 411 | # y_hn = get_gt_heatmap(shift_x, shift_y, hm_size, hm_size, full_size=False) 412 | # print(f"HEATMAP: {y_hn.shape} {shift_x} {shift_y}") 413 | 414 | # y_p = y_p/y_p.max() 415 | y_hn = y_hn/y_hn.max() 416 | 417 | # y = np.array([0, shift_x, shift_y], dtype=np.float32) 418 | 419 | if self.return_all: 420 | imgs = (search_img, template_img, template_hard, search_hard) 421 | else: 422 | imgs = (search_img, template_img, template_hard) 423 | 424 | if self.named: 425 | cx = int(hm_size//2 + shift_x//scale) 426 | cy = int(hm_size//2 + shift_y//scale) 427 | img_info.update({ 428 | "p_match": (cx, cy), 429 | "shift": (shift_x, shift_y) 430 | }) 431 | 432 | return imgs, y_hn, img_info 433 | else: 434 | return imgs, y_hn -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: custom 2 | channels: 3 | - menpo 4 | - conda-forge 5 | - pytorch 6 | - defaults 7 | dependencies: 8 | - asn1crypto=0.24.0 9 | - astroid=2.2.5 10 | - atk=2.25.90 11 | - boost-cpp=1.68.0 12 | - bzip2=1.0.6 13 | - c-ares=1.15.0 14 | - ca-certificates=2019.3.9 15 | - certifi=2019.3.9 16 | - cffi=1.12.3 17 | - cfitsio=3.430 18 | - chardet=3.0.4 19 | - cloudpickle=0.6.0 20 | - cmake=3.14.4 21 | - cryptography=2.6.1 22 | - cryptography-vectors=2.6.1 23 | - cudatoolkit=9.0 24 | - curl=7.64.1 25 | - cycler=0.10.0 26 | - cython=0.29.9 27 | - cytoolz=0.9.0.1 28 | - dask-core=0.19.3 29 | - dbus=1.13.6 30 | - decorator=4.3.0 31 | - descartes=1.1.0 32 | - expat=2.2.5 33 | - ffmpeg=4.1.3 34 | - fontconfig=2.13.1 35 | - freetype=2.10.0 36 | - freexl=1.0.5 37 | - gdal=2.4.1 38 | - gdk-pixbuf=2.36.12 39 | - geojson=2.4.1 40 | - geopandas=0.5.0 41 | - geos=3.7.1 42 | - geotiff=1.5.1 43 | - gettext=0.19.8.1 44 | - giflib=5.1.7 45 | - glib=2.58.3 46 | - gmp=6.1.2 47 | - gnutls=3.6.5 48 | - gobject-introspection=1.58.2 49 | - graphite2=1.3.13 50 | - gst-plugins-base=1.14.4 51 | - gstreamer=1.14.4 52 | - gtk2=2.24.31 53 | - harfbuzz=2.4.0 54 | - hdf4=4.2.13 55 | - hdf5=1.10.4 56 | - icu=58.2 57 | - idna=2.8 58 | - imageio=2.5.0 59 | - intel-openmp=2019.0 60 | - isort=4.3.20 61 | - jasper=1.900.1 62 | - joblib=0.13.2 63 | - jpeg=9c 64 | - json-c=0.13.1 65 | - kealib=1.4.10 66 | - kiwisolver=1.1.0 67 | - krb5=1.16.3 68 | - lame=3.100 69 | - lazy-object-proxy=1.4.1 70 | - libblas=3.8.0 71 | - libcblas=3.8.0 72 | - libcurl=7.64.1 73 | - libdap4=3.19.1 74 | - libedit=3.1.20170329 75 | - libffi=3.2.1 76 | - libgcc=7.2.0 77 | - libgcc-ng=8.2.0 78 | - libgdal=2.4.1 79 | - libgfortran=3.0.0 80 | - libgfortran-ng=7.3.0 81 | - libiconv=1.15 82 | - libkml=1.3.0 83 | - liblapack=3.8.0 84 | - liblapacke=3.8.0 85 | - libnetcdf=4.6.2 86 | - libpng=1.6.37 87 | - libpq=11.3 88 | - libprotobuf=3.6.1 89 | - libsodium=1.0.16 90 | - libspatialindex=1.9.0 91 | - libspatialite=4.3.0a 92 | - libssh2=1.8.2 93 | - libstdcxx-ng=8.2.0 94 | - libtiff=4.0.10 95 | - libuuid=2.32.1 96 | - libuv=1.29.1 97 | - libwebp=1.0.2 98 | - libxcb=1.13 99 | - libxml2=2.9.9 100 | - mapclassify=2.0.1 101 | - markdown=2.6.11 102 | - matplotlib=3.1.0 103 | - matplotlib-base=3.1.0 104 | - mccabe=0.6.1 105 | - mkl=2019.3 106 | - ncurses=6.1 107 | - nettle=3.4.1 108 | - networkx=2.2 109 | - ninja=1.8.2 110 | - numpy=1.16.3 111 | - numpy-base=1.16.3 112 | - olefile=0.46 113 | - opencv3=3.1.0 114 | - openh264=1.8.0 115 | - openjpeg=2.3.1 116 | - openssl=1.1.1b 117 | - pango=1.40.14 118 | - pcre=8.41 119 | - pillow=6.0.0 120 | - pip=19.1 121 | - pixman=0.34.0 122 | - poppler=0.67.0 123 | - poppler-data=0.4.9 124 | - postgresql=11.3 125 | - proj4=6.0.0 126 | - pthread-stubs=0.4 127 | - pycparser=2.19 128 | - pylint=2.3.1 129 | - pyopenssl=19.0.0 130 | - pyparsing=2.2.2 131 | - pysocks=1.7.0 132 | - python=3.6.7 133 | - python-dateutil=2.7.3 134 | - pytorch=1.1.0 135 | - pywavelets=1.0.3 136 | - pyzmq=18.0.1 137 | - qt=5.9.7 138 | - readline=7.0 139 | - requests=2.22.0 140 | - rhash=1.3.6 141 | - rtree=0.8.3 142 | - scikit-image=0.15.0 143 | - scipy=1.3.0 144 | - setuptools=41.0.1 145 | - shapely=1.6.4 146 | - sip=4.19.8 147 | - six=1.12.0 148 | - tk=8.6.9 149 | - toolz=0.9.0 150 | - torchvision=0.3.0 151 | - tornado=6.0.2 152 | - typed-ast=1.3.5 153 | - tzcode=2018g 154 | - unzip=6.0 155 | - urllib3=1.24.3 156 | - websocket-client=0.56.0 157 | - wheel=0.33.4 158 | - wrapt=1.11.1 159 | - x264=1!152.20180806 160 | - xerces-c=3.2.2 161 | - xorg-kbproto=1.0.7 162 | - xorg-libice=1.0.9 163 | - xorg-libsm=1.2.3 164 | - xorg-libx11=1.6.7 165 | - xorg-libxau=1.0.9 166 | - xorg-libxdmcp=1.1.3 167 | - xorg-libxext=1.3.4 168 | - xorg-libxrender=0.9.10 169 | - xorg-libxt=1.1.5 170 | - xorg-renderproto=0.11.1 171 | - xorg-xextproto=7.3.0 172 | - xorg-xproto=7.0.31 173 | - xz=5.2.4 174 | - yaml=0.1.7 175 | - zeromq=4.2.5 176 | - zlib=1.2.11 177 | - zstd=1.3.3 178 | - pip: 179 | - absl-py==0.6.1 180 | - affine==2.2.1 181 | - argh==0.26.2 182 | - astor==0.7.1 183 | - attrs==18.2.0 184 | - beautifulsoup4==4.6.3 185 | - bleach==1.5.0 186 | - bolt-python==0.7.1 187 | - boto==2.49.0 188 | - boto3==1.9.43 189 | - botocore==1.12.43 190 | - bs4==0.0.1 191 | - click==7.0 192 | - click-plugins==1.0.4 193 | - cligj==0.5.0 194 | - docutils==0.14 195 | - dotmap==1.2.39 196 | - enum34==1.1.6 197 | - eo-learn==0.3.3 198 | - eo-learn-core==0.3.2 199 | - eo-learn-coregistration==0.3.2 200 | - eo-learn-features==0.3.3 201 | - eo-learn-geometry==0.3.3 202 | - eo-learn-io==0.3.3 203 | - eo-learn-mask==0.3.2 204 | - eo-learn-ml-tools==0.3.2 205 | - fiona==1.8.0 206 | - flask==1.0.2 207 | - flask-cors==3.0.7 208 | - gast==0.2.0 209 | - geomet==0.2.0.post2 210 | - gevent==1.3.7 211 | - gitdb2==2.0.5 212 | - gitpython==2.1.11 213 | - gputil==1.3.0 214 | - gql==0.1.0 215 | - graphql-core==2.1 216 | - greenlet==0.4.15 217 | - grpcio==1.16.1 218 | - h5py==2.9.0rc1 219 | - html2text==2018.1.9 220 | - html5lib==0.9999999 221 | - imgaug==0.2.6 222 | - itsdangerous==1.1.0 223 | - jinja2==2.10 224 | - jmespath==0.9.3 225 | - lightgbm==2.2.2 226 | - llvmlite==0.26.0 227 | - markupsafe==1.1.0 228 | - munch==2.3.2 229 | - nibabel==2.3.1 230 | - numba==0.41.0 231 | - nvidia-ml-py3==7.352.0 232 | - opencv-contrib-python==3.4.2.16 233 | - opencv-contrib-python-headless==3.4.3.18 234 | - pandas==0.23.4 235 | - pathtools==0.1.2 236 | - phasepack==1.5 237 | - promise==2.2.1 238 | - protobuf==3.6.1 239 | - psutil==5.4.8 240 | - pydot==1.2.4 241 | - pyfftw==0.10.4 242 | - pygit==0.1 243 | - pygments==2.2.0 244 | - pyproj==2.1.0 245 | - pytorch-ignite==0.2.0 246 | - pytz==2018.5 247 | - pyyaml==4.2b4 248 | - rasterio==1.0.8 249 | - rx==1.6.1 250 | - s2cloudless==1.2.1 251 | - s3transfer==0.1.13 252 | - seaborn==0.9.0 253 | - sentinelhub==2.4.5 254 | - sentinelsat==0.13 255 | - sentry-sdk==0.6.9 256 | - sh==1.12.14 257 | - shortuuid==0.5.0 258 | - smmap2==2.0.5 259 | - snuggs==1.4.2 260 | - subprocess32==3.5.3 261 | - tensorboardx==1.4 262 | - tensorflow==1.6.0 263 | - termcolor==1.1.0 264 | - thunder-python==1.4.2 265 | - thunder-registration==1.0.1 266 | - tifffile==2018.11.6 267 | - torchsummary==1.5.1 268 | - tqdm==4.27.0 269 | - utm==0.4.2 270 | - wandb==0.6.33 271 | - watchdog==0.9.0 272 | - werkzeug==0.14.1 273 | - visdom 274 | 275 | -------------------------------------------------------------------------------- /logger/tensorboard_logger.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | import numpy as np 5 | import torchvision.utils as vutils 6 | 7 | from tensorboardX import SummaryWriter 8 | from datetime import datetime 9 | from utils.helpers import get_learning_rate 10 | 11 | class TensorboardLogger: 12 | 13 | def __init__(self, log_every=10, log_params=False, log_dir=None, log_images=False, log_grads=False, **kwargs): 14 | current_time = datetime.now().strftime('%b%d_%H-%M-%S') 15 | self.log_dir = os.path.join(log_dir, "runs", current_time) 16 | self.writer = SummaryWriter(log_dir=self.log_dir) 17 | 18 | self.counters = {"evaluate": 0, "train": 0, "test": 0} 19 | self.epochs = {"evaluate": 0, "train": 0, "test": 0} 20 | self.log_every = log_every 21 | self.log_params = log_params if isinstance(log_params, bool) else False 22 | self.log_images = log_images if isinstance(log_images, bool) else False 23 | self.log_grads = log_grads if isinstance(log_grads, bool) else False 24 | 25 | print(f"Logger: Log parameters={log_params}, Log gradients={log_grads}") 26 | 27 | # def state_dict(self): 28 | # state = {} 29 | # state['counters'] = self.counters 30 | # state['epochs'] = self.epochs 31 | # return {'state': state} 32 | 33 | def fast_forward(self, last_epoch=0, step_per_epoch=0): 34 | step = (last_epoch+1)*step_per_epoch 35 | self.counters = {"evaluate": step, "train": step, "test": step} 36 | self.epochs = {"evaluate": last_epoch+1, "train": last_epoch+1, "test": last_epoch+1} 37 | 38 | def teardown(self): 39 | self.writer.export_scalars_to_json(os.path.join(self.log_dir, "all_scalars.json")) 40 | self.writer.close() 41 | 42 | def add_embedding(self, features, images, phase="train", stage="epoch"): 43 | step = self.epochs[phase] if stage == "epoch" else self.counters[phase] 44 | self.writer.add_embedding(features, label_img=images, global_step=step) 45 | 46 | def _plot_metrics(self, metrics, phase, step): 47 | for m_name, m_val in metrics.items(): 48 | self.writer.add_scalar("{}/{}".format(phase, m_name), m_val, step) 49 | 50 | def log_gradients(self, tag, model, phase="train", log_every=1000): 51 | if (self.log_grads is True) and (self.counters[phase] % log_every == 0): 52 | for name, param in model.named_parameters(): 53 | if param.grad is not None: 54 | self.writer.add_histogram("{}_{}".format(tag, name), param.grad.data.cpu().numpy(), self.counters[phase]) 55 | 56 | def log_preactivations(self, module, phase="train"): 57 | classname = module.__class__.__name__ 58 | 59 | def _log_preactivations(input, output): 60 | self.writer.add_histogram("{}_{}".format(classname, "forward"), output.data.cpu().numpy(), self.counters[phase]) 61 | 62 | if classname.find('Conv') != -1 or classname.find('Linear') != -1: 63 | module.register_forward_hook(_log_preactivations) 64 | 65 | def log_image_grid(self, name, images, phase="train", normalize=True): 66 | if self.log_images is True: 67 | x_rg = vutils.make_grid(images, normalize=normalize, scale_each=True) 68 | self.writer.add_image(name, x_rg, self.counters[phase]) 69 | 70 | # Method Missing - automatically assume it is for the summaryWriter 71 | def __getattr__(self, method_name): 72 | log_fn = getattr(self.writer, method_name, None) 73 | 74 | if log_fn: 75 | return log_fn 76 | else: 77 | raise AttributeError(method_name) 78 | 79 | def log_iteration(self, engine, phase="train", models=None, optims=None): 80 | # other_metrics = {} 81 | if optims: 82 | for name, optim in optims.items(): 83 | lr = get_learning_rate(optim)[0] 84 | self.writer.add_scalar("{}/{}_lr".format(phase, name), lr, self.counters[phase]) 85 | 86 | if self.counters[phase] % self.log_every == 0: 87 | self._plot_metrics(engine.state.metrics, phase, self.counters[phase]) 88 | # self._plot_metrics(other_metrics, phase, self.counters[phase]) 89 | 90 | self.counters[phase] += 1 91 | 92 | def log_epoch(self, engine, phase="train", models=None, optims=None): 93 | self._plot_metrics(engine.state.metrics, phase, self.counters[phase]) 94 | 95 | if phase == "train" and self.log_params is True: 96 | for m_name, model in models.items(): 97 | for name, param in model.named_parameters(): 98 | self.writer.add_histogram("{}_{}".format(m_name, name), param.data.cpu().numpy(), self.epochs[phase]) 99 | 100 | if phase == "evaluate": 101 | self.epochs[phase] += 1 102 | else: 103 | self.epochs[phase] = engine.state.epoch -------------------------------------------------------------------------------- /losses/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/system123/SOMatch/6f10cf28f506998a5e430ccd3faab3076fe350d5/losses/__init__.py -------------------------------------------------------------------------------- /losses/functional.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | 5 | EPS = 1e-4 6 | 7 | def mse_loss_weighted(input, target, reduction="mean", pos_weight=None): 8 | negs = input[target == 0] 9 | pos = input[target == 1] 10 | 11 | loss = torch.zeros_like(target, device=target.device) 12 | loss[target == 1] = pos_weight*F.mse_loss(pos, target[target == 1], reduction="none") 13 | loss[target == 0] = F.mse_loss(negs, target[target == 0], reduction="none") 14 | 15 | if reduction == "mean": 16 | loss = loss.mean() 17 | # loss /= (pos_weight*target[target == 1].size().numel() + target[target == 0].size().numel()) 18 | elif reduction == "sum": 19 | loss = loss.sum() 20 | 21 | return loss 22 | -------------------------------------------------------------------------------- /metrics/categorical_accuracy_one_hot.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | 3 | import torch 4 | 5 | from ignite.metrics.metric import Metric 6 | from ignite.exceptions import NotComputableError 7 | 8 | 9 | class CategoricalAccuracyOneHot(Metric): 10 | """ 11 | Calculates the categorical accuracy. 12 | `update` must receive output of the form (y_pred, y). 13 | `y_pred` must be in the following shape (batch_size, num_categories, ...) 14 | `y` must be in the following shape (batch_size, num_categories, ...) 15 | """ 16 | def reset(self): 17 | self._num_correct = 0 18 | self._num_examples = 0 19 | 20 | def update(self, output): 21 | y_pred, y = output 22 | _, indices = torch.max(y_pred, dim=1) 23 | _, labels = torch.max(y, dim=1) 24 | correct = torch.eq(indices, labels).view(-1) 25 | self._num_correct += torch.sum(correct).item() 26 | self._num_examples += correct.shape[0] 27 | 28 | def compute(self): 29 | if self._num_examples == 0: 30 | raise NotComputableError('CategoricalAccuracy must have at least one example before it can be computed') 31 | return self._num_correct / self._num_examples 32 | -------------------------------------------------------------------------------- /metrics/custom_metric.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | 3 | from ignite.exceptions import NotComputableError 4 | from ignite.metrics.metric import Metric 5 | 6 | class CustomLoss(Metric): 7 | """ 8 | Calculates the average loss according to the passed loss_fn. 9 | `loss_fn` must return the average loss over all observations in the batch. 10 | `update` must receive output of the form (y_pred, y). 11 | """ 12 | def __init__(self, loss_fn, output_transform=lambda x: x): 13 | super(CustomLoss, self).__init__(output_transform) 14 | self._loss_fn = loss_fn 15 | 16 | def reset(self): 17 | self._sum = 0 18 | self._num_examples = 0 19 | 20 | def update(self, output): 21 | y_pred, y = output 22 | average_loss = self._loss_fn(y_pred, y) 23 | assert len(average_loss.shape) == 0, '`loss_fn` did not return the average loss' 24 | self._sum += average_loss.item() * y.shape[0] 25 | self._num_examples += y.shape[0] 26 | 27 | def compute(self): 28 | if self._num_examples == 0: 29 | raise NotComputableError( 30 | 'Loss must have at least one example before it can be computed') 31 | return self._sum / self._num_examples 32 | -------------------------------------------------------------------------------- /models/corr_feature_net.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import torch.nn.functional as F 4 | import torchvision.models as models 5 | import numpy as np 6 | 7 | from torch import nn 8 | from torchvision import models 9 | 10 | from utils.modules import * 11 | 12 | class CorrelationFeatureNet(nn.Module): 13 | def __init__(self, column_depth=1024, normalize=False, return_hypercol=False, no_relu=False, attention=False, return_attn=True, attn_act="sigmoid"): 14 | super().__init__() 15 | self.normalize = normalize 16 | self.column_depth = column_depth 17 | self.return_hypercol = return_hypercol 18 | self.no_relu = no_relu 19 | self.attention = attention 20 | self.return_attn = return_attn 21 | 22 | self.stem = nn.Sequential( 23 | nn.Conv2d(1, 32, kernel_size=7, stride=1, padding=0), #0 24 | nn.BatchNorm2d(32, eps=1e-05, momentum=0.05, affine=True), #1 25 | nn.ReLU(), #2 26 | nn.MaxPool2d(3, 2), #3 27 | nn.Conv2d(32, 64, kernel_size=5, stride=1, padding=0), #4 28 | nn.BatchNorm2d(64, eps=1e-05, momentum=0.05, affine=True), #5 29 | nn.ReLU(), #6 30 | nn.MaxPool2d(3, 2), #7 31 | nn.Conv2d(64, 128, kernel_size=5, stride=1, padding=0), #8 32 | nn.BatchNorm2d(128, eps=1e-05, momentum=0.05, affine=True), #9 33 | nn.ReLU(), #10 34 | nn.MaxPool2d(3, 2), #11 35 | nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=0), #12 36 | nn.BatchNorm2d(128, eps=1e-05, momentum=0.05, affine=True), #13 37 | nn.ReLU(), #14 38 | nn.MaxPool2d(3, 2) #15 39 | ) 40 | self.build_from = [3, 7, 11, 15] 41 | 42 | squish = [nn.Conv2d(352, 256, kernel_size=1, stride=1, padding=0), 43 | nn.ReLU(), 44 | nn.Conv2d(256, 512, kernel_size=1, stride=1, padding=0), 45 | nn.ReLU(), 46 | nn.Conv2d(512, column_depth, kernel_size=1, stride=1, padding=0)] 47 | 48 | if not self.no_relu: 49 | squish.append(nn.ReLU()) 50 | 51 | self.squash = nn.Sequential(*squish) 52 | 53 | if self.attention: 54 | # self.ch_attn = ChannelAttention(self.column_depth) 55 | self.sp_attn = SpatialAttention(kernel_size=3, activation=attn_act) 56 | 57 | self.init_weights() 58 | 59 | def init_weights(self): 60 | for m in self.modules(): 61 | if isinstance(m, nn.Conv2d): 62 | nn.init.kaiming_normal_(m.weight, mode="fan_out") 63 | elif isinstance(m, nn.Linear): 64 | init_range = 1.0 / math.sqrt(m.weight.shape[1]) 65 | nn.init.uniform_(m.weight, -init_range, init_range) 66 | 67 | def _build_hypercolumn(self, ft_maps): 68 | # Get the output size we want 69 | size = ft_maps[0].shape[2:] 70 | 71 | stack = [] 72 | 73 | for ft in ft_maps: 74 | stack.append( F.interpolate(ft, size=size, mode='bilinear', align_corners=True) ) 75 | 76 | # Stack the tensors in the channel dimension 77 | return torch.cat(stack, dim=1) 78 | 79 | # Convolve two tensors together 80 | def correlation_map(self, tensor_a, tensor_b): 81 | b, c, h, w = tensor_a.shape 82 | _, _, h1, w1 = tensor_b.shape 83 | h1 = h - h1 + 1 84 | w1 = w - w1 + 1 85 | tensor_a = tensor_a.view(1, b*c, h, w) 86 | heatmap = F.conv2d(tensor_a, tensor_b, groups=b).view(b, 1, h1, w1) 87 | 88 | return heatmap 89 | 90 | def forward(self, x): 91 | y = x 92 | 93 | outputs = [] 94 | for i, block in enumerate(self.stem): 95 | y = block(y) 96 | if i in self.build_from: 97 | outputs.append(y) 98 | 99 | hypercol = self._build_hypercolumn(outputs) 100 | 101 | hypercol_reduced = self.squash(hypercol) 102 | 103 | if self.attention: 104 | # hypercol_attn = self.ch_attn(hypercol_reduced) * hypercol_reduced 105 | spatial_attn = self.sp_attn(hypercol_reduced) 106 | hypercol_reduced = spatial_attn * hypercol_reduced 107 | 108 | if self.normalize: 109 | hypercol_reduced = F.normalize(hypercol_reduced, p=2, dim=1) 110 | 111 | if self.return_hypercol: 112 | return hypercol_reduced, hypercol 113 | elif self.attention and self.return_attn: 114 | return hypercol_reduced, spatial_attn 115 | else: 116 | return hypercol_reduced 117 | -------------------------------------------------------------------------------- /models/goodness_net.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | 4 | import torch.nn.functional as F 5 | import numpy as np 6 | 7 | from torch import nn 8 | from skimage.feature import corner_peaks, peak_local_max 9 | from torchvision import models 10 | 11 | class VGGBasedGoodnessNet(nn.Module): 12 | def __init__(self, *args, **kwargs): 13 | super().__init__() 14 | 15 | self.pooling = kwargs["pooling"] if "pooling" in kwargs else "max" 16 | 17 | self.leg_a = self._make_siamese_leg() 18 | 19 | self.conv1 = nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=1) 20 | self.bn1 = nn.BatchNorm2d(128) 21 | self.conv2 = nn.Conv2d(128, 64, kernel_size=3, stride=2, padding=1) 22 | self.bn2 = nn.BatchNorm2d(64) 23 | 24 | self.fc1 = nn.Conv2d(64, 64, kernel_size=1, stride=1, padding=0) 25 | self.dropout = nn.Dropout2d(p=0.5) 26 | self.fc2 = nn.Conv2d(64, 1, kernel_size=1, stride=1, padding=0) 27 | 28 | def fine_tune(self): 29 | for param in self.leg_a.parameters(): 30 | param.requires_grad = False 31 | 32 | def _make_siamese_leg(self): 33 | def _make_vgg_block(layers=[], input_depth=1, depth=32, kernel_size=3, padding=True, max_pool=True): 34 | pad = kernel_size//2 if padding else 0 35 | 36 | layers.append( nn.Conv2d(input_depth, depth, kernel_size=kernel_size, padding=pad) ) 37 | layers.append( nn.ReLU() ) 38 | layers.append( nn.BatchNorm2d(depth) ) 39 | layers.append( nn.Conv2d(depth, depth, kernel_size=kernel_size, padding=pad) ) 40 | layers.append( nn.ReLU() ) 41 | layers.append( nn.BatchNorm2d(depth) ) 42 | if max_pool: 43 | layers.append( nn.MaxPool2d(2) ) 44 | 45 | return layers 46 | 47 | layers = [] 48 | layers = _make_vgg_block(layers=layers, input_depth=1, depth=32, kernel_size=3, padding=True, max_pool=True) 49 | layers = _make_vgg_block(layers=layers, input_depth=32, depth=64, kernel_size=3, padding=True, max_pool=True) 50 | layers = _make_vgg_block(layers=layers, input_depth=64, depth=128, kernel_size=3, padding=True, max_pool=True) 51 | layers = _make_vgg_block(layers=layers, input_depth=128, depth=128, kernel_size=3, padding=True, max_pool=False) 52 | 53 | layers.append( nn.Dropout2d(p=0.25) ) 54 | 55 | return nn.Sequential(*layers) 56 | 57 | def spatial_softnms(self, heatmap, soft_local_max_size=3): 58 | b = heatmap.size(0) 59 | pad = soft_local_max_size//2 60 | 61 | heatmap = torch.sigmoid(heatmap) 62 | 63 | max_per_sample = torch.max(heatmap.view(b, -1), dim=1)[0] 64 | 65 | exp = torch.exp(heatmap / max_per_sample.view(b, 1, 1, 1)) 66 | 67 | sum_exp = ( 68 | soft_local_max_size ** 2 * 69 | F.avg_pool2d( 70 | F.pad(exp, [pad] * 4, mode='constant', value=1.), 71 | soft_local_max_size, stride=1 72 | ) 73 | ) 74 | local_max_score = exp / sum_exp 75 | 76 | return local_max_score 77 | 78 | # Fuse two goodness maps together 79 | def fuse_goodness(self, gdo, gds, reduce="max", pool=False): 80 | if pool: 81 | gdo_p = F.avg_pool2d(gdo, 4, stride=1, padding=2) 82 | gds_p = F.avg_pool2d(gds, 4, stride=1, padding=2) 83 | 84 | gdo_p = F.interpolate(gdo_p, size=gdo.shape[2:], align_corners=True, mode='bilinear') 85 | gds_p = F.interpolate(gds_p, size=gds.shape[2:], align_corners=True, mode='bilinear') 86 | 87 | if reduce == "mean": 88 | return (gdo_p + gds_p)/2 89 | elif reduce == "min": 90 | return torch.min(gdo_p, gds_p) 91 | elif reduce == "max": 92 | return torch.max(gdo_p, gds_p) 93 | 94 | def extract_good_points(self, heatmap, input_shape, exclude_border=64, nms_k=5, peak_k=5): 95 | hm = F.pad(heatmap, (2,2,2,2), "constant", heatmap.min()) 96 | hm = F.interpolate(hm, size=input_shape, align_corners=None, mode='bilinear') 97 | 98 | # Get rid of the edge effects 99 | nms = self.spatial_softnms(hm, nms_k)[:1,:1,nms_k:-nms_k,nms_k:-nms_k] 100 | # Pad the edge back in 101 | nms = F.pad(nms, (nms_k, nms_k, nms_k, nms_k), "constant", nms.min()) 102 | nms = (nms - nms.min())/(nms.max() - nms.min()) 103 | nms = hm*nms 104 | 105 | nms_np = nms.cpu().numpy()[0,0,] 106 | 107 | pnts = peak_local_max(nms_np, footprint=np.full((peak_k,peak_k),True), exclude_border=exclude_border, 108 | indices=True, threshold_abs=0.9) 109 | fltrd = [] 110 | hm_pad = np.zeros_like(nms_np) 111 | 112 | for p in pnts: 113 | if np.any(p < exclude_border) or np.any(p > (np.array(input_shape) - exclude_border)): 114 | continue 115 | fltrd.append(p) 116 | hm_pad[p[0],p[1]] = 1 117 | 118 | return np.array(fltrd), nms, hm_pad 119 | 120 | def forward(self, x1, pool=True): 121 | x1 = self.leg_a(x1) 122 | 123 | fts = self.conv1(x1) 124 | fts = self.bn1( F.relu(fts) ) 125 | fts = self.conv2(fts) 126 | fts = self.bn2( F.relu(fts) ) 127 | 128 | fts = self.fc1(fts) 129 | fts = self.dropout(fts) 130 | fts = F.relu(fts) 131 | fts = self.fc2(fts) 132 | 133 | if pool: 134 | if self.pooling == "max": 135 | fts = F.adaptive_max_pool2d(fts, 1) 136 | else: 137 | fts = F.adaptive_avg_pool2d(fts, 1) 138 | 139 | fts = fts.view(fts.size(0), -1) 140 | 141 | return fts -------------------------------------------------------------------------------- /models/outlier_reduction_net.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | from torch import nn 5 | from torchvision import models 6 | import math 7 | 8 | class ORN(nn.Module): 9 | def __init__(self, classes=1, padding=False): 10 | super().__init__() 11 | self.classes = classes 12 | 13 | if padding: 14 | pad = [3, 2, 2, 1] 15 | else: 16 | pad = [0, 0, 0, 0] 17 | 18 | self.stem = nn.Sequential( 19 | nn.Conv2d(1, 32, kernel_size=7, stride=1, padding=pad[0]), #0 20 | nn.InstanceNorm2d(32, eps=1e-05, momentum=0.05, affine=True), #1 21 | nn.ReLU(), #2 22 | nn.MaxPool2d(3, 2), #3b 23 | nn.Conv2d(32, 64, kernel_size=5, stride=1, padding=pad[1]), #4 24 | nn.BatchNorm2d(64, eps=1e-05, momentum=0.05, affine=True), #5 25 | nn.ReLU(), #6 26 | nn.MaxPool2d(3, 2), #7c 27 | nn.Conv2d(64, 128, kernel_size=5, stride=1, padding=pad[2]), #8 28 | nn.BatchNorm2d(128, eps=1e-05, momentum=0.05, affine=True), #9 29 | nn.ReLU(), #10 30 | nn.MaxPool2d(3, 2), #11d 31 | nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=pad[3]), #12 32 | nn.BatchNorm2d(128, eps=1e-05, momentum=0.05, affine=True), #13 33 | nn.ReLU() #14 34 | ) 35 | 36 | self.head = nn.Sequential( 37 | nn.Conv2d(128, 128, kernel_size=1, stride=1, padding=0), 38 | nn.ReLU(), 39 | nn.Conv2d(128, self.classes, kernel_size=1, stride=1, padding=0) 40 | ) 41 | 42 | self.init_weights() 43 | 44 | def init_weights(self): 45 | for m in self.modules(): 46 | if isinstance(m, nn.Conv2d): 47 | nn.init.kaiming_normal_(m.weight, mode="fan_out") 48 | elif isinstance(m, nn.Linear): 49 | init_range = 1.0 / math.sqrt(m.weight.shape[1]) 50 | nn.init.uniform_(m.weight, -init_range, init_range) 51 | 52 | def rescale_heatmap(self, heatmap, shape=(129, 129), pad=True, norm=False): 53 | if norm: 54 | heatmap = (heatmap - heatmap.min())/(heatmap.max() - heatmap.min()) 55 | 56 | scaled_heatmap = F.interpolate(heatmap, size=shape, align_corners=True, mode='bilinear') 57 | 58 | if pad: 59 | pad_x = shape[1]//2 60 | pad_y = shape[0]//2 61 | scaled_heatmap = F.pad(scaled_heatmap, (pad_y, pad_y-1, pad_x, pad_x-1), "constant", heatmap.min()) 62 | 63 | return scaled_heatmap 64 | 65 | def forward(self, x, pool=True): 66 | y = self.stem(x) 67 | y = self.head(y) 68 | 69 | if pool: 70 | y = F.adaptive_avg_pool2d(y, 1) 71 | y = y.view(y.shape[0], -1) 72 | 73 | return y 74 | -------------------------------------------------------------------------------- /optimizers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/system123/SOMatch/6f10cf28f506998a5e430ccd3faab3076fe350d5/optimizers/__init__.py -------------------------------------------------------------------------------- /samplers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/system123/SOMatch/6f10cf28f506998a5e430ccd3faab3076fe350d5/samplers/__init__.py -------------------------------------------------------------------------------- /samplers/round_robin_batch_sampler.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Sampler, BatchSampler 2 | from torch._six import int_classes as _int_classes 3 | from itertools import islice, cycle 4 | 5 | # Samples batches in an alternating manner when there are different sections to a dataset, drop_last is True by default 6 | # If samplers are of different length then the round robin will only occur while there are sufficient samples remaining, after which it will return from the longer sampler 7 | # list(RoundRobinBatchSampler([SubsetRandomSampler(range(10,20)), SubsetRandomSampler(range(100,120))], batch_size=3)) 8 | # [[15, 10, 11], [119, 103, 104], [14, 17, 16], [107, 109, 102], [18, 12, 19], [112, 105, 114], [117, 108, 100], [111, 116, 101], [118, 113, 110]] 9 | class RoundRobinBatchSampler(Sampler): 10 | def __init__(self, samplers, batch_size): 11 | for sampler in samplers: 12 | if not isinstance(sampler, Sampler): 13 | raise ValueError("sampler should be an instance of torch.utils.data.Sampler, but got sampler={}".format(sampler)) 14 | if not isinstance(batch_size, _int_classes) or isinstance(batch_size, bool) or batch_size <= 0: 15 | raise ValueError("batch_size should be a positive integeral value, but got batch_size={}".format(batch_size)) 16 | 17 | self.samplers = samplers 18 | self.batch_size = batch_size 19 | self.batch_samplers = [BatchSampler(sampler, self.batch_size, True) for sampler in self.samplers] 20 | 21 | def __iter__(self): 22 | num_active = len(self.batch_samplers) 23 | nexts = cycle(iter(it).__next__ for it in self.batch_samplers) 24 | 25 | while num_active: 26 | try: 27 | for next in nexts: 28 | yield next() 29 | except StopIteration: 30 | num_active -= 1 31 | nexts = cycle(islice(nexts, num_active)) 32 | 33 | def __len__(self): 34 | return sum([len(sampler) // self.batch_size for sampler in self.samplers]) -------------------------------------------------------------------------------- /schedulers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/system123/SOMatch/6f10cf28f506998a5e430ccd3faab3076fe350d5/schedulers/__init__.py -------------------------------------------------------------------------------- /schedulers/functional.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | 4 | def custom_scheduler(factor=0.1, patience=10, min_lr=1e-5): 5 | return lambda i: 1e-3 6 | 7 | def smith_1cycle(total_iter=1000, max_lr=1e-2, min_lr=1e-3, anneal_pct=10, anneal_div=100): 8 | anneal_iter = math.ceil(total_iter * anneal_pct/100.0) 9 | total_iter -= anneal_iter 10 | 11 | step_size = float(total_iter//2) 12 | 13 | piecewise_linear = lambda y0, y1, x, dx: (y1 - y0) * abs(float(x)/dx) + y0 14 | triangle = lambda y0, y1, x, dx: (y1 - y0) * (1 - abs(float(x)/dx - 1)) + y0 15 | 16 | def next(itr): 17 | if itr <= total_iter: 18 | lr = triangle(min_lr, max_lr, itr, step_size) 19 | elif itr - total_iter <= anneal_iter: 20 | lr = piecewise_linear(min_lr, min_lr/anneal_div, itr - total_iter, anneal_iter) 21 | else: 22 | lr = min_lr/anneal_div 23 | 24 | return lr 25 | 26 | return next 27 | 28 | def lr_finder(min_lr=1e-6, max_lr=1, lr_multiplier=1.1, iter_div=1): 29 | max_itr = int(np.log(max_lr/min_lr)/np.log(lr_multiplier)) 30 | 31 | def next(itr): 32 | i = min(itr//iter_div, max_itr) 33 | lr = min_lr * (lr_multiplier**i) 34 | return lr 35 | 36 | return next 37 | 38 | def cyclical_lr(step_sz=2000, min_lr=0.001, max_lr=1, mode='triangular', scale_func=None, scale_md='cycles', gamma=1.): 39 | """implements a cyclical learning rate policy (CLR). 40 | Notes: the learning rate of optimizer should be 1 41 | 42 | Parameters: 43 | ---------- 44 | mode : str, optional 45 | one of {triangular, triangular2, exp_range}. 46 | scale_md : str, optional 47 | {'cycles', 'iterations'}. 48 | gamma : float, optional 49 | constant in 'exp_range' scaling function: gamma**(cycle iterations) 50 | 51 | Examples: 52 | -------- 53 | >>> # the learning rate of optimizer should be 1 54 | >>> optimizer = optim.SGD(model.parameters(), lr=1.) 55 | >>> step_size = 2*len(train_loader) 56 | >>> clr = cyclical_lr(step_size, min_lr=0.001, max_lr=0.005) 57 | >>> scheduler = lr_scheduler.LambdaLR(optimizer, [clr]) 58 | >>> # some other operations 59 | >>> scheduler.step() 60 | >>> optimizer.step() 61 | """ 62 | if scale_func == None: 63 | if mode == 'triangular': 64 | scale_fn = lambda x: 1. 65 | scale_mode = 'cycles' 66 | elif mode == 'triangular2': 67 | scale_fn = lambda x: 1 / (2.**(x - 1)) 68 | scale_mode = 'cycles' 69 | elif mode == 'exp_range': 70 | scale_fn = lambda x: gamma**(x) 71 | scale_mode = 'iterations' 72 | else: 73 | raise ValueError(f'The {mode} is not valid value!') 74 | else: 75 | scale_fn = scale_func 76 | scale_mode = scale_md 77 | 78 | lr_lambda = lambda iters: min_lr + (max_lr - min_lr) * rel_val(iters, step_sz, scale_mode) 79 | 80 | def rel_val(iteration, stepsize, mode): 81 | cycle = math.floor(1 + iteration / (2 * stepsize)) 82 | x = abs(iteration / stepsize - 2 * cycle + 1) 83 | if mode == 'cycles': 84 | return max(0, (1 - x)) * scale_fn(cycle) 85 | elif mode == 'iterations': 86 | return max(0, (1 - x)) * scale_fn(iteration) 87 | else: 88 | raise ValueError(f'The {scale_mode} is not valid value!') 89 | 90 | return lr_lambda 91 | 92 | # if __name__=="__main__": 93 | # fn = smith_1cycle(total_iter=100, max_lr=1e-2, min_lr=1e-3, anneal_pct=50, anneal_div=100) 94 | # lrs = [fn(i) for i in range(0, 105)] 95 | # print(lrs) -------------------------------------------------------------------------------- /schedulers/smith_1cycle_lr.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch.optim import Optimizer 4 | 5 | class Smith1CycleLR: 6 | def __init__(self, optimizer, total_iter=1000, max_lr=1e-2, min_lr=1e-3, anneal_pct=10, anneal_lr_div=100, max_momentum=0.95, min_momentum=0.85, momentum_name='momentum'): 7 | 8 | if not isinstance(optimizer, Optimizer): 9 | raise TypeError('{} is not an Optimizer'.format(type(optimizer).__name__)) 10 | 11 | self.optimizer = optimizer 12 | 13 | self.momentum_name = momentum_name 14 | 15 | self.lr_fcn = self.smith_1cycle(total_iter, max_lr, min_lr, anneal_pct, anneal_lr_div) 16 | self.mom_fcn = self.smith_1cycle(total_iter, min_momentum, max_momentum, anneal_pct, 1) 17 | 18 | self.last_epoch = -1 19 | 20 | self.step(self.last_epoch + 1) 21 | 22 | def smith_1cycle(self, total_iter=1000, max_lr=1e-2, min_lr=1e-3, anneal_pct=10, anneal_div=100): 23 | anneal_iter = math.ceil(total_iter * anneal_pct/100.0) 24 | total_iter -= anneal_iter 25 | 26 | step_size = float(total_iter//2) 27 | 28 | piecewise_linear = lambda y0, y1, x, dx: (y1 - y0) * float(x)/dx + y0 29 | triangle = lambda y0, y1, x, dx: (y1 - y0) * (1 - abs(float(x)/dx - 1)) + y0 30 | 31 | def next(itr): 32 | if itr <= total_iter: 33 | lr = triangle(min_lr, max_lr, itr, step_size) 34 | else: 35 | lr = piecewise_linear(min_lr, min_lr/anneal_div, itr - total_iter, anneal_iter) 36 | 37 | return lr 38 | 39 | return next 40 | 41 | def state_dict(self): 42 | return {key: value for key, value in self.__dict__.items() if key != 'optimizer'} 43 | 44 | def load_state_dict(self, state_dict): 45 | self.__dict__.update(state_dict) 46 | 47 | def get_lr(self): 48 | return self.lr_fcn(self.last_epoch) 49 | 50 | def get_momentum(self): 51 | return self.mom_fcn(self.last_epoch) 52 | 53 | def step(self, epoch=None): 54 | if epoch is None: 55 | epoch = self.last_epoch + 1 56 | 57 | self.last_epoch = epoch 58 | lr = self.get_lr() 59 | momentum = self.get_momentum() 60 | 61 | for param_group in self.optimizer.param_groups: 62 | param_group['lr'] = lr 63 | 64 | if self.momentum_name in param_group: 65 | param_group[self.momentum_name] = momentum -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | from argparse import ArgumentParser 4 | 5 | import pandas as pd 6 | 7 | import torch 8 | from torch.utils.data import DataLoader 9 | 10 | from ignite.engine.engine import Engine, State, Events 11 | from ignite.handlers import ModelCheckpoint, EarlyStopping 12 | from ignite._utils import convert_tensor 13 | 14 | from utils import Experiment 15 | from utils.factory import * 16 | from utils.helpers import static_vars 17 | 18 | import matplotlib.pyplot as plt 19 | import numpy as np 20 | import scipy.io as sio 21 | 22 | from tqdm import tqdm 23 | from utils.helpers import BinaryClassificationMeter, accuracy 24 | 25 | logging.basicConfig(level=logging.INFO, format='') 26 | logger = logging.getLogger() 27 | 28 | def save_image(tensor, fname, cmap=plt.cm.jet): 29 | data = tensor.to("cpu").numpy().squeeze(0).squeeze(0) 30 | plt.imsave(fname, data, cmap=cmap) 31 | 32 | def save_numpy(tensor, fname): 33 | data = tensor.to("cpu").numpy() 34 | np.save(fname, data) 35 | 36 | def main(config, dataset="test"): 37 | assert validate_config(config), "ERROR: Config file is invalid. Please see log for details." 38 | 39 | logger.info("INFO: {}".format(config.toDict())) 40 | 41 | if config.device == "cpu" and torch.cuda.is_available(): 42 | logger.warning("WARNING: Not using the GPU") 43 | 44 | if "cuda" in config.device: 45 | config.device = "cuda" 46 | 47 | assert dataset in config.datasets, "ERROR: Not test dataset is specified in the config. Don't know how to proceed." 48 | 49 | logger.info("INFO: Creating datasets and dataloaders...") 50 | 51 | config.datasets[dataset].update({'shuffle': False, 'augment': False, 'workers': 1}) 52 | config.datasets[dataset].update({'batch_size': 1, "named": True}) 53 | 54 | meter = BinaryClassificationMeter() 55 | 56 | # Create the training dataset 57 | dset_test = create_dataset(config.datasets[dataset]) 58 | 59 | loader_test = get_data_loader(dset_test, config.datasets[dataset]) 60 | 61 | logger.info("INFO: Running inference on {} samples".format(len(dset_test))) 62 | 63 | cp_paths = None 64 | last_epoch = 0 65 | checkpoint_dir = config.result_dir 66 | if 'checkpoint' in config: 67 | checkpoint_dir = config.checkpoint_dir if 'checkpoint_dir' in config else config.result_path 68 | cp_paths, last_epoch = config.get_checkpoints(path=checkpoint_dir, tag=config.checkpoint) 69 | print(f"Found checkpoint {cp_paths} for epoch {last_epoch}") 70 | 71 | models = {} 72 | for name, model in config.model.items(): 73 | logger.info("INFO: Building the {} model".format(name)) 74 | models[name] = build_model(model) 75 | 76 | # Load the checkpoint 77 | if name in cp_paths: 78 | models[name].load_state_dict( torch.load( cp_paths[name] ) ) 79 | logger.info("INFO: Loaded model {} checkpoint {}".format(name, cp_paths[name])) 80 | 81 | models[name].to(config.device) 82 | print(models[name]) 83 | 84 | if 'debug' in config and config.debug is True: 85 | print("*********** {} ************".format(name)) 86 | for name, param in models[name].named_parameters(): 87 | if param.requires_grad: 88 | print(name, param.data) 89 | 90 | losses = {} 91 | for name, fcns in config.loss.items(): 92 | losses[name] = [] 93 | for l in fcns: 94 | losses[name].append( get_loss(l) ) 95 | assert losses[name][-1], "Loss function {} for {} could not be found, please check your config".format(l, name) 96 | 97 | exp_logger = None 98 | if 'logger' in config: 99 | logger.info("INFO: Initialising the experiment logger") 100 | exp_logger = get_experiment_logger(config.result_path, config.logger) 101 | 102 | logger.info("INFO: Creating training manager and configuring callbacks") 103 | trainer = get_trainer(models, None, losses, None, config) 104 | 105 | evaluator_engine = Engine(trainer.evaluate) 106 | 107 | trainer.attach("test_loader", loader_test) 108 | trainer.attach("evaluation_engine", evaluator_engine) 109 | 110 | logger.info("INFO: Starting inference...") 111 | 112 | results = [] 113 | 114 | save_path = os.path.join(config.checkpoint_dir, f"inference_{last_epoch}", dataset) 115 | os.makedirs(save_path, exist_ok=True) 116 | 117 | with torch.no_grad(): 118 | for i, (xs, ys, names) in enumerate(tqdm(loader_test)): 119 | batch = (xs, ys) 120 | 121 | entity = { 122 | "wkt": names["WKT"][0], 123 | "city": names["city"][0] 124 | } 125 | 126 | filename = "{}_{}".format(names["city"], names["WKT"]) 127 | 128 | xs, ys, y_pred = trainer.infer_batch(batch) 129 | 130 | ys = ys[0] 131 | 132 | loss = trainer.loss_fn(y_pred, ys) 133 | 134 | meter.update(torch.sigmoid(y_pred).to("cpu"), ys.to("cpu")) 135 | 136 | entity["loss"] = loss.to("cpu").numpy() 137 | entity["y"] = ys.to("cpu").numpy()[0][0] 138 | entity["y_pred"] = torch.sigmoid(y_pred).to("cpu").numpy()[0][0] 139 | 140 | # Save heatmaps 141 | save_image( xs[0], os.path.join(save_path, f"{filename}_x.png"), plt.cm.gray ) 142 | 143 | results.append(entity) 144 | 145 | if i % 1000 == 0: 146 | df = pd.DataFrame.from_dict(results) 147 | df.to_csv(os.path.join(config.checkpoint_dir, "checkpoint_{}_inference_bce_dset_{}.csv".format(last_epoch, dataset)) , index=None) 148 | 149 | print(f"Accuracy: {meter.acc} Precision: {meter.pre} Recall: {meter.rec}") 150 | 151 | df = pd.DataFrame.from_dict(results) 152 | df.to_csv(os.path.join(config.checkpoint_dir, "checkpoint_{}_inference_bce_dset_{}.csv".format(last_epoch, dataset)) , index=None) 153 | config.save() 154 | 155 | if __name__ == "__main__": 156 | parser = ArgumentParser() 157 | parser.add_argument('-c', '--config', default=None, type=str, required=True, help='config file path (default: None)') 158 | parser.add_argument('--checkpoint', default=None, type=str, help='Checkpoint tag to reload') 159 | parser.add_argument('--checkpoint_dir', default=None, type=str, help='Checkpoint directory to reload') 160 | parser.add_argument('--dataset', default="test", type=str, help="Which dataset to test on") 161 | args = parser.parse_args() 162 | 163 | OVERLOADABLE = ['checkpoint', 'epochs', 'checkpoint_dir', 'resume_from'] 164 | 165 | overloaded = {} 166 | for k, v in vars(args).items(): 167 | if (k in OVERLOADABLE) and (v is not None): 168 | overloaded[k] = v 169 | 170 | config = Experiment.load_from_path(args.config, overloaded) 171 | 172 | print(config.checkpoint) 173 | 174 | assert config, "Config could not be loaded." 175 | 176 | main(config, args.dataset) 177 | -------------------------------------------------------------------------------- /tools/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/system123/SOMatch/6f10cf28f506998a5e430ccd3faab3076fe350d5/tools/__init__.py -------------------------------------------------------------------------------- /tools/compare_all.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | import numpy as np 4 | import pandas as pd 5 | import matplotlib.pyplot as plt 6 | 7 | from scipy.spatial.distance import cdist 8 | from argparse import ArgumentParser 9 | from sklearn.manifold import TSNE 10 | 11 | parser = ArgumentParser() 12 | parser.add_argument("src", type=str, help="File of descriptors and image names") 13 | parser.add_argument("--normalise", action="store_true", help="Normalise the descriptors") 14 | parser.add_argument("--metric", type=str, default="euclidean", help="Distance function to use [euclidean, matching, cosine, correlation, hamming]") 15 | args = parser.parse_args() 16 | 17 | def normalise(x): 18 | return x/np.linalg.norm(x) 19 | 20 | def string2ndarray(x, dtype=np.float32): 21 | # Remove BS which pandas adds to numpy array string 22 | x = x.replace("\n","").replace("[","").replace("]","").replace(",","") 23 | x = re.sub('\s+', ' ', x).strip().split(" ") 24 | return np.asfarray(x, dtype) 25 | 26 | def extract_img_id(x): 27 | return x.rsplit('_', 1)[0] 28 | 29 | df = pd.read_csv(args.src) 30 | 31 | # Get the data back into the format we want 32 | df['opt_id'] = df['opt'].apply(extract_img_id) 33 | df['sar_id'] = df['sar'].apply(extract_img_id) 34 | 35 | # indices = [i for i, s in enumerate(mylist) if 'aa' in s] 36 | 37 | df['z_sar'] = df['z_sar'].apply(string2ndarray) 38 | df['z_opt'] = df['z_opt'].apply(string2ndarray) 39 | 40 | dfnm = df.copy() 41 | 42 | df = df.loc[df['sar'] == df['opt']] 43 | dfnm = dfnm.loc[dfnm['sar'] != dfnm['opt']] 44 | df = df.reset_index(drop=True) 45 | dfnm = dfnm.reset_index(drop=True) 46 | 47 | z_sar = np.stack(df['z_sar'].values) 48 | z_opt = np.stack(df['z_opt'].values) 49 | 50 | z_sarnm = np.stack(dfnm['z_sar'].values) 51 | z_optnm = np.stack(dfnm['z_opt'].values) 52 | 53 | if args.normalise: 54 | z_sar = np.apply_along_axis(normalise ,1 , z_sar) 55 | z_opt = np.apply_along_axis(normalise ,1 , z_opt) 56 | z_sarnm = np.apply_along_axis(normalise ,1 , z_sarnm) 57 | z_optnm = np.apply_along_axis(normalise ,1 , z_optnm) 58 | 59 | dists = cdist(z_sar, z_opt, metric=args.metric) 60 | 61 | plt.imshow(dists, cmap="jet") 62 | plt.show() 63 | 64 | idxs = np.zeros(dists.shape[0]) 65 | 66 | for i, row in enumerate(dists): 67 | order = np.argsort(row) 68 | idx = np.argwhere(order == i)[0] 69 | idxs[i] = idx 70 | 71 | top_n = np.zeros(dists.shape[0]) 72 | 73 | for i in range(25): 74 | top_n[i] = np.sum(idxs < i+1) 75 | print(f"Top {i+1}: {np.round(top_n[i]/len(idxs)*100, 2)}") 76 | 77 | import code 78 | code.interact(local=locals()) 79 | 80 | MAX = 100 81 | 82 | Zo_2d = TSNE(n_components=2).fit_transform(z_opt[:MAX]) 83 | Zs_2d = TSNE(n_components=2).fit_transform(z_sar[:MAX]) 84 | # https://www.kaggle.com/gaborvecsei/plants-t-sne 85 | from matplotlib import pyplot as plt 86 | plt.figure(figsize=(6, 5)) 87 | colors = ['r', 'g', 'b', 'c', 'm', 'y', 'k', 'w', 'orange', 'purple'] 88 | for i, (sar, opt) in enumerate(zip(Zs_2d, Zo_2d)): 89 | plt.scatter(sar[0], sar[1], c=colors[i%10], label="s_{}".format(df["sar"].values[i])) 90 | plt.scatter(opt[0], opt[1], c=colors[i%10], label="o_{}".format(df["opt"].values[i])) 91 | 92 | plt.legend() 93 | plt.show() 94 | 95 | # for i, c, label in zip(target_ids, colors, df["opt"].values): 96 | # plt.scatter(Z_2d[i, 0], Z_2d[i, 1], c=c, label=label) 97 | # plt.legend() 98 | # plt.show() 99 | 100 | # import code 101 | # code.interact(local=locals()) 102 | 103 | # print(df.head()) -------------------------------------------------------------------------------- /tools/create_file_list.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pandas as pd 3 | import numpy as np 4 | from sklearn.utils import shuffle 5 | import argparse 6 | import glob 7 | import shutil 8 | from tqdm import tqdm 9 | 10 | from skimage import io 11 | 12 | def get_image_collection(src_dir, ext=None): 13 | if ext is None: 14 | ext = "" 15 | else: 16 | ext = ".{}".format(ext) 17 | 18 | files = io.ImageCollection(os.path.join(src_dir, "*{}".format(ext)), conserve_memory=False) 19 | 20 | return(files) 21 | 22 | def make_dataset_lists(df, dest, ltype, no_negs=False): 23 | # parent, _ = os.path.split(dest.rstrip(os.sep)) 24 | parent = dest 25 | la_name = os.path.join(parent, "list.{}.opt.txt".format(ltype)) 26 | lb_name = os.path.join(parent, "list.{}.sar.txt".format(ltype)) 27 | 28 | f_a = df['files_a'].values.tolist() 29 | f_b = df['files_b'].values.tolist() 30 | f_c = df['files_c'].values.tolist() 31 | 32 | # Shuffle the file lists to make the dataset more diverse 33 | # f_a_shuf, f_c = f_a, f_c 34 | # f_mirror, f_orig = shuffle(f_mirror, f_orig) 35 | 36 | if no_negs: 37 | zip_a = zip(f_a) 38 | zip_b = zip(f_b) 39 | else: 40 | zip_a = zip(f_a, f_a) 41 | zip_b = zip(f_b, f_c) 42 | 43 | # Doing it this way ensures the lists stay balanced, 1 pos + 1 neg ... 44 | # Otherwise if we created the list and then shuffled the dataset could become unbalanced 45 | list_a = [val for pair in zip_a for val in pair] 46 | list_b = [val for pair in zip_b for val in pair] 47 | 48 | with open(la_name, "w") as f: 49 | list_a = map(lambda x: x + '\n', list_a) 50 | f.writelines(list_a) 51 | 52 | with open(lb_name, "w") as f: 53 | list_b = map(lambda x: x + '\n', list_b) 54 | f.writelines(list_b) 55 | 56 | if __name__=="__main__": 57 | parser = argparse.ArgumentParser() 58 | parser.add_argument("src_a", help="The optical dataset") 59 | parser.add_argument("src_b", help="The sar dataset") 60 | parser.add_argument("dest", help="Base directory to save file lists to") 61 | parser.add_argument("--src_c", default=None, help="Negative dataset if not created randomly") 62 | parser.add_argument("--type", default="train", help="Addtional file list identifier") 63 | parser.add_argument("--ext", default="png", help="File extension") 64 | parser.add_argument("--no_negs", action="store_true", help="Don't create negative pairs, just match the files as they are in the folders") 65 | args = parser.parse_args() 66 | 67 | files_a = get_image_collection(args.src_a, ext=args.ext).files 68 | files_b = get_image_collection(args.src_b, ext=args.ext).files 69 | print(len(files_b)) 70 | df = pd.DataFrame.from_dict({'files_a': files_a}) 71 | df['files_b'] = files_b 72 | 73 | if args.src_c: 74 | files_c = get_image_collection(args.src_c, ext=args.ext).files 75 | print(len(files_c)) 76 | df['files_c'] = files_c 77 | else: 78 | df['files_c'] = shuffle(files_b) 79 | 80 | print("# matching negative items {}".format(len(df.loc[df['files_b'] == df['files_c']]))) 81 | 82 | make_dataset_lists(df, args.dest, args.type, args.no_negs) 83 | -------------------------------------------------------------------------------- /tools/dfc_sen12ms_dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | Routines for loading the SEN12MS dataset of corresponding Sentinel-1, Sentinel-2 3 | and simplified IGBP landcover for the 2020 IEEE GRSS Data Fusion Contest. 4 | 5 | The SEN12MS class is meant to provide a set of helper routines for loading individual 6 | image patches as well as triplets of patches from the dataset. These routines can easily 7 | be wrapped or extended for use with many Deep Learning frameworks or as standalone helper 8 | methods. For an example use case please see the "main" routine at the end of this file. 9 | 10 | NOTE: Some folder/file existence and validity checks are implemented but it is 11 | by no means complete. 12 | 13 | Author: Lloyd Hughes (lloyd.hughes@tum.de) 14 | """ 15 | 16 | import os 17 | import rasterio 18 | 19 | import numpy as np 20 | 21 | from enum import Enum 22 | from glob import glob 23 | 24 | class S1Bands(Enum): 25 | VV = 1 26 | VH = 2 27 | ALL = [VV, VH] 28 | NONE = None 29 | 30 | 31 | class S2Bands(Enum): 32 | B01 = aerosol = 1 33 | B02 = blue = 2 34 | B03 = green = 3 35 | B04 = red = 4 36 | B05 = re1 = 5 37 | B06 = re2 = 6 38 | B07 = re3 = 7 39 | B08 = nir1 = 8 40 | B08A = nir2 = 9 41 | B09 = vapor = 10 42 | B10 = cirrus = 11 43 | B11 = swir1 = 12 44 | B12 = swir2 = 13 45 | ALL = [B01, B02, B03, B04, B05, B06, B07, B08, B08A, B09, B10, B11, B12] 46 | RGB = [B04, B03, B02] 47 | NONE = None 48 | 49 | 50 | class LCBands(Enum): 51 | LC = lc = 0 52 | DFC = dfc = 1 53 | ALL = [DFC] 54 | NONE = None 55 | 56 | 57 | class Seasons(Enum): 58 | SPRING = "ROIs1158_spring" 59 | SUMMER = "ROIs1868_summer" 60 | FALL = "ROIs1970_fall" 61 | WINTER = "ROIs2017_winter" 62 | TESTSET = "ROIs0000_test" 63 | VALSET = "ROIs0000_validation" 64 | TEST = [TESTSET] 65 | VALIDATION = [VALSET] 66 | TRAIN = [SPRING, SUMMER, FALL, WINTER] 67 | ALL = [SPRING, SUMMER, FALL, WINTER, VALIDATION, TEST] 68 | 69 | 70 | class Sensor(Enum): 71 | s1 = "s1" 72 | s2 = "s2" 73 | lc = "lc" 74 | dfc = "dfc" 75 | 76 | # Remapping IGBP classes to simplified DFC classes 77 | IGBP2DFC = np.array([0, 1, 1, 1, 1, 1, 2, 2, 3, 3, 4, 5, 6, 7, 6, 8, 9, 10]) 78 | 79 | # Note: The order in which you request the bands is the same order they will be returned in. 80 | class DFCSEN12MSDataset: 81 | def __init__(self, base_dir): 82 | self.base_dir = base_dir 83 | 84 | if not os.path.exists(self.base_dir): 85 | raise Exception("The specified base_dir for SEN12MS dataset does not exist") 86 | 87 | 88 | def get_scene_ids(self, season): 89 | """ 90 | Returns a list of scene ids for a specific season. 91 | """ 92 | 93 | season = Seasons(season).value 94 | path = os.path.join(self.base_dir, season) 95 | 96 | if not os.path.exists(path): 97 | raise NameError("Could not find season {} in base directory {}".format(season, self.base_dir)) 98 | 99 | scene_list = [os.path.basename(s) for s in glob(os.path.join(path, "*"))] 100 | scene_list = [int(s.split('_')[1]) for s in scene_list] 101 | return set(scene_list) 102 | 103 | 104 | def get_patch_ids(self, season, scene_id, sensor=Sensor.s1, ext="tif"): 105 | """ 106 | Returns a list of patch ids for a specific scene within a specific season 107 | """ 108 | season = Seasons(season).value 109 | path = os.path.join(self.base_dir, season, f"{sensor.value}_{scene_id}") 110 | 111 | if not os.path.exists(path): 112 | raise NameError("Could not find scene {} within season {}".format(scene_id, season)) 113 | 114 | patch_ids = [os.path.splitext(os.path.basename(p))[0] for p in glob(os.path.join(path, f"*.{ext}"))] 115 | patch_ids = [int(p.rsplit("_", 1)[1].split("p")[1]) for p in patch_ids] 116 | 117 | return patch_ids 118 | 119 | 120 | def get_season_ids(self, season): 121 | """ 122 | Return a dict of scene ids and their corresponding patch ids. 123 | key => scene_ids, value => list of patch_ids 124 | """ 125 | season = Seasons(season).value 126 | ids = {} 127 | scene_ids = self.get_scene_ids(season) 128 | 129 | for sid in scene_ids: 130 | ids[sid] = self.get_patch_ids(season, sid) 131 | 132 | return ids 133 | 134 | 135 | def get_patch(self, season, scene_id, patch_id, bands, ext="tif"): 136 | """ 137 | Returns raster data and image bounds for the defined bands of a specific patch 138 | This method only loads a sinlge patch from a single sensor as defined by the bands specified 139 | """ 140 | season = Seasons(season).value 141 | sensor = None 142 | 143 | if not bands: 144 | return None, None 145 | 146 | if isinstance(bands, (list, tuple)): 147 | b = bands[0] 148 | else: 149 | b = bands 150 | 151 | if isinstance(b, S1Bands): 152 | sensor = Sensor.s1.value 153 | bandEnum = S1Bands 154 | elif isinstance(b, S2Bands): 155 | sensor = Sensor.s2.value 156 | bandEnum = S2Bands 157 | elif isinstance(b, LCBands): 158 | if LCBands(bands) == LCBands.LC: 159 | sensor = Sensor.lc.value 160 | else: 161 | sensor = Sensor.dfc.value 162 | 163 | bands = LCBands(1) 164 | bandEnum = LCBands 165 | else: 166 | raise Exception("Invalid bands specified") 167 | 168 | if isinstance(bands, (list, tuple)): 169 | bands = [b.value for b in bands] 170 | else: 171 | bands = bandEnum(bands).value 172 | 173 | scene = "{}_{}".format(sensor, scene_id) 174 | filename = "{}_{}_p{}.{}".format(season, scene, patch_id, ext) 175 | patch_path = os.path.join(self.base_dir, season, scene, filename) 176 | 177 | with rasterio.open(patch_path) as patch: 178 | data = patch.read(bands) 179 | bounds = patch.bounds 180 | 181 | # Remap IGBP to DFC bands 182 | if sensor == "lc": 183 | data = IGBP2DFC[data] 184 | 185 | if len(data.shape) == 2: 186 | data = np.expand_dims(data, axis=0) 187 | 188 | return data, bounds 189 | 190 | def get_s1_s2_lc_dfc_quad(self, season, scene_id, patch_id, s1_bands=S1Bands.ALL, s2_bands=S2Bands.ALL, lc_bands=LCBands.ALL, dfc_bands=LCBands.NONE): 191 | """ 192 | Returns a quadruple of patches. S1, S2, LC and DFC as well as the geo-bounds of the patch. If the number of bands is NONE 193 | then a None value will be returned instead of image data 194 | """ 195 | 196 | s1, bounds1 = self.get_patch(season, scene_id, patch_id, s1_bands) 197 | s2, bounds2 = self.get_patch(season, scene_id, patch_id, s2_bands) 198 | lc, bounds3 = self.get_patch(season, scene_id, patch_id, lc_bands) 199 | dfc, bounds4 = self.get_patch(season, scene_id, patch_id, dfc_bands) 200 | 201 | bounds = next(filter(None, [bounds1, bounds2, bounds3, bounds4]), None) 202 | 203 | return s1, s2, lc, dfc, bounds 204 | 205 | def get_s1_s2_pair(self, season, scene_id, patch_id, s1_bands=S1Bands.ALL, s2_bands=S2Bands.ALL): 206 | """ 207 | Returns a quadruple of patches. S1, S2, LC and DFC as well as the geo-bounds of the patch. If the number of bands is NONE 208 | then a None value will be returned instead of image data 209 | """ 210 | 211 | s1, bounds1 = self.get_patch(season, scene_id, patch_id, s1_bands) 212 | s2, bounds2 = self.get_patch(season, scene_id, patch_id, s2_bands) 213 | 214 | bounds = next(filter(None, [bounds1, bounds2]), None) 215 | 216 | return s1, s2, bounds 217 | 218 | def get_quad_stack(self, season, scene_ids=None, patch_ids=None, s1_bands=S1Bands.ALL, s2_bands=S2Bands.ALL, lc_bands=LCBands.ALL, dfc_bands=LCBands.NONE): 219 | """ 220 | Returns a triplet of numpy arrays with dimensions D, B, W, H where D is the number of patches specified 221 | using scene_ids and patch_ids and B is the number of bands for S1, S2 or LC 222 | """ 223 | season = Seasons(season) 224 | scene_list = [] 225 | patch_list = [] 226 | bounds = [] 227 | s1_data = [] 228 | s2_data = [] 229 | lc_data = [] 230 | dfc_data = [] 231 | 232 | # This is due to the fact that not all patch ids are available in all scenes 233 | # And not all scenes exist in all seasons 234 | if isinstance(scene_ids, list) and isinstance(patch_ids, list): 235 | raise Exception("Only scene_ids or patch_ids can be a list, not both.") 236 | 237 | if scene_ids is None: 238 | scene_list = self.get_scene_ids(season) 239 | else: 240 | try: 241 | scene_list.extend(scene_ids) 242 | except TypeError: 243 | scene_list.append(scene_ids) 244 | 245 | if patch_ids is not None: 246 | try: 247 | patch_list.extend(patch_ids) 248 | except TypeError: 249 | patch_list.append(patch_ids) 250 | 251 | for sid in scene_list: 252 | if patch_ids is None: 253 | patch_list = self.get_patch_ids(season, sid) 254 | 255 | for pid in patch_list: 256 | s1, s2, lc, dfc, bound = self.get_s1_s2_lc_dfc_quad(season, sid, pid, s1_bands, s2_bands, lc_bands, dfc_bands) 257 | s1_data.append(s1) 258 | s2_data.append(s2) 259 | lc_data.append(lc) 260 | dfc_data.append(dfc) 261 | bounds.append(bound) 262 | 263 | return np.stack(s1_data, axis=0), np.stack(s2_data, axis=0), np.stack(lc_data, axis=0), np.stack(dfc_data, axis=0), bounds 264 | 265 | # This documents some example usage of the dataset handler. 266 | # To use the Seasons.TEST and Seasons.VALIDATION sets, they need to be in the same folder as the SEN12MS dataset. 267 | if __name__ == "__main__": 268 | from argparse import ArgumentParser 269 | 270 | parse = ArgumentParser() 271 | parse.add_argument('src', type=str, help="Base directory of SEN12MS dataset") 272 | args = parse.parse_args() 273 | 274 | # Load the dataset specifying the base directory 275 | sen12ms = DFCSEN12MSDataset(args.src) 276 | 277 | # Get the scene IDs for a single season 278 | spring_ids = sen12ms.get_season_ids(Seasons.SPRING) 279 | cnt_patches = sum([len(pids) for pids in spring_ids.values()]) 280 | print("Spring: {} scenes with a total of {} patches".format(len(spring_ids), cnt_patches)) 281 | 282 | # Let's get all the scene IDs for the Training dataset 283 | patch_cnt = 0 284 | for s in Seasons.TEST.value: 285 | test_ids = sen12ms.get_season_ids(s) 286 | patch_cnt += sum([len(pids) for pids in test_ids.values()]) 287 | 288 | print("There are a total of {} patches in the Test set".format(patch_cnt)) 289 | 290 | # Load the RGB bands of the first S2 patch in scene 8 291 | SCENE_ID = 8 292 | s2_rgb_patch, bounds = sen12ms.get_patch(Seasons.SPRING, SCENE_ID, spring_ids[SCENE_ID][0], bands=S2Bands.RGB) 293 | 294 | print("S2 RGB: {} Bounds: {}".format(s2_rgb_patch.shape, bounds)) 295 | print("\n") 296 | 297 | # Load a quadruplet of patches from the first three scenes of the Validation set - all S1 bands, NDVI S2 bands, the low resolution LC band and the high resolution DFC LC band 298 | validation_ids = sen12ms.get_season_ids(Seasons.VALSET) 299 | for i, (scene_id, patch_ids) in enumerate(validation_ids.items()): 300 | if i >= 3: 301 | break 302 | 303 | s1, s2, lc, dfc, bounds = sen12ms.get_s1_s2_lc_dfc_quad(Seasons.TESTSET, scene_id, patch_ids[0], s1_bands=S1Bands.ALL, 304 | s2_bands=[S2Bands.red, S2Bands.nir1], lc_bands=LCBands.LC, dfc_bands=LCBands.DFC) 305 | 306 | print(f"Scene: {scene_id}, S1: {s1.shape}, S2: {s2.shape}, LC: {lc.shape}, DFC: {dfc.shape}, Bounds: {bounds}") 307 | 308 | print("\n") 309 | 310 | # Load all bands of all patches in a specified scene (scene 106) 311 | s1, s2, lc, dfc, _ = sen12ms.get_quad_stack(Seasons.SPRING, 106, s1_bands=S1Bands.ALL, 312 | s2_bands=S2Bands.ALL, lc_bands=LCBands.ALL, dfc_bands=LCBands.DFC) 313 | 314 | print(f"Scene: 106, S1: {s1.shape}, S2: {s2.shape}, LC: {lc.shape}") 315 | -------------------------------------------------------------------------------- /tools/find_feature_points.py: -------------------------------------------------------------------------------- 1 | import rasterio 2 | import cv2 3 | import os 4 | import geojson 5 | import zipfile 6 | import tarfile 7 | import pyproj 8 | 9 | import numpy as np 10 | import geopandas as gpd 11 | import pandas as pd 12 | import skimage as ski 13 | 14 | from skimage.feature import peak_local_max, corner_peaks, corner_harris 15 | from argparse import ArgumentParser 16 | from rasterio.windows import get_data_window 17 | from geojson import Feature, Point, FeatureCollection 18 | from tqdm import tqdm 19 | 20 | from rasterio.windows import Window 21 | from rasterio.vrt import WarpedVRT 22 | from rasterio.crs import CRS 23 | from rasterio.warp import calculate_default_transform 24 | from rasterio.enums import Resampling 25 | 26 | class FakeCVFpt: 27 | def __init__(self, xy): 28 | self.size = 1 29 | self.response = 1 30 | self.pt = (xy[1], xy[0]) 31 | 32 | def create_virtual_warped_raster(raster, projection="EPSG:3035", resolution=2.5): 33 | epsg_code = int(projection.split(':')[1]) 34 | 35 | # Broken PRISM images 36 | if raster.crs is None: 37 | raster._crs = CRS.from_epsg(epsg_code) 38 | 39 | if raster.crs is not None and raster.crs.to_epsg() == epsg_code: 40 | return raster 41 | 42 | dst_transform, dst_width, dst_height = calculate_default_transform(raster.crs, projection, raster.width, raster.height, *raster.bounds, resolution=(resolution, resolution)) 43 | vrt_opts = { 44 | 'resampling': Resampling.nearest, 45 | 'crs': projection, 46 | 'transform': dst_transform, 47 | 'height': dst_height, 48 | 'width': dst_width, 49 | } 50 | 51 | return WarpedVRT(raster, **vrt_opts) 52 | 53 | def load_from_geotif(src, band, roi=None): 54 | if roi is None: 55 | roi = get_data_window(src.read(band, masked=True)) 56 | 57 | img = src.read(band, window=roi) 58 | width = src.width 59 | height = src.height 60 | transform = src.transform 61 | 62 | return img, (width, height), transform 63 | 64 | def find_keypoints(img, scheme="SURF", radius=None): 65 | if scheme == "SURF": 66 | detector = cv2.xfeatures2d.SURF_create(hessianThreshold=400, nOctaves=4, nOctaveLayers=3, extended=False, upright=True) 67 | elif scheme == "SIFT": 68 | detector = cv2.xfeatures2d.SIFT_create(nOctaveLayers=3, sigma=1.3) 69 | elif scheme == "BRISK": 70 | detector = cv2.BRISK_create(thresh=30, octaves=3) 71 | elif scheme == "ORB": 72 | detector = cv2.ORB_create(nfeatures=10000) 73 | 74 | if scheme not in ["HARRIS"]: 75 | kps = detector.detect(img, None) 76 | else: 77 | cnrs = corner_peaks(corner_harris(img), min_distance=radius) 78 | kps = [FakeCVFpt(xy) for xy in cnrs] 79 | 80 | return kps 81 | 82 | def superimpose_keypoints(img, fpts): 83 | img2 = cv2.drawKeypoints(img, fpts, None, (255,0,0), 4) 84 | return img2 85 | 86 | # pseudo non-maximal suppression of the keypoints based on selecting the maximal point within a radius of r 87 | def keypoints_nms(img, keypoints, r=64): 88 | feature_map = np.zeros_like(img).astype(np.float64) 89 | 90 | for img_pt, kpt in keypoints.items(): 91 | feature_map[img_pt] = kpt["response"] 92 | 93 | # Find the peaks in the original feature map and ensure they are seperated by at least r+1 pixels 94 | coords = corner_peaks(feature_map, min_distance=r, exclude_border=True) 95 | 96 | # Rebuild the keypoint list with all features from the NMS selected features 97 | keypoint_list = {} 98 | for pt in coords: 99 | keypoint_list[tuple(pt)] = keypoints[tuple(pt)] 100 | 101 | return keypoint_list 102 | 103 | # Convert openCV keypoints to world cordinates 104 | def cv_keypoints_to_world(kpts, tif, epsg="EPSG:4326"): 105 | keypoints = {} 106 | 107 | for kp in tqdm(kpts): 108 | x, y = kp.pt 109 | (lon, lat) = tif.xy(y, x) 110 | 111 | img_pt = (int(y), int(x)) 112 | 113 | if img_pt not in keypoints or keypoints[img_pt]["response"] < kp.response: 114 | keypoints[img_pt] = {"pt_w": (lon, lat), "pt_i": (int(x), int(y)), "size": kp.size, "response": kp.response} 115 | 116 | if epsg is not None: 117 | dest_proj = pyproj.Proj(init=epsg) 118 | pt_wgs84 = pyproj.transform(tif.crs.to_proj4(), dest_proj, lon, lat) 119 | keypoints[img_pt][epsg] = (pt_wgs84[0], pt_wgs84[1]) 120 | 121 | return keypoints 122 | 123 | def keypoints_to_geojson(keypoints, geom="pt_w"): 124 | point_list = [] 125 | 126 | for _, kpt in keypoints.items(): 127 | point_list.append( Feature( geometry=Point(kpt[geom]), properties=kpt ) ) 128 | 129 | return FeatureCollection(point_list) 130 | 131 | def open_raster(src, temp_dest=None, name_prefix=None): 132 | _, ext = os.path.splitext(src) 133 | raster = None 134 | fmem = None 135 | 136 | if ext == ".tif": 137 | raster = create_virtual_warped_raster( rasterio.open(src) ) 138 | elif ext == ".zip": 139 | archive = zipfile.ZipFile(src, 'r') 140 | tifs = [fname for fname in archive.infolist() if os.path.splitext(fname.filename)[1] == '.tif'] 141 | 142 | if len(tifs) > 0: 143 | if temp_dest is None: 144 | f = archive.open(tifs[0], 'r') 145 | fmem = io.MemoryFile(f.read()) 146 | raster = fmem.open() 147 | else: 148 | tifs[0].filename = os.path.basename(tifs[0].filename) 149 | if name_prefix is not None: 150 | tifs[0].filename = "{}_{}".format(name_prefix, tifs[0].filename) 151 | 152 | archive.extract(tifs[0], path=temp_dest) 153 | raster = rasterio.open(os.path.join(temp_dest, tifs[0].filename)) 154 | 155 | elif ext in ['.tar', '.gz', '.bz2', '.xz']: 156 | archive = tarfile.TarFile(src, 'r') 157 | tifs = [] 158 | for member in archive.getmembers(): 159 | # TSX Data archive contains numerous tif files so ensure we only open the actual image data 160 | if 'IMAGEDATA' in member.path and os.path.splitext(member.path)[1] == '.tif': 161 | tifs.append(member) 162 | 163 | if len(tifs) > 0: 164 | if temp_dest is None: 165 | f = archive.extractfile(tifs[0]) 166 | fmem = io.MemoryFile(f.read()) 167 | raster = fmem.open() 168 | else: 169 | tifs[0].name = os.path.basename(tifs[0].name) 170 | if name_prefix is not None: 171 | tifs[0].name = "{}_{}".format(name_prefix, tifs[0].name) 172 | 173 | archive.extract(tifs[0], path=temp_dest) 174 | raster = rasterio.open(os.path.join(temp_dest, tifs[0].filename)) 175 | 176 | return raster, fmem 177 | 178 | if __name__ == "__main__": 179 | parser = ArgumentParser() 180 | parser.add_argument("src", type=str, help="GeoTif to find feature points in") 181 | parser.add_argument("-o", "--output", action="store_true", help="Save features to CSV file and as geojson") 182 | parser.add_argument("-p", "--plot", action="store_true", help="Plot feature points") 183 | parser.add_argument("-b", "--band", type=int, default=1, help="Band in which to detect feature points") 184 | parser.add_argument("-r", "--radius", type=int, default=63, help="Radius for non-maximal suppression of keypoints") 185 | parser.add_argument("-s", "--scheme", type=str, default="SURF", help="Feature point detection scheme") 186 | parser.add_argument("-c", "--cut_patches", type=int, default=-1, help="Cut patches of specified size centered around the detected feature point") 187 | parser.add_argument("-g", "--geometry", type=str, default="pt_w", help="The geometry to use in the geojson object. 'pt_w' is raster projection, 'pt_i' is the image coords, 'epsg' the proj specified with -e is used.") 188 | parser.add_argument("-e", "--epsg", type=str, default=None, help="Specify an epsg code to reproject all points to. By default no reprojection happens") 189 | 190 | # parser.add_argument("-roi", "--roi", nargs=4, default=[0, 0, -1, -1], help="Geo-coords of bounding box in which to find feature points") 191 | args = parser.parse_args() 192 | 193 | # args.src = "/Volumes/Hades/Documents/Varsity/PhD_Remote_Sensing/Data and Experiments/Athens/PSM_MMC_TP__0002631001.317.1/O.tif" 194 | 195 | assert os.path.exists(args.src), "File does not exist" 196 | 197 | if args.epsg is not None: 198 | args.epsg = "EPSG:{}".format(args.epsg) 199 | 200 | if args.geometry == 'epsg': 201 | args.geometry = args.epsg 202 | 203 | print(args.src) 204 | 205 | tif, tmp_file = open_raster(args.src) 206 | img, dims, transform = load_from_geotif(tif, 1) 207 | print(img.shape) 208 | 209 | kpts = find_keypoints(img, scheme=args.scheme, radius=args.radius) 210 | keypoint_list = cv_keypoints_to_world(kpts, tif, epsg=args.epsg) 211 | print(f"Found {len(keypoint_list)} keypoints") 212 | 213 | if args.plot: 214 | img2 = superimpose_keypoints(img, kpts) 215 | 216 | meta = tif.meta.copy() 217 | meta.update({ 218 | "count": img2.shape[-1] 219 | }) 220 | 221 | assert len(img2.shape) == 3, "Incorrect image size, expected (H, W, C)" 222 | 223 | fname = os.path.splitext(args.src) 224 | fname = "{}_keypoints.{}".format(*fname) 225 | 226 | with rasterio.open(fname, 'w', **meta) as dst: 227 | for i in range(img2.shape[-1]): 228 | dst.write(img2[:,:,i], i+1) 229 | 230 | # Supress non-maximal keypoints 231 | keypoint_list = keypoints_nms(img, keypoint_list, r=args.radius) 232 | print(f"{len(keypoint_list)} keypoints after NMS") 233 | 234 | featcol = keypoints_to_geojson(keypoint_list, geom=args.geometry) 235 | 236 | df = pd.DataFrame.from_dict(list(keypoint_list.values())) 237 | 238 | if args.output: 239 | fname = os.path.splitext(args.src) 240 | df.to_csv( "{}_{}.csv".format(fname[0], args.scheme) ) 241 | 242 | print(f"Created GeoJSON with {len(featcol['features'])} keypoints") 243 | 244 | filename = "{}_{}.geojson".format(fname[0], args.scheme) 245 | 246 | try: 247 | os.remove(filename) 248 | except: 249 | pass 250 | finally: 251 | with open(filename, 'w') as dest: 252 | geojson.dump(featcol, dest) 253 | 254 | print(len(keypoint_list)) -------------------------------------------------------------------------------- /tools/lr_loss_plot.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | import numpy as np 3 | import pandas as pd 4 | import matplotlib.pyplot as plt 5 | import math 6 | 7 | parser = ArgumentParser() 8 | parser.add_argument("lr", type=str, help="Path to LR CSV file from tensorboard") 9 | parser.add_argument("loss", type=str, help="Path to Loss CSV file from tensorboard") 10 | parser.add_argument("--clip", type=float, help="Only plot until a LR of n") 11 | args = parser.parse_args() 12 | 13 | lr = pd.read_csv(args.lr)['Value'].values 14 | loss = pd.read_csv(args.loss)['Value'].values 15 | 16 | if args.clip: 17 | lr = lr[np.where(lr <= args.clip)] 18 | 19 | # Solve some common issues 20 | if len(loss) != len(lr): 21 | l = min(len(loss), len(lr)) 22 | r = math.ceil(len(lr)/len(loss)) 23 | loss = loss if len(lr) > len(loss) else loss[::r] 24 | lr = lr if len(lr) < len(loss) else lr[::r] 25 | 26 | loss = loss[:l] 27 | lr = lr[:l] 28 | 29 | print(f"Loss: {len(loss)} LR: {len(lr)}") 30 | 31 | plt.semilogx(lr, loss) 32 | plt.show() 33 | -------------------------------------------------------------------------------- /tools/plot_results.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | import argparse 4 | import matplotlib.pyplot as plt 5 | import numpy as np 6 | import scipy.io as sio 7 | import os 8 | 9 | from sklearn.metrics import classification_report, confusion_matrix, roc_auc_score, roc_curve 10 | 11 | one_hot_decode = lambda x: np.argmax(x, axis=1) if (len(x.shape) > 1 and x.shape[-1] > 1) else x.flatten() 12 | one_hot_decode_thresh = lambda x, t: x[:,-1]>=t if (len(x.shape) > 1 and x.shape[-1] > 1) else x.flatten() 13 | select_one_class = lambda x: x[:,-1] if (len(x.shape) > 1 and x.shape[-1] > 1) else x.flatten() 14 | colors = ['b','r','c','g','y','m','k'] 15 | 16 | def print_report(y_true, y_pred, names): 17 | for yt, yp, name in zip(y_true, y_pred, names): 18 | print("Classification Report for {}".format(name)) 19 | print( classification_report(yt, one_hot_decode(yp)) ) 20 | 21 | print("Confusion Matrix [TN, FP, FN, TP]") 22 | for yt, yp in zip(y_true, y_pred): 23 | print( confusion_matrix(yt, one_hot_decode(yp)).ravel() ) 24 | 25 | for i, (yt, yp) in enumerate(zip(y_true, y_pred)): 26 | fpr, tpr, thresh = roc_curve(yt, select_one_class(yp)) 27 | plt.plot(fpr, tpr, colors[i]) 28 | 29 | auc = roc_auc_score(yt, select_one_class(yp)) 30 | 31 | acc_5fpr = {'fpr': 0, 'acc': 0} 32 | fpr_max_acc = {'fpr': 0, 'acc': 0} 33 | for t in thresh: 34 | tn, fp, fn, tp = confusion_matrix(yt, one_hot_decode_thresh(yp, t)).ravel() 35 | fpr = fp/(fp+tn) 36 | acc = (tp+tn)/(tp+tn+fp+fn) 37 | 38 | if acc > fpr_max_acc['acc']: 39 | fpr_max_acc['acc'] = acc 40 | fpr_max_acc['fpr'] = fpr 41 | 42 | if fpr > acc_5fpr['fpr'] and fpr <= 0.05: 43 | acc_5fpr['acc'] = acc 44 | acc_5fpr['fpr'] = fpr 45 | 46 | print("Max Acc: {}".format(fpr_max_acc)) 47 | print("FPR5: {}".format(acc_5fpr)) 48 | print("AUC: {}".format(auc)) 49 | plt.show() 50 | 51 | parser = argparse.ArgumentParser() 52 | parser.add_argument("results_paths", type=str, nargs="+", help="Path to the file containing prediciton results") 53 | parser.add_argument("--y_pred", default="y_pred", help="Dataset identifier for the predicted results") 54 | parser.add_argument("--y_true", default="y_true", help="Dataset identifier for the ground truth labels") 55 | args = parser.parse_args() 56 | 57 | results = [sio.loadmat(path) for path in args.results_paths] 58 | names = ["{}_{}".format(i, path) for i, path in enumerate(args.results_paths)] 59 | y_pred = [res[args.y_pred] for res in results] 60 | y_true = [one_hot_decode(res[args.y_true]) for res in results] 61 | 62 | print_report(y_true, y_pred, names) 63 | -------------------------------------------------------------------------------- /tools/process_full_scene_results.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | import scipy.io as sio 3 | import numpy as np 4 | from skimage.io import imread, imsave 5 | import matplotlib.pyplot as plt 6 | 7 | import os 8 | 9 | from utils import Experiment 10 | from utils.factory import * 11 | from utils.helpers import load_file_list 12 | 13 | def _len_full_scene(full_scene, patch_size, stride): 14 | ny, nx = full_scene.shape[:2] 15 | 16 | # Compensate for edges, and stride to get the number of centers 17 | ncy = (ny - 2*(patch_size//2))//stride 18 | ncx = (nx - 2*(patch_size//2))//stride 19 | 20 | return nx, ny, ncx, ncy 21 | 22 | def make_heatmap(data, normalize=False): 23 | cmap = plt.cm.jet 24 | 25 | if normalize: 26 | norm = plt.Normalize(vmin=data.min(), vmax=data.max()) 27 | data = norm(data) 28 | 29 | image = cmap(data) 30 | return image 31 | 32 | def main(config): 33 | results = sio.loadmat(config.results_file)['y_pred'][:,1] 34 | full_scene = imread(config["datasets"]["test"].full_scene) 35 | img_list = load_file_list(config["datasets"]["test"].data_path) 36 | 37 | nx, ny, ncx, ncy = _len_full_scene(full_scene, config["datasets"]["test"].patch_size, config["datasets"]["test"].stride) 38 | map_size = ncx*ncy 39 | 40 | n_map = len(results)//map_size 41 | 42 | maps = [np.reshape(results[i*map_size:i*map_size+map_size], (ncy, ncx)) for i in range(n_map)] 43 | 44 | map_dir = os.path.join(config.result_path, 'maps') 45 | os.makedirs(map_dir, exist_ok=True) 46 | 47 | for i in range(n_map): 48 | image = make_heatmap(maps[i], True) 49 | coords = np.unravel_index(maps[i].argmax(), (ncx, ncy)) 50 | print("Max Point {} {:.5f}".format(coords, maps[i][coords])) 51 | plt.imsave(os.path.join(map_dir, "{}.png".format(os.path.basename(img_list[i]))), image) 52 | 53 | if __name__ == "__main__": 54 | parser = ArgumentParser() 55 | parser.add_argument('-c', '--config', default=None, type=str, required=True, help='config file path (default: None)') 56 | parser.add_argument('-r', '--results', default=None, type=str, required=True, help='Results file to process') 57 | args = parser.parse_args() 58 | 59 | config = Experiment.load_from_path(args.config) 60 | 61 | config.results_file = args.results 62 | 63 | assert config, "Config could not be loaded." 64 | 65 | main(config) 66 | -------------------------------------------------------------------------------- /tools/process_psnet_results.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import matplotlib.pyplot as plt 4 | 5 | from argparse import ArgumentParser 6 | 7 | parser = ArgumentParser() 8 | parser.add_argument("csv", help="Result CSV file") 9 | args = parser.parse_args() 10 | 11 | results = pd.read_csv(args.csv) 12 | 13 | results["l2"] = np.sqrt( (results.hm_x_max - results.shift_x)**2 + (results.hm_y_max - results.shift_y)**2 ) 14 | 15 | count = [] 16 | for t in np.arange(0, results.l2.max(), 0.5): 17 | count.append(np.sum(results.l2 <= t)) 18 | 19 | count = np.array(count) 20 | plt.plot(np.arange(0, results.l2.max(), 0.5), count/len(results)) 21 | plt.title("Threshold vs % Successful matches") 22 | plt.show() 23 | 24 | plt.scatter(test.l2, test.nlog_match_loss) 25 | plt.title("L2 error vs -log(matching loss)") 26 | plt.show() 27 | 28 | results["nnlog_match_loss"] = results.nlog_match_loss 29 | 30 | # For each possible matching loss threshold count the number of regions where we managed to match accurately 31 | counts = {1:[], 2:[], 3:[]} 32 | counts2 = {1:[], 2:[], 3:[]} 33 | c = ['r','b','k'] 34 | for k in counts.keys(): 35 | for t in np.unique(results.nnlog_match_loss): 36 | counts[k].append( np.sum(results.loc[results.nnlog_match_loss >= t].l2 <= k)/len(results.loc[results.nnlog_match_loss >= t]) ) 37 | counts2[k].append( np.sum(results.loc[results.nnlog_match_loss >= t].l2 > k)/len(results.loc[results.nnlog_match_loss >= t]) ) 38 | 39 | counts = {k: np.array(v) for k,v in counts.items()} 40 | counts2 = {k: np.array(v) for k,v in counts2.items()} 41 | 42 | [plt.plot(np.unique(results.nnlog_match_loss), counts[k], c[k-1]) for k in counts.keys()] 43 | plt.title("ROCish") 44 | plt.show() -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import logging 4 | import signal 5 | import random 6 | from argparse import ArgumentParser 7 | 8 | import numpy as np 9 | 10 | import torch 11 | import torch.nn as nn 12 | from torch.utils.data import DataLoader 13 | 14 | from torchsummary import summary 15 | 16 | from ignite.engine.engine import Engine, State, Events 17 | from ignite.handlers import ModelCheckpoint, EarlyStopping 18 | from ignite._utils import convert_tensor 19 | 20 | from utils import Experiment 21 | from utils.factory import * 22 | 23 | logging.basicConfig(level=logging.INFO, format='') 24 | logger = logging.getLogger() 25 | 26 | torch.backends.cudnn.benchmark = True 27 | 28 | def main(config): 29 | assert validate_config(config), "ERROR: Config file is invalid. Please see log for details." 30 | 31 | logger.info("INFO: {}".format(config.toDict())) 32 | 33 | # Set the random number generator seed for torch, as we use their dataloaders this will ensure shuffle is constant 34 | # Remeber to seed custom datasets etc with the same seed 35 | if config.seed > 0: 36 | torch.backends.cudnn.deterministic = True 37 | torch.cuda.manual_seed_all(config.seed) 38 | torch.manual_seed(config.seed) 39 | random.seed(config.seed) 40 | np.random.seed(config.seed) 41 | 42 | if config.device == "cpu" and torch.cuda.is_available(): 43 | logger.warning("WARNING: Not using the GPU") 44 | elif config.device == "cuda": 45 | config.device = f"cuda:{config.device_ids[0]}" 46 | 47 | config.nsave = config.nsave if "nsave" in config else 5 48 | 49 | logger.info("INFO: Creating datasets and dataloaders...") 50 | # Create the training dataset 51 | dset_train = create_dataset(config.datasets.train) 52 | 53 | # If the validation config has a parameter called split then we ask the training dset for the validation dataset 54 | # it should be noted that you shouldn't shuffle the dataset in the init of the train dataset if this is the case 55 | # as only on get_validation_split will we know how to split the data. Unless shuffling is deterministic. 56 | train_ids = None 57 | if 'validation' in config.datasets: 58 | # Ensure we have a full config for validation, this means we don't need t specify everything in the config file 59 | # only the differences 60 | config_val = config.datasets.train.copy() 61 | config_val.update(config.datasets.validation) 62 | 63 | dset_val = create_dataset(config_val) 64 | 65 | loader_val = get_data_loader(dset_val, config_val) 66 | print("Using validation dataset of {} samples or {} batches".format(len(dset_val), len(loader_val))) 67 | elif 'includes_validation' in config.datasets.train: 68 | train_ids, val_ids = dset_train.get_validation_split(config_val) 69 | loader_val = get_data_loader(dset_train, config.datasets.train, val_ids) 70 | print("Using validation dataset of {} samples or {} batches".format(len(val_ids), len(loader_val))) 71 | else: 72 | logger.warning("WARNING: No validation dataset was specified") 73 | dset_val = None 74 | loader_val = None 75 | 76 | loader_train = get_data_loader(dset_train, config.datasets.train, train_ids) 77 | dset_len = len(train_ids) if train_ids is not None else len(dset_train) 78 | print("Using training dataset of {} samples or {} batches".format(dset_len, len(loader_train))) 79 | 80 | cp_paths = None 81 | last_epoch = 0 82 | if 'checkpoint' in config: 83 | checkpoint_dir = config.checkpoint_dir if 'checkpoint_dir' in config else config.result_path 84 | cp_paths, last_epoch = config.get_checkpoints(path=checkpoint_dir, tag=config.checkpoint) 85 | print("Found checkpoint {} for Epoch {}".format(config.checkpoint, last_epoch)) 86 | last_epoch = last_epoch if config.resume_from == -1 else config.resume_from 87 | # config.epochs = config.epochs - last_epoch if last_epoch else config.epochs 88 | 89 | models = {} 90 | for name, model in config.model.items(): 91 | logger.info("INFO: Building the {} model".format(name)) 92 | models[name] = build_model(model) 93 | 94 | # Load the checkpoint 95 | if name in cp_paths: 96 | models[name].load_state_dict( torch.load( cp_paths[name] ) ) 97 | logger.info("INFO: Loaded model {} checkpoint {}".format(name, cp_paths[name])) 98 | 99 | if len(config.device_ids) > 1: 100 | models[name] = nn.DataParallel(models[name], device_ids=config.device_ids) 101 | 102 | models[name].to(config.device) 103 | print(models[name]) 104 | 105 | if 'debug' in config and config.debug is True: 106 | print("*********** {} ************".format(name)) 107 | for name, param in models[name].named_parameters(): 108 | if param.requires_grad: 109 | print(name, param.data) 110 | 111 | optimizers = {} 112 | for name, conf in config.optimizer.items(): 113 | optim_conf = conf.copy() 114 | del optim_conf["models"] 115 | 116 | model_params = [] 117 | for model_id in conf.models: 118 | model_params.extend( list(filter(lambda p: p.requires_grad, models[model_id].parameters())) ) 119 | 120 | logger.info("INFO: Using {} Optimization for {}".format(list(optim_conf.keys())[0], name)) 121 | optimizers[name] = get_optimizer(model_params, optim_conf) 122 | 123 | # Restoring the optimizer breaks because we do not include all parameters in the optimizer state. So if we aren't continuing training then just make a new optimizer 124 | if name in cp_paths and 'checkpoint_dir' not in config: 125 | optimizers[name].load_state_dict( torch.load( cp_paths[name] ) ) 126 | logger.info("INFO: Loaded {} optimizer checkpoint {}".format(name, cp_paths[name])) 127 | 128 | for state in optimizers[name].state.values(): 129 | for k, v in state.items(): 130 | if isinstance(v, torch.Tensor): 131 | state[k] = v.to(config.device) 132 | 133 | losses = {} 134 | for name, fcns in config.loss.items(): 135 | losses[name] = [] 136 | for l in fcns: 137 | losses[name].append( get_loss(l) ) 138 | assert losses[name][-1], "Loss function {} for {} could not be found, please check your config".format(l, name) 139 | 140 | if 'logger' in config: 141 | logger.info("INFO: Initialising the experiment logger") 142 | exp_logger = get_experiment_logger(config.result_path, config.logger) 143 | if last_epoch > 0: 144 | exp_logger.fast_forward(last_epoch, len(loader_train)) 145 | 146 | logger.info("INFO: Creating training manager and configuring callbacks") 147 | trainer = get_trainer(models, optimizers, losses, exp_logger, config) 148 | 149 | trainer_engine = Engine(trainer.train) 150 | evaluator_engine = Engine(trainer.evaluate) 151 | 152 | trainer.attach("train_loader", loader_train) 153 | trainer.attach("validation_loader", loader_val) 154 | trainer.attach("evaluation_engine", evaluator_engine) 155 | trainer.attach("train_engine", trainer_engine) 156 | 157 | for phase in config.metrics.keys(): 158 | if phase == "train": engine = trainer_engine 159 | if phase == "validation": engine = evaluator_engine 160 | 161 | for name, metric in config.metrics[phase].items(): 162 | metric = get_metric(metric) 163 | if metric is not None: 164 | metric.attach(engine, name) 165 | else: 166 | logger.warning("WARNING: Metric {} could not be created for {} phase".format(name, phase)) 167 | 168 | # Register default callbacks to run the validation stage 169 | if loader_val is not None: 170 | if len(loader_train) > 2000: 171 | # Validate 4 times an epoch 172 | num_batch = len(loader_train)//4 173 | 174 | def validate_run(engine): 175 | if engine.state.iteration % num_batch == 0: 176 | evaluator_engine.run(loader_val) 177 | 178 | trainer_engine.add_event_handler(Events.ITERATION_COMPLETED, validate_run) 179 | else: 180 | trainer_engine.add_event_handler(Events.EPOCH_COMPLETED, lambda engine: evaluator_engine.run(loader_val)) 181 | 182 | # Initialise the Epoch from the checkpoint - this is a hack because Ignite is dumb 183 | if last_epoch > 0: 184 | def set_epoch(engine, last_epoch): 185 | engine.state.epoch = last_epoch 186 | 187 | trainer_engine.add_event_handler(Events.STARTED, set_epoch, last_epoch) 188 | 189 | 190 | schedulers = {"batch": {}, "epoch": {}} 191 | if 'scheduler' in config: 192 | for sched_name, sched in config.scheduler.items(): 193 | if sched_name in optimizers: 194 | logger.info("INFO: Setting up LR scheduler for {}".format(sched_name)) 195 | sched_fn, sched_scheme = get_lr_scheduler(optimizers[sched_name], sched) 196 | assert sched_fn, "Learning Rate scheduler for {} could not be found, please check your config".format(sched_name) 197 | assert sched_scheme in ["batch", "epoch"], "ERROR: Invalid scheduler scheme, must be either epoch or batch" 198 | 199 | schedulers[sched_scheme][sched_name] = sched_fn 200 | 201 | def epoch_scheduler(engine): 202 | for name, sched in schedulers["epoch"].items(): 203 | sched.step() 204 | 205 | def batch_scheduler(engine): 206 | for name, sched in schedulers["batch"].items(): 207 | sched.step() 208 | 209 | trainer_engine.add_event_handler(Events.ITERATION_COMPLETED, lambda engine: batch_scheduler(engine)) 210 | trainer_engine.add_event_handler(Events.EPOCH_COMPLETED, lambda engine: epoch_scheduler(engine)) 211 | 212 | if exp_logger is not None: 213 | trainer_engine.add_event_handler(Events.ITERATION_COMPLETED, exp_logger.log_iteration, phase="train", models=models, optims=optimizers) 214 | trainer_engine.add_event_handler(Events.EPOCH_COMPLETED, exp_logger.log_epoch, phase="train", models=models, optims=optimizers) 215 | evaluator_engine.add_event_handler(Events.ITERATION_COMPLETED, exp_logger.log_iteration, phase="evaluate", models=models, optims=optimizers) 216 | evaluator_engine.add_event_handler(Events.EPOCH_COMPLETED, exp_logger.log_epoch, phase="evaluate", models=models, optims=optimizers) 217 | 218 | if "monitor" in config and config.monitor.early_stopping: 219 | logger.info("INFO: Enabling early stopping, monitoring {}".format(config.monitor.score)) 220 | score_fn = lambda e: config.monitor.scale * e.state.metrics[config.monitor.score] 221 | es_handler = EarlyStopping(patience=config.monitor.patience, score_function=score_fn, trainer=trainer_engine) 222 | evaluator_engine.add_event_handler(Events.COMPLETED, es_handler) 223 | 224 | if "monitor" in config and config.monitor.save_score: 225 | logger.info("INFO: Saving best model based on {}".format(config.monitor.save_score)) 226 | score_fn = lambda e: config.monitor.save_scale * e.state.metrics[config.monitor.save_score] 227 | ch_handler = ModelCheckpoint(config.result_path, 'best_checkpoint', score_function=score_fn, score_name=config.monitor.save_score, n_saved=1, require_empty=False, save_as_state_dict=True) 228 | to_save = dict(models, **optimizers) 229 | evaluator_engine.add_event_handler(Events.EPOCH_COMPLETED, ch_handler, to_save) 230 | 231 | if config.save_freq > 0: 232 | ch_handler = ModelCheckpoint(config.result_path, 'checkpoint', save_interval=config.save_freq, n_saved=config.nsave, require_empty=False, save_as_state_dict=True) 233 | to_save = dict(models, **optimizers) 234 | trainer_engine.add_event_handler(Events.EPOCH_COMPLETED, ch_handler, to_save) 235 | 236 | # Register custom callbacks with the engines 237 | if check_if_implemented(trainer, "on_iteration_start"): 238 | trainer_engine.add_event_handler(Events.ITERATION_STARTED, trainer.on_iteration_start, phase="train") 239 | evaluator_engine.add_event_handler(Events.ITERATION_STARTED, trainer.on_iteration_start, phase="evaluate") 240 | if check_if_implemented(trainer, "on_iteration_end"): 241 | trainer_engine.add_event_handler(Events.ITERATION_COMPLETED, trainer.on_iteration_end, phase="train") 242 | evaluator_engine.add_event_handler(Events.ITERATION_COMPLETED, trainer.on_iteration_end, phase="evaluate") 243 | if check_if_implemented(trainer, "on_epoch_start"): 244 | trainer_engine.add_event_handler(Events.EPOCH_STARTED, trainer.on_epoch_start, phase="train") 245 | evaluator_engine.add_event_handler(Events.EPOCH_STARTED, trainer.on_epoch_start, phase="evaluate") 246 | if check_if_implemented(trainer, "on_epoch_end"): 247 | trainer_engine.add_event_handler(Events.EPOCH_COMPLETED, trainer.on_epoch_end, phase="train") 248 | evaluator_engine.add_event_handler(Events.EPOCH_COMPLETED, trainer.on_epoch_end, phase="evaluate") 249 | 250 | # Save the config for this experiment to the results directory, once we know the params are good 251 | config.save() 252 | 253 | def signal_handler(sig, frame): 254 | print('You pressed Ctrl+C!') 255 | if exp_logger is not None: 256 | exp_logger.teardown() 257 | sys.exit(0) 258 | 259 | signal.signal(signal.SIGINT, signal_handler) 260 | 261 | logger.info("INFO: Starting training...") 262 | trainer_engine.run(loader_train, max_epochs=config.epochs) 263 | 264 | if exp_logger is not None: 265 | exp_logger.teardown() 266 | 267 | if __name__ == "__main__": 268 | parser = ArgumentParser() 269 | parser.add_argument('-c', '--config', default=None, type=str, required=True, help='config file path (default: None)') 270 | parser.add_argument('--checkpoint', default=None, type=str, help='Checkpoint tag to reload') 271 | parser.add_argument('--checkpoint_dir', default=None, type=str, help='Checkpoint directory to reload') 272 | parser.add_argument('--suffix', default=None, type=str, help='Add to the name') 273 | parser.add_argument('--epochs', default=None, type=int, help='Number of epochs') 274 | parser.add_argument('--resume_from', default=None, type=int, help='Epoch to resume from, allows using checkpoints as initialisation') 275 | args = parser.parse_args() 276 | 277 | OVERLOADABLE = ['checkpoint', 'epochs', 'checkpoint_dir', 'resume_from'] 278 | 279 | overloaded = {} 280 | for k, v in vars(args).items(): 281 | if (k in OVERLOADABLE) and (v is not None): 282 | overloaded[k] = v 283 | 284 | config = Experiment.load_from_path(args.config, overloaded, args.suffix) 285 | 286 | assert config, "Config could not be loaded." 287 | 288 | # Else load the saved config from the results dir or throw an error if one doesn't exist 289 | if len(config.checkpoint) > 0: 290 | logger.warning("WARNING: --config specifies resuming, overriding config with exising experiment config.") 291 | # resume_config = Experiment(config.name, desc=config.desc, result_dir=config.result_dir).load() 292 | # assert resume_config is not None, "No experiment {} exists, cannot resume training".format(config.name) 293 | # config = resume_config 294 | assert config, "Config could not be loaded for resume" 295 | # If we have resume_from in the config but have it < 0 to start a fresh training run then throw and error if the directory already exists 296 | elif config.overwrite is False: 297 | assert not config.exists(), "Results directory {} already exists! Please specify a new experiment name or the remove old files.".format(config.result_path) 298 | else: 299 | empty_folder(config.result_path) 300 | 301 | main(config) 302 | -------------------------------------------------------------------------------- /trainer/default_trainer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from ignite.engine.engine import Engine, State, Events 4 | from ignite._utils import convert_tensor 5 | 6 | from utils.helpers import BinaryClassificationMeter, accuracy 7 | 8 | class DefaultTrainer: 9 | def __init__(self, model, optimizer, loss_fn, logger, config): 10 | self.opts = config.trainer_config 11 | 12 | self.model = model["Model"] 13 | 14 | if optimizer: 15 | self.optimizer = optimizer["Optim"] 16 | 17 | if loss_fn: 18 | self.loss_fn = loss_fn["Loss"][0] 19 | 20 | self.logger = logger 21 | self.device = config.device 22 | self.log_freq = config.log_freq 23 | self.attached = {} 24 | self.curr_epoch = 0 25 | self.metric = BinaryClassificationMeter() 26 | self.metric_train = BinaryClassificationMeter() 27 | 28 | def _prepare_batch(self, batch): 29 | xs, ys = batch 30 | 31 | if isinstance(xs, list): 32 | xs = [convert_tensor(x, self.device).float() for x in xs] 33 | else: 34 | xs = [convert_tensor(xs, self.device).float()] 35 | 36 | if isinstance(ys, list): 37 | ys = [convert_tensor(y, self.device).float() for y in ys] 38 | else: 39 | ys = [convert_tensor(ys, self.device).float()] 40 | 41 | return xs, ys 42 | 43 | def train(self, engine, batch): 44 | self.model.train() 45 | 46 | curr_step = self.logger.counters["train"] 47 | 48 | self.optimizer.zero_grad() 49 | 50 | xs, ys = self._prepare_batch(batch) 51 | y_pred = self.model(*xs) 52 | 53 | if not (isinstance(y_pred, list) or isinstance(y_pred, tuple)): 54 | ys = ys[0] 55 | 56 | loss = self.loss_fn(y_pred, ys, pos_weight=torch.Tensor([1.5]).to("cuda")) 57 | 58 | self.logger.add_scalars('train/loss', {'L': loss.item()}, curr_step) 59 | 60 | if engine.state.iteration % 1000 == 0: 61 | self.logger.log_image_grid("Input", xs[0], "train") 62 | y_img = torch.ones_like(xs[0])*ys.view(ys.size(0),1,1,1) 63 | self.logger.log_image_grid("Label", y_img, "train", normalize=False) 64 | y2_img = torch.ones_like(xs[0])*torch.sigmoid(y_pred).view(y_pred.size(0),1,1,1) 65 | self.logger.log_image_grid("Prediction", y2_img, "train", normalize=False) 66 | 67 | loss.backward() 68 | self.optimizer.step() 69 | 70 | return loss.item() 71 | 72 | def on_epoch_start(self, engine, phase=None): 73 | self.log_batch = True 74 | self.metric_train.reset() 75 | if phase == "train": 76 | self.curr_epoch = engine.state.epoch 77 | 78 | def on_epoch_end(self, engine, phase=None): 79 | if phase in ["evaluate", "test"]: 80 | metrics = engine.state.metrics 81 | log = "" 82 | for k, v in metrics.items(): 83 | log += "{}: {:.2f} ".format(k, v) 84 | 85 | print("{} Results - Epoch: {} {}".format(phase.capitalize(), self.curr_epoch, log)) 86 | 87 | if phase in ["evaluate"]: 88 | curr_step = self.logger.counters["evaluate"] 89 | self.logger.add_scalars('evaluate/metrics', {'Acc': self.metric.acc, 'Precision': self.metric.pre, 'f1':self.metric.f1, 'Recall': self.metric.rec}, curr_step) 90 | self.metric.reset() 91 | 92 | def on_iteration_start(self, engine, phase=None): 93 | if phase == "train": 94 | curr_iter = (engine.state.iteration - 1) % len(self.attached["train_loader"]) + 1 95 | if curr_iter % self.log_freq == 0: 96 | print("Epoch[{}] Iteration[{}/{}] Loss: {:.2f}".format(engine.state.epoch, curr_iter, len(self.attached["train_loader"]), engine.state.output)) 97 | elif phase == "test": 98 | curr_iter = (engine.state.iteration - 1) % len(self.attached["test_loader"]) + 1 99 | if curr_iter % self.log_freq == 0: 100 | print("Iteration[{}/{}]".format(curr_iter, len(self.attached["test_loader"]))) 101 | 102 | def on_iteration_end(self, engine, phase=None): 103 | pass 104 | 105 | def infer_batch(self, batch): 106 | self.model.eval() 107 | 108 | with torch.no_grad(): 109 | xs, ys = self._prepare_batch(batch) 110 | y_pred = self.model(*xs) 111 | 112 | return xs, ys, y_pred 113 | 114 | def evaluate(self, engine, batch): 115 | curr_step = self.logger.counters["evaluate"] 116 | 117 | xs, ys, y_pred = self.infer_batch(batch) 118 | 119 | if not (isinstance(y_pred, list) or isinstance(y_pred, tuple)): 120 | ys = ys[0] 121 | 122 | if self.log_batch: 123 | self.logger.log_image_grid("evInput", xs[0], "evaluate") 124 | y_img = torch.ones_like(xs[0])*ys.view(ys.size(0),1,1,1) 125 | self.logger.log_image_grid("evLabel", y_img, "evaluate", normalize=False) 126 | y2_img = torch.ones_like(xs[0])*torch.sigmoid(y_pred).view(y_pred.size(0),1,1,1) 127 | self.logger.log_image_grid("evPrediction", y2_img, "evaluate", normalize=False) 128 | self.log_batch = False 129 | 130 | loss = self.loss_fn(y_pred, ys) 131 | 132 | self.metric.update(torch.sigmoid(y_pred), ys) 133 | 134 | self.logger.add_scalars('evaluate/loss', {'L': loss.item()}, curr_step) 135 | 136 | return y_pred.float(), ys.float() 137 | 138 | def attach(self, name, obj): 139 | self.attached[name] = obj 140 | -------------------------------------------------------------------------------- /trainer/wur_hypercol_trainer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from ignite.engine.engine import Engine, State, Events 4 | from ignite._utils import convert_tensor 5 | 6 | from torch.nn.utils import clip_grad_value_ 7 | import torch.autograd as autograd 8 | from torch.autograd import Variable 9 | 10 | from apex import amp 11 | from apex.fp16_utils import * 12 | 13 | import numpy as np 14 | 15 | DEFAULT_LOSS_WEIGHTS = { 16 | "match": 1, 17 | "spatial_softmax": 0, 18 | "heatmap_l1": 0, 19 | } 20 | 21 | class WURHypercolTrainer: 22 | def __init__(self, models, optimizers, loss_fns, logger, config): 23 | self.opts = config.trainer_config 24 | self.search_domain = self.opts['search_domain'] if 'search_domain' in self.opts else "A" 25 | self.loss_weights = self.opts['loss_weights'] if 'loss_weights' in self.opts else DEFAULT_LOSS_WEIGHTS.copy() 26 | self.hm_act = F.tanhshrink if 'hm_act' in self.opts else False 27 | self.fp16 = False 28 | self.binarize = False 29 | 30 | self.log_eval_batch = False 31 | 32 | self.FtsA = self._model_precision(models["FtsA"]) 33 | self.FtsB = self._model_precision(models["FtsB"]) 34 | 35 | if optimizers: 36 | if "Fts" in optimizers: 37 | self.optim_fts = self._optim_precision(optimizers["Fts"]) 38 | else: 39 | self.optim_fts = None 40 | 41 | if loss_fns: 42 | self.Lmatch = loss_fns["Lmatch"][0] 43 | 44 | self.device = torch.device(config.device) 45 | self.logger = logger 46 | self.log_freq = config.log_freq 47 | self.attached = {} 48 | self.curr_epoch = 0 49 | self.log_str = "" 50 | 51 | def _freeze(self, model): 52 | for param in model.parameters(): 53 | param.requires_grad = False 54 | 55 | def _optim_precision(self, optim): 56 | if self.fp16: 57 | return FP16_Optimizer(optim, dynamic_loss_scale=True) 58 | else: 59 | return optim 60 | 61 | def _model_precision(self, model): 62 | if self.fp16: 63 | return network_to_half(model) 64 | else: 65 | return model 66 | 67 | def _prepare_batch(self, batch, non_blocking=True): 68 | xs, ys = batch 69 | 70 | if isinstance(xs, list): 71 | xs = [convert_tensor(x, self.device, non_blocking=non_blocking).float() for x in xs] 72 | else: 73 | xs = [convert_tensor(xs, self.device, non_blocking=non_blocking).float()] 74 | 75 | if isinstance(ys, list): 76 | ys = [convert_tensor(y, self.device, non_blocking=non_blocking).float() for y in ys] 77 | else: 78 | ys = [convert_tensor(ys, self.device, non_blocking=non_blocking).float()] 79 | 80 | if self.fp16: 81 | xs = [x.half() for x in xs] 82 | ys = [y.half() for y in ys] 83 | 84 | return xs, ys 85 | 86 | def _zero_grad(self): 87 | if self.optim_fts: 88 | self.optim_fts.zero_grad() 89 | 90 | def freeze_model(self, model): 91 | model.eval() 92 | for params in model.parameters(): 93 | params.requires_grad = False 94 | 95 | def unfreeze_model(self, model): 96 | model.train() 97 | for params in model.parameters(): 98 | params.requires_grad = True 99 | 100 | def _AdamW(self, optimizer, wd=1e-2): 101 | if self.adamW and isinstance(optimizer, torch.optim.Adam): 102 | for group in optimizer.param_groups: 103 | for param in group['params']: 104 | param.data = param.data.add(-wd * group['lr'], param.data) 105 | 106 | def l2_shift_loss(self, heatmap, label, device="cuda"): 107 | hm = heatmap.view(heatmap.size(0), -1) 108 | gt = label.view(label.size(0), -1) 109 | 110 | hm_max, hm_pos = hm.max(1) 111 | gt_max, gt_pos = gt.max(1) 112 | 113 | hm_pos = torch.Tensor( np.unravel_index(hm_pos.cpu().numpy(), heatmap.shape[2:]) ).transpose(0, 1) 114 | gt_pos = torch.Tensor( np.unravel_index(gt_pos.cpu().numpy(), label.shape[2:]) ).transpose(0, 1) 115 | 116 | l2 = F.pairwise_distance(hm_pos, gt_pos) 117 | return l2.to(device) 118 | 119 | # Apply softmax per channel 120 | def spatial_softmax(self, heatmap): 121 | b,c,h,w = heatmap.size() 122 | x = heatmap.view(b, c, -1).transpose(2, 1) 123 | x = F.softmax(x, dim=1) 124 | return x.transpose(2, 1).view(b, c, h, w) 125 | 126 | # Similar to Neighbourhood Consensus Network Loss 127 | # For the distribution at each layer to be Kronecker Delta 128 | def softmax_localization_loss(self, heatmap, alpha=0): 129 | heatmap = self.spatial_softmax(heatmap) 130 | b,c,h,w = heatmap.size() 131 | 132 | # Flatten the tensor but keep the channels for adaptions to multiscale later 133 | scores, _ = torch.max(heatmap.view(b, c, -1), dim=2) 134 | return torch.mean(scores) + alpha*heatmap.sum() 135 | 136 | def weighted_binary_cross_entropy(self, heatmap, labels, thresh=0.8, device="cuda", reduction="mean"): 137 | b, c, h, w = heatmap.shape 138 | weight = torch.sum(labels < thresh)/(torch.sum(labels >= thresh)) 139 | 140 | return self.Lmatch(heatmap, labels, pos_weight=weight.to(device), reduction=reduction) 141 | 142 | def train(self, engine, batch): 143 | self.FtsA.train() 144 | self.FtsB.train() 145 | 146 | log_str = "" 147 | curr_step = self.logger.counters["train"] 148 | 149 | # Extract the a and b image pairs from the batch and whether the pairs match 150 | (search_img, template_img, template_hard), y = self._prepare_batch(batch) 151 | 152 | self._zero_grad() 153 | 154 | y_a = self.FtsA(search_img) 155 | y_bhn = self.FtsB(template_hard) 156 | 157 | heatmap_hneg_raw = self.FtsA.correlation_map(y_a, y_bhn, self.hm_act) 158 | 159 | if self.loss_weights["spatial_softmax"] > 0: 160 | heatmap_hneg = self.spatial_softmax(heatmap_hneg_raw) 161 | else: 162 | heatmap_hneg = heatmap_hneg_raw 163 | 164 | if engine.state.iteration % 1000 == 0: 165 | self.logger.log_image_grid("Search", search_img, "train") 166 | self.logger.log_image_grid("Template_match", template_hard, "train") 167 | self.logger.log_image_grid("Heatmap", heatmap_hneg, "train") 168 | self.logger.log_image_grid("Heatmap_raw", heatmap_hneg_raw, "train") 169 | self.logger.log_image_grid("Ground_Truth", y[0], "train") 170 | 171 | match_loss = self.weighted_binary_cross_entropy(heatmap_hneg, y[0], device=self.device) 172 | 173 | self.logger.add_scalars('train/match_loss', {'Lmatch': match_loss.item()}, curr_step) 174 | 175 | heatmap_l1_loss = self.loss_weights["heatmap_l1"]*heatmap_hneg_raw.norm(p=1) 176 | 177 | self.logger.add_scalars('train/regularize', {'L1': heatmap_l1_loss}, curr_step) 178 | 179 | loss = self.loss_weights["match"]*match_loss + heatmap_l1_loss 180 | 181 | self.logger.add_scalar('train/loss', loss.item(), curr_step) 182 | 183 | log_str += "Ltotal: {:.5f} \t".format(loss.item()) 184 | log_str += "Lmatch: {:.5f} \t".format(match_loss.item()) 185 | 186 | loss.backward() 187 | self.optim_fts.step() 188 | 189 | self.log_str = log_str 190 | 191 | return None, None 192 | 193 | def on_epoch_start(self, engine, phase=None): 194 | if phase == "train": 195 | self.curr_epoch = engine.state.epoch 196 | 197 | if phase == "evaluate": 198 | self.log_eval_batch = True 199 | 200 | def on_epoch_end(self, engine, phase=None): 201 | if phase in ["evaluate", "test"]: 202 | metrics = engine.state.metrics 203 | log = "" 204 | for k, v in metrics.items(): 205 | log += "{}: {:.5f} ".format(k, v) 206 | 207 | print("{} Results - Epoch: {} {}".format(phase.capitalize(), self.curr_epoch, log)) 208 | 209 | def on_iteration_start(self, engine, phase=None): 210 | if phase == "train": 211 | curr_iter = (engine.state.iteration - 1) % len(self.attached["train_loader"]) + 1 212 | 213 | if curr_iter % self.log_freq == 0: 214 | print("Epoch[{}] Iteration[{}/{}] {}".format(engine.state.epoch, curr_iter, len(self.attached["train_loader"]), self.log_str)) 215 | 216 | elif phase == "test": 217 | curr_iter = (engine.state.iteration - 1) % len(self.attached["test_loader"]) + 1 218 | if curr_iter % self.log_freq == 0: 219 | print("Iteration[{}/{}]".format(curr_iter, len(self.attached["test_loader"]))) 220 | 221 | def on_iteration_end(self, engine, phase=None): 222 | pass 223 | 224 | def infer_batch(self, batch): 225 | self.FtsA.eval() 226 | self.FtsB.eval() 227 | 228 | with torch.no_grad(): 229 | imgs, y = self._prepare_batch(batch) 230 | 231 | try: 232 | (search_img, template_img, template_hard) = imgs 233 | except: 234 | (search_img, template_img, template_hard, _) = imgs 235 | 236 | y_a = self.FtsA(search_img) 237 | y_b = self.FtsB(template_img) 238 | y_bhn = self.FtsB(template_hard) 239 | y_bn = torch.roll(y_bhn, -1, 0) 240 | 241 | # Create negative examples for matching and easy negative heatmaps 242 | heatmap_neg_raw = self.FtsA.correlation_map(y_a, y_bn, self.hm_act) 243 | heatmap_hneg_raw = self.FtsA.correlation_map(y_a, y_bhn, self.hm_act) 244 | 245 | if self.loss_weights["spatial_softmax"] > 0: 246 | heatmap_hneg = self.spatial_softmax(heatmap_hneg_raw) 247 | heatmap_neg = self.spatial_softmax(heatmap_neg_raw) 248 | else: 249 | heatmap_hneg = heatmap_hneg_raw 250 | heatmap_neg = heatmap_neg_raw 251 | 252 | d_a, attn_a = None, None 253 | d_b, attn_b = None, None 254 | d_bhn, attn_bhn = None, None 255 | 256 | hms = (heatmap_neg, heatmap_neg_raw, heatmap_hneg, heatmap_hneg_raw) 257 | fts = (y_a, y_b, y_bhn) 258 | dets = (d_a, attn_a, d_b, attn_b, d_bhn, attn_bhn) 259 | 260 | return imgs, hms, y, fts, dets 261 | 262 | def evaluate(self, engine, batch): 263 | curr_step = self.logger.counters["evaluate"] 264 | 265 | imgs, hms, y, fts, dets = self.infer_batch(batch) 266 | (search_img, template_img, template_hard) = imgs 267 | (heatmap_neg, heatmap_neg_raw, heatmap_hneg, heatmap_hneg_raw) = hms 268 | (y_a, y_b, y_bhn) = fts 269 | (d_a, attn_a, d_b, attn_b, d_bhn, attn_bhn) = dets 270 | 271 | match_loss = self.weighted_binary_cross_entropy(heatmap_hneg, y[0], device=self.device) 272 | 273 | self.logger.add_scalars('evaluate/match_loss', {'Lmatch': match_loss.item()}, curr_step) 274 | 275 | heatmap_l1_loss = self.loss_weights["heatmap_l1"]*heatmap_hneg_raw.norm(p=1) 276 | 277 | self.logger.add_scalars('evaluate/regularize', {'L1': heatmap_l1_loss}, curr_step) 278 | 279 | loss = self.loss_weights["match"]*match_loss + heatmap_l1_loss 280 | self.logger.add_scalar('evaluate/loss', loss.item(), curr_step) 281 | 282 | if self.log_eval_batch: 283 | self.logger.log_image_grid("ev_Search", search_img, "evaluate") 284 | self.logger.log_image_grid("ev_Template", template_img, "evaluate") 285 | self.logger.log_image_grid("ev_Template_s", template_hard, "evaluate") 286 | self.logger.log_image_grid("ev_Heatmap", heatmap_hneg, "evaluate") 287 | self.logger.log_image_grid("ev_Ground_Truth", y[0], "evaluate") 288 | self.log_eval_batch = False 289 | 290 | return None, None #(search_img, template_img), (heatmap_pos, heatmap_neg) 291 | 292 | def attach(self, name, obj): 293 | self.attached[name] = obj 294 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .experiment import * 2 | from .modules import * 3 | from .augmentation import * 4 | from .helpers import * 5 | from .basic_cache import * 6 | -------------------------------------------------------------------------------- /utils/augmentation.py: -------------------------------------------------------------------------------- 1 | from imgaug import augmenters as iaa 2 | from skimage.color import rgb2gray 3 | import operator 4 | import numpy as np 5 | 6 | DEFAULT_PROBS = { 7 | "fliplr": 0.5, 8 | "flipud": 0.3, 9 | "scale": 0.1, 10 | "scale_px": (0.98, 1.02), 11 | "translate": 0.15, 12 | "translate_perc": (-0.05, 0.05), 13 | "rotate": 0.2, 14 | "rotate_angle": (-5, 5), 15 | "contrast": 0, 16 | "dropout": 0 17 | } 18 | 19 | import torch 20 | import numpy as np 21 | 22 | def cutout(img, n_holes=1, length=8): 23 | """ 24 | Args: 25 | img (Tensor): Tensor image of size (C, H, W). 26 | Returns: 27 | Tensor: Image with n_holes of dimension length x length cut out of it. 28 | """ 29 | h = img.size(1) 30 | w = img.size(2) 31 | 32 | mask = np.ones((h, w), np.float32) 33 | 34 | for n in range(n_holes): 35 | y = np.random.randint(h) 36 | x = np.random.randint(w) 37 | 38 | y1 = np.clip(y - length // 2, 0, h) 39 | y2 = np.clip(y + length // 2, 0, h) 40 | x1 = np.clip(x - length // 2, 0, w) 41 | x2 = np.clip(x + length // 2, 0, w) 42 | 43 | mask[y1: y2, x1: x2] = 0. 44 | 45 | mask = torch.from_numpy(mask) 46 | mask = mask.expand_as(img) 47 | img = img * mask 48 | 49 | return img 50 | 51 | def cropCenter(img, bounding, shift=(0,0)): 52 | imshape = [x+y*2 for x,y in zip(img.shape, shift)] 53 | bounding = list(bounding) 54 | imshape.reverse() 55 | bounding.reverse() 56 | start = tuple(map(lambda a, da: a//2-da//2, imshape, bounding)) 57 | end = tuple(map(operator.add, start, bounding)) 58 | slices = tuple(map(slice, start, end)) 59 | return img[slices] 60 | 61 | def cropCorner(img, bounding, corner="tl"): 62 | # Corner in (y, x) coordinates 63 | if corner == "tl": 64 | start = (0, 0) 65 | elif corner == "bl": 66 | start = (img.shape[1]-bounding[1], 0) 67 | elif corner == "br": 68 | start = (img.shape[1]-bounding[1], img.shape[0]-bounding[0]) 69 | else: 70 | start = (0, img.shape[0]-bounding[0]) 71 | end = tuple(map(operator.add, start, bounding)) 72 | slices = tuple(map(slice, start, end)) 73 | return img[slices] 74 | 75 | def toGrayscale(img): 76 | if len(img.shape) >= 3 and img.shape[-1] == 3: 77 | img = rgb2gray(img) 78 | 79 | if len(img.shape) < 3: 80 | img = np.expand_dims(img, axis=2) 81 | 82 | return img 83 | 84 | class Augmentation: 85 | 86 | def __init__(self, probs=DEFAULT_PROBS): 87 | trans = [] 88 | if "fliplr" in probs: 89 | trans.append(iaa.Fliplr(probs['fliplr'])) 90 | 91 | if "flipud" in probs: 92 | trans.append(iaa.Fliplr(probs['flipud'])) 93 | 94 | if "scale" in probs: 95 | trans.append(iaa.Sometimes(probs["scale"], iaa.Affine(scale={"x": probs['scale_px'], "y": probs['scale_px']}))) 96 | 97 | if "translate" in probs: 98 | trans.append(iaa.Sometimes(probs["translate"], iaa.Affine(translate_percent={"x": probs['translate_perc'], "y": probs['translate_perc']}))) 99 | 100 | if "rotate" in probs: 101 | trans.append(iaa.Sometimes(probs["rotate"], iaa.Affine(rotate=probs["rotate_angle"]))) 102 | 103 | if "contrast" in probs: 104 | # trans.append(iaa.ContrastNormalization((0.9, 1.5), per_channel=probs["contrast"])) 105 | trans.append(iaa.Multiply((0.7, 1.3), per_channel=probs["contrast"])) 106 | 107 | if "dropout" in probs: 108 | trans.append(iaa.CoarseDropout((0.0, 0.05), size_percent=(0.02, 0.2), per_channel=probs["dropout"])) 109 | 110 | self.seq = iaa.Sequential(trans) 111 | self.transformer = self.seq 112 | 113 | def __call__(self, imgs): 114 | if isinstance(imgs, list): 115 | imgs = [self.transformer.augment_images(img) for img in imgs] 116 | else: 117 | imgs = self.transformer.augment_image(imgs) 118 | 119 | return imgs 120 | 121 | def refresh_random_state(self): 122 | self.transformer = self.seq.to_deterministic() 123 | return self.transformer 124 | -------------------------------------------------------------------------------- /utils/basic_cache.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import shutil 4 | import hashlib 5 | 6 | from glob import glob 7 | 8 | # A basic folder cache for Numpy objects 9 | class BasicCache: 10 | # type = "lifo", "fifo", "fill" 11 | def __init__(self, cache_dir, size=10000, scheme="fill", clear=True, overwrite=False): 12 | self.basedir = os.path.abspath(cache_dir) 13 | self.size = size 14 | self.scheme = scheme 15 | self.overwrite = overwrite 16 | self.cache = {} 17 | self.times = [] 18 | 19 | if clear: 20 | self.clear() 21 | 22 | os.makedirs(self.basedir, exist_ok=True) 23 | self._load_existing_cache() 24 | 25 | def key_hash(self, key): 26 | return hashlib.sha224(key.encode()).hexdigest() 27 | 28 | def _load_existing_cache(self): 29 | # We only cache numpy objects 30 | flist = glob(os.path.join(self.basedir, "*.npz")) 31 | 32 | for f in flist: 33 | h = os.path.splitext(os.path.basename(f))[0] 34 | 35 | if len(self.cache) < self.size: 36 | self.cache[h] = f 37 | self.times.append(h) 38 | 39 | def _get_filename(self, key): 40 | h = self.key_hash(key) 41 | return os.path.join(self.basedir, "{}.npz".format(h)), h 42 | 43 | def _prune_cache(self): 44 | # Remove an item from the cache according to scheme 45 | if len(self.cache) >= self.size: 46 | if self.scheme == "lifo": 47 | rm_idx = self.times.pop() 48 | elif self.scheme == "fifo": 49 | rm_idx = self.times.pop(0) 50 | else: 51 | return False 52 | 53 | rm_file = self.cache[rm_idx] 54 | # Remove the cached file 55 | if os.path.exists(rm_file): 56 | os.remove(rm_file) 57 | 58 | del self.cache[rm_idx] 59 | return True 60 | 61 | def clear(self): 62 | if os.path.exists(self.basedir): 63 | shutil.rmtree(self.basedir) 64 | 65 | def isin(self, key): 66 | _, h = self._get_filename(key) 67 | return h in self.times 68 | 69 | def __len__(self): 70 | return len(self.cache) 71 | 72 | def __getitem__(self, key): 73 | _, h = self._get_filename(key) 74 | 75 | if self.isin(key): 76 | fname = self.cache[h] 77 | data = np.load(fname, allow_pickle=True) 78 | 79 | if "arr_0" in data.files: 80 | return data["arr_0"] 81 | else: 82 | return {f: data[f] for f in data.files} 83 | 84 | def __setitem__(self, key, value): 85 | fname, h = self._get_filename(key) 86 | 87 | # Only add the item if it isn't already in the cache 88 | if (self.overwrite and h in self.times) or self._prune_cache(): 89 | if isinstance(value, dict): 90 | np.savez(fname, **value) 91 | else: 92 | np.savez(fname, value) 93 | 94 | self.cache[h] = fname 95 | 96 | if h not in self.times: 97 | self.times.append(h) 98 | 99 | if __name__=="__main__": 100 | cache = BasicCache("tmp_cache", size=100, scheme="fill", clear=True) 101 | 102 | # Create 120 random objects and cache them (only 100 should cache) 103 | for i in range(0, 120): 104 | cache[i] = np.random.rand(4, 4) 105 | 106 | for i in range(0, 120): 107 | data = cache[i] 108 | 109 | if data is not None: 110 | print(f"Retrieved {i} = {data.shape} form cache") 111 | else: 112 | print(f"{i} not in cache") 113 | 114 | cache2 = BasicCache("tmp_cache", size=100, scheme="fill", clear=False, overwrite=True) 115 | cache[45] = np.random.rand(10, 10) 116 | print(f"Retrieved without overwrite {cache[45].shape}") 117 | cache2[45] = np.random.rand(10, 10) 118 | print(f"Retrieved with overwrite {cache2[45].shape}") 119 | -------------------------------------------------------------------------------- /utils/experiment.py: -------------------------------------------------------------------------------- 1 | from dotmap import DotMap 2 | from glob import glob 3 | from numbers import Number 4 | import os 5 | import json 6 | 7 | from utils.helpers import extract_numbers 8 | 9 | def number_ordering(x): 10 | n = extract_numbers(x) 11 | return n[-1] if len(n) > 0 else 0 12 | 13 | class Experiment(DotMap): 14 | def __init__(self, name, desc="", result_dir="./results", data={}): 15 | super(Experiment, self).__init__(data) 16 | 17 | self.name = name 18 | self.desc = desc 19 | self.result_dir = result_dir 20 | self.result_path = os.path.join(self.result_dir, self.name) 21 | 22 | def save(self): 23 | os.makedirs(self.result_path, exist_ok=True) 24 | with open(os.path.join(self.result_path, "config.json"), "w") as f: 25 | f.write(json.dumps(self.toDict(), indent=4, sort_keys=False)) 26 | 27 | def load(self): 28 | try: 29 | with open(os.path.join(self.result_path, "config.json")) as f: 30 | data = json.load(f) 31 | super(Experiment, self).__init__(data) 32 | return self 33 | except: 34 | return None 35 | 36 | def exists(self): 37 | return os.path.exists(self.result_path) 38 | 39 | # def get_checkpoint_path(self): 40 | # model_path = None 41 | 42 | # if 'checkpoint' in self and len(self.checkpoint) > 0: 43 | # _, ext = os.path.splitext(self.checkpoint) 44 | 45 | # if ext == ".pth": 46 | # model_path = self.checkpoint if os.path.exists(self.checkpoint) else None 47 | # path = os.path.join(self.result_path, self.checkpoint) if model_path is None else model_path 48 | # model_path = path if os.path.exists(path) else None 49 | # elif self.checkpoint == "best": 50 | # cpts = glob(os.path.join(self.result_path, 'best_checkpoint_model*.pth')) 51 | # cpts = sorted(cpts, key=number_ordering) 52 | # if len(cpts) > 0: 53 | # model_path = cpts[-1] 54 | # elif self.checkpoint.isdigit(): 55 | # path = os.path.join(self.result_path, 'checkpoint_model_{}.pth'.format(self.resume_from)) 56 | # model_path = path if os.path.exists(path) else None 57 | 58 | # optim_path = model_path.replace('model','optim') if model_path else None 59 | # return (model_path, optim_path) 60 | 61 | def get_checkpoints(self, path=None, tag="best"): 62 | checkpoints = {} 63 | last_epoch = 0 64 | 65 | path = self.result_path if path is None else path 66 | filelist = glob(os.path.join(path, "*{}*.pth".format(tag))) 67 | 68 | # Clean the filenames so we can use split() to extract parts of them 69 | filelist = [os.path.splitext(fname)[0] for fname in filelist] 70 | epochs = set([p for fname in filelist for p in os.path.basename(fname).split('_') if p.isdigit()]) 71 | 72 | # Find the largest epoch 73 | for e in epochs: 74 | if e.isdigit() and int(e) > last_epoch: 75 | last_epoch = int(e) 76 | 77 | # Ensure we only select the chosen epoch 78 | filelist = [fname for fname in filelist if str(last_epoch) in fname] 79 | 80 | for fname in filelist: 81 | parts = os.path.basename(fname).split('_') 82 | name = parts[len(parts)//2] # Middle element is always the name 83 | checkpoints[name] = fname + ".pth" # Add back the file extension 84 | 85 | return checkpoints, last_epoch 86 | 87 | @staticmethod 88 | def load_from_path(path, overloads=None, suffix=None): 89 | with open(path) as f: 90 | data = json.load(f) 91 | 92 | if overloads: 93 | data.update(overloads) 94 | 95 | data['name'] = data["name"] if suffix is None else "{}_{}".format(data['name'], suffix) 96 | 97 | return Experiment(data['name'], data['desc'], data['result_dir'], data) 98 | 99 | 100 | @staticmethod 101 | def load_by_name(name, conf_dir="./config"): 102 | exp = Experiment(name, result_dir=conf_dir).load() 103 | return(exp) 104 | -------------------------------------------------------------------------------- /utils/factory.py: -------------------------------------------------------------------------------- 1 | from .helpers import * 2 | from torch.utils.data import DataLoader, SubsetRandomSampler 3 | import ignite 4 | import torch.nn.functional as F 5 | import torch.optim.lr_scheduler 6 | 7 | def create_dataset(config): 8 | dataset = get_module('./datasets', config.type) 9 | return dataset(config) or None 10 | 11 | def build_model(config): 12 | ident = list(config.keys())[0] 13 | model = get_module('./models', ident) 14 | return model(**config[ident].toDict()) or None 15 | 16 | def get_optimizer(params, config): 17 | ident = list(config.keys())[0] 18 | 19 | optim = None 20 | try: 21 | optim = get_module('./optimizers', ident) 22 | except: 23 | pass 24 | 25 | if optim is None: 26 | optim = str_to_class('torch.optim', ident) 27 | 28 | return optim(params, **config[ident].toDict()) or None 29 | 30 | def get_data_loader(dset, config, indices=None): 31 | sampler = None 32 | 33 | if indices is not None: 34 | sampler = SubsetRandomSampler(dset) 35 | config.shuffle = False 36 | 37 | if getattr(dset, "get_batch_sampler", None): 38 | batch_sampler = dset.get_batch_sampler(config.batch_size) 39 | return DataLoader(dset, num_workers=config.workers, batch_sampler=batch_sampler) 40 | 41 | return DataLoader(dset, batch_size=config.batch_size, shuffle=config.shuffle, num_workers=config.workers, sampler=sampler) 42 | 43 | def get_trainer(model, optimizer, loss_fn, exp_logger, config): 44 | trainer = get_module('./trainer', config.trainer) 45 | return trainer(model, optimizer, loss_fn, exp_logger, config) or None 46 | 47 | def get_experiment_logger(log_dir, config): 48 | ident = list(config.keys())[0] 49 | logger = get_module('./logger', ident) 50 | config = config[ident].toDict() 51 | config["log_dir"] = log_dir 52 | return logger(**config) or None 53 | 54 | def get_metric(name): 55 | metric = get_if_implemented(ignite.metrics, name) 56 | 57 | if metric is None: 58 | try: 59 | metric = get_module('./metrics', name) 60 | except: 61 | pass 62 | 63 | if metric is None: 64 | loss_fcn = get_loss(name) 65 | assert loss_fcn, "No loss function {} was found for use as a metric".format(name) 66 | metric = ignite.metrics.Loss(loss_fcn) 67 | else: 68 | metric = metric() 69 | 70 | return metric or None 71 | 72 | def get_loss(loss_fn): 73 | loss = get_if_implemented(F, loss_fn) 74 | if loss is None: 75 | loss = get_function('losses.functional', loss_fn) 76 | return loss 77 | 78 | def get_lr_scheduler(optimizer, config): 79 | name = list(config.keys())[0] 80 | args = config[name].toDict().copy() 81 | scheme = args["scheme"] 82 | args = copy_and_delete(args, 'scheme') 83 | 84 | lr_scheduler = get_if_implemented(torch.optim.lr_scheduler, name) 85 | 86 | if lr_scheduler is None: 87 | try: 88 | lr_scheduler = get_module('./schedulers', name) 89 | except: 90 | pass 91 | 92 | if lr_scheduler is None: 93 | fcn = get_function('schedulers.functional', name) 94 | assert fcn, "No functional implementation of {} was found".format(name) 95 | fcn_wrapper = fcn(**args) 96 | lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, [fcn_wrapper]) 97 | else: 98 | lr_scheduler = lr_scheduler(optimizer, **args) 99 | 100 | return lr_scheduler, scheme 101 | -------------------------------------------------------------------------------- /utils/geo_tools.py: -------------------------------------------------------------------------------- 1 | import rasterio 2 | import numpy as np 3 | 4 | from itertools import product 5 | from scipy.spatial.distance import cdist, pdist 6 | from scipy.sparse.csgraph import csgraph_from_dense, depth_first_tree 7 | 8 | from rasterio.coords import BoundingBox, disjoint_bounds 9 | from rasterio.warp import transform_bounds 10 | from rasterio.windows import Window 11 | from rasterio.plot import show, show_hist 12 | 13 | def bbox_intersection(bounds1, bounds2): 14 | if disjoint_bounds(bounds1, bounds2): 15 | raise Exception("Bounds are disjoint, no interseciton exists") 16 | 17 | bbox = BoundingBox( 18 | left=max(bounds1.left, bounds2.left), 19 | right=min(bounds1.right, bounds2.right), 20 | top=min(bounds1.top, bounds2.top), 21 | bottom=max(bounds1.bottom, bounds2.bottom) 22 | ) 23 | 24 | return bbox 25 | 26 | def relative_window(base_window, abs_window, strict=True): 27 | window = Window( 28 | col_off=base_window.col_off + abs_window.col_off, 29 | row_off=base_window.row_off + abs_window.row_off, 30 | width=abs_window.width, 31 | height=abs_window.height, 32 | ) 33 | 34 | return window.intersection(base_window) if strict else window 35 | 36 | def absolute_window(base_window, rel_window, strict=True): 37 | window = Window( 38 | col_off=rel_window.col_off - base_window.col_off, 39 | row_off=rel_window.row_off - base_window.row_off, 40 | width=rel_window.width, 41 | height=rel_window.height, 42 | ) 43 | 44 | return window.intersection(base_window) if strict else window 45 | 46 | class Raster: 47 | def __init__(self, src_path, bands=None): 48 | self.path = src_path 49 | self.raster = rasterio.open(self.path) 50 | self.bands = bands 51 | 52 | # These will change when doing set operations 53 | self.width = self.raster.width 54 | self.height = self.raster.height 55 | self.window = Window(0, 0, self.width, self.height) 56 | self.transform = self.raster.transform 57 | self.profile = self.raster.profile.copy() 58 | 59 | def _update_profile(self): 60 | self.profile.update({ 61 | 'height': self.height, 62 | 'width': self.width, 63 | 'transform': self.transform 64 | }) 65 | 66 | # Assumes that bounds are in the same CRS 67 | def _clip_bounds(self, bounds): 68 | if disjoint_bounds(bounds, self.raster.bounds): 69 | raise Exception("Bounds are disjoint, no interseciton exists") 70 | 71 | # Get the new bounds as a window in the original raster 72 | bounds_window = rasterio.windows.from_bounds(*bounds, transform=self.raster.transform) 73 | bounds_window = bounds_window.intersection(self.window) 74 | 75 | self.window = bounds_window.round_lengths(op='ceil') 76 | self.height = int(self.window.height) 77 | self.width = int(self.window.width) 78 | self.transform = rasterio.windows.transform(self.window, self.transform) 79 | 80 | self._update_profile() 81 | 82 | def clip_bounds_by_raster(self, template, intersection=False): 83 | bounds = template.raster.bounds 84 | 85 | if template.raster.crs != self.raster.crs: 86 | # Make sure bounds are in same coordinate system 87 | bounds = transform_bounds(template.raster.crs, self.raster.crs, *bounds) 88 | 89 | if intersection: 90 | bounds = bbox_intersection(self.raster.bounds, bounds) 91 | 92 | self._clip_bounds(bounds) 93 | 94 | # Same as clip by raster except only valida data regions in both images are kept 95 | def crop_by_raster(self, template): 96 | pass 97 | 98 | # Takes in a set of patch offsets (relative) and outputs windows for each which generates a full sized patch (if strict) 99 | def _get_patch_windows_from_offsets(self, offsets, size): 100 | abs_window = Window(0, 0, self.width, self.height) 101 | 102 | for col_off, row_off in offsets: 103 | window = Window(col_off=col_off, row_off=row_off, width=size, height=size).intersection(abs_window) 104 | 105 | if (window.width != size or window.height != size): 106 | continue 107 | 108 | transform = rasterio.windows.transform(relative_window(self.window, window, strict=True), self.transform) 109 | yield window, transform 110 | 111 | # If strict is false then all points will be returned, if stride is specified then it will be enforced at best effort (guarentee no larger overlap, but don't guarentee all points) 112 | # Points must be a numpy array of Mx2 (cols, rows) or (x, y), it is assumed points represent the center pixels of the patch 113 | def get_patch_windows_from_imgXY(self, pts, size, stride=None): 114 | # Convert to top left corner first 115 | pts = pts - (size//2) 116 | 117 | idxs = np.all(pts >= 0, axis=1) 118 | idxs = np.logical_and(np.logical_and( np.all(pts >= 0, axis=1), pts[:, 0] <= self.width ), 119 | pts[:, 1] <= self.height) 120 | pts = pts[idxs, :] 121 | 122 | if stride: 123 | selected_pts = np.empty(shape=(0, 2)) 124 | 125 | # Sort by x - we'll iterate that way first 126 | sort = np.argsort(pts[:,0]) 127 | pts = pts[sort, :] 128 | 129 | for pt in pts: 130 | pt = pt[np.newaxis,:] 131 | if len(selected_pts) > 0: 132 | valid = np.all(cdist(pt, selected_pts) >= stride) 133 | else: 134 | valid = True 135 | 136 | if valid: 137 | selected_pts = np.concatenate([selected_pts, pt], axis=0) 138 | 139 | pts = selected_pts 140 | 141 | return self._get_patch_windows_from_offsets(pts, size) 142 | 143 | # Using geocoords rather than raster image coords 144 | def get_patch_windows_from_worldXY(self, pts, size, stride=None): 145 | r, c = rasterio.transform.rowcol(self.transform, xs=pts[:,0], ys=pts[:,1]) 146 | ptsXY = np.stack([c, r], axis=1) 147 | 148 | # Clean up invalid points - negative, or outside of the main data window 149 | # This isn't needed but is more efficient than processing points which we know 150 | # won't be selected 151 | idxs = np.all(ptsXY >= 0, axis=1) 152 | idxs = np.logical_and(np.logical_and( np.all(ptsXY >= 0, axis=1), ptsXY[:, 0] <= self.width ), 153 | ptsXY[:, 1] <= self.height) 154 | ptsXY = ptsXY[idxs, :] 155 | 156 | return self.get_patch_windows_from_imgXY(ptsXY, size, stride=stride) 157 | 158 | # Get patch windows with size and stride, only full windows will be returned 159 | def get_patch_windows(self, size, stride): 160 | offsets = product(range(0, self.width, stride), range(0, self.height, stride)) 161 | return self._get_patch_windows_from_offsets(offsets, size) 162 | 163 | def _get_patches(self, bands, windows, strict=True): 164 | for window, transform in windows: 165 | data = self.read(bands, window=window, masked=True) 166 | 167 | # If the patch contains nodata values then don't generate it 168 | if strict and np.ma.is_masked(data): 169 | continue 170 | 171 | yield data, transform 172 | 173 | # Yields patches from the raster with the defined stride and size, strict means no nodata will exist in thee returned patches 174 | def get_patches(self, bands=1, size=128, stride=64, strict=True): 175 | return self._get_patches(bands, self.get_patch_windows(size, stride), strict=strict) 176 | 177 | def get_patches_from_imgXY(self, bands): 178 | pass 179 | 180 | def read(self, bands=None, window=None, masked=False): 181 | if window is None: 182 | window = self.window 183 | else: 184 | window = relative_window(self.window, window) 185 | 186 | return self.raster.read(bands, window=window, masked=masked) 187 | 188 | def save(self, dest): 189 | with rasterio.open(dest, "w", **self.profile) as dest: 190 | dest.write(self.read(out_shape=(self.raster.count, self.height, self.width))) 191 | 192 | def show(self, band=1, cmap="terrain", window=None, **kwargs): 193 | show((self.read(band, window=window)), cmap=cmap, transform=self.transform, **kwargs) 194 | 195 | def show_hist(self, bands=1, bins=50, **kwargs): 196 | show_hist(self.read(bands), bins=bins, **kwargs) 197 | 198 | if __name__=="__main__": 199 | a=Raster("/media/Zambezi/Data/RawGeoTifs/Munich_WV2.tif") 200 | b=Raster("/media/Zambezi/SimGeoI_Data/GeoTif/TSX_Frauenkirche_Munich.tif") 201 | pts = np.load("/media/Zambezi/SAR_OPT_Data/Munich_Center_17112017/coords.npy") 202 | a.clip_bounds_by_raster(b) 203 | # a.show(window=Window(0,0,128,128), cmap="gray") 204 | # b.show_hist(bins=10, alpha=0.3, lw=0.0, histtype='stepfilled', title='Histogram', masked=True) 205 | print(a.window) 206 | print(a.transform) 207 | # for w,t in b.get_patch_windows(128, 128, strict=True): 208 | # print(f"{w}") 209 | 210 | # for i, (patch, t) in enumerate(a.get_patches(1, size=256, stride=128, strict=True)): 211 | # show(patch, transform=t, cmap="gray", title=i) 212 | import code 213 | code.interact(local=locals()) -------------------------------------------------------------------------------- /utils/helpers.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import pkgutil 3 | import inspect 4 | import os 5 | import shutil 6 | import re 7 | 8 | def validate_config(config): 9 | # assert config.device in ["cpu", "cuda"], "Invalid compute device was specified. Only 'cpu' and 'cuda' are supported." 10 | return True 11 | 12 | def count_parameters(model): 13 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 14 | 15 | def freeze_net(model, filter=None, freeze=True): 16 | for p in model.parameters(): 17 | p.requires_grad = not freeze 18 | 19 | def load_file_list(base_dir, path): 20 | img_list = [] 21 | for line in open(path, 'r'): 22 | if base_dir: 23 | img_list.append(os.path.join(base_dir, line.strip())) 24 | else: 25 | img_list.append(line.strip()) 26 | return(img_list) 27 | 28 | def empty_folder(path): 29 | if os.path.exists(path): 30 | for the_file in os.listdir(path): 31 | file_path = os.path.join(path, the_file) 32 | try: 33 | if os.path.isfile(file_path): 34 | os.unlink(file_path) 35 | elif os.path.isdir(file_path): 36 | shutil.rmtree(file_path) 37 | except Exception as e: 38 | pass 39 | 40 | def static_vars(**kwargs): 41 | def decorate(func): 42 | for k in kwargs: 43 | setattr(func, k, kwargs[k]) 44 | return func 45 | return decorate 46 | 47 | def is_float(string): 48 | try: 49 | float(string) 50 | return True 51 | except ValueError: 52 | return False 53 | 54 | def extract_numbers(x): 55 | r = re.compile('(\d+(?:\.\d+)?)') 56 | l = r.split(x) 57 | return [float(y) for y in l if is_float(y)] 58 | 59 | def copy_and_delete(d, key): 60 | copy = d.copy() 61 | del copy[key] 62 | return(copy) 63 | 64 | def get_modules(path): 65 | modules = {} 66 | 67 | for loader, name, is_pkg in pkgutil.walk_packages(path): 68 | module = loader.find_module(name).load_module(name) 69 | for name, value in inspect.getmembers(module): 70 | # Only import classes we defined 71 | if inspect.isclass(value) is False or value.__module__ is not module.__name__: 72 | continue 73 | 74 | modules[name] = value 75 | 76 | return modules 77 | 78 | def get_learning_rate(optimizer): 79 | lr = [] 80 | for param_group in optimizer.param_groups: 81 | lr += [ param_group['lr'] ] 82 | return lr 83 | 84 | def get_function(module, fcn): 85 | try: 86 | fn = str_to_class(module, fcn) 87 | except: 88 | fn = None 89 | return fn 90 | 91 | def get_module(path, name): 92 | modules = get_modules([path]) 93 | assert name in modules.keys(), "Could not find module {}".format(name) 94 | return modules[name] 95 | 96 | def __classname_to_modulename(name): 97 | s1 = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', name) 98 | return re.sub('([a-z0-9])([A-Z])', r'\1_\2', s1).lower() 99 | 100 | def str_to_class(module_name, class_name): 101 | try: 102 | module_ = importlib.import_module(module_name) 103 | try: 104 | class_ = getattr(module_, class_name) 105 | except AttributeError: 106 | logging.error('Class does not exist') 107 | except ImportError: 108 | logging.error('Module does not exist') 109 | return class_ or None 110 | 111 | def check_if_implemented(obj, fcn): 112 | op = getattr(obj, fcn, None) 113 | return callable(op) 114 | 115 | def get_if_implemented(obj, fcn): 116 | op = getattr(obj, fcn, None) 117 | if not callable(op): 118 | op = None 119 | return op 120 | 121 | def accuracy(output, target): 122 | """Computes the accuracy for multiple binary predictions""" 123 | pred = output >= 0.5 124 | truth = target >= 0.5 125 | acc = pred.eq(truth).sum() / target.numel() 126 | return acc 127 | 128 | 129 | class BinaryClassificationMeter(object): 130 | """Computes and stores the average and current value""" 131 | def __init__(self): 132 | self.reset() 133 | 134 | def reset(self): 135 | self.tp = 0 136 | self.tn = 0 137 | self.fp = 0 138 | self.fn = 0 139 | self.acc = 0 140 | self.pre = 0 141 | self.rec = 0 142 | self.f1 = 0 143 | 144 | def update(self, output, target): 145 | pred = output >= 0.5 146 | truth = target >= 0.5 147 | self.tp += pred.mul(truth).sum(0).float() 148 | self.tn += (1 - pred).mul(1 - truth).sum(0).float() 149 | self.fp += pred.mul(1 - truth).sum(0).float() 150 | self.fn += (1 - pred).mul(truth).sum(0).float() 151 | self.acc = (self.tp + self.tn).sum() / (self.tp + self.tn + self.fp + self.fn).sum() 152 | self.pre = self.tp / (self.tp + self.fp) 153 | self.rec = self.tp / (self.tp + self.fn) 154 | self.f1 = (2.0 * self.tp) / (2.0 * self.tp + self.fp + self.fn) 155 | # self.avg_pre = torch.nanmean(self.pre) 156 | # self.avg_rec = nanmean(self.rec) 157 | # self.avg_f1 = nanmean(self.f1) -------------------------------------------------------------------------------- /utils/modules.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import math 4 | 5 | from torch.nn.parameter import Parameter 6 | # from torch._jit_internal import weak_module, weak_script_method 7 | 8 | class GELU(nn.Module): 9 | """ 10 | Paper Section 3.4, last paragraph notice that BERT used the GELU instead of RELU 11 | """ 12 | def forward(self, x): 13 | return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) 14 | import math 15 | 16 | import torch 17 | import torch.nn as nn 18 | import torch.nn.functional as F 19 | 20 | 21 | def conv_bn_act(in_, out_, kernel_size, 22 | stride=1, groups=1, bias=True, 23 | eps=1e-3, momentum=0.01): 24 | return nn.Sequential( 25 | SamePadConv2d(in_, out_, kernel_size, stride, groups=groups, bias=bias), 26 | nn.BatchNorm2d(out_, eps, momentum), 27 | Swish() 28 | ) 29 | 30 | 31 | class SamePadConv2d(nn.Conv2d): 32 | """ 33 | Conv with TF padding='same' 34 | https://github.com/pytorch/pytorch/issues/3867#issuecomment-349279036 35 | """ 36 | 37 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, groups=1, bias=True, padding_mode="zeros"): 38 | super().__init__(in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias, padding_mode) 39 | 40 | def get_pad_odd(self, in_, weight, stride, dilation): 41 | effective_filter_size_rows = (weight - 1) * dilation + 1 42 | out_rows = (in_ + stride - 1) // stride 43 | padding_needed = max(0, (out_rows - 1) * stride + effective_filter_size_rows - in_) 44 | padding_rows = max(0, (out_rows - 1) * stride + (weight - 1) * dilation + 1 - in_) 45 | rows_odd = (padding_rows % 2 != 0) 46 | return padding_rows, rows_odd 47 | 48 | def forward(self, x): 49 | padding_rows, rows_odd = self.get_pad_odd(x.shape[2], self.weight.shape[2], self.stride[0], self.dilation[0]) 50 | padding_cols, cols_odd = self.get_pad_odd(x.shape[3], self.weight.shape[3], self.stride[1], self.dilation[1]) 51 | 52 | if rows_odd or cols_odd: 53 | x = F.pad(x, [0, int(cols_odd), 0, int(rows_odd)]) 54 | 55 | return F.conv2d(x, self.weight, self.bias, self.stride, 56 | padding=(padding_rows // 2, padding_cols // 2), 57 | dilation=self.dilation, groups=self.groups) 58 | 59 | 60 | class Swish(nn.Module): 61 | def forward(self, x): 62 | return x * torch.sigmoid(x) 63 | 64 | 65 | class Flatten(nn.Module): 66 | def forward(self, x): 67 | return x.view(x.shape[0], -1) 68 | 69 | 70 | class SEModule(nn.Module): 71 | def __init__(self, in_, squeeze_ch): 72 | super().__init__() 73 | self.se = nn.Sequential( 74 | nn.AdaptiveAvgPool2d(1), 75 | nn.Conv2d(in_, squeeze_ch, kernel_size=1, stride=1, padding=0, bias=True), 76 | Swish(), 77 | nn.Conv2d(squeeze_ch, in_, kernel_size=1, stride=1, padding=0, bias=True), 78 | ) 79 | 80 | def forward(self, x): 81 | return x * torch.sigmoid(self.se(x)) 82 | 83 | 84 | class DropConnect(nn.Module): 85 | def __init__(self, ratio): 86 | super().__init__() 87 | self.ratio = 1.0 - ratio 88 | 89 | def forward(self, x): 90 | if not self.training: 91 | return x 92 | 93 | random_tensor = self.ratio 94 | random_tensor += torch.rand([x.shape[0], 1, 1, 1], dtype=torch.float, device=x.device) 95 | random_tensor.requires_grad_(False) 96 | return x / self.ratio * random_tensor.floor() 97 | 98 | class MBConv(nn.Module): 99 | def __init__(self, in_, out_, expand, 100 | kernel_size, stride, skip, 101 | se_ratio, dc_ratio=0.2): 102 | super().__init__() 103 | mid_ = in_ * expand 104 | self.expand_conv = conv_bn_act(in_, mid_, kernel_size=1, bias=False) if expand != 1 else nn.Identity() 105 | 106 | self.depth_wise_conv = conv_bn_act(mid_, mid_, 107 | kernel_size=kernel_size, stride=stride, 108 | groups=mid_, bias=False) 109 | 110 | self.se = SEModule(mid_, int(in_ * se_ratio)) if se_ratio > 0 else nn.Identity() 111 | 112 | self.project_conv = nn.Sequential( 113 | SamePadConv2d(mid_, out_, kernel_size=1, stride=1, bias=False), 114 | nn.BatchNorm2d(out_, 1e-3, 0.01) 115 | ) 116 | 117 | # if _block_args.id_skip: 118 | # and all(s == 1 for s in self._block_args.strides) 119 | # and self._block_args.input_filters == self._block_args.output_filters: 120 | self.skip = skip and (stride == 1) and (in_ == out_) 121 | 122 | # DropConnect 123 | # self.dropconnect = DropConnect(dc_ratio) if dc_ratio > 0 else nn.Identity() 124 | # Original TF Repo not using drop_rate 125 | # https://github.com/tensorflow/tpu/blob/05f7b15cdf0ae36bac84beb4aef0a09983ce8f66/models/official/efficientnet/efficientnet_model.py#L408 126 | self.dropconnect = nn.Identity() 127 | 128 | def forward(self, inputs): 129 | expand = self.expand_conv(inputs) 130 | x = self.depth_wise_conv(expand) 131 | x = self.se(x) 132 | x = self.project_conv(x) 133 | if self.skip: 134 | x = self.dropconnect(x) 135 | x = x + inputs 136 | return x 137 | 138 | 139 | class MBBlock(nn.Module): 140 | def __init__(self, in_, out_, expand, kernel, stride, num_repeat, skip, se_ratio, drop_connect_ratio=0.2): 141 | super().__init__() 142 | layers = [MBConv(in_, out_, expand, kernel, stride, skip, se_ratio, drop_connect_ratio)] 143 | for i in range(1, num_repeat): 144 | layers.append(MBConv(out_, out_, expand, kernel, 1, skip, se_ratio, drop_connect_ratio)) 145 | self.layers = nn.Sequential(*layers) 146 | 147 | def forward(self, x): 148 | return self.layers(x) 149 | 150 | class ChannelAttention(nn.Module): 151 | def __init__(self, in_planes, ratio=16): 152 | super(ChannelAttention, self).__init__() 153 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 154 | self.max_pool = nn.AdaptiveMaxPool2d(1) 155 | 156 | self.fc1 = nn.Conv2d(in_planes, in_planes // 16, 1, bias=False) 157 | self.relu1 = nn.ReLU() 158 | self.fc2 = nn.Conv2d(in_planes // 16, in_planes, 1, bias=False) 159 | 160 | self.sigmoid = nn.Sigmoid() 161 | self.init_weights() 162 | 163 | def init_weights(self): 164 | for m in self.modules(): 165 | if isinstance(m, nn.Conv2d): 166 | nn.init.kaiming_normal_(m.weight, mode="fan_out") 167 | elif isinstance(m, nn.Linear): 168 | init_range = 1.0 / math.sqrt(m.weight.shape[1]) 169 | nn.init.uniform_(m.weight, -init_range, init_range) 170 | 171 | def forward(self, x): 172 | avg_out = self.fc2(self.relu1(self.fc1(self.avg_pool(x)))) 173 | max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x)))) 174 | out = avg_out + max_out 175 | return self.sigmoid(out) 176 | 177 | 178 | class SpatialAttention(nn.Module): 179 | def __init__(self, kernel_size=7, activation="sigmoid"): 180 | super(SpatialAttention, self).__init__() 181 | 182 | assert kernel_size in (1, 3, 7), 'kernel size must be 3 or 7' 183 | padding = {1:0, 3:1, 7:3} 184 | 185 | self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=padding[kernel_size], bias=False) 186 | if activation == "tanh": 187 | self.sigmoid = nn.Tanh() 188 | else: 189 | self.sigmoid = nn.Sigmoid() 190 | 191 | self.init_weights() 192 | 193 | def init_weights(self): 194 | for m in self.modules(): 195 | if isinstance(m, nn.Conv2d): 196 | nn.init.kaiming_normal_(m.weight, mode="fan_out") 197 | elif isinstance(m, nn.Linear): 198 | init_range = 1.0 / math.sqrt(m.weight.shape[1]) 199 | nn.init.uniform_(m.weight, -init_range, init_range) 200 | 201 | def forward(self, x): 202 | avg_out = torch.mean(x, dim=1, keepdim=True) 203 | max_out, _ = torch.max(x, dim=1, keepdim=True) 204 | x = torch.cat([avg_out, max_out], dim=1) 205 | x = self.conv1(x) 206 | return self.sigmoid(x) 207 | 208 | class SpatialAttentionMMA(nn.Module): 209 | def __init__(self, kernel_size=7, activation="sigmoid"): 210 | super(SpatialAttentionMMA, self).__init__() 211 | 212 | assert kernel_size in (1, 3, 7), 'kernel size must be 3 or 7' 213 | padding = {1:0, 3:1, 7:3} 214 | 215 | self.conv1 = nn.Conv2d(3, 1, kernel_size, padding=padding[kernel_size], bias=False) 216 | if activation == "tanh": 217 | self.sigmoid = nn.Tanh() 218 | else: 219 | self.sigmoid = nn.Sigmoid() 220 | 221 | def forward(self, x): 222 | avg_out = torch.mean(x, dim=1, keepdim=True) 223 | max_out, _ = torch.max(x, dim=1, keepdim=True) 224 | min_out, _ = torch.max(-x, dim=1, keepdim=True) 225 | x = torch.cat([avg_out, max_out, -min_out], dim=1) 226 | x = self.conv1(x) 227 | return self.sigmoid(x) 228 | 229 | class SoftDetectionModule(nn.Module): 230 | def __init__(self, soft_local_max_size=3): 231 | super(SoftDetectionModule, self).__init__() 232 | 233 | self.soft_local_max_size = soft_local_max_size 234 | 235 | self.pad = self.soft_local_max_size // 2 236 | 237 | def forward(self, batch): 238 | b = batch.size(0) 239 | 240 | batch = F.relu(batch) 241 | 242 | max_per_sample = torch.max(batch.view(b, -1), dim=1)[0] 243 | exp = torch.exp(batch / max_per_sample.view(b, 1, 1, 1)) 244 | sum_exp = ( 245 | self.soft_local_max_size ** 2 * 246 | F.avg_pool2d( 247 | F.pad(exp, [self.pad] * 4, mode='constant', value=1.), 248 | self.soft_local_max_size, stride=1 249 | ) 250 | ) 251 | local_max_score = exp / sum_exp 252 | 253 | depth_wise_max = torch.max(batch, dim=1, keepdim=True)[0] 254 | depth_wise_max_score = batch / depth_wise_max 255 | 256 | all_scores = local_max_score * depth_wise_max_score 257 | score = torch.max(all_scores, dim=1, keepdim=True)[0] 258 | 259 | score = score / torch.sum(score.view(b, -1), dim=1).view(b, 1, 1, 1) 260 | 261 | return score 262 | 263 | def build_grid(source_size, target_size, device="cuda"): 264 | k = float(target_size)/float(source_size) 265 | direct = torch.linspace(0,k,target_size).unsqueeze(0).repeat(target_size,1).unsqueeze(-1) 266 | full = torch.cat([direct,direct.transpose(1,0)],dim=2).unsqueeze(0) 267 | return full.to(device) 268 | 269 | def random_crop_grid(x, grid, device="cuda"): 270 | delta = x.size(2)-grid.size(1) 271 | grid = grid.repeat(x.size(0),1,1,1).to(device) 272 | #Add random shifts by x 273 | grid[:,:,:,0] = grid[:,:,:,0]+ torch.FloatTensor(x.size(0)).to(device).random_(0, delta).unsqueeze(-1).unsqueeze(-1).expand(-1, grid.size(1), grid.size(2)) /x.size(2) 274 | #Add random shifts by y 275 | grid[:,:,:,1] = grid[:,:,:,1]+ torch.FloatTensor(x.size(0)).to(device).random_(0, delta).unsqueeze(-1).unsqueeze(-1).expand(-1, grid.size(1), grid.size(2)) /x.size(2) 276 | return grid --------------------------------------------------------------------------------