├── 0311_predict_tlcnetU_process_wholeimg.ipynb ├── LICENSE ├── README.md ├── asset ├── crop.jpg ├── orthorectification.jpg ├── pansharpening.jpg ├── quac.jpg └── registration.jpg ├── configs ├── fcn8s_pascal.yml ├── frrnB_cityscapes.yml ├── scratch.yml ├── tlcnet_zy3bh.yml ├── tlcnetu_zy3bh.yml ├── tlcnetu_zy3bh_equalweight.yml ├── tlcnetu_zy3bh_mux.yml ├── tlcnetu_zy3bh_tlc.yml ├── tlcnetu_zy3bh_tlcmux.yml ├── tlcnetu_zy3bh_us.yml ├── tlcnetu_zy3bh_us_cn.yml ├── tlcnetu_zy3bh_us_pre_loss.yml ├── unet_zy3bh.yml ├── unet_zy3bh_mux.yml ├── unet_zy3bh_tlc.yml ├── val_tlcnetu_zy3bh.yml ├── val_tlcnetu_zy3bh_t1.yml ├── val_tlcnetu_zy3bh_testus.yml ├── val_tlcnetu_zy3bh_us.yml ├── val_tlcnetu_zy3bh_us_cn.yml ├── val_tlcnetu_zy3bh_us_onlycn.yml └── val_tlcnetu_zy3bh_us_pre.yml ├── demo_deeppred.m ├── evaluate.py ├── object_val_prcimgv2.m ├── pred.rar ├── pred_zy3bh_tlcnetU.py ├── pred_zy3bh_tlcnetU_mux.py ├── pred_zy3bh_tlcnetU_tlc.py ├── pred_zy3bh_tlcnetU_tlcmux.py ├── ptsemseg ├── augmentations │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ ├── augmentations.cpython-36.pyc │ │ └── diyaugmentation.cpython-36.pyc │ ├── augmentations.py │ └── diyaugmentation.py ├── caffe_pb2.py ├── loader │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ ├── ade20k_loader.cpython-36.pyc │ │ ├── camvid_loader.cpython-36.pyc │ │ ├── cityscapes_loader.cpython-36.pyc │ │ ├── diy_dataset.cpython-36.pyc │ │ ├── diyloader.cpython-36.pyc │ │ ├── mapillary_vistas_loader.cpython-36.pyc │ │ ├── mit_sceneparsing_benchmark_loader.cpython-36.pyc │ │ ├── nyuv2_loader.cpython-36.pyc │ │ ├── pascal_voc_loader.cpython-36.pyc │ │ └── sunrgbd_loader.cpython-36.pyc │ ├── ade20k_loader.py │ ├── camvid_loader.py │ ├── cityscapes_loader.py │ ├── diy_dataset.py │ ├── diyloader.py │ ├── mapillary_vistas_loader.py │ ├── mit_sceneparsing_benchmark_loader.py │ ├── nyuv2_loader.py │ ├── pascal_voc_loader.py │ └── sunrgbd_loader.py ├── loss │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ └── loss.cpython-36.pyc │ └── loss.py ├── metrics.py ├── models │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ ├── fcn.cpython-36.pyc │ │ ├── frrn.cpython-36.pyc │ │ ├── icnet.cpython-36.pyc │ │ ├── linknet.cpython-36.pyc │ │ ├── pspnet.cpython-36.pyc │ │ ├── segnet.cpython-36.pyc │ │ ├── submodule.cpython-36.pyc │ │ ├── tlcnet.cpython-36.pyc │ │ ├── unet.cpython-36.pyc │ │ └── utils.cpython-36.pyc │ ├── fcn.py │ ├── frrn.py │ ├── icnet.py │ ├── linknet.py │ ├── pspnet.py │ ├── refinenet.py │ ├── segnet.py │ ├── submodule.py │ ├── tlcnet.py │ ├── tlcnetold.py │ ├── unet.py │ ├── utils.py │ └── utilsold.py ├── optimizers │ ├── __init__.py │ └── __pycache__ │ │ └── __init__.cpython-36.pyc ├── schedulers │ ├── __init__.py │ └── schedulers.py └── utils.py ├── runs ├── tlcnetu_zy3bh │ ├── V1.rar │ └── finetune.tar ├── tlcnetu_zy3bh_mux.rar ├── tlcnetu_zy3bh_tlc.rar └── tlcnetu_zy3bh_tlcmux.rar ├── sample ├── img │ ├── img_bj0216.tif │ ├── img_bj0217.tif │ ├── img_bj0219.tif │ ├── img_bj0317.tif │ ├── img_bj0322.tif │ ├── img_bj0324.tif │ └── img_bj0325.tif ├── lab │ ├── lab_bj0216.tif │ ├── lab_bj0217.tif │ ├── lab_bj0219.tif │ ├── lab_bj0317.tif │ ├── lab_bj0322.tif │ ├── lab_bj0324.tif │ └── lab_bj0325.tif ├── lab_floor │ ├── lab_bj0216.tif │ ├── lab_bj0217.tif │ ├── lab_bj0219.tif │ ├── lab_bj0317.tif │ ├── lab_bj0322.tif │ ├── lab_bj0324.tif │ └── lab_bj0325.tif └── tlc │ ├── tlc_bj0216.tif │ ├── tlc_bj0217.tif │ ├── tlc_bj0219.tif │ ├── tlc_bj0317.tif │ ├── tlc_bj0322.tif │ ├── tlc_bj0324.tif │ └── tlc_bj0325.tif ├── test_zy3bh_tlcnetU.py └── train_zy3bh_tlcnetU_loss.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 YinxiaCao 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

A deep learning method for building height estimation using high-resolution multi-view imagery over urban areas: A case study of 42 Chinese cities.

