├── README.md ├── __pycache__ ├── data_config.cpython-37.pyc └── utils.cpython-37.pyc ├── data_config.py ├── data_preparation ├── dsifn_cd_256.m ├── find_mean_std.py └── levir_cd_256.m ├── datasets ├── CD_dataset.py ├── __pycache__ │ ├── CD_dataset.cpython-37.pyc │ └── data_utils.cpython-37.pyc └── data_utils.py ├── demo_LEVIR.py ├── eval_cd.py ├── images ├── .DS_Store ├── Figure 1.jpg ├── Figure 10.jpg └── Figure 9.jpg ├── main_cd.py ├── misc ├── __pycache__ │ ├── imutils.cpython-37.pyc │ ├── logger_tool.cpython-37.pyc │ └── metric_tool.cpython-37.pyc ├── imutils.py ├── logger_tool.py ├── metric_tool.py ├── pyutils.py └── torchutils.py ├── models ├── EGCTNet.py ├── TransformerBaseNetworks.py ├── __init__.py ├── __pycache__ │ ├── EGCTNet.cpython-37.pyc │ ├── TransformerBaseNetworks.cpython-37.pyc │ ├── __init__.cpython-37.pyc │ ├── base_model.cpython-37.pyc │ ├── basic_model.cpython-37.pyc │ ├── evaluator.cpython-37.pyc │ ├── help_funcs.cpython-37.pyc │ ├── losses.cpython-37.pyc │ ├── networks.cpython-37.pyc │ ├── pixel_shuffel_up.cpython-37.pyc │ ├── resnet.cpython-37.pyc │ └── trainer.cpython-37.pyc ├── base_model.py ├── basic_model.py ├── evaluator.py ├── help_funcs.py ├── losses.py ├── networks.py ├── pixel_shuffel_up.py ├── resnet.py ├── sync_batchnorm │ ├── __init__.c │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-37.pyc │ │ ├── batchnorm.cpython-37.pyc │ │ ├── comm.cpython-37.pyc │ │ └── replicate.cpython-37.pyc │ ├── batchnorm.c │ ├── batchnorm.py │ ├── batchnorm_reimpl.c │ ├── batchnorm_reimpl.py │ ├── comm.c │ ├── comm.py │ ├── replicate.c │ ├── replicate.py │ ├── unittest.c │ └── unittest.py └── trainer.py ├── requirements.txt ├── samples_LEVIR ├── A │ ├── test_102_0512_0000.png │ ├── test_113_0256.png │ ├── test_121_0768_0256.png │ ├── test_2_0000_0000.png │ ├── test_2_0000_0512.png │ ├── test_55_0256_0000.png │ ├── test_77_0512_0256.png │ ├── test_7_0256_0512.png │ ├── train_36_0512_0512.png │ ├── train_386_0512_0768.png │ ├── train_412_0512_0768.png │ └── val_27_0000_0256.png ├── B │ ├── test_102_0512_0000.png │ ├── test_113_0256.png │ ├── test_121_0768_0256.png │ ├── test_2_0000_0000.png │ ├── test_2_0000_0512.png │ ├── test_55_0256_0000.png │ ├── test_77_0512_0256.png │ ├── test_7_0256_0512.png │ ├── train_36_0512_0512.png │ ├── train_386_0512_0768.png │ ├── train_412_0512_0768.png │ └── val_27_0000_0256.png ├── label │ ├── test_102_0512_0000.png │ ├── test_121_0768_0256.png │ ├── test_2_0000_0000.png │ ├── test_2_0000_0512.png │ ├── test_55_0256_0000.png │ ├── test_77_0512_0256.png │ ├── test_7_0256_0512.png │ ├── train_36_0512_0512.png │ ├── train_386_0512_0768.png │ ├── train_412_0512_0768.png │ └── val_27_0000_0256.png ├── list │ └── demo.txt ├── predict_CD_BIT │ ├── test_102_0512_0000.png │ ├── test_121_0768_0256.png │ ├── test_2_0000_0000.png │ ├── test_2_0000_0512.png │ ├── test_55_0256_0000.png │ ├── test_77_0512_0256.png │ └── test_7_0256_0512.png ├── predict_CD_ChangeFormerV6 │ ├── test_102_0512_0000.png │ ├── test_121_0768_0256.png │ ├── test_2_0000_0000.png │ ├── test_2_0000_0512.png │ ├── test_55_0256_0000.png │ ├── test_77_0512_0256.png │ └── test_7_0256_0512.png ├── predict_CD_DTCDSCN │ ├── test_102_0512_0000.png │ ├── test_121_0768_0256.png │ ├── test_2_0000_0000.png │ ├── test_2_0000_0512.png │ ├── test_55_0256_0000.png │ ├── test_77_0512_0256.png │ └── test_7_0256_0512.png ├── predict_CD_SiamUnet_conc │ ├── test_102_0512_0000.png │ ├── test_121_0768_0256.png │ ├── test_2_0000_0000.png │ ├── test_2_0000_0512.png │ ├── test_55_0256_0000.png │ ├── test_77_0512_0256.png │ └── test_7_0256_0512.png ├── predict_CD_SiamUnet_diff │ ├── test_102_0512_0000.png │ ├── test_121_0768_0256.png │ ├── test_2_0000_0000.png │ ├── test_2_0000_0512.png │ ├── test_55_0256_0000.png │ ├── test_77_0512_0256.png │ └── test_7_0256_0512.png ├── predict_CD_Unet │ ├── test_102_0512_0000.png │ ├── test_121_0768_0256.png │ ├── test_2_0000_0000.png │ ├── test_2_0000_0512.png │ ├── test_55_0256_0000.png │ ├── test_77_0512_0256.png │ └── test_7_0256_0512.png └── predict_ChangeFormerV6 │ ├── test_102_0512_0000.png │ ├── test_121_0768_0256.png │ ├── test_2_0000_0000.png │ ├── test_2_0000_0512.png │ ├── test_55_0256_0000.png │ ├── test_77_0512_0256.png │ └── test_7_0256_0512.png └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # EGCTNet: Building Change Detection based on an Edge-Guided Convolutional Neural Network combined with Transformer 2 | (Posted in Remote Sensing) 3 | 4 | Here, we provide the pytorch implementation of the paper: Building Change Detection based on an Edge-Guided Convolutional Neural Network combined with Transformer. 5 | 6 | For more information, please see our paper at [arxiv](https://www.mdpi.com/2072-4292/14/18/4524). 7 | 8 | ## Network Architecture 9 | ![image-20210228153142126](./images/Figure 1.jpg) 10 | 11 | ## Quantitative & Qualitative Results on LEVIR-CD and WHU-CD 12 | LEVIR-CD 13 | ![image-20210228153142126](./images/Figure 9.jpg) 14 | WHU-CD 15 | ![image-20210228153142126](./images/Figure 10.jpg) 16 | ## Requirements 17 | 18 | ``` 19 | Python 3.8.0 20 | pytorch 1.10.1 21 | torchvision 0.11.2 22 | einops 0.3.2 23 | ``` 24 | 25 | Please see `requirements.txt` for all the other requirements. 26 | 27 | 28 | ## Train on LEVIR-CD 29 | 30 | You can run the script file by `main_cd.py` in the command environment. 31 | 32 | 33 | ## Evaluate on LEVIR 34 | 35 | You can run the script file by `eval_cd.py` in the command environment. 36 | 37 | 38 | 39 | ## Dataset Preparation 40 | 41 | ### Data structure 42 | 43 | ``` 44 | """ 45 | Change detection data set with pixel-level binary labels; 46 | ├─A 47 | ├─B 48 | ├─label 49 | ├─label_edge 50 | └─list 51 | """ 52 | ``` 53 | 54 | `A`: images of t1 phase; 55 | 56 | `B`:images of t2 phase; 57 | 58 | `label`: label maps; 59 | 60 | `label_edge`: using the Canny edge detection operator on theusing the Canny edge detection operator on the label maps; 61 | 62 | `list`: contains `train.txt, val.txt and test.txt`, each file records the image names (XXX.png) in the change detection dataset. 63 | 64 | ### Data Download 65 | 66 | LEVIR-CD: https://justchenhao.github.io/LEVIR/ 67 | 68 | WHU-CD: https://study.rsgis.whu.edu.cn/pages/download/building_dataset.html 69 | 70 | DSIFN-CD: https://github.com/GeoZcx/A-deeply-supervised-image-fusion-network-for-change-detection-in-remote-sensing-images/tree/master/dataset 71 | 72 | ## License 73 | 74 | Code is released for non-commercial and research purposes **only**. For commercial purposes, please contact the authors. 75 | 76 | ## Citation 77 | 78 | If you use this code for your research, please cite our paper: 79 | 80 | ``` 81 | MDPI and ACS Style 82 | Xia, L.; Chen, J.; Luo, J.; Zhang, J.; Yang, D.; Shen, Z. Building Change Detection Based on an Edge-Guided Convolutional Neural Network Combined with a Transformer. Remote Sens. 2022, 14, 4524. https://doi.org/10.3390/rs14184524 83 | 84 | AMA Style 85 | Xia L, Chen J, Luo J, Zhang J, Yang D, Shen Z. Building Change Detection Based on an Edge-Guided Convolutional Neural Network Combined with a Transformer. Remote Sensing. 2022; 14(18):4524. https://doi.org/10.3390/rs14184524 86 | 87 | Chicago/Turabian Style 88 | Xia, Liegang, Jun Chen, Jiancheng Luo, Junxia Zhang, Dezhi Yang, and Zhanfeng Shen. 2022. "Building Change Detection Based on an Edge-Guided Convolutional Neural Network Combined with a Transformer" Remote Sensing 14, no. 18: 4524. https://doi.org/10.3390/rs14184524 89 | 90 | ``` 91 | 92 | ## References 93 | Appreciate the work from the following repositories: 94 | 95 | - https://github.com/wgcban/ChangeFormer (Our EGCTNet is implemented on the code provided in this repository) 96 | 97 | -------------------------------------------------------------------------------- /__pycache__/data_config.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen11221/EGCTNet_pytorch/c35da37f05ca494aca4444050e80bde26e541ec2/__pycache__/data_config.cpython-37.pyc -------------------------------------------------------------------------------- /__pycache__/utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen11221/EGCTNet_pytorch/c35da37f05ca494aca4444050e80bde26e541ec2/__pycache__/utils.cpython-37.pyc -------------------------------------------------------------------------------- /data_config.py: -------------------------------------------------------------------------------- 1 | 2 | class DataConfig: 3 | data_name = "" 4 | root_dir = "" 5 | label_transform = "" 6 | def get_data_config(self, data_name): 7 | self.data_name = data_name 8 | if data_name == 'LEVIR': 9 | self.label_transform = "norm" 10 | self.root_dir = r'Z:\cj\datasourse\LEVIR-CD-256' 11 | elif data_name == 'WHU': 12 | self.label_transform = "norm" 13 | self.root_dir = r'E:\bianhuajiance\WHU' 14 | elif data_name == 'WHU-512-100': 15 | self.label_transform = "norm" 16 | self.root_dir = r'E:\bianhuajiance\database\WHU-CD-512-100' 17 | elif data_name == 'WHU-512-0': 18 | self.label_transform = "norm" 19 | self.root_dir = r'Z:\cj\datasourse\WHU-CD-512-0' 20 | elif data_name == 'WHU-512-10': 21 | self.label_transform = "norm" 22 | self.root_dir = r'Z:\cj\datasourse\WHU-CD-512-10' 23 | elif data_name == 'WHU-512-20': 24 | self.label_transform = "norm" 25 | self.root_dir = r'Z:\cj\datasourse\WHU-CD-512-20' 26 | elif data_name == 'WHU-512-30': 27 | self.label_transform = "norm" 28 | self.root_dir = r'Z:\cj\datasourse\WHU-CD-512-30' 29 | elif data_name == 'WHU-512-30-only': 30 | self.label_transform = "norm" 31 | self.root_dir = r'Z:\cj\datasourse\WHU-CD-512-30-only' 32 | elif data_name == 'WHU-512-40': 33 | self.label_transform = "norm" 34 | self.root_dir = r'Z:\cj\datasourse\WHU-CD-512-40' 35 | elif data_name == 'WHU-512-40-only': 36 | self.label_transform = "norm" 37 | self.root_dir = r'Z:\cj\datasourse\WHU-CD-512-40-only' 38 | elif data_name == 'WHU-512-50': 39 | self.label_transform = "norm" 40 | self.root_dir = r'Z:\cj\datasourse\WHU-CD-512-50' 41 | elif data_name == 'WHU-512-50-only': 42 | self.label_transform = "norm" 43 | self.root_dir = r'Z:\cj\datasourse\WHU-CD-512-50-only' 44 | elif data_name == 'quick_start_LEVIR': 45 | self.root_dir = './samples_LEVIR/' 46 | elif data_name == 'quick_start_WHU': 47 | self.root_dir = './samples_WHU/' 48 | else: 49 | raise TypeError('%s has not defined' % data_name) 50 | return self 51 | 52 | -------------------------------------------------------------------------------- /data_preparation/dsifn_cd_256.m: -------------------------------------------------------------------------------- 1 | %Dataset preparation code for DSFIN dataset (MATLAB) 2 | %Download DSFIN dataset here: https://github.com/GeoZcx/A-deeply-supervised-image-fusion-network-for-change-detection-in-remote-sensing-images/tree/master/dataset 3 | %This code generate 256x256 image partches required for the train/val/test 4 | %Please create folders according to following format. 5 | %DSIFN_256 6 | %------(train) 7 | % |---> A 8 | % |---> B 9 | % |---> label 10 | %------(val) 11 | % |---> A 12 | % |---> B 13 | % |---> label 14 | %------(test) 15 | % |---> A 16 | % |---> B 17 | % |---> label 18 | %Then run this code 19 | %Then copy all images in train-A, val-A, test-A to a folder name A 20 | %Then copy all images in train-B, val-B, test-B to a folder name B 21 | %Then copy all images in train-label, val-label, test-label to a folder name label 22 | 23 | 24 | 25 | 26 | clear all; 27 | close all; 28 | clc; 29 | 30 | %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 31 | %Train-A 32 | imgs_name = struct2cell(dir('DSIFN/download/Archive/train/t1/*.jpg')); 33 | for i=1:1:length(imgs_name) 34 | img_file_name = imgs_name{1,i}; 35 | temp = imread(strcat('DSIFN/download/Archive/train/t1/', img_file_name)); 36 | c=1; 37 | for j=1:2 38 | for k=1:2 39 | patch = temp((j-1)*256+1: j*256, (k-1)*256+1: k*256, :); 40 | imwrite(patch, strcat('DSIFN_256/train/A/', img_file_name(1:end-4), '_', num2str(c), '.png')); 41 | c=c+1; 42 | end 43 | end 44 | 45 | end 46 | 47 | %Train-B 48 | imgs_name = struct2cell(dir('DSIFN/download/Archive/train/t2/*.jpg')); 49 | for i=1:1:length(imgs_name) 50 | img_file_name = imgs_name{1,i}; 51 | temp = imread(strcat('DSIFN/download/Archive/train/t2/', img_file_name)); 52 | c=1; 53 | for j=1:2 54 | for k=1:2 55 | patch = temp((j-1)*256+1: j*256, (k-1)*256+1: k*256, :); 56 | imwrite(patch, strcat('DSIFN_256/train/B/', img_file_name(1:end-4), '_', num2str(c), '.png')); 57 | c=c+1; 58 | end 59 | end 60 | 61 | end 62 | 63 | %Train-label 64 | imgs_name = struct2cell(dir('DSIFN/download/Archive/train/mask/*.png')); 65 | for i=1:1:length(imgs_name) 66 | img_file_name = imgs_name{1,i}; 67 | temp = imread(strcat('DSIFN/download/Archive/train/mask/',img_file_name)); 68 | c=1; 69 | for j=1:2 70 | for k=1:2 71 | patch = temp((j-1)*256+1: j*256, (k-1)*256+1: k*256, :); 72 | imwrite(patch, strcat('DSIFN_256/train/label/', img_file_name(1:end-4), '_', num2str(c), '.png')); 73 | c=c+1; 74 | end 75 | end 76 | 77 | end 78 | 79 | %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 80 | %test-A 81 | imgs_name = struct2cell(dir('DSIFN/download/Archive/test/t1/*.jpg')); 82 | for i=1:1:length(imgs_name) 83 | img_file_name = imgs_name{1,i}; 84 | temp = imread(strcat('DSIFN/download/Archive/test/t1/', img_file_name)); 85 | c=1; 86 | for j=1:2 87 | for k=1:2 88 | patch = temp((j-1)*256+1: j*256, (k-1)*256+1: k*256, :); 89 | imwrite(patch, strcat('DSIFN_256/test/A/', img_file_name(1:end-4), '_', num2str(c), '.png')); 90 | c=c+1; 91 | end 92 | end 93 | 94 | end 95 | 96 | %test-B 97 | imgs_name = struct2cell(dir('DSIFN/download/Archive/test/t2/*.jpg')); 98 | for i=1:1:length(imgs_name) 99 | img_file_name = imgs_name{1,i}; 100 | temp = imread(strcat('DSIFN/download/Archive/test/t2/', img_file_name)); 101 | c=1; 102 | for j=1:2 103 | for k=1:2 104 | patch = temp((j-1)*256+1: j*256, (k-1)*256+1: k*256, :); 105 | imwrite(patch, strcat('DSIFN_256/test/B/', img_file_name(1:end-4), '_', num2str(c), '.png')); 106 | c=c+1; 107 | end 108 | end 109 | 110 | end 111 | 112 | %test-label 113 | imgs_name = struct2cell(dir('DSIFN/download/Archive/test/mask/*.png')); 114 | for i=1:1:length(imgs_name) 115 | img_file_name = imgs_name{1,i}; 116 | temp = imread(strcat('DSIFN/download/Archive/test/mask/',img_file_name)); 117 | c=1; 118 | for j=1:2 119 | for k=1:2 120 | patch = temp((j-1)*256+1: j*256, (k-1)*256+1: k*256, :); 121 | imwrite(patch, strcat('DSIFN_256/test/label/', img_file_name(1:end-4), '_', num2str(c), '.png')); 122 | c=c+1; 123 | end 124 | end 125 | 126 | end 127 | 128 | 129 | %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 130 | %val-A 131 | imgs_name = struct2cell(dir('DSIFN/download/Archive/val/t1/*.jpg')); 132 | for i=1:1:length(imgs_name) 133 | img_file_name = imgs_name{1,i}; 134 | temp = imread(strcat('DSIFN/download/Archive/val/t1/', img_file_name)); 135 | c=1; 136 | for j=1:2 137 | for k=1:2 138 | patch = temp((j-1)*256+1: j*256, (k-1)*256+1: k*256, :); 139 | imwrite(patch, strcat('DSIFN_256/val/A/', img_file_name(1:end-4), '_', num2str(c), '.png')); 140 | c=c+1; 141 | end 142 | end 143 | 144 | end 145 | 146 | %val-B 147 | imgs_name = struct2cell(dir('DSIFN/download/Archive/val/t2/*.jpg')); 148 | for i=1:1:length(imgs_name) 149 | img_file_name = imgs_name{1,i}; 150 | temp = imread(strcat('DSIFN/download/Archive/val/t2/', img_file_name)); 151 | c=1; 152 | for j=1:2 153 | for k=1:2 154 | patch = temp((j-1)*256+1: j*256, (k-1)*256+1: k*256, :); 155 | imwrite(patch, strcat('DSIFN_256/val/B/', img_file_name(1:end-4), '_', num2str(c), '.png')); 156 | c=c+1; 157 | end 158 | end 159 | 160 | end 161 | 162 | %val-label 163 | imgs_name = struct2cell(dir('DSIFN/download/Archive/val/mask/*.png')); 164 | for i=1:1:length(imgs_name) 165 | img_file_name = imgs_name{1,i}; 166 | temp = imread(strcat('DSIFN/download/Archive/val/mask/',img_file_name)); 167 | c=1; 168 | for j=1:2 169 | for k=1:2 170 | patch = temp((j-1)*256+1: j*256, (k-1)*256+1: k*256, :); 171 | imwrite(patch, strcat('DSIFN_256/val/label/', img_file_name(1:end-4), '_', num2str(c), '.png')); 172 | c=c+1; 173 | end 174 | end 175 | 176 | end 177 | -------------------------------------------------------------------------------- /data_preparation/find_mean_std.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from PIL import Image 4 | 5 | 6 | if __name__ == '__main__': 7 | filepath = r"/media/lidan/ssd2/CDData/LEVIR-CD256/B/" # Dataset directory 8 | pathDir = os.listdir(filepath) # Images in dataset directory 9 | num = len(pathDir) # Here (512512) is the size of each image 10 | 11 | print("Computing mean...") 12 | data_mean = np.zeros(3) 13 | for idx in range(len(pathDir)): 14 | filename = pathDir[idx] 15 | img = Image.open(os.path.join(filepath, filename)) 16 | img = np.array(img) / 255.0 17 | print(img.shape) 18 | data_mean += np.mean(img) # Take all the data of the first dimension in the three-dimensional matrix 19 | # As the use of gray images, so calculate a channel on it 20 | data_mean = data_mean / num 21 | 22 | print("Computing var...") 23 | data_std = 0. 24 | for idx in range(len(pathDir)): 25 | filename = pathDir[idx] 26 | img = Image.open(os.path.join(filepath, filename)).convert('L').resize((256, 256)) 27 | img = np.array(img) / 255.0 28 | data_std += np.std(img) 29 | 30 | data_std = data_std / num 31 | print("mean:{}".format(data_mean)) 32 | print("std:{}".format(data_std)) -------------------------------------------------------------------------------- /data_preparation/levir_cd_256.m: -------------------------------------------------------------------------------- 1 | %Dataset preparation code for DSFIN dataset (MATLAB) 2 | %Download DSFIN dataset here: https://justchenhao.github.io/LEVIR/ 3 | %This code generate 256x256 image partches required for the train/val/test 4 | %Please create folders according to following format. 5 | %DSIFN_256 6 | %------(train) 7 | % |---> A 8 | % |---> B 9 | % |---> label 10 | %------(val) 11 | % |---> A 12 | % |---> B 13 | % |---> label 14 | %------(test) 15 | % |---> A 16 | % |---> B 17 | % |---> label 18 | %Then run this code 19 | %Then copy all images in train-A, val-A, test-A to a folder name A 20 | %Then copy all images in train-B, val-B, test-B to a folder name B 21 | %Then copy all images in train-label, val-label, test-label to a folder name label 22 | 23 | clear all; 24 | close all; 25 | clc; 26 | 27 | %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 28 | %Train-A 29 | imgs_name = struct2cell(dir('LEVIR-CD/train/A/*.png')); 30 | for i=1:1:length(imgs_name) 31 | img_file_name = imgs_name{1,i}; 32 | temp = imread(strcat('LEVIR-CD/train/A/',img_file_name)); 33 | c=1; 34 | for j=1:4 35 | for k=1:4 36 | patch = temp((j-1)*256+1: j*256, (k-1)*256+1: k*256, :); 37 | imwrite(patch, strcat('LEVIR-CD256/train/A/', img_file_name(1:end-4), '_', num2str(c), '.png')); 38 | c=c+1; 39 | end 40 | end 41 | 42 | end 43 | 44 | %Train-B 45 | imgs_name = struct2cell(dir('LEVIR-CD/train/B/*.png')); 46 | for i=1:1:length(imgs_name) 47 | img_file_name = imgs_name{1,i}; 48 | temp = imread(strcat('LEVIR-CD/train/B/',img_file_name)); 49 | c=1; 50 | for j=1:4 51 | for k=1:4 52 | patch = temp((j-1)*256+1: j*256, (k-1)*256+1: k*256, :); 53 | imwrite(patch, strcat('LEVIR-CD256/train/B/', img_file_name(1:end-4), '_', num2str(j+k-1), '.png')); 54 | c=c+1; 55 | end 56 | end 57 | 58 | end 59 | 60 | %Train-label 61 | imgs_name = struct2cell(dir('LEVIR-CD/train/label/*.png')); 62 | for i=1:1:length(imgs_name) 63 | img_file_name = imgs_name{1,i}; 64 | temp = imread(strcat('LEVIR-CD/train/label/',img_file_name)); 65 | c=1; 66 | for j=1:4 67 | for k=1:4 68 | patch = temp((j-1)*256+1: j*256, (k-1)*256+1: k*256, :); 69 | imwrite(patch, strcat('LEVIR-CD256/train/label/', img_file_name(1:end-4), '_', num2str(j+k-1), '.png')); 70 | c=c+1; 71 | end 72 | end 73 | 74 | end 75 | 76 | %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 77 | %Test-A 78 | imgs_name = struct2cell(dir('LEVIR-CD/test/A/*.png')); 79 | for i=1:1:length(imgs_name) 80 | img_file_name = imgs_name{1,i}; 81 | temp = imread(strcat('LEVIR-CD/test/A/',img_file_name)); 82 | c=1; 83 | for j=1:4 84 | for k=1:4 85 | patch = temp((j-1)*256+1: j*256, (k-1)*256+1: k*256, :); 86 | imwrite(patch, strcat('LEVIR-CD256/test/A/', img_file_name(1:end-4), '_', num2str(j+k-1), '.png')); 87 | c=c+1; 88 | end 89 | end 90 | 91 | end 92 | 93 | %Test-B 94 | imgs_name = struct2cell(dir('LEVIR-CD/test/B/*.png')); 95 | for i=1:1:length(imgs_name) 96 | img_file_name = imgs_name{1,i}; 97 | temp = imread(strcat('LEVIR-CD/test/B/',img_file_name)); 98 | c=1; 99 | for j=1:4 100 | for k=1:4 101 | patch = temp((j-1)*256+1: j*256, (k-1)*256+1: k*256, :); 102 | imwrite(patch, strcat('LEVIR-CD256/test/B/', img_file_name(1:end-4), '_', num2str(j+k-1), '.png')); 103 | c=c+1; 104 | end 105 | end 106 | 107 | end 108 | 109 | %Test-label 110 | imgs_name = struct2cell(dir('LEVIR-CD/test/label/*.png')); 111 | for i=1:1:length(imgs_name) 112 | img_file_name = imgs_name{1,i}; 113 | temp = imread(strcat('LEVIR-CD/test/label/',img_file_name)); 114 | c=1; 115 | for j=1:4 116 | for k=1:4 117 | patch = temp((j-1)*256+1: j*256, (k-1)*256+1: k*256, :); 118 | imwrite(patch, strcat('LEVIR-CD256/test/label/', img_file_name(1:end-4), '_', num2str(j+k-1), '.png')); 119 | c=c+1; 120 | end 121 | end 122 | 123 | end 124 | 125 | 126 | %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 127 | %val-A 128 | imgs_name = struct2cell(dir('LEVIR-CD/val/A/*.png')); 129 | for i=1:1:length(imgs_name) 130 | img_file_name = imgs_name{1,i}; 131 | temp = imread(strcat('LEVIR-CD/val/A/',img_file_name)); 132 | c=1; 133 | for j=1:4 134 | for k=1:4 135 | patch = temp((j-1)*256+1: j*256, (k-1)*256+1: k*256, :); 136 | imwrite(patch, strcat('LEVIR-CD256/val/A/', img_file_name(1:end-4), '_', num2str(j+k-1), '.png')); 137 | c=c+1; 138 | end 139 | end 140 | 141 | end 142 | 143 | %val-B 144 | imgs_name = struct2cell(dir('LEVIR-CD/val/B/*.png')); 145 | for i=1:1:length(imgs_name) 146 | img_file_name = imgs_name{1,i}; 147 | temp = imread(strcat('LEVIR-CD/val/B/',img_file_name)); 148 | c=1; 149 | for j=1:4 150 | for k=1:4 151 | patch = temp((j-1)*256+1: j*256, (k-1)*256+1: k*256, :); 152 | imwrite(patch, strcat('LEVIR-CD256/val/B/', img_file_name(1:end-4), '_', num2str(j+k-1), '.png')); 153 | c=c+1; 154 | end 155 | end 156 | 157 | end 158 | 159 | %val-label 160 | imgs_name = struct2cell(dir('LEVIR-CD/val/label/*.png')); 161 | for i=1:1:length(imgs_name) 162 | img_file_name = imgs_name{1,i}; 163 | temp = imread(strcat('LEVIR-CD/val/label/',img_file_name)); 164 | c=1; 165 | for j=1:4 166 | for k=1:4 167 | patch = temp((j-1)*256+1: j*256, (k-1)*256+1: k*256, :); 168 | imwrite(patch, strcat('LEVIR-CD256/val/label/', img_file_name(1:end-4), '_', num2str(j+k-1), '.png')); 169 | c=c+1; 170 | end 171 | end 172 | 173 | end -------------------------------------------------------------------------------- /datasets/CD_dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | 变化检测数据集 3 | """ 4 | 5 | import os 6 | from PIL import Image 7 | import numpy as np 8 | import matplotlib.pyplot as plt 9 | 10 | from torchvision import transforms 11 | from torch.utils import data 12 | 13 | from datasets.data_utils import CDDataAugmentation 14 | 15 | 16 | """ 17 | CD data set with pixel-level labels; 18 | ├─image 19 | ├─image_post 20 | ├─label 21 | ├─label_edge 22 | └─list 23 | """ 24 | IMG_FOLDER_NAME = "A" 25 | IMG_POST_FOLDER_NAME = 'B' 26 | LIST_FOLDER_NAME = 'list' 27 | ANNOT_FOLDER_NAME = "label" 28 | EDGE_FOLDER_NAME = "label_edge" 29 | 30 | IGNORE = 255 31 | 32 | label_suffix='.png' # jpg for gan dataset, others : png 33 | 34 | def load_img_name_list(dataset_path): 35 | img_name_list = np.loadtxt(dataset_path, dtype=np.str) 36 | if img_name_list.ndim == 2: 37 | return img_name_list[:, 0] 38 | return img_name_list 39 | 40 | 41 | def load_image_label_list_from_npy(npy_path, img_name_list): 42 | cls_labels_dict = np.load(npy_path, allow_pickle=True).item() 43 | return [cls_labels_dict[img_name] for img_name in img_name_list] 44 | 45 | 46 | def get_img_post_path(root_dir,img_name): 47 | return os.path.join(root_dir, IMG_POST_FOLDER_NAME, img_name) 48 | 49 | 50 | def get_img_path(root_dir, img_name): 51 | return os.path.join(root_dir, IMG_FOLDER_NAME, img_name) 52 | 53 | 54 | def get_label_path(root_dir, img_name): 55 | return os.path.join(root_dir, ANNOT_FOLDER_NAME, img_name.replace('.jpg', label_suffix)) 56 | 57 | def get_edge_label_path(root_dir, img_name): 58 | return os.path.join(root_dir, EDGE_FOLDER_NAME, img_name.replace('.jpg', label_suffix)) 59 | 60 | class ImageDataset(data.Dataset): 61 | """VOCdataloder""" 62 | def __init__(self, root_dir, split='train', img_size=256, is_train=True,to_tensor=True): 63 | super(ImageDataset, self).__init__() 64 | self.root_dir = root_dir 65 | self.img_size = img_size 66 | self.split = split #train | train_aug | val 67 | # self.list_path = self.root_dir + '/' + LIST_FOLDER_NAME + '/' + self.list + '.txt' 68 | self.list_path = os.path.join(self.root_dir, LIST_FOLDER_NAME, self.split+'.txt') 69 | self.img_name_list = load_img_name_list(self.list_path) 70 | 71 | self.A_size = len(self.img_name_list) # get the size of dataset A 72 | self.to_tensor = to_tensor 73 | if is_train: 74 | self.augm = CDDataAugmentation( 75 | img_size=self.img_size, 76 | with_random_hflip=True, 77 | with_random_vflip=True, 78 | with_scale_random_crop=True, 79 | with_random_blur=True, 80 | random_color_tf=True 81 | ) 82 | else: 83 | self.augm = CDDataAugmentation( 84 | img_size=self.img_size 85 | ) 86 | def __getitem__(self, index): 87 | name = self.img_name_list[index] 88 | A_path = get_img_path(self.root_dir, self.img_name_list[index % self.A_size]) 89 | B_path = get_img_post_path(self.root_dir, self.img_name_list[index % self.A_size]) 90 | 91 | img = np.asarray(Image.open(A_path).convert('RGB')) 92 | img_B = np.asarray(Image.open(B_path).convert('RGB')) 93 | 94 | [img, img_B], _ = self.augm.transform([img, img_B],[], to_tensor=self.to_tensor) 95 | 96 | return {'A': img, 'B': img_B, 'name': name} 97 | 98 | def __len__(self): 99 | """Return the total number of images in the dataset.""" 100 | return self.A_size 101 | 102 | 103 | class CDDataset(ImageDataset): 104 | 105 | def __init__(self, root_dir, img_size, split='train', is_train=True, label_transform=None, 106 | to_tensor=True): 107 | super(CDDataset, self).__init__(root_dir, img_size=img_size, split=split, is_train=is_train, 108 | to_tensor=to_tensor) 109 | self.label_transform = label_transform 110 | 111 | def __getitem__(self, index): 112 | name = self.img_name_list[index] 113 | A_path = get_img_path(self.root_dir, self.img_name_list[index % self.A_size]) 114 | B_path = get_img_post_path(self.root_dir, self.img_name_list[index % self.A_size]) 115 | img = np.asarray(Image.open(A_path).convert('RGB')) 116 | img_B = np.asarray(Image.open(B_path).convert('RGB')) 117 | L_path = get_label_path(self.root_dir, self.img_name_list[index % self.A_size]) 118 | L_edge_path = get_edge_label_path(self.root_dir, self.img_name_list[index % self.A_size]) 119 | 120 | label = np.array(Image.open(L_path), dtype=np.uint8) 121 | edge_label = np.array(Image.open(L_edge_path), dtype=np.uint8) 122 | 123 | # 二分类中,前景标注为255 124 | if self.label_transform == 'norm': 125 | label = label // 255 126 | edge_label = edge_label // 255 127 | 128 | [img, img_B], [label, edge_label] = self.augm.transform([img, img_B], [label, edge_label], to_tensor=self.to_tensor) 129 | 130 | # label = label.numpy() 131 | # label = np.transpose(label, (1, 2, 0)) 132 | # # 显示图片 133 | # plt.imshow(label) 134 | # plt.show() 135 | 136 | return {'name': name, 'A': img, 'B': img_B, 'L': label, "L_edge": edge_label} 137 | 138 | -------------------------------------------------------------------------------- /datasets/__pycache__/CD_dataset.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen11221/EGCTNet_pytorch/c35da37f05ca494aca4444050e80bde26e541ec2/datasets/__pycache__/CD_dataset.cpython-37.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/data_utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen11221/EGCTNet_pytorch/c35da37f05ca494aca4444050e80bde26e541ec2/datasets/__pycache__/data_utils.cpython-37.pyc -------------------------------------------------------------------------------- /datasets/data_utils.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | 4 | from PIL import Image 5 | from PIL import ImageFilter 6 | 7 | import torchvision.transforms.functional as TF 8 | from torchvision import transforms 9 | import torch 10 | 11 | 12 | def to_tensor_and_norm(imgs, labels): 13 | # to tensor 14 | imgs = [TF.to_tensor(img) for img in imgs] 15 | labels = [torch.from_numpy(np.array(img, np.uint8)).unsqueeze(dim=0) 16 | for img in labels] 17 | 18 | imgs = [TF.normalize(img, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 19 | for img in imgs] 20 | return imgs, labels 21 | 22 | 23 | class CDDataAugmentation: 24 | 25 | def __init__( 26 | self, 27 | img_size, 28 | with_random_hflip=False, 29 | with_random_vflip=False, 30 | with_random_rot=False, 31 | with_random_crop=False, 32 | with_scale_random_crop=False, 33 | with_random_blur=False, 34 | random_color_tf=False 35 | ): 36 | self.img_size = img_size 37 | if self.img_size is None: 38 | self.img_size_dynamic = True 39 | else: 40 | self.img_size_dynamic = False 41 | self.with_random_hflip = with_random_hflip 42 | self.with_random_vflip = with_random_vflip 43 | self.with_random_rot = with_random_rot 44 | self.with_random_crop = with_random_crop 45 | self.with_scale_random_crop = with_scale_random_crop 46 | self.with_random_blur = with_random_blur 47 | self.random_color_tf=random_color_tf 48 | def transform(self, imgs, labels, to_tensor=True): 49 | """ 50 | :param imgs: [ndarray,] 51 | :param labels: [ndarray,] 52 | :return: [ndarray,],[ndarray,] 53 | """ 54 | # resize image and covert to tensor 55 | imgs = [TF.to_pil_image(img) for img in imgs] 56 | if self.img_size is None: 57 | self.img_size = None 58 | 59 | if not self.img_size_dynamic: 60 | if imgs[0].size != (self.img_size, self.img_size): 61 | imgs = [TF.resize(img, [self.img_size, self.img_size], interpolation=3) 62 | for img in imgs] 63 | else: 64 | self.img_size = imgs[0].size[0] 65 | 66 | labels = [TF.to_pil_image(img) for img in labels] 67 | if len(labels) != 0: 68 | if labels[0].size != (self.img_size, self.img_size): 69 | labels = [TF.resize(img, [self.img_size, self.img_size], interpolation=0) 70 | for img in labels] 71 | 72 | random_base = 0.5 73 | if self.with_random_hflip and random.random() > 0.5: 74 | imgs = [TF.hflip(img) for img in imgs] 75 | labels = [TF.hflip(img) for img in labels] 76 | 77 | if self.with_random_vflip and random.random() > 0.5: 78 | imgs = [TF.vflip(img) for img in imgs] 79 | labels = [TF.vflip(img) for img in labels] 80 | 81 | if self.with_random_rot and random.random() > random_base: 82 | angles = [90, 180, 270] 83 | index = random.randint(0, 2) 84 | angle = angles[index] 85 | imgs = [TF.rotate(img, angle) for img in imgs] 86 | labels = [TF.rotate(img, angle) for img in labels] 87 | 88 | if self.with_random_crop and random.random() > 0: 89 | i, j, h, w = transforms.RandomResizedCrop(size=self.img_size). \ 90 | get_params(img=imgs[0], scale=(0.8, 1.2), ratio=(1, 1)) 91 | 92 | imgs = [TF.resized_crop(img, i, j, h, w, 93 | size=(self.img_size, self.img_size), 94 | interpolation=Image.CUBIC) 95 | for img in imgs] 96 | 97 | labels = [TF.resized_crop(img, i, j, h, w, 98 | size=(self.img_size, self.img_size), 99 | interpolation=Image.NEAREST) 100 | for img in labels] 101 | 102 | if self.with_scale_random_crop: 103 | # rescale 104 | scale_range = [1, 1.2] 105 | target_scale = scale_range[0] + random.random() * (scale_range[1] - scale_range[0]) 106 | 107 | imgs = [pil_rescale(img, target_scale, order=3) for img in imgs] 108 | labels = [pil_rescale(img, target_scale, order=0) for img in labels] 109 | # crop 110 | imgsize = imgs[0].size # h, w 111 | box = get_random_crop_box(imgsize=imgsize, cropsize=self.img_size) 112 | imgs = [pil_crop(img, box, cropsize=self.img_size, default_value=0) 113 | for img in imgs] 114 | labels = [pil_crop(img, box, cropsize=self.img_size, default_value=255) 115 | for img in labels] 116 | 117 | if self.with_random_blur and random.random() > 0: 118 | radius = random.random() 119 | imgs = [img.filter(ImageFilter.GaussianBlur(radius=radius)) 120 | for img in imgs] 121 | 122 | if self.random_color_tf: 123 | color_jitter = transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.3) 124 | imgs_tf = [] 125 | for img in imgs: 126 | tf = transforms.ColorJitter( 127 | color_jitter.brightness, 128 | color_jitter.contrast, 129 | color_jitter.saturation, 130 | color_jitter.hue) 131 | imgs_tf.append(tf(img)) 132 | imgs = imgs_tf 133 | 134 | if to_tensor: 135 | # to tensor 136 | imgs = [TF.to_tensor(img) for img in imgs] 137 | labels = [torch.from_numpy(np.array(img, np.uint8)).unsqueeze(dim=0) 138 | for img in labels] 139 | 140 | imgs = [TF.normalize(img, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 141 | for img in imgs] 142 | 143 | return imgs, labels 144 | 145 | 146 | def pil_crop(image, box, cropsize, default_value): 147 | assert isinstance(image, Image.Image) 148 | img = np.array(image) 149 | 150 | if len(img.shape) == 3: 151 | cont = np.ones((cropsize, cropsize, img.shape[2]), img.dtype)*default_value 152 | else: 153 | cont = np.ones((cropsize, cropsize), img.dtype)*default_value 154 | cont[box[0]:box[1], box[2]:box[3]] = img[box[4]:box[5], box[6]:box[7]] 155 | 156 | return Image.fromarray(cont) 157 | 158 | 159 | def get_random_crop_box(imgsize, cropsize): 160 | h, w = imgsize 161 | ch = min(cropsize, h) 162 | cw = min(cropsize, w) 163 | 164 | w_space = w - cropsize 165 | h_space = h - cropsize 166 | 167 | if w_space > 0: 168 | cont_left = 0 169 | img_left = random.randrange(w_space + 1) 170 | else: 171 | cont_left = random.randrange(-w_space + 1) 172 | img_left = 0 173 | 174 | if h_space > 0: 175 | cont_top = 0 176 | img_top = random.randrange(h_space + 1) 177 | else: 178 | cont_top = random.randrange(-h_space + 1) 179 | img_top = 0 180 | 181 | return cont_top, cont_top+ch, cont_left, cont_left+cw, img_top, img_top+ch, img_left, img_left+cw 182 | 183 | 184 | def pil_rescale(img, scale, order): 185 | assert isinstance(img, Image.Image) 186 | height, width = img.size 187 | target_size = (int(np.round(height*scale)), int(np.round(width*scale))) 188 | return pil_resize(img, target_size, order) 189 | 190 | 191 | def pil_resize(img, size, order): 192 | assert isinstance(img, Image.Image) 193 | if size[0] == img.size[0] and size[1] == img.size[1]: 194 | return img 195 | if order == 3: 196 | resample = Image.BICUBIC 197 | elif order == 0: 198 | resample = Image.NEAREST 199 | return img.resize(size[::-1], resample) 200 | -------------------------------------------------------------------------------- /demo_LEVIR.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | 3 | import utils 4 | import torch 5 | from models.basic_model import CDEvaluator 6 | 7 | import os 8 | 9 | """ 10 | quick start 11 | 12 | sample files in ./samples 13 | 14 | save prediction files in the ./samples/predict 15 | 16 | """ 17 | 18 | 19 | def get_args(): 20 | # ------------ 21 | # args 22 | # ------------ 23 | parser = ArgumentParser() 24 | parser.add_argument('--project_name', default='CD_ChangeFormerV6_LEVIR_b16_lr0.0001_adamw_train_test_200_linear_ce_multi_train_True_multi_infer_False_shuffle_AB_False_embed_dim_256', type=str) 25 | parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU') 26 | parser.add_argument('--checkpoint_root', default='/media/lidan/ssd2/ChangeFormer/checkpoints/', type=str) 27 | parser.add_argument('--output_folder', default='samples_LEVIR/predict_CD_ChangeFormerV6', type=str) 28 | 29 | # data 30 | parser.add_argument('--num_workers', default=0, type=int) 31 | parser.add_argument('--dataset', default='CDDataset', type=str) 32 | parser.add_argument('--data_name', default='quick_start_LEVIR', type=str) 33 | 34 | parser.add_argument('--batch_size', default=1, type=int) 35 | parser.add_argument('--split', default="demo", type=str) 36 | parser.add_argument('--img_size', default=256, type=int) 37 | 38 | # model 39 | parser.add_argument('--n_class', default=2, type=int) 40 | parser.add_argument('--embed_dim', default=256, type=int) 41 | parser.add_argument('--net_G', default='ChangeFormerV6', type=str, 42 | help='ChangeFormerV6 | CD_SiamUnet_diff | SiamUnet_conc | Unet | DTCDSCN | base_resnet18 | base_transformer_pos_s4_dd8 | base_transformer_pos_s4_dd8_dedim8|') 43 | parser.add_argument('--checkpoint_name', default='best_ckpt.pt', type=str) 44 | 45 | args = parser.parse_args() 46 | return args 47 | 48 | 49 | if __name__ == '__main__': 50 | 51 | args = get_args() 52 | utils.get_device(args) 53 | device = torch.device("cuda:%s" % args.gpu_ids[0] 54 | if torch.cuda.is_available() and len(args.gpu_ids)>0 55 | else "cpu") 56 | args.checkpoint_dir = os.path.join(args.checkpoint_root, args.project_name) 57 | os.makedirs(args.output_folder, exist_ok=True) 58 | 59 | log_path = os.path.join(args.output_folder, 'log_vis.txt') 60 | 61 | data_loader = utils.get_loader(args.data_name, img_size=args.img_size, 62 | batch_size=args.batch_size, 63 | split=args.split, is_train=False) 64 | 65 | model = CDEvaluator(args) 66 | model.load_checkpoint(args.checkpoint_name) 67 | model.eval() 68 | 69 | for i, batch in enumerate(data_loader): 70 | name = batch['name'] 71 | print('process: %s' % name) 72 | score_map = model._forward_pass(batch) 73 | model._save_predictions() 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | -------------------------------------------------------------------------------- /eval_cd.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | import torch 3 | from models.evaluator import * 4 | 5 | print(torch.cuda.is_available()) 6 | 7 | 8 | """ 9 | eval the CD model 10 | """ 11 | 12 | def main(): 13 | # ------------ 14 | # args 15 | # ------------ 16 | parser = ArgumentParser() 17 | parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU') 18 | parser.add_argument('--project_name', default='test', type=str) 19 | parser.add_argument('--print_models', default=False, type=bool, help='print models') 20 | parser.add_argument('--checkpoints_root', default='checkpoints', type=str) 21 | parser.add_argument('--vis_root', default='vis', type=str) 22 | 23 | # data 24 | parser.add_argument('--num_workers', default=8, type=int) 25 | parser.add_argument('--dataset', default='CDDataset', type=str) 26 | parser.add_argument('--data_name', default='LEVIR', type=str) 27 | 28 | parser.add_argument('--batch_size', default=1, type=int) 29 | parser.add_argument('--split', default="test", type=str) 30 | 31 | parser.add_argument('--img_size', default=256, type=int) 32 | 33 | # model 34 | parser.add_argument('--n_class', default=2, type=int) 35 | parser.add_argument('--embed_dim', default=256, type=int) 36 | parser.add_argument('--net_G', default='base_transformer_pos_s4_dd8_dedim8', type=str, 37 | help='base_resnet18 | base_transformer_pos_s4_dd8 | base_transformer_pos_s4_dd8_dedim8|') 38 | 39 | parser.add_argument('--checkpoint_name', default='best_ckpt.pt', type=str) 40 | 41 | args = parser.parse_args() 42 | utils.get_device(args) 43 | print(args.gpu_ids) 44 | 45 | # checkpoints dir 46 | args.checkpoint_dir = os.path.join(args.checkpoints_root, args.project_name) 47 | os.makedirs(args.checkpoint_dir, exist_ok=True) 48 | # visualize dir 49 | args.vis_dir = os.path.join(args.vis_root, args.project_name) 50 | os.makedirs(args.vis_dir, exist_ok=True) 51 | 52 | dataloader = utils.get_loader(args.data_name, img_size=args.img_size, 53 | batch_size=args.batch_size, is_train=False, 54 | split=args.split) 55 | model = CDEvaluator(args=args, dataloader=dataloader) 56 | 57 | model.eval_models(checkpoint_name=args.checkpoint_name) 58 | 59 | 60 | if __name__ == '__main__': 61 | main() 62 | 63 | -------------------------------------------------------------------------------- /images/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen11221/EGCTNet_pytorch/c35da37f05ca494aca4444050e80bde26e541ec2/images/.DS_Store -------------------------------------------------------------------------------- /images/Figure 1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen11221/EGCTNet_pytorch/c35da37f05ca494aca4444050e80bde26e541ec2/images/Figure 1.jpg -------------------------------------------------------------------------------- /images/Figure 10.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen11221/EGCTNet_pytorch/c35da37f05ca494aca4444050e80bde26e541ec2/images/Figure 10.jpg -------------------------------------------------------------------------------- /images/Figure 9.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen11221/EGCTNet_pytorch/c35da37f05ca494aca4444050e80bde26e541ec2/images/Figure 9.jpg -------------------------------------------------------------------------------- /main_cd.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | import torch 3 | from models.trainer import * 4 | 5 | print(torch.cuda.is_available()) 6 | 7 | """ 8 | the main function for training the CD networks 9 | """ 10 | 11 | 12 | def train(args): 13 | dataloaders = utils.get_loaders(args) 14 | model = CDTrainer(args=args, dataloaders=dataloaders) 15 | model.train_models() 16 | 17 | 18 | def test(args): 19 | from models.evaluator import CDEvaluator 20 | dataloader = utils.get_loader(args.data_name, img_size=args.img_size, 21 | batch_size=args.batch_size, is_train=False, 22 | split='test', dataset=args.dataset) 23 | model = CDEvaluator(args=args, dataloader=dataloader) 24 | 25 | model.eval_models() 26 | 27 | 28 | def test2(args): 29 | from models.evaluator import CDEvaluator 30 | 31 | model = CDEvaluator(args=args, dataloader='') 32 | model.pred_gdal_blocks_write(r'E:\bianhuajiance\cq\fengjiexian\nanbu\2020\500115_clip4.tif', 33 | r'E:\bianhuajiance\cq\fengjiexian\nanbu\2021\500115_clip4.tif') 34 | 35 | 36 | if __name__ == '__main__': 37 | # ------------ 38 | # args 39 | # ------------ 40 | parser = ArgumentParser() 41 | parser.add_argument('--gpu_ids', type=str, default='-1', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU') 42 | parser.add_argument('--project_name', default='chongqing_EGCTNet_eas_bs16_0.0001', type=str) 43 | parser.add_argument('--checkpoint_root', default='checkpoints', type=str) 44 | parser.add_argument('--vis_root', default='vis', type=str) 45 | 46 | # data 47 | parser.add_argument('--num_workers', default=2, type=int) 48 | parser.add_argument('--dataset', default='CDDataset', type=str) 49 | parser.add_argument('--data_name', default='WHU-512-100', type=str) 50 | 51 | parser.add_argument('--batch_size', default=1, type=int) 52 | parser.add_argument('--split', default="train", type=str) 53 | parser.add_argument('--split_val', default="val", type=str) 54 | 55 | parser.add_argument('--img_size', default=512, type=int) 56 | parser.add_argument('--shuffle_AB', default=False, type=str) 57 | 58 | # model 59 | parser.add_argument('--n_class', default=2, type=int) 60 | parser.add_argument('--embed_dim', default=32, type=int) 61 | parser.add_argument('--pretrain', default=None, type=str) 62 | parser.add_argument('--multi_scale_train', default=False, type=str) 63 | parser.add_argument('--multi_scale_infer', default=False, type=str) 64 | parser.add_argument('--multi_pred_weights', nargs='+', type=float, default=[0.5, 0.5, 0.5, 0.8, 1.0]) 65 | 66 | parser.add_argument('--net_G', default='EGCTNet', type=str, 67 | help='base_resnet18 | base_transformer_pos_s4 | ' 68 | 'base_transformer_pos_s4_dd8 | ' 69 | 'base_transformer_pos_s4_dd8_dedim8|ChangeFormerV5|SiamUnet_diff') 70 | parser.add_argument('--loss', default='eas', type=str) 71 | 72 | # optimizer 73 | parser.add_argument('--optimizer', default='adamw', type=str) 74 | parser.add_argument('--lr', default=0.002, type=float) 75 | parser.add_argument('--max_epochs', default=100, type=int) 76 | parser.add_argument('--lr_policy', default='linear', type=str, 77 | help='linear | step') 78 | parser.add_argument('--lr_decay_iters', default=100, type=int) 79 | 80 | args = parser.parse_args() 81 | utils.get_device(args) 82 | print(args.gpu_ids) 83 | 84 | # checkpoints dir 85 | args.checkpoint_dir = os.path.join(args.checkpoint_root, args.project_name) 86 | os.makedirs(args.checkpoint_dir, exist_ok=True) 87 | # visualize dir 88 | args.vis_dir = os.path.join(args.vis_root, args.project_name) 89 | os.makedirs(args.vis_dir, exist_ok=True) 90 | 91 | # train(args) 92 | test2(args) 93 | -------------------------------------------------------------------------------- /misc/__pycache__/imutils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen11221/EGCTNet_pytorch/c35da37f05ca494aca4444050e80bde26e541ec2/misc/__pycache__/imutils.cpython-37.pyc -------------------------------------------------------------------------------- /misc/__pycache__/logger_tool.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen11221/EGCTNet_pytorch/c35da37f05ca494aca4444050e80bde26e541ec2/misc/__pycache__/logger_tool.cpython-37.pyc -------------------------------------------------------------------------------- /misc/__pycache__/metric_tool.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen11221/EGCTNet_pytorch/c35da37f05ca494aca4444050e80bde26e541ec2/misc/__pycache__/metric_tool.cpython-37.pyc -------------------------------------------------------------------------------- /misc/imutils.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | import cv2 4 | from PIL import Image 5 | from PIL import ImageFilter 6 | import PIL 7 | import tifffile 8 | 9 | 10 | def cv_rotate(image, angle, borderValue): 11 | """ 12 | rot angle, fill with borderValue 13 | """ 14 | # grab the dimensions of the image and then determine the 15 | # center 16 | (h, w) = image.shape[:2] 17 | (cX, cY) = (w // 2, h // 2) 18 | 19 | # grab the rotation matrix (applying the negative of the 20 | # angle to rotate clockwise), then grab the sine and cosine 21 | # (i.e., the rotation components of the matrix) 22 | # -angle位置参数为角度参数负值表示顺时针旋转; 1.0位置参数scale是调整尺寸比例(图像缩放参数),建议0.75 23 | M = cv2.getRotationMatrix2D((cX, cY), -angle, 1.0) 24 | cos = np.abs(M[0, 0]) 25 | sin = np.abs(M[0, 1]) 26 | 27 | # compute the new bounding dimensions of the image 28 | nW = int((h * sin) + (w * cos)) 29 | nH = int((h * cos) + (w * sin)) 30 | 31 | # adjust the rotation matrix to take into account translation 32 | M[0, 2] += (nW / 2) - cX 33 | M[1, 2] += (nH / 2) - cY 34 | if isinstance(borderValue, int): 35 | values = (borderValue, borderValue, borderValue) 36 | else: 37 | values = borderValue 38 | # perform the actual rotation and return the image 39 | return cv2.warpAffine(image, M, (nW, nH), borderValue=values) 40 | 41 | 42 | def pil_resize(img, size, order): 43 | if size[0] == img.shape[0] and size[1] == img.shape[1]: 44 | return img 45 | 46 | if order == 3: 47 | resample = Image.BICUBIC 48 | elif order == 0: 49 | resample = Image.NEAREST 50 | 51 | return np.asarray(Image.fromarray(img).resize(size[::-1], resample)) 52 | 53 | 54 | def pil_rescale(img, scale, order): 55 | height, width = img.shape[:2] 56 | target_size = (int(np.round(height*scale)), int(np.round(width*scale))) 57 | return pil_resize(img, target_size, order) 58 | 59 | 60 | def pil_rotate(img, degree, default_value): 61 | if isinstance(default_value, tuple): 62 | values = (default_value[0], default_value[1], default_value[2], 0) 63 | else: 64 | values = (default_value, default_value, default_value,0) 65 | img = Image.fromarray(img) 66 | if img.mode =='RGB': 67 | # set img padding == default_value 68 | img2 = img.convert('RGBA') 69 | rot = img2.rotate(degree, expand=1) 70 | fff = Image.new('RGBA', rot.size, values) # 灰色 71 | out = Image.composite(rot, fff, rot) 72 | img = out.convert(img.mode) 73 | 74 | else: 75 | # set label padding == default_value 76 | img2 = img.convert('RGBA') 77 | rot = img2.rotate(degree, expand=1) 78 | # a white image same size as rotated image 79 | fff = Image.new('RGBA', rot.size, values) 80 | # create a composite image using the alpha layer of rot as a mask 81 | out = Image.composite(rot, fff, rot) 82 | img = out.convert(img.mode) 83 | 84 | return np.asarray(img) 85 | 86 | 87 | def random_resize_long_image_list(img_list, min_long, max_long): 88 | target_long = random.randint(min_long, max_long) 89 | h, w = img_list[0].shape[:2] 90 | if w < h: 91 | scale = target_long / h 92 | else: 93 | scale = target_long / w 94 | out = [] 95 | for img in img_list: 96 | out.append(pil_rescale(img, scale, 3) ) 97 | return out 98 | 99 | 100 | def random_resize_long(img, min_long, max_long): 101 | target_long = random.randint(min_long, max_long) 102 | h, w = img.shape[:2] 103 | 104 | if w < h: 105 | scale = target_long / h 106 | else: 107 | scale = target_long / w 108 | 109 | return pil_rescale(img, scale, 3) 110 | 111 | 112 | def random_scale_list(img_list, scale_range, order): 113 | """ 114 | 输入:图像列表 115 | """ 116 | target_scale = scale_range[0] + random.random() * (scale_range[1] - scale_range[0]) 117 | 118 | if isinstance(img_list, tuple): 119 | assert img_list.__len__() == 2 120 | img1 = [] 121 | img2 = [] 122 | for img in img_list[0]: 123 | img1.append(pil_rescale(img, target_scale, order[0])) 124 | for img in img_list[1]: 125 | img2.append(pil_rescale(img, target_scale, order[1])) 126 | return (img1, img2) 127 | else: 128 | out = [] 129 | for img in img_list: 130 | out.append(pil_rescale(img, target_scale, order)) 131 | return out 132 | 133 | 134 | def random_scale(img, scale_range, order): 135 | 136 | target_scale = scale_range[0] + random.random() * (scale_range[1] - scale_range[0]) 137 | 138 | if isinstance(img, tuple): 139 | return (pil_rescale(img[0], target_scale, order[0]), pil_rescale(img[1], target_scale, order[1])) 140 | else: 141 | return pil_rescale(img, target_scale, order) 142 | 143 | 144 | def random_rotate_list(img_list, max_degree, default_values): 145 | degree = random.random() * max_degree 146 | if isinstance(img_list, tuple): 147 | assert img_list.__len__() == 2 148 | img1 = [] 149 | img2 = [] 150 | for img in img_list[0]: 151 | assert isinstance(img, np.ndarray) 152 | img1.append((pil_rotate(img, degree, default_values[0]))) 153 | for img in img_list[1]: 154 | img2.append((pil_rotate(img, degree, default_values[1]))) 155 | return (img1, img2) 156 | else: 157 | out = [] 158 | for img in img_list: 159 | out.append(pil_rotate(img, degree, default_values)) 160 | return out 161 | 162 | 163 | def random_rotate(img, max_degree, default_values): 164 | degree = random.random() * max_degree 165 | if isinstance(img, tuple): 166 | return (pil_rotate(img[0], degree, default_values[0]), 167 | pil_rotate(img[1], degree, default_values[1])) 168 | else: 169 | return pil_rotate(img, degree, default_values) 170 | 171 | 172 | def random_lr_flip_list(img_list): 173 | 174 | if bool(random.getrandbits(1)): 175 | if isinstance(img_list, tuple): 176 | assert img_list.__len__()==2 177 | img1=list((np.fliplr(m) for m in img_list[0])) 178 | img2=list((np.fliplr(m) for m in img_list[1])) 179 | 180 | return (img1, img2) 181 | else: 182 | return list([np.fliplr(m) for m in img_list]) 183 | else: 184 | return img_list 185 | 186 | 187 | def random_lr_flip(img): 188 | 189 | if bool(random.getrandbits(1)): 190 | if isinstance(img, tuple): 191 | return tuple([np.fliplr(m) for m in img]) 192 | else: 193 | return np.fliplr(img) 194 | else: 195 | return img 196 | 197 | 198 | def get_random_crop_box(imgsize, cropsize): 199 | h, w = imgsize 200 | 201 | ch = min(cropsize, h) 202 | cw = min(cropsize, w) 203 | 204 | w_space = w - cropsize 205 | h_space = h - cropsize 206 | 207 | if w_space > 0: 208 | cont_left = 0 209 | img_left = random.randrange(w_space + 1) 210 | else: 211 | cont_left = random.randrange(-w_space + 1) 212 | img_left = 0 213 | 214 | if h_space > 0: 215 | cont_top = 0 216 | img_top = random.randrange(h_space + 1) 217 | else: 218 | cont_top = random.randrange(-h_space + 1) 219 | img_top = 0 220 | 221 | return cont_top, cont_top+ch, cont_left, cont_left+cw, img_top, img_top+ch, img_left, img_left+cw 222 | 223 | 224 | def random_crop_list(images_list, cropsize, default_values): 225 | 226 | if isinstance(images_list, tuple): 227 | imgsize = images_list[0][0].shape[:2] 228 | elif isinstance(images_list, list): 229 | imgsize = images_list[0].shape[:2] 230 | else: 231 | raise RuntimeError('do not support the type of image_list') 232 | if isinstance(default_values, int): default_values = (default_values,) 233 | 234 | box = get_random_crop_box(imgsize, cropsize) 235 | if isinstance(images_list, tuple): 236 | assert images_list.__len__()==2 237 | img1 = [] 238 | img2 = [] 239 | for img in images_list[0]: 240 | f = default_values[0] 241 | if len(img.shape) == 3: 242 | cont = np.ones((cropsize, cropsize, img.shape[2]), img.dtype)*f 243 | else: 244 | cont = np.ones((cropsize, cropsize), img.dtype)*f 245 | cont[box[0]:box[1], box[2]:box[3]] = img[box[4]:box[5], box[6]:box[7]] 246 | img1.append(cont) 247 | for img in images_list[1]: 248 | f = default_values[1] 249 | if len(img.shape) == 3: 250 | cont = np.ones((cropsize, cropsize, img.shape[2]), img.dtype)*f 251 | else: 252 | cont = np.ones((cropsize, cropsize), img.dtype)*f 253 | cont[box[0]:box[1], box[2]:box[3]] = img[box[4]:box[5], box[6]:box[7]] 254 | img2.append(cont) 255 | return (img1, img2) 256 | else: 257 | out = [] 258 | for img in images_list: 259 | f = default_values 260 | if len(img.shape) == 3: 261 | cont = np.ones((cropsize, cropsize, img.shape[2]), img.dtype) * f 262 | else: 263 | cont = np.ones((cropsize, cropsize), img.dtype) * f 264 | cont[box[0]:box[1], box[2]:box[3]] = img[box[4]:box[5], box[6]:box[7]] 265 | out.append(cont) 266 | return out 267 | 268 | 269 | def random_crop(images, cropsize, default_values): 270 | 271 | if isinstance(images, np.ndarray): images = (images,) 272 | if isinstance(default_values, int): default_values = (default_values,) 273 | 274 | imgsize = images[0].shape[:2] 275 | box = get_random_crop_box(imgsize, cropsize) 276 | 277 | new_images = [] 278 | for img, f in zip(images, default_values): 279 | 280 | if len(img.shape) == 3: 281 | cont = np.ones((cropsize, cropsize, img.shape[2]), img.dtype)*f 282 | else: 283 | cont = np.ones((cropsize, cropsize), img.dtype)*f 284 | cont[box[0]:box[1], box[2]:box[3]] = img[box[4]:box[5], box[6]:box[7]] 285 | new_images.append(cont) 286 | 287 | if len(new_images) == 1: 288 | new_images = new_images[0] 289 | 290 | return new_images 291 | 292 | 293 | def top_left_crop(img, cropsize, default_value): 294 | 295 | h, w = img.shape[:2] 296 | 297 | ch = min(cropsize, h) 298 | cw = min(cropsize, w) 299 | 300 | if len(img.shape) == 2: 301 | container = np.ones((cropsize, cropsize), img.dtype)*default_value 302 | else: 303 | container = np.ones((cropsize, cropsize, img.shape[2]), img.dtype)*default_value 304 | 305 | container[:ch, :cw] = img[:ch, :cw] 306 | 307 | return container 308 | 309 | 310 | def center_crop(img, cropsize, default_value=0): 311 | 312 | h, w = img.shape[:2] 313 | 314 | ch = min(cropsize, h) 315 | cw = min(cropsize, w) 316 | 317 | sh = h - cropsize 318 | sw = w - cropsize 319 | 320 | if sw > 0: 321 | cont_left = 0 322 | img_left = int(round(sw / 2)) 323 | else: 324 | cont_left = int(round(-sw / 2)) 325 | img_left = 0 326 | 327 | if sh > 0: 328 | cont_top = 0 329 | img_top = int(round(sh / 2)) 330 | else: 331 | cont_top = int(round(-sh / 2)) 332 | img_top = 0 333 | 334 | if len(img.shape) == 2: 335 | container = np.ones((cropsize, cropsize), img.dtype)*default_value 336 | else: 337 | container = np.ones((cropsize, cropsize, img.shape[2]), img.dtype)*default_value 338 | 339 | container[cont_top:cont_top+ch, cont_left:cont_left+cw] = \ 340 | img[img_top:img_top+ch, img_left:img_left+cw] 341 | 342 | return container 343 | 344 | 345 | def HWC_to_CHW(img): 346 | return np.transpose(img, (2, 0, 1)) 347 | 348 | 349 | def pil_blur(img, radius): 350 | return np.array(Image.fromarray(img).filter(ImageFilter.GaussianBlur(radius=radius))) 351 | 352 | 353 | def random_blur(img): 354 | radius = random.random() 355 | # print('add blur: ', radius) 356 | if isinstance(img, list): 357 | out = [] 358 | for im in img: 359 | out.append(pil_blur(im, radius)) 360 | return out 361 | elif isinstance(img, np.ndarray): 362 | return pil_blur(img, radius) 363 | else: 364 | print(img) 365 | raise RuntimeError("do not support the input image type!") 366 | 367 | 368 | def save_image(image_numpy, image_path): 369 | """Save a numpy image to the disk 370 | Parameters: 371 | image_numpy (numpy array) -- input numpy array 372 | image_path (str) -- the path of the image 373 | """ 374 | image_pil = Image.fromarray(np.array(image_numpy,dtype=np.uint8)) 375 | image_pil.save(image_path) 376 | 377 | 378 | def im2arr(img_path, mode=1, dtype=np.uint8): 379 | """ 380 | :param img_path: 381 | :param mode: 382 | :return: numpy.ndarray, shape: H*W*C 383 | """ 384 | if mode==1: 385 | img = PIL.Image.open(img_path) 386 | arr = np.asarray(img, dtype=dtype) 387 | else: 388 | arr = tifffile.imread(img_path) 389 | if arr.ndim == 3: 390 | a, b, c = arr.shape 391 | if a < b and a < c: # 当arr为C*H*W时,需要交换通道顺序 392 | arr = arr.transpose([1,2,0]) 393 | # print('shape: ', arr.shape, 'dytpe: ',arr.dtype) 394 | return arr 395 | 396 | 397 | 398 | 399 | 400 | 401 | 402 | -------------------------------------------------------------------------------- /misc/logger_tool.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import time 3 | 4 | 5 | class Logger(object): 6 | def __init__(self, outfile): 7 | self.terminal = sys.stdout 8 | self.log_path = outfile 9 | now = time.strftime("%c") 10 | self.write('================ (%s) ================\n' % now) 11 | 12 | def write(self, message): 13 | self.terminal.write(message) 14 | with open(self.log_path, mode='a') as f: 15 | f.write(message) 16 | 17 | def write_dict(self, dict): 18 | message = '' 19 | for k, v in dict.items(): 20 | message += '%s: %.7f ' % (k, v) 21 | self.write(message) 22 | 23 | def write_dict_str(self, dict): 24 | message = '' 25 | for k, v in dict.items(): 26 | message += '%s: %s ' % (k, v) 27 | self.write(message) 28 | 29 | def flush(self): 30 | self.terminal.flush() 31 | 32 | 33 | class Timer: 34 | def __init__(self, starting_msg = None): 35 | self.start = time.time() 36 | self.stage_start = self.start 37 | 38 | if starting_msg is not None: 39 | print(starting_msg, time.ctime(time.time())) 40 | 41 | def __enter__(self): 42 | return self 43 | 44 | def __exit__(self, exc_type, exc_val, exc_tb): 45 | return 46 | 47 | def update_progress(self, progress): 48 | self.elapsed = time.time() - self.start 49 | self.est_total = self.elapsed / progress 50 | self.est_remaining = self.est_total - self.elapsed 51 | self.est_finish = int(self.start + self.est_total) 52 | 53 | 54 | def str_estimated_complete(self): 55 | return str(time.ctime(self.est_finish)) 56 | 57 | def str_estimated_remaining(self): 58 | return str(self.est_remaining/3600) + 'h' 59 | 60 | def estimated_remaining(self): 61 | return self.est_remaining/3600 62 | 63 | def get_stage_elapsed(self): 64 | return time.time() - self.stage_start 65 | 66 | def reset_stage(self): 67 | self.stage_start = time.time() 68 | 69 | def lapse(self): 70 | out = time.time() - self.stage_start 71 | self.stage_start = time.time() 72 | return out 73 | 74 | -------------------------------------------------------------------------------- /misc/metric_tool.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | ################### metrics ################### 5 | class AverageMeter(object): 6 | """Computes and stores the average and current value""" 7 | def __init__(self): 8 | self.initialized = False 9 | self.val = None 10 | self.avg = None 11 | self.sum = None 12 | self.count = None 13 | 14 | def initialize(self, val, weight): 15 | self.val = val 16 | self.avg = val 17 | self.sum = val * weight 18 | self.count = weight 19 | self.initialized = True 20 | 21 | def update(self, val, weight=1): 22 | if not self.initialized: 23 | self.initialize(val, weight) 24 | else: 25 | self.add(val, weight) 26 | 27 | def add(self, val, weight): 28 | self.val = val 29 | self.sum += val * weight 30 | self.count += weight 31 | self.avg = self.sum / self.count 32 | 33 | def value(self): 34 | return self.val 35 | 36 | def average(self): 37 | return self.avg 38 | 39 | def get_scores(self): 40 | scores_dict = cm2score(self.sum) 41 | return scores_dict 42 | 43 | def clear(self): 44 | self.initialized = False 45 | 46 | 47 | ################### cm metrics ################### 48 | class ConfuseMatrixMeter(AverageMeter): 49 | """Computes and stores the average and current value""" 50 | def __init__(self, n_class): 51 | super(ConfuseMatrixMeter, self).__init__() 52 | self.n_class = n_class 53 | 54 | def update_cm(self, pr, gt, weight=1): 55 | """获得当前混淆矩阵,并计算当前F1得分,并更新混淆矩阵""" 56 | val = get_confuse_matrix(num_classes=self.n_class, label_gts=gt, label_preds=pr) 57 | self.update(val, weight) 58 | current_score = cm2F1(val) 59 | return current_score 60 | 61 | def get_scores(self): 62 | scores_dict = cm2score(self.sum) 63 | return scores_dict 64 | 65 | 66 | 67 | def harmonic_mean(xs): 68 | harmonic_mean = len(xs) / sum((x+1e-6)**-1 for x in xs) 69 | return harmonic_mean 70 | 71 | 72 | def cm2F1(confusion_matrix): 73 | hist = confusion_matrix 74 | n_class = hist.shape[0] 75 | tp = np.diag(hist) 76 | sum_a1 = hist.sum(axis=1) 77 | sum_a0 = hist.sum(axis=0) 78 | # ---------------------------------------------------------------------- # 79 | # 1. Accuracy & Class Accuracy 80 | # ---------------------------------------------------------------------- # 81 | acc = tp.sum() / (hist.sum() + np.finfo(np.float32).eps) 82 | 83 | # recall 84 | recall = tp / (sum_a1 + np.finfo(np.float32).eps) 85 | # acc_cls = np.nanmean(recall) 86 | 87 | # precision 88 | precision = tp / (sum_a0 + np.finfo(np.float32).eps) 89 | 90 | # F1 score 91 | F1 = 2 * recall * precision / (recall + precision + np.finfo(np.float32).eps) 92 | mean_F1 = np.nanmean(F1) 93 | return mean_F1 94 | 95 | 96 | def cm2score(confusion_matrix): 97 | hist = confusion_matrix 98 | n_class = hist.shape[0] 99 | tp = np.diag(hist) 100 | sum_a1 = hist.sum(axis=1) 101 | sum_a0 = hist.sum(axis=0) 102 | # ---------------------------------------------------------------------- # 103 | # 1. Accuracy & Class Accuracy 104 | # ---------------------------------------------------------------------- # 105 | acc = tp.sum() / (hist.sum() + np.finfo(np.float32).eps) 106 | 107 | # recall 108 | recall = tp / (sum_a1 + np.finfo(np.float32).eps) 109 | # acc_cls = np.nanmean(recall) 110 | 111 | # precision 112 | precision = tp / (sum_a0 + np.finfo(np.float32).eps) 113 | 114 | # F1 score 115 | F1 = 2*recall * precision / (recall + precision + np.finfo(np.float32).eps) 116 | mean_F1 = np.nanmean(F1) 117 | # ---------------------------------------------------------------------- # 118 | # 2. Frequency weighted Accuracy & Mean IoU 119 | # ---------------------------------------------------------------------- # 120 | iu = tp / (sum_a1 + hist.sum(axis=0) - tp + np.finfo(np.float32).eps) 121 | mean_iu = np.nanmean(iu) 122 | 123 | freq = sum_a1 / (hist.sum() + np.finfo(np.float32).eps) 124 | fwavacc = (freq[freq > 0] * iu[freq > 0]).sum() 125 | 126 | # 127 | cls_iou = dict(zip(['iou_'+str(i) for i in range(n_class)], iu)) 128 | 129 | cls_precision = dict(zip(['precision_'+str(i) for i in range(n_class)], precision)) 130 | cls_recall = dict(zip(['recall_'+str(i) for i in range(n_class)], recall)) 131 | cls_F1 = dict(zip(['F1_'+str(i) for i in range(n_class)], F1)) 132 | 133 | score_dict = {'acc': acc, 'miou': mean_iu, 'mf1':mean_F1} 134 | score_dict.update(cls_iou) 135 | score_dict.update(cls_F1) 136 | score_dict.update(cls_precision) 137 | score_dict.update(cls_recall) 138 | return score_dict 139 | 140 | 141 | def get_confuse_matrix(num_classes, label_gts, label_preds): 142 | """计算一组预测的混淆矩阵""" 143 | def __fast_hist(label_gt, label_pred): 144 | """ 145 | Collect values for Confusion Matrix 146 | For reference, please see: https://en.wikipedia.org/wiki/Confusion_matrix 147 | :param label_gt: ground-truth 148 | :param label_pred: prediction 149 | :return: values for confusion matrix 150 | """ 151 | mask = (label_gt >= 0) & (label_gt < num_classes) 152 | hist = np.bincount(num_classes * label_gt[mask].astype(int) + label_pred[mask], 153 | minlength=num_classes**2).reshape(num_classes, num_classes) 154 | return hist 155 | confusion_matrix = np.zeros((num_classes, num_classes)) 156 | for lt, lp in zip(label_gts, label_preds): 157 | confusion_matrix += __fast_hist(lt.flatten(), lp.flatten()) 158 | return confusion_matrix 159 | 160 | 161 | def get_mIoU(num_classes, label_gts, label_preds): 162 | confusion_matrix = get_confuse_matrix(num_classes, label_gts, label_preds) 163 | score_dict = cm2score(confusion_matrix) 164 | return score_dict['miou'] 165 | -------------------------------------------------------------------------------- /misc/pyutils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import random 4 | import glob 5 | 6 | 7 | def seed_random(seed=2020): 8 | # 加入以下随机种子,数据输入,随机扩充等保持一致 9 | random.seed(seed) 10 | os.environ['PYTHONHASHSEED'] = str(seed) 11 | np.random.seed(seed) 12 | 13 | 14 | def mkdir(path): 15 | """create a single empty directory if it didn't exist 16 | 17 | Parameters: 18 | path (str) -- a single directory path 19 | """ 20 | if not os.path.exists(path): 21 | os.makedirs(path) 22 | 23 | 24 | def get_paths(image_folder_path, suffix='*.png'): 25 | """从文件夹中返回指定格式的文件 26 | :param image_folder_path: str 27 | :param suffix: str 28 | :return: list 29 | """ 30 | paths = sorted(glob.glob(os.path.join(image_folder_path, suffix))) 31 | return paths 32 | 33 | 34 | def get_paths_from_list(image_folder_path, list): 35 | """从image folder中找到list中的文件,返回path list""" 36 | out = [] 37 | for item in list: 38 | path = os.path.join(image_folder_path,item) 39 | out.append(path) 40 | return sorted(out) 41 | 42 | 43 | -------------------------------------------------------------------------------- /models/TransformerBaseNetworks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import math 5 | import torch 6 | 7 | from torch import nn 8 | from torch.nn import init 9 | from torch.nn import functional as F 10 | from torch.autograd import Function 11 | 12 | from math import sqrt 13 | 14 | import random 15 | 16 | class ConvBlock(torch.nn.Module): 17 | def __init__(self, input_size, output_size, kernel_size=3, stride=1, padding=1, bias=True, activation='prelu', norm=None): 18 | super(ConvBlock, self).__init__() 19 | self.conv = torch.nn.Conv2d(input_size, output_size, kernel_size, stride, padding, bias=bias) 20 | 21 | self.norm = norm 22 | if self.norm =='batch': 23 | self.bn = torch.nn.BatchNorm2d(output_size) 24 | elif self.norm == 'instance': 25 | self.bn = torch.nn.InstanceNorm2d(output_size) 26 | 27 | self.activation = activation 28 | if self.activation == 'relu': 29 | self.act = torch.nn.ReLU(True) 30 | elif self.activation == 'prelu': 31 | self.act = torch.nn.PReLU() 32 | elif self.activation == 'lrelu': 33 | self.act = torch.nn.LeakyReLU(0.2, True) 34 | elif self.activation == 'tanh': 35 | self.act = torch.nn.Tanh() 36 | elif self.activation == 'sigmoid': 37 | self.act = torch.nn.Sigmoid() 38 | 39 | def forward(self, x): 40 | if self.norm is not None: 41 | out = self.bn(self.conv(x)) 42 | else: 43 | out = self.conv(x) 44 | 45 | if self.activation != 'no': 46 | return self.act(out) 47 | else: 48 | return out 49 | 50 | class DeconvBlock(torch.nn.Module): 51 | def __init__(self, input_size, output_size, kernel_size=4, stride=2, padding=1, bias=True, activation='prelu', norm=None): 52 | super(DeconvBlock, self).__init__() 53 | self.deconv = torch.nn.ConvTranspose2d(input_size, output_size, kernel_size, stride, padding, bias=bias) 54 | 55 | self.norm = norm 56 | if self.norm == 'batch': 57 | self.bn = torch.nn.BatchNorm2d(output_size) 58 | elif self.norm == 'instance': 59 | self.bn = torch.nn.InstanceNorm2d(output_size) 60 | 61 | self.activation = activation 62 | if self.activation == 'relu': 63 | self.act = torch.nn.ReLU(True) 64 | elif self.activation == 'prelu': 65 | self.act = torch.nn.PReLU() 66 | elif self.activation == 'lrelu': 67 | self.act = torch.nn.LeakyReLU(0.2, True) 68 | elif self.activation == 'tanh': 69 | self.act = torch.nn.Tanh() 70 | elif self.activation == 'sigmoid': 71 | self.act = torch.nn.Sigmoid() 72 | 73 | def forward(self, x): 74 | if self.norm is not None: 75 | out = self.bn(self.deconv(x)) 76 | else: 77 | out = self.deconv(x) 78 | 79 | if self.activation is not None: 80 | return self.act(out) 81 | else: 82 | return out 83 | 84 | 85 | class ConvLayer(nn.Module): 86 | def __init__(self, in_channels, out_channels, kernel_size, stride, padding): 87 | super(ConvLayer, self).__init__() 88 | # reflection_padding = kernel_size // 2 89 | # self.reflection_pad = nn.ReflectionPad2d(reflection_padding) 90 | self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding) 91 | 92 | def forward(self, x): 93 | # out = self.reflection_pad(x) 94 | out = self.conv2d(x) 95 | return out 96 | 97 | 98 | class UpsampleConvLayer(torch.nn.Module): 99 | def __init__(self, in_channels, out_channels, kernel_size, stride): 100 | super(UpsampleConvLayer, self).__init__() 101 | self.conv2d = nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride=stride, padding=1) 102 | 103 | def forward(self, x): 104 | out = self.conv2d(x) 105 | return out 106 | 107 | 108 | class ResidualBlock(torch.nn.Module): 109 | def __init__(self, channels): 110 | super(ResidualBlock, self).__init__() 111 | self.conv1 = ConvLayer(channels, channels, kernel_size=3, stride=1, padding=1) 112 | self.conv2 = ConvLayer(channels, channels, kernel_size=3, stride=1, padding=1) 113 | self.relu = nn.ReLU() 114 | 115 | def forward(self, x): 116 | residual = x 117 | out = self.relu(self.conv1(x)) 118 | out = self.conv2(out) * 0.1 119 | out = torch.add(out, residual) 120 | return out 121 | 122 | 123 | 124 | def init_linear(linear): 125 | init.xavier_normal(linear.weight) 126 | linear.bias.data.zero_() 127 | 128 | 129 | def init_conv(conv, glu=True): 130 | init.kaiming_normal(conv.weight) 131 | if conv.bias is not None: 132 | conv.bias.data.zero_() 133 | 134 | 135 | class EqualLR: 136 | def __init__(self, name): 137 | self.name = name 138 | 139 | def compute_weight(self, module): 140 | weight = getattr(module, self.name + '_orig') 141 | fan_in = weight.data.size(1) * weight.data[0][0].numel() 142 | 143 | return weight * sqrt(2 / fan_in) 144 | 145 | @staticmethod 146 | def apply(module, name): 147 | fn = EqualLR(name) 148 | 149 | weight = getattr(module, name) 150 | del module._parameters[name] 151 | module.register_parameter(name + '_orig', nn.Parameter(weight.data)) 152 | module.register_forward_pre_hook(fn) 153 | 154 | return fn 155 | 156 | def __call__(self, module, input): 157 | weight = self.compute_weight(module) 158 | setattr(module, self.name, weight) 159 | 160 | 161 | def equal_lr(module, name='weight'): 162 | EqualLR.apply(module, name) 163 | 164 | return module -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .resnet import * 2 | -------------------------------------------------------------------------------- /models/__pycache__/EGCTNet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen11221/EGCTNet_pytorch/c35da37f05ca494aca4444050e80bde26e541ec2/models/__pycache__/EGCTNet.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/TransformerBaseNetworks.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen11221/EGCTNet_pytorch/c35da37f05ca494aca4444050e80bde26e541ec2/models/__pycache__/TransformerBaseNetworks.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen11221/EGCTNet_pytorch/c35da37f05ca494aca4444050e80bde26e541ec2/models/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/base_model.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen11221/EGCTNet_pytorch/c35da37f05ca494aca4444050e80bde26e541ec2/models/__pycache__/base_model.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/basic_model.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen11221/EGCTNet_pytorch/c35da37f05ca494aca4444050e80bde26e541ec2/models/__pycache__/basic_model.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/evaluator.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen11221/EGCTNet_pytorch/c35da37f05ca494aca4444050e80bde26e541ec2/models/__pycache__/evaluator.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/help_funcs.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen11221/EGCTNet_pytorch/c35da37f05ca494aca4444050e80bde26e541ec2/models/__pycache__/help_funcs.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/losses.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen11221/EGCTNet_pytorch/c35da37f05ca494aca4444050e80bde26e541ec2/models/__pycache__/losses.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/networks.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen11221/EGCTNet_pytorch/c35da37f05ca494aca4444050e80bde26e541ec2/models/__pycache__/networks.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/pixel_shuffel_up.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen11221/EGCTNet_pytorch/c35da37f05ca494aca4444050e80bde26e541ec2/models/__pycache__/pixel_shuffel_up.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/resnet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen11221/EGCTNet_pytorch/c35da37f05ca494aca4444050e80bde26e541ec2/models/__pycache__/resnet.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/trainer.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen11221/EGCTNet_pytorch/c35da37f05ca494aca4444050e80bde26e541ec2/models/__pycache__/trainer.cpython-37.pyc -------------------------------------------------------------------------------- /models/base_model.py: -------------------------------------------------------------------------------- 1 | import torch.nn.functional 2 | from models.TransformerBaseNetworks import * 3 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_ 4 | import math 5 | 6 | class ASPP_v1(nn.Module): 7 | def __init__(self, channel): 8 | super(ASPP_v1, self).__init__() 9 | self.dilate1 = nn.Conv2d(channel, channel, kernel_size=3, dilation=1, padding=1) 10 | self.dilate2 = nn.Conv2d(channel, channel, kernel_size=3, dilation=2, padding=2) 11 | self.dilate3 = nn.Conv2d(channel, channel, kernel_size=3, dilation=4, padding=4) 12 | self.dilate4 = nn.Conv2d(channel, channel, kernel_size=3, dilation=8, padding=8) 13 | self.leakyReLU = nn.LeakyReLU(inplace=True) 14 | self.conv_1x1_output = nn.Conv2d(channel * 4, channel, 1, 1) 15 | for m in self.modules(): 16 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): 17 | if m.bias is not None: 18 | m.bias.data.zero_() 19 | 20 | def forward(self, x): 21 | dilate1_out = self.leakyReLU(self.dilate1(x)) 22 | dilate2_out = self.leakyReLU(self.dilate2(x)) 23 | dilate3_out = self.leakyReLU(self.dilate3(x)) 24 | dilate4_out = self.leakyReLU(self.dilate4(x)) 25 | out = self.conv_1x1_output(torch.cat([dilate1_out, dilate2_out, 26 | dilate3_out, dilate4_out], dim=1)) 27 | return out 28 | 29 | class OverlapPatchEmbed(nn.Module): 30 | """ Image to Patch Embedding 31 | """ 32 | 33 | def __init__(self, img_size=224, patch_size=7, stride=4, in_chans=3, embed_dim=768): 34 | super().__init__() 35 | img_size = to_2tuple(img_size) 36 | patch_size = to_2tuple(patch_size) 37 | 38 | self.img_size = img_size 39 | self.patch_size = patch_size 40 | self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1] 41 | self.num_patches = self.H * self.W 42 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride, 43 | padding=(patch_size[0] // 2, patch_size[1] // 2)) 44 | self.norm = nn.LayerNorm(embed_dim) 45 | 46 | self.apply(self._init_weights) 47 | 48 | def _init_weights(self, m): 49 | if isinstance(m, nn.Linear): 50 | trunc_normal_(m.weight, std=.02) 51 | if isinstance(m, nn.Linear) and m.bias is not None: 52 | nn.init.constant_(m.bias, 0) 53 | elif isinstance(m, nn.LayerNorm): 54 | nn.init.constant_(m.bias, 0) 55 | nn.init.constant_(m.weight, 1.0) 56 | elif isinstance(m, nn.Conv2d): 57 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 58 | fan_out //= m.groups 59 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 60 | if m.bias is not None: 61 | m.bias.data.zero_() 62 | 63 | def forward(self, x): 64 | # pdb.set_trace() 65 | x = self.proj(x) 66 | _, _, H, W = x.shape 67 | x = x.flatten(2).transpose(1, 2) 68 | x = self.norm(x) 69 | 70 | return x, H, W 71 | 72 | class EdgeAddSementic(nn.Module): 73 | def __init__(self, sem_channel, edge_channel, out_channel, stride=4): 74 | super(EdgeAddSementic, self).__init__() 75 | self.conv1 = nn.Conv2d(sem_channel, out_channel, 1, 1) 76 | self.conv2 = nn.Conv2d(edge_channel, out_channel, 1, 1) 77 | # self.up = nn.ConvTranspose2d(32, 1, kernel_size=stride, stride=stride) 78 | 79 | self.upsample = nn.Upsample(scale_factor=stride, mode='bilinear') 80 | self.upconv2 = nn.Conv2d(out_channel, 1, 3, padding=1) 81 | self.probability = nn.Sigmoid() 82 | self.conv3 = nn.Conv2d(out_channel, out_channel, 3, 1, 1) 83 | self.bn = nn.BatchNorm2d(out_channel) 84 | self.LeakRelu = nn.LeakyReLU(inplace=True) 85 | 86 | def forward(self, sem, edge): 87 | sem = self.conv1(sem) 88 | edge = self.conv2(edge) 89 | sem_up = self.upconv2(self.upsample(sem)) 90 | sem_up = self.probability(sem_up) 91 | out = edge * sem_up 92 | out = self.conv3(out) 93 | out = self.bn(out) 94 | out = self.LeakRelu(out) 95 | return out 96 | 97 | class SementicAddEdge(nn.Module): 98 | def __init__(self, sem_channel, edge_channel, out_channel, stride=2): 99 | super(SementicAddEdge, self).__init__() 100 | self.conv1 = nn.Conv2d(sem_channel, out_channel, 1, 1) 101 | self.conv2 = nn.Conv2d(edge_channel, out_channel, 1, 1) 102 | self.upsample = nn.Upsample(scale_factor=stride, mode='bilinear') 103 | self.upconv2 = nn.Conv2d(out_channel, out_channel, 3, padding=1) 104 | self.conv3 = nn.Conv2d(out_channel, out_channel, 3, 1, 1) 105 | self.bn = nn.BatchNorm2d(out_channel) 106 | self.LeakRelu = nn.LeakyReLU(inplace=True) 107 | 108 | def forward(self, sem, edge): 109 | sem = self.conv1(sem) 110 | edge = self.conv2(edge) 111 | sem_up = self.upconv2(self.upsample(sem)) 112 | out = sem_up + edge 113 | out = self.conv3(out) 114 | out = self.bn(out) 115 | out = self.LeakRelu(out) 116 | return out 117 | 118 | class Attention(nn.Module): 119 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1): 120 | super().__init__() 121 | assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}." 122 | 123 | self.dim = dim 124 | self.num_heads = num_heads 125 | head_dim = dim // num_heads 126 | self.scale = qk_scale or head_dim ** -0.5 127 | 128 | self.q = nn.Linear(dim, dim, bias=qkv_bias) 129 | self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias) 130 | self.attn_drop = nn.Dropout(attn_drop) 131 | self.proj = nn.Linear(dim, dim) 132 | self.proj_drop = nn.Dropout(proj_drop) 133 | 134 | self.sr_ratio = sr_ratio 135 | if sr_ratio > 1: 136 | self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio) 137 | self.norm = nn.LayerNorm(dim) 138 | 139 | self.apply(self._init_weights) 140 | 141 | def _init_weights(self, m): 142 | if isinstance(m, nn.Linear): 143 | trunc_normal_(m.weight, std=.02) 144 | if isinstance(m, nn.Linear) and m.bias is not None: 145 | nn.init.constant_(m.bias, 0) 146 | elif isinstance(m, nn.LayerNorm): 147 | nn.init.constant_(m.bias, 0) 148 | nn.init.constant_(m.weight, 1.0) 149 | elif isinstance(m, nn.Conv2d): 150 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 151 | fan_out //= m.groups 152 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 153 | if m.bias is not None: 154 | m.bias.data.zero_() 155 | 156 | def forward(self, x, H, W): 157 | 158 | B, N, C = x.shape 159 | q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) 160 | 161 | if self.sr_ratio > 1: 162 | x_ = x.permute(0, 2, 1).reshape(B, C, H, W) 163 | x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1) 164 | x_ = self.norm(x_) 165 | kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 166 | else: 167 | kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 168 | k, v = kv[0], kv[1] 169 | 170 | attn = (q @ k.transpose(-2, -1)) * self.scale 171 | attn = attn.softmax(dim=-1) 172 | attn = self.attn_drop(attn) 173 | 174 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 175 | x = self.proj(x) 176 | x = self.proj_drop(x) 177 | 178 | return x 179 | 180 | class DWConv(nn.Module): 181 | def __init__(self, dim=768): 182 | super(DWConv, self).__init__() 183 | self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim) 184 | 185 | def forward(self, x, H, W): 186 | B, N, C = x.shape 187 | x = x.transpose(1, 2).view(B, C, H, W) 188 | x = self.dwconv(x) 189 | x = x.flatten(2).transpose(1, 2) 190 | 191 | return x 192 | 193 | class Mlp(nn.Module): 194 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 195 | super().__init__() 196 | out_features = out_features or in_features 197 | hidden_features = hidden_features or in_features 198 | self.fc1 = nn.Linear(in_features, hidden_features) 199 | self.dwconv = DWConv(hidden_features) 200 | self.act = act_layer() 201 | self.fc2 = nn.Linear(hidden_features, out_features) 202 | self.drop = nn.Dropout(drop) 203 | 204 | self.apply(self._init_weights) 205 | 206 | def _init_weights(self, m): 207 | if isinstance(m, nn.Linear): 208 | trunc_normal_(m.weight, std=.02) 209 | if isinstance(m, nn.Linear) and m.bias is not None: 210 | nn.init.constant_(m.bias, 0) 211 | elif isinstance(m, nn.LayerNorm): 212 | nn.init.constant_(m.bias, 0) 213 | nn.init.constant_(m.weight, 1.0) 214 | elif isinstance(m, nn.Conv2d): 215 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 216 | fan_out //= m.groups 217 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 218 | if m.bias is not None: 219 | m.bias.data.zero_() 220 | 221 | def forward(self, x, H, W): 222 | x = self.fc1(x) 223 | x = self.dwconv(x, H, W) 224 | x = self.act(x) 225 | x = self.drop(x) 226 | x = self.fc2(x) 227 | x = self.drop(x) 228 | return x 229 | 230 | class Block(nn.Module): 231 | 232 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 233 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, sr_ratio=1): 234 | super().__init__() 235 | self.norm1 = norm_layer(dim) 236 | self.attn = Attention( 237 | dim, 238 | num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, 239 | attn_drop=attn_drop, proj_drop=drop, sr_ratio=sr_ratio) 240 | 241 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 242 | self.norm2 = norm_layer(dim) 243 | mlp_hidden_dim = int(dim * mlp_ratio) 244 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 245 | 246 | self.apply(self._init_weights) 247 | 248 | def _init_weights(self, m): 249 | if isinstance(m, nn.Linear): 250 | trunc_normal_(m.weight, std=.02) 251 | if isinstance(m, nn.Linear) and m.bias is not None: 252 | nn.init.constant_(m.bias, 0) 253 | elif isinstance(m, nn.LayerNorm): 254 | nn.init.constant_(m.bias, 0) 255 | nn.init.constant_(m.weight, 1.0) 256 | elif isinstance(m, nn.Conv2d): 257 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 258 | fan_out //= m.groups 259 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 260 | if m.bias is not None: 261 | m.bias.data.zero_() 262 | 263 | def forward(self, x, H, W): 264 | 265 | x = x + self.drop_path(self.attn(self.norm1(x), H, W)) 266 | x = x + self.drop_path(self.mlp(self.norm2(x), H, W)) 267 | return x 268 | 269 | # 𝑈𝑝𝑠𝑎𝑚𝑙e 270 | class BilinearUp(nn.Module): 271 | def __init__(self, in_channels, out_channels, scale=2): 272 | super(BilinearUp, self).__init__() 273 | self.upsample = nn.Upsample(scale_factor=scale, mode='bilinear') 274 | self.finalconv2 = nn.Conv2d(in_channels, in_channels // 2, 3, padding=1) 275 | self.finalrelu1 = nn.LeakyReLU(inplace=True) 276 | self.finalconv3 = nn.Conv2d(in_channels // 2, out_channels, 3, padding=1) 277 | 278 | def forward(self, x): 279 | x = self.upsample(x) 280 | x = self.finalconv2(x) 281 | x = self.finalrelu1(x) 282 | x = self.finalconv3(x) 283 | x = self.finalrelu1(x) 284 | return x 285 | 286 | class EdgeFusion(nn.Module): 287 | def __init__(self, in_chn, out_chn): 288 | super(EdgeFusion, self).__init__() 289 | 290 | self.conv1 = torch.nn.Sequential( 291 | torch.nn.Conv2d(in_chn, in_chn, kernel_size=3, padding=1), 292 | torch.nn.ReLU(inplace=True), 293 | ) 294 | self.conv_out = torch.nn.Sequential( 295 | torch.nn.Conv2d(in_chn, out_chn, kernel_size=1, padding=0), 296 | torch.nn.BatchNorm2d(out_chn), 297 | torch.nn.ReLU(inplace=True), 298 | ) 299 | 300 | def forward(self, x, y): 301 | x1 = self.conv1(x) 302 | y1 = self.conv1(y) 303 | 304 | return self.conv_out(x + x1 + y + y1) -------------------------------------------------------------------------------- /models/basic_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | 5 | from misc.imutils import save_image 6 | from models.networks import * 7 | 8 | 9 | class CDEvaluator(): 10 | 11 | def __init__(self, args): 12 | 13 | self.n_class = args.n_class 14 | # define G 15 | self.net_G = define_G(args=args, gpu_ids=args.gpu_ids) 16 | 17 | self.device = torch.device("cuda:%s" % args.gpu_ids[0] 18 | if torch.cuda.is_available() and len(args.gpu_ids)>0 19 | else "cpu") 20 | 21 | print(self.device) 22 | 23 | self.checkpoint_dir = args.checkpoint_dir 24 | 25 | self.pred_dir = args.output_folder 26 | os.makedirs(self.pred_dir, exist_ok=True) 27 | 28 | def load_checkpoint(self, checkpoint_name='best_ckpt.pt'): 29 | 30 | if os.path.exists(os.path.join(self.checkpoint_dir, checkpoint_name)): 31 | # load the entire checkpoint 32 | checkpoint = torch.load(os.path.join(self.checkpoint_dir, checkpoint_name), 33 | map_location=self.device) 34 | 35 | self.net_G.load_state_dict(checkpoint['model_G_state_dict']) 36 | self.net_G.to(self.device) 37 | # update some other states 38 | self.best_val_acc = checkpoint['best_val_acc'] 39 | self.best_epoch_id = checkpoint['best_epoch_id'] 40 | 41 | else: 42 | raise FileNotFoundError('no such checkpoint %s' % checkpoint_name) 43 | return self.net_G 44 | 45 | 46 | def _visualize_pred(self): 47 | pred = torch.argmax(self.G_pred, dim=1, keepdim=True) 48 | pred_vis = pred * 255 49 | return pred_vis 50 | 51 | def _forward_pass(self, batch): 52 | self.batch = batch 53 | img_in1 = batch['A'].to(self.device) 54 | img_in2 = batch['B'].to(self.device) 55 | self.shape_h = img_in1.shape[-2] 56 | self.shape_w = img_in1.shape[-1] 57 | self.G_pred = self.net_G(img_in1, img_in2)[-1] 58 | return self._visualize_pred() 59 | 60 | def eval(self): 61 | self.net_G.eval() 62 | 63 | def _save_predictions(self): 64 | """ 65 | 保存模型输出结果,二分类图像 66 | """ 67 | 68 | preds = self._visualize_pred() 69 | name = self.batch['name'] 70 | for i, pred in enumerate(preds): 71 | file_name = os.path.join( 72 | self.pred_dir, name[i].replace('.jpg', '.png')) 73 | pred = pred[0].cpu().numpy() 74 | save_image(pred, file_name) 75 | 76 | -------------------------------------------------------------------------------- /models/evaluator.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import matplotlib.pyplot as plt 4 | 5 | from torch.autograd import Variable as V 6 | from datasets.data_utils import CDDataAugmentation 7 | import torchvision.transforms.functional as TF 8 | from models.networks import * 9 | from misc.metric_tool import ConfuseMatrixMeter 10 | from misc.logger_tool import Logger 11 | from utils import de_norm 12 | import utils 13 | import cv2 14 | from tqdm import tqdm 15 | from osgeo import gdal,ogr,osr 16 | 17 | 18 | # Decide which device we want to run on 19 | # torch.cuda.current_device() 20 | 21 | # device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 22 | 23 | 24 | class CDEvaluator(): 25 | 26 | def __init__(self, args, dataloader): 27 | 28 | self.dataloader = dataloader 29 | 30 | self.n_class = args.n_class 31 | # define G 32 | self.net_G = define_G(args=args, gpu_ids=args.gpu_ids) 33 | self.device = torch.device("cuda:%s" % args.gpu_ids[0] if torch.cuda.is_available() and len(args.gpu_ids)>0 34 | else "cpu") 35 | print(self.device) 36 | 37 | # define some other vars to record the training states 38 | self.running_metric = ConfuseMatrixMeter(n_class=self.n_class) 39 | 40 | # define logger file 41 | logger_path = os.path.join(args.checkpoint_dir, 'log_test.txt') 42 | self.logger = Logger(logger_path) 43 | self.logger.write_dict_str(args.__dict__) 44 | 45 | 46 | # training log 47 | self.epoch_acc = 0 48 | self.best_val_acc = 0.0 49 | self.best_epoch_id = 0 50 | 51 | self.steps_per_epoch = len(dataloader) 52 | 53 | self.G_pred = None 54 | self.pred_vis = None 55 | self.batch = None 56 | self.is_training = False 57 | self.batch_id = 0 58 | self.epoch_id = 0 59 | self.checkpoint_dir = args.checkpoint_dir 60 | self.vis_dir = args.vis_dir 61 | 62 | # check and create model dir 63 | if os.path.exists(self.checkpoint_dir) is False: 64 | os.mkdir(self.checkpoint_dir) 65 | if os.path.exists(self.vis_dir) is False: 66 | os.mkdir(self.vis_dir) 67 | 68 | 69 | def _load_checkpoint(self, checkpoint_name='best_ckpt.pt'): 70 | 71 | if os.path.exists(os.path.join(self.checkpoint_dir, checkpoint_name)): 72 | self.logger.write('loading last checkpoint...\n') 73 | # load the entire checkpoint 74 | checkpoint = torch.load(os.path.join(self.checkpoint_dir, checkpoint_name), map_location=self.device) 75 | 76 | self.net_G.load_state_dict(checkpoint['model_G_state_dict']) 77 | 78 | self.net_G.to(self.device) 79 | 80 | # update some other states 81 | self.best_val_acc = checkpoint['best_val_acc'] 82 | self.best_epoch_id = checkpoint['best_epoch_id'] 83 | 84 | self.logger.write('Eval Historical_best_acc = %.4f (at epoch %d)\n' % 85 | (self.best_val_acc, self.best_epoch_id)) 86 | self.logger.write('\n') 87 | 88 | else: 89 | raise FileNotFoundError('no such checkpoint %s' % checkpoint_name) 90 | 91 | 92 | def _visualize_pred(self): 93 | pred = torch.argmax(self.G_pred, dim=1, keepdim=True) 94 | pred_vis = pred * 255 95 | return pred_vis 96 | 97 | 98 | def _update_metric(self): 99 | """ 100 | update metric 101 | """ 102 | target = self.batch['L'].to(self.device).detach() 103 | G_pred = self.G_pred.detach() 104 | G_pred = torch.argmax(G_pred, dim=1) 105 | 106 | current_score = self.running_metric.update_cm(pr=G_pred.cpu().numpy(), gt=target.cpu().numpy()) 107 | return current_score 108 | 109 | def _collect_running_batch_states(self): 110 | 111 | running_acc = self._update_metric() 112 | 113 | m = len(self.dataloader) 114 | 115 | if np.mod(self.batch_id, 100) == 1: 116 | message = 'Is_training: %s. [%d,%d], running_mf1: %.5f\n' %\ 117 | (self.is_training, self.batch_id, m, running_acc) 118 | self.logger.write(message) 119 | 120 | # if np.mod(self.batch_id, 100) == 1: 121 | vis_input = utils.make_numpy_grid(de_norm(self.batch['A'])) 122 | vis_input2 = utils.make_numpy_grid(de_norm(self.batch['B'])) 123 | 124 | vis_pred = utils.make_numpy_grid(self._visualize_pred()) 125 | 126 | vis_gt = utils.make_numpy_grid(self.batch['L']) 127 | vis = np.concatenate([vis_input, vis_input2, vis_pred, vis_gt], axis=0) 128 | vis = np.clip(vis, a_min=0.0, a_max=1.0) 129 | file_name = os.path.join( 130 | self.vis_dir, 'eval_' + str(self.batch_id)+'.jpg') 131 | plt.imsave(file_name, vis) 132 | 133 | 134 | def _collect_epoch_states(self): 135 | 136 | scores_dict = self.running_metric.get_scores() 137 | 138 | np.save(os.path.join(self.checkpoint_dir, 'scores_dict.npy'), scores_dict) 139 | 140 | self.epoch_acc = scores_dict['mf1'] 141 | 142 | with open(os.path.join(self.checkpoint_dir, '%s.txt' % (self.epoch_acc)), 143 | mode='a') as file: 144 | pass 145 | 146 | message = '' 147 | for k, v in scores_dict.items(): 148 | message += '%s: %.5f ' % (k, v) 149 | self.logger.write('%s\n' % message) # save the message 150 | 151 | self.logger.write('\n') 152 | 153 | def _clear_cache(self): 154 | self.running_metric.clear() 155 | 156 | def _forward_pass(self, batch): 157 | self.batch = batch 158 | img_in1 = batch['A'].to(self.device) 159 | img_in2 = batch['B'].to(self.device) 160 | self.G_pred = self.net_G(img_in1, img_in2)[-1] 161 | 162 | def eval_models(self,checkpoint_name='best_ckpt.pt'): 163 | 164 | self._load_checkpoint(checkpoint_name) 165 | 166 | ################## Eval ################## 167 | ########################################## 168 | self.logger.write('Begin evaluation...\n') 169 | self._clear_cache() 170 | self.is_training = False 171 | self.net_G.eval() 172 | 173 | # Iterate over data. 174 | for self.batch_id, batch in enumerate(self.dataloader, 0): 175 | with torch.no_grad(): 176 | self._forward_pass(batch) 177 | self._collect_running_batch_states() 178 | self._collect_epoch_states() 179 | 180 | def block_gdal_input(self, img, img_size, crop=512, pad=0): # gdal分块读取 181 | [img_width, img_height] = img_size 182 | x_height = x_width = crop 183 | crop_width = x_width - 2 * pad 184 | crop_height = x_height - 2 * pad 185 | 186 | numBand = 3 187 | # numBand = img.RasterCount 188 | num_Xblock = img_width // crop_width 189 | x_start, x_end = [], [] 190 | x_start.append(0) 191 | for i in range(num_Xblock): 192 | xs = crop_width * (i + 1) - pad 193 | xe = crop_width * i + x_width - pad 194 | if (i == num_Xblock - 1): 195 | xs = img_width - crop_width - pad 196 | xe = min(xe, img_width) 197 | x_start.append(xs) 198 | x_end.append(xe) 199 | x_end.append(img_width) 200 | 201 | num_Yblock = img_height // crop_height 202 | y_start, y_end = [], [] 203 | y_start.append(0) 204 | for i in range(num_Yblock): 205 | ys = crop_height * (i + 1) - pad 206 | ye = crop_height * i + x_height - pad 207 | if (i == num_Yblock - 1): 208 | ys = img_height - crop_height - pad 209 | ye = min(ye, img_height) 210 | y_start.append(ys) 211 | y_end.append(ye) 212 | y_end.append(img_height) 213 | 214 | if img_width % crop_width > 0: 215 | num_Xblock = num_Xblock + 1 216 | if img_height % crop_height > 0: 217 | num_Yblock = num_Yblock + 1 218 | for i in range(num_Yblock): 219 | for j in range(num_Xblock): 220 | [x0, x1, y0, y1] = [x_start[j], x_end[j], y_start[i], y_end[i]] 221 | 222 | feature = np.zeros(np.append([y1 - y0, x1 - x0], numBand), np.float32) 223 | for ii in range(numBand): 224 | floatData = np.array(img.GetRasterBand(ii + 1).ReadAsArray(x0, y0, x1 - x0, y1 - y0),dtype=np.float32) 225 | # floatData = np.array(img.GetRasterBand(4-ii).ReadAsArray(x0,y0,x1-x0,y1-y0)) 226 | 227 | feature[..., ii] = (floatData/255-0.5)/0.5 228 | # feature[..., ii] = floatData 229 | 230 | if (i == 0): 231 | feature_pad = cv2.copyMakeBorder(feature, 232 | pad, x_height - pad - feature.shape[0], 233 | 0, 0, cv2.BORDER_REFLECT_101) 234 | else: 235 | feature_pad = cv2.copyMakeBorder(feature, 236 | 0, x_height - feature.shape[0], 237 | 0, 0, cv2.BORDER_REFLECT_101) 238 | if (j == 0): 239 | feature_pad = cv2.copyMakeBorder(feature_pad, 240 | 0, 0, pad, x_width - pad - feature_pad.shape[1], 241 | cv2.BORDER_REFLECT_101) 242 | else: 243 | feature_pad = cv2.copyMakeBorder(feature_pad, 244 | 0, 0, 0, x_width - feature_pad.shape[1], 245 | cv2.BORDER_REFLECT_101) 246 | 247 | yield feature_pad, [x0, x1, y0, y1] 248 | 249 | def pred_gdal_blocks_write(self, img_pathA, img_pathB,out_path=''): 250 | self._load_checkpoint() 251 | 252 | ################## Eval ################## 253 | ########################################## 254 | self.logger.write('Begin evaluation...\n') 255 | self._clear_cache() 256 | self.is_training = False 257 | self.net_G.eval() 258 | 259 | #logger.info('predicting %s' % img_pathA) 260 | 261 | batch_size = 1 262 | pad = 16 263 | x_width = 256 264 | x_height = 256 265 | crop_width = x_width - 2 * pad 266 | crop_height = x_height - 2 * pad 267 | datasetname = gdal.Open(img_pathA, gdal.GA_ReadOnly) 268 | # datasetname = reproject_dataset(img_path,5500,5500) 269 | if datasetname is None: 270 | print('Could not open %s' % img_pathA) 271 | img_width = datasetname.RasterXSize 272 | img_height = datasetname.RasterYSize 273 | imageSize = [img_width, img_height] 274 | nBand = datasetname.RasterCount 275 | 276 | datasetname2 = gdal.Open(img_pathB, gdal.GA_ReadOnly) 277 | if datasetname2 is None: 278 | print('Could not open %s' % img_pathB) 279 | img_width2 = datasetname2.RasterXSize 280 | img_height2 = datasetname2.RasterYSize 281 | 282 | if img_width != img_width2 or img_height != img_height2: 283 | print("范围不一致") 284 | return 285 | 286 | driver = gdal.GetDriverByName('GTiff') 287 | if out_path == '': 288 | out_path = img_pathA.rsplit('.', 1)[0] + '_res.tif' 289 | outRaster = driver.Create(out_path, img_width, img_height, 1, gdal.GDT_Byte) 290 | outband = outRaster.GetRasterBand(1) 291 | outRaster.SetGeoTransform(datasetname.GetGeoTransform()) 292 | outRaster.SetProjection(datasetname.GetProjection()) 293 | 294 | num_Xblock = img_width // crop_width 295 | if img_width % crop_width > 0: 296 | num_Xblock += 1 297 | num_Yblock = img_height // crop_height 298 | if img_height % crop_height > 0: 299 | num_Yblock += 1 300 | i = 0 301 | blocks = num_Xblock * num_Yblock 302 | # mask = np.zeros([batch_size, img_height, img_width],dtype=np.float32) 303 | input_gen = self.block_gdal_input(datasetname, imageSize, x_width, pad) 304 | input_gen2 = self.block_gdal_input(datasetname2, imageSize, x_width, pad) 305 | for i in tqdm(range(blocks)): 306 | imgA, xy = next(input_gen) 307 | imgB, xyB = next(input_gen2) 308 | if (xy[0] > 0): 309 | xs = xy[0] + pad 310 | else: 311 | xs = xy[0] 312 | 313 | if (xy[2] > 0): 314 | ys = xy[2] + pad 315 | else: 316 | ys = xy[2] 317 | #if np.max(imgA[pad: pad + crop_height, pad: pad + crop_width]) < 5: 318 | #predictions = np.zeros([batch_size, x_height, x_width]) 319 | #else: 320 | imgs = [] 321 | imgs.append(imgA) 322 | imgs = np.array(imgs) 323 | # imgs = imgs[:,np.newaxis] 324 | # np.squeeze(imgs) 325 | # np.expand_dims(imgs,axis=1) 326 | imgs = imgs.transpose(0, 3, 1, 2) 327 | imgs = V(torch.Tensor(np.array(imgs, np.float32)).to(self.device)) 328 | 329 | # imgs = TF.normalize(imgs, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 330 | # imgs = imgs.resize(1,3,256,256) 331 | 332 | imgs2 = [] 333 | imgs2.append(imgB) 334 | imgs2 = np.array(imgs2) 335 | # imgs = imgs[:,np.newaxis] 336 | # np.squeeze(imgs) 337 | # np.expand_dims(imgs,axis=1) 338 | imgs2 = imgs2.transpose(0, 3, 1, 2) 339 | imgs2 = V(torch.Tensor(np.array(imgs2, np.float32)).to(self.device)) 340 | # imgs2 = TF.to_tensor(np.array(imgs2[0], np.float32)).to(self.device) 341 | # imgs2 = TF.normalize(imgs2, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 342 | # imgs2 = imgs2.resize(1, 3, 256, 256) 343 | 344 | predictions = self.net_G(imgs, imgs2)[-1] 345 | # predictions = predictions.numpy() 346 | predictions = torch.argmax(predictions, dim=1, keepdim=True) 347 | # print(predictions) 348 | predictions = np.array(predictions)[0][0] 349 | 350 | prediction = predictions[pad: pad + crop_height, 351 | pad: pad + crop_width] 352 | 353 | outband.WriteArray((prediction * 255).astype(np.int), xs, ys) 354 | # mask[0,ys: ys+crop_height,\ 355 | # xs : xs+crop_width] = prediction.astype(np.float32) 356 | 357 | # if(i%num_Xblock==0): 358 | # y=i//num_Xblock 359 | # logger.info('predicting data: [{}{}] {}%'.\ 360 | # format('=' * (y+1), 361 | # ' ' * (num_Yblock - y-1), 362 | # 100 * (y+1)/num_Yblock)) 363 | outband.FlushCache() 364 | # sys.stdout.flush() 365 | # i=i+1 366 | datasetname = None 367 | datasetname2 = None 368 | outRaster = None 369 | return # np.squeeze(mask*255).astype(np.int) 370 | -------------------------------------------------------------------------------- /models/help_funcs.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from einops import rearrange 4 | from torch import nn 5 | 6 | 7 | class TwoLayerConv2d(nn.Sequential): 8 | def __init__(self, in_channels, out_channels, kernel_size=3): 9 | super().__init__(nn.Conv2d(in_channels, in_channels, kernel_size=kernel_size, 10 | padding=kernel_size // 2, stride=1, bias=False), 11 | nn.BatchNorm2d(in_channels), 12 | nn.ReLU(), 13 | nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, 14 | padding=kernel_size // 2, stride=1) 15 | ) 16 | 17 | 18 | class Residual(nn.Module): 19 | def __init__(self, fn): 20 | super().__init__() 21 | self.fn = fn 22 | def forward(self, x, **kwargs): 23 | return self.fn(x, **kwargs) + x 24 | 25 | 26 | class Residual2(nn.Module): 27 | def __init__(self, fn): 28 | super().__init__() 29 | self.fn = fn 30 | def forward(self, x, x2, **kwargs): 31 | return self.fn(x, x2, **kwargs) + x 32 | 33 | 34 | class PreNorm(nn.Module): 35 | def __init__(self, dim, fn): 36 | super().__init__() 37 | self.norm = nn.LayerNorm(dim) 38 | self.fn = fn 39 | def forward(self, x, **kwargs): 40 | return self.fn(self.norm(x), **kwargs) 41 | 42 | 43 | class PreNorm2(nn.Module): 44 | def __init__(self, dim, fn): 45 | super().__init__() 46 | self.norm = nn.LayerNorm(dim) 47 | self.fn = fn 48 | def forward(self, x, x2, **kwargs): 49 | return self.fn(self.norm(x), self.norm(x2), **kwargs) 50 | 51 | 52 | class FeedForward(nn.Module): 53 | def __init__(self, dim, hidden_dim, dropout = 0.): 54 | super().__init__() 55 | self.net = nn.Sequential( 56 | nn.Linear(dim, hidden_dim), 57 | nn.GELU(), 58 | nn.Dropout(dropout), 59 | nn.Linear(hidden_dim, dim), 60 | nn.Dropout(dropout) 61 | ) 62 | def forward(self, x): 63 | return self.net(x) 64 | 65 | 66 | class Cross_Attention(nn.Module): 67 | def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0., softmax=True): 68 | super().__init__() 69 | inner_dim = dim_head * heads 70 | self.heads = heads 71 | self.scale = dim ** -0.5 72 | 73 | self.softmax = softmax 74 | self.to_q = nn.Linear(dim, inner_dim, bias=False) 75 | self.to_k = nn.Linear(dim, inner_dim, bias=False) 76 | self.to_v = nn.Linear(dim, inner_dim, bias=False) 77 | 78 | self.to_out = nn.Sequential( 79 | nn.Linear(inner_dim, dim), 80 | nn.Dropout(dropout) 81 | ) 82 | 83 | def forward(self, x, m, mask = None): 84 | 85 | b, n, _, h = *x.shape, self.heads 86 | q = self.to_q(x) 87 | k = self.to_k(m) 88 | v = self.to_v(m) 89 | 90 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), [q,k,v]) 91 | 92 | dots = torch.einsum('bhid,bhjd->bhij', q, k) * self.scale 93 | mask_value = -torch.finfo(dots.dtype).max 94 | 95 | if mask is not None: 96 | mask = F.pad(mask.flatten(1), (1, 0), value = True) 97 | assert mask.shape[-1] == dots.shape[-1], 'mask has incorrect dimensions' 98 | mask = mask[:, None, :] * mask[:, :, None] 99 | dots.masked_fill_(~mask, mask_value) 100 | del mask 101 | 102 | if self.softmax: 103 | attn = dots.softmax(dim=-1) 104 | else: 105 | attn = dots 106 | # attn = dots 107 | # vis_tmp(dots) 108 | 109 | out = torch.einsum('bhij,bhjd->bhid', attn, v) 110 | out = rearrange(out, 'b h n d -> b n (h d)') 111 | out = self.to_out(out) 112 | # vis_tmp2(out) 113 | 114 | return out 115 | 116 | 117 | class Attention(nn.Module): 118 | def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): 119 | super().__init__() 120 | inner_dim = dim_head * heads 121 | self.heads = heads 122 | self.scale = dim ** -0.5 123 | 124 | self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) 125 | self.to_out = nn.Sequential( 126 | nn.Linear(inner_dim, dim), 127 | nn.Dropout(dropout) 128 | ) 129 | 130 | def forward(self, x, mask = None): 131 | b, n, _, h = *x.shape, self.heads 132 | qkv = self.to_qkv(x).chunk(3, dim = -1) 133 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv) 134 | 135 | dots = torch.einsum('bhid,bhjd->bhij', q, k) * self.scale 136 | mask_value = -torch.finfo(dots.dtype).max 137 | 138 | if mask is not None: 139 | mask = F.pad(mask.flatten(1), (1, 0), value = True) 140 | assert mask.shape[-1] == dots.shape[-1], 'mask has incorrect dimensions' 141 | mask = mask[:, None, :] * mask[:, :, None] 142 | dots.masked_fill_(~mask, mask_value) 143 | del mask 144 | 145 | attn = dots.softmax(dim=-1) 146 | 147 | 148 | out = torch.einsum('bhij,bhjd->bhid', attn, v) 149 | out = rearrange(out, 'b h n d -> b n (h d)') 150 | out = self.to_out(out) 151 | return out 152 | 153 | 154 | class Transformer(nn.Module): 155 | def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout): 156 | super().__init__() 157 | self.layers = nn.ModuleList([]) 158 | for _ in range(depth): 159 | self.layers.append(nn.ModuleList([ 160 | Residual(PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout))), 161 | Residual(PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))) 162 | ])) 163 | def forward(self, x, mask = None): 164 | for attn, ff in self.layers: 165 | x = attn(x, mask = mask) 166 | x = ff(x) 167 | return x 168 | 169 | 170 | class TransformerDecoder(nn.Module): 171 | def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout, softmax=True): 172 | super().__init__() 173 | self.layers = nn.ModuleList([]) 174 | for _ in range(depth): 175 | self.layers.append(nn.ModuleList([ 176 | Residual2(PreNorm2(dim, Cross_Attention(dim, heads = heads, 177 | dim_head = dim_head, dropout = dropout, 178 | softmax=softmax))), 179 | Residual(PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))) 180 | ])) 181 | def forward(self, x, m, mask = None): 182 | """target(query), memory""" 183 | for attn, ff in self.layers: 184 | x = attn(x, m, mask = mask) 185 | x = ff(x) 186 | return x 187 | 188 | from scipy.io import savemat 189 | def save_to_mat(x1, x2, fx1, fx2, cp, file_name): 190 | #Save to mat files 191 | x1_np = x1.detach().cpu().numpy() 192 | x2_np = x2.detach().cpu().numpy() 193 | 194 | fx1_0_np = fx1[0].detach().cpu().numpy() 195 | fx2_0_np = fx2[0].detach().cpu().numpy() 196 | fx1_1_np = fx1[1].detach().cpu().numpy() 197 | fx2_1_np = fx2[1].detach().cpu().numpy() 198 | fx1_2_np = fx1[2].detach().cpu().numpy() 199 | fx2_2_np = fx2[2].detach().cpu().numpy() 200 | fx1_3_np = fx1[3].detach().cpu().numpy() 201 | fx2_3_np = fx2[3].detach().cpu().numpy() 202 | fx1_4_np = fx1[4].detach().cpu().numpy() 203 | fx2_4_np = fx2[4].detach().cpu().numpy() 204 | 205 | cp_np = cp[-1].detach().cpu().numpy() 206 | 207 | mdic = {'x1': x1_np, 'x2': x2_np, 208 | 'fx1_0': fx1_0_np, 'fx1_1': fx1_1_np, 'fx1_2': fx1_2_np, 'fx1_3': fx1_3_np, 'fx1_4': fx1_4_np, 209 | 'fx2_0': fx2_0_np, 'fx2_1': fx2_1_np, 'fx2_2': fx2_2_np, 'fx2_3': fx2_3_np, 'fx2_4': fx2_4_np, 210 | "final_pred": cp_np} 211 | 212 | savemat("/media/lidan/ssd2/ChangeFormer/vis/mat/"+file_name+".mat", mdic) 213 | 214 | 215 | 216 | -------------------------------------------------------------------------------- /models/networks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import init 4 | import torch.nn.functional as F 5 | from torch.optim import lr_scheduler 6 | 7 | import functools 8 | from einops import rearrange 9 | 10 | import models 11 | from models.help_funcs import Transformer, TransformerDecoder, TwoLayerConv2d 12 | from models.EGCTNet import EGCTNet 13 | 14 | 15 | ############################################################################### 16 | # Helper Functions 17 | ############################################################################### 18 | 19 | def get_scheduler(optimizer, args): 20 | """Return a learning rate scheduler 21 | 22 | Parameters: 23 | optimizer -- the optimizer of the network 24 | args (option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions.  25 | opt.lr_policy is the name of learning rate policy: linear | step | plateau | cosine 26 | 27 | For 'linear', we keep the same learning rate for the first epochs 28 | and linearly decay the rate to zero over the next epochs. 29 | For other schedulers (step, plateau, and cosine), we use the default PyTorch schedulers. 30 | See https://pytorch.org/docs/stable/optim.html for more details. 31 | """ 32 | if args.lr_policy == 'linear': 33 | def lambda_rule(epoch): 34 | lr_l = 1.0 - epoch / float(args.max_epochs + 1) 35 | return lr_l 36 | scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule) 37 | elif args.lr_policy == 'step': 38 | step_size = args.max_epochs//3 39 | # args.lr_decay_iters 40 | scheduler = lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=0.1) 41 | else: 42 | return NotImplementedError('learning rate policy [%s] is not implemented', args.lr_policy) 43 | return scheduler 44 | 45 | 46 | class Identity(nn.Module): 47 | def forward(self, x): 48 | return x 49 | 50 | 51 | def get_norm_layer(norm_type='instance'): 52 | """Return a normalization layer 53 | 54 | Parameters: 55 | norm_type (str) -- the name of the normalization layer: batch | instance | none 56 | 57 | For BatchNorm, we use learnable affine parameters and track running statistics (mean/stddev). 58 | For InstanceNorm, we do not use learnable affine parameters. We do not track running statistics. 59 | """ 60 | if norm_type == 'batch': 61 | norm_layer = functools.partial(nn.BatchNorm2d, affine=True, track_running_stats=True) 62 | elif norm_type == 'instance': 63 | norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False) 64 | elif norm_type == 'none': 65 | norm_layer = lambda x: Identity() 66 | else: 67 | raise NotImplementedError('normalization layer [%s] is not found' % norm_type) 68 | return norm_layer 69 | 70 | 71 | def init_weights(net, init_type='normal', init_gain=0.02): 72 | """Initialize network weights. 73 | 74 | Parameters: 75 | net (network) -- network to be initialized 76 | init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal 77 | init_gain (float) -- scaling factor for normal, xavier and orthogonal. 78 | 79 | We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might 80 | work better for some applications. Feel free to try yourself. 81 | """ 82 | def init_func(m): # define the initialization function 83 | classname = m.__class__.__name__ 84 | if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): 85 | if init_type == 'normal': 86 | init.normal_(m.weight.data, 0.0, init_gain) 87 | elif init_type == 'xavier': 88 | init.xavier_normal_(m.weight.data, gain=init_gain) 89 | elif init_type == 'kaiming': 90 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 91 | elif init_type == 'orthogonal': 92 | init.orthogonal_(m.weight.data, gain=init_gain) 93 | else: 94 | raise NotImplementedError('initialization method [%s] is not implemented' % init_type) 95 | if hasattr(m, 'bias') and m.bias is not None: 96 | init.constant_(m.bias.data, 0.0) 97 | elif classname.find('BatchNorm2d') != -1: # BatchNorm Layer's weight is not a matrix; only normal distribution applies. 98 | init.normal_(m.weight.data, 1.0, init_gain) 99 | init.constant_(m.bias.data, 0.0) 100 | 101 | print('initialize network with %s' % init_type) 102 | net.apply(init_func) # apply the initialization function 103 | 104 | 105 | def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[]): 106 | """Initialize a network: 1. register CPU/GPU device (with multi-GPU support); 2. initialize the network weights 107 | Parameters: 108 | net (network) -- the network to be initialized 109 | init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal 110 | gain (float) -- scaling factor for normal, xavier and orthogonal. 111 | gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2 112 | 113 | Return an initialized network. 114 | """ 115 | if len(gpu_ids) > 0: 116 | assert(torch.cuda.is_available()) 117 | net.to(gpu_ids[0]) 118 | if len(gpu_ids) > 1: 119 | net = torch.nn.DataParallel(net, gpu_ids) # multi-GPUs 120 | init_weights(net, init_type, init_gain=init_gain) 121 | return net 122 | 123 | 124 | def define_G(args, init_type='normal', init_gain=0.02, gpu_ids=[]): 125 | if args.net_G == 'base_resnet18': 126 | net = ResNet(input_nc=3, output_nc=2, output_sigmoid=False) 127 | 128 | elif args.net_G == 'base_transformer_pos_s4': 129 | net = BASE_Transformer(input_nc=3, output_nc=2, token_len=4, resnet_stages_num=4, 130 | with_pos='learned') 131 | 132 | elif args.net_G == 'base_transformer_pos_s4_dd8': 133 | net = BASE_Transformer(input_nc=3, output_nc=2, token_len=4, resnet_stages_num=4, 134 | with_pos='learned', enc_depth=1, dec_depth=8) 135 | 136 | elif args.net_G == 'base_transformer_pos_s4_dd8_dedim8': 137 | net = BASE_Transformer(input_nc=3, output_nc=2, token_len=4, resnet_stages_num=4, 138 | with_pos='learned', enc_depth=1, dec_depth=8, decoder_dim_head=8) 139 | 140 | elif args.net_G == 'EGCTNet': 141 | net = EGCTNet(img_size=args.img_size, input_nc=3, output_nc=2, embed_dim=args.embed_dim, num_classes=args.n_class) 142 | 143 | else: 144 | raise NotImplementedError('Generator model name [%s] is not recognized' % args.net_G) 145 | return init_net(net, init_type, init_gain, gpu_ids) 146 | 147 | 148 | ############################################################################### 149 | # main Functions 150 | ############################################################################### 151 | 152 | 153 | class ResNet(torch.nn.Module): 154 | def __init__(self, input_nc, output_nc, 155 | resnet_stages_num=5, backbone='resnet18', 156 | output_sigmoid=False, if_upsample_2x=True): 157 | """ 158 | In the constructor we instantiate two nn.Linear modules and assign them as 159 | member variables. 160 | """ 161 | super(ResNet, self).__init__() 162 | expand = 1 163 | if backbone == 'resnet18': 164 | self.resnet = models.resnet18(pretrained=True, 165 | replace_stride_with_dilation=[False,True,True]) 166 | elif backbone == 'resnet34': 167 | self.resnet = models.resnet34(pretrained=True, 168 | replace_stride_with_dilation=[False,True,True]) 169 | elif backbone == 'resnet50': 170 | self.resnet = models.resnet50(pretrained=True, 171 | replace_stride_with_dilation=[False,True,True]) 172 | expand = 4 173 | else: 174 | raise NotImplementedError 175 | self.relu = nn.ReLU() 176 | self.upsamplex2 = nn.Upsample(scale_factor=2) 177 | self.upsamplex4 = nn.Upsample(scale_factor=4, mode='bilinear') 178 | 179 | self.classifier = TwoLayerConv2d(in_channels=32, out_channels=output_nc) 180 | 181 | self.resnet_stages_num = resnet_stages_num 182 | 183 | self.if_upsample_2x = if_upsample_2x 184 | if self.resnet_stages_num == 5: 185 | layers = 512 * expand 186 | elif self.resnet_stages_num == 4: 187 | layers = 256 * expand 188 | elif self.resnet_stages_num == 3: 189 | layers = 128 * expand 190 | else: 191 | raise NotImplementedError 192 | self.conv_pred = nn.Conv2d(layers, 32, kernel_size=3, padding=1) 193 | 194 | self.output_sigmoid = output_sigmoid 195 | self.sigmoid = nn.Sigmoid() 196 | 197 | def forward(self, x1, x2): 198 | x1 = self.forward_single(x1) 199 | x2 = self.forward_single(x2) 200 | x = torch.abs(x1 - x2) 201 | if not self.if_upsample_2x: 202 | x = self.upsamplex2(x) 203 | x = self.upsamplex4(x) 204 | x = self.classifier(x) 205 | 206 | if self.output_sigmoid: 207 | x = self.sigmoid(x) 208 | return x 209 | 210 | def forward_single(self, x): 211 | # resnet layers 212 | x = self.resnet.conv1(x) 213 | x = self.resnet.bn1(x) 214 | x = self.resnet.relu(x) 215 | x = self.resnet.maxpool(x) 216 | 217 | x_4 = self.resnet.layer1(x) # 1/4, in=64, out=64 218 | x_8 = self.resnet.layer2(x_4) # 1/8, in=64, out=128 219 | 220 | if self.resnet_stages_num > 3: 221 | x_8 = self.resnet.layer3(x_8) # 1/8, in=128, out=256 222 | 223 | if self.resnet_stages_num == 5: 224 | x_8 = self.resnet.layer4(x_8) # 1/32, in=256, out=512 225 | elif self.resnet_stages_num > 5: 226 | raise NotImplementedError 227 | 228 | if self.if_upsample_2x: 229 | x = self.upsamplex2(x_8) 230 | else: 231 | x = x_8 232 | # output layers 233 | x = self.conv_pred(x) 234 | return x 235 | 236 | 237 | class BASE_Transformer(ResNet): 238 | """ 239 | Resnet of 8 downsampling + BIT + bitemporal feature Differencing + a small CNN 240 | """ 241 | def __init__(self, input_nc, output_nc, with_pos, resnet_stages_num=5, 242 | token_len=4, token_trans=True, 243 | enc_depth=1, dec_depth=1, 244 | dim_head=64, decoder_dim_head=64, 245 | tokenizer=True, if_upsample_2x=True, 246 | pool_mode='max', pool_size=2, 247 | backbone='resnet18', 248 | decoder_softmax=True, with_decoder_pos=None, 249 | with_decoder=True): 250 | super(BASE_Transformer, self).__init__(input_nc, output_nc,backbone=backbone, 251 | resnet_stages_num=resnet_stages_num, 252 | if_upsample_2x=if_upsample_2x, 253 | ) 254 | self.token_len = token_len 255 | self.conv_a = nn.Conv2d(32, self.token_len, kernel_size=1, 256 | padding=0, bias=False) 257 | self.tokenizer = tokenizer 258 | if not self.tokenizer: 259 | # if not use tokenzier,then downsample the feature map into a certain size 260 | self.pooling_size = pool_size 261 | self.pool_mode = pool_mode 262 | self.token_len = self.pooling_size * self.pooling_size 263 | 264 | self.token_trans = token_trans 265 | self.with_decoder = with_decoder 266 | dim = 32 267 | mlp_dim = 2*dim 268 | 269 | self.with_pos = with_pos 270 | if with_pos == 'learned': 271 | self.pos_embedding = nn.Parameter(torch.randn(1, self.token_len*2, 32)) 272 | decoder_pos_size = 256//4 273 | self.with_decoder_pos = with_decoder_pos 274 | if self.with_decoder_pos == 'learned': 275 | self.pos_embedding_decoder =nn.Parameter(torch.randn(1, 32, 276 | decoder_pos_size, 277 | decoder_pos_size)) 278 | self.enc_depth = enc_depth 279 | self.dec_depth = dec_depth 280 | self.dim_head = dim_head 281 | self.decoder_dim_head = decoder_dim_head 282 | self.transformer = Transformer(dim=dim, depth=self.enc_depth, heads=8, 283 | dim_head=self.dim_head, 284 | mlp_dim=mlp_dim, dropout=0) 285 | self.transformer_decoder = TransformerDecoder(dim=dim, depth=self.dec_depth, 286 | heads=8, dim_head=self.decoder_dim_head, mlp_dim=mlp_dim, dropout=0, 287 | softmax=decoder_softmax) 288 | 289 | def _forward_semantic_tokens(self, x): 290 | b, c, h, w = x.shape 291 | spatial_attention = self.conv_a(x) 292 | spatial_attention = spatial_attention.view([b, self.token_len, -1]).contiguous() 293 | spatial_attention = torch.softmax(spatial_attention, dim=-1) 294 | x = x.view([b, c, -1]).contiguous() 295 | tokens = torch.einsum('bln,bcn->blc', spatial_attention, x) 296 | 297 | return tokens 298 | 299 | def _forward_reshape_tokens(self, x): 300 | # b,c,h,w = x.shape 301 | if self.pool_mode == 'max': 302 | x = F.adaptive_max_pool2d(x, [self.pooling_size, self.pooling_size]) 303 | elif self.pool_mode == 'ave': 304 | x = F.adaptive_avg_pool2d(x, [self.pooling_size, self.pooling_size]) 305 | else: 306 | x = x 307 | tokens = rearrange(x, 'b c h w -> b (h w) c') 308 | return tokens 309 | 310 | def _forward_transformer(self, x): 311 | if self.with_pos: 312 | x += self.pos_embedding 313 | x = self.transformer(x) 314 | return x 315 | 316 | def _forward_transformer_decoder(self, x, m): 317 | b, c, h, w = x.shape 318 | if self.with_decoder_pos == 'fix': 319 | x = x + self.pos_embedding_decoder 320 | elif self.with_decoder_pos == 'learned': 321 | x = x + self.pos_embedding_decoder 322 | x = rearrange(x, 'b c h w -> b (h w) c') 323 | x = self.transformer_decoder(x, m) 324 | x = rearrange(x, 'b (h w) c -> b c h w', h=h) 325 | return x 326 | 327 | def _forward_simple_decoder(self, x, m): 328 | b, c, h, w = x.shape 329 | b, l, c = m.shape 330 | m = m.expand([h,w,b,l,c]) 331 | m = rearrange(m, 'h w b l c -> l b c h w') 332 | m = m.sum(0) 333 | x = x + m 334 | return x 335 | 336 | def forward(self, x1, x2): 337 | # forward backbone resnet 338 | x1 = self.forward_single(x1) 339 | x2 = self.forward_single(x2) 340 | 341 | # forward tokenzier 342 | if self.tokenizer: 343 | token1 = self._forward_semantic_tokens(x1) 344 | token2 = self._forward_semantic_tokens(x2) 345 | else: 346 | token1 = self._forward_reshape_tokens(x1) 347 | token2 = self._forward_reshape_tokens(x2) 348 | # forward transformer encoder 349 | if self.token_trans: 350 | self.tokens_ = torch.cat([token1, token2], dim=1) 351 | self.tokens = self._forward_transformer(self.tokens_) 352 | token1, token2 = self.tokens.chunk(2, dim=1) 353 | # forward transformer decoder 354 | if self.with_decoder: 355 | x1 = self._forward_transformer_decoder(x1, token1) 356 | x2 = self._forward_transformer_decoder(x2, token2) 357 | else: 358 | x1 = self._forward_simple_decoder(x1, token1) 359 | x2 = self._forward_simple_decoder(x2, token2) 360 | # feature differencing 361 | x = torch.abs(x1 - x2) 362 | if not self.if_upsample_2x: 363 | x = self.upsamplex2(x) 364 | x = self.upsamplex4(x) 365 | # forward small cnn 366 | x = self.classifier(x) 367 | if self.output_sigmoid: 368 | x = self.sigmoid(x) 369 | outputs = [] 370 | outputs.append(x) 371 | return outputs 372 | 373 | -------------------------------------------------------------------------------- /models/pixel_shuffel_up.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn.functional as F 4 | from torch import nn 5 | 6 | def icnr(x, scale=2, init=nn.init.kaiming_normal_): 7 | """ 8 | Checkerboard artifact free sub-pixel convolution 9 | https://arxiv.org/abs/1707.02937 10 | """ 11 | ni,nf,h,w = x.shape 12 | ni2 = int(ni/(scale**2)) 13 | k = init(torch.zeros([ni2,nf,h,w])).transpose(0, 1) 14 | k = k.contiguous().view(ni2, nf, -1) 15 | k = k.repeat(1, 1, scale**2) 16 | k = k.contiguous().view([nf,ni,h,w]).transpose(0, 1) 17 | x.data.copy_(k) 18 | 19 | 20 | class PixelShuffle(nn.Module): 21 | """ 22 | Real-Time Single Image and Video Super-Resolution 23 | https://arxiv.org/abs/1609.05158 24 | """ 25 | def __init__(self, n_channels, scale): 26 | super(PixelShuffle, self).__init__() 27 | self.conv = nn.Conv2d(n_channels, n_channels*(scale**2), kernel_size=1) 28 | icnr(self.conv.weight) 29 | self.shuf = nn.PixelShuffle(scale) 30 | self.relu = nn.ReLU(inplace=True) 31 | 32 | def forward(self,x): 33 | x = self.shuf(self.relu(self.conv(x))) 34 | return x 35 | 36 | 37 | def upsample(in_channels, out_channels, upscale, kernel_size=3): 38 | # A series of x 2 upsamling until we get to the upscale we want 39 | layers = [] 40 | conv1x1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False) 41 | nn.init.kaiming_normal_(conv1x1.weight.data, nonlinearity='relu') 42 | layers.append(conv1x1) 43 | for i in range(int(math.log(upscale, 2))): 44 | layers.append(PixelShuffle(out_channels, scale=2)) 45 | return nn.Sequential(*layers) 46 | 47 | 48 | class PS_UP(nn.Module): 49 | def __init__(self, upscale, conv_in_ch, num_classes): 50 | super(PS_UP, self).__init__() 51 | self.upsample = upsample(conv_in_ch, num_classes, upscale=upscale) 52 | 53 | def forward(self, x): 54 | x = self.upsample(x) 55 | return x -------------------------------------------------------------------------------- /models/resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.utils.model_zoo import load_url as load_state_dict_from_url 4 | 5 | # from torchvision.models.utils import load_state_dict_from_url 6 | 7 | 8 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 9 | 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', 10 | 'wide_resnet50_2', 'wide_resnet101_2'] 11 | 12 | 13 | model_urls = { 14 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 15 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 16 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 17 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 18 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 19 | 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', 20 | 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth', 21 | 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth', 22 | 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth', 23 | } 24 | 25 | 26 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 27 | """3x3 convolution with padding""" 28 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 29 | padding=dilation, groups=groups, bias=False, dilation=dilation) 30 | 31 | 32 | def conv1x1(in_planes, out_planes, stride=1): 33 | """1x1 convolution""" 34 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 35 | 36 | 37 | class BasicBlock(nn.Module): 38 | expansion = 1 39 | 40 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 41 | base_width=64, dilation=1, norm_layer=None): 42 | super(BasicBlock, self).__init__() 43 | if norm_layer is None: 44 | norm_layer = nn.BatchNorm2d 45 | if groups != 1 or base_width != 64: 46 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 47 | if dilation > 1: 48 | dilation = 1 49 | # raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 50 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 51 | self.conv1 = conv3x3(inplanes, planes, stride) 52 | self.bn1 = norm_layer(planes) 53 | self.relu = nn.ReLU(inplace=True) 54 | self.conv2 = conv3x3(planes, planes) 55 | self.bn2 = norm_layer(planes) 56 | self.downsample = downsample 57 | self.stride = stride 58 | 59 | def forward(self, x): 60 | identity = x 61 | 62 | out = self.conv1(x) 63 | out = self.bn1(out) 64 | out = self.relu(out) 65 | 66 | out = self.conv2(out) 67 | out = self.bn2(out) 68 | 69 | if self.downsample is not None: 70 | identity = self.downsample(x) 71 | 72 | out += identity 73 | out = self.relu(out) 74 | 75 | return out 76 | 77 | 78 | class Bottleneck(nn.Module): 79 | # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2) 80 | # while original implementation places the stride at the first 1x1 convolution(self.conv1) 81 | # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385. 82 | # This variant is also known as ResNet V1.5 and improves accuracy according to 83 | # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. 84 | 85 | expansion = 4 86 | 87 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 88 | base_width=64, dilation=1, norm_layer=None): 89 | super(Bottleneck, self).__init__() 90 | if norm_layer is None: 91 | norm_layer = nn.BatchNorm2d 92 | width = int(planes * (base_width / 64.)) * groups 93 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 94 | self.conv1 = conv1x1(inplanes, width) 95 | self.bn1 = norm_layer(width) 96 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 97 | self.bn2 = norm_layer(width) 98 | self.conv3 = conv1x1(width, planes * self.expansion) 99 | self.bn3 = norm_layer(planes * self.expansion) 100 | self.relu = nn.ReLU(inplace=True) 101 | self.downsample = downsample 102 | self.stride = stride 103 | 104 | def forward(self, x): 105 | identity = x 106 | 107 | out = self.conv1(x) 108 | out = self.bn1(out) 109 | out = self.relu(out) 110 | 111 | out = self.conv2(out) 112 | out = self.bn2(out) 113 | out = self.relu(out) 114 | 115 | out = self.conv3(out) 116 | out = self.bn3(out) 117 | 118 | if self.downsample is not None: 119 | identity = self.downsample(x) 120 | 121 | out += identity 122 | out = self.relu(out) 123 | 124 | return out 125 | 126 | 127 | class ResNet(nn.Module): 128 | 129 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, 130 | groups=1, width_per_group=64, replace_stride_with_dilation=None, 131 | norm_layer=None, strides=None): 132 | super(ResNet, self).__init__() 133 | if norm_layer is None: 134 | norm_layer = nn.BatchNorm2d 135 | self._norm_layer = norm_layer 136 | 137 | self.strides = strides 138 | if self.strides is None: 139 | self.strides = [2, 2, 2, 2, 2] 140 | 141 | self.inplanes = 64 142 | self.dilation = 1 143 | if replace_stride_with_dilation is None: 144 | # each element in the tuple indicates if we should replace 145 | # the 2x2 stride with a dilated convolution instead 146 | replace_stride_with_dilation = [False, False, False] 147 | if len(replace_stride_with_dilation) != 3: 148 | raise ValueError("replace_stride_with_dilation should be None " 149 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 150 | self.groups = groups 151 | self.base_width = width_per_group 152 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=self.strides[0], padding=3, 153 | bias=False) 154 | self.bn1 = norm_layer(self.inplanes) 155 | self.relu = nn.ReLU(inplace=True) 156 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=self.strides[1], padding=1) 157 | self.layer1 = self._make_layer(block, 64, layers[0]) 158 | self.layer2 = self._make_layer(block, 128, layers[1], stride=self.strides[2], 159 | dilate=replace_stride_with_dilation[0]) 160 | self.layer3 = self._make_layer(block, 256, layers[2], stride=self.strides[3], 161 | dilate=replace_stride_with_dilation[1]) 162 | self.layer4 = self._make_layer(block, 512, layers[3], stride=self.strides[4], 163 | dilate=replace_stride_with_dilation[2]) 164 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 165 | self.fc = nn.Linear(512 * block.expansion, num_classes) 166 | 167 | for m in self.modules(): 168 | if isinstance(m, nn.Conv2d): 169 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 170 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 171 | nn.init.constant_(m.weight, 1) 172 | nn.init.constant_(m.bias, 0) 173 | 174 | # Zero-initialize the last BN in each residual branch, 175 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 176 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 177 | if zero_init_residual: 178 | for m in self.modules(): 179 | if isinstance(m, Bottleneck): 180 | nn.init.constant_(m.bn3.weight, 0) 181 | elif isinstance(m, BasicBlock): 182 | nn.init.constant_(m.bn2.weight, 0) 183 | 184 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 185 | norm_layer = self._norm_layer 186 | downsample = None 187 | previous_dilation = self.dilation 188 | if dilate: 189 | self.dilation *= stride 190 | stride = 1 191 | if stride != 1 or self.inplanes != planes * block.expansion: 192 | downsample = nn.Sequential( 193 | conv1x1(self.inplanes, planes * block.expansion, stride), 194 | norm_layer(planes * block.expansion), 195 | ) 196 | 197 | layers = [] 198 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 199 | self.base_width, previous_dilation, norm_layer)) 200 | self.inplanes = planes * block.expansion 201 | for _ in range(1, blocks): 202 | layers.append(block(self.inplanes, planes, groups=self.groups, 203 | base_width=self.base_width, dilation=self.dilation, 204 | norm_layer=norm_layer)) 205 | 206 | return nn.Sequential(*layers) 207 | 208 | def _forward_impl(self, x): 209 | # See note [TorchScript super()] 210 | x = self.conv1(x) 211 | x = self.bn1(x) 212 | x = self.relu(x) 213 | x = self.maxpool(x) 214 | 215 | x = self.layer1(x) 216 | x = self.layer2(x) 217 | x = self.layer3(x) 218 | x = self.layer4(x) 219 | 220 | x = self.avgpool(x) 221 | x = torch.flatten(x, 1) 222 | x = self.fc(x) 223 | 224 | return x 225 | 226 | def forward(self, x): 227 | return self._forward_impl(x) 228 | 229 | 230 | def _resnet(arch, block, layers, pretrained, progress, **kwargs): 231 | model = ResNet(block, layers, **kwargs) 232 | if pretrained: 233 | state_dict = load_state_dict_from_url(model_urls[arch], 234 | progress=progress) 235 | model.load_state_dict(state_dict) 236 | return model 237 | 238 | 239 | def resnet18(pretrained=False, progress=True, **kwargs): 240 | r"""ResNet-18 model from 241 | `"Deep Residual Learning for Image Recognition" `_ 242 | 243 | Args: 244 | pretrained (bool): If True, returns a model pre-trained on ImageNet 245 | progress (bool): If True, displays a progress bar of the download to stderr 246 | """ 247 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, 248 | **kwargs) 249 | 250 | 251 | def resnet34(pretrained=False, progress=True, **kwargs): 252 | r"""ResNet-34 model from 253 | `"Deep Residual Learning for Image Recognition" `_ 254 | 255 | Args: 256 | pretrained (bool): If True, returns a model pre-trained on ImageNet 257 | progress (bool): If True, displays a progress bar of the download to stderr 258 | """ 259 | return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, 260 | **kwargs) 261 | 262 | 263 | def resnet50(pretrained=False, progress=True, **kwargs): 264 | r"""ResNet-50 model from 265 | `"Deep Residual Learning for Image Recognition" `_ 266 | 267 | Args: 268 | pretrained (bool): If True, returns a model pre-trained on ImageNet 269 | progress (bool): If True, displays a progress bar of the download to stderr 270 | """ 271 | return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, 272 | **kwargs) 273 | 274 | 275 | def resnet101(pretrained=False, progress=True, **kwargs): 276 | r"""ResNet-101 model from 277 | `"Deep Residual Learning for Image Recognition" `_ 278 | 279 | Args: 280 | pretrained (bool): If True, returns a model pre-trained on ImageNet 281 | progress (bool): If True, displays a progress bar of the download to stderr 282 | """ 283 | return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress, 284 | **kwargs) 285 | 286 | 287 | def resnet152(pretrained=False, progress=True, **kwargs): 288 | r"""ResNet-152 model from 289 | `"Deep Residual Learning for Image Recognition" `_ 290 | 291 | Args: 292 | pretrained (bool): If True, returns a model pre-trained on ImageNet 293 | progress (bool): If True, displays a progress bar of the download to stderr 294 | """ 295 | return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress, 296 | **kwargs) 297 | 298 | 299 | def resnext50_32x4d(pretrained=False, progress=True, **kwargs): 300 | r"""ResNeXt-50 32x4d model from 301 | `"Aggregated Residual Transformation for Deep Neural Networks" `_ 302 | 303 | Args: 304 | pretrained (bool): If True, returns a model pre-trained on ImageNet 305 | progress (bool): If True, displays a progress bar of the download to stderr 306 | """ 307 | kwargs['groups'] = 32 308 | kwargs['width_per_group'] = 4 309 | return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3], 310 | pretrained, progress, **kwargs) 311 | 312 | 313 | def resnext101_32x8d(pretrained=False, progress=True, **kwargs): 314 | r"""ResNeXt-101 32x8d model from 315 | `"Aggregated Residual Transformation for Deep Neural Networks" `_ 316 | 317 | Args: 318 | pretrained (bool): If True, returns a model pre-trained on ImageNet 319 | progress (bool): If True, displays a progress bar of the download to stderr 320 | """ 321 | kwargs['groups'] = 32 322 | kwargs['width_per_group'] = 8 323 | return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3], 324 | pretrained, progress, **kwargs) 325 | 326 | 327 | def wide_resnet50_2(pretrained=False, progress=True, **kwargs): 328 | r"""Wide ResNet-50-2 model from 329 | `"Wide Residual Networks" `_ 330 | 331 | The model is the same as ResNet except for the bottleneck number of channels 332 | which is twice larger in every block. The number of channels in outer 1x1 333 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 334 | channels, and in Wide ResNet-50-2 has 2048-1024-2048. 335 | 336 | Args: 337 | pretrained (bool): If True, returns a model pre-trained on ImageNet 338 | progress (bool): If True, displays a progress bar of the download to stderr 339 | """ 340 | kwargs['width_per_group'] = 64 * 2 341 | return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3], 342 | pretrained, progress, **kwargs) 343 | 344 | 345 | def wide_resnet101_2(pretrained=False, progress=True, **kwargs): 346 | r"""Wide ResNet-101-2 model from 347 | `"Wide Residual Networks" `_ 348 | 349 | The model is the same as ResNet except for the bottleneck number of channels 350 | which is twice larger in every block. The number of channels in outer 1x1 351 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 352 | channels, and in Wide ResNet-50-2 has 2048-1024-2048. 353 | 354 | Args: 355 | pretrained (bool): If True, returns a model pre-trained on ImageNet 356 | progress (bool): If True, displays a progress bar of the download to stderr 357 | """ 358 | kwargs['width_per_group'] = 64 * 2 359 | return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3], 360 | pretrained, progress, **kwargs) 361 | -------------------------------------------------------------------------------- /models/sync_batchnorm/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : __init__.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d 12 | from .replicate import DataParallelWithCallback, patch_replication_callback 13 | -------------------------------------------------------------------------------- /models/sync_batchnorm/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen11221/EGCTNet_pytorch/c35da37f05ca494aca4444050e80bde26e541ec2/models/sync_batchnorm/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /models/sync_batchnorm/__pycache__/batchnorm.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen11221/EGCTNet_pytorch/c35da37f05ca494aca4444050e80bde26e541ec2/models/sync_batchnorm/__pycache__/batchnorm.cpython-37.pyc -------------------------------------------------------------------------------- /models/sync_batchnorm/__pycache__/comm.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen11221/EGCTNet_pytorch/c35da37f05ca494aca4444050e80bde26e541ec2/models/sync_batchnorm/__pycache__/comm.cpython-37.pyc -------------------------------------------------------------------------------- /models/sync_batchnorm/__pycache__/replicate.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen11221/EGCTNet_pytorch/c35da37f05ca494aca4444050e80bde26e541ec2/models/sync_batchnorm/__pycache__/replicate.cpython-37.pyc -------------------------------------------------------------------------------- /models/sync_batchnorm/batchnorm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : batchnorm.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import collections 12 | 13 | import torch 14 | import torch.nn.functional as F 15 | 16 | from torch.nn.modules.batchnorm import _BatchNorm 17 | from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast 18 | 19 | from .comm import SyncMaster 20 | 21 | __all__ = ['SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d', 'SynchronizedBatchNorm3d'] 22 | 23 | 24 | def _sum_ft(tensor): 25 | """sum over the first and last dimention""" 26 | return tensor.sum(dim=0).sum(dim=-1) 27 | 28 | 29 | def _unsqueeze_ft(tensor): 30 | """add new dementions at the front and the tail""" 31 | return tensor.unsqueeze(0).unsqueeze(-1) 32 | 33 | 34 | _ChildMessage = collections.namedtuple('_ChildMessage', ['sum', 'ssum', 'sum_size']) 35 | _MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std']) 36 | 37 | 38 | class _SynchronizedBatchNorm(_BatchNorm): 39 | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True): 40 | super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine) 41 | 42 | self._sync_master = SyncMaster(self._data_parallel_master) 43 | 44 | self._is_parallel = False 45 | self._parallel_id = None 46 | self._slave_pipe = None 47 | 48 | def forward(self, input): 49 | # If it is not parallel computation or is in evaluation mode, use PyTorch's implementation. 50 | if not (self._is_parallel and self.training): 51 | return F.batch_norm( 52 | input, self.running_mean, self.running_var, self.weight, self.bias, 53 | self.training, self.momentum, self.eps) 54 | 55 | # Resize the input to (B, C, -1). 56 | input_shape = input.size() 57 | input = input.view(input.size(0), self.num_features, -1) 58 | 59 | # Compute the sum and square-sum. 60 | sum_size = input.size(0) * input.size(2) 61 | input_sum = _sum_ft(input) 62 | input_ssum = _sum_ft(input ** 2) 63 | 64 | # Reduce-and-broadcast the statistics. 65 | if self._parallel_id == 0: 66 | mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size)) 67 | else: 68 | mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size)) 69 | 70 | # Compute the output. 71 | if self.affine: 72 | # MJY:: Fuse the multiplication for speed. 73 | output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias) 74 | else: 75 | output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std) 76 | 77 | # Reshape it. 78 | return output.view(input_shape) 79 | 80 | def __data_parallel_replicate__(self, ctx, copy_id): 81 | self._is_parallel = True 82 | self._parallel_id = copy_id 83 | 84 | # parallel_id == 0 means master device. 85 | if self._parallel_id == 0: 86 | ctx.sync_master = self._sync_master 87 | else: 88 | self._slave_pipe = ctx.sync_master.register_slave(copy_id) 89 | 90 | def _data_parallel_master(self, intermediates): 91 | """Reduce the sum and square-sum, compute the statistics, and broadcast it.""" 92 | 93 | # Always using same "device order" makes the ReduceAdd operation faster. 94 | # Thanks to:: Tete Xiao (http://tetexiao.com/) 95 | intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device()) 96 | 97 | to_reduce = [i[1][:2] for i in intermediates] 98 | to_reduce = [j for i in to_reduce for j in i] # flatten 99 | target_gpus = [i[1].sum.get_device() for i in intermediates] 100 | 101 | sum_size = sum([i[1].sum_size for i in intermediates]) 102 | sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce) 103 | mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size) 104 | 105 | broadcasted = Broadcast.apply(target_gpus, mean, inv_std) 106 | 107 | outputs = [] 108 | for i, rec in enumerate(intermediates): 109 | outputs.append((rec[0], _MasterMessage(*broadcasted[i*2:i*2+2]))) 110 | 111 | return outputs 112 | 113 | def _compute_mean_std(self, sum_, ssum, size): 114 | """Compute the mean and standard-deviation with sum and square-sum. This method 115 | also maintains the moving average on the master device.""" 116 | assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.' 117 | mean = sum_ / size 118 | sumvar = ssum - sum_ * mean 119 | unbias_var = sumvar / (size - 1) 120 | bias_var = sumvar / size 121 | 122 | self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data 123 | self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data 124 | 125 | return mean, bias_var.clamp(self.eps) ** -0.5 126 | 127 | 128 | class SynchronizedBatchNorm1d(_SynchronizedBatchNorm): 129 | r"""Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a 130 | mini-batch. 131 | 132 | .. math:: 133 | 134 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta 135 | 136 | This module differs from the built-in PyTorch BatchNorm1d as the mean and 137 | standard-deviation are reduced across all devices during training. 138 | 139 | For example, when one uses `nn.DataParallel` to wrap the network during 140 | training, PyTorch's implementation normalize the tensor on each device using 141 | the statistics only on that device, which accelerated the computation and 142 | is also easy to implement, but the statistics might be inaccurate. 143 | Instead, in this synchronized version, the statistics will be computed 144 | over all training samples distributed on multiple devices. 145 | 146 | Note that, for one-GPU or CPU-only case, this module behaves exactly same 147 | as the built-in PyTorch implementation. 148 | 149 | The mean and standard-deviation are calculated per-dimension over 150 | the mini-batches and gamma and beta are learnable parameter vectors 151 | of size C (where C is the input size). 152 | 153 | During training, this layer keeps a running estimate of its computed mean 154 | and variance. The running sum is kept with a default momentum of 0.1. 155 | 156 | During evaluation, this running mean/variance is used for normalization. 157 | 158 | Because the BatchNorm is done over the `C` dimension, computing statistics 159 | on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm 160 | 161 | Args: 162 | num_features: num_features from an expected input of size 163 | `batch_size x num_features [x width]` 164 | eps: a value added to the denominator for numerical stability. 165 | Default: 1e-5 166 | momentum: the value used for the running_mean and running_var 167 | computation. Default: 0.1 168 | affine: a boolean value that when set to ``True``, gives the layer learnable 169 | affine parameters. Default: ``True`` 170 | 171 | Shape: 172 | - Input: :math:`(N, C)` or :math:`(N, C, L)` 173 | - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input) 174 | 175 | Examples: 176 | >>> # With Learnable Parameters 177 | >>> m = SynchronizedBatchNorm1d(100) 178 | >>> # Without Learnable Parameters 179 | >>> m = SynchronizedBatchNorm1d(100, affine=False) 180 | >>> input = torch.autograd.Variable(torch.randn(20, 100)) 181 | >>> output = m(input) 182 | """ 183 | 184 | def _check_input_dim(self, input): 185 | if input.dim() != 2 and input.dim() != 3: 186 | raise ValueError('expected 2D or 3D input (got {}D input)' 187 | .format(input.dim())) 188 | super(SynchronizedBatchNorm1d, self)._check_input_dim(input) 189 | 190 | 191 | class SynchronizedBatchNorm2d(_SynchronizedBatchNorm): 192 | r"""Applies Batch Normalization over a 4d input that is seen as a mini-batch 193 | of 3d inputs 194 | 195 | .. math:: 196 | 197 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta 198 | 199 | This module differs from the built-in PyTorch BatchNorm2d as the mean and 200 | standard-deviation are reduced across all devices during training. 201 | 202 | For example, when one uses `nn.DataParallel` to wrap the network during 203 | training, PyTorch's implementation normalize the tensor on each device using 204 | the statistics only on that device, which accelerated the computation and 205 | is also easy to implement, but the statistics might be inaccurate. 206 | Instead, in this synchronized version, the statistics will be computed 207 | over all training samples distributed on multiple devices. 208 | 209 | Note that, for one-GPU or CPU-only case, this module behaves exactly same 210 | as the built-in PyTorch implementation. 211 | 212 | The mean and standard-deviation are calculated per-dimension over 213 | the mini-batches and gamma and beta are learnable parameter vectors 214 | of size C (where C is the input size). 215 | 216 | During training, this layer keeps a running estimate of its computed mean 217 | and variance. The running sum is kept with a default momentum of 0.1. 218 | 219 | During evaluation, this running mean/variance is used for normalization. 220 | 221 | Because the BatchNorm is done over the `C` dimension, computing statistics 222 | on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm 223 | 224 | Args: 225 | num_features: num_features from an expected input of 226 | size batch_size x num_features x height x width 227 | eps: a value added to the denominator for numerical stability. 228 | Default: 1e-5 229 | momentum: the value used for the running_mean and running_var 230 | computation. Default: 0.1 231 | affine: a boolean value that when set to ``True``, gives the layer learnable 232 | affine parameters. Default: ``True`` 233 | 234 | Shape: 235 | - Input: :math:`(N, C, H, W)` 236 | - Output: :math:`(N, C, H, W)` (same shape as input) 237 | 238 | Examples: 239 | >>> # With Learnable Parameters 240 | >>> m = SynchronizedBatchNorm2d(100) 241 | >>> # Without Learnable Parameters 242 | >>> m = SynchronizedBatchNorm2d(100, affine=False) 243 | >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45)) 244 | >>> output = m(input) 245 | """ 246 | 247 | def _check_input_dim(self, input): 248 | if input.dim() != 4: 249 | raise ValueError('expected 4D input (got {}D input)' 250 | .format(input.dim())) 251 | super(SynchronizedBatchNorm2d, self)._check_input_dim(input) 252 | 253 | 254 | class SynchronizedBatchNorm3d(_SynchronizedBatchNorm): 255 | r"""Applies Batch Normalization over a 5d input that is seen as a mini-batch 256 | of 4d inputs 257 | 258 | .. math:: 259 | 260 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta 261 | 262 | This module differs from the built-in PyTorch BatchNorm3d as the mean and 263 | standard-deviation are reduced across all devices during training. 264 | 265 | For example, when one uses `nn.DataParallel` to wrap the network during 266 | training, PyTorch's implementation normalize the tensor on each device using 267 | the statistics only on that device, which accelerated the computation and 268 | is also easy to implement, but the statistics might be inaccurate. 269 | Instead, in this synchronized version, the statistics will be computed 270 | over all training samples distributed on multiple devices. 271 | 272 | Note that, for one-GPU or CPU-only case, this module behaves exactly same 273 | as the built-in PyTorch implementation. 274 | 275 | The mean and standard-deviation are calculated per-dimension over 276 | the mini-batches and gamma and beta are learnable parameter vectors 277 | of size C (where C is the input size). 278 | 279 | During training, this layer keeps a running estimate of its computed mean 280 | and variance. The running sum is kept with a default momentum of 0.1. 281 | 282 | During evaluation, this running mean/variance is used for normalization. 283 | 284 | Because the BatchNorm is done over the `C` dimension, computing statistics 285 | on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm 286 | or Spatio-temporal BatchNorm 287 | 288 | Args: 289 | num_features: num_features from an expected input of 290 | size batch_size x num_features x depth x height x width 291 | eps: a value added to the denominator for numerical stability. 292 | Default: 1e-5 293 | momentum: the value used for the running_mean and running_var 294 | computation. Default: 0.1 295 | affine: a boolean value that when set to ``True``, gives the layer learnable 296 | affine parameters. Default: ``True`` 297 | 298 | Shape: 299 | - Input: :math:`(N, C, D, H, W)` 300 | - Output: :math:`(N, C, D, H, W)` (same shape as input) 301 | 302 | Examples: 303 | >>> # With Learnable Parameters 304 | >>> m = SynchronizedBatchNorm3d(100) 305 | >>> # Without Learnable Parameters 306 | >>> m = SynchronizedBatchNorm3d(100, affine=False) 307 | >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10)) 308 | >>> output = m(input) 309 | """ 310 | 311 | def _check_input_dim(self, input): 312 | if input.dim() != 5: 313 | raise ValueError('expected 5D input (got {}D input)' 314 | .format(input.dim())) 315 | super(SynchronizedBatchNorm3d, self)._check_input_dim(input) 316 | -------------------------------------------------------------------------------- /models/sync_batchnorm/batchnorm_reimpl.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # File : batchnorm_reimpl.py 4 | # Author : acgtyrant 5 | # Date : 11/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.init as init 14 | 15 | __all__ = ['BatchNormReimpl'] 16 | 17 | 18 | class BatchNorm2dReimpl(nn.Module): 19 | """ 20 | A re-implementation of batch normalization, used for testing the numerical 21 | stability. 22 | 23 | Author: acgtyrant 24 | See also: 25 | https://github.com/vacancy/Synchronized-BatchNorm-PyTorch/issues/14 26 | """ 27 | def __init__(self, num_features, eps=1e-5, momentum=0.1): 28 | super().__init__() 29 | 30 | self.num_features = num_features 31 | self.eps = eps 32 | self.momentum = momentum 33 | self.weight = nn.Parameter(torch.empty(num_features)) 34 | self.bias = nn.Parameter(torch.empty(num_features)) 35 | self.register_buffer('running_mean', torch.zeros(num_features)) 36 | self.register_buffer('running_var', torch.ones(num_features)) 37 | self.reset_parameters() 38 | 39 | def reset_running_stats(self): 40 | self.running_mean.zero_() 41 | self.running_var.fill_(1) 42 | 43 | def reset_parameters(self): 44 | self.reset_running_stats() 45 | init.uniform_(self.weight) 46 | init.zeros_(self.bias) 47 | 48 | def forward(self, input_): 49 | batchsize, channels, height, width = input_.size() 50 | numel = batchsize * height * width 51 | input_ = input_.permute(1, 0, 2, 3).contiguous().view(channels, numel) 52 | sum_ = input_.sum(1) 53 | sum_of_square = input_.pow(2).sum(1) 54 | mean = sum_ / numel 55 | sumvar = sum_of_square - sum_ * mean 56 | 57 | self.running_mean = ( 58 | (1 - self.momentum) * self.running_mean 59 | + self.momentum * mean.detach() 60 | ) 61 | unbias_var = sumvar / (numel - 1) 62 | self.running_var = ( 63 | (1 - self.momentum) * self.running_var 64 | + self.momentum * unbias_var.detach() 65 | ) 66 | 67 | bias_var = sumvar / numel 68 | inv_std = 1 / (bias_var + self.eps).pow(0.5) 69 | output = ( 70 | (input_ - mean.unsqueeze(1)) * inv_std.unsqueeze(1) * 71 | self.weight.unsqueeze(1) + self.bias.unsqueeze(1)) 72 | 73 | return output.view(channels, batchsize, height, width).permute(1, 0, 2, 3).contiguous() 74 | 75 | -------------------------------------------------------------------------------- /models/sync_batchnorm/comm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : comm.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import queue 12 | import collections 13 | import threading 14 | 15 | __all__ = ['FutureResult', 'SlavePipe', 'SyncMaster'] 16 | 17 | 18 | class FutureResult(object): 19 | """A thread-safe future implementation. Used only as one-to-one pipe.""" 20 | 21 | def __init__(self): 22 | self._result = None 23 | self._lock = threading.Lock() 24 | self._cond = threading.Condition(self._lock) 25 | 26 | def put(self, result): 27 | with self._lock: 28 | assert self._result is None, 'Previous result has\'t been fetched.' 29 | self._result = result 30 | self._cond.notify() 31 | 32 | def get(self): 33 | with self._lock: 34 | if self._result is None: 35 | self._cond.wait() 36 | 37 | res = self._result 38 | self._result = None 39 | return res 40 | 41 | 42 | _MasterRegistry = collections.namedtuple('MasterRegistry', ['result']) 43 | _SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result']) 44 | 45 | 46 | class SlavePipe(_SlavePipeBase): 47 | """Pipe for master-slave communication.""" 48 | 49 | def run_slave(self, msg): 50 | self.queue.put((self.identifier, msg)) 51 | ret = self.result.get() 52 | self.queue.put(True) 53 | return ret 54 | 55 | 56 | class SyncMaster(object): 57 | """An abstract `SyncMaster` object. 58 | 59 | - During the replication, as the data parallel will trigger an callback of each module, all slave devices should 60 | call `register(id)` and obtain an `SlavePipe` to communicate with the master. 61 | - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected, 62 | and passed to a registered callback. 63 | - After receiving the messages, the master device should gather the information and determine to message passed 64 | back to each slave devices. 65 | """ 66 | 67 | def __init__(self, master_callback): 68 | """ 69 | 70 | Args: 71 | master_callback: a callback to be invoked after having collected messages from slave devices. 72 | """ 73 | self._master_callback = master_callback 74 | self._queue = queue.Queue() 75 | self._registry = collections.OrderedDict() 76 | self._activated = False 77 | 78 | def __getstate__(self): 79 | return {'master_callback': self._master_callback} 80 | 81 | def __setstate__(self, state): 82 | self.__init__(state['master_callback']) 83 | 84 | def register_slave(self, identifier): 85 | """ 86 | Register an slave device. 87 | 88 | Args: 89 | identifier: an identifier, usually is the device id. 90 | 91 | Returns: a `SlavePipe` object which can be used to communicate with the master device. 92 | 93 | """ 94 | if self._activated: 95 | assert self._queue.empty(), 'Queue is not clean before next initialization.' 96 | self._activated = False 97 | self._registry.clear() 98 | future = FutureResult() 99 | self._registry[identifier] = _MasterRegistry(future) 100 | return SlavePipe(identifier, self._queue, future) 101 | 102 | def run_master(self, master_msg): 103 | """ 104 | Main entry for the master device in each forward pass. 105 | The messages were first collected from each devices (including the master device), and then 106 | an callback will be invoked to compute the message to be sent back to each devices 107 | (including the master device). 108 | 109 | Args: 110 | master_msg: the message that the master want to send to itself. This will be placed as the first 111 | message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example. 112 | 113 | Returns: the message to be sent back to the master device. 114 | 115 | """ 116 | self._activated = True 117 | 118 | intermediates = [(0, master_msg)] 119 | for i in range(self.nr_slaves): 120 | intermediates.append(self._queue.get()) 121 | 122 | results = self._master_callback(intermediates) 123 | assert results[0][0] == 0, 'The first result should belongs to the master.' 124 | 125 | for i, res in results: 126 | if i == 0: 127 | continue 128 | self._registry[i].result.put(res) 129 | 130 | for i in range(self.nr_slaves): 131 | assert self._queue.get() is True 132 | 133 | return results[0][1] 134 | 135 | @property 136 | def nr_slaves(self): 137 | return len(self._registry) 138 | -------------------------------------------------------------------------------- /models/sync_batchnorm/replicate.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : replicate.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import functools 12 | 13 | from torch.nn.parallel.data_parallel import DataParallel 14 | 15 | __all__ = [ 16 | 'CallbackContext', 17 | 'execute_replication_callbacks', 18 | 'DataParallelWithCallback', 19 | 'patch_replication_callback' 20 | ] 21 | 22 | 23 | class CallbackContext(object): 24 | pass 25 | 26 | 27 | def execute_replication_callbacks(modules): 28 | """ 29 | Execute an replication callback `__data_parallel_replicate__` on each module created by original replication. 30 | 31 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 32 | 33 | Note that, as all modules are isomorphism, we assign each sub-module with a context 34 | (shared among multiple copies of this module on different devices). 35 | Through this context, different copies can share some information. 36 | 37 | We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback 38 | of any slave copies. 39 | """ 40 | master_copy = modules[0] 41 | nr_modules = len(list(master_copy.modules())) 42 | ctxs = [CallbackContext() for _ in range(nr_modules)] 43 | 44 | for i, module in enumerate(modules): 45 | for j, m in enumerate(module.modules()): 46 | if hasattr(m, '__data_parallel_replicate__'): 47 | m.__data_parallel_replicate__(ctxs[j], i) 48 | 49 | 50 | class DataParallelWithCallback(DataParallel): 51 | """ 52 | Data Parallel with a replication callback. 53 | 54 | An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by 55 | original `replicate` function. 56 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 57 | 58 | Examples: 59 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 60 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 61 | # sync_bn.__data_parallel_replicate__ will be invoked. 62 | """ 63 | 64 | def replicate(self, module, device_ids): 65 | modules = super(DataParallelWithCallback, self).replicate(module, device_ids) 66 | execute_replication_callbacks(modules) 67 | return modules 68 | 69 | 70 | def patch_replication_callback(data_parallel): 71 | """ 72 | Monkey-patch an existing `DataParallel` object. Add the replication callback. 73 | Useful when you have customized `DataParallel` implementation. 74 | 75 | Examples: 76 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 77 | > sync_bn = DataParallel(sync_bn, device_ids=[0, 1]) 78 | > patch_replication_callback(sync_bn) 79 | # this is equivalent to 80 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 81 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 82 | """ 83 | 84 | assert isinstance(data_parallel, DataParallel) 85 | 86 | old_replicate = data_parallel.replicate 87 | 88 | @functools.wraps(old_replicate) 89 | def new_replicate(module, device_ids): 90 | modules = old_replicate(module, device_ids) 91 | execute_replication_callbacks(modules) 92 | return modules 93 | 94 | data_parallel.replicate = new_replicate 95 | -------------------------------------------------------------------------------- /models/sync_batchnorm/unittest.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : unittest.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import unittest 12 | import torch 13 | 14 | 15 | class TorchTestCase(unittest.TestCase): 16 | def assertTensorClose(self, x, y): 17 | adiff = float((x - y).abs().max()) 18 | if (y == 0).all(): 19 | rdiff = 'NaN' 20 | else: 21 | rdiff = float((adiff / y).abs().max()) 22 | 23 | message = ( 24 | 'Tensor close check failed\n' 25 | 'adiff={}\n' 26 | 'rdiff={}\n' 27 | ).format(adiff, rdiff) 28 | self.assertTrue(torch.allclose(x, y), message) 29 | 30 | -------------------------------------------------------------------------------- /models/trainer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | import os 4 | 5 | import utils 6 | from models.networks import * 7 | 8 | import torch 9 | import torch.optim as optim 10 | import numpy as np 11 | from misc.metric_tool import ConfuseMatrixMeter 12 | from models.losses import cross_entropy, myloss 13 | import models.losses as losses 14 | from models.losses import get_alpha, softmax_helper, FocalLoss, mIoULoss, mmIoULoss 15 | 16 | from misc.logger_tool import Logger, Timer 17 | 18 | from utils import de_norm 19 | 20 | from tqdm import tqdm 21 | 22 | class CDTrainer(): 23 | 24 | def __init__(self, args, dataloaders): 25 | self.args = args 26 | self.dataloaders = dataloaders 27 | 28 | self.n_class = args.n_class 29 | # define G 30 | self.net_G = define_G(args=args, gpu_ids=args.gpu_ids) 31 | 32 | 33 | self.device = torch.device("cuda:%s" % args.gpu_ids[0] if torch.cuda.is_available() and len(args.gpu_ids)>0 34 | else "cpu") 35 | print(self.device) 36 | 37 | # Learning rate and Beta1 for Adam optimizers 38 | self.lr = args.lr 39 | 40 | # define optimizers 41 | if args.optimizer == "sgd": 42 | self.optimizer_G = optim.SGD(self.net_G.parameters(), lr=self.lr, 43 | momentum=0.9, 44 | weight_decay=5e-4) 45 | elif args.optimizer == "adam": 46 | self.optimizer_G = optim.Adam(self.net_G.parameters(), lr=self.lr, 47 | weight_decay=0) 48 | elif args.optimizer == "adamw": 49 | self.optimizer_G = optim.AdamW(self.net_G.parameters(), lr=self.lr, 50 | betas=(0.9, 0.999), weight_decay=0.01) 51 | 52 | # self.optimizer_G = optim.Adam(self.net_G.parameters(), lr=self.lr) 53 | 54 | # define lr schedulers 55 | self.exp_lr_scheduler_G = get_scheduler(self.optimizer_G, args) 56 | 57 | self.running_metric = ConfuseMatrixMeter(n_class=2) 58 | 59 | # define logger file 60 | logger_path = os.path.join(args.checkpoint_dir, 'log.txt') 61 | self.logger = Logger(logger_path) 62 | self.logger.write_dict_str(args.__dict__) 63 | # define timer 64 | self.timer = Timer() 65 | self.batch_size = args.batch_size 66 | 67 | # training log 68 | self.epoch_acc = 0 69 | self.best_val_acc = 0.0 70 | self.best_epoch_id = 0 71 | self.epoch_to_start = 0 72 | self.max_num_epochs = args.max_epochs 73 | 74 | self.global_step = 0 75 | self.steps_per_epoch = len(dataloaders['train']) 76 | self.total_steps = (self.max_num_epochs - self.epoch_to_start)*self.steps_per_epoch 77 | 78 | self.G_pred = None 79 | self.pred_vis = None 80 | self.batch = None 81 | self.G_loss = None 82 | self.is_training = False 83 | self.batch_id = 0 84 | self.epoch_id = 0 85 | self.checkpoint_dir = args.checkpoint_dir 86 | self.vis_dir = args.vis_dir 87 | 88 | self.shuffle_AB = args.shuffle_AB 89 | 90 | # define the loss functions 91 | self.multi_scale_train = args.multi_scale_train 92 | self.multi_scale_infer = args.multi_scale_infer 93 | self.weights = tuple(args.multi_pred_weights) 94 | if args.loss == 'ce': 95 | self._pxl_loss = cross_entropy 96 | elif args.loss == 'bce': 97 | self._pxl_loss = losses.binary_ce 98 | elif args.loss == 'fl': 99 | print('\n Calculating alpha in Focal-Loss (FL) ...') 100 | alpha = get_alpha(dataloaders['train']) # calculare class occurences 101 | print(f"alpha-0 (no-change)={alpha[0]}, alpha-1 (change)={alpha[1]}") 102 | self._pxl_loss = FocalLoss(apply_nonlin = softmax_helper, alpha = alpha, gamma = 2, smooth = 1e-5) 103 | elif args.loss == "miou": 104 | print('\n Calculating Class occurances in training set...') 105 | alpha = np.asarray(get_alpha(dataloaders['train'])) # calculare class occurences 106 | alpha = alpha/np.sum(alpha) 107 | # weights = torch.tensor([1.0, 1.0]).cuda() 108 | weights = 1-torch.from_numpy(alpha).cuda() 109 | print(f"Weights = {weights}") 110 | self._pxl_loss = mIoULoss(weight=weights, size_average=True, n_classes=args.n_class).cuda() 111 | elif args.loss == "mmiou": 112 | self._pxl_loss = mmIoULoss(n_classes=args.n_class).cuda() 113 | elif args.loss == "eas": 114 | # print('\n Calculating Class occurances in training set...') 115 | # alpha = get_alpha(dataloaders['train']) 116 | # print(f"alpha-0 (no-change)={alpha[0]}, alpha-1 (change)={alpha[1]}") 117 | self._pxl_loss = myloss(apply_nonlin=softmax_helper, alpha=None, gamma=2, smooth=1e-5) 118 | else: 119 | raise NotImplemented(args.loss) 120 | 121 | self.VAL_ACC = np.array([], np.float32) 122 | if os.path.exists(os.path.join(self.checkpoint_dir, 'val_acc.npy')): 123 | self.VAL_ACC = np.load(os.path.join(self.checkpoint_dir, 'val_acc.npy')) 124 | self.TRAIN_ACC = np.array([], np.float32) 125 | if os.path.exists(os.path.join(self.checkpoint_dir, 'train_acc.npy')): 126 | self.TRAIN_ACC = np.load(os.path.join(self.checkpoint_dir, 'train_acc.npy')) 127 | 128 | # check and create model dir 129 | if os.path.exists(self.checkpoint_dir) is False: 130 | os.mkdir(self.checkpoint_dir) 131 | if os.path.exists(self.vis_dir) is False: 132 | os.mkdir(self.vis_dir) 133 | 134 | 135 | def _load_checkpoint(self, ckpt_name='last_ckpt.pt'): 136 | print("\n") 137 | if os.path.exists(os.path.join(self.checkpoint_dir, ckpt_name)): 138 | self.logger.write('loading last checkpoint...\n') 139 | # load the entire checkpoint 140 | checkpoint = torch.load(os.path.join(self.checkpoint_dir, ckpt_name), 141 | map_location=self.device) 142 | # update net_G states 143 | self.net_G.load_state_dict(checkpoint['model_G_state_dict']) 144 | 145 | self.optimizer_G.load_state_dict(checkpoint['optimizer_G_state_dict']) 146 | self.exp_lr_scheduler_G.load_state_dict( 147 | checkpoint['exp_lr_scheduler_G_state_dict']) 148 | 149 | self.net_G.to(self.device) 150 | 151 | # update some other states 152 | self.epoch_to_start = checkpoint['epoch_id'] + 1 153 | self.best_val_acc = checkpoint['best_val_acc'] 154 | self.best_epoch_id = checkpoint['best_epoch_id'] 155 | 156 | self.total_steps = (self.max_num_epochs - self.epoch_to_start)*self.steps_per_epoch 157 | 158 | self.logger.write('Epoch_to_start = %d, Historical_best_acc = %.4f (at epoch %d)\n' % 159 | (self.epoch_to_start, self.best_val_acc, self.best_epoch_id)) 160 | self.logger.write('\n') 161 | elif self.args.pretrain is not None: 162 | print("Initializing backbone weights from: " + self.args.pretrain) 163 | self.net_G.load_state_dict(torch.load(self.args.pretrain), strict=False) 164 | self.net_G.to(self.device) 165 | self.net_G.eval() 166 | else: 167 | print('training from scratch...') 168 | print("\n") 169 | 170 | def _timer_update(self): 171 | self.global_step = (self.epoch_id-self.epoch_to_start) * self.steps_per_epoch + self.batch_id 172 | 173 | self.timer.update_progress((self.global_step + 1) / self.total_steps) 174 | est = self.timer.estimated_remaining() 175 | imps = (self.global_step + 1) * self.batch_size / self.timer.get_stage_elapsed() 176 | return imps, est 177 | 178 | def _visualize_pred(self): 179 | pred = torch.argmax(self.G_final_pred, dim=1, keepdim=True) 180 | pred_vis = pred * 255 181 | return pred_vis 182 | 183 | def _save_checkpoint(self, ckpt_name): 184 | torch.save({ 185 | 'epoch_id': self.epoch_id, 186 | 'best_val_acc': self.best_val_acc, 187 | 'best_epoch_id': self.best_epoch_id, 188 | 'model_G_state_dict': self.net_G.state_dict(), 189 | 'optimizer_G_state_dict': self.optimizer_G.state_dict(), 190 | 'exp_lr_scheduler_G_state_dict': self.exp_lr_scheduler_G.state_dict(), 191 | }, os.path.join(self.checkpoint_dir, ckpt_name)) 192 | 193 | def _update_lr_schedulers(self): 194 | self.exp_lr_scheduler_G.step() 195 | 196 | def _update_metric(self): 197 | """ 198 | update metric 199 | """ 200 | target = self.batch['L'].to(self.device).detach() 201 | 202 | G_pred = self.G_final_pred.detach() 203 | G_pred = torch.argmax(G_pred, dim=1) 204 | 205 | current_score = self.running_metric.update_cm(pr=G_pred.cpu().numpy(), gt=target.cpu().numpy()) 206 | return current_score 207 | 208 | def _collect_running_batch_states(self): 209 | 210 | running_acc = self._update_metric() 211 | 212 | m = len(self.dataloaders['train']) 213 | if self.is_training is False: 214 | m = len(self.dataloaders['val']) 215 | 216 | imps, est = self._timer_update() 217 | if np.mod(self.batch_id, 100) == 1: 218 | message = 'Is_training: %s. [%d,%d][%d,%d], imps: %.2f, est: %.2fh, G_loss: %.5f, running_mf1: %.5f\n' %\ 219 | (self.is_training, self.epoch_id, self.max_num_epochs-1, self.batch_id, m, 220 | imps*self.batch_size, est, 221 | self.G_loss.item(), running_acc) 222 | self.logger.write(message) 223 | 224 | 225 | if np.mod(self.batch_id, 500) == 1: 226 | vis_input = utils.make_numpy_grid(de_norm(self.batch['A'])) 227 | vis_input2 = utils.make_numpy_grid(de_norm(self.batch['B'])) 228 | 229 | vis_pred = utils.make_numpy_grid(self._visualize_pred()) 230 | 231 | vis_gt = utils.make_numpy_grid(self.batch['L']) 232 | vis = np.concatenate([vis_input, vis_input2, vis_pred, vis_gt], axis=0) 233 | vis = np.clip(vis, a_min=0.0, a_max=1.0) 234 | file_name = os.path.join( 235 | self.vis_dir, 'istrain_'+str(self.is_training)+'_'+ 236 | str(self.epoch_id)+'_'+str(self.batch_id)+'.jpg') 237 | plt.imsave(file_name, vis) 238 | 239 | def _collect_epoch_states(self): 240 | scores = self.running_metric.get_scores() 241 | self.epoch_acc = scores['mf1'] 242 | self.logger.write('Is_training: %s. Epoch %d / %d, epoch_mF1= %.5f\n' % 243 | (self.is_training, self.epoch_id, self.max_num_epochs-1, self.epoch_acc)) 244 | message = '' 245 | for k, v in scores.items(): 246 | message += '%s: %.5f ' % (k, v) 247 | self.logger.write(message+'\n') 248 | self.logger.write('\n') 249 | 250 | def _update_checkpoints(self): 251 | 252 | # save current model 253 | self._save_checkpoint(ckpt_name='last_ckpt.pt') 254 | self.logger.write('Lastest model updated. Epoch_acc=%.4f, Historical_best_acc=%.4f (at epoch %d)\n' 255 | % (self.epoch_acc, self.best_val_acc, self.best_epoch_id)) 256 | self.logger.write('\n') 257 | 258 | # update the best model (based on eval acc) 259 | if self.epoch_acc > self.best_val_acc: 260 | self.best_val_acc = self.epoch_acc 261 | self.best_epoch_id = self.epoch_id 262 | self._save_checkpoint(ckpt_name='best_ckpt.pt') 263 | self.logger.write('*' * 10 + 'Best model updated!\n') 264 | self.logger.write('\n') 265 | 266 | def _update_training_acc_curve(self): 267 | # update train acc curve 268 | self.TRAIN_ACC = np.append(self.TRAIN_ACC, [self.epoch_acc]) 269 | np.save(os.path.join(self.checkpoint_dir, 'train_acc.npy'), self.TRAIN_ACC) 270 | 271 | def _update_val_acc_curve(self): 272 | # update val acc curve 273 | self.VAL_ACC = np.append(self.VAL_ACC, [self.epoch_acc]) 274 | np.save(os.path.join(self.checkpoint_dir, 'val_acc.npy'), self.VAL_ACC) 275 | 276 | def _clear_cache(self): 277 | self.running_metric.clear() 278 | 279 | 280 | def _forward_pass(self, batch): 281 | self.batch = batch 282 | img_in1 = batch['A'].to(self.device) 283 | img_in2 = batch['B'].to(self.device) 284 | 285 | self.G_pred = self.net_G(img_in1, img_in2) 286 | 287 | if self.multi_scale_infer == "True": 288 | self.G_final_pred = torch.zeros(self.G_pred[-1].size()).to(self.device) 289 | for pred in self.G_pred: 290 | if pred.size(2) != self.G_pred[-1].size(2): 291 | self.G_final_pred = self.G_final_pred + F.interpolate(pred, size=self.G_pred[-1].size(2), mode="nearest") 292 | else: 293 | self.G_final_pred = self.G_final_pred + pred 294 | self.G_final_pred = self.G_final_pred/len(self.G_pred) 295 | else: 296 | self.G_final_pred = self.G_pred[-1] 297 | 298 | 299 | def _backward_G(self): 300 | gt = self.batch['L'].to(self.device).float() 301 | 302 | if self.multi_scale_train == "True": 303 | i = 0 304 | temp_loss = 0.0 305 | for pred in self.G_pred: 306 | if pred.size(2) != gt.size(2): 307 | temp_loss = temp_loss + self.weights[i]*self._pxl_loss(pred, F.interpolate(gt, size=pred.size(2), mode="nearest")) 308 | else: 309 | temp_loss = temp_loss + self.weights[i]*self._pxl_loss(pred, gt) 310 | i+=1 311 | self.G_loss = temp_loss 312 | else: 313 | if self.args.loss == 'eas': 314 | gt_edge = self.batch['L_edge'].to(self.device).float() 315 | self.G_loss = self._pxl_loss(self.G_pred[-1], gt, self.G_pred[-2], gt_edge) 316 | else: 317 | self.G_loss = self._pxl_loss(self.G_pred[-1], gt) 318 | # print(self.G_pred[-1].shape) 319 | # print(gt.shape) 320 | self.G_loss.backward() 321 | 322 | 323 | def train_models(self): 324 | 325 | self._load_checkpoint() 326 | 327 | # loop over the dataset multiple times 328 | for self.epoch_id in range(self.epoch_to_start, self.max_num_epochs): 329 | 330 | ################## train ################# 331 | ########################################## 332 | self._clear_cache() 333 | self.is_training = True 334 | self.net_G.train() # Set model to training mode 335 | # Iterate over data. 336 | total = len(self.dataloaders['train']) 337 | self.logger.write('lr: %0.7f\n \n' % self.optimizer_G.param_groups[0]['lr']) 338 | for self.batch_id, batch in tqdm(enumerate(self.dataloaders['train'], 0), total=total): 339 | self._forward_pass(batch) 340 | # update G 341 | self.optimizer_G.zero_grad() 342 | self._backward_G() 343 | self.optimizer_G.step() 344 | self._collect_running_batch_states() 345 | self._timer_update() 346 | 347 | self._collect_epoch_states() 348 | self._update_training_acc_curve() 349 | self._update_lr_schedulers() 350 | 351 | 352 | ################## Eval ################## 353 | ########################################## 354 | self.logger.write('Begin evaluation...\n') 355 | self._clear_cache() 356 | self.is_training = False 357 | self.net_G.eval() 358 | 359 | # Iterate over data. 360 | for self.batch_id, batch in enumerate(self.dataloaders['val'], 0): 361 | with torch.no_grad(): 362 | self._forward_pass(batch) 363 | self._collect_running_batch_states() 364 | self._collect_epoch_states() 365 | 366 | ########### Update_Checkpoints ########### 367 | ########################################## 368 | self._update_val_acc_curve() 369 | self._update_checkpoints() 370 | 371 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # This file may be used to create an environment using: 2 | # $ conda create --name --file 3 | # platform: linux-64 4 | _libgcc_mutex=0.1=main 5 | _openmp_mutex=4.5=1_gnu 6 | blas=1.0=mkl 7 | brotli=1.0.9=h9c3ff4c_4 8 | bzip2=1.0.8=h7b6447c_0 9 | ca-certificates=2021.10.8=ha878542_0 10 | certifi=2021.10.8=py38h578d9bd_1 11 | colorama=0.4.4=pyh9f0ad1d_0 12 | cudatoolkit=10.2.89=hfd86e86_1 13 | cycler=0.11.0=pyhd8ed1ab_0 14 | dbus=1.13.18=hb2f20db_0 15 | einops=0.3.2=pyhd8ed1ab_0 16 | expat=2.2.10=h9c3ff4c_0 17 | ffmpeg=4.3=hf484d3e_0 18 | fontconfig=2.13.1=hba837de_1005 19 | fonttools=4.25.0=pyhd3eb1b0_0 20 | freetype=2.11.0=h70c0345_0 21 | gettext=0.19.8.1=hf34092f_1004 22 | giflib=5.2.1=h7b6447c_0 23 | glib=2.58.3=py38h73cb85d_1004 24 | gmp=6.2.1=h2531618_2 25 | gnutls=3.6.15=he1e5248_0 26 | gst-plugins-base=1.14.5=h0935bb2_2 27 | gstreamer=1.14.5=h36ae1b5_2 28 | icu=64.2=he1b5a44_1 29 | intel-openmp=2021.4.0=h06a4308_3561 30 | jpeg=9d=h7f8727e_0 31 | kiwisolver=1.3.1=py38h2531618_0 32 | lame=3.100=h7b6447c_0 33 | lcms2=2.12=h3be6417_0 34 | libedit=3.1.20210910=h7f8727e_0 35 | libffi=3.2.1=hf484d3e_1007 36 | libgcc-ng=9.3.0=h5101ec6_17 37 | libgfortran-ng=7.5.0=ha8ba4b0_17 38 | libgfortran4=7.5.0=ha8ba4b0_17 39 | libgomp=9.3.0=h5101ec6_17 40 | libiconv=1.15=h63c8f33_5 41 | libidn2=2.3.2=h7f8727e_0 42 | libpng=1.6.37=hbc83047_0 43 | libstdcxx-ng=9.3.0=hd4cf53a_17 44 | libtasn1=4.16.0=h27cfd23_0 45 | libtiff=4.2.0=h85742a9_0 46 | libunistring=0.9.10=h27cfd23_0 47 | libuuid=2.32.1=h7f98852_1000 48 | libuv=1.40.0=h7b6447c_0 49 | libwebp=1.2.0=h89dd481_0 50 | libwebp-base=1.2.0=h27cfd23_0 51 | libxcb=1.13=h7f98852_1003 52 | libxml2=2.9.10=hee79883_0 53 | lz4-c=1.9.3=h295c915_1 54 | matplotlib=3.4.3=py38h578d9bd_1 55 | matplotlib-base=3.4.3=py38hbbc1b5f_0 56 | mkl=2021.4.0=h06a4308_640 57 | mkl-service=2.4.0=py38h7f8727e_0 58 | mkl_fft=1.3.1=py38hd3c417c_0 59 | mkl_random=1.2.2=py38h51133e4_0 60 | munkres=1.1.4=pyh9f0ad1d_0 61 | ncurses=6.3=h7f8727e_2 62 | nettle=3.7.3=hbbd107a_1 63 | numpy=1.21.2=py38h20f2e39_0 64 | numpy-base=1.21.2=py38h79a1101_0 65 | olefile=0.46=pyhd3eb1b0_0 66 | openh264=2.1.1=h4ff587b_0 67 | openssl=1.1.1l=h7f8727e_0 68 | pcre=8.45=h9c3ff4c_0 69 | pillow=8.4.0=py38h5aabda8_0 70 | pip=21.2.4=py38h06a4308_0 71 | pthread-stubs=0.4=h36c2ea0_1001 72 | pyparsing=3.0.6=pyhd8ed1ab_0 73 | pyqt=5.9.2=py38h05f1152_4 74 | python=3.8.0=h0371630_2 75 | python-dateutil=2.8.2=pyhd8ed1ab_0 76 | python_abi=3.8=2_cp38 77 | pytorch=1.10.1=py3.8_cuda10.2_cudnn7.6.5_0 78 | pytorch-mutex=1.0=cuda 79 | qt=5.9.7=h0c104cb_3 80 | readline=7.0=h7b6447c_5 81 | scipy=1.7.1=py38h292c36d_2 82 | setuptools=58.0.4=py38h06a4308_0 83 | sip=4.19.13=py38he6710b0_0 84 | six=1.16.0=pyhd3eb1b0_0 85 | sqlite=3.33.0=h62c20be_0 86 | timm=0.4.12=pyhd8ed1ab_0 87 | tk=8.6.11=h1ccaba5_0 88 | torchaudio=0.10.1=py38_cu102 89 | torchvision=0.11.2=py38_cu102 90 | tornado=6.1=py38h497a2fe_1 91 | tqdm=4.62.3=pyhd8ed1ab_0 92 | typing_extensions=3.10.0.2=pyh06a4308_0 93 | wheel=0.37.0=pyhd3eb1b0_1 94 | xorg-libxau=1.0.9=h7f98852_0 95 | xorg-libxdmcp=1.1.3=h7f98852_0 96 | xz=5.2.5=h7b6447c_0 97 | zlib=1.2.11=h7f8727e_4 98 | zstd=1.4.9=haebb681_0 -------------------------------------------------------------------------------- /samples_LEVIR/A/test_102_0512_0000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen11221/EGCTNet_pytorch/c35da37f05ca494aca4444050e80bde26e541ec2/samples_LEVIR/A/test_102_0512_0000.png -------------------------------------------------------------------------------- /samples_LEVIR/A/test_113_0256.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen11221/EGCTNet_pytorch/c35da37f05ca494aca4444050e80bde26e541ec2/samples_LEVIR/A/test_113_0256.png -------------------------------------------------------------------------------- /samples_LEVIR/A/test_121_0768_0256.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen11221/EGCTNet_pytorch/c35da37f05ca494aca4444050e80bde26e541ec2/samples_LEVIR/A/test_121_0768_0256.png -------------------------------------------------------------------------------- /samples_LEVIR/A/test_2_0000_0000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen11221/EGCTNet_pytorch/c35da37f05ca494aca4444050e80bde26e541ec2/samples_LEVIR/A/test_2_0000_0000.png -------------------------------------------------------------------------------- /samples_LEVIR/A/test_2_0000_0512.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen11221/EGCTNet_pytorch/c35da37f05ca494aca4444050e80bde26e541ec2/samples_LEVIR/A/test_2_0000_0512.png -------------------------------------------------------------------------------- /samples_LEVIR/A/test_55_0256_0000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen11221/EGCTNet_pytorch/c35da37f05ca494aca4444050e80bde26e541ec2/samples_LEVIR/A/test_55_0256_0000.png -------------------------------------------------------------------------------- /samples_LEVIR/A/test_77_0512_0256.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen11221/EGCTNet_pytorch/c35da37f05ca494aca4444050e80bde26e541ec2/samples_LEVIR/A/test_77_0512_0256.png -------------------------------------------------------------------------------- /samples_LEVIR/A/test_7_0256_0512.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen11221/EGCTNet_pytorch/c35da37f05ca494aca4444050e80bde26e541ec2/samples_LEVIR/A/test_7_0256_0512.png -------------------------------------------------------------------------------- /samples_LEVIR/A/train_36_0512_0512.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen11221/EGCTNet_pytorch/c35da37f05ca494aca4444050e80bde26e541ec2/samples_LEVIR/A/train_36_0512_0512.png -------------------------------------------------------------------------------- /samples_LEVIR/A/train_386_0512_0768.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen11221/EGCTNet_pytorch/c35da37f05ca494aca4444050e80bde26e541ec2/samples_LEVIR/A/train_386_0512_0768.png -------------------------------------------------------------------------------- /samples_LEVIR/A/train_412_0512_0768.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen11221/EGCTNet_pytorch/c35da37f05ca494aca4444050e80bde26e541ec2/samples_LEVIR/A/train_412_0512_0768.png -------------------------------------------------------------------------------- /samples_LEVIR/A/val_27_0000_0256.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen11221/EGCTNet_pytorch/c35da37f05ca494aca4444050e80bde26e541ec2/samples_LEVIR/A/val_27_0000_0256.png -------------------------------------------------------------------------------- /samples_LEVIR/B/test_102_0512_0000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen11221/EGCTNet_pytorch/c35da37f05ca494aca4444050e80bde26e541ec2/samples_LEVIR/B/test_102_0512_0000.png -------------------------------------------------------------------------------- /samples_LEVIR/B/test_113_0256.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen11221/EGCTNet_pytorch/c35da37f05ca494aca4444050e80bde26e541ec2/samples_LEVIR/B/test_113_0256.png -------------------------------------------------------------------------------- /samples_LEVIR/B/test_121_0768_0256.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen11221/EGCTNet_pytorch/c35da37f05ca494aca4444050e80bde26e541ec2/samples_LEVIR/B/test_121_0768_0256.png -------------------------------------------------------------------------------- /samples_LEVIR/B/test_2_0000_0000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen11221/EGCTNet_pytorch/c35da37f05ca494aca4444050e80bde26e541ec2/samples_LEVIR/B/test_2_0000_0000.png -------------------------------------------------------------------------------- /samples_LEVIR/B/test_2_0000_0512.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen11221/EGCTNet_pytorch/c35da37f05ca494aca4444050e80bde26e541ec2/samples_LEVIR/B/test_2_0000_0512.png -------------------------------------------------------------------------------- /samples_LEVIR/B/test_55_0256_0000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen11221/EGCTNet_pytorch/c35da37f05ca494aca4444050e80bde26e541ec2/samples_LEVIR/B/test_55_0256_0000.png -------------------------------------------------------------------------------- /samples_LEVIR/B/test_77_0512_0256.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen11221/EGCTNet_pytorch/c35da37f05ca494aca4444050e80bde26e541ec2/samples_LEVIR/B/test_77_0512_0256.png -------------------------------------------------------------------------------- /samples_LEVIR/B/test_7_0256_0512.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen11221/EGCTNet_pytorch/c35da37f05ca494aca4444050e80bde26e541ec2/samples_LEVIR/B/test_7_0256_0512.png -------------------------------------------------------------------------------- /samples_LEVIR/B/train_36_0512_0512.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen11221/EGCTNet_pytorch/c35da37f05ca494aca4444050e80bde26e541ec2/samples_LEVIR/B/train_36_0512_0512.png -------------------------------------------------------------------------------- /samples_LEVIR/B/train_386_0512_0768.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen11221/EGCTNet_pytorch/c35da37f05ca494aca4444050e80bde26e541ec2/samples_LEVIR/B/train_386_0512_0768.png -------------------------------------------------------------------------------- /samples_LEVIR/B/train_412_0512_0768.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen11221/EGCTNet_pytorch/c35da37f05ca494aca4444050e80bde26e541ec2/samples_LEVIR/B/train_412_0512_0768.png -------------------------------------------------------------------------------- /samples_LEVIR/B/val_27_0000_0256.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen11221/EGCTNet_pytorch/c35da37f05ca494aca4444050e80bde26e541ec2/samples_LEVIR/B/val_27_0000_0256.png -------------------------------------------------------------------------------- /samples_LEVIR/label/test_102_0512_0000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen11221/EGCTNet_pytorch/c35da37f05ca494aca4444050e80bde26e541ec2/samples_LEVIR/label/test_102_0512_0000.png -------------------------------------------------------------------------------- /samples_LEVIR/label/test_121_0768_0256.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen11221/EGCTNet_pytorch/c35da37f05ca494aca4444050e80bde26e541ec2/samples_LEVIR/label/test_121_0768_0256.png -------------------------------------------------------------------------------- /samples_LEVIR/label/test_2_0000_0000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen11221/EGCTNet_pytorch/c35da37f05ca494aca4444050e80bde26e541ec2/samples_LEVIR/label/test_2_0000_0000.png -------------------------------------------------------------------------------- /samples_LEVIR/label/test_2_0000_0512.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen11221/EGCTNet_pytorch/c35da37f05ca494aca4444050e80bde26e541ec2/samples_LEVIR/label/test_2_0000_0512.png -------------------------------------------------------------------------------- /samples_LEVIR/label/test_55_0256_0000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen11221/EGCTNet_pytorch/c35da37f05ca494aca4444050e80bde26e541ec2/samples_LEVIR/label/test_55_0256_0000.png -------------------------------------------------------------------------------- /samples_LEVIR/label/test_77_0512_0256.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen11221/EGCTNet_pytorch/c35da37f05ca494aca4444050e80bde26e541ec2/samples_LEVIR/label/test_77_0512_0256.png -------------------------------------------------------------------------------- /samples_LEVIR/label/test_7_0256_0512.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen11221/EGCTNet_pytorch/c35da37f05ca494aca4444050e80bde26e541ec2/samples_LEVIR/label/test_7_0256_0512.png -------------------------------------------------------------------------------- /samples_LEVIR/label/train_36_0512_0512.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen11221/EGCTNet_pytorch/c35da37f05ca494aca4444050e80bde26e541ec2/samples_LEVIR/label/train_36_0512_0512.png -------------------------------------------------------------------------------- /samples_LEVIR/label/train_386_0512_0768.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen11221/EGCTNet_pytorch/c35da37f05ca494aca4444050e80bde26e541ec2/samples_LEVIR/label/train_386_0512_0768.png -------------------------------------------------------------------------------- /samples_LEVIR/label/train_412_0512_0768.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen11221/EGCTNet_pytorch/c35da37f05ca494aca4444050e80bde26e541ec2/samples_LEVIR/label/train_412_0512_0768.png -------------------------------------------------------------------------------- /samples_LEVIR/label/val_27_0000_0256.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen11221/EGCTNet_pytorch/c35da37f05ca494aca4444050e80bde26e541ec2/samples_LEVIR/label/val_27_0000_0256.png -------------------------------------------------------------------------------- /samples_LEVIR/list/demo.txt: -------------------------------------------------------------------------------- 1 | test_77_0512_0256.png 2 | test_102_0512_0000.png 3 | test_121_0768_0256.png 4 | test_2_0000_0000.png 5 | test_2_0000_0512.png 6 | test_7_0256_0512.png 7 | test_55_0256_0000.png 8 | -------------------------------------------------------------------------------- /samples_LEVIR/predict_CD_BIT/test_102_0512_0000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen11221/EGCTNet_pytorch/c35da37f05ca494aca4444050e80bde26e541ec2/samples_LEVIR/predict_CD_BIT/test_102_0512_0000.png -------------------------------------------------------------------------------- /samples_LEVIR/predict_CD_BIT/test_121_0768_0256.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen11221/EGCTNet_pytorch/c35da37f05ca494aca4444050e80bde26e541ec2/samples_LEVIR/predict_CD_BIT/test_121_0768_0256.png -------------------------------------------------------------------------------- /samples_LEVIR/predict_CD_BIT/test_2_0000_0000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen11221/EGCTNet_pytorch/c35da37f05ca494aca4444050e80bde26e541ec2/samples_LEVIR/predict_CD_BIT/test_2_0000_0000.png -------------------------------------------------------------------------------- /samples_LEVIR/predict_CD_BIT/test_2_0000_0512.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen11221/EGCTNet_pytorch/c35da37f05ca494aca4444050e80bde26e541ec2/samples_LEVIR/predict_CD_BIT/test_2_0000_0512.png -------------------------------------------------------------------------------- /samples_LEVIR/predict_CD_BIT/test_55_0256_0000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen11221/EGCTNet_pytorch/c35da37f05ca494aca4444050e80bde26e541ec2/samples_LEVIR/predict_CD_BIT/test_55_0256_0000.png -------------------------------------------------------------------------------- /samples_LEVIR/predict_CD_BIT/test_77_0512_0256.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen11221/EGCTNet_pytorch/c35da37f05ca494aca4444050e80bde26e541ec2/samples_LEVIR/predict_CD_BIT/test_77_0512_0256.png -------------------------------------------------------------------------------- /samples_LEVIR/predict_CD_BIT/test_7_0256_0512.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen11221/EGCTNet_pytorch/c35da37f05ca494aca4444050e80bde26e541ec2/samples_LEVIR/predict_CD_BIT/test_7_0256_0512.png -------------------------------------------------------------------------------- /samples_LEVIR/predict_CD_ChangeFormerV6/test_102_0512_0000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen11221/EGCTNet_pytorch/c35da37f05ca494aca4444050e80bde26e541ec2/samples_LEVIR/predict_CD_ChangeFormerV6/test_102_0512_0000.png -------------------------------------------------------------------------------- /samples_LEVIR/predict_CD_ChangeFormerV6/test_121_0768_0256.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen11221/EGCTNet_pytorch/c35da37f05ca494aca4444050e80bde26e541ec2/samples_LEVIR/predict_CD_ChangeFormerV6/test_121_0768_0256.png -------------------------------------------------------------------------------- /samples_LEVIR/predict_CD_ChangeFormerV6/test_2_0000_0000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen11221/EGCTNet_pytorch/c35da37f05ca494aca4444050e80bde26e541ec2/samples_LEVIR/predict_CD_ChangeFormerV6/test_2_0000_0000.png -------------------------------------------------------------------------------- /samples_LEVIR/predict_CD_ChangeFormerV6/test_2_0000_0512.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen11221/EGCTNet_pytorch/c35da37f05ca494aca4444050e80bde26e541ec2/samples_LEVIR/predict_CD_ChangeFormerV6/test_2_0000_0512.png -------------------------------------------------------------------------------- /samples_LEVIR/predict_CD_ChangeFormerV6/test_55_0256_0000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen11221/EGCTNet_pytorch/c35da37f05ca494aca4444050e80bde26e541ec2/samples_LEVIR/predict_CD_ChangeFormerV6/test_55_0256_0000.png -------------------------------------------------------------------------------- /samples_LEVIR/predict_CD_ChangeFormerV6/test_77_0512_0256.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen11221/EGCTNet_pytorch/c35da37f05ca494aca4444050e80bde26e541ec2/samples_LEVIR/predict_CD_ChangeFormerV6/test_77_0512_0256.png -------------------------------------------------------------------------------- /samples_LEVIR/predict_CD_ChangeFormerV6/test_7_0256_0512.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen11221/EGCTNet_pytorch/c35da37f05ca494aca4444050e80bde26e541ec2/samples_LEVIR/predict_CD_ChangeFormerV6/test_7_0256_0512.png -------------------------------------------------------------------------------- /samples_LEVIR/predict_CD_DTCDSCN/test_102_0512_0000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen11221/EGCTNet_pytorch/c35da37f05ca494aca4444050e80bde26e541ec2/samples_LEVIR/predict_CD_DTCDSCN/test_102_0512_0000.png -------------------------------------------------------------------------------- /samples_LEVIR/predict_CD_DTCDSCN/test_121_0768_0256.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen11221/EGCTNet_pytorch/c35da37f05ca494aca4444050e80bde26e541ec2/samples_LEVIR/predict_CD_DTCDSCN/test_121_0768_0256.png -------------------------------------------------------------------------------- /samples_LEVIR/predict_CD_DTCDSCN/test_2_0000_0000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen11221/EGCTNet_pytorch/c35da37f05ca494aca4444050e80bde26e541ec2/samples_LEVIR/predict_CD_DTCDSCN/test_2_0000_0000.png -------------------------------------------------------------------------------- /samples_LEVIR/predict_CD_DTCDSCN/test_2_0000_0512.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen11221/EGCTNet_pytorch/c35da37f05ca494aca4444050e80bde26e541ec2/samples_LEVIR/predict_CD_DTCDSCN/test_2_0000_0512.png -------------------------------------------------------------------------------- /samples_LEVIR/predict_CD_DTCDSCN/test_55_0256_0000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen11221/EGCTNet_pytorch/c35da37f05ca494aca4444050e80bde26e541ec2/samples_LEVIR/predict_CD_DTCDSCN/test_55_0256_0000.png -------------------------------------------------------------------------------- /samples_LEVIR/predict_CD_DTCDSCN/test_77_0512_0256.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen11221/EGCTNet_pytorch/c35da37f05ca494aca4444050e80bde26e541ec2/samples_LEVIR/predict_CD_DTCDSCN/test_77_0512_0256.png -------------------------------------------------------------------------------- /samples_LEVIR/predict_CD_DTCDSCN/test_7_0256_0512.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen11221/EGCTNet_pytorch/c35da37f05ca494aca4444050e80bde26e541ec2/samples_LEVIR/predict_CD_DTCDSCN/test_7_0256_0512.png -------------------------------------------------------------------------------- /samples_LEVIR/predict_CD_SiamUnet_conc/test_102_0512_0000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen11221/EGCTNet_pytorch/c35da37f05ca494aca4444050e80bde26e541ec2/samples_LEVIR/predict_CD_SiamUnet_conc/test_102_0512_0000.png -------------------------------------------------------------------------------- /samples_LEVIR/predict_CD_SiamUnet_conc/test_121_0768_0256.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen11221/EGCTNet_pytorch/c35da37f05ca494aca4444050e80bde26e541ec2/samples_LEVIR/predict_CD_SiamUnet_conc/test_121_0768_0256.png -------------------------------------------------------------------------------- /samples_LEVIR/predict_CD_SiamUnet_conc/test_2_0000_0000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen11221/EGCTNet_pytorch/c35da37f05ca494aca4444050e80bde26e541ec2/samples_LEVIR/predict_CD_SiamUnet_conc/test_2_0000_0000.png -------------------------------------------------------------------------------- /samples_LEVIR/predict_CD_SiamUnet_conc/test_2_0000_0512.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen11221/EGCTNet_pytorch/c35da37f05ca494aca4444050e80bde26e541ec2/samples_LEVIR/predict_CD_SiamUnet_conc/test_2_0000_0512.png -------------------------------------------------------------------------------- /samples_LEVIR/predict_CD_SiamUnet_conc/test_55_0256_0000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen11221/EGCTNet_pytorch/c35da37f05ca494aca4444050e80bde26e541ec2/samples_LEVIR/predict_CD_SiamUnet_conc/test_55_0256_0000.png -------------------------------------------------------------------------------- /samples_LEVIR/predict_CD_SiamUnet_conc/test_77_0512_0256.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen11221/EGCTNet_pytorch/c35da37f05ca494aca4444050e80bde26e541ec2/samples_LEVIR/predict_CD_SiamUnet_conc/test_77_0512_0256.png -------------------------------------------------------------------------------- /samples_LEVIR/predict_CD_SiamUnet_conc/test_7_0256_0512.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen11221/EGCTNet_pytorch/c35da37f05ca494aca4444050e80bde26e541ec2/samples_LEVIR/predict_CD_SiamUnet_conc/test_7_0256_0512.png -------------------------------------------------------------------------------- /samples_LEVIR/predict_CD_SiamUnet_diff/test_102_0512_0000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen11221/EGCTNet_pytorch/c35da37f05ca494aca4444050e80bde26e541ec2/samples_LEVIR/predict_CD_SiamUnet_diff/test_102_0512_0000.png -------------------------------------------------------------------------------- /samples_LEVIR/predict_CD_SiamUnet_diff/test_121_0768_0256.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen11221/EGCTNet_pytorch/c35da37f05ca494aca4444050e80bde26e541ec2/samples_LEVIR/predict_CD_SiamUnet_diff/test_121_0768_0256.png -------------------------------------------------------------------------------- /samples_LEVIR/predict_CD_SiamUnet_diff/test_2_0000_0000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen11221/EGCTNet_pytorch/c35da37f05ca494aca4444050e80bde26e541ec2/samples_LEVIR/predict_CD_SiamUnet_diff/test_2_0000_0000.png -------------------------------------------------------------------------------- /samples_LEVIR/predict_CD_SiamUnet_diff/test_2_0000_0512.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen11221/EGCTNet_pytorch/c35da37f05ca494aca4444050e80bde26e541ec2/samples_LEVIR/predict_CD_SiamUnet_diff/test_2_0000_0512.png -------------------------------------------------------------------------------- /samples_LEVIR/predict_CD_SiamUnet_diff/test_55_0256_0000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen11221/EGCTNet_pytorch/c35da37f05ca494aca4444050e80bde26e541ec2/samples_LEVIR/predict_CD_SiamUnet_diff/test_55_0256_0000.png -------------------------------------------------------------------------------- /samples_LEVIR/predict_CD_SiamUnet_diff/test_77_0512_0256.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen11221/EGCTNet_pytorch/c35da37f05ca494aca4444050e80bde26e541ec2/samples_LEVIR/predict_CD_SiamUnet_diff/test_77_0512_0256.png -------------------------------------------------------------------------------- /samples_LEVIR/predict_CD_SiamUnet_diff/test_7_0256_0512.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen11221/EGCTNet_pytorch/c35da37f05ca494aca4444050e80bde26e541ec2/samples_LEVIR/predict_CD_SiamUnet_diff/test_7_0256_0512.png -------------------------------------------------------------------------------- /samples_LEVIR/predict_CD_Unet/test_102_0512_0000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen11221/EGCTNet_pytorch/c35da37f05ca494aca4444050e80bde26e541ec2/samples_LEVIR/predict_CD_Unet/test_102_0512_0000.png -------------------------------------------------------------------------------- /samples_LEVIR/predict_CD_Unet/test_121_0768_0256.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen11221/EGCTNet_pytorch/c35da37f05ca494aca4444050e80bde26e541ec2/samples_LEVIR/predict_CD_Unet/test_121_0768_0256.png -------------------------------------------------------------------------------- /samples_LEVIR/predict_CD_Unet/test_2_0000_0000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen11221/EGCTNet_pytorch/c35da37f05ca494aca4444050e80bde26e541ec2/samples_LEVIR/predict_CD_Unet/test_2_0000_0000.png -------------------------------------------------------------------------------- /samples_LEVIR/predict_CD_Unet/test_2_0000_0512.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen11221/EGCTNet_pytorch/c35da37f05ca494aca4444050e80bde26e541ec2/samples_LEVIR/predict_CD_Unet/test_2_0000_0512.png -------------------------------------------------------------------------------- /samples_LEVIR/predict_CD_Unet/test_55_0256_0000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen11221/EGCTNet_pytorch/c35da37f05ca494aca4444050e80bde26e541ec2/samples_LEVIR/predict_CD_Unet/test_55_0256_0000.png -------------------------------------------------------------------------------- /samples_LEVIR/predict_CD_Unet/test_77_0512_0256.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen11221/EGCTNet_pytorch/c35da37f05ca494aca4444050e80bde26e541ec2/samples_LEVIR/predict_CD_Unet/test_77_0512_0256.png -------------------------------------------------------------------------------- /samples_LEVIR/predict_CD_Unet/test_7_0256_0512.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen11221/EGCTNet_pytorch/c35da37f05ca494aca4444050e80bde26e541ec2/samples_LEVIR/predict_CD_Unet/test_7_0256_0512.png -------------------------------------------------------------------------------- /samples_LEVIR/predict_ChangeFormerV6/test_102_0512_0000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen11221/EGCTNet_pytorch/c35da37f05ca494aca4444050e80bde26e541ec2/samples_LEVIR/predict_ChangeFormerV6/test_102_0512_0000.png -------------------------------------------------------------------------------- /samples_LEVIR/predict_ChangeFormerV6/test_121_0768_0256.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen11221/EGCTNet_pytorch/c35da37f05ca494aca4444050e80bde26e541ec2/samples_LEVIR/predict_ChangeFormerV6/test_121_0768_0256.png -------------------------------------------------------------------------------- /samples_LEVIR/predict_ChangeFormerV6/test_2_0000_0000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen11221/EGCTNet_pytorch/c35da37f05ca494aca4444050e80bde26e541ec2/samples_LEVIR/predict_ChangeFormerV6/test_2_0000_0000.png -------------------------------------------------------------------------------- /samples_LEVIR/predict_ChangeFormerV6/test_2_0000_0512.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen11221/EGCTNet_pytorch/c35da37f05ca494aca4444050e80bde26e541ec2/samples_LEVIR/predict_ChangeFormerV6/test_2_0000_0512.png -------------------------------------------------------------------------------- /samples_LEVIR/predict_ChangeFormerV6/test_55_0256_0000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen11221/EGCTNet_pytorch/c35da37f05ca494aca4444050e80bde26e541ec2/samples_LEVIR/predict_ChangeFormerV6/test_55_0256_0000.png -------------------------------------------------------------------------------- /samples_LEVIR/predict_ChangeFormerV6/test_77_0512_0256.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen11221/EGCTNet_pytorch/c35da37f05ca494aca4444050e80bde26e541ec2/samples_LEVIR/predict_ChangeFormerV6/test_77_0512_0256.png -------------------------------------------------------------------------------- /samples_LEVIR/predict_ChangeFormerV6/test_7_0256_0512.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen11221/EGCTNet_pytorch/c35da37f05ca494aca4444050e80bde26e541ec2/samples_LEVIR/predict_ChangeFormerV6/test_7_0256_0512.png -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch.utils.data import DataLoader 4 | from torchvision import utils 5 | 6 | import data_config 7 | from datasets.CD_dataset import CDDataset 8 | 9 | 10 | def get_loader(data_name, img_size=256, batch_size=8, split='test', 11 | is_train=False, dataset='CDDataset'): 12 | dataConfig = data_config.DataConfig().get_data_config(data_name) 13 | root_dir = dataConfig.root_dir 14 | label_transform = dataConfig.label_transform 15 | 16 | if dataset == 'CDDataset': 17 | data_set = CDDataset(root_dir=root_dir, split=split, 18 | img_size=img_size, is_train=is_train, 19 | label_transform=label_transform) 20 | else: 21 | raise NotImplementedError( 22 | 'Wrong dataset name %s (choose one from [CDDataset])' 23 | % dataset) 24 | 25 | shuffle = is_train 26 | dataloader = DataLoader(data_set, batch_size=batch_size, 27 | shuffle=shuffle, num_workers=0) 28 | 29 | return dataloader 30 | 31 | 32 | def get_loaders(args): 33 | 34 | data_name = args.data_name 35 | dataConfig = data_config.DataConfig().get_data_config(data_name) 36 | root_dir = dataConfig.root_dir 37 | label_transform = dataConfig.label_transform 38 | split = args.split 39 | split_val = 'val' 40 | if hasattr(args, 'split_val'): 41 | split_val = args.split_val 42 | if args.dataset == 'CDDataset': 43 | training_set = CDDataset(root_dir=root_dir, split=split, 44 | img_size=args.img_size,is_train=True, 45 | label_transform=label_transform) 46 | val_set = CDDataset(root_dir=root_dir, split=split_val, 47 | img_size=args.img_size,is_train=False, 48 | label_transform=label_transform) 49 | else: 50 | raise NotImplementedError( 51 | 'Wrong dataset name %s (choose one from [CDDataset,])' 52 | % args.dataset) 53 | 54 | datasets = {'train': training_set, 'val': val_set} 55 | dataloaders = {x: DataLoader(datasets[x], batch_size=args.batch_size, 56 | shuffle=True, num_workers=args.num_workers) 57 | for x in ['train', 'val']} 58 | 59 | return dataloaders 60 | 61 | 62 | def make_numpy_grid(tensor_data, pad_value=0,padding=0): 63 | tensor_data = tensor_data.detach() 64 | vis = utils.make_grid(tensor_data, pad_value=pad_value,padding=padding) 65 | vis = np.array(vis.cpu()).transpose((1,2,0)) 66 | if vis.shape[2] == 1: 67 | vis = np.stack([vis, vis, vis], axis=-1) 68 | return vis 69 | 70 | 71 | def de_norm(tensor_data): 72 | return tensor_data * 0.5 + 0.5 73 | 74 | 75 | def get_device(args): 76 | # set gpu ids 77 | str_ids = args.gpu_ids.split(',') 78 | args.gpu_ids = [] 79 | for str_id in str_ids: 80 | id = int(str_id) 81 | if id >= 0: 82 | args.gpu_ids.append(id) 83 | if len(args.gpu_ids) > 0: 84 | torch.cuda.set_device(args.gpu_ids[0]) --------------------------------------------------------------------------------