├── Figures ├── FigureResults.pdf ├── FigureResultsPNG.png ├── FlowchartDiagram.pdf ├── FlowchartDiagramPNG.png └── __init__ ├── LICENSE ├── README.md ├── requirements.txt └── src ├── __init__ ├── inference.py ├── modeling_SegFormer.py ├── test.py └── train.py /Figures/FigureResults.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ESA-PhiLab/Learning_from_Unlabeled_Data_for_Domain_Adaptation_for_Semantic_Segmentation/9351e9d5e8d9e58eff007564fa921b60d3fff3bd/Figures/FigureResults.pdf -------------------------------------------------------------------------------- /Figures/FigureResultsPNG.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ESA-PhiLab/Learning_from_Unlabeled_Data_for_Domain_Adaptation_for_Semantic_Segmentation/9351e9d5e8d9e58eff007564fa921b60d3fff3bd/Figures/FigureResultsPNG.png -------------------------------------------------------------------------------- /Figures/FlowchartDiagram.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ESA-PhiLab/Learning_from_Unlabeled_Data_for_Domain_Adaptation_for_Semantic_Segmentation/9351e9d5e8d9e58eff007564fa921b60d3fff3bd/Figures/FlowchartDiagram.pdf -------------------------------------------------------------------------------- /Figures/FlowchartDiagramPNG.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ESA-PhiLab/Learning_from_Unlabeled_Data_for_Domain_Adaptation_for_Semantic_Segmentation/9351e9d5e8d9e58eff007564fa921b60d3fff3bd/Figures/FlowchartDiagramPNG.png -------------------------------------------------------------------------------- /Figures/__init__: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 ESA-PhiLab 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Learning_from_Unlabeled_Data_for_Domain_Adaptation_for_Semantic_Segmentation 2 | Learning from Unlabeled EO Data: Domain Adaptation for Semantic Segmentation 3 | 4 | Paper: [http://arxiv.org/pdf/2404.11299.pdf](http://arxiv.org/pdf/2404.11299.pdf) - arXiv:2404.11299 5 | 6 | GitHub repository of the published paper: "Learning from Unlabelled Data with Transformers: Domain Adaptation for Semantic Segmentation of High Resolution Aerial Images" 7 | 8 | Authors: Nikolaos Dionelis*, Francesco Pro#, Luca Maiano#, Irene Amerini#, Bertrand Le Saux* 9 | 10 | Affiliation: (1) *European Space Agency (ESA), ESRIN, Φ-lab, Italy, and 11 | 12 | (2) #Sapienza University of Rome, Italy 13 | 14 | Related paper that uses our model: [http://arxiv.org/pdf/2404.11302.pdf](http://arxiv.org/pdf/2404.11302.pdf) - arXiv:2404.11302 15 | 16 | Related website: [Click here](http://scholar.google.com/citations?hl=en&user=2UweGHoAAAAJ&view_op=list_works&sortby=pubdate) 17 | 18 | Non-annotated Earth Observation Semantic Segmentation (NEOS) model 19 | 20 | ## Abstract of Paper: 21 | 22 | Data from satellites or aerial vehicles are most of the times unlabelled. Annotating such data accurately is difficult, requires expertise, and is costly in terms of time. Even if Earth Observation (EO) data were correctly labelled, labels might change over time. Learning from unlabelled data within a semi-supervised learning framework for segmentation of aerial images is challenging. To this end, in this paper, we develop a new model for semantic segmentation of unlabelled images, the Non-annotated Earth Observation Semantic Segmentation (NEOS) model. NEOS performs domain adaptation as the target domain does not have ground truth masks. The distribution inconsistencies between the target and source domains are due to differences in acquisition scenes, environment conditions, sensors, and times. Our model aligns the learned representations of the different domains to make them coincide. The evaluation results show that it is successful and outperforms other models for semantic segmentation of unlabelled data. 23 | 24 | Index Terms - Semantic segmentation, Unlabelled data, Domain adaptation, Aerial images 25 | 26 | ## Flowchart Diagram: 27 | 28 | ![plot](./Figures/FlowchartDiagramPNG.png) 29 | 30 | Figure 1: Flowchart of the proposed model NEOS for semantic segmentation of unlabelled data using domain adaptation. 31 | 32 | ## Brief Discussion about the Model: 33 | 34 | We develop a model for semantic segmentation of unlabelled aerial images, the Non-annotated Earth Observation Semantic Segmentation (NEOS) model. To address the problem of performing accurate semantic segmentation on unlabelled image data, the model NEOS performs domain adaptation. NEOS makes the learned latent representations of the different domains to coincide. This is achieved by minimizing a loss function that makes the network to align the latent features of the different domains. Our main contribution is the development of a model for semantic segmentation of aerial images that do not have ground truth segmentation masks, performing domain adaptation. 35 | 36 | ## Usage: 37 | 38 | For the evaluation of the proposed model NEOS, we use the dataset CVUSA, which is unlabelled, as well as the datasets Potsdam and Vaihingen, which are labelled. We use the CVUSA aerial images. In the paper, we also use the labelled dataset CityScapes for the CVUSA street images. 39 | 40 | Tested with: Python 3.10.13 41 | 42 | Code in PyTorch 43 | 44 | We have used the Visual Studio Code IDE Code Editor. 45 | 46 | Download the dataset CVUSA, as well as the datasets Potsdam and Vaihingen. Also, download the dataset CityScapes. 47 | 48 | The image datasets have to be stored in folders so that the training and the testing/ inference of the model NEOS is performed, as described in the paper. 49 | 50 | Example usage: 51 | 52 | Run from the Terminal: 53 | 54 | ``` 55 | cd ./src/ 56 | python train.py 57 | ``` 58 | 59 | ## Further Usage Information: 60 | 61 | This GitHub code repository contains a PyTorch implementation for the model NEOS. 62 | 63 | To install the required libraries: Run "pip install -r requirements.txt" from the Terminal. 64 | 65 | In Linux, run: 66 | ``` 67 | git clone 68 | conda create -n modelNEOS python 69 | conda info --envs 70 | conda activate modelNEOS 71 | pip install --user --requirement requirements.txt 72 | ``` 73 | 74 | ## NEOS Results 75 | 76 | The main results and the key outcomes of the proposed model NEOS can be found in the paper "Learning from Unlabelled Data with Transformers: Domain Adaptation for Semantic Segmentation of High Resolution Aerial Images". 77 | 78 | Results of NEOS using the source code: 79 | 80 | ![plot](./Figures/FigureResultsPNG.png) 81 | 82 | ## Acknowledgement: 83 | 84 | All the acknowledgements, references, and citations for the model NEOS can be found in the paper "Learning from Unlabelled Data with Transformers: Domain Adaptation for Semantic Segmentation of High Resolution Aerial Images". 85 | 86 | ## If you use our code, please cite: 87 | 88 | Nikolaos Dionelis, Francesco Pro, Luca Maiano, Irene Amerini, and Bertrand Le Saux, "Learning from Unlabelled Data with Transformers: Domain Adaptation for Semantic Segmentation of High Resolution Aerial Images," To appear, 2023. 89 | 90 | ``` 91 | @misc{dionelis2023, 92 | title = "Learning from Unlabelled Data with Transformers: Domain Adaptation for Semantic Segmentation of High Resolution Aerial Images", 93 | author = "Nikolaos Dionelis and Francesco Pro and Luca Maiano and Irene Amerini and Bertrand Le Saux", 94 | year = 2023, 95 | note = "To appear" 96 | } 97 | ``` 98 | 99 | If you would like to get in touch, please contact: [Nikolaos.Dionelis@esa.int](mailto:Nikolaos.Dionelis@esa.int?subject=[GitHub]). 100 | 101 | ## Further Information 102 | 103 | Paper: "Learning from Unlabelled Data with Transformers: Domain Adaptation for Semantic Segmentation of High Resolution Aerial Images" 104 | 105 | Published in: IEEE International Geoscience and Remote Sensing Symposium (IGARSS 2024) - [IGARSS, Athens](http://www.2024.ieeeigarss.org) 106 | 107 | European Space Agency (ESA), ESRIN, Φ-lab, Italy: [ESA Φ-lab](http://philab.esa.int) 108 | 109 | Sapienza University of Rome, Italy: [Sapienza University of Rome](http://www.uniroma1.it/en/pagina-strutturale/home) 110 | 111 | 112 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==2.0.0 2 | affine==2.4.0 3 | aiohttp==3.9.0 4 | aiosignal==1.3.1 5 | albumentations==1.3.1 6 | aniso8601==9.0.1 7 | anyio==4.2.0 8 | argon2-cffi==23.1.0 9 | argon2-cffi-bindings==21.2.0 10 | arrow==1.3.0 11 | asttokens==2.4.1 12 | async-lru==2.0.4 13 | async-timeout==4.0.3 14 | attrs==23.1.0 15 | av==11.0.0 16 | Babel==2.14.0 17 | beautifulsoup4==4.12.2 18 | bleach==6.1.0 19 | blinker==1.7.0 20 | bqplot==0.12.42 21 | branca==0.7.0 22 | Brotli @ file:///tmp/abs_ecyw11_7ze/croots/recipe/brotli-split_1659616059936/work 23 | buteo==0.9.54 24 | cachelib==0.9.0 25 | cachetools==5.3.2 26 | certifi @ file:///croot/certifi_1700501669400/work/certifi 27 | cffi @ file:///croot/cffi_1700254295673/work 28 | charset-normalizer @ file:///tmp/build/80754af9/charset-normalizer_1630003229654/work 29 | click==8.1.7 30 | click-plugins==1.1.1 31 | cligj==0.7.2 32 | colour==0.1.5 33 | comm==0.2.0 34 | contourpy==1.2.0 35 | cryptography @ file:///croot/cryptography_1694444244250/work 36 | cycler==0.12.1 37 | datasets==2.15.0 38 | debugpy==1.8.0 39 | decorator==5.1.1 40 | defusedxml==0.7.1 41 | dill==0.3.7 42 | duckdb==0.9.2 43 | efficientnet-pytorch==0.7.1 44 | exceptiongroup==1.2.0 45 | executing==2.0.1 46 | fastjsonschema==2.19.0 47 | filelock==3.13.1 48 | fiona==1.9.5 49 | Flask==2.3.3 50 | Flask-Caching==2.1.0 51 | Flask-Cors==4.0.0 52 | flask-restx==1.2.0 53 | folium==0.15.0 54 | fonttools==4.45.0 55 | fqdn==1.5.1 56 | frozenlist==1.4.0 57 | fsspec==2023.10.0 58 | GDAL==3.3.2 59 | gdown==4.7.1 60 | geojson==3.1.0 61 | geopandas==0.14.1 62 | google-auth==2.25.2 63 | google-auth-oauthlib==1.2.0 64 | grpcio==1.60.0 65 | h11==0.14.0 66 | huggingface-hub==0.19.4 67 | idna @ file:///croot/idna_1666125576474/work 68 | imageio==2.33.0 69 | importlib-resources==6.1.1 70 | IProgress==0.4 71 | ipyevents==2.0.2 72 | ipyfilechooser==0.6.0 73 | ipykernel==6.27.1 74 | ipyleaflet==0.18.0 75 | ipython==8.17.2 76 | ipytree==0.2.2 77 | ipywidgets==8.1.1 78 | isoduration==20.11.0 79 | itsdangerous==2.1.2 80 | jedi==0.19.1 81 | Jinja2==3.1.2 82 | joblib==1.3.2 83 | json5==0.9.14 84 | jsonpointer==2.4 85 | jsonschema==4.20.0 86 | jsonschema-specifications==2023.11.2 87 | jupyter==1.0.0 88 | jupyter-console==6.6.3 89 | jupyter-events==0.9.0 90 | jupyter-lsp==2.2.1 91 | jupyter_client==8.6.0 92 | jupyter_core==5.5.1 93 | jupyter_server==2.12.1 94 | jupyter_server_terminals==0.5.0 95 | jupyterlab==4.0.9 96 | jupyterlab-widgets==3.0.9 97 | jupyterlab_pygments==0.3.0 98 | jupyterlab_server==2.25.2 99 | kiwisolver==1.4.5 100 | kornia==0.7.0 101 | large-image==1.26.2 102 | large-image-source-rasterio==1.26.2 103 | lazy_loader==0.3 104 | leafmap==0.29.5 105 | llvmlite==0.41.1 106 | localtileserver==0.7.2 107 | Markdown==3.5.1 108 | MarkupSafe==2.1.3 109 | matplotlib==3.8.2 110 | matplotlib-inline==0.1.6 111 | mistune==3.0.2 112 | mkl-fft @ file:///croot/mkl_fft_1695058164594/work 113 | mkl-random @ file:///croot/mkl_random_1695059800811/work 114 | mkl-service==2.4.0 115 | multidict==6.0.4 116 | multiprocess==0.70.15 117 | munch==4.0.0 118 | nbclient==0.9.0 119 | nbconvert==7.13.1 120 | nbformat==5.9.2 121 | nest-asyncio==1.5.8 122 | networkx==3.2.1 123 | notebook==7.0.6 124 | notebook_shim==0.2.3 125 | numba==0.58.1 126 | numpy @ file:///croot/numpy_and_numpy_base_1695830428084/work/dist/numpy-1.26.0-cp310-cp310-linux_x86_64.whl#sha256=fc2732718bc9e06a7b702492cb4f5afffe9671083930452d894377bf563464a3 127 | oauthlib==3.2.2 128 | opencv-python==4.8.1.78 129 | opencv-python-headless==4.8.1.78 130 | overrides==7.4.0 131 | packaging==23.2 132 | palettable==3.3.3 133 | pandas==2.1.3 134 | pandocfilters==1.5.0 135 | parso==0.8.3 136 | pexpect==4.8.0 137 | Pillow @ file:///croot/pillow_1696580024257/work 138 | platformdirs==4.1.0 139 | plotly==5.18.0 140 | pretrainedmodels==0.7.4 141 | prometheus-client==0.19.0 142 | prompt-toolkit==3.0.41 143 | protobuf==4.23.4 144 | psutil==5.9.7 145 | ptyprocess==0.7.0 146 | pure-eval==0.2.2 147 | pyarrow==14.0.1 148 | pyarrow-hotfix==0.6 149 | pyasn1==0.5.1 150 | pyasn1-modules==0.3.0 151 | pycocotools==2.0.7 152 | pycparser @ file:///tmp/build/80754af9/pycparser_1636541352034/work 153 | Pygments==2.17.2 154 | pyOpenSSL @ file:///croot/pyopenssl_1690223430423/work 155 | pyparsing==3.1.1 156 | pyproj==3.6.1 157 | pyrsistent==0.20.0 158 | pyshp==2.3.1 159 | PySocks @ file:///home/builder/ci_310/pysocks_1640793678128/work 160 | pystac==1.8.3 161 | pystac-client==0.7.5 162 | python-box==7.1.1 163 | python-dateutil==2.8.2 164 | python-json-logger==2.0.7 165 | pytorch-msssim==1.0.0 166 | pytorch-ssim==0.1 167 | pytz==2023.3.post1 168 | PyYAML==6.0.1 169 | pyzmq==25.1.2 170 | qtconsole==5.5.1 171 | QtPy==2.4.1 172 | qudida==0.0.4 173 | rasterio==1.3.9 174 | referencing==0.32.0 175 | regex==2023.10.3 176 | requests @ file:///croot/requests_1690400202158/work 177 | requests-oauthlib==1.3.1 178 | rfc3339-validator==0.1.4 179 | rfc3986-validator==0.1.1 180 | rpds-py==0.15.2 181 | rsa==4.9 182 | safetensors==0.4.0 183 | scikit-image==0.22.0 184 | scikit-learn==1.3.2 185 | scipy==1.11.4 186 | scooby==0.9.2 187 | seaborn==0.13.0 188 | segment-anything-hq==0.3 189 | segment-anything-py==1.0 190 | segment-geospatial==0.10.2 191 | segmentation-models-pytorch==0.3.3 192 | Send2Trash==1.8.2 193 | server-thread==0.2.0 194 | shapely==2.0.2 195 | six==1.16.0 196 | sniffio==1.3.0 197 | snuggs==1.4.7 198 | soupsieve==2.5 199 | stack-data==0.6.3 200 | tenacity==8.2.3 201 | tensorboard==2.15.1 202 | tensorboard-data-server==0.7.2 203 | terminado==0.18.0 204 | threadpoolctl==3.2.0 205 | tifffile==2023.9.26 206 | timm==0.9.2 207 | tinycss2==1.2.1 208 | tokenizers==0.15.0 209 | tomli==2.0.1 210 | torch==1.12.1 211 | torchaudio==0.12.1 212 | torchsummary==1.5.1 213 | torchvision==0.13.1 214 | tornado==6.4 215 | tqdm==4.66.1 216 | traitlets==5.13.0 217 | traittypes==0.2.1 218 | transformers==4.35.2 219 | types-python-dateutil==2.8.19.14 220 | typing_extensions @ file:///croot/typing_extensions_1690297465030/work 221 | tzdata==2023.3 222 | uri-template==1.3.0 223 | urllib3 @ file:///croot/urllib3_1698257533958/work 224 | uvicorn==0.24.0.post1 225 | wcwidth==0.2.12 226 | webcolors==1.13 227 | webencodings==0.5.1 228 | websocket-client==1.7.0 229 | Werkzeug==2.3.8 230 | whitebox==2.3.1 231 | whiteboxgui==2.3.0 232 | widgetsnbextension==4.0.9 233 | xxhash==3.4.1 234 | xyzservices==2023.10.1 235 | yarl==1.9.3 236 | -------------------------------------------------------------------------------- /src/__init__: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /src/inference.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | # # torch.save(model.state_dict(), '/Data/ndionelis/formodels/sgfrmr2ehwh') 5 | # # data_dir = '/Data/ndionelis/bingmap/maindatasetnewnew2' 6 | # # Potsdam 7 | # # data_dir = '/Data/ndionelis/mainpo2' 8 | # # On the dataset Vaihingen 9 | # # data_dir = '/Data/ndionelis/mainva2' 10 | 11 | # # This is to test: torch.save(model.state_dict(), './segformermain30082023') 12 | # # Also: data_dir = '../../CVUSA/bingmap/mainTheDataset' 13 | 14 | # # pip install pytorch_ssim 15 | 16 | import sys 17 | #get_ipython().system('{sys.executable} -m pip install pytorch_ssim') 18 | #get_ipython().system('{sys.executable} -m pip install pytorch_msssim') 19 | 20 | # #!{sys.executable} -m conda install pytorch torchvision torchaudio pytorch-cuda=11.8 -c pytorch -c nvidia 21 | # #!{sys.executable} -m pip install timm 22 | 23 | # import numpy as np 24 | # from skimage import io 25 | # from glob import glob 26 | # from tqdm import tqdm_notebook as tqdm 27 | # from sklearn.metrics import confusion_matrix 28 | # import random 29 | # import itertools 30 | # # # Matplotlib 31 | # import matplotlib.pyplot as plt 32 | # #get_ipython().run_line_magic('matplotlib', 'inline') 33 | # #from IPython import get_ipython 34 | # #get_ipython().run_line_magic('matplotlib', 'inline') 35 | # #exec(%matplotlib inline) 36 | # # # Torch imports 37 | # import torch 38 | # import torch.nn as nn 39 | # import torch.nn.functional as F 40 | # import torch.utils.data as data 41 | # import torch.optim as optim 42 | # import torch.optim.lr_scheduler 43 | # import torch.nn.init 44 | # from torch.autograd import Variable 45 | # import torchvision.transforms as T 46 | # import albumentations as A 47 | # import segmentation_models_pytorch as smp 48 | # import kornia 49 | 50 | # WINDOW_SIZE = (256, 256) # Patch size 51 | # STRIDE = 32 # # # # Stride for testing 52 | # IN_CHANNELS = 3 # Number of input channels (e.g. RGB) 53 | # #FOLDER = "./ISPRS_dataset/" # # Replace with your "/path/to/the/ISPRS/dataset/folder/" 54 | # #FOLDER = "../" 55 | # FOLDER = "/Data/ndionelis/" 56 | # #BATCH_SIZE = 10 # Number of samples in a mini-batch 57 | # #BATCH_SIZE = 64 58 | # #BATCH_SIZE = 10 59 | # BATCH_SIZE = 10 60 | 61 | # LABELS = ["roads", "buildings", "low veg.", "trees", "cars", "clutter"] # # Label names 62 | # N_CLASSES = len(LABELS) # # Number of classes 63 | # #print(N_CLASSES) 64 | 65 | # WEIGHTS = torch.ones(N_CLASSES) # # # Weights for class balancing 66 | # CACHE = True # # Store the dataset in-memory 67 | 68 | # #DATASET = 'Vaihingen' 69 | # DATASET = 'Potsdam' 70 | 71 | # DATASET2 = 'Vaihingen' 72 | 73 | # if DATASET == 'Potsdam': 74 | # MAIN_FOLDER = FOLDER + 'Potsdam/' 75 | # # Uncomment the next line for IRRG data 76 | # # DATA_FOLDER = MAIN_FOLDER + '3_Ortho_IRRG/top_potsdam_{}_IRRG.tif' 77 | # # # For RGB data 78 | # #print(MAIN_FOLDER) 79 | # #sadfszf 80 | # #print(MAIN_FOLDER) 81 | # #asdfklsz 82 | # DATA_FOLDER = MAIN_FOLDER + '2_Ortho_RGB/top_potsdam_{}_RGB.tif' 83 | # LABEL_FOLDER = MAIN_FOLDER + '5_Labels_for_participants/top_potsdam_{}_label.tif' 84 | # ERODED_FOLDER = MAIN_FOLDER + '5_Labels_for_participants_no_Boundary/top_potsdam_{}_label_noBoundary.tif' 85 | 86 | # elif DATASET == 'Vaihingen': 87 | # MAIN_FOLDER = FOLDER + 'Vaihingen/' 88 | # #print(MAIN_FOLDER) 89 | # #asdfszdf 90 | # DATA_FOLDER = MAIN_FOLDER + 'top/top_mosaic_09cm_area{}.tif' 91 | # LABEL_FOLDER = MAIN_FOLDER + 'gts_for_participants/top_mosaic_09cm_area{}.tif' 92 | # ERODED_FOLDER = MAIN_FOLDER + 'gts_eroded_for_participants/top_mosaic_09cm_area{}_noBoundary.tif' 93 | 94 | # # imports and stuff 95 | import numpy as np 96 | from skimage import io 97 | from glob import glob 98 | from tqdm import tqdm_notebook as tqdm 99 | from sklearn.metrics import confusion_matrix 100 | import random 101 | import itertools 102 | # # Matplotlib 103 | import matplotlib.pyplot as plt 104 | #get_ipython().run_line_magic('matplotlib', 'inline') 105 | # # Torch imports 106 | import torch 107 | import torch.nn as nn 108 | import torch.nn.functional as F 109 | import torch.utils.data as data 110 | import torch.optim as optim 111 | import torch.optim.lr_scheduler 112 | import torch.nn.init 113 | from torch.autograd import Variable 114 | import torchvision.transforms as T 115 | import albumentations as A 116 | import segmentation_models_pytorch as smp 117 | import kornia 118 | 119 | import time 120 | 121 | WINDOW_SIZE = (220, 220) # Patch size 122 | #WINDOW_SIZE = (512, 512) # Patch size 123 | STRIDE = 32 # # # # # Stride for testing 124 | IN_CHANNELS = 3 # Number of input channels (e.g. RGB) 125 | #FOLDER = "./ISPRS_dataset/" # Replace with your "/path/to/the/ISPRS/dataset/folder/" 126 | #FOLDER = "../../" 127 | #FOLDER = "/Data/ndionelis/" 128 | FOLDER = '../' 129 | #BATCH_SIZE = 10 # # Number of samples in a mini-batch 130 | #BATCH_SIZE = 64 131 | #BATCH_SIZE = 128 132 | #BATCH_SIZE = 256 133 | BATCH_SIZE = 10 134 | #BATCH_SIZE = 30 135 | 136 | LABELS = ["roads", "buildings", "low veg.", "trees", "cars", "clutter"] # # Label names 137 | N_CLASSES = len(LABELS) # # Number of classes 138 | #print(N_CLASSES) 139 | 140 | WEIGHTS = torch.ones(N_CLASSES) # # # Weights for class balancing 141 | CACHE = True # # Store the dataset in-memory 142 | 143 | DATASET = 'Vaihingen' 144 | #DATASET = 'Potsdam' 145 | 146 | if DATASET == 'Potsdam': 147 | MAIN_FOLDER = FOLDER + 'Potsdam/' 148 | # Uncomment the next line for IRRG data 149 | # DATA_FOLDER = MAIN_FOLDER + '3_Ortho_IRRG/top_potsdam_{}_IRRG.tif' 150 | # For RGB data 151 | DATA_FOLDER = MAIN_FOLDER + '2_Ortho_RGB/top_potsdam_{}_RGB.tif' 152 | LABEL_FOLDER = MAIN_FOLDER + '5_Labels_for_participants/top_potsdam_{}_label.tif' 153 | ERODED_FOLDER = MAIN_FOLDER + '5_Labels_for_participants_no_Boundary/top_potsdam_{}_label_noBoundary.tif' 154 | elif DATASET == 'Vaihingen': 155 | MAIN_FOLDER = FOLDER + 'Vaihingen/' 156 | DATA_FOLDER = MAIN_FOLDER + 'top/top_mosaic_09cm_area{}.tif' 157 | LABEL_FOLDER = MAIN_FOLDER + 'gts_for_participants/top_mosaic_09cm_area{}.tif' 158 | ERODED_FOLDER = MAIN_FOLDER + 'gts_eroded_for_participants/top_mosaic_09cm_area{}_noBoundary.tif' 159 | 160 | palette = {0 : (255, 255, 255), # Impervious surfaces (white) 161 | 1 : (0, 0, 255), # # Buildings (blue) 162 | 2 : (0, 255, 255), # Low vegetation (cyan) 163 | 3 : (0, 255, 0), # Trees (green) 164 | 4 : (255, 255, 0), # Cars (yellow) 165 | 5 : (255, 0, 0), # Clutter (red) 166 | 6 : (0, 0, 0)} # Undefined (black) 167 | invert_palette = {v: k for k, v in palette.items()} 168 | 169 | def ade_palette(): 170 | """ADE20K palette that maps each class to RGB values. """ 171 | return [[120, 120, 120], [180, 120, 120], [6, 230, 230], [80, 50, 50], 172 | [4, 200, 3], [120, 120, 80], [140, 140, 140], [204, 5, 255], 173 | [230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7], 174 | [150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82], 175 | [143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3], 176 | [0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255], 177 | [255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220], 178 | [255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224], 179 | [255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255], 180 | [224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7], 181 | [255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153], 182 | [6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255], 183 | [140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0], 184 | [255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255], 185 | [255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255], 186 | [11, 200, 200], [255, 82, 0], [0, 255, 245], [0, 61, 255], 187 | [0, 255, 112], [0, 255, 133], [255, 0, 0], [255, 163, 0], 188 | [255, 102, 0], [194, 255, 0], [0, 143, 255], [51, 255, 0], 189 | [0, 82, 255], [0, 255, 41], [0, 255, 173], [10, 0, 255], 190 | [173, 255, 0], [0, 255, 153], [255, 92, 0], [255, 0, 255], 191 | [255, 0, 245], [255, 0, 102], [255, 173, 0], [255, 0, 20], 192 | [255, 184, 184], [0, 31, 255], [0, 255, 61], [0, 71, 255], 193 | [255, 0, 204], [0, 255, 194], [0, 255, 82], [0, 10, 255], 194 | [0, 112, 255], [51, 0, 255], [0, 194, 255], [0, 122, 255], 195 | [0, 255, 163], [255, 153, 0], [0, 255, 10], [255, 112, 0], 196 | [143, 255, 0], [82, 0, 255], [163, 255, 0], [255, 235, 0], 197 | [8, 184, 170], [133, 0, 255], [0, 255, 92], [184, 0, 255], 198 | [255, 0, 31], [0, 184, 255], [0, 214, 255], [255, 0, 112], 199 | [92, 255, 0], [0, 224, 255], [112, 224, 255], [70, 184, 160], 200 | [163, 0, 255], [153, 0, 255], [71, 255, 0], [255, 0, 163], 201 | [255, 204, 0], [255, 0, 143], [0, 255, 235], [133, 255, 0], 202 | [255, 0, 235], [245, 0, 255], [255, 0, 122], [255, 245, 0], 203 | [10, 190, 212], [214, 255, 0], [0, 204, 255], [20, 0, 255], 204 | [255, 255, 0], [0, 153, 255], [0, 41, 255], [0, 255, 204], 205 | [41, 0, 255], [41, 255, 0], [173, 0, 255], [0, 245, 255], 206 | [71, 0, 255], [122, 0, 255], [0, 255, 184], [0, 92, 255], 207 | [184, 255, 0], [0, 133, 255], [255, 214, 0], [25, 194, 194], 208 | [102, 255, 0], [92, 0, 255]] 209 | 210 | def convert_to_color(arr_2d, palette=palette): 211 | arr_3d = np.zeros((arr_2d.shape[0], arr_2d.shape[1], 3), dtype=np.uint8) 212 | for c, i in palette.items(): 213 | m = arr_2d == c 214 | arr_3d[m] = i 215 | return arr_3d 216 | 217 | def convert_from_color(arr_3d, palette=invert_palette): 218 | arr_2d = np.zeros((arr_3d.shape[0], arr_3d.shape[1]), dtype=np.uint8) 219 | for c, i in palette.items(): 220 | m = np.all(arr_3d == np.array(c).reshape(1, 1, 3), axis=2) 221 | arr_2d[m] = i 222 | return arr_2d 223 | 224 | # img = io.imread('/Data/ndionelis/Vaihingen/top/top_mosaic_09cm_area11.tif') 225 | # fig = plt.figure() 226 | # fig.add_subplot(121) 227 | # plt.imshow(img) 228 | 229 | # gt = io.imread('/Data/ndionelis/Vaihingen/gts_for_participants/top_mosaic_09cm_area11.tif') 230 | # fig.add_subplot(122) 231 | # plt.imshow(gt) 232 | # plt.show() 233 | 234 | # array_gt = convert_from_color(gt) 235 | # print("Ground truth in numerical format has shape ({},{}) : \n".format(*array_gt.shape[:2]), array_gt) 236 | 237 | def get_random_pos(img, window_shape): 238 | w, h = window_shape 239 | W, H = img.shape[-2:] 240 | x1 = random.randint(0, W - w - 1) 241 | x2 = x1 + w 242 | y1 = random.randint(0, H - h - 1) 243 | y2 = y1 + h 244 | return x1, x2, y1, y2 245 | 246 | def CrossEntropy2d(input, target, weight=None, size_average=True): 247 | dim = input.dim() 248 | if dim == 2: 249 | return F.cross_entropy(input, target, weight, size_average) 250 | elif dim == 4: 251 | output = input.view(input.size(0),input.size(1), -1) 252 | output = torch.transpose(output,1,2).contiguous() 253 | output = output.view(-1,output.size(2)) 254 | target = target.view(-1) 255 | return F.cross_entropy(output, target,weight, size_average) 256 | else: 257 | raise ValueError('Expected 2 or 4 dimensions (got {})'.format(dim)) 258 | 259 | def accuracy(input, target): 260 | return 100 * float(np.count_nonzero(input == target)) / target.size 261 | 262 | def sliding_window(top, step=10, window_size=(20,20)): 263 | for x in range(0, top.shape[0], step): 264 | if x + window_size[0] > top.shape[0]: 265 | x = top.shape[0] - window_size[0] 266 | for y in range(0, top.shape[1], step): 267 | if y + window_size[1] > top.shape[1]: 268 | y = top.shape[1] - window_size[1] 269 | yield x, y, window_size[0], window_size[1] 270 | 271 | def count_sliding_window(top, step=10, window_size=(20,20)): 272 | c = 0 273 | for x in range(0, top.shape[0], step): 274 | if x + window_size[0] > top.shape[0]: 275 | x = top.shape[0] - window_size[0] 276 | for y in range(0, top.shape[1], step): 277 | if y + window_size[1] > top.shape[1]: 278 | y = top.shape[1] - window_size[1] 279 | c += 1 280 | return c 281 | 282 | def grouper(n, iterable): 283 | it = iter(iterable) 284 | while True: 285 | chunk = tuple(itertools.islice(it, n)) 286 | if not chunk: 287 | return 288 | yield chunk 289 | 290 | def metrics(predictions, gts, label_values=LABELS): 291 | cm = confusion_matrix( 292 | gts, 293 | predictions, 294 | range(len(label_values))) 295 | 296 | print("Confusion matrix :") 297 | print(cm) 298 | 299 | total = sum(sum(cm)) 300 | accuracy = sum([cm[x][x] for x in range(len(cm))]) 301 | accuracy *= 100 / float(total) 302 | print("{} pixels processed".format(total)) 303 | print("Total accuracy : {}%".format(accuracy)) 304 | 305 | F1Score = np.zeros(len(label_values)) 306 | for i in range(len(label_values)): 307 | try: 308 | F1Score[i] = 2. * cm[i,i] / (np.sum(cm[i,:]) + np.sum(cm[:,i])) 309 | except: 310 | pass 311 | print("F1Score :") 312 | for l_id, score in enumerate(F1Score): 313 | print("{}: {}".format(label_values[l_id], score)) 314 | 315 | total = np.sum(cm) 316 | pa = np.trace(cm) / float(total) 317 | pe = np.sum(np.sum(cm, axis=0) * np.sum(cm, axis=1)) / float(total*total) 318 | kappa = (pa - pe) / (1 - pe); 319 | print("Kappa: " + str(kappa)) 320 | 321 | return accuracy 322 | 323 | # The Dataset class 324 | class ISPRS_dataset(torch.utils.data.Dataset): 325 | def __init__(self, ids, data_files=DATA_FOLDER, label_files=LABEL_FOLDER, 326 | cache=False, augmentation=True): 327 | super(ISPRS_dataset, self).__init__() 328 | self.augmentation = augmentation 329 | self.cache = cache 330 | self.data_files = [DATA_FOLDER.format(id) for id in ids] 331 | self.label_files = [LABEL_FOLDER.format(id) for id in ids] 332 | for f in self.data_files + self.label_files: 333 | if not os.path.isfile(f): 334 | raise KeyError('{} is not a file !'.format(f)) 335 | self.data_cache_ = {} 336 | self.label_cache_ = {} 337 | 338 | def __len__(self): 339 | return 10000 340 | 341 | @classmethod 342 | def data_augmentation(cls, *arrays, flip=True, mirror=True): 343 | will_flip, will_mirror = False, False 344 | #will_rotate = False 345 | #will_rotate2 = False 346 | if flip and random.random() < 0.5: 347 | will_flip = True 348 | if mirror and random.random() < 0.5: 349 | will_mirror = True 350 | 351 | results = [] 352 | for array in arrays: 353 | if will_flip: 354 | if len(array.shape) == 2: 355 | array = array[::-1, :] 356 | else: 357 | array = array[:, ::-1, :] 358 | if will_mirror: 359 | if len(array.shape) == 2: 360 | array = array[:, ::-1] 361 | else: 362 | array = array[:, :, ::-1] 363 | 364 | results.append(np.copy(array)) 365 | 366 | return tuple(results) 367 | 368 | def __getitem__(self, i): 369 | random_idx = random.randint(0, len(self.data_files) - 1) 370 | if random_idx in self.data_cache_.keys(): 371 | data = self.data_cache_[random_idx] 372 | else: 373 | data = 1/255 * np.asarray(io.imread(self.data_files[random_idx]).transpose((2,0,1)), dtype='float32') 374 | if self.cache: 375 | self.data_cache_[random_idx] = data 376 | 377 | if random_idx in self.label_cache_.keys(): 378 | label = self.label_cache_[random_idx] 379 | else: 380 | label = np.asarray(convert_from_color(io.imread(self.label_files[random_idx])), dtype='int64') 381 | if self.cache: 382 | self.label_cache_[random_idx] = label 383 | 384 | x1, x2, y1, y2 = get_random_pos(data, WINDOW_SIZE) 385 | data_p = data[:, x1:x2,y1:y2] 386 | label_p = label[x1:x2,y1:y2] 387 | 388 | data_p, label_p = self.data_augmentation(data_p, label_p) 389 | 390 | return (torch.from_numpy(data_p), 391 | torch.from_numpy(label_p)) 392 | 393 | 394 | from torch.autograd import Function 395 | #import torch.autograd 396 | 397 | class ReverseLayerF(Function): 398 | @staticmethod 399 | def forward(ctx, x, alpha): 400 | ctx.alpha = alpha 401 | 402 | return x.view_as(x) 403 | 404 | @staticmethod 405 | def backward(ctx, grad_output): 406 | output = grad_output.neg() * ctx.alpha 407 | 408 | return output, None 409 | 410 | class SegNet(nn.Module): 411 | # # SegNet network 412 | @staticmethod 413 | def weight_init(m): 414 | if isinstance(m, nn.Linear): 415 | torch.nn.init.kaiming_normal(m.weight.data) 416 | 417 | def __init__(self, in_channels=IN_CHANNELS, out_channels=N_CLASSES): 418 | super(SegNet, self).__init__() 419 | self.pool = nn.MaxPool2d(2, return_indices=True) 420 | self.unpool = nn.MaxUnpool2d(2) 421 | 422 | self.conv1_1 = nn.Conv2d(in_channels, 64, 3, padding=1) 423 | self.conv1_1_bn = nn.BatchNorm2d(64) 424 | self.conv1_2 = nn.Conv2d(64, 64, 3, padding=1) 425 | self.conv1_2_bn = nn.BatchNorm2d(64) 426 | 427 | self.conv2_1 = nn.Conv2d(64, 128, 3, padding=1) 428 | self.conv2_1_bn = nn.BatchNorm2d(128) 429 | self.conv2_2 = nn.Conv2d(128, 128, 3, padding=1) 430 | self.conv2_2_bn = nn.BatchNorm2d(128) 431 | 432 | self.conv3_1 = nn.Conv2d(128, 256, 3, padding=1) 433 | self.conv3_1_bn = nn.BatchNorm2d(256) 434 | self.conv3_2 = nn.Conv2d(256, 256, 3, padding=1) 435 | self.conv3_2_bn = nn.BatchNorm2d(256) 436 | self.conv3_3 = nn.Conv2d(256, 256, 3, padding=1) 437 | self.conv3_3_bn = nn.BatchNorm2d(256) 438 | 439 | self.conv4_1 = nn.Conv2d(256, 512, 3, padding=1) 440 | self.conv4_1_bn = nn.BatchNorm2d(512) 441 | self.conv4_2 = nn.Conv2d(512, 512, 3, padding=1) 442 | self.conv4_2_bn = nn.BatchNorm2d(512) 443 | self.conv4_3 = nn.Conv2d(512, 512, 3, padding=1) 444 | self.conv4_3_bn = nn.BatchNorm2d(512) 445 | 446 | self.conv5_1 = nn.Conv2d(512, 512, 3, padding=1) 447 | self.conv5_1_bn = nn.BatchNorm2d(512) 448 | self.conv5_2 = nn.Conv2d(512, 512, 3, padding=1) 449 | self.conv5_2_bn = nn.BatchNorm2d(512) 450 | self.conv5_3 = nn.Conv2d(512, 512, 3, padding=1) 451 | self.conv5_3_bn = nn.BatchNorm2d(512) 452 | 453 | #self.ll00 = nn.Linear(1024, 2) 454 | self.ll00 = nn.Linear(512, 3) 455 | 456 | self.conv5_3_D = nn.Conv2d(512, 512, 3, padding=1) 457 | self.conv5_3_D_bn = nn.BatchNorm2d(512) 458 | self.conv5_2_D = nn.Conv2d(512, 512, 3, padding=1) 459 | self.conv5_2_D_bn = nn.BatchNorm2d(512) 460 | self.conv5_1_D = nn.Conv2d(512, 512, 3, padding=1) 461 | self.conv5_1_D_bn = nn.BatchNorm2d(512) 462 | 463 | self.conv4_3_D = nn.Conv2d(512, 512, 3, padding=1) 464 | self.conv4_3_D_bn = nn.BatchNorm2d(512) 465 | self.conv4_2_D = nn.Conv2d(512, 512, 3, padding=1) 466 | self.conv4_2_D_bn = nn.BatchNorm2d(512) 467 | self.conv4_1_D = nn.Conv2d(512, 256, 3, padding=1) 468 | self.conv4_1_D_bn = nn.BatchNorm2d(256) 469 | 470 | self.conv3_3_D = nn.Conv2d(256, 256, 3, padding=1) 471 | self.conv3_3_D_bn = nn.BatchNorm2d(256) 472 | self.conv3_2_D = nn.Conv2d(256, 256, 3, padding=1) 473 | self.conv3_2_D_bn = nn.BatchNorm2d(256) 474 | self.conv3_1_D = nn.Conv2d(256, 128, 3, padding=1) 475 | self.conv3_1_D_bn = nn.BatchNorm2d(128) 476 | 477 | self.conv2_2_D = nn.Conv2d(128, 128, 3, padding=1) 478 | self.conv2_2_D_bn = nn.BatchNorm2d(128) 479 | self.conv2_1_D = nn.Conv2d(128, 64, 3, padding=1) 480 | self.conv2_1_D_bn = nn.BatchNorm2d(64) 481 | 482 | self.conv1_2_D = nn.Conv2d(64, 64, 3, padding=1) 483 | self.conv1_2_D_bn = nn.BatchNorm2d(64) 484 | self.conv1_1_D = nn.Conv2d(64, out_channels, 3, padding=1) 485 | 486 | self.apply(self.weight_init) 487 | 488 | def forward(self, x, alpha): 489 | x = self.conv1_1_bn(F.relu(self.conv1_1(x))) 490 | x = self.conv1_2_bn(F.relu(self.conv1_2(x))) 491 | x, mask1 = self.pool(x) 492 | 493 | # Encoder block 2 494 | x = self.conv2_1_bn(F.relu(self.conv2_1(x))) 495 | x = self.conv2_2_bn(F.relu(self.conv2_2(x))) 496 | x, mask2 = self.pool(x) 497 | 498 | # Encoder block 3 499 | x = self.conv3_1_bn(F.relu(self.conv3_1(x))) 500 | x = self.conv3_2_bn(F.relu(self.conv3_2(x))) 501 | x = self.conv3_3_bn(F.relu(self.conv3_3(x))) 502 | x, mask3 = self.pool(x) 503 | 504 | # Encoder block 4 505 | x = self.conv4_1_bn(F.relu(self.conv4_1(x))) 506 | x = self.conv4_2_bn(F.relu(self.conv4_2(x))) 507 | x = self.conv4_3_bn(F.relu(self.conv4_3(x))) 508 | x, mask4 = self.pool(x) 509 | 510 | # Encoder block 5 511 | x = self.conv5_1_bn(F.relu(self.conv5_1(x))) 512 | x = self.conv5_2_bn(F.relu(self.conv5_2(x))) 513 | x = self.conv5_3_bn(F.relu(self.conv5_3(x))) 514 | x, mask5 = self.pool(x) 515 | 516 | #print(x) 517 | #print(x.shape) 518 | 519 | #xx22 = F.adaptive_avg_pool2d(x, 1).reshape(BATCH_SIZE, -1) 520 | #xx22 = F.adaptive_avg_pool2d(x, 1).reshape(2*BATCH_SIZE, -1) 521 | xx22 = F.adaptive_avg_pool2d(x, 1).reshape(BATCH_SIZE, -1) 522 | 523 | #xx22 = self.ll00(xx22) 524 | 525 | xx22 = ReverseLayerF.apply(xx22, alpha) 526 | 527 | xx22 = self.ll00(xx22) 528 | 529 | xx22 = F.softmax(xx22) 530 | 531 | x = self.unpool(x, mask5) 532 | x = self.conv5_3_D_bn(F.relu(self.conv5_3_D(x))) 533 | x = self.conv5_2_D_bn(F.relu(self.conv5_2_D(x))) 534 | x = self.conv5_1_D_bn(F.relu(self.conv5_1_D(x))) 535 | 536 | # Decoder block 4 537 | x = self.unpool(x, mask4) 538 | x = self.conv4_3_D_bn(F.relu(self.conv4_3_D(x))) 539 | x = self.conv4_2_D_bn(F.relu(self.conv4_2_D(x))) 540 | x = self.conv4_1_D_bn(F.relu(self.conv4_1_D(x))) 541 | 542 | # Decoder block 3 543 | x = self.unpool(x, mask3) 544 | x = self.conv3_3_D_bn(F.relu(self.conv3_3_D(x))) 545 | x = self.conv3_2_D_bn(F.relu(self.conv3_2_D(x))) 546 | x = self.conv3_1_D_bn(F.relu(self.conv3_1_D(x))) 547 | 548 | # # # Decoder block 2 549 | x = self.unpool(x, mask2) 550 | x = self.conv2_2_D_bn(F.relu(self.conv2_2_D(x))) 551 | x = self.conv2_1_D_bn(F.relu(self.conv2_1_D(x))) 552 | 553 | # Decoder block 1 554 | x = self.unpool(x, mask1) 555 | x = self.conv1_2_D_bn(F.relu(self.conv1_2_D(x))) 556 | x = F.log_softmax(self.conv1_1_D(x)) 557 | #return x 558 | return x, xx22 559 | 560 | def forward2(self, x): 561 | x = self.conv1_1_bn(F.relu(self.conv1_1(x))) 562 | x = self.conv1_2_bn(F.relu(self.conv1_2(x))) 563 | x, mask1 = self.pool(x) 564 | 565 | # Encoder block 2 566 | x = self.conv2_1_bn(F.relu(self.conv2_1(x))) 567 | x = self.conv2_2_bn(F.relu(self.conv2_2(x))) 568 | x, mask2 = self.pool(x) 569 | 570 | # Encoder block 3 571 | x = self.conv3_1_bn(F.relu(self.conv3_1(x))) 572 | x = self.conv3_2_bn(F.relu(self.conv3_2(x))) 573 | x = self.conv3_3_bn(F.relu(self.conv3_3(x))) 574 | x, mask3 = self.pool(x) 575 | 576 | # Encoder block 4 577 | x = self.conv4_1_bn(F.relu(self.conv4_1(x))) 578 | x = self.conv4_2_bn(F.relu(self.conv4_2(x))) 579 | x = self.conv4_3_bn(F.relu(self.conv4_3(x))) 580 | x, mask4 = self.pool(x) 581 | 582 | # Encoder block 5 583 | x = self.conv5_1_bn(F.relu(self.conv5_1(x))) 584 | x = self.conv5_2_bn(F.relu(self.conv5_2(x))) 585 | x = self.conv5_3_bn(F.relu(self.conv5_3(x))) 586 | x, mask5 = self.pool(x) 587 | 588 | #print(x) 589 | #print(x.shape) 590 | 591 | #xx22 = F.adaptive_avg_pool2d(x, 1).reshape(BATCH_SIZE, -1) 592 | #xx22 = F.adaptive_avg_pool2d(x, 1).reshape(2*BATCH_SIZE, -1) 593 | #xx22 = F.adaptive_avg_pool2d(x, 1).reshape(BATCH_SIZE, -1) 594 | 595 | #xx22 = self.ll00(xx22) 596 | 597 | #xx22 = ReverseLayerF.apply(xx22, alpha) 598 | 599 | #xx22 = self.ll00(xx22) 600 | 601 | #xx22 = F.softmax(xx22) 602 | 603 | x = self.unpool(x, mask5) 604 | x = self.conv5_3_D_bn(F.relu(self.conv5_3_D(x))) 605 | x = self.conv5_2_D_bn(F.relu(self.conv5_2_D(x))) 606 | x = self.conv5_1_D_bn(F.relu(self.conv5_1_D(x))) 607 | 608 | # Decoder block 4 609 | x = self.unpool(x, mask4) 610 | x = self.conv4_3_D_bn(F.relu(self.conv4_3_D(x))) 611 | x = self.conv4_2_D_bn(F.relu(self.conv4_2_D(x))) 612 | x = self.conv4_1_D_bn(F.relu(self.conv4_1_D(x))) 613 | 614 | # Decoder block 3 615 | x = self.unpool(x, mask3) 616 | x = self.conv3_3_D_bn(F.relu(self.conv3_3_D(x))) 617 | x = self.conv3_2_D_bn(F.relu(self.conv3_2_D(x))) 618 | x = self.conv3_1_D_bn(F.relu(self.conv3_1_D(x))) 619 | 620 | # # # Decoder block 2 621 | x = self.unpool(x, mask2) 622 | x = self.conv2_2_D_bn(F.relu(self.conv2_2_D(x))) 623 | x = self.conv2_1_D_bn(F.relu(self.conv2_1_D(x))) 624 | 625 | # Decoder block 1 626 | x = self.unpool(x, mask1) 627 | x = self.conv1_2_D_bn(F.relu(self.conv1_2_D(x))) 628 | x = F.log_softmax(self.conv1_1_D(x)) 629 | return x 630 | #return x, xx22 631 | 632 | net = SegNet() 633 | 634 | import os 635 | try: 636 | from urllib.request import URLopener 637 | except ImportError: 638 | from urllib import URLopener 639 | 640 | # # # Download VGG-16 weights from PyTorch 641 | vgg_url = 'https://download.pytorch.org/models/vgg16_bn-6c64b313.pth' 642 | if not os.path.isfile('./vgg16_bn-6c64b313.pth'): 643 | weights = URLopener().retrieve(vgg_url, './vgg16_bn-6c64b313.pth') 644 | 645 | vgg16_weights = torch.load('./vgg16_bn-6c64b313.pth') 646 | mapped_weights = {} 647 | for k_vgg, k_segnet in zip(vgg16_weights.keys(), net.state_dict().keys()): 648 | if "features" in k_vgg: 649 | mapped_weights[k_segnet] = vgg16_weights[k_vgg] 650 | print("Mapping {} to {}".format(k_vgg, k_segnet)) 651 | 652 | try: 653 | net.load_state_dict(mapped_weights) 654 | print("Loaded VGG-16 weights in SegNet !") 655 | except: 656 | # Ignore missing keys 657 | #pass 658 | pass 659 | 660 | #net.cuda() 661 | 662 | # The model SegFormer 663 | 664 | # # This is to test: torch.save(model.state_dict(), './segformermain30082023') 665 | # # Also: data_dir = '../../CVUSA/bingmap/mainTheDataset' 666 | 667 | from transformers import SegformerForSemanticSegmentation 668 | #import transformers 669 | 670 | # id2label = {0 : 'Impervious surfaces', 671 | # 1 : 'Buildings', 672 | # 2 : 'Low vegetation', 673 | # 3 : 'Trees', 674 | # 4 : 'Cars', 675 | # 5 : 'Clutter'} 676 | # label2id = {v: k for k, v in id2label.items()} 677 | 678 | #from transformers import SegformerForSemanticSegmentation 679 | import json 680 | from huggingface_hub import cached_download, hf_hub_url 681 | 682 | # # load id2label mapping from a JSON on the hub 683 | #repo_id = "datasets/huggingface/label-files" 684 | #filename = "./ade20k-id2label.json" 685 | filename = "./ade20kid2label.json" 686 | #id2label = json.load(open(cached_download(hf_hub_url(repo_id, filename)), "r")) 687 | id2label = json.load(open(filename, "r")) 688 | id2label = {int(k): v for k, v in id2label.items()} 689 | label2id = {v: k for k, v in id2label.items()} 690 | 691 | net = SegformerForSemanticSegmentation.from_pretrained("nvidia/mit-b5", 692 | num_labels=6, 693 | id2label=id2label, 694 | label2id=label2id, 695 | alpha=0.0, 696 | ) 697 | 698 | # net = SegformerForSemanticSegmentation.from_pretrained("nvidia/segformer-b5-finetuned-cityscapes-1024-1024", 699 | # num_labels=19, 700 | # id2label=id2label, 701 | # label2id=label2id, 702 | # alpha=0.0, 703 | # ) 704 | 705 | net.cuda() 706 | 707 | # #base_lr = 0.01 708 | # #base_lr = 0.04 709 | # #base_lr = 0.08 710 | # base_lr = 0.01 711 | 712 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 713 | 714 | N_CLASSES = 6 715 | #print(N_CLASSES) 716 | 717 | def test(net, test_ids, all=False, stride=WINDOW_SIZE[0], batch_size=BATCH_SIZE, window_size=WINDOW_SIZE): 718 | # # # # # # /home/nikolaos/CVUSA/bingmap 719 | 720 | # #print(test_images) 721 | # #print(test_images.size) 722 | 723 | all_preds = [] 724 | all_gts = [] 725 | 726 | # # # Switch the network to inference mode 727 | #net.eval() 728 | 729 | net.eval() 730 | 731 | for img, gt, gt_e in tqdm(zip(test_images, test_labels, eroded_labels), total=len(test_ids), leave=False): 732 | #img = (1 / 255 * np.asarray(io.imread('/home/nikolaos/CVUSA/bingmap/0000001.jpg'), dtype='float32')) 733 | 734 | #img = (1 / 255 * np.asarray(io.imread('/home/nikolaos/CVUSA/bingmap/19/'+str(44514).zfill(7)+'.jpg'), dtype='float32')) 735 | 736 | # data_dir = '/Data/ndionelis/inputtss' 737 | 738 | countertotal = 0 739 | foraverage = 0 740 | foraverage2 = 0 741 | foraverage3 = 0 742 | 743 | #data_dir = '/Data/ndionelis/bingmap/main_dataset/val/1' 744 | 745 | #data_dir = '/Data/ndionelis/bingmap/maindatasetnewnew/val/1' 746 | 747 | #data_dir = '../mubsa54b2bc2h/val/1' 748 | #data_dir = '../mubsa54b2bc2a/val/1' 749 | #data_dir = '/home/ndionelis/mubsa54b2bc2a/val/1' 750 | #data_dir = '/Data/ndionelis/1' 751 | data_dir = '/Data/ndionelis/SegmentationNew/Inputs/InputImages' 752 | #data_dir = '/Data/ndionelis/Folder_CVUSA_Segmentation/Inputs/Inputs' 753 | 754 | # # torch.save(model.state_dict(), '/Data/ndionelis/formodels/sgfrmr2ehwh') 755 | # # data_dir = '/Data/ndionelis/bingmap/maindatasetnewnew2' 756 | # # # Potsdam 757 | # # data_dir = '/Data/ndionelis/mainpo2' 758 | # # On the dataset Vaihingen 759 | # # data_dir = '/Data/ndionelis/mainva2' 760 | 761 | # # This is to test: torch.save(model.state_dict(), './segformermain30082023') 762 | # # Also: data_dir = '../../CVUSA/bingmap/mainTheDataset' 763 | 764 | for file in os.listdir(data_dir): 765 | img = (1 / 255 * np.asarray(io.imread(os.path.join(data_dir, file)), dtype='float32')) 766 | 767 | countertotal += 1 768 | counterr11 = file 769 | 770 | print(counterr11) 771 | print(countertotal) 772 | 773 | #img = (1 / 255 * np.asarray(io.imread('/home/nikolaos/CVUSA/bingmap/19/'+str(44515).zfill(7)+'.jpg'), dtype='float32')) 774 | 775 | pred = np.zeros(img.shape[:2] + (N_CLASSES,)) 776 | 777 | #img = (1 / 255 * np.asarray(io.imread('/home/nikolaos/CVUSA/bingmap/19/0000019.jpg'), dtype='float32')) 778 | 779 | total = count_sliding_window(img, step=stride, window_size=window_size) // batch_size 780 | for i, coords in enumerate(tqdm(grouper(batch_size, sliding_window(img, step=stride, window_size=window_size)), total=total, leave=False)): 781 | # # # # # Display in progress results 782 | if i > 0 and total > 10 and i % int(10 * total / 100) == 0: 783 | _pred = np.argmax(pred, axis=-1) 784 | # fig = plt.figure() 785 | # #fig.add_subplot(1,3,1) 786 | # fig.add_subplot(1,2,1) 787 | # plt.imshow(np.asarray(255 * img, dtype='uint8')) 788 | # plt.axis('off') 789 | # #fig.add_subplot(1,3,2) 790 | # fig.add_subplot(1,2,2) 791 | # plt.imshow(convert_to_color(_pred)) 792 | # plt.axis('off') 793 | # #fig.add_subplot(1,3,3) 794 | # #plt.imshow(gt) 795 | # #plt.axis('off') 796 | # #clear_output() 797 | # plt.show() 798 | 799 | # # # # # Build the tensor 800 | #image_patches = [np.copy(img[x:x+w, y:y+h]).transpose((2,0,1)) for x,y,w,h in coords] 801 | import copy 802 | image_patches = [copy.deepcopy(img[x:x+w, y:y+h]).transpose((2,0,1)) for x,y,w,h in coords] 803 | 804 | image_patches = np.asarray(image_patches) 805 | #image_patches = Variable(torch.from_numpy(image_patches).cuda(), volatile=True) 806 | with torch.no_grad(): 807 | #image_patches = Variable(torch.from_numpy(image_patches).cuda(), volatile=True) 808 | image_patches = Variable(torch.from_numpy(image_patches).cuda()) 809 | 810 | # # /home/nikolaos/CVUSA/bingmsap 811 | 812 | # # Do the inference 813 | #outs = net(image_patches) 814 | #outs, _ = net(image_patches, 0.01) 815 | #outs = net.forward2(image_patches) 816 | #outs = outs.data.cpu().numpy() 817 | 818 | # # prepare the image for the model 819 | #encoding = feature_extractor(image_patches, return_tensors="pt") 820 | #pixel_values = encoding.pixel_values.to(device) 821 | 822 | #encoding = feature_extractor(image_patches, return_tensors="pt") 823 | #pixel_values = encoding.pixel_values.to(device) 824 | 825 | pixel_values = image_patches.to(device) 826 | 827 | #outputs = model(pixel_values=pixel_values) 828 | 829 | outputs = net(pixel_values=pixel_values) 830 | 831 | logits = outputs.logits.cpu() 832 | upsampled_logits = nn.functional.interpolate(logits, 833 | size=image_patches.shape[3], # # (height, width) 834 | mode='bilinear', 835 | align_corners=False) 836 | #seg = upsampled_logits.argmax(dim=1)[0] 837 | outs = upsampled_logits.detach().cpu().numpy() 838 | 839 | # # /home/nikolaos/CVUSA/bingmsap 840 | 841 | # # # # # Do the inference 842 | #outs = net(image_patches) 843 | #outs, _ = net(image_patches, 0.01) 844 | #outs = net.forward2(image_patches) 845 | 846 | #outs = net.forward2(image_patches) 847 | 848 | # For the Transformer 849 | #outs = outs.data.cpu().numpy() 850 | 851 | # # # Fill in the results array 852 | for out, (x, y, w, h) in zip(outs, coords): 853 | out = out.transpose((1,2,0)) 854 | pred[x:x+w, y:y+h] += out 855 | del(outs) 856 | 857 | #break 858 | 859 | pred = np.argmax(pred, axis=-1) 860 | 861 | #import pdb; pdb.set_trace() 862 | 863 | from samgeo.hq_sam import SamGeo, show_image, download_file, overlay_images, tms_to_geotiff 864 | 865 | sam_kwargs = { 866 | "points_per_side": 32, 867 | "pred_iou_thresh": 0.86, 868 | "stability_score_thresh": 0.92, 869 | "crop_n_layers": 1, 870 | "crop_n_points_downscale_factor": 2, 871 | "min_mask_region_area": 100, 872 | } 873 | 874 | sam = SamGeo( 875 | model_type="vit_h", 876 | checkpoint="sam_vit_h_4b8939.pth", 877 | sam_kwargs=sam_kwargs, 878 | ) 879 | 880 | import cv2 881 | import matplotlib.pyplot as plt 882 | 883 | plt.figure() 884 | #plt.imshow(pixel_values[mainvarloop,:,:,:].permute(1, 2, 0).cpu().numpy()) 885 | plt.imshow(np.asarray(255 * img, dtype='uint8')) 886 | #img = (1 / 255 * np.asarray(io.imread(os.path.join(data_dir, '0044478.jpg')), dtype='float32')) 887 | #plt.imshow(np.asarray(255 * img, dtype='uint8')) 888 | plt.axis('off') 889 | 890 | # sam_eo = SamEO(checkpoint="sam_vit_h_4b8939.pth", 891 | # model_type='vit_h', 892 | # device=device, 893 | # erosion_kernel=(3, 3), 894 | # mask_multiplier=255, 895 | # sam_kwargs=None) 896 | 897 | # pred_tiff_path = 'pred.tiff' 898 | 899 | # #sam_eo.tiff_to_tiff(tms_tiff_path, pred_tiff_path) 900 | # sam_eo.tiff_to_tiff("SAMGeoSAMHQInp5B.png", pred_tiff_path) 901 | 902 | # pred_image = cv2.cvtColor(cv2.imread(pred_tiff_path), cv2.COLOR_BGR2RGB) 903 | 904 | # plt.figure() 905 | # f, axarr = plt.subplots(1,2) 906 | # axarr[0].imshow(image) 907 | # axarr[1].imshow(pred_image) 908 | # plt.show() 909 | 910 | #from torchvision.transforms import functional as F 911 | #F.to_pil_image(image_tensor) 912 | #sam.generate(F.to_pil_image(pixel_values[1,:,:,:].cpu())) 913 | #sam.generate(F.to_pil_image(pixel_values[1,:,:,:].cpu()), output="masks2.png", foreground=True, unique=True) 914 | sam.generate("SAMGeoSAMHQInp5B.png", output="masks2.png", foreground=True, unique=True) 915 | #sam.generate(np.array(pixel_values[1,:,:,:].permute(1, 2, 0).cpu().numpy().astype('uint8')), output="masks2.png", foreground=True, unique=True) 916 | #cv2_image = numpy_image, (1, 2, 0)) 917 | #cv2_image = cv2.cvtColor(cv2_image, cv2.COLOR_BGR2RGB) 918 | 919 | #sam.generate(cv2.cvtColor(pixel_values[1,:,:,:].permute(1, 2, 0).cpu().numpy().astype('uint8'), cv2.COLOR_BGR2RGB), output="masks2.png", foreground=True, unique=True) 920 | #((np.asarray(255 * pixel_values, dtype='uint8')) / 255.0) 921 | #pixel_values[0,:,:,:].permute(1, 2, 0).cpu().numpy().astype('uint8') 922 | #sam.generate(np.asarray(pixel_values[0,:,:,:].permute(1, 2, 0).cpu() / 255., dtype='uint8'), output="masks2.png", foreground=True, unique=True) 923 | 924 | #from segment_anything_hq import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor 925 | #mask_generator = SamAutomaticMaskGenerator(sam) # The automatic mask generator 926 | #masks = mask_generator.generate(pixel_values) # # Segment the input image 927 | 928 | #mainimage = sam.show_anns(axis="off", alpha=1, output="SAMGeoSAMHQInp5B.png") 929 | 930 | mainimage = np.asarray(255 * (1 / 255 * np.asarray(io.imread(os.path.join('/Data/ndionelis/SegmentationNew/Outputs/OutputImages', 'output'+file[5:])), dtype='float32')), dtype='uint8') 931 | #mainimage = np.asarray(255 * (1 / 255 * np.asarray(io.imread(os.path.join('/Data/ndionelis/Folder_CVUSA_Segmentation/Outputs/Outputs', 'output'+file[5:])), dtype='float32')), dtype='uint8') 932 | 933 | #pred = mainimage 934 | #mainimage = pred 935 | 936 | #mainimage = torch.zeros_like(torch.from_numpy(mainimage)).numpy() 937 | #pred = torch.zeros_like(torch.from_numpy(pred)).numpy() 938 | 939 | # # # # # Display the result 940 | #clear_output() 941 | fig = plt.figure() 942 | #fig.add_subplot(1,3,1) 943 | #fig.add_subplot(1,2,1) 944 | plt.imshow(np.asarray(255 * img, dtype='uint8')) 945 | plt.axis('off') 946 | #plt.show() 947 | #fig.add_subplot(1,3,2) 948 | #fig.add_subplot(1,2,2) 949 | #print(counterr11) 950 | #print(countertotal) 951 | #counterr11 = counterr11[:-4] 952 | #print(counterr11) 953 | #adfasdflzs 954 | # try: 955 | # #plt.savefig('./inputs/input'+str(counterr11)+'.png', bbox_inches='tight') 956 | # plt.savefig('./ii/input'+str(counterr11)+'.png', bbox_inches='tight') 957 | # except: 958 | # os.mkdir('ii') 959 | # plt.savefig('./ii/input'+str(counterr11)+'.png', bbox_inches='tight') 960 | #plt.savefig('inputim1.png', bbox_inches='tight') 961 | plt.savefig('/Data/ndionelis/theinput'+str(counterr11)+'.png', bbox_inches='tight') 962 | 963 | fig2 = plt.figure() 964 | plt.imshow(convert_to_color(pred)) 965 | plt.axis('off') 966 | #plt.show() 967 | #fig.add_subplot(1,3,3) 968 | #plt.imshow(gt) 969 | #plt.axis('off') 970 | #plt.show() 971 | #plt.savefig('./tesssttt2_tttlll2.png') 972 | #io.imsave('./tesssttt2_tttlll2.png') 973 | #plt.pause(10) 974 | # try: 975 | # #plt.savefig('./outputs/output'+str(counterr11)+'.png', bbox_inches='tight') 976 | # plt.savefig('./oo/output'+str(counterr11)+'.png', bbox_inches='tight') 977 | # except: 978 | # os.mkdir('oo') 979 | # plt.savefig('./oo/output'+str(counterr11)+'.png', bbox_inches='tight') 980 | plt.savefig('/Data/ndionelis/theoutput'+str(counterr11)+'.png', bbox_inches='tight') 981 | #plt.savefig('outputim1.png', bbox_inches='tight') 982 | 983 | fig2 = plt.figure() 984 | plt.imshow(mainimage) 985 | #plt.imshow(convert_to_color(pred)) 986 | plt.axis('off') 987 | #plt.show() 988 | #fig.add_subplot(1,3,3) 989 | #plt.imshow(gt) 990 | #plt.axis('off') 991 | #plt.show() 992 | #plt.savefig('./tesssttt2_tttlll2.png') 993 | #io.imsave('./tesssttt2_tttlll2.png') 994 | #plt.pause(10) 995 | # try: 996 | # #plt.savefig('./outputs/output'+str(counterr11)+'.png', bbox_inches='tight') 997 | # plt.savefig('./oo/output'+str(counterr11)+'.png', bbox_inches='tight') 998 | # except: 999 | # os.mkdir('oo') 1000 | # plt.savefig('./oo/output'+str(counterr11)+'.png', bbox_inches='tight') 1001 | plt.savefig('/Data/ndionelis/thetheoutput'+str(counterr11)+'.png', bbox_inches='tight') 1002 | #plt.savefig('outputim1.png', bbox_inches='tight') 1003 | 1004 | #os.chdir('/home/ndionelis/pytorch_unet/segment-anything-eo') 1005 | 1006 | #get_ipython().system('pip install rasterio') 1007 | #get_ipython().system('pip install geopandas') 1008 | 1009 | #get_ipython().system('wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth') 1010 | #get_ipython().system('wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth') 1011 | #get_ipython().system('wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth') 1012 | 1013 | import cv2 1014 | import matplotlib.pyplot as plt 1015 | from sameo import SamEO 1016 | 1017 | # ## Initialize SemEO class 1018 | 1019 | # # Availble SamEO arguments: 1020 | # checkpoint="sam_vit_h_4b8939.pth", 1021 | # model_type='vit_h', 1022 | # device='cpu', 1023 | # erosion_kernel=(3, 3), 1024 | # mask_multiplier=255, 1025 | # sam_kwargs=None 1026 | 1027 | # Availble sam_kwargs: 1028 | # points_per_side: Optional[int] = 32, 1029 | # points_per_batch: int = 64, 1030 | # pred_iou_thresh: float = 0.88, 1031 | # stability_score_thresh: float = 0.95, 1032 | # stability_score_offset: float = 1.0, 1033 | # box_nms_thresh: float = 0.7, 1034 | # crop_n_layers: int = 0, 1035 | # crop_nms_thresh: float = 0.7, 1036 | # crop_overlap_ratio: float = 512 / 1500, 1037 | # crop_n_points_downscale_factor: int = 1, 1038 | # point_grids: Optional[List[np.ndarray]] = None, 1039 | # min_mask_region_area: int = 0, 1040 | # output_mode: str = "binary_mask", 1041 | 1042 | #device = 'cuda:0' 1043 | 1044 | # sam_eo = SamEO(checkpoint="sam_vit_h_4b8939.pth", 1045 | # model_type='vit_h', 1046 | # device=device, 1047 | # erosion_kernel=(3, 3), 1048 | # mask_multiplier=255, 1049 | # sam_kwargs=None) 1050 | 1051 | # sam_eo = SamEO(checkpoint="sam_vit_h_4b8939.pth", 1052 | # model_type='vit_h', 1053 | # device=device, 1054 | # erosion_kernel=(3, 3), 1055 | # mask_multiplier=255, 1056 | # sam_kwargs=None) 1057 | 1058 | sam_eo = SamEO(checkpoint="/Data/ndionelis/sam_vit_h_4b8939.pth", 1059 | model_type='vit_h', 1060 | device=device, 1061 | erosion_kernel=(3, 3), 1062 | mask_multiplier=255, 1063 | sam_kwargs=None) 1064 | 1065 | # ## Download file from Openaerialmap and save it 1066 | tms_source = 'https://tiles.openaerialmap.org/642385491a8878000512126c/0/642385491a8878000512126d/{z}/{x}/{y}' 1067 | #pt1 = (29.676840, -95.369222) 1068 | #pt2 = (29.678559, -95.367314) 1069 | zoom = 20 1070 | #tms_tiff_path = 'test_tms_image.tif' 1071 | #tms_tiff_path = '/home/ndionelis/segmentanything/notebooks/Input2.png' 1072 | tms_tiff_path = '/home/ndionelis/Transformers-Tutorials/SegFormer/Inpuutt2.png' 1073 | 1074 | #image = sam_eo.download_tms_as_tiff(tms_source, pt1, pt2, zoom, tms_tiff_path) 1075 | tiff_image = cv2.cvtColor(cv2.imread(tms_tiff_path), cv2.COLOR_BGR2RGB) 1076 | 1077 | #plt.figure() 1078 | #f, axarr = plt.subplots(1,2) 1079 | #axarr[0].imshow(image) 1080 | #axarr[1].imshow(tiff_image) 1081 | #plt.show() 1082 | 1083 | pred_tiff_path = 'pred.tiff' 1084 | 1085 | #tms_tiff_path = 'test_tms_image.tif' 1086 | #tms_tiff_path = '/home/ndionelis/segmentanything/notebooks/Input1.tif' 1087 | #tms_tiff_path = '/home/ndionelis/segmentanything/notebooks/Input3.png' 1088 | #tms_tiff_path = '../../../segmentanything/notebooks/Input3.png' 1089 | #tms_tiff_path = '../segmentanything/notebooks/Inpu5.png' 1090 | #tms_tiff_path = '/home/ndionelis/segmentanything/notebooks/Inpu2.png' 1091 | sam_eo.tiff_to_tiff('/Data/ndionelis/theinput'+str(counterr11)+'.png', pred_tiff_path) 1092 | #sam_eo.tiff_to_tiff(tms_tiff_path, pred_tiff_path) 1093 | 1094 | pred_image = cv2.cvtColor(cv2.imread(pred_tiff_path), cv2.COLOR_BGR2RGB) 1095 | 1096 | plt.figure() 1097 | #f, axarr = plt.plot() 1098 | #f, axarr = plt.subplots(1,2) 1099 | #axarr[0].imshow(image) 1100 | #axarr[1].imshow(pred_image) 1101 | #plt.imshow(cv2.imread(tms_tiff_path)) 1102 | #plt.imshow(pred_image, alpha=0.5) 1103 | plt.imshow(pred_image) 1104 | #plt.show() 1105 | plt.axis('off') 1106 | 1107 | #plt.savefig('SAMGeospatial5.png') 1108 | #plt.savefig('/home/ndionelis/segment-anything-eo/SAMGeospatial5.png', bbox_inches='tight') 1109 | plt.savefig('/Data/ndionelis/SAMGeospatial5.png', bbox_inches='tight') 1110 | 1111 | sam_eo.tiff_to_tiff('/Data/ndionelis/thetheoutput'+str(counterr11)+'.png', pred_tiff_path) 1112 | #sam_eo.tiff_to_tiff(tms_tiff_path, pred_tiff_path) 1113 | 1114 | pred_image2 = pred_image 1115 | 1116 | pred_image = cv2.cvtColor(cv2.imread(pred_tiff_path), cv2.COLOR_BGR2RGB) 1117 | 1118 | plt.figure() 1119 | #f, axarr = plt.plot() 1120 | #f, axarr = plt.subplots(1,2) 1121 | #axarr[0].imshow(image) 1122 | #axarr[1].imshow(pred_image) 1123 | #plt.imshow(cv2.imread(tms_tiff_path)) 1124 | #plt.imshow(pred_image, alpha=0.5) 1125 | plt.imshow(pred_image) 1126 | #plt.show() 1127 | plt.axis('off') 1128 | 1129 | #plt.savefig('SAMGeospatial5.png') 1130 | #plt.savefig('/home/ndionelis/segment-anything-eo/SAMGeospatial5.png', bbox_inches='tight') 1131 | plt.savefig('/Data/ndionelis/SAMSAMGeospatial5.png', bbox_inches='tight') 1132 | 1133 | #pred_image = np.resize(pred_image, np.shape(pred_image2)) 1134 | pred_image2 = np.resize(pred_image2, np.shape(pred_image)) 1135 | 1136 | #criterion = nn.MSELoss() 1137 | #loss = torch.sqrt(criterion(pred_image2, pred_image)) 1138 | #loss = torch.sqrt(criterion(torch.from_numpy(pred_image).float(), torch.from_numpy(pred_image2).float())) 1139 | loss = torch.abs(torch.mean((torch.from_numpy(pred_image).float() - torch.from_numpy(pred_image2).float()))) 1140 | loss /= torch.abs(torch.mean((torch.zeros_like(torch.from_numpy(pred_image).float()) - torch.from_numpy(pred_image2).float()))) 1141 | #loss = torch.sqrt(torch.mean((torch.from_numpy(pred_image).float() - torch.from_numpy(pred_image2).float())**2)) 1142 | #loss /= torch.sqrt(torch.mean((torch.from_numpy(pred_image).float() - torch.from_numpy(pred_image2).float())**2)) 1143 | #loss /= torch.mean(torch.zeros_like(torch.from_numpy(pred_image).float()) - torch.from_numpy(pred_image2).float()) 1144 | #loss /= torch.sqrt(criterion(torch.zeros_like(torch.from_numpy(pred_image).float()), torch.from_numpy(pred_image2).float())) 1145 | #loss /= torch.sqrt(criterion(255*torch.ones_like(torch.from_numpy(pred_image).float()), torch.from_numpy(pred_image2).float())) 1146 | print(loss) 1147 | 1148 | #import pdb; pdb.set_trace() 1149 | 1150 | # # from pytorch_msssim import ssim, ms_ssim, SSIM, MS_SSIM 1151 | # # # X: (N,3,H,W) a batch of non-negative RGB images (0 ~ 255) 1152 | # # # Y: (N,3,H,W) 1153 | 1154 | # # ssim_val = ssim(torch.from_numpy(pred_image).float().unsqueeze(0), torch.from_numpy(pred_image2).float().unsqueeze(0), size_average=True) 1155 | # # #ssim_val = ssim(torch.from_numpy(pred_image2).float().unsqueeze(0), torch.from_numpy(pred_image).float().unsqueeze(0), data_range=255, size_average=False) 1156 | # # print(ssim_val.squeeze()) 1157 | 1158 | # # #ssim_val = ssim(torch.zeros_like(torch.from_numpy(pred_image).float()).unsqueeze(0), torch.from_numpy(pred_image2).float().unsqueeze(0), data_range=255, size_average=False) 1159 | # # #print(ssim_val.squeeze()) 1160 | 1161 | # # from CannyEdgePytorch.net_canny import Net 1162 | # # #import imageio 1163 | 1164 | # # def canny(raw_img, use_cuda=False): 1165 | # # img = torch.from_numpy(raw_img.transpose((2, 0, 1))) 1166 | # # batch = torch.stack([img]).float() 1167 | 1168 | # # net = Net(threshold=3.0, use_cuda=use_cuda) 1169 | # # if use_cuda: 1170 | # # net.cuda() 1171 | # # net.eval() 1172 | 1173 | # # data = Variable(batch) 1174 | # # if use_cuda: 1175 | # # data = Variable(batch).cuda() 1176 | 1177 | # # blurred_img, grad_mag, grad_orientation, thin_edges, thresholded, early_threshold = net(data) 1178 | 1179 | # # fig = plt.figure() 1180 | # # #fig.add_subplot(1,2,1) 1181 | # # #plt.imshow((thresholded.data.cpu().numpy()[0, 0] > 0.0).astype(float)) 1182 | # # # # grad_mag.data.cpu().numpy()[0,0] 1183 | # # #plt.imshow(grad_mag.data.cpu().numpy()[0,0]) 1184 | # # plt.imshow(early_threshold.data.cpu().numpy()[0, 0]) 1185 | # # plt.axis('off') 1186 | # # plt.show() 1187 | # # #plt.savefig('./theoutputnew'+str(counterr11)+'.png', bbox_inches='tight') 1188 | # # plt.savefig('/home/ndionelis/pytorch_unet/theoutoutputnew'+str(counterr11)+'.png', bbox_inches='tight') 1189 | 1190 | # # return early_threshold.data.cpu().numpy()[0, 0] 1191 | 1192 | # # #print(np.shape(raw_img)) 1193 | # # #print(np.shape((thresholded.data.cpu().numpy()[0, 0] > 0.0).astype(float))) 1194 | 1195 | # # #imageio.imwrite('gradient_magnitude.png',grad_mag.data.cpu().numpy()[0,0]) 1196 | # # #imageio.imwrite('thin_edges.png', thresholded.data.cpu().numpy()[0, 0]) 1197 | # # #imageio.imwrite('final.png', (thresholded.data.cpu().numpy()[0, 0] > 0.0).astype(float)) 1198 | # # #imageio.imwrite('thresholded.png', early_threshold.data.cpu().numpy()[0, 0]) 1199 | 1200 | # # #criterion = nn.MSELoss() 1201 | # # #loss = torch.sqrt(criterion(a, b)) 1202 | # # loss = torch.abs(torch.mean(a - b)) 1203 | # # #b = canny2(((255*torch.ones_like(torch.from_numpy(mainimage))).numpy() / 255.0), use_cuda=True) 1204 | # # #b = canny2(((torch.zeros_like(torch.from_numpy(mainimage))).numpy() / 255.0), use_cuda=True) 1205 | # # #b = np.resize(b, np.shape(a)) 1206 | # # #b = torch.from_numpy(b).float() 1207 | # # #loss /= torch.sqrt(criterion(a, b)) 1208 | # # #loss /= torch.abs(torch.mean(a - b)) 1209 | # # print(loss) 1210 | 1211 | # all_preds.append(pred) 1212 | # #all_gts.append(gt_e) 1213 | 1214 | # #from scipy.misc import imread, imsave 1215 | # #from scipy.misc import imsave 1216 | # #import torch 1217 | # #from torch.autograd import Variable 1218 | # from CannyEdgePytorch.net_canny import Net 1219 | # #import imageio 1220 | 1221 | # #from CannyEdgePytorch.canny import * 1222 | # #from CannyEdgePytorch.net_canny import * 1223 | 1224 | # import pytorch_ssim 1225 | # #import torch 1226 | # #from torch.autograd import Variable 1227 | 1228 | # #img1 = Variable(torch.rand(1, 1, 256, 256)) 1229 | # #img2 = Variable(torch.rand(1, 1, 256, 256)) 1230 | 1231 | # #if torch.cuda.is_available(): 1232 | # # img1 = img1.cuda() 1233 | # # img2 = img2.cuda() 1234 | 1235 | # criterion = nn.MSELoss() 1236 | # loss = torch.sqrt(criterion(a, b)) 1237 | # print(loss) 1238 | 1239 | foraverage += loss 1240 | 1241 | # from pytorch_msssim import ssim, ms_ssim, SSIM, MS_SSIM 1242 | # # X: (N,3,H,W) a batch of non-negative RGB images (0~255) 1243 | # # Y: (N,3,H,W) 1244 | 1245 | # ssim_val = ssim(a, b, data_range=255, size_average=False) 1246 | # print(ssim_val.squeeze()) 1247 | 1248 | # foraverage2 += ssim_val.squeeze() 1249 | 1250 | # fig3 = plt.figure() 1251 | # plt.imshow(convert_to_color(pred)) 1252 | # plt.axis('off') 1253 | # #plt.show() 1254 | # #fig.add_subplot(1,3,3) 1255 | # #plt.imshow(gt) 1256 | # #plt.axis('off') 1257 | # #plt.show() 1258 | # #plt.savefig('./tesssttt2_tttlll2.png') 1259 | # #io.imsave('./tesssttt2_tttlll2.png') 1260 | # #plt.pause(10) 1261 | # #plt.figtext(0.5, 0.01, str(loss.item())+' '+str(ssim_val.squeeze().item())+' '+str(iou_pytorch(a, b).item()), wrap=True, horizontalalignment='center', fontsize=12) 1262 | # #"{:.2f}".format(round(a, 2)) 1263 | # plt.figtext(0.5, 0.01, "RMSE: "+str("{:.4f}".format(loss.item()))+', SSIM: '+str("{:.4f}".format(ssim_val.squeeze().item()))+', IoU: '+str("{:.4f}".format(iou_pytorch(a, b).item())), wrap=True, horizontalalignment='center', fontsize=12) 1264 | # # try: 1265 | # # plt.savefig('./oo/output'+str(counterr11)+'b.png', bbox_inches='tight') 1266 | # # except: 1267 | # # os.mkdir('oo') 1268 | # # plt.savefig('./oo/output'+str(counterr11)+'b.png', bbox_inches='tight') 1269 | # plt.savefig('./theoutput'+str(counterr11)+'b.png', bbox_inches='tight') 1270 | 1271 | plt.close('all') 1272 | 1273 | torch.save(countertotal, 'countertotal.pt') 1274 | 1275 | clear_output() 1276 | torch.cuda.empty_cache() 1277 | 1278 | #time.sleep(4) 1279 | time.sleep(3) 1280 | 1281 | #clear_output() 1282 | 1283 | #metrics(pred.ravel(), gt_e.ravel()) 1284 | #accuracy = metrics(np.concatenate([p.ravel() for p in all_preds]), np.concatenate([p.ravel() for p in all_gts]).ravel()) 1285 | 1286 | #break 1287 | 1288 | foraverage /= countertotal 1289 | 1290 | print(foraverage) 1291 | 1292 | # print(foraverage) 1293 | # print(foraverage2) 1294 | # print(foraverage3) 1295 | 1296 | #criterion = nn.MSELoss() 1297 | #loss = torch.sqrt(criterion(a, b)) 1298 | #print(loss) 1299 | 1300 | #from pytorch_msssim import ssim, ms_ssim, SSIM, MS_SSIM 1301 | # X: (N,3,H,W) a batch of non-negative RGB images (0~255) 1302 | # Y: (N,3,H,W) 1303 | 1304 | #ssim_val = ssim(a, b, data_range=255, size_average=False) 1305 | #print(ssim_val.squeeze()) 1306 | 1307 | fig3 = plt.figure() 1308 | plt.imshow(convert_to_color(pred)) 1309 | plt.axis('off') 1310 | #plt.show() 1311 | #fig.add_subplot(1,3,3) 1312 | #plt.imshow(gt) 1313 | #plt.axis('off') 1314 | #plt.show() 1315 | #plt.savefig('./tesssttt2_tttlll2.png') 1316 | #io.imsave('./tesssttt2_tttlll2.png') 1317 | #plt.pause(10) 1318 | #plt.figtext(0.5, 0.01, str(loss.item())+' '+str(ssim_val.squeeze().item())+' '+str(iou_pytorch(a, b).item()), wrap=True, horizontalalignment='center', fontsize=12) 1319 | #"{:.2f}".format(round(a, 2)) 1320 | try: 1321 | plt.savefig('./outputtss/output'+str(counterr11)+'b.png', bbox_inches='tight') 1322 | except: 1323 | os.mkdir('outputtss') 1324 | plt.savefig('./outputtss/output'+str(counterr11)+'b.png', bbox_inches='tight') 1325 | 1326 | clear_output() 1327 | torch.cuda.empty_cache() 1328 | 1329 | #clear_output() 1330 | 1331 | #metrics(pred.ravel(), gt_e.ravel()) 1332 | #accuracy = metrics(np.concatenate([p.ravel() for p in all_preds]), np.concatenate([p.ravel() for p in all_gts]).ravel()) 1333 | 1334 | #break 1335 | 1336 | if all: 1337 | return accuracy, all_preds, all_gts 1338 | else: 1339 | return accuracy 1340 | 1341 | from IPython.display import clear_output 1342 | 1343 | def train(net, optimizer, epochs, scheduler=None, weights=WEIGHTS, save_epoch = 5): 1344 | losses = np.zeros(1000000) 1345 | mean_losses = np.zeros(100000000) 1346 | weights = weights.cuda() 1347 | 1348 | criterion = nn.NLLLoss2d(weight=weights) 1349 | 1350 | iter_ = 0 1351 | 1352 | #print(loss) 1353 | 1354 | #train(net, optimizer, 50, scheduler) 1355 | 1356 | #net.load_state_dict(torch.load('./segnet_finale3')) 1357 | 1358 | # SegFormer 1359 | #net.load_state_dict(torch.load('./segformermain30082023')) 1360 | 1361 | # # torch.save(model.state_dict(), '/Data/ndionelis/formodels/sgfrmr2ehwh') 1362 | # # data_dir = '/Data/ndionelis/bingmap/maindatasetnewnew2' 1363 | # # Potsdam 1364 | # # data_dir = '/Data/ndionelis/mainpo2' 1365 | # # On the dataset Vaihingen 1366 | # # data_dir = '/Data/ndionelis/mainva2' 1367 | 1368 | # # net.load_state_dict(torch.load('./sgfrmr2023')) 1369 | 1370 | #clear_output() 1371 | 1372 | #_, all_preds, all_gts = test(net, test_ids, all=True, stride=32) 1373 | 1374 | -------------------------------------------------------------------------------- /src/modeling_SegFormer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | # We modify: http://github.com/huggingface/transformers/blob/v4.33.0/src/transformers/models/segformer/modeling_segformer.py#L746 5 | 6 | # Copyright 2021 NVIDIA The HuggingFace Inc. team. All rights reserved. 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at: http://www.apache.org/licenses/LICENSE-2.0 10 | # # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """ The PyTorch SegFormer model. """ 17 | 18 | # # We have modified: http://github.com/huggingface/transformers/blob/v4.33.0/src/transformers/models/segformer/modeling_segformer.py#L746 19 | 20 | import math 21 | from typing import Optional, Tuple, Union 22 | 23 | import torch 24 | import torch.utils.checkpoint 25 | 26 | from torch import nn 27 | from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss 28 | 29 | import torch.nn.functional as F 30 | 31 | from ...activations import ACT2FN 32 | from ...modeling_outputs import BaseModelOutput, ImageClassifierOutput, SemanticSegmenterOutput 33 | from ...modeling_utils import PreTrainedModel 34 | from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer 35 | 36 | from ...utils import ( 37 | add_code_sample_docstrings, 38 | add_start_docstrings, 39 | add_start_docstrings_to_model_forward, 40 | logging, 41 | replace_return_docstrings, 42 | ) 43 | 44 | from .configuration_segformer import SegformerConfig 45 | 46 | logger = logging.get_logger(__name__) 47 | 48 | _CONFIG_FOR_DOC = "SegformerConfig" 49 | 50 | _CHECKPOINT_FOR_DOC = "nvidia/mit-b0" 51 | _EXPECTED_OUTPUT_SHAPE = [1, 256, 16, 16] 52 | 53 | _IMAGE_CLASS_CHECKPOINT = "nvidia/mit-b0" 54 | _IMAGE_CLASS_EXPECTED_OUTPUT = "tabby, tabby cat" 55 | 56 | SEGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = [ 57 | "nvidia/segformer-b0-finetuned-ade-512-512", 58 | # # See all SegFormer models at https://huggingface.co/models?filter=segformer 59 | ] 60 | 61 | class SegFormerImageClassifierOutput(ImageClassifierOutput): 62 | loss: Optional[torch.FloatTensor] = None 63 | logits: torch.FloatTensor = None 64 | hidden_states: Optional[Tuple[torch.FloatTensor]] = None 65 | attentions: Optional[Tuple[torch.FloatTensor]] = None 66 | 67 | def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor: 68 | if drop_prob == 0.0 or not training: 69 | return input 70 | keep_prob = 1 - drop_prob 71 | shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets 72 | random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device) 73 | random_tensor.floor_() # binarize 74 | output = input.div(keep_prob) * random_tensor 75 | return output 76 | 77 | class SegformerDropPath(nn.Module): 78 | def __init__(self, drop_prob: Optional[float] = None) -> None: 79 | super().__init__() 80 | self.drop_prob = drop_prob 81 | 82 | def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: 83 | return drop_path(hidden_states, self.drop_prob, self.training) 84 | 85 | def extra_repr(self) -> str: 86 | return "p={}".format(self.drop_prob) 87 | 88 | class SegformerOverlapPatchEmbeddings(nn.Module): 89 | def __init__(self, patch_size, stride, num_channels, hidden_size): 90 | super().__init__() 91 | self.proj = nn.Conv2d( 92 | num_channels, 93 | hidden_size, 94 | kernel_size=patch_size, 95 | stride=stride, 96 | padding=patch_size // 2, 97 | ) 98 | self.layer_norm = nn.LayerNorm(hidden_size) 99 | 100 | def forward(self, pixel_values): 101 | embeddings = self.proj(pixel_values) 102 | _, _, height, width = embeddings.shape 103 | embeddings = embeddings.flatten(2).transpose(1, 2) 104 | embeddings = self.layer_norm(embeddings) 105 | return embeddings, height, width 106 | 107 | class SegformerEfficientSelfAttention(nn.Module): 108 | def __init__(self, config, hidden_size, num_attention_heads, sequence_reduction_ratio): 109 | super().__init__() 110 | self.hidden_size = hidden_size 111 | self.num_attention_heads = num_attention_heads 112 | 113 | if self.hidden_size % self.num_attention_heads != 0: 114 | raise ValueError( 115 | f"The hidden size ({self.hidden_size}) is not a multiple of the number of attention " 116 | f"heads ({self.num_attention_heads})" 117 | ) 118 | 119 | self.attention_head_size = int(self.hidden_size / self.num_attention_heads) 120 | self.all_head_size = self.num_attention_heads * self.attention_head_size 121 | 122 | self.query = nn.Linear(self.hidden_size, self.all_head_size) 123 | self.key = nn.Linear(self.hidden_size, self.all_head_size) 124 | self.value = nn.Linear(self.hidden_size, self.all_head_size) 125 | 126 | self.dropout = nn.Dropout(config.attention_probs_dropout_prob) 127 | 128 | self.sr_ratio = sequence_reduction_ratio 129 | if sequence_reduction_ratio > 1: 130 | self.sr = nn.Conv2d( 131 | hidden_size, hidden_size, kernel_size=sequence_reduction_ratio, stride=sequence_reduction_ratio 132 | ) 133 | self.layer_norm = nn.LayerNorm(hidden_size) 134 | 135 | def transpose_for_scores(self, hidden_states): 136 | new_shape = hidden_states.size()[:-1] + (self.num_attention_heads, self.attention_head_size) 137 | hidden_states = hidden_states.view(new_shape) 138 | return hidden_states.permute(0, 2, 1, 3) 139 | 140 | def forward( 141 | self, 142 | hidden_states, 143 | height, 144 | width, 145 | output_attentions=False, 146 | ): 147 | query_layer = self.transpose_for_scores(self.query(hidden_states)) 148 | 149 | if self.sr_ratio > 1: 150 | batch_size, seq_len, num_channels = hidden_states.shape 151 | hidden_states = hidden_states.permute(0, 2, 1).reshape(batch_size, num_channels, height, width) 152 | hidden_states = self.sr(hidden_states) 153 | hidden_states = hidden_states.reshape(batch_size, num_channels, -1).permute(0, 2, 1) 154 | hidden_states = self.layer_norm(hidden_states) 155 | 156 | key_layer = self.transpose_for_scores(self.key(hidden_states)) 157 | value_layer = self.transpose_for_scores(self.value(hidden_states)) 158 | 159 | attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) 160 | attention_scores = attention_scores / math.sqrt(self.attention_head_size) 161 | 162 | attention_probs = nn.functional.softmax(attention_scores, dim=-1) 163 | 164 | attention_probs = self.dropout(attention_probs) 165 | 166 | context_layer = torch.matmul(attention_probs, value_layer) 167 | 168 | context_layer = context_layer.permute(0, 2, 1, 3).contiguous() 169 | new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) 170 | context_layer = context_layer.view(new_context_layer_shape) 171 | 172 | outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) 173 | 174 | return outputs 175 | 176 | class SegformerSelfOutput(nn.Module): 177 | def __init__(self, config, hidden_size): 178 | super().__init__() 179 | self.dense = nn.Linear(hidden_size, hidden_size) 180 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 181 | 182 | def forward(self, hidden_states, input_tensor): 183 | hidden_states = self.dense(hidden_states) 184 | hidden_states = self.dropout(hidden_states) 185 | return hidden_states 186 | 187 | class SegformerAttention(nn.Module): 188 | def __init__(self, config, hidden_size, num_attention_heads, sequence_reduction_ratio): 189 | super().__init__() 190 | self.self = SegformerEfficientSelfAttention( 191 | config=config, 192 | hidden_size=hidden_size, 193 | num_attention_heads=num_attention_heads, 194 | sequence_reduction_ratio=sequence_reduction_ratio, 195 | ) 196 | self.output = SegformerSelfOutput(config, hidden_size=hidden_size) 197 | self.pruned_heads = set() 198 | 199 | def prune_heads(self, heads): 200 | if len(heads) == 0: 201 | return 202 | heads, index = find_pruneable_heads_and_indices( 203 | heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads 204 | ) 205 | 206 | self.self.query = prune_linear_layer(self.self.query, index) 207 | self.self.key = prune_linear_layer(self.self.key, index) 208 | self.self.value = prune_linear_layer(self.self.value, index) 209 | self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) 210 | 211 | self.self.num_attention_heads = self.self.num_attention_heads - len(heads) 212 | self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads 213 | self.pruned_heads = self.pruned_heads.union(heads) 214 | 215 | def forward(self, hidden_states, height, width, output_attentions=False): 216 | self_outputs = self.self(hidden_states, height, width, output_attentions) 217 | 218 | attention_output = self.output(self_outputs[0], hidden_states) 219 | outputs = (attention_output,) + self_outputs[1:] 220 | return outputs 221 | 222 | class SegformerDWConv(nn.Module): 223 | def __init__(self, dim=768): 224 | super().__init__() 225 | self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim) 226 | 227 | def forward(self, hidden_states, height, width): 228 | batch_size, seq_len, num_channels = hidden_states.shape 229 | hidden_states = hidden_states.transpose(1, 2).view(batch_size, num_channels, height, width) 230 | hidden_states = self.dwconv(hidden_states) 231 | hidden_states = hidden_states.flatten(2).transpose(1, 2) 232 | 233 | return hidden_states 234 | 235 | class SegformerMixFFN(nn.Module): 236 | def __init__(self, config, in_features, hidden_features=None, out_features=None): 237 | super().__init__() 238 | out_features = out_features or in_features 239 | self.dense1 = nn.Linear(in_features, hidden_features) 240 | self.dwconv = SegformerDWConv(hidden_features) 241 | if isinstance(config.hidden_act, str): 242 | self.intermediate_act_fn = ACT2FN[config.hidden_act] 243 | else: 244 | self.intermediate_act_fn = config.hidden_act 245 | self.dense2 = nn.Linear(hidden_features, out_features) 246 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 247 | 248 | def forward(self, hidden_states, height, width): 249 | hidden_states = self.dense1(hidden_states) 250 | hidden_states = self.dwconv(hidden_states, height, width) 251 | hidden_states = self.intermediate_act_fn(hidden_states) 252 | hidden_states = self.dropout(hidden_states) 253 | hidden_states = self.dense2(hidden_states) 254 | hidden_states = self.dropout(hidden_states) 255 | return hidden_states 256 | 257 | class SegformerLayer(nn.Module): 258 | def __init__(self, config, hidden_size, num_attention_heads, drop_path, sequence_reduction_ratio, mlp_ratio): 259 | super().__init__() 260 | self.layer_norm_1 = nn.LayerNorm(hidden_size) 261 | self.attention = SegformerAttention( 262 | config, 263 | hidden_size=hidden_size, 264 | num_attention_heads=num_attention_heads, 265 | sequence_reduction_ratio=sequence_reduction_ratio, 266 | ) 267 | self.drop_path = SegformerDropPath(drop_path) if drop_path > 0.0 else nn.Identity() 268 | self.layer_norm_2 = nn.LayerNorm(hidden_size) 269 | mlp_hidden_size = int(hidden_size * mlp_ratio) 270 | self.mlp = SegformerMixFFN(config, in_features=hidden_size, hidden_features=mlp_hidden_size) 271 | 272 | def forward(self, hidden_states, height, width, output_attentions=False): 273 | self_attention_outputs = self.attention( 274 | self.layer_norm_1(hidden_states), 275 | height, 276 | width, 277 | output_attentions=output_attentions, 278 | ) 279 | 280 | attention_output = self_attention_outputs[0] 281 | outputs = self_attention_outputs[1:] 282 | 283 | attention_output = self.drop_path(attention_output) 284 | hidden_states = attention_output + hidden_states 285 | 286 | mlp_output = self.mlp(self.layer_norm_2(hidden_states), height, width) 287 | 288 | mlp_output = self.drop_path(mlp_output) 289 | layer_output = mlp_output + hidden_states 290 | 291 | outputs = (layer_output,) + outputs 292 | return outputs 293 | 294 | class SegformerEncoder(nn.Module): 295 | def __init__(self, config): 296 | super().__init__() 297 | self.config = config 298 | 299 | drop_path_decays = [x.item() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths))] 300 | 301 | embeddings = [] 302 | for i in range(config.num_encoder_blocks): 303 | embeddings.append( 304 | SegformerOverlapPatchEmbeddings( 305 | patch_size=config.patch_sizes[i], 306 | stride=config.strides[i], 307 | num_channels=config.num_channels if i == 0 else config.hidden_sizes[i - 1], 308 | hidden_size=config.hidden_sizes[i], 309 | ) 310 | ) 311 | self.patch_embeddings = nn.ModuleList(embeddings) 312 | 313 | blocks, cur = [], 0 314 | for i in range(config.num_encoder_blocks): 315 | layers = [] 316 | if i != 0: 317 | cur += config.depths[i - 1] 318 | for j in range(config.depths[i]): 319 | layers.append( 320 | SegformerLayer( 321 | config, 322 | hidden_size=config.hidden_sizes[i], 323 | num_attention_heads=config.num_attention_heads[i], 324 | drop_path=drop_path_decays[cur + j], 325 | sequence_reduction_ratio=config.sr_ratios[i], 326 | mlp_ratio=config.mlp_ratios[i], 327 | ) 328 | ) 329 | blocks.append(nn.ModuleList(layers)) 330 | 331 | self.block = nn.ModuleList(blocks) 332 | 333 | self.layer_norm = nn.ModuleList( 334 | [nn.LayerNorm(config.hidden_sizes[i]) for i in range(config.num_encoder_blocks)] 335 | ) 336 | 337 | def forward( 338 | self, 339 | pixel_values: torch.FloatTensor, 340 | output_attentions: Optional[bool] = False, 341 | output_hidden_states: Optional[bool] = False, 342 | return_dict: Optional[bool] = True, 343 | ) -> Union[Tuple, BaseModelOutput]: 344 | all_hidden_states = () if output_hidden_states else None 345 | all_self_attentions = () if output_attentions else None 346 | 347 | batch_size = pixel_values.shape[0] 348 | 349 | hidden_states = pixel_values 350 | for idx, x in enumerate(zip(self.patch_embeddings, self.block, self.layer_norm)): 351 | embedding_layer, block_layer, norm_layer = x 352 | hidden_states, height, width = embedding_layer(hidden_states) 353 | for i, blk in enumerate(block_layer): 354 | layer_outputs = blk(hidden_states, height, width, output_attentions) 355 | hidden_states = layer_outputs[0] 356 | if output_attentions: 357 | all_self_attentions = all_self_attentions + (layer_outputs[1],) 358 | hidden_states = norm_layer(hidden_states) 359 | if idx != len(self.patch_embeddings) - 1 or ( 360 | idx == len(self.patch_embeddings) - 1 and self.config.reshape_last_stage 361 | ): 362 | hidden_states = hidden_states.reshape(batch_size, height, width, -1).permute(0, 3, 1, 2).contiguous() 363 | if output_hidden_states: 364 | all_hidden_states = all_hidden_states + (hidden_states,) 365 | 366 | if not return_dict: 367 | return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) 368 | return BaseModelOutput( 369 | last_hidden_state=hidden_states, 370 | hidden_states=all_hidden_states, 371 | attentions=all_self_attentions, 372 | ) 373 | 374 | class SegformerPreTrainedModel(PreTrainedModel): 375 | config_class = SegformerConfig 376 | base_model_prefix = "segformer" 377 | main_input_name = "pixel_values" 378 | 379 | def _init_weights(self, module): 380 | if isinstance(module, (nn.Linear, nn.Conv2d)): 381 | module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) 382 | if module.bias is not None: 383 | module.bias.data.zero_() 384 | elif isinstance(module, nn.Embedding): 385 | module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) 386 | if module.padding_idx is not None: 387 | module.weight.data[module.padding_idx].zero_() 388 | elif isinstance(module, nn.LayerNorm): 389 | module.bias.data.zero_() 390 | module.weight.data.fill_(1.0) 391 | 392 | SEGFORMER_START_DOCSTRING = r""" 393 | This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use 394 | it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and 395 | behavior. 396 | 397 | Parameters: 398 | config ([`SegformerConfig`]): Model configuration class with all the parameters of the model. 399 | Initializing with a config file does not load the weights associated with the model, only the 400 | configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. 401 | """ 402 | 403 | SEGFORMER_INPUTS_DOCSTRING = r""" 404 | Args: 405 | pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): 406 | Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using 407 | [`AutoImageProcessor`]. See [`SegformerImageProcessor.__call__`] for details. 408 | 409 | output_attentions (`bool`, *optional*): 410 | Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned 411 | tensors for more detail. 412 | output_hidden_states (`bool`, *optional*): 413 | Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for 414 | more detail. 415 | return_dict (`bool`, *optional*): 416 | Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. 417 | """ 418 | 419 | @add_start_docstrings( 420 | "The bare SegFormer encoder (Mix-Transformer) outputting raw hidden-states without any specific head on top.", 421 | SEGFORMER_START_DOCSTRING, 422 | ) 423 | class SegformerModel(SegformerPreTrainedModel): 424 | def __init__(self, config): 425 | super().__init__(config) 426 | self.config = config 427 | 428 | self.encoder = SegformerEncoder(config) 429 | self.post_init() 430 | 431 | def _prune_heads(self, heads_to_prune): 432 | for layer, heads in heads_to_prune.items(): 433 | self.encoder.layer[layer].attention.prune_heads(heads) 434 | 435 | @add_start_docstrings_to_model_forward(SEGFORMER_INPUTS_DOCSTRING.format("(batch_size, sequence_length)")) 436 | @add_code_sample_docstrings( 437 | checkpoint=_CHECKPOINT_FOR_DOC, 438 | output_type=BaseModelOutput, 439 | config_class=_CONFIG_FOR_DOC, 440 | modality="vision", 441 | expected_output=_EXPECTED_OUTPUT_SHAPE, 442 | ) 443 | def forward( 444 | self, 445 | pixel_values: torch.FloatTensor, 446 | output_attentions: Optional[bool] = None, 447 | output_hidden_states: Optional[bool] = None, 448 | return_dict: Optional[bool] = None, 449 | ) -> Union[Tuple, BaseModelOutput]: 450 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 451 | output_hidden_states = ( 452 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 453 | ) 454 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 455 | 456 | encoder_outputs = self.encoder( 457 | pixel_values, 458 | output_attentions=output_attentions, 459 | output_hidden_states=output_hidden_states, 460 | return_dict=return_dict, 461 | ) 462 | sequence_output = encoder_outputs[0] 463 | 464 | if not return_dict: 465 | return (sequence_output,) + encoder_outputs[1:] 466 | 467 | return BaseModelOutput( 468 | last_hidden_state=sequence_output, 469 | hidden_states=encoder_outputs.hidden_states, 470 | attentions=encoder_outputs.attentions, 471 | ) 472 | 473 | 474 | @add_start_docstrings( 475 | """ 476 | SegFormer Model transformer with an image classification head on top (a linear layer on top of the final hidden 477 | states) e.g. for ImageNet. 478 | """, 479 | SEGFORMER_START_DOCSTRING, 480 | ) 481 | class SegformerForImageClassification(SegformerPreTrainedModel): 482 | def __init__(self, config): 483 | super().__init__(config) 484 | 485 | #import pdb; pdb.set_trace() 486 | 487 | self.num_labels = config.num_labels 488 | #self.num_labels2 = config.num_labels2 489 | self.segformer = SegformerModel(config) 490 | 491 | # Classifier head 492 | self.classifier = nn.Linear(config.hidden_sizes[-1], config.num_labels) 493 | 494 | # Initialize weights and apply final processing 495 | self.post_init() 496 | 497 | @add_start_docstrings_to_model_forward(SEGFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length")) 498 | @add_code_sample_docstrings( 499 | checkpoint=_IMAGE_CLASS_CHECKPOINT, 500 | output_type=SegFormerImageClassifierOutput, 501 | config_class=_CONFIG_FOR_DOC, 502 | expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT, 503 | ) 504 | def forward( 505 | self, 506 | pixel_values: Optional[torch.FloatTensor] = None, 507 | labels: Optional[torch.LongTensor] = None, 508 | output_attentions: Optional[bool] = None, 509 | output_hidden_states: Optional[bool] = None, 510 | return_dict: Optional[bool] = None, 511 | ) -> Union[Tuple, SegFormerImageClassifierOutput]: 512 | r""" 513 | labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): 514 | Labels for computing the image classification/regression loss. Indices should be in `[0, ..., 515 | config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If 516 | `config.num_labels > 1` a classification loss is computed (Cross-Entropy). 517 | """ 518 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 519 | 520 | outputs = self.segformer( 521 | pixel_values, 522 | output_attentions=output_attentions, 523 | output_hidden_states=output_hidden_states, 524 | return_dict=return_dict, 525 | ) 526 | 527 | sequence_output = outputs[0] 528 | 529 | # convert last hidden states to (batch_size, height*width, hidden_size) 530 | batch_size = sequence_output.shape[0] 531 | if self.config.reshape_last_stage: 532 | # (batch_size, num_channels, height, width) -> (batch_size, height, width, num_channels) 533 | sequence_output = sequence_output.permute(0, 2, 3, 1) 534 | sequence_output = sequence_output.reshape(batch_size, -1, self.config.hidden_sizes[-1]) 535 | 536 | # # global average pooling 537 | sequence_output = sequence_output.mean(dim=1) 538 | 539 | logits = self.classifier(sequence_output) 540 | 541 | #import pdb; pdb.set_trace() 542 | 543 | loss = None 544 | if labels is not None: 545 | if self.config.problem_type is None: 546 | if self.num_labels == 1: 547 | self.config.problem_type = "regression" 548 | elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): 549 | self.config.problem_type = "single_label_classification" 550 | else: 551 | self.config.problem_type = "multi_label_classification" 552 | 553 | if self.config.problem_type == "regression": 554 | loss_fct = MSELoss() 555 | if self.num_labels == 1: 556 | loss = loss_fct(logits.squeeze(), labels.squeeze()) 557 | else: 558 | loss = loss_fct(logits, labels) 559 | elif self.config.problem_type == "single_label_classification": 560 | loss_fct = CrossEntropyLoss() 561 | loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) 562 | elif self.config.problem_type == "multi_label_classification": 563 | loss_fct = BCEWithLogitsLoss() 564 | loss = loss_fct(logits, labels) 565 | if not return_dict: 566 | output = (logits,) + outputs[1:] 567 | return ((loss,) + output) if loss is not None else output 568 | 569 | return SegFormerImageClassifierOutput( 570 | loss=loss, 571 | logits=logits, 572 | hidden_states=outputs.hidden_states, 573 | attentions=outputs.attentions, 574 | ) 575 | 576 | class SegformerMLP(nn.Module): 577 | def __init__(self, config: SegformerConfig, input_dim): 578 | super().__init__() 579 | self.proj = nn.Linear(input_dim, config.decoder_hidden_size) 580 | 581 | def forward(self, hidden_states: torch.Tensor): 582 | hidden_states = hidden_states.flatten(2).transpose(1, 2) 583 | hidden_states = self.proj(hidden_states) 584 | return hidden_states 585 | 586 | from torch.autograd import Function 587 | 588 | class ReverseLayerF(Function): 589 | @staticmethod 590 | def forward(ctx, x, alpha): 591 | ctx.alpha = alpha 592 | 593 | return x.view_as(x) 594 | 595 | @staticmethod 596 | def backward(ctx, grad_output): 597 | output = grad_output.neg() * ctx.alpha 598 | 599 | return output, None 600 | 601 | class SegformerDecodeHead(SegformerPreTrainedModel): 602 | def __init__(self, config, alpha): 603 | super().__init__(config) 604 | # # linear layers which will unify the channel dimension of each of the encoder blocks to the same config.decoder_hidden_size 605 | mlps = [] 606 | for i in range(config.num_encoder_blocks): 607 | mlp = SegformerMLP(config, input_dim=config.hidden_sizes[i]) 608 | mlps.append(mlp) 609 | self.linear_c = nn.ModuleList(mlps) 610 | 611 | # # the following 3 layers implement the ConvModule of the original implementation 612 | self.linear_fuse = nn.Conv2d( 613 | in_channels=config.decoder_hidden_size * config.num_encoder_blocks, 614 | out_channels=config.decoder_hidden_size, 615 | kernel_size=1, 616 | bias=False, 617 | ) 618 | self.batch_norm = nn.BatchNorm2d(config.decoder_hidden_size) 619 | self.activation = nn.ReLU() 620 | 621 | self.dropout = nn.Dropout(config.classifier_dropout_prob) 622 | self.classifier = nn.Conv2d(config.decoder_hidden_size, config.num_labels, kernel_size=1) 623 | 624 | #self.classifier = nn.Conv2d(config.decoder_hidden_size, config.num_labels, kernel_size=1) 625 | 626 | #self.classifier2 = nn.Conv2d(config.decoder_hidden_size, config.num_labels2, kernel_size=1) 627 | #config.num_labels2 = 2 628 | config.num_labels2 = 3 629 | self.classifier2 = nn.Linear(config.decoder_hidden_size, config.num_labels2) 630 | 631 | self.config = config 632 | 633 | def forward(self, encoder_hidden_states: torch.FloatTensor, alpha: Optional[torch.FloatTensor] = None) -> torch.Tensor: 634 | batch_size = encoder_hidden_states[-1].shape[0] 635 | 636 | all_hidden_states = () 637 | for encoder_hidden_state, mlp in zip(encoder_hidden_states, self.linear_c): 638 | if self.config.reshape_last_stage is False and encoder_hidden_state.ndim == 3: 639 | height = width = int(math.sqrt(encoder_hidden_state.shape[-1])) 640 | encoder_hidden_state = ( 641 | encoder_hidden_state.reshape(batch_size, height, width, -1).permute(0, 3, 1, 2).contiguous() 642 | ) 643 | 644 | # # unify channel dimension 645 | height, width = encoder_hidden_state.shape[2], encoder_hidden_state.shape[3] 646 | encoder_hidden_state = mlp(encoder_hidden_state) 647 | encoder_hidden_state = encoder_hidden_state.permute(0, 2, 1) 648 | encoder_hidden_state = encoder_hidden_state.reshape(batch_size, -1, height, width) 649 | 650 | # # # upsample 651 | encoder_hidden_state = nn.functional.interpolate( 652 | encoder_hidden_state, size=encoder_hidden_states[0].size()[2:], mode="bilinear", align_corners=False 653 | ) 654 | 655 | all_hidden_states += (encoder_hidden_state,) 656 | 657 | hidden_states = self.linear_fuse(torch.cat(all_hidden_states[::-1], dim=1)) 658 | hidden_states = self.batch_norm(hidden_states) 659 | hidden_states = self.activation(hidden_states) 660 | hidden_states = self.dropout(hidden_states) 661 | 662 | # # logits are of shape (batch_size, num_labels, height / 4, width / 4) 663 | #logits = self.classifier(hidden_states) 664 | 665 | logits = self.classifier(hidden_states) 666 | 667 | #logits2 = self.classifier2(hidden_states) 668 | 669 | hidden_states = F.adaptive_avg_pool2d(hidden_states, 1).reshape(batch_size, -1) 670 | 671 | #print(alpha) 672 | 673 | hidden_states = ReverseLayerF.apply(hidden_states, alpha) 674 | 675 | logits2 = self.classifier2(hidden_states) 676 | 677 | logits2 = F.softmax(logits2) 678 | 679 | #return logits 680 | return logits, logits2 681 | 682 | @add_start_docstrings( 683 | """SegFormer Model transformer with an all-MLP decode head on top e.g. for ADE20k, CityScapes. """, 684 | SEGFORMER_START_DOCSTRING, 685 | ) 686 | class SegformerForSemanticSegmentation(SegformerPreTrainedModel): 687 | def __init__(self, config, alpha): 688 | super().__init__(config) 689 | self.segformer = SegformerModel(config) 690 | 691 | # # SegformerForSemanticSegmentation 692 | 693 | #print(config) 694 | #print(config.num_labels) 695 | 696 | #self.decode_head = SegformerDecodeHead(config) 697 | 698 | #self.alpha = alpha 699 | 700 | #self.decode_head = SegformerDecodeHead(config, self.alpha) 701 | #self.decode_head, self.decode_head2 = SegformerDecodeHead(config, self.alpha) 702 | 703 | self.alpha = alpha 704 | self.decode_head = SegformerDecodeHead(config, self.alpha) 705 | 706 | #self.decode_head2 = SegformerDecodeHead(config, alpha) 707 | 708 | #self.decode_head = SegformerDecodeHead(config, alpha) 709 | 710 | # # Initialize weights and apply final processing 711 | self.post_init() 712 | 713 | @add_start_docstrings_to_model_forward(SEGFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length")) 714 | @replace_return_docstrings(output_type=SemanticSegmenterOutput, config_class=_CONFIG_FOR_DOC) 715 | def forward( 716 | self, 717 | pixel_values: torch.FloatTensor, 718 | labels: Optional[torch.LongTensor] = None, 719 | output_attentions: Optional[bool] = None, 720 | output_hidden_states: Optional[bool] = None, 721 | return_dict: Optional[bool] = None, 722 | alpha: Optional[torch.FloatTensor] = None, 723 | image_patches: Optional[torch.FloatTensor] = None, 724 | data2: Optional[torch.FloatTensor] = None, 725 | target2: Optional[torch.FloatTensor] = None, 726 | ) -> Union[Tuple, SemanticSegmenterOutput]: 727 | r""" 728 | labels (`torch.LongTensor` of shape `(batch_size, height, width)`, *optional*): 729 | Ground truth semantic segmentation maps for computing the loss. Indices should be in `[0, ..., 730 | config.num_labels - 1]`. If `config.num_labels > 1`, a classification loss is computed (Cross-Entropy). 731 | 732 | Returns: 733 | 734 | Examples: 735 | 736 | ```python 737 | >>> from transformers import AutoImageProcessor, SegformerForSemanticSegmentation 738 | >>> from PIL import Image 739 | >>> import requests 740 | 741 | >>> image_processor = AutoImageProcessor.from_pretrained("nvidia/segformer-b0-finetuned-ade-512-512") 742 | >>> model = SegformerForSemanticSegmentation.from_pretrained("nvidia/segformer-b0-finetuned-ade-512-512") 743 | 744 | >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" 745 | >>> image = Image.open(requests.get(url, stream=True).raw) 746 | 747 | >>> inputs = image_processor(images=image, return_tensors="pt") 748 | >>> outputs = model(**inputs) 749 | >>> logits = outputs.logits # # shape (batch_size, num_labels, height / 4, width / 4) 750 | >>> list(logits.shape) 751 | [1, 150, 128, 128] 752 | ```""" 753 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 754 | output_hidden_states = ( 755 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 756 | ) 757 | 758 | outputs = self.segformer( 759 | pixel_values, 760 | output_attentions=output_attentions, 761 | output_hidden_states=True, # # we need the intermediate hidden states 762 | return_dict=return_dict, 763 | ) 764 | 765 | #print(return_dict) 766 | # # True 767 | 768 | encoder_hidden_states = outputs.hidden_states if return_dict else outputs[1] 769 | 770 | #logits = self.decode_head(encoder_hidden_states) 771 | #logits, logits2 = self.decode_head(encoder_hidden_states) 772 | 773 | #print(alpha) 774 | 775 | logits, logits2 = self.decode_head(encoder_hidden_states, alpha) 776 | 777 | #logits, logits2 = self.decode_head(encoder_hidden_states) 778 | 779 | #logits2 = self.decode_head2(encoder_hidden_states) 780 | 781 | #print(logits2) 782 | #print(logits2.shape) 783 | 784 | import segmentation_models_pytorch as smp 785 | 786 | loss = None 787 | 788 | if labels is not None: 789 | upsampled_logits = nn.functional.interpolate( 790 | logits, size=labels.shape[-2:], mode="bilinear", align_corners=False 791 | ) 792 | if self.config.num_labels > 1: 793 | loss_fct = CrossEntropyLoss(ignore_index=self.config.semantic_loss_ignore_index) 794 | loss = loss_fct(upsampled_logits, labels) 795 | 796 | # print(upsampled_logits.shape) 797 | # print(labels.shape) 798 | 799 | # # Dice Loss 800 | 801 | #loss += smp.losses.JaccardLoss(mode="multiclass", classes=6)(y_pred=upsampled_logits, y_true=labels) 802 | loss += smp.losses.DiceLoss(mode="multiclass", classes=6)(y_pred=upsampled_logits, y_true=labels) 803 | 804 | batch_size = labels.shape[0] 805 | domain_label = torch.zeros(batch_size).long().cuda() 806 | 807 | err_s_domain = nn.CrossEntropyLoss()(logits2, domain_label) 808 | domain_label = torch.ones(batch_size).long().cuda() 809 | 810 | #_, domain_output = net(image_patches, alpha) 811 | #_, domain_output = self.decode_head(encoder_hidden_states, alpha) 812 | 813 | #_, domain_output = net(image_patches, alpha) 814 | 815 | # outputs2 = self.segformer( 816 | # pixel_values, 817 | # output_attentions=output_attentions, 818 | # output_hidden_states=True, # we need the intermediate hidden states 819 | # return_dict=return_dict, 820 | # ) 821 | 822 | outputs2 = self.segformer( 823 | image_patches, 824 | output_attentions=output_attentions, 825 | output_hidden_states=True, # we need the intermediate hidden states 826 | return_dict=return_dict, 827 | ) 828 | 829 | # outputs2 = self.segformer( 830 | # pixel_values, 831 | # output_attentions=output_attentions, 832 | # output_hidden_states=True, # # we need the intermediate hidden states 833 | # return_dict=return_dict, 834 | # ) 835 | 836 | domain_output = outputs2.hidden_states 837 | _, domain_output = self.decode_head(domain_output, alpha) 838 | 839 | err_t_domain = nn.CrossEntropyLoss()(domain_output, domain_label) 840 | 841 | loss += err_s_domain + err_t_domain 842 | 843 | #loss += err_s_domain + err_t_domain 844 | 845 | #print(loss) 846 | #print(err_s_domain + err_t_domain) 847 | 848 | #import pdb; pdb.set_trace() 849 | 850 | #output2, oo222 = net(data2, alpha) 851 | 852 | # # # data2, target2 853 | 854 | outputs3 = self.segformer( 855 | data2, 856 | output_attentions=output_attentions, 857 | output_hidden_states=True, # # we need the intermediate hidden states 858 | return_dict=return_dict, 859 | ) 860 | outputt22 = outputs3.hidden_states 861 | output2, logitss22 = self.decode_head(outputt22, alpha) 862 | upsampled_logitss22 = nn.functional.interpolate( 863 | output2, size=labels.shape[-2:], mode="bilinear", align_corners=False 864 | ) 865 | #loss_fct = CrossEntropyLoss(ignore_index=self.config.semantic_loss_ignore_index) 866 | #losss22 = loss_fct(upsampled_logitss22, target2) 867 | loss += loss_fct(upsampled_logitss22, target2) 868 | 869 | #loss += CrossEntropy2d(output2, target2, weight=weights) 870 | 871 | #loss += smp.losses.JaccardLoss(mode="multiclass", classes=6)(y_pred=output2, y_true=target2) 872 | 873 | #loss += smp.losses.JaccardLoss(mode="multiclass", classes=6)(y_pred=upsampled_logitss22, y_true=target2) 874 | loss += smp.losses.DiceLoss(mode="multiclass", classes=6)(y_pred=upsampled_logitss22, y_true=target2) 875 | 876 | #domain_label2 = (2*torch.ones(BATCH_SIZE).long()).cuda() 877 | 878 | domain_label2 = (2*torch.ones(batch_size).long()).cuda() 879 | 880 | #loss += nn.CrossEntropyLoss()(oo222, domain_label2) 881 | 882 | loss += nn.CrossEntropyLoss()(logitss22, domain_label2) 883 | 884 | #import segmentation_models_pytorch as smp 885 | #loss += smp.losses.JaccardLoss(mode="multiclass", classes=6)(y_pred=upsampled_logits, y_true=labels) 886 | 887 | #jaccard_distance_loss1 = JaccardDistanceLoss() 888 | #loss += jaccard_distance_loss1(labels, upsampled_logits) 889 | 890 | # from torchmetrics import JaccardIndex 891 | # #pred[2:5, 7:13, 9:15] = 1 - pred[2:5, 7:13, 9:15] 892 | # #upsampled_logits = 1 - upsampled_logits 893 | # jaccard = JaccardIndex(num_classes=150) 894 | # loss += jaccard(upsampled_logits, labels) 895 | 896 | #print(loss) 897 | 898 | #loss += lovaszloss(upsampled_logits, labels) 899 | 900 | #import segmentation_models_pytorch as smp 901 | #loss += smp.losses.JaccardLoss(mode="multiclass", classes=151)(y_pred=upsampled_logits, y_true=labels) 902 | 903 | #import pdb; pdb.set_trace() 904 | 905 | #print(loss) 906 | 907 | elif self.config.num_labels == 1: 908 | valid_mask = ((labels >= 0) & (labels != self.config.semantic_loss_ignore_index)).float() 909 | loss_fct = BCEWithLogitsLoss(reduction="none") 910 | loss = loss_fct(upsampled_logits.squeeze(1), labels.float()) 911 | loss = (loss * valid_mask).mean() 912 | 913 | else: 914 | raise ValueError(f"Number of labels should be >=0: {self.config.num_labels}") 915 | 916 | if not return_dict: 917 | if output_hidden_states: 918 | output = (logits,) + outputs[1:] 919 | else: 920 | output = (logits,) + outputs[2:] 921 | return ((loss,) + output) if loss is not None else output 922 | 923 | return SemanticSegmenterOutput( 924 | loss=loss, 925 | logits=logits, 926 | hidden_states=outputs.hidden_states if output_hidden_states else None, 927 | attentions=outputs.attentions, 928 | ) 929 | 930 | -------------------------------------------------------------------------------- /src/test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | import sys 5 | #get_ipython().system('{sys.executable} -m pip install pytorch_ssim') 6 | #get_ipython().system('{sys.executable} -m pip install pytorch_msssim') 7 | #pip install pytorch_ssim 8 | 9 | # #!{sys.executable} -m conda install pytorch torchvision torchaudio pytorch-cuda=11.8 -c pytorch -c nvidia 10 | # #!{sys.executable} -m pip install timm 11 | 12 | # # torch.save(model.state_dict(), '/Data/ndionelis/segformermain30082023noSubsaa22') 13 | # # data_dir = '/Data/ndionelis/bingmap/mainTheDatasetNoSubsa2' 14 | 15 | # # # This is to test: torch.save(model.state_dict(), './segformermain30082023') 16 | # # # Also: data_dir = '../../CVUSA/bingmap/mainTheDataset' 17 | 18 | # import numpy as np 19 | # from skimage import io 20 | # from glob import glob 21 | # from tqdm import tqdm_notebook as tqdm 22 | # from sklearn.metrics import confusion_matrix 23 | # import random 24 | # import itertools 25 | # # # Matplotlib 26 | # import matplotlib.pyplot as plt 27 | # #get_ipython().run_line_magic('matplotlib', 'inline') 28 | # #from IPython import get_ipython 29 | # #get_ipython().run_line_magic('matplotlib', 'inline') 30 | # #exec(%matplotlib inline) 31 | # # # Torch imports 32 | # import torch 33 | # import torch.nn as nn 34 | # import torch.nn.functional as F 35 | # import torch.utils.data as data 36 | # import torch.optim as optim 37 | # import torch.optim.lr_scheduler 38 | # import torch.nn.init 39 | # from torch.autograd import Variable 40 | # import torchvision.transforms as T 41 | # import albumentations as A 42 | # import segmentation_models_pytorch as smp 43 | # import kornia 44 | 45 | # WINDOW_SIZE = (256, 256) # Patch size 46 | # STRIDE = 32 # # # # Stride for testing 47 | # IN_CHANNELS = 3 # Number of input channels (e.g. RGB) 48 | # #FOLDER = "./ISPRS_dataset/" # # Replace with your "/path/to/the/ISPRS/dataset/folder/" 49 | # #FOLDER = "../" 50 | # FOLDER = "/Data/ndionelis/" 51 | # #BATCH_SIZE = 10 # Number of samples in a mini-batch 52 | # #BATCH_SIZE = 64 53 | # #BATCH_SIZE = 128 54 | # #BATCH_SIZE = 30 55 | # BATCH_SIZE = 10 56 | 57 | # LABELS = ["roads", "buildings", "low veg.", "trees", "cars", "clutter"] # # Label names 58 | # N_CLASSES = len(LABELS) # # Number of classes 59 | # #print(N_CLASSES) 60 | 61 | # WEIGHTS = torch.ones(N_CLASSES) # # # Weights for class balancing 62 | # CACHE = True # # Store the dataset in-memory 63 | 64 | # #DATASET = 'Vaihingen' 65 | # DATASET = 'Potsdam' 66 | # if DATASET == 'Potsdam': 67 | # MAIN_FOLDER = FOLDER + 'Potsdam/' 68 | # # Uncomment the next line for IRRG data 69 | # # DATA_FOLDER = MAIN_FOLDER + '3_Ortho_IRRG/top_potsdam_{}_IRRG.tif' 70 | # # # For RGB data 71 | # #print(MAIN_FOLDER) 72 | # #sadfszf 73 | # #print(MAIN_FOLDER) 74 | # #asdfklsz 75 | # DATA_FOLDER = MAIN_FOLDER + '2_Ortho_RGB/top_potsdam_{}_RGB.tif' 76 | # LABEL_FOLDER = MAIN_FOLDER + '5_Labels_for_participants/top_potsdam_{}_label.tif' 77 | # ERODED_FOLDER = MAIN_FOLDER + '5_Labels_for_participants_no_Boundary/top_potsdam_{}_label_noBoundary.tif' 78 | # elif DATASET == 'Vaihingen': 79 | # MAIN_FOLDER = FOLDER + 'Vaihingen/' 80 | # #print(MAIN_FOLDER) 81 | # #asdfszdf 82 | # DATA_FOLDER = MAIN_FOLDER + 'top/top_mosaic_09cm_area{}.tif' 83 | # LABEL_FOLDER = MAIN_FOLDER + 'gts_for_participants/top_mosaic_09cm_area{}.tif' 84 | # ERODED_FOLDER = MAIN_FOLDER + 'gts_eroded_for_participants/top_mosaic_09cm_area{}_noBoundary.tif' 85 | 86 | import numpy as np 87 | from skimage import io 88 | from glob import glob 89 | from tqdm import tqdm_notebook as tqdm 90 | from sklearn.metrics import confusion_matrix 91 | import random 92 | import itertools 93 | # # Matplotlib 94 | import matplotlib.pyplot as plt 95 | #get_ipython().run_line_magic('matplotlib', 'inline') 96 | # # The PyTorch imports 97 | import torch 98 | import torch.nn as nn 99 | import torch.nn.functional as F 100 | import torch.utils.data as data 101 | import torch.optim as optim 102 | import torch.optim.lr_scheduler 103 | import torch.nn.init 104 | from torch.autograd import Variable 105 | import torchvision.transforms as T 106 | import albumentations as A 107 | import segmentation_models_pytorch as smp 108 | import time 109 | import kornia 110 | 111 | WINDOW_SIZE = (256, 256) # # The patch size 112 | WINDOW_SIZE = (512, 512) # # Patch size 113 | STRIDE = 32 # # # # # Stride for testing 114 | IN_CHANNELS = 3 # Number of input channels (e.g. RGB) 115 | #FOLDER = "./ISPRS_dataset/" # Replace with your "/path/to/the/ISPRS/dataset/folder/" 116 | #FOLDER = "../../" 117 | FOLDER = "/Data/ndionelis/" 118 | #BATCH_SIZE = 10 # # Number of samples in a mini-batch 119 | #BATCH_SIZE = 64 120 | #BATCH_SIZE = 128 121 | BATCH_SIZE = 10 122 | #BATCH_SIZE = 30 123 | 124 | LABELS = ["roads", "buildings", "low veg.", "trees", "cars", "clutter"] # # Label names 125 | N_CLASSES = len(LABELS) # # Number of classes 126 | #print(N_CLASSES) 127 | 128 | WEIGHTS = torch.ones(N_CLASSES) # # # Weights for class balancing 129 | CACHE = True # # Store the dataset in-memory 130 | 131 | DATASET = 'Vaihingen' 132 | #DATASET = 'Potsdam' 133 | 134 | if DATASET == 'Potsdam': 135 | MAIN_FOLDER = FOLDER + 'Potsdam/' 136 | # Uncomment the next line for IRRG data 137 | # DATA_FOLDER = MAIN_FOLDER + '3_Ortho_IRRG/top_potsdam_{}_IRRG.tif' 138 | # # For RGB data 139 | DATA_FOLDER = MAIN_FOLDER + '2_Ortho_RGB/top_potsdam_{}_RGB.tif' 140 | LABEL_FOLDER = MAIN_FOLDER + '5_Labels_for_participants/top_potsdam_{}_label.tif' 141 | ERODED_FOLDER = MAIN_FOLDER + '5_Labels_for_participants_no_Boundary/top_potsdam_{}_label_noBoundary.tif' 142 | elif DATASET == 'Vaihingen': 143 | MAIN_FOLDER = FOLDER + 'Vaihingen/' 144 | DATA_FOLDER = MAIN_FOLDER + 'top/top_mosaic_09cm_area{}.tif' 145 | LABEL_FOLDER = MAIN_FOLDER + 'gts_for_participants/top_mosaic_09cm_area{}.tif' 146 | ERODED_FOLDER = MAIN_FOLDER + 'gts_eroded_for_participants/top_mosaic_09cm_area{}_noBoundary.tif' 147 | 148 | # ISPRS color palette 149 | # # Let's define the standard ISPRS color palette 150 | palette = {0 : (255, 255, 255), # Impervious surfaces (white) 151 | 1 : (0, 0, 255), # # Buildings (blue) 152 | 2 : (0, 255, 255), # Low vegetation (cyan) 153 | 3 : (0, 255, 0), # Trees (green) 154 | 4 : (255, 255, 0), # Cars (yellow) 155 | 5 : (255, 0, 0), # Clutter (red) 156 | 6 : (0, 0, 0)} # Undefined (black) 157 | invert_palette = {v: k for k, v in palette.items()} 158 | 159 | def convert_to_color(arr_2d, palette=palette): 160 | """ Numeric labels to RGB-color encoding """ 161 | arr_3d = np.zeros((arr_2d.shape[0], arr_2d.shape[1], 3), dtype=np.uint8) 162 | 163 | for c, i in palette.items(): 164 | m = arr_2d == c 165 | arr_3d[m] = i 166 | 167 | return arr_3d 168 | 169 | def convert_from_color(arr_3d, palette=invert_palette): 170 | arr_2d = np.zeros((arr_3d.shape[0], arr_3d.shape[1]), dtype=np.uint8) 171 | 172 | for c, i in palette.items(): 173 | m = np.all(arr_3d == np.array(c).reshape(1, 1, 3), axis=2) 174 | arr_2d[m] = i 175 | 176 | return arr_2d 177 | 178 | img = io.imread('/Data/ndionelis/Vaihingen/top/top_mosaic_09cm_area11.tif') 179 | fig = plt.figure() 180 | fig.add_subplot(121) 181 | plt.imshow(img) 182 | 183 | gt = io.imread('/Data/ndionelis/Vaihingen/gts_for_participants/top_mosaic_09cm_area11.tif') 184 | fig.add_subplot(122) 185 | plt.imshow(gt) 186 | plt.show() 187 | 188 | array_gt = convert_from_color(gt) 189 | print("Ground truth in numerical format has shape ({},{}) : \n".format(*array_gt.shape[:2]), array_gt) 190 | 191 | def get_random_pos(img, window_shape): 192 | w, h = window_shape 193 | W, H = img.shape[-2:] 194 | x1 = random.randint(0, W - w - 1) 195 | x2 = x1 + w 196 | y1 = random.randint(0, H - h - 1) 197 | y2 = y1 + h 198 | return x1, x2, y1, y2 199 | 200 | def CrossEntropy2d(input, target, weight=None, size_average=True): 201 | dim = input.dim() 202 | if dim == 2: 203 | return F.cross_entropy(input, target, weight, size_average) 204 | elif dim == 4: 205 | output = input.view(input.size(0),input.size(1), -1) 206 | output = torch.transpose(output,1,2).contiguous() 207 | output = output.view(-1,output.size(2)) 208 | target = target.view(-1) 209 | return F.cross_entropy(output, target,weight, size_average) 210 | else: 211 | raise ValueError('Expected 2 or 4 dimensions (got {})'.format(dim)) 212 | 213 | def accuracy(input, target): 214 | return 100 * float(np.count_nonzero(input == target)) / target.size 215 | 216 | def sliding_window(top, step=10, window_size=(20,20)): 217 | for x in range(0, top.shape[0], step): 218 | if x + window_size[0] > top.shape[0]: 219 | x = top.shape[0] - window_size[0] 220 | for y in range(0, top.shape[1], step): 221 | if y + window_size[1] > top.shape[1]: 222 | y = top.shape[1] - window_size[1] 223 | yield x, y, window_size[0], window_size[1] 224 | 225 | def count_sliding_window(top, step=10, window_size=(20,20)): 226 | c = 0 227 | for x in range(0, top.shape[0], step): 228 | if x + window_size[0] > top.shape[0]: 229 | x = top.shape[0] - window_size[0] 230 | for y in range(0, top.shape[1], step): 231 | if y + window_size[1] > top.shape[1]: 232 | y = top.shape[1] - window_size[1] 233 | c += 1 234 | return c 235 | 236 | def grouper(n, iterable): 237 | it = iter(iterable) 238 | while True: 239 | chunk = tuple(itertools.islice(it, n)) 240 | if not chunk: 241 | return 242 | yield chunk 243 | 244 | def metrics(predictions, gts, label_values=LABELS): 245 | cm = confusion_matrix( 246 | gts, 247 | predictions, 248 | labels = range(len(label_values))) 249 | for0 = 0 250 | for1 = 0 251 | for2 = 0 252 | for3 = 0 253 | for4 = 0 254 | for5 = 0 255 | for6 = 0 256 | print("Confusion matrix :") 257 | print(cm) 258 | # # Compute global accuracy 259 | total = sum(sum(cm)) 260 | accuracy = sum([cm[x][x] for x in range(len(cm))]) 261 | accuracy *= 100 / float(total) 262 | print("{} pixels processed".format(total)) 263 | print("Total accuracy : {}%".format(accuracy)) 264 | 265 | # # # Compute F1 score 266 | # F1Score = np.zeros(len(label_values)) 267 | # for i in range(len(label_values)): 268 | # try: 269 | # F1Score[i] = 2. * cm[i,i] / (np.sum(cm[i,:]) + np.sum(cm[:,i])) 270 | # except: 271 | # # # Ignore exception if there is no element in class i for test set 272 | # pass 273 | # print("F1Score :") 274 | # for l_id, score in enumerate(F1Score): 275 | # print("{}: {}".format(label_values[l_id], score)) 276 | 277 | # # # Compute kappa coefficient 278 | # total = np.sum(cm) 279 | # pa = np.trace(cm) / float(total) 280 | # pe = np.sum(np.sum(cm, axis=0) * np.sum(cm, axis=1)) / float(total*total) 281 | # kappa = (pa - pe) / (1 - pe); 282 | # print("Kappa: " + str(kappa)) 283 | # return accuracy 284 | 285 | totalfor0 = [] 286 | totalfor0.append(for0) 287 | totalfor0.append(for1) 288 | totalfor0.append(for2) 289 | totalfor0.append(for3) 290 | totalfor0.append(for4) 291 | totalfor0.append(for5) 292 | totalfor0.append(for6) 293 | 294 | # # Compute F1 score 295 | totalF1Score = 0 296 | 297 | #print(label_values) 298 | #F1Score = np.zeros(len(label_values)) 299 | F1Score = np.zeros(len(label_values)) 300 | 301 | for i in range(len(label_values)): 302 | try: 303 | F1Score[i] = 2. * cm[i,i] / (np.sum(cm[i,:]) + np.sum(cm[:,i])) 304 | except: 305 | # # Ignore exception if there is no element in class i for test set 306 | pass 307 | print("F1Score :") 308 | for l_id, score in enumerate(F1Score): 309 | #print(l_id) 310 | 311 | #print("{}: {}, {}, {}".format(label_values[l_id], l_id, totalfor0[l_id], score)) 312 | print("{}: {}, {}".format(label_values[l_id], totalfor0[l_id], score)) 313 | #print("{}: {}".format(label_values[l_id], score)) 314 | #totalF1Score += ((for0) / (for0 + for1 + for2 + for3 + for4 + for5)) * score 315 | totalF1Score += ((totalfor0[l_id]) / (sum(totalfor0))) * score 316 | #if l_id < 5: 317 | # totalF1Score += ((totalfor0[l_id]) / (sum(totalfor0))) * score 318 | 319 | # totalfor0 = [] 320 | # totalfor0.append(for0) 321 | # totalfor0.append(for1) 322 | # totalfor0.append(for2) 323 | # totalfor0.append(for3) 324 | # totalfor0.append(for4) 325 | # totalfor0.append(for5) 326 | # totalfor0.append(for6) 327 | 328 | # print(for0) 329 | # print(for1) 330 | # print(for2) 331 | # print(for3) 332 | # print(for4) 333 | # print(for5) 334 | # #print(for6) 335 | 336 | print(totalF1Score) 337 | 338 | #accuracy = totalF1Score 339 | 340 | mini = 1 341 | nclass = 6 342 | maxi = nclass 343 | nbins = nclass 344 | predict = predictions + 1 345 | target = gts + 1 346 | 347 | predict = torch.from_numpy(np.array(predict)) 348 | target = torch.from_numpy(np.array(target)) 349 | 350 | #img = img.to(torch.float32) 351 | 352 | predict = predict.float() * (target < 7).float() 353 | #predict = predict * (target < 7) 354 | intersection = predict * (predict == target).float() 355 | #intersection = predict * (predict == target) 356 | # # areas of intersection and union 357 | # element 0 in intersection occur the main difference from np.bincount. set boundary to -1 is necessary. 358 | area_inter = torch.histc(intersection.cpu(), bins=nbins, min=mini, max=maxi) 359 | area_pred = torch.histc(predict.cpu(), bins=nbins, min=mini, max=maxi) 360 | target = target.float() 361 | area_lab = torch.histc(target.cpu(), bins=nbins, min=mini, max=maxi) 362 | area_union = area_pred + area_lab - area_inter 363 | assert torch.sum(area_inter > area_union).item() == 0, "Intersection area should be smaller than Union area" 364 | #return area_inter.float(), area_union.float() 365 | #iou_main = area_inter.float() / area_union.float() 366 | #iou_main = area_inter / area_union 367 | 368 | #iou_main = area_inter.float() / area_union.float() 369 | iou_main = 1.0 * np.sum(area_inter.float().numpy(), axis=0) / np.sum(np.spacing(1)+area_union.float().numpy(), axis=0) 370 | 371 | # mini = 1 372 | # #nclass = 6 373 | # nclass = 5 374 | # maxi = nclass 375 | # nbins = nclass 376 | # predict = predictions + 1 377 | # target = gts + 1 378 | 379 | # predict = torch.from_numpy(np.array(predict)) 380 | # target = torch.from_numpy(np.array(target)) 381 | 382 | # #img = img.to(torch.float32) 383 | 384 | # #predict = predict.float() * (target < 7).float() 385 | # predict = predict.float() * (target < 6).float() 386 | # #predict = predict * (target < 7) 387 | # intersection = predict * (predict == target).float() 388 | # #intersection = predict * (predict == target) 389 | # # # areas of intersection and union 390 | # # element 0 in intersection occur the main difference from np.bincount. set boundary to -1 is necessary. 391 | # area_inter = torch.histc(intersection.cpu(), bins=nbins, min=mini, max=maxi) 392 | # area_pred = torch.histc(predict.cpu(), bins=nbins, min=mini, max=maxi) 393 | # target = target.float() 394 | # area_lab = torch.histc(target.cpu(), bins=nbins, min=mini, max=maxi) 395 | # area_union = area_pred + area_lab - area_inter 396 | # assert torch.sum(area_inter > area_union).item() == 0, "Intersection area should be smaller than Union area" 397 | # #return area_inter.float(), area_union.float() 398 | # #iou_main = area_inter.float() / area_union.float() 399 | # #iou_main = area_inter / area_union 400 | 401 | # #iou_main = area_inter.float() / area_union.float() 402 | # iou_main = 1.0 * np.sum(area_inter.float().numpy(), axis=0) / np.sum(np.spacing(1)+area_union.float().numpy(), axis=0) 403 | 404 | #print(iou_main) 405 | 406 | #print('') 407 | 408 | print(iou_main) 409 | 410 | # #imPred = imPred * (imLab >= 0) 411 | # #numClass = 6 412 | # numClass = 6 + 1 413 | 414 | # #imPred = predictions 415 | # #imLab = gts 416 | 417 | # imPred = predictions[gts <= 5] + 1 418 | # imLab = gts[gts <= 5] + 1 419 | 420 | # #imPred = predictions[gts <= 5] 421 | # #imLab = gts[gts <= 5] 422 | 423 | # #imPred = imPred * (imLab <= 5) 424 | 425 | # imPred = imPred * (imLab <= 6) 426 | 427 | # # # Compute area intersection: 428 | # intersection = imPred * (imPred == imLab) 429 | # (area_intersection, _) = np.histogram(intersection, bins=numClass, range=(1, numClass)) 430 | 431 | # # Compute area union: 432 | # (area_pred, _) = np.histogram(imPred, bins=numClass, range=(1, numClass)) 433 | # (area_lab, _) = np.histogram(imLab, bins=numClass, range=(1, numClass)) 434 | # area_union = area_pred + area_lab - area_intersection 435 | # #return (area_intersection, area_union) 436 | # IoU = 1.0 * np.sum(area_intersection, axis=0) / np.sum(np.spacing(1)+area_union, axis=0) 437 | 438 | # print(IoU) 439 | 440 | #accuracy = iou_main 441 | 442 | # # # Compute kappa coefficient 443 | # total = np.sum(cm) 444 | # pa = np.trace(cm) / float(total) 445 | # pe = np.sum(np.sum(cm, axis=0) * np.sum(cm, axis=1)) / float(total*total) 446 | # kappa = (pa - pe) / (1 - pe); 447 | # print("Kappa: " + str(kappa)) 448 | 449 | return accuracy 450 | 451 | # # Dataset class 452 | class ISPRS_dataset(torch.utils.data.Dataset): 453 | def __init__(self, ids, data_files=DATA_FOLDER, label_files=LABEL_FOLDER, 454 | cache=False, augmentation=True): 455 | super(ISPRS_dataset, self).__init__() 456 | self.augmentation = augmentation 457 | self.cache = cache 458 | self.data_files = [DATA_FOLDER.format(id) for id in ids] 459 | self.label_files = [LABEL_FOLDER.format(id) for id in ids] 460 | for f in self.data_files + self.label_files: 461 | if not os.path.isfile(f): 462 | raise KeyError('{} is not a file !'.format(f)) 463 | self.data_cache_ = {} 464 | self.label_cache_ = {} 465 | 466 | def __len__(self): 467 | return 10000 468 | 469 | @classmethod 470 | def data_augmentation(cls, *arrays, flip=True, mirror=True): 471 | will_flip, will_mirror = False, False 472 | #will_rotate = False 473 | #will_rotate2 = False 474 | if flip and random.random() < 0.5: 475 | will_flip = True 476 | if mirror and random.random() < 0.5: 477 | will_mirror = True 478 | 479 | results = [] 480 | for array in arrays: 481 | if will_flip: 482 | if len(array.shape) == 2: 483 | array = array[::-1, :] 484 | else: 485 | array = array[:, ::-1, :] 486 | if will_mirror: 487 | if len(array.shape) == 2: 488 | array = array[:, ::-1] 489 | else: 490 | array = array[:, :, ::-1] 491 | 492 | results.append(np.copy(array)) 493 | 494 | return tuple(results) 495 | 496 | def __getitem__(self, i): 497 | random_idx = random.randint(0, len(self.data_files) - 1) 498 | if random_idx in self.data_cache_.keys(): 499 | data = self.data_cache_[random_idx] 500 | else: 501 | data = 1/255 * np.asarray(io.imread(self.data_files[random_idx]).transpose((2,0,1)), dtype='float32') 502 | if self.cache: 503 | self.data_cache_[random_idx] = data 504 | 505 | if random_idx in self.label_cache_.keys(): 506 | label = self.label_cache_[random_idx] 507 | else: 508 | label = np.asarray(convert_from_color(io.imread(self.label_files[random_idx])), dtype='int64') 509 | if self.cache: 510 | self.label_cache_[random_idx] = label 511 | 512 | x1, x2, y1, y2 = get_random_pos(data, WINDOW_SIZE) 513 | data_p = data[:, x1:x2,y1:y2] 514 | label_p = label[x1:x2,y1:y2] 515 | 516 | data_p, label_p = self.data_augmentation(data_p, label_p) 517 | 518 | return (torch.from_numpy(data_p), 519 | torch.from_numpy(label_p)) 520 | 521 | from torch.autograd import Function 522 | class ReverseLayerF(Function): 523 | @staticmethod 524 | def forward(ctx, x, alpha): 525 | ctx.alpha = alpha 526 | return x.view_as(x) 527 | 528 | @staticmethod 529 | def backward(ctx, grad_output): 530 | output = grad_output.neg() * ctx.alpha 531 | return output, None 532 | 533 | class SegNet(nn.Module): 534 | @staticmethod 535 | def weight_init(m): 536 | if isinstance(m, nn.Linear): 537 | torch.nn.init.kaiming_normal(m.weight.data) 538 | 539 | def __init__(self, in_channels=IN_CHANNELS, out_channels=N_CLASSES): 540 | super(SegNet, self).__init__() 541 | self.pool = nn.MaxPool2d(2, return_indices=True) 542 | self.unpool = nn.MaxUnpool2d(2) 543 | 544 | self.conv1_1 = nn.Conv2d(in_channels, 64, 3, padding=1) 545 | self.conv1_1_bn = nn.BatchNorm2d(64) 546 | self.conv1_2 = nn.Conv2d(64, 64, 3, padding=1) 547 | self.conv1_2_bn = nn.BatchNorm2d(64) 548 | 549 | self.conv2_1 = nn.Conv2d(64, 128, 3, padding=1) 550 | self.conv2_1_bn = nn.BatchNorm2d(128) 551 | self.conv2_2 = nn.Conv2d(128, 128, 3, padding=1) 552 | self.conv2_2_bn = nn.BatchNorm2d(128) 553 | 554 | self.conv3_1 = nn.Conv2d(128, 256, 3, padding=1) 555 | self.conv3_1_bn = nn.BatchNorm2d(256) 556 | self.conv3_2 = nn.Conv2d(256, 256, 3, padding=1) 557 | self.conv3_2_bn = nn.BatchNorm2d(256) 558 | self.conv3_3 = nn.Conv2d(256, 256, 3, padding=1) 559 | self.conv3_3_bn = nn.BatchNorm2d(256) 560 | 561 | self.conv4_1 = nn.Conv2d(256, 512, 3, padding=1) 562 | self.conv4_1_bn = nn.BatchNorm2d(512) 563 | self.conv4_2 = nn.Conv2d(512, 512, 3, padding=1) 564 | self.conv4_2_bn = nn.BatchNorm2d(512) 565 | self.conv4_3 = nn.Conv2d(512, 512, 3, padding=1) 566 | self.conv4_3_bn = nn.BatchNorm2d(512) 567 | 568 | self.conv5_1 = nn.Conv2d(512, 512, 3, padding=1) 569 | self.conv5_1_bn = nn.BatchNorm2d(512) 570 | self.conv5_2 = nn.Conv2d(512, 512, 3, padding=1) 571 | self.conv5_2_bn = nn.BatchNorm2d(512) 572 | self.conv5_3 = nn.Conv2d(512, 512, 3, padding=1) 573 | self.conv5_3_bn = nn.BatchNorm2d(512) 574 | 575 | self.ll00 = nn.Linear(512, 3) 576 | 577 | self.conv5_3_D = nn.Conv2d(512, 512, 3, padding=1) 578 | self.conv5_3_D_bn = nn.BatchNorm2d(512) 579 | self.conv5_2_D = nn.Conv2d(512, 512, 3, padding=1) 580 | self.conv5_2_D_bn = nn.BatchNorm2d(512) 581 | self.conv5_1_D = nn.Conv2d(512, 512, 3, padding=1) 582 | self.conv5_1_D_bn = nn.BatchNorm2d(512) 583 | 584 | self.conv4_3_D = nn.Conv2d(512, 512, 3, padding=1) 585 | self.conv4_3_D_bn = nn.BatchNorm2d(512) 586 | self.conv4_2_D = nn.Conv2d(512, 512, 3, padding=1) 587 | self.conv4_2_D_bn = nn.BatchNorm2d(512) 588 | self.conv4_1_D = nn.Conv2d(512, 256, 3, padding=1) 589 | self.conv4_1_D_bn = nn.BatchNorm2d(256) 590 | 591 | self.conv3_3_D = nn.Conv2d(256, 256, 3, padding=1) 592 | self.conv3_3_D_bn = nn.BatchNorm2d(256) 593 | self.conv3_2_D = nn.Conv2d(256, 256, 3, padding=1) 594 | self.conv3_2_D_bn = nn.BatchNorm2d(256) 595 | self.conv3_1_D = nn.Conv2d(256, 128, 3, padding=1) 596 | self.conv3_1_D_bn = nn.BatchNorm2d(128) 597 | 598 | self.conv2_2_D = nn.Conv2d(128, 128, 3, padding=1) 599 | self.conv2_2_D_bn = nn.BatchNorm2d(128) 600 | self.conv2_1_D = nn.Conv2d(128, 64, 3, padding=1) 601 | self.conv2_1_D_bn = nn.BatchNorm2d(64) 602 | 603 | self.conv1_2_D = nn.Conv2d(64, 64, 3, padding=1) 604 | self.conv1_2_D_bn = nn.BatchNorm2d(64) 605 | self.conv1_1_D = nn.Conv2d(64, out_channels, 3, padding=1) 606 | 607 | self.apply(self.weight_init) 608 | 609 | def forward(self, x, alpha): 610 | x = self.conv1_1_bn(F.relu(self.conv1_1(x))) 611 | x = self.conv1_2_bn(F.relu(self.conv1_2(x))) 612 | x, mask1 = self.pool(x) 613 | 614 | # Encoder block 2 615 | x = self.conv2_1_bn(F.relu(self.conv2_1(x))) 616 | x = self.conv2_2_bn(F.relu(self.conv2_2(x))) 617 | x, mask2 = self.pool(x) 618 | 619 | # Encoder block 3 620 | x = self.conv3_1_bn(F.relu(self.conv3_1(x))) 621 | x = self.conv3_2_bn(F.relu(self.conv3_2(x))) 622 | x = self.conv3_3_bn(F.relu(self.conv3_3(x))) 623 | x, mask3 = self.pool(x) 624 | 625 | # Encoder block 4 626 | x = self.conv4_1_bn(F.relu(self.conv4_1(x))) 627 | x = self.conv4_2_bn(F.relu(self.conv4_2(x))) 628 | x = self.conv4_3_bn(F.relu(self.conv4_3(x))) 629 | x, mask4 = self.pool(x) 630 | 631 | # Encoder block 5 632 | x = self.conv5_1_bn(F.relu(self.conv5_1(x))) 633 | x = self.conv5_2_bn(F.relu(self.conv5_2(x))) 634 | x = self.conv5_3_bn(F.relu(self.conv5_3(x))) 635 | x, mask5 = self.pool(x) 636 | 637 | #print(x) 638 | #print(x.shape) 639 | 640 | #xx22 = F.adaptive_avg_pool2d(x, 1).reshape(BATCH_SIZE, -1) 641 | #xx22 = F.adaptive_avg_pool2d(x, 1).reshape(2*BATCH_SIZE, -1) 642 | xx22 = F.adaptive_avg_pool2d(x, 1).reshape(BATCH_SIZE, -1) 643 | 644 | #xx22 = self.ll00(xx22) 645 | 646 | xx22 = ReverseLayerF.apply(xx22, alpha) 647 | 648 | xx22 = self.ll00(xx22) 649 | 650 | xx22 = F.softmax(xx22) 651 | 652 | # Decoder block 5 653 | x = self.unpool(x, mask5) 654 | x = self.conv5_3_D_bn(F.relu(self.conv5_3_D(x))) 655 | x = self.conv5_2_D_bn(F.relu(self.conv5_2_D(x))) 656 | x = self.conv5_1_D_bn(F.relu(self.conv5_1_D(x))) 657 | 658 | # Decoder block 4 659 | x = self.unpool(x, mask4) 660 | x = self.conv4_3_D_bn(F.relu(self.conv4_3_D(x))) 661 | x = self.conv4_2_D_bn(F.relu(self.conv4_2_D(x))) 662 | x = self.conv4_1_D_bn(F.relu(self.conv4_1_D(x))) 663 | 664 | # Decoder block 3 665 | x = self.unpool(x, mask3) 666 | x = self.conv3_3_D_bn(F.relu(self.conv3_3_D(x))) 667 | x = self.conv3_2_D_bn(F.relu(self.conv3_2_D(x))) 668 | x = self.conv3_1_D_bn(F.relu(self.conv3_1_D(x))) 669 | 670 | # # # Decoder block 2 671 | x = self.unpool(x, mask2) 672 | x = self.conv2_2_D_bn(F.relu(self.conv2_2_D(x))) 673 | x = self.conv2_1_D_bn(F.relu(self.conv2_1_D(x))) 674 | 675 | # Decoder block 1 676 | x = self.unpool(x, mask1) 677 | x = self.conv1_2_D_bn(F.relu(self.conv1_2_D(x))) 678 | x = F.log_softmax(self.conv1_1_D(x)) 679 | #return x 680 | return x, xx22 681 | 682 | def forward2(self, x): 683 | x = self.conv1_1_bn(F.relu(self.conv1_1(x))) 684 | x = self.conv1_2_bn(F.relu(self.conv1_2(x))) 685 | x, mask1 = self.pool(x) 686 | 687 | # Encoder block 2 688 | x = self.conv2_1_bn(F.relu(self.conv2_1(x))) 689 | x = self.conv2_2_bn(F.relu(self.conv2_2(x))) 690 | x, mask2 = self.pool(x) 691 | 692 | # Encoder block 3 693 | x = self.conv3_1_bn(F.relu(self.conv3_1(x))) 694 | x = self.conv3_2_bn(F.relu(self.conv3_2(x))) 695 | x = self.conv3_3_bn(F.relu(self.conv3_3(x))) 696 | x, mask3 = self.pool(x) 697 | 698 | # Encoder block 4 699 | x = self.conv4_1_bn(F.relu(self.conv4_1(x))) 700 | x = self.conv4_2_bn(F.relu(self.conv4_2(x))) 701 | x = self.conv4_3_bn(F.relu(self.conv4_3(x))) 702 | x, mask4 = self.pool(x) 703 | 704 | # Encoder block 5 705 | x = self.conv5_1_bn(F.relu(self.conv5_1(x))) 706 | x = self.conv5_2_bn(F.relu(self.conv5_2(x))) 707 | x = self.conv5_3_bn(F.relu(self.conv5_3(x))) 708 | x, mask5 = self.pool(x) 709 | 710 | #print(x) 711 | #print(x.shape) 712 | 713 | #xx22 = F.adaptive_avg_pool2d(x, 1).reshape(BATCH_SIZE, -1) 714 | #xx22 = F.adaptive_avg_pool2d(x, 1).reshape(2*BATCH_SIZE, -1) 715 | #xx22 = F.adaptive_avg_pool2d(x, 1).reshape(BATCH_SIZE, -1) 716 | 717 | #xx22 = self.ll00(xx22) 718 | 719 | #xx22 = ReverseLayerF.apply(xx22, alpha) 720 | 721 | #xx22 = self.ll00(xx22) 722 | 723 | #xx22 = F.softmax(xx22) 724 | 725 | # Decoder block 5 726 | x = self.unpool(x, mask5) 727 | x = self.conv5_3_D_bn(F.relu(self.conv5_3_D(x))) 728 | x = self.conv5_2_D_bn(F.relu(self.conv5_2_D(x))) 729 | x = self.conv5_1_D_bn(F.relu(self.conv5_1_D(x))) 730 | 731 | # Decoder block 4 732 | x = self.unpool(x, mask4) 733 | x = self.conv4_3_D_bn(F.relu(self.conv4_3_D(x))) 734 | x = self.conv4_2_D_bn(F.relu(self.conv4_2_D(x))) 735 | x = self.conv4_1_D_bn(F.relu(self.conv4_1_D(x))) 736 | 737 | # Decoder block 3 738 | x = self.unpool(x, mask3) 739 | x = self.conv3_3_D_bn(F.relu(self.conv3_3_D(x))) 740 | x = self.conv3_2_D_bn(F.relu(self.conv3_2_D(x))) 741 | x = self.conv3_1_D_bn(F.relu(self.conv3_1_D(x))) 742 | 743 | # # # Decoder block 2 744 | x = self.unpool(x, mask2) 745 | x = self.conv2_2_D_bn(F.relu(self.conv2_2_D(x))) 746 | x = self.conv2_1_D_bn(F.relu(self.conv2_1_D(x))) 747 | 748 | # Decoder block 1 749 | x = self.unpool(x, mask1) 750 | x = self.conv1_2_D_bn(F.relu(self.conv1_2_D(x))) 751 | x = F.log_softmax(self.conv1_1_D(x)) 752 | return x 753 | #return x, xx22 754 | 755 | net = SegNet() 756 | 757 | import os 758 | try: 759 | from urllib.request import URLopener 760 | except ImportError: 761 | from urllib import URLopener 762 | 763 | # # # Download VGG-16 weights from PyTorch 764 | vgg_url = 'https://download.pytorch.org/models/vgg16_bn-6c64b313.pth' 765 | if not os.path.isfile('./vgg16_bn-6c64b313.pth'): 766 | weights = URLopener().retrieve(vgg_url, './vgg16_bn-6c64b313.pth') 767 | 768 | vgg16_weights = torch.load('./vgg16_bn-6c64b313.pth') 769 | mapped_weights = {} 770 | for k_vgg, k_segnet in zip(vgg16_weights.keys(), net.state_dict().keys()): 771 | if "features" in k_vgg: 772 | mapped_weights[k_segnet] = vgg16_weights[k_vgg] 773 | print("Mapping {} to {}".format(k_vgg, k_segnet)) 774 | 775 | try: 776 | net.load_state_dict(mapped_weights) 777 | print("Loaded VGG-16 weights in SegNet !") 778 | except: 779 | # # Ignore missing keys 780 | pass 781 | 782 | # # Then, we load the network on GPU. 783 | #net.cuda() 784 | 785 | # For SegFormer 786 | # We modify: https://github.com/huggingface/transformers/blob/v4.33.0/src/transformers/models/segformer/modeling_segformer.py#L746 787 | 788 | # The model SegFormer 789 | # # This is to test: torch.save(model.state_dict(), './segformermain30082023') 790 | # # Also: data_dir = '../../CVUSA/bingmap/mainTheDataset' 791 | 792 | # # We have modified: https://github.com/huggingface/transformers/blob/v4.33.0/src/transformers/models/segformer/modeling_segformer.py#L746 793 | #from transformers import SegformerForSemanticSegmentation 794 | 795 | from transformers import SegformerForSemanticSegmentation 796 | #import transformers 797 | 798 | id2label = {0 : 'Impervious surfaces', 799 | 1 : 'Buildings', 800 | 2 : 'Low vegetation', 801 | 3 : 'Trees', 802 | 4 : 'Cars', 803 | 5 : 'Clutter'} 804 | label2id = {v: k for k, v in id2label.items()} 805 | 806 | # net = SegformerForSemanticSegmentation.from_pretrained("nvidia/mit-b0", 807 | # num_labels=6, 808 | # id2label=id2label, 809 | # label2id=label2id, 810 | # alpha=0.0, 811 | # ) 812 | # net = SegformerForSemanticSegmentation.from_pretrained("nvidia/mit-b5", 813 | # num_labels=6, 814 | # id2label=id2label, 815 | # label2id=label2id, 816 | # alpha=0.0, 817 | # ) 818 | 819 | # Use B5 820 | net = SegformerForSemanticSegmentation.from_pretrained("nvidia/mit-b5", 821 | num_labels=6, 822 | id2label=id2label, 823 | label2id=label2id, 824 | alpha=0.0, 825 | ) 826 | # # We use b5 827 | 828 | #net = SegNet() 829 | #net.cuda() 830 | net.cuda() 831 | 832 | # ### Loading the data 833 | # # # Load the datasets 834 | if DATASET == 'Potsdam': 835 | all_files = sorted(glob(LABEL_FOLDER.replace('{}', '*'))) 836 | all_ids = ["".join(f.split('')[5:7]) for f in all_files] 837 | elif DATASET == 'Vaihingen': 838 | #all_ids = 839 | all_files = sorted(glob(LABEL_FOLDER.replace('{}', '*'))) 840 | all_ids = [f.split('area')[-1].split('.')[0] for f in all_files] 841 | test_ids = list(set(all_ids) - set(train_ids)) 842 | print("Tiles for training : ", train_ids) 843 | print("Tiles for testing : ", test_ids) 844 | train_set = ISPRS_dataset(train_ids, cache=CACHE) 845 | train_loader = torch.utils.data.DataLoader(train_set,batch_size=BATCH_SIZE) 846 | 847 | # The optimizer 848 | #base_lr = 0.01 849 | #base_lr = 0.04 850 | #base_lr = 0.08 851 | base_lr = 0.01 852 | 853 | params_dict = dict(net.named_parameters()) 854 | params = [] 855 | for key, value in params_dict.items(): 856 | if '_D' in key: 857 | params += [{'params':[value],'lr': base_lr}] 858 | else: 859 | params += [{'params':[value],'lr': base_lr / 2}] 860 | 861 | optimizer = optim.SGD(net.parameters(), lr=base_lr, momentum=0.9, weight_decay=0.0005) 862 | scheduler = optim.lr_scheduler.MultiStepLR(optimizer, [25, 35, 45], gamma=0.1) 863 | 864 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 865 | 866 | def test(net, test_ids, all=False, stride=WINDOW_SIZE[0], batch_size=BATCH_SIZE, window_size=WINDOW_SIZE): 867 | # # # # # /home/nikolaos/CVUSA/bingmap 868 | 869 | test_images = (1 / 255 * np.asarray(io.imread(DATA_FOLDER.format(id)), dtype='float32') for id in test_ids) 870 | test_labels = (np.asarray(io.imread(LABEL_FOLDER.format(id)), dtype='uint8') for id in test_ids) 871 | eroded_labels = (convert_from_color(io.imread(ERODED_FOLDER.format(id))) for id in test_ids) 872 | 873 | # # # # # /home/nikolaos/CVUSA/bingmap 874 | #print(test_images) 875 | #print(test_images.size) 876 | 877 | all_preds = [] 878 | all_gts = [] 879 | 880 | # # # Switch the network to inference mode 881 | #net.eval() 882 | 883 | net.eval() 884 | for img, gt, gt_e in tqdm(zip(test_images, test_labels, eroded_labels), total=len(test_ids), leave=False): 885 | #pred = np.zeros(img.shape[:2] + (N_CLASSES,)) 886 | 887 | # fig = plt.figure() 888 | # plt.imshow(np.asarray(255 * img, dtype='uint8')) 889 | # plt.axis('off') 890 | # plt.show() 891 | # print(img) 892 | 893 | #print(img.shape) 894 | 895 | #from PIL import Image 896 | #img2 = Image.open('/home/nikolaos/CVUSA/bingmap/0000001.jpg') 897 | 898 | #print(DATA_FOLDER.format(id)) 899 | 900 | #img2 = (1 / 255 * np.asarray(io.imread(DATA_FOLDER.format(id)), dtype='float32') for id in test_ids) 901 | 902 | #img2 = (1 / 255 * np.asarray(io.imread('/home/nikolaos/CVUSA/bingmap/0000001.jpg'), dtype='float32') for id in test_ids) 903 | 904 | # fig = plt.figure() 905 | # #plt.imshow(np.asarray(255 * img2, dtype='uint8')) 906 | # #plt.imshow(img2) 907 | # plt.imshow(np.asarray(255 * img2, dtype='uint8')) 908 | # plt.axis('off') 909 | # plt.show() 910 | 911 | # print(img2) 912 | 913 | #print(img2.shape) 914 | 915 | #img = (1 / 255 * np.asarray(io.imread('/home/nikolaos/CVUSA/bingmap/0000001.jpg'), dtype='float32')) 916 | 917 | #img = (1 / 255 * np.asarray(io.imread('/home/nikolaos/CVUSA/bingmap/19/0000017.jpg'), dtype='float32')) 918 | 919 | countertotal = 0 920 | foraverage = 0 921 | foraverage2 = 0 922 | foraverage3 = 0 923 | 924 | #data_dir = '/Data/ndionelis/bingmap/main_dataset/val/1' 925 | 926 | data_dir = '/Data/ndionelis/bingmap/mainTheDatasetNoSubsa2/val/1' 927 | 928 | # # This is to test: torch.save(model.state_dict(), './segformermain30082023') 929 | # # Also: data_dir = '../../CVUSA/bingmap/mainTheDataset' 930 | 931 | for file in os.listdir(data_dir): 932 | img = (1 / 255 * np.asarray(io.imread(os.path.join(data_dir, file)), dtype='float32')) 933 | 934 | countertotal += 1 935 | counterr11 = file 936 | #print(counterr11) 937 | 938 | #img = (1 / 255 * np.asarray(io.imread('/home/nikolaos/CVUSA/bingmap/19/'+str(44515).zfill(7)+'.jpg'), dtype='float32')) 939 | 940 | pred = np.zeros(img.shape[:2] + (N_CLASSES,)) 941 | 942 | #img = (1 / 255 * np.asarray(io.imread('/home/nikolaos/CVUSA/bingmap/19/0000019.jpg'), dtype='float32')) 943 | total = count_sliding_window(img, step=stride, window_size=window_size) // batch_size 944 | for i, coords in enumerate(tqdm(grouper(batch_size, sliding_window(img, step=stride, window_size=window_size)), total=total, leave=False)): 945 | # # # # # Display in progress results 946 | if i > 0 and total > 10 and i % int(10 * total / 100) == 0: 947 | _pred = np.argmax(pred, axis=-1) 948 | # fig = plt.figure() 949 | # #fig.add_subplot(1,3,1) 950 | # fig.add_subplot(1,2,1) 951 | # plt.imshow(np.asarray(255 * img, dtype='uint8')) 952 | # plt.axis('off') 953 | # #fig.add_subplot(1,3,2) 954 | # fig.add_subplot(1,2,2) 955 | # plt.imshow(convert_to_color(_pred)) 956 | # plt.axis('off') 957 | # #fig.add_subplot(1,3,3) 958 | # #plt.imshow(gt) 959 | # #plt.axis('off') 960 | # #clear_output() 961 | # plt.show() 962 | 963 | # # # # # Build the tensor 964 | image_patches = [np.copy(img[x:x+w, y:y+h]).transpose((2,0,1)) for x,y,w,h in coords] 965 | image_patches = np.asarray(image_patches) 966 | #image_patches = Variable(torch.from_numpy(image_patches).cuda(), volatile=True) 967 | with torch.no_grad(): 968 | #image_patches = Variable(torch.from_numpy(image_patches).cuda(), volatile=True) 969 | image_patches = Variable(torch.from_numpy(image_patches).cuda()) 970 | 971 | # # /home/nikolaos/CVUSA/bingmsap 972 | 973 | #encoding = feature_extractor(image_patches, return_tensors="pt") 974 | #pixel_values = encoding.pixel_values.to(device) 975 | 976 | pixel_values = image_patches.to(device) 977 | 978 | outputs = net(pixel_values=pixel_values) 979 | 980 | logits = outputs.logits.cpu() 981 | 982 | upsampled_logits = nn.functional.interpolate(logits, 983 | size=image_patches.shape[3], # # (height, width) 984 | mode='bilinear', 985 | align_corners=False) 986 | 987 | #seg = upsampled_logits.argmax(dim=1)[0] 988 | 989 | outs = upsampled_logits.detach().cpu().numpy() 990 | 991 | #print(image_patches.shape) 992 | #print(image_patches) 993 | 994 | # # /home/nikolaos/CVUSA/bingmsap 995 | 996 | # # # # # Do the inference 997 | #outs = net(image_patches) 998 | #outs, _ = net(image_patches, 0.01) 999 | #outs = net.forward2(image_patches) 1000 | 1001 | #outs = net.forward2(image_patches) 1002 | 1003 | #outputs = model(pixel_values=pixel_values, labels=labels) 1004 | 1005 | #outs = outs.data.cpu().numpy() 1006 | 1007 | # # # Fill in the results array 1008 | for out, (x, y, w, h) in zip(outs, coords): 1009 | out = out.transpose((1,2,0)) 1010 | pred[x:x+w, y:y+h] += out 1011 | del(outs) 1012 | 1013 | #break 1014 | 1015 | pred = np.argmax(pred, axis=-1) 1016 | 1017 | # # # # # Display the result 1018 | #clear_output() 1019 | fig = plt.figure() 1020 | #fig.add_subplot(1,3,1) 1021 | #fig.add_subplot(1,2,1) 1022 | plt.imshow(np.asarray(255 * img, dtype='uint8')) 1023 | plt.axis('off') 1024 | #plt.show() 1025 | #fig.add_subplot(1,3,2) 1026 | #fig.add_subplot(1,2,2) 1027 | print(counterr11) 1028 | print(countertotal) 1029 | counterr11 = counterr11[:-4] 1030 | #print(counterr11) 1031 | try: 1032 | plt.savefig('./inputs/input'+str(counterr11)+'.png', bbox_inches='tight') 1033 | except: 1034 | os.mkdir('inputs') 1035 | plt.savefig('./inputs/input'+str(counterr11)+'.png', bbox_inches='tight') 1036 | #plt.savefig('inputim1.png', bbox_inches='tight') 1037 | fig2 = plt.figure() 1038 | plt.imshow(convert_to_color(pred)) 1039 | plt.axis('off') 1040 | #plt.show() 1041 | #fig.add_subplot(1,3,3) 1042 | #plt.imshow(gt) 1043 | #plt.axis('off') 1044 | #plt.show() 1045 | #plt.pause(10) 1046 | try: 1047 | plt.savefig('./outputs/output'+str(counterr11)+'.png', bbox_inches='tight') 1048 | except: 1049 | os.mkdir('outputs') 1050 | plt.savefig('./outputs/output'+str(counterr11)+'.png', bbox_inches='tight') 1051 | #plt.savefig('outputim1.png', bbox_inches='tight') 1052 | 1053 | all_preds.append(pred) 1054 | #all_gts.append(gt_e) 1055 | 1056 | import pytorch_ssim 1057 | #import torch 1058 | #from torch.autograd import Variable 1059 | 1060 | #img1 = Variable(torch.rand(1, 1, 256, 256)) 1061 | #img2 = Variable(torch.rand(1, 1, 256, 256)) 1062 | 1063 | from pytorch_msssim import ssim, ms_ssim, SSIM, MS_SSIM 1064 | # X: (N,3,H,W) a batch of non-negative RGB images (0~255) 1065 | # Y: (N,3,H,W) 1066 | 1067 | fig3 = plt.figure() 1068 | plt.imshow(convert_to_color(pred)) 1069 | plt.axis('off') 1070 | #plt.show() 1071 | #fig.add_subplot(1,3,3) 1072 | #plt.imshow(gt) 1073 | #plt.axis('off') 1074 | #plt.show() 1075 | #plt.pause(10) 1076 | #plt.figtext(0.5, 0.01, str(loss.item())+' '+str(ssim_val.squeeze().item())+' '+str(iou_pytorch(a, b).item()), wrap=True, horizontalalignment='center', fontsize=12) 1077 | #"{:.2f}".format(round(a, 2)) 1078 | plt.figtext(0.5, 0.01, "RMSE: "+str("{:.4f}".format(loss.item()))+', SSIM: '+str("{:.4f}".format(ssim_val.squeeze().item()))+', IoU: '+str("{:.4f}".format(iou_pytorch(a, b).item())), wrap=True, horizontalalignment='center', fontsize=12) 1079 | try: 1080 | plt.savefig('./outputs/output'+str(counterr11)+'b.png', bbox_inches='tight') 1081 | except: 1082 | os.mkdir('outputs') 1083 | plt.savefig('./outputs/output'+str(counterr11)+'b.png', bbox_inches='tight') 1084 | 1085 | plt.close('all') 1086 | 1087 | clear_output() 1088 | torch.cuda.empty_cache() 1089 | 1090 | time.sleep(5) 1091 | 1092 | #clear_output() 1093 | 1094 | #metrics(pred.ravel(), gt_e.ravel()) 1095 | #accuracy = metrics(np.concatenate([p.ravel() for p in all_preds]), np.concatenate([p.ravel() for p in all_gts]).ravel()) 1096 | 1097 | #break 1098 | 1099 | foraverage /= countertotal 1100 | foraverage2 /= countertotal 1101 | foraverage3 /= countertotal 1102 | 1103 | print(foraverage) 1104 | print(foraverage2) 1105 | print(foraverage3) 1106 | 1107 | #clear_output() 1108 | 1109 | #break 1110 | 1111 | if all: 1112 | return accuracy, all_preds, all_gts 1113 | else: 1114 | return accuracy 1115 | 1116 | from IPython.display import clear_output 1117 | 1118 | def train(net, optimizer, epochs, scheduler=None, weights=WEIGHTS, save_epoch = 5): 1119 | losses = np.zeros(1000000) 1120 | mean_losses = np.zeros(100000000) 1121 | weights = weights.cuda() 1122 | 1123 | criterion = nn.NLLLoss2d(weight=weights) 1124 | 1125 | iter_ = 0 1126 | 1127 | #train(net, optimizer, 10, scheduler) 1128 | 1129 | #net.load_state_dict(torch.load('./segformermain30082023')) 1130 | #net.load_state_dict(torch.load('./segnet_finale')) 1131 | 1132 | #net.load_state_dict(torch.load('./segformermain300820232')) 1133 | 1134 | net.load_state_dict(torch.load('/Data/ndionelis/formodels/srmr2ehw4h44')) 1135 | 1136 | #DATASET = 'Vaihingen' 1137 | DATASET = 'Potsdam' 1138 | 1139 | if DATASET == 'Potsdam': 1140 | MAIN_FOLDER = FOLDER + 'Potsdam/' 1141 | # Uncomment the next line for IRRG data 1142 | # DATA_FOLDER = MAIN_FOLDER + '3_Ortho_IRRG/top_potsdam_{}_IRRG.tif' 1143 | # For RGB data 1144 | DATA_FOLDER = MAIN_FOLDER + '2_Ortho_RGB/top_potsdam_{}_RGB.tif' 1145 | LABEL_FOLDER = MAIN_FOLDER + '5_Labels_for_participants/top_potsdam_{}_label.tif' 1146 | ERODED_FOLDER = MAIN_FOLDER + '5_Labels_for_participants_no_Boundary/top_potsdam_{}_label_noBoundary.tif' 1147 | elif DATASET == 'Vaihingen': 1148 | MAIN_FOLDER = FOLDER + 'Vaihingen/' 1149 | DATA_FOLDER = MAIN_FOLDER + 'top/top_mosaic_09cm_area{}.tif' 1150 | LABEL_FOLDER = MAIN_FOLDER + 'gts_for_participants/top_mosaic_09cm_area{}.tif' 1151 | ERODED_FOLDER = MAIN_FOLDER + 'gts_eroded_for_participants/top_mosaic_09cm_area{}_noBoundary.tif' 1152 | 1153 | print("Tiles for testing : ", test_ids) 1154 | 1155 | def test(net, test_ids, all=False, stride=WINDOW_SIZE[0], batch_size=BATCH_SIZE, window_size=WINDOW_SIZE): 1156 | test_images = (1 / 255 * np.asarray(io.imread(DATA_FOLDER.format(id)), dtype='float32') for id in test_ids) 1157 | 1158 | test_labels = (np.asarray(io.imread(LABEL_FOLDER.format(id)), dtype='uint8') for id in test_ids) 1159 | eroded_labels = (convert_from_color(io.imread(ERODED_FOLDER.format(id))) for id in test_ids) 1160 | all_preds = [] 1161 | all_gts = [] 1162 | 1163 | net.eval() 1164 | for img, gt, gt_e in tqdm(zip(test_images, test_labels, eroded_labels), total=len(test_ids), leave=False): 1165 | pred = np.zeros(img.shape[:2] + (N_CLASSES,)) 1166 | 1167 | total = count_sliding_window(img, step=stride, window_size=window_size) // batch_size 1168 | for i, coords in enumerate(tqdm(grouper(batch_size, sliding_window(img, step=stride, window_size=window_size)), total=total, leave=False)): 1169 | if i > 0 and total > 10 and i % int(10 * total / 100) == 0: 1170 | _pred = np.argmax(pred, axis=-1) 1171 | fig = plt.figure() 1172 | fig.add_subplot(1,3,1) 1173 | plt.imshow(np.asarray(255 * img, dtype='uint8')) 1174 | fig.add_subplot(1,3,2) 1175 | plt.imshow(convert_to_color(_pred)) 1176 | fig.add_subplot(1,3,3) 1177 | plt.imshow(gt) 1178 | clear_output() 1179 | plt.show() 1180 | 1181 | image_patches = [np.copy(img[x:x+w, y:y+h]).transpose((2,0,1)) for x,y,w,h in coords] 1182 | image_patches = np.asarray(image_patches) 1183 | #image_patches = Variable(torch.from_numpy(image_patches).cuda(), volatile=True) 1184 | 1185 | with torch.no_grad(): 1186 | image_patches = Variable(torch.from_numpy(image_patches).cuda()) 1187 | 1188 | # # Do the inference 1189 | # outs = net(image_patches) 1190 | # outs = outs.data.cpu().numpy() 1191 | 1192 | pixel_values = image_patches.to(device) 1193 | outputs = net(pixel_values=pixel_values) 1194 | 1195 | logits = outputs.logits.cpu() 1196 | upsampled_logits = nn.functional.interpolate(logits, 1197 | size=image_patches.shape[3], # # (height, width) 1198 | mode='bilinear', 1199 | align_corners=False) 1200 | 1201 | #print(upsampled_logits.shape) 1202 | # # torch.Size([10, 150, 256, 256]) 1203 | 1204 | #seg = upsampled_logits.argmax(dim=1)[0] 1205 | 1206 | outs = upsampled_logits.detach().cpu().numpy() 1207 | 1208 | for out, (x, y, w, h) in zip(outs, coords): 1209 | out = out.transpose((1,2,0)) 1210 | pred[x:x+w, y:y+h] += out 1211 | del(outs) 1212 | 1213 | pred = np.argmax(pred, axis=-1) 1214 | 1215 | clear_output() 1216 | 1217 | fig = plt.figure() 1218 | #fig.add_subplot(1,3,1) 1219 | plt.imshow(np.asarray(255 * img, dtype='uint8')) 1220 | plt.axis('off') 1221 | plt.savefig('Input1.png', bbox_inches='tight') 1222 | 1223 | fig = plt.figure() 1224 | #fig.add_subplot(1,3,2) 1225 | plt.imshow(convert_to_color(pred)) 1226 | plt.axis('off') 1227 | plt.savefig('Output1.png', bbox_inches='tight') 1228 | 1229 | fig = plt.figure() 1230 | plt.imshow(gt) 1231 | plt.axis('off') 1232 | plt.savefig('Correct1.png', bbox_inches='tight') 1233 | 1234 | all_preds.append(pred) 1235 | all_gts.append(gt_e) 1236 | 1237 | clear_output() 1238 | 1239 | metrics(pred.ravel(), gt_e.ravel()) 1240 | accuracy = metrics(np.concatenate([p.ravel() for p in all_preds]), np.concatenate([p.ravel() for p in all_gts]).ravel()) 1241 | 1242 | if all: 1243 | return accuracy, all_preds, all_gts 1244 | else: 1245 | return accuracy 1246 | 1247 | #_, all_preds, all_gts = test(net, test_ids, all=True, stride=32) 1248 | 1249 | accAccuracy, all_preds, all_gts = test(net, test_ids, all=True, stride=32) 1250 | 1251 | print('') 1252 | #print(accAccuracy) 1253 | 1254 | print(accAccuracy) 1255 | 1256 | for p, id_ in zip(all_gts, test_ids): 1257 | img = convert_to_color(p) 1258 | plt.imshow(img) and plt.show() 1259 | #io.imsave('./inference_tile_{}.png'.format(id_), img) 1260 | #io.imsave('./testing2_tile2_{}.png'.format(id_), img) 1261 | #io.imsave('./tst1_tl1_{}.png'.format(id_), img) 1262 | #io.imsave('./test2_tl2_{}.png'.format(id_), img) 1263 | io.imsave('./testt2_ttll2_{}.png'.format(id_), img) 1264 | 1265 | #_, all_preds, all_gts = test(net, test_ids, all=True, stride=32) 1266 | 1267 | -------------------------------------------------------------------------------- /src/train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | # Define PyTorch dataset and dataloaders 5 | from torch.utils.data import Dataset 6 | import os 7 | from PIL import Image 8 | import torch 9 | import numpy as np 10 | import random 11 | 12 | # We set the random seed 13 | #SEED = 17 14 | #SEED = 71 15 | SEED = random.randint(1, 10000) 16 | print(SEED) 17 | torch.cuda.empty_cache() 18 | random.seed(SEED) 19 | np.random.seed(SEED) 20 | torch.manual_seed(SEED) 21 | torch.cuda.manual_seed_all(SEED) 22 | 23 | #NUMWORKERS = 6 24 | NUMWORKERS = 0 25 | 26 | # We have modified: https://github.com/huggingface/transformers/blob/v4.33.0/src/transformers/models/segformer/modeling_segformer.py#L746 27 | 28 | import requests, zipfile, io 29 | def download_data(): 30 | url = "https://www.dropbox.com/s/l1e45oht447053f/ADE20k_toy_dataset.zip?dl=1" 31 | r = requests.get(url) 32 | z = zipfile.ZipFile(io.BytesIO(r.content)) 33 | z.extractall() 34 | download_data() 35 | from datasets import load_dataset 36 | load_entire_dataset = False 37 | if load_entire_dataset: 38 | dataset = load_dataset("scene_parse_150") 39 | 40 | class SemanticSegmentationDataset(Dataset): 41 | def __init__(self, root_dir, feature_extractor, train=True): 42 | self.root_dir = root_dir 43 | self.feature_extractor = feature_extractor 44 | self.train = train 45 | sub_path = "training" if self.train else "validation" 46 | self.img_dir = os.path.join(self.root_dir, "images", sub_path) 47 | self.ann_dir = os.path.join(self.root_dir, "annotations", sub_path) 48 | image_file_names = [] 49 | for root, dirs, files in os.walk(self.img_dir): 50 | image_file_names.extend(files) 51 | self.images = sorted(image_file_names) 52 | annotation_file_names = [] 53 | for root, dirs, files in os.walk(self.ann_dir): 54 | annotation_file_names.extend(files) 55 | self.annotations = sorted(annotation_file_names) 56 | assert len(self.images) == len(self.annotations), "There must be as many images as there are segmentation maps" 57 | 58 | def __len__(self): 59 | return len(self.images) 60 | 61 | def __getitem__(self, idx): 62 | image = Image.open(os.path.join(self.img_dir, self.images[idx])) 63 | segmentation_map = Image.open(os.path.join(self.ann_dir, self.annotations[idx])) 64 | encoded_inputs = self.feature_extractor(image, segmentation_map, return_tensors="pt") 65 | for k,v in encoded_inputs.items(): 66 | encoded_inputs[k].squeeze_() 67 | return encoded_inputs 68 | 69 | from transformers import SegformerFeatureExtractor 70 | #import transformers 71 | root_dir = './ADE20k_toy_dataset' 72 | feature_extractor = SegformerFeatureExtractor(reduce_labels=True) 73 | train_dataset = SemanticSegmentationDataset(root_dir=root_dir, feature_extractor=feature_extractor) 74 | valid_dataset = SemanticSegmentationDataset(root_dir=root_dir, feature_extractor=feature_extractor, train=False) 75 | 76 | print("Number of training examples:", len(train_dataset)) 77 | print("Number of validation examples:", len(valid_dataset)) 78 | 79 | encoded_inputs = train_dataset[0] 80 | encoded_inputs["pixel_values"].shape 81 | encoded_inputs["labels"].shape 82 | encoded_inputs["labels"] 83 | encoded_inputs["labels"].squeeze().unique() 84 | 85 | # Define corresponding dataloaders 86 | from torch.utils.data import DataLoader 87 | train_dataloader = DataLoader(train_dataset, batch_size=2, shuffle=True) 88 | valid_dataloader = DataLoader(valid_dataset, batch_size=2) 89 | batch = next(iter(train_dataloader)) 90 | for k,v in batch.items(): 91 | print(k, v.shape) 92 | batch["labels"].shape 93 | 94 | mask = (batch["labels"] != 255) 95 | mask 96 | 97 | batch["labels"][mask] 98 | 99 | # We define the model 100 | from transformers import SegformerForSemanticSegmentation 101 | import json 102 | from huggingface_hub import cached_download, hf_hub_url 103 | 104 | # # # load id2label mapping from a JSON on the hub 105 | # repo_id = "datasets/huggingface/label-files" 106 | # filename = "ade20k-id2label.json" 107 | 108 | # id2label = json.load(open(cached_download(hf_hub_url(repo_id, filename, repo_type="dataset")), "r")) 109 | # id2label = {int(k): v for k, v in id2label.items()} 110 | # label2id = {v: k for k, v in id2label.items()} 111 | 112 | id2label = json.load(open("./ade20k-id2label.json", "r")) 113 | id2label = {int(k): v for k, v in id2label.items()} 114 | label2id = {v: k for k, v in id2label.items()} 115 | 116 | # Define our model 117 | # model = SegformerForSemanticSegmentation.from_pretrained("nvidia/mit-b0", 118 | # num_labels=6, 119 | # id2label=id2label, 120 | # label2id=label2id, 121 | # ) 122 | 123 | from datasets import load_metric 124 | metric = load_metric("mean_iou") 125 | 126 | #import torch 127 | from torch import nn 128 | from sklearn.metrics import accuracy_score 129 | from tqdm.notebook import tqdm 130 | from skimage import io 131 | import matplotlib.pyplot as plt 132 | from glob import glob 133 | from tqdm import tqdm_notebook as tqdm 134 | from sklearn.metrics import confusion_matrix 135 | #import random 136 | import itertools 137 | # # Matplotlib 138 | import matplotlib.pyplot as plt 139 | #get_ipython().run_line_magic('matplotlib', 'inline') 140 | #from IPython import get_ipython 141 | #get_ipython().run_line_magic('matplotlib', 'inline') 142 | #exec(%matplotlib inline) 143 | # # Torch imports 144 | import torch 145 | import torch.nn as nn 146 | import torch.nn.functional as F 147 | import torch.utils.data as data 148 | import torch.optim as optim 149 | import torch.optim.lr_scheduler 150 | import torch.nn.init 151 | from torch.autograd import Variable 152 | import torchvision.transforms as T 153 | import albumentations as A 154 | import segmentation_models_pytorch as smp 155 | import kornia 156 | 157 | # # Parameters 158 | WINDOW_SIZE = (256, 256) # # Patch size 159 | STRIDE = 32 # Stride for testing 160 | IN_CHANNELS = 3 # # Number of input channels (e.g. RGB) 161 | #FOLDER = "./ISPRS_dataset/" # Replace with your "/path/to/the/ISPRS/dataset/folder/" 162 | #FOLDER = "../../" 163 | FOLDER = "/Data/ndionelis/" 164 | #BATCH_SIZE = 64 165 | #BATCH_SIZE = 128 166 | BATCH_SIZE = 10 167 | #BATCH_SIZE = 30 168 | 169 | LABELS = ["roads", "buildings", "low veg.", "trees", "cars", "clutter"] # # Label names 170 | N_CLASSES = len(LABELS) # # Number of classes 171 | #print(N_CLASSES) 172 | 173 | WEIGHTS = torch.ones(N_CLASSES) # # # Weights for class balancing 174 | CACHE = True # # # Store the dataset in-memory 175 | 176 | #DATASET = 'Vaihingen' 177 | DATASET = 'Potsdam' 178 | 179 | MAIN_FOLDER = FOLDER + 'Potsdam/' 180 | # # Uncomment the next line for IRRG data 181 | # DATA_FOLDER = MAIN_FOLDER + '3_Ortho_IRRG/top_potsdam_{}_IRRG.tif' 182 | # # For RGB data 183 | #print(MAIN_FOLDER) 184 | DATA_FOLDER = MAIN_FOLDER + '2_Ortho_RGB/top_potsdam_{}_RGB.tif' 185 | #LABEL_FOLDER = MAIN_FOLDER + '5_Labels_for_participants/top_potsdam_{}_label.tif' 186 | LABEL_FOLDER = MAIN_FOLDER + 'top_potsdam_{}_label.tif' 187 | #ERODED_FOLDER = MAIN_FOLDER + '5_Labels_for_participants_no_Boundary/top_potsdam_{}_label_noBoundary.tif' 188 | ERODED_FOLDER = MAIN_FOLDER + 'top_potsdam_{}_label_noBoundary.tif' 189 | 190 | # if DATASET == 'Potsdam': 191 | # MAIN_FOLDER = FOLDER + 'Potsdam/' 192 | # # # Uncomment the next line for IRRG data 193 | # # DATA_FOLDER = MAIN_FOLDER + '3_Ortho_IRRG/top_potsdam_{}_IRRG.tif' 194 | # # # For RGB data 195 | # #print(MAIN_FOLDER) 196 | # #sadfszf 197 | # DATA_FOLDER = MAIN_FOLDER + '2_Ortho_RGB/top_potsdam_{}_RGB.tif' 198 | # #LABEL_FOLDER = MAIN_FOLDER + '5_Labels_for_participants/top_potsdam_{}_label.tif' 199 | # LABEL_FOLDER = MAIN_FOLDER + 'top_potsdam_{}_label.tif' 200 | # #ERODED_FOLDER = MAIN_FOLDER + '5_Labels_for_participants_no_Boundary/top_potsdam_{}_label_noBoundary.tif' 201 | # ERODED_FOLDER = MAIN_FOLDER + 'top_potsdam_{}_label_noBoundary.tif' 202 | 203 | # elif DATASET == 'Vaihingen': 204 | # MAIN_FOLDER = FOLDER + 'Vaihingen/' 205 | # #print(MAIN_FOLDER) 206 | # #asdfszdf 207 | # DATA_FOLDER = MAIN_FOLDER + 'top/top_mosaic_09cm_area{}.tif' 208 | # LABEL_FOLDER = MAIN_FOLDER + 'gts_for_participants/top_mosaic_09cm_area{}.tif' 209 | # ERODED_FOLDER = MAIN_FOLDER + 'gts_eroded_for_participants/top_mosaic_09cm_area{}_noBoundary.tif' 210 | 211 | # # ISPRS color palette 212 | # # # Let's define the standard ISPRS color palette 213 | palette = {0 : (255, 255, 255), # Impervious surfaces (white) 214 | 1 : (0, 0, 255), # # Buildings (blue) 215 | 2 : (0, 255, 255), # Low vegetation (cyan) 216 | 3 : (0, 255, 0), # # Trees (green) 217 | 4 : (255, 255, 0), # Cars (yellow) 218 | 5 : (255, 0, 0), # Clutter (red) 219 | 6 : (0, 0, 0)} # Undefined (black) 220 | invert_palette = {v: k for k, v in palette.items()} 221 | 222 | def convert_to_color(arr_2d, palette=palette): 223 | arr_3d = np.zeros((arr_2d.shape[0], arr_2d.shape[1], 3), dtype=np.uint8) 224 | for c, i in palette.items(): 225 | m = arr_2d == c 226 | arr_3d[m] = i 227 | return arr_3d 228 | 229 | def convert_from_color(arr_3d, palette=invert_palette): 230 | arr_2d = np.zeros((arr_3d.shape[0], arr_3d.shape[1]), dtype=np.uint8) 231 | for c, i in palette.items(): 232 | m = np.all(arr_3d == np.array(c).reshape(1, 1, 3), axis=2) 233 | arr_2d[m] = i 234 | return arr_2d 235 | 236 | def get_random_pos(img, window_shape): 237 | w, h = window_shape 238 | W, H = img.shape[-2:] 239 | x1 = random.randint(0, W - w - 1) 240 | x2 = x1 + w 241 | y1 = random.randint(0, H - h - 1) 242 | y2 = y1 + h 243 | return x1, x2, y1, y2 244 | 245 | def CrossEntropy2d(input, target, weight=None, size_average=True): 246 | dim = input.dim() 247 | if dim == 2: 248 | return F.cross_entropy(input, target, weight, size_average) 249 | elif dim == 4: 250 | output = input.view(input.size(0),input.size(1), -1) 251 | output = torch.transpose(output,1,2).contiguous() 252 | output = output.view(-1,output.size(2)) 253 | target = target.view(-1) 254 | return F.cross_entropy(output, target,weight, size_average) 255 | else: 256 | raise ValueError('Expected 2 or 4 dimensions (got {})'.format(dim)) 257 | 258 | def accuracy(input, target): 259 | return 100 * float(np.count_nonzero(input == target)) / target.size 260 | 261 | def sliding_window(top, step=10, window_size=(20,20)): 262 | for x in range(0, top.shape[0], step): 263 | if x + window_size[0] > top.shape[0]: 264 | x = top.shape[0] - window_size[0] 265 | for y in range(0, top.shape[1], step): 266 | if y + window_size[1] > top.shape[1]: 267 | y = top.shape[1] - window_size[1] 268 | yield x, y, window_size[0], window_size[1] 269 | 270 | def count_sliding_window(top, step=10, window_size=(20,20)): 271 | c = 0 272 | for x in range(0, top.shape[0], step): 273 | if x + window_size[0] > top.shape[0]: 274 | x = top.shape[0] - window_size[0] 275 | for y in range(0, top.shape[1], step): 276 | if y + window_size[1] > top.shape[1]: 277 | y = top.shape[1] - window_size[1] 278 | c += 1 279 | return c 280 | 281 | def grouper(n, iterable): 282 | it = iter(iterable) 283 | while True: 284 | chunk = tuple(itertools.islice(it, n)) 285 | if not chunk: 286 | return 287 | yield chunk 288 | 289 | def metrics(predictions, gts, label_values=LABELS): 290 | cm = confusion_matrix( 291 | gts, 292 | predictions, 293 | range(len(label_values))) 294 | print("Confusion matrix :") 295 | print(cm) 296 | # # Compute global accuracy 297 | total = sum(sum(cm)) 298 | accuracy = sum([cm[x][x] for x in range(len(cm))]) 299 | accuracy *= 100 / float(total) 300 | print("{} pixels processed".format(total)) 301 | print("Total accuracy : {}%".format(accuracy)) 302 | # # # Compute F1 score 303 | F1Score = np.zeros(len(label_values)) 304 | for i in range(len(label_values)): 305 | try: 306 | F1Score[i] = 2. * cm[i,i] / (np.sum(cm[i,:]) + np.sum(cm[:,i])) 307 | except: 308 | # # Ignore exception if there is no element in class i for test set 309 | #pass 310 | pass 311 | print("F1Score :") 312 | for l_id, score in enumerate(F1Score): 313 | print("{}: {}".format(label_values[l_id], score)) 314 | # # Compute kappa coefficient 315 | total = np.sum(cm) 316 | pa = np.trace(cm) / float(total) 317 | pe = np.sum(np.sum(cm, axis=0) * np.sum(cm, axis=1)) / float(total*total) 318 | kappa = (pa - pe) / (1 - pe); 319 | print("Kappa: " + str(kappa)) 320 | return accuracy 321 | 322 | # Load the dataset 323 | class ISPRS_dataset(torch.utils.data.Dataset): 324 | def __init__(self, ids, data_files=DATA_FOLDER, label_files=LABEL_FOLDER, 325 | cache=False, augmentation=True): 326 | super(ISPRS_dataset, self).__init__() 327 | self.augmentation = augmentation 328 | self.cache = cache 329 | self.data_files = [DATA_FOLDER.format(id) for id in ids] 330 | self.label_files = [LABEL_FOLDER.format(id) for id in ids] 331 | for f in self.data_files + self.label_files: 332 | if not os.path.isfile(f): 333 | raise KeyError('{} is not a file !'.format(f)) 334 | self.data_cache_ = {} 335 | self.label_cache_ = {} 336 | 337 | def __len__(self): 338 | return 10000 339 | 340 | @classmethod 341 | def data_augmentation(cls, *arrays, flip=True, mirror=True): 342 | will_flip, will_mirror = False, False 343 | #will_rotate = False 344 | #will_rotate2 = False 345 | if flip and random.random() < 0.5: 346 | will_flip = True 347 | if mirror and random.random() < 0.5: 348 | will_mirror = True 349 | 350 | results = [] 351 | for array in arrays: 352 | if will_flip: 353 | if len(array.shape) == 2: 354 | array = array[::-1, :] 355 | else: 356 | array = array[:, ::-1, :] 357 | if will_mirror: 358 | if len(array.shape) == 2: 359 | array = array[:, ::-1] 360 | else: 361 | array = array[:, :, ::-1] 362 | 363 | results.append(np.copy(array)) 364 | 365 | return tuple(results) 366 | 367 | def __getitem__(self, i): 368 | random_idx = random.randint(0, len(self.data_files) - 1) 369 | if random_idx in self.data_cache_.keys(): 370 | data = self.data_cache_[random_idx] 371 | else: 372 | data = 1/255 * np.asarray(io.imread(self.data_files[random_idx]).transpose((2,0,1)), dtype='float32') 373 | if self.cache: 374 | self.data_cache_[random_idx] = data 375 | 376 | if random_idx in self.label_cache_.keys(): 377 | label = self.label_cache_[random_idx] 378 | else: 379 | label = np.asarray(convert_from_color(io.imread(self.label_files[random_idx])), dtype='int64') 380 | if self.cache: 381 | self.label_cache_[random_idx] = label 382 | 383 | x1, x2, y1, y2 = get_random_pos(data, WINDOW_SIZE) 384 | data_p = data[:, x1:x2,y1:y2] 385 | label_p = label[x1:x2,y1:y2] 386 | 387 | #data_p, label_p = self.data_augmentation(data_p, label_p) 388 | data_p, label_p = self.data_augmentation(data_p, label_p) 389 | 390 | # # # # # Return the torch.Tensor values 391 | return (torch.from_numpy(data_p), 392 | torch.from_numpy(label_p)) 393 | 394 | # model = SegformerForSemanticSegmentation.from_pretrained("nvidia/mit-b0", 395 | # num_labels=6, 396 | # id2label=id2label, 397 | # label2id=label2id, 398 | # ) 399 | 400 | id2label = {0 : 'Impervious surfaces', 401 | 1 : 'Buildings', 402 | 2 : 'Low vegetation', 403 | 3 : 'Trees', 404 | 4 : 'Cars', 405 | 5 : 'Clutter'} 406 | label2id = {v: k for k, v in id2label.items()} 407 | 408 | # model = SegformerForSemanticSegmentation.from_pretrained("nvidia/mit-b0", 409 | # num_labels=6, 410 | # id2label=id2label, 411 | # label2id=label2id, 412 | # ) 413 | 414 | # We use B5 415 | model = SegformerForSemanticSegmentation.from_pretrained("nvidia/mit-b5", 416 | num_labels=6, 417 | id2label=id2label, 418 | label2id=label2id, 419 | alpha=0.0, 420 | ) 421 | # Use b5 422 | 423 | # Load the data 424 | if DATASET == 'Potsdam': 425 | all_files = sorted(glob(LABEL_FOLDER.replace('{}', '*'))) 426 | #print(all_files) 427 | #all_ids = ["".join(f.split('')[5:7]) for f in all_files] 428 | #print(all_ids) 429 | 430 | elif DATASET == 'Vaihingen': 431 | #all_ids = 432 | all_files = sorted(glob(LABEL_FOLDER.replace('{}', '*'))) 433 | all_ids = [f.split('area')[-1].split('.')[0] for f in all_files] 434 | 435 | #print(all_files) 436 | #print(all_ids) 437 | 438 | #folderName = '../../MainPotsd' 439 | folderName = '/Data/ndionelis/MainPotsd' 440 | names_data = os.listdir(folderName) # to not load all data in a single tensor, load only the names 441 | length_names = len(names_data) 442 | 443 | from torchvision import transforms, datasets 444 | import shutil 445 | 446 | data_dir = '/Data/ndionelis/mainpo4244' 447 | 448 | train_ids = [] 449 | iTheLoopNumber = 0 450 | for iTheLoop in training_data: 451 | train_ids.append(training_data[iTheLoopNumber][12:-8]) 452 | iTheLoopNumber += 1 453 | 454 | print("Tiles for training : ", train_ids) 455 | train_set = ISPRS_dataset(train_ids, cache=CACHE) 456 | #print(len(train_set)) 457 | #train_loader = torch.utils.data.DataLoader(train_set, batch_size=BATCH_SIZE) 458 | #train_dataloader = torch.utils.data.DataLoader(train_set, batch_size=BATCH_SIZE) 459 | train_dataloader = torch.utils.data.DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUMWORKERS) 460 | 461 | # train_dataloader_iter = iter(train_dataloader) 462 | # #train_dataloader_iter_next = next(train_dataloader_iter) 463 | # #train_dataloader_iter_next1, train_dataloader_iter_next2 = train_dataloader_iter_next 464 | # train_dataloader_iter_next1, train_dataloader_iter_next2 = next(train_dataloader_iter) 465 | # print(train_dataloader_iter_next1) 466 | # print(train_dataloader_iter_next2) 467 | # print(train_dataloader_iter_next1.shape) 468 | # print(train_dataloader_iter_next2.shape) 469 | 470 | # train_dataloader_iter = iter(train_dataloader) 471 | # #train_dataloader_iter_next = next(train_dataloader_iter) 472 | # #train_dataloader_iter_next1, train_dataloader_iter_next2 = train_dataloader_iter_next 473 | # train_dataloader_iter_next1, train_dataloader_iter_next2 = next(train_dataloader_iter) 474 | # print(train_dataloader_iter_next1) 475 | # print(train_dataloader_iter_next2) 476 | # print(train_dataloader_iter_next1.shape) 477 | # print(train_dataloader_iter_next2.shape) 478 | # # torch.Size([10, 3, 256, 256]) 479 | # # torch.Size([10, 256, 256]) 480 | 481 | # data_transforms = { 482 | # 'train': transforms.Compose([ 483 | # transforms.RandomCrop((256, 256)), 484 | # transforms.ToTensor(), 485 | # ]), 486 | # 'val': transforms.Compose([ 487 | # transforms.RandomCrop((256, 256)), 488 | # transforms.ToTensor(), 489 | # ]), 490 | # } 491 | # image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x), 492 | # data_transforms[x]) 493 | # for x in ['train', 'val']} 494 | 495 | # #print(image_datasets['train'].imgs) 496 | # #print(image_datasets['train'].targets) 497 | 498 | # #print(dataloaders['train']) 499 | # dataloaders_iter = iter(dataloaders['train']) 500 | # dataloaders_iter_next1, dataloaders_iter_next2 = next(dataloaders_iter) 501 | # print(dataloaders_iter_next1) 502 | # print(dataloaders_iter_next2) 503 | # print(dataloaders_iter_next1.shape) 504 | # print(dataloaders_iter_next2.shape) 505 | 506 | DATASET = 'Vaihingen' 507 | #DATASET = 'Potsdam' 508 | 509 | MAIN_FOLDER = FOLDER + 'Vaihingen/' 510 | DATA_FOLDER = MAIN_FOLDER + 'top/top_mosaic_09cm_area{}.tif' 511 | #LABEL_FOLDER = MAIN_FOLDER + 'gts_for_participants/top_mosaic_09cm_area{}.tif' 512 | LABEL_FOLDER = MAIN_FOLDER + 'top_mosaic_09cm_area{}.tif' 513 | #ERODED_FOLDER = MAIN_FOLDER + 'gts_eroded_for_participants/top_mosaic_09cm_area{}_noBoundary.tif' 514 | ERODED_FOLDER = MAIN_FOLDER + 'top_mosaic_09cm_area{}_noBoundary.tif' 515 | 516 | # if DATASET == 'Potsdam': 517 | # MAIN_FOLDER = FOLDER + 'Potsdam/' 518 | # # Uncomment the next line for IRRG data 519 | # # DATA_FOLDER = MAIN_FOLDER + '3_Ortho_IRRG/top_potsdam_{}_IRRG.tif' 520 | # # # For RGB data 521 | # #print(MAIN_FOLDER) 522 | # #sadfszf 523 | # DATA_FOLDER = MAIN_FOLDER + '2_Ortho_RGB/top_potsdam_{}_RGB.tif' 524 | # LABEL_FOLDER = MAIN_FOLDER + '5_Labels_for_participants/top_potsdam_{}_label.tif' 525 | # ERODED_FOLDER = MAIN_FOLDER + '5_Labels_for_participants_no_Boundary/top_potsdam_{}_label_noBoundary.tif' 526 | 527 | # elif DATASET == 'Vaihingen': 528 | # MAIN_FOLDER = FOLDER + 'Vaihingen/' 529 | # #print(MAIN_FOLDER) 530 | # #asdfszdf 531 | # DATA_FOLDER = MAIN_FOLDER + 'top/top_mosaic_09cm_area{}.tif' 532 | # LABEL_FOLDER = MAIN_FOLDER + 'gts_for_participants/top_mosaic_09cm_area{}.tif' 533 | # ERODED_FOLDER = MAIN_FOLDER + 'gts_eroded_for_participants/top_mosaic_09cm_area{}_noBoundary.tif' 534 | 535 | # # Dataset class 536 | class ISPRS_dataset(torch.utils.data.Dataset): 537 | def __init__(self, ids, data_files=DATA_FOLDER, label_files=LABEL_FOLDER, 538 | cache=False, augmentation=True): 539 | super(ISPRS_dataset, self).__init__() 540 | self.augmentation = augmentation 541 | self.cache = cache 542 | self.data_files = [DATA_FOLDER.format(id) for id in ids] 543 | self.label_files = [LABEL_FOLDER.format(id) for id in ids] 544 | for f in self.data_files + self.label_files: 545 | if not os.path.isfile(f): 546 | raise KeyError('{} is not a file !'.format(f)) 547 | self.data_cache_ = {} 548 | self.label_cache_ = {} 549 | 550 | def __len__(self): 551 | return 10000 552 | 553 | @classmethod 554 | def data_augmentation(cls, *arrays, flip=True, mirror=True): 555 | will_flip, will_mirror = False, False 556 | #will_rotate = False 557 | #will_rotate2 = False 558 | if flip and random.random() < 0.5: 559 | will_flip = True 560 | if mirror and random.random() < 0.5: 561 | will_mirror = True 562 | 563 | results = [] 564 | for array in arrays: 565 | if will_flip: 566 | if len(array.shape) == 2: 567 | array = array[::-1, :] 568 | else: 569 | array = array[:, ::-1, :] 570 | if will_mirror: 571 | if len(array.shape) == 2: 572 | array = array[:, ::-1] 573 | else: 574 | array = array[:, :, ::-1] 575 | 576 | results.append(np.copy(array)) 577 | 578 | return tuple(results) 579 | 580 | def __getitem__(self, i): 581 | # # Pick a random image 582 | random_idx = random.randint(0, len(self.data_files) - 1) 583 | 584 | # If the tile hasn't been loaded yet, put in cache 585 | if random_idx in self.data_cache_.keys(): 586 | data = self.data_cache_[random_idx] 587 | else: 588 | # Data is normalized in [0, 1] 589 | data = 1/255 * np.asarray(io.imread(self.data_files[random_idx]).transpose((2,0,1)), dtype='float32') 590 | if self.cache: 591 | self.data_cache_[random_idx] = data 592 | 593 | if random_idx in self.label_cache_.keys(): 594 | label = self.label_cache_[random_idx] 595 | else: 596 | # # Labels are converted from RGB to their numeric values 597 | label = np.asarray(convert_from_color(io.imread(self.label_files[random_idx])), dtype='int64') 598 | if self.cache: 599 | self.label_cache_[random_idx] = label 600 | 601 | x1, x2, y1, y2 = get_random_pos(data, WINDOW_SIZE) 602 | data_p = data[:, x1:x2,y1:y2] 603 | label_p = label[x1:x2,y1:y2] 604 | 605 | data_p, label_p = self.data_augmentation(data_p, label_p) 606 | 607 | return (torch.from_numpy(data_p), 608 | torch.from_numpy(label_p)) 609 | 610 | # Load the datasets 611 | if DATASET == 'Potsdam': 612 | all_files = sorted(glob(LABEL_FOLDER.replace('{}', '*'))) 613 | #print(all_files) 614 | 615 | #all_ids = ["".join(f.split('')[5:7]) for f in all_files] 616 | #print(all_ids) 617 | 618 | elif DATASET == 'Vaihingen': 619 | #all_ids = 620 | all_files = sorted(glob(LABEL_FOLDER.replace('{}', '*'))) 621 | all_ids = [f.split('area')[-1].split('.')[0] for f in all_files] 622 | 623 | #print(all_files) 624 | #print(all_ids) 625 | 626 | # train_loader2_iter = iter(train_loader2) 627 | # train_loader2_iter_next1, train_loader2_iter_next2 = next(train_loader2_iter) 628 | # print(train_loader2_iter_next1) 629 | # print(train_loader2_iter_next2) 630 | # print(train_loader2_iter_next1.shape) 631 | # print(train_loader2_iter_next2.shape) 632 | # # torch.Size([10, 3, 256, 256]) 633 | # # torch.Size([10, 256, 256]) 634 | 635 | #folderName = '../../Vaihingen/top' 636 | folderName = '/Data/ndionelis/Vaihingen/top' 637 | names_data = os.listdir(folderName) # to not load all data in a single tensor, load only the names 638 | length_names = len(names_data) 639 | 640 | from torchvision import transforms, datasets 641 | #import shutil 642 | 643 | #data_dir = '/Data/ndionelis/mainva2' 644 | data_dir = '/Data/ndionelis/mainva4244' 645 | 646 | train_ids = [] 647 | iTheLoopNumber = 0 648 | for iTheLoop in training_data: 649 | train_ids.append(training_data[iTheLoopNumber][20:-4]) 650 | iTheLoopNumber += 1 651 | #print(train_ids) 652 | 653 | print("Tiles for training : ", train_ids) 654 | #print("Tiles for testing : ", test_ids) 655 | 656 | train_set = ISPRS_dataset(train_ids, cache=CACHE) 657 | #print(len(train_set)) 658 | #train_loader = torch.utils.data.DataLoader(train_set, batch_size=BATCH_SIZE) 659 | #train_loader2 = torch.utils.data.DataLoader(train_set, batch_size=BATCH_SIZE) 660 | train_loader2 = torch.utils.data.DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUMWORKERS) 661 | 662 | # train_loader2_iter = iter(train_loader2) 663 | # train_loader2_iter_next1, train_loader2_iter_next2 = next(train_loader2_iter) 664 | # print(train_loader2_iter_next1) 665 | # print(train_loader2_iter_next2) 666 | # print(train_loader2_iter_next1.shape) 667 | # print(train_loader2_iter_next2.shape) 668 | # # torch.Size([10, 3, 256, 256]) 669 | # # torch.Size([10, 256, 256]) 670 | 671 | # data_transforms = { 672 | # 'train': transforms.Compose([ 673 | # transforms.ToTensor(), 674 | # ]), 675 | # 'val': transforms.Compose([ 676 | # transforms.ToTensor(), 677 | # ]), 678 | # } 679 | # image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x), 680 | # data_transforms[x]) 681 | # for x in ['train', 'val']} 682 | # dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=BATCH_SIZE, 683 | # shuffle=True, num_workers=8) 684 | # for x in ['train', 'val']} 685 | # dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']} 686 | # class_names = image_datasets['train'].classes 687 | 688 | # Define the optimizer 689 | #optimizer = torch.optim.AdamW(model.parameters(), lr=0.00006) 690 | #optimizer = torch.optim.AdamW(model.parameters(), lr=0.0000006) 691 | 692 | optimizer = torch.optim.AdamW(model.parameters(), lr=0.00006) 693 | 694 | folderName = '/Data/ndionelis/bingmap/19' 695 | #folderName = '../../CVUSA/bingmap/19' 696 | 697 | names_data = os.listdir(folderName) # # to not load all data in a single tensor, load only the names 698 | length_names = len(names_data) 699 | #print(names_data) 700 | #print(length_names) 701 | 702 | # print(training_data) 703 | # print(len(training_data)) 704 | # print(test_data) 705 | # print(len(test_data)) 706 | 707 | from torchvision import transforms, datasets 708 | 709 | #data_transforms = transforms.Compose([transforms.ToTensor(),]) 710 | #data_transforms = transforms.ToTensor() 711 | 712 | #import shutil 713 | 714 | #data_dir = '../../CVUSA/bingmap/mainTheDatasetNoSubsaa223newnew' 715 | #data_dir = '/Data/ndionelis/bingmap/maindatasetnewnew2' 716 | data_dir = '/Data/ndionelis/bingmap/maindatasetnewnew4244' 717 | 718 | #data_dir = '/Data/ndionelis/bingmap/mainDataset' 719 | # image_datasets = {x: datasets.ImageFolder(data_dir, 720 | # data_transforms) 721 | # if x in training_data; 722 | # else: continue} 723 | #image_datasets = datasets.ImageFolder(os.path.join(data_dir, training_data), data_transforms) 724 | #image_datasets2 = datasets.ImageFolder(os.path.join(data_dir, test_data), data_transforms) 725 | 726 | data_transforms = { 727 | 'train': transforms.Compose([ 728 | transforms.ToTensor(), 729 | ]), 730 | 'val': transforms.Compose([ 731 | transforms.ToTensor(), 732 | ]), 733 | } 734 | 735 | #data_dir = 'data/hymenoptera_data' 736 | image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x), 737 | data_transforms[x]) 738 | for x in ['train', 'val']} 739 | # dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=BATCH_SIZE, 740 | # shuffle=True, num_workers=8) 741 | # for x in ['train', 'val']} 742 | # dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=BATCH_SIZE, 743 | # shuffle=True, num_workers=6) 744 | # for x in ['train', 'val']} 745 | dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=BATCH_SIZE, 746 | shuffle=True, num_workers=NUMWORKERS) 747 | for x in ['train', 'val']} 748 | dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']} 749 | class_names = image_datasets['train'].classes 750 | 751 | #device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 752 | 753 | # Move model to GPU 754 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 755 | print(device) 756 | 757 | model.to(device) 758 | model.train() 759 | 760 | #model.load_state_dict(torch.load('/Data/ndionelis/segformermain30082023')) 761 | 762 | w, h = WINDOW_SIZE 763 | 764 | for epoch in range(200): 765 | print("Epoch:", epoch) 766 | #train_loader2_iter = iter(train_loader2) 767 | train_loader2_iter = iter(train_loader2) 768 | dataloaders_iter = iter(dataloaders['train']) 769 | #for idx, batch in enumerate(tqdm(train_dataloader)): 770 | for idx, (pixel_values, labels) in enumerate(tqdm(train_dataloader)): 771 | pixel_values, labels = Variable(pixel_values.cuda()), Variable(labels.cuda()) 772 | epochs = 200 773 | #epochs = 50 774 | data_dataloaders = next(dataloaders_iter) 775 | #data_dataloaders2 = next(dataloaders_iter2) 776 | #data_dataloaders, labels_dataloaders = data_dataloaders 777 | data_dataloaders, _ = data_dataloaders 778 | #w, h = WINDOW_SIZE 779 | W, H = data_dataloaders.shape[-2:] 780 | # x1 = random.randint(0, W - w - 1) 781 | # x2 = x1 + w 782 | # y1 = random.randint(0, H - h - 1) 783 | # y2 = y1 + h 784 | # data_dataloaders = data_dataloaders[:, :, x1:x2, y1:y2] 785 | for iToLoop in range(BATCH_SIZE): 786 | if iToLoop == 0: 787 | image_patches = torch.zeros((BATCH_SIZE, 3, w, h), device=device) 788 | x1 = random.randint(0, W - w - 1) 789 | x2 = x1 + w 790 | y1 = random.randint(0, H - h - 1) 791 | y2 = y1 + h 792 | #data_dataloaders = data_dataloaders[:, :, x1:x2, y1:y2] 793 | #data_dataloaders = data_dataloaders[iToLoop, :, x1:x2, y1:y2] 794 | image_patches[iToLoop, :, :, :] = data_dataloaders[iToLoop, :, x1:x2, y1:y2] 795 | 796 | #p = float(idx + e * len(train_dataloader)) / epochs / len(train_dataloader) 797 | p = float(idx + (epoch+1) * len(train_dataloader)) / epochs / len(train_dataloader) 798 | alpha = 2. / (1. + np.exp(-10 * p)) - 1 799 | 800 | #encoded_inputs = self.feature_extractor(image, segmentation_map, return_tensors="pt") 801 | #pixel_values = encoded_inputs.pixel_values.to(device) 802 | #labels = encoded_inputs.labels.to(device) 803 | 804 | #image_patches = Variable(torch.from_numpy(image_patches).cuda()) 805 | #encoding = feature_extractor(image_patches, return_tensors="pt") 806 | #pixel_values = encoding.pixel_values.to(device) 807 | 808 | # # # # # get the inputs 809 | #pixel_values = batch["pixel_values"].to(device) 810 | #labels = batch["labels"].to(device) 811 | #print(pixel_values.shape) 812 | #print(labels.shape) 813 | 814 | # #import matplotlib.pyplot as plt 815 | # plt.figure() 816 | # plt.imshow(pixel_values[0,:,:,:].permute(1, 2, 0).cpu()) 817 | # plt.show() 818 | 819 | # plt.figure() 820 | # plt.imshow((labels.unsqueeze(1))[0,:,:,:].permute(1, 2, 0).cpu()) 821 | # plt.show() 822 | 823 | # plt.figure() 824 | # plt.imshow(pixel_values[1,:,:,:].permute(1, 2, 0).cpu()) 825 | # plt.show() 826 | 827 | # plt.figure() 828 | # plt.imshow((labels.unsqueeze(1))[1,:,:,:].permute(1, 2, 0).cpu()) 829 | # plt.show() 830 | # # # torch.Size([10, 3, 256, 256]) 831 | # # # torch.Size([10, 256, 256]) 832 | 833 | # # torch.Size([2, 3, 512, 512]) 834 | # # torch.Size([2, 512, 512]) 835 | 836 | # # # zero the parameter gradients 837 | optimizer.zero_grad() 838 | 839 | data2, target2 = next(train_loader2_iter) 840 | data2, target2 = Variable(data2.cuda()), Variable(target2.cuda()) 841 | #print(alpha) 842 | 843 | # import matplotlib.pyplot as plt 844 | # plt.imshow(pixel_values[3,:,:,:].permute(1, 2, 0).cpu().numpy()) 845 | # plt.savefig('beforeIm.png') 846 | 847 | # # radias = int(0.1 * 256) // 2 848 | # # kernel_size = radias * 2 + 1 849 | # # blur_h = nn.Conv2d(3, 3, kernel_size=(kernel_size, 1), 850 | # # stride=1, padding=0, bias=False, groups=3) 851 | # # blur_v = nn.Conv2d(3, 3, kernel_size=(1, kernel_size), 852 | # # stride=1, padding=0, bias=False, groups=3) 853 | # # #k = kernel_size 854 | # # #r = radias 855 | 856 | # # blur = nn.Sequential( 857 | # # nn.ReflectionPad2d(radias), 858 | # # blur_h, 859 | # # blur_v 860 | # # ) 861 | 862 | # # #self.pil_to_tensor = transforms.ToTensor() 863 | # # #self.tensor_to_pil = transforms.ToPILImage() 864 | 865 | # # #img = self.pil_to_tensor(img).unsqueeze(0) 866 | 867 | # # sigma = np.random.uniform(0.1, 2.0) 868 | # # x = np.arange(-radias, radias + 1) 869 | # # x = np.exp(-np.power(x, 2) / (2 * sigma * sigma)) 870 | # # x = x / x.sum() 871 | # # x = torch.from_numpy(x).view(1, -1).repeat(3, 1) 872 | 873 | # # blur_h.weight.data.copy_(x.view(3, 1, kernel_size, 1)) 874 | # # blur_v.weight.data.copy_(x.view(3, 1, 1, kernel_size)) 875 | # # pixel_values = blur(pixel_values) 876 | 877 | # blurrer = T.GaussianBlur(kernel_size=(5, 9), sigma=(0.1, 5)) 878 | # pixel_values = blurrer(pixel_values) 879 | 880 | # plt.figure() 881 | # plt.imshow(pixel_values[3,:,:,:].permute(1, 2, 0).cpu().numpy()) 882 | # plt.savefig('afterImage.png') 883 | 884 | # if epoch > 8: 885 | # if random.random() < 0.5: 886 | # # # # # Gaussian blur 887 | # #applier = T.RandomApply(transforms=[T.GaussianBlur(kernel_size=(5, 9), sigma=(0.1, 5))], p=0.5) 888 | # #applier = T.RandomApply(transforms=[T.GaussianBlur(kernel_size=(5, 9), sigma=(0.1, 5))], p=1.0) 889 | # applier = T.RandomVerticalFlip(p=1.0) 890 | # pixel_values = applier(pixel_values) 891 | # labels = applier(labels) 892 | # image_patches = applier(image_patches) 893 | # data2 = applier(data2) 894 | # target2 = applier(target2) 895 | 896 | # if random.random() < 0.5: 897 | # # # # Gaussian blur 898 | # #applier = T.RandomApply(transforms=[T.GaussianBlur(kernel_size=(5, 9), sigma=(0.1, 5))], p=0.5) 899 | # #applier = T.RandomApply(transforms=[T.GaussianBlur(kernel_size=(5, 9), sigma=(0.1, 5))], p=1.0) 900 | # applier = T.RandomHorizontalFlip(p=1.0) 901 | # pixel_values = applier(pixel_values) 902 | # labels = applier(labels) 903 | # image_patches = applier(image_patches) 904 | # data2 = applier(data2) 905 | # target2 = applier(target2) 906 | 907 | # if random.random() < 0.5: 908 | # blurrer = T.GaussianBlur(kernel_size=(5, 9), sigma=(0.1, 5)) 909 | # pixel_values = blurrer(pixel_values) 910 | # image_patches = blurrer(image_patches) 911 | # data2 = blurrer(data2) 912 | 913 | #print(pixel_values) 914 | #print(pixel_values.shape) 915 | 916 | # We modify: https://github.com/huggingface/transformers/blob/v4.33.0/src/transformers/models/segformer/modeling_segformer.py#L746 917 | 918 | # # forward + backward + optimize 919 | #outputs = model(pixel_values=pixel_values, labels=labels, alpha=alpha) 920 | outputs = model(pixel_values=pixel_values, labels=labels, alpha=alpha, image_patches=image_patches, data2=data2, target2=target2) 921 | #outputs = model(pixel_values=pixel_values, labels=labels) 922 | 923 | # For the above, we have modified: https://github.com/huggingface/transformers/blob/v4.33.0/src/transformers/models/segformer/modeling_segformer.py#L746 924 | 925 | loss, logits = outputs.loss, outputs.logits 926 | 927 | # output2, oo222 = net(data2, alpha) 928 | # #loss += 10.0 * CrossEntropy2d(output2, target2, weight=weights) 929 | # loss += CrossEntropy2d(output2, target2, weight=weights) 930 | # loss += smp.losses.JaccardLoss(mode="multiclass", classes=6)(y_pred=output2, y_true=target2) 931 | # domain_label2 = (2*torch.ones(BATCH_SIZE).long()).cuda() 932 | # #err_s_domain = nn.CrossEntropyLoss()(oo222, domain_label2) 933 | # loss += nn.CrossEntropyLoss()(oo222, domain_label2) 934 | #loss.backward() 935 | 936 | loss.backward() 937 | optimizer.step() 938 | print("Loss:", loss.item()) 939 | 940 | # # # # # evaluate 941 | # with torch.no_grad(): 942 | # upsampled_logits = nn.functional.interpolate(logits, size=labels.shape[-2:], mode="bilinear", align_corners=False) 943 | # predicted = upsampled_logits.argmax(dim=1) 944 | 945 | # # # note that the metric expects predictions + labels as numpy arrays 946 | # metric.add_batch(predictions = predicted.detach().cpu().numpy(), references = labels.detach().cpu().numpy()) 947 | 948 | # # # # let's print loss and metrics every 100 batches 949 | # #if idx % 100 == 0: 950 | # if idx % 10 == 0: 951 | # metrics = metric._compute(predictions = predicted.detach().cpu().numpy(), references = labels.detach().cpu().numpy(), num_labels=len(id2label), 952 | # ignore_index=255, 953 | # reduce_labels=False, # we've already reduced the labels before) 954 | # ) 955 | 956 | # print("Loss:", loss.item()) 957 | # print("Mean_iou:", metrics["mean_iou"]) 958 | # print("Mean accuracy:", metrics["mean_accuracy"]) 959 | 960 | 961 | # ## Inference 962 | ##image = Image.open('./ADE20k_toy_dataset/images/training/ADE_train_00000001.jpg') 963 | 964 | def ade_palette(): 965 | """ADE20K palette that maps each class to RGB values.""" 966 | return [[120, 120, 120], [180, 120, 120], [6, 230, 230], [80, 50, 50], 967 | [4, 200, 3], [120, 120, 80], [140, 140, 140], [204, 5, 255], 968 | [230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7], 969 | [150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82], 970 | [143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3], 971 | [0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255], 972 | [255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220], 973 | [255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224], 974 | [255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255], 975 | [224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7], 976 | [255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153], 977 | [6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255], 978 | [140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0], 979 | [255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255], 980 | [255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255], 981 | [11, 200, 200], [255, 82, 0], [0, 255, 245], [0, 61, 255], 982 | [0, 255, 112], [0, 255, 133], [255, 0, 0], [255, 163, 0], 983 | [255, 102, 0], [194, 255, 0], [0, 143, 255], [51, 255, 0], 984 | [0, 82, 255], [0, 255, 41], [0, 255, 173], [10, 0, 255], 985 | [173, 255, 0], [0, 255, 153], [255, 92, 0], [255, 0, 255], 986 | [255, 0, 245], [255, 0, 102], [255, 173, 0], [255, 0, 20], 987 | [255, 184, 184], [0, 31, 255], [0, 255, 61], [0, 71, 255], 988 | [255, 0, 204], [0, 255, 194], [0, 255, 82], [0, 10, 255], 989 | [0, 112, 255], [51, 0, 255], [0, 194, 255], [0, 122, 255], 990 | [0, 255, 163], [255, 153, 0], [0, 255, 10], [255, 112, 0], 991 | [143, 255, 0], [82, 0, 255], [163, 255, 0], [255, 235, 0], 992 | [8, 184, 170], [133, 0, 255], [0, 255, 92], [184, 0, 255], 993 | [255, 0, 31], [0, 184, 255], [0, 214, 255], [255, 0, 112], 994 | [92, 255, 0], [0, 224, 255], [112, 224, 255], [70, 184, 160], 995 | [163, 0, 255], [153, 0, 255], [71, 255, 0], [255, 0, 163], 996 | [255, 204, 0], [255, 0, 143], [0, 255, 235], [133, 255, 0], 997 | [255, 0, 235], [245, 0, 255], [255, 0, 122], [255, 245, 0], 998 | [10, 190, 212], [214, 255, 0], [0, 204, 255], [20, 0, 255], 999 | [255, 255, 0], [0, 153, 255], [0, 41, 255], [0, 255, 204], 1000 | [41, 0, 255], [41, 255, 0], [173, 0, 255], [0, 245, 255], 1001 | [71, 0, 255], [122, 0, 255], [0, 255, 184], [0, 92, 255], 1002 | [184, 255, 0], [0, 133, 255], [255, 214, 0], [25, 194, 194], 1003 | [102, 255, 0], [92, 0, 255]] 1004 | 1005 | def test(net, test_ids, all=False, stride=WINDOW_SIZE[0], batch_size=BATCH_SIZE, window_size=WINDOW_SIZE): 1006 | # # /home/nikolaos/CVUSA/bingmap 1007 | 1008 | test_images = (1 / 255 * np.asarray(io.imread(DATA_FOLDER.format(id)), dtype='float32') for id in test_ids) 1009 | test_labels = (np.asarray(io.imread(LABEL_FOLDER.format(id)), dtype='uint8') for id in test_ids) 1010 | eroded_labels = (convert_from_color(io.imread(ERODED_FOLDER.format(id))) for id in test_ids) 1011 | 1012 | # # /home/nikolaos/CVUSA/bingmap 1013 | #print(test_images) 1014 | #print(test_images.size) 1015 | 1016 | all_preds = [] 1017 | all_gts = [] 1018 | net.eval() 1019 | 1020 | for img, gt, gt_e in tqdm(zip(test_images, test_labels, eroded_labels), total=len(test_ids), leave=False): 1021 | #pred = np.zeros(img.shape[:2] + (N_CLASSES,)) 1022 | 1023 | # fig = plt.figure() 1024 | # plt.imshow(np.asarray(255 * img, dtype='uint8')) 1025 | # plt.axis('off') 1026 | # plt.show() 1027 | 1028 | # print(img) 1029 | 1030 | #print(img.shape) 1031 | 1032 | #from PIL import Image 1033 | #img2 = Image.open('/home/nikolaos/CVUSA/bingmap/0000001.jpg') 1034 | 1035 | #print(DATA_FOLDER.format(id)) 1036 | 1037 | #img2 = (1 / 255 * np.asarray(io.imread(DATA_FOLDER.format(id)), dtype='float32') for id in test_ids) 1038 | 1039 | #img2 = (1 / 255 * np.asarray(io.imread('/home/nikolaos/CVUSA/bingmap/0000001.jpg'), dtype='float32') for id in test_ids) 1040 | 1041 | # fig = plt.figure() 1042 | # #plt.imshow(np.asarray(255 * img2, dtype='uint8')) 1043 | # #plt.imshow(img2) 1044 | # plt.imshow(np.asarray(255 * img2, dtype='uint8')) 1045 | # plt.axis('off') 1046 | # plt.show() 1047 | 1048 | # print(img2) 1049 | 1050 | #print(img2.shape) 1051 | 1052 | #img = (1 / 255 * np.asarray(io.imread('/home/nikolaos/CVUSA/bingmap/0000001.jpg'), dtype='float32')) 1053 | 1054 | # #img = (1 / 255 * np.asarray(io.imread('/Data/ndionelis/bingmap/19/'+str(44506).zfill(7)+'.jpg'), dtype='float32')) 1055 | 1056 | # #img = (1 / 255 * np.asarray(io.imread('/home/nikolaos/CVUSA/bingmap/19/0000019.jpg'), dtype='float32')) 1057 | 1058 | # #img = (1 / 255 * np.asarray(io.imread('/home/nikolaos/CVUSA/bingmap/19/0000020.jpg'), dtype='float32')) 1059 | 1060 | # #image 1061 | 1062 | # #plt.figure 1063 | # #plt.imshow(img) 1064 | # #plt.show() 1065 | 1066 | pred = np.zeros(img.shape[:2] + (N_CLASSES,)) 1067 | 1068 | total = count_sliding_window(img, step=stride, window_size=window_size) // batch_size 1069 | for i, coords in enumerate(tqdm(grouper(batch_size, sliding_window(img, step=stride, window_size=window_size)), total=total, leave=False)): 1070 | if i > 0 and total > 10 and i % int(10 * total / 100) == 0: 1071 | _pred = np.argmax(pred, axis=-1) 1072 | # fig = plt.figure() 1073 | # #fig.add_subplot(1,3,1) 1074 | # fig.add_subplot(1,2,1) 1075 | # plt.imshow(np.asarray(255 * img, dtype='uint8')) 1076 | # plt.axis('off') 1077 | # #fig.add_subplot(1,3,2) 1078 | # fig.add_subplot(1,2,2) 1079 | # plt.imshow(convert_to_color(_pred)) 1080 | # plt.axis('off') 1081 | # #fig.add_subplot(1,3,3) 1082 | # #plt.imshow(gt) 1083 | # #plt.axis('off') 1084 | # #clear_output() 1085 | # plt.show() 1086 | 1087 | image_patches = [np.copy(img[x:x+w, y:y+h]).transpose((2,0,1)) for x,y,w,h in coords] 1088 | image_patches = np.asarray(image_patches) 1089 | with torch.no_grad(): 1090 | image_patches = Variable(torch.from_numpy(image_patches).cuda()) 1091 | 1092 | #print(image_patches.shape) 1093 | #print(image_patches) 1094 | 1095 | # torch.Size([10, 3, 256, 256]) 1096 | # tensor([[[[0.0980, 0.0902, 0.0784, ..., 0.1176, 0.1137, 0.1020], 1097 | # [0.0863, 0.0706, 0.0549, ..., 0.1098, 0.1059, 0.1020], 1098 | 1099 | # # /home/nikolaos/CVUSA/bingmsap 1100 | 1101 | # from PIL import Image 1102 | # img2 = Image.open('/home/nikolaos/CVUSA/bingmap/0000001.jpg') 1103 | 1104 | # from torchvision import transforms 1105 | # img2 = transforms.Resize((256, 256))(img2) 1106 | 1107 | # #image_patches = transforms.ToTensor()(img) 1108 | # image_patches = transforms.ToTensor()(img2).unsqueeze_(0) 1109 | 1110 | # for iloop in range(2, 10): 1111 | # img2 = Image.open('/home/nikolaos/CVUSA/bingmap/000000'+str(iloop)+'.jpg') 1112 | # img2 = transforms.Resize((256, 256))(img2) 1113 | # img2 = transforms.ToTensor()(img2).unsqueeze_(0) 1114 | # image_patches = torch.cat((image_patches, img2), 0) 1115 | 1116 | # img2 = Image.open('/home/nikolaos/CVUSA/bingmap/0000010.jpg') 1117 | # img2 = transforms.Resize((256, 256))(img2) 1118 | # img2 = transforms.ToTensor()(img2).unsqueeze_(0) 1119 | # image_patches = torch.cat((image_patches, img2), 0).cuda() 1120 | 1121 | # #print(image_patches.shape) 1122 | # #print(image_patches) 1123 | 1124 | # # Do the inference 1125 | #outs = net(image_patches) 1126 | #outs, _ = net(image_patches, 0.01) 1127 | #outs = net.forward2(image_patches) 1128 | #outs = outs.data.cpu().numpy() 1129 | 1130 | # # prepare the image for the model 1131 | #encoding = feature_extractor(image_patches, return_tensors="pt") 1132 | #pixel_values = encoding.pixel_values.to(device) 1133 | #print(pixel_values.shape) 1134 | 1135 | #encoding = feature_extractor(image_patches, return_tensors="pt") 1136 | #pixel_values = encoding.pixel_values.to(device) 1137 | 1138 | pixel_values = image_patches.to(device) 1139 | outputs = model(pixel_values=pixel_values) 1140 | 1141 | logits = outputs.logits.cpu() 1142 | 1143 | upsampled_logits = nn.functional.interpolate(logits, 1144 | size=image_patches.shape[3], # # (height, width) 1145 | mode='bilinear', 1146 | align_corners=False) 1147 | 1148 | #seg = upsampled_logits.argmax(dim=1)[0] 1149 | outs = upsampled_logits.detach().cpu().numpy() 1150 | 1151 | # color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8) # # height, width, 3 1152 | # palette = np.array(ade_palette()) 1153 | # for label, color in enumerate(palette): 1154 | # color_seg[seg == label, :] = color 1155 | # # # Convert to BGR 1156 | # color_seg = color_seg[..., ::-1] 1157 | 1158 | # # Show image + mask 1159 | # #img = np.array(image) * 0.5 + color_seg * 0.5 1160 | # #img = np.array(image_patches.cpu()) * 0.5 + color_seg * 0.5 1161 | # img = color_seg 1162 | # img = img.astype(np.uint8) 1163 | 1164 | # plt.figure(figsize=(15, 10)) 1165 | # plt.imshow(img) 1166 | # plt.show() 1167 | 1168 | # # # Fill in the results array 1169 | for out, (x, y, w, h) in zip(outs, coords): 1170 | out = out.transpose((1,2,0)) 1171 | pred[x:x+w, y:y+h] += out 1172 | del(outs) 1173 | 1174 | #break 1175 | 1176 | pred = np.argmax(pred, axis=-1) 1177 | 1178 | fig = plt.figure() 1179 | #fig.add_subplot(1,3,1) 1180 | fig.add_subplot(1,2,1) 1181 | plt.imshow(np.asarray(255 * img, dtype='uint8')) 1182 | plt.axis('off') 1183 | #fig.add_subplot(1,3,2) 1184 | fig.add_subplot(1,2,2) 1185 | plt.imshow(convert_to_color(pred)) 1186 | plt.axis('off') 1187 | #fig.add_subplot(1,3,3) 1188 | #plt.imshow(gt) 1189 | #plt.axis('off') 1190 | #plt.show() 1191 | plt.savefig('./theOutputImage.png') 1192 | #plt.savefig('./tesssttt2_tttlll2.png') 1193 | #io.imsave('./tesssttt2_tttlll2.png') 1194 | #plt.pause(10) 1195 | 1196 | all_preds.append(pred) 1197 | all_gts.append(gt_e) 1198 | 1199 | # # # prepare the image for the model 1200 | # encoding = feature_extractor(img, return_tensors="pt") 1201 | # pixel_values = encoding.pixel_values.to(device) 1202 | 1203 | # img = np.array(image) * 0.5 + color_seg * 0.5 1204 | # img = img.astype(np.uint8) 1205 | 1206 | # plt.figure(figsize=(15, 10)) 1207 | # plt.imshow(img) 1208 | # plt.show() 1209 | 1210 | #_, all_preds, all_gts = test(model, test_ids, all=True, stride=32) 1211 | 1212 | # # save the model 1213 | #torch.save(model.state_dict(), '/Data/ndionelis/segformermain30082023') 1214 | torch.save(model.state_dict(), '/Data/ndionelis/formodels/srmr2ehw4h44') 1215 | #model.load_state_dict(torch.load('/Data/ndionelis/segformermain30082023')) 1216 | 1217 | _, all_preds, all_gts = test(model, test_ids, all=True, stride=32) 1218 | 1219 | # # # prepare the image for the model 1220 | # encoding = feature_extractor(image, return_tensors="pt") 1221 | # pixel_values = encoding.pixel_values.to(device) 1222 | 1223 | # # forward pass 1224 | #outputs = model(pixel_values=pixel_values) 1225 | 1226 | from torch import nn 1227 | import numpy as np 1228 | import matplotlib.pyplot as plt 1229 | 1230 | # img = np.array(image) * 0.5 + color_seg * 0.5 1231 | # img = img.astype(np.uint8) 1232 | 1233 | # plt.figure(figsize=(15, 10)) 1234 | # plt.imshow(img) 1235 | # plt.show() 1236 | 1237 | --------------------------------------------------------------------------------