├── .gitignore ├── start.sh ├── .vscode └── settings.json ├── .dockerignore ├── requirements.txt ├── .github └── ISSUE_TEMPLATE │ ├── bug_report.md │ └── feature_request.md ├── Dockerfile ├── custom_datasets.py ├── License ├── inference.py ├── README.md ├── datasets.py ├── main.py ├── utils.py └── models.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | *.ipynb_checkpoints/ 3 | *.mat 4 | -------------------------------------------------------------------------------- /start.sh: -------------------------------------------------------------------------------- 1 | { python -m visdom.server & } 2>/dev/null 2 | /bin/bash 3 | -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "python.formatting.provider": "black" 3 | } -------------------------------------------------------------------------------- /.dockerignore: -------------------------------------------------------------------------------- 1 | Datasets/ 2 | checkpoints/ 3 | __pycache__/ 4 | *.pyc 5 | .git/ 6 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy>=1.10.0 2 | spectral==0.19 3 | scipy>=0.19.0 4 | tqdm>=4.15.0 5 | visdom>=0.1.5 6 | seaborn>=0.8 7 | scikit-learn>=0.19.0 8 | scikit-image>=0.13.1 9 | torch>=0.4.0 10 | matplotlib>=2.0.2 11 | torchsummary>=1.5 12 | joblib==0.14.1 13 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Describe the bug** 11 | A clear and concise description of what the bug is. 12 | 13 | **To Reproduce** 14 | Steps to reproduce the behavior (e.g. the command that you used). 15 | 16 | **Expected behavior** 17 | A clear and concise description of what you expected to happen. 18 | 19 | **Desktop (please complete the following information):** 20 | - OS: [e.g. Linux/Windows] 21 | - CUDA : yes/no 22 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Suggest an idea for this project 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Is your feature request related to a problem? Please describe.** 11 | A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] 12 | 13 | **Describe the solution you'd like** 14 | A clear and concise description of what you want to happen. 15 | 16 | **Describe alternatives you've considered** 17 | A clear and concise description of any alternative solutions or features you've considered. 18 | 19 | **Additional context** 20 | Add any other context or screenshots about the feature request here. 21 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvidia/cuda:9.1-cudnn7-runtime-ubuntu16.04 2 | MAINTAINER Nicolas Audebert (nicolas.audebert@onera.fr) 3 | 4 | RUN apt-get update && apt-get install -y --no-install-recommends \ 5 | #build-essential \ 6 | #cmake \ 7 | #git \ 8 | curl \ 9 | bzip2 \ 10 | #vim \ 11 | ca-certificates \ 12 | #libjpeg-dev \ 13 | #libpng-dev \ 14 | libgl1-mesa-glx &&\ 15 | rm -rf /var/lib/apt/lists/* 16 | # (libGL is for matplotlib/seaborn) 17 | 18 | WORKDIR /workspace/DeepHyperX/ 19 | RUN mkdir -p Datasets 20 | COPY . . 21 | RUN curl -o ~/miniconda.sh -O https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh && \ 22 | chmod +x ~/miniconda.sh && \ 23 | ~/miniconda.sh -b -p /opt/conda && \ 24 | rm ~/miniconda.sh 25 | #&& \ 26 | #/opt/conda/bin/conda install numpy pyyaml scipy ipython mkl mkl-include && \ 27 | #/opt/conda/bin/conda install -c pytorch magma-cuda90 && \ 28 | #/opt/conda/bin/conda clean -ya 29 | ENV PATH /opt/conda/bin:$PATH 30 | #RUN pip install http://download.pytorch.org/whl/cu90/torch-0.4.0-cp36-cp36m-linux_x86_64.whl 31 | RUN pip install --no-cache-dir -r requirements.txt 32 | #RUN python main.py --download KSC Botswana PaviaU PaviaC IndianPines 33 | 34 | EXPOSE 8097 35 | 36 | ENTRYPOINT ["sh", "start.sh"] 37 | -------------------------------------------------------------------------------- /custom_datasets.py: -------------------------------------------------------------------------------- 1 | from utils import open_file 2 | import numpy as np 3 | 4 | CUSTOM_DATASETS_CONFIG = { 5 | "DFC2018_HSI": { 6 | "img": "2018_IEEE_GRSS_DFC_HSI_TR.HDR", 7 | "gt": "2018_IEEE_GRSS_DFC_GT_TR.tif", 8 | "download": False, 9 | "loader": lambda folder: dfc2018_loader(folder), 10 | } 11 | } 12 | 13 | 14 | def dfc2018_loader(folder): 15 | img = open_file(folder + "2018_IEEE_GRSS_DFC_HSI_TR.HDR")[:, :, :-2] 16 | gt = open_file(folder + "2018_IEEE_GRSS_DFC_GT_TR.tif") 17 | gt = gt.astype("uint8") 18 | 19 | rgb_bands = (47, 31, 15) 20 | 21 | label_values = [ 22 | "Unclassified", 23 | "Healthy grass", 24 | "Stressed grass", 25 | "Artificial turf", 26 | "Evergreen trees", 27 | "Deciduous trees", 28 | "Bare earth", 29 | "Water", 30 | "Residential buildings", 31 | "Non-residential buildings", 32 | "Roads", 33 | "Sidewalks", 34 | "Crosswalks", 35 | "Major thoroughfares", 36 | "Highways", 37 | "Railways", 38 | "Paved parking lots", 39 | "Unpaved parking lots", 40 | "Cars", 41 | "Trains", 42 | "Stadium seats", 43 | ] 44 | ignored_labels = [0] 45 | palette = None 46 | return img, gt, rgb_bands, ignored_labels, label_values, palette 47 | -------------------------------------------------------------------------------- /License: -------------------------------------------------------------------------------- 1 | # License information 2 | 3 | Code for the DeepHyperX toolbox is dual licensed depending on applications, research or commercial. 4 | 5 | --- 6 | 7 | ## COMMERCIAL PURPOSES 8 | 9 | Please contact the ONERA [www.onera.fr/en/contact-us](www.onera.fr/en/contact-us) for additional information or directly the authors Nicolas Audebert or Bertrand Le Saux. 10 | 11 | --- 12 | 13 | ## RESEARCH AND NON COMMERCIAL PURPOSES 14 | 15 | #### Code license 16 | 17 | For research and non commercial purposes, all the code and documentation is released under the GPLv3 license: 18 | 19 | Copyright (c) 2018 ONERA and IRISA, Nicolas Audebert, Bertrand Le Saux, Sébastien Lefèvre. 20 | 21 | This program is free software; you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation; either version 3 of the License, or any later version. This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details. You should have received a copy of the GNU General Public License along with this program; if not, write to the Free Software Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. 22 | 23 | PLEASE ACKNOWLEDGE THE ORIGINAL AUTHORS AND PUBLICATION ACCORDING TO THE REPOSITORY github.com/nshaud/DeepHyperx OR IF NOT AVAILABLE: 24 | Nicolas Audebert, Bertrand Le Saux and Sébastien Lefèvre 25 | "Deep Learning for Classification of Hyperspectral Data: A comparative review", 26 | IEEE Geosciences and Remote Sensing Magazine, 2019. 27 | -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | # Python 2/3 compatiblity 2 | from __future__ import print_function 3 | from __future__ import division 4 | import joblib 5 | import os 6 | from utils import convert_to_color_, convert_from_color_, get_device 7 | from datasets import open_file 8 | from models import get_model, test 9 | import numpy as np 10 | import seaborn as sns 11 | from skimage import io 12 | import argparse 13 | import torch 14 | 15 | # Test options 16 | parser = argparse.ArgumentParser( 17 | description="Run deep learning experiments on" " various hyperspectral datasets" 18 | ) 19 | parser.add_argument( 20 | "--model", 21 | type=str, 22 | default=None, 23 | help="Model to train. Available:\n" 24 | "SVM (linear), " 25 | "SVM_grid (grid search on linear, poly and RBF kernels), " 26 | "baseline (fully connected NN), " 27 | "hu (1D CNN), " 28 | "hamida (3D CNN + 1D classifier), " 29 | "lee (3D FCN), " 30 | "chen (3D CNN), " 31 | "li (3D CNN), " 32 | "he (3D CNN), " 33 | "luo (3D CNN), " 34 | "sharma (2D CNN), " 35 | "boulch (1D semi-supervised CNN), " 36 | "liu (3D semi-supervised CNN), " 37 | "mou (1D RNN)", 38 | ) 39 | parser.add_argument( 40 | "--cuda", 41 | type=int, 42 | default=-1, 43 | help="Specify CUDA device (defaults to -1, which learns on CPU)", 44 | ) 45 | parser.add_argument( 46 | "--checkpoint", 47 | type=str, 48 | default=None, 49 | help="Weights to use for initialization, e.g. a checkpoint", 50 | ) 51 | 52 | group_test = parser.add_argument_group("Test") 53 | group_test.add_argument( 54 | "--test_stride", 55 | type=int, 56 | default=1, 57 | help="Sliding window step stride during inference (default = 1)", 58 | ) 59 | group_test.add_argument( 60 | "--image", 61 | type=str, 62 | default=None, 63 | nargs="?", 64 | help="Path to an image on which to run inference.", 65 | ) 66 | group_test.add_argument( 67 | "--only_test", 68 | type=str, 69 | default=None, 70 | nargs="?", 71 | help="Choose the data on which to test the trained algorithm ", 72 | ) 73 | group_test.add_argument( 74 | "--mat", 75 | type=str, 76 | default=None, 77 | nargs="?", 78 | help="In case of a .mat file, define the variable to call inside the file", 79 | ) 80 | group_test.add_argument( 81 | "--n_classes", 82 | type=int, 83 | default=None, 84 | nargs="?", 85 | help="When using a trained algorithm, specified the number of classes of this algorithm", 86 | ) 87 | # Training options 88 | group_train = parser.add_argument_group("Model") 89 | group_train.add_argument( 90 | "--patch_size", 91 | type=int, 92 | help="Size of the spatial neighbourhood (optional, if " 93 | "absent will be set by the model)", 94 | ) 95 | group_train.add_argument( 96 | "--batch_size", 97 | type=int, 98 | help="Batch size (optional, if absent will be set by the model", 99 | ) 100 | 101 | args = parser.parse_args() 102 | CUDA_DEVICE = get_device(args.cuda) 103 | MODEL = args.model 104 | # Testing file 105 | MAT = args.mat 106 | N_CLASSES = args.n_classes 107 | INFERENCE = args.image 108 | TEST_STRIDE = args.test_stride 109 | CHECKPOINT = args.checkpoint 110 | 111 | img_filename = os.path.basename(INFERENCE) 112 | basename = MODEL + img_filename 113 | dirname = os.path.dirname(INFERENCE) 114 | 115 | img = open_file(INFERENCE) 116 | if MAT is not None: 117 | img = img[MAT] 118 | # Normalization 119 | img = np.asarray(img, dtype="float32") 120 | img = (img - np.min(img)) / (np.max(img) - np.min(img)) 121 | N_BANDS = img.shape[-1] 122 | hyperparams = vars(args) 123 | hyperparams.update( 124 | { 125 | "n_classes": N_CLASSES, 126 | "n_bands": N_BANDS, 127 | "device": CUDA_DEVICE, 128 | "ignored_labels": [0], 129 | } 130 | ) 131 | hyperparams = dict((k, v) for k, v in hyperparams.items() if v is not None) 132 | 133 | palette = {0: (0, 0, 0)} 134 | for k, color in enumerate(sns.color_palette("hls", N_CLASSES)): 135 | palette[k + 1] = tuple(np.asarray(255 * np.array(color), dtype="uint8")) 136 | invert_palette = {v: k for k, v in palette.items()} 137 | 138 | 139 | def convert_to_color(x): 140 | return convert_to_color_(x, palette=palette) 141 | 142 | 143 | def convert_from_color(x): 144 | return convert_from_color_(x, palette=invert_palette) 145 | 146 | 147 | if MODEL in ["SVM", "SVM_grid", "SGD", "nearest"]: 148 | model = joblib.load(CHECKPOINT) 149 | w, h = img.shape[:2] 150 | X = img.reshape((w * h, N_BANDS)) 151 | prediction = model.predict(X) 152 | prediction = prediction.reshape(img.shape[:2]) 153 | else: 154 | model, _, _, hyperparams = get_model(MODEL, **hyperparams) 155 | model.load_state_dict(torch.load(CHECKPOINT)) 156 | probabilities = test(model, img, hyperparams) 157 | prediction = np.argmax(probabilities, axis=-1) 158 | 159 | filename = dirname + "/" + basename + ".tif" 160 | io.imsave(filename, prediction) 161 | basename = "color_" + basename 162 | filename = dirname + "/" + basename + ".tif" 163 | io.imsave(filename, convert_to_color(prediction)) 164 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DeepHyperX 2 | 3 | A Python tool to perform deep learning experiments on various hyperspectral datasets. 4 | 5 | ![https://www.onera.fr/en/research/information-processing-and-systems-domain](https://www.onera.fr/sites/default/files/logo-onera.png|height=60) 6 | 7 | ![https://www-obelix.irisa.fr/](https://www.irisa.fr/sites/all/themes/irisa_theme/logoIRISA-web.png|width=60) 8 | 9 | ## Reference 10 | 11 | This toolbox was used for our review paper in Geoscience and Remote Sensing Magazine : 12 | > N. Audebert, B. Le Saux and S. Lefevre, "*Deep Learning for Classification of Hyperspectral Data: A Comparative Review*," in IEEE Geoscience and Remote Sensing Magazine, vol. 7, no. 2, pp. 159-173, June 2019. 13 | 14 | Bibtex format : 15 | 16 | > @article{8738045, 17 | author={N. {Audebert} and B. {Le Saux} and S. {Lefèvre}}, 18 | journal={IEEE Geoscience and Remote Sensing Magazine}, 19 | title={Deep Learning for Classification of Hyperspectral Data: A Comparative Review}, 20 | year={2019}, 21 | volume={7}, 22 | number={2}, 23 | pages={159-173}, 24 | doi={10.1109/MGRS.2019.2912563}, 25 | ISSN={2373-7468}, 26 | month={June},} 27 | 28 | ## Requirements 29 | 30 | This tool is compatible with Python 2.7 and Python 3.5+. 31 | 32 | It is based on the [PyTorch](http://pytorch.org/) deep learning and GPU computing framework and use the [Visdom](https://github.com/facebookresearch/visdom) visualization server. 33 | 34 | ## Setup 35 | 36 | The easiest way to install this code is to create a [Python virtual environment](https://docs.python.org/3/tutorial/venv.html) and to install dependencies using: 37 | `pip install -r requirements.txt` 38 | 39 | (on Windows you should use `pip install -r requirements.txt -f https://download.pytorch.org/whl/torch_stable.html`) 40 | 41 | ### Docker 42 | 43 | Alternatively, it is possible to run the [Docker](https://www.docker.com/community-edition) image. 44 | 45 | Grab the image using: 46 | ``` 47 | docker pull registry.gitlab.inria.fr/naudeber/deephyperx:preview 48 | ``` 49 | 50 | And then run the image using: 51 | ``` 52 | docker run -p 9999:8097 -ti --rm -v `pwd`:/workspace/DeepHyperX/ registry.gitlab.inria.fr/naudeber/deephyperx:preview 53 | ``` 54 | 55 | This command: 56 | * starts a Docker container using the image `registry.gitlab.inria.fr/naudeber/deephyperx:preview` 57 | * starts an interactive shell session `-ti` 58 | * mounts the current folder in the `/workspace/DeepHyperX/` path of the container 59 | * binds the local port 9999 to the container port 8097 (for Visdom) 60 | * removes the container `--rm` when the user has finished. 61 | 62 | All data and products are stored in the current folder. 63 | 64 | Users can build the Docker image locally using the `Dockerfile` using the command `docker build .`. 65 | 66 | ## Hyperspectral datasets 67 | 68 | Several public hyperspectral datasets are available on the [UPV/EHU](http://www.ehu.eus/ccwintco/index.php?title=Hyperspectral_Remote_Sensing_Scenes) wiki. Users can download those beforehand or let the tool download them. The default dataset folder is `./Datasets/`, although this can be modified at runtime using the `--folder` arg. 69 | 70 | At this time, the tool automatically downloads the following public datasets: 71 | * Pavia University 72 | * Pavia Center 73 | * Kennedy Space Center 74 | * Indian Pines 75 | * Botswana 76 | 77 | The [Data Fusion Contest 2018 hyperspectral dataset]() is also preconfigured, although users need to download it on the [DASE](http://dase.ticinumaerospace.com/) website and store it in the dataset folder under `DFC2018_HSI`. 78 | 79 | An example dataset folder has the following structure: 80 | ``` 81 | Datasets 82 | ├── Botswana 83 | │   ├── Botswana_gt.mat 84 | │   └── Botswana.mat 85 | ├── DFC2018_HSI 86 | │   ├── 2018_IEEE_GRSS_DFC_GT_TR.tif 87 | │   ├── 2018_IEEE_GRSS_DFC_HSI_TR 88 | │   ├── 2018_IEEE_GRSS_DFC_HSI_TR.aux.xml 89 | │   ├── 2018_IEEE_GRSS_DFC_HSI_TR.HDR 90 | ├── IndianPines 91 | │   ├── Indian_pines_corrected.mat 92 | │   ├── Indian_pines_gt.mat 93 | ├── KSC 94 | │   ├── KSC_gt.mat 95 | │   └── KSC.mat 96 | ├── PaviaC 97 | │   ├── Pavia_gt.mat 98 | │   └── Pavia.mat 99 | └── PaviaU 100 | ├── PaviaU_gt.mat 101 | └── PaviaU.mat 102 | ``` 103 | 104 | ### Adding a new dataset 105 | 106 | Adding a custom dataset can be done by modifying the `custom_datasets.py` file. Developers should add a new entry to the `CUSTOM_DATASETS_CONFIG` variable and define a specific data loader for their use case. 107 | 108 | ## Models 109 | 110 | Currently, this tool implements several SVM variants from the [scikit-learn](http://scikit-learn.org/stable/) library and many state-of-the-art deep networks implemented in PyTorch. 111 | * SVM (linear, RBF and poly kernels with grid search) 112 | * SGD (linear SVM using stochastic gradient descent for fast optimization) 113 | * baseline neural network (4 fully connected layers with dropout) 114 | * 1D CNN ([Deep Convolutional Neural Networks for Hyperspectral Image Classification, Hu et al., Journal of Sensors 2015](https://www.hindawi.com/journals/js/2015/258619/)) 115 | * Semi-supervised 1D CNN ([Autoencodeurs pour la visualisation d'images hyperspectrales, Boulch et al., GRETSI 2017](https://delta-onera.github.io/publication/2017-GRETSI)) 116 | * 2D CNN ([Hyperspectral CNN for Image Classification & Band Selection, with Application to Face Recognition, Sharma et al, technical report 2018](https://lirias.kuleuven.be/bitstream/123456789/566754/1/4166_final.pdf)) 117 | * Semi-supervised 2D CNN ([A semi-supervised Convolutional Neural Network for Hyperspectral Image Classification, Liu et al, Remote Sensing Letters 2017](https://www.tandfonline.com/doi/abs/10.1080/2150704X.2017.1331053)) 118 | * 3D CNN ([3-D Deep Learning Approach for Remote Sensing Image Classification, Hamida et al., TGRS 2018](https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=8344565)) 119 | * 3D FCN ([Contextual Deep CNN Based Hyperspectral Classification, Lee and Kwon, IGARSS 2016](https://arxiv.org/abs/1604.03519)) 120 | * 3D CNN ([Deep Feature Extraction and Classification of Hyperspectral Images Based on Convolutional Neural Networks, Chen et al., TGRS 2016](http://elib.dlr.de/106352/2/CNN.pdf)) 121 | * 3D CNN ([Spectral–Spatial Classification of Hyperspectral Imagery with 3D Convolutional Neural Network, Li et al., Remote Sensing 2017](http://www.mdpi.com/2072-4292/9/1/67)) 122 | * 3D CNN ([HSI-CNN: A Novel Convolution Neural Network for Hyperspectral Image, Luo et al, ICPR 2018](https://arxiv.org/abs/1802.10478)) 123 | * Multi-scale 3D CNN ([Multi-scale 3D Deep Convolutional Neural Network for Hyperspectral Image Classification, He et al, ICIP 2017](https://ieeexplore.ieee.org/document/8297014/)) 124 | 125 | ### Adding a new model 126 | 127 | Adding a custom deep network can be done by modifying the `models.py` file. This implies creating a new class for the custom deep network and altering the `get_model` function. 128 | 129 | ## Usage 130 | 131 | Start a Visdom server: 132 | `python -m visdom.server` 133 | and go to [`http://localhost:8097`](http://localhost:8097) to see the visualizations (or [`http://localhost:9999`](http://localhost:9999) if you use Docker). 134 | 135 | Then, run the script `main.py`. 136 | 137 | The most useful arguments are: 138 | * `--model` to specify the model (e.g. 'svm', 'nn', 'hamida', 'lee', 'chen', 'li'), 139 | * `--dataset` to specify which dataset to use (e.g. 'PaviaC', 'PaviaU', 'IndianPines', 'KSC', 'Botswana'), 140 | * the `--cuda` switch to run the neural nets on GPU. The tool fallbacks on CPU if this switch is not specified. 141 | 142 | There are more parameters that can be used to control more finely the behaviour of the tool. See `python main.py -h` for more information. 143 | 144 | Examples: 145 | * `python main.py --model SVM --dataset IndianPines --training_sample 0.3` 146 | This runs a grid search on SVM on the Indian Pines dataset, using 30% of the samples for training and the rest for testing. Results are displayed in the visdom panel. 147 | * `python main.py --model nn --dataset PaviaU --training_sample 0.1 --cuda` 148 | This runs on GPU a basic 4-layers fully connected neural network on the Pavia University dataset, using 10% of the samples for training. 149 | * `python main.py --model hamida --dataset PaviaU --training_sample 0.5 --patch_size 7 --epoch 50 --cuda` 150 | This runs on GPU the 3D CNN from Hamida et al. on the Pavia University dataset with a patch size of 7, using 50% of the samples for training and optimizing for 50 epochs. 151 | 152 | [![Say Thanks!](https://img.shields.io/badge/Say%20Thanks-!-1EAEDB.svg)](https://saythanks.io/to/nshaud) 153 | -------------------------------------------------------------------------------- /datasets.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | This file contains the PyTorch dataset for hyperspectral images and 4 | related helpers. 5 | """ 6 | import spectral 7 | import numpy as np 8 | import torch 9 | import torch.utils 10 | import torch.utils.data 11 | import os 12 | from tqdm import tqdm 13 | 14 | try: 15 | # Python 3 16 | from urllib.request import urlretrieve 17 | except ImportError: 18 | # Python 2 19 | from urllib import urlretrieve 20 | 21 | from utils import open_file 22 | 23 | DATASETS_CONFIG = { 24 | "PaviaC": { 25 | "urls": [ 26 | "http://www.ehu.eus/ccwintco/uploads/e/e3/Pavia.mat", 27 | "http://www.ehu.eus/ccwintco/uploads/5/53/Pavia_gt.mat", 28 | ], 29 | "img": "Pavia.mat", 30 | "gt": "Pavia_gt.mat", 31 | }, 32 | "Salinas": { 33 | "urls": [ 34 | "http://www.ehu.eus/ccwintco/uploads/a/a3/Salinas_corrected.mat", 35 | "http://www.ehu.eus/ccwintco/uploads/f/fa/Salinas_gt.mat", 36 | ], 37 | "img": "Salinas_corrected.mat", 38 | "gt": "Salinas_gt.mat", 39 | }, 40 | "PaviaU": { 41 | "urls": [ 42 | "http://www.ehu.eus/ccwintco/uploads/e/ee/PaviaU.mat", 43 | "http://www.ehu.eus/ccwintco/uploads/5/50/PaviaU_gt.mat", 44 | ], 45 | "img": "PaviaU.mat", 46 | "gt": "PaviaU_gt.mat", 47 | }, 48 | "KSC": { 49 | "urls": [ 50 | "http://www.ehu.es/ccwintco/uploads/2/26/KSC.mat", 51 | "http://www.ehu.es/ccwintco/uploads/a/a6/KSC_gt.mat", 52 | ], 53 | "img": "KSC.mat", 54 | "gt": "KSC_gt.mat", 55 | }, 56 | "IndianPines": { 57 | "urls": [ 58 | "http://www.ehu.eus/ccwintco/uploads/6/67/Indian_pines_corrected.mat", 59 | "http://www.ehu.eus/ccwintco/uploads/c/c4/Indian_pines_gt.mat", 60 | ], 61 | "img": "Indian_pines_corrected.mat", 62 | "gt": "Indian_pines_gt.mat", 63 | }, 64 | "Botswana": { 65 | "urls": [ 66 | "http://www.ehu.es/ccwintco/uploads/7/72/Botswana.mat", 67 | "http://www.ehu.es/ccwintco/uploads/5/58/Botswana_gt.mat", 68 | ], 69 | "img": "Botswana.mat", 70 | "gt": "Botswana_gt.mat", 71 | }, 72 | } 73 | 74 | try: 75 | from custom_datasets import CUSTOM_DATASETS_CONFIG 76 | 77 | DATASETS_CONFIG.update(CUSTOM_DATASETS_CONFIG) 78 | except ImportError: 79 | pass 80 | 81 | 82 | class TqdmUpTo(tqdm): 83 | """Provides `update_to(n)` which uses `tqdm.update(delta_n)`.""" 84 | 85 | def update_to(self, b=1, bsize=1, tsize=None): 86 | """ 87 | b : int, optional 88 | Number of blocks transferred so far [default: 1]. 89 | bsize : int, optional 90 | Size of each block (in tqdm units) [default: 1]. 91 | tsize : int, optional 92 | Total size (in tqdm units). If [default: None] remains unchanged. 93 | """ 94 | if tsize is not None: 95 | self.total = tsize 96 | self.update(b * bsize - self.n) # will also set self.n = b * bsize 97 | 98 | 99 | def get_dataset(dataset_name, target_folder="./", datasets=DATASETS_CONFIG): 100 | """Gets the dataset specified by name and return the related components. 101 | Args: 102 | dataset_name: string with the name of the dataset 103 | target_folder (optional): folder to store the datasets, defaults to ./ 104 | datasets (optional): dataset configuration dictionary, defaults to prebuilt one 105 | Returns: 106 | img: 3D hyperspectral image (WxHxB) 107 | gt: 2D int array of labels 108 | label_values: list of class names 109 | ignored_labels: list of int classes to ignore 110 | rgb_bands: int tuple that correspond to red, green and blue bands 111 | """ 112 | palette = None 113 | 114 | if dataset_name not in datasets.keys(): 115 | raise ValueError("{} dataset is unknown.".format(dataset_name)) 116 | 117 | dataset = datasets[dataset_name] 118 | 119 | folder = target_folder + datasets[dataset_name].get("folder", dataset_name + "/") 120 | if dataset.get("download", True): 121 | # Download the dataset if is not present 122 | if not os.path.isdir(folder): 123 | os.makedirs(folder) 124 | for url in datasets[dataset_name]["urls"]: 125 | # download the files 126 | filename = url.split("/")[-1] 127 | if not os.path.exists(folder + filename): 128 | with TqdmUpTo( 129 | unit="B", 130 | unit_scale=True, 131 | miniters=1, 132 | desc="Downloading {}".format(filename), 133 | ) as t: 134 | urlretrieve(url, filename=folder + filename, reporthook=t.update_to) 135 | elif not os.path.isdir(folder): 136 | print("WARNING: {} is not downloadable.".format(dataset_name)) 137 | 138 | if dataset_name == "PaviaC": 139 | # Load the image 140 | img = open_file(folder + "Pavia.mat")["pavia"] 141 | 142 | rgb_bands = (55, 41, 12) 143 | 144 | gt = open_file(folder + "Pavia_gt.mat")["pavia_gt"] 145 | 146 | label_values = [ 147 | "Undefined", 148 | "Water", 149 | "Trees", 150 | "Asphalt", 151 | "Self-Blocking Bricks", 152 | "Bitumen", 153 | "Tiles", 154 | "Shadows", 155 | "Meadows", 156 | "Bare Soil", 157 | ] 158 | 159 | ignored_labels = [0] 160 | 161 | elif dataset_name == "PaviaU": 162 | # Load the image 163 | img = open_file(folder + "PaviaU.mat")["paviaU"] 164 | 165 | rgb_bands = (55, 41, 12) 166 | 167 | gt = open_file(folder + "PaviaU_gt.mat")["paviaU_gt"] 168 | 169 | label_values = [ 170 | "Undefined", 171 | "Asphalt", 172 | "Meadows", 173 | "Gravel", 174 | "Trees", 175 | "Painted metal sheets", 176 | "Bare Soil", 177 | "Bitumen", 178 | "Self-Blocking Bricks", 179 | "Shadows", 180 | ] 181 | 182 | ignored_labels = [0] 183 | 184 | elif dataset_name == "Salinas": 185 | img = open_file(folder + "Salinas_corrected.mat")["salinas_corrected"] 186 | 187 | rgb_bands = (43, 21, 11) # AVIRIS sensor 188 | 189 | gt = open_file(folder + "Salinas_gt.mat")["salinas_gt"] 190 | 191 | label_values = [ 192 | "Undefined", 193 | "Brocoli_green_weeds_1", 194 | "Brocoli_green_weeds_2", 195 | "Fallow", 196 | "Fallow_rough_plow", 197 | "Fallow_smooth", 198 | "Stubble", 199 | "Celery", 200 | "Grapes_untrained", 201 | "Soil_vinyard_develop", 202 | "Corn_senesced_green_weeds", 203 | "Lettuce_romaine_4wk", 204 | "Lettuce_romaine_5wk", 205 | "Lettuce_romaine_6wk", 206 | "Lettuce_romaine_7wk", 207 | "Vinyard_untrained", 208 | "Vinyard_vertical_trellis", 209 | ] 210 | 211 | ignored_labels = [0] 212 | 213 | elif dataset_name == "IndianPines": 214 | # Load the image 215 | img = open_file(folder + "Indian_pines_corrected.mat") 216 | img = img["indian_pines_corrected"] 217 | 218 | rgb_bands = (43, 21, 11) # AVIRIS sensor 219 | 220 | gt = open_file(folder + "Indian_pines_gt.mat")["indian_pines_gt"] 221 | label_values = [ 222 | "Undefined", 223 | "Alfalfa", 224 | "Corn-notill", 225 | "Corn-mintill", 226 | "Corn", 227 | "Grass-pasture", 228 | "Grass-trees", 229 | "Grass-pasture-mowed", 230 | "Hay-windrowed", 231 | "Oats", 232 | "Soybean-notill", 233 | "Soybean-mintill", 234 | "Soybean-clean", 235 | "Wheat", 236 | "Woods", 237 | "Buildings-Grass-Trees-Drives", 238 | "Stone-Steel-Towers", 239 | ] 240 | 241 | ignored_labels = [0] 242 | 243 | elif dataset_name == "Botswana": 244 | # Load the image 245 | img = open_file(folder + "Botswana.mat")["Botswana"] 246 | 247 | rgb_bands = (75, 33, 15) 248 | 249 | gt = open_file(folder + "Botswana_gt.mat")["Botswana_gt"] 250 | label_values = [ 251 | "Undefined", 252 | "Water", 253 | "Hippo grass", 254 | "Floodplain grasses 1", 255 | "Floodplain grasses 2", 256 | "Reeds", 257 | "Riparian", 258 | "Firescar", 259 | "Island interior", 260 | "Acacia woodlands", 261 | "Acacia shrublands", 262 | "Acacia grasslands", 263 | "Short mopane", 264 | "Mixed mopane", 265 | "Exposed soils", 266 | ] 267 | 268 | ignored_labels = [0] 269 | 270 | elif dataset_name == "KSC": 271 | # Load the image 272 | img = open_file(folder + "KSC.mat")["KSC"] 273 | 274 | rgb_bands = (43, 21, 11) # AVIRIS sensor 275 | 276 | gt = open_file(folder + "KSC_gt.mat")["KSC_gt"] 277 | label_values = [ 278 | "Undefined", 279 | "Scrub", 280 | "Willow swamp", 281 | "Cabbage palm hammock", 282 | "Cabbage palm/oak hammock", 283 | "Slash pine", 284 | "Oak/broadleaf hammock", 285 | "Hardwood swamp", 286 | "Graminoid marsh", 287 | "Spartina marsh", 288 | "Cattail marsh", 289 | "Salt marsh", 290 | "Mud flats", 291 | "Wate", 292 | ] 293 | 294 | ignored_labels = [0] 295 | else: 296 | # Custom dataset 297 | ( 298 | img, 299 | gt, 300 | rgb_bands, 301 | ignored_labels, 302 | label_values, 303 | palette, 304 | ) = CUSTOM_DATASETS_CONFIG[dataset_name]["loader"](folder) 305 | 306 | # Filter NaN out 307 | nan_mask = np.isnan(img.sum(axis=-1)) 308 | if np.count_nonzero(nan_mask) > 0: 309 | print( 310 | "Warning: NaN have been found in the data. It is preferable to remove them beforehand. Learning on NaN data is disabled." 311 | ) 312 | img[nan_mask] = 0 313 | gt[nan_mask] = 0 314 | ignored_labels.append(0) 315 | 316 | ignored_labels = list(set(ignored_labels)) 317 | # Normalization 318 | img = np.asarray(img, dtype="float32") 319 | img = (img - np.min(img)) / (np.max(img) - np.min(img)) 320 | return img, gt, label_values, ignored_labels, rgb_bands, palette 321 | 322 | 323 | class HyperX(torch.utils.data.Dataset): 324 | """ Generic class for a hyperspectral scene """ 325 | 326 | def __init__(self, data, gt, **hyperparams): 327 | """ 328 | Args: 329 | data: 3D hyperspectral image 330 | gt: 2D array of labels 331 | patch_size: int, size of the spatial neighbourhood 332 | center_pixel: bool, set to True to consider only the label of the 333 | center pixel 334 | data_augmentation: bool, set to True to perform random flips 335 | supervision: 'full' or 'semi' supervised algorithms 336 | """ 337 | super(HyperX, self).__init__() 338 | self.data = data 339 | self.label = gt 340 | self.name = hyperparams["dataset"] 341 | self.patch_size = hyperparams["patch_size"] 342 | self.ignored_labels = set(hyperparams["ignored_labels"]) 343 | self.flip_augmentation = hyperparams["flip_augmentation"] 344 | self.radiation_augmentation = hyperparams["radiation_augmentation"] 345 | self.mixture_augmentation = hyperparams["mixture_augmentation"] 346 | self.center_pixel = hyperparams["center_pixel"] 347 | supervision = hyperparams["supervision"] 348 | # Fully supervised : use all pixels with label not ignored 349 | if supervision == "full": 350 | mask = np.ones_like(gt) 351 | for l in self.ignored_labels: 352 | mask[gt == l] = 0 353 | # Semi-supervised : use all pixels, except padding 354 | elif supervision == "semi": 355 | mask = np.ones_like(gt) 356 | x_pos, y_pos = np.nonzero(mask) 357 | p = self.patch_size // 2 358 | self.indices = np.array( 359 | [ 360 | (x, y) 361 | for x, y in zip(x_pos, y_pos) 362 | if x > p and x < data.shape[0] - p and y > p and y < data.shape[1] - p 363 | ] 364 | ) 365 | self.labels = [self.label[x, y] for x, y in self.indices] 366 | np.random.shuffle(self.indices) 367 | 368 | @staticmethod 369 | def flip(*arrays): 370 | horizontal = np.random.random() > 0.5 371 | vertical = np.random.random() > 0.5 372 | if horizontal: 373 | arrays = [np.fliplr(arr) for arr in arrays] 374 | if vertical: 375 | arrays = [np.flipud(arr) for arr in arrays] 376 | return arrays 377 | 378 | @staticmethod 379 | def radiation_noise(data, alpha_range=(0.9, 1.1), beta=1 / 25): 380 | alpha = np.random.uniform(*alpha_range) 381 | noise = np.random.normal(loc=0.0, scale=1.0, size=data.shape) 382 | return alpha * data + beta * noise 383 | 384 | def mixture_noise(self, data, label, beta=1 / 25): 385 | alpha1, alpha2 = np.random.uniform(0.01, 1.0, size=2) 386 | noise = np.random.normal(loc=0.0, scale=1.0, size=data.shape) 387 | data2 = np.zeros_like(data) 388 | for idx, value in np.ndenumerate(label): 389 | if value not in self.ignored_labels: 390 | l_indices = np.nonzero(self.labels == value)[0] 391 | l_indice = np.random.choice(l_indices) 392 | assert self.labels[l_indice] == value 393 | x, y = self.indices[l_indice] 394 | data2[idx] = self.data[x, y] 395 | return (alpha1 * data + alpha2 * data2) / (alpha1 + alpha2) + beta * noise 396 | 397 | def __len__(self): 398 | return len(self.indices) 399 | 400 | def __getitem__(self, i): 401 | x, y = self.indices[i] 402 | x1, y1 = x - self.patch_size // 2, y - self.patch_size // 2 403 | x2, y2 = x1 + self.patch_size, y1 + self.patch_size 404 | 405 | data = self.data[x1:x2, y1:y2] 406 | label = self.label[x1:x2, y1:y2] 407 | 408 | if self.flip_augmentation and self.patch_size > 1: 409 | # Perform data augmentation (only on 2D patches) 410 | data, label = self.flip(data, label) 411 | if self.radiation_augmentation and np.random.random() < 0.1: 412 | data = self.radiation_noise(data) 413 | if self.mixture_augmentation and np.random.random() < 0.2: 414 | data = self.mixture_noise(data, label) 415 | 416 | # Copy the data into numpy arrays (PyTorch doesn't like numpy views) 417 | data = np.asarray(np.copy(data).transpose((2, 0, 1)), dtype="float32") 418 | label = np.asarray(np.copy(label), dtype="int64") 419 | 420 | # Load the data into PyTorch tensors 421 | data = torch.from_numpy(data) 422 | label = torch.from_numpy(label) 423 | # Extract the center label if needed 424 | if self.center_pixel and self.patch_size > 1: 425 | label = label[self.patch_size // 2, self.patch_size // 2] 426 | # Remove unused dimensions when we work with invidual spectrums 427 | elif self.patch_size == 1: 428 | data = data[:, 0, 0] 429 | label = label[0, 0] 430 | 431 | # Add a fourth dimension for 3D CNN 432 | if self.patch_size > 1: 433 | # Make 4D data ((Batch x) Planes x Channels x Width x Height) 434 | data = data.unsqueeze(0) 435 | return data, label 436 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | DEEP LEARNING FOR HYPERSPECTRAL DATA. 4 | 5 | This script allows the user to run several deep models (and SVM baselines) 6 | against various hyperspectral datasets. It is designed to quickly benchmark 7 | state-of-the-art CNNs on various public hyperspectral datasets. 8 | 9 | This code is released under the GPLv3 license for non-commercial and research 10 | purposes only. 11 | For commercial use, please contact the authors. 12 | """ 13 | # Python 2/3 compatiblity 14 | from __future__ import print_function 15 | from __future__ import division 16 | 17 | # Torch 18 | import torch 19 | import torch.utils.data as data 20 | from torchsummary import summary 21 | 22 | # Numpy, scipy, scikit-image, spectral 23 | import numpy as np 24 | import sklearn.svm 25 | import sklearn.model_selection 26 | from skimage import io 27 | 28 | # Visualization 29 | import seaborn as sns 30 | import visdom 31 | 32 | import os 33 | from utils import ( 34 | metrics, 35 | convert_to_color_, 36 | convert_from_color_, 37 | display_dataset, 38 | display_predictions, 39 | explore_spectrums, 40 | plot_spectrums, 41 | sample_gt, 42 | build_dataset, 43 | show_results, 44 | compute_imf_weights, 45 | get_device, 46 | ) 47 | from datasets import get_dataset, HyperX, open_file, DATASETS_CONFIG 48 | from models import get_model, train, test, save_model 49 | 50 | import argparse 51 | 52 | dataset_names = [ 53 | v["name"] if "name" in v.keys() else k for k, v in DATASETS_CONFIG.items() 54 | ] 55 | 56 | # Argument parser for CLI interaction 57 | parser = argparse.ArgumentParser( 58 | description="Run deep learning experiments on" " various hyperspectral datasets" 59 | ) 60 | parser.add_argument( 61 | "--dataset", type=str, default=None, choices=dataset_names, help="Dataset to use." 62 | ) 63 | parser.add_argument( 64 | "--model", 65 | type=str, 66 | default=None, 67 | help="Model to train. Available:\n" 68 | "SVM (linear), " 69 | "SVM_grid (grid search on linear, poly and RBF kernels), " 70 | "baseline (fully connected NN), " 71 | "hu (1D CNN), " 72 | "hamida (3D CNN + 1D classifier), " 73 | "lee (3D FCN), " 74 | "chen (3D CNN), " 75 | "li (3D CNN), " 76 | "he (3D CNN), " 77 | "luo (3D CNN), " 78 | "sharma (2D CNN), " 79 | "boulch (1D semi-supervised CNN), " 80 | "liu (3D semi-supervised CNN), " 81 | "mou (1D RNN)", 82 | ) 83 | parser.add_argument( 84 | "--folder", 85 | type=str, 86 | help="Folder where to store the " 87 | "datasets (defaults to the current working directory).", 88 | default="./Datasets/", 89 | ) 90 | parser.add_argument( 91 | "--cuda", 92 | type=int, 93 | default=-1, 94 | help="Specify CUDA device (defaults to -1, which learns on CPU)", 95 | ) 96 | parser.add_argument("--runs", type=int, default=1, help="Number of runs (default: 1)") 97 | parser.add_argument( 98 | "--restore", 99 | type=str, 100 | default=None, 101 | help="Weights to use for initialization, e.g. a checkpoint", 102 | ) 103 | 104 | # Dataset options 105 | group_dataset = parser.add_argument_group("Dataset") 106 | group_dataset.add_argument( 107 | "--training_sample", 108 | type=float, 109 | default=10, 110 | help="Percentage of samples to use for training (default: 10%%)", 111 | ) 112 | group_dataset.add_argument( 113 | "--sampling_mode", 114 | type=str, 115 | help="Sampling mode" " (random sampling or disjoint, default: random)", 116 | default="random", 117 | ) 118 | group_dataset.add_argument( 119 | "--train_set", 120 | type=str, 121 | default=None, 122 | help="Path to the train ground truth (optional, this " 123 | "supersedes the --sampling_mode option)", 124 | ) 125 | group_dataset.add_argument( 126 | "--test_set", 127 | type=str, 128 | default=None, 129 | help="Path to the test set (optional, by default " 130 | "the test_set is the entire ground truth minus the training)", 131 | ) 132 | # Training options 133 | group_train = parser.add_argument_group("Training") 134 | group_train.add_argument( 135 | "--epoch", 136 | type=int, 137 | help="Training epochs (optional, if" " absent will be set by the model)", 138 | ) 139 | group_train.add_argument( 140 | "--patch_size", 141 | type=int, 142 | help="Size of the spatial neighbourhood (optional, if " 143 | "absent will be set by the model)", 144 | ) 145 | group_train.add_argument( 146 | "--lr", type=float, help="Learning rate, set by the model if not specified." 147 | ) 148 | group_train.add_argument( 149 | "--class_balancing", 150 | action="store_true", 151 | help="Inverse median frequency class balancing (default = False)", 152 | ) 153 | group_train.add_argument( 154 | "--batch_size", 155 | type=int, 156 | help="Batch size (optional, if absent will be set by the model", 157 | ) 158 | group_train.add_argument( 159 | "--test_stride", 160 | type=int, 161 | default=1, 162 | help="Sliding window step stride during inference (default = 1)", 163 | ) 164 | # Data augmentation parameters 165 | group_da = parser.add_argument_group("Data augmentation") 166 | group_da.add_argument( 167 | "--flip_augmentation", action="store_true", help="Random flips (if patch_size > 1)" 168 | ) 169 | group_da.add_argument( 170 | "--radiation_augmentation", 171 | action="store_true", 172 | help="Random radiation noise (illumination)", 173 | ) 174 | group_da.add_argument( 175 | "--mixture_augmentation", action="store_true", help="Random mixes between spectra" 176 | ) 177 | 178 | parser.add_argument( 179 | "--with_exploration", action="store_true", help="See data exploration visualization" 180 | ) 181 | parser.add_argument( 182 | "--download", 183 | type=str, 184 | default=None, 185 | nargs="+", 186 | choices=dataset_names, 187 | help="Download the specified datasets and quits.", 188 | ) 189 | 190 | 191 | args = parser.parse_args() 192 | 193 | CUDA_DEVICE = get_device(args.cuda) 194 | 195 | # % of training samples 196 | SAMPLE_PERCENTAGE = args.training_sample 197 | # Data augmentation ? 198 | FLIP_AUGMENTATION = args.flip_augmentation 199 | RADIATION_AUGMENTATION = args.radiation_augmentation 200 | MIXTURE_AUGMENTATION = args.mixture_augmentation 201 | # Dataset name 202 | DATASET = args.dataset 203 | # Model name 204 | MODEL = args.model 205 | # Number of runs (for cross-validation) 206 | N_RUNS = args.runs 207 | # Spatial context size (number of neighbours in each spatial direction) 208 | PATCH_SIZE = args.patch_size 209 | # Add some visualization of the spectra ? 210 | DATAVIZ = args.with_exploration 211 | # Target folder to store/download/load the datasets 212 | FOLDER = args.folder 213 | # Number of epochs to run 214 | EPOCH = args.epoch 215 | # Sampling mode, e.g random sampling 216 | SAMPLING_MODE = args.sampling_mode 217 | # Pre-computed weights to restore 218 | CHECKPOINT = args.restore 219 | # Learning rate for the SGD 220 | LEARNING_RATE = args.lr 221 | # Automated class balancing 222 | CLASS_BALANCING = args.class_balancing 223 | # Training ground truth file 224 | TRAIN_GT = args.train_set 225 | # Testing ground truth file 226 | TEST_GT = args.test_set 227 | TEST_STRIDE = args.test_stride 228 | 229 | if args.download is not None and len(args.download) > 0: 230 | for dataset in args.download: 231 | get_dataset(dataset, target_folder=FOLDER) 232 | quit() 233 | 234 | viz = visdom.Visdom(env=DATASET + " " + MODEL) 235 | if not viz.check_connection: 236 | print("Visdom is not connected. Did you run 'python -m visdom.server' ?") 237 | 238 | 239 | hyperparams = vars(args) 240 | # Load the dataset 241 | img, gt, LABEL_VALUES, IGNORED_LABELS, RGB_BANDS, palette = get_dataset(DATASET, FOLDER) 242 | # Number of classes 243 | N_CLASSES = len(LABEL_VALUES) 244 | # Number of bands (last dimension of the image tensor) 245 | N_BANDS = img.shape[-1] 246 | 247 | # Parameters for the SVM grid search 248 | SVM_GRID_PARAMS = [ 249 | {"kernel": ["rbf"], "gamma": [1e-1, 1e-2, 1e-3], "C": [1, 10, 100, 1000]}, 250 | {"kernel": ["linear"], "C": [0.1, 1, 10, 100, 1000]}, 251 | {"kernel": ["poly"], "degree": [3], "gamma": [1e-1, 1e-2, 1e-3]}, 252 | ] 253 | 254 | if palette is None: 255 | # Generate color palette 256 | palette = {0: (0, 0, 0)} 257 | for k, color in enumerate(sns.color_palette("hls", len(LABEL_VALUES) - 1)): 258 | palette[k + 1] = tuple(np.asarray(255 * np.array(color), dtype="uint8")) 259 | invert_palette = {v: k for k, v in palette.items()} 260 | 261 | 262 | def convert_to_color(x): 263 | return convert_to_color_(x, palette=palette) 264 | 265 | 266 | def convert_from_color(x): 267 | return convert_from_color_(x, palette=invert_palette) 268 | 269 | 270 | # Instantiate the experiment based on predefined networks 271 | hyperparams.update( 272 | { 273 | "n_classes": N_CLASSES, 274 | "n_bands": N_BANDS, 275 | "ignored_labels": IGNORED_LABELS, 276 | "device": CUDA_DEVICE, 277 | } 278 | ) 279 | hyperparams = dict((k, v) for k, v in hyperparams.items() if v is not None) 280 | 281 | # Show the image and the ground truth 282 | display_dataset(img, gt, RGB_BANDS, LABEL_VALUES, palette, viz) 283 | color_gt = convert_to_color(gt) 284 | 285 | if DATAVIZ: 286 | # Data exploration : compute and show the mean spectrums 287 | mean_spectrums = explore_spectrums( 288 | img, gt, LABEL_VALUES, viz, ignored_labels=IGNORED_LABELS 289 | ) 290 | plot_spectrums(mean_spectrums, viz, title="Mean spectrum/class") 291 | 292 | results = [] 293 | # run the experiment several times 294 | for run in range(N_RUNS): 295 | if TRAIN_GT is not None and TEST_GT is not None: 296 | train_gt = open_file(TRAIN_GT) 297 | test_gt = open_file(TEST_GT) 298 | elif TRAIN_GT is not None: 299 | train_gt = open_file(TRAIN_GT) 300 | test_gt = np.copy(gt) 301 | w, h = test_gt.shape 302 | test_gt[(train_gt > 0)[:w, :h]] = 0 303 | elif TEST_GT is not None: 304 | test_gt = open_file(TEST_GT) 305 | else: 306 | # Sample random training spectra 307 | train_gt, test_gt = sample_gt(gt, SAMPLE_PERCENTAGE, mode=SAMPLING_MODE) 308 | print( 309 | "{} samples selected (over {})".format( 310 | np.count_nonzero(train_gt), np.count_nonzero(gt) 311 | ) 312 | ) 313 | print( 314 | "Running an experiment with the {} model".format(MODEL), 315 | "run {}/{}".format(run + 1, N_RUNS), 316 | ) 317 | 318 | display_predictions(convert_to_color(train_gt), viz, caption="Train ground truth") 319 | display_predictions(convert_to_color(test_gt), viz, caption="Test ground truth") 320 | 321 | if MODEL == "SVM_grid": 322 | print("Running a grid search SVM") 323 | # Grid search SVM (linear and RBF) 324 | X_train, y_train = build_dataset(img, train_gt, ignored_labels=IGNORED_LABELS) 325 | class_weight = "balanced" if CLASS_BALANCING else None 326 | clf = sklearn.svm.SVC(class_weight=class_weight) 327 | clf = sklearn.model_selection.GridSearchCV( 328 | clf, SVM_GRID_PARAMS, verbose=5, n_jobs=4 329 | ) 330 | clf.fit(X_train, y_train) 331 | print("SVM best parameters : {}".format(clf.best_params_)) 332 | prediction = clf.predict(img.reshape(-1, N_BANDS)) 333 | save_model(clf, MODEL, DATASET) 334 | prediction = prediction.reshape(img.shape[:2]) 335 | elif MODEL == "SVM": 336 | X_train, y_train = build_dataset(img, train_gt, ignored_labels=IGNORED_LABELS) 337 | class_weight = "balanced" if CLASS_BALANCING else None 338 | clf = sklearn.svm.SVC(class_weight=class_weight) 339 | clf.fit(X_train, y_train) 340 | save_model(clf, MODEL, DATASET) 341 | prediction = clf.predict(img.reshape(-1, N_BANDS)) 342 | prediction = prediction.reshape(img.shape[:2]) 343 | elif MODEL == "SGD": 344 | X_train, y_train = build_dataset(img, train_gt, ignored_labels=IGNORED_LABELS) 345 | X_train, y_train = sklearn.utils.shuffle(X_train, y_train) 346 | scaler = sklearn.preprocessing.StandardScaler() 347 | X_train = scaler.fit_transform(X_train) 348 | class_weight = "balanced" if CLASS_BALANCING else None 349 | clf = sklearn.linear_model.SGDClassifier( 350 | class_weight=class_weight, learning_rate="optimal", tol=1e-3, average=10 351 | ) 352 | clf.fit(X_train, y_train) 353 | save_model(clf, MODEL, DATASET) 354 | prediction = clf.predict(scaler.transform(img.reshape(-1, N_BANDS))) 355 | prediction = prediction.reshape(img.shape[:2]) 356 | elif MODEL == "nearest": 357 | X_train, y_train = build_dataset(img, train_gt, ignored_labels=IGNORED_LABELS) 358 | X_train, y_train = sklearn.utils.shuffle(X_train, y_train) 359 | class_weight = "balanced" if CLASS_BALANCING else None 360 | clf = sklearn.neighbors.KNeighborsClassifier(weights="distance") 361 | clf = sklearn.model_selection.GridSearchCV( 362 | clf, {"n_neighbors": [1, 3, 5, 10, 20]}, verbose=5, n_jobs=4 363 | ) 364 | clf.fit(X_train, y_train) 365 | clf.fit(X_train, y_train) 366 | save_model(clf, MODEL, DATASET) 367 | prediction = clf.predict(img.reshape(-1, N_BANDS)) 368 | prediction = prediction.reshape(img.shape[:2]) 369 | else: 370 | if CLASS_BALANCING: 371 | weights = compute_imf_weights(train_gt, N_CLASSES, IGNORED_LABELS) 372 | hyperparams["weights"] = torch.from_numpy(weights) 373 | # Neural network 374 | model, optimizer, loss, hyperparams = get_model(MODEL, **hyperparams) 375 | # Split train set in train/val 376 | train_gt, val_gt = sample_gt(train_gt, 0.95, mode="random") 377 | # Generate the dataset 378 | train_dataset = HyperX(img, train_gt, **hyperparams) 379 | train_loader = data.DataLoader( 380 | train_dataset, 381 | batch_size=hyperparams["batch_size"], 382 | # pin_memory=hyperparams['device'], 383 | shuffle=True, 384 | ) 385 | val_dataset = HyperX(img, val_gt, **hyperparams) 386 | val_loader = data.DataLoader( 387 | val_dataset, 388 | # pin_memory=hyperparams['device'], 389 | batch_size=hyperparams["batch_size"], 390 | ) 391 | 392 | print(hyperparams) 393 | print("Network :") 394 | with torch.no_grad(): 395 | for input, _ in train_loader: 396 | break 397 | summary(model.to(hyperparams["device"]), input.size()[1:]) 398 | # We would like to use device=hyperparams['device'] altough we have 399 | # to wait for torchsummary to be fixed first. 400 | 401 | if CHECKPOINT is not None: 402 | model.load_state_dict(torch.load(CHECKPOINT)) 403 | 404 | try: 405 | train( 406 | model, 407 | optimizer, 408 | loss, 409 | train_loader, 410 | hyperparams["epoch"], 411 | scheduler=hyperparams["scheduler"], 412 | device=hyperparams["device"], 413 | supervision=hyperparams["supervision"], 414 | val_loader=val_loader, 415 | display=viz, 416 | ) 417 | except KeyboardInterrupt: 418 | # Allow the user to stop the training 419 | pass 420 | 421 | probabilities = test(model, img, hyperparams) 422 | prediction = np.argmax(probabilities, axis=-1) 423 | 424 | run_results = metrics( 425 | prediction, 426 | test_gt, 427 | ignored_labels=hyperparams["ignored_labels"], 428 | n_classes=N_CLASSES, 429 | ) 430 | 431 | mask = np.zeros(gt.shape, dtype="bool") 432 | for l in IGNORED_LABELS: 433 | mask[gt == l] = True 434 | prediction[mask] = 0 435 | 436 | color_prediction = convert_to_color(prediction) 437 | display_predictions( 438 | color_prediction, 439 | viz, 440 | gt=convert_to_color(test_gt), 441 | caption="Prediction vs. test ground truth", 442 | ) 443 | 444 | results.append(run_results) 445 | show_results(run_results, viz, label_values=LABEL_VALUES) 446 | 447 | if N_RUNS > 1: 448 | show_results(results, viz, label_values=LABEL_VALUES, agregated=True) 449 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import random 3 | import numpy as np 4 | from sklearn.metrics import confusion_matrix 5 | import sklearn.model_selection 6 | import seaborn as sns 7 | import itertools 8 | import spectral 9 | import visdom 10 | import matplotlib.pyplot as plt 11 | from scipy import io, misc 12 | import os 13 | import re 14 | import torch 15 | 16 | def get_device(ordinal): 17 | # Use GPU ? 18 | if ordinal < 0: 19 | print("Computation on CPU") 20 | device = torch.device('cpu') 21 | elif torch.cuda.is_available(): 22 | print("Computation on CUDA GPU device {}".format(ordinal)) 23 | device = torch.device('cuda:{}'.format(ordinal)) 24 | else: 25 | print("/!\\ CUDA was requested but is not available! Computation will go on CPU. /!\\") 26 | device = torch.device('cpu') 27 | return device 28 | 29 | 30 | def open_file(dataset): 31 | _, ext = os.path.splitext(dataset) 32 | ext = ext.lower() 33 | if ext == '.mat': 34 | # Load Matlab array 35 | return io.loadmat(dataset) 36 | elif ext == '.tif' or ext == '.tiff': 37 | # Load TIFF file 38 | return misc.imread(dataset) 39 | elif ext == '.hdr': 40 | img = spectral.open_image(dataset) 41 | return img.load() 42 | else: 43 | raise ValueError("Unknown file format: {}".format(ext)) 44 | 45 | def convert_to_color_(arr_2d, palette=None): 46 | """Convert an array of labels to RGB color-encoded image. 47 | 48 | Args: 49 | arr_2d: int 2D array of labels 50 | palette: dict of colors used (label number -> RGB tuple) 51 | 52 | Returns: 53 | arr_3d: int 2D images of color-encoded labels in RGB format 54 | 55 | """ 56 | arr_3d = np.zeros((arr_2d.shape[0], arr_2d.shape[1], 3), dtype=np.uint8) 57 | if palette is None: 58 | raise Exception("Unknown color palette") 59 | 60 | for c, i in palette.items(): 61 | m = arr_2d == c 62 | arr_3d[m] = i 63 | 64 | return arr_3d 65 | 66 | 67 | def convert_from_color_(arr_3d, palette=None): 68 | """Convert an RGB-encoded image to grayscale labels. 69 | 70 | Args: 71 | arr_3d: int 2D image of color-coded labels on 3 channels 72 | palette: dict of colors used (RGB tuple -> label number) 73 | 74 | Returns: 75 | arr_2d: int 2D array of labels 76 | 77 | """ 78 | if palette is None: 79 | raise Exception("Unknown color palette") 80 | 81 | arr_2d = np.zeros((arr_3d.shape[0], arr_3d.shape[1]), dtype=np.uint8) 82 | 83 | for c, i in palette.items(): 84 | m = np.all(arr_3d == np.array(c).reshape(1, 1, 3), axis=2) 85 | arr_2d[m] = i 86 | 87 | return arr_2d 88 | 89 | 90 | def display_predictions(pred, vis, gt=None, caption=""): 91 | if gt is None: 92 | vis.images([np.transpose(pred, (2, 0, 1))], 93 | opts={'caption': caption}) 94 | else: 95 | vis.images([np.transpose(pred, (2, 0, 1)), 96 | np.transpose(gt, (2, 0, 1))], 97 | nrow=2, 98 | opts={'caption': caption}) 99 | 100 | def display_dataset(img, gt, bands, labels, palette, vis): 101 | """Display the specified dataset. 102 | 103 | Args: 104 | img: 3D hyperspectral image 105 | gt: 2D array labels 106 | bands: tuple of RGB bands to select 107 | labels: list of label class names 108 | palette: dict of colors 109 | display (optional): type of display, if any 110 | 111 | """ 112 | print("Image has dimensions {}x{} and {} channels".format(*img.shape)) 113 | rgb = spectral.get_rgb(img, bands) 114 | rgb /= np.max(rgb) 115 | rgb = np.asarray(255 * rgb, dtype='uint8') 116 | 117 | # Display the RGB composite image 118 | caption = "RGB (bands {}, {}, {})".format(*bands) 119 | # send to visdom server 120 | vis.images([np.transpose(rgb, (2, 0, 1))], 121 | opts={'caption': caption}) 122 | 123 | def explore_spectrums(img, complete_gt, class_names, vis, 124 | ignored_labels=None): 125 | """Plot sampled spectrums with mean + std for each class. 126 | 127 | Args: 128 | img: 3D hyperspectral image 129 | complete_gt: 2D array of labels 130 | class_names: list of class names 131 | ignored_labels (optional): list of labels to ignore 132 | vis : Visdom display 133 | Returns: 134 | mean_spectrums: dict of mean spectrum by class 135 | 136 | """ 137 | mean_spectrums = {} 138 | for c in np.unique(complete_gt): 139 | if c in ignored_labels: 140 | continue 141 | mask = complete_gt == c 142 | class_spectrums = img[mask].reshape(-1, img.shape[-1]) 143 | step = max(1, class_spectrums.shape[0] // 100) 144 | fig = plt.figure() 145 | plt.title(class_names[c]) 146 | # Sample and plot spectrums from the selected class 147 | for spectrum in class_spectrums[::step, :]: 148 | plt.plot(spectrum, alpha=0.25) 149 | mean_spectrum = np.mean(class_spectrums, axis=0) 150 | std_spectrum = np.std(class_spectrums, axis=0) 151 | lower_spectrum = np.maximum(0, mean_spectrum - std_spectrum) 152 | higher_spectrum = mean_spectrum + std_spectrum 153 | 154 | # Plot the mean spectrum with thickness based on std 155 | plt.fill_between(range(len(mean_spectrum)), lower_spectrum, 156 | higher_spectrum, color="#3F5D7D") 157 | plt.plot(mean_spectrum, alpha=1, color="#FFFFFF", lw=2) 158 | vis.matplot(plt) 159 | mean_spectrums[class_names[c]] = mean_spectrum 160 | return mean_spectrums 161 | 162 | 163 | def plot_spectrums(spectrums, vis, title=""): 164 | """Plot the specified dictionary of spectrums. 165 | 166 | Args: 167 | spectrums: dictionary (name -> spectrum) of spectrums to plot 168 | vis: Visdom display 169 | """ 170 | win = None 171 | for k, v in spectrums.items(): 172 | n_bands = len(v) 173 | update = None if win is None else 'append' 174 | win = vis.line(X=np.arange(n_bands), Y=v, name=k, win=win, update=update, 175 | opts={'title': title}) 176 | 177 | 178 | def build_dataset(mat, gt, ignored_labels=None): 179 | """Create a list of training samples based on an image and a mask. 180 | 181 | Args: 182 | mat: 3D hyperspectral matrix to extract the spectrums from 183 | gt: 2D ground truth 184 | ignored_labels (optional): list of classes to ignore, e.g. 0 to remove 185 | unlabeled pixels 186 | return_indices (optional): bool set to True to return the indices of 187 | the chosen samples 188 | 189 | """ 190 | samples = [] 191 | labels = [] 192 | # Check that image and ground truth have the same 2D dimensions 193 | assert mat.shape[:2] == gt.shape[:2] 194 | 195 | for label in np.unique(gt): 196 | if label in ignored_labels: 197 | continue 198 | else: 199 | indices = np.nonzero(gt == label) 200 | samples += list(mat[indices]) 201 | labels += len(indices[0]) * [label] 202 | return np.asarray(samples), np.asarray(labels) 203 | 204 | 205 | def get_random_pos(img, window_shape): 206 | """ Return the corners of a random window in the input image 207 | 208 | Args: 209 | img: 2D (or more) image, e.g. RGB or grayscale image 210 | window_shape: (width, height) tuple of the window 211 | 212 | Returns: 213 | xmin, xmax, ymin, ymax: tuple of the corners of the window 214 | 215 | """ 216 | w, h = window_shape 217 | W, H = img.shape[:2] 218 | x1 = random.randint(0, W - w - 1) 219 | x2 = x1 + w 220 | y1 = random.randint(0, H - h - 1) 221 | y2 = y1 + h 222 | return x1, x2, y1, y2 223 | 224 | 225 | def padding_image(image, patch_size=None, mode="symmetric", constant_values=0): 226 | """Padding an input image. 227 | Modified at 2020.11.16. If you find any issues, please email at mengxue_zhang@hhu.edu.cn with details. 228 | 229 | Args: 230 | image: 2D+ image with a shape of [h, w, ...], 231 | The array to pad 232 | patch_size: optional, a list include two integers, default is [1, 1] for pure spectra algorithm, 233 | The patch size of the algorithm 234 | mode: optional, str or function, default is "symmetric", 235 | Including 'constant', 'reflect', 'symmetric', more details see np.pad() 236 | constant_values: optional, sequence or scalar, default is 0, 237 | Used in 'constant'. The values to set the padded values for each axis 238 | Returns: 239 | padded_image with a shape of [h + patch_size[0] // 2 * 2, w + patch_size[1] // 2 * 2, ...] 240 | 241 | """ 242 | if patch_size is None: 243 | patch_size = [1, 1] 244 | h = patch_size[0] // 2 245 | w = patch_size[1] // 2 246 | pad_width = [[h, h], [w, w]] 247 | [pad_width.append([0, 0]) for i in image.shape[2:]] 248 | padded_image = np.pad(image, pad_width, mode=mode, constant_values=constant_values) 249 | return padded_image 250 | 251 | 252 | def sliding_window(image, step=10, window_size=(20, 20), with_data=True): 253 | """Sliding window generator over an input image. 254 | 255 | Args: 256 | image: 2D+ image to slide the window on, e.g. RGB or hyperspectral 257 | step: int stride of the sliding window 258 | window_size: int tuple, width and height of the window 259 | with_data (optional): bool set to True to return both the data and the 260 | corner indices 261 | Yields: 262 | ([data], x, y, w, h) where x and y are the top-left corner of the 263 | window, (w,h) the window size 264 | 265 | """ 266 | # slide a window across the image 267 | w, h = window_size 268 | W, H = image.shape[:2] 269 | offset_w = (W - w) % step 270 | offset_h = (H - h) % step 271 | """ 272 | Compensate one for the stop value of range(...). because this function does not include the stop value. 273 | Two examples are listed as follows. 274 | When step = 1, supposing w = h = 3, W = H = 7, and step = 1. 275 | Then offset_w = 0, offset_h = 0. 276 | In this case, the x should have been ranged from 0 to 4 (4-6 is the last window), 277 | i.e., x is in range(0, 5) while W (7) - w (3) + offset_w (0) + 1 = 5. Plus one ! 278 | Range(0, 5, 1) equals [0, 1, 2, 3, 4]. 279 | 280 | When step = 2, supposing w = h = 3, W = H = 8, and step = 2. 281 | Then offset_w = 1, offset_h = 1. 282 | In this case, x is in [0, 2, 4] while W (8) - w (3) + offset_w (1) + 1 = 6. Plus one ! 283 | Range(0, 6, 2) equals [0, 2, 4]/ 284 | 285 | Same reason to H, h, offset_h, and y. 286 | """ 287 | for x in range(0, W - w + offset_w + 1, step): 288 | if x + w > W: 289 | x = W - w 290 | for y in range(0, H - h + offset_h + 1, step): 291 | if y + h > H: 292 | y = H - h 293 | if with_data: 294 | yield image[x:x + w, y:y + h], x, y, w, h 295 | else: 296 | yield x, y, w, h 297 | 298 | 299 | def count_sliding_window(top, step=10, window_size=(20, 20)): 300 | """ Count the number of windows in an image. 301 | 302 | Args: 303 | image: 2D+ image to slide the window on, e.g. RGB or hyperspectral, ... 304 | step: int stride of the sliding window 305 | window_size: int tuple, width and height of the window 306 | Returns: 307 | int number of windows 308 | """ 309 | sw = sliding_window(top, step, window_size, with_data=False) 310 | return sum(1 for _ in sw) 311 | 312 | 313 | def grouper(n, iterable): 314 | """ Browse an iterable by grouping n elements by n elements. 315 | 316 | Args: 317 | n: int, size of the groups 318 | iterable: the iterable to Browse 319 | Yields: 320 | chunk of n elements from the iterable 321 | 322 | """ 323 | it = iter(iterable) 324 | while True: 325 | chunk = tuple(itertools.islice(it, n)) 326 | if not chunk: 327 | return 328 | yield chunk 329 | 330 | 331 | def metrics(prediction, target, ignored_labels=[], n_classes=None): 332 | """Compute and print metrics (accuracy, confusion matrix and F1 scores). 333 | 334 | Args: 335 | prediction: list of predicted labels 336 | target: list of target labels 337 | ignored_labels (optional): list of labels to ignore, e.g. 0 for undef 338 | n_classes (optional): number of classes, max(target) by default 339 | Returns: 340 | accuracy, F1 score by class, confusion matrix 341 | """ 342 | ignored_mask = np.zeros(target.shape[:2], dtype=np.bool) 343 | for l in ignored_labels: 344 | ignored_mask[target == l] = True 345 | ignored_mask = ~ignored_mask 346 | target = target[ignored_mask] 347 | prediction = prediction[ignored_mask] 348 | 349 | results = {} 350 | 351 | n_classes = np.max(target) + 1 if n_classes is None else n_classes 352 | 353 | cm = confusion_matrix( 354 | target, 355 | prediction, 356 | labels=range(n_classes)) 357 | 358 | results["Confusion matrix"] = cm 359 | 360 | # Compute global accuracy 361 | total = np.sum(cm) 362 | accuracy = sum([cm[x][x] for x in range(len(cm))]) 363 | accuracy *= 100 / float(total) 364 | 365 | results["Accuracy"] = accuracy 366 | 367 | # Compute F1 score 368 | F1scores = np.zeros(len(cm)) 369 | for i in range(len(cm)): 370 | try: 371 | F1 = 2. * cm[i, i] / (np.sum(cm[i, :]) + np.sum(cm[:, i])) 372 | except ZeroDivisionError: 373 | F1 = 0. 374 | F1scores[i] = F1 375 | 376 | results["F1 scores"] = F1scores 377 | 378 | # Compute kappa coefficient 379 | pa = np.trace(cm) / float(total) 380 | pe = np.sum(np.sum(cm, axis=0) * np.sum(cm, axis=1)) / \ 381 | float(total * total) 382 | kappa = (pa - pe) / (1 - pe) 383 | results["Kappa"] = kappa 384 | 385 | return results 386 | 387 | 388 | def show_results(results, vis, label_values=None, agregated=False): 389 | text = "" 390 | 391 | if agregated: 392 | accuracies = [r["Accuracy"] for r in results] 393 | kappas = [r["Kappa"] for r in results] 394 | F1_scores = [r["F1 scores"] for r in results] 395 | 396 | F1_scores_mean = np.mean(F1_scores, axis=0) 397 | F1_scores_std = np.std(F1_scores, axis=0) 398 | cm = np.mean([r["Confusion matrix"] for r in results], axis=0) 399 | text += "Agregated results :\n" 400 | else: 401 | cm = results["Confusion matrix"] 402 | accuracy = results["Accuracy"] 403 | F1scores = results["F1 scores"] 404 | kappa = results["Kappa"] 405 | 406 | vis.heatmap(cm, opts={'title': "Confusion matrix", 407 | 'marginbottom': 150, 408 | 'marginleft': 150, 409 | 'width': 500, 410 | 'height': 500, 411 | 'rownames': label_values, 'columnnames': label_values}) 412 | text += "Confusion matrix :\n" 413 | text += str(cm) 414 | text += "---\n" 415 | 416 | if agregated: 417 | text += ("Accuracy: {:.03f} +- {:.03f}\n".format(np.mean(accuracies), 418 | np.std(accuracies))) 419 | else: 420 | text += "Accuracy : {:.03f}%\n".format(accuracy) 421 | text += "---\n" 422 | 423 | text += "F1 scores :\n" 424 | if agregated: 425 | for label, score, std in zip(label_values, F1_scores_mean, 426 | F1_scores_std): 427 | text += "\t{}: {:.03f} +- {:.03f}\n".format(label, score, std) 428 | else: 429 | for label, score in zip(label_values, F1scores): 430 | text += "\t{}: {:.03f}\n".format(label, score) 431 | text += "---\n" 432 | 433 | if agregated: 434 | text += ("Kappa: {:.03f} +- {:.03f}\n".format(np.mean(kappas), 435 | np.std(kappas))) 436 | else: 437 | text += "Kappa: {:.03f}\n".format(kappa) 438 | 439 | vis.text(text.replace('\n', '
')) 440 | print(text) 441 | 442 | 443 | def sample_gt(gt, train_size, mode='random'): 444 | """Extract a fixed percentage of samples from an array of labels. 445 | 446 | Args: 447 | gt: a 2D array of int labels 448 | percentage: [0, 1] float 449 | Returns: 450 | train_gt, test_gt: 2D arrays of int labels 451 | 452 | """ 453 | indices = np.nonzero(gt) 454 | X = list(zip(*indices)) # x,y features 455 | y = gt[indices].ravel() # classes 456 | train_gt = np.zeros_like(gt) 457 | test_gt = np.zeros_like(gt) 458 | if train_size > 1: 459 | train_size = int(train_size) 460 | 461 | if mode == 'random': 462 | train_indices, test_indices = sklearn.model_selection.train_test_split(X, train_size=train_size, stratify=y) 463 | train_indices = [list(t) for t in zip(*train_indices)] 464 | test_indices = [list(t) for t in zip(*test_indices)] 465 | train_gt[train_indices] = gt[train_indices] 466 | test_gt[test_indices] = gt[test_indices] 467 | elif mode == 'fixed': 468 | print("Sampling {} with train size = {}".format(mode, train_size)) 469 | train_indices, test_indices = [], [] 470 | for c in np.unique(gt): 471 | if c == 0: 472 | continue 473 | indices = np.nonzero(gt == c) 474 | X = list(zip(*indices)) # x,y features 475 | 476 | train, test = sklearn.model_selection.train_test_split(X, train_size=train_size) 477 | train_indices += train 478 | test_indices += test 479 | train_indices = [list(t) for t in zip(*train_indices)] 480 | test_indices = [list(t) for t in zip(*test_indices)] 481 | train_gt[train_indices] = gt[train_indices] 482 | test_gt[test_indices] = gt[test_indices] 483 | 484 | elif mode == 'disjoint': 485 | train_gt = np.copy(gt) 486 | test_gt = np.copy(gt) 487 | for c in np.unique(gt): 488 | mask = gt == c 489 | for x in range(gt.shape[0]): 490 | first_half_count = np.count_nonzero(mask[:x, :]) 491 | second_half_count = np.count_nonzero(mask[x:, :]) 492 | try: 493 | ratio = first_half_count / (first_half_count + second_half_count) 494 | if ratio > 0.9 * train_size: 495 | break 496 | except ZeroDivisionError: 497 | continue 498 | mask[:x, :] = 0 499 | train_gt[mask] = 0 500 | 501 | test_gt[train_gt > 0] = 0 502 | else: 503 | raise ValueError("{} sampling is not implemented yet.".format(mode)) 504 | return train_gt, test_gt 505 | 506 | 507 | def compute_imf_weights(ground_truth, n_classes=None, ignored_classes=[]): 508 | """ Compute inverse median frequency weights for class balancing. 509 | 510 | For each class i, it computes its frequency f_i, i.e the ratio between 511 | the number of pixels from class i and the total number of pixels. 512 | 513 | Then, it computes the median m of all frequencies. For each class the 514 | associated weight is m/f_i. 515 | 516 | Args: 517 | ground_truth: the annotations array 518 | n_classes: number of classes (optional, defaults to max(ground_truth)) 519 | ignored_classes: id of classes to ignore (optional) 520 | Returns: 521 | numpy array with the IMF coefficients 522 | """ 523 | n_classes = np.max(ground_truth) if n_classes is None else n_classes 524 | weights = np.zeros(n_classes) 525 | frequencies = np.zeros(n_classes) 526 | 527 | for c in range(0, n_classes): 528 | if c in ignored_classes: 529 | continue 530 | frequencies[c] = np.count_nonzero(ground_truth == c) 531 | 532 | # Normalize the pixel counts to obtain frequencies 533 | frequencies /= np.sum(frequencies) 534 | # Obtain the median on non-zero frequencies 535 | idx = np.nonzero(frequencies) 536 | median = np.median(frequencies[idx]) 537 | weights[idx] = median / frequencies[idx] 538 | weights[frequencies == 0] = 0. 539 | return weights 540 | 541 | def camel_to_snake(name): 542 | s = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', name) 543 | return re.sub('([a-z0-9])([A-Z])', r'\1_\2', s).lower() 544 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torch 6 | import torch.optim as optim 7 | from torch.nn import init 8 | 9 | # utils 10 | import math 11 | import os 12 | import datetime 13 | import numpy as np 14 | import joblib 15 | 16 | from tqdm import tqdm 17 | from utils import grouper, sliding_window, count_sliding_window, camel_to_snake 18 | 19 | 20 | def get_model(name, **kwargs): 21 | """ 22 | Instantiate and obtain a model with adequate hyperparameters 23 | 24 | Args: 25 | name: string of the model name 26 | kwargs: hyperparameters 27 | Returns: 28 | model: PyTorch network 29 | optimizer: PyTorch optimizer 30 | criterion: PyTorch loss Function 31 | kwargs: hyperparameters with sane defaults 32 | """ 33 | device = kwargs.setdefault("device", torch.device("cpu")) 34 | n_classes = kwargs["n_classes"] 35 | n_bands = kwargs["n_bands"] 36 | weights = torch.ones(n_classes) 37 | weights[torch.LongTensor(kwargs["ignored_labels"])] = 0.0 38 | weights = weights.to(device) 39 | weights = kwargs.setdefault("weights", weights) 40 | 41 | if name == "nn": 42 | kwargs.setdefault("patch_size", 1) 43 | center_pixel = True 44 | model = Baseline(n_bands, n_classes, kwargs.setdefault("dropout", False)) 45 | lr = kwargs.setdefault("learning_rate", 0.0001) 46 | optimizer = optim.Adam(model.parameters(), lr=lr) 47 | criterion = nn.CrossEntropyLoss(weight=kwargs["weights"]) 48 | kwargs.setdefault("epoch", 100) 49 | kwargs.setdefault("batch_size", 100) 50 | elif name == "hamida": 51 | patch_size = kwargs.setdefault("patch_size", 5) 52 | center_pixel = True 53 | model = HamidaEtAl(n_bands, n_classes, patch_size=patch_size) 54 | lr = kwargs.setdefault("learning_rate", 0.01) 55 | optimizer = optim.SGD(model.parameters(), lr=lr, weight_decay=0.0005) 56 | kwargs.setdefault("batch_size", 100) 57 | criterion = nn.CrossEntropyLoss(weight=kwargs["weights"]) 58 | elif name == "lee": 59 | kwargs.setdefault("epoch", 200) 60 | patch_size = kwargs.setdefault("patch_size", 5) 61 | center_pixel = False 62 | model = LeeEtAl(n_bands, n_classes) 63 | lr = kwargs.setdefault("learning_rate", 0.001) 64 | optimizer = optim.Adam(model.parameters(), lr=lr) 65 | criterion = nn.CrossEntropyLoss(weight=kwargs["weights"]) 66 | elif name == "chen": 67 | patch_size = kwargs.setdefault("patch_size", 27) 68 | center_pixel = True 69 | model = ChenEtAl(n_bands, n_classes, patch_size=patch_size) 70 | lr = kwargs.setdefault("learning_rate", 0.003) 71 | optimizer = optim.SGD(model.parameters(), lr=lr) 72 | criterion = nn.CrossEntropyLoss(weight=kwargs["weights"]) 73 | kwargs.setdefault("epoch", 400) 74 | kwargs.setdefault("batch_size", 100) 75 | elif name == "li": 76 | patch_size = kwargs.setdefault("patch_size", 5) 77 | center_pixel = True 78 | model = LiEtAl(n_bands, n_classes, n_planes=16, patch_size=patch_size) 79 | lr = kwargs.setdefault("learning_rate", 0.01) 80 | optimizer = optim.SGD( 81 | model.parameters(), lr=lr, momentum=0.9, weight_decay=0.0005 82 | ) 83 | epoch = kwargs.setdefault("epoch", 200) 84 | criterion = nn.CrossEntropyLoss(weight=kwargs["weights"]) 85 | # kwargs.setdefault('scheduler', optim.lr_scheduler.MultiStepLR(optimizer, milestones=[epoch // 2, (5 * epoch) // 6], gamma=0.1)) 86 | elif name == "hu": 87 | kwargs.setdefault("patch_size", 1) 88 | center_pixel = True 89 | model = HuEtAl(n_bands, n_classes) 90 | # From what I infer from the paper (Eq.7 and Algorithm 1), it is standard SGD with lr = 0.01 91 | lr = kwargs.setdefault("learning_rate", 0.01) 92 | optimizer = optim.SGD(model.parameters(), lr=lr) 93 | criterion = nn.CrossEntropyLoss(weight=kwargs["weights"]) 94 | kwargs.setdefault("epoch", 100) 95 | kwargs.setdefault("batch_size", 100) 96 | elif name == "he": 97 | # We train our model by AdaGrad [18] algorithm, in which 98 | # the base learning rate is 0.01. In addition, we set the batch 99 | # as 40, weight decay as 0.01 for all the layers 100 | # The input of our network is the HSI 3D patch in the size of 7×7×Band 101 | kwargs.setdefault("patch_size", 7) 102 | kwargs.setdefault("batch_size", 40) 103 | lr = kwargs.setdefault("learning_rate", 0.01) 104 | center_pixel = True 105 | model = HeEtAl(n_bands, n_classes, patch_size=kwargs["patch_size"]) 106 | # For Adagrad, we need to load the model on GPU before creating the optimizer 107 | model = model.to(device) 108 | optimizer = optim.Adagrad(model.parameters(), lr=lr, weight_decay=0.01) 109 | criterion = nn.CrossEntropyLoss(weight=kwargs["weights"]) 110 | elif name == "luo": 111 | # All the experiments are settled by the learning rate of 0.1, 112 | # the decay term of 0.09 and batch size of 100. 113 | kwargs.setdefault("patch_size", 3) 114 | kwargs.setdefault("batch_size", 100) 115 | lr = kwargs.setdefault("learning_rate", 0.1) 116 | center_pixel = True 117 | model = LuoEtAl(n_bands, n_classes, patch_size=kwargs["patch_size"]) 118 | optimizer = optim.SGD(model.parameters(), lr=lr, weight_decay=0.09) 119 | criterion = nn.CrossEntropyLoss(weight=kwargs["weights"]) 120 | elif name == "sharma": 121 | # We train our S-CNN from scratch using stochastic gradient descent with 122 | # momentum set to 0.9, weight decay of 0.0005, and with a batch size 123 | # of 60. We initialize an equal learning rate for all trainable layers 124 | # to 0.05, which is manually decreased by a factor of 10 when the validation 125 | # error stopped decreasing. Prior to the termination the learning rate was 126 | # reduced two times at 15th and 25th epoch. [...] 127 | # We trained the network for 30 epochs 128 | kwargs.setdefault("batch_size", 60) 129 | epoch = kwargs.setdefault("epoch", 30) 130 | lr = kwargs.setdefault("lr", 0.05) 131 | center_pixel = True 132 | # We assume patch_size = 64 133 | kwargs.setdefault("patch_size", 64) 134 | model = SharmaEtAl(n_bands, n_classes, patch_size=kwargs["patch_size"]) 135 | optimizer = optim.SGD(model.parameters(), lr=lr, weight_decay=0.0005) 136 | criterion = nn.CrossEntropyLoss(weight=kwargs["weights"]) 137 | kwargs.setdefault( 138 | "scheduler", 139 | optim.lr_scheduler.MultiStepLR( 140 | optimizer, milestones=[epoch // 2, (5 * epoch) // 6], gamma=0.1 141 | ), 142 | ) 143 | elif name == "liu": 144 | kwargs["supervision"] = "semi" 145 | # "The learning rate is set to 0.001 empirically. The number of epochs is set to be 40." 146 | kwargs.setdefault("epoch", 40) 147 | lr = kwargs.setdefault("lr", 0.001) 148 | center_pixel = True 149 | patch_size = kwargs.setdefault("patch_size", 9) 150 | model = LiuEtAl(n_bands, n_classes, patch_size) 151 | optimizer = optim.SGD(model.parameters(), lr=lr) 152 | # "The unsupervised cost is the squared error of the difference" 153 | criterion = ( 154 | nn.CrossEntropyLoss(weight=kwargs["weights"]), 155 | lambda rec, data: F.mse_loss( 156 | rec, data[:, :, :, patch_size // 2, patch_size // 2].squeeze() 157 | ), 158 | ) 159 | elif name == "boulch": 160 | kwargs["supervision"] = "semi" 161 | kwargs.setdefault("patch_size", 1) 162 | kwargs.setdefault("epoch", 100) 163 | lr = kwargs.setdefault("lr", 0.001) 164 | center_pixel = True 165 | model = BoulchEtAl(n_bands, n_classes) 166 | optimizer = optim.SGD(model.parameters(), lr=lr) 167 | criterion = ( 168 | nn.CrossEntropyLoss(weight=kwargs["weights"]), 169 | lambda rec, data: F.mse_loss(rec, data.squeeze()), 170 | ) 171 | elif name == "mou": 172 | kwargs.setdefault("patch_size", 1) 173 | center_pixel = True 174 | kwargs.setdefault("epoch", 100) 175 | # "The RNN was trained with the Adadelta algorithm [...] We made use of a 176 | # fairly high learning rate of 1.0 instead of the relatively low 177 | # default of 0.002 to train the network" 178 | lr = kwargs.setdefault("lr", 1.0) 179 | model = MouEtAl(n_bands, n_classes) 180 | # For Adadelta, we need to load the model on GPU before creating the optimizer 181 | model = model.to(device) 182 | optimizer = optim.Adadelta(model.parameters(), lr=lr) 183 | criterion = nn.CrossEntropyLoss(weight=kwargs["weights"]) 184 | else: 185 | raise KeyError("{} model is unknown.".format(name)) 186 | 187 | model = model.to(device) 188 | epoch = kwargs.setdefault("epoch", 100) 189 | kwargs.setdefault( 190 | "scheduler", 191 | optim.lr_scheduler.ReduceLROnPlateau( 192 | optimizer, factor=0.1, patience=epoch // 4, verbose=True 193 | ), 194 | ) 195 | # kwargs.setdefault('scheduler', None) 196 | kwargs.setdefault("batch_size", 100) 197 | kwargs.setdefault("supervision", "full") 198 | kwargs.setdefault("flip_augmentation", False) 199 | kwargs.setdefault("radiation_augmentation", False) 200 | kwargs.setdefault("mixture_augmentation", False) 201 | kwargs["center_pixel"] = center_pixel 202 | return model, optimizer, criterion, kwargs 203 | 204 | 205 | class Baseline(nn.Module): 206 | """ 207 | Baseline network 208 | """ 209 | 210 | @staticmethod 211 | def weight_init(m): 212 | if isinstance(m, nn.Linear): 213 | init.kaiming_normal_(m.weight) 214 | init.zeros_(m.bias) 215 | 216 | def __init__(self, input_channels, n_classes, dropout=False): 217 | super(Baseline, self).__init__() 218 | self.use_dropout = dropout 219 | if dropout: 220 | self.dropout = nn.Dropout(p=0.5) 221 | 222 | self.fc1 = nn.Linear(input_channels, 2048) 223 | self.fc2 = nn.Linear(2048, 4096) 224 | self.fc3 = nn.Linear(4096, 2048) 225 | self.fc4 = nn.Linear(2048, n_classes) 226 | 227 | self.apply(self.weight_init) 228 | 229 | def forward(self, x): 230 | x = F.relu(self.fc1(x)) 231 | if self.use_dropout: 232 | x = self.dropout(x) 233 | x = F.relu(self.fc2(x)) 234 | if self.use_dropout: 235 | x = self.dropout(x) 236 | x = F.relu(self.fc3(x)) 237 | if self.use_dropout: 238 | x = self.dropout(x) 239 | x = self.fc4(x) 240 | return x 241 | 242 | 243 | class HuEtAl(nn.Module): 244 | """ 245 | Deep Convolutional Neural Networks for Hyperspectral Image Classification 246 | Wei Hu, Yangyu Huang, Li Wei, Fan Zhang and Hengchao Li 247 | Journal of Sensors, Volume 2015 (2015) 248 | https://www.hindawi.com/journals/js/2015/258619/ 249 | """ 250 | 251 | @staticmethod 252 | def weight_init(m): 253 | # [All the trainable parameters in our CNN should be initialized to 254 | # be a random value between −0.05 and 0.05.] 255 | if isinstance(m, nn.Linear) or isinstance(m, nn.Conv1d): 256 | init.uniform_(m.weight, -0.05, 0.05) 257 | init.zeros_(m.bias) 258 | 259 | def _get_final_flattened_size(self): 260 | with torch.no_grad(): 261 | x = torch.zeros(1, 1, self.input_channels) 262 | x = self.pool(self.conv(x)) 263 | return x.numel() 264 | 265 | def __init__(self, input_channels, n_classes, kernel_size=None, pool_size=None): 266 | super(HuEtAl, self).__init__() 267 | if kernel_size is None: 268 | # [In our experiments, k1 is better to be [ceil](n1/9)] 269 | kernel_size = math.ceil(input_channels / 9) 270 | if pool_size is None: 271 | # The authors recommand that k2's value is chosen so that the pooled features have 30~40 values 272 | # ceil(kernel_size/5) gives the same values as in the paper so let's assume it's okay 273 | pool_size = math.ceil(kernel_size / 5) 274 | self.input_channels = input_channels 275 | 276 | # [The first hidden convolution layer C1 filters the n1 x 1 input data with 20 kernels of size k1 x 1] 277 | self.conv = nn.Conv1d(1, 20, kernel_size) 278 | self.pool = nn.MaxPool1d(pool_size) 279 | self.features_size = self._get_final_flattened_size() 280 | # [n4 is set to be 100] 281 | self.fc1 = nn.Linear(self.features_size, 100) 282 | self.fc2 = nn.Linear(100, n_classes) 283 | self.apply(self.weight_init) 284 | 285 | def forward(self, x): 286 | # [In our design architecture, we choose the hyperbolic tangent function tanh(u)] 287 | x = x.squeeze(dim=-1).squeeze(dim=-1) 288 | x = x.unsqueeze(1) 289 | x = self.conv(x) 290 | x = torch.tanh(self.pool(x)) 291 | x = x.view(-1, self.features_size) 292 | x = torch.tanh(self.fc1(x)) 293 | x = self.fc2(x) 294 | return x 295 | 296 | 297 | class HamidaEtAl(nn.Module): 298 | """ 299 | 3-D Deep Learning Approach for Remote Sensing Image Classification 300 | Amina Ben Hamida, Alexandre Benoit, Patrick Lambert, Chokri Ben Amar 301 | IEEE TGRS, 2018 302 | https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=8344565 303 | """ 304 | 305 | @staticmethod 306 | def weight_init(m): 307 | if isinstance(m, nn.Linear) or isinstance(m, nn.Conv3d): 308 | init.kaiming_normal_(m.weight) 309 | init.zeros_(m.bias) 310 | 311 | def __init__(self, input_channels, n_classes, patch_size=5, dilation=1): 312 | super(HamidaEtAl, self).__init__() 313 | # The first layer is a (3,3,3) kernel sized Conv characterized 314 | # by a stride equal to 1 and number of neurons equal to 20 315 | self.patch_size = patch_size 316 | self.input_channels = input_channels 317 | dilation = (dilation, 1, 1) 318 | 319 | if patch_size == 3: 320 | self.conv1 = nn.Conv3d( 321 | 1, 20, (3, 3, 3), stride=(1, 1, 1), dilation=dilation, padding=1 322 | ) 323 | else: 324 | self.conv1 = nn.Conv3d( 325 | 1, 20, (3, 3, 3), stride=(1, 1, 1), dilation=dilation, padding=0 326 | ) 327 | # Next pooling is applied using a layer identical to the previous one 328 | # with the difference of a 1D kernel size (1,1,3) and a larger stride 329 | # equal to 2 in order to reduce the spectral dimension 330 | self.pool1 = nn.Conv3d( 331 | 20, 20, (3, 1, 1), dilation=dilation, stride=(2, 1, 1), padding=(1, 0, 0) 332 | ) 333 | # Then, a duplicate of the first and second layers is created with 334 | # 35 hidden neurons per layer. 335 | self.conv2 = nn.Conv3d( 336 | 20, 35, (3, 3, 3), dilation=dilation, stride=(1, 1, 1), padding=(1, 0, 0) 337 | ) 338 | self.pool2 = nn.Conv3d( 339 | 35, 35, (3, 1, 1), dilation=dilation, stride=(2, 1, 1), padding=(1, 0, 0) 340 | ) 341 | # Finally, the 1D spatial dimension is progressively reduced 342 | # thanks to the use of two Conv layers, 35 neurons each, 343 | # with respective kernel sizes of (1,1,3) and (1,1,2) and strides 344 | # respectively equal to (1,1,1) and (1,1,2) 345 | self.conv3 = nn.Conv3d( 346 | 35, 35, (3, 1, 1), dilation=dilation, stride=(1, 1, 1), padding=(1, 0, 0) 347 | ) 348 | self.conv4 = nn.Conv3d( 349 | 35, 35, (2, 1, 1), dilation=dilation, stride=(2, 1, 1), padding=(1, 0, 0) 350 | ) 351 | 352 | # self.dropout = nn.Dropout(p=0.5) 353 | 354 | self.features_size = self._get_final_flattened_size() 355 | # The architecture ends with a fully connected layer where the number 356 | # of neurons is equal to the number of input classes. 357 | self.fc = nn.Linear(self.features_size, n_classes) 358 | 359 | self.apply(self.weight_init) 360 | 361 | def _get_final_flattened_size(self): 362 | with torch.no_grad(): 363 | x = torch.zeros( 364 | (1, 1, self.input_channels, self.patch_size, self.patch_size) 365 | ) 366 | x = self.pool1(self.conv1(x)) 367 | x = self.pool2(self.conv2(x)) 368 | x = self.conv3(x) 369 | x = self.conv4(x) 370 | _, t, c, w, h = x.size() 371 | return t * c * w * h 372 | 373 | def forward(self, x): 374 | x = F.relu(self.conv1(x)) 375 | x = self.pool1(x) 376 | x = F.relu(self.conv2(x)) 377 | x = self.pool2(x) 378 | x = F.relu(self.conv3(x)) 379 | x = F.relu(self.conv4(x)) 380 | x = x.view(-1, self.features_size) 381 | # x = self.dropout(x) 382 | x = self.fc(x) 383 | return x 384 | 385 | 386 | class LeeEtAl(nn.Module): 387 | """ 388 | CONTEXTUAL DEEP CNN BASED HYPERSPECTRAL CLASSIFICATION 389 | Hyungtae Lee and Heesung Kwon 390 | IGARSS 2016 391 | """ 392 | 393 | @staticmethod 394 | def weight_init(m): 395 | if isinstance(m, nn.Linear) or isinstance(m, nn.Conv3d): 396 | init.kaiming_uniform_(m.weight) 397 | init.zeros_(m.bias) 398 | 399 | def __init__(self, in_channels, n_classes): 400 | super(LeeEtAl, self).__init__() 401 | # The first convolutional layer applied to the input hyperspectral 402 | # image uses an inception module that locally convolves the input 403 | # image with two convolutional filters with different sizes 404 | # (1x1xB and 3x3xB where B is the number of spectral bands) 405 | self.conv_3x3 = nn.Conv3d( 406 | 1, 128, (in_channels, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1) 407 | ) 408 | self.conv_1x1 = nn.Conv3d( 409 | 1, 128, (in_channels, 1, 1), stride=(1, 1, 1), padding=0 410 | ) 411 | 412 | # We use two modules from the residual learning approach 413 | # Residual block 1 414 | self.conv1 = nn.Conv2d(256, 128, (1, 1)) 415 | self.conv2 = nn.Conv2d(128, 128, (1, 1)) 416 | self.conv3 = nn.Conv2d(128, 128, (1, 1)) 417 | 418 | # Residual block 2 419 | self.conv4 = nn.Conv2d(128, 128, (1, 1)) 420 | self.conv5 = nn.Conv2d(128, 128, (1, 1)) 421 | 422 | # The layer combination in the last three convolutional layers 423 | # is the same as the fully connected layers of Alexnet 424 | self.conv6 = nn.Conv2d(128, 128, (1, 1)) 425 | self.conv7 = nn.Conv2d(128, 128, (1, 1)) 426 | self.conv8 = nn.Conv2d(128, n_classes, (1, 1)) 427 | 428 | self.lrn1 = nn.LocalResponseNorm(256) 429 | self.lrn2 = nn.LocalResponseNorm(128) 430 | 431 | # The 7 th and 8 th convolutional layers have dropout in training 432 | self.dropout = nn.Dropout(p=0.5) 433 | 434 | self.apply(self.weight_init) 435 | 436 | def forward(self, x): 437 | # Inception module 438 | x_3x3 = self.conv_3x3(x) 439 | x_1x1 = self.conv_1x1(x) 440 | x = torch.cat([x_3x3, x_1x1], dim=1) 441 | # Remove the third dimension of the tensor 442 | x = torch.squeeze(x) 443 | 444 | # Local Response Normalization 445 | x = F.relu(self.lrn1(x)) 446 | 447 | # First convolution 448 | x = self.conv1(x) 449 | 450 | # Local Response Normalization 451 | x = F.relu(self.lrn2(x)) 452 | 453 | # First residual block 454 | x_res = F.relu(self.conv2(x)) 455 | x_res = self.conv3(x_res) 456 | x = F.relu(x + x_res) 457 | 458 | # Second residual block 459 | x_res = F.relu(self.conv4(x)) 460 | x_res = self.conv5(x_res) 461 | x = F.relu(x + x_res) 462 | 463 | x = F.relu(self.conv6(x)) 464 | x = self.dropout(x) 465 | x = F.relu(self.conv7(x)) 466 | x = self.dropout(x) 467 | x = self.conv8(x) 468 | return x 469 | 470 | 471 | class ChenEtAl(nn.Module): 472 | """ 473 | DEEP FEATURE EXTRACTION AND CLASSIFICATION OF HYPERSPECTRAL IMAGES BASED ON 474 | CONVOLUTIONAL NEURAL NETWORKS 475 | Yushi Chen, Hanlu Jiang, Chunyang Li, Xiuping Jia and Pedram Ghamisi 476 | IEEE Transactions on Geoscience and Remote Sensing (TGRS), 2017 477 | """ 478 | 479 | @staticmethod 480 | def weight_init(m): 481 | # In the beginning, the weights are randomly initialized 482 | # with standard deviation 0.001 483 | if isinstance(m, nn.Linear) or isinstance(m, nn.Conv3d): 484 | init.normal_(m.weight, std=0.001) 485 | init.zeros_(m.bias) 486 | 487 | def __init__(self, input_channels, n_classes, patch_size=27, n_planes=32): 488 | super(ChenEtAl, self).__init__() 489 | self.input_channels = input_channels 490 | self.n_planes = n_planes 491 | self.patch_size = patch_size 492 | 493 | self.conv1 = nn.Conv3d(1, n_planes, (32, 4, 4)) 494 | self.pool1 = nn.MaxPool3d((1, 2, 2)) 495 | self.conv2 = nn.Conv3d(n_planes, n_planes, (32, 4, 4)) 496 | self.pool2 = nn.MaxPool3d((1, 2, 2)) 497 | self.conv3 = nn.Conv3d(n_planes, n_planes, (32, 4, 4)) 498 | 499 | self.features_size = self._get_final_flattened_size() 500 | 501 | self.fc = nn.Linear(self.features_size, n_classes) 502 | 503 | self.dropout = nn.Dropout(p=0.5) 504 | 505 | self.apply(self.weight_init) 506 | 507 | def _get_final_flattened_size(self): 508 | with torch.no_grad(): 509 | x = torch.zeros( 510 | (1, 1, self.input_channels, self.patch_size, self.patch_size) 511 | ) 512 | x = self.pool1(self.conv1(x)) 513 | x = self.pool2(self.conv2(x)) 514 | x = self.conv3(x) 515 | _, t, c, w, h = x.size() 516 | return t * c * w * h 517 | 518 | def forward(self, x): 519 | x = F.relu(self.conv1(x)) 520 | x = self.pool1(x) 521 | x = self.dropout(x) 522 | x = F.relu(self.conv2(x)) 523 | x = self.pool2(x) 524 | x = self.dropout(x) 525 | x = F.relu(self.conv3(x)) 526 | x = self.dropout(x) 527 | x = x.view(-1, self.features_size) 528 | x = self.fc(x) 529 | return x 530 | 531 | 532 | class LiEtAl(nn.Module): 533 | """ 534 | SPECTRAL–SPATIAL CLASSIFICATION OF HYPERSPECTRAL IMAGERY 535 | WITH 3D CONVOLUTIONAL NEURAL NETWORK 536 | Ying Li, Haokui Zhang and Qiang Shen 537 | MDPI Remote Sensing, 2017 538 | http://www.mdpi.com/2072-4292/9/1/67 539 | """ 540 | 541 | @staticmethod 542 | def weight_init(m): 543 | if isinstance(m, nn.Linear) or isinstance(m, nn.Conv3d): 544 | init.xavier_uniform_(m.weight.data) 545 | init.constant_(m.bias.data, 0) 546 | 547 | def __init__(self, input_channels, n_classes, n_planes=2, patch_size=5): 548 | super(LiEtAl, self).__init__() 549 | self.input_channels = input_channels 550 | self.n_planes = n_planes 551 | self.patch_size = patch_size 552 | 553 | # The proposed 3D-CNN model has two 3D convolution layers (C1 and C2) 554 | # and a fully-connected layer (F1) 555 | # we fix the spatial size of the 3D convolution kernels to 3 × 3 556 | # while only slightly varying the spectral depth of the kernels 557 | # for the Pavia University and Indian Pines scenes, those in C1 and C2 558 | # were set to seven and three, respectively 559 | self.conv1 = nn.Conv3d(1, n_planes, (7, 3, 3), padding=(1, 0, 0)) 560 | # the number of kernels in the second convolution layer is set to be 561 | # twice as many as that in the first convolution layer 562 | self.conv2 = nn.Conv3d(n_planes, 2 * n_planes, (3, 3, 3), padding=(1, 0, 0)) 563 | # self.dropout = nn.Dropout(p=0.5) 564 | self.features_size = self._get_final_flattened_size() 565 | 566 | self.fc = nn.Linear(self.features_size, n_classes) 567 | 568 | self.apply(self.weight_init) 569 | 570 | def _get_final_flattened_size(self): 571 | with torch.no_grad(): 572 | x = torch.zeros( 573 | (1, 1, self.input_channels, self.patch_size, self.patch_size) 574 | ) 575 | x = self.conv1(x) 576 | x = self.conv2(x) 577 | _, t, c, w, h = x.size() 578 | return t * c * w * h 579 | 580 | def forward(self, x): 581 | x = F.relu(self.conv1(x)) 582 | x = F.relu(self.conv2(x)) 583 | x = x.view(-1, self.features_size) 584 | # x = self.dropout(x) 585 | x = self.fc(x) 586 | return x 587 | 588 | 589 | class HeEtAl(nn.Module): 590 | """ 591 | MULTI-SCALE 3D DEEP CONVOLUTIONAL NEURAL NETWORK FOR HYPERSPECTRAL 592 | IMAGE CLASSIFICATION 593 | Mingyi He, Bo Li, Huahui Chen 594 | IEEE International Conference on Image Processing (ICIP) 2017 595 | https://ieeexplore.ieee.org/document/8297014/ 596 | """ 597 | 598 | @staticmethod 599 | def weight_init(m): 600 | if isinstance(m, nn.Linear) or isinstance(m, nn.Conv3d): 601 | init.kaiming_uniform(m.weight) 602 | init.zeros_(m.bias) 603 | 604 | def __init__(self, input_channels, n_classes, patch_size=7): 605 | super(HeEtAl, self).__init__() 606 | self.input_channels = input_channels 607 | self.patch_size = patch_size 608 | 609 | self.conv1 = nn.Conv3d(1, 16, (11, 3, 3), stride=(3, 1, 1)) 610 | self.conv2_1 = nn.Conv3d(16, 16, (1, 1, 1), padding=(0, 0, 0)) 611 | self.conv2_2 = nn.Conv3d(16, 16, (3, 1, 1), padding=(1, 0, 0)) 612 | self.conv2_3 = nn.Conv3d(16, 16, (5, 1, 1), padding=(2, 0, 0)) 613 | self.conv2_4 = nn.Conv3d(16, 16, (11, 1, 1), padding=(5, 0, 0)) 614 | self.conv3_1 = nn.Conv3d(16, 16, (1, 1, 1), padding=(0, 0, 0)) 615 | self.conv3_2 = nn.Conv3d(16, 16, (3, 1, 1), padding=(1, 0, 0)) 616 | self.conv3_3 = nn.Conv3d(16, 16, (5, 1, 1), padding=(2, 0, 0)) 617 | self.conv3_4 = nn.Conv3d(16, 16, (11, 1, 1), padding=(5, 0, 0)) 618 | self.conv4 = nn.Conv3d(16, 16, (3, 2, 2)) 619 | self.pooling = nn.MaxPool2d((3, 2, 2), stride=(3, 2, 2)) 620 | # the ratio of dropout is 0.6 in our experiments 621 | self.dropout = nn.Dropout(p=0.6) 622 | 623 | self.features_size = self._get_final_flattened_size() 624 | 625 | self.fc = nn.Linear(self.features_size, n_classes) 626 | 627 | self.apply(self.weight_init) 628 | 629 | def _get_final_flattened_size(self): 630 | with torch.no_grad(): 631 | x = torch.zeros( 632 | (1, 1, self.input_channels, self.patch_size, self.patch_size) 633 | ) 634 | x = self.conv1(x) 635 | x2_1 = self.conv2_1(x) 636 | x2_2 = self.conv2_2(x) 637 | x2_3 = self.conv2_3(x) 638 | x2_4 = self.conv2_4(x) 639 | x = x2_1 + x2_2 + x2_3 + x2_4 640 | x3_1 = self.conv3_1(x) 641 | x3_2 = self.conv3_2(x) 642 | x3_3 = self.conv3_3(x) 643 | x3_4 = self.conv3_4(x) 644 | x = x3_1 + x3_2 + x3_3 + x3_4 645 | x = self.conv4(x) 646 | _, t, c, w, h = x.size() 647 | return t * c * w * h 648 | 649 | def forward(self, x): 650 | x = F.relu(self.conv1(x)) 651 | x2_1 = self.conv2_1(x) 652 | x2_2 = self.conv2_2(x) 653 | x2_3 = self.conv2_3(x) 654 | x2_4 = self.conv2_4(x) 655 | x = x2_1 + x2_2 + x2_3 + x2_4 656 | x = F.relu(x) 657 | x3_1 = self.conv3_1(x) 658 | x3_2 = self.conv3_2(x) 659 | x3_3 = self.conv3_3(x) 660 | x3_4 = self.conv3_4(x) 661 | x = x3_1 + x3_2 + x3_3 + x3_4 662 | x = F.relu(x) 663 | x = F.relu(self.conv4(x)) 664 | x = x.view(-1, self.features_size) 665 | x = self.dropout(x) 666 | x = self.fc(x) 667 | return x 668 | 669 | 670 | class LuoEtAl(nn.Module): 671 | """ 672 | HSI-CNN: A Novel Convolution Neural Network for Hyperspectral Image 673 | Yanan Luo, Jie Zou, Chengfei Yao, Tao Li, Gang Bai 674 | International Conference on Pattern Recognition 2018 675 | """ 676 | 677 | @staticmethod 678 | def weight_init(m): 679 | if isinstance(m, (nn.Linear, nn.Conv2d, nn.Conv3d)): 680 | init.kaiming_uniform_(m.weight) 681 | init.zeros_(m.bias) 682 | 683 | def __init__(self, input_channels, n_classes, patch_size=3, n_planes=90): 684 | super(LuoEtAl, self).__init__() 685 | self.input_channels = input_channels 686 | self.patch_size = patch_size 687 | self.n_planes = n_planes 688 | 689 | # the 8-neighbor pixels [...] are fed into the Conv1 convolved by n1 kernels 690 | # and s1 stride. Conv1 results are feature vectors each with height of and 691 | # the width is 1. After reshape layer, the feature vectors becomes an image-like 692 | # 2-dimension data. 693 | # Conv2 has 64 kernels size of 3x3, with stride s2. 694 | # After that, the 64 results are drawn into a vector as the input of the fully 695 | # connected layer FC1 which has n4 nodes. 696 | # In the four datasets, the kernel height nk1 is 24 and stride s1, s2 is 9 and 1 697 | self.conv1 = nn.Conv3d(1, 90, (24, 3, 3), padding=0, stride=(9, 1, 1)) 698 | self.conv2 = nn.Conv2d(1, 64, (3, 3), stride=(1, 1)) 699 | 700 | self.features_size = self._get_final_flattened_size() 701 | 702 | self.fc1 = nn.Linear(self.features_size, 1024) 703 | self.fc2 = nn.Linear(1024, n_classes) 704 | 705 | self.apply(self.weight_init) 706 | 707 | def _get_final_flattened_size(self): 708 | with torch.no_grad(): 709 | x = torch.zeros( 710 | (1, 1, self.input_channels, self.patch_size, self.patch_size) 711 | ) 712 | x = self.conv1(x) 713 | b = x.size(0) 714 | x = x.view(b, 1, -1, self.n_planes) 715 | x = self.conv2(x) 716 | _, c, w, h = x.size() 717 | return c * w * h 718 | 719 | def forward(self, x): 720 | x = F.relu(self.conv1(x)) 721 | b = x.size(0) 722 | x = x.view(b, 1, -1, self.n_planes) 723 | x = F.relu(self.conv2(x)) 724 | x = x.view(-1, self.features_size) 725 | x = F.relu(self.fc1(x)) 726 | x = self.fc2(x) 727 | return x 728 | 729 | 730 | class SharmaEtAl(nn.Module): 731 | """ 732 | HYPERSPECTRAL CNN FOR IMAGE CLASSIFICATION & BAND SELECTION, WITH APPLICATION 733 | TO FACE RECOGNITION 734 | Vivek Sharma, Ali Diba, Tinne Tuytelaars, Luc Van Gool 735 | Technical Report, KU Leuven/ETH Zürich 736 | """ 737 | 738 | @staticmethod 739 | def weight_init(m): 740 | if isinstance(m, (nn.Linear, nn.Conv3d)): 741 | init.kaiming_normal_(m.weight) 742 | init.zeros_(m.bias) 743 | 744 | def __init__(self, input_channels, n_classes, patch_size=64): 745 | super(SharmaEtAl, self).__init__() 746 | self.input_channels = input_channels 747 | self.patch_size = patch_size 748 | 749 | # An input image of size 263x263 pixels is fed to conv1 750 | # with 96 kernels of size 6x6x96 with a stride of 2 pixels 751 | self.conv1 = nn.Conv3d(1, 96, (input_channels, 6, 6), stride=(1, 2, 2)) 752 | self.conv1_bn = nn.BatchNorm3d(96) 753 | self.pool1 = nn.MaxPool3d((1, 2, 2)) 754 | # 256 kernels of size 3x3x256 with a stride of 2 pixels 755 | self.conv2 = nn.Conv3d(1, 256, (96, 3, 3), stride=(1, 2, 2)) 756 | self.conv2_bn = nn.BatchNorm3d(256) 757 | self.pool2 = nn.MaxPool3d((1, 2, 2)) 758 | # 512 kernels of size 3x3x512 with a stride of 1 pixel 759 | self.conv3 = nn.Conv3d(1, 512, (256, 3, 3), stride=(1, 1, 1)) 760 | # Considering those large kernel values, I assume they actually merge the 761 | # 3D tensors at each step 762 | 763 | self.features_size = self._get_final_flattened_size() 764 | 765 | # The fc1 has 1024 outputs, where dropout was applied after 766 | # fc1 with a rate of 0.5 767 | self.fc1 = nn.Linear(self.features_size, 1024) 768 | self.dropout = nn.Dropout(p=0.5) 769 | self.fc2 = nn.Linear(1024, n_classes) 770 | 771 | self.apply(self.weight_init) 772 | 773 | def _get_final_flattened_size(self): 774 | with torch.no_grad(): 775 | x = torch.zeros( 776 | (1, 1, self.input_channels, self.patch_size, self.patch_size) 777 | ) 778 | x = F.relu(self.conv1_bn(self.conv1(x))) 779 | x = self.pool1(x) 780 | print(x.size()) 781 | b, t, c, w, h = x.size() 782 | x = x.view(b, 1, t * c, w, h) 783 | x = F.relu(self.conv2_bn(self.conv2(x))) 784 | x = self.pool2(x) 785 | print(x.size()) 786 | b, t, c, w, h = x.size() 787 | x = x.view(b, 1, t * c, w, h) 788 | x = F.relu(self.conv3(x)) 789 | print(x.size()) 790 | _, t, c, w, h = x.size() 791 | return t * c * w * h 792 | 793 | def forward(self, x): 794 | x = F.relu(self.conv1_bn(self.conv1(x))) 795 | x = self.pool1(x) 796 | b, t, c, w, h = x.size() 797 | x = x.view(b, 1, t * c, w, h) 798 | x = F.relu(self.conv2_bn(self.conv2(x))) 799 | x = self.pool2(x) 800 | b, t, c, w, h = x.size() 801 | x = x.view(b, 1, t * c, w, h) 802 | x = F.relu(self.conv3(x)) 803 | x = x.view(-1, self.features_size) 804 | x = self.fc1(x) 805 | x = self.dropout(x) 806 | x = self.fc2(x) 807 | return x 808 | 809 | 810 | class LiuEtAl(nn.Module): 811 | """ 812 | A semi-supervised convolutional neural network for hyperspectral image classification 813 | Bing Liu, Xuchu Yu, Pengqiang Zhang, Xiong Tan, Anzhu Yu, Zhixiang Xue 814 | Remote Sensing Letters, 2017 815 | """ 816 | 817 | @staticmethod 818 | def weight_init(m): 819 | if isinstance(m, (nn.Linear, nn.Conv2d)): 820 | init.kaiming_normal_(m.weight) 821 | init.zeros_(m.bias) 822 | 823 | def __init__(self, input_channels, n_classes, patch_size=9): 824 | super(LiuEtAl, self).__init__() 825 | self.input_channels = input_channels 826 | self.patch_size = patch_size 827 | self.aux_loss_weight = 1 828 | 829 | # "W1 is a 3x3xB1 kernel [...] B1 is the number of the output bands for the convolutional 830 | # "and pooling layer" -> actually 3x3 2D convolutions with B1 outputs 831 | # "the value of B1 is set to be 80" 832 | self.conv1 = nn.Conv2d(input_channels, 80, (3, 3)) 833 | self.pool1 = nn.MaxPool2d((2, 2)) 834 | self.conv1_bn = nn.BatchNorm2d(80) 835 | 836 | self.features_sizes = self._get_sizes() 837 | 838 | self.fc_enc = nn.Linear(self.features_sizes[2], n_classes) 839 | 840 | # Decoder 841 | self.fc1_dec = nn.Linear(self.features_sizes[2], self.features_sizes[2]) 842 | self.fc1_dec_bn = nn.BatchNorm1d(self.features_sizes[2]) 843 | self.fc2_dec = nn.Linear(self.features_sizes[2], self.features_sizes[1]) 844 | self.fc2_dec_bn = nn.BatchNorm1d(self.features_sizes[1]) 845 | self.fc3_dec = nn.Linear(self.features_sizes[1], self.features_sizes[0]) 846 | self.fc3_dec_bn = nn.BatchNorm1d(self.features_sizes[0]) 847 | self.fc4_dec = nn.Linear(self.features_sizes[0], input_channels) 848 | 849 | self.apply(self.weight_init) 850 | 851 | def _get_sizes(self): 852 | x = torch.zeros((1, self.input_channels, self.patch_size, self.patch_size)) 853 | x = F.relu(self.conv1_bn(self.conv1(x))) 854 | _, c, w, h = x.size() 855 | size0 = c * w * h 856 | 857 | x = self.pool1(x) 858 | _, c, w, h = x.size() 859 | size1 = c * w * h 860 | 861 | x = self.conv1_bn(x) 862 | _, c, w, h = x.size() 863 | size2 = c * w * h 864 | 865 | return size0, size1, size2 866 | 867 | def forward(self, x): 868 | x = x.squeeze() 869 | x_conv1 = self.conv1_bn(self.conv1(x)) 870 | x = x_conv1 871 | x_pool1 = self.pool1(x) 872 | x = x_pool1 873 | x_enc = F.relu(x).view(-1, self.features_sizes[2]) 874 | x = x_enc 875 | 876 | x_classif = self.fc_enc(x) 877 | 878 | # x = F.relu(self.fc1_dec_bn(self.fc1_dec(x) + x_enc)) 879 | x = F.relu(self.fc1_dec(x)) 880 | x = F.relu( 881 | self.fc2_dec_bn(self.fc2_dec(x) + x_pool1.view(-1, self.features_sizes[1])) 882 | ) 883 | x = F.relu( 884 | self.fc3_dec_bn(self.fc3_dec(x) + x_conv1.view(-1, self.features_sizes[0])) 885 | ) 886 | x = self.fc4_dec(x) 887 | return x_classif, x 888 | 889 | 890 | class BoulchEtAl(nn.Module): 891 | """ 892 | Autoencodeurs pour la visualisation d'images hyperspectrales 893 | A.Boulch, N. Audebert, D. Dubucq 894 | GRETSI 2017 895 | """ 896 | 897 | @staticmethod 898 | def weight_init(m): 899 | if isinstance(m, (nn.Linear, nn.Conv1d)): 900 | init.kaiming_normal_(m.weight) 901 | init.zeros_(m.bias) 902 | 903 | def __init__(self, input_channels, n_classes, planes=16): 904 | super(BoulchEtAl, self).__init__() 905 | self.input_channels = input_channels 906 | self.aux_loss_weight = 0.1 907 | 908 | encoder_modules = [] 909 | n = input_channels 910 | with torch.no_grad(): 911 | x = torch.zeros((10, 1, self.input_channels)) 912 | print(x.size()) 913 | while n > 1: 914 | print("---------- {} ---------".format(n)) 915 | if n == input_channels: 916 | p1, p2 = 1, 2 * planes 917 | elif n == input_channels // 2: 918 | p1, p2 = 2 * planes, planes 919 | else: 920 | p1, p2 = planes, planes 921 | encoder_modules.append(nn.Conv1d(p1, p2, 3, padding=1)) 922 | x = encoder_modules[-1](x) 923 | print(x.size()) 924 | encoder_modules.append(nn.MaxPool1d(2)) 925 | x = encoder_modules[-1](x) 926 | print(x.size()) 927 | encoder_modules.append(nn.ReLU(inplace=True)) 928 | x = encoder_modules[-1](x) 929 | print(x.size()) 930 | encoder_modules.append(nn.BatchNorm1d(p2)) 931 | x = encoder_modules[-1](x) 932 | print(x.size()) 933 | n = n // 2 934 | 935 | encoder_modules.append(nn.Conv1d(planes, 3, 3, padding=1)) 936 | encoder_modules.append(nn.Tanh()) 937 | self.encoder = nn.Sequential(*encoder_modules) 938 | self.features_sizes = self._get_sizes() 939 | 940 | self.classifier = nn.Linear(self.features_sizes, n_classes) 941 | self.regressor = nn.Linear(self.features_sizes, input_channels) 942 | self.apply(self.weight_init) 943 | 944 | def _get_sizes(self): 945 | with torch.no_grad(): 946 | x = torch.zeros((10, 1, self.input_channels)) 947 | x = self.encoder(x) 948 | _, c, w = x.size() 949 | return c * w 950 | 951 | def forward(self, x): 952 | x = x.unsqueeze(1) 953 | x = self.encoder(x) 954 | x = x.view(-1, self.features_sizes) 955 | x_classif = self.classifier(x) 956 | x = self.regressor(x) 957 | return x_classif, x 958 | 959 | 960 | class MouEtAl(nn.Module): 961 | """ 962 | Deep recurrent neural networks for hyperspectral image classification 963 | Lichao Mou, Pedram Ghamisi, Xiao Xang Zhu 964 | https://ieeexplore.ieee.org/document/7914752/ 965 | """ 966 | 967 | @staticmethod 968 | def weight_init(m): 969 | # All weight matrices in our RNN and bias vectors are initialized with a uniform distribution, and the values of these weight matrices and bias vectors are initialized in the range [−0.1,0.1] 970 | if isinstance(m, (nn.Linear, nn.GRU)): 971 | init.uniform_(m.weight.data, -0.1, 0.1) 972 | init.uniform_(m.bias.data, -0.1, 0.1) 973 | 974 | def __init__(self, input_channels, n_classes): 975 | # The proposed network model uses a single recurrent layer that adopts our modified GRUs of size 64 with sigmoid gate activation and PRetanh activation functions for hidden representations 976 | super(MouEtAl, self).__init__() 977 | self.input_channels = input_channels 978 | self.gru = nn.GRU(1, 64, 1, bidirectional=False) # TODO: try to change this ? 979 | self.gru_bn = nn.BatchNorm1d(64 * input_channels) 980 | self.tanh = nn.Tanh() 981 | self.fc = nn.Linear(64 * input_channels, n_classes) 982 | 983 | def forward(self, x): 984 | x = x.squeeze() 985 | x = x.unsqueeze(0) 986 | # x is in 1, N, C but we expect C, N, 1 for GRU layer 987 | x = x.permute(2, 1, 0) 988 | x = self.gru(x)[0] 989 | # x is in C, N, 64, we permute back 990 | x = x.permute(1, 2, 0).contiguous() 991 | x = x.view(x.size(0), -1) 992 | x = self.gru_bn(x) 993 | x = self.tanh(x) 994 | x = self.fc(x) 995 | return x 996 | 997 | 998 | def train( 999 | net, 1000 | optimizer, 1001 | criterion, 1002 | data_loader, 1003 | epoch, 1004 | scheduler=None, 1005 | display_iter=100, 1006 | device=torch.device("cpu"), 1007 | display=None, 1008 | val_loader=None, 1009 | supervision="full", 1010 | ): 1011 | """ 1012 | Training loop to optimize a network for several epochs and a specified loss 1013 | 1014 | Args: 1015 | net: a PyTorch model 1016 | optimizer: a PyTorch optimizer 1017 | data_loader: a PyTorch dataset loader 1018 | epoch: int specifying the number of training epochs 1019 | criterion: a PyTorch-compatible loss function, e.g. nn.CrossEntropyLoss 1020 | device (optional): torch device to use (defaults to CPU) 1021 | display_iter (optional): number of iterations before refreshing the 1022 | display (False/None to switch off). 1023 | scheduler (optional): PyTorch scheduler 1024 | val_loader (optional): validation dataset 1025 | supervision (optional): 'full' or 'semi' 1026 | """ 1027 | 1028 | if criterion is None: 1029 | raise Exception("Missing criterion. You must specify a loss function.") 1030 | 1031 | net.to(device) 1032 | 1033 | save_epoch = epoch // 20 if epoch > 20 else 1 1034 | 1035 | losses = np.zeros(1000000) 1036 | mean_losses = np.zeros(100000000) 1037 | iter_ = 1 1038 | loss_win, val_win = None, None 1039 | val_accuracies = [] 1040 | 1041 | for e in tqdm(range(1, epoch + 1), desc="Training the network"): 1042 | # Set the network to training mode 1043 | net.train() 1044 | avg_loss = 0.0 1045 | 1046 | # Run the training loop for one epoch 1047 | for batch_idx, (data, target) in tqdm( 1048 | enumerate(data_loader), total=len(data_loader) 1049 | ): 1050 | # Load the data into the GPU if required 1051 | data, target = data.to(device), target.to(device) 1052 | 1053 | optimizer.zero_grad() 1054 | if supervision == "full": 1055 | output = net(data) 1056 | loss = criterion(output, target) 1057 | elif supervision == "semi": 1058 | outs = net(data) 1059 | output, rec = outs 1060 | loss = criterion[0](output, target) + net.aux_loss_weight * criterion[ 1061 | 1 1062 | ](rec, data) 1063 | else: 1064 | raise ValueError( 1065 | 'supervision mode "{}" is unknown.'.format(supervision) 1066 | ) 1067 | loss.backward() 1068 | optimizer.step() 1069 | 1070 | avg_loss += loss.item() 1071 | losses[iter_] = loss.item() 1072 | mean_losses[iter_] = np.mean(losses[max(0, iter_ - 100) : iter_ + 1]) 1073 | 1074 | if display_iter and iter_ % display_iter == 0: 1075 | string = "Train (epoch {}/{}) [{}/{} ({:.0f}%)]\tLoss: {:.6f}" 1076 | string = string.format( 1077 | e, 1078 | epoch, 1079 | batch_idx * len(data), 1080 | len(data) * len(data_loader), 1081 | 100.0 * batch_idx / len(data_loader), 1082 | mean_losses[iter_], 1083 | ) 1084 | update = None if loss_win is None else "append" 1085 | loss_win = display.line( 1086 | X=np.arange(iter_ - display_iter, iter_), 1087 | Y=mean_losses[iter_ - display_iter : iter_], 1088 | win=loss_win, 1089 | update=update, 1090 | opts={ 1091 | "title": "Training loss", 1092 | "xlabel": "Iterations", 1093 | "ylabel": "Loss", 1094 | }, 1095 | ) 1096 | tqdm.write(string) 1097 | 1098 | if len(val_accuracies) > 0: 1099 | val_win = display.line( 1100 | Y=np.array(val_accuracies), 1101 | X=np.arange(len(val_accuracies)), 1102 | win=val_win, 1103 | opts={ 1104 | "title": "Validation accuracy", 1105 | "xlabel": "Epochs", 1106 | "ylabel": "Accuracy", 1107 | }, 1108 | ) 1109 | iter_ += 1 1110 | del (data, target, loss, output) 1111 | 1112 | # Update the scheduler 1113 | avg_loss /= len(data_loader) 1114 | if val_loader is not None: 1115 | val_acc = val(net, val_loader, device=device, supervision=supervision) 1116 | val_accuracies.append(val_acc) 1117 | metric = -val_acc 1118 | else: 1119 | metric = avg_loss 1120 | 1121 | if isinstance(scheduler, optim.lr_scheduler.ReduceLROnPlateau): 1122 | scheduler.step(metric) 1123 | elif scheduler is not None: 1124 | scheduler.step() 1125 | 1126 | # Save the weights 1127 | if e % save_epoch == 0: 1128 | save_model( 1129 | net, 1130 | camel_to_snake(str(net.__class__.__name__)), 1131 | data_loader.dataset.name, 1132 | epoch=e, 1133 | metric=abs(metric), 1134 | ) 1135 | 1136 | 1137 | def save_model(model, model_name, dataset_name, **kwargs): 1138 | model_dir = "./checkpoints/" + model_name + "/" + dataset_name + "/" 1139 | """ 1140 | Using strftime in case it triggers exceptions on windows 10 system 1141 | """ 1142 | time_str = datetime.datetime.now().strftime('%Y_%m_%d_%H_%M_%S') 1143 | if not os.path.isdir(model_dir): 1144 | os.makedirs(model_dir, exist_ok=True) 1145 | if isinstance(model, torch.nn.Module): 1146 | filename = time_str + "_epoch{epoch}_{metric:.2f}".format( 1147 | **kwargs 1148 | ) 1149 | tqdm.write("Saving neural network weights in {}".format(filename)) 1150 | torch.save(model.state_dict(), model_dir + filename + ".pth") 1151 | else: 1152 | filename = time_str 1153 | tqdm.write("Saving model params in {}".format(filename)) 1154 | joblib.dump(model, model_dir + filename + ".pkl") 1155 | 1156 | 1157 | def test(net, img, hyperparams): 1158 | """ 1159 | Test a model on a specific image 1160 | """ 1161 | net.eval() 1162 | patch_size = hyperparams["patch_size"] 1163 | center_pixel = hyperparams["center_pixel"] 1164 | batch_size, device = hyperparams["batch_size"], hyperparams["device"] 1165 | n_classes = hyperparams["n_classes"] 1166 | 1167 | kwargs = { 1168 | "step": hyperparams["test_stride"], 1169 | "window_size": (patch_size, patch_size), 1170 | } 1171 | probs = np.zeros(img.shape[:2] + (n_classes,)) 1172 | 1173 | iterations = count_sliding_window(img, **kwargs) // batch_size 1174 | for batch in tqdm( 1175 | grouper(batch_size, sliding_window(img, **kwargs)), 1176 | total=(iterations), 1177 | desc="Inference on the image", 1178 | ): 1179 | with torch.no_grad(): 1180 | if patch_size == 1: 1181 | data = [b[0][0, 0] for b in batch] 1182 | data = np.copy(data) 1183 | data = torch.from_numpy(data) 1184 | else: 1185 | data = [b[0] for b in batch] 1186 | data = np.copy(data) 1187 | data = data.transpose(0, 3, 1, 2) 1188 | data = torch.from_numpy(data) 1189 | data = data.unsqueeze(1) 1190 | 1191 | indices = [b[1:] for b in batch] 1192 | data = data.to(device) 1193 | output = net(data) 1194 | if isinstance(output, tuple): 1195 | output = output[0] 1196 | output = output.to("cpu") 1197 | 1198 | if patch_size == 1 or center_pixel: 1199 | output = output.numpy() 1200 | else: 1201 | output = np.transpose(output.numpy(), (0, 2, 3, 1)) 1202 | for (x, y, w, h), out in zip(indices, output): 1203 | if center_pixel: 1204 | probs[x + w // 2, y + h // 2] += out 1205 | else: 1206 | probs[x : x + w, y : y + h] += out 1207 | return probs 1208 | 1209 | 1210 | def val(net, data_loader, device="cpu", supervision="full"): 1211 | # TODO : fix me using metrics() 1212 | accuracy, total = 0.0, 0.0 1213 | ignored_labels = data_loader.dataset.ignored_labels 1214 | for batch_idx, (data, target) in enumerate(data_loader): 1215 | with torch.no_grad(): 1216 | # Load the data into the GPU if required 1217 | data, target = data.to(device), target.to(device) 1218 | if supervision == "full": 1219 | output = net(data) 1220 | elif supervision == "semi": 1221 | outs = net(data) 1222 | output, rec = outs 1223 | _, output = torch.max(output, dim=1) 1224 | for out, pred in zip(output.view(-1), target.view(-1)): 1225 | if out.item() in ignored_labels: 1226 | continue 1227 | else: 1228 | accuracy += out.item() == pred.item() 1229 | total += 1 1230 | return accuracy / total 1231 | --------------------------------------------------------------------------------