├── LICENSE ├── README.md ├── ckp └── README.TXT ├── comparator.py ├── config.py ├── dataloaders ├── __init__.py ├── custom_transforms.py └── dataset.py ├── erf.py ├── erf ├── README.TXT └── selected_pixel_all.json ├── example ├── test │ ├── Images │ │ ├── Barren_07_2524.tif │ │ ├── Forest_01_2715.tif │ │ ├── GrassCrops_09_2014.tif │ │ ├── Shrubland_10_1706.tif │ │ └── Urban_11_1326.tif │ ├── Masks │ │ ├── Barren_07_2524.tif │ │ ├── Forest_01_2715.tif │ │ ├── GrassCrops_09_2014.tif │ │ ├── Shrubland_10_1706.tif │ │ └── Urban_11_1326.tif │ └── test.txt ├── train │ ├── Images │ │ ├── Barren_00_0214.tif │ │ ├── Barren_01_0213.tif │ │ ├── Forest_03_0209.tif │ │ ├── Forest_05_0415.tif │ │ ├── GrassCrops_00_1014.tif │ │ ├── GrassCrops_02_0608.tif │ │ ├── Shrubland_00_0614.tif │ │ ├── Shrubland_02_1021.tif │ │ ├── SnowIce_00_1415.tif │ │ ├── SnowIce_01_0909.tif │ │ ├── Urban_00_0515.tif │ │ ├── Urban_01_0609.tif │ │ ├── Water_01_0922.tif │ │ ├── Water_02_1010.tif │ │ ├── Wetlands_00_1426.tif │ │ └── Wetlands_01_0416.tif │ ├── Masks │ │ ├── Barren_00_0214.tif │ │ ├── Barren_01_0213.tif │ │ ├── Barren_03_0207.tif │ │ ├── Barren_08_0207.tif │ │ ├── Barren_10_0206.tif │ │ ├── Barren_11_0207.tif │ │ ├── Forest_03_0209.tif │ │ ├── Forest_05_0415.tif │ │ ├── Forest_06_0515.tif │ │ ├── Forest_07_1209.tif │ │ ├── Forest_08_0805.tif │ │ ├── Forest_10_1028.tif │ │ ├── GrassCrops_00_1014.tif │ │ ├── GrassCrops_02_0608.tif │ │ ├── GrassCrops_05_0806.tif │ │ ├── GrassCrops_07_1723.tif │ │ ├── GrassCrops_08_1326.tif │ │ ├── GrassCrops_11_1014.tif │ │ ├── Shrubland_00_0614.tif │ │ ├── Shrubland_02_1021.tif │ │ ├── Shrubland_03_0807.tif │ │ ├── Shrubland_04_0711.tif │ │ ├── Shrubland_05_0621.tif │ │ ├── Shrubland_08_1222.tif │ │ ├── SnowIce_00_1415.tif │ │ ├── SnowIce_01_0909.tif │ │ ├── SnowIce_02_0712.tif │ │ ├── SnowIce_05_0708.tif │ │ ├── SnowIce_08_0823.tif │ │ ├── SnowIce_10_1022.tif │ │ ├── Urban_00_0515.tif │ │ ├── Urban_01_0609.tif │ │ ├── Urban_02_0814.tif │ │ ├── Urban_05_0608.tif │ │ ├── Urban_06_1113.tif │ │ ├── Urban_08_0817.tif │ │ ├── Water_01_0922.tif │ │ ├── Water_02_1010.tif │ │ ├── Water_04_0726.tif │ │ ├── Water_05_0517.tif │ │ ├── Water_07_0706.tif │ │ ├── Water_08_0606.tif │ │ ├── Wetlands_00_1426.tif │ │ ├── Wetlands_01_0416.tif │ │ ├── Wetlands_04_1107.tif │ │ ├── Wetlands_05_0713.tif │ │ ├── Wetlands_06_1127.tif │ │ └── Wetlands_11_1217.tif │ └── train.txt └── val │ ├── Images │ ├── Barren_05_0208.tif │ ├── Forest_00_1008.tif │ ├── GrassCrops_01_0413.tif │ ├── Shrubland_01_0508.tif │ ├── SnowIce_11_1123.tif │ ├── Urban_03_0619.tif │ ├── Water_11_0728.tif │ └── Wetlands_07_1826.tif │ ├── Masks │ ├── Barren_05_0208.tif │ ├── Forest_00_1008.tif │ ├── GrassCrops_01_0413.tif │ ├── Shrubland_01_0508.tif │ ├── SnowIce_11_1123.tif │ ├── Urban_03_0619.tif │ ├── Water_11_0728.tif │ └── Wetlands_07_1826.tif │ └── val.txt ├── inference-mix └── README.TXT ├── inference.py ├── inference └── README.TXT ├── interpretation └── calculate_erf.py ├── model ├── __init__.py ├── deeplab │ ├── __init__.py │ ├── aspp.py │ ├── backbone │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-37.pyc │ │ │ ├── drn.cpython-37.pyc │ │ │ ├── mobilenet.cpython-37.pyc │ │ │ ├── resnet.cpython-37.pyc │ │ │ └── xception.cpython-37.pyc │ │ ├── drn.py │ │ ├── mobilenet.py │ │ ├── resnet.py │ │ └── xception.py │ ├── decoder.py │ ├── deeplab.py │ └── sync_batchnorm │ │ ├── __init__.py │ │ ├── __pycache__ │ │ ├── __init__.cpython-37.pyc │ │ ├── batchnorm.cpython-37.pyc │ │ ├── comm.cpython-37.pyc │ │ └── replicate.cpython-37.pyc │ │ ├── batchnorm.py │ │ ├── comm.py │ │ ├── replicate.py │ │ └── unittest.py ├── mfcnn │ ├── __init__.py │ ├── mfcnn_model.py │ └── mfcnn_parts.py ├── mscff │ ├── __init__.py │ ├── mscff_model.py │ └── mscff_parts.py ├── munet │ ├── __init__.py │ ├── munet_model.py │ └── munet_parts.py ├── tlnet │ ├── __init__.py │ ├── tlnet_model.py │ └── tlnet_parts.py ├── unet │ ├── __init__.py │ ├── unet_model.py │ └── unet_parts.py ├── unet_1 │ ├── __init__.py │ ├── unet_model.py │ └── unet_parts.py ├── unet_2 │ ├── __init__.py │ ├── unet_model.py │ └── unet_parts.py ├── unet_3 │ ├── __init__.py │ ├── unet_model.py │ └── unet_parts.py ├── unet_dilation │ ├── __init__.py │ ├── unet_model.py │ └── unet_parts.py ├── unets1 │ ├── __init__.py │ ├── unet_model.py │ └── unet_parts.py ├── unets2 │ ├── __init__.py │ ├── unet_model.py │ └── unet_parts.py └── unets3 │ ├── __init__.py │ ├── unet_model.py │ └── unet_parts.py ├── train.py └── utils ├── __pycache__ ├── calculate_weights.cpython-37.pyc ├── f_boundary.cpython-37.pyc ├── img_saver.cpython-37.pyc ├── loss.cpython-37.pyc ├── lr_scheduler.cpython-37.pyc ├── metrics.cpython-37.pyc ├── net_convert.cpython-37.pyc ├── saver.cpython-37.pyc ├── summaries.cpython-37.pyc └── tracker.cpython-37.pyc ├── calculate_weights.py ├── f_boundary.py ├── img_saver.py ├── loss.py ├── lr_scheduler.py ├── metrics.py ├── net_convert.py ├── saver.py ├── summaries.py └── tracker.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 LK-Peng 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CNN-based-Cloud-Detection-Methods 2 | ## Paper: Understanding the Role of Receptive Field of Convolutional Neural Network for Cloud Detection in Landsat 8 OLI Imagery 3 | 4 | ### TODO 5 | - [x] Support different convolutional neural networks for cloud detection 6 | - [x] Support calculation of effective receptive field 7 | - [x] Multi-GPU training 8 | 9 | 10 | 11 | * The supported networks are as follows: 12 | 13 | |Method|Reference| 14 | |:-:|:-:| 15 | |TL-Net|[Transferring deep learning models for cloud detection between Landsat-8 and Proba-V](https://www.sciencedirect.com/science/article/pii/S0924271619302801)| 16 | |MUNet|[Multi-sensor cloud and cloud shadow segmentation with a convolutional neural network](https://www.sciencedirect.com/science/article/pii/S0034425719302159)| 17 | |UNet|[U-net: Convolutional networks for biomedical image segmentation](https://link.springer.com/chapter/10.1007/978-3-319-24574-4_28)| 18 | |MF-CNN|[Cloud detection in remote sensing images based on multiscale features-convolutional neural network](https://ieeexplore.ieee.org/document/8625476)| 19 | |MSCFF|[Deep learning based cloud detection for medium and high resolution remote sensing images of different sensors](https://www.sciencedirect.com/science/article/pii/S0924271619300565)| 20 | |DeepLabv3+|[Encoder-decoder with atrous separable convolution for semantic image segmentation](https://arxiv.org/abs/1802.02611)| 21 | |UNet-1|Understanding the Role of Receptive Field of Convolutional Neural Network for Cloud Detection in Landsat 8 OLI Imagery| 22 | |UNet-2|Understanding the Role of Receptive Field of Convolutional Neural Network for Cloud Detection in Landsat 8 OLI Imagery| 23 | |UNet-3|Understanding the Role of Receptive Field of Convolutional Neural Network for Cloud Detection in Landsat 8 OLI Imagery| 24 | |UNet-D2|Understanding the Role of Receptive Field of Convolutional Neural Network for Cloud Detection in Landsat 8 OLI Imagery| 25 | |UNet-D4|Understanding the Role of Receptive Field of Convolutional Neural Network for Cloud Detection in Landsat 8 OLI Imagery| 26 | |UNet-S3|Understanding the Role of Receptive Field of Convolutional Neural Network for Cloud Detection in Landsat 8 OLI Imagery| 27 | |UNet-S2|Understanding the Role of Receptive Field of Convolutional Neural Network for Cloud Detection in Landsat 8 OLI Imagery| 28 | |UNet-S1|Understanding the Role of Receptive Field of Convolutional Neural Network for Cloud Detection in Landsat 8 OLI Imagery| 29 | 30 | 31 | * The links of the trained models are as follows: 32 | 33 | |Input Band Number|Band|Download Link|Password| 34 | |:-:|:-:|:-:|:-:| 35 | |8|red/green/blue/NIR/SWIR1/SWIR2/cirrus/TIR1|[Baidu Netdisk](https://pan.baidu.com/s/1obbeQlKybN40EW5lO6XUqQ?pwd=3tre)|3tre| 36 | |6|red/green/blue/NIR/SWIR1/SWIR2|[Baidu Netdisk](https://pan.baidu.com/s/1xAf6PnOfokroxmcQlIUhUA?pwd=m6nt)|m6nt| 37 | |4|red/green/blue/NIR|[Baidu Netdisk](https://pan.baidu.com/s/1nYHaIWZ0aA3MsxqHdviG5Q?pwd=qy48)|qy48| 38 | 39 | The trained model for the input data of 8 channels can also be downloaded from **[Google Drive](https://drive.google.com/drive/folders/1Av1Gl3WEug_G2UC4WZgddI1YVdvCiwfW?usp=sharing)** 40 | 41 | 42 | 43 | ### Introduction 44 | This is a PyTorch(1.7.1) implementation of **varied convolutional neural networks (CNNs) for cloud detection in Landsat 8 OLI imagery**. Currently, we train these networks using [L8 Biome](https://landsat.usgs.gov/landsat-8-cloud-cover-assessment-validation-data) dataset. **The related paper** aims to **understand the role of receptive field of CNN for cloud detection in Landsat 8 OLI imagery** and is under review. 45 | 46 | 47 | 48 | ### Installation 49 | The code was tested with **Anaconda** and **Python 3.7.3**. 50 | 51 | 0. For **PyTorch** dependency, see [pytorch.org](https://pytorch.org/) for more details. 52 | 53 | 1. For **Captum** dependency used for computing the effective receptive field, see [captum.ai](https://captum.ai/) for more details. 54 | 55 | 2. For **GDAL** dependency used for reading and writing raster data, use version 2.3.3. 56 | 57 | 58 | 59 | ### Training 60 | Follow steps below to train your model 61 | 62 | 0. Configure your dataset path in [config.py](https://github.com/LK-Peng/CNN-based-Cloud-Detection-Methods/blob/main/config.py) 63 | ```Shell 64 | def get_config_tr(net_name): 65 | ... 66 | parser.add_argument('--train-root', type=str, 67 | default='./example/train/Images', 68 | help='image root of train set') 69 | parser.add_argument('--train-list', type=str, 70 | default='./example/train/train.txt', 71 | help='image list of train set') 72 | parser.add_argument('--val-root', type=str, 73 | default='./example/val/Images', 74 | help='image root of validation set') 75 | parser.add_argument('--val-list', type=str, 76 | default='./example/val/val.txt', 77 | help='image list of validation set') 78 | ``` 79 | 80 | 1. Configure the network you want to use in [config.py](https://github.com/LK-Peng/CNN-based-Cloud-Detection-Methods/blob/main/config.py) 81 | ```Shell 82 | def get_config_tr(net_name): 83 | ... 84 | parser.add_argument('--net', type=str, default='{}'.format(net_name), 85 | choices=['DeeplabV3Plus', 'MFCNN', 'MSCFF', 'MUNet', 86 | 'TLNet', 'UNet', 'UNet-3', 'UNet-2', 'UNet-1', 87 | 'UNet-dilation', 'UNetS3', 'UNetS2', 'UNetS1'], 88 | help='network name (default: ?)') 89 | ``` 90 | 91 | or [train.py](https://github.com/LK-Peng/CNN-based-Cloud-Detection-Methods/blob/main/train.py) 92 | 93 | ```Shell 94 | def main(): 95 | # choices=['DeeplabV3Plus', 'MFCNN', 'MSCFF', 'MUNet', 'TLNet', 'UNet', 'UNet-3', 'UNet-2', 'UNet-1', 'UNet-dilation', 'UNetS3', 'UNetS2', 'UNetS1'] 96 | args = get_config_tr('TLNet') 97 | ``` 98 | 99 | 2. Run script 100 | ```Shell 101 | python train.py 102 | ``` 103 | 104 | ### Others 105 | 0. [inference.py](https://github.com/LK-Peng/CNN-based-Cloud-Detection-Methods/blob/main/inference.py) is used for predicting cloud detection results and output accuracies. 106 | 107 | 1. [erf.py](https://github.com/LK-Peng/CNN-based-Cloud-Detection-Methods/blob/main/erf.py) is used for computing the effective receptive field 108 | 109 | 2. [comparator.py](https://github.com/LK-Peng/CNN-based-Cloud-Detection-Methods/blob/main/comparator.py) is used for computing the accuracies of the predicted results. 110 | 111 | 112 | 113 | ### Acknowledgement 114 | * [DeepLab-V3-Plus](https://github.com/jfzhang95/pytorch-deeplab-xception) 115 | 116 | * [UNet](https://github.com/milesial/Pytorch-UNet) 117 | -------------------------------------------------------------------------------- /ckp/README.TXT: -------------------------------------------------------------------------------- 1 | Folder to store optimal models. -------------------------------------------------------------------------------- /comparator.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import time 4 | import argparse 5 | import numpy as np 6 | from tqdm import tqdm 7 | from multiprocessing import Pool 8 | from torch.utils.data import DataLoader 9 | 10 | from dataloaders.dataset import MaskSet 11 | from utils.metrics import Evaluator, BoundaryEvaluator 12 | 13 | 14 | class Comparator(object): 15 | def __init__(self, args): 16 | self.args = args 17 | 18 | # define dataloader 19 | kwargs = {'num_workers': args.workers, 'pin_memory': True} 20 | dataset = MaskSet(args) 21 | self.mask_loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False, **kwargs) 22 | 23 | # define multiprocess 24 | if args.num_proc: 25 | self.p = Pool(processes=args.num_proc) 26 | else: 27 | self.p = None 28 | 29 | # Define Evaluator 30 | self.evaluator = Evaluator(self.args.num_classes) 31 | self.boundaryevaluator_3 = BoundaryEvaluator(self.args.num_classes, self.p, self.args.num_proc, bound_th=3) 32 | self.boundaryevaluator_5 = BoundaryEvaluator(self.args.num_classes, self.p, self.args.num_proc, bound_th=5) 33 | 34 | def cal_metric(self): 35 | tbar = tqdm(self.mask_loader, desc='\r') 36 | num_mask = len(self.mask_loader.dataset) 37 | print('numImages: {}'.format(num_mask)) 38 | # metric_img = dict() 39 | for i, sample in enumerate(tbar): 40 | gt_mask, pre_mask = sample['gt'].numpy(), sample['pre'].numpy() 41 | 42 | self.evaluator.add_batch(gt_mask, pre_mask) 43 | self.boundaryevaluator_3.add_batch(gt_mask, pre_mask) 44 | self.boundaryevaluator_5.add_batch(gt_mask, pre_mask) 45 | 46 | metric_dct = { 47 | 'PA': self.evaluator.Pixel_Accuracy(), 48 | 'MPA': self.evaluator.Pixel_Accuracy_Class(), 49 | 'MIoU': self.evaluator.Mean_Intersection_over_Union(), 50 | 'FWIoU': self.evaluator.Frequency_Weighted_Intersection_over_Union(), 51 | 'Precision': self.evaluator.Precision(), 52 | 'Recall': self.evaluator.Recall(), 53 | 'F1': self.evaluator.F_score(), 54 | 'F_boundary_3': self.boundaryevaluator_3.F_score_boundary().tolist(), 55 | 'Pr_boundary_3': self.boundaryevaluator_3.Precision_boundary().tolist(), 56 | 'Re_boundary_3': self.boundaryevaluator_3.Recall_boundary().tolist(), 57 | 'F_boundary_5': self.boundaryevaluator_5.F_score_boundary().tolist(), 58 | 'Pr_boundary_5': self.boundaryevaluator_5.Precision_boundary().tolist(), 59 | 'Re_boundary_5': self.boundaryevaluator_5.Recall_boundary().tolist(), 60 | } 61 | with open(self.args.out_file, 'w') as f: 62 | json.dump(metric_dct, f, indent=4) 63 | 64 | 65 | if __name__ == '__main__': 66 | parser = argparse.ArgumentParser(description='Compare two mask') 67 | parser.add_argument('--workers', type=int, default=4, 68 | metavar='N', help='dataloader threads') 69 | parser.add_argument('--batch-size', type=int, default=24, 70 | metavar='N', help='input batch size for comparison (default: auto)') 71 | parser.add_argument('--pre-root', type=str, 72 | default=None, 73 | help='mask root of prediction') 74 | parser.add_argument('--gt-root', type=str, 75 | default='./example/test/Masks', 76 | help='mask root of ground truth') 77 | parser.add_argument('--merge-class', action='store_true', default=True, 78 | help='if merge class in ground truth') 79 | parser.add_argument('--num-classes', type=int, default=2, 80 | help='the number of classes (default:2)') 81 | parser.add_argument('--num-proc', type=int, default=4, 82 | help='the number of processes (default:4)') 83 | parser.add_argument('--selected-file', type=str, 84 | default='./inference/cld_clr_tile_list.json', 85 | help='list of files needed to compute boundary accuracy') 86 | parser.add_argument('--out-file', type=str, 87 | default=None, 88 | help='output file') 89 | 90 | args = parser.parse_args() 91 | 92 | net_root = { 93 | 'DeeplabV3Plus-seed1': './inference/DeeplabV3Plus-seed1', 94 | 'DeeplabV3Plus-seed2': './inference/DeeplabV3Plus-seed2', 95 | 'DeeplabV3Plus-seed3': './inference/DeeplabV3Plus-seed3', 96 | 'DeeplabV3Plus-seed4': './inference/DeeplabV3Plus-seed4', 97 | } 98 | 99 | for net in net_root.keys(): 100 | args.pre_root = net_root[net] 101 | args.out_file = os.path.join('./inference-mix', net + '.json') 102 | print('prediction: {}'.format(args.pre_root)) 103 | print('ground truth: {}'.format(args.gt_root)) 104 | start = time.time() 105 | comparator = Comparator(args) 106 | comparator.cal_metric() 107 | comparator.p.close() # 关闭进程池 108 | print('Using {}s!'.format(time.time() - start)) 109 | -------------------------------------------------------------------------------- /dataloaders/__init__.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import DataLoader 2 | 3 | from dataloaders import dataset 4 | 5 | 6 | def make_data_loader(args, **kwargs): 7 | 8 | if args.dataset == 'RS': 9 | train_set = dataset.RSSet(args, split='train') 10 | val_set = dataset.RSSet(args, split='val') 11 | train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True, **kwargs) 12 | val_loader = DataLoader(val_set, batch_size=args.batch_size, shuffle=False, **kwargs) 13 | return train_loader, val_loader 14 | else: 15 | raise NotImplementedError 16 | 17 | -------------------------------------------------------------------------------- /dataloaders/custom_transforms.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import random 3 | import numpy as np 4 | 5 | from PIL import Image, ImageOps, ImageFilter 6 | 7 | 8 | class Normalize(object): 9 | """Normalize a numpy image with mean and standard deviation. 10 | Args: 11 | mean (tuple): means for each channel. 12 | std (tuple): standard deviations for each channel. 13 | """ 14 | def __init__(self, mean=[0., 0., 0.], std=[1., 1., 1.], no_gt=False): 15 | self.mean = mean 16 | self.std = std 17 | self.no_gt = no_gt 18 | 19 | def __call__(self, sample): 20 | # numpy image: C X H X W 21 | img = sample['image'] 22 | mask = sample['label'] 23 | img = np.array(img).astype(np.float32) 24 | if not self.no_gt: 25 | mask = np.array(mask).astype(np.float32) 26 | 27 | # swap axis for broadcast operations 28 | img = img.transpose((1, 2, 0)) # H X W X C 29 | img -= self.mean 30 | img /= self.std 31 | # restore the order of axis 32 | img = img.transpose((2, 0, 1)) 33 | 34 | return {'image': img, 35 | 'label': mask} 36 | 37 | 38 | class ToTensor(object): 39 | """Convert ndarrays in sample to Tensors.""" 40 | def __init__(self, no_gt=False): 41 | self.no_gt = no_gt 42 | 43 | def __call__(self, sample): 44 | # swap color axis because 45 | # numpy image: C X H X W 46 | # torch image: C X H X W 47 | img = sample['image'] 48 | mask = sample['label'] 49 | img = np.array(img).astype(np.float32) 50 | img = torch.from_numpy(img).float() 51 | if not self.no_gt: 52 | mask = np.array(mask).astype(np.float32) 53 | mask = torch.from_numpy(mask).float() 54 | 55 | return {'image': img, 56 | 'label': mask} 57 | 58 | 59 | class RandomHorizontalFlip(object): 60 | def __call__(self, sample): 61 | img = sample['image'] 62 | mask = sample['label'] 63 | if random.random() < 0.5: 64 | img = img.transpose(Image.FLIP_LEFT_RIGHT) 65 | mask = mask.transpose(Image.FLIP_LEFT_RIGHT) 66 | 67 | return {'image': img, 68 | 'label': mask} 69 | 70 | 71 | class RandomRotate(object): 72 | def __init__(self, degree): 73 | self.degree = degree 74 | 75 | def __call__(self, sample): 76 | img = sample['image'] 77 | mask = sample['label'] 78 | rotate_degree = random.uniform(-1*self.degree, self.degree) 79 | img = img.rotate(rotate_degree, Image.BILINEAR) 80 | mask = mask.rotate(rotate_degree, Image.NEAREST) 81 | 82 | return {'image': img, 83 | 'label': mask} 84 | 85 | 86 | class RandomGaussianBlur(object): 87 | def __call__(self, sample): 88 | img = sample['image'] 89 | mask = sample['label'] 90 | if random.random() < 0.5: 91 | img = img.filter(ImageFilter.GaussianBlur( 92 | radius=random.random())) 93 | 94 | return {'image': img, 95 | 'label': mask} 96 | 97 | 98 | class RandomScaleCrop(object): 99 | def __init__(self, base_size, crop_size, fill=0): 100 | self.base_size = base_size 101 | self.crop_size = crop_size 102 | self.fill = fill 103 | 104 | def __call__(self, sample): 105 | img = sample['image'] 106 | mask = sample['label'] 107 | # random scale (short edge) 108 | short_size = random.randint(int(self.base_size * 0.5), int(self.base_size * 2.0)) 109 | w, h = img.size 110 | if h > w: 111 | ow = short_size 112 | oh = int(1.0 * h * ow / w) 113 | else: 114 | oh = short_size 115 | ow = int(1.0 * w * oh / h) 116 | img = img.resize((ow, oh), Image.BILINEAR) 117 | mask = mask.resize((ow, oh), Image.NEAREST) 118 | # pad crop 119 | if short_size < self.crop_size: 120 | padh = self.crop_size - oh if oh < self.crop_size else 0 121 | padw = self.crop_size - ow if ow < self.crop_size else 0 122 | img = ImageOps.expand(img, border=(0, 0, padw, padh), fill=0) 123 | mask = ImageOps.expand(mask, border=(0, 0, padw, padh), fill=self.fill) 124 | # random crop crop_size 125 | w, h = img.size 126 | x1 = random.randint(0, w - self.crop_size) 127 | y1 = random.randint(0, h - self.crop_size) 128 | img = img.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size)) 129 | mask = mask.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size)) 130 | 131 | return {'image': img, 132 | 'label': mask} 133 | 134 | 135 | class FixScaleCrop(object): 136 | def __init__(self, crop_size): 137 | self.crop_size = crop_size 138 | 139 | def __call__(self, sample): 140 | img = sample['image'] 141 | mask = sample['label'] 142 | w, h = img.size 143 | if w > h: 144 | oh = self.crop_size 145 | ow = int(1.0 * w * oh / h) 146 | else: 147 | ow = self.crop_size 148 | oh = int(1.0 * h * ow / w) 149 | img = img.resize((ow, oh), Image.BILINEAR) 150 | mask = mask.resize((ow, oh), Image.NEAREST) 151 | # center crop 152 | w, h = img.size 153 | x1 = int(round((w - self.crop_size) / 2.)) 154 | y1 = int(round((h - self.crop_size) / 2.)) 155 | img = img.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size)) 156 | mask = mask.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size)) 157 | 158 | return {'image': img, 159 | 'label': mask} 160 | 161 | class FixedResize(object): 162 | def __init__(self, size): 163 | self.size = (size, size) # size: (h, w) 164 | 165 | def __call__(self, sample): 166 | img = sample['image'] 167 | mask = sample['label'] 168 | 169 | assert img.size == mask.size 170 | 171 | img = img.resize(self.size, Image.BILINEAR) 172 | mask = mask.resize(self.size, Image.NEAREST) 173 | 174 | return {'image': img, 175 | 'label': mask} -------------------------------------------------------------------------------- /dataloaders/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import numpy as np 4 | from osgeo import gdal 5 | from torchvision import transforms 6 | from torch.utils.data import Dataset 7 | 8 | import dataloaders.custom_transforms as ctr 9 | # import custom_transforms as ctr 10 | 11 | 12 | class RSSet(Dataset): 13 | def __init__(self, args, split): 14 | super().__init__() 15 | if split == 'train': 16 | set_list = args.train_list 17 | set_root = args.train_root 18 | elif split == 'val': 19 | set_list = args.val_list 20 | set_root = args.val_root 21 | elif split == 'test': 22 | set_list = args.test_list 23 | set_root = args.test_root 24 | with open(set_list, 'r') as f: 25 | lines = f.readlines() 26 | img_files = [] 27 | mask_files = [] 28 | for line in lines: 29 | img_files.append(os.path.join(set_root, line.split()[0])) 30 | mask_files.append(img_files[-1].replace('Images', 'Masks')) 31 | self.img_files = img_files 32 | self.mask_files = mask_files 33 | self.split = split 34 | self.args = args 35 | 36 | def __getitem__(self, item): 37 | img = gdal.Open(self.img_files[item]).ReadAsArray()[0:self.args.in_channels, :] 38 | if self.split == 'train': 39 | sample = {'image': img, 40 | 'label': self._transform_mask_multi(gdal.Open(self.mask_files[item]).ReadAsArray())} 41 | return self._transform_tr(sample) 42 | elif self.split == 'val': 43 | sample = {'image': img, 44 | 'label': self._transform_mask_multi(gdal.Open(self.mask_files[item]).ReadAsArray())} 45 | return self._transform_tr(sample) 46 | elif self.split == 'test': 47 | sample = {'image': img, 'label': np.array([])} 48 | return self._transform_test(sample) 49 | 50 | def _transform_tr(self, img): 51 | data_transforms = transforms.Compose([ 52 | ctr.Normalize(mean=self.args.mean, std=self.args.std), 53 | ctr.ToTensor() 54 | ]) 55 | return data_transforms(img) 56 | 57 | def _transform_val(self, img): 58 | data_transforms = transforms.Compose([ 59 | ctr.Normalize(mean=self.args.mean, std=self.args.std), 60 | ctr.ToTensor() 61 | ]) 62 | return data_transforms(img) 63 | 64 | def _transform_test(self, img): 65 | data_transforms = transforms.Compose([ 66 | ctr.Normalize(mean=self.args.mean, std=self.args.std, no_gt=True), 67 | ctr.ToTensor(no_gt=True) 68 | ]) 69 | return data_transforms(img) 70 | 71 | def _transform_mask_binary(self, mask): 72 | # convert mask value to 0,1,2... 73 | # mask(mix2): 1 -- clear, 2 -- cloud 74 | mask[mask == 1] = 0 75 | mask[mask == 2] = 1 76 | 77 | return mask 78 | 79 | def _transform_mask_multi(self, mask): 80 | # convert mask value to 0,1,2... 81 | # mask(mix): 0 -- fill, 64 -- cloud shadow, 128 -- clear, 192 -- thin cloud, 255 -- cloud 82 | mask[mask == 64] = 0 83 | mask[mask == 128] = 0 84 | mask[mask == 192] = 1 85 | mask[mask == 255] = 1 86 | 87 | return mask 88 | 89 | def __len__(self): 90 | return len(self.img_files) 91 | 92 | 93 | class RSERFSet(Dataset): 94 | def __init__(self, args): 95 | super().__init__() 96 | 97 | """ 98 | Example: 99 | pixels = {'Barren_02_0319.txt': [[121, 122], [125,127], [122,131]]} 100 | """ 101 | with open(args.pixel_list, 'r') as f: 102 | pixels = json.load(f) 103 | selected_files = list(pixels.keys()) 104 | img_files, targets = [], [] 105 | for key in pixels.keys(): 106 | if key in selected_files: 107 | img_files.extend([os.path.join(args.img_root, key)] * len(pixels[key])) 108 | targets.extend(pixels[key]) 109 | self.img_files = img_files 110 | self.targets = targets 111 | self.args = args 112 | 113 | def __getitem__(self, item): 114 | img = gdal.Open(self.img_files[item]).ReadAsArray()[0:self.args.in_channels, :] 115 | sample = {'image': img, 116 | 'label': np.array(self.targets[item])} 117 | return self._transform_erf(sample) 118 | 119 | def _transform_erf(self, img): 120 | data_transforms = transforms.Compose([ 121 | ctr.Normalize(mean=self.args.mean, std=self.args.std, no_gt=True), 122 | ctr.ToTensor(no_gt=True) 123 | ]) 124 | return data_transforms(img) 125 | 126 | def __len__(self): 127 | return len(self.img_files) 128 | 129 | 130 | class MaskSet(Dataset): 131 | def __init__(self, args): 132 | super().__init__() 133 | 134 | filelist = os.listdir(args.pre_root) 135 | with open('/home/clouddt/XAI/dataFinal/inference/cld_clr_tile_list.json', 'r') as f: 136 | selected_files = json.load(f) 137 | self.pre_files = [os.path.join(args.pre_root, file) for file in filelist if file in selected_files] 138 | self.gt_files = [os.path.join(args.gt_root, os.path.split(file)[-1]) for file in self.pre_files] 139 | # in real ground truth mask: 140 | # mask(mix): 0 -- fill, 64 -- cloud shadow, 128 -- clear, 192 -- thin cloud, 255 -- cloud 141 | self.merge_class = args.merge_class 142 | 143 | def __getitem__(self, item): 144 | sample = {'pre': gdal.Open(self.pre_files[item]).ReadAsArray(), 145 | 'gt': gdal.Open(self.gt_files[item]).ReadAsArray()} 146 | if self.merge_class: 147 | sample['gt'] = self._transform_mask_multi(sample['gt']) 148 | return sample 149 | 150 | def _transform_mask_multi(self, mask): 151 | # convert mask value to 0,1,2... 152 | # mask(mix): 0 -- fill, 64 -- cloud shadow, 128 -- clear, 192 -- thin cloud, 255 -- cloud 153 | mask[mask == 64] = 0 154 | mask[mask == 128] = 0 155 | mask[mask == 192] = 1 156 | mask[mask == 255] = 1 157 | 158 | return mask 159 | 160 | def __len__(self): 161 | return len(self.pre_files) 162 | 163 | 164 | if __name__ == "__main__": 165 | import os 166 | import argparse 167 | import numpy as np 168 | from osgeo import gdal 169 | from tqdm import tqdm 170 | from torch.utils.data import DataLoader 171 | 172 | 173 | def save_img(tiff, out_file, projection=None, geotransform=None): 174 | # save tiff image 175 | 176 | NP2GDAL_CONVERSION = { 177 | "uint8": 1, 178 | "int8": 1, 179 | "uint16": 2, 180 | "int16": 3, 181 | "uint32": 4, 182 | "int32": 5, 183 | "float32": 6, 184 | "float64": 7, 185 | "complex64": 10, 186 | "complex128": 11, 187 | } # convert np to gdal 188 | gdal_type = NP2GDAL_CONVERSION[tiff.dtype.name] 189 | if len(tiff.shape) == 2: 190 | tiff = np.expand_dims(tiff, axis=0) 191 | channel, row, col = tiff.shape 192 | # 使用驱动对象来创建数据集 193 | gtiff_driver = gdal.GetDriverByName('GTiff') 194 | out_ds = gtiff_driver.Create(out_file, col, row, channel, gdal_type) 195 | if projection is not None and geotransform is not None: 196 | out_ds.SetProjection(projection) # 设置投影 197 | out_ds.SetGeoTransform(geotransform) # 设置geotransform信息 198 | # 向输出数据源写入数据 199 | for iband in range(channel): 200 | out_ds.GetRasterBand(iband + 1).WriteArray(tiff[iband, :, :]) 201 | del out_ds 202 | 203 | parser = argparse.ArgumentParser(description="PyTorch DeeplabV3Plus Training") # create ArgumentParser object 204 | parser.add_argument('--workers', type=int, default=4, 205 | metavar='N', help='dataloader threads') 206 | parser.add_argument('--batch-size', type=int, default=16, 207 | metavar='N', help='input batch size for \ 208 | training (default: auto)') 209 | parser.add_argument('--train-root', type=str, default='./example/train/Images', 210 | help='image root of train set') 211 | parser.add_argument('--train-list', type=str, default='./example/train/train.txt', 212 | help='image list of train set') 213 | parser.add_argument('--val-root', type=str, default='./example/val/Images', 214 | help='image root of validation set') 215 | parser.add_argument('--val-list', type=str, default='./example/val/val.txt', 216 | help='image list of validation set') 217 | parser.add_argument('--mean', type=str, 218 | default='0.432, 0.398, 0.411, 0.479, 0.240, 0.192, 0.037, 268.051', 219 | help='mean of each channel (used in data normalization), \ 220 | must be a comma-separated list of floats only') 221 | parser.add_argument('--std', type=str, 222 | default='0.313, 0.295, 0.311, 0.285, 0.162, 0.132, 0.079, 25.412', 223 | help='standard deviation of each channel (used in data normalization), \ 224 | must be a comma-separated list of floats only') 225 | parser.add_argument('--output-root', type=str, default='./example/train/check', 226 | help='image root of train set') 227 | 228 | args = parser.parse_args() # analyze parameters 229 | 230 | args.mean = [float(s) for s in args.mean.split(',')] 231 | args.std = [float(s) for s in args.std.split(',')] 232 | 233 | kwargs = {'num_workers': args.workers, 'pin_memory': True} 234 | train_set = RSSet(args, split='train') 235 | train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True, **kwargs) 236 | tbar = tqdm(train_loader, desc='\r') 237 | for i, sample in enumerate(tbar): 238 | if i == 0: 239 | image, target = sample['image'].numpy(), sample['label'].numpy() 240 | for j in range(args.batch_size): 241 | # save image 242 | image_temp = image[j, :] 243 | save_img(image_temp, os.path.join(args.output_root, 'train_check_{}.tif'.format(j))) 244 | # save target 245 | target_temp = target[j, :] 246 | save_img(target_temp, os.path.join(args.output_root, 'train_check_target_{}.tif'.format(j))) 247 | break 248 | 249 | val_set = RSSet(args, split='val') 250 | val_loader = DataLoader(val_set, batch_size=args.batch_size, shuffle=True, **kwargs) 251 | tbar = tqdm(val_loader, desc='\r') 252 | for i, sample in enumerate(tbar): 253 | if i == 0: 254 | image, target = sample['image'].numpy(), sample['label'].numpy() 255 | for j in range(args.batch_size): 256 | # save image 257 | image_temp = image[j, :] 258 | save_img(image_temp, os.path.join(args.output_root, 'val_check_{}.tif'.format(j))) 259 | # save target 260 | target_temp = target[j, :] 261 | save_img(target_temp, os.path.join(args.output_root, 'val_check_target_{}.tif'.format(j))) 262 | break 263 | 264 | -------------------------------------------------------------------------------- /erf.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import time 4 | import json 5 | import random 6 | import numpy as np 7 | from osgeo import gdal 8 | from tqdm import tqdm 9 | from torch.utils.data import DataLoader 10 | from matplotlib.colors import LinearSegmentedColormap 11 | from torchvision import transforms 12 | from captum.attr import Saliency 13 | 14 | # os.environ['CUDA_VISIBLE_DEVICES'] = '1' 15 | 16 | from dataloaders.dataset import RSERFSet 17 | from model import get_network 18 | from interpretation.calculate_erf import calculate_erf 19 | from utils.img_saver import save_img 20 | from config import get_config_erf 21 | 22 | 23 | class ERF(object): 24 | def __init__(self, args): 25 | self.args = args 26 | 27 | # Define Dataloader 28 | kwargs = {'num_workers': args.workers, 'pin_memory': True} 29 | erf_set = RSERFSet(args) 30 | self.erf_loader = DataLoader(erf_set, batch_size=args.batch_size, shuffle=False, **kwargs) 31 | 32 | # Define network 33 | model = get_network(args) 34 | 35 | # count parameters 36 | param_count = 0 37 | for param in model.parameters(): 38 | param_count += param.view(-1).size()[0] 39 | print('Total parameters: {}M ({})'.format(param_count / 1e6, param_count)) 40 | self.model = model 41 | 42 | # Using cuda 43 | if args.cuda: 44 | self.model = torch.nn.DataParallel(self.model, device_ids=self.args.gpu_ids) 45 | self.model = self.model.cuda() 46 | 47 | def cal_erf(self, load_path): 48 | # load 49 | checkpoint = torch.load(load_path) 50 | if self.args.cuda: 51 | self.model.module.load_state_dict(checkpoint['state_dict']) 52 | else: 53 | self.model.load_state_dict(checkpoint['state_dict']) 54 | self.model.eval() 55 | sl = Saliency(self.model) 56 | 57 | tbar = tqdm(self.erf_loader, desc='\r') 58 | num_pixel = len(self.erf_loader.dataset) # pixel num 59 | erf_img = dict() 60 | out_path = os.path.join(self.args.out_path, os.path.split(load_path)[-1]) 61 | if not os.path.exists(out_path): 62 | os.mkdir(out_path) 63 | 64 | for i, sample in enumerate(tbar): 65 | image, targets = sample['image'], sample['label'].numpy() 66 | if self.args.cuda: 67 | image = image.cuda() 68 | if self.args.max_pro: 69 | with torch.no_grad(): 70 | output = self.model(image) 71 | pred = output.data.cpu().numpy() 72 | pred = np.argmax(pred, axis=1).astype(np.uint8) 73 | targets_tup = [(pred[i, targets[i][0], targets[i][1]], targets[i][0], targets[i][1]) for i in 74 | range(len(targets))] 75 | else: 76 | targets_tup = [(1, hw[0], hw[1]) for hw in targets] 77 | image.requires_grad = True 78 | sl_attr = sl.attribute(image, target=targets_tup, abs=False) 79 | sl_img = sl_attr.data.cpu().numpy() 80 | 81 | b, _, _, _ = sl_img.shape 82 | for ib in range(b): 83 | filename = os.path.split(self.erf_loader.dataset.img_files[i * self.args.batch_size + ib])[-1] 84 | filename = filename.split(sep='.')[0] + '_h%.3d' % targets[ib][0] + '_w%.3d' % targets[ib][1] + '.tif' 85 | erf_img[filename] = calculate_erf(sl_img[ib, :], targets[ib, 0], targets[ib, 1]) 86 | if self.args.save_img: # and i % 100 == 0 87 | save_img(sl_img[ib, :], os.path.join(out_path, filename)) 88 | 89 | with open(os.path.join(self.args.out_path, os.path.split(load_path)[-1] + '.json'), 'w') as f: 90 | json.dump(erf_img, f, indent=4) 91 | 92 | print('ERF:') 93 | print('[numPixels: %5d]' % num_pixel) 94 | print("mean ERF:{}, std ERF:{}".format(np.mean(list(erf_img.values())), np.std(list(erf_img.values())))) 95 | 96 | 97 | def main(): 98 | 99 | load_roots = { 100 | 'DeeplabV3Plus-seed1': './ckp/DeeplabV3Plus-seed1.pth.tar', 101 | 'DeeplabV3Plus-seed2': './ckp/DeeplabV3Plus-seed2.pth.tar', 102 | 'DeeplabV3Plus-seed3': './ckp/DeeplabV3Plus-seed3.pth.tar', 103 | 'DeeplabV3Plus-seed4': './ckp/DeeplabV3Plus-seed4.pth.tar', 104 | } 105 | 106 | net_names = [ 107 | 'DeeplabV3Plus-seed1', 'DeeplabV3Plus-seed2', 'DeeplabV3Plus-seed3', 'DeeplabV3Plus-seed4', 108 | ] 109 | 110 | for net_name in net_names: 111 | print('Using model {}'.format(net_name)) 112 | start1 = time.time() 113 | 114 | if 'dilation' in net_name: 115 | args = get_config_erf('UNet-dilation') 116 | args.dilation = int(net_name.split(sep='-')[0][-1]) 117 | else: 118 | args = get_config_erf('-'.join(net_name.split(sep='-')[0:-1])) 119 | args.seed = int(net_name.split(sep='-')[-1][4:]) 120 | if 'MSCFF' in args.net or 'DeeplabV3Plus' in args.net: 121 | args.batch_size = 6 122 | 123 | # define parameters files 124 | args.load_paths = [load_roots[net_name]] 125 | 126 | # define output path 127 | args.out_path = os.path.join('./erf', net_name) 128 | if not os.path.exists(args.out_path): 129 | os.mkdir(args.out_path) 130 | 131 | print(args) 132 | 133 | torch.manual_seed(args.seed) # set seed for the CPU 134 | np.random.seed(args.seed) 135 | torch.cuda.manual_seed_all(args.seed) 136 | torch.backends.cudnn.deterministic = True 137 | torch.backends.cudnn.benchmark = False 138 | 139 | # erf = ERF(args) 140 | for load_path in args.load_paths: 141 | erf = ERF(args) 142 | erf.cal_erf(load_path) 143 | del erf 144 | 145 | print('Using {}s!'.format(time.time() - start1)) 146 | 147 | 148 | if __name__ == '__main__': 149 | start = time.time() 150 | main() 151 | print('Using {}s!'.format(time.time() - start)) -------------------------------------------------------------------------------- /erf/README.TXT: -------------------------------------------------------------------------------- 1 | Folder to store the results of effective receptive field. -------------------------------------------------------------------------------- /erf/selected_pixel_all.json: -------------------------------------------------------------------------------- 1 | { 2 | "Barren_07_2524.tif": [ 3 | [ 4 | 124, 5 | 130 6 | ], 7 | [ 8 | 127, 9 | 122 10 | ] 11 | ], 12 | "Forest_01_2715.tif": [ 13 | [ 14 | 125, 15 | 127 16 | ], 17 | [ 18 | 122, 19 | 115 20 | ] 21 | ], 22 | "GrassCrops_09_2014.tif": [ 23 | [ 24 | 122, 25 | 123 26 | ], 27 | [ 28 | 129, 29 | 118 30 | ] 31 | ], 32 | "Shrubland_10_1706.tif": [ 33 | [ 34 | 127, 35 | 125 36 | ], 37 | [ 38 | 129, 39 | 130 40 | ] 41 | ], 42 | "Urban_11_1326.tif": [ 43 | [ 44 | 125, 45 | 124 46 | ], 47 | [ 48 | 128, 49 | 125 50 | ] 51 | ] 52 | } -------------------------------------------------------------------------------- /example/test/Images/Barren_07_2524.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LK-Peng/CNN-based-Cloud-Detection-Methods/1393a6886e62f1ed5a612d57c5a725c763a6b2cc/example/test/Images/Barren_07_2524.tif -------------------------------------------------------------------------------- /example/test/Images/Forest_01_2715.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LK-Peng/CNN-based-Cloud-Detection-Methods/1393a6886e62f1ed5a612d57c5a725c763a6b2cc/example/test/Images/Forest_01_2715.tif -------------------------------------------------------------------------------- /example/test/Images/GrassCrops_09_2014.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LK-Peng/CNN-based-Cloud-Detection-Methods/1393a6886e62f1ed5a612d57c5a725c763a6b2cc/example/test/Images/GrassCrops_09_2014.tif -------------------------------------------------------------------------------- /example/test/Images/Shrubland_10_1706.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LK-Peng/CNN-based-Cloud-Detection-Methods/1393a6886e62f1ed5a612d57c5a725c763a6b2cc/example/test/Images/Shrubland_10_1706.tif -------------------------------------------------------------------------------- /example/test/Images/Urban_11_1326.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LK-Peng/CNN-based-Cloud-Detection-Methods/1393a6886e62f1ed5a612d57c5a725c763a6b2cc/example/test/Images/Urban_11_1326.tif -------------------------------------------------------------------------------- /example/test/Masks/Barren_07_2524.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LK-Peng/CNN-based-Cloud-Detection-Methods/1393a6886e62f1ed5a612d57c5a725c763a6b2cc/example/test/Masks/Barren_07_2524.tif -------------------------------------------------------------------------------- /example/test/Masks/Forest_01_2715.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LK-Peng/CNN-based-Cloud-Detection-Methods/1393a6886e62f1ed5a612d57c5a725c763a6b2cc/example/test/Masks/Forest_01_2715.tif -------------------------------------------------------------------------------- /example/test/Masks/GrassCrops_09_2014.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LK-Peng/CNN-based-Cloud-Detection-Methods/1393a6886e62f1ed5a612d57c5a725c763a6b2cc/example/test/Masks/GrassCrops_09_2014.tif -------------------------------------------------------------------------------- /example/test/Masks/Shrubland_10_1706.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LK-Peng/CNN-based-Cloud-Detection-Methods/1393a6886e62f1ed5a612d57c5a725c763a6b2cc/example/test/Masks/Shrubland_10_1706.tif -------------------------------------------------------------------------------- /example/test/Masks/Urban_11_1326.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LK-Peng/CNN-based-Cloud-Detection-Methods/1393a6886e62f1ed5a612d57c5a725c763a6b2cc/example/test/Masks/Urban_11_1326.tif -------------------------------------------------------------------------------- /example/test/test.txt: -------------------------------------------------------------------------------- 1 | Barren_07_2524.tif 2 | Forest_01_2715.tif 3 | GrassCrops_09_2014.tif 4 | Shrubland_10_1706.tif 5 | Urban_11_1326.tif 6 | -------------------------------------------------------------------------------- /example/train/Images/Barren_00_0214.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LK-Peng/CNN-based-Cloud-Detection-Methods/1393a6886e62f1ed5a612d57c5a725c763a6b2cc/example/train/Images/Barren_00_0214.tif -------------------------------------------------------------------------------- /example/train/Images/Barren_01_0213.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LK-Peng/CNN-based-Cloud-Detection-Methods/1393a6886e62f1ed5a612d57c5a725c763a6b2cc/example/train/Images/Barren_01_0213.tif -------------------------------------------------------------------------------- /example/train/Images/Forest_03_0209.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LK-Peng/CNN-based-Cloud-Detection-Methods/1393a6886e62f1ed5a612d57c5a725c763a6b2cc/example/train/Images/Forest_03_0209.tif -------------------------------------------------------------------------------- /example/train/Images/Forest_05_0415.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LK-Peng/CNN-based-Cloud-Detection-Methods/1393a6886e62f1ed5a612d57c5a725c763a6b2cc/example/train/Images/Forest_05_0415.tif -------------------------------------------------------------------------------- /example/train/Images/GrassCrops_00_1014.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LK-Peng/CNN-based-Cloud-Detection-Methods/1393a6886e62f1ed5a612d57c5a725c763a6b2cc/example/train/Images/GrassCrops_00_1014.tif -------------------------------------------------------------------------------- /example/train/Images/GrassCrops_02_0608.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LK-Peng/CNN-based-Cloud-Detection-Methods/1393a6886e62f1ed5a612d57c5a725c763a6b2cc/example/train/Images/GrassCrops_02_0608.tif -------------------------------------------------------------------------------- /example/train/Images/Shrubland_00_0614.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LK-Peng/CNN-based-Cloud-Detection-Methods/1393a6886e62f1ed5a612d57c5a725c763a6b2cc/example/train/Images/Shrubland_00_0614.tif -------------------------------------------------------------------------------- /example/train/Images/Shrubland_02_1021.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LK-Peng/CNN-based-Cloud-Detection-Methods/1393a6886e62f1ed5a612d57c5a725c763a6b2cc/example/train/Images/Shrubland_02_1021.tif -------------------------------------------------------------------------------- /example/train/Images/SnowIce_00_1415.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LK-Peng/CNN-based-Cloud-Detection-Methods/1393a6886e62f1ed5a612d57c5a725c763a6b2cc/example/train/Images/SnowIce_00_1415.tif -------------------------------------------------------------------------------- /example/train/Images/SnowIce_01_0909.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LK-Peng/CNN-based-Cloud-Detection-Methods/1393a6886e62f1ed5a612d57c5a725c763a6b2cc/example/train/Images/SnowIce_01_0909.tif -------------------------------------------------------------------------------- /example/train/Images/Urban_00_0515.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LK-Peng/CNN-based-Cloud-Detection-Methods/1393a6886e62f1ed5a612d57c5a725c763a6b2cc/example/train/Images/Urban_00_0515.tif -------------------------------------------------------------------------------- /example/train/Images/Urban_01_0609.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LK-Peng/CNN-based-Cloud-Detection-Methods/1393a6886e62f1ed5a612d57c5a725c763a6b2cc/example/train/Images/Urban_01_0609.tif -------------------------------------------------------------------------------- /example/train/Images/Water_01_0922.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LK-Peng/CNN-based-Cloud-Detection-Methods/1393a6886e62f1ed5a612d57c5a725c763a6b2cc/example/train/Images/Water_01_0922.tif -------------------------------------------------------------------------------- /example/train/Images/Water_02_1010.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LK-Peng/CNN-based-Cloud-Detection-Methods/1393a6886e62f1ed5a612d57c5a725c763a6b2cc/example/train/Images/Water_02_1010.tif -------------------------------------------------------------------------------- /example/train/Images/Wetlands_00_1426.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LK-Peng/CNN-based-Cloud-Detection-Methods/1393a6886e62f1ed5a612d57c5a725c763a6b2cc/example/train/Images/Wetlands_00_1426.tif -------------------------------------------------------------------------------- /example/train/Images/Wetlands_01_0416.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LK-Peng/CNN-based-Cloud-Detection-Methods/1393a6886e62f1ed5a612d57c5a725c763a6b2cc/example/train/Images/Wetlands_01_0416.tif -------------------------------------------------------------------------------- /example/train/Masks/Barren_00_0214.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LK-Peng/CNN-based-Cloud-Detection-Methods/1393a6886e62f1ed5a612d57c5a725c763a6b2cc/example/train/Masks/Barren_00_0214.tif -------------------------------------------------------------------------------- /example/train/Masks/Barren_01_0213.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LK-Peng/CNN-based-Cloud-Detection-Methods/1393a6886e62f1ed5a612d57c5a725c763a6b2cc/example/train/Masks/Barren_01_0213.tif -------------------------------------------------------------------------------- /example/train/Masks/Barren_03_0207.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LK-Peng/CNN-based-Cloud-Detection-Methods/1393a6886e62f1ed5a612d57c5a725c763a6b2cc/example/train/Masks/Barren_03_0207.tif -------------------------------------------------------------------------------- /example/train/Masks/Barren_08_0207.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LK-Peng/CNN-based-Cloud-Detection-Methods/1393a6886e62f1ed5a612d57c5a725c763a6b2cc/example/train/Masks/Barren_08_0207.tif -------------------------------------------------------------------------------- /example/train/Masks/Barren_10_0206.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LK-Peng/CNN-based-Cloud-Detection-Methods/1393a6886e62f1ed5a612d57c5a725c763a6b2cc/example/train/Masks/Barren_10_0206.tif -------------------------------------------------------------------------------- /example/train/Masks/Barren_11_0207.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LK-Peng/CNN-based-Cloud-Detection-Methods/1393a6886e62f1ed5a612d57c5a725c763a6b2cc/example/train/Masks/Barren_11_0207.tif -------------------------------------------------------------------------------- /example/train/Masks/Forest_03_0209.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LK-Peng/CNN-based-Cloud-Detection-Methods/1393a6886e62f1ed5a612d57c5a725c763a6b2cc/example/train/Masks/Forest_03_0209.tif -------------------------------------------------------------------------------- /example/train/Masks/Forest_05_0415.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LK-Peng/CNN-based-Cloud-Detection-Methods/1393a6886e62f1ed5a612d57c5a725c763a6b2cc/example/train/Masks/Forest_05_0415.tif -------------------------------------------------------------------------------- /example/train/Masks/Forest_06_0515.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LK-Peng/CNN-based-Cloud-Detection-Methods/1393a6886e62f1ed5a612d57c5a725c763a6b2cc/example/train/Masks/Forest_06_0515.tif -------------------------------------------------------------------------------- /example/train/Masks/Forest_07_1209.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LK-Peng/CNN-based-Cloud-Detection-Methods/1393a6886e62f1ed5a612d57c5a725c763a6b2cc/example/train/Masks/Forest_07_1209.tif -------------------------------------------------------------------------------- /example/train/Masks/Forest_08_0805.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LK-Peng/CNN-based-Cloud-Detection-Methods/1393a6886e62f1ed5a612d57c5a725c763a6b2cc/example/train/Masks/Forest_08_0805.tif -------------------------------------------------------------------------------- /example/train/Masks/Forest_10_1028.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LK-Peng/CNN-based-Cloud-Detection-Methods/1393a6886e62f1ed5a612d57c5a725c763a6b2cc/example/train/Masks/Forest_10_1028.tif -------------------------------------------------------------------------------- /example/train/Masks/GrassCrops_00_1014.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LK-Peng/CNN-based-Cloud-Detection-Methods/1393a6886e62f1ed5a612d57c5a725c763a6b2cc/example/train/Masks/GrassCrops_00_1014.tif -------------------------------------------------------------------------------- /example/train/Masks/GrassCrops_02_0608.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LK-Peng/CNN-based-Cloud-Detection-Methods/1393a6886e62f1ed5a612d57c5a725c763a6b2cc/example/train/Masks/GrassCrops_02_0608.tif -------------------------------------------------------------------------------- /example/train/Masks/GrassCrops_05_0806.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LK-Peng/CNN-based-Cloud-Detection-Methods/1393a6886e62f1ed5a612d57c5a725c763a6b2cc/example/train/Masks/GrassCrops_05_0806.tif -------------------------------------------------------------------------------- /example/train/Masks/GrassCrops_07_1723.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LK-Peng/CNN-based-Cloud-Detection-Methods/1393a6886e62f1ed5a612d57c5a725c763a6b2cc/example/train/Masks/GrassCrops_07_1723.tif -------------------------------------------------------------------------------- /example/train/Masks/GrassCrops_08_1326.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LK-Peng/CNN-based-Cloud-Detection-Methods/1393a6886e62f1ed5a612d57c5a725c763a6b2cc/example/train/Masks/GrassCrops_08_1326.tif -------------------------------------------------------------------------------- /example/train/Masks/GrassCrops_11_1014.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LK-Peng/CNN-based-Cloud-Detection-Methods/1393a6886e62f1ed5a612d57c5a725c763a6b2cc/example/train/Masks/GrassCrops_11_1014.tif -------------------------------------------------------------------------------- /example/train/Masks/Shrubland_00_0614.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LK-Peng/CNN-based-Cloud-Detection-Methods/1393a6886e62f1ed5a612d57c5a725c763a6b2cc/example/train/Masks/Shrubland_00_0614.tif -------------------------------------------------------------------------------- /example/train/Masks/Shrubland_02_1021.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LK-Peng/CNN-based-Cloud-Detection-Methods/1393a6886e62f1ed5a612d57c5a725c763a6b2cc/example/train/Masks/Shrubland_02_1021.tif -------------------------------------------------------------------------------- /example/train/Masks/Shrubland_03_0807.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LK-Peng/CNN-based-Cloud-Detection-Methods/1393a6886e62f1ed5a612d57c5a725c763a6b2cc/example/train/Masks/Shrubland_03_0807.tif -------------------------------------------------------------------------------- /example/train/Masks/Shrubland_04_0711.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LK-Peng/CNN-based-Cloud-Detection-Methods/1393a6886e62f1ed5a612d57c5a725c763a6b2cc/example/train/Masks/Shrubland_04_0711.tif -------------------------------------------------------------------------------- /example/train/Masks/Shrubland_05_0621.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LK-Peng/CNN-based-Cloud-Detection-Methods/1393a6886e62f1ed5a612d57c5a725c763a6b2cc/example/train/Masks/Shrubland_05_0621.tif -------------------------------------------------------------------------------- /example/train/Masks/Shrubland_08_1222.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LK-Peng/CNN-based-Cloud-Detection-Methods/1393a6886e62f1ed5a612d57c5a725c763a6b2cc/example/train/Masks/Shrubland_08_1222.tif -------------------------------------------------------------------------------- /example/train/Masks/SnowIce_00_1415.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LK-Peng/CNN-based-Cloud-Detection-Methods/1393a6886e62f1ed5a612d57c5a725c763a6b2cc/example/train/Masks/SnowIce_00_1415.tif -------------------------------------------------------------------------------- /example/train/Masks/SnowIce_01_0909.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LK-Peng/CNN-based-Cloud-Detection-Methods/1393a6886e62f1ed5a612d57c5a725c763a6b2cc/example/train/Masks/SnowIce_01_0909.tif -------------------------------------------------------------------------------- /example/train/Masks/SnowIce_02_0712.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LK-Peng/CNN-based-Cloud-Detection-Methods/1393a6886e62f1ed5a612d57c5a725c763a6b2cc/example/train/Masks/SnowIce_02_0712.tif -------------------------------------------------------------------------------- /example/train/Masks/SnowIce_05_0708.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LK-Peng/CNN-based-Cloud-Detection-Methods/1393a6886e62f1ed5a612d57c5a725c763a6b2cc/example/train/Masks/SnowIce_05_0708.tif -------------------------------------------------------------------------------- /example/train/Masks/SnowIce_08_0823.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LK-Peng/CNN-based-Cloud-Detection-Methods/1393a6886e62f1ed5a612d57c5a725c763a6b2cc/example/train/Masks/SnowIce_08_0823.tif -------------------------------------------------------------------------------- /example/train/Masks/SnowIce_10_1022.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LK-Peng/CNN-based-Cloud-Detection-Methods/1393a6886e62f1ed5a612d57c5a725c763a6b2cc/example/train/Masks/SnowIce_10_1022.tif -------------------------------------------------------------------------------- /example/train/Masks/Urban_00_0515.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LK-Peng/CNN-based-Cloud-Detection-Methods/1393a6886e62f1ed5a612d57c5a725c763a6b2cc/example/train/Masks/Urban_00_0515.tif -------------------------------------------------------------------------------- /example/train/Masks/Urban_01_0609.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LK-Peng/CNN-based-Cloud-Detection-Methods/1393a6886e62f1ed5a612d57c5a725c763a6b2cc/example/train/Masks/Urban_01_0609.tif -------------------------------------------------------------------------------- /example/train/Masks/Urban_02_0814.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LK-Peng/CNN-based-Cloud-Detection-Methods/1393a6886e62f1ed5a612d57c5a725c763a6b2cc/example/train/Masks/Urban_02_0814.tif -------------------------------------------------------------------------------- /example/train/Masks/Urban_05_0608.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LK-Peng/CNN-based-Cloud-Detection-Methods/1393a6886e62f1ed5a612d57c5a725c763a6b2cc/example/train/Masks/Urban_05_0608.tif -------------------------------------------------------------------------------- /example/train/Masks/Urban_06_1113.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LK-Peng/CNN-based-Cloud-Detection-Methods/1393a6886e62f1ed5a612d57c5a725c763a6b2cc/example/train/Masks/Urban_06_1113.tif -------------------------------------------------------------------------------- /example/train/Masks/Urban_08_0817.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LK-Peng/CNN-based-Cloud-Detection-Methods/1393a6886e62f1ed5a612d57c5a725c763a6b2cc/example/train/Masks/Urban_08_0817.tif -------------------------------------------------------------------------------- /example/train/Masks/Water_01_0922.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LK-Peng/CNN-based-Cloud-Detection-Methods/1393a6886e62f1ed5a612d57c5a725c763a6b2cc/example/train/Masks/Water_01_0922.tif -------------------------------------------------------------------------------- /example/train/Masks/Water_02_1010.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LK-Peng/CNN-based-Cloud-Detection-Methods/1393a6886e62f1ed5a612d57c5a725c763a6b2cc/example/train/Masks/Water_02_1010.tif -------------------------------------------------------------------------------- /example/train/Masks/Water_04_0726.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LK-Peng/CNN-based-Cloud-Detection-Methods/1393a6886e62f1ed5a612d57c5a725c763a6b2cc/example/train/Masks/Water_04_0726.tif -------------------------------------------------------------------------------- /example/train/Masks/Water_05_0517.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LK-Peng/CNN-based-Cloud-Detection-Methods/1393a6886e62f1ed5a612d57c5a725c763a6b2cc/example/train/Masks/Water_05_0517.tif -------------------------------------------------------------------------------- /example/train/Masks/Water_07_0706.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LK-Peng/CNN-based-Cloud-Detection-Methods/1393a6886e62f1ed5a612d57c5a725c763a6b2cc/example/train/Masks/Water_07_0706.tif -------------------------------------------------------------------------------- /example/train/Masks/Water_08_0606.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LK-Peng/CNN-based-Cloud-Detection-Methods/1393a6886e62f1ed5a612d57c5a725c763a6b2cc/example/train/Masks/Water_08_0606.tif -------------------------------------------------------------------------------- /example/train/Masks/Wetlands_00_1426.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LK-Peng/CNN-based-Cloud-Detection-Methods/1393a6886e62f1ed5a612d57c5a725c763a6b2cc/example/train/Masks/Wetlands_00_1426.tif -------------------------------------------------------------------------------- /example/train/Masks/Wetlands_01_0416.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LK-Peng/CNN-based-Cloud-Detection-Methods/1393a6886e62f1ed5a612d57c5a725c763a6b2cc/example/train/Masks/Wetlands_01_0416.tif -------------------------------------------------------------------------------- /example/train/Masks/Wetlands_04_1107.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LK-Peng/CNN-based-Cloud-Detection-Methods/1393a6886e62f1ed5a612d57c5a725c763a6b2cc/example/train/Masks/Wetlands_04_1107.tif -------------------------------------------------------------------------------- /example/train/Masks/Wetlands_05_0713.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LK-Peng/CNN-based-Cloud-Detection-Methods/1393a6886e62f1ed5a612d57c5a725c763a6b2cc/example/train/Masks/Wetlands_05_0713.tif -------------------------------------------------------------------------------- /example/train/Masks/Wetlands_06_1127.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LK-Peng/CNN-based-Cloud-Detection-Methods/1393a6886e62f1ed5a612d57c5a725c763a6b2cc/example/train/Masks/Wetlands_06_1127.tif -------------------------------------------------------------------------------- /example/train/Masks/Wetlands_11_1217.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LK-Peng/CNN-based-Cloud-Detection-Methods/1393a6886e62f1ed5a612d57c5a725c763a6b2cc/example/train/Masks/Wetlands_11_1217.tif -------------------------------------------------------------------------------- /example/train/train.txt: -------------------------------------------------------------------------------- 1 | Barren_00_0214.tif 2 | Barren_01_0213.tif 3 | Forest_03_0209.tif 4 | Forest_05_0415.tif 5 | GrassCrops_00_1014.tif 6 | GrassCrops_02_0608.tif 7 | Shrubland_00_0614.tif 8 | Shrubland_02_1021.tif 9 | SnowIce_00_1415.tif 10 | SnowIce_01_0909.tif 11 | Urban_00_0515.tif 12 | Urban_01_0609.tif 13 | Water_01_0922.tif 14 | Water_02_1010.tif 15 | Wetlands_00_1426.tif 16 | Wetlands_01_0416.tif 17 | -------------------------------------------------------------------------------- /example/val/Images/Barren_05_0208.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LK-Peng/CNN-based-Cloud-Detection-Methods/1393a6886e62f1ed5a612d57c5a725c763a6b2cc/example/val/Images/Barren_05_0208.tif -------------------------------------------------------------------------------- /example/val/Images/Forest_00_1008.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LK-Peng/CNN-based-Cloud-Detection-Methods/1393a6886e62f1ed5a612d57c5a725c763a6b2cc/example/val/Images/Forest_00_1008.tif -------------------------------------------------------------------------------- /example/val/Images/GrassCrops_01_0413.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LK-Peng/CNN-based-Cloud-Detection-Methods/1393a6886e62f1ed5a612d57c5a725c763a6b2cc/example/val/Images/GrassCrops_01_0413.tif -------------------------------------------------------------------------------- /example/val/Images/Shrubland_01_0508.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LK-Peng/CNN-based-Cloud-Detection-Methods/1393a6886e62f1ed5a612d57c5a725c763a6b2cc/example/val/Images/Shrubland_01_0508.tif -------------------------------------------------------------------------------- /example/val/Images/SnowIce_11_1123.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LK-Peng/CNN-based-Cloud-Detection-Methods/1393a6886e62f1ed5a612d57c5a725c763a6b2cc/example/val/Images/SnowIce_11_1123.tif -------------------------------------------------------------------------------- /example/val/Images/Urban_03_0619.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LK-Peng/CNN-based-Cloud-Detection-Methods/1393a6886e62f1ed5a612d57c5a725c763a6b2cc/example/val/Images/Urban_03_0619.tif -------------------------------------------------------------------------------- /example/val/Images/Water_11_0728.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LK-Peng/CNN-based-Cloud-Detection-Methods/1393a6886e62f1ed5a612d57c5a725c763a6b2cc/example/val/Images/Water_11_0728.tif -------------------------------------------------------------------------------- /example/val/Images/Wetlands_07_1826.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LK-Peng/CNN-based-Cloud-Detection-Methods/1393a6886e62f1ed5a612d57c5a725c763a6b2cc/example/val/Images/Wetlands_07_1826.tif -------------------------------------------------------------------------------- /example/val/Masks/Barren_05_0208.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LK-Peng/CNN-based-Cloud-Detection-Methods/1393a6886e62f1ed5a612d57c5a725c763a6b2cc/example/val/Masks/Barren_05_0208.tif -------------------------------------------------------------------------------- /example/val/Masks/Forest_00_1008.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LK-Peng/CNN-based-Cloud-Detection-Methods/1393a6886e62f1ed5a612d57c5a725c763a6b2cc/example/val/Masks/Forest_00_1008.tif -------------------------------------------------------------------------------- /example/val/Masks/GrassCrops_01_0413.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LK-Peng/CNN-based-Cloud-Detection-Methods/1393a6886e62f1ed5a612d57c5a725c763a6b2cc/example/val/Masks/GrassCrops_01_0413.tif -------------------------------------------------------------------------------- /example/val/Masks/Shrubland_01_0508.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LK-Peng/CNN-based-Cloud-Detection-Methods/1393a6886e62f1ed5a612d57c5a725c763a6b2cc/example/val/Masks/Shrubland_01_0508.tif -------------------------------------------------------------------------------- /example/val/Masks/SnowIce_11_1123.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LK-Peng/CNN-based-Cloud-Detection-Methods/1393a6886e62f1ed5a612d57c5a725c763a6b2cc/example/val/Masks/SnowIce_11_1123.tif -------------------------------------------------------------------------------- /example/val/Masks/Urban_03_0619.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LK-Peng/CNN-based-Cloud-Detection-Methods/1393a6886e62f1ed5a612d57c5a725c763a6b2cc/example/val/Masks/Urban_03_0619.tif -------------------------------------------------------------------------------- /example/val/Masks/Water_11_0728.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LK-Peng/CNN-based-Cloud-Detection-Methods/1393a6886e62f1ed5a612d57c5a725c763a6b2cc/example/val/Masks/Water_11_0728.tif -------------------------------------------------------------------------------- /example/val/Masks/Wetlands_07_1826.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LK-Peng/CNN-based-Cloud-Detection-Methods/1393a6886e62f1ed5a612d57c5a725c763a6b2cc/example/val/Masks/Wetlands_07_1826.tif -------------------------------------------------------------------------------- /example/val/val.txt: -------------------------------------------------------------------------------- 1 | Barren_05_0208.tif 2 | Forest_00_1008.tif 3 | GrassCrops_01_0413.tif 4 | Shrubland_01_0508.tif 5 | SnowIce_11_1123.tif 6 | Urban_03_0619.tif 7 | Water_11_0728.tif 8 | Wetlands_07_1826.tif 9 | -------------------------------------------------------------------------------- /inference-mix/README.TXT: -------------------------------------------------------------------------------- 1 | Folder to store the accuracy of prediction results. 2 | The accuracy could calculated based on part of the prediction results by comparator.py. -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import time 4 | import numpy as np 5 | import torch 6 | from torch.utils.data import DataLoader 7 | from tqdm import tqdm 8 | from multiprocessing import Pool 9 | 10 | # os.environ['CUDA_VISIBLE_DEVICES'] = '1' 11 | 12 | from dataloaders.dataset import RSSet 13 | from model import get_network 14 | from utils.loss import SegmentationLosses 15 | from utils.calculate_weights import calculate_weigths_labels 16 | from utils.metrics import Evaluator, BoundaryEvaluator 17 | from utils.img_saver import save_img 18 | from config import get_config_test 19 | 20 | 21 | class Inference(object): 22 | def __init__(self, args): 23 | self.args = args 24 | 25 | # Define Dataloader 26 | kwargs = {'num_workers': args.workers, 'pin_memory': True} 27 | if args.no_gt: 28 | test_set = RSSet(args, split='test') 29 | self.test_loader = DataLoader(test_set, batch_size=args.batch_size, shuffle=False, **kwargs) 30 | else: 31 | args.val_list, args.val_root = args.test_list, args.test_root 32 | test_set = RSSet(args, split='val') 33 | self.test_loader = DataLoader(test_set, batch_size=args.batch_size, shuffle=False, **kwargs) 34 | 35 | # Define Criterion 36 | # whether to use class balanced weights 37 | if args.use_balanced_weights: 38 | classes_weights_path = os.path.join(os.path.split(args.train_list)[0], 39 | args.dataset + '_classes_weights.npy') 40 | if os.path.isfile(classes_weights_path): 41 | weight = np.load(classes_weights_path) 42 | else: 43 | weight = calculate_weigths_labels(os.path.split(args.train_list)[0], 44 | args.dataset, self.train_loader, self.nclass) 45 | weight = torch.from_numpy(weight.astype(np.float32)) 46 | else: 47 | weight = None 48 | self.criterion = SegmentationLosses(weight=weight, cuda=args.cuda).build_loss(mode=args.loss_type) 49 | 50 | self.nclass = args.num_classes 51 | 52 | # Define network 53 | model = get_network(args) 54 | 55 | # count parameters 56 | param_count = 0 57 | for param in model.parameters(): 58 | param_count += param.view(-1).size()[0] 59 | print('Total parameters: {}M ({})'.format(param_count / 1e6, param_count)) 60 | self.model = model 61 | 62 | # define multiprocess 63 | if args.num_proc: 64 | self.p = Pool(processes=args.num_proc) 65 | else: 66 | self.p = None 67 | 68 | # Define Evaluator 69 | self.evaluator = Evaluator(self.nclass) 70 | self.boundaryevaluator_3 = BoundaryEvaluator(self.nclass, self.p, self.args.num_proc, bound_th=3) 71 | self.boundaryevaluator_5 = BoundaryEvaluator(self.nclass, self.p, self.args.num_proc, bound_th=5) 72 | 73 | # Using cuda 74 | if args.cuda: 75 | self.model = torch.nn.DataParallel(self.model, device_ids=self.args.gpu_ids) 76 | self.model = self.model.cuda() 77 | 78 | # define dict to save metric 79 | self.metric_dct = dict() 80 | 81 | def test(self, load_path): 82 | # load 83 | checkpoint = torch.load(load_path) 84 | if self.args.cuda: 85 | self.model.module.load_state_dict(checkpoint['state_dict']) 86 | else: 87 | self.model.load_state_dict(checkpoint['state_dict']) 88 | self.model.eval() 89 | 90 | tbar = tqdm(self.test_loader, desc='\r') 91 | num_img_val = len(self.test_loader.dataset) # image num 92 | out_path = os.path.join(self.args.out_path, os.path.split(load_path)[-1]) 93 | if not os.path.exists(out_path): 94 | os.mkdir(out_path) 95 | 96 | val_loss = 0.0 97 | self.evaluator.reset() 98 | self.boundaryevaluator_3.reset() 99 | self.boundaryevaluator_5.reset() 100 | for i, sample in enumerate(tbar): 101 | image, target = sample['image'], sample['label'] 102 | if self.args.cuda: 103 | image, target = image.cuda(), target.cuda() 104 | with torch.no_grad(): 105 | output = self.model(image) 106 | if not self.args.no_gt: # ground truth is available 107 | loss = self.criterion(output, target) 108 | val_loss += loss.item() 109 | tbar.set_description('Validation loss: %.3f' % (val_loss / (i + 1))) 110 | pred, target = output.data.cpu().numpy(), target.cpu().numpy() 111 | pred = np.argmax(pred, axis=1).astype(np.uint8) 112 | 113 | if not self.args.no_gt: # ground truth is available 114 | time.sleep(0.1) 115 | self.evaluator.add_batch(target, pred) 116 | self.boundaryevaluator_3.add_batch(target, pred) 117 | self.boundaryevaluator_5.add_batch(target, pred) 118 | 119 | if self.args.save_img: 120 | c, _, _ = pred.shape 121 | for ic in range(c): 122 | filename = os.path.split(self.test_loader.dataset.img_files[i * self.args.batch_size + ic])[-1] 123 | save_img(pred[ic, :, :], os.path.join(out_path, filename)) 124 | 125 | self.metric_dct = { 126 | 'PA': self.evaluator.Pixel_Accuracy(), 127 | 'MPA': self.evaluator.Pixel_Accuracy_Class(), 128 | 'MIoU': self.evaluator.Mean_Intersection_over_Union(), 129 | 'FWIoU': self.evaluator.Frequency_Weighted_Intersection_over_Union(), 130 | 'Precision': self.evaluator.Precision(), 131 | 'Recall': self.evaluator.Recall(), 132 | 'F1': self.evaluator.F_score(), 133 | 'F_boundary_3': self.boundaryevaluator_3.F_score_boundary().tolist(), 134 | 'Pr_boundary_3': self.boundaryevaluator_3.Precision_boundary().tolist(), 135 | 'Re_boundary_3': self.boundaryevaluator_3.Recall_boundary().tolist(), 136 | 'F_boundary_5': self.boundaryevaluator_5.F_score_boundary().tolist(), 137 | 'Pr_boundary_5': self.boundaryevaluator_5.Precision_boundary().tolist(), 138 | 'Re_boundary_5': self.boundaryevaluator_5.Recall_boundary().tolist(), 139 | 'loss': val_loss / num_img_val, 140 | } 141 | 142 | print('Validation:') 143 | print('[numImages: %5d]' % num_img_val) 144 | print('Loss: %.3f' % val_loss) 145 | print(self.metric_dct) 146 | 147 | 148 | def main(): 149 | 150 | load_roots = { 151 | 'DeeplabV3Plus-seed1': './ckp/DeeplabV3Plus-seed1.pth.tar', 152 | 'DeeplabV3Plus-seed2': './ckp/DeeplabV3Plus-seed2.pth.tar', 153 | 'DeeplabV3Plus-seed3': './ckp/DeeplabV3Plus-seed3.pth.tar', 154 | 'DeeplabV3Plus-seed4': './ckp/DeeplabV3Plus-seed4.pth.tar', 155 | } 156 | 157 | net_names = [ 158 | 'DeeplabV3Plus-seed1', 'DeeplabV3Plus-seed2', 'DeeplabV3Plus-seed3', 'DeeplabV3Plus-seed4', 159 | ] 160 | 161 | for net_name in net_names: 162 | print('Using model {}'.format(net_name)) 163 | start1 = time.time() 164 | 165 | if 'dilation' in net_name: 166 | args = get_config_test('UNet-dilation') 167 | args.dilation = int(net_name.split(sep='-')[0][-1]) 168 | else: 169 | args = get_config_test('-'.join(net_name.split(sep='-')[0:-1])) 170 | args.seed = int(net_name.split(sep='-')[-1][4:]) 171 | if 'MSCFF' in args.net or 'Deeplab' in args.net: 172 | args.batch_size = 256 173 | 174 | # define parameters files 175 | args.load_paths = [load_roots[net_name]] 176 | 177 | # define output path 178 | args.out_path = os.path.join('./inference', net_name) 179 | if not os.path.exists(args.out_path): 180 | os.mkdir(args.out_path) 181 | 182 | print(args) 183 | 184 | torch.manual_seed(args.seed) # set seed for the CPU 185 | np.random.seed(args.seed) 186 | torch.cuda.manual_seed_all(args.seed) 187 | torch.backends.cudnn.deterministic = True 188 | torch.backends.cudnn.benchmark = False 189 | 190 | for load_path in args.load_paths: 191 | print(load_path) 192 | start2 = time.time() 193 | inference = Inference(args) 194 | inference.test(load_path) 195 | with open(os.path.join(args.out_path, '{}-pixel.json'.format(os.path.split(load_path)[-1])), 'w') as f: 196 | json.dump(inference.metric_dct, f, indent=4) 197 | inference.p.close() 198 | del inference 199 | print('One parameter file using {}s!'.format(time.time() - start2)) 200 | 201 | print('All parameter file using {}s!'.format(time.time() - start1)) 202 | 203 | 204 | if __name__ == '__main__': 205 | start = time.time() 206 | main() 207 | print('Using {}s!'.format(time.time() - start)) 208 | 209 | -------------------------------------------------------------------------------- /inference/README.TXT: -------------------------------------------------------------------------------- 1 | Folder to store the prediction results. -------------------------------------------------------------------------------- /interpretation/calculate_erf.py: -------------------------------------------------------------------------------- 1 | """ 2 | Date: 2020-12 3 | Function: 计算有效感受野,包括使用论文中原始方法及区分方向的半变异函数计算 4 | Ref:Luo W, Li Y, Urtasun R, et al. Understanding the effective receptive field in 5 | deep convolutional neural networks[J]. arXiv preprint arXiv:1701.04128, 2017. 6 | """ 7 | 8 | import numpy as np 9 | 10 | 11 | def calculate_erf(img, h_ct, w_ct): 12 | img = np.sum(abs(img), axis=0) 13 | # square of pixel number (> (1-95.45) * center) 14 | erf = img > ((1 - 0.9545) * img[h_ct, w_ct]) 15 | return erf.sum() ** 0.5 16 | 17 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | from .deeplab import DeepLab 2 | from .mfcnn import MFCNN 3 | from .mscff import MSCFF 4 | from .munet import MUNet 5 | from .tlnet import TUNet 6 | from .unet import UNet 7 | from .unet_3 import UNet_3 8 | from .unet_2 import UNet_2 9 | from .unet_1 import UNet_1 10 | from .unet_dilation import UNet_dil 11 | from .unet_s3 import UNetS3 12 | from .unet_s2 import UNetS2 13 | from .unet_s1 import UNetS1 14 | 15 | 16 | def get_network(args): 17 | if args.net == 'DeeplabV3Plus': 18 | net = DeepLab(backbone=args.backbone, output_stride=args.out_stride, 19 | num_classes=args.num_classes, in_channels=args.in_channels, 20 | sync_bn=args.sync_bn, freeze_bn=args.freeze_bn, 21 | pretrained=args.pretrained) 22 | elif args.net == 'MFCNN': 23 | net = MFCNN(n_channels=args.in_channels, n_classes=args.num_classes, dropout_p=args.dropout_p) 24 | elif args.net == 'MSCFF': 25 | net = MSCFF(n_channels=args.in_channels, n_classes=args.num_classes) 26 | elif args.net == 'MUNet': 27 | net = MUNet(n_channels=args.in_channels, n_classes=args.num_classes) 28 | elif args.net == 'TLNet': 29 | net = TLNet(n_channels=args.in_channels, n_classes=args.num_classes) 30 | elif args.net == 'UNet': 31 | net = UNet(n_channels=args.in_channels, n_classes=args.num_classes, bilinear=False) 32 | elif args.net == 'UNetS3': 33 | net = UNetS3(n_channels=args.in_channels, n_classes=args.num_classes, bilinear=False) 34 | elif args.net == 'UNetS2': 35 | net = UNetS2(n_channels=args.in_channels, n_classes=args.num_classes, bilinear=False) 36 | elif args.net == 'UNetS1': 37 | net = UNetS1(n_channels=args.in_channels, n_classes=args.num_classes, bilinear=False) 38 | elif args.net == 'UNet-3': 39 | net = UNet_3(n_channels=args.in_channels, n_classes=args.num_classes, bilinear=False) 40 | elif args.net == 'UNet-2': 41 | net = UNet_2(n_channels=args.in_channels, n_classes=args.num_classes, bilinear=False) 42 | elif args.net == 'UNet-1': 43 | net = UNet_1(n_channels=args.in_channels, n_classes=args.num_classes, bilinear=False) 44 | elif args.net == 'UNet-dilation': 45 | net = UNet_dil(n_channels=args.in_channels, n_classes=args.num_classes, bilinear=False, 46 | maxpool=False, dilation=args.dilation) 47 | else: 48 | raise NotImplementedError('The network {} is not supported yet'.format(args.net)) 49 | 50 | return net 51 | 52 | 53 | -------------------------------------------------------------------------------- /model/deeplab/__init__.py: -------------------------------------------------------------------------------- 1 | from .deeplab import DeepLab -------------------------------------------------------------------------------- /model/deeplab/aspp.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from .sync_batchnorm.batchnorm import SynchronizedBatchNorm2d 6 | 7 | class _ASPPModule(nn.Module): 8 | def __init__(self, inplanes, planes, kernel_size, padding, dilation, BatchNorm): 9 | super(_ASPPModule, self).__init__() 10 | self.atrous_conv = nn.Conv2d(inplanes, planes, kernel_size=kernel_size, 11 | stride=1, padding=padding, dilation=dilation, bias=False) 12 | self.bn = BatchNorm(planes) 13 | self.relu = nn.ReLU() 14 | 15 | self._init_weight() 16 | 17 | def forward(self, x): 18 | x = self.atrous_conv(x) 19 | x = self.bn(x) 20 | 21 | return self.relu(x) 22 | 23 | def _init_weight(self): 24 | for m in self.modules(): 25 | if isinstance(m, nn.Conv2d): 26 | torch.nn.init.kaiming_normal_(m.weight) 27 | elif isinstance(m, SynchronizedBatchNorm2d): 28 | m.weight.data.fill_(1) 29 | m.bias.data.zero_() 30 | elif isinstance(m, nn.BatchNorm2d): 31 | m.weight.data.fill_(1) 32 | m.bias.data.zero_() 33 | 34 | class ASPP(nn.Module): 35 | def __init__(self, backbone, output_stride, BatchNorm): 36 | super(ASPP, self).__init__() 37 | if backbone == 'drn': 38 | inplanes = 512 39 | elif backbone == 'mobilenet': 40 | inplanes = 320 41 | else: 42 | inplanes = 2048 43 | if output_stride == 16: 44 | dilations = [1, 6, 12, 18] 45 | elif output_stride == 8: 46 | dilations = [1, 12, 24, 36] 47 | else: 48 | raise NotImplementedError 49 | 50 | self.aspp1 = _ASPPModule(inplanes, 256, 1, padding=0, dilation=dilations[0], BatchNorm=BatchNorm) 51 | self.aspp2 = _ASPPModule(inplanes, 256, 3, padding=dilations[1], dilation=dilations[1], BatchNorm=BatchNorm) 52 | self.aspp3 = _ASPPModule(inplanes, 256, 3, padding=dilations[2], dilation=dilations[2], BatchNorm=BatchNorm) 53 | self.aspp4 = _ASPPModule(inplanes, 256, 3, padding=dilations[3], dilation=dilations[3], BatchNorm=BatchNorm) 54 | 55 | self.global_avg_pool = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)), 56 | nn.Conv2d(inplanes, 256, 1, stride=1, bias=False), 57 | BatchNorm(256), 58 | nn.ReLU()) 59 | self.conv1 = nn.Conv2d(1280, 256, 1, bias=False) 60 | self.bn1 = BatchNorm(256) 61 | self.relu = nn.ReLU() 62 | self.dropout = nn.Dropout(0.5) 63 | self._init_weight() 64 | 65 | def forward(self, x): 66 | x1 = self.aspp1(x) 67 | x2 = self.aspp2(x) 68 | x3 = self.aspp3(x) 69 | x4 = self.aspp4(x) 70 | x5 = self.global_avg_pool(x) 71 | x5 = F.interpolate(x5, size=x4.size()[2:], mode='bilinear', align_corners=True) 72 | x = torch.cat((x1, x2, x3, x4, x5), dim=1) 73 | 74 | x = self.conv1(x) 75 | x = self.bn1(x) 76 | x = self.relu(x) 77 | 78 | return self.dropout(x) 79 | 80 | def _init_weight(self): 81 | for m in self.modules(): 82 | if isinstance(m, nn.Conv2d): 83 | # n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 84 | # m.weight.data.normal_(0, math.sqrt(2. / n)) 85 | torch.nn.init.kaiming_normal_(m.weight) 86 | elif isinstance(m, SynchronizedBatchNorm2d): 87 | m.weight.data.fill_(1) 88 | m.bias.data.zero_() 89 | elif isinstance(m, nn.BatchNorm2d): 90 | m.weight.data.fill_(1) 91 | m.bias.data.zero_() 92 | 93 | 94 | def build_aspp(backbone, output_stride, BatchNorm): 95 | return ASPP(backbone, output_stride, BatchNorm) -------------------------------------------------------------------------------- /model/deeplab/backbone/__init__.py: -------------------------------------------------------------------------------- 1 | from .resnet import ResNet101 2 | from .xception import AlignedXception 3 | from .drn import drn_d_54 4 | from .mobilenet import MobileNetV2 5 | 6 | 7 | def build_backbone(backbone, output_stride, in_channels, BatchNorm, pretrained=True): 8 | if backbone == 'resnet': 9 | return ResNet101(output_stride, in_channels, BatchNorm, pretrained=pretrained) 10 | elif backbone == 'xception': 11 | return AlignedXception(output_stride, BatchNorm) 12 | elif backbone == 'drn': 13 | return drn_d_54(BatchNorm) 14 | elif backbone == 'mobilenet': 15 | return MobileNetV2(output_stride, BatchNorm) 16 | else: 17 | raise NotImplementedError 18 | -------------------------------------------------------------------------------- /model/deeplab/backbone/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LK-Peng/CNN-based-Cloud-Detection-Methods/1393a6886e62f1ed5a612d57c5a725c763a6b2cc/model/deeplab/backbone/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /model/deeplab/backbone/__pycache__/drn.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LK-Peng/CNN-based-Cloud-Detection-Methods/1393a6886e62f1ed5a612d57c5a725c763a6b2cc/model/deeplab/backbone/__pycache__/drn.cpython-37.pyc -------------------------------------------------------------------------------- /model/deeplab/backbone/__pycache__/mobilenet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LK-Peng/CNN-based-Cloud-Detection-Methods/1393a6886e62f1ed5a612d57c5a725c763a6b2cc/model/deeplab/backbone/__pycache__/mobilenet.cpython-37.pyc -------------------------------------------------------------------------------- /model/deeplab/backbone/__pycache__/resnet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LK-Peng/CNN-based-Cloud-Detection-Methods/1393a6886e62f1ed5a612d57c5a725c763a6b2cc/model/deeplab/backbone/__pycache__/resnet.cpython-37.pyc -------------------------------------------------------------------------------- /model/deeplab/backbone/__pycache__/xception.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LK-Peng/CNN-based-Cloud-Detection-Methods/1393a6886e62f1ed5a612d57c5a725c763a6b2cc/model/deeplab/backbone/__pycache__/xception.cpython-37.pyc -------------------------------------------------------------------------------- /model/deeplab/backbone/mobilenet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import torch.nn as nn 4 | import math 5 | from model.deeplab.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d 6 | import torch.utils.model_zoo as model_zoo 7 | 8 | def conv_bn(inp, oup, stride, BatchNorm): 9 | return nn.Sequential( 10 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False), 11 | BatchNorm(oup), 12 | nn.ReLU6(inplace=True) 13 | ) 14 | 15 | 16 | def fixed_padding(inputs, kernel_size, dilation): 17 | kernel_size_effective = kernel_size + (kernel_size - 1) * (dilation - 1) 18 | pad_total = kernel_size_effective - 1 19 | pad_beg = pad_total // 2 20 | pad_end = pad_total - pad_beg 21 | padded_inputs = F.pad(inputs, (pad_beg, pad_end, pad_beg, pad_end)) 22 | return padded_inputs 23 | 24 | 25 | class InvertedResidual(nn.Module): 26 | def __init__(self, inp, oup, stride, dilation, expand_ratio, BatchNorm): 27 | super(InvertedResidual, self).__init__() 28 | self.stride = stride 29 | assert stride in [1, 2] 30 | 31 | hidden_dim = round(inp * expand_ratio) 32 | self.use_res_connect = self.stride == 1 and inp == oup 33 | self.kernel_size = 3 34 | self.dilation = dilation 35 | 36 | if expand_ratio == 1: 37 | self.conv = nn.Sequential( 38 | # dw 39 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 0, dilation, groups=hidden_dim, bias=False), 40 | BatchNorm(hidden_dim), 41 | nn.ReLU6(inplace=True), 42 | # pw-linear 43 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, 1, 1, bias=False), 44 | BatchNorm(oup), 45 | ) 46 | else: 47 | self.conv = nn.Sequential( 48 | # pw 49 | nn.Conv2d(inp, hidden_dim, 1, 1, 0, 1, bias=False), 50 | BatchNorm(hidden_dim), 51 | nn.ReLU6(inplace=True), 52 | # dw 53 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 0, dilation, groups=hidden_dim, bias=False), 54 | BatchNorm(hidden_dim), 55 | nn.ReLU6(inplace=True), 56 | # pw-linear 57 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, 1, bias=False), 58 | BatchNorm(oup), 59 | ) 60 | 61 | def forward(self, x): 62 | x_pad = fixed_padding(x, self.kernel_size, dilation=self.dilation) 63 | if self.use_res_connect: 64 | x = x + self.conv(x_pad) 65 | else: 66 | x = self.conv(x_pad) 67 | return x 68 | 69 | 70 | class MobileNetV2(nn.Module): 71 | def __init__(self, output_stride=8, BatchNorm=None, width_mult=1., pretrained=True): 72 | super(MobileNetV2, self).__init__() 73 | block = InvertedResidual 74 | input_channel = 32 75 | current_stride = 1 76 | rate = 1 77 | interverted_residual_setting = [ 78 | # t, c, n, s 79 | [1, 16, 1, 1], 80 | [6, 24, 2, 2], 81 | [6, 32, 3, 2], 82 | [6, 64, 4, 2], 83 | [6, 96, 3, 1], 84 | [6, 160, 3, 2], 85 | [6, 320, 1, 1], 86 | ] 87 | 88 | # building first layer 89 | input_channel = int(input_channel * width_mult) 90 | self.features = [conv_bn(3, input_channel, 2, BatchNorm)] 91 | current_stride *= 2 92 | # building inverted residual blocks 93 | for t, c, n, s in interverted_residual_setting: 94 | if current_stride == output_stride: 95 | stride = 1 96 | dilation = rate 97 | rate *= s 98 | else: 99 | stride = s 100 | dilation = 1 101 | current_stride *= s 102 | output_channel = int(c * width_mult) 103 | for i in range(n): 104 | if i == 0: 105 | self.features.append(block(input_channel, output_channel, stride, dilation, t, BatchNorm)) 106 | else: 107 | self.features.append(block(input_channel, output_channel, 1, dilation, t, BatchNorm)) 108 | input_channel = output_channel 109 | self.features = nn.Sequential(*self.features) 110 | self._initialize_weights() 111 | 112 | if pretrained: 113 | self._load_pretrained_model() 114 | 115 | self.low_level_features = self.features[0:4] 116 | self.high_level_features = self.features[4:] 117 | 118 | def forward(self, x): 119 | low_level_feat = self.low_level_features(x) 120 | x = self.high_level_features(low_level_feat) 121 | return x, low_level_feat 122 | 123 | def _load_pretrained_model(self): 124 | pretrain_dict = model_zoo.load_url('http://jeff95.me/models/mobilenet_v2-6a65762b.pth') 125 | model_dict = {} 126 | state_dict = self.state_dict() 127 | for k, v in pretrain_dict.items(): 128 | if k in state_dict: 129 | model_dict[k] = v 130 | state_dict.update(model_dict) 131 | self.load_state_dict(state_dict) 132 | 133 | def _initialize_weights(self): 134 | for m in self.modules(): 135 | if isinstance(m, nn.Conv2d): 136 | # n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 137 | # m.weight.data.normal_(0, math.sqrt(2. / n)) 138 | torch.nn.init.kaiming_normal_(m.weight) 139 | elif isinstance(m, SynchronizedBatchNorm2d): 140 | m.weight.data.fill_(1) 141 | m.bias.data.zero_() 142 | elif isinstance(m, nn.BatchNorm2d): 143 | m.weight.data.fill_(1) 144 | m.bias.data.zero_() 145 | 146 | if __name__ == "__main__": 147 | input = torch.rand(1, 3, 512, 512) 148 | model = MobileNetV2(output_stride=16, BatchNorm=nn.BatchNorm2d) 149 | output, low_level_feat = model(input) 150 | print(output.size()) 151 | print(low_level_feat.size()) 152 | -------------------------------------------------------------------------------- /model/deeplab/backbone/resnet.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.utils.model_zoo as model_zoo 5 | from model.deeplab.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d 6 | 7 | 8 | class Bottleneck(nn.Module): 9 | expansion = 4 10 | 11 | def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None, BatchNorm=None): 12 | super(Bottleneck, self).__init__() 13 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 14 | self.bn1 = BatchNorm(planes) 15 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 16 | dilation=dilation, padding=dilation, bias=False) 17 | self.bn2 = BatchNorm(planes) 18 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 19 | self.bn3 = BatchNorm(planes * 4) 20 | self.relu = nn.ReLU(inplace=True) 21 | self.downsample = downsample 22 | self.stride = stride 23 | self.dilation = dilation 24 | 25 | def forward(self, x): 26 | residual = x 27 | 28 | out = self.conv1(x) 29 | out = self.bn1(out) 30 | out = self.relu(out) 31 | 32 | out = self.conv2(out) 33 | out = self.bn2(out) 34 | out = self.relu(out) 35 | 36 | out = self.conv3(out) 37 | out = self.bn3(out) 38 | 39 | if self.downsample is not None: 40 | residual = self.downsample(x) 41 | 42 | out += residual 43 | out = self.relu(out) 44 | 45 | return out 46 | 47 | 48 | class ResNet(nn.Module): 49 | 50 | def __init__(self, block, layers, output_stride, BatchNorm, pretrained=True): 51 | self.inplanes = 64 52 | super(ResNet, self).__init__() 53 | blocks = [1, 2, 4] 54 | if output_stride == 16: 55 | strides = [1, 2, 2, 1] 56 | dilations = [1, 1, 1, 2] # original 57 | elif output_stride == 8: 58 | strides = [1, 2, 1, 1] 59 | dilations = [1, 1, 2, 4] # original 60 | # dilations = [1, 1, 1, 1] 61 | # dilations = [1, 1, 2, 1] 62 | else: 63 | raise NotImplementedError 64 | 65 | # Modules 66 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) 67 | self.bn1 = BatchNorm(64) 68 | self.relu = nn.ReLU(inplace=True) 69 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 70 | 71 | self.layer1 = self._make_layer(block, 64, layers[0], stride=strides[0], dilation=dilations[0], BatchNorm=BatchNorm) 72 | self.layer2 = self._make_layer(block, 128, layers[1], stride=strides[1], dilation=dilations[1], BatchNorm=BatchNorm) 73 | self.layer3 = self._make_layer(block, 256, layers[2], stride=strides[2], dilation=dilations[2], BatchNorm=BatchNorm) 74 | self.layer4 = self._make_MG_unit(block, 512, blocks=blocks, stride=strides[3], dilation=dilations[3], BatchNorm=BatchNorm) 75 | # self.layer4 = self._make_layer(block, 512, layers[3], stride=strides[3], dilation=dilations[3], BatchNorm=BatchNorm) 76 | self._init_weight() 77 | 78 | if pretrained: 79 | self._load_pretrained_model() 80 | 81 | def _make_layer(self, block, planes, blocks, stride=1, dilation=1, BatchNorm=None): 82 | downsample = None 83 | if stride != 1 or self.inplanes != planes * block.expansion: 84 | downsample = nn.Sequential( 85 | nn.Conv2d(self.inplanes, planes * block.expansion, 86 | kernel_size=1, stride=stride, bias=False), 87 | BatchNorm(planes * block.expansion), 88 | ) 89 | 90 | layers = [] 91 | layers.append(block(self.inplanes, planes, stride, dilation, downsample, BatchNorm)) 92 | self.inplanes = planes * block.expansion 93 | for i in range(1, blocks): 94 | layers.append(block(self.inplanes, planes, dilation=dilation, BatchNorm=BatchNorm)) 95 | 96 | return nn.Sequential(*layers) 97 | 98 | def _make_MG_unit(self, block, planes, blocks, stride=1, dilation=1, BatchNorm=None): 99 | downsample = None 100 | if stride != 1 or self.inplanes != planes * block.expansion: 101 | downsample = nn.Sequential( 102 | nn.Conv2d(self.inplanes, planes * block.expansion, 103 | kernel_size=1, stride=stride, bias=False), 104 | BatchNorm(planes * block.expansion), 105 | ) 106 | 107 | layers = [] 108 | layers.append(block(self.inplanes, planes, stride, dilation=blocks[0]*dilation, 109 | downsample=downsample, BatchNorm=BatchNorm)) 110 | self.inplanes = planes * block.expansion 111 | for i in range(1, len(blocks)): 112 | layers.append(block(self.inplanes, planes, stride=1, 113 | dilation=blocks[i]*dilation, BatchNorm=BatchNorm)) 114 | 115 | return nn.Sequential(*layers) 116 | 117 | def forward(self, input): 118 | x = self.conv1(input) 119 | x = self.bn1(x) 120 | x = self.relu(x) 121 | x = self.maxpool(x) 122 | 123 | x = self.layer1(x) 124 | low_level_feat = x 125 | x = self.layer2(x) 126 | x = self.layer3(x) 127 | x = self.layer4(x) 128 | return x, low_level_feat 129 | 130 | def _init_weight(self): 131 | for m in self.modules(): 132 | if isinstance(m, nn.Conv2d): 133 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 134 | m.weight.data.normal_(0, math.sqrt(2. / n)) 135 | elif isinstance(m, SynchronizedBatchNorm2d): 136 | m.weight.data.fill_(1) 137 | m.bias.data.zero_() 138 | elif isinstance(m, nn.BatchNorm2d): 139 | m.weight.data.fill_(1) 140 | m.bias.data.zero_() 141 | 142 | def _load_pretrained_model(self): 143 | pretrain_dict = model_zoo.load_url('https://download.pytorch.org/models/resnet101-5d3b4d8f.pth') 144 | model_dict = {} 145 | state_dict = self.state_dict() 146 | for k, v in pretrain_dict.items(): 147 | if k in state_dict: 148 | model_dict[k] = v 149 | state_dict.update(model_dict) 150 | self.load_state_dict(state_dict) 151 | 152 | 153 | def ResNet101(output_stride, in_channels, BatchNorm, pretrained=True): 154 | """Constructs a ResNet-101 model. 155 | Args: 156 | pretrained (bool): If True, returns a model pre-trained on ImageNet 157 | """ 158 | model = ResNet(Bottleneck, [3, 4, 23, 3], output_stride, BatchNorm, pretrained=pretrained) 159 | 160 | # get the parameters of the first convolution 161 | out_channels = model.conv1.out_channels 162 | kernel_size = model.conv1.kernel_size 163 | stride = model.conv1.stride 164 | padding = model.conv1.padding 165 | bias = model.conv1.bias 166 | # replace the number of input channels in the first convolution 167 | model.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, 168 | stride=stride, padding=padding, bias=bias) 169 | nn.init.xavier_normal_(model.conv1.weight) 170 | if bias is not None: 171 | model.conv1.bias.fill_(0) 172 | 173 | return model 174 | 175 | 176 | if __name__ == "__main__": 177 | import torch 178 | model = ResNet101(BatchNorm=nn.BatchNorm2d, pretrained=True, output_stride=8) 179 | input = torch.rand(1, 3, 512, 512) 180 | output, low_level_feat = model(input) 181 | print(output.size()) 182 | print(low_level_feat.size()) -------------------------------------------------------------------------------- /model/deeplab/decoder.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from .sync_batchnorm.batchnorm import SynchronizedBatchNorm2d 6 | 7 | class Decoder(nn.Module): 8 | def __init__(self, num_classes, backbone, BatchNorm): 9 | super(Decoder, self).__init__() 10 | if backbone == 'resnet' or backbone == 'drn': 11 | low_level_inplanes = 256 12 | elif backbone == 'xception': 13 | low_level_inplanes = 128 14 | elif backbone == 'mobilenet': 15 | low_level_inplanes = 24 16 | else: 17 | raise NotImplementedError 18 | 19 | self.conv1 = nn.Conv2d(low_level_inplanes, 48, 1, bias=False) 20 | self.bn1 = BatchNorm(48) 21 | self.relu = nn.ReLU() 22 | self.last_conv = nn.Sequential(nn.Conv2d(304, 256, kernel_size=3, stride=1, padding=1, bias=False), 23 | BatchNorm(256), 24 | nn.ReLU(), 25 | nn.Dropout(0.5), 26 | nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=False), 27 | BatchNorm(256), 28 | nn.ReLU(), 29 | nn.Dropout(0.1), 30 | nn.Conv2d(256, num_classes, kernel_size=1, stride=1)) 31 | self._init_weight() 32 | 33 | 34 | def forward(self, x, low_level_feat): 35 | low_level_feat = self.conv1(low_level_feat) 36 | low_level_feat = self.bn1(low_level_feat) 37 | low_level_feat = self.relu(low_level_feat) 38 | 39 | x = F.interpolate(x, size=low_level_feat.size()[2:], mode='bilinear', align_corners=True) 40 | x = torch.cat((x, low_level_feat), dim=1) 41 | x = self.last_conv(x) 42 | 43 | return x 44 | 45 | def _init_weight(self): 46 | for m in self.modules(): 47 | if isinstance(m, nn.Conv2d): 48 | torch.nn.init.kaiming_normal_(m.weight) 49 | elif isinstance(m, SynchronizedBatchNorm2d): 50 | m.weight.data.fill_(1) 51 | m.bias.data.zero_() 52 | elif isinstance(m, nn.BatchNorm2d): 53 | m.weight.data.fill_(1) 54 | m.bias.data.zero_() 55 | 56 | def build_decoder(num_classes, backbone, BatchNorm): 57 | return Decoder(num_classes, backbone, BatchNorm) -------------------------------------------------------------------------------- /model/deeplab/deeplab.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from .sync_batchnorm.batchnorm import SynchronizedBatchNorm2d 5 | from .aspp import build_aspp 6 | from .decoder import build_decoder 7 | from .backbone import build_backbone 8 | 9 | 10 | class DeepLab(nn.Module): 11 | def __init__(self, backbone='resnet', output_stride=16, num_classes=21, in_channels=3, 12 | sync_bn=True, freeze_bn=False, pretrained=True): 13 | super(DeepLab, self).__init__() 14 | if backbone == 'drn': 15 | output_stride = 8 16 | 17 | if sync_bn == True: 18 | BatchNorm = SynchronizedBatchNorm2d 19 | else: 20 | BatchNorm = nn.BatchNorm2d 21 | 22 | self.backbone = build_backbone(backbone, output_stride, in_channels, BatchNorm, pretrained=pretrained) 23 | self.aspp = build_aspp(backbone, output_stride, BatchNorm) 24 | self.decoder = build_decoder(num_classes, backbone, BatchNorm) 25 | 26 | if freeze_bn: 27 | self.freeze_bn() 28 | 29 | def forward(self, input): 30 | x, low_level_feat = self.backbone(input) 31 | x = self.aspp(x) 32 | x = self.decoder(x, low_level_feat) 33 | x = F.interpolate(x, size=input.size()[2:], mode='bilinear', align_corners=True) 34 | 35 | return x 36 | 37 | def freeze_bn(self): 38 | for m in self.modules(): 39 | if isinstance(m, SynchronizedBatchNorm2d): 40 | m.eval() 41 | elif isinstance(m, nn.BatchNorm2d): 42 | m.eval() 43 | 44 | def get_1x_lr_params(self): 45 | modules = [self.backbone] 46 | for i in range(len(modules)): 47 | for m in modules[i].named_modules(): 48 | if isinstance(m[1], nn.Conv2d) or isinstance(m[1], SynchronizedBatchNorm2d) \ 49 | or isinstance(m[1], nn.BatchNorm2d): 50 | for p in m[1].parameters(): 51 | if p.requires_grad: 52 | yield p 53 | 54 | def get_10x_lr_params(self): 55 | modules = [self.aspp, self.decoder] 56 | for i in range(len(modules)): 57 | for m in modules[i].named_modules(): 58 | if isinstance(m[1], nn.Conv2d) or isinstance(m[1], SynchronizedBatchNorm2d) \ 59 | or isinstance(m[1], nn.BatchNorm2d): 60 | for p in m[1].parameters(): 61 | if p.requires_grad: 62 | yield p 63 | 64 | 65 | if __name__ == "__main__": 66 | model = DeepLab(backbone='mobilenet', output_stride=16) 67 | model.eval() 68 | input = torch.rand(1, 3, 513, 513) 69 | output = model(input) 70 | print(output.size()) 71 | 72 | 73 | -------------------------------------------------------------------------------- /model/deeplab/sync_batchnorm/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : __init__.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d 12 | from .replicate import DataParallelWithCallback, patch_replication_callback -------------------------------------------------------------------------------- /model/deeplab/sync_batchnorm/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LK-Peng/CNN-based-Cloud-Detection-Methods/1393a6886e62f1ed5a612d57c5a725c763a6b2cc/model/deeplab/sync_batchnorm/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /model/deeplab/sync_batchnorm/__pycache__/batchnorm.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LK-Peng/CNN-based-Cloud-Detection-Methods/1393a6886e62f1ed5a612d57c5a725c763a6b2cc/model/deeplab/sync_batchnorm/__pycache__/batchnorm.cpython-37.pyc -------------------------------------------------------------------------------- /model/deeplab/sync_batchnorm/__pycache__/comm.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LK-Peng/CNN-based-Cloud-Detection-Methods/1393a6886e62f1ed5a612d57c5a725c763a6b2cc/model/deeplab/sync_batchnorm/__pycache__/comm.cpython-37.pyc -------------------------------------------------------------------------------- /model/deeplab/sync_batchnorm/__pycache__/replicate.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LK-Peng/CNN-based-Cloud-Detection-Methods/1393a6886e62f1ed5a612d57c5a725c763a6b2cc/model/deeplab/sync_batchnorm/__pycache__/replicate.cpython-37.pyc -------------------------------------------------------------------------------- /model/deeplab/sync_batchnorm/comm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : comm.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import queue 12 | import collections 13 | import threading 14 | 15 | __all__ = ['FutureResult', 'SlavePipe', 'SyncMaster'] 16 | 17 | 18 | class FutureResult(object): 19 | """A thread-safe future implementation. Used only as one-to-one pipe.""" 20 | 21 | def __init__(self): 22 | self._result = None 23 | self._lock = threading.Lock() 24 | self._cond = threading.Condition(self._lock) 25 | 26 | def put(self, result): 27 | with self._lock: 28 | assert self._result is None, 'Previous result has\'t been fetched.' 29 | self._result = result 30 | self._cond.notify() 31 | 32 | def get(self): 33 | with self._lock: 34 | if self._result is None: 35 | self._cond.wait() 36 | 37 | res = self._result 38 | self._result = None 39 | return res 40 | 41 | 42 | _MasterRegistry = collections.namedtuple('MasterRegistry', ['result']) 43 | _SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result']) 44 | 45 | 46 | class SlavePipe(_SlavePipeBase): 47 | """Pipe for master-slave communication.""" 48 | 49 | def run_slave(self, msg): 50 | self.queue.put((self.identifier, msg)) 51 | ret = self.result.get() 52 | self.queue.put(True) 53 | return ret 54 | 55 | 56 | class SyncMaster(object): 57 | """An abstract `SyncMaster` object. 58 | - During the replication, as the data parallel will trigger an callback of each module, all slave devices should 59 | call `register(id)` and obtain an `SlavePipe` to communicate with the master. 60 | - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected, 61 | and passed to a registered callback. 62 | - After receiving the messages, the master device should gather the information and determine to message passed 63 | back to each slave devices. 64 | """ 65 | 66 | def __init__(self, master_callback): 67 | """ 68 | Args: 69 | master_callback: a callback to be invoked after having collected messages from slave devices. 70 | """ 71 | self._master_callback = master_callback 72 | self._queue = queue.Queue() 73 | self._registry = collections.OrderedDict() 74 | self._activated = False 75 | 76 | def __getstate__(self): 77 | return {'master_callback': self._master_callback} 78 | 79 | def __setstate__(self, state): 80 | self.__init__(state['master_callback']) 81 | 82 | def register_slave(self, identifier): 83 | """ 84 | Register an slave device. 85 | Args: 86 | identifier: an identifier, usually is the device id. 87 | Returns: a `SlavePipe` object which can be used to communicate with the master device. 88 | """ 89 | if self._activated: 90 | assert self._queue.empty(), 'Queue is not clean before next initialization.' 91 | self._activated = False 92 | self._registry.clear() 93 | future = FutureResult() 94 | self._registry[identifier] = _MasterRegistry(future) 95 | return SlavePipe(identifier, self._queue, future) 96 | 97 | def run_master(self, master_msg): 98 | """ 99 | Main entry for the master device in each forward pass. 100 | The messages were first collected from each devices (including the master device), and then 101 | an callback will be invoked to compute the message to be sent back to each devices 102 | (including the master device). 103 | Args: 104 | master_msg: the message that the master want to send to itself. This will be placed as the first 105 | message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example. 106 | Returns: the message to be sent back to the master device. 107 | """ 108 | self._activated = True 109 | 110 | intermediates = [(0, master_msg)] 111 | for i in range(self.nr_slaves): 112 | intermediates.append(self._queue.get()) 113 | 114 | results = self._master_callback(intermediates) 115 | assert results[0][0] == 0, 'The first result should belongs to the master.' 116 | 117 | for i, res in results: 118 | if i == 0: 119 | continue 120 | self._registry[i].result.put(res) 121 | 122 | for i in range(self.nr_slaves): 123 | assert self._queue.get() is True 124 | 125 | return results[0][1] 126 | 127 | @property 128 | def nr_slaves(self): 129 | return len(self._registry) 130 | -------------------------------------------------------------------------------- /model/deeplab/sync_batchnorm/replicate.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : replicate.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import functools 12 | 13 | from torch.nn.parallel.data_parallel import DataParallel 14 | 15 | __all__ = [ 16 | 'CallbackContext', 17 | 'execute_replication_callbacks', 18 | 'DataParallelWithCallback', 19 | 'patch_replication_callback' 20 | ] 21 | 22 | 23 | class CallbackContext(object): 24 | pass 25 | 26 | 27 | def execute_replication_callbacks(modules): 28 | """ 29 | Execute an replication callback `__data_parallel_replicate__` on each module created by original replication. 30 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 31 | Note that, as all modules are isomorphism, we assign each sub-module with a context 32 | (shared among multiple copies of this module on different devices). 33 | Through this context, different copies can share some information. 34 | We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback 35 | of any slave copies. 36 | """ 37 | master_copy = modules[0] 38 | nr_modules = len(list(master_copy.modules())) 39 | ctxs = [CallbackContext() for _ in range(nr_modules)] 40 | 41 | for i, module in enumerate(modules): 42 | for j, m in enumerate(module.modules()): 43 | if hasattr(m, '__data_parallel_replicate__'): 44 | m.__data_parallel_replicate__(ctxs[j], i) 45 | 46 | 47 | class DataParallelWithCallback(DataParallel): 48 | """ 49 | Data Parallel with a replication callback. 50 | An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by 51 | original `replicate` function. 52 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 53 | Examples: 54 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 55 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 56 | # sync_bn.__data_parallel_replicate__ will be invoked. 57 | """ 58 | 59 | def replicate(self, module, device_ids): 60 | modules = super(DataParallelWithCallback, self).replicate(module, device_ids) 61 | execute_replication_callbacks(modules) 62 | return modules 63 | 64 | 65 | def patch_replication_callback(data_parallel): 66 | """ 67 | Monkey-patch an existing `DataParallel` object. Add the replication callback. 68 | Useful when you have customized `DataParallel` implementation. 69 | Examples: 70 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 71 | > sync_bn = DataParallel(sync_bn, device_ids=[0, 1]) 72 | > patch_replication_callback(sync_bn) 73 | # this is equivalent to 74 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 75 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 76 | """ 77 | 78 | assert isinstance(data_parallel, DataParallel) 79 | 80 | old_replicate = data_parallel.replicate 81 | 82 | @functools.wraps(old_replicate) 83 | def new_replicate(module, device_ids): 84 | modules = old_replicate(module, device_ids) 85 | execute_replication_callbacks(modules) 86 | return modules 87 | 88 | data_parallel.replicate = new_replicate -------------------------------------------------------------------------------- /model/deeplab/sync_batchnorm/unittest.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : unittest.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import unittest 12 | 13 | import numpy as np 14 | from torch.autograd import Variable 15 | 16 | 17 | def as_numpy(v): 18 | if isinstance(v, Variable): 19 | v = v.data 20 | return v.cpu().numpy() 21 | 22 | 23 | class TorchTestCase(unittest.TestCase): 24 | def assertTensorClose(self, a, b, atol=1e-3, rtol=1e-3): 25 | npa, npb = as_numpy(a), as_numpy(b) 26 | self.assertTrue( 27 | np.allclose(npa, npb, atol=atol), 28 | 'Tensor close check failed\n{}\n{}\nadiff={}, rdiff={}'.format(a, b, np.abs(npa - npb).max(), np.abs((npa - npb) / np.fmax(npa, 1e-5)).max()) 29 | ) 30 | -------------------------------------------------------------------------------- /model/mfcnn/__init__.py: -------------------------------------------------------------------------------- 1 | from .mfcnn_model import MFCNN 2 | -------------------------------------------------------------------------------- /model/mfcnn/mfcnn_model.py: -------------------------------------------------------------------------------- 1 | """ Full assembly of the parts to form the complete network """ 2 | 3 | import torch.nn.functional as F 4 | 5 | from .mfcnn_parts import * 6 | 7 | 8 | class MFCNN(nn.Module): 9 | def __init__(self, n_channels, n_classes, dropout_p=0.2): 10 | super(MFCNN, self).__init__() 11 | self.n_channels = n_channels 12 | self.n_classes = n_classes 13 | 14 | # feature map module 15 | self.fmm = FMM(n_channels) 16 | 17 | # multiscale module 18 | self.msm = Multiscale() 19 | 20 | # up-sampling module 21 | self.up1 = Up(1536, 512) 22 | self.up2 = Up(768, 256, bn=True) 23 | self.up3 = Up(384, 128, bn=True) 24 | 25 | # out 26 | self.dp = nn.Dropout(p=dropout_p) 27 | self.outc = OutConv(128, n_classes) 28 | 29 | def forward(self, x): 30 | # feature map module 31 | x1, x2, x3 = self.fmm(x) 32 | 33 | # multiscale module 34 | x4 = self.msm(x3) 35 | 36 | # up-sampling module 37 | # resolving image size inconsistencies 38 | diffH, diffW = x3.size()[2] - x4.size()[2], x3.size()[3] - x4.size()[3] 39 | x4 = F.pad(x4, [diffW // 2, diffW - diffW // 2, diffH // 2, diffH - diffH // 2]) 40 | x = self.up1(torch.cat((x3, x4), dim=1)) 41 | 42 | # resolving image size inconsistencies 43 | diffH, diffW = x2.size()[2] - x.size()[2], x2.size()[3] - x.size()[3] 44 | x = F.pad(x, [diffW // 2, diffW - diffW // 2, diffH // 2, diffH - diffH // 2]) 45 | x = self.up2(torch.cat((x2, x), dim=1)) 46 | 47 | # resolving image size inconsistencies 48 | diffH, diffW = x1.size()[2] - x.size()[2], x1.size()[3] - x.size()[3] 49 | x = F.pad(x, [diffW // 2, diffW - diffW // 2, diffH // 2, diffH - diffH // 2]) 50 | x = self.up3(torch.cat((x1, x), dim=1)) 51 | 52 | # out 53 | x = self.dp(x) 54 | return self.outc(x) 55 | -------------------------------------------------------------------------------- /model/mfcnn/mfcnn_parts.py: -------------------------------------------------------------------------------- 1 | """ Parts of the U-Net model """ 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | class DoubleConv(nn.Module): 9 | """(convolution => [BN] => ReLU) * 2""" 10 | 11 | def __init__(self, in_channels, out_channels, mid_channels=None): 12 | super(DoubleConv, self).__init__() 13 | if not mid_channels: 14 | mid_channels = out_channels 15 | self.double_conv = nn.Sequential( 16 | nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1), 17 | nn.BatchNorm2d(mid_channels), 18 | nn.ReLU(inplace=True), 19 | nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1), 20 | nn.BatchNorm2d(out_channels), 21 | nn.ReLU(inplace=True) 22 | ) 23 | 24 | def forward(self, x): 25 | return self.double_conv(x) 26 | 27 | 28 | class FMM(nn.Module): 29 | """Feature map module""" 30 | 31 | def __init__(self, in_channels): 32 | super(FMM, self).__init__() 33 | self.stage1 = nn.Sequential( 34 | nn.Conv2d(in_channels, 64, kernel_size=3, stride=2, padding=1), 35 | nn.ReLU(inplace=True), 36 | DoubleConv(64, 128, 96) 37 | ) 38 | self.stage2 = nn.Sequential( 39 | nn.MaxPool2d(2), 40 | DoubleConv(128, 256, 192) 41 | ) 42 | self.stage3 = nn.Sequential( 43 | nn.MaxPool2d(2), 44 | DoubleConv(256, 512, 256) 45 | ) 46 | 47 | def forward(self, x): 48 | x1 = self.stage1(x) 49 | x2 = self.stage2(x1) 50 | x3 = self.stage3(x2) 51 | 52 | return x1, x2, x3 53 | 54 | 55 | class Multiscale(nn.Module): 56 | """Multiscale module""" 57 | 58 | def __init__(self): 59 | super(Multiscale, self).__init__() 60 | 61 | self.scale1 = ScaleBlock(16) 62 | self.scale2 = ScaleBlock(8) 63 | self.scale3 = ScaleBlock(4) 64 | self.scale4 = ScaleBlock(2) 65 | 66 | def forward(self, x): 67 | x1 = self.scale1(x) 68 | x2 = self.scale2(x) 69 | x3 = self.scale3(x) 70 | x4 = self.scale4(x) 71 | 72 | # resolving image size inconsistencies 73 | maxH = max([x1.size()[2], x2.size()[2], x3.size()[2], x4.size()[2]]) 74 | maxW = max([x1.size()[3], x2.size()[3], x3.size()[3], x4.size()[3]]) 75 | x1 = self._padding(x1, maxH, maxW) 76 | x2 = self._padding(x2, maxH, maxW) 77 | x3 = self._padding(x3, maxH, maxW) 78 | x4 = self._padding(x4, maxH, maxW) 79 | return torch.cat((x1, x2, x3, x4), dim=1) 80 | 81 | def _padding(self, x, maxH, maxW): 82 | diffH, diffW = maxH - x.size()[2], maxW - x.size()[3] 83 | return F.pad(x, [diffW // 2, diffW - diffW // 2, diffH // 2, diffH - diffH // 2]) 84 | 85 | 86 | class ScaleBlock(nn.Module): 87 | """Used in multiscale module""" 88 | 89 | def __init__(self, pool_size): 90 | super(ScaleBlock, self).__init__() 91 | self.scale = nn.Sequential( 92 | nn.AvgPool2d(pool_size), 93 | nn.Conv2d(512, 256, kernel_size=1), 94 | nn.ReLU(inplace=True), 95 | nn.Upsample(scale_factor=pool_size, mode='bilinear', align_corners=True), 96 | nn.Conv2d(256, 256, kernel_size=3, padding=1), 97 | nn.ReLU(inplace=True) 98 | ) 99 | 100 | def forward(self, x): 101 | return self.scale(x) 102 | 103 | 104 | class Up(nn.Module): 105 | """convolution then upscaling""" 106 | 107 | def __init__(self, in_channels, out_channels, bn=False): 108 | super().__init__() 109 | if bn: 110 | self.conv_up = nn.Sequential( 111 | nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), 112 | nn.BatchNorm2d(out_channels), 113 | nn.ReLU(inplace=True), 114 | nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 115 | ) 116 | else: 117 | self.conv_up = nn.Sequential( 118 | nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), 119 | nn.ReLU(inplace=True), 120 | nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 121 | ) 122 | 123 | def forward(self, x): 124 | return self.conv_up(x) 125 | 126 | 127 | class OutConv(nn.Module): 128 | def __init__(self, in_channels, out_channels): 129 | super(OutConv, self).__init__() 130 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1) 131 | 132 | def forward(self, x): 133 | return self.conv(x) 134 | -------------------------------------------------------------------------------- /model/mscff/__init__.py: -------------------------------------------------------------------------------- 1 | from .mscff_model import MSCFF 2 | -------------------------------------------------------------------------------- /model/mscff/mscff_model.py: -------------------------------------------------------------------------------- 1 | """ Full assembly of the parts to form the complete network """ 2 | 3 | import torch.nn.functional as F 4 | 5 | from .mscff_parts import * 6 | 7 | 8 | class MSCFF(nn.Module): 9 | def __init__(self, n_channels, n_classes): 10 | super(MSCFF, self).__init__() 11 | self.n_channels = n_channels 12 | self.n_classes = n_classes 13 | 14 | # encoder 15 | self.inc = CBRR(n_channels, 64) 16 | self.down1 = Down(64, 128) 17 | self.down2 = Down(128, 256) 18 | self.down3 = Down(256, 512) 19 | self.encoder_DCBRR1 = CBRR(512, 512, 2) 20 | self.encoder_DCBRR2 = CBRR(512, 512, 4) 21 | 22 | # decoder 23 | self.decoder_DCBRR1 = CBRR(512, 512, 4) 24 | self.decoder_DCBRR2 = CBRR(512, 512, 2) 25 | self.decoder_CBRR = CBRR(512, 512) 26 | self.up1 = Up(512, 256) 27 | self.up2 = Up(256, 128) 28 | self.up3 = Up(128, 64) 29 | 30 | # multi-scale feature fusion 31 | self.ms_block1 = MultiScaleBlock(512, n_classes, scale_factor=8) 32 | self.ms_block2 = MultiScaleBlock(512, n_classes, scale_factor=8) 33 | self.ms_block3 = MultiScaleBlock(512, n_classes, scale_factor=8) 34 | self.ms_block4 = MultiScaleBlock(256, n_classes, scale_factor=4) 35 | self.ms_block5 = MultiScaleBlock(128, n_classes, scale_factor=2) 36 | self.ms_block6 = MultiScaleBlock(64, n_classes) 37 | 38 | # out 39 | self.outc = OutConv(6 * n_classes, n_classes) 40 | 41 | def forward(self, x): 42 | # encoder 43 | x1 = self.inc(x) 44 | x2 = self.down1(x1) 45 | x3 = self.down2(x2) 46 | x4 = self.down3(x3) 47 | x5 = self.encoder_DCBRR1(x4) 48 | x6 = self.encoder_DCBRR2(x5) 49 | 50 | # decoder 51 | x7 = self.decoder_DCBRR1(x6) + x6 52 | x8 = self.decoder_DCBRR2(x7) + x5 53 | x9 = self.decoder_CBRR(x8) + x4 54 | 55 | # resolving image size inconsistencies 56 | x10 = self.up1(x9) 57 | diffH, diffW = x3.size()[2] - x10.size()[2], x3.size()[3] - x10.size()[3] 58 | x10 = F.pad(x10, [diffW // 2, diffW - diffW // 2, diffH // 2, diffH - diffH // 2]) + x3 59 | 60 | # resolving image size inconsistencies 61 | x11 = self.up2(x10) 62 | diffH, diffW = x2.size()[2] - x11.size()[2], x2.size()[3] - x11.size()[3] 63 | x11 = F.pad(x11, [diffW // 2, diffW - diffW // 2, diffH // 2, diffH - diffH // 2]) + x2 64 | 65 | # resolving image size inconsistencies 66 | x12 = self.up3(x11) 67 | diffH, diffW = x1.size()[2] - x12.size()[2], x1.size()[3] - x12.size()[3] 68 | x12 = F.pad(x12, [diffW // 2, diffW - diffW // 2, diffH // 2, diffH - diffH // 2]) + x1 69 | 70 | # multi-scale feature fusion 71 | x7, x8, x9, x10, x11, x12 = self._padding(self.ms_block1(x7), x.size()[2], x.size()[3]), \ 72 | self._padding(self.ms_block2(x8), x.size()[2], x.size()[3]), \ 73 | self._padding(self.ms_block3(x9), x.size()[2], x.size()[3]), \ 74 | self._padding(self.ms_block4(x10), x.size()[2], x.size()[3]), \ 75 | self._padding(self.ms_block5(x11), x.size()[2], x.size()[3]), \ 76 | self._padding(self.ms_block6(x12), x.size()[2], x.size()[3]) 77 | x_ms = torch.cat((x7, x8, x9, x10, x11, x12), dim=1) 78 | 79 | logits = self.outc(x_ms) 80 | return logits 81 | 82 | def _padding(self, x, maxH, maxW): 83 | diffH, diffW = maxH - x.size()[2], maxW - x.size()[3] 84 | return F.pad(x, [diffW // 2, diffW - diffW // 2, diffH // 2, diffH - diffH // 2]) 85 | -------------------------------------------------------------------------------- /model/mscff/mscff_parts.py: -------------------------------------------------------------------------------- 1 | """ Parts of the U-Net model """ 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | class CBRR(nn.Module): 9 | """Convolution/Dilated Convolution, BN, ReLU with residual unit""" 10 | 11 | def __init__(self, in_channels, out_channels, dilation=1): 12 | super(CBRR, self).__init__() 13 | 14 | self.block1 = nn.Sequential( 15 | nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=dilation, dilation=dilation), 16 | nn.BatchNorm2d(out_channels), 17 | nn.ReLU(inplace=True), 18 | ) 19 | self.block2 = nn.Sequential( 20 | nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=dilation, dilation=dilation), 21 | nn.BatchNorm2d(out_channels), 22 | nn.ReLU(inplace=True), 23 | nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=dilation, dilation=dilation), 24 | nn.BatchNorm2d(out_channels), 25 | nn.ReLU(inplace=True), 26 | ) 27 | 28 | def forward(self, x): 29 | x = self.block1(x) 30 | return self.block2(x) + x 31 | 32 | 33 | class Down(nn.Module): 34 | """Downscaling with maxpool then CBRR""" 35 | 36 | def __init__(self, in_channels, out_channels): 37 | super().__init__() 38 | self.maxpool_conv = nn.Sequential( 39 | nn.MaxPool2d(2), 40 | CBRR(in_channels, out_channels) 41 | ) 42 | 43 | def forward(self, x): 44 | return self.maxpool_conv(x) 45 | 46 | 47 | class Up(nn.Module): 48 | """Upscaling then CBRR""" 49 | 50 | def __init__(self, in_channels, out_channels): 51 | super().__init__() 52 | 53 | self.up_conv = nn.Sequential( 54 | nn.ConvTranspose2d(in_channels, in_channels, kernel_size=4, stride=2, padding=1), 55 | CBRR(in_channels, out_channels) 56 | ) 57 | 58 | def forward(self, x): 59 | return self.up_conv(x) 60 | 61 | 62 | class MultiScaleBlock(nn.Module): 63 | """Block in multi-scale feature fusion""" 64 | 65 | def __init__(self, in_channels, out_channels, scale_factor=None): 66 | super(MultiScaleBlock, self).__init__() 67 | 68 | if scale_factor: 69 | self.ms_block = nn.Sequential( 70 | nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1), 71 | nn.BatchNorm2d(out_channels), 72 | nn.ReLU(inplace=True), 73 | nn.Upsample(scale_factor=scale_factor, mode='bilinear', align_corners=True) 74 | ) 75 | else: 76 | self.ms_block = nn.Sequential( 77 | nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1), 78 | nn.BatchNorm2d(out_channels), 79 | nn.ReLU(inplace=True) 80 | ) 81 | 82 | def forward(self, x): 83 | return self.ms_block(x) 84 | 85 | 86 | class OutConv(nn.Module): 87 | def __init__(self, in_channels, out_channels): 88 | super(OutConv, self).__init__() 89 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) 90 | 91 | def forward(self, x): 92 | return self.conv(x) 93 | -------------------------------------------------------------------------------- /model/munet/__init__.py: -------------------------------------------------------------------------------- 1 | from .munet_model import MUNet 2 | -------------------------------------------------------------------------------- /model/munet/munet_model.py: -------------------------------------------------------------------------------- 1 | """ Full assembly of the parts to form the complete network. """ 2 | 3 | from .munet_parts import * 4 | 5 | 6 | class MUNet(nn.Module): 7 | def __init__(self, n_channels, n_classes): 8 | super(MUNet, self).__init__() 9 | self.n_channels = n_channels 10 | self.n_classes = n_classes 11 | 12 | self.inc = DoubleConv(n_channels, 32) 13 | self.down1 = Down(32, 64) 14 | self.down2 = Down(64, 128) 15 | self.down3 = Down(128, 256) 16 | self.down4 = Down(256, 512) 17 | self.up1 = Up(512, 256) 18 | self.up2 = Up(256, 128) 19 | self.up3 = Up(128, 64) 20 | self.up4 = Up(64, 32) 21 | self.outc = OutConv(32, n_classes) 22 | 23 | def forward(self, x): 24 | x1 = self.inc(x) 25 | x2 = self.down1(x1) 26 | x3 = self.down2(x2) 27 | x4 = self.down3(x3) 28 | x = self.down4(x4) 29 | x = self.up1(x, x4) 30 | x = self.up2(x, x3) 31 | x = self.up3(x, x2) 32 | x = self.up4(x, x1) 33 | logits = self.outc(x) 34 | return logits 35 | 36 | -------------------------------------------------------------------------------- /model/munet/munet_parts.py: -------------------------------------------------------------------------------- 1 | """Parts of the U-Net model""" 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | class DoubleConv(nn.Module): 9 | """(Convolution => [BN] => ReLU) * 2""" 10 | 11 | def __init__(self, in_channels, out_channels, mid_channels=None): 12 | super().__init__() 13 | if not mid_channels: 14 | mid_channels = out_channels 15 | self.double_conv = nn.Sequential( 16 | nn.Conv2d(in_channels, mid_channels, kernel_size=3, stride=1, padding=1), 17 | nn.BatchNorm2d(mid_channels), 18 | nn.ReLU(inplace=True), 19 | nn.Conv2d(mid_channels, out_channels, kernel_size=3, stride=1, padding=1), 20 | nn.BatchNorm2d(out_channels), 21 | nn.ReLU(inplace=True), 22 | ) 23 | 24 | def forward(self, x): 25 | return self.double_conv(x) 26 | 27 | 28 | class Down(nn.Module): 29 | """Downscaling with maxpooling then double convolution""" 30 | 31 | def __init__(self, in_channels, out_channels, mid_channels=None): 32 | super(Down, self).__init__() 33 | self.maxpool_conv = nn.Sequential( 34 | nn.MaxPool2d(2), 35 | DoubleConv(in_channels, out_channels, mid_channels) 36 | ) 37 | 38 | def forward(self, x): 39 | return self.maxpool_conv(x) 40 | 41 | 42 | class Up(nn.Module): 43 | """Upscaling then double convolution""" 44 | 45 | def __init__(self, in_channels, out_channels): 46 | super(Up, self).__init__() 47 | self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2) 48 | self.conv = DoubleConv(in_channels, out_channels) 49 | 50 | def forward(self, x1, x2): 51 | x1 = self.up(x1) 52 | 53 | # # padding, make the two images the same size 54 | # diffh = x2.size()[2] - x1.size()[2] 55 | # diffw = x2.size()[3] - x1.size()[3] 56 | # x1 = F.pad(x1, [diffw // 2, diffw - diffw // 2, diffh // 2, diffh - diffh // 2]) 57 | # if you have padding issues, see 58 | # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a 59 | # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd 60 | 61 | x = torch.cat((x2, x1), dim=1) 62 | return self.conv(x) 63 | 64 | 65 | class OutConv(nn.Module): 66 | def __init__(self, in_channels, out_channels): 67 | super(OutConv, self).__init__() 68 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1) 69 | 70 | def forward(self, x): 71 | return self.conv(x) 72 | 73 | -------------------------------------------------------------------------------- /model/tlnet/__init__.py: -------------------------------------------------------------------------------- 1 | from .tlnet_model import TLNet 2 | -------------------------------------------------------------------------------- /model/tlnet/tlnet_model.py: -------------------------------------------------------------------------------- 1 | """ Full assembly of the parts to form the complete network """ 2 | 3 | import torch.nn.functional as F 4 | 5 | from .tlnet_parts import * 6 | 7 | 8 | class TLNet(nn.Module): 9 | def __init__(self, n_channels, n_classes, dilations=[1, 1, 1, 1]): 10 | super(TLNet, self).__init__() 11 | self.n_channels = n_channels 12 | self.n_classes = n_classes 13 | 14 | self.inc = DoubleConv(n_channels, 32) 15 | self.down1 = Down(32, 64, dilations=dilations[0:2]) 16 | self.down2 = Down(64, 128, dilations=dilations[2:]) 17 | self.up1 = Up(128, 64) 18 | self.up2 = Up(64, 32) 19 | self.outc = OutConv(32, n_classes) 20 | 21 | def forward(self, x): 22 | x1 = self.inc(x) 23 | x2 = self.down1(x1) 24 | x = self.down2(x2) 25 | x = self.up1(x, x2) 26 | x = self.up2(x, x1) 27 | logits = self.outc(x) 28 | return logits 29 | -------------------------------------------------------------------------------- /model/tlnet/tlnet_parts.py: -------------------------------------------------------------------------------- 1 | """ Parts of the U-Net model """ 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | class SeparableConv2d(nn.Module): 9 | """It has the same function as in keras.""" 10 | def __init__(self, in_channels, out_channels, kernel_size=3, dilation=1, padding=1): 11 | super(SeparableConv2d, self).__init__() 12 | 13 | self.depthwise = nn.Conv2d(in_channels, in_channels, kernel_size=kernel_size, dilation=dilation, 14 | padding=padding, stride=1, groups=in_channels) 15 | self.pointwise = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1) 16 | 17 | def forward(self, x): 18 | x = self.depthwise(x) 19 | return self.pointwise(x) 20 | 21 | 22 | class DoubleConv(nn.Module): 23 | """(convolution => [BN] => ReLU) * 2""" 24 | 25 | def __init__(self, in_channels, out_channels, mid_channels=None, dilations=[1, 1]): 26 | super().__init__() 27 | if not mid_channels: 28 | mid_channels = out_channels 29 | self.double_conv = nn.Sequential( 30 | SeparableConv2d(in_channels, mid_channels, kernel_size=3, dilation=dilations[0], padding=dilations[0]), 31 | nn.BatchNorm2d(mid_channels), 32 | nn.ReLU(inplace=True), 33 | SeparableConv2d(mid_channels, out_channels, kernel_size=3, dilation=dilations[1], padding=dilations[1]), 34 | nn.BatchNorm2d(out_channels), 35 | nn.ReLU(inplace=True) 36 | ) 37 | 38 | def forward(self, x): 39 | return self.double_conv(x) 40 | 41 | 42 | class Down(nn.Module): 43 | """Downscaling with maxpool then double conv""" 44 | 45 | def __init__(self, in_channels, out_channels, dilations=[1, 1]): 46 | super().__init__() 47 | self.maxpool_conv = nn.Sequential( 48 | nn.MaxPool2d(2), 49 | DoubleConv(in_channels, out_channels, dilations=dilations) 50 | ) 51 | 52 | def forward(self, x): 53 | return self.maxpool_conv(x) 54 | 55 | 56 | class Up(nn.Module): 57 | """Upscaling then double conv""" 58 | 59 | def __init__(self, in_channels, out_channels): 60 | super().__init__() 61 | 62 | self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2) 63 | self.conv = DoubleConv(in_channels, out_channels) 64 | 65 | def forward(self, x1, x2): 66 | x1 = self.up(x1) 67 | 68 | # # padding, make the two images the same size 69 | # diffh = x2.size()[2] - x1.size()[2] 70 | # diffw = x2.size()[3] - x1.size()[3] 71 | # x1 = F.pad(x1, [diffw // 2, diffw - diffw // 2, diffh // 2, diffh - diffh // 2]) 72 | # if you have padding issues, see 73 | # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a 74 | # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd 75 | 76 | x = torch.cat([x2, x1], dim=1) 77 | return self.conv(x) 78 | 79 | 80 | class OutConv(nn.Module): 81 | def __init__(self, in_channels, out_channels): 82 | super(OutConv, self).__init__() 83 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1) 84 | 85 | def forward(self, x): 86 | return self.conv(x) 87 | -------------------------------------------------------------------------------- /model/unet/__init__.py: -------------------------------------------------------------------------------- 1 | from .unet_model import UNet 2 | -------------------------------------------------------------------------------- /model/unet/unet_model.py: -------------------------------------------------------------------------------- 1 | """ Full assembly of the parts to form the complete network """ 2 | 3 | import torch.nn.functional as F 4 | 5 | from .unet_parts import * 6 | 7 | 8 | class UNet(nn.Module): 9 | def __init__(self, n_channels, n_classes, bilinear=True): 10 | super(UNet, self).__init__() 11 | self.n_channels = n_channels 12 | self.n_classes = n_classes 13 | self.bilinear = bilinear 14 | 15 | self.inc = DoubleConv(n_channels, 64) 16 | self.down1 = Down(64, 128) 17 | self.down2 = Down(128, 256) 18 | self.down3 = Down(256, 512) 19 | factor = 2 if bilinear else 1 20 | self.down4 = Down(512, 1024 // factor) 21 | self.up1 = Up(1024, 512 // factor, bilinear) 22 | self.up2 = Up(512, 256 // factor, bilinear) 23 | self.up3 = Up(256, 128 // factor, bilinear) 24 | self.up4 = Up(128, 64, bilinear) 25 | self.outc = OutConv(64, n_classes) 26 | 27 | def forward(self, x): 28 | x1 = self.inc(x) 29 | x2 = self.down1(x1) 30 | x3 = self.down2(x2) 31 | x4 = self.down3(x3) 32 | x5 = self.down4(x4) 33 | x = self.up1(x5, x4) 34 | x = self.up2(x, x3) 35 | x = self.up3(x, x2) 36 | x = self.up4(x, x1) 37 | logits = self.outc(x) 38 | return logits 39 | -------------------------------------------------------------------------------- /model/unet/unet_parts.py: -------------------------------------------------------------------------------- 1 | """ Parts of the U-Net model """ 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | class DoubleConv(nn.Module): 9 | """(convolution => [BN] => ReLU) * 2""" 10 | 11 | def __init__(self, in_channels, out_channels, mid_channels=None): 12 | super().__init__() 13 | if not mid_channels: 14 | mid_channels = out_channels 15 | self.double_conv = nn.Sequential( 16 | nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1), 17 | nn.BatchNorm2d(mid_channels), 18 | nn.ReLU(inplace=True), 19 | nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1), 20 | nn.BatchNorm2d(out_channels), 21 | nn.ReLU(inplace=True) 22 | ) 23 | 24 | def forward(self, x): 25 | return self.double_conv(x) 26 | 27 | 28 | class Down(nn.Module): 29 | """Downscaling with maxpool then double conv""" 30 | 31 | def __init__(self, in_channels, out_channels): 32 | super().__init__() 33 | self.maxpool_conv = nn.Sequential( 34 | nn.MaxPool2d(2), 35 | DoubleConv(in_channels, out_channels) 36 | ) 37 | 38 | def forward(self, x): 39 | return self.maxpool_conv(x) 40 | 41 | 42 | class Up(nn.Module): 43 | """Upscaling then double conv""" 44 | 45 | def __init__(self, in_channels, out_channels, bilinear=True): 46 | super().__init__() 47 | 48 | # if bilinear, use the normal convolutions to reduce the number of channels 49 | if bilinear: 50 | self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 51 | self.conv = DoubleConv(in_channels, out_channels, in_channels // 2) 52 | else: 53 | self.up = nn.ConvTranspose2d(in_channels , in_channels // 2, kernel_size=2, stride=2) 54 | self.conv = DoubleConv(in_channels, out_channels) 55 | 56 | 57 | def forward(self, x1, x2): 58 | x1 = self.up(x1) 59 | # # input is CHW 60 | # diffY = x2.size()[2] - x1.size()[2] 61 | # diffX = x2.size()[3] - x1.size()[3] 62 | 63 | # x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, 64 | # diffY // 2, diffY - diffY // 2]) 65 | # if you have padding issues, see 66 | # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a 67 | # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd 68 | x = torch.cat([x2, x1], dim=1) 69 | return self.conv(x) 70 | 71 | 72 | class OutConv(nn.Module): 73 | def __init__(self, in_channels, out_channels): 74 | super(OutConv, self).__init__() 75 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1) 76 | 77 | def forward(self, x): 78 | return self.conv(x) 79 | -------------------------------------------------------------------------------- /model/unet_1/__init__.py: -------------------------------------------------------------------------------- 1 | """Reduce the number of pooling step""" 2 | from .unet_model import UNet_1 -------------------------------------------------------------------------------- /model/unet_1/unet_model.py: -------------------------------------------------------------------------------- 1 | """ Full assembly of the parts to form the complete network """ 2 | 3 | import torch.nn.functional as F 4 | 5 | from .unet_parts import * 6 | 7 | 8 | class UNet_1(nn.Module): 9 | def __init__(self, n_channels, n_classes, bilinear=True): 10 | super(UNet_1, self).__init__() 11 | self.n_channels = n_channels 12 | self.n_classes = n_classes 13 | self.bilinear = bilinear 14 | 15 | self.inc = DoubleConv(n_channels, 64) 16 | self.down1 = Down(64, 128) 17 | # self.down2 = Down(128, 256) 18 | # self.down3 = Down(256, 512) 19 | factor = 2 if bilinear else 1 20 | # self.down4 = Down(512, 1024 // factor) 21 | # self.up1 = Up(1024, 512 // factor, bilinear) 22 | # self.up2 = Up(512, 256 // factor, bilinear) 23 | # self.up3 = Up(256, 128 // factor, bilinear) 24 | self.up4 = Up(128, 64, bilinear) 25 | self.outc = OutConv(64, n_classes) 26 | 27 | def forward(self, x): 28 | x1 = self.inc(x) 29 | x2 = self.down1(x1) 30 | # x3 = self.down2(x2) 31 | # x4 = self.down3(x3) 32 | # x5 = self.down4(x4) 33 | # x = self.up1(x5, x4) 34 | # x = self.up2(x, x3) 35 | # x = self.up3(x3, x2) 36 | x = self.up4(x2, x1) 37 | logits = self.outc(x) 38 | return logits 39 | -------------------------------------------------------------------------------- /model/unet_1/unet_parts.py: -------------------------------------------------------------------------------- 1 | """ Parts of the U-Net model """ 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | class DoubleConv(nn.Module): 9 | """(convolution => [BN] => ReLU) * 2""" 10 | 11 | def __init__(self, in_channels, out_channels, mid_channels=None): 12 | super().__init__() 13 | if not mid_channels: 14 | mid_channels = out_channels 15 | self.double_conv = nn.Sequential( 16 | nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1), 17 | nn.BatchNorm2d(mid_channels), 18 | nn.ReLU(inplace=True), 19 | nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1), 20 | nn.BatchNorm2d(out_channels), 21 | nn.ReLU(inplace=True) 22 | ) 23 | 24 | def forward(self, x): 25 | return self.double_conv(x) 26 | 27 | 28 | class Down(nn.Module): 29 | """Downscaling with maxpool then double conv""" 30 | 31 | def __init__(self, in_channels, out_channels): 32 | super().__init__() 33 | self.maxpool_conv = nn.Sequential( 34 | nn.MaxPool2d(2), 35 | DoubleConv(in_channels, out_channels) 36 | ) 37 | 38 | def forward(self, x): 39 | return self.maxpool_conv(x) 40 | 41 | 42 | class Up(nn.Module): 43 | """Upscaling then double conv""" 44 | 45 | def __init__(self, in_channels, out_channels, bilinear=True): 46 | super().__init__() 47 | 48 | # if bilinear, use the normal convolutions to reduce the number of channels 49 | if bilinear: 50 | self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 51 | self.conv = DoubleConv(in_channels, out_channels, in_channels // 2) 52 | else: 53 | self.up = nn.ConvTranspose2d(in_channels , in_channels // 2, kernel_size=2, stride=2) 54 | self.conv = DoubleConv(in_channels, out_channels) 55 | 56 | 57 | def forward(self, x1, x2): 58 | x1 = self.up(x1) 59 | # # input is CHW 60 | # diffY = x2.size()[2] - x1.size()[2] 61 | # diffX = x2.size()[3] - x1.size()[3] 62 | 63 | # x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, 64 | # diffY // 2, diffY - diffY // 2]) 65 | # if you have padding issues, see 66 | # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a 67 | # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd 68 | x = torch.cat([x2, x1], dim=1) 69 | return self.conv(x) 70 | 71 | 72 | class OutConv(nn.Module): 73 | def __init__(self, in_channels, out_channels): 74 | super(OutConv, self).__init__() 75 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1) 76 | 77 | def forward(self, x): 78 | return self.conv(x) 79 | -------------------------------------------------------------------------------- /model/unet_2/__init__.py: -------------------------------------------------------------------------------- 1 | """Reduce the number of pooling step""" 2 | from .unet_model import UNet_2 3 | -------------------------------------------------------------------------------- /model/unet_2/unet_model.py: -------------------------------------------------------------------------------- 1 | """ Full assembly of the parts to form the complete network """ 2 | 3 | import torch.nn.functional as F 4 | 5 | from .unet_parts import * 6 | 7 | 8 | class UNet_2(nn.Module): 9 | def __init__(self, n_channels, n_classes, bilinear=True): 10 | super(UNet_2, self).__init__() 11 | self.n_channels = n_channels 12 | self.n_classes = n_classes 13 | self.bilinear = bilinear 14 | 15 | self.inc = DoubleConv(n_channels, 64) 16 | self.down1 = Down(64, 128) 17 | self.down2 = Down(128, 256) 18 | # self.down3 = Down(256, 512) 19 | factor = 2 if bilinear else 1 20 | # self.down4 = Down(512, 1024 // factor) 21 | # self.up1 = Up(1024, 512 // factor, bilinear) 22 | # self.up2 = Up(512, 256 // factor, bilinear) 23 | self.up3 = Up(256, 128 // factor, bilinear) 24 | self.up4 = Up(128, 64, bilinear) 25 | self.outc = OutConv(64, n_classes) 26 | 27 | def forward(self, x): 28 | x1 = self.inc(x) 29 | x2 = self.down1(x1) 30 | x3 = self.down2(x2) 31 | # x4 = self.down3(x3) 32 | # x5 = self.down4(x4) 33 | # x = self.up1(x5, x4) 34 | # x = self.up2(x, x3) 35 | x = self.up3(x3, x2) 36 | x = self.up4(x, x1) 37 | logits = self.outc(x) 38 | return logits 39 | -------------------------------------------------------------------------------- /model/unet_2/unet_parts.py: -------------------------------------------------------------------------------- 1 | """ Parts of the U-Net model """ 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | class DoubleConv(nn.Module): 9 | """(convolution => [BN] => ReLU) * 2""" 10 | 11 | def __init__(self, in_channels, out_channels, mid_channels=None): 12 | super().__init__() 13 | if not mid_channels: 14 | mid_channels = out_channels 15 | self.double_conv = nn.Sequential( 16 | nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1), 17 | nn.BatchNorm2d(mid_channels), 18 | nn.ReLU(inplace=True), 19 | nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1), 20 | nn.BatchNorm2d(out_channels), 21 | nn.ReLU(inplace=True) 22 | ) 23 | 24 | def forward(self, x): 25 | return self.double_conv(x) 26 | 27 | 28 | class Down(nn.Module): 29 | """Downscaling with maxpool then double conv""" 30 | 31 | def __init__(self, in_channels, out_channels): 32 | super().__init__() 33 | self.maxpool_conv = nn.Sequential( 34 | nn.MaxPool2d(2), 35 | DoubleConv(in_channels, out_channels) 36 | ) 37 | 38 | def forward(self, x): 39 | return self.maxpool_conv(x) 40 | 41 | 42 | class Up(nn.Module): 43 | """Upscaling then double conv""" 44 | 45 | def __init__(self, in_channels, out_channels, bilinear=True): 46 | super().__init__() 47 | 48 | # if bilinear, use the normal convolutions to reduce the number of channels 49 | if bilinear: 50 | self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 51 | self.conv = DoubleConv(in_channels, out_channels, in_channels // 2) 52 | else: 53 | self.up = nn.ConvTranspose2d(in_channels , in_channels // 2, kernel_size=2, stride=2) 54 | self.conv = DoubleConv(in_channels, out_channels) 55 | 56 | 57 | def forward(self, x1, x2): 58 | x1 = self.up(x1) 59 | # # input is CHW 60 | # diffY = x2.size()[2] - x1.size()[2] 61 | # diffX = x2.size()[3] - x1.size()[3] 62 | 63 | # x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, 64 | # diffY // 2, diffY - diffY // 2]) 65 | # if you have padding issues, see 66 | # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a 67 | # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd 68 | x = torch.cat([x2, x1], dim=1) 69 | return self.conv(x) 70 | 71 | 72 | class OutConv(nn.Module): 73 | def __init__(self, in_channels, out_channels): 74 | super(OutConv, self).__init__() 75 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1) 76 | 77 | def forward(self, x): 78 | return self.conv(x) 79 | -------------------------------------------------------------------------------- /model/unet_3/__init__.py: -------------------------------------------------------------------------------- 1 | """Reduce the number of pooling step""" 2 | from .unet_model import UNet_3 3 | -------------------------------------------------------------------------------- /model/unet_3/unet_model.py: -------------------------------------------------------------------------------- 1 | """ Full assembly of the parts to form the complete network """ 2 | 3 | import torch.nn.functional as F 4 | 5 | from .unet_parts import * 6 | 7 | 8 | class UNet_3(nn.Module): 9 | def __init__(self, n_channels, n_classes, bilinear=True): 10 | super(UNet_3, self).__init__() 11 | self.n_channels = n_channels 12 | self.n_classes = n_classes 13 | self.bilinear = bilinear 14 | 15 | self.inc = DoubleConv(n_channels, 64) 16 | self.down1 = Down(64, 128) 17 | self.down2 = Down(128, 256) 18 | self.down3 = Down(256, 512) 19 | factor = 2 if bilinear else 1 20 | # self.down4 = Down(512, 1024 // factor) 21 | # self.up1 = Up(1024, 512 // factor, bilinear) 22 | self.up2 = Up(512, 256 // factor, bilinear) 23 | self.up3 = Up(256, 128 // factor, bilinear) 24 | self.up4 = Up(128, 64, bilinear) 25 | self.outc = OutConv(64, n_classes) 26 | 27 | def forward(self, x): 28 | x1 = self.inc(x) 29 | x2 = self.down1(x1) 30 | x3 = self.down2(x2) 31 | x4 = self.down3(x3) 32 | # x5 = self.down4(x4) 33 | # x = self.up1(x5, x4) 34 | x = self.up2(x4, x3) 35 | x = self.up3(x, x2) 36 | x = self.up4(x, x1) 37 | logits = self.outc(x) 38 | return logits 39 | -------------------------------------------------------------------------------- /model/unet_3/unet_parts.py: -------------------------------------------------------------------------------- 1 | """ Parts of the U-Net model """ 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | class DoubleConv(nn.Module): 9 | """(convolution => [BN] => ReLU) * 2""" 10 | 11 | def __init__(self, in_channels, out_channels, mid_channels=None): 12 | super().__init__() 13 | if not mid_channels: 14 | mid_channels = out_channels 15 | self.double_conv = nn.Sequential( 16 | nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1), 17 | nn.BatchNorm2d(mid_channels), 18 | nn.ReLU(inplace=True), 19 | nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1), 20 | nn.BatchNorm2d(out_channels), 21 | nn.ReLU(inplace=True) 22 | ) 23 | 24 | def forward(self, x): 25 | return self.double_conv(x) 26 | 27 | 28 | class Down(nn.Module): 29 | """Downscaling with maxpool then double conv""" 30 | 31 | def __init__(self, in_channels, out_channels): 32 | super().__init__() 33 | self.maxpool_conv = nn.Sequential( 34 | nn.MaxPool2d(2), 35 | DoubleConv(in_channels, out_channels) 36 | ) 37 | 38 | def forward(self, x): 39 | return self.maxpool_conv(x) 40 | 41 | 42 | class Up(nn.Module): 43 | """Upscaling then double conv""" 44 | 45 | def __init__(self, in_channels, out_channels, bilinear=True): 46 | super().__init__() 47 | 48 | # if bilinear, use the normal convolutions to reduce the number of channels 49 | if bilinear: 50 | self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 51 | self.conv = DoubleConv(in_channels, out_channels, in_channels // 2) 52 | else: 53 | self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2) 54 | self.conv = DoubleConv(in_channels, out_channels) 55 | 56 | 57 | def forward(self, x1, x2): 58 | x1 = self.up(x1) 59 | # # input is CHW 60 | # diffY = x2.size()[2] - x1.size()[2] 61 | # diffX = x2.size()[3] - x1.size()[3] 62 | 63 | # x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, 64 | # diffY // 2, diffY - diffY // 2]) 65 | # if you have padding issues, see 66 | # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a 67 | # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd 68 | x = torch.cat([x2, x1], dim=1) 69 | return self.conv(x) 70 | 71 | 72 | class OutConv(nn.Module): 73 | def __init__(self, in_channels, out_channels): 74 | super(OutConv, self).__init__() 75 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1) 76 | 77 | def forward(self, x): 78 | return self.conv(x) 79 | -------------------------------------------------------------------------------- /model/unet_dilation/__init__.py: -------------------------------------------------------------------------------- 1 | """Reduce the number of filters""" 2 | from .unet_model import UNet_dil 3 | -------------------------------------------------------------------------------- /model/unet_dilation/unet_model.py: -------------------------------------------------------------------------------- 1 | """ Full assembly of the parts to form the complete network """ 2 | 3 | import torch.nn.functional as F 4 | 5 | from .unet_parts import * 6 | 7 | 8 | class UNet_dil(nn.Module): 9 | def __init__(self, n_channels, n_classes, bilinear=True, maxpool=True, dilation=1): 10 | super(UNet_dil, self).__init__() 11 | self.n_channels = n_channels 12 | self.n_classes = n_classes 13 | self.bilinear = bilinear 14 | 15 | self.inc = DoubleConv(n_channels, 64) 16 | self.down1 = Down(64, 128) 17 | self.down2 = Down(128, 256) 18 | self.down3 = Down(256, 512) 19 | factor = 2 if bilinear else 1 20 | self.down4 = Down(512, 1024 // factor, maxpool=maxpool, dilation=dilation) 21 | self.up1 = Up(1024, 512 // factor, bilinear, uplayer=False) 22 | self.up2 = Up(512, 256 // factor, bilinear) 23 | self.up3 = Up(256, 128 // factor, bilinear) 24 | self.up4 = Up(128, 64, bilinear) 25 | self.outc = OutConv(64, n_classes) 26 | 27 | def forward(self, x): 28 | x1 = self.inc(x) 29 | x2 = self.down1(x1) 30 | x3 = self.down2(x2) 31 | x4 = self.down3(x3) 32 | x5 = self.down4(x4) 33 | x = self.up1(x5, x4) 34 | x = self.up2(x, x3) 35 | x = self.up3(x, x2) 36 | x = self.up4(x, x1) 37 | logits = self.outc(x) 38 | return logits 39 | -------------------------------------------------------------------------------- /model/unet_dilation/unet_parts.py: -------------------------------------------------------------------------------- 1 | """ Parts of the U-Net model """ 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | class DoubleConv(nn.Module): 9 | """(convolution => [BN] => ReLU) * 2""" 10 | 11 | def __init__(self, in_channels, out_channels, dilation=1, mid_channels=None): 12 | super().__init__() 13 | if not mid_channels: 14 | mid_channels = out_channels 15 | self.double_conv = nn.Sequential( 16 | nn.Conv2d(in_channels, mid_channels, kernel_size=3, dilation=dilation, padding=dilation), 17 | nn.BatchNorm2d(mid_channels), 18 | nn.ReLU(inplace=True), 19 | nn.Conv2d(mid_channels, out_channels, kernel_size=3, dilation=dilation, padding=dilation), 20 | nn.BatchNorm2d(out_channels), 21 | nn.ReLU(inplace=True) 22 | ) 23 | 24 | def forward(self, x): 25 | return self.double_conv(x) 26 | 27 | 28 | class Down(nn.Module): 29 | """ Dilation conv or downscaling with maxpool then double conv""" 30 | 31 | def __init__(self, in_channels, out_channels, maxpool=True, dilation=1): 32 | super().__init__() 33 | if maxpool: 34 | self.maxpool_conv = nn.Sequential( 35 | nn.MaxPool2d(2), 36 | DoubleConv(in_channels, out_channels) 37 | ) 38 | else: 39 | self.maxpool_conv = nn.Sequential( 40 | nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, dilation=dilation, padding=dilation), 41 | DoubleConv(in_channels, out_channels, dilation=dilation) 42 | ) 43 | 44 | def forward(self, x): 45 | return self.maxpool_conv(x) 46 | 47 | 48 | class Up(nn.Module): 49 | """Upscaling then double conv""" 50 | 51 | def __init__(self, in_channels, out_channels, bilinear=True, uplayer=True): 52 | super().__init__() 53 | self.uplayer = uplayer 54 | if uplayer: 55 | # if bilinear, use the normal convolutions to reduce the number of channels 56 | if bilinear: 57 | self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 58 | self.conv = DoubleConv(in_channels, out_channels, in_channels // 2) 59 | else: 60 | self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2) 61 | self.conv = DoubleConv(in_channels, out_channels) 62 | else: 63 | self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=1, stride=1) 64 | self.conv = DoubleConv(in_channels, out_channels, in_channels // 2) 65 | 66 | def forward(self, x1, x2): 67 | x1 = self.up(x1) 68 | # input is CHW 69 | # diffY = x2.size()[2] - x1.size()[2] 70 | # diffX = x2.size()[3] - x1.size()[3] 71 | 72 | # x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, 73 | # diffY // 2, diffY - diffY // 2]) 74 | # if you have padding issues, see 75 | # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a 76 | # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd 77 | x = torch.cat([x2, x1], dim=1) 78 | return self.conv(x) 79 | 80 | 81 | class OutConv(nn.Module): 82 | def __init__(self, in_channels, out_channels): 83 | super(OutConv, self).__init__() 84 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1) 85 | 86 | def forward(self, x): 87 | return self.conv(x) 88 | -------------------------------------------------------------------------------- /model/unets1/__init__.py: -------------------------------------------------------------------------------- 1 | from .unet_model import UNetS1 2 | -------------------------------------------------------------------------------- /model/unets1/unet_model.py: -------------------------------------------------------------------------------- 1 | """ Full assembly of the parts to form the complete network """ 2 | 3 | import torch.nn.functional as F 4 | 5 | from .unet_parts import * 6 | 7 | 8 | class UNetS1(nn.Module): 9 | def __init__(self, n_channels, n_classes, bilinear=True): 10 | super(UNetS1, self).__init__() 11 | self.n_channels = n_channels 12 | self.n_classes = n_classes 13 | self.bilinear = bilinear 14 | 15 | self.inc = DoubleConv(n_channels, 64) 16 | self.down1 = Down(64, 128) 17 | self.down2 = Down(128, 256) 18 | self.down3 = Down(256, 512) 19 | factor = 2 if bilinear else 1 20 | self.down4 = Down(512, 1024 // factor) 21 | self.up1 = Up(1024, 512 // factor, bilinear) 22 | self.up2 = UpNoConcat(512, 256 // factor, bilinear) 23 | self.up3 = UpNoConcat(256, 128 // factor, bilinear) 24 | self.up4 = UpNoConcat(128, 64, bilinear) 25 | self.outc = OutConv(64, n_classes) 26 | 27 | def forward(self, x): 28 | x = self.inc(x) 29 | x = self.down1(x) 30 | x = self.down2(x) 31 | x4 = self.down3(x) 32 | x = self.down4(x4) 33 | x = self.up1(x, x4) 34 | x = self.up2(x) 35 | x = self.up3(x) 36 | x = self.up4(x) 37 | logits = self.outc(x) 38 | return logits 39 | -------------------------------------------------------------------------------- /model/unets1/unet_parts.py: -------------------------------------------------------------------------------- 1 | """ Parts of the U-Net model """ 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | class DoubleConv(nn.Module): 9 | """(convolution => [BN] => ReLU) * 2""" 10 | 11 | def __init__(self, in_channels, out_channels, mid_channels=None): 12 | super().__init__() 13 | if not mid_channels: 14 | mid_channels = out_channels 15 | self.double_conv = nn.Sequential( 16 | nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1), 17 | nn.BatchNorm2d(mid_channels), 18 | nn.ReLU(inplace=True), 19 | nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1), 20 | nn.BatchNorm2d(out_channels), 21 | nn.ReLU(inplace=True) 22 | ) 23 | 24 | def forward(self, x): 25 | return self.double_conv(x) 26 | 27 | 28 | class Down(nn.Module): 29 | """Downscaling with maxpool then double conv""" 30 | 31 | def __init__(self, in_channels, out_channels): 32 | super().__init__() 33 | self.maxpool_conv = nn.Sequential( 34 | nn.MaxPool2d(2), 35 | DoubleConv(in_channels, out_channels) 36 | ) 37 | 38 | def forward(self, x): 39 | return self.maxpool_conv(x) 40 | 41 | 42 | class Up(nn.Module): 43 | """Upscaling then double conv""" 44 | 45 | def __init__(self, in_channels, out_channels, bilinear=True): 46 | super().__init__() 47 | 48 | # if bilinear, use the normal convolutions to reduce the number of channels 49 | if bilinear: 50 | self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 51 | self.conv = DoubleConv(in_channels, out_channels, in_channels // 2) 52 | else: 53 | self.up = nn.ConvTranspose2d(in_channels , in_channels // 2, kernel_size=2, stride=2) 54 | self.conv = DoubleConv(in_channels, out_channels) 55 | 56 | 57 | def forward(self, x1, x2): 58 | x1 = self.up(x1) 59 | # # input is CHW 60 | # diffY = x2.size()[2] - x1.size()[2] 61 | # diffX = x2.size()[3] - x1.size()[3] 62 | 63 | # x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, 64 | # diffY // 2, diffY - diffY // 2]) 65 | # if you have padding issues, see 66 | # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a 67 | # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd 68 | x = torch.cat([x2, x1], dim=1) 69 | return self.conv(x) 70 | 71 | 72 | class UpNoConcat(nn.Module): 73 | """Upscaling then double conv""" 74 | 75 | def __init__(self, in_channels, out_channels, bilinear=True): 76 | super().__init__() 77 | 78 | # if bilinear, use the normal convolutions to reduce the number of channels 79 | if bilinear: 80 | self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 81 | self.conv = DoubleConv(in_channels, out_channels, in_channels // 2) 82 | else: 83 | self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2) 84 | self.conv = DoubleConv(in_channels // 2, out_channels) 85 | 86 | def forward(self, x): 87 | x = self.up(x) 88 | return self.conv(x) 89 | 90 | 91 | class OutConv(nn.Module): 92 | def __init__(self, in_channels, out_channels): 93 | super(OutConv, self).__init__() 94 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1) 95 | 96 | def forward(self, x): 97 | return self.conv(x) 98 | -------------------------------------------------------------------------------- /model/unets2/__init__.py: -------------------------------------------------------------------------------- 1 | from .unet_model import UNetS2 2 | -------------------------------------------------------------------------------- /model/unets2/unet_model.py: -------------------------------------------------------------------------------- 1 | """ Full assembly of the parts to form the complete network """ 2 | 3 | import torch.nn.functional as F 4 | 5 | from .unet_parts import * 6 | 7 | 8 | class UNetS2(nn.Module): 9 | def __init__(self, n_channels, n_classes, bilinear=True): 10 | super(UNetS2, self).__init__() 11 | self.n_channels = n_channels 12 | self.n_classes = n_classes 13 | self.bilinear = bilinear 14 | 15 | self.inc = DoubleConv(n_channels, 64) 16 | self.down1 = Down(64, 128) 17 | self.down2 = Down(128, 256) 18 | self.down3 = Down(256, 512) 19 | factor = 2 if bilinear else 1 20 | self.down4 = Down(512, 1024 // factor) 21 | self.up1 = Up(1024, 512 // factor, bilinear) 22 | self.up2 = Up(512, 256 // factor, bilinear) 23 | self.up3 = UpNoConcat(256, 128 // factor, bilinear) 24 | self.up4 = UpNoConcat(128, 64, bilinear) 25 | self.outc = OutConv(64, n_classes) 26 | 27 | def forward(self, x): 28 | x = self.inc(x) 29 | x = self.down1(x) 30 | x3 = self.down2(x) 31 | x4 = self.down3(x3) 32 | x5 = self.down4(x4) 33 | x = self.up1(x5, x4) 34 | x = self.up2(x, x3) 35 | x = self.up3(x) 36 | x = self.up4(x) 37 | logits = self.outc(x) 38 | return logits 39 | -------------------------------------------------------------------------------- /model/unets2/unet_parts.py: -------------------------------------------------------------------------------- 1 | """ Parts of the U-Net model """ 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | class DoubleConv(nn.Module): 9 | """(convolution => [BN] => ReLU) * 2""" 10 | 11 | def __init__(self, in_channels, out_channels, mid_channels=None): 12 | super().__init__() 13 | if not mid_channels: 14 | mid_channels = out_channels 15 | self.double_conv = nn.Sequential( 16 | nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1), 17 | nn.BatchNorm2d(mid_channels), 18 | nn.ReLU(inplace=True), 19 | nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1), 20 | nn.BatchNorm2d(out_channels), 21 | nn.ReLU(inplace=True) 22 | ) 23 | 24 | def forward(self, x): 25 | return self.double_conv(x) 26 | 27 | 28 | class Down(nn.Module): 29 | """Downscaling with maxpool then double conv""" 30 | 31 | def __init__(self, in_channels, out_channels): 32 | super().__init__() 33 | self.maxpool_conv = nn.Sequential( 34 | nn.MaxPool2d(2), 35 | DoubleConv(in_channels, out_channels) 36 | ) 37 | 38 | def forward(self, x): 39 | return self.maxpool_conv(x) 40 | 41 | 42 | class Up(nn.Module): 43 | """Upscaling then double conv""" 44 | 45 | def __init__(self, in_channels, out_channels, bilinear=True): 46 | super().__init__() 47 | 48 | # if bilinear, use the normal convolutions to reduce the number of channels 49 | if bilinear: 50 | self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 51 | self.conv = DoubleConv(in_channels, out_channels, in_channels // 2) 52 | else: 53 | self.up = nn.ConvTranspose2d(in_channels , in_channels // 2, kernel_size=2, stride=2) 54 | self.conv = DoubleConv(in_channels, out_channels) 55 | 56 | 57 | def forward(self, x1, x2): 58 | x1 = self.up(x1) 59 | # input is CHW 60 | # diffY = x2.size()[2] - x1.size()[2] 61 | # diffX = x2.size()[3] - x1.size()[3] 62 | 63 | # x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, 64 | # diffY // 2, diffY - diffY // 2]) 65 | # if you have padding issues, see 66 | # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a 67 | # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd 68 | x = torch.cat([x2, x1], dim=1) 69 | return self.conv(x) 70 | 71 | 72 | class UpNoConcat(nn.Module): 73 | """Upscaling then double conv""" 74 | 75 | def __init__(self, in_channels, out_channels, bilinear=True): 76 | super().__init__() 77 | 78 | # if bilinear, use the normal convolutions to reduce the number of channels 79 | if bilinear: 80 | self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 81 | self.conv = DoubleConv(in_channels, out_channels, in_channels // 2) 82 | else: 83 | self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2) 84 | self.conv = DoubleConv(in_channels // 2, out_channels) 85 | 86 | def forward(self, x): 87 | x = self.up(x) 88 | return self.conv(x) 89 | 90 | 91 | class OutConv(nn.Module): 92 | def __init__(self, in_channels, out_channels): 93 | super(OutConv, self).__init__() 94 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1) 95 | 96 | def forward(self, x): 97 | return self.conv(x) 98 | -------------------------------------------------------------------------------- /model/unets3/__init__.py: -------------------------------------------------------------------------------- 1 | from .unet_model import UNetS3 2 | -------------------------------------------------------------------------------- /model/unets3/unet_model.py: -------------------------------------------------------------------------------- 1 | """ Full assembly of the parts to form the complete network """ 2 | 3 | import torch.nn.functional as F 4 | 5 | from .unet_parts import * 6 | 7 | 8 | class UNetS3(nn.Module): 9 | def __init__(self, n_channels, n_classes, bilinear=True): 10 | super(UNetS3, self).__init__() 11 | self.n_channels = n_channels 12 | self.n_classes = n_classes 13 | self.bilinear = bilinear 14 | 15 | self.inc = DoubleConv(n_channels, 64) 16 | self.down1 = Down(64, 128) 17 | self.down2 = Down(128, 256) 18 | self.down3 = Down(256, 512) 19 | factor = 2 if bilinear else 1 20 | self.down4 = Down(512, 1024 // factor) 21 | self.up1 = Up(1024, 512 // factor, bilinear) 22 | self.up2 = Up(512, 256 // factor, bilinear) 23 | self.up3 = Up(256, 128 // factor, bilinear) 24 | self.up4 = UpNoConcat(128, 64, bilinear) 25 | self.outc = OutConv(64, n_classes) 26 | 27 | def forward(self, x): 28 | x1 = self.inc(x) 29 | x2 = self.down1(x1) 30 | x3 = self.down2(x2) 31 | x4 = self.down3(x3) 32 | x5 = self.down4(x4) 33 | x = self.up1(x5, x4) 34 | x = self.up2(x, x3) 35 | x = self.up3(x, x2) 36 | x = self.up4(x) 37 | logits = self.outc(x) 38 | return logits 39 | -------------------------------------------------------------------------------- /model/unets3/unet_parts.py: -------------------------------------------------------------------------------- 1 | """ Parts of the U-Net model """ 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | class DoubleConv(nn.Module): 9 | """(convolution => [BN] => ReLU) * 2""" 10 | 11 | def __init__(self, in_channels, out_channels, mid_channels=None): 12 | super().__init__() 13 | if not mid_channels: 14 | mid_channels = out_channels 15 | self.double_conv = nn.Sequential( 16 | nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1), 17 | nn.BatchNorm2d(mid_channels), 18 | nn.ReLU(inplace=True), 19 | nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1), 20 | nn.BatchNorm2d(out_channels), 21 | nn.ReLU(inplace=True) 22 | ) 23 | 24 | def forward(self, x): 25 | return self.double_conv(x) 26 | 27 | 28 | class Down(nn.Module): 29 | """Downscaling with maxpool then double conv""" 30 | 31 | def __init__(self, in_channels, out_channels): 32 | super().__init__() 33 | self.maxpool_conv = nn.Sequential( 34 | nn.MaxPool2d(2), 35 | DoubleConv(in_channels, out_channels) 36 | ) 37 | 38 | def forward(self, x): 39 | return self.maxpool_conv(x) 40 | 41 | 42 | class Up(nn.Module): 43 | """Upscaling then double conv""" 44 | 45 | def __init__(self, in_channels, out_channels, bilinear=True): 46 | super().__init__() 47 | 48 | # if bilinear, use the normal convolutions to reduce the number of channels 49 | if bilinear: 50 | self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 51 | self.conv = DoubleConv(in_channels, out_channels, in_channels // 2) 52 | else: 53 | self.up = nn.ConvTranspose2d(in_channels , in_channels // 2, kernel_size=2, stride=2) 54 | self.conv = DoubleConv(in_channels, out_channels) 55 | 56 | 57 | def forward(self, x1, x2): 58 | x1 = self.up(x1) 59 | # input is CHW 60 | # diffY = x2.size()[2] - x1.size()[2] 61 | # diffX = x2.size()[3] - x1.size()[3] 62 | 63 | # x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, 64 | # diffY // 2, diffY - diffY // 2]) 65 | # if you have padding issues, see 66 | # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a 67 | # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd 68 | x = torch.cat([x2, x1], dim=1) 69 | return self.conv(x) 70 | 71 | 72 | class UpNoConcat(nn.Module): 73 | """Upscaling then double conv""" 74 | 75 | def __init__(self, in_channels, out_channels, bilinear=True): 76 | super().__init__() 77 | 78 | # if bilinear, use the normal convolutions to reduce the number of channels 79 | if bilinear: 80 | self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 81 | self.conv = DoubleConv(in_channels, out_channels, in_channels // 2) 82 | else: 83 | self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2) 84 | self.conv = DoubleConv(in_channels // 2, out_channels) 85 | 86 | def forward(self, x): 87 | x = self.up(x) 88 | return self.conv(x) 89 | 90 | 91 | class OutConv(nn.Module): 92 | def __init__(self, in_channels, out_channels): 93 | super(OutConv, self).__init__() 94 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1) 95 | 96 | def forward(self, x): 97 | return self.conv(x) 98 | -------------------------------------------------------------------------------- /utils/__pycache__/calculate_weights.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LK-Peng/CNN-based-Cloud-Detection-Methods/1393a6886e62f1ed5a612d57c5a725c763a6b2cc/utils/__pycache__/calculate_weights.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/f_boundary.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LK-Peng/CNN-based-Cloud-Detection-Methods/1393a6886e62f1ed5a612d57c5a725c763a6b2cc/utils/__pycache__/f_boundary.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/img_saver.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LK-Peng/CNN-based-Cloud-Detection-Methods/1393a6886e62f1ed5a612d57c5a725c763a6b2cc/utils/__pycache__/img_saver.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/loss.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LK-Peng/CNN-based-Cloud-Detection-Methods/1393a6886e62f1ed5a612d57c5a725c763a6b2cc/utils/__pycache__/loss.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/lr_scheduler.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LK-Peng/CNN-based-Cloud-Detection-Methods/1393a6886e62f1ed5a612d57c5a725c763a6b2cc/utils/__pycache__/lr_scheduler.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/metrics.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LK-Peng/CNN-based-Cloud-Detection-Methods/1393a6886e62f1ed5a612d57c5a725c763a6b2cc/utils/__pycache__/metrics.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/net_convert.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LK-Peng/CNN-based-Cloud-Detection-Methods/1393a6886e62f1ed5a612d57c5a725c763a6b2cc/utils/__pycache__/net_convert.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/saver.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LK-Peng/CNN-based-Cloud-Detection-Methods/1393a6886e62f1ed5a612d57c5a725c763a6b2cc/utils/__pycache__/saver.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/summaries.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LK-Peng/CNN-based-Cloud-Detection-Methods/1393a6886e62f1ed5a612d57c5a725c763a6b2cc/utils/__pycache__/summaries.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/tracker.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LK-Peng/CNN-based-Cloud-Detection-Methods/1393a6886e62f1ed5a612d57c5a725c763a6b2cc/utils/__pycache__/tracker.cpython-37.pyc -------------------------------------------------------------------------------- /utils/calculate_weights.py: -------------------------------------------------------------------------------- 1 | import os 2 | from tqdm import tqdm 3 | import numpy as np 4 | 5 | 6 | def calculate_weigths_labels(train_list_dir, dataset, dataloader, num_classes): 7 | # Create an instance from the data loader 8 | z = np.zeros((num_classes,)) 9 | # Initialize tqdm 10 | tqdm_batch = tqdm(dataloader) 11 | print('Calculating classes weights') 12 | for sample in tqdm_batch: 13 | y = sample['label'] 14 | y = y.detach().cpu().numpy() 15 | mask = (y >= 0) & (y < num_classes) 16 | labels = y[mask].astype(np.uint8) 17 | count_l = np.bincount(labels, minlength=num_classes) 18 | z += count_l 19 | tqdm_batch.close() 20 | total_frequency = np.sum(z) 21 | class_weights = [] 22 | for frequency in z: 23 | class_weight = 1 / (np.log(1.02 + (frequency / total_frequency))) 24 | class_weights.append(class_weight) 25 | ret = np.array(class_weights) 26 | classes_weights_path = os.path.join(train_list_dir, dataset+'_classes_weights.npy') 27 | np.save(classes_weights_path, ret) 28 | 29 | return ret -------------------------------------------------------------------------------- /utils/f_boundary.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | # Code adapted from: 5 | # https://github.com/fperazzi/davis/blob/master/python/lib/davis/measures/f_boundary.py 6 | # 7 | # Source License 8 | # 9 | # BSD 3-Clause License 10 | # 11 | # Copyright (c) 2017, 12 | # All rights reserved. 13 | # 14 | # Redistribution and use in source and binary forms, with or without 15 | # modification, are permitted provided that the following conditions are met: 16 | # 17 | # * Redistributions of source code must retain the above copyright notice, this 18 | # list of conditions and the following disclaimer. 19 | # 20 | # * Redistributions in binary form must reproduce the above copyright notice, 21 | # this list of conditions and the following disclaimer in the documentation 22 | # and/or other materials provided with the distribution. 23 | # 24 | # * Neither the name of the copyright holder nor the names of its 25 | # contributors may be used to endorse or promote products derived from 26 | # this software without specific prior written permission. 27 | # 28 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 29 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 30 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 31 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 32 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 33 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 34 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 35 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 36 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 37 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.s 38 | ############################################################################## 39 | # 40 | # Based on: 41 | # ---------------------------------------------------------------------------- 42 | # A Benchmark Dataset and Evaluation Methodology for Video Object Segmentation 43 | # Copyright (c) 2016 Federico Perazzi 44 | # Licensed under the BSD License [see LICENSE for details] 45 | # Written by Federico Perazzi 46 | # ---------------------------------------------------------------------------- 47 | """ 48 | 49 | import math 50 | import numpy as np 51 | from multiprocessing import Pool 52 | from tqdm import tqdm 53 | 54 | """ Utilities for computing, reading and saving benchmark evaluation.""" 55 | 56 | 57 | def eval_mask_boundary(seg_mask, gt_mask, num_classes, p=None, num_proc=10, bound_th=0.008): 58 | """ 59 | Compute F score for a segmentation mask 60 | Arguments: 61 | seg_mask (ndarray): segmentation mask prediction 62 | gt_mask (ndarray): segmentation mask ground truth 63 | num_classes (int): number of classes 64 | p: Multiprocess, if p is not None, num_proc is invalid 65 | num_proc: number of processes 66 | Returns: 67 | F (float): mean F score across all classes 68 | Fpc (list of float): F score per class 69 | """ 70 | if num_proc > 1 and not p: 71 | p = Pool(processes=num_proc) 72 | batch_size = seg_mask.shape[0] 73 | 74 | # Fpc = list(np.zeros(num_classes)) 75 | # Ppc = list(np.zeros(num_classes)) 76 | # Rpc = list(np.zeros(num_classes)) 77 | confusion_matrix_pc = np.zeros((num_classes, 4)) 78 | # for class_id in tqdm(range(num_classes)): 79 | for class_id in range(num_classes): 80 | args = [((seg_mask[i] == class_id).astype(np.uint8), 81 | (gt_mask[i] == class_id).astype(np.uint8), 82 | gt_mask[i] == 255, 83 | bound_th) 84 | for i in range(batch_size)] 85 | if p: 86 | temp = p.map(db_eval_boundary_wrapper, args) 87 | else: 88 | temp = [db_eval_boundary_wrapper(args[i]) for i in range(batch_size)] 89 | temp = np.array(temp) 90 | 91 | # # F score 92 | # Fs = temp[:, 0] 93 | # Fs[np.isnan(Fs)] = 0 # if valid batch? 94 | # Fpc[class_id] = Fs # f-score of every batch 95 | 96 | # # precision 97 | # Ps = temp[:, 1] 98 | # Ps[np.isnan(Ps)] = 0 99 | # Ppc[class_id] = Ps 100 | 101 | # # recall 102 | # Rs = temp[:, 2] 103 | # Rs[np.isnan(Rs)] = 0 104 | # Rpc[class_id] = Rs 105 | # return {'F': np.array(Fpc).transpose(1, 0), 106 | # 'Precision': np.array(Ppc).transpose(1, 0), 107 | # 'Recall': np.array(Rpc).transpose(1, 0)} 108 | confusion_matrix_pc[class_id, :] = np.sum(temp, axis=0) 109 | return confusion_matrix_pc 110 | 111 | 112 | def db_eval_boundary_wrapper(args): 113 | foreground_mask, gt_mask, ignore, bound_th = args 114 | return db_eval_boundary(foreground_mask, gt_mask, ignore, bound_th) 115 | 116 | 117 | def db_eval_boundary(foreground_mask, gt_mask, ignore_mask, bound_th=0.008): 118 | """ 119 | Compute mean,recall and decay from per-frame evaluation. 120 | Calculates precision/recall for boundaries between foreground_mask and 121 | gt_mask using morphological operators to speed it up. 122 | Arguments: 123 | foreground_mask (ndarray): binary segmentation image. 124 | gt_mask (ndarray): binary annotated image. 125 | Returns: 126 | F (float): boundaries F-measure 127 | P (float): boundaries precision 128 | R (float): boundaries recall 129 | """ 130 | assert np.atleast_3d(foreground_mask).shape[2] == 1 131 | 132 | bound_pix = bound_th if bound_th >= 1 else \ 133 | np.ceil(bound_th * np.linalg.norm(foreground_mask.shape)) 134 | 135 | # print(bound_pix) 136 | # print(gt.shape) 137 | # print(np.unique(gt)) 138 | foreground_mask[ignore_mask] = 0 139 | gt_mask[ignore_mask] = 0 140 | 141 | # Get the pixel boundaries of both masks 142 | fg_boundary = seg2bmap(foreground_mask) 143 | gt_boundary = seg2bmap(gt_mask) 144 | 145 | from skimage.morphology import binary_dilation, disk 146 | 147 | fg_dil = binary_dilation(fg_boundary, disk(bound_pix)) 148 | gt_dil = binary_dilation(gt_boundary, disk(bound_pix)) 149 | 150 | # Get the intersection 151 | gt_match = gt_boundary * fg_dil 152 | fg_match = fg_boundary * gt_dil 153 | 154 | # Area of the intersection 155 | n_fg = np.sum(fg_boundary) 156 | n_gt = np.sum(gt_boundary) 157 | 158 | # # % Compute precision and recall 159 | # if n_fg == 0 and n_gt > 0: 160 | # precision = 1 161 | # recall = 0 162 | # elif n_fg > 0 and n_gt == 0: 163 | # precision = 0 164 | # recall = 1 165 | # elif n_fg == 0 and n_gt == 0: 166 | # precision = 1 167 | # recall = 1 168 | # else: 169 | # precision = np.sum(fg_match) / float(n_fg) 170 | # recall = np.sum(gt_match) / float(n_gt) 171 | 172 | # # Compute F measure 173 | # if precision + recall == 0: 174 | # F = 0 175 | # else: 176 | # F = 2 * precision * recall / (precision + recall) 177 | 178 | # return F, precision, recall 179 | return np.sum(fg_match), float(n_fg), np.sum(gt_match), float(n_gt) 180 | 181 | 182 | def seg2bmap(seg, width=None, height=None): 183 | """ 184 | From a segmentation, compute a binary boundary map with 1 pixel wide 185 | boundaries. The boundary pixels are offset by 1/2 pixel towards the 186 | origin from the actual segment boundary. 187 | Arguments: 188 | seg : Segments labeled from 1..k. 189 | width : Width of desired bmap <= seg.shape[1] 190 | height : Height of desired bmap <= seg.shape[0] 191 | Returns: 192 | bmap (ndarray): Binary boundary map. 193 | David Martin 194 | January 2003 195 | """ 196 | 197 | seg = seg.astype(np.bool) 198 | seg[seg > 0] = 1 199 | 200 | assert np.atleast_3d(seg).shape[2] == 1 201 | 202 | width = seg.shape[1] if width is None else width 203 | height = seg.shape[0] if height is None else height 204 | 205 | h, w = seg.shape[:2] 206 | 207 | ar1 = float(width) / float(height) 208 | ar2 = float(w) / float(h) 209 | 210 | assert not (width > w | height > h | abs(ar1 - ar2) > 0.01), \ 211 | 'Can''t convert %dx%d seg to %dx%d bmap.' % (w, h, width, height) 212 | 213 | e = np.zeros_like(seg) 214 | s = np.zeros_like(seg) 215 | se = np.zeros_like(seg) 216 | 217 | e[:, :-1] = seg[:, 1:] 218 | s[:-1, :] = seg[1:, :] 219 | se[:-1, :-1] = seg[1:, 1:] 220 | 221 | b = seg ^ e | seg ^ s | seg ^ se 222 | b[-1, :] = seg[-1, :] ^ e[-1, :] 223 | b[:, -1] = seg[:, -1] ^ s[:, -1] 224 | b[-1, -1] = 0 225 | 226 | if w == width and h == height: 227 | bmap = b 228 | else: 229 | bmap = np.zeros((height, width)) 230 | for x in range(w): 231 | for y in range(h): 232 | if b[y, x]: 233 | j = 1 + math.floor((y - 1) * height / h) 234 | i = 1 + math.floor((x - 1) * width / h) 235 | bmap[j, i] = 1 236 | 237 | return bmap 238 | -------------------------------------------------------------------------------- /utils/img_saver.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from osgeo import gdal 3 | 4 | 5 | def save_img(tiff, out_file, proj=None, geot=(0, 30, 0, 0, 0, 30)): 6 | """ save tiff image """ 7 | NP2GDAL_CONVERSION = { 8 | "uint8": 1, 9 | "int8": 1, 10 | "uint16": 2, 11 | "int16": 3, 12 | "uint32": 4, 13 | "int32": 5, 14 | "float32": 6, 15 | "float64": 7, 16 | "complex64": 10, 17 | "complex128": 11, 18 | } # convert np to gdal 19 | gdal_type = NP2GDAL_CONVERSION[tiff.dtype.name] 20 | if len(tiff.shape) == 2: 21 | tiff = np.expand_dims(tiff, axis=0) 22 | channel, row, col = tiff.shape 23 | # create data set 24 | gtiff_driver = gdal.GetDriverByName('GTiff') 25 | out_ds = gtiff_driver.Create(out_file, col, row, channel, gdal_type) 26 | if proj is not None: 27 | out_ds.SetProjection(proj) # projection 28 | if geot is not None: 29 | out_ds.SetGeoTransform(geot) # geotransform 30 | # write 31 | for iband in range(channel): 32 | out_ds.GetRasterBand(iband+1).WriteArray(tiff[iband, :, :]) 33 | del out_ds 34 | -------------------------------------------------------------------------------- /utils/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class SegmentationLosses(object): 6 | def __init__(self, weight=None, reduction='mean', batch_average=False, ignore_index=255, cuda=False): 7 | self.ignore_index = ignore_index 8 | self.weight = weight 9 | self.reduction = reduction 10 | self.batch_average = batch_average # When 'reduction' is set to 'mean', 'batch_average' is redundant. 11 | self.cuda = cuda 12 | 13 | def build_loss(self, mode='ce'): 14 | """Choices: ['ce' or 'focal']""" 15 | if mode == 'ce': 16 | return self.CrossEntropyLoss 17 | elif mode == 'focal': 18 | return self.FocalLoss 19 | elif mode == 'wb': 20 | return self.WeightedBalanceLoss 21 | else: 22 | raise NotImplementedError 23 | 24 | def CrossEntropyLoss(self, logit, target): 25 | n, c, h, w = logit.size() 26 | criterion = nn.CrossEntropyLoss(weight=self.weight, ignore_index=self.ignore_index, 27 | reduction=self.reduction) 28 | if self.cuda: 29 | criterion = criterion.cuda() 30 | 31 | loss = criterion(logit, target.long()) 32 | 33 | if self.batch_average: 34 | loss /= n 35 | 36 | return loss 37 | 38 | def FocalLoss(self, logit, target, gamma=2, alpha=0.5): 39 | n, c, h, w = logit.size() 40 | criterion = nn.CrossEntropyLoss(weight=self.weight, ignore_index=self.ignore_index, 41 | reduction=self.reduction) 42 | if self.cuda: 43 | criterion = criterion.cuda() 44 | 45 | logpt = -criterion(logit, target.long()) 46 | pt = torch.exp(logpt) 47 | if alpha is not None: 48 | logpt *= alpha 49 | loss = -((1 - pt) ** gamma) * logpt 50 | 51 | if self.batch_average: 52 | loss /= n 53 | 54 | return loss 55 | 56 | def WeightedBalanceLoss(self, logit, target, T=0.4): 57 | # n, c, h, w = logit.size() 58 | # 59 | # m = nn.LogSoftmax(dim=1) 60 | # logit_prob = m(logit) 61 | # 62 | # logit_label = torch.argmax(logit_prob, dim=1) 63 | # w1 = torch.zeros((n, h, w)).cuda() 64 | # w1[(target.long() - logit_label) == 1] = 1 65 | # w1 = w1.unsqueeze(1) 66 | # 67 | # w2 = torch.ones((n, h, w))*T 68 | # w2 = np.maximum(w2, torch.exp(logit_prob[:, 1, :, :]).cpu().detach()) 69 | # w2 = w2.cuda().unsqueeze(1) 70 | # 71 | # target = target.unsqueeze(1) 72 | # loss = torch.mean(-w1*target*logit_prob[:, 1, :, :] - w2*(1-target)*logit_prob[:, 0, :, :]) 73 | # 74 | # if self.batch_average: 75 | # loss /= n 76 | # 77 | # return loss 78 | pass 79 | 80 | 81 | if __name__ == "__main__": 82 | loss = SegmentationLosses(cuda=True) 83 | a = torch.rand(4, 2, 7, 7).cuda() 84 | b = torch.rand(4, 7, 7).cuda() 85 | print(loss.CrossEntropyLoss(a, b).item()) 86 | print(loss.FocalLoss(a, b, gamma=0, alpha=None).item()) 87 | print(loss.FocalLoss(a, b, gamma=2, alpha=0.5).item()) 88 | print(loss.WeightedBalanceLoss(a, b).item()) 89 | 90 | a = torch.rand(4, 2, 7, 7).cuda() 91 | b = torch.rand(4, 7, 7).cuda() 92 | print(loss.CrossEntropyLoss(a, b).item()) 93 | print(loss.FocalLoss(a, b, gamma=0, alpha=None).item()) 94 | print(loss.FocalLoss(a, b, gamma=2, alpha=0.5).item()) 95 | 96 | a = torch.rand(4, 2, 7, 7).cuda() 97 | b = torch.rand(4, 7, 7).cuda() 98 | print(loss.CrossEntropyLoss(a, b).item()) 99 | print(loss.FocalLoss(a, b, gamma=0, alpha=None).item()) 100 | print(loss.FocalLoss(a, b, gamma=2, alpha=0.5).item()) 101 | 102 | 103 | -------------------------------------------------------------------------------- /utils/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 | ## Created by: Hang Zhang 3 | ## ECE Department, Rutgers University 4 | ## Email: zhang.hang@rutgers.edu 5 | ## Copyright (c) 2017 6 | ## 7 | ## This source code is licensed under the MIT-style license found in the 8 | ## LICENSE file in the root directory of this source tree 9 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 10 | 11 | import math 12 | 13 | 14 | class LR_Scheduler(object): 15 | """Learning Rate Scheduler 16 | 17 | Step mode: ``lr = baselr * 0.1 ^ {floor(epoch-1 / lr_step)}`` 18 | 19 | Cosine mode: ``lr = baselr * 0.5 * (1 + cos(iter/maxiter))`` 20 | 21 | Poly mode: ``lr = baselr * (1 - iter/maxiter) ^ 0.9`` 22 | 23 | Args: 24 | args: 25 | :attr:`args.lr_scheduler` lr scheduler mode (`cos`, `poly`), 26 | :attr:`args.lr` base learning rate, :attr:`args.epochs` number of epochs, 27 | :attr:`args.lr_step` 28 | 29 | iters_per_epoch: number of iterations per epoch 30 | """ 31 | def __init__(self, mode, base_lr, num_epochs, iters_per_epoch=0, 32 | lr_step=20, warmup_epochs=0): 33 | self.mode = mode 34 | print('Using {} LR Scheduler!'.format(self.mode)) 35 | self.lr = base_lr 36 | if mode == 'step': 37 | assert lr_step 38 | self.lr_step = lr_step 39 | self.iters_per_epoch = iters_per_epoch 40 | self.N = num_epochs * iters_per_epoch 41 | self.epoch = -1 42 | self.warmup_iters = warmup_epochs * iters_per_epoch 43 | 44 | def __call__(self, optimizer, i, epoch, best_pred): 45 | T = epoch * self.iters_per_epoch + i 46 | if self.mode == 'cos': 47 | lr = 0.5 * self.lr * (1 + math.cos(1.0 * T / self.N * math.pi)) 48 | elif self.mode == 'poly': 49 | lr = self.lr * pow((1 - 1.0 * T / self.N), 0.9) 50 | elif self.mode == 'step': 51 | lr = self.lr * (0.1 ** (epoch // self.lr_step)) 52 | else: 53 | raise NotImplemented 54 | # warm up lr schedule 55 | if self.warmup_iters > 0 and T < self.warmup_iters: 56 | lr = lr * 1.0 * T / self.warmup_iters 57 | if epoch > self.epoch: 58 | print('\n=>Epoches %i, learning rate = %.4f, \ 59 | previous best = %.4f' % (epoch, lr, best_pred)) 60 | self.epoch = epoch 61 | assert lr >= 0 62 | self._adjust_learning_rate(optimizer, lr) 63 | 64 | def _adjust_learning_rate(self, optimizer, lr): 65 | if len(optimizer.param_groups) == 1: 66 | optimizer.param_groups[0]['lr'] = lr 67 | else: 68 | # enlarge the lr at the head 69 | optimizer.param_groups[0]['lr'] = lr 70 | for i in range(1, len(optimizer.param_groups)): 71 | optimizer.param_groups[i]['lr'] = lr * 10 72 | -------------------------------------------------------------------------------- /utils/metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | # 忽略除数为0的warning, 对于除数为0的情况已有相应处理 4 | np.seterr(divide='ignore', invalid='ignore') 5 | 6 | from utils.f_boundary import eval_mask_boundary 7 | 8 | 9 | class Evaluator(object): 10 | def __init__(self, num_class): 11 | self.num_class = num_class 12 | self.confusion_matrix = np.zeros((self.num_class,)*2) 13 | 14 | def Pixel_Accuracy(self): 15 | Acc = np.diag(self.confusion_matrix).sum() / self.confusion_matrix.sum() 16 | return Acc 17 | 18 | def Pixel_Accuracy_Class(self): 19 | Acc = np.diag(self.confusion_matrix) / self.confusion_matrix.sum(axis=1) 20 | Acc = np.nanmean(Acc) 21 | return Acc 22 | 23 | def Precision(self): 24 | assert self.num_class == 2 25 | pr = self.confusion_matrix[1, 1] / (self.confusion_matrix[1, 1] + self.confusion_matrix[0, 1]) 26 | return 1.0 if np.isnan(pr) else pr 27 | 28 | def Recall(self): 29 | assert self.num_class == 2 30 | re = self.confusion_matrix[1, 1] / (self.confusion_matrix[1, 1] + self.confusion_matrix[1, 0]) 31 | return 1.0 if np.isnan(re) else re 32 | 33 | def F_score(self): 34 | assert self.num_class == 2 35 | pr, re = self.Precision(), self.Recall() 36 | if pr + re == 0: 37 | return 0.0 38 | else: 39 | return 2.0 * pr * re / (pr + re) 40 | 41 | def Mean_Intersection_over_Union(self): 42 | MIoU = np.diag(self.confusion_matrix) / ( 43 | np.sum(self.confusion_matrix, axis=1) + np.sum(self.confusion_matrix, axis=0) - 44 | np.diag(self.confusion_matrix)) 45 | # MIoU = np.nanmean(MIoU) 46 | return MIoU[1] 47 | 48 | def Frequency_Weighted_Intersection_over_Union(self): 49 | freq = np.sum(self.confusion_matrix, axis=1) / np.sum(self.confusion_matrix) 50 | iu = np.diag(self.confusion_matrix) / ( 51 | np.sum(self.confusion_matrix, axis=1) + np.sum(self.confusion_matrix, axis=0) - 52 | np.diag(self.confusion_matrix)) 53 | 54 | FWIoU = (freq[freq > 0] * iu[freq > 0]).sum() 55 | return FWIoU 56 | 57 | def _generate_matrix(self, gt_image, pre_image): 58 | mask = (gt_image >= 0) & (gt_image < self.num_class) 59 | label = self.num_class * gt_image[mask].astype('int') + pre_image[mask] 60 | count = np.bincount(label, minlength=self.num_class**2) 61 | confusion_matrix = count.reshape(self.num_class, self.num_class) 62 | return confusion_matrix 63 | 64 | def add_batch(self, gt_image, pre_image): 65 | assert gt_image.shape == pre_image.shape 66 | self.confusion_matrix += self._generate_matrix(gt_image, pre_image) 67 | 68 | def reset(self): 69 | self.confusion_matrix = np.zeros((self.num_class,) * 2) 70 | 71 | 72 | class BoundaryEvaluator(object): 73 | def __init__(self, num_class, p=None, num_proc=10, bound_th=0.008): 74 | self.num_class = num_class 75 | self.p = p 76 | self.num_proc = num_proc 77 | self.bound_th = bound_th 78 | self.confusion_matrix_pc = np.zeros((self.num_class, 4)) 79 | 80 | def Precision_boundary(self): 81 | pr = self.confusion_matrix_pc[:, 0] / self.confusion_matrix_pc[:, 1] 82 | pr[np.isnan(pr)] = 1.0 83 | return pr 84 | 85 | def Recall_boundary(self): 86 | re = self.confusion_matrix_pc[:, 2] / self.confusion_matrix_pc[:, 3] 87 | re[np.isnan(re)] = 1.0 88 | return re 89 | 90 | def F_score_boundary(self): 91 | pr, re = self.Precision_boundary(), self.Recall_boundary() 92 | f_score = 2 * pr * re / (pr + re) 93 | f_score[np.isnan(f_score)] = 0.0 94 | return f_score 95 | 96 | def _generate_matrix(self, gt_image, pre_image): 97 | if len(gt_image.shape) == 2: 98 | pre_image, gt_image = np.expand_dims(pre_image, axis=0), np.expand_dims(gt_image, axis=0) 99 | return eval_mask_boundary(pre_image, gt_image, self.num_class, self.p, self.num_proc, self.bound_th) 100 | 101 | def add_batch(self, gt_image, pre_image): 102 | assert gt_image.shape == pre_image.shape 103 | self.confusion_matrix_pc += self._generate_matrix(gt_image, pre_image) 104 | 105 | def reset(self): 106 | self.confusion_matrix_pc = np.zeros((self.num_class, 4)) 107 | -------------------------------------------------------------------------------- /utils/net_convert.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class NetConvert1(nn.Module): 6 | ''' 7 | 获取UNet第一个池化及之后的结构 8 | ''' 9 | def __init__(self, model, net_name): 10 | super(NetConvert1, self).__init__() 11 | 12 | self.net_name = net_name 13 | # 去掉第一个池化层前的层 14 | self.down1 = model.down1 15 | if net_name == 'UNet-2': 16 | self.down2 = model.down2 17 | self.up3 = model.up3 18 | if net_name == 'UNet-3': 19 | self.down2 = model.down2 20 | self.down3 =model.down3 21 | self.up2 = model.up2 22 | self.up3 = model.up3 23 | if net_name in ['UNet', 'UNet-dilation']: 24 | self.down2 = model.down2 25 | self.down3 = model.down3 26 | self.down4 = model.down4 27 | self.up1 = model.up1 28 | self.up2 = model.up2 29 | self.up3 = model.up3 30 | self.up4 = model.up4 31 | self.outc = model.outc 32 | 33 | def forward(self, x): 34 | x1 = x 35 | x2 = self.down1(x1) 36 | if self.net_name == 'UNet-2': 37 | x3 = self.down2(x2) 38 | x = self.up3(x3, x2) 39 | if self.net_name == 'UNet-3': 40 | x3 = self.down2(x2) 41 | x4 = self.down3(x3) 42 | x = self.up2(x4, x3) 43 | x = self.up3(x, x2) 44 | if self.net_name in ['UNet', 'UNet-dilation']: 45 | x3 = self.down2(x2) 46 | x4 = self.down3(x3) 47 | x5 = self.down4(x4) 48 | x = self.up1(x5, x4) 49 | x = self.up2(x, x3) 50 | x = self.up3(x, x2) 51 | if self.net_name == 'UNet-1': 52 | x = self.up4(x2, x1) 53 | else: 54 | x = self.up4(x, x1) 55 | x = self.outc(x) 56 | return x 57 | 58 | 59 | class NetConvert2(nn.Module): 60 | ''' 61 | 获取UNet中最后一个上采样层之后的结构 62 | ''' 63 | def __init__(self, model): 64 | super(NetConvert2, self).__init__() 65 | 66 | # 截取出最后一个反卷积之后的层 67 | self.up4_conv = nn.Sequential(*list(model.up4.children())[1:]) 68 | self.outc = model.outc 69 | 70 | def forward(self, x): 71 | ''' 72 | x1: 上采样得到 73 | x: skip-connection得到 74 | ''' 75 | x = self.up4_conv(x) 76 | x = self.outc(x) 77 | return x 78 | 79 | 80 | class NetConvert3(nn.Module): 81 | ''' 82 | 获取UNet中最后一个上采样层及之前的结构 83 | ''' 84 | def __init__(self, model, net_name): 85 | super(NetConvert3, self).__init__() 86 | 87 | self.net_name = net_name 88 | # 截取出最后一个反卷积之后的层 89 | self.inc = model.inc 90 | self.down1 = model.down1 91 | if net_name == 'UNet-2': 92 | self.down2 = model.down2 93 | self.up3 = model.up3 94 | if net_name == 'UNet-3': 95 | self.down2 = model.down2 96 | self.down3 = model.down3 97 | self.up2 = model.up2 98 | self.up3 = model.up3 99 | if net_name in ['UNet', 'UNet-dilation']: 100 | self.down2 = model.down2 101 | self.down3 = model.down3 102 | self.down4 = model.down4 103 | self.up1 = model.up1 104 | self.up2 = model.up2 105 | self.up3 = model.up3 106 | self.up_conv = model.up4.up 107 | 108 | def forward(self, x): 109 | x1 = self.inc(x) 110 | x2 = self.down1(x1) 111 | if self.net_name == 'UNet-2': 112 | x3 = self.down2(x2) 113 | x = self.up3(x3, x2) 114 | if self.net_name == 'UNet-3': 115 | x3 = self.down2(x2) 116 | x4 = self.down3(x3) 117 | x = self.up2(x4, x3) 118 | x = self.up3(x, x2) 119 | if self.net_name in ['UNet', 'UNet-dilation']: 120 | x3 = self.down2(x2) 121 | x4 = self.down3(x3) 122 | x5 = self.down4(x4) 123 | x = self.up1(x5, x4) 124 | x = self.up2(x, x3) 125 | x = self.up3(x, x2) 126 | if self.net_name == 'UNet-1': 127 | x = self.up_conv(x2) 128 | else: 129 | x = self.up_conv(x) 130 | return x 131 | 132 | 133 | class NetConvertShort(nn.Module): 134 | ''' 135 | 获取UNet中最后一个上采样层之后的结构 136 | ''' 137 | def __init__(self, model): 138 | super(NetConvertShort, self).__init__() 139 | 140 | self.inc = model.inc 141 | # 截取出最后一个反卷积之后的层 142 | self.up4_conv = nn.Sequential(*list(model.up4.children())[1:]) 143 | self.outc = model.outc 144 | 145 | def forward(self, x): 146 | ''' 147 | x1: 上采样得到 148 | x: skip-connection得到 149 | ''' 150 | x1 = self.inc(x[:, 0:8, :]) 151 | x1 = torch.cat([x1, x[:, 8:, :]], dim=1) 152 | x1 = self.up4_conv(x1) 153 | x1 = self.outc(x1) 154 | return x1 155 | -------------------------------------------------------------------------------- /utils/saver.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import torch 4 | from collections import OrderedDict 5 | import glob 6 | 7 | 8 | class Saver(object): 9 | 10 | def __init__(self, args): 11 | self.args = args 12 | self.directory = os.path.join('run', args.dataset, args.checkname) 13 | self.runs = sorted(glob.glob(os.path.join(self.directory, 'experiment_*'))) 14 | run_id = int(self.runs[-1].split('_')[-1]) + 1 if self.runs else 0 15 | 16 | self.experiment_dir = os.path.join(self.directory, 'experiment_{}'.format(str(run_id))) 17 | if not os.path.exists(self.experiment_dir): 18 | os.makedirs(self.experiment_dir) 19 | 20 | def save_checkpoint(self, state, is_best, filename='checkpoint.pth.tar'): 21 | """Saves checkpoint to disk""" 22 | filename = os.path.join(self.experiment_dir, filename) 23 | torch.save(state, filename) 24 | if is_best: 25 | best_pred = state['best_pred'] 26 | with open(os.path.join(self.experiment_dir, 'best_pred.txt'), 'w') as f: 27 | f.write(str(best_pred)) 28 | if self.runs: 29 | previous_miou = [0.0] 30 | for run in self.runs: 31 | run_id = run.split('_')[-1] 32 | path = os.path.join(self.directory, 'experiment_{}'.format(str(run_id)), 'best_pred.txt') 33 | if os.path.exists(path): 34 | with open(path, 'r') as f: 35 | miou = float(f.readline()) 36 | previous_miou.append(miou) 37 | else: 38 | continue 39 | max_miou = max(previous_miou) 40 | if best_pred > max_miou: 41 | shutil.copyfile(filename, os.path.join(self.directory, 'model_best.pth.tar')) 42 | else: 43 | shutil.copyfile(filename, os.path.join(self.directory, 'model_best.pth.tar')) 44 | 45 | def save_experiment_config(self): 46 | logfile = os.path.join(self.experiment_dir, 'parameters.txt') 47 | log_file = open(logfile, 'w') 48 | p = OrderedDict() 49 | p['datset'] = self.args.dataset 50 | p['backbone'] = self.args.backbone 51 | p['out_stride'] = self.args.out_stride 52 | p['lr'] = self.args.lr 53 | p['lr_scheduler'] = self.args.lr_scheduler 54 | p['loss_type'] = self.args.loss_type 55 | p['epoch'] = self.args.epochs 56 | # p['base_size'] = self.args.base_size 57 | # p['crop_size'] = self.args.crop_size 58 | 59 | for key, val in p.items(): 60 | log_file.write(key + ':' + str(val) + '\n') 61 | log_file.close() -------------------------------------------------------------------------------- /utils/summaries.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from torchvision.utils import make_grid 4 | from torch.utils.tensorboard import SummaryWriter 5 | 6 | 7 | class TensorboardSummary(object): 8 | def __init__(self, directory): 9 | self.directory = directory 10 | 11 | def create_summary(self): 12 | writer = SummaryWriter(log_dir=os.path.join(self.directory)) 13 | return writer 14 | 15 | def visualize_image(self, writer, dataset, image, target, output, global_step): 16 | pass 17 | -------------------------------------------------------------------------------- /utils/tracker.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import pandas as pd 4 | from collections import OrderedDict 5 | 6 | 7 | class Tracker(object): 8 | 9 | def __init__(self, run_directory, run_filename='training_loop_performance'): 10 | self.out_csv_file = os.path.join(run_directory, run_filename + '.csv') 11 | self.out_json_file = os.path.join(run_directory, run_filename + '.json') 12 | self.run_data = [] 13 | self.results_epoch = OrderedDict() 14 | 15 | def begin_epoch(self): 16 | self.results_epoch = OrderedDict() 17 | 18 | def train_epoch(self, epoch, train_loss, lr): 19 | self.results_epoch['epoch'] = epoch 20 | self.results_epoch['train loss'] = train_loss 21 | self.results_epoch['learning rate'] = lr 22 | 23 | def val_epoch(self, epoch, val_loss, pa, mpa, miou, fwiou): 24 | assert epoch == self.results_epoch['epoch'] 25 | self.results_epoch['val loss'] = val_loss 26 | self.results_epoch['PA'] = pa 27 | self.results_epoch['MPA'] = mpa 28 | self.results_epoch['MIoU'] = miou 29 | self.results_epoch['FWIoU'] = fwiou 30 | 31 | def end_epoch(self): 32 | self.run_data.append(self.results_epoch) 33 | self.results_epoch = OrderedDict() 34 | self._save() 35 | 36 | def _save(self): 37 | # save as csv file 38 | pd.DataFrame.from_dict( 39 | self.run_data, orient='columns' 40 | ).to_csv(self.out_csv_file) 41 | 42 | # save as json file 43 | with open(self.out_json_file, 'w', encoding='utf-8') as f: 44 | json.dump(self.run_data, f, ensure_ascii=False, indent=4) 45 | --------------------------------------------------------------------------------