2 | 3 | We introduce high-resolution ZY-3 multi-view images to estimate building height at a spatial resolution of 2.5 m. We propose a multi-spectral, multi-view, and multi-task deep network (called M3Net) for building height estimation, where ZY-3 multi-spectral and multi-view images are fused in a multi-task learning framework. By preprocessing the data from [Amap](https://amap.com) (details can be seen in the Section 2 of the paper), we obtained 4723 samples from the 42 cities (Table 1), and randomly selected 70%, 10%, and 20% of them for training, validation, and testing, respectively. Paper link ([website](https://www.sciencedirect.com/science/article/pii/S0034425721003102)) 4 | 5 |
by Yinxia Cao, Xin Huang
6 | 7 | --------------------- 8 | ## Getting Started 9 | 10 | #### Requirements: 11 | - pytorch >= 1.8.0 (lower version can also work) 12 | - python >=3.6 13 | 14 | ## Prepare the training set 15 | 16 | See the sample directory. Due to the copyright problem, the whole dataset is not available publicly now. 17 | However, the reference height data from Amap can be accessible for research use. Here is the download [link](https://pan.baidu.com/s/1bBTvZcPM6PeOXxxW3j_jOg) and extraction code is 4gn2 ). The provided data is original one, and preprocessing is needed before use. 18 | ``` 19 | for the sample directory: 20 | --img: the multi-spectral images with four bands (B, G, R, and NIR) 21 | --lab: the building height (unit: meter) 22 | --lab_floor: the number of floors of buildings 23 | --tlc: the multi-view images with three bands (nadir, forward, and backward viewing angles) 24 | ``` 25 | Note that it is a good start to use the open ZY3 data from the ISPRS organization, see [link](https://www.isprs.org/data/zy-3/Default-HongKong-StMaxime.aspx). 26 | Take Hong Kong, China for example: 27 | ![image](https://user-images.githubusercontent.com/39206462/158020784-6eb7d27e-6d93-4c42-b211-17d543675ba7.png) 28 | This image can be used to test the performance of the pretrained building height model. 29 | 30 | ## Preprocess ZY-3 images 31 | - References can be seen in https://www.cnblogs.com/enviidl/p/16541009.html 32 | - One-by-one steps: ortho-rectification, image-to-image registration, pan-sharpening, radiometric correction (i.e., quick atmospheric correction (QUAC)), and image cropping. 33 | - Software: ENVI 5.3 34 | - The resolution of all images at each step is set to 2.5 m. 35 | - The detailed procedures are shown below: 36 | #### 1. ortho-rectification 37 | Apply the ENVI tool called `RPC orthorectification workflow` to all ZY-3 images including multi-spectral and nadir, backward, and forward images. 38 | ![](asset/orthorectification.jpg) 39 | 40 | #### 2. image-to-image registration 41 | Apply the ENVI tool called `Image Registration workflow` to nadir image (as reference) and other images (as warp images). 42 | Thus, all warp images can be registered to the reference image. 43 | ![](asset/registration.jpg) 44 | 45 | #### 3. pan-sharpening 46 | Apply the ENVI tool called `Gram-Schmidt Pan Sharpening` to original multi-spectral and nadir images. 47 | Thus, the two images can be fused to generate high-resolution multi-spectral images. 48 | ![](asset/pansharpening.jpg) 49 | 50 | #### 4. radiometric correction 51 | Note that all original images from the data provider have been radiometrically corrected, but they still suffer from 52 | atmospheric effects. 53 | Thus, apply the ENVI tool called `quick atmospheric correction (QUAC)` to the fused multi-spectral images from step 3. 54 | ![](asset/quac.jpg) 55 | 56 | #### 5. image cropping 57 | All images should be cropped at the same size. 58 | Apply the ENVI tool called `layer stacking` to multi-spectral and multi-view images. 59 | ![](asset/crop.jpg) 60 | 61 | ## Predict the height model 62 | #### 1. download the pretrained weights in the `run` directory. 63 | #### 2. run the predict code and revise the path of data and weights. 64 | ``` 65 | data_path = r'sample' # the path of images 66 | resume = r'runs\tlcnetu_zy3bh\V1\finetune_298.tar' # the path of pretrained weights 67 | ``` 68 | - whole image 69 | use `jupyterlab` to run the following code: (first `pip install jupyterlab`, then type `jupyter lab` in the command prompt. 70 | ``` 71 | 0311_predict_tlcnetU_process_wholeimg.ipynb 72 | ``` 73 | 74 | - testset 75 | ``` 76 | python pred_zy3bh_tlcnetU.py # the proposed model with two encoders for multi-spectral and multi-view images 77 | python pred_zy3bh_tlcnetU_mux.py # the model with one encoder for multi-spectral images 78 | python pred_zy3bh_tlcnetU_tlc.py # the model with one encoder for multi-view images 79 | python pred_zy3bh_tlcnetU_tlcmux.py # the model with one encoder for the stacking image from multi-spectral and multi-view images along the channel dimension 80 | ``` 81 | #### 3. the predicted results can be seen in the `pred.rar` 82 | 83 | ## Postprocessing 84 | Project the building height into the footprint. 85 | ``` 86 | demo_deeppreed.m 87 | ``` 88 | 89 | ## Train the height model 90 | #### 1. Prepare your dataset 91 | #### 2. edit data path 92 | ``` 93 | python train_zy3bh_tlcnetU_loss.py 94 | ``` 95 | 96 | #### 3. Evaluate on test set 97 | see the pretrained model in directory runs/ 98 | ``` 99 | python evaluate.py 100 | ``` 101 | 102 | If there is any issue, please feel free to contact me. The email adress is yinxcao@163.com or yinxcao@whu.edu.cn, and [researchgate link](https://www.researchgate.net/profile/Yinxia-Cao.) 103 | 104 | ## Interesting application in Bangalore, India 105 | update on 2022.2.26 106 | We directly applied the trained model in China to Bangalore, and obtained amazing results as follows. 107 | 1. Results on the Bangalore 108 | ![image](https://user-images.githubusercontent.com/39206462/155845595-80a7cecb-ae88-4ef6-bcd2-f9dabaea6771.png) 109 | 2. Enlarged views 110 | ![image](https://user-images.githubusercontent.com/39206462/155845516-f891da88-a178-4fd6-9edc-8eb5bcb26278.png) 111 | 112 | Note that the acquisition dates of the ZY-3 images and Google images are different, as well as their spatial resolutions, 113 | and therefore,there are some differences between google images and our results. 114 | The above results show that our method outperforms random forest method, and shows rich details of buildings. 115 | 116 | 117 | ## Citation 118 | 119 | If you find this repo useful for your research, please consider citing the paper 120 | ``` 121 | @article{cao2021deep, 122 | title={A deep learning method for building height estimation using high-resolution multi-view imagery over urban areas: A case study of 42 Chinese cities}, 123 | author={Cao, Yinxia and Huang, Xin}, 124 | journal={Remote Sensing of Environment}, 125 | volume={264}, 126 | pages={112590}, 127 | year={2021}, 128 | publisher={Elsevier} 129 | } 130 | ``` 131 | ## Acknowledgement 132 | Thanks for advice from the supervisor [Xin Huang](https://scholar.google.com/citations?user=TS6FzEwAAAAJ&hl=zh-CN), Doctor [Mengmeng Li](https://scholar.google.com/citations?user=TwTgEzwAAAAJ&hl=en), Professor [Xuecao Li](https://scholar.google.com.hk/citations?user=r2p47SEAAAAJ&hl=zh-CN), and anonymous reviewers. 133 | ``` 134 | @article{mshahsemseg, 135 | Author = {Meet P Shah}, 136 | Title = {Semantic Segmentation Architectures Implemented in PyTorch.}, 137 | Journal = {https://github.com/meetshah1995/pytorch-semseg}, 138 | Year = {2017} 139 | } 140 | ``` 141 | -------------------------------------------------------------------------------- /asset/crop.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lauraset/BuildingHeightModel/d13d16a02dfdfa042d63492009a7cc8e46032170/asset/crop.jpg -------------------------------------------------------------------------------- /asset/orthorectification.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lauraset/BuildingHeightModel/d13d16a02dfdfa042d63492009a7cc8e46032170/asset/orthorectification.jpg -------------------------------------------------------------------------------- /asset/pansharpening.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lauraset/BuildingHeightModel/d13d16a02dfdfa042d63492009a7cc8e46032170/asset/pansharpening.jpg -------------------------------------------------------------------------------- /asset/quac.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lauraset/BuildingHeightModel/d13d16a02dfdfa042d63492009a7cc8e46032170/asset/quac.jpg -------------------------------------------------------------------------------- /asset/registration.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lauraset/BuildingHeightModel/d13d16a02dfdfa042d63492009a7cc8e46032170/asset/registration.jpg -------------------------------------------------------------------------------- /configs/fcn8s_pascal.yml: -------------------------------------------------------------------------------- 1 | model: 2 | arch: fcn8s 3 | data: 4 | dataset: pascal 5 | train_split: train_aug 6 | val_split: val 7 | img_rows: 'same' 8 | img_cols: 'same' 9 | path: /private/home/meetshah/datasets/VOC/060817/VOCdevkit/VOC2012/ 10 | sbd_path: /private/home/meetshah/datasets/VOC/benchmark_RELEASE/ 11 | training: 12 | train_iters: 300000 13 | batch_size: 1 14 | val_interval: 1000 15 | n_workers: 16 16 | print_interval: 50 17 | optimizer: 18 | name: 'sgd' 19 | lr: 1.0e-10 20 | weight_decay: 0.0005 21 | momentum: 0.99 22 | loss: 23 | name: 'cross_entropy' 24 | size_average: False 25 | lr_schedule: 26 | resume: fcn8s_pascal_best_model.pkl 27 | -------------------------------------------------------------------------------- /configs/frrnB_cityscapes.yml: -------------------------------------------------------------------------------- 1 | model: 2 | arch: frrnB 3 | data: 4 | dataset: cityscapes 5 | train_split: train 6 | val_split: val 7 | img_rows: 512 8 | img_cols: 1024 9 | path: /private/home/meetshah/misc_code/ps/data/VOCdevkit/VOC2012/ 10 | training: 11 | train_iters: 85000 12 | batch_size: 2 13 | val_interval: 500 14 | print_interval: 25 15 | optimizer: 16 | lr: 1.0e-4 17 | l_rate: 1.0e-4 18 | l_schedule: 19 | momentum: 0.99 20 | weight_decay: 0.0005 21 | resume: frrnB_cityscapes_best_model.pkl 22 | visdom: False 23 | -------------------------------------------------------------------------------- /configs/scratch.yml: -------------------------------------------------------------------------------- 1 | model: 2 | arch: unet 3 | data: 4 | dataset: diydata 5 | train_split: train_aug 6 | val_split: val 7 | img_rows: 'same' 8 | img_cols: 'same' 9 | path: E:\Yinxcao\ningbo\fcn\data 10 | n_class: 7 11 | training: 12 | epochs: 200 13 | batch_size: 16 14 | val_interval: 2 15 | n_workers: 16 16 | print_interval: 1 17 | optimizer: 18 | name: 'sgd' 19 | lr: 1.0e-3 20 | weight_decay: 0.0005 21 | momentum: 0.99 22 | loss: 23 | name: 'cross_entropy' 24 | size_average: False 25 | lr_schedule: 1 26 | learning_rate: 0.0001 27 | resume: 'runs\\scratch\\V1\\finetune_120.tar' 28 | device: 'cuda' 29 | savepath: 'runs\\scratch\\V1\\pred' 30 | -------------------------------------------------------------------------------- /configs/tlcnet_zy3bh.yml: -------------------------------------------------------------------------------- 1 | model: 2 | arch: tlcnet 3 | data: 4 | dataset: diydata 5 | train_split: train_aug 6 | val_split: val 7 | img_rows: 256 8 | img_cols: 'same' 9 | path: F:\yinxcao\sample 10 | n_class: 1 11 | n_maxdisp: 192 12 | training: 13 | epochs: 200 14 | batch_size: 8 15 | val_interval: 2 16 | n_workers: 16 17 | print_interval: 1 18 | optimizer: 19 | name: 'sgd' 20 | lr: 1.0e-4 21 | weight_decay: 0.0005 22 | momentum: 0.99 23 | loss: 24 | name: 'cross_entropy' 25 | size_average: True 26 | lr_schedule: 1 27 | learning_rate: 0.001 28 | resume: '' 29 | device: 'cuda' 30 | augmentation: 31 | rflip: 0.5 32 | vflip: 0.5 33 | rotate: 30 34 | savepath: 'runs\\tlcnet_zy3bh' 35 | -------------------------------------------------------------------------------- /configs/tlcnetu_zy3bh.yml: -------------------------------------------------------------------------------- 1 | model: 2 | arch: tlcnetu 3 | data: 4 | dataset: diydata 5 | train_split: train_aug 6 | val_split: val 7 | img_rows: 400 8 | img_cols: 'same' 9 | path: D:\yinxcao\height\sample 10 | n_class: 1 11 | n_maxdisp: 192 12 | training: 13 | epochs: 300 14 | batch_size: 16 15 | learning_rate: 0.001 16 | resume: '' 17 | device: 'cuda' 18 | -------------------------------------------------------------------------------- /configs/tlcnetu_zy3bh_equalweight.yml: -------------------------------------------------------------------------------- 1 | model: 2 | arch: tlcnetu 3 | data: 4 | dataset: diydata 5 | train_split: train_aug 6 | val_split: val 7 | img_rows: 400 8 | img_cols: 'same' 9 | path: D:\yinxcao\height\sample 10 | n_class: 1 11 | n_maxdisp: 192 12 | training: 13 | epochs: 300 14 | batch_size: 16 15 | learning_rate: 0.001 16 | resume: '' 17 | device: 'cuda' 18 | -------------------------------------------------------------------------------- /configs/tlcnetu_zy3bh_mux.yml: -------------------------------------------------------------------------------- 1 | model: 2 | arch: tlcnetumux 3 | data: 4 | dataset: diydata 5 | train_split: train_aug 6 | val_split: val 7 | img_rows: 400 8 | img_cols: 'same' 9 | path: D:\yinxcao\height\sample 10 | n_class: 1 11 | n_maxdisp: 192 12 | training: 13 | epochs: 300 14 | batch_size: 16 15 | learning_rate: 0.001 16 | resume: '' 17 | device: 'cuda' 18 | -------------------------------------------------------------------------------- /configs/tlcnetu_zy3bh_tlc.yml: -------------------------------------------------------------------------------- 1 | model: 2 | arch: tlcnetutlc 3 | data: 4 | dataset: diydata 5 | train_split: train_aug 6 | val_split: val 7 | img_rows: 400 8 | img_cols: 'same' 9 | path: D:\yinxcao\height\sample 10 | n_class: 1 11 | n_maxdisp: 192 12 | training: 13 | epochs: 300 14 | batch_size: 16 15 | learning_rate: 0.001 16 | resume: '' 17 | device: 'cuda' 18 | -------------------------------------------------------------------------------- /configs/tlcnetu_zy3bh_tlcmux.yml: -------------------------------------------------------------------------------- 1 | model: 2 | arch: tlcnetutlcmux 3 | data: 4 | dataset: diydata 5 | train_split: train_aug 6 | val_split: val 7 | img_rows: 400 8 | img_cols: 'same' 9 | path: D:\yinxcao\height\sample 10 | n_class: 1 11 | n_maxdisp: 192 12 | training: 13 | epochs: 300 14 | batch_size: 16 15 | learning_rate: 0.001 16 | resume: '' 17 | device: 'cuda' 18 | -------------------------------------------------------------------------------- /configs/tlcnetu_zy3bh_us.yml: -------------------------------------------------------------------------------- 1 | model: 2 | arch: tlcnetu 3 | data: 4 | dataset: diydata 5 | train_split: train_aug 6 | val_split: val 7 | img_rows: 400 8 | img_cols: 'same' 9 | path: D:\yinxcao\height\sample_us 10 | n_class: 1 11 | n_maxdisp: 192 12 | training: 13 | epochs: 300 14 | batch_size: 16 15 | learning_rate: 0.001 16 | resume: '' 17 | device: 'cuda' 18 | -------------------------------------------------------------------------------- /configs/tlcnetu_zy3bh_us_cn.yml: -------------------------------------------------------------------------------- 1 | model: 2 | arch: tlcnetu 3 | data: 4 | dataset: diydata 5 | train_split: train_aug 6 | val_split: val 7 | img_rows: 400 8 | img_cols: 'same' 9 | path: D:\yinxcao\height\sample 10 | path_us: D:\yinxcao\height\sample_us 11 | n_class: 1 12 | n_maxdisp: 192 13 | training: 14 | epochs: 300 15 | batch_size: 16 16 | learning_rate: 0.001 17 | resume: '' 18 | device: 'cuda' 19 | -------------------------------------------------------------------------------- /configs/tlcnetu_zy3bh_us_pre_loss.yml: -------------------------------------------------------------------------------- 1 | model: 2 | arch: tlcnetu 3 | data: 4 | dataset: diydata 5 | train_split: train_aug 6 | val_split: val 7 | img_rows: 400 8 | img_cols: 'same' 9 | path: D:\yinxcao\height\sample_us 10 | n_class: 1 11 | n_maxdisp: 192 12 | training: 13 | epochs: 150 14 | batch_size: 16 15 | learning_rate: 0.001 16 | resume: 'runs\\tlcnetu_zy3bh\\V1\\finetune_298.tar' 17 | device: 'cuda' 18 | -------------------------------------------------------------------------------- /configs/unet_zy3bh.yml: -------------------------------------------------------------------------------- 1 | model: 2 | arch: unet 3 | in_channels: 7 4 | data: 5 | dataset: diydata 6 | train_split: train_aug 7 | val_split: val 8 | img_rows: 400 9 | img_cols: 'same' 10 | path: D:\yinxcao\height\sample 11 | n_class: 1 12 | n_maxdisp: 192 13 | training: 14 | epochs: 300 15 | batch_size: 16 16 | learning_rate: 0.001 17 | resume: '' 18 | device: 'cuda' 19 | 20 | -------------------------------------------------------------------------------- /configs/unet_zy3bh_mux.yml: -------------------------------------------------------------------------------- 1 | model: 2 | arch: unet 3 | in_channels: 4 4 | data: 5 | dataset: diydata 6 | train_split: train_aug 7 | val_split: val 8 | img_rows: 400 9 | img_cols: 'same' 10 | path: D:\yinxcao\height\sample 11 | n_class: 1 12 | n_maxdisp: 192 13 | training: 14 | epochs: 300 15 | batch_size: 16 16 | learning_rate: 0.001 17 | resume: '' 18 | device: 'cuda' 19 | 20 | -------------------------------------------------------------------------------- /configs/unet_zy3bh_tlc.yml: -------------------------------------------------------------------------------- 1 | model: 2 | arch: unet 3 | in_channels: 3 4 | data: 5 | dataset: diydata 6 | train_split: train_aug 7 | val_split: val 8 | img_rows: 400 9 | img_cols: 'same' 10 | path: D:\yinxcao\height\sample 11 | n_class: 1 12 | n_maxdisp: 192 13 | training: 14 | epochs: 300 15 | batch_size: 16 16 | learning_rate: 0.001 17 | resume: '' 18 | device: 'cuda' 19 | 20 | -------------------------------------------------------------------------------- /configs/val_tlcnetu_zy3bh.yml: -------------------------------------------------------------------------------- 1 | model: 2 | arch: tlcnetu 3 | data: 4 | dataset: diydata 5 | train_split: train_aug 6 | val_split: val 7 | img_rows: 400 8 | img_cols: 'same' 9 | path: D:\cn\sample 10 | n_class: 1 11 | n_maxdisp: 192 12 | training: 13 | epochs: 300 14 | batch_size: 16 15 | learning_rate: 0.001 16 | resume: 'runs\\tlcnetu_zy3bh\\V4\\finetune_300.tar' 17 | device: 'cuda' 18 | savepath: D:\cn\pred 19 | -------------------------------------------------------------------------------- /configs/val_tlcnetu_zy3bh_t1.yml: -------------------------------------------------------------------------------- 1 | model: 2 | arch: tlcnetu 3 | data: 4 | dataset: diydata 5 | train_split: train_aug 6 | val_split: val 7 | img_rows: 400 8 | img_cols: 'same' 9 | path1: F:\yinxcao\sample 10 | path2: D:\cn\sample 11 | n_class: 1 12 | n_maxdisp: 192 13 | training: 14 | epochs: 300 15 | batch_size: 16 16 | learning_rate: 0.001 17 | resume: 'runs\\tlcnetu_zy3bh\\V4\\finetune_300.tar' 18 | device: 'cuda' 19 | savepath: D:\cn\pred_t1 20 | -------------------------------------------------------------------------------- /configs/val_tlcnetu_zy3bh_testus.yml: -------------------------------------------------------------------------------- 1 | model: 2 | arch: tlcnetu 3 | data: 4 | dataset: diydata 5 | train_split: train_aug 6 | val_split: val 7 | img_rows: 400 8 | img_cols: 'same' 9 | path: E:\yinxcao\us\sample 10 | n_class: 1 11 | n_maxdisp: 192 12 | training: 13 | epochs: 300 14 | batch_size: 16 15 | learning_rate: 0.001 16 | resume: 'runs\\tlcnetu_zy3bh\\V4\\finetune_300.tar' 17 | device: 'cuda' 18 | savepath: D:\cn\pred_us 19 | -------------------------------------------------------------------------------- /configs/val_tlcnetu_zy3bh_us.yml: -------------------------------------------------------------------------------- 1 | model: 2 | arch: tlcnetu 3 | data: 4 | dataset: diydata 5 | train_split: train_aug 6 | val_split: val 7 | img_rows: 400 8 | img_cols: 'same' 9 | path: E:\yinxcao\us\sample 10 | n_class: 1 11 | n_maxdisp: 192 12 | training: 13 | epochs: 300 14 | batch_size: 16 15 | learning_rate: 0.001 16 | resume: 'runs\\tlcnetu_zy3bh_us_loss\\V1\\finetune_300.tar' 17 | device: 'cuda' 18 | savepath: '' 19 | 20 | -------------------------------------------------------------------------------- /configs/val_tlcnetu_zy3bh_us_cn.yml: -------------------------------------------------------------------------------- 1 | model: 2 | arch: tlcnetu 3 | data: 4 | dataset: diydata 5 | train_split: train_aug 6 | val_split: val 7 | img_rows: 400 8 | img_cols: 'same' 9 | path: E:\yinxcao\us\sample 10 | n_class: 1 11 | n_maxdisp: 192 12 | training: 13 | epochs: 300 14 | batch_size: 16 15 | learning_rate: 0.001 16 | resume: 'runs\\tlcnetu_zy3bh_us_cn_loss\\V1\\finetune_300.tar' 17 | device: 'cuda' 18 | savepath: '' 19 | 20 | -------------------------------------------------------------------------------- /configs/val_tlcnetu_zy3bh_us_onlycn.yml: -------------------------------------------------------------------------------- 1 | model: 2 | arch: tlcnetu 3 | data: 4 | dataset: diydata 5 | train_split: train_aug 6 | val_split: val 7 | img_rows: 400 8 | img_cols: 'same' 9 | path: E:\yinxcao\us\sample 10 | n_class: 1 11 | n_maxdisp: 192 12 | training: 13 | epochs: 300 14 | batch_size: 16 15 | learning_rate: 0.001 16 | resume: 'runs\\tlcnetu_zy3bh\\V4\\finetune_300.tar' 17 | device: 'cuda' 18 | savepath: '' 19 | 20 | -------------------------------------------------------------------------------- /configs/val_tlcnetu_zy3bh_us_pre.yml: -------------------------------------------------------------------------------- 1 | model: 2 | arch: tlcnetu 3 | data: 4 | dataset: diydata 5 | train_split: train_aug 6 | val_split: val 7 | img_rows: 400 8 | img_cols: 'same' 9 | path: E:\yinxcao\us\sample 10 | n_class: 1 11 | n_maxdisp: 192 12 | training: 13 | epochs: 300 14 | batch_size: 16 15 | learning_rate: 0.001 16 | resume: 'runs\\tlcnetu_zy3bh_us_pre_loss\\V2\\finetune_150.tar' 17 | device: 'cuda' 18 | savepath: '' 19 | 20 | -------------------------------------------------------------------------------- /demo_deeppred.m: -------------------------------------------------------------------------------- 1 | % iroot = 'D:\yinxcao\height\pred_deep\'; 2 | iroot = ''; 3 | 4 | fcode ='ningbo'; 5 | % fcode='shenzhen1'; 6 | % fcode='shenzhen2'; 7 | 8 | for i=1%3:num 9 | tic; 10 | file = [iroot, fcode, '\pred'];%fullfile(filelist(i).folder, filelist(i).name); 11 | ipath=[file,'\']; % the path of data 12 | respath=[file,'\']; 13 | % if ~isfolder(respath) 14 | % mkdir(respath) 15 | % end 16 | % path0= [respath, 'predtlcnetu_200_segh']; 17 | path1 = [respath, 'predtlcnetu_200_obj1.tif']; 18 | % path2 = [respath, 'predtlcnetu_200_obj1000']; 19 | % path3 =[respath, 'predtlcnetu_200_obj500']; 20 | % path4 = [respath, 'predtlcnetu_200_obj100']; 21 | % if isfile(path0) 22 | % continue; 23 | % end 24 | % 0. read 2.5 m deep prediction 25 | [predheight, R]=geotiffread([ipath, 'predtlcnetu_200.tif']);%predheight=predheight'; 26 | predseg=imread([ipath, 'predtlcnetu_seg.tif']); 27 | %predseg=predseg'; 28 | info = geotiffinfo([ipath, 'predtlcnetu_200.tif']); 29 | 30 | pred=cat(3,predheight,single(predseg)); 31 | clear('predseg','predheight'); 32 | % 0. seg mask height 33 | % if ~isfile(path0) 34 | % disp('process seg.*height'); 35 | % nenviwrite(pred(:,:,1).*pred(:,:,2),path0); 36 | % end 37 | % 1. object processing 38 | pvalue=98; % 75 39 | ifun=@(block_struct) object_val_prcimgv2(block_struct.data,pvalue); 40 | predobject=blockproc(pred, [800,800],ifun); 41 | % add 42 | predobject = uint8(predobject); 43 | % 2. save object result 44 | if ~isfile(path1) 45 | disp('process object75 2.5'); 46 | geotiffwrite(path1, predobject, R,"GeoKeyDirectoryTag", ... 47 | info.GeoTIFFTags.GeoKeyDirectoryTag) 48 | % nenviwrite(predobject, path1); 49 | % imwrite(predobject, [path1,'.tif']); 50 | end 51 | % % 3. downsampling to any scale 52 | % if ~isfile(path2) 53 | % disp('process object75 1000'); 54 | % scale=1000; 55 | % preobj1000=func_resize(predobject, pred(:,:,2), 1, 2.5/scale); 56 | % nenviwrite(preobj1000, path2); 57 | % end 58 | % 59 | % %500m 60 | % if ~isfile(path3) 61 | % disp('process object75 500'); 62 | % scale=500; 63 | % preobj500=func_resize(predobject, pred(:,:,2), 1, 2.5/scale); 64 | % nenviwrite(preobj500, path3); 65 | % end 66 | % % 100m 67 | % if ~isfile(path4) 68 | % disp('process object75 100'); 69 | % scale=100; 70 | % preobj100=func_resize(predobject, pred(:,:,2), 1, 2.5/scale); 71 | % nenviwrite(preobj100, path4); 72 | % end 73 | 74 | toc; 75 | end 76 | 77 | -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | ''' 2 | test ningbo images 3 | ''' 4 | import os 5 | import yaml 6 | import shutil 7 | import torch 8 | import random 9 | import argparse 10 | import numpy as np 11 | 12 | from ptsemseg.models import get_model 13 | from ptsemseg.utils import get_logger 14 | from tensorboardX import SummaryWriter 15 | from ptsemseg.loader.diy_dataset import dataloaderbh 16 | import sklearn.metrics 17 | import matplotlib.pyplot as plt 18 | import tifffile as tif 19 | 20 | 21 | def main(cfg, writer, logger): 22 | 23 | # Setup device 24 | device = torch.device(cfg["training"]["device"]) 25 | 26 | # Setup Dataloader 27 | data_path = "sample" # cfg["data"]["path"] 28 | n_classes = cfg["data"]["n_class"] 29 | n_maxdisp = cfg["data"]["n_maxdisp"] 30 | batch_size = cfg["training"]["batch_size"] 31 | epochs = cfg["training"]["epochs"] 32 | learning_rate = cfg["training"]["learning_rate"] 33 | patchsize = cfg["data"]["img_rows"] 34 | 35 | _, _, valimg, vallab = dataloaderbh(data_path) 36 | 37 | # Setup Model 38 | model = get_model(cfg["model"], n_maxdisp=n_maxdisp, n_classes=n_classes).to(device) 39 | if torch.cuda.device_count() > 1: 40 | model = torch.nn.DataParallel(model, device_ids=range(torch.cuda.device_count())) 41 | 42 | #resume = cfg["training"]["resume"] 43 | resume = r'runs\tlcnetu_zy3bh\V1\finetune.tar' 44 | if os.path.isfile(resume): 45 | print("=> loading checkpoint '{}'".format(resume)) 46 | checkpoint = torch.load(resume) 47 | model.load_state_dict(checkpoint['state_dict']) 48 | # optimizer.load_state_dict(checkpoint['optimizer']) 49 | print("=> loaded checkpoint '{}' (epoch {})" 50 | .format(resume, checkpoint['epoch'])) 51 | else: 52 | print("=> no checkpoint found at resume") 53 | print("=> Will start from scratch.") 54 | 55 | model.eval() 56 | 57 | for idx, imgpath in enumerate(valimg[0:20]): 58 | name = os.path.basename(vallab[idx]) 59 | respath = os.path.join(cfg["savepath"],'pred'+name) 60 | y_true = tif.imread(vallab[idx]) 61 | y_true = y_true.astype(np.int16)*3 62 | # random crop: test and train is the same 63 | mux = tif.imread(imgpath[0])/10000 # convert to surface reflectance (SR): 0-1 64 | tlc = tif.imread(imgpath[1])/10000 # stretch to 0-1 65 | 66 | offset = mux.shape[0] - patchsize 67 | x1 = random.randint(0, offset) 68 | y1 = random.randint(0, offset) 69 | mux = mux[x1:x1 + patchsize, y1:y1 + patchsize, :] 70 | tlc = tlc[x1:x1 + patchsize, y1:y1 + patchsize, :] 71 | y_true = y_true[x1:x1 + patchsize, y1:y1 + patchsize] 72 | 73 | img = np.concatenate((mux, tlc), axis=2) 74 | img[img > 1] = 1 # ensure data range is 0-1 75 | # remove tlc 76 | # img[:,:,4:] = 0 77 | 78 | img = img.transpose((2, 0, 1)) 79 | img = np.expand_dims(img, 0) 80 | img = torch.from_numpy(img).float() 81 | y_res = model(img.to(device)) 82 | 83 | y_pred = y_res[0] # height 84 | y_pred = y_pred.cpu().detach().numpy() 85 | y_pred = np.squeeze(y_pred) 86 | rmse = myrmse(y_true, y_pred) 87 | 88 | y_seg = y_res[1] # seg 89 | y_seg = y_seg.cpu().detach().numpy() 90 | y_seg = np.argmax(y_seg.squeeze(), axis=0) # C H W=> H W 91 | precision, recall, f1score = metricsperclass(y_true, y_seg, value=1) # 92 | print('rmse: %.3f, segerror: ua %.3f, pa %.3f, f1 %.3f'%(rmse, precision, recall, f1score)) 93 | 94 | tif.imsave((os.path.join(cfg["savepath"],'mux'+name)), mux) 95 | tif.imsave( (os.path.join(cfg["savepath"], 'ref' + name)), y_true) 96 | tif.imsave( (os.path.join(cfg["savepath"], 'pred' + name)), y_pred) 97 | tif.imsave((os.path.join(cfg["savepath"], 'seg' + name)), y_seg.astype(np.uint8)) 98 | 99 | # 100 | # color encode: change to the 101 | # get color info 102 | # _, color_values = get_colored_info('class_dict.csv') 103 | # prediction = color_encode(y_pred, color_values) 104 | # label = color_encode(y_true, color_values) 105 | 106 | # plt.subplot(131) 107 | # plt.title('Image', fontsize='large', fontweight='bold') 108 | # plt.imshow(mux[:, :, 0:3]/1000) 109 | # plt.subplot(132) 110 | # plt.title('Ref', fontsize='large', fontweight='bold') 111 | # plt.imshow(y_true) 112 | # # plt.subplot(143) 113 | # # plt.title('Pred', fontsize='large', fontweight='bold') 114 | # # plt.imshow(prediction) 115 | # plt.subplot(133) 116 | # plt.title('Pred %.3f'%scores, fontsize='large', fontweight='bold') 117 | # plt.imshow(y_pred) 118 | # plt.savefig(os.path.join(cfg["savepath"], 'fig'+name)) 119 | # plt.close() 120 | 121 | 122 | def gray2rgb(image): 123 | res=np.zeros((image.shape[0], image.shape[1], 3)) 124 | res[ :, :, 0] = image.copy() 125 | res[ :, :, 1] = image.copy() 126 | res[ :, :, 2] = image.copy() 127 | return res 128 | 129 | 130 | def metrics(y_true, y_pred, ignorevalue=0): 131 | y_true = y_true.flatten() 132 | y_pred = y_pred.flatten() 133 | maskid = np.where(y_true!=ignorevalue) 134 | y_true = y_true[maskid] 135 | y_pred = y_pred[maskid] 136 | accuracy = sklearn.metrics.accuracy_score(y_true, y_pred) 137 | kappa = sklearn.metrics.cohen_kappa_score(y_true, y_pred) 138 | f1_micro = sklearn.metrics.f1_score(y_true, y_pred, average="micro") 139 | f1_macro = sklearn.metrics.f1_score(y_true, y_pred, average="macro") 140 | f1_weighted = sklearn.metrics.f1_score(y_true, y_pred, average="weighted") 141 | recall_micro = sklearn.metrics.recall_score(y_true, y_pred, average="micro") 142 | recall_macro = sklearn.metrics.recall_score(y_true, y_pred, average="macro") 143 | recall_weighted = sklearn.metrics.recall_score(y_true, y_pred, average="weighted") 144 | precision_micro = sklearn.metrics.precision_score(y_true, y_pred, average="micro") 145 | precision_macro = sklearn.metrics.precision_score(y_true, y_pred, average="macro") 146 | precision_weighted = sklearn.metrics.precision_score(y_true, y_pred, average="weighted") 147 | 148 | return dict( 149 | accuracy=accuracy, 150 | kappa=kappa, 151 | f1_micro=f1_micro, 152 | f1_macro=f1_macro, 153 | f1_weighted=f1_weighted, 154 | recall_micro=recall_micro, 155 | recall_macro=recall_macro, 156 | recall_weighted=recall_weighted, 157 | precision_micro=precision_micro, 158 | precision_macro=precision_macro, 159 | precision_weighted=precision_weighted, 160 | ) 161 | 162 | def myrmse(y_true, ypred): 163 | diff=y_true.flatten()-ypred.flatten() 164 | return np.sqrt(np.mean(diff*diff)) 165 | 166 | 167 | def metricsperclass(y_true, y_pred, value): 168 | y_pred = y_pred.flatten() 169 | y_true = np.where(y_true>0, np.ones_like(y_true), np.zeros_like(y_true)).flatten() 170 | 171 | tp=len(np.where((y_true==value) & (y_pred==value))[0]) 172 | tn=len(np.where(y_true==value)[0]) 173 | fn = len(np.where(y_pred == value)[0]) 174 | precision = tp/(1e-10+fn) 175 | recall = tp/(1e-10+tn) 176 | f1score = 2*precision*recall/(precision+recall+1e-10) 177 | return precision, recall, f1score 178 | 179 | 180 | if __name__ == "__main__": 181 | parser = argparse.ArgumentParser(description="config") 182 | parser.add_argument( 183 | "--config", 184 | nargs="?", 185 | type=str, 186 | default="configs/tlcnetu_zy3bh.yml", 187 | help="Configuration file to use", 188 | ) 189 | 190 | args = parser.parse_args() 191 | 192 | with open(args.config) as fp: 193 | cfg = yaml.load(fp) 194 | 195 | #run_id = random.randint(1, 100000) 196 | logdir = os.path.join("runs", os.path.basename(args.config)[:-4], "V1") 197 | writer = SummaryWriter(log_dir=logdir) 198 | 199 | print("RUNDIR: {}".format(logdir)) 200 | shutil.copy(args.config, logdir) 201 | 202 | logger = get_logger(logdir) 203 | logger.info("Let the games begin") 204 | 205 | main(cfg, writer, logger) 206 | -------------------------------------------------------------------------------- /object_val_prcimgv2.m: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lauraset/BuildingHeightModel/d13d16a02dfdfa042d63492009a7cc8e46032170/object_val_prcimgv2.m -------------------------------------------------------------------------------- /pred.rar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lauraset/BuildingHeightModel/d13d16a02dfdfa042d63492009a7cc8e46032170/pred.rar -------------------------------------------------------------------------------- /pred_zy3bh_tlcnetU.py: -------------------------------------------------------------------------------- 1 | ''' 2 | 2020.12.28 validate us samples 3 | ''' 4 | 5 | import os 6 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' 7 | 8 | import os 9 | import torch 10 | from tqdm import tqdm 11 | import numpy as np 12 | import tifffile as tif 13 | 14 | from torch.utils import data 15 | from ptsemseg.models import TLCNetU 16 | from ptsemseg.loader.diy_dataset import dataloaderbh_testall 17 | from ptsemseg.loader.diyloader import myImageFloder 18 | from ptsemseg.metrics import heightacc 19 | 20 | def main(): 21 | 22 | # Setup device 23 | device = 'cuda' 24 | 25 | # Setup Dataloader 26 | data_path = r'sample' 27 | batch_size = 16 28 | # Load dataset 29 | testimg, testlab, nameid = dataloaderbh_testall(data_path, [0,0,1]) # all images for testing 30 | 31 | testdataloader = torch.utils.data.DataLoader( 32 | myImageFloder(testimg, testlab, num=16), 33 | batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True) 34 | 35 | # Setup Model 36 | model = TLCNetU(n_classes=1).to(device) 37 | if torch.cuda.device_count() > 1: 38 | model = torch.nn.DataParallel(model, device_ids=range(torch.cuda.device_count())) 39 | 40 | # print the model 41 | start_epoch = 0 42 | resume = r'runs\tlcnetu_zy3bh\V1\finetune_298.tar' 43 | if os.path.isfile(resume): 44 | print("=> loading checkpoint '{}'".format(resume)) 45 | checkpoint = torch.load(resume) 46 | model.load_state_dict(checkpoint['state_dict']) 47 | # optimizer.load_state_dict(checkpoint['optimizer']) 48 | print("=> loaded checkpoint '{}' (epoch {})" 49 | .format(resume, checkpoint['epoch'])) 50 | start_epoch = checkpoint['epoch'] 51 | else: 52 | print("=> no checkpoint found at resume") 53 | print("=> Will start from scratch.") 54 | return 55 | 56 | model.eval() 57 | acc = heightacc() 58 | counts = 0 59 | respath = os.path.dirname(os.path.dirname(resume)).replace('runs', 'pred') 60 | if not os.path.exists(respath): 61 | os.makedirs(respath) 62 | 63 | with torch.no_grad(): 64 | for x, y_true in tqdm(testdataloader): 65 | y_pred, y_seg = model.forward(x.to(device)) 66 | y_pred = y_pred.cpu().detach().numpy() 67 | 68 | acc.update(y_pred, y_true.numpy(), x.shape[0]) 69 | 70 | # save to tif 71 | y_pred = np.squeeze(y_pred, axis=1) # B H W 72 | y_seg = np.argmax(y_seg.cpu().numpy(), axis=1).astype(np.uint8) # B H W 73 | count = x.shape[0] 74 | names = nameid[counts:counts+count] 75 | for k in range(count): 76 | tif.imsave((os.path.join(respath,'pred_'+names[k]+'.tif')), y_pred[k]) 77 | tif.imsave((os.path.join(respath,'seg_'+names[k]+'.tif')), y_seg[k]) 78 | tif.imsave((os.path.join(respath, 'seg_' + names[k] + '_clr.tif')), y_seg[k] * 255) 79 | counts += count 80 | 81 | res = acc.getacc() 82 | print('r2, rmse, mae, se') 83 | print('%.6f %.6f %.6f %.6f' % (res[0], res[1], res[2], res[3])) 84 | print(res) 85 | 86 | 87 | if __name__ == "__main__": 88 | main() 89 | -------------------------------------------------------------------------------- /pred_zy3bh_tlcnetU_mux.py: -------------------------------------------------------------------------------- 1 | ''' 2 | 2020.12.28 validate us samples 3 | ''' 4 | 5 | import os 6 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' 7 | 8 | import os 9 | import torch 10 | from tqdm import tqdm 11 | import numpy as np 12 | import tifffile as tif 13 | 14 | from torch.utils import data 15 | from ptsemseg.models import TLCNetUmux 16 | from ptsemseg.loader.diy_dataset import dataloaderbh_testall 17 | from ptsemseg.loader.diyloader import myImageFloder_mux 18 | from ptsemseg.metrics import heightacc 19 | 20 | def main(): 21 | 22 | # Setup device 23 | device = 'cuda' 24 | 25 | # Setup Dataloader 26 | data_path = r'sample' 27 | batch_size = 16 28 | # Load dataset 29 | testimg, testlab, nameid = dataloaderbh_testall(data_path,[0,0,1]) 30 | 31 | testdataloader = torch.utils.data.DataLoader( 32 | myImageFloder_mux(testimg, testlab), 33 | batch_size=batch_size, shuffle=False, num_workers=8, pin_memory=True) 34 | 35 | # Setup Model 36 | model = TLCNetUmux(n_classes=1).to(device) 37 | if torch.cuda.device_count() > 1: 38 | model = torch.nn.DataParallel(model, device_ids=range(torch.cuda.device_count())) 39 | 40 | # print the model 41 | start_epoch = 0 42 | resume = r'runs\tlcnetu_zy3bh_mux\V1\finetune_298.tar' 43 | if os.path.isfile(resume): 44 | print("=> loading checkpoint '{}'".format(resume)) 45 | checkpoint = torch.load(resume) 46 | model.load_state_dict(checkpoint['state_dict']) 47 | # optimizer.load_state_dict(checkpoint['optimizer']) 48 | print("=> loaded checkpoint '{}' (epoch {})" 49 | .format(resume, checkpoint['epoch'])) 50 | start_epoch = checkpoint['epoch'] 51 | else: 52 | print("=> no checkpoint found at resume") 53 | print("=> Will start from scratch.") 54 | return 55 | 56 | model.eval() 57 | acc = heightacc() 58 | counts = 0 59 | respath = os.path.dirname(os.path.dirname(resume)).replace('runs', 'pred') 60 | if not os.path.exists(respath): 61 | os.makedirs(respath) 62 | 63 | with torch.no_grad(): 64 | for x, y_true in tqdm(testdataloader): 65 | y_pred, y_seg = model.forward(x.to(device)) 66 | y_pred = y_pred.cpu().detach().numpy() 67 | 68 | acc.update(y_pred, y_true.numpy(), x.shape[0]) 69 | 70 | # save to tif 71 | y_pred = np.squeeze(y_pred, axis=1) # B H W 72 | y_seg = np.argmax(y_seg.cpu().numpy(), axis=1).astype(np.uint8) # B H W 73 | count = x.shape[0] 74 | names = nameid[counts:counts+count] 75 | for k in range(count): 76 | tif.imsave((os.path.join(respath,'pred_'+names[k]+'.tif')), y_pred[k]) 77 | tif.imsave((os.path.join(respath,'seg_'+names[k]+'.tif')), y_seg[k]) 78 | tif.imsave((os.path.join(respath, 'seg_' + names[k] + '_clr.tif')), y_seg[k] * 255) 79 | counts += count 80 | 81 | res = acc.getacc() 82 | print('r2, rmse, mae, se') 83 | print('%.6f %.6f %.6f %.6f' % (res[0], res[1], res[2], res[3])) 84 | print(res) 85 | 86 | 87 | if __name__ == "__main__": 88 | main() 89 | -------------------------------------------------------------------------------- /pred_zy3bh_tlcnetU_tlc.py: -------------------------------------------------------------------------------- 1 | ''' 2 | 2020.12.28 validate us samples 3 | ''' 4 | 5 | import os 6 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' 7 | 8 | import os 9 | import torch 10 | from tqdm import tqdm 11 | import numpy as np 12 | import tifffile as tif 13 | 14 | from torch.utils import data 15 | from ptsemseg.models import TLCNetUtlc 16 | from ptsemseg.loader.diy_dataset import dataloaderbh_testall 17 | from ptsemseg.loader.diyloader import myImageFloder_tlc 18 | from ptsemseg.metrics import heightacc 19 | 20 | def main(): 21 | 22 | # Setup device 23 | device = 'cuda' 24 | 25 | # Setup Dataloader 26 | data_path = r'sample' 27 | batch_size = 16 28 | # Load dataset 29 | testimg, testlab, nameid = dataloaderbh_testall(data_path,[0,0,1]) 30 | 31 | testdataloader = torch.utils.data.DataLoader( 32 | myImageFloder_tlc(testimg, testlab, num=16), 33 | batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True) 34 | 35 | # Setup Model 36 | model = TLCNetUtlc(n_classes=1).to(device) 37 | if torch.cuda.device_count() > 1: 38 | model = torch.nn.DataParallel(model, device_ids=range(torch.cuda.device_count())) 39 | 40 | # print the model 41 | start_epoch = 0 42 | resume = r'runs\tlcnetu_zy3bh_tlc\V1\finetune_298.tar' 43 | if os.path.isfile(resume): 44 | print("=> loading checkpoint '{}'".format(resume)) 45 | checkpoint = torch.load(resume) 46 | model.load_state_dict(checkpoint['state_dict']) 47 | # optimizer.load_state_dict(checkpoint['optimizer']) 48 | print("=> loaded checkpoint '{}' (epoch {})" 49 | .format(resume, checkpoint['epoch'])) 50 | start_epoch = checkpoint['epoch'] 51 | else: 52 | print("=> no checkpoint found at resume") 53 | print("=> Will start from scratch.") 54 | return 55 | 56 | model.eval() 57 | acc = heightacc() 58 | counts = 0 59 | respath = os.path.dirname(os.path.dirname(resume)).replace('runs', 'pred') 60 | if not os.path.exists(respath): 61 | os.makedirs(respath) 62 | 63 | with torch.no_grad(): 64 | for x, y_true in tqdm(testdataloader): 65 | y_pred, y_seg = model.forward(x.to(device)) 66 | y_pred = y_pred.cpu().detach().numpy() 67 | 68 | acc.update(y_pred, y_true.numpy(), x.shape[0]) 69 | 70 | # save to tif 71 | y_pred = np.squeeze(y_pred, axis=1) # B H W 72 | y_seg = np.argmax(y_seg.cpu().numpy(), axis=1).astype(np.uint8) # B H W 73 | count = x.shape[0] 74 | names = nameid[counts:counts+count] 75 | for k in range(count): 76 | tif.imsave((os.path.join(respath,'pred_'+names[k]+'.tif')), y_pred[k]) 77 | tif.imsave((os.path.join(respath,'seg_'+names[k]+'.tif')), y_seg[k]) 78 | tif.imsave((os.path.join(respath, 'seg_' + names[k] + '_clr.tif')), y_seg[k] * 255) 79 | counts += count 80 | 81 | res = acc.getacc() 82 | print('r2, rmse, mae, se') 83 | print('%.6f %.6f %.6f %.6f' % (res[0], res[1], res[2], res[3])) 84 | print(res) 85 | 86 | 87 | if __name__ == "__main__": 88 | main() 89 | -------------------------------------------------------------------------------- /pred_zy3bh_tlcnetU_tlcmux.py: -------------------------------------------------------------------------------- 1 | ''' 2 | 2020.12.28 validate us samples 3 | ''' 4 | 5 | import os 6 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' 7 | 8 | import os 9 | import torch 10 | from tqdm import tqdm 11 | import numpy as np 12 | import tifffile as tif 13 | 14 | from torch.utils import data 15 | from ptsemseg.models import TLCNetUtlcmux 16 | from ptsemseg.loader.diy_dataset import dataloaderbh_testall 17 | from ptsemseg.loader.diyloader import myImageFloder 18 | from ptsemseg.metrics import heightacc 19 | 20 | def main(): 21 | 22 | # Setup device 23 | device = 'cuda' 24 | 25 | # Setup Dataloader 26 | data_path = r'sample' 27 | batch_size = 16 28 | # Load dataset 29 | testimg, testlab, nameid = dataloaderbh_testall(data_path,[0,0,1]) 30 | 31 | testdataloader = torch.utils.data.DataLoader( 32 | myImageFloder(testimg, testlab, num=16), 33 | batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True) 34 | 35 | # Setup Model 36 | model = TLCNetUtlcmux(n_classes=1).to(device) 37 | if torch.cuda.device_count() > 1: 38 | model = torch.nn.DataParallel(model, device_ids=range(torch.cuda.device_count())) 39 | 40 | # print the model 41 | start_epoch = 0 42 | resume = r'runs\tlcnetu_zy3bh_tlcmux\V1\finetune_298.tar' 43 | if os.path.isfile(resume): 44 | print("=> loading checkpoint '{}'".format(resume)) 45 | checkpoint = torch.load(resume) 46 | model.load_state_dict(checkpoint['state_dict']) 47 | # optimizer.load_state_dict(checkpoint['optimizer']) 48 | print("=> loaded checkpoint '{}' (epoch {})" 49 | .format(resume, checkpoint['epoch'])) 50 | start_epoch = checkpoint['epoch'] 51 | else: 52 | print("=> no checkpoint found at resume") 53 | print("=> Will start from scratch.") 54 | return 55 | 56 | model.eval() 57 | acc = heightacc() 58 | counts = 0 59 | respath = os.path.dirname(os.path.dirname(resume)).replace('runs', 'pred') 60 | if not os.path.exists(respath): 61 | os.makedirs(respath) 62 | 63 | with torch.no_grad(): 64 | for x, y_true in tqdm(testdataloader): 65 | y_pred, y_seg = model.forward(x.to(device)) 66 | y_pred = y_pred.cpu().detach().numpy() 67 | 68 | acc.update(y_pred, y_true.numpy(), x.shape[0]) 69 | 70 | # save to tif 71 | y_pred = np.squeeze(y_pred, axis=1) # B H W 72 | y_seg = np.argmax(y_seg.cpu().numpy(), axis=1).astype(np.uint8) # B H W 73 | count = x.shape[0] 74 | names = nameid[counts:counts+count] 75 | for k in range(count): 76 | tif.imsave((os.path.join(respath,'pred_'+names[k]+'.tif')), y_pred[k]) 77 | tif.imsave((os.path.join(respath,'seg_'+names[k]+'.tif')), y_seg[k]) 78 | tif.imsave((os.path.join(respath, 'seg_' + names[k] + '_clr.tif')), y_seg[k] * 255) 79 | counts += count 80 | 81 | res = acc.getacc() 82 | print('r2, rmse, mae, se') 83 | print('%.6f %.6f %.6f %.6f' % (res[0], res[1], res[2], res[3])) 84 | print(res) 85 | 86 | 87 | if __name__ == "__main__": 88 | main() 89 | -------------------------------------------------------------------------------- /ptsemseg/augmentations/__init__.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from ptsemseg.augmentations.augmentations import ( 3 | AdjustContrast, 4 | AdjustGamma, 5 | AdjustBrightness, 6 | AdjustSaturation, 7 | AdjustHue, 8 | RandomCrop, 9 | RandomHorizontallyFlip, 10 | RandomVerticallyFlip, 11 | Scale, 12 | RandomSized, 13 | RandomSizedCrop, 14 | RandomRotate, 15 | RandomTranslate, 16 | CenterCrop, 17 | Compose, 18 | ) 19 | 20 | logger = logging.getLogger("ptsemseg") 21 | 22 | key2aug = { 23 | "gamma": AdjustGamma, 24 | "hue": AdjustHue, 25 | "brightness": AdjustBrightness, 26 | "saturation": AdjustSaturation, 27 | "contrast": AdjustContrast, 28 | "rcrop": RandomCrop, 29 | "hflip": RandomHorizontallyFlip, 30 | "vflip": RandomVerticallyFlip, 31 | "scale": Scale, 32 | "rsize": RandomSized, 33 | "rsizecrop": RandomSizedCrop, 34 | "rotate": RandomRotate, 35 | "translate": RandomTranslate, 36 | "ccrop": CenterCrop, 37 | } 38 | 39 | 40 | def get_composed_augmentations(aug_dict): 41 | if aug_dict is None: 42 | logger.info("Using No Augmentations") 43 | return None 44 | 45 | augmentations = [] 46 | for aug_key, aug_param in aug_dict.items(): 47 | augmentations.append(key2aug[aug_key](aug_param)) 48 | logger.info("Using {} aug with params {}".format(aug_key, aug_param)) 49 | return Compose(augmentations) 50 | -------------------------------------------------------------------------------- /ptsemseg/augmentations/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lauraset/BuildingHeightModel/d13d16a02dfdfa042d63492009a7cc8e46032170/ptsemseg/augmentations/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /ptsemseg/augmentations/__pycache__/augmentations.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lauraset/BuildingHeightModel/d13d16a02dfdfa042d63492009a7cc8e46032170/ptsemseg/augmentations/__pycache__/augmentations.cpython-36.pyc -------------------------------------------------------------------------------- /ptsemseg/augmentations/__pycache__/diyaugmentation.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lauraset/BuildingHeightModel/d13d16a02dfdfa042d63492009a7cc8e46032170/ptsemseg/augmentations/__pycache__/diyaugmentation.cpython-36.pyc -------------------------------------------------------------------------------- /ptsemseg/augmentations/augmentations.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numbers 3 | import random 4 | import numpy as np 5 | import torchvision.transforms.functional as tf 6 | 7 | from PIL import Image, ImageOps 8 | 9 | 10 | class Compose(object): 11 | def __init__(self, augmentations): 12 | self.augmentations = augmentations 13 | self.PIL2Numpy = False 14 | 15 | def __call__(self, img, mask): 16 | if isinstance(img, np.ndarray): 17 | img = Image.fromarray(img, mode="RGB") 18 | mask = Image.fromarray(mask, mode="L") 19 | self.PIL2Numpy = True 20 | 21 | assert img.size == mask.size 22 | for a in self.augmentations: 23 | img, mask = a(img, mask) 24 | 25 | if self.PIL2Numpy: 26 | img, mask = np.array(img), np.array(mask, dtype=np.uint8) 27 | 28 | return img, mask 29 | 30 | 31 | class RandomCrop(object): 32 | def __init__(self, size, padding=0): 33 | if isinstance(size, numbers.Number): 34 | self.size = (int(size), int(size)) 35 | else: 36 | self.size = size 37 | self.padding = padding 38 | 39 | def __call__(self, img, mask): 40 | if self.padding > 0: 41 | img = ImageOps.expand(img, border=self.padding, fill=0) 42 | mask = ImageOps.expand(mask, border=self.padding, fill=0) 43 | 44 | assert img.size == mask.size 45 | w, h = img.size 46 | th, tw = self.size 47 | if w == tw and h == th: 48 | return img, mask 49 | if w < tw or h < th: 50 | return (img.resize((tw, th), Image.BILINEAR), mask.resize((tw, th), Image.NEAREST)) 51 | 52 | x1 = random.randint(0, w - tw) 53 | y1 = random.randint(0, h - th) 54 | return (img.crop((x1, y1, x1 + tw, y1 + th)), mask.crop((x1, y1, x1 + tw, y1 + th))) 55 | 56 | 57 | class AdjustGamma(object): 58 | def __init__(self, gamma): 59 | self.gamma = gamma 60 | 61 | def __call__(self, img, mask): 62 | assert img.size == mask.size 63 | return tf.adjust_gamma(img, random.uniform(1, 1 + self.gamma)), mask 64 | 65 | 66 | class AdjustSaturation(object): 67 | def __init__(self, saturation): 68 | self.saturation = saturation 69 | 70 | def __call__(self, img, mask): 71 | assert img.size == mask.size 72 | return ( 73 | tf.adjust_saturation(img, random.uniform(1 - self.saturation, 1 + self.saturation)), 74 | mask, 75 | ) 76 | 77 | 78 | class AdjustHue(object): 79 | def __init__(self, hue): 80 | self.hue = hue 81 | 82 | def __call__(self, img, mask): 83 | assert img.size == mask.size 84 | return tf.adjust_hue(img, random.uniform(-self.hue, self.hue)), mask 85 | 86 | 87 | class AdjustBrightness(object): 88 | def __init__(self, bf): 89 | self.bf = bf 90 | 91 | def __call__(self, img, mask): 92 | assert img.size == mask.size 93 | return tf.adjust_brightness(img, random.uniform(1 - self.bf, 1 + self.bf)), mask 94 | 95 | 96 | class AdjustContrast(object): 97 | def __init__(self, cf): 98 | self.cf = cf 99 | 100 | def __call__(self, img, mask): 101 | assert img.size == mask.size 102 | return tf.adjust_contrast(img, random.uniform(1 - self.cf, 1 + self.cf)), mask 103 | 104 | 105 | class CenterCrop(object): 106 | def __init__(self, size): 107 | if isinstance(size, numbers.Number): 108 | self.size = (int(size), int(size)) 109 | else: 110 | self.size = size 111 | 112 | def __call__(self, img, mask): 113 | assert img.size == mask.size 114 | w, h = img.size 115 | th, tw = self.size 116 | x1 = int(round((w - tw) / 2.0)) 117 | y1 = int(round((h - th) / 2.0)) 118 | return (img.crop((x1, y1, x1 + tw, y1 + th)), mask.crop((x1, y1, x1 + tw, y1 + th))) 119 | 120 | 121 | class RandomHorizontallyFlip(object): 122 | def __init__(self, p): 123 | self.p = p 124 | 125 | def __call__(self, img, mask): 126 | if random.random() < self.p: 127 | return (img.transpose(Image.FLIP_LEFT_RIGHT), mask.transpose(Image.FLIP_LEFT_RIGHT)) 128 | return img, mask 129 | 130 | 131 | class RandomVerticallyFlip(object): 132 | def __init__(self, p): 133 | self.p = p 134 | 135 | def __call__(self, img, mask): 136 | if random.random() < self.p: 137 | return (img.transpose(Image.FLIP_TOP_BOTTOM), mask.transpose(Image.FLIP_TOP_BOTTOM)) 138 | return img, mask 139 | 140 | 141 | class RandomRotate(object): 142 | def __init__(self, degree): 143 | self.degree = degree 144 | 145 | def __call__(self, img, mask): 146 | rotate_degree = random.random() * 2 * self.degree - self.degree 147 | return ( 148 | tf.affine( 149 | img, 150 | translate=(0, 0), 151 | scale=1.0, 152 | angle=rotate_degree, 153 | resample=Image.BILINEAR, 154 | fillcolor=(0, 0, 0), 155 | shear=0.0, 156 | ), 157 | tf.affine( 158 | mask, 159 | translate=(0, 0), 160 | scale=1.0, 161 | angle=rotate_degree, 162 | resample=Image.NEAREST, 163 | fillcolor=250, 164 | shear=0.0, 165 | ), 166 | ) 167 | 168 | 169 | class RandomTranslate(object): 170 | def __init__(self, offset): 171 | # tuple (delta_x, delta_y) 172 | self.offset = offset 173 | 174 | def __call__(self, img, mask): 175 | assert img.size == mask.size 176 | x_offset = int(2 * (random.random() - 0.5) * self.offset[0]) 177 | y_offset = int(2 * (random.random() - 0.5) * self.offset[1]) 178 | 179 | x_crop_offset = x_offset 180 | y_crop_offset = y_offset 181 | if x_offset < 0: 182 | x_crop_offset = 0 183 | if y_offset < 0: 184 | y_crop_offset = 0 185 | 186 | cropped_img = tf.crop( 187 | img, 188 | y_crop_offset, 189 | x_crop_offset, 190 | img.size[1] - abs(y_offset), 191 | img.size[0] - abs(x_offset), 192 | ) 193 | 194 | if x_offset >= 0 and y_offset >= 0: 195 | padding_tuple = (0, 0, x_offset, y_offset) 196 | 197 | elif x_offset >= 0 and y_offset < 0: 198 | padding_tuple = (0, abs(y_offset), x_offset, 0) 199 | 200 | elif x_offset < 0 and y_offset >= 0: 201 | padding_tuple = (abs(x_offset), 0, 0, y_offset) 202 | 203 | elif x_offset < 0 and y_offset < 0: 204 | padding_tuple = (abs(x_offset), abs(y_offset), 0, 0) 205 | 206 | return ( 207 | tf.pad(cropped_img, padding_tuple, padding_mode="reflect"), 208 | tf.affine( 209 | mask, 210 | translate=(-x_offset, -y_offset), 211 | scale=1.0, 212 | angle=0.0, 213 | shear=0.0, 214 | fillcolor=250, 215 | ), 216 | ) 217 | 218 | 219 | class FreeScale(object): 220 | def __init__(self, size): 221 | self.size = tuple(reversed(size)) # size: (h, w) 222 | 223 | def __call__(self, img, mask): 224 | assert img.size == mask.size 225 | return (img.resize(self.size, Image.BILINEAR), mask.resize(self.size, Image.NEAREST)) 226 | 227 | 228 | class Scale(object): 229 | def __init__(self, size): 230 | self.size = size 231 | 232 | def __call__(self, img, mask): 233 | assert img.size == mask.size 234 | w, h = img.size 235 | if (w >= h and w == self.size) or (h >= w and h == self.size): 236 | return img, mask 237 | if w > h: 238 | ow = self.size 239 | oh = int(self.size * h / w) 240 | return (img.resize((ow, oh), Image.BILINEAR), mask.resize((ow, oh), Image.NEAREST)) 241 | else: 242 | oh = self.size 243 | ow = int(self.size * w / h) 244 | return (img.resize((ow, oh), Image.BILINEAR), mask.resize((ow, oh), Image.NEAREST)) 245 | 246 | 247 | class RandomSizedCrop(object): 248 | def __init__(self, size): 249 | self.size = size 250 | 251 | def __call__(self, img, mask): 252 | assert img.size == mask.size 253 | for attempt in range(10): 254 | area = img.size[0] * img.size[1] 255 | target_area = random.uniform(0.45, 1.0) * area 256 | aspect_ratio = random.uniform(0.5, 2) 257 | 258 | w = int(round(math.sqrt(target_area * aspect_ratio))) 259 | h = int(round(math.sqrt(target_area / aspect_ratio))) 260 | 261 | if random.random() < 0.5: 262 | w, h = h, w 263 | 264 | if w <= img.size[0] and h <= img.size[1]: 265 | x1 = random.randint(0, img.size[0] - w) 266 | y1 = random.randint(0, img.size[1] - h) 267 | 268 | img = img.crop((x1, y1, x1 + w, y1 + h)) 269 | mask = mask.crop((x1, y1, x1 + w, y1 + h)) 270 | assert img.size == (w, h) 271 | 272 | return ( 273 | img.resize((self.size, self.size), Image.BILINEAR), 274 | mask.resize((self.size, self.size), Image.NEAREST), 275 | ) 276 | 277 | # Fallback 278 | scale = Scale(self.size) 279 | crop = CenterCrop(self.size) 280 | return crop(*scale(img, mask)) 281 | 282 | 283 | class RandomSized(object): 284 | def __init__(self, size): 285 | self.size = size 286 | self.scale = Scale(self.size) 287 | self.crop = RandomCrop(self.size) 288 | 289 | def __call__(self, img, mask): 290 | assert img.size == mask.size 291 | 292 | w = int(random.uniform(0.5, 2) * img.size[0]) 293 | h = int(random.uniform(0.5, 2) * img.size[1]) 294 | 295 | img, mask = (img.resize((w, h), Image.BILINEAR), mask.resize((w, h), Image.NEAREST)) 296 | 297 | return self.crop(*self.scale(img, mask)) 298 | -------------------------------------------------------------------------------- /ptsemseg/augmentations/diyaugmentation.py: -------------------------------------------------------------------------------- 1 | ''' 2 | custom definition 3 | ''' 4 | import random 5 | import numpy as np 6 | from skimage.transform import rotate 7 | from skimage.exposure import rescale_intensity 8 | 9 | # from scipy.misc import imrotate 10 | # import cv2 11 | 12 | def my_segmentation_transforms(image, segmentation): 13 | # rotate 14 | if random.random() > 0.5: 15 | angle = (np.random.randint(11) + 1) * 15 16 | # angles=[30, 60, 120, 150, 45, 135] old ones 17 | # angle = random.choice(angles) 18 | image = rotate(image, angle) # 1: Bi-linear (default) 19 | segmentation = rotate(segmentation, angle, order=0) # Nearest-neighbor 20 | # segmentation = imrotate(segmentation, angle, interp='nearest') # old ones 21 | 22 | #flip left-right 23 | if random.random() > 0.5: 24 | image = np.fliplr(image) 25 | segmentation = np.fliplr(segmentation) 26 | 27 | #flip up-down 28 | if random.random() > 0.5: 29 | image = np.flipud(image) 30 | segmentation = np.flipud(segmentation) 31 | 32 | # brightness 33 | ratio=random.random() 34 | if ratio>0.5: 35 | image = rescale_intensity(image, out_range=(0, ratio)) #(0.5, 1) 36 | 37 | return image, segmentation 38 | 39 | 40 | def my_segmentation_transforms_crop(image, segmentation, th): 41 | # random crop 42 | h=image.shape[0] 43 | offset=h-th 44 | x1 = random.randint(0, offset) 45 | y1 = random.randint(0, offset) 46 | image = image[x1:x1+th, y1:y1+th,:] 47 | segmentation = segmentation[x1:x1+th, y1:y1+th] 48 | 49 | # rotate 50 | if random.random() > 0.5: 51 | angles=[30, 60, 120, 150, 45, 135] 52 | angle = random.choice(angles) 53 | image = rotate(image, angle) 54 | segmentation = rotate(segmentation, angle, order=0) # Nearest-neighbor 55 | # segmentation = imrotate(segmentation, angle, interp='nearest') # old ones 56 | 57 | #flip left-right 58 | if random.random() > 0.5: 59 | image = np.fliplr(image) 60 | segmentation = np.fliplr(segmentation) 61 | 62 | #flip up-down 63 | if random.random() > 0.5: 64 | image = np.flipud(image) 65 | segmentation = np.flipud(segmentation) 66 | 67 | return image, segmentation -------------------------------------------------------------------------------- /ptsemseg/loader/__init__.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | from ptsemseg.loader.pascal_voc_loader import pascalVOCLoader 4 | from ptsemseg.loader.camvid_loader import camvidLoader 5 | from ptsemseg.loader.ade20k_loader import ADE20KLoader 6 | from ptsemseg.loader.mit_sceneparsing_benchmark_loader import MITSceneParsingBenchmarkLoader 7 | from ptsemseg.loader.cityscapes_loader import cityscapesLoader 8 | from ptsemseg.loader.nyuv2_loader import NYUv2Loader 9 | from ptsemseg.loader.sunrgbd_loader import SUNRGBDLoader 10 | from ptsemseg.loader.mapillary_vistas_loader import mapillaryVistasLoader 11 | 12 | 13 | def get_loader(name): 14 | """get_loader 15 | 16 | :param name: 17 | """ 18 | return { 19 | "pascal": pascalVOCLoader, 20 | "camvid": camvidLoader, 21 | "ade20k": ADE20KLoader, 22 | "mit_sceneparsing_benchmark": MITSceneParsingBenchmarkLoader, 23 | "cityscapes": cityscapesLoader, 24 | "nyuv2": NYUv2Loader, 25 | "sunrgbd": SUNRGBDLoader, 26 | "vistas": mapillaryVistasLoader, 27 | }[name] 28 | -------------------------------------------------------------------------------- /ptsemseg/loader/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lauraset/BuildingHeightModel/d13d16a02dfdfa042d63492009a7cc8e46032170/ptsemseg/loader/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /ptsemseg/loader/__pycache__/ade20k_loader.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lauraset/BuildingHeightModel/d13d16a02dfdfa042d63492009a7cc8e46032170/ptsemseg/loader/__pycache__/ade20k_loader.cpython-36.pyc -------------------------------------------------------------------------------- /ptsemseg/loader/__pycache__/camvid_loader.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lauraset/BuildingHeightModel/d13d16a02dfdfa042d63492009a7cc8e46032170/ptsemseg/loader/__pycache__/camvid_loader.cpython-36.pyc -------------------------------------------------------------------------------- /ptsemseg/loader/__pycache__/cityscapes_loader.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lauraset/BuildingHeightModel/d13d16a02dfdfa042d63492009a7cc8e46032170/ptsemseg/loader/__pycache__/cityscapes_loader.cpython-36.pyc -------------------------------------------------------------------------------- /ptsemseg/loader/__pycache__/diy_dataset.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lauraset/BuildingHeightModel/d13d16a02dfdfa042d63492009a7cc8e46032170/ptsemseg/loader/__pycache__/diy_dataset.cpython-36.pyc -------------------------------------------------------------------------------- /ptsemseg/loader/__pycache__/diyloader.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lauraset/BuildingHeightModel/d13d16a02dfdfa042d63492009a7cc8e46032170/ptsemseg/loader/__pycache__/diyloader.cpython-36.pyc -------------------------------------------------------------------------------- /ptsemseg/loader/__pycache__/mapillary_vistas_loader.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lauraset/BuildingHeightModel/d13d16a02dfdfa042d63492009a7cc8e46032170/ptsemseg/loader/__pycache__/mapillary_vistas_loader.cpython-36.pyc -------------------------------------------------------------------------------- /ptsemseg/loader/__pycache__/mit_sceneparsing_benchmark_loader.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lauraset/BuildingHeightModel/d13d16a02dfdfa042d63492009a7cc8e46032170/ptsemseg/loader/__pycache__/mit_sceneparsing_benchmark_loader.cpython-36.pyc -------------------------------------------------------------------------------- /ptsemseg/loader/__pycache__/nyuv2_loader.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lauraset/BuildingHeightModel/d13d16a02dfdfa042d63492009a7cc8e46032170/ptsemseg/loader/__pycache__/nyuv2_loader.cpython-36.pyc -------------------------------------------------------------------------------- /ptsemseg/loader/__pycache__/pascal_voc_loader.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lauraset/BuildingHeightModel/d13d16a02dfdfa042d63492009a7cc8e46032170/ptsemseg/loader/__pycache__/pascal_voc_loader.cpython-36.pyc -------------------------------------------------------------------------------- /ptsemseg/loader/__pycache__/sunrgbd_loader.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lauraset/BuildingHeightModel/d13d16a02dfdfa042d63492009a7cc8e46032170/ptsemseg/loader/__pycache__/sunrgbd_loader.cpython-36.pyc -------------------------------------------------------------------------------- /ptsemseg/loader/ade20k_loader.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import torch 3 | import torchvision 4 | import numpy as np 5 | import scipy.misc as m 6 | import matplotlib.pyplot as plt 7 | 8 | from torch.utils import data 9 | 10 | from ptsemseg.utils import recursive_glob 11 | 12 | 13 | class ADE20KLoader(data.Dataset): 14 | def __init__( 15 | self, 16 | root, 17 | split="training", 18 | is_transform=False, 19 | img_size=512, 20 | augmentations=None, 21 | img_norm=True, 22 | test_mode=False, 23 | ): 24 | self.root = root 25 | self.split = split 26 | self.is_transform = is_transform 27 | self.augmentations = augmentations 28 | self.img_norm = img_norm 29 | self.test_mode = test_mode 30 | self.n_classes = 150 31 | self.img_size = img_size if isinstance(img_size, tuple) else (img_size, img_size) 32 | self.mean = np.array([104.00699, 116.66877, 122.67892]) 33 | self.files = collections.defaultdict(list) 34 | 35 | if not self.test_mode: 36 | for split in ["training", "validation"]: 37 | file_list = recursive_glob( 38 | rootdir=self.root + "images/" + self.split + "/", suffix=".jpg" 39 | ) 40 | self.files[split] = file_list 41 | 42 | def __len__(self): 43 | return len(self.files[self.split]) 44 | 45 | def __getitem__(self, index): 46 | img_path = self.files[self.split][index].rstrip() 47 | lbl_path = img_path[:-4] + "_seg.png" 48 | 49 | img = m.imread(img_path) 50 | img = np.array(img, dtype=np.uint8) 51 | 52 | lbl = m.imread(lbl_path) 53 | lbl = np.array(lbl, dtype=np.int32) 54 | 55 | if self.augmentations is not None: 56 | img, lbl = self.augmentations(img, lbl) 57 | 58 | if self.is_transform: 59 | img, lbl = self.transform(img, lbl) 60 | 61 | return img, lbl 62 | 63 | def transform(self, img, lbl): 64 | img = m.imresize(img, (self.img_size[0], self.img_size[1])) # uint8 with RGB mode 65 | img = img[:, :, ::-1] # RGB -> BGR 66 | img = img.astype(np.float64) 67 | img -= self.mean 68 | if self.img_norm: 69 | # Resize scales images from 0 to 255, thus we need 70 | # to divide by 255.0 71 | img = img.astype(float) / 255.0 72 | # NHWC -> NCHW 73 | img = img.transpose(2, 0, 1) 74 | 75 | lbl = self.encode_segmap(lbl) 76 | classes = np.unique(lbl) 77 | lbl = lbl.astype(float) 78 | lbl = m.imresize(lbl, (self.img_size[0], self.img_size[1]), "nearest", mode="F") 79 | lbl = lbl.astype(int) 80 | assert np.all(classes == np.unique(lbl)) 81 | 82 | img = torch.from_numpy(img).float() 83 | lbl = torch.from_numpy(lbl).long() 84 | return img, lbl 85 | 86 | def encode_segmap(self, mask): 87 | # Refer : http://groups.csail.mit.edu/vision/datasets/ADE20K/code/loadAde20K.m 88 | mask = mask.astype(int) 89 | label_mask = np.zeros((mask.shape[0], mask.shape[1])) 90 | label_mask = (mask[:, :, 0] / 10.0) * 256 + mask[:, :, 1] 91 | return np.array(label_mask, dtype=np.uint8) 92 | 93 | def decode_segmap(self, temp, plot=False): 94 | # TODO:(@meetshah1995) 95 | # Verify that the color mapping is 1-to-1 96 | r = temp.copy() 97 | g = temp.copy() 98 | b = temp.copy() 99 | for l in range(0, self.n_classes): 100 | r[temp == l] = 10 * (l % 10) 101 | g[temp == l] = l 102 | b[temp == l] = 0 103 | 104 | rgb = np.zeros((temp.shape[0], temp.shape[1], 3)) 105 | rgb[:, :, 0] = r / 255.0 106 | rgb[:, :, 1] = g / 255.0 107 | rgb[:, :, 2] = b / 255.0 108 | if plot: 109 | plt.imshow(rgb) 110 | plt.show() 111 | else: 112 | return rgb 113 | 114 | 115 | if __name__ == "__main__": 116 | local_path = "/Users/meet/data/ADE20K_2016_07_26/" 117 | dst = ADE20KLoader(local_path, is_transform=True) 118 | trainloader = data.DataLoader(dst, batch_size=4) 119 | for i, data_samples in enumerate(trainloader): 120 | imgs, labels = data_samples 121 | if i == 0: 122 | img = torchvision.utils.make_grid(imgs).numpy() 123 | img = np.transpose(img, (1, 2, 0)) 124 | img = img[:, :, ::-1] 125 | plt.imshow(img) 126 | plt.show() 127 | for j in range(4): 128 | plt.imshow(dst.decode_segmap(labels.numpy()[j])) 129 | plt.show() 130 | -------------------------------------------------------------------------------- /ptsemseg/loader/camvid_loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import collections 3 | import torch 4 | import numpy as np 5 | import scipy.misc as m 6 | import matplotlib.pyplot as plt 7 | 8 | from torch.utils import data 9 | from ptsemseg.augmentations import Compose, RandomHorizontallyFlip, RandomRotate 10 | 11 | 12 | class camvidLoader(data.Dataset): 13 | def __init__( 14 | self, 15 | root, 16 | split="train", 17 | is_transform=False, 18 | img_size=None, 19 | augmentations=None, 20 | img_norm=True, 21 | test_mode=False, 22 | ): 23 | self.root = root 24 | self.split = split 25 | self.img_size = [360, 480] 26 | self.is_transform = is_transform 27 | self.augmentations = augmentations 28 | self.img_norm = img_norm 29 | self.test_mode = test_mode 30 | self.mean = np.array([104.00699, 116.66877, 122.67892]) 31 | self.n_classes = 12 32 | self.files = collections.defaultdict(list) 33 | 34 | if not self.test_mode: 35 | for split in ["train", "test", "val"]: 36 | file_list = os.listdir(root + "/" + split) 37 | self.files[split] = file_list 38 | 39 | def __len__(self): 40 | return len(self.files[self.split]) 41 | 42 | def __getitem__(self, index): 43 | img_name = self.files[self.split][index] 44 | img_path = self.root + "/" + self.split + "/" + img_name 45 | lbl_path = self.root + "/" + self.split + "annot/" + img_name 46 | 47 | img = m.imread(img_path) 48 | img = np.array(img, dtype=np.uint8) 49 | 50 | lbl = m.imread(lbl_path) 51 | lbl = np.array(lbl, dtype=np.int8) 52 | 53 | if self.augmentations is not None: 54 | img, lbl = self.augmentations(img, lbl) 55 | 56 | if self.is_transform: 57 | img, lbl = self.transform(img, lbl) 58 | 59 | return img, lbl 60 | 61 | def transform(self, img, lbl): 62 | img = m.imresize(img, (self.img_size[0], self.img_size[1])) # uint8 with RGB mode 63 | img = img[:, :, ::-1] # RGB -> BGR 64 | img = img.astype(np.float64) 65 | img -= self.mean 66 | if self.img_norm: 67 | # Resize scales images from 0 to 255, thus we need 68 | # to divide by 255.0 69 | img = img.astype(float) / 255.0 70 | # NHWC -> NCHW 71 | img = img.transpose(2, 0, 1) 72 | 73 | img = torch.from_numpy(img).float() 74 | lbl = torch.from_numpy(lbl).long() 75 | return img, lbl 76 | 77 | def decode_segmap(self, temp, plot=False): 78 | Sky = [128, 128, 128] 79 | Building = [128, 0, 0] 80 | Pole = [192, 192, 128] 81 | Road = [128, 64, 128] 82 | Pavement = [60, 40, 222] 83 | Tree = [128, 128, 0] 84 | SignSymbol = [192, 128, 128] 85 | Fence = [64, 64, 128] 86 | Car = [64, 0, 128] 87 | Pedestrian = [64, 64, 0] 88 | Bicyclist = [0, 128, 192] 89 | Unlabelled = [0, 0, 0] 90 | 91 | label_colours = np.array( 92 | [ 93 | Sky, 94 | Building, 95 | Pole, 96 | Road, 97 | Pavement, 98 | Tree, 99 | SignSymbol, 100 | Fence, 101 | Car, 102 | Pedestrian, 103 | Bicyclist, 104 | Unlabelled, 105 | ] 106 | ) 107 | r = temp.copy() 108 | g = temp.copy() 109 | b = temp.copy() 110 | for l in range(0, self.n_classes): 111 | r[temp == l] = label_colours[l, 0] 112 | g[temp == l] = label_colours[l, 1] 113 | b[temp == l] = label_colours[l, 2] 114 | 115 | rgb = np.zeros((temp.shape[0], temp.shape[1], 3)) 116 | rgb[:, :, 0] = r / 255.0 117 | rgb[:, :, 1] = g / 255.0 118 | rgb[:, :, 2] = b / 255.0 119 | return rgb 120 | 121 | 122 | if __name__ == "__main__": 123 | local_path = "/home/meetshah1995/datasets/segnet/CamVid" 124 | augmentations = Compose([RandomRotate(10), RandomHorizontallyFlip()]) 125 | 126 | dst = camvidLoader(local_path, is_transform=True, augmentations=augmentations) 127 | bs = 4 128 | trainloader = data.DataLoader(dst, batch_size=bs) 129 | for i, data_samples in enumerate(trainloader): 130 | imgs, labels = data_samples 131 | imgs = imgs.numpy()[:, ::-1, :, :] 132 | imgs = np.transpose(imgs, [0, 2, 3, 1]) 133 | f, axarr = plt.subplots(bs, 2) 134 | for j in range(bs): 135 | axarr[j][0].imshow(imgs[j]) 136 | axarr[j][1].imshow(dst.decode_segmap(labels.numpy()[j])) 137 | plt.show() 138 | a = input() 139 | if a == "ex": 140 | break 141 | else: 142 | plt.close() 143 | -------------------------------------------------------------------------------- /ptsemseg/loader/cityscapes_loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | import scipy.misc as m 5 | 6 | from torch.utils import data 7 | 8 | from ptsemseg.utils import recursive_glob 9 | from ptsemseg.augmentations import Compose, RandomHorizontallyFlip, RandomRotate, Scale 10 | 11 | 12 | class cityscapesLoader(data.Dataset): 13 | """cityscapesLoader 14 | 15 | https://www.cityscapes-dataset.com 16 | 17 | Data is derived from CityScapes, and can be downloaded from here: 18 | https://www.cityscapes-dataset.com/downloads/ 19 | 20 | Many Thanks to @fvisin for the loader repo: 21 | https://github.com/fvisin/dataset_loaders/blob/master/dataset_loaders/images/cityscapes.py 22 | """ 23 | 24 | colors = [ # [ 0, 0, 0], 25 | [128, 64, 128], 26 | [244, 35, 232], 27 | [70, 70, 70], 28 | [102, 102, 156], 29 | [190, 153, 153], 30 | [153, 153, 153], 31 | [250, 170, 30], 32 | [220, 220, 0], 33 | [107, 142, 35], 34 | [152, 251, 152], 35 | [0, 130, 180], 36 | [220, 20, 60], 37 | [255, 0, 0], 38 | [0, 0, 142], 39 | [0, 0, 70], 40 | [0, 60, 100], 41 | [0, 80, 100], 42 | [0, 0, 230], 43 | [119, 11, 32], 44 | ] 45 | 46 | label_colours = dict(zip(range(19), colors)) 47 | 48 | mean_rgb = { 49 | "pascal": [103.939, 116.779, 123.68], 50 | "cityscapes": [0.0, 0.0, 0.0], 51 | } # pascal mean for PSPNet and ICNet pre-trained model 52 | 53 | def __init__( 54 | self, 55 | root, 56 | split="train", 57 | is_transform=False, 58 | img_size=(512, 1024), 59 | augmentations=None, 60 | img_norm=True, 61 | version="cityscapes", 62 | test_mode=False, 63 | ): 64 | """__init__ 65 | 66 | :param root: 67 | :param split: 68 | :param is_transform: 69 | :param img_size: 70 | :param augmentations 71 | """ 72 | self.root = root 73 | self.split = split 74 | self.is_transform = is_transform 75 | self.augmentations = augmentations 76 | self.img_norm = img_norm 77 | self.n_classes = 19 78 | self.img_size = img_size if isinstance(img_size, tuple) else (img_size, img_size) 79 | self.mean = np.array(self.mean_rgb[version]) 80 | self.files = {} 81 | 82 | self.images_base = os.path.join(self.root, "leftImg8bit", self.split) 83 | self.annotations_base = os.path.join(self.root, "gtFine", self.split) 84 | 85 | self.files[split] = recursive_glob(rootdir=self.images_base, suffix=".png") 86 | 87 | self.void_classes = [0, 1, 2, 3, 4, 5, 6, 9, 10, 14, 15, 16, 18, 29, 30, -1] 88 | self.valid_classes = [ 89 | 7, 90 | 8, 91 | 11, 92 | 12, 93 | 13, 94 | 17, 95 | 19, 96 | 20, 97 | 21, 98 | 22, 99 | 23, 100 | 24, 101 | 25, 102 | 26, 103 | 27, 104 | 28, 105 | 31, 106 | 32, 107 | 33, 108 | ] 109 | self.class_names = [ 110 | "unlabelled", 111 | "road", 112 | "sidewalk", 113 | "building", 114 | "wall", 115 | "fence", 116 | "pole", 117 | "traffic_light", 118 | "traffic_sign", 119 | "vegetation", 120 | "terrain", 121 | "sky", 122 | "person", 123 | "rider", 124 | "car", 125 | "truck", 126 | "bus", 127 | "train", 128 | "motorcycle", 129 | "bicycle", 130 | ] 131 | 132 | self.ignore_index = 250 133 | self.class_map = dict(zip(self.valid_classes, range(19))) 134 | 135 | if not self.files[split]: 136 | raise Exception("No files for split=[%s] found in %s" % (split, self.images_base)) 137 | 138 | print("Found %d %s images" % (len(self.files[split]), split)) 139 | 140 | def __len__(self): 141 | """__len__""" 142 | return len(self.files[self.split]) 143 | 144 | def __getitem__(self, index): 145 | """__getitem__ 146 | 147 | :param index: 148 | """ 149 | img_path = self.files[self.split][index].rstrip() 150 | lbl_path = os.path.join( 151 | self.annotations_base, 152 | img_path.split(os.sep)[-2], 153 | os.path.basename(img_path)[:-15] + "gtFine_labelIds.png", 154 | ) 155 | 156 | img = m.imread(img_path) 157 | img = np.array(img, dtype=np.uint8) 158 | 159 | lbl = m.imread(lbl_path) 160 | lbl = self.encode_segmap(np.array(lbl, dtype=np.uint8)) 161 | 162 | if self.augmentations is not None: 163 | img, lbl = self.augmentations(img, lbl) 164 | 165 | if self.is_transform: 166 | img, lbl = self.transform(img, lbl) 167 | 168 | return img, lbl 169 | 170 | def transform(self, img, lbl): 171 | """transform 172 | 173 | :param img: 174 | :param lbl: 175 | """ 176 | img = m.imresize(img, (self.img_size[0], self.img_size[1])) # uint8 with RGB mode 177 | img = img[:, :, ::-1] # RGB -> BGR 178 | img = img.astype(np.float64) 179 | img -= self.mean 180 | if self.img_norm: 181 | # Resize scales images from 0 to 255, thus we need 182 | # to divide by 255.0 183 | img = img.astype(float) / 255.0 184 | # NHWC -> NCHW 185 | img = img.transpose(2, 0, 1) 186 | 187 | classes = np.unique(lbl) 188 | lbl = lbl.astype(float) 189 | lbl = m.imresize(lbl, (self.img_size[0], self.img_size[1]), "nearest", mode="F") 190 | lbl = lbl.astype(int) 191 | 192 | if not np.all(classes == np.unique(lbl)): 193 | print("WARN: resizing labels yielded fewer classes") 194 | 195 | if not np.all(np.unique(lbl[lbl != self.ignore_index]) < self.n_classes): 196 | print("after det", classes, np.unique(lbl)) 197 | raise ValueError("Segmentation map contained invalid class values") 198 | 199 | img = torch.from_numpy(img).float() 200 | lbl = torch.from_numpy(lbl).long() 201 | 202 | return img, lbl 203 | 204 | def decode_segmap(self, temp): 205 | r = temp.copy() 206 | g = temp.copy() 207 | b = temp.copy() 208 | for l in range(0, self.n_classes): 209 | r[temp == l] = self.label_colours[l][0] 210 | g[temp == l] = self.label_colours[l][1] 211 | b[temp == l] = self.label_colours[l][2] 212 | 213 | rgb = np.zeros((temp.shape[0], temp.shape[1], 3)) 214 | rgb[:, :, 0] = r / 255.0 215 | rgb[:, :, 1] = g / 255.0 216 | rgb[:, :, 2] = b / 255.0 217 | return rgb 218 | 219 | def encode_segmap(self, mask): 220 | # Put all void classes to zero 221 | for _voidc in self.void_classes: 222 | mask[mask == _voidc] = self.ignore_index 223 | for _validc in self.valid_classes: 224 | mask[mask == _validc] = self.class_map[_validc] 225 | return mask 226 | 227 | 228 | if __name__ == "__main__": 229 | import matplotlib.pyplot as plt 230 | 231 | augmentations = Compose([Scale(2048), RandomRotate(10), RandomHorizontallyFlip(0.5)]) 232 | 233 | local_path = "/datasets01/cityscapes/112817/" 234 | dst = cityscapesLoader(local_path, is_transform=True, augmentations=augmentations) 235 | bs = 4 236 | trainloader = data.DataLoader(dst, batch_size=bs, num_workers=0) 237 | for i, data_samples in enumerate(trainloader): 238 | imgs, labels = data_samples 239 | import pdb 240 | 241 | pdb.set_trace() 242 | imgs = imgs.numpy()[:, ::-1, :, :] 243 | imgs = np.transpose(imgs, [0, 2, 3, 1]) 244 | f, axarr = plt.subplots(bs, 2) 245 | for j in range(bs): 246 | axarr[j][0].imshow(imgs[j]) 247 | axarr[j][1].imshow(dst.decode_segmap(labels.numpy()[j])) 248 | plt.show() 249 | a = input() 250 | if a == "ex": 251 | break 252 | else: 253 | plt.close() 254 | -------------------------------------------------------------------------------- /ptsemseg/loader/diyloader.py: -------------------------------------------------------------------------------- 1 | ''' 2 | new files yinxcao 3 | used for ningbo high-resolution images 4 | format: png 5 | April 25, 2020 6 | ''' 7 | 8 | import torch.utils.data as data 9 | from PIL import Image, ImageOps 10 | import numpy as np 11 | import torch 12 | import tifffile as tif 13 | from ptsemseg.augmentations.diyaugmentation import my_segmentation_transforms, my_segmentation_transforms_crop 14 | import random 15 | 16 | IMG_EXTENSIONS = [ 17 | '.jpg', '.JPG', '.jpeg', '.JPEG', 18 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', 19 | '.tif' #new added 20 | ] 21 | 22 | 23 | def is_image_file(filename): 24 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 25 | 26 | 27 | def default_loader(path): 28 | return np.asarray(Image.open(path)) 29 | 30 | 31 | def stretch_img(image, nrange): 32 | #according the range [low,high] to rescale image 33 | h, w, nbands = np.shape(image) 34 | image_stretch = np.zeros(shape=(h, w, nbands), dtype=np.float32) 35 | for i in range(nbands): 36 | image_stretch[:, :, i] = 1.0*(image[:, :, i]-nrange[1, i])/(nrange[0, i]-nrange[1, i]) 37 | return image_stretch 38 | 39 | 40 | def gray2rgb(image): 41 | res=np.zeros((image.shape[0], image.shape[1], 3)) 42 | res[ :, :, 0] = image.copy() 43 | res[ :, :, 1] = image.copy() 44 | res[ :, :, 2] = image.copy() 45 | return res 46 | 47 | 48 | def readtif(name): 49 | # is gray 50 | # should be fused with spectral bands 51 | img=tif.imread(name) 52 | return img 53 | 54 | 55 | class myImageFloderold(data.Dataset): 56 | def __init__(self, imgpath, labpath): # data loader #params nrange 57 | self.imgpath = imgpath 58 | self.labpath = labpath 59 | 60 | def __getitem__(self, index): 61 | imgpath_ = self.imgpath[index] 62 | labpath_ = self.labpath[index] 63 | img = default_loader(imgpath_) 64 | lab = default_loader(labpath_) # 0, 1, ..., N_CLASS-1 65 | img = img[:, :, ::-1] / 255 # RGB => BGR 66 | img = torch.tensor(img, dtype=torch.float).permute(2, 0, 1) # H W C => C H W 67 | lab = torch.tensor(lab, dtype=torch.long)-1 68 | return img, lab # new added 69 | 70 | def __len__(self): 71 | return len(self.imgpath) 72 | 73 | # img, tlc, lab 2020.7.27 74 | # update: 2020.9.11: lab, the unit has been changed to meters (float) rather than floor number. 75 | class myImageFloder(data.Dataset): 76 | def __init__(self, imgpath, labpath, augmentations=False, num= 0): # data loader #params nrange 77 | self.imgpath = imgpath 78 | self.labpath = labpath 79 | if num>0: 80 | self.imgpath = imgpath[:num] 81 | self.labpath = labpath[:num] 82 | self.augmentations = augmentations 83 | 84 | def __getitem__(self, index): 85 | muxpath_ = self.imgpath[index, 0] 86 | tlcpath_ = self.imgpath[index, 1] 87 | labpath_ = self.labpath[index] 88 | mux = tif.imread(muxpath_) / 10000 # convert to surface reflectance (SR): 0-1 89 | # tlc = tif.imread(tlcpath_)/950 # stretch to 0-1 90 | tlc = tif.imread(tlcpath_) / 10000 # convert to 0-1 91 | img = np.concatenate((mux, tlc), axis=2) # the third dimension 92 | img[img>1]=1 # ensure data range is 0-1 93 | lab = tif.imread(labpath_) # building floor * 3 (meters) in float format 94 | 95 | if self.augmentations: 96 | img, lab = my_segmentation_transforms(img, lab) 97 | 98 | img = img.transpose((2, 0, 1)) # H W C => C H W 99 | # lab = lab.astype(np.int16) * 3 : storing the number of floor, deprecated 100 | lab = np.expand_dims(lab, axis=0) 101 | 102 | img = torch.tensor(img.copy(), dtype=torch.float) 103 | lab = torch.tensor(lab.copy(), dtype=torch.float) 104 | return img, lab 105 | 106 | def __len__(self): 107 | return len(self.imgpath) 108 | 109 | 110 | # only load tlc (3bands) 111 | class myImageFloder_tlc(data.Dataset): 112 | def __init__(self, imgpath, labpath, augmentations=False, num=0): # data loader #params nrange 113 | self.imgpath = imgpath 114 | self.labpath = labpath 115 | if num>0: 116 | self.imgpath = imgpath[:num] 117 | self.labpath = labpath[:num] 118 | self.augmentations = augmentations 119 | 120 | def __getitem__(self, index): 121 | tlcpath_ = self.imgpath[index, 1] 122 | labpath_ = self.labpath[index] 123 | img = tif.imread(tlcpath_)/10000 # stretch to 0-1 124 | img[img>1] = 1 # ensure data range is 0-1 125 | lab = tif.imread(labpath_) # building floor * 3 (meters) in float format 126 | 127 | if self.augmentations: 128 | img, lab = my_segmentation_transforms(img, lab) 129 | 130 | img = img.transpose((2, 0, 1)) # H W C => C H W 131 | lab = np.expand_dims(lab, axis=0) 132 | 133 | img = torch.tensor(img.copy(), dtype=torch.float) 134 | lab = torch.tensor(lab.copy(), dtype=torch.float) 135 | return img, lab 136 | 137 | def __len__(self): 138 | return len(self.imgpath) 139 | 140 | 141 | # only load mux (4 bands) and lab 142 | class myImageFloder_mux(data.Dataset): 143 | def __init__(self, imgpath, labpath, augmentations=False, num=0): # data loader #params nrange 144 | self.imgpath = imgpath 145 | self.labpath = labpath 146 | if num>0: 147 | self.imgpath = imgpath[:num] 148 | self.labpath = labpath[:num] 149 | self.augmentations = augmentations 150 | 151 | def __getitem__(self, index): 152 | muxpath_ = self.imgpath[index, 0] 153 | tlcpath_ = self.imgpath[index, 1] 154 | labpath_ = self.labpath[index] 155 | img = tif.imread(muxpath_) / 10000 # convert to surface reflectance (SR): 0-1 156 | # tlc = tif.imread(tlcpath_)/950 # stretch to 0-1 157 | # tlc = tif.imread(tlcpath_) / 10000 # convert to 0-1 158 | # img = np.concatenate((mux, tlc), axis=2) # the third dimension 159 | img[img>1]=1 # ensure data range is 0-1 160 | lab = tif.imread(labpath_) # building floor * 3 (meters) in float format 161 | 162 | if self.augmentations: 163 | img, lab = my_segmentation_transforms(img, lab) 164 | 165 | img = img.transpose((2, 0, 1)) # H W C => C H W 166 | # lab = lab.astype(np.int16) * 3 : storing the number of floor, deprecated 167 | lab = np.expand_dims(lab, axis=0) 168 | 169 | img = torch.tensor(img.copy(), dtype=torch.float) 170 | lab = torch.tensor(lab.copy(), dtype=torch.float) 171 | return img, lab 172 | 173 | def __len__(self): 174 | return len(self.imgpath) 175 | 176 | # img, tlc, lab 2020.8.3 177 | ''' 178 | class myImageFloder_tlc(data.Dataset): 179 | def __init__(self, imgpath, labpath, patchsize=256, augmentations=False): # data loader #params nrange 180 | self.imgpath = imgpath 181 | self.labpath = labpath 182 | self.patchsize = patchsize 183 | self.augmentations = augmentations 184 | 185 | def __getitem__(self, index): 186 | muxpath_ = self.imgpath[index, 0] 187 | tlcpath_ = self.imgpath[index, 1] 188 | labpath_ = self.labpath[index] 189 | mux = tif.imread(muxpath_) # convert to surface reflectance (SR): 0-1 190 | tlc = tif.imread(tlcpath_) # convert to 0-1 191 | lab = tif.imread(labpath_) # building floor * 3 (meters) 192 | # 1. clip 193 | # random crop: test and train is the same 194 | offset = mux.shape[0] - self.patchsize 195 | x1 = random.randint(0, offset) 196 | y1 = random.randint(0, offset) 197 | mux = mux[x1:x1 + self.patchsize, y1:y1 + self.patchsize, :]/ 10000 198 | tlc = tlc[x1:x1 + self.patchsize, y1:y1 + self.patchsize, :] / 10000 199 | lab = lab[x1:x1 + self.patchsize, y1:y1 + self.patchsize] 200 | 201 | # 2. normalize 202 | #img = np.concatenate((mux, gray2rgb(tlc[:,:,0]), gray2rgb(tlc[:,:,1]), gray2rgb(tlc[:,:,2])), axis=2) 203 | img = np.concatenate((mux, tlc), axis=2) 204 | #img[img>1]=1 # ensure data range is 0-1 205 | 206 | if self.augmentations: 207 | img, lab = my_segmentation_transforms(img, lab) 208 | 209 | img = img.transpose((2, 0, 1)) 210 | lab = lab.astype(np.int16) * 3 211 | 212 | img = torch.tensor(img.copy(), dtype=torch.float) #H W C => C H W 213 | lab = torch.tensor(lab.copy(), dtype=torch.float) 214 | return img, lab 215 | 216 | def __len__(self): 217 | return len(self.imgpath) 218 | ''' -------------------------------------------------------------------------------- /ptsemseg/loader/mapillary_vistas_loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import torch 4 | import numpy as np 5 | 6 | from torch.utils import data 7 | from PIL import Image 8 | 9 | from ptsemseg.utils import recursive_glob 10 | from ptsemseg.augmentations import Compose, RandomHorizontallyFlip, RandomRotate 11 | 12 | 13 | class mapillaryVistasLoader(data.Dataset): 14 | def __init__( 15 | self, 16 | root, 17 | split="training", 18 | img_size=(640, 1280), 19 | is_transform=True, 20 | augmentations=None, 21 | test_mode=False, 22 | ): 23 | self.root = root 24 | self.split = split 25 | self.is_transform = is_transform 26 | self.augmentations = augmentations 27 | self.n_classes = 65 28 | 29 | self.img_size = img_size if isinstance(img_size, tuple) else (img_size, img_size) 30 | self.mean = np.array([80.5423, 91.3162, 81.4312]) 31 | self.files = {} 32 | 33 | self.images_base = os.path.join(self.root, self.split, "images") 34 | self.annotations_base = os.path.join(self.root, self.split, "labels") 35 | 36 | self.files[split] = recursive_glob(rootdir=self.images_base, suffix=".jpg") 37 | 38 | self.class_ids, self.class_names, self.class_colors = self.parse_config() 39 | 40 | self.ignore_id = 250 41 | 42 | if not self.files[split]: 43 | raise Exception("No files for split=[%s] found in %s" % (split, self.images_base)) 44 | 45 | print("Found %d %s images" % (len(self.files[split]), split)) 46 | 47 | def parse_config(self): 48 | with open(os.path.join(self.root, "config.json")) as config_file: 49 | config = json.load(config_file) 50 | 51 | labels = config["labels"] 52 | 53 | class_names = [] 54 | class_ids = [] 55 | class_colors = [] 56 | print("There are {} labels in the config file".format(len(labels))) 57 | for label_id, label in enumerate(labels): 58 | class_names.append(label["readable"]) 59 | class_ids.append(label_id) 60 | class_colors.append(label["color"]) 61 | 62 | return class_names, class_ids, class_colors 63 | 64 | def __len__(self): 65 | """__len__""" 66 | return len(self.files[self.split]) 67 | 68 | def __getitem__(self, index): 69 | """__getitem__ 70 | :param index: 71 | """ 72 | img_path = self.files[self.split][index].rstrip() 73 | lbl_path = os.path.join( 74 | self.annotations_base, os.path.basename(img_path).replace(".jpg", ".png") 75 | ) 76 | 77 | img = Image.open(img_path) 78 | lbl = Image.open(lbl_path) 79 | 80 | if self.augmentations is not None: 81 | img, lbl = self.augmentations(img, lbl) 82 | 83 | if self.is_transform: 84 | img, lbl = self.transform(img, lbl) 85 | 86 | return img, lbl 87 | 88 | def transform(self, img, lbl): 89 | if self.img_size == ("same", "same"): 90 | pass 91 | else: 92 | img = img.resize( 93 | (self.img_size[0], self.img_size[1]), resample=Image.LANCZOS 94 | ) # uint8 with RGB mode 95 | lbl = lbl.resize((self.img_size[0], self.img_size[1])) 96 | img = np.array(img).astype(np.float64) / 255.0 97 | img = torch.from_numpy(img.transpose(2, 0, 1)).float() # From HWC to CHW 98 | lbl = torch.from_numpy(np.array(lbl)).long() 99 | lbl[lbl == 65] = self.ignore_id 100 | return img, lbl 101 | 102 | def decode_segmap(self, temp): 103 | r = temp.copy() 104 | g = temp.copy() 105 | b = temp.copy() 106 | for l in range(0, self.n_classes): 107 | r[temp == l] = self.class_colors[l][0] 108 | g[temp == l] = self.class_colors[l][1] 109 | b[temp == l] = self.class_colors[l][2] 110 | 111 | rgb = np.zeros((temp.shape[0], temp.shape[1], 3)) 112 | rgb[:, :, 0] = r / 255.0 113 | rgb[:, :, 1] = g / 255.0 114 | rgb[:, :, 2] = b / 255.0 115 | return rgb 116 | 117 | 118 | if __name__ == "__main__": 119 | augment = Compose([RandomHorizontallyFlip(), RandomRotate(6)]) 120 | 121 | local_path = "/private/home/meetshah/datasets/seg/vistas/" 122 | dst = mapillaryVistasLoader( 123 | local_path, img_size=(512, 1024), is_transform=True, augmentations=augment 124 | ) 125 | bs = 8 126 | trainloader = data.DataLoader(dst, batch_size=bs, num_workers=4, shuffle=True) 127 | for i, data_samples in enumerate(trainloader): 128 | x = dst.decode_segmap(data_samples[1][0].numpy()) 129 | print("batch :", i) 130 | -------------------------------------------------------------------------------- /ptsemseg/loader/mit_sceneparsing_benchmark_loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | import scipy.misc as m 5 | 6 | from torch.utils import data 7 | 8 | from ptsemseg.utils import recursive_glob 9 | 10 | 11 | class MITSceneParsingBenchmarkLoader(data.Dataset): 12 | """MITSceneParsingBenchmarkLoader 13 | 14 | http://sceneparsing.csail.mit.edu/ 15 | 16 | Data is derived from ADE20k, and can be downloaded from here: 17 | http://data.csail.mit.edu/places/ADEchallenge/ADEChallengeData2016.zip 18 | 19 | NOTE: this loader is not designed to work with the original ADE20k dataset; 20 | for that you will need the ADE20kLoader 21 | 22 | This class can also be extended to load data for places challenge: 23 | https://github.com/CSAILVision/placeschallenge/tree/master/sceneparsing 24 | 25 | """ 26 | 27 | def __init__( 28 | self, 29 | root, 30 | split="training", 31 | is_transform=False, 32 | img_size=512, 33 | augmentations=None, 34 | img_norm=True, 35 | test_mode=False, 36 | ): 37 | """__init__ 38 | 39 | :param root: 40 | :param split: 41 | :param is_transform: 42 | :param img_size: 43 | """ 44 | self.root = root 45 | self.split = split 46 | self.is_transform = is_transform 47 | self.augmentations = augmentations 48 | self.img_norm = img_norm 49 | self.n_classes = 151 # 0 is reserved for "other" 50 | self.img_size = img_size if isinstance(img_size, tuple) else (img_size, img_size) 51 | self.mean = np.array([104.00699, 116.66877, 122.67892]) 52 | self.files = {} 53 | 54 | self.images_base = os.path.join(self.root, "images", self.split) 55 | self.annotations_base = os.path.join(self.root, "annotations", self.split) 56 | 57 | self.files[split] = recursive_glob(rootdir=self.images_base, suffix=".jpg") 58 | 59 | if not self.files[split]: 60 | raise Exception("No files for split=[%s] found in %s" % (split, self.images_base)) 61 | 62 | print("Found %d %s images" % (len(self.files[split]), split)) 63 | 64 | def __len__(self): 65 | """__len__""" 66 | return len(self.files[self.split]) 67 | 68 | def __getitem__(self, index): 69 | """__getitem__ 70 | 71 | :param index: 72 | """ 73 | img_path = self.files[self.split][index].rstrip() 74 | lbl_path = os.path.join(self.annotations_base, os.path.basename(img_path)[:-4] + ".png") 75 | 76 | img = m.imread(img_path, mode="RGB") 77 | img = np.array(img, dtype=np.uint8) 78 | 79 | lbl = m.imread(lbl_path) 80 | lbl = np.array(lbl, dtype=np.uint8) 81 | 82 | if self.augmentations is not None: 83 | img, lbl = self.augmentations(img, lbl) 84 | 85 | if self.is_transform: 86 | img, lbl = self.transform(img, lbl) 87 | 88 | return img, lbl 89 | 90 | def transform(self, img, lbl): 91 | """transform 92 | 93 | :param img: 94 | :param lbl: 95 | """ 96 | if self.img_size == ("same", "same"): 97 | pass 98 | else: 99 | img = m.imresize(img, (self.img_size[0], self.img_size[1])) # uint8 with RGB mode 100 | img = img[:, :, ::-1] # RGB -> BGR 101 | img = img.astype(np.float64) 102 | img -= self.mean 103 | if self.img_norm: 104 | # Resize scales images from 0 to 255, thus we need 105 | # to divide by 255.0 106 | img = img.astype(float) / 255.0 107 | # NHWC -> NCHW 108 | img = img.transpose(2, 0, 1) 109 | 110 | classes = np.unique(lbl) 111 | lbl = lbl.astype(float) 112 | if self.img_size == ("same", "same"): 113 | pass 114 | else: 115 | lbl = m.imresize(lbl, (self.img_size[0], self.img_size[1]), "nearest", mode="F") 116 | lbl = lbl.astype(int) 117 | 118 | if not np.all(classes == np.unique(lbl)): 119 | print("WARN: resizing labels yielded fewer classes") 120 | 121 | if not np.all(np.unique(lbl) < self.n_classes): 122 | raise ValueError("Segmentation map contained invalid class values") 123 | 124 | img = torch.from_numpy(img).float() 125 | lbl = torch.from_numpy(lbl).long() 126 | 127 | return img, lbl 128 | -------------------------------------------------------------------------------- /ptsemseg/loader/nyuv2_loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import collections 3 | import torch 4 | import numpy as np 5 | import scipy.misc as m 6 | 7 | from torch.utils import data 8 | 9 | from ptsemseg.utils import recursive_glob 10 | from ptsemseg.augmentations import Compose, RandomHorizontallyFlip, RandomRotate, Scale 11 | 12 | 13 | class NYUv2Loader(data.Dataset): 14 | """ 15 | NYUv2 loader 16 | Download From (only 13 classes): 17 | test source: http://www.doc.ic.ac.uk/~ahanda/nyu_test_rgb.tgz 18 | train source: http://www.doc.ic.ac.uk/~ahanda/nyu_train_rgb.tgz 19 | test_labels source: 20 | https://github.com/ankurhanda/nyuv2-meta-data/raw/master/test_labels_13/nyuv2_test_class13.tgz 21 | train_labels source: 22 | https://github.com/ankurhanda/nyuv2-meta-data/raw/master/train_labels_13/nyuv2_train_class13.tgz 23 | 24 | """ 25 | 26 | def __init__( 27 | self, 28 | root, 29 | split="training", 30 | is_transform=False, 31 | img_size=(480, 640), 32 | augmentations=None, 33 | img_norm=True, 34 | test_mode=False, 35 | ): 36 | self.root = root 37 | self.is_transform = is_transform 38 | self.n_classes = 14 39 | self.augmentations = augmentations 40 | self.img_norm = img_norm 41 | self.test_mode = test_mode 42 | self.img_size = img_size if isinstance(img_size, tuple) else (img_size, img_size) 43 | self.mean = np.array([104.00699, 116.66877, 122.67892]) 44 | self.files = collections.defaultdict(list) 45 | self.cmap = self.color_map(normalized=False) 46 | 47 | split_map = {"training": "train", "val": "test"} 48 | self.split = split_map[split] 49 | 50 | for split in ["train", "test"]: 51 | file_list = recursive_glob(rootdir=self.root + split + "/", suffix="png") 52 | self.files[split] = file_list 53 | 54 | def __len__(self): 55 | return len(self.files[self.split]) 56 | 57 | def __getitem__(self, index): 58 | img_path = self.files[self.split][index].rstrip() 59 | img_number = img_path.split("_")[-1][:4] 60 | lbl_path = os.path.join( 61 | self.root, self.split + "_annot", "new_nyu_class13_" + img_number + ".png" 62 | ) 63 | 64 | img = m.imread(img_path) 65 | img = np.array(img, dtype=np.uint8) 66 | 67 | lbl = m.imread(lbl_path) 68 | lbl = np.array(lbl, dtype=np.uint8) 69 | 70 | if not (len(img.shape) == 3 and len(lbl.shape) == 2): 71 | return self.__getitem__(np.random.randint(0, self.__len__())) 72 | 73 | if self.augmentations is not None: 74 | img, lbl = self.augmentations(img, lbl) 75 | 76 | if self.is_transform: 77 | img, lbl = self.transform(img, lbl) 78 | 79 | return img, lbl 80 | 81 | def transform(self, img, lbl): 82 | img = m.imresize(img, (self.img_size[0], self.img_size[1])) # uint8 with RGB mode 83 | img = img[:, :, ::-1] # RGB -> BGR 84 | img = img.astype(np.float64) 85 | img -= self.mean 86 | if self.img_norm: 87 | # Resize scales images from 0 to 255, thus we need 88 | # to divide by 255.0 89 | img = img.astype(float) / 255.0 90 | # NHWC -> NCHW 91 | img = img.transpose(2, 0, 1) 92 | 93 | classes = np.unique(lbl) 94 | lbl = lbl.astype(float) 95 | lbl = m.imresize(lbl, (self.img_size[0], self.img_size[1]), "nearest", mode="F") 96 | lbl = lbl.astype(int) 97 | assert np.all(classes == np.unique(lbl)) 98 | 99 | img = torch.from_numpy(img).float() 100 | lbl = torch.from_numpy(lbl).long() 101 | return img, lbl 102 | 103 | def color_map(self, N=256, normalized=False): 104 | """ 105 | Return Color Map in PASCAL VOC format 106 | """ 107 | 108 | def bitget(byteval, idx): 109 | return (byteval & (1 << idx)) != 0 110 | 111 | dtype = "float32" if normalized else "uint8" 112 | cmap = np.zeros((N, 3), dtype=dtype) 113 | for i in range(N): 114 | r = g = b = 0 115 | c = i 116 | for j in range(8): 117 | r = r | (bitget(c, 0) << 7 - j) 118 | g = g | (bitget(c, 1) << 7 - j) 119 | b = b | (bitget(c, 2) << 7 - j) 120 | c = c >> 3 121 | 122 | cmap[i] = np.array([r, g, b]) 123 | 124 | cmap = cmap / 255.0 if normalized else cmap 125 | return cmap 126 | 127 | def decode_segmap(self, temp): 128 | r = temp.copy() 129 | g = temp.copy() 130 | b = temp.copy() 131 | for l in range(0, self.n_classes): 132 | r[temp == l] = self.cmap[l, 0] 133 | g[temp == l] = self.cmap[l, 1] 134 | b[temp == l] = self.cmap[l, 2] 135 | 136 | rgb = np.zeros((temp.shape[0], temp.shape[1], 3)) 137 | rgb[:, :, 0] = r / 255.0 138 | rgb[:, :, 1] = g / 255.0 139 | rgb[:, :, 2] = b / 255.0 140 | return rgb 141 | 142 | 143 | if __name__ == "__main__": 144 | import matplotlib.pyplot as plt 145 | 146 | augmentations = Compose([Scale(512), RandomRotate(10), RandomHorizontallyFlip()]) 147 | 148 | local_path = "/home/meet/datasets/NYUv2/" 149 | dst = NYUv2Loader(local_path, is_transform=True, augmentations=augmentations) 150 | bs = 4 151 | trainloader = data.DataLoader(dst, batch_size=bs, num_workers=0) 152 | for i, datas in enumerate(trainloader): 153 | imgs, labels = datas 154 | imgs = imgs.numpy()[:, ::-1, :, :] 155 | imgs = np.transpose(imgs, [0, 2, 3, 1]) 156 | f, axarr = plt.subplots(bs, 2) 157 | for j in range(bs): 158 | axarr[j][0].imshow(imgs[j]) 159 | axarr[j][1].imshow(dst.decode_segmap(labels.numpy()[j])) 160 | plt.show() 161 | a = input() 162 | if a == "ex": 163 | break 164 | else: 165 | plt.close() 166 | -------------------------------------------------------------------------------- /ptsemseg/loader/pascal_voc_loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | from os.path import join as pjoin 3 | import collections 4 | import json 5 | import torch 6 | import numpy as np 7 | import scipy.misc as m 8 | import scipy.io as io 9 | import matplotlib.pyplot as plt 10 | import glob 11 | 12 | from PIL import Image 13 | from tqdm import tqdm 14 | from torch.utils import data 15 | from torchvision import transforms 16 | 17 | 18 | class pascalVOCLoader(data.Dataset): 19 | """Data loader for the Pascal VOC semantic segmentation dataset. 20 | 21 | Annotations from both the original VOC data (which consist of RGB images 22 | in which colours map to specific classes) and the SBD (Berkely) dataset 23 | (where annotations are stored as .mat files) are converted into a common 24 | `label_mask` format. Under this format, each mask is an (M,N) array of 25 | integer values from 0 to 21, where 0 represents the background class. 26 | 27 | The label masks are stored in a new folder, called `pre_encoded`, which 28 | is added as a subdirectory of the `SegmentationClass` folder in the 29 | original Pascal VOC data layout. 30 | 31 | A total of five data splits are provided for working with the VOC data: 32 | train: The original VOC 2012 training data - 1464 images 33 | val: The original VOC 2012 validation data - 1449 images 34 | trainval: The combination of `train` and `val` - 2913 images 35 | train_aug: The unique images present in both the train split and 36 | training images from SBD: - 8829 images (the unique members 37 | of the result of combining lists of length 1464 and 8498) 38 | train_aug_val: The original VOC 2012 validation data minus the images 39 | present in `train_aug` (This is done with the same logic as 40 | the validation set used in FCN PAMI paper, but with VOC 2012 41 | rather than VOC 2011) - 904 images 42 | """ 43 | 44 | def __init__( 45 | self, 46 | root, 47 | sbd_path=None, 48 | split="train_aug", 49 | is_transform=False, 50 | img_size=512, 51 | augmentations=None, 52 | img_norm=True, 53 | test_mode=False, 54 | ): 55 | self.root = root 56 | self.sbd_path = sbd_path 57 | self.split = split 58 | self.is_transform = is_transform 59 | self.augmentations = augmentations 60 | self.img_norm = img_norm 61 | self.test_mode = test_mode 62 | self.n_classes = 21 63 | self.mean = np.array([104.00699, 116.66877, 122.67892]) 64 | self.files = collections.defaultdict(list) 65 | self.img_size = img_size if isinstance(img_size, tuple) else (img_size, img_size) 66 | 67 | if not self.test_mode: 68 | for split in ["train", "val", "trainval"]: 69 | path = pjoin(self.root, "ImageSets/Segmentation", split + ".txt") 70 | file_list = tuple(open(path, "r")) 71 | file_list = [id_.rstrip() for id_ in file_list] 72 | self.files[split] = file_list 73 | self.setup_annotations() 74 | 75 | self.tf = transforms.Compose( 76 | [ 77 | transforms.ToTensor(), 78 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), 79 | ] 80 | ) 81 | 82 | def __len__(self): 83 | return len(self.files[self.split]) 84 | 85 | def __getitem__(self, index): 86 | im_name = self.files[self.split][index] 87 | im_path = pjoin(self.root, "JPEGImages", im_name + ".jpg") 88 | lbl_path = pjoin(self.root, "SegmentationClass/pre_encoded", im_name + ".png") 89 | im = Image.open(im_path) 90 | lbl = Image.open(lbl_path) 91 | if self.augmentations is not None: 92 | im, lbl = self.augmentations(im, lbl) 93 | if self.is_transform: 94 | im, lbl = self.transform(im, lbl) 95 | return im, lbl 96 | 97 | def transform(self, img, lbl): 98 | if self.img_size == ("same", "same"): 99 | pass 100 | else: 101 | img = img.resize((self.img_size[0], self.img_size[1])) # uint8 with RGB mode 102 | lbl = lbl.resize((self.img_size[0], self.img_size[1])) 103 | img = self.tf(img) 104 | lbl = torch.from_numpy(np.array(lbl)).long() 105 | lbl[lbl == 255] = 0 106 | return img, lbl 107 | 108 | def get_pascal_labels(self): 109 | """Load the mapping that associates pascal classes with label colors 110 | 111 | Returns: 112 | np.ndarray with dimensions (21, 3) 113 | """ 114 | return np.asarray( 115 | [ 116 | [0, 0, 0], 117 | [128, 0, 0], 118 | [0, 128, 0], 119 | [128, 128, 0], 120 | [0, 0, 128], 121 | [128, 0, 128], 122 | [0, 128, 128], 123 | [128, 128, 128], 124 | [64, 0, 0], 125 | [192, 0, 0], 126 | [64, 128, 0], 127 | [192, 128, 0], 128 | [64, 0, 128], 129 | [192, 0, 128], 130 | [64, 128, 128], 131 | [192, 128, 128], 132 | [0, 64, 0], 133 | [128, 64, 0], 134 | [0, 192, 0], 135 | [128, 192, 0], 136 | [0, 64, 128], 137 | ] 138 | ) 139 | 140 | def encode_segmap(self, mask): 141 | """Encode segmentation label images as pascal classes 142 | 143 | Args: 144 | mask (np.ndarray): raw segmentation label image of dimension 145 | (M, N, 3), in which the Pascal classes are encoded as colours. 146 | 147 | Returns: 148 | (np.ndarray): class map with dimensions (M,N), where the value at 149 | a given location is the integer denoting the class index. 150 | """ 151 | mask = mask.astype(int) 152 | label_mask = np.zeros((mask.shape[0], mask.shape[1]), dtype=np.int16) 153 | for ii, label in enumerate(self.get_pascal_labels()): 154 | label_mask[np.where(np.all(mask == label, axis=-1))[:2]] = ii 155 | label_mask = label_mask.astype(int) 156 | return label_mask 157 | 158 | def decode_segmap(self, label_mask, plot=False): 159 | """Decode segmentation class labels into a color image 160 | 161 | Args: 162 | label_mask (np.ndarray): an (M,N) array of integer values denoting 163 | the class label at each spatial location. 164 | plot (bool, optional): whether to show the resulting color image 165 | in a figure. 166 | 167 | Returns: 168 | (np.ndarray, optional): the resulting decoded color image. 169 | """ 170 | label_colours = self.get_pascal_labels() 171 | r = label_mask.copy() 172 | g = label_mask.copy() 173 | b = label_mask.copy() 174 | for ll in range(0, self.n_classes): 175 | r[label_mask == ll] = label_colours[ll, 0] 176 | g[label_mask == ll] = label_colours[ll, 1] 177 | b[label_mask == ll] = label_colours[ll, 2] 178 | rgb = np.zeros((label_mask.shape[0], label_mask.shape[1], 3)) 179 | rgb[:, :, 0] = r / 255.0 180 | rgb[:, :, 1] = g / 255.0 181 | rgb[:, :, 2] = b / 255.0 182 | if plot: 183 | plt.imshow(rgb) 184 | plt.show() 185 | else: 186 | return rgb 187 | 188 | def setup_annotations(self): 189 | """Sets up Berkley annotations by adding image indices to the 190 | `train_aug` split and pre-encode all segmentation labels into the 191 | common label_mask format (if this has not already been done). This 192 | function also defines the `train_aug` and `train_aug_val` data splits 193 | according to the description in the class docstring 194 | """ 195 | sbd_path = self.sbd_path 196 | target_path = pjoin(self.root, "SegmentationClass/pre_encoded") 197 | if not os.path.exists(target_path): 198 | os.makedirs(target_path) 199 | path = pjoin(sbd_path, "dataset/train.txt") 200 | sbd_train_list = tuple(open(path, "r")) 201 | sbd_train_list = [id_.rstrip() for id_ in sbd_train_list] 202 | train_aug = self.files["train"] + sbd_train_list 203 | 204 | # keep unique elements (stable) 205 | train_aug = [train_aug[i] for i in sorted(np.unique(train_aug, return_index=True)[1])] 206 | self.files["train_aug"] = train_aug 207 | set_diff = set(self.files["val"]) - set(train_aug) # remove overlap 208 | self.files["train_aug_val"] = list(set_diff) 209 | 210 | pre_encoded = glob.glob(pjoin(target_path, "*.png")) 211 | expected = np.unique(self.files["train_aug"] + self.files["val"]).size 212 | 213 | if len(pre_encoded) != expected: 214 | print("Pre-encoding segmentation masks...") 215 | for ii in tqdm(sbd_train_list): 216 | lbl_path = pjoin(sbd_path, "dataset/cls", ii + ".mat") 217 | data = io.loadmat(lbl_path) 218 | lbl = data["GTcls"][0]["Segmentation"][0].astype(np.int32) 219 | lbl = m.toimage(lbl, high=lbl.max(), low=lbl.min()) 220 | m.imsave(pjoin(target_path, ii + ".png"), lbl) 221 | 222 | for ii in tqdm(self.files["trainval"]): 223 | fname = ii + ".png" 224 | lbl_path = pjoin(self.root, "SegmentationClass", fname) 225 | lbl = self.encode_segmap(m.imread(lbl_path)) 226 | lbl = m.toimage(lbl, high=lbl.max(), low=lbl.min()) 227 | m.imsave(pjoin(target_path, fname), lbl) 228 | 229 | assert expected == 9733, "unexpected dataset sizes" 230 | 231 | 232 | # Leave code for debugging purposes 233 | # import ptsemseg.augmentations as aug 234 | # if __name__ == '__main__': 235 | # # local_path = '/home/meetshah1995/datasets/VOCdevkit/VOC2012/' 236 | # bs = 4 237 | # augs = aug.Compose([aug.RandomRotate(10), aug.RandomHorizontallyFlip()]) 238 | # dst = pascalVOCLoader(root=local_path, is_transform=True, augmentations=augs) 239 | # trainloader = data.DataLoader(dst, batch_size=bs) 240 | # for i, data in enumerate(trainloader): 241 | # imgs, labels = data 242 | # imgs = imgs.numpy()[:, ::-1, :, :] 243 | # imgs = np.transpose(imgs, [0,2,3,1]) 244 | # f, axarr = plt.subplots(bs, 2) 245 | # for j in range(bs): 246 | # axarr[j][0].imshow(imgs[j]) 247 | # axarr[j][1].imshow(dst.decode_segmap(labels.numpy()[j])) 248 | # plt.show() 249 | # a = raw_input() 250 | # if a == 'ex': 251 | # break 252 | # else: 253 | # plt.close() 254 | -------------------------------------------------------------------------------- /ptsemseg/loader/sunrgbd_loader.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import torch 3 | import numpy as np 4 | import scipy.misc as m 5 | 6 | from torch.utils import data 7 | 8 | from ptsemseg.utils import recursive_glob 9 | from ptsemseg.augmentations import Compose, RandomHorizontallyFlip, RandomRotate, Scale 10 | 11 | 12 | class SUNRGBDLoader(data.Dataset): 13 | """SUNRGBD loader 14 | 15 | Download From: 16 | http://www.doc.ic.ac.uk/~ahanda/SUNRGBD-test_images.tgz 17 | test source: http://www.doc.ic.ac.uk/~ahanda/SUNRGBD-test_images.tgz 18 | train source: http://www.doc.ic.ac.uk/~ahanda/SUNRGBD-train_images.tgz 19 | 20 | first 5050 in this is test, later 5051 is train 21 | test and train labels source: 22 | https://github.com/ankurhanda/sunrgbd-meta-data/raw/master/sunrgbd_train_test_labels.tar.gz 23 | """ 24 | 25 | def __init__( 26 | self, 27 | root, 28 | split="training", 29 | is_transform=False, 30 | img_size=(480, 640), 31 | augmentations=None, 32 | img_norm=True, 33 | test_mode=False, 34 | ): 35 | self.root = root 36 | self.is_transform = is_transform 37 | self.n_classes = 38 38 | self.augmentations = augmentations 39 | self.img_norm = img_norm 40 | self.test_mode = test_mode 41 | self.img_size = img_size if isinstance(img_size, tuple) else (img_size, img_size) 42 | self.mean = np.array([104.00699, 116.66877, 122.67892]) 43 | self.files = collections.defaultdict(list) 44 | self.anno_files = collections.defaultdict(list) 45 | self.cmap = self.color_map(normalized=False) 46 | 47 | split_map = {"training": "train", "val": "test"} 48 | self.split = split_map[split] 49 | 50 | for split in ["train", "test"]: 51 | file_list = sorted(recursive_glob(rootdir=self.root + split + "/", suffix="jpg")) 52 | self.files[split] = file_list 53 | 54 | for split in ["train", "test"]: 55 | file_list = sorted( 56 | recursive_glob(rootdir=self.root + "annotations/" + split + "/", suffix="png") 57 | ) 58 | self.anno_files[split] = file_list 59 | 60 | def __len__(self): 61 | return len(self.files[self.split]) 62 | 63 | def __getitem__(self, index): 64 | img_path = self.files[self.split][index].rstrip() 65 | lbl_path = self.anno_files[self.split][index].rstrip() 66 | # img_number = img_path.split('/')[-1] 67 | # lbl_path = os.path.join(self.root, 'annotations', img_number).replace('jpg', 'png') 68 | 69 | img = m.imread(img_path) 70 | img = np.array(img, dtype=np.uint8) 71 | 72 | lbl = m.imread(lbl_path) 73 | lbl = np.array(lbl, dtype=np.uint8) 74 | 75 | if not (len(img.shape) == 3 and len(lbl.shape) == 2): 76 | return self.__getitem__(np.random.randint(0, self.__len__())) 77 | 78 | if self.augmentations is not None: 79 | img, lbl = self.augmentations(img, lbl) 80 | 81 | if self.is_transform: 82 | img, lbl = self.transform(img, lbl) 83 | 84 | return img, lbl 85 | 86 | def transform(self, img, lbl): 87 | img = m.imresize(img, (self.img_size[0], self.img_size[1])) # uint8 with RGB mode 88 | img = img[:, :, ::-1] # RGB -> BGR 89 | img = img.astype(np.float64) 90 | img -= self.mean 91 | if self.img_norm: 92 | # Resize scales images from 0 to 255, thus we need 93 | # to divide by 255.0 94 | img = img.astype(float) / 255.0 95 | # NHWC -> NCHW 96 | img = img.transpose(2, 0, 1) 97 | 98 | classes = np.unique(lbl) 99 | lbl = lbl.astype(float) 100 | lbl = m.imresize(lbl, (self.img_size[0], self.img_size[1]), "nearest", mode="F") 101 | lbl = lbl.astype(int) 102 | assert np.all(classes == np.unique(lbl)) 103 | 104 | img = torch.from_numpy(img).float() 105 | lbl = torch.from_numpy(lbl).long() 106 | return img, lbl 107 | 108 | def color_map(self, N=256, normalized=False): 109 | """ 110 | Return Color Map in PASCAL VOC format 111 | """ 112 | 113 | def bitget(byteval, idx): 114 | return (byteval & (1 << idx)) != 0 115 | 116 | dtype = "float32" if normalized else "uint8" 117 | cmap = np.zeros((N, 3), dtype=dtype) 118 | for i in range(N): 119 | r = g = b = 0 120 | c = i 121 | for j in range(8): 122 | r = r | (bitget(c, 0) << 7 - j) 123 | g = g | (bitget(c, 1) << 7 - j) 124 | b = b | (bitget(c, 2) << 7 - j) 125 | c = c >> 3 126 | 127 | cmap[i] = np.array([r, g, b]) 128 | 129 | cmap = cmap / 255.0 if normalized else cmap 130 | return cmap 131 | 132 | def decode_segmap(self, temp): 133 | r = temp.copy() 134 | g = temp.copy() 135 | b = temp.copy() 136 | for l in range(0, self.n_classes): 137 | r[temp == l] = self.cmap[l, 0] 138 | g[temp == l] = self.cmap[l, 1] 139 | b[temp == l] = self.cmap[l, 2] 140 | 141 | rgb = np.zeros((temp.shape[0], temp.shape[1], 3)) 142 | rgb[:, :, 0] = r / 255.0 143 | rgb[:, :, 1] = g / 255.0 144 | rgb[:, :, 2] = b / 255.0 145 | return rgb 146 | 147 | 148 | if __name__ == "__main__": 149 | import matplotlib.pyplot as plt 150 | 151 | augmentations = Compose([Scale(512), RandomRotate(10), RandomHorizontallyFlip()]) 152 | 153 | local_path = "/home/meet/datasets/SUNRGBD/" 154 | dst = SUNRGBDLoader(local_path, is_transform=True, augmentations=augmentations) 155 | bs = 4 156 | trainloader = data.DataLoader(dst, batch_size=bs, num_workers=0) 157 | for i, data_samples in enumerate(trainloader): 158 | imgs, labels = data_samples 159 | imgs = imgs.numpy()[:, ::-1, :, :] 160 | imgs = np.transpose(imgs, [0, 2, 3, 1]) 161 | f, axarr = plt.subplots(bs, 2) 162 | for j in range(bs): 163 | axarr[j][0].imshow(imgs[j]) 164 | axarr[j][1].imshow(dst.decode_segmap(labels.numpy()[j])) 165 | plt.show() 166 | a = input() 167 | if a == "ex": 168 | break 169 | else: 170 | plt.close() 171 | -------------------------------------------------------------------------------- /ptsemseg/loss/__init__.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import functools 3 | 4 | from ptsemseg.loss.loss import ( 5 | cross_entropy2d, 6 | bootstrapped_cross_entropy2d, 7 | multi_scale_cross_entropy2d, 8 | ) 9 | 10 | 11 | logger = logging.getLogger("ptsemseg") 12 | 13 | key2loss = { 14 | "cross_entropy": cross_entropy2d, 15 | "bootstrapped_cross_entropy": bootstrapped_cross_entropy2d, 16 | "multi_scale_cross_entropy": multi_scale_cross_entropy2d, 17 | } 18 | 19 | 20 | def get_loss_function(cfg): 21 | if cfg["training"]["loss"] is None: 22 | logger.info("Using default cross entropy loss") 23 | return cross_entropy2d 24 | 25 | else: 26 | loss_dict = cfg["training"]["loss"] 27 | loss_name = loss_dict["name"] 28 | loss_params = {k: v for k, v in loss_dict.items() if k != "name"} 29 | 30 | if loss_name not in key2loss: 31 | raise NotImplementedError("Loss {} not implemented".format(loss_name)) 32 | 33 | logger.info("Using {} with {} params".format(loss_name, loss_params)) 34 | return functools.partial(key2loss[loss_name], **loss_params) 35 | -------------------------------------------------------------------------------- /ptsemseg/loss/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lauraset/BuildingHeightModel/d13d16a02dfdfa042d63492009a7cc8e46032170/ptsemseg/loss/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /ptsemseg/loss/__pycache__/loss.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lauraset/BuildingHeightModel/d13d16a02dfdfa042d63492009a7cc8e46032170/ptsemseg/loss/__pycache__/loss.cpython-36.pyc -------------------------------------------------------------------------------- /ptsemseg/loss/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | def multitaskloss_train(input, target): 5 | loss = 0.5*F.smooth_l1_loss(input[0], target, size_average=True) +\ 6 | 0.7*F.smooth_l1_loss(input[1], target, size_average=True) + \ 7 | F.smooth_l1_loss(input[2], target, size_average=True) 8 | return loss 9 | 10 | 11 | def cross_entropy2d(input, target, weight=None, size_average=True): 12 | n, c, h, w = input.size() 13 | nt, ht, wt = target.size() 14 | 15 | # Handle inconsistent size between input and target 16 | if h != ht and w != wt: # upsample labels 17 | input = F.interpolate(input, size=(ht, wt), mode="bilinear", align_corners=True) 18 | 19 | input = input.transpose(1, 2).transpose(2, 3).contiguous().view(-1, c) 20 | target = target.view(-1) 21 | loss = F.cross_entropy( 22 | input, target, weight=weight, size_average=size_average, ignore_index=250 23 | ) 24 | return loss 25 | 26 | 27 | def multi_scale_cross_entropy2d(input, target, weight=None, size_average=True, scale_weight=None): 28 | if not isinstance(input, tuple): 29 | return cross_entropy2d(input=input, target=target, weight=weight, size_average=size_average) 30 | 31 | # Auxiliary training for PSPNet [1.0, 0.4] and ICNet [1.0, 0.4, 0.16] 32 | if scale_weight is None: # scale_weight: torch tensor type 33 | n_inp = len(input) 34 | scale = 0.4 35 | scale_weight = torch.pow(scale * torch.ones(n_inp), torch.arange(n_inp).float()).to( 36 | target.device 37 | ) 38 | 39 | loss = 0.0 40 | for i, inp in enumerate(input): 41 | loss = loss + scale_weight[i] * cross_entropy2d( 42 | input=inp, target=target, weight=weight, size_average=size_average 43 | ) 44 | 45 | return loss 46 | 47 | 48 | def bootstrapped_cross_entropy2d(input, target, K, weight=None, size_average=True): 49 | 50 | batch_size = input.size()[0] 51 | 52 | def _bootstrap_xentropy_single(input, target, K, weight=None, size_average=True): 53 | 54 | n, c, h, w = input.size() 55 | input = input.transpose(1, 2).transpose(2, 3).contiguous().view(-1, c) 56 | target = target.view(-1) 57 | loss = F.cross_entropy( 58 | input, target, weight=weight, reduce=False, size_average=False, ignore_index=250 59 | ) 60 | 61 | topk_loss, _ = loss.topk(K) 62 | reduced_topk_loss = topk_loss.sum() / K 63 | 64 | return reduced_topk_loss 65 | 66 | loss = 0.0 67 | # Bootstrap from each image not entire batch 68 | for i in range(batch_size): 69 | loss += _bootstrap_xentropy_single( 70 | input=torch.unsqueeze(input[i], 0), 71 | target=torch.unsqueeze(target[i], 0), 72 | K=K, 73 | weight=weight, 74 | size_average=size_average, 75 | ) 76 | return loss / float(batch_size) 77 | -------------------------------------------------------------------------------- /ptsemseg/metrics.py: -------------------------------------------------------------------------------- 1 | # Adapted from score written by wkentaro 2 | # https://github.com/wkentaro/pytorch-fcn/blob/master/torchfcn/utils.py 3 | 4 | import numpy as np 5 | 6 | 7 | class runningScore(object): 8 | def __init__(self, n_classes): 9 | self.n_classes = n_classes 10 | self.confusion_matrix = np.zeros((n_classes, n_classes)) 11 | 12 | def _fast_hist(self, label_true, label_pred, n_class): 13 | mask = (label_true >= 0) & (label_true < n_class) 14 | hist = np.bincount( 15 | n_class * label_true[mask].astype(int) + label_pred[mask], minlength=n_class ** 2 16 | ).reshape(n_class, n_class) 17 | return hist 18 | 19 | def update(self, label_trues, label_preds): 20 | for lt, lp in zip(label_trues, label_preds): 21 | self.confusion_matrix += self._fast_hist(lt.flatten(), lp.flatten(), self.n_classes) 22 | 23 | def get_scores(self): 24 | """Returns accuracy score evaluation result. 25 | - overall accuracy 26 | - mean accuracy 27 | - mean IU 28 | - fwavacc 29 | """ 30 | hist = self.confusion_matrix 31 | acc = np.diag(hist).sum() / hist.sum() 32 | acc_cls = np.diag(hist) / hist.sum(axis=1) 33 | acc_cls = np.nanmean(acc_cls) 34 | iu = np.diag(hist) / (hist.sum(axis=1) + hist.sum(axis=0) - np.diag(hist)) 35 | mean_iu = np.nanmean(iu) 36 | freq = hist.sum(axis=1) / hist.sum() 37 | fwavacc = (freq[freq > 0] * iu[freq > 0]).sum() 38 | cls_iu = dict(zip(range(self.n_classes), iu)) 39 | 40 | return ( 41 | { 42 | "Overall Acc: \t": acc, 43 | "Mean Acc : \t": acc_cls, 44 | "FreqW Acc : \t": fwavacc, 45 | "Mean IoU : \t": mean_iu, 46 | }, 47 | cls_iu, 48 | ) 49 | 50 | def reset(self): 51 | self.confusion_matrix = np.zeros((self.n_classes, self.n_classes)) 52 | 53 | 54 | class averageMeter(object): 55 | """Computes and stores the average and current value""" 56 | 57 | def __init__(self): 58 | self.reset() 59 | 60 | def reset(self): 61 | self.val = 0 62 | self.avg = 0 63 | self.sum = 0 64 | self.count = 0 65 | 66 | def update(self, val, n=1): 67 | self.val = val 68 | self.sum += val * n 69 | self.count += n 70 | self.avg = self.sum / self.count 71 | 72 | 73 | class heightacc(object): 74 | ''' 75 | compute acc 76 | ''' 77 | def __init__(self): 78 | self.reset() 79 | 80 | def reset(self): 81 | #self.r2 = 0 不好计算 82 | self.mse = 0 83 | self.se = 0 84 | self.mae = 0 85 | #self.mape = 0 86 | self.count = 0 87 | self.yrefmean = 0 88 | self.ypref2 = 0 89 | 90 | def update(self, ypred, yref, num): 91 | self.se += np.mean(ypred-yref)*num 92 | 93 | self.mae += np.mean(np.abs(ypred-yref))*num 94 | 95 | self.mse += np.mean((ypred-yref)**2)*num 96 | 97 | #self.mape += np.mean(np.abs((ypred-yref)/(1e-8+yref)))*num 98 | 99 | self.yrefmean += np.mean(yref)*num 100 | 101 | self.ypref2 += np.mean(yref**2)*num 102 | 103 | self.count += num 104 | 105 | def getacc(self): 106 | se = self.se/self.count 107 | mae = self.mae/self.count 108 | mse = self.mse/self.count 109 | #mape = self.mape/self.count 110 | rmse = np.sqrt(mse) 111 | 112 | yrefmean = self.yrefmean/self.count 113 | yref2 = self.ypref2/self.count 114 | r2 = 1 - mse/(yref2 -yrefmean**2) 115 | return r2, rmse, mae, se 116 | -------------------------------------------------------------------------------- /ptsemseg/models/__init__.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import torchvision.models as models 3 | 4 | from ptsemseg.models.fcn import fcn8s, fcn16s, fcn32s 5 | from ptsemseg.models.segnet import segnet 6 | from ptsemseg.models.unet import unet 7 | from ptsemseg.models.pspnet import pspnet 8 | from ptsemseg.models.icnet import icnet 9 | from ptsemseg.models.linknet import linknet 10 | from ptsemseg.models.frrn import frrn 11 | from ptsemseg.models.tlcnet import TLCNet, TLCNetU, TLCNetUmux, TLCNetUtlc, TLCNetUtlcmux 12 | 13 | def get_model(model_dict, n_maxdisp=256, n_classes=1, version=None): 14 | name = model_dict["arch"] 15 | model = _get_model_instance(name) 16 | param_dict = copy.deepcopy(model_dict) 17 | param_dict.pop("arch") 18 | 19 | if name in ["frrnA", "frrnB"]: 20 | model = model(n_classes, **param_dict) 21 | 22 | elif name in ["fcn32s", "fcn16s", "fcn8s"]: 23 | model = model(n_classes=n_classes, **param_dict) 24 | vgg16 = models.vgg16(pretrained=True) 25 | model.init_vgg16_params(vgg16) 26 | 27 | elif name == "segnet": 28 | model = model(n_classes=n_classes, **param_dict) 29 | vgg16 = models.vgg16(pretrained=True) 30 | model.init_vgg16_params(vgg16) 31 | 32 | elif name == "unet": 33 | model = model(n_classes=n_classes, **param_dict) 34 | 35 | elif name == "pspnet": 36 | model = model(n_classes=n_classes, **param_dict) 37 | 38 | elif name == "icnet": 39 | model = model(n_classes=n_classes, **param_dict) 40 | 41 | elif name == "icnetBN": 42 | model = model(n_classes=n_classes, **param_dict) 43 | 44 | elif name == "tlcnet": 45 | model = model(maxdisp=n_maxdisp, **param_dict) 46 | 47 | elif name == "tlcnetu": 48 | model = model(n_classes=n_classes, **param_dict) 49 | 50 | elif name=="tlcnetumux": # 2020.10.3 add 51 | model = model(n_classes=n_classes, **param_dict) 52 | 53 | elif name=="tlcnetutlc": # 2020.10.3 add 54 | model = model(n_classes=n_classes, **param_dict) 55 | 56 | elif name=="tlcnetutlcmux": # 2020.10.5 add 57 | model = model(n_classes=n_classes, **param_dict) 58 | 59 | else: 60 | model = model(n_classes=n_classes, **param_dict) 61 | 62 | return model 63 | 64 | 65 | def _get_model_instance(name): 66 | try: 67 | return { 68 | "fcn32s": fcn32s, 69 | "fcn8s": fcn8s, 70 | "fcn16s": fcn16s, 71 | "unet": unet, 72 | "segnet": segnet, 73 | "pspnet": pspnet, 74 | "icnet": icnet, 75 | "icnetBN": icnet, 76 | "linknet": linknet, 77 | "frrnA": frrn, 78 | "frrnB": frrn, 79 | "tlcnet": TLCNet, 80 | "tlcnetu": TLCNetU, 81 | "tlcnetumux": TLCNetUmux, 82 | "tlcnetutlc": TLCNetUtlc, 83 | "tlcnetutlcmux": TLCNetUtlcmux 84 | }[name] 85 | except: 86 | raise ("Model {} not available".format(name)) 87 | -------------------------------------------------------------------------------- /ptsemseg/models/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lauraset/BuildingHeightModel/d13d16a02dfdfa042d63492009a7cc8e46032170/ptsemseg/models/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /ptsemseg/models/__pycache__/fcn.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lauraset/BuildingHeightModel/d13d16a02dfdfa042d63492009a7cc8e46032170/ptsemseg/models/__pycache__/fcn.cpython-36.pyc -------------------------------------------------------------------------------- /ptsemseg/models/__pycache__/frrn.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lauraset/BuildingHeightModel/d13d16a02dfdfa042d63492009a7cc8e46032170/ptsemseg/models/__pycache__/frrn.cpython-36.pyc -------------------------------------------------------------------------------- /ptsemseg/models/__pycache__/icnet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lauraset/BuildingHeightModel/d13d16a02dfdfa042d63492009a7cc8e46032170/ptsemseg/models/__pycache__/icnet.cpython-36.pyc -------------------------------------------------------------------------------- /ptsemseg/models/__pycache__/linknet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lauraset/BuildingHeightModel/d13d16a02dfdfa042d63492009a7cc8e46032170/ptsemseg/models/__pycache__/linknet.cpython-36.pyc -------------------------------------------------------------------------------- /ptsemseg/models/__pycache__/pspnet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lauraset/BuildingHeightModel/d13d16a02dfdfa042d63492009a7cc8e46032170/ptsemseg/models/__pycache__/pspnet.cpython-36.pyc -------------------------------------------------------------------------------- /ptsemseg/models/__pycache__/segnet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lauraset/BuildingHeightModel/d13d16a02dfdfa042d63492009a7cc8e46032170/ptsemseg/models/__pycache__/segnet.cpython-36.pyc -------------------------------------------------------------------------------- /ptsemseg/models/__pycache__/submodule.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lauraset/BuildingHeightModel/d13d16a02dfdfa042d63492009a7cc8e46032170/ptsemseg/models/__pycache__/submodule.cpython-36.pyc -------------------------------------------------------------------------------- /ptsemseg/models/__pycache__/tlcnet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lauraset/BuildingHeightModel/d13d16a02dfdfa042d63492009a7cc8e46032170/ptsemseg/models/__pycache__/tlcnet.cpython-36.pyc -------------------------------------------------------------------------------- /ptsemseg/models/__pycache__/unet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lauraset/BuildingHeightModel/d13d16a02dfdfa042d63492009a7cc8e46032170/ptsemseg/models/__pycache__/unet.cpython-36.pyc -------------------------------------------------------------------------------- /ptsemseg/models/__pycache__/utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lauraset/BuildingHeightModel/d13d16a02dfdfa042d63492009a7cc8e46032170/ptsemseg/models/__pycache__/utils.cpython-36.pyc -------------------------------------------------------------------------------- /ptsemseg/models/frrn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from ptsemseg.models.utils import FRRU, RU, conv2DBatchNormRelu, conv2DGroupNormRelu 6 | 7 | frrn_specs_dic = { 8 | "A": { 9 | "encoder": [[3, 96, 2], [4, 192, 4], [2, 384, 8], [2, 384, 16]], 10 | "decoder": [[2, 192, 8], [2, 192, 4], [2, 48, 2]], 11 | }, 12 | "B": { 13 | "encoder": [[3, 96, 2], [4, 192, 4], [2, 384, 8], [2, 384, 16], [2, 384, 32]], 14 | "decoder": [[2, 192, 16], [2, 192, 8], [2, 192, 4], [2, 48, 2]], 15 | }, 16 | } 17 | 18 | 19 | class frrn(nn.Module): 20 | """ 21 | Full Resolution Residual Networks for Semantic Segmentation 22 | URL: https://arxiv.org/abs/1611.08323 23 | 24 | References: 25 | 1) Original Author's code: https://github.com/TobyPDE/FRRN 26 | 2) TF implementation by @kiwonjoon: https://github.com/hiwonjoon/tf-frrn 27 | """ 28 | 29 | def __init__(self, n_classes=21, model_type="B", group_norm=False, n_groups=16): 30 | super(frrn, self).__init__() 31 | self.n_classes = n_classes 32 | self.model_type = model_type 33 | self.group_norm = group_norm 34 | self.n_groups = n_groups 35 | 36 | if self.group_norm: 37 | self.conv1 = conv2DGroupNormRelu(3, 48, 5, 1, 2) 38 | else: 39 | self.conv1 = conv2DBatchNormRelu(3, 48, 5, 1, 2) 40 | 41 | self.up_residual_units = [] 42 | self.down_residual_units = [] 43 | for i in range(3): 44 | self.up_residual_units.append( 45 | RU( 46 | channels=48, 47 | kernel_size=3, 48 | strides=1, 49 | group_norm=self.group_norm, 50 | n_groups=self.n_groups, 51 | ) 52 | ) 53 | self.down_residual_units.append( 54 | RU( 55 | channels=48, 56 | kernel_size=3, 57 | strides=1, 58 | group_norm=self.group_norm, 59 | n_groups=self.n_groups, 60 | ) 61 | ) 62 | 63 | self.up_residual_units = nn.ModuleList(self.up_residual_units) 64 | self.down_residual_units = nn.ModuleList(self.down_residual_units) 65 | 66 | self.split_conv = nn.Conv2d(48, 32, kernel_size=1, padding=0, stride=1, bias=False) 67 | 68 | # each spec is as (n_blocks, channels, scale) 69 | self.encoder_frru_specs = frrn_specs_dic[self.model_type]["encoder"] 70 | 71 | self.decoder_frru_specs = frrn_specs_dic[self.model_type]["decoder"] 72 | 73 | # encoding 74 | prev_channels = 48 75 | self.encoding_frrus = {} 76 | for n_blocks, channels, scale in self.encoder_frru_specs: 77 | for block in range(n_blocks): 78 | key = "_".join(map(str, ["encoding_frru", n_blocks, channels, scale, block])) 79 | setattr( 80 | self, 81 | key, 82 | FRRU( 83 | prev_channels=prev_channels, 84 | out_channels=channels, 85 | scale=scale, 86 | group_norm=self.group_norm, 87 | n_groups=self.n_groups, 88 | ), 89 | ) 90 | prev_channels = channels 91 | 92 | # decoding 93 | self.decoding_frrus = {} 94 | for n_blocks, channels, scale in self.decoder_frru_specs: 95 | # pass through decoding FRRUs 96 | for block in range(n_blocks): 97 | key = "_".join(map(str, ["decoding_frru", n_blocks, channels, scale, block])) 98 | setattr( 99 | self, 100 | key, 101 | FRRU( 102 | prev_channels=prev_channels, 103 | out_channels=channels, 104 | scale=scale, 105 | group_norm=self.group_norm, 106 | n_groups=self.n_groups, 107 | ), 108 | ) 109 | prev_channels = channels 110 | 111 | self.merge_conv = nn.Conv2d( 112 | prev_channels + 32, 48, kernel_size=1, padding=0, stride=1, bias=False 113 | ) 114 | 115 | self.classif_conv = nn.Conv2d( 116 | 48, self.n_classes, kernel_size=1, padding=0, stride=1, bias=True 117 | ) 118 | 119 | def forward(self, x): 120 | 121 | # pass to initial conv 122 | x = self.conv1(x) 123 | 124 | # pass through residual units 125 | for i in range(3): 126 | x = self.up_residual_units[i](x) 127 | 128 | # divide stream 129 | y = x 130 | z = self.split_conv(x) 131 | 132 | prev_channels = 48 133 | # encoding 134 | for n_blocks, channels, scale in self.encoder_frru_specs: 135 | # maxpool bigger feature map 136 | y_pooled = F.max_pool2d(y, stride=2, kernel_size=2, padding=0) 137 | # pass through encoding FRRUs 138 | for block in range(n_blocks): 139 | key = "_".join(map(str, ["encoding_frru", n_blocks, channels, scale, block])) 140 | y, z = getattr(self, key)(y_pooled, z) 141 | prev_channels = channels 142 | 143 | # decoding 144 | for n_blocks, channels, scale in self.decoder_frru_specs: 145 | # bilinear upsample smaller feature map 146 | upsample_size = torch.Size([_s * 2 for _s in y.size()[-2:]]) 147 | y_upsampled = F.upsample(y, size=upsample_size, mode="bilinear", align_corners=True) 148 | # pass through decoding FRRUs 149 | for block in range(n_blocks): 150 | key = "_".join(map(str, ["decoding_frru", n_blocks, channels, scale, block])) 151 | # print("Incoming FRRU Size: ", key, y_upsampled.shape, z.shape) 152 | y, z = getattr(self, key)(y_upsampled, z) 153 | # print("Outgoing FRRU Size: ", key, y.shape, z.shape) 154 | prev_channels = channels 155 | 156 | # merge streams 157 | x = torch.cat( 158 | [F.upsample(y, scale_factor=2, mode="bilinear", align_corners=True), z], dim=1 159 | ) 160 | x = self.merge_conv(x) 161 | 162 | # pass through residual units 163 | for i in range(3): 164 | x = self.down_residual_units[i](x) 165 | 166 | # final 1x1 conv to get classification 167 | x = self.classif_conv(x) 168 | 169 | return x 170 | -------------------------------------------------------------------------------- /ptsemseg/models/linknet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from ptsemseg.models.utils import conv2DBatchNormRelu, linknetUp, residualBlock 4 | 5 | 6 | class linknet(nn.Module): 7 | def __init__( 8 | self, feature_scale=4, n_classes=21, is_deconv=True, in_channels=3, is_batchnorm=True 9 | ): 10 | super(linknet, self).__init__() 11 | self.is_deconv = is_deconv 12 | self.in_channels = in_channels 13 | self.is_batchnorm = is_batchnorm 14 | self.feature_scale = feature_scale 15 | self.layers = [2, 2, 2, 2] # Currently hardcoded for ResNet-18 16 | 17 | filters = [64, 128, 256, 512] 18 | filters = [x / self.feature_scale for x in filters] 19 | 20 | self.inplanes = filters[0] 21 | 22 | # Encoder 23 | self.convbnrelu1 = conv2DBatchNormRelu( 24 | in_channels=3, k_size=7, n_filters=64, padding=3, stride=2, bias=False 25 | ) 26 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 27 | 28 | block = residualBlock 29 | self.encoder1 = self._make_layer(block, filters[0], self.layers[0]) 30 | self.encoder2 = self._make_layer(block, filters[1], self.layers[1], stride=2) 31 | self.encoder3 = self._make_layer(block, filters[2], self.layers[2], stride=2) 32 | self.encoder4 = self._make_layer(block, filters[3], self.layers[3], stride=2) 33 | self.avgpool = nn.AvgPool2d(7) 34 | 35 | # Decoder 36 | self.decoder4 = linknetUp(filters[3], filters[2]) 37 | self.decoder4 = linknetUp(filters[2], filters[1]) 38 | self.decoder4 = linknetUp(filters[1], filters[0]) 39 | self.decoder4 = linknetUp(filters[0], filters[0]) 40 | 41 | # Final Classifier 42 | self.finaldeconvbnrelu1 = nn.Sequential( 43 | nn.ConvTranspose2d(filters[0], 32 / feature_scale, 3, 2, 1), 44 | nn.BatchNorm2d(32 / feature_scale), 45 | nn.ReLU(inplace=True), 46 | ) 47 | self.finalconvbnrelu2 = conv2DBatchNormRelu( 48 | in_channels=32 / feature_scale, 49 | k_size=3, 50 | n_filters=32 / feature_scale, 51 | padding=1, 52 | stride=1, 53 | ) 54 | self.finalconv3 = nn.Conv2d(32 / feature_scale, n_classes, 2, 2, 0) 55 | 56 | def _make_layer(self, block, planes, blocks, stride=1): 57 | downsample = None 58 | if stride != 1 or self.inplanes != planes * block.expansion: 59 | downsample = nn.Sequential( 60 | nn.Conv2d( 61 | self.inplanes, 62 | planes * block.expansion, 63 | kernel_size=1, 64 | stride=stride, 65 | bias=False, 66 | ), 67 | nn.BatchNorm2d(planes * block.expansion), 68 | ) 69 | layers = [] 70 | layers.append(block(self.inplanes, planes, stride, downsample)) 71 | self.inplanes = planes * block.expansion 72 | for i in range(1, blocks): 73 | layers.append(block(self.inplanes, planes)) 74 | return nn.Sequential(*layers) 75 | 76 | def forward(self, x): 77 | # Encoder 78 | x = self.convbnrelu1(x) 79 | x = self.maxpool(x) 80 | 81 | e1 = self.encoder1(x) 82 | e2 = self.encoder2(e1) 83 | e3 = self.encoder3(e2) 84 | e4 = self.encoder4(e3) 85 | 86 | # Decoder with Skip Connections 87 | d4 = self.decoder4(e4) 88 | d4 += e3 89 | d3 = self.decoder3(d4) 90 | d3 += e2 91 | d2 = self.decoder2(d3) 92 | d2 += e1 93 | d1 = self.decoder1(d2) 94 | 95 | # Final Classification 96 | f1 = self.finaldeconvbnrelu1(d1) 97 | f2 = self.finalconvbnrelu2(f1) 98 | f3 = self.finalconv3(f2) 99 | 100 | return f3 101 | -------------------------------------------------------------------------------- /ptsemseg/models/refinenet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class refinenet(nn.Module): 5 | """ 6 | RefineNet: Multi-Path Refinement Networks for High-Resolution Semantic Segmentation 7 | URL: https://arxiv.org/abs/1611.06612 8 | 9 | References: 10 | 1) Original Author's MATLAB code: https://github.com/guosheng/refinenet 11 | 2) TF implementation by @eragonruan: https://github.com/eragonruan/refinenet-image-segmentation 12 | """ 13 | 14 | def __init__(self, n_classes=21): 15 | super(refinenet, self).__init__() 16 | self.n_classes = n_classes 17 | 18 | def forward(self, x): 19 | pass 20 | -------------------------------------------------------------------------------- /ptsemseg/models/segnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from ptsemseg.models.utils import segnetDown2, segnetDown3, segnetUp2, segnetUp3 4 | 5 | 6 | class segnet(nn.Module): 7 | def __init__(self, n_classes=21, in_channels=3, is_unpooling=True): 8 | super(segnet, self).__init__() 9 | 10 | self.in_channels = in_channels 11 | self.is_unpooling = is_unpooling 12 | 13 | self.down1 = segnetDown2(self.in_channels, 64) 14 | self.down2 = segnetDown2(64, 128) 15 | self.down3 = segnetDown3(128, 256) 16 | self.down4 = segnetDown3(256, 512) 17 | self.down5 = segnetDown3(512, 512) 18 | 19 | self.up5 = segnetUp3(512, 512) 20 | self.up4 = segnetUp3(512, 256) 21 | self.up3 = segnetUp3(256, 128) 22 | self.up2 = segnetUp2(128, 64) 23 | self.up1 = segnetUp2(64, n_classes) 24 | 25 | def forward(self, inputs): 26 | 27 | down1, indices_1, unpool_shape1 = self.down1(inputs) 28 | down2, indices_2, unpool_shape2 = self.down2(down1) 29 | down3, indices_3, unpool_shape3 = self.down3(down2) 30 | down4, indices_4, unpool_shape4 = self.down4(down3) 31 | down5, indices_5, unpool_shape5 = self.down5(down4) 32 | 33 | up5 = self.up5(down5, indices_5, unpool_shape5) 34 | up4 = self.up4(up5, indices_4, unpool_shape4) 35 | up3 = self.up3(up4, indices_3, unpool_shape3) 36 | up2 = self.up2(up3, indices_2, unpool_shape2) 37 | up1 = self.up1(up2, indices_1, unpool_shape1) 38 | 39 | return up1 40 | 41 | def init_vgg16_params(self, vgg16): 42 | blocks = [self.down1, self.down2, self.down3, self.down4, self.down5] 43 | 44 | features = list(vgg16.features.children()) 45 | 46 | vgg_layers = [] 47 | for _layer in features: 48 | if isinstance(_layer, nn.Conv2d): 49 | vgg_layers.append(_layer) 50 | 51 | merged_layers = [] 52 | for idx, conv_block in enumerate(blocks): 53 | if idx < 2: 54 | units = [conv_block.conv1.cbr_unit, conv_block.conv2.cbr_unit] 55 | else: 56 | units = [ 57 | conv_block.conv1.cbr_unit, 58 | conv_block.conv2.cbr_unit, 59 | conv_block.conv3.cbr_unit, 60 | ] 61 | for _unit in units: 62 | for _layer in _unit: 63 | if isinstance(_layer, nn.Conv2d): 64 | merged_layers.append(_layer) 65 | 66 | assert len(vgg_layers) == len(merged_layers) 67 | 68 | for l1, l2 in zip(vgg_layers, merged_layers): 69 | if isinstance(l1, nn.Conv2d) and isinstance(l2, nn.Conv2d): 70 | assert l1.weight.size() == l2.weight.size() 71 | assert l1.bias.size() == l2.bias.size() 72 | l2.weight.data = l1.weight.data 73 | l2.bias.data = l1.bias.data 74 | -------------------------------------------------------------------------------- /ptsemseg/models/submodule.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import torch 3 | import torch.nn as nn 4 | import torch.utils.data 5 | from torch.autograd import Variable 6 | import torch.nn.functional as F 7 | import math 8 | import numpy as np 9 | 10 | def convbn(in_planes, out_planes, kernel_size, stride, pad, dilation): 11 | 12 | return nn.Sequential(nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=dilation if dilation > 1 else pad, dilation = dilation, bias=False), 13 | nn.BatchNorm2d(out_planes)) 14 | 15 | 16 | def convbn_3d(in_planes, out_planes, kernel_size, stride, pad): 17 | 18 | return nn.Sequential(nn.Conv3d(in_planes, out_planes, kernel_size=kernel_size, padding=pad, stride=stride,bias=False), 19 | nn.BatchNorm3d(out_planes)) 20 | 21 | class BasicBlock(nn.Module): 22 | expansion = 1 23 | def __init__(self, inplanes, planes, stride, downsample, pad, dilation): 24 | super(BasicBlock, self).__init__() 25 | 26 | self.conv1 = nn.Sequential(convbn(inplanes, planes, 3, stride, pad, dilation), 27 | nn.ReLU(inplace=True)) 28 | 29 | self.conv2 = convbn(planes, planes, 3, 1, pad, dilation) 30 | 31 | self.downsample = downsample 32 | self.stride = stride 33 | 34 | def forward(self, x): 35 | out = self.conv1(x) 36 | out = self.conv2(out) 37 | 38 | if self.downsample is not None: 39 | x = self.downsample(x) 40 | 41 | out += x 42 | 43 | return out 44 | 45 | class matchshifted(nn.Module): 46 | def __init__(self): 47 | super(matchshifted, self).__init__() 48 | 49 | def forward(self, left, right, shift): 50 | batch, filters, height, width = left.size() 51 | shifted_left = F.pad(torch.index_select(left, 3, Variable(torch.LongTensor([i for i in range(shift,width)])).cuda()),(shift,0,0,0)) 52 | shifted_right = F.pad(torch.index_select(right, 3, Variable(torch.LongTensor([i for i in range(width-shift)])).cuda()),(shift,0,0,0)) 53 | out = torch.cat((shifted_left,shifted_right),1).view(batch,filters*2,1,height,width) 54 | return out 55 | 56 | class disparityregression(nn.Module): 57 | def __init__(self, maxdisp): 58 | super(disparityregression, self).__init__() 59 | self.disp = Variable(torch.Tensor(np.reshape(np.array(range(maxdisp)),[1,maxdisp,1,1])).cuda(), requires_grad=False) 60 | 61 | def forward(self, x): 62 | disp = self.disp.repeat(x.size()[0],1,x.size()[2],x.size()[3]) 63 | out = torch.sum(x*disp,1) 64 | return out 65 | 66 | 67 | class feature_extraction(nn.Module): 68 | def __init__(self, n_channels=3): 69 | super(feature_extraction, self).__init__() 70 | self.inplanes = 32 71 | self.n_channels=n_channels #input channels 72 | self.firstconv = nn.Sequential(convbn(self.n_channels, 32, 3, 2, 1, 1), 73 | nn.ReLU(inplace=True), 74 | convbn(32, 32, 3, 1, 1, 1), 75 | nn.ReLU(inplace=True), 76 | convbn(32, 32, 3, 1, 1, 1), 77 | nn.ReLU(inplace=True)) 78 | 79 | self.layer1 = self._make_layer(BasicBlock, 32, 3, 1,1,1) 80 | #self.layer2 = self._make_layer(BasicBlock, 64, 16, 2,1,1) original 81 | self.layer2 = self._make_layer(BasicBlock, 64, 3, 2, 1, 1) # changed to simple one 82 | self.layer3 = self._make_layer(BasicBlock, 128, 3, 1,1,1) 83 | self.layer4 = self._make_layer(BasicBlock, 128, 3, 1,1,2) 84 | 85 | self.branch1 = nn.Sequential(nn.AvgPool2d((64, 64), stride=(64,64)), 86 | convbn(128, 32, 1, 1, 0, 1), 87 | nn.ReLU(inplace=True)) 88 | 89 | self.branch2 = nn.Sequential(nn.AvgPool2d((32, 32), stride=(32,32)), 90 | convbn(128, 32, 1, 1, 0, 1), 91 | nn.ReLU(inplace=True)) 92 | 93 | self.branch3 = nn.Sequential(nn.AvgPool2d((16, 16), stride=(16,16)), 94 | convbn(128, 32, 1, 1, 0, 1), 95 | nn.ReLU(inplace=True)) 96 | 97 | self.branch4 = nn.Sequential(nn.AvgPool2d((8, 8), stride=(8,8)), 98 | convbn(128, 32, 1, 1, 0, 1), 99 | nn.ReLU(inplace=True)) 100 | 101 | self.lastconv = nn.Sequential(convbn(320, 128, 3, 1, 1, 1), 102 | nn.ReLU(inplace=True), 103 | nn.Conv2d(128, 32, kernel_size=1, padding=0, stride = 1, bias=False)) 104 | 105 | def _make_layer(self, block, planes, blocks, stride, pad, dilation): 106 | downsample = None 107 | if stride != 1 or self.inplanes != planes * block.expansion: 108 | downsample = nn.Sequential( 109 | nn.Conv2d(self.inplanes, planes * block.expansion, 110 | kernel_size=1, stride=stride, bias=False), 111 | nn.BatchNorm2d(planes * block.expansion),) 112 | 113 | layers = [] 114 | layers.append(block(self.inplanes, planes, stride, downsample, pad, dilation)) 115 | self.inplanes = planes * block.expansion 116 | for i in range(1, blocks): 117 | layers.append(block(self.inplanes, planes,1,None,pad,dilation)) 118 | 119 | return nn.Sequential(*layers) 120 | 121 | def forward(self, x): 122 | output = self.firstconv(x) 123 | output = self.layer1(output) 124 | output_raw = self.layer2(output) 125 | output = self.layer3(output_raw) 126 | output_skip = self.layer4(output) 127 | 128 | 129 | output_branch1 = self.branch1(output_skip) 130 | output_branch1 = F.interpolate(output_branch1, (output_skip.size()[2],output_skip.size()[3]),mode='bilinear', align_corners=True) 131 | 132 | output_branch2 = self.branch2(output_skip) 133 | output_branch2 = F.interpolate(output_branch2, (output_skip.size()[2],output_skip.size()[3]),mode='bilinear', align_corners=True) 134 | 135 | output_branch3 = self.branch3(output_skip) 136 | output_branch3 = F.interpolate(output_branch3, (output_skip.size()[2],output_skip.size()[3]),mode='bilinear', align_corners=True) 137 | 138 | output_branch4 = self.branch4(output_skip) 139 | output_branch4 = F.interpolate(output_branch4, (output_skip.size()[2],output_skip.size()[3]),mode='bilinear', align_corners=True) 140 | 141 | output_feature = torch.cat((output_raw, output_skip, output_branch4, output_branch3, output_branch2, output_branch1), 1) 142 | output_feature = self.lastconv(output_feature) 143 | 144 | return output_feature 145 | 146 | 147 | 148 | 149 | 150 | -------------------------------------------------------------------------------- /ptsemseg/models/tlcnetold.py: -------------------------------------------------------------------------------- 1 | ''' 2 | nad, fwd, bwd network 3 | refer to PSMN 4 | ''' 5 | 6 | from __future__ import print_function 7 | from ptsemseg.models.submodule import * 8 | from ptsemseg.models.utils import unetUpsimple 9 | 10 | class hourglass(nn.Module): 11 | def __init__(self, inplanes): 12 | super(hourglass, self).__init__() 13 | 14 | self.conv1 = nn.Sequential(convbn_3d(inplanes, inplanes * 2, kernel_size=3, stride=2, pad=1), 15 | nn.ReLU(inplace=True)) 16 | 17 | self.conv2 = convbn_3d(inplanes * 2, inplanes * 2, kernel_size=3, stride=1, pad=1) 18 | 19 | self.conv3 = nn.Sequential(convbn_3d(inplanes * 2, inplanes * 2, kernel_size=3, stride=2, pad=1), 20 | nn.ReLU(inplace=True)) 21 | 22 | self.conv4 = nn.Sequential(convbn_3d(inplanes * 2, inplanes * 2, kernel_size=3, stride=1, pad=1), 23 | nn.ReLU(inplace=True)) 24 | 25 | self.conv5 = nn.Sequential( 26 | nn.ConvTranspose3d(inplanes * 2, inplanes * 2, kernel_size=3, padding=1, output_padding=1, stride=2, 27 | bias=False), 28 | nn.BatchNorm3d(inplanes * 2)) # +conv2 29 | 30 | self.conv6 = nn.Sequential( 31 | nn.ConvTranspose3d(inplanes * 2, inplanes, kernel_size=3, padding=1, output_padding=1, stride=2, 32 | bias=False), 33 | nn.BatchNorm3d(inplanes)) # +x 34 | 35 | def forward(self, x, presqu, postsqu): 36 | 37 | out = self.conv1(x) # in:1/4 out:1/8 38 | pre = self.conv2(out) # in:1/8 out:1/8 39 | if postsqu is not None: 40 | pre = F.relu(pre + postsqu, inplace=True) 41 | else: 42 | pre = F.relu(pre, inplace=True) 43 | 44 | out = self.conv3(pre) # in:1/8 out:1/16 45 | out = self.conv4(out) # in:1/16 out:1/16 46 | 47 | if presqu is not None: 48 | post = F.relu(self.conv5(out) + presqu, inplace=True) # in:1/16 out:1/8 49 | else: 50 | post = F.relu(self.conv5(out) + pre, inplace=True) 51 | 52 | out = self.conv6(post) # in:1/8 out:1/4 53 | 54 | return out, pre, post 55 | 56 | 57 | class TLCNet(nn.Module): 58 | def __init__(self, maxdisp): 59 | super(TLCNet, self).__init__() 60 | self.maxdisp = maxdisp # max floor 61 | 62 | self.feature_extraction = feature_extraction() 63 | # add 64 | # self.dropout = nn.Dropout2d(p=0.5, inplace=False) 65 | # 1/4 w x h 66 | self.up1 = unetUpsimple(32, 32, True) 67 | self.up2 = unetUpsimple(32, 32, True) 68 | self.up3 = nn.Conv2d(32, 1, 1) 69 | #old 70 | self.dres0 = nn.Sequential(convbn_3d(64, 32, 3, 1, 1), 71 | nn.ReLU(inplace=True), 72 | convbn_3d(32, 32, 3, 1, 1), 73 | nn.ReLU(inplace=True)) 74 | 75 | self.dres1 = nn.Sequential(convbn_3d(32, 32, 3, 1, 1), 76 | nn.ReLU(inplace=True), 77 | convbn_3d(32, 32, 3, 1, 1)) 78 | 79 | self.dres2 = hourglass(32) 80 | 81 | self.dres3 = hourglass(32) 82 | 83 | self.dres4 = hourglass(32) 84 | 85 | self.classif1 = nn.Sequential(convbn_3d(32, 32, 3, 1, 1), 86 | nn.ReLU(inplace=True), 87 | nn.Conv3d(32, 1, kernel_size=3, padding=1, stride=1, bias=False)) 88 | 89 | self.classif2 = nn.Sequential(convbn_3d(32, 32, 3, 1, 1), 90 | nn.ReLU(inplace=True), 91 | nn.Conv3d(32, 1, kernel_size=3, padding=1, stride=1, bias=False)) 92 | 93 | self.classif3 = nn.Sequential(convbn_3d(32, 32, 3, 1, 1), 94 | nn.ReLU(inplace=True), 95 | nn.Conv3d(32, 1, kernel_size=3, padding=1, stride=1, bias=False)) 96 | 97 | for m in self.modules(): 98 | if isinstance(m, nn.Conv2d): 99 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 100 | m.weight.data.normal_(0, math.sqrt(2. / n)) 101 | elif isinstance(m, nn.Conv3d): 102 | n = m.kernel_size[0] * m.kernel_size[1] * m.kernel_size[2] * m.out_channels 103 | m.weight.data.normal_(0, math.sqrt(2. / n)) 104 | elif isinstance(m, nn.BatchNorm2d): 105 | m.weight.data.fill_(1) 106 | m.bias.data.zero_() 107 | elif isinstance(m, nn.BatchNorm3d): 108 | m.weight.data.fill_(1) 109 | m.bias.data.zero_() 110 | elif isinstance(m, nn.Linear): 111 | m.bias.data.zero_() 112 | 113 | def forward(self, tlc): 114 | heights=tlc.size()[2] 115 | widths=tlc.size()[3] 116 | # weight sharing 117 | muximg_fea = self.feature_extraction(tlc[:, 0:3, :, :]) # need decoder N x C x 1/4 H x 1/4 W 118 | refimg_fea = self.feature_extraction(tlc[:, 3:6, :, :]) 119 | targetimg_fea = self.feature_extraction(tlc[:, 6:9, :, :]) 120 | bwdimg_fea = self.feature_extraction(tlc[:, 9:12, :, :]) # need cost fusion 121 | 122 | # learning fine border: predicting building boundary 123 | costmux = self.up1(muximg_fea) 124 | costmux = self.up2(costmux) 125 | costmux = self.up3(costmux) 126 | predmux = torch.squeeze(costmux, 1) 127 | # return predmux 128 | 129 | cost1 = self.costvolume(refimg_fea, targetimg_fea) 130 | cost2 = self.costvolume(refimg_fea, bwdimg_fea) 131 | # max pooling to aggregate different volumes 132 | cost = torch.max(cost1, cost2) 133 | # correlation layer: fail to compile 134 | # costa = correlate(refimg_fea, targetimg_fea) # 135 | # costb = correlate(refimg_fea, bwdimg_fea) # 136 | # cost = torch.cat((costa, costb, muximg_fea), 1) # concatenate mux 137 | 138 | cost = cost.contiguous() # contigous in memory 139 | 140 | # 3D CNN 141 | cost0 = self.dres0(cost) 142 | cost0 = self.dres1(cost0) + cost0 143 | 144 | out1, pre1, post1 = self.dres2(cost0, None, None) 145 | out1 = out1 + cost0 146 | 147 | out2, pre2, post2 = self.dres3(out1, pre1, post1) 148 | out2 = out2 + cost0 149 | 150 | out3, pre3, post3 = self.dres4(out2, pre1, post2) 151 | out3 = out3 + cost0 152 | 153 | cost1 = self.classif1(out1) 154 | cost2 = self.classif2(out2) + cost1 155 | cost3 = self.classif3(out3) + cost2 156 | 157 | # deep supervision 158 | if self.training: 159 | cost1 = F.interpolate(cost1, [self.maxdisp, heights, widths], mode='trilinear', align_corners=True) 160 | cost2 = F.interpolate(cost2, [self.maxdisp, heights, widths], mode='trilinear', align_corners=True) 161 | 162 | cost1 = torch.squeeze(cost1, 1) 163 | pred1 = F.softmax(cost1, dim=1) 164 | pred1 = disparityregression(self.maxdisp)(pred1) 165 | 166 | cost2 = torch.squeeze(cost2, 1) 167 | pred2 = F.softmax(cost2, dim=1) 168 | pred2 = disparityregression(self.maxdisp)(pred2) 169 | 170 | cost3 = F.interpolate(cost3, [self.maxdisp, heights, widths], mode='trilinear', align_corners=True) 171 | cost3 = torch.squeeze(cost3, 1) 172 | pred3 = F.softmax(cost3, dim=1) 173 | # For your information: This formulation 'softmax(c)' learned "similarity" 174 | # while 'softmax(-c)' learned 'matching cost' as mentioned in the paper. 175 | # However, 'c' or '-c' do not affect the performance because feature-based cost volume provided flexibility. 176 | pred3 = disparityregression(self.maxdisp)(pred3) 177 | 178 | # return three supervision only when training 179 | # predict the floor number 180 | if self.training: 181 | return pred1, pred2, pred3, predmux 182 | else: 183 | return pred3 184 | 185 | def costvolume(self, refimg_fea , targetimg_fea): 186 | # matching: aggregate cost by conjunction along the disparity dimension 187 | # Shape: N x 2C x D/4 x (1/4)H x (1/4)W 188 | cost = Variable( 189 | torch.FloatTensor(refimg_fea.size()[0], refimg_fea.size()[1] * 2, self.maxdisp // 4, refimg_fea.size()[2], 190 | refimg_fea.size()[3]).zero_()).cuda() 191 | for i in range(self.maxdisp // 4): 192 | if i > 0: 193 | cost[:, :refimg_fea.size()[1], i, :, i:] = refimg_fea[:, :, :, i:] 194 | cost[:, refimg_fea.size()[1]:, i, :, i:] = targetimg_fea[:, :, :, :-i] 195 | else: 196 | cost[:, :refimg_fea.size()[1], i, :, :] = refimg_fea 197 | cost[:, refimg_fea.size()[1]:, i, :, :] = targetimg_fea 198 | return cost 199 | 200 | 201 | class PSMNet(nn.Module): 202 | def __init__(self, maxdisp): 203 | super(PSMNet, self).__init__() 204 | self.maxdisp = maxdisp 205 | 206 | self.feature_extraction = feature_extraction() 207 | 208 | self.dres0 = nn.Sequential(convbn_3d(64, 32, 3, 1, 1), 209 | nn.ReLU(inplace=True), 210 | convbn_3d(32, 32, 3, 1, 1), 211 | nn.ReLU(inplace=True)) 212 | 213 | self.dres1 = nn.Sequential(convbn_3d(32, 32, 3, 1, 1), 214 | nn.ReLU(inplace=True), 215 | convbn_3d(32, 32, 3, 1, 1)) 216 | 217 | self.dres2 = hourglass(32) 218 | 219 | self.dres3 = hourglass(32) 220 | 221 | self.dres4 = hourglass(32) 222 | 223 | self.classif1 = nn.Sequential(convbn_3d(32, 32, 3, 1, 1), 224 | nn.ReLU(inplace=True), 225 | nn.Conv3d(32, 1, kernel_size=3, padding=1, stride=1, bias=False)) 226 | 227 | self.classif2 = nn.Sequential(convbn_3d(32, 32, 3, 1, 1), 228 | nn.ReLU(inplace=True), 229 | nn.Conv3d(32, 1, kernel_size=3, padding=1, stride=1, bias=False)) 230 | 231 | self.classif3 = nn.Sequential(convbn_3d(32, 32, 3, 1, 1), 232 | nn.ReLU(inplace=True), 233 | nn.Conv3d(32, 1, kernel_size=3, padding=1, stride=1, bias=False)) 234 | 235 | for m in self.modules(): 236 | if isinstance(m, nn.Conv2d): 237 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 238 | m.weight.data.normal_(0, math.sqrt(2. / n)) 239 | elif isinstance(m, nn.Conv3d): 240 | n = m.kernel_size[0] * m.kernel_size[1] * m.kernel_size[2] * m.out_channels 241 | m.weight.data.normal_(0, math.sqrt(2. / n)) 242 | elif isinstance(m, nn.BatchNorm2d): 243 | m.weight.data.fill_(1) 244 | m.bias.data.zero_() 245 | elif isinstance(m, nn.BatchNorm3d): 246 | m.weight.data.fill_(1) 247 | m.bias.data.zero_() 248 | elif isinstance(m, nn.Linear): 249 | m.bias.data.zero_() 250 | 251 | def forward(self, left, right): 252 | 253 | refimg_fea = self.feature_extraction(left) 254 | targetimg_fea = self.feature_extraction(right) 255 | 256 | # matching: aggregate cost 257 | cost = Variable( 258 | torch.FloatTensor(refimg_fea.size()[0], refimg_fea.size()[1] * 2, self.maxdisp // 4, refimg_fea.size()[2], 259 | refimg_fea.size()[3]).zero_()).cuda() 260 | 261 | for i in range(self.maxdisp // 4): 262 | if i > 0: 263 | cost[:, :refimg_fea.size()[1], i, :, i:] = refimg_fea[:, :, :, i:] 264 | cost[:, refimg_fea.size()[1]:, i, :, i:] = targetimg_fea[:, :, :, :-i] 265 | else: 266 | cost[:, :refimg_fea.size()[1], i, :, :] = refimg_fea 267 | cost[:, refimg_fea.size()[1]:, i, :, :] = targetimg_fea 268 | cost = cost.contiguous() 269 | 270 | # 3D CNN 271 | cost0 = self.dres0(cost) 272 | cost0 = self.dres1(cost0) + cost0 273 | 274 | out1, pre1, post1 = self.dres2(cost0, None, None) 275 | out1 = out1 + cost0 276 | 277 | out2, pre2, post2 = self.dres3(out1, pre1, post1) 278 | out2 = out2 + cost0 279 | 280 | out3, pre3, post3 = self.dres4(out2, pre1, post2) 281 | out3 = out3 + cost0 282 | 283 | cost1 = self.classif1(out1) 284 | cost2 = self.classif2(out2) + cost1 285 | cost3 = self.classif3(out3) + cost2 286 | 287 | # deep supervision 288 | if self.training: 289 | cost1 = F.upsample(cost1, [self.maxdisp, left.size()[2], left.size()[3]], mode='trilinear') 290 | cost2 = F.upsample(cost2, [self.maxdisp, left.size()[2], left.size()[3]], mode='trilinear') 291 | 292 | cost1 = torch.squeeze(cost1, 1) 293 | pred1 = F.softmax(cost1, dim=1) 294 | pred1 = disparityregression(self.maxdisp)(pred1) 295 | 296 | cost2 = torch.squeeze(cost2, 1) 297 | pred2 = F.softmax(cost2, dim=1) 298 | pred2 = disparityregression(self.maxdisp)(pred2) 299 | 300 | cost3 = F.upsample(cost3, [self.maxdisp, left.size()[2], left.size()[3]], mode='trilinear') 301 | cost3 = torch.squeeze(cost3, 1) 302 | pred3 = F.softmax(cost3, dim=1) 303 | # For your information: This formulation 'softmax(c)' learned "similarity" 304 | # while 'softmax(-c)' learned 'matching cost' as mentioned in the paper. 305 | # However, 'c' or '-c' do not affect the performance because feature-based cost volume provided flexibility. 306 | pred3 = disparityregression(self.maxdisp)(pred3) 307 | 308 | # return three supervision only when training 309 | if self.training: 310 | return pred1, pred2, pred3 311 | else: 312 | return pred3 -------------------------------------------------------------------------------- /ptsemseg/models/unet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from ptsemseg.models.utils import unetConv2, unetUp 4 | 5 | 6 | class unet(nn.Module): 7 | def __init__( 8 | self, feature_scale=4, n_classes=21, is_deconv=True, in_channels=3, is_batchnorm=True 9 | ): 10 | super(unet, self).__init__() 11 | self.is_deconv = is_deconv 12 | self.in_channels = in_channels 13 | self.is_batchnorm = is_batchnorm 14 | self.feature_scale = feature_scale 15 | 16 | filters = [64, 128, 256, 512, 1024] 17 | filters = [int(x / self.feature_scale) for x in filters] 18 | 19 | # downsampling 20 | self.conv1 = unetConv2(self.in_channels, filters[0], self.is_batchnorm) 21 | self.maxpool1 = nn.MaxPool2d(kernel_size=2) 22 | 23 | self.conv2 = unetConv2(filters[0], filters[1], self.is_batchnorm) 24 | self.maxpool2 = nn.MaxPool2d(kernel_size=2) 25 | 26 | self.conv3 = unetConv2(filters[1], filters[2], self.is_batchnorm) 27 | self.maxpool3 = nn.MaxPool2d(kernel_size=2) 28 | 29 | self.conv4 = unetConv2(filters[2], filters[3], self.is_batchnorm) 30 | self.maxpool4 = nn.MaxPool2d(kernel_size=2) 31 | 32 | self.center = unetConv2(filters[3], filters[4], self.is_batchnorm) 33 | 34 | # upsampling 35 | self.up_concat4 = unetUp(filters[4], filters[3], self.is_deconv) 36 | self.up_concat3 = unetUp(filters[3], filters[2], self.is_deconv) 37 | self.up_concat2 = unetUp(filters[2], filters[1], self.is_deconv) 38 | self.up_concat1 = unetUp(filters[1], filters[0], self.is_deconv) 39 | 40 | # final conv (without any concat) 41 | self.final = nn.Conv2d(filters[0], n_classes, 1) 42 | 43 | def forward(self, inputs): 44 | conv1 = self.conv1(inputs) 45 | maxpool1 = self.maxpool1(conv1) 46 | 47 | conv2 = self.conv2(maxpool1) 48 | maxpool2 = self.maxpool2(conv2) 49 | 50 | conv3 = self.conv3(maxpool2) 51 | maxpool3 = self.maxpool3(conv3) 52 | 53 | conv4 = self.conv4(maxpool3) 54 | maxpool4 = self.maxpool4(conv4) 55 | 56 | center = self.center(maxpool4) 57 | up4 = self.up_concat4(conv4, center) 58 | up3 = self.up_concat3(conv3, up4) 59 | up2 = self.up_concat2(conv2, up3) 60 | up1 = self.up_concat1(conv1, up2) 61 | 62 | final = self.final(up1) 63 | 64 | return final 65 | -------------------------------------------------------------------------------- /ptsemseg/optimizers/__init__.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from torch.optim import SGD, Adam, ASGD, Adamax, Adadelta, Adagrad, RMSprop 4 | 5 | logger = logging.getLogger("ptsemseg") 6 | 7 | key2opt = { 8 | "sgd": SGD, 9 | "adam": Adam, 10 | "asgd": ASGD, 11 | "adamax": Adamax, 12 | "adadelta": Adadelta, 13 | "adagrad": Adagrad, 14 | "rmsprop": RMSprop, 15 | } 16 | 17 | 18 | def get_optimizer(cfg): 19 | if cfg["training"]["optimizer"] is None: 20 | logger.info("Using SGD optimizer") 21 | return SGD 22 | 23 | else: 24 | opt_name = cfg["training"]["optimizer"]["name"] 25 | if opt_name not in key2opt: 26 | raise NotImplementedError("Optimizer {} not implemented".format(opt_name)) 27 | 28 | logger.info("Using {} optimizer".format(opt_name)) 29 | return key2opt[opt_name] 30 | -------------------------------------------------------------------------------- /ptsemseg/optimizers/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lauraset/BuildingHeightModel/d13d16a02dfdfa042d63492009a7cc8e46032170/ptsemseg/optimizers/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /ptsemseg/schedulers/__init__.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from torch.optim.lr_scheduler import MultiStepLR, ExponentialLR, CosineAnnealingLR 4 | 5 | from ptsemseg.schedulers.schedulers import WarmUpLR, ConstantLR, PolynomialLR 6 | 7 | logger = logging.getLogger("ptsemseg") 8 | 9 | key2scheduler = { 10 | "constant_lr": ConstantLR, 11 | "poly_lr": PolynomialLR, 12 | "multi_step": MultiStepLR, 13 | "cosine_annealing": CosineAnnealingLR, 14 | "exp_lr": ExponentialLR, 15 | } 16 | 17 | 18 | def get_scheduler(optimizer, scheduler_dict): 19 | if scheduler_dict is None: 20 | logger.info("Using No LR Scheduling") 21 | return ConstantLR(optimizer) 22 | 23 | s_type = scheduler_dict["name"] 24 | scheduler_dict.pop("name") 25 | 26 | logging.info("Using {} scheduler with {} params".format(s_type, scheduler_dict)) 27 | 28 | warmup_dict = {} 29 | if "warmup_iters" in scheduler_dict: 30 | # This can be done in a more pythonic way... 31 | warmup_dict["warmup_iters"] = scheduler_dict.get("warmup_iters", 100) 32 | warmup_dict["mode"] = scheduler_dict.get("warmup_mode", "linear") 33 | warmup_dict["gamma"] = scheduler_dict.get("warmup_factor", 0.2) 34 | 35 | logger.info( 36 | "Using Warmup with {} iters {} gamma and {} mode".format( 37 | warmup_dict["warmup_iters"], warmup_dict["gamma"], warmup_dict["mode"] 38 | ) 39 | ) 40 | 41 | scheduler_dict.pop("warmup_iters", None) 42 | scheduler_dict.pop("warmup_mode", None) 43 | scheduler_dict.pop("warmup_factor", None) 44 | 45 | base_scheduler = key2scheduler[s_type](optimizer, **scheduler_dict) 46 | return WarmUpLR(optimizer, base_scheduler, **warmup_dict) 47 | 48 | return key2scheduler[s_type](optimizer, **scheduler_dict) 49 | -------------------------------------------------------------------------------- /ptsemseg/schedulers/schedulers.py: -------------------------------------------------------------------------------- 1 | from torch.optim.lr_scheduler import _LRScheduler 2 | 3 | 4 | class ConstantLR(_LRScheduler): 5 | def __init__(self, optimizer, last_epoch=-1): 6 | super(ConstantLR, self).__init__(optimizer, last_epoch) 7 | 8 | def get_lr(self): 9 | return [base_lr for base_lr in self.base_lrs] 10 | 11 | 12 | class PolynomialLR(_LRScheduler): 13 | def __init__(self, optimizer, max_iter, decay_iter=1, gamma=0.9, last_epoch=-1): 14 | self.decay_iter = decay_iter 15 | self.max_iter = max_iter 16 | self.gamma = gamma 17 | super(PolynomialLR, self).__init__(optimizer, last_epoch) 18 | 19 | def get_lr(self): 20 | if self.last_epoch % self.decay_iter or self.last_epoch % self.max_iter: 21 | return [base_lr for base_lr in self.base_lrs] 22 | else: 23 | factor = (1 - self.last_epoch / float(self.max_iter)) ** self.gamma 24 | return [base_lr * factor for base_lr in self.base_lrs] 25 | 26 | 27 | class WarmUpLR(_LRScheduler): 28 | def __init__( 29 | self, optimizer, scheduler, mode="linear", warmup_iters=100, gamma=0.2, last_epoch=-1 30 | ): 31 | self.mode = mode 32 | self.scheduler = scheduler 33 | self.warmup_iters = warmup_iters 34 | self.gamma = gamma 35 | super(WarmUpLR, self).__init__(optimizer, last_epoch) 36 | 37 | def get_lr(self): 38 | cold_lrs = self.scheduler.get_lr() 39 | 40 | if self.last_epoch < self.warmup_iters: 41 | if self.mode == "linear": 42 | alpha = self.last_epoch / float(self.warmup_iters) 43 | factor = self.gamma * (1 - alpha) + alpha 44 | 45 | elif self.mode == "constant": 46 | factor = self.gamma 47 | else: 48 | raise KeyError("WarmUp type {} not implemented".format(self.mode)) 49 | 50 | return [factor * base_lr for base_lr in cold_lrs] 51 | 52 | return cold_lrs 53 | -------------------------------------------------------------------------------- /ptsemseg/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Misc Utility functions 3 | """ 4 | import os 5 | import logging 6 | import datetime 7 | import numpy as np 8 | 9 | from collections import OrderedDict 10 | 11 | 12 | def recursive_glob(rootdir=".", suffix=""): 13 | """Performs recursive glob with given suffix and rootdir 14 | :param rootdir is the root directory 15 | :param suffix is the suffix to be searched 16 | """ 17 | return [ 18 | os.path.join(looproot, filename) 19 | for looproot, _, filenames in os.walk(rootdir) 20 | for filename in filenames 21 | if filename.endswith(suffix) 22 | ] 23 | 24 | 25 | def alpha_blend(input_image, segmentation_mask, alpha=0.5): 26 | """Alpha Blending utility to overlay RGB masks on RBG images 27 | :param input_image is a np.ndarray with 3 channels 28 | :param segmentation_mask is a np.ndarray with 3 channels 29 | :param alpha is a float value 30 | """ 31 | blended = np.zeros(input_image.size, dtype=np.float32) 32 | blended = input_image * alpha + segmentation_mask * (1 - alpha) 33 | return blended 34 | 35 | 36 | def convert_state_dict(state_dict): 37 | """Converts a state dict saved from a dataParallel module to normal 38 | module state_dict inplace 39 | :param state_dict is the loaded DataParallel model_state 40 | """ 41 | if not next(iter(state_dict)).startswith("module."): 42 | return state_dict # abort if dict is not a DataParallel model_state 43 | new_state_dict = OrderedDict() 44 | for k, v in state_dict.items(): 45 | name = k[7:] # remove `module.` 46 | new_state_dict[name] = v 47 | return new_state_dict 48 | 49 | 50 | def get_logger(logdir): 51 | logger = logging.getLogger("ptsemseg") 52 | ts = str(datetime.datetime.now()).split(".")[0].replace(" ", "_") 53 | ts = ts.replace(":", "_").replace("-", "_") 54 | file_path = os.path.join(logdir, "run_{}.log".format(ts)) 55 | hdlr = logging.FileHandler(file_path) 56 | formatter = logging.Formatter("%(asctime)s %(levelname)s %(message)s") 57 | hdlr.setFormatter(formatter) 58 | logger.addHandler(hdlr) 59 | logger.setLevel(logging.INFO) 60 | return logger 61 | -------------------------------------------------------------------------------- /runs/tlcnetu_zy3bh/V1.rar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lauraset/BuildingHeightModel/d13d16a02dfdfa042d63492009a7cc8e46032170/runs/tlcnetu_zy3bh/V1.rar -------------------------------------------------------------------------------- /runs/tlcnetu_zy3bh/finetune.tar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lauraset/BuildingHeightModel/d13d16a02dfdfa042d63492009a7cc8e46032170/runs/tlcnetu_zy3bh/finetune.tar -------------------------------------------------------------------------------- /runs/tlcnetu_zy3bh_mux.rar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lauraset/BuildingHeightModel/d13d16a02dfdfa042d63492009a7cc8e46032170/runs/tlcnetu_zy3bh_mux.rar -------------------------------------------------------------------------------- /runs/tlcnetu_zy3bh_tlc.rar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lauraset/BuildingHeightModel/d13d16a02dfdfa042d63492009a7cc8e46032170/runs/tlcnetu_zy3bh_tlc.rar -------------------------------------------------------------------------------- /runs/tlcnetu_zy3bh_tlcmux.rar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lauraset/BuildingHeightModel/d13d16a02dfdfa042d63492009a7cc8e46032170/runs/tlcnetu_zy3bh_tlcmux.rar -------------------------------------------------------------------------------- /sample/img/img_bj0216.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lauraset/BuildingHeightModel/d13d16a02dfdfa042d63492009a7cc8e46032170/sample/img/img_bj0216.tif -------------------------------------------------------------------------------- /sample/img/img_bj0217.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lauraset/BuildingHeightModel/d13d16a02dfdfa042d63492009a7cc8e46032170/sample/img/img_bj0217.tif -------------------------------------------------------------------------------- /sample/img/img_bj0219.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lauraset/BuildingHeightModel/d13d16a02dfdfa042d63492009a7cc8e46032170/sample/img/img_bj0219.tif -------------------------------------------------------------------------------- /sample/img/img_bj0317.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lauraset/BuildingHeightModel/d13d16a02dfdfa042d63492009a7cc8e46032170/sample/img/img_bj0317.tif -------------------------------------------------------------------------------- /sample/img/img_bj0322.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lauraset/BuildingHeightModel/d13d16a02dfdfa042d63492009a7cc8e46032170/sample/img/img_bj0322.tif -------------------------------------------------------------------------------- /sample/img/img_bj0324.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lauraset/BuildingHeightModel/d13d16a02dfdfa042d63492009a7cc8e46032170/sample/img/img_bj0324.tif -------------------------------------------------------------------------------- /sample/img/img_bj0325.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lauraset/BuildingHeightModel/d13d16a02dfdfa042d63492009a7cc8e46032170/sample/img/img_bj0325.tif -------------------------------------------------------------------------------- /sample/lab/lab_bj0216.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lauraset/BuildingHeightModel/d13d16a02dfdfa042d63492009a7cc8e46032170/sample/lab/lab_bj0216.tif -------------------------------------------------------------------------------- /sample/lab/lab_bj0217.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lauraset/BuildingHeightModel/d13d16a02dfdfa042d63492009a7cc8e46032170/sample/lab/lab_bj0217.tif -------------------------------------------------------------------------------- /sample/lab/lab_bj0219.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lauraset/BuildingHeightModel/d13d16a02dfdfa042d63492009a7cc8e46032170/sample/lab/lab_bj0219.tif -------------------------------------------------------------------------------- /sample/lab/lab_bj0317.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lauraset/BuildingHeightModel/d13d16a02dfdfa042d63492009a7cc8e46032170/sample/lab/lab_bj0317.tif -------------------------------------------------------------------------------- /sample/lab/lab_bj0322.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lauraset/BuildingHeightModel/d13d16a02dfdfa042d63492009a7cc8e46032170/sample/lab/lab_bj0322.tif -------------------------------------------------------------------------------- /sample/lab/lab_bj0324.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lauraset/BuildingHeightModel/d13d16a02dfdfa042d63492009a7cc8e46032170/sample/lab/lab_bj0324.tif -------------------------------------------------------------------------------- /sample/lab/lab_bj0325.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lauraset/BuildingHeightModel/d13d16a02dfdfa042d63492009a7cc8e46032170/sample/lab/lab_bj0325.tif -------------------------------------------------------------------------------- /sample/lab_floor/lab_bj0216.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lauraset/BuildingHeightModel/d13d16a02dfdfa042d63492009a7cc8e46032170/sample/lab_floor/lab_bj0216.tif -------------------------------------------------------------------------------- /sample/lab_floor/lab_bj0217.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lauraset/BuildingHeightModel/d13d16a02dfdfa042d63492009a7cc8e46032170/sample/lab_floor/lab_bj0217.tif -------------------------------------------------------------------------------- /sample/lab_floor/lab_bj0219.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lauraset/BuildingHeightModel/d13d16a02dfdfa042d63492009a7cc8e46032170/sample/lab_floor/lab_bj0219.tif -------------------------------------------------------------------------------- /sample/lab_floor/lab_bj0317.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lauraset/BuildingHeightModel/d13d16a02dfdfa042d63492009a7cc8e46032170/sample/lab_floor/lab_bj0317.tif -------------------------------------------------------------------------------- /sample/lab_floor/lab_bj0322.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lauraset/BuildingHeightModel/d13d16a02dfdfa042d63492009a7cc8e46032170/sample/lab_floor/lab_bj0322.tif -------------------------------------------------------------------------------- /sample/lab_floor/lab_bj0324.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lauraset/BuildingHeightModel/d13d16a02dfdfa042d63492009a7cc8e46032170/sample/lab_floor/lab_bj0324.tif -------------------------------------------------------------------------------- /sample/lab_floor/lab_bj0325.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lauraset/BuildingHeightModel/d13d16a02dfdfa042d63492009a7cc8e46032170/sample/lab_floor/lab_bj0325.tif -------------------------------------------------------------------------------- /sample/tlc/tlc_bj0216.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lauraset/BuildingHeightModel/d13d16a02dfdfa042d63492009a7cc8e46032170/sample/tlc/tlc_bj0216.tif -------------------------------------------------------------------------------- /sample/tlc/tlc_bj0217.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lauraset/BuildingHeightModel/d13d16a02dfdfa042d63492009a7cc8e46032170/sample/tlc/tlc_bj0217.tif -------------------------------------------------------------------------------- /sample/tlc/tlc_bj0219.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lauraset/BuildingHeightModel/d13d16a02dfdfa042d63492009a7cc8e46032170/sample/tlc/tlc_bj0219.tif -------------------------------------------------------------------------------- /sample/tlc/tlc_bj0317.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lauraset/BuildingHeightModel/d13d16a02dfdfa042d63492009a7cc8e46032170/sample/tlc/tlc_bj0317.tif -------------------------------------------------------------------------------- /sample/tlc/tlc_bj0322.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lauraset/BuildingHeightModel/d13d16a02dfdfa042d63492009a7cc8e46032170/sample/tlc/tlc_bj0322.tif -------------------------------------------------------------------------------- /sample/tlc/tlc_bj0324.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lauraset/BuildingHeightModel/d13d16a02dfdfa042d63492009a7cc8e46032170/sample/tlc/tlc_bj0324.tif -------------------------------------------------------------------------------- /sample/tlc/tlc_bj0325.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lauraset/BuildingHeightModel/d13d16a02dfdfa042d63492009a7cc8e46032170/sample/tlc/tlc_bj0325.tif -------------------------------------------------------------------------------- /test_zy3bh_tlcnetU.py: -------------------------------------------------------------------------------- 1 | ''' 2 | predict on images 3 | ''' 4 | import os 5 | import yaml 6 | import shutil 7 | import torch 8 | import random 9 | import argparse 10 | import numpy as np 11 | 12 | from ptsemseg.models import get_model 13 | from ptsemseg.utils import get_logger 14 | from tensorboardX import SummaryWriter 15 | from ptsemseg.loader.diy_dataset import dataloaderbh 16 | import sklearn.metrics 17 | import matplotlib.pyplot as plt 18 | import tifffile as tif 19 | 20 | 21 | def main(cfg, writer, logger): 22 | 23 | # Setup device 24 | device = torch.device(cfg["training"]["device"]) 25 | 26 | # Setup Dataloader 27 | data_path = cfg["data"]["path"] 28 | n_classes = cfg["data"]["n_class"] 29 | n_maxdisp = cfg["data"]["n_maxdisp"] 30 | batch_size = cfg["training"]["batch_size"] 31 | epochs = cfg["training"]["epochs"] 32 | learning_rate = cfg["training"]["learning_rate"] 33 | patchsize = cfg["data"]["img_rows"] 34 | 35 | _, _, valimg, vallab = dataloaderbh(data_path) 36 | 37 | # Setup Model 38 | model = get_model(cfg["model"], n_maxdisp=n_maxdisp, n_classes=n_classes).to(device) 39 | if torch.cuda.device_count() > 1: 40 | model = torch.nn.DataParallel(model, device_ids=range(torch.cuda.device_count())) 41 | 42 | #resume = cfg["training"]["resume"] 43 | resume = r'runs\tlcnetu_zy3bh\V1\finetune_298.tar' 44 | if os.path.isfile(resume): 45 | print("=> loading checkpoint '{}'".format(resume)) 46 | checkpoint = torch.load(resume) 47 | model.load_state_dict(checkpoint['state_dict']) 48 | # optimizer.load_state_dict(checkpoint['optimizer']) 49 | print("=> loaded checkpoint '{}' (epoch {})" 50 | .format(resume, checkpoint['epoch'])) 51 | else: 52 | print("=> no checkpoint found at resume") 53 | print("=> Will start from scratch.") 54 | 55 | model.eval() 56 | 57 | for idx, imgpath in enumerate(valimg[0:20]): 58 | name = os.path.basename(vallab[idx]) 59 | respath = os.path.join(cfg["savepath"],'pred'+name) 60 | y_true = tif.imread(vallab[idx]) 61 | y_true = y_true.astype(np.int16)*3 62 | # random crop: test and train is the same 63 | mux = tif.imread(imgpath[0])/10000 # convert to surface reflectance (SR): 0-1 64 | tlc = tif.imread(imgpath[1])/10000 # stretch to 0-1 65 | 66 | offset = mux.shape[0] - patchsize 67 | x1 = random.randint(0, offset) 68 | y1 = random.randint(0, offset) 69 | mux = mux[x1:x1 + patchsize, y1:y1 + patchsize, :] 70 | tlc = tlc[x1:x1 + patchsize, y1:y1 + patchsize, :] 71 | y_true = y_true[x1:x1 + patchsize, y1:y1 + patchsize] 72 | 73 | img = np.concatenate((mux, tlc), axis=2) 74 | img[img > 1] = 1 # ensure data range is 0-1 75 | # remove tlc 76 | # img[:,:,4:] = 0 77 | 78 | img = img.transpose((2, 0, 1)) 79 | img = np.expand_dims(img, 0) 80 | img = torch.from_numpy(img).float() 81 | y_res = model(img.to(device)) 82 | 83 | y_pred = y_res[0] # height 84 | y_pred = y_pred.cpu().detach().numpy() 85 | y_pred = np.squeeze(y_pred) 86 | rmse = myrmse(y_true, y_pred) 87 | 88 | y_seg = y_res[1] # seg 89 | y_seg = y_seg.cpu().detach().numpy() 90 | y_seg = np.argmax(y_seg.squeeze(), axis=0) # C H W=> H W 91 | precision, recall, f1score = metricsperclass(y_true, y_seg, value=1) # 92 | print('rmse: %.3f, segerror: ua %.3f, pa %.3f, f1 %.3f'%(rmse, precision, recall, f1score)) 93 | 94 | # tif.imsave((os.path.join(cfg["savepath"],'mux'+name)), mux) 95 | # tif.imsave( (os.path.join(cfg["savepath"], 'ref' + name)), y_true) 96 | # tif.imsave( (os.path.join(cfg["savepath"], 'pred' + name)), y_pred) 97 | tif.imsave((os.path.join(cfg["savepath"], 'seg' + name)), y_seg.astype(np.uint8)) 98 | 99 | # 100 | # color encode: change to the 101 | # get color info 102 | # _, color_values = get_colored_info('class_dict.csv') 103 | # prediction = color_encode(y_pred, color_values) 104 | # label = color_encode(y_true, color_values) 105 | 106 | # plt.subplot(131) 107 | # plt.title('Image', fontsize='large', fontweight='bold') 108 | # plt.imshow(mux[:, :, 0:3]/1000) 109 | # plt.subplot(132) 110 | # plt.title('Ref', fontsize='large', fontweight='bold') 111 | # plt.imshow(y_true) 112 | # # plt.subplot(143) 113 | # # plt.title('Pred', fontsize='large', fontweight='bold') 114 | # # plt.imshow(prediction) 115 | # plt.subplot(133) 116 | # plt.title('Pred %.3f'%scores, fontsize='large', fontweight='bold') 117 | # plt.imshow(y_pred) 118 | # plt.savefig(os.path.join(cfg["savepath"], 'fig'+name)) 119 | # plt.close() 120 | 121 | 122 | def gray2rgb(image): 123 | res=np.zeros((image.shape[0], image.shape[1], 3)) 124 | res[ :, :, 0] = image.copy() 125 | res[ :, :, 1] = image.copy() 126 | res[ :, :, 2] = image.copy() 127 | return res 128 | 129 | 130 | def metrics(y_true, y_pred, ignorevalue=0): 131 | y_true = y_true.flatten() 132 | y_pred = y_pred.flatten() 133 | maskid = np.where(y_true!=ignorevalue) 134 | y_true = y_true[maskid] 135 | y_pred = y_pred[maskid] 136 | accuracy = sklearn.metrics.accuracy_score(y_true, y_pred) 137 | kappa = sklearn.metrics.cohen_kappa_score(y_true, y_pred) 138 | f1_micro = sklearn.metrics.f1_score(y_true, y_pred, average="micro") 139 | f1_macro = sklearn.metrics.f1_score(y_true, y_pred, average="macro") 140 | f1_weighted = sklearn.metrics.f1_score(y_true, y_pred, average="weighted") 141 | recall_micro = sklearn.metrics.recall_score(y_true, y_pred, average="micro") 142 | recall_macro = sklearn.metrics.recall_score(y_true, y_pred, average="macro") 143 | recall_weighted = sklearn.metrics.recall_score(y_true, y_pred, average="weighted") 144 | precision_micro = sklearn.metrics.precision_score(y_true, y_pred, average="micro") 145 | precision_macro = sklearn.metrics.precision_score(y_true, y_pred, average="macro") 146 | precision_weighted = sklearn.metrics.precision_score(y_true, y_pred, average="weighted") 147 | 148 | return dict( 149 | accuracy=accuracy, 150 | kappa=kappa, 151 | f1_micro=f1_micro, 152 | f1_macro=f1_macro, 153 | f1_weighted=f1_weighted, 154 | recall_micro=recall_micro, 155 | recall_macro=recall_macro, 156 | recall_weighted=recall_weighted, 157 | precision_micro=precision_micro, 158 | precision_macro=precision_macro, 159 | precision_weighted=precision_weighted, 160 | ) 161 | 162 | def myrmse(y_true, ypred): 163 | diff=y_true.flatten()-ypred.flatten() 164 | return np.sqrt(np.mean(diff*diff)) 165 | 166 | 167 | def metricsperclass(y_true, y_pred, value): 168 | y_pred = y_pred.flatten() 169 | y_true = np.where(y_true>0, np.ones_like(y_true), np.zeros_like(y_true)).flatten() 170 | 171 | tp=len(np.where((y_true==value) & (y_pred==value))[0]) 172 | tn=len(np.where(y_true==value)[0]) 173 | fn = len(np.where(y_pred == value)[0]) 174 | precision = tp/(1e-10+fn) 175 | recall = tp/(1e-10+tn) 176 | f1score = 2*precision*recall/(precision+recall+1e-10) 177 | return precision, recall, f1score 178 | 179 | 180 | if __name__ == "__main__": 181 | parser = argparse.ArgumentParser(description="config") 182 | parser.add_argument( 183 | "--config", 184 | nargs="?", 185 | type=str, 186 | default="configs/tlcnetu_zy3bh.yml", 187 | help="Configuration file to use", 188 | ) 189 | 190 | args = parser.parse_args() 191 | 192 | with open(args.config) as fp: 193 | cfg = yaml.load(fp) 194 | 195 | #run_id = random.randint(1, 100000) 196 | logdir = os.path.join("runs", os.path.basename(args.config)[:-4], "V1") 197 | writer = SummaryWriter(log_dir=logdir) 198 | 199 | print("RUNDIR: {}".format(logdir)) 200 | shutil.copy(args.config, logdir) 201 | 202 | logger = get_logger(logdir) 203 | logger.info("Let the games begin") 204 | 205 | main(cfg, writer, logger) 206 | -------------------------------------------------------------------------------- /train_zy3bh_tlcnetU_loss.py: -------------------------------------------------------------------------------- 1 | ''' 2 | date: 2020.7.27 3 | author: yinxia cao 4 | function: train building height using unet method 5 | @Update: 2020.10.8 uncertainty weighting multi-loss 6 | ''' 7 | 8 | import os 9 | os.environ['CUDA_VISIBLE_DEVICES'] = '7' 10 | 11 | import yaml 12 | import shutil 13 | import torch 14 | import random 15 | import argparse 16 | import numpy as np 17 | from tqdm import tqdm 18 | 19 | from torch.utils import data 20 | from ptsemseg.models import get_model 21 | from ptsemseg.utils import get_logger 22 | from tensorboardX import SummaryWriter #change tensorboardX 23 | from ptsemseg.loader.diy_dataset import dataloaderbh 24 | from ptsemseg.loader.diyloader import myImageFloder 25 | import torch.nn.functional as F 26 | # from segmentation_models_pytorch_revised import DeepLabV3Plus 27 | 28 | def main(cfg, writer, logger): 29 | 30 | # Setup seeds 31 | torch.manual_seed(cfg.get("seed", 1337)) 32 | torch.cuda.manual_seed(cfg.get("seed", 1337)) 33 | np.random.seed(cfg.get("seed", 1337)) 34 | random.seed(cfg.get("seed", 1337)) 35 | 36 | # Setup device 37 | device = torch.device(cfg["training"]["device"]) 38 | 39 | # Setup Dataloader 40 | data_path = cfg["data"]["path"] 41 | n_classes = cfg["data"]["n_class"] 42 | n_maxdisp = cfg["data"]["n_maxdisp"] 43 | batch_size = cfg["training"]["batch_size"] 44 | epochs = cfg["training"]["epochs"] 45 | 46 | # Load dataset 47 | trainimg, trainlab, valimg, vallab = dataloaderbh(data_path) 48 | traindataloader = torch.utils.data.DataLoader( 49 | myImageFloder(trainimg, trainlab, True), 50 | batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True) 51 | 52 | testdataloader = torch.utils.data.DataLoader( 53 | myImageFloder(valimg, vallab), 54 | batch_size=batch_size, shuffle=False, num_workers=8, pin_memory=True) 55 | 56 | # Setup Model 57 | # model = DeepLabV3Plus("resnet18", encoder_weights='imagenet' ) 58 | model = get_model(cfg["model"], n_maxdisp=n_maxdisp, n_classes=n_classes).to(device) 59 | if torch.cuda.device_count() > 1: 60 | model = torch.nn.DataParallel(model, device_ids=range(torch.cuda.device_count())) 61 | 62 | # print the model 63 | start_epoch = 0 64 | resume = cfg["training"]["resume"] 65 | if os.path.isfile(resume): 66 | print("=> loading checkpoint '{}'".format(resume)) 67 | checkpoint = torch.load(resume) 68 | model.load_state_dict(checkpoint['state_dict']) 69 | # optimizer.load_state_dict(checkpoint['optimizer']) 70 | print("=> loaded checkpoint '{}' (epoch {})" 71 | .format(resume, checkpoint['epoch'])) 72 | start_epoch = checkpoint['epoch'] 73 | else: 74 | print("=> no checkpoint found at resume") 75 | print("=> Will start from scratch.") 76 | 77 | # define task-dependent log_variance 78 | log_var_a = torch.zeros((1,), requires_grad=True) 79 | log_var_b = torch.zeros((1,), requires_grad=True) 80 | # log_var_c = torch.tensor(1.) # fix the weight of semantic segmentation 81 | log_var_c = torch.zeros((1,), requires_grad=True) 82 | 83 | # get all parameters (model parameters + task dependent log variances) 84 | params = ([p for p in model.parameters()] + [log_var_a] + [log_var_b] + [log_var_c]) 85 | 86 | print('Number of model parameters: {}'.format(sum([p.data.nelement() for p in model.parameters()]))) 87 | optimizer = torch.optim.Adam(params, lr=cfg["training"]["learning_rate"], betas=(0.9, 0.999)) 88 | 89 | criterion = 'rmse' #useless 90 | 91 | for epoch in range(epochs-start_epoch): 92 | epoch = start_epoch + epoch 93 | adjust_learning_rate(optimizer, epoch) 94 | model.train() 95 | train_loss = list() 96 | train_mse = 0. 97 | count = 0 98 | print_count = 0 99 | vara = list() 100 | varb = list() 101 | varc = list() 102 | # with tqdm(enumerate(dataloader), total=len(dataloader), leave=True) as iterator: 103 | for x, y_true in tqdm(traindataloader): 104 | x = x.to(device, non_blocking=True) 105 | y_true = y_true.to(device, non_blocking=True) 106 | 107 | ypred1, ypred2, ypred3, ypred4 = model.forward(x) 108 | y_truebi = torch.where(y_true > 0, torch.ones_like(y_true), torch.zeros_like(y_true)) 109 | y_truebi = y_truebi.long().view(-1).to(device, non_blocking=True) 110 | ypred3 = ypred3.transpose(1, 2).transpose(2, 3).contiguous().view(-1, 2) 111 | loss_mse = F.mse_loss(ypred4 , y_true, reduction='mean').cpu().detach().numpy() 112 | loss = loss_weight([ypred1, ypred2, ypred3, ypred4], 113 | [y_true, y_truebi], 114 | [log_var_a.to(device), log_var_b.to(device), log_var_c.to(device)]) 115 | 116 | optimizer.zero_grad() 117 | loss.backward() 118 | optimizer.step() 119 | 120 | train_loss.append(loss.cpu().detach().numpy()) 121 | train_mse += loss_mse*x.shape[0] 122 | count += x.shape[0] 123 | 124 | vara.append(log_var_a.cpu().detach().numpy()) 125 | varb.append(log_var_b.cpu().detach().numpy()) 126 | varc.append(log_var_c.cpu().detach().numpy()) 127 | 128 | if print_count%20 ==0: 129 | print('training loss %.3f, rmse %.3f, vara %.2f, b %.2f, c %.2f' % 130 | (loss.item(), np.sqrt(loss_mse), log_var_a, log_var_b, log_var_c)) 131 | print_count += 1 132 | 133 | train_rmse = np.sqrt(train_mse/count) 134 | # test 135 | val_rmse = test_epoch(model, criterion, 136 | testdataloader, device, epoch) 137 | print("epoch %d rmse: train %.3f, test %.3f" % (epoch, train_rmse, val_rmse)) 138 | 139 | # save models 140 | if epoch % 2 == 0: # every five internval 141 | savefilename = os.path.join(logdir, 'finetune_'+str(epoch)+'.tar') 142 | torch.save({ 143 | 'epoch': epoch, 144 | 'state_dict': model.state_dict(), 145 | 'train_loss': np.mean(train_loss), 146 | 'test_loss': np.mean(val_rmse), #*100 147 | }, savefilename) 148 | # 149 | writer.add_scalar('train loss', 150 | (np.mean(train_loss)), #average 151 | epoch) 152 | writer.add_scalar('train rmse', 153 | (np.mean(train_rmse)), #average 154 | epoch) 155 | writer.add_scalar('val rmse', 156 | (np.mean(val_rmse)), #average 157 | epoch) 158 | writer.add_scalar('weight a', 159 | (np.mean(vara)), #average 160 | epoch) 161 | writer.add_scalar('weight b', 162 | (np.mean(varb)), #average 163 | epoch) 164 | writer.add_scalar('weight c', 165 | (np.mean(varc)), #average 166 | epoch) 167 | writer.close() 168 | 169 | 170 | def adjust_learning_rate(optimizer, epoch): 171 | if epoch <= 200: 172 | lr = cfg["training"]["learning_rate"] 173 | elif epoch <=250: 174 | lr = cfg["training"]["learning_rate"] * 0.1 175 | elif epoch <=300: 176 | lr = cfg["training"]["learning_rate"] * 0.01 177 | else: 178 | lr = cfg["training"]["learning_rate"] * 0.025 # 0.0025 before 179 | print(lr) 180 | for param_group in optimizer.param_groups: 181 | param_group['lr'] = lr 182 | return lr #added 183 | 184 | 185 | # def rmse(disp, gt): 186 | # errmap = torch.sqrt(torch.pow((disp - gt), 2).mean()) 187 | # return errmap # rmse 188 | 189 | 190 | # def mse(disp, gt): 191 | # return (disp-gt)**2. 192 | 193 | # custom loss 194 | def loss_weight_ori(y_pred, y_true, log_vars): 195 | loss = 0 196 | for i in range(len(y_pred)): 197 | precision = torch.exp(-log_vars[i]) 198 | diff = (y_pred[i]-y_true[i])**2. 199 | loss += torch.sum(precision * diff + log_vars[i], -1) 200 | return torch.mean(loss) 201 | 202 | 203 | def loss_weight(y_pred, y_true, log_vars): 204 | #loss 0 tlc height 205 | precision0 = torch.exp(-log_vars[0]) 206 | diff0 = F.mse_loss(y_pred[0],y_true[0],reduction='mean') 207 | loss0 = diff0*precision0 + log_vars[0] 208 | #loss 1 mux height 209 | precision1 = torch.exp(-log_vars[1]) 210 | diff1 = F.mse_loss(y_pred[1], y_true[0], reduction='mean') 211 | loss1 = diff1*precision1 + log_vars[1] 212 | #loss 2 mux segmentation 213 | loss2 = F.cross_entropy(y_pred[2], y_true[1], reduction='mean') 214 | #loss 3 final height 215 | precision3 = torch.exp(-log_vars[2]) 216 | diff3 = F.mse_loss(y_pred[3], y_true[0], reduction='mean') 217 | loss3 = diff3*precision3 + log_vars[2] 218 | return loss0+loss1+loss3+loss2 219 | 220 | 221 | def crossentrop(ypred, y_true, device='cuda'): 222 | y_truebi = torch.where(y_true > 0, torch.ones_like(y_true), torch.zeros_like(y_true)) 223 | y_truebi = y_truebi.long().view(-1).to(device) 224 | ypred = ypred.transpose(1, 2).transpose(2, 3).contiguous().view(-1, 2) 225 | return F.cross_entropy(ypred, y_truebi) 226 | 227 | 228 | def test_epoch(model, criterion, dataloader, device, epoch): 229 | model.eval() 230 | with torch.no_grad(): 231 | losses = 0. 232 | count = 0 233 | for x, y_true in tqdm(dataloader): 234 | x = x.to(device, non_blocking =True) 235 | y_true = y_true.to(device, non_blocking =True) 236 | 237 | y_pred, _ = model.forward(x) 238 | lossv = F.mse_loss(y_pred, y_true, reduction='mean').cpu().detach().numpy() 239 | losses += lossv*x.shape[0] 240 | count += x.shape[0] 241 | 242 | lossfinal = np.sqrt(losses/count) 243 | print('test error %.3f rmse' % lossfinal) 244 | return lossfinal 245 | 246 | 247 | if __name__ == "__main__": 248 | parser = argparse.ArgumentParser(description="config") 249 | parser.add_argument( 250 | "--config", 251 | nargs="?", 252 | type=str, 253 | default="configs/tlcnetu_zy3bh.yml", 254 | help="Configuration file to use", 255 | ) 256 | 257 | args = parser.parse_args() 258 | 259 | with open(args.config) as fp: 260 | cfg = yaml.load(fp, Loader=yaml.FullLoader) 261 | 262 | logdir = os.path.join("runs", os.path.basename(args.config)[:-4], "V1") 263 | writer = SummaryWriter(log_dir=logdir) 264 | 265 | print("RUNDIR: {}".format(logdir)) 266 | shutil.copy(args.config, logdir) 267 | 268 | logger = get_logger(logdir) 269 | logger.info("Let the games begin") 270 | 271 | main(cfg, writer, logger) 272 | --------------------------------------------------------------------------------