├── .gitignore ├── README.md ├── STADE-CDNet ├── .gitignore ├── data_config.py ├── data_preparation │ ├── dsifn_cd_256.m │ ├── find_mean_std.py │ └── levir_cd_256.m ├── datasets │ ├── CD_dataset.py │ └── data_utils.py ├── eval_zl.py ├── main_zl.py ├── misc │ ├── imutils.py │ ├── logger_tool.py │ ├── metric_tool.py │ ├── pyutils.py │ └── torchutils.py ├── models │ ├── ChangeFormer.py │ ├── ChangeFormerBaseNetworks.py │ ├── DTCDSCN.py │ ├── SiamUnet_conc.py │ ├── SiamUnet_diff.py │ ├── Unet.py │ ├── __init__.py │ ├── basic_model.py │ ├── evaluator.py │ ├── help_funcs.py │ ├── losses.py │ ├── networks.py │ ├── pixel_shuffel_up.py │ ├── resnet.py │ └── trainer.py ├── samples_DSIFN │ ├── A │ │ ├── 0_2.png │ │ ├── 1_1.png │ │ ├── 2_4.png │ │ ├── 3_4.png │ │ ├── 4_4.png │ │ ├── 5_3.png │ │ ├── 6_3.png │ │ ├── 7_4.png │ │ ├── 8_3.png │ │ └── 9_3.png │ ├── B │ │ ├── 0_2.png │ │ ├── 1_1.png │ │ ├── 2_4.png │ │ ├── 3_4.png │ │ ├── 4_4.png │ │ ├── 5_3.png │ │ ├── 6_3.png │ │ ├── 7_4.png │ │ ├── 8_3.png │ │ └── 9_3.png │ ├── label │ │ ├── 0_2.png │ │ ├── 1_1.png │ │ ├── 2_4.png │ │ ├── 3_4.png │ │ ├── 4_4.png │ │ ├── 5_3.png │ │ ├── 6_3.png │ │ ├── 7_4.png │ │ ├── 8_3.png │ │ └── 9_3.png │ └── list │ │ └── demo.txt └── utils.py └── image ├── 1 (2).png ├── 11.png ├── 16.png ├── 22.png ├── 33.png ├── 4.png ├── 44.png ├── 5.png ├── 55.jpg ├── 6.png ├── 66.png ├── 7.png └── 77.png /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # STADE-CDNet: Spatial-temporal attention with difference enhancement-based Network for remote sensing image change detection 2 | ## Requirements 3 | 4 | 5 | Python 3.8.0 6 | pytorch 1.10.1 7 | torchvision 0.11.2 8 | einops 0.3.2 9 | 10 | 11 | ## Installation 12 | Clone this repo: 13 | ```python 14 | git clone https://github.com/LiLisaZhi/STADE-CDNet.git cd STADE-CDNet 15 | 16 | ``` 17 | 18 | 19 | ## Dataset Preparation 20 | 21 | 22 | ``` 23 | """ 24 | Change detection data set with pixel-level binary labels; 25 | ├─A 26 | ├─B 27 | ├─label 28 | └─list 29 | """ 30 | ``` 31 | `A`:image of pro-image; 32 | `B`:image of post-image; 33 | `label`:label maps; 34 | `list`:contains train.txt, val.txt and test.txt, each file records the image names (XXX.png) in the change detection dataset. 35 | 36 | ## Links to download processed datsets 37 | - LEVIR-CD:[`click here to download`](https://justchenhao.github.io/LEVIR/) 38 | - DSIFN-CD: [`click here to download`](https://github.com/GeoZcx/A-deeply-supervised-image-fusion-network-for-change-detection-in-remote-sensing-images/tree/master/dataset) 39 | ## References 40 | Appreciate the work from the following repositories: 41 | ``` 42 | https://github.com/justchenhao/BIT_CD 43 | ``` 44 | 45 | ``` 46 | https://github.com/wgcban/ChangeFormer 47 | ``` 48 | 49 | (The code implementation of our STADE-CDNet method references these code repoistories) 50 | ## Contact 51 | lisa_zhi@foxmail.com 52 | 53 | -------------------------------------------------------------------------------- /STADE-CDNet/.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | *.DS_Store 3 | ./sha256 4 | ./sha256.pub 5 | -------------------------------------------------------------------------------- /STADE-CDNet/data_config.py: -------------------------------------------------------------------------------- 1 | 2 | class DataConfig: 3 | data_name = "" 4 | root_dir = "" 5 | label_transform = "" 6 | def get_data_config(self, data_name): 7 | self.data_name = data_name 8 | if data_name == 'LEVIR': 9 | self.label_transform = "norm" 10 | self.root_dir = '/where' 11 | elif data_name == 'DSIFN': 12 | self.label_transform = "norm" 13 | self.root_dir = '/where' 14 | elif data_name == 'WHU': 15 | self.label_transform = "norm" 16 | self.root_dir = '/where' 17 | elif data_name == 'CDD': 18 | self.label_transform = "norm" 19 | self.root_dir = '/where' 20 | elif data_name == 'TYPO': 21 | self.label_transform = "norm" 22 | self.root_dir = '/where' 23 | elif data_name == 'quick_start_LEVIR': 24 | self.root_dir = '/where' 25 | elif data_name == 'quick_start_DSIFN': 26 | self.root_dir = '/where' 27 | else: 28 | raise TypeError('%s has not defined' % data_name) 29 | return self 30 | 31 | 32 | 33 | 34 | if __name__ == '__main__': 35 | data = DataConfig().get_data_config(data_name='LEVIR') 36 | print(data.data_name) 37 | print(data.root_dir) 38 | print(data.label_transform) 39 | 40 | -------------------------------------------------------------------------------- /STADE-CDNet/data_preparation/dsifn_cd_256.m: -------------------------------------------------------------------------------- 1 | %Dataset preparation code for DSFIN dataset (MATLAB) 2 | %Download DSFIN dataset here: https://github.com/GeoZcx/A-deeply-supervised-image-fusion-network-for-change-detection-in-remote-sensing-images/tree/master/dataset 3 | %This code generate 256x256 image partches required for the train/val/test 4 | %Please create folders according to following format. 5 | %DSIFN_256 6 | %------(train) 7 | % |---> A 8 | % |---> B 9 | % |---> label 10 | %------(val) 11 | % |---> A 12 | % |---> B 13 | % |---> label 14 | %------(test) 15 | % |---> A 16 | % |---> B 17 | % |---> label 18 | %Then run this code 19 | %Then copy all images in train-A, val-A, test-A to a folder name A 20 | %Then copy all images in train-B, val-B, test-B to a folder name B 21 | %Then copy all images in train-label, val-label, test-label to a folder name label 22 | 23 | 24 | 25 | 26 | clear all; 27 | close all; 28 | clc; 29 | 30 | %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 31 | %Train-A 32 | imgs_name = struct2cell(dir('DSIFN/download/Archive/train/t1/*.jpg')); 33 | for i=1:1:length(imgs_name) 34 | img_file_name = imgs_name{1,i}; 35 | temp = imread(strcat('DSIFN/download/Archive/train/t1/', img_file_name)); 36 | c=1; 37 | for j=1:2 38 | for k=1:2 39 | patch = temp((j-1)*256+1: j*256, (k-1)*256+1: k*256, :); 40 | imwrite(patch, strcat('DSIFN_256/train/A/', img_file_name(1:end-4), '_', num2str(c), '.png')); 41 | c=c+1; 42 | end 43 | end 44 | 45 | end 46 | 47 | %Train-B 48 | imgs_name = struct2cell(dir('DSIFN/download/Archive/train/t2/*.jpg')); 49 | for i=1:1:length(imgs_name) 50 | img_file_name = imgs_name{1,i}; 51 | temp = imread(strcat('DSIFN/download/Archive/train/t2/', img_file_name)); 52 | c=1; 53 | for j=1:2 54 | for k=1:2 55 | patch = temp((j-1)*256+1: j*256, (k-1)*256+1: k*256, :); 56 | imwrite(patch, strcat('DSIFN_256/train/B/', img_file_name(1:end-4), '_', num2str(c), '.png')); 57 | c=c+1; 58 | end 59 | end 60 | 61 | end 62 | 63 | %Train-label 64 | imgs_name = struct2cell(dir('DSIFN/download/Archive/train/mask/*.png')); 65 | for i=1:1:length(imgs_name) 66 | img_file_name = imgs_name{1,i}; 67 | temp = imread(strcat('DSIFN/download/Archive/train/mask/',img_file_name)); 68 | c=1; 69 | for j=1:2 70 | for k=1:2 71 | patch = temp((j-1)*256+1: j*256, (k-1)*256+1: k*256, :); 72 | imwrite(patch, strcat('DSIFN_256/train/label/', img_file_name(1:end-4), '_', num2str(c), '.png')); 73 | c=c+1; 74 | end 75 | end 76 | 77 | end 78 | 79 | %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 80 | %test-A 81 | imgs_name = struct2cell(dir('DSIFN/download/Archive/test/t1/*.jpg')); 82 | for i=1:1:length(imgs_name) 83 | img_file_name = imgs_name{1,i}; 84 | temp = imread(strcat('DSIFN/download/Archive/test/t1/', img_file_name)); 85 | c=1; 86 | for j=1:2 87 | for k=1:2 88 | patch = temp((j-1)*256+1: j*256, (k-1)*256+1: k*256, :); 89 | imwrite(patch, strcat('DSIFN_256/test/A/', img_file_name(1:end-4), '_', num2str(c), '.png')); 90 | c=c+1; 91 | end 92 | end 93 | 94 | end 95 | 96 | %test-B 97 | imgs_name = struct2cell(dir('DSIFN/download/Archive/test/t2/*.jpg')); 98 | for i=1:1:length(imgs_name) 99 | img_file_name = imgs_name{1,i}; 100 | temp = imread(strcat('DSIFN/download/Archive/test/t2/', img_file_name)); 101 | c=1; 102 | for j=1:2 103 | for k=1:2 104 | patch = temp((j-1)*256+1: j*256, (k-1)*256+1: k*256, :); 105 | imwrite(patch, strcat('DSIFN_256/test/B/', img_file_name(1:end-4), '_', num2str(c), '.png')); 106 | c=c+1; 107 | end 108 | end 109 | 110 | end 111 | 112 | %test-label 113 | imgs_name = struct2cell(dir('DSIFN/download/Archive/test/mask/*.png')); 114 | for i=1:1:length(imgs_name) 115 | img_file_name = imgs_name{1,i}; 116 | temp = imread(strcat('DSIFN/download/Archive/test/mask/',img_file_name)); 117 | c=1; 118 | for j=1:2 119 | for k=1:2 120 | patch = temp((j-1)*256+1: j*256, (k-1)*256+1: k*256, :); 121 | imwrite(patch, strcat('DSIFN_256/test/label/', img_file_name(1:end-4), '_', num2str(c), '.png')); 122 | c=c+1; 123 | end 124 | end 125 | 126 | end 127 | 128 | 129 | %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 130 | %val-A 131 | imgs_name = struct2cell(dir('DSIFN/download/Archive/val/t1/*.jpg')); 132 | for i=1:1:length(imgs_name) 133 | img_file_name = imgs_name{1,i}; 134 | temp = imread(strcat('DSIFN/download/Archive/val/t1/', img_file_name)); 135 | c=1; 136 | for j=1:2 137 | for k=1:2 138 | patch = temp((j-1)*256+1: j*256, (k-1)*256+1: k*256, :); 139 | imwrite(patch, strcat('DSIFN_256/val/A/', img_file_name(1:end-4), '_', num2str(c), '.png')); 140 | c=c+1; 141 | end 142 | end 143 | 144 | end 145 | 146 | %val-B 147 | imgs_name = struct2cell(dir('DSIFN/download/Archive/val/t2/*.jpg')); 148 | for i=1:1:length(imgs_name) 149 | img_file_name = imgs_name{1,i}; 150 | temp = imread(strcat('DSIFN/download/Archive/val/t2/', img_file_name)); 151 | c=1; 152 | for j=1:2 153 | for k=1:2 154 | patch = temp((j-1)*256+1: j*256, (k-1)*256+1: k*256, :); 155 | imwrite(patch, strcat('DSIFN_256/val/B/', img_file_name(1:end-4), '_', num2str(c), '.png')); 156 | c=c+1; 157 | end 158 | end 159 | 160 | end 161 | 162 | %val-label 163 | imgs_name = struct2cell(dir('DSIFN/download/Archive/val/mask/*.png')); 164 | for i=1:1:length(imgs_name) 165 | img_file_name = imgs_name{1,i}; 166 | temp = imread(strcat('DSIFN/download/Archive/val/mask/',img_file_name)); 167 | c=1; 168 | for j=1:2 169 | for k=1:2 170 | patch = temp((j-1)*256+1: j*256, (k-1)*256+1: k*256, :); 171 | imwrite(patch, strcat('DSIFN_256/val/label/', img_file_name(1:end-4), '_', num2str(c), '.png')); 172 | c=c+1; 173 | end 174 | end 175 | 176 | end 177 | -------------------------------------------------------------------------------- /STADE-CDNet/data_preparation/find_mean_std.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from PIL import Image 4 | 5 | 6 | if __name__ == '__main__': 7 | filepath = r"/where" # Dataset directory 8 | pathDir = os.listdir(filepath) # Images in dataset directory 9 | num = len(pathDir) # Here (512512) is the size of each image 10 | 11 | print("Computing mean...") 12 | data_mean = np.zeros(3) 13 | for idx in range(len(pathDir)): 14 | filename = pathDir[idx] 15 | img = Image.open(os.path.join(filepath, filename)) 16 | img = np.array(img) / 255.0 17 | print(img.shape) 18 | data_mean += np.mean(img) # Take all the data of the first dimension in the three-dimensional matrix 19 | # As the use of gray images, so calculate a channel on it 20 | data_mean = data_mean / num 21 | 22 | print("Computing var...") 23 | data_std = 0. 24 | for idx in range(len(pathDir)): 25 | filename = pathDir[idx] 26 | img = Image.open(os.path.join(filepath, filename)).convert('L').resize((256, 256)) 27 | img = np.array(img) / 255.0 28 | data_std += np.std(img) 29 | 30 | data_std = data_std / num 31 | print("mean:{}".format(data_mean)) 32 | print("std:{}".format(data_std)) -------------------------------------------------------------------------------- /STADE-CDNet/data_preparation/levir_cd_256.m: -------------------------------------------------------------------------------- 1 | %Dataset preparation code for DSFIN dataset (MATLAB) 2 | %Download LEVIR dataset here: https://www.dropbox.com/s/h9jl2ygznsaeg5d/LEVIR-CD-256.zip 3 | %This code generate 256x256 image partches required for the train/val/test 4 | %Please create folders according to following format. 5 | %DSIFN-CD-256 6 | %------(train) 7 | % |---> A 8 | % |---> B 9 | % |---> label 10 | %------(val) 11 | % |---> A 12 | % |---> B 13 | % |---> label 14 | %------(test) 15 | % |---> A 16 | % |---> B 17 | % |---> label 18 | %Then run this code 19 | %Then copy all images in train-A, val-A, test-A to a folder name A 20 | %Then copy all images in train-B, val-B, test-B to a folder name B 21 | %Then copy all images in train-label, val-label, test-label to a folder name label 22 | 23 | clear all; 24 | close all; 25 | clc; 26 | 27 | %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 28 | %Train-A 29 | imgs_name = struct2cell(dir('LEVIR-CD/train/A/*.png')); 30 | for i=1:1:length(imgs_name) 31 | img_file_name = imgs_name{1,i}; 32 | temp = imread(strcat('LEVIR-CD/train/A/',img_file_name)); 33 | c=1; 34 | for j=1:4 35 | for k=1:4 36 | patch = temp((j-1)*256+1: j*256, (k-1)*256+1: k*256, :); 37 | imwrite(patch, strcat('LEVIR-CD256/train/A/', img_file_name(1:end-4), '_', num2str(c), '.png')); 38 | c=c+1; 39 | end 40 | end 41 | 42 | end 43 | 44 | %Train-B 45 | imgs_name = struct2cell(dir('LEVIR-CD/train/B/*.png')); 46 | for i=1:1:length(imgs_name) 47 | img_file_name = imgs_name{1,i}; 48 | temp = imread(strcat('LEVIR-CD/train/B/',img_file_name)); 49 | c=1; 50 | for j=1:4 51 | for k=1:4 52 | patch = temp((j-1)*256+1: j*256, (k-1)*256+1: k*256, :); 53 | imwrite(patch, strcat('LEVIR-CD256/train/B/', img_file_name(1:end-4), '_', num2str(c), '.png')); 54 | c=c+1; 55 | end 56 | end 57 | 58 | end 59 | 60 | %Train-label 61 | imgs_name = struct2cell(dir('LEVIR-CD/train/label/*.png')); 62 | for i=1:1:length(imgs_name) 63 | img_file_name = imgs_name{1,i}; 64 | temp = imread(strcat('LEVIR-CD/train/label/',img_file_name)); 65 | c=1; 66 | for j=1:4 67 | for k=1:4 68 | patch = temp((j-1)*256+1: j*256, (k-1)*256+1: k*256, :); 69 | imwrite(patch, strcat('LEVIR-CD256/train/label/', img_file_name(1:end-4), '_', num2str(c), '.png')); 70 | c=c+1; 71 | end 72 | end 73 | 74 | end 75 | 76 | %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 77 | %Test-A 78 | imgs_name = struct2cell(dir('LEVIR-CD/test/A/*.png')); 79 | for i=1:1:length(imgs_name) 80 | img_file_name = imgs_name{1,i}; 81 | temp = imread(strcat('LEVIR-CD/test/A/',img_file_name)); 82 | c=1; 83 | for j=1:4 84 | for k=1:4 85 | patch = temp((j-1)*256+1: j*256, (k-1)*256+1: k*256, :); 86 | imwrite(patch, strcat('LEVIR-CD256/test/A/', img_file_name(1:end-4), '_', num2str(c), '.png')); 87 | c=c+1; 88 | end 89 | end 90 | 91 | end 92 | 93 | %Test-B 94 | imgs_name = struct2cell(dir('LEVIR-CD/test/B/*.png')); 95 | for i=1:1:length(imgs_name) 96 | img_file_name = imgs_name{1,i}; 97 | temp = imread(strcat('LEVIR-CD/test/B/',img_file_name)); 98 | c=1; 99 | for j=1:4 100 | for k=1:4 101 | patch = temp((j-1)*256+1: j*256, (k-1)*256+1: k*256, :); 102 | imwrite(patch, strcat('LEVIR-CD256/test/B/', img_file_name(1:end-4), '_', num2str(c), '.png')); 103 | c=c+1; 104 | end 105 | end 106 | 107 | end 108 | 109 | %Test-label 110 | imgs_name = struct2cell(dir('LEVIR-CD/test/label/*.png')); 111 | for i=1:1:length(imgs_name) 112 | img_file_name = imgs_name{1,i}; 113 | temp = imread(strcat('LEVIR-CD/test/label/',img_file_name)); 114 | c=1; 115 | for j=1:4 116 | for k=1:4 117 | patch = temp((j-1)*256+1: j*256, (k-1)*256+1: k*256, :); 118 | imwrite(patch, strcat('LEVIR-CD256/test/label/', img_file_name(1:end-4), '_', num2str(c), '.png')); 119 | c=c+1; 120 | end 121 | end 122 | 123 | end 124 | 125 | 126 | %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 127 | %val-A 128 | imgs_name = struct2cell(dir('LEVIR-CD/val/A/*.png')); 129 | for i=1:1:length(imgs_name) 130 | img_file_name = imgs_name{1,i}; 131 | temp = imread(strcat('LEVIR-CD/val/A/',img_file_name)); 132 | c=1; 133 | for j=1:4 134 | for k=1:4 135 | patch = temp((j-1)*256+1: j*256, (k-1)*256+1: k*256, :); 136 | imwrite(patch, strcat('LEVIR-CD256/val/A/', img_file_name(1:end-4), '_', num2str(c), '.png')); 137 | c=c+1; 138 | end 139 | end 140 | 141 | end 142 | 143 | %val-B 144 | imgs_name = struct2cell(dir('LEVIR-CD/val/B/*.png')); 145 | for i=1:1:length(imgs_name) 146 | img_file_name = imgs_name{1,i}; 147 | temp = imread(strcat('LEVIR-CD/val/B/',img_file_name)); 148 | c=1; 149 | for j=1:4 150 | for k=1:4 151 | patch = temp((j-1)*256+1: j*256, (k-1)*256+1: k*256, :); 152 | imwrite(patch, strcat('LEVIR-CD256/val/B/', img_file_name(1:end-4), '_', num2str(c), '.png')); 153 | c=c+1; 154 | end 155 | end 156 | 157 | end 158 | 159 | %val-label 160 | imgs_name = struct2cell(dir('LEVIR-CD/val/label/*.png')); 161 | for i=1:1:length(imgs_name) 162 | img_file_name = imgs_name{1,i}; 163 | temp = imread(strcat('LEVIR-CD/val/label/',img_file_name)); 164 | c=1; 165 | for j=1:4 166 | for k=1:4 167 | patch = temp((j-1)*256+1: j*256, (k-1)*256+1: k*256, :); 168 | imwrite(patch, strcat('LEVIR-CD256/val/label/', img_file_name(1:end-4), '_', num2str(c), '.png')); 169 | c=c+1; 170 | end 171 | end 172 | 173 | end 174 | -------------------------------------------------------------------------------- /STADE-CDNet/datasets/CD_dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | change detection data 3 | """ 4 | 5 | import os 6 | from PIL import Image 7 | import numpy as np 8 | 9 | from torch.utils import data 10 | 11 | from datasets.data_utils import CDDataAugmentation 12 | 13 | 14 | """ 15 | CD data set with pixel-level labels; 16 | ├─image 17 | ├─image_post 18 | ├─label 19 | └─list 20 | """ 21 | IMG_FOLDER_NAME = "A" 22 | IMG_POST_FOLDER_NAME = 'B' 23 | LIST_FOLDER_NAME = 'list' 24 | ANNOT_FOLDER_NAME = "label" 25 | 26 | IGNORE = 255 27 | 28 | label_suffix='.png' # jpg for gan dataset, others : png 29 | 30 | def load_img_name_list(dataset_path): 31 | img_name_list = np.loadtxt(dataset_path, dtype=np.str) 32 | if img_name_list.ndim == 2: 33 | return img_name_list[:, 0] 34 | return img_name_list 35 | 36 | 37 | def load_image_label_list_from_npy(npy_path, img_name_list): 38 | cls_labels_dict = np.load(npy_path, allow_pickle=True).item() 39 | return [cls_labels_dict[img_name] for img_name in img_name_list] 40 | 41 | 42 | def get_img_post_path(root_dir,img_name): 43 | return os.path.join(root_dir, IMG_POST_FOLDER_NAME, img_name) 44 | 45 | 46 | def get_img_path(root_dir, img_name): 47 | return os.path.join(root_dir, IMG_FOLDER_NAME, img_name) 48 | 49 | 50 | def get_label_path(root_dir, img_name): 51 | return os.path.join(root_dir, ANNOT_FOLDER_NAME, img_name.replace('.jpg', label_suffix)) 52 | 53 | 54 | class ImageDataset(data.Dataset): 55 | """VOCdataloder""" 56 | def __init__(self, root_dir, split='train', img_size=256, is_train=True,to_tensor=True): 57 | super(ImageDataset, self).__init__() 58 | self.root_dir = root_dir 59 | self.img_size = img_size 60 | self.split = split #train | train_aug | val 61 | # self.list_path = self.root_dir + '/' + LIST_FOLDER_NAME + '/' + self.list + '.txt' 62 | self.list_path = os.path.join(self.root_dir, LIST_FOLDER_NAME, self.split+'.txt') 63 | self.img_name_list = load_img_name_list(self.list_path) 64 | 65 | self.A_size = len(self.img_name_list) # get the size of dataset A 66 | self.to_tensor = to_tensor 67 | if is_train: 68 | self.augm = CDDataAugmentation( 69 | img_size=self.img_size, 70 | with_random_hflip=True, 71 | with_random_vflip=True, 72 | with_scale_random_crop=True, 73 | with_random_blur=True, 74 | random_color_tf=True 75 | ) 76 | else: 77 | self.augm = CDDataAugmentation( 78 | img_size=self.img_size 79 | ) 80 | def __getitem__(self, index): 81 | name = self.img_name_list[index] 82 | A_path = get_img_path(self.root_dir, self.img_name_list[index % self.A_size]) 83 | B_path = get_img_post_path(self.root_dir, self.img_name_list[index % self.A_size]) 84 | 85 | img = np.asarray(Image.open(A_path).convert('RGB')) 86 | img_B = np.asarray(Image.open(B_path).convert('RGB')) 87 | 88 | [img, img_B], _ = self.augm.transform([img, img_B],[], to_tensor=self.to_tensor) 89 | 90 | return {'A': img, 'B': img_B, 'name': name} 91 | 92 | def __len__(self): 93 | """Return the total number of images in the dataset.""" 94 | return self.A_size 95 | 96 | 97 | class CDDataset(ImageDataset): 98 | 99 | def __init__(self, root_dir, img_size, split='train', is_train=True, label_transform=None, 100 | to_tensor=True): 101 | super(CDDataset, self).__init__(root_dir, img_size=img_size, split=split, is_train=is_train, 102 | to_tensor=to_tensor) 103 | self.label_transform = label_transform 104 | 105 | def __getitem__(self, index): 106 | name = self.img_name_list[index] 107 | A_path = get_img_path(self.root_dir, self.img_name_list[index % self.A_size]) 108 | B_path = get_img_post_path(self.root_dir, self.img_name_list[index % self.A_size]) 109 | img = np.asarray(Image.open(A_path).convert('RGB')) 110 | img_B = np.asarray(Image.open(B_path).convert('RGB')) 111 | L_path = get_label_path(self.root_dir, self.img_name_list[index % self.A_size]) 112 | 113 | label = np.array(Image.open(L_path), dtype=np.uint8) 114 | # if you are getting error because of dim mismatch ad [:,:,0] at the end 115 | 116 | # 二分类中,前景标注为255 117 | if self.label_transform == 'norm': 118 | label = label // 255 119 | 120 | [img, img_B], [label] = self.augm.transform([img, img_B], [label], to_tensor=self.to_tensor) 121 | # print(label.max()) 122 | 123 | return {'name': name, 'A': img, 'B': img_B, 'L': label} 124 | 125 | -------------------------------------------------------------------------------- /STADE-CDNet/datasets/data_utils.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | 4 | from PIL import Image 5 | from PIL import ImageFilter 6 | 7 | import torchvision.transforms.functional as TF 8 | from torchvision import transforms 9 | import torch 10 | 11 | 12 | def to_tensor_and_norm(imgs, labels): 13 | # to tensor 14 | imgs = [TF.to_tensor(img) for img in imgs] 15 | labels = [torch.from_numpy(np.array(img, np.uint8)).unsqueeze(dim=0) 16 | for img in labels] 17 | 18 | imgs = [TF.normalize(img, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 19 | for img in imgs] 20 | return imgs, labels 21 | 22 | 23 | class CDDataAugmentation: 24 | 25 | def __init__( 26 | self, 27 | img_size, 28 | with_random_hflip=False, 29 | with_random_vflip=False, 30 | with_random_rot=False, 31 | with_random_crop=False, 32 | with_scale_random_crop=False, 33 | with_random_blur=False, 34 | random_color_tf=False 35 | ): 36 | self.img_size = img_size 37 | if self.img_size is None: 38 | self.img_size_dynamic = True 39 | else: 40 | self.img_size_dynamic = False 41 | self.with_random_hflip = with_random_hflip 42 | self.with_random_vflip = with_random_vflip 43 | self.with_random_rot = with_random_rot 44 | self.with_random_crop = with_random_crop 45 | self.with_scale_random_crop = with_scale_random_crop 46 | self.with_random_blur = with_random_blur 47 | self.random_color_tf=random_color_tf 48 | def transform(self, imgs, labels, to_tensor=True): 49 | """ 50 | :param imgs: [ndarray,] 51 | :param labels: [ndarray,] 52 | :return: [ndarray,],[ndarray,] 53 | """ 54 | # resize image and covert to tensor 55 | imgs = [TF.to_pil_image(img) for img in imgs] 56 | if self.img_size is None: 57 | self.img_size = None 58 | 59 | if not self.img_size_dynamic: 60 | if imgs[0].size != (self.img_size, self.img_size): 61 | imgs = [TF.resize(img, [self.img_size, self.img_size], interpolation=3) 62 | for img in imgs] 63 | else: 64 | self.img_size = imgs[0].size[0] 65 | 66 | labels = [TF.to_pil_image(img) for img in labels] 67 | if len(labels) != 0: 68 | if labels[0].size != (self.img_size, self.img_size): 69 | labels = [TF.resize(img, [self.img_size, self.img_size], interpolation=0) 70 | for img in labels] 71 | 72 | random_base = 0.5 73 | if self.with_random_hflip and random.random() > 0.5: 74 | imgs = [TF.hflip(img) for img in imgs] 75 | labels = [TF.hflip(img) for img in labels] 76 | 77 | if self.with_random_vflip and random.random() > 0.5: 78 | imgs = [TF.vflip(img) for img in imgs] 79 | labels = [TF.vflip(img) for img in labels] 80 | 81 | if self.with_random_rot and random.random() > random_base: 82 | angles = [90, 180, 270] 83 | index = random.randint(0, 2) 84 | angle = angles[index] 85 | imgs = [TF.rotate(img, angle) for img in imgs] 86 | labels = [TF.rotate(img, angle) for img in labels] 87 | 88 | if self.with_random_crop and random.random() > 0: 89 | i, j, h, w = transforms.RandomResizedCrop(size=self.img_size). \ 90 | get_params(img=imgs[0], scale=(0.8, 1.2), ratio=(1, 1)) 91 | 92 | imgs = [TF.resized_crop(img, i, j, h, w, 93 | size=(self.img_size, self.img_size), 94 | interpolation=Image.CUBIC) 95 | for img in imgs] 96 | 97 | labels = [TF.resized_crop(img, i, j, h, w, 98 | size=(self.img_size, self.img_size), 99 | interpolation=Image.NEAREST) 100 | for img in labels] 101 | 102 | if self.with_scale_random_crop: 103 | # rescale 104 | scale_range = [1, 1.2] 105 | target_scale = scale_range[0] + random.random() * (scale_range[1] - scale_range[0]) 106 | 107 | imgs = [pil_rescale(img, target_scale, order=3) for img in imgs] 108 | labels = [pil_rescale(img, target_scale, order=0) for img in labels] 109 | # crop 110 | imgsize = imgs[0].size # h, w 111 | box = get_random_crop_box(imgsize=imgsize, cropsize=self.img_size) 112 | imgs = [pil_crop(img, box, cropsize=self.img_size, default_value=0) 113 | for img in imgs] 114 | labels = [pil_crop(img, box, cropsize=self.img_size, default_value=255) 115 | for img in labels] 116 | 117 | if self.with_random_blur and random.random() > 0: 118 | radius = random.random() 119 | imgs = [img.filter(ImageFilter.GaussianBlur(radius=radius)) 120 | for img in imgs] 121 | 122 | if self.random_color_tf: 123 | color_jitter = transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.3) 124 | imgs_tf = [] 125 | for img in imgs: 126 | tf = transforms.ColorJitter( 127 | color_jitter.brightness, 128 | color_jitter.contrast, 129 | color_jitter.saturation, 130 | color_jitter.hue) 131 | imgs_tf.append(tf(img)) 132 | imgs = imgs_tf 133 | 134 | if to_tensor: 135 | # to tensor 136 | imgs = [TF.to_tensor(img) for img in imgs] 137 | labels = [torch.from_numpy(np.array(img, np.uint8)).unsqueeze(dim=0) 138 | for img in labels] 139 | 140 | imgs = [TF.normalize(img, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 141 | for img in imgs] 142 | 143 | return imgs, labels 144 | 145 | 146 | def pil_crop(image, box, cropsize, default_value): 147 | assert isinstance(image, Image.Image) 148 | img = np.array(image) 149 | 150 | if len(img.shape) == 3: 151 | cont = np.ones((cropsize, cropsize, img.shape[2]), img.dtype)*default_value 152 | else: 153 | cont = np.ones((cropsize, cropsize), img.dtype)*default_value 154 | cont[box[0]:box[1], box[2]:box[3]] = img[box[4]:box[5], box[6]:box[7]] 155 | 156 | return Image.fromarray(cont) 157 | 158 | 159 | def get_random_crop_box(imgsize, cropsize): 160 | h, w = imgsize 161 | ch = min(cropsize, h) 162 | cw = min(cropsize, w) 163 | 164 | w_space = w - cropsize 165 | h_space = h - cropsize 166 | 167 | if w_space > 0: 168 | cont_left = 0 169 | img_left = random.randrange(w_space + 1) 170 | else: 171 | cont_left = random.randrange(-w_space + 1) 172 | img_left = 0 173 | 174 | if h_space > 0: 175 | cont_top = 0 176 | img_top = random.randrange(h_space + 1) 177 | else: 178 | cont_top = random.randrange(-h_space + 1) 179 | img_top = 0 180 | 181 | return cont_top, cont_top+ch, cont_left, cont_left+cw, img_top, img_top+ch, img_left, img_left+cw 182 | 183 | 184 | def pil_rescale(img, scale, order): 185 | assert isinstance(img, Image.Image) 186 | height, width = img.size 187 | target_size = (int(np.round(height*scale)), int(np.round(width*scale))) 188 | return pil_resize(img, target_size, order) 189 | 190 | 191 | def pil_resize(img, size, order): 192 | assert isinstance(img, Image.Image) 193 | if size[0] == img.size[0] and size[1] == img.size[1]: 194 | return img 195 | if order == 3: 196 | resample = Image.BICUBIC 197 | elif order == 0: 198 | resample = Image.NEAREST 199 | return img.resize(size[::-1], resample) 200 | -------------------------------------------------------------------------------- /STADE-CDNet/eval_zl.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | import torch 3 | from models.evaluator import * 4 | 5 | print(torch.cuda.is_available()) 6 | 7 | 8 | """ 9 | eval the CD model 10 | """ 11 | 12 | def main(): 13 | # ------------ 14 | # args 15 | # ------------ 16 | parser = ArgumentParser() 17 | parser.add_argument('--gpu_ids', type=str, default="your need", help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU') 18 | parser.add_argument('--project_name', default='test', type=str) 19 | parser.add_argument('--print_models', default=False, type=bool, help='print models') 20 | parser.add_argument('--checkpoints_root', default='checkpoints', type=str) 21 | parser.add_argument('--vis_root', default='vis', type=str) 22 | 23 | # data 24 | parser.add_argument('--num_workers', default="your need", type=int) 25 | parser.add_argument('--dataset', default='CDDataset', type=str) 26 | parser.add_argument('--data_name', default='LEVIR', type=str) 27 | 28 | parser.add_argument('--batch_size', default=1, type=int) 29 | parser.add_argument('--split', default="test", type=str) 30 | 31 | parser.add_argument('--img_size', default="your data need", type=int) 32 | 33 | # model 34 | parser.add_argument('--n_class', default=2, type=int) 35 | parser.add_argument('--embed_dim', default="your need", type=int) 36 | parser.add_argument('--net_G', default='base_transformer_pos_s4_dd8_dedim8', type=str, 37 | help='base_resnet18 | base_transformer_pos_s4_dd8 | base_transformer_pos_s4_dd8_dedim8|') 38 | 39 | parser.add_argument('--checkpoint_name', default='best_ckpt.pt', type=str) 40 | 41 | args = parser.parse_args() 42 | utils.get_device(args) 43 | print(args.gpu_ids) 44 | 45 | # checkpoints dir 46 | args.checkpoint_dir = os.path.join(args.checkpoints_root, args.project_name) 47 | os.makedirs(args.checkpoint_dir, exist_ok=True) 48 | # visualize dir 49 | args.vis_dir = os.path.join(args.vis_root, args.project_name) 50 | os.makedirs(args.vis_dir, exist_ok=True) 51 | 52 | dataloader = utils.get_loader(args.data_name, img_size=args.img_size, 53 | batch_size=args.batch_size, is_train=False, 54 | split=args.split) 55 | model = CDEvaluator(args=args, dataloader=dataloader) 56 | 57 | model.eval_models(checkpoint_name=args.checkpoint_name) 58 | 59 | 60 | if __name__ == '__main__': 61 | main() 62 | 63 | -------------------------------------------------------------------------------- /STADE-CDNet/main_zl.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | import torch 3 | from models.trainer import * 4 | 5 | print(torch.cuda.is_available()) 6 | 7 | """ 8 | the main function for training the CD networks 9 | """ 10 | 11 | 12 | def train(args): 13 | dataloaders = utils.get_loaders(args) 14 | model = CDTrainer(args=args, dataloaders=dataloaders) 15 | model.train_models() 16 | 17 | 18 | def test(args): 19 | from models.evaluator import CDEvaluator 20 | dataloader = utils.get_loader(args.data_name, img_size=args.img_size, 21 | batch_size=args.batch_size, is_train=False, 22 | split='test') 23 | model = CDEvaluator(args=args, dataloader=dataloader) 24 | 25 | model.eval_models() 26 | 27 | 28 | if __name__ == '__main__': 29 | # ------------ 30 | # args 31 | # ------------ 32 | parser = ArgumentParser() 33 | parser.add_argument('--gpu_ids', type=str, default='your need', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU') 34 | parser.add_argument('--project_name', default='STADE-CD', type=str) 35 | parser.add_argument('--checkpoint_root', default='where', type=str) 36 | parser.add_argument('--vis_root', default='where', type=str) 37 | parser.add_argument('--output_folder', default='samples_LEVIR/predict_STADE-CD', type=str) 38 | # data 39 | parser.add_argument('--num_workers', default=2, type=int) 40 | parser.add_argument('--dataset', default='CDDataset', type=str) 41 | parser.add_argument('--data_name', default='DSIFN', type=str) 42 | 43 | parser.add_argument('--batch_size', default="your need", type=int,help='The parameters I set = 8') # 44 | parser.add_argument('--split', default="train", type=str) 45 | parser.add_argument('--split_val', default="val", type=str) 46 | 47 | parser.add_argument('--img_size', default=256, type=int) 48 | parser.add_argument('--shuffle_AB', default=False, type=str) 49 | 50 | # model 51 | parser.add_argument('--n_class', default=2, type=int) 52 | parser.add_argument('--embed_dim', default=64, type=int) 53 | parser.add_argument('--pretrain', default=None, type=str) 54 | parser.add_argument('--multi_scale_train', default=False, type=str) 55 | parser.add_argument('--multi_scale_infer', default=False, type=str) 56 | parser.add_argument('--multi_pred_weights', nargs = '+', type = float, default = "your need",) 57 | 58 | parser.add_argument('--net_G', default='base_transformer_pos_s4_dd8', type=str, 59 | help='base_resnet18 | base_transformer_pos_s4 | ' 60 | 'base_transformer_pos_s4_dd8 | ' 61 | 'base_transformer_pos_s4_dd8_dedim8|ChangeFormerV5|SiamUnet_diff') 62 | parser.add_argument('--loss', default='ce', type=str) 63 | 64 | # optimizer 65 | parser.add_argument('--optimizer', default='adamw', type=str) 66 | parser.add_argument('--lr', default="your need", type=float,help='The parameters I set = 0.00009567') 67 | parser.add_argument('--max_epochs', default=406, type=int) 68 | parser.add_argument('--lr_policy', default='linear', type=str, 69 | help='linear | step') 70 | parser.add_argument('--lr_decay_iters', default=100, type=int) 71 | 72 | args = parser.parse_args() 73 | utils.get_device(args) 74 | print(args.gpu_ids) 75 | 76 | # checkpoints dir 77 | args.checkpoint_dir = os.path.join(args.checkpoint_root, args.project_name) 78 | os.makedirs(args.checkpoint_dir, exist_ok=True) 79 | # visualize dir 80 | args.vis_dir = os.path.join(args.vis_root, args.project_name) 81 | os.makedirs(args.vis_dir, exist_ok=True) 82 | 83 | train(args) 84 | 85 | test(args) 86 | -------------------------------------------------------------------------------- /STADE-CDNet/misc/imutils.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | import cv2 4 | from PIL import Image 5 | from PIL import ImageFilter 6 | import PIL 7 | import tifffile 8 | 9 | 10 | def cv_rotate(image, angle, borderValue): 11 | """ 12 | rot angle, fill with borderValue 13 | """ 14 | # grab the dimensions of the image and then determine the 15 | # center 16 | (h, w) = image.shape[:2] 17 | (cX, cY) = (w // 2, h // 2) 18 | 19 | # grab the rotation matrix (applying the negative of the 20 | # angle to rotate clockwise), then grab the sine and cosine 21 | # (i.e., the rotation components of the matrix) 22 | # -angle位置参数为角度参数负值表示顺时针旋转; 1.0位置参数scale是调整尺寸比例(图像缩放参数),建议0.75 23 | M = cv2.getRotationMatrix2D((cX, cY), -angle, 1.0) 24 | cos = np.abs(M[0, 0]) 25 | sin = np.abs(M[0, 1]) 26 | 27 | # compute the new bounding dimensions of the image 28 | nW = int((h * sin) + (w * cos)) 29 | nH = int((h * cos) + (w * sin)) 30 | 31 | # adjust the rotation matrix to take into account translation 32 | M[0, 2] += (nW / 2) - cX 33 | M[1, 2] += (nH / 2) - cY 34 | if isinstance(borderValue, int): 35 | values = (borderValue, borderValue, borderValue) 36 | else: 37 | values = borderValue 38 | # perform the actual rotation and return the image 39 | return cv2.warpAffine(image, M, (nW, nH), borderValue=values) 40 | 41 | 42 | def pil_resize(img, size, order): 43 | if size[0] == img.shape[0] and size[1] == img.shape[1]: 44 | return img 45 | 46 | if order == 3: 47 | resample = Image.BICUBIC 48 | elif order == 0: 49 | resample = Image.NEAREST 50 | 51 | return np.asarray(Image.fromarray(img).resize(size[::-1], resample)) 52 | 53 | 54 | def pil_rescale(img, scale, order): 55 | height, width = img.shape[:2] 56 | target_size = (int(np.round(height*scale)), int(np.round(width*scale))) 57 | return pil_resize(img, target_size, order) 58 | 59 | 60 | def pil_rotate(img, degree, default_value): 61 | if isinstance(default_value, tuple): 62 | values = (default_value[0], default_value[1], default_value[2], 0) 63 | else: 64 | values = (default_value, default_value, default_value,0) 65 | img = Image.fromarray(img) 66 | if img.mode =='RGB': 67 | # set img padding == default_value 68 | img2 = img.convert('RGBA') 69 | rot = img2.rotate(degree, expand=1) 70 | fff = Image.new('RGBA', rot.size, values) # 灰色 71 | out = Image.composite(rot, fff, rot) 72 | img = out.convert(img.mode) 73 | 74 | else: 75 | # set label padding == default_value 76 | img2 = img.convert('RGBA') 77 | rot = img2.rotate(degree, expand=1) 78 | # a white image same size as rotated image 79 | fff = Image.new('RGBA', rot.size, values) 80 | # create a composite image using the alpha layer of rot as a mask 81 | out = Image.composite(rot, fff, rot) 82 | img = out.convert(img.mode) 83 | 84 | return np.asarray(img) 85 | 86 | 87 | def random_resize_long_image_list(img_list, min_long, max_long): 88 | target_long = random.randint(min_long, max_long) 89 | h, w = img_list[0].shape[:2] 90 | if w < h: 91 | scale = target_long / h 92 | else: 93 | scale = target_long / w 94 | out = [] 95 | for img in img_list: 96 | out.append(pil_rescale(img, scale, 3) ) 97 | return out 98 | 99 | 100 | def random_resize_long(img, min_long, max_long): 101 | target_long = random.randint(min_long, max_long) 102 | h, w = img.shape[:2] 103 | 104 | if w < h: 105 | scale = target_long / h 106 | else: 107 | scale = target_long / w 108 | 109 | return pil_rescale(img, scale, 3) 110 | 111 | 112 | def random_scale_list(img_list, scale_range, order): 113 | """ 114 | 输入:图像列表 115 | """ 116 | target_scale = scale_range[0] + random.random() * (scale_range[1] - scale_range[0]) 117 | 118 | if isinstance(img_list, tuple): 119 | assert img_list.__len__() == 2 120 | img1 = [] 121 | img2 = [] 122 | for img in img_list[0]: 123 | img1.append(pil_rescale(img, target_scale, order[0])) 124 | for img in img_list[1]: 125 | img2.append(pil_rescale(img, target_scale, order[1])) 126 | return (img1, img2) 127 | else: 128 | out = [] 129 | for img in img_list: 130 | out.append(pil_rescale(img, target_scale, order)) 131 | return out 132 | 133 | 134 | def random_scale(img, scale_range, order): 135 | 136 | target_scale = scale_range[0] + random.random() * (scale_range[1] - scale_range[0]) 137 | 138 | if isinstance(img, tuple): 139 | return (pil_rescale(img[0], target_scale, order[0]), pil_rescale(img[1], target_scale, order[1])) 140 | else: 141 | return pil_rescale(img, target_scale, order) 142 | 143 | 144 | def random_rotate_list(img_list, max_degree, default_values): 145 | degree = random.random() * max_degree 146 | if isinstance(img_list, tuple): 147 | assert img_list.__len__() == 2 148 | img1 = [] 149 | img2 = [] 150 | for img in img_list[0]: 151 | assert isinstance(img, np.ndarray) 152 | img1.append((pil_rotate(img, degree, default_values[0]))) 153 | for img in img_list[1]: 154 | img2.append((pil_rotate(img, degree, default_values[1]))) 155 | return (img1, img2) 156 | else: 157 | out = [] 158 | for img in img_list: 159 | out.append(pil_rotate(img, degree, default_values)) 160 | return out 161 | 162 | 163 | def random_rotate(img, max_degree, default_values): 164 | degree = random.random() * max_degree 165 | if isinstance(img, tuple): 166 | return (pil_rotate(img[0], degree, default_values[0]), 167 | pil_rotate(img[1], degree, default_values[1])) 168 | else: 169 | return pil_rotate(img, degree, default_values) 170 | 171 | 172 | def random_lr_flip_list(img_list): 173 | 174 | if bool(random.getrandbits(1)): 175 | if isinstance(img_list, tuple): 176 | assert img_list.__len__()==2 177 | img1=list((np.fliplr(m) for m in img_list[0])) 178 | img2=list((np.fliplr(m) for m in img_list[1])) 179 | 180 | return (img1, img2) 181 | else: 182 | return list([np.fliplr(m) for m in img_list]) 183 | else: 184 | return img_list 185 | 186 | 187 | def random_lr_flip(img): 188 | 189 | if bool(random.getrandbits(1)): 190 | if isinstance(img, tuple): 191 | return tuple([np.fliplr(m) for m in img]) 192 | else: 193 | return np.fliplr(img) 194 | else: 195 | return img 196 | 197 | 198 | def get_random_crop_box(imgsize, cropsize): 199 | h, w = imgsize 200 | 201 | ch = min(cropsize, h) 202 | cw = min(cropsize, w) 203 | 204 | w_space = w - cropsize 205 | h_space = h - cropsize 206 | 207 | if w_space > 0: 208 | cont_left = 0 209 | img_left = random.randrange(w_space + 1) 210 | else: 211 | cont_left = random.randrange(-w_space + 1) 212 | img_left = 0 213 | 214 | if h_space > 0: 215 | cont_top = 0 216 | img_top = random.randrange(h_space + 1) 217 | else: 218 | cont_top = random.randrange(-h_space + 1) 219 | img_top = 0 220 | 221 | return cont_top, cont_top+ch, cont_left, cont_left+cw, img_top, img_top+ch, img_left, img_left+cw 222 | 223 | 224 | def random_crop_list(images_list, cropsize, default_values): 225 | 226 | if isinstance(images_list, tuple): 227 | imgsize = images_list[0][0].shape[:2] 228 | elif isinstance(images_list, list): 229 | imgsize = images_list[0].shape[:2] 230 | else: 231 | raise RuntimeError('do not support the type of image_list') 232 | if isinstance(default_values, int): default_values = (default_values,) 233 | 234 | box = get_random_crop_box(imgsize, cropsize) 235 | if isinstance(images_list, tuple): 236 | assert images_list.__len__()==2 237 | img1 = [] 238 | img2 = [] 239 | for img in images_list[0]: 240 | f = default_values[0] 241 | if len(img.shape) == 3: 242 | cont = np.ones((cropsize, cropsize, img.shape[2]), img.dtype)*f 243 | else: 244 | cont = np.ones((cropsize, cropsize), img.dtype)*f 245 | cont[box[0]:box[1], box[2]:box[3]] = img[box[4]:box[5], box[6]:box[7]] 246 | img1.append(cont) 247 | for img in images_list[1]: 248 | f = default_values[1] 249 | if len(img.shape) == 3: 250 | cont = np.ones((cropsize, cropsize, img.shape[2]), img.dtype)*f 251 | else: 252 | cont = np.ones((cropsize, cropsize), img.dtype)*f 253 | cont[box[0]:box[1], box[2]:box[3]] = img[box[4]:box[5], box[6]:box[7]] 254 | img2.append(cont) 255 | return (img1, img2) 256 | else: 257 | out = [] 258 | for img in images_list: 259 | f = default_values 260 | if len(img.shape) == 3: 261 | cont = np.ones((cropsize, cropsize, img.shape[2]), img.dtype) * f 262 | else: 263 | cont = np.ones((cropsize, cropsize), img.dtype) * f 264 | cont[box[0]:box[1], box[2]:box[3]] = img[box[4]:box[5], box[6]:box[7]] 265 | out.append(cont) 266 | return out 267 | 268 | 269 | def random_crop(images, cropsize, default_values): 270 | 271 | if isinstance(images, np.ndarray): images = (images,) 272 | if isinstance(default_values, int): default_values = (default_values,) 273 | 274 | imgsize = images[0].shape[:2] 275 | box = get_random_crop_box(imgsize, cropsize) 276 | 277 | new_images = [] 278 | for img, f in zip(images, default_values): 279 | 280 | if len(img.shape) == 3: 281 | cont = np.ones((cropsize, cropsize, img.shape[2]), img.dtype)*f 282 | else: 283 | cont = np.ones((cropsize, cropsize), img.dtype)*f 284 | cont[box[0]:box[1], box[2]:box[3]] = img[box[4]:box[5], box[6]:box[7]] 285 | new_images.append(cont) 286 | 287 | if len(new_images) == 1: 288 | new_images = new_images[0] 289 | 290 | return new_images 291 | 292 | 293 | def top_left_crop(img, cropsize, default_value): 294 | 295 | h, w = img.shape[:2] 296 | 297 | ch = min(cropsize, h) 298 | cw = min(cropsize, w) 299 | 300 | if len(img.shape) == 2: 301 | container = np.ones((cropsize, cropsize), img.dtype)*default_value 302 | else: 303 | container = np.ones((cropsize, cropsize, img.shape[2]), img.dtype)*default_value 304 | 305 | container[:ch, :cw] = img[:ch, :cw] 306 | 307 | return container 308 | 309 | 310 | def center_crop(img, cropsize, default_value=0): 311 | 312 | h, w = img.shape[:2] 313 | 314 | ch = min(cropsize, h) 315 | cw = min(cropsize, w) 316 | 317 | sh = h - cropsize 318 | sw = w - cropsize 319 | 320 | if sw > 0: 321 | cont_left = 0 322 | img_left = int(round(sw / 2)) 323 | else: 324 | cont_left = int(round(-sw / 2)) 325 | img_left = 0 326 | 327 | if sh > 0: 328 | cont_top = 0 329 | img_top = int(round(sh / 2)) 330 | else: 331 | cont_top = int(round(-sh / 2)) 332 | img_top = 0 333 | 334 | if len(img.shape) == 2: 335 | container = np.ones((cropsize, cropsize), img.dtype)*default_value 336 | else: 337 | container = np.ones((cropsize, cropsize, img.shape[2]), img.dtype)*default_value 338 | 339 | container[cont_top:cont_top+ch, cont_left:cont_left+cw] = \ 340 | img[img_top:img_top+ch, img_left:img_left+cw] 341 | 342 | return container 343 | 344 | 345 | def HWC_to_CHW(img): 346 | return np.transpose(img, (2, 0, 1)) 347 | 348 | 349 | def pil_blur(img, radius): 350 | return np.array(Image.fromarray(img).filter(ImageFilter.GaussianBlur(radius=radius))) 351 | 352 | 353 | def random_blur(img): 354 | radius = random.random() 355 | # print('add blur: ', radius) 356 | if isinstance(img, list): 357 | out = [] 358 | for im in img: 359 | out.append(pil_blur(im, radius)) 360 | return out 361 | elif isinstance(img, np.ndarray): 362 | return pil_blur(img, radius) 363 | else: 364 | print(img) 365 | raise RuntimeError("do not support the input image type!") 366 | 367 | 368 | def save_image(image_numpy, image_path): 369 | """Save a numpy image to the disk 370 | Parameters: 371 | image_numpy (numpy array) -- input numpy array 372 | image_path (str) -- the path of the image 373 | """ 374 | image_pil = Image.fromarray(np.array(image_numpy,dtype=np.uint8)) 375 | image_pil.save(image_path) 376 | 377 | 378 | def im2arr(img_path, mode=1, dtype=np.uint8): 379 | """ 380 | :param img_path: 381 | :param mode: 382 | :return: numpy.ndarray, shape: H*W*C 383 | """ 384 | if mode==1: 385 | img = PIL.Image.open(img_path) 386 | arr = np.asarray(img, dtype=dtype) 387 | else: 388 | arr = tifffile.imread(img_path) 389 | if arr.ndim == 3: 390 | a, b, c = arr.shape 391 | if a < b and a < c: # 当arr为C*H*W时,需要交换通道顺序 392 | arr = arr.transpose([1,2,0]) 393 | # print('shape: ', arr.shape, 'dytpe: ',arr.dtype) 394 | return arr 395 | 396 | 397 | 398 | 399 | 400 | 401 | 402 | -------------------------------------------------------------------------------- /STADE-CDNet/misc/logger_tool.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import time 3 | 4 | 5 | class Logger(object): 6 | def __init__(self, outfile): 7 | self.terminal = sys.stdout 8 | self.log_path = outfile 9 | now = time.strftime("%c") 10 | self.write('================ (%s) ================\n' % now) 11 | 12 | def write(self, message): 13 | self.terminal.write(message) 14 | with open(self.log_path, mode='a') as f: 15 | f.write(message) 16 | 17 | def write_dict(self, dict): 18 | message = '' 19 | for k, v in dict.items(): 20 | message += '%s: %.7f ' % (k, v) 21 | self.write(message) 22 | 23 | def write_dict_str(self, dict): 24 | message = '' 25 | for k, v in dict.items(): 26 | message += '%s: %s ' % (k, v) 27 | self.write(message) 28 | 29 | def flush(self): 30 | self.terminal.flush() 31 | 32 | 33 | class Timer: 34 | def __init__(self, starting_msg = None): 35 | self.start = time.time() 36 | self.stage_start = self.start 37 | 38 | if starting_msg is not None: 39 | print(starting_msg, time.ctime(time.time())) 40 | 41 | def __enter__(self): 42 | return self 43 | 44 | def __exit__(self, exc_type, exc_val, exc_tb): 45 | return 46 | 47 | def update_progress(self, progress): 48 | self.elapsed = time.time() - self.start 49 | self.est_total = self.elapsed / progress 50 | self.est_remaining = self.est_total - self.elapsed 51 | self.est_finish = int(self.start + self.est_total) 52 | 53 | 54 | def str_estimated_complete(self): 55 | return str(time.ctime(self.est_finish)) 56 | 57 | def str_estimated_remaining(self): 58 | return str(self.est_remaining/3600) + 'h' 59 | 60 | def estimated_remaining(self): 61 | return self.est_remaining/3600 62 | 63 | def get_stage_elapsed(self): 64 | return time.time() - self.stage_start 65 | 66 | def reset_stage(self): 67 | self.stage_start = time.time() 68 | 69 | def lapse(self): 70 | out = time.time() - self.stage_start 71 | self.stage_start = time.time() 72 | return out 73 | 74 | -------------------------------------------------------------------------------- /STADE-CDNet/misc/metric_tool.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | ################### metrics ################### 5 | class AverageMeter(object): 6 | """Computes and stores the average and current value""" 7 | def __init__(self): 8 | self.initialized = False 9 | self.val = None 10 | self.avg = None 11 | self.sum = None 12 | self.count = None 13 | 14 | def initialize(self, val, weight): 15 | self.val = val 16 | self.avg = val 17 | self.sum = val * weight 18 | self.count = weight 19 | self.initialized = True 20 | 21 | def update(self, val, weight=1): 22 | if not self.initialized: 23 | self.initialize(val, weight) 24 | else: 25 | self.add(val, weight) 26 | 27 | def add(self, val, weight): 28 | self.val = val 29 | self.sum += val * weight 30 | self.count += weight 31 | self.avg = self.sum / self.count 32 | 33 | def value(self): 34 | return self.val 35 | 36 | def average(self): 37 | return self.avg 38 | 39 | def get_scores(self): 40 | scores_dict = cm2score(self.sum) 41 | return scores_dict 42 | 43 | def clear(self): 44 | self.initialized = False 45 | 46 | 47 | ################### cm metrics ################### 48 | class ConfuseMatrixMeter(AverageMeter): 49 | """Computes and stores the average and current value""" 50 | def __init__(self, n_class): 51 | super(ConfuseMatrixMeter, self).__init__() 52 | self.n_class = n_class 53 | 54 | def update_cm(self, pr, gt, weight=1): 55 | """获得当前混淆矩阵,并计算当前F1得分,并更新混淆矩阵""" 56 | val = get_confuse_matrix(num_classes=self.n_class, label_gts=gt, label_preds=pr) 57 | self.update(val, weight) 58 | current_score = cm2F1(val) 59 | return current_score 60 | 61 | def get_scores(self): 62 | scores_dict = cm2score(self.sum) 63 | return scores_dict 64 | 65 | 66 | 67 | def harmonic_mean(xs): 68 | harmonic_mean = len(xs) / sum((x+1e-6)**-1 for x in xs) 69 | return harmonic_mean 70 | 71 | 72 | def cm2F1(confusion_matrix): 73 | hist = confusion_matrix 74 | n_class = hist.shape[0] 75 | tp = np.diag(hist) 76 | sum_a1 = hist.sum(axis=1) 77 | sum_a0 = hist.sum(axis=0) 78 | # ---------------------------------------------------------------------- # 79 | # 1. Accuracy & Class Accuracy 80 | # ---------------------------------------------------------------------- # 81 | acc = tp.sum() / (hist.sum() + np.finfo(np.float32).eps) 82 | 83 | # recall 84 | recall = tp / (sum_a1 + np.finfo(np.float32).eps) 85 | # acc_cls = np.nanmean(recall) 86 | 87 | # precision 88 | precision = tp / (sum_a0 + np.finfo(np.float32).eps) 89 | 90 | # F1 score 91 | F1 = 2 * recall * precision / (recall + precision + np.finfo(np.float32).eps) 92 | mean_F1 = np.nanmean(F1) 93 | return mean_F1 94 | 95 | 96 | def cm2score(confusion_matrix): 97 | hist = confusion_matrix 98 | n_class = hist.shape[0] 99 | tp = np.diag(hist) 100 | sum_a1 = hist.sum(axis=1) 101 | sum_a0 = hist.sum(axis=0) 102 | # ---------------------------------------------------------------------- # 103 | # 1. Accuracy & Class Accuracy 104 | # ---------------------------------------------------------------------- # 105 | acc = tp.sum() / (hist.sum() + np.finfo(np.float32).eps) 106 | 107 | # recall 108 | recall = tp / (sum_a1 + np.finfo(np.float32).eps) 109 | # acc_cls = np.nanmean(recall) 110 | 111 | # precision 112 | precision = tp / (sum_a0 + np.finfo(np.float32).eps) 113 | 114 | # F1 score 115 | F1 = 2*recall * precision / (recall + precision + np.finfo(np.float32).eps) 116 | mean_F1 = np.nanmean(F1) 117 | # ---------------------------------------------------------------------- # 118 | # 2. Frequency weighted Accuracy & Mean IoU 119 | # ---------------------------------------------------------------------- # 120 | iu = tp / (sum_a1 + hist.sum(axis=0) - tp + np.finfo(np.float32).eps) 121 | mean_iu = np.nanmean(iu) 122 | 123 | freq = sum_a1 / (hist.sum() + np.finfo(np.float32).eps) 124 | fwavacc = (freq[freq > 0] * iu[freq > 0]).sum() 125 | 126 | # 127 | cls_iou = dict(zip(['iou_'+str(i) for i in range(n_class)], iu)) 128 | 129 | cls_precision = dict(zip(['precision_'+str(i) for i in range(n_class)], precision)) 130 | cls_recall = dict(zip(['recall_'+str(i) for i in range(n_class)], recall)) 131 | cls_F1 = dict(zip(['F1_'+str(i) for i in range(n_class)], F1)) 132 | 133 | score_dict = {'acc': acc, 'miou': mean_iu, 'mf1':mean_F1} 134 | score_dict.update(cls_iou) 135 | score_dict.update(cls_F1) 136 | score_dict.update(cls_precision) 137 | score_dict.update(cls_recall) 138 | return score_dict 139 | 140 | 141 | def get_confuse_matrix(num_classes, label_gts, label_preds): 142 | """计算一组预测的混淆矩阵""" 143 | def __fast_hist(label_gt, label_pred): 144 | """ 145 | Collect values for Confusion Matrix 146 | For reference, please see: https://en.wikipedia.org/wiki/Confusion_matrix 147 | :param label_gt: ground-truth 148 | :param label_pred: prediction 149 | :return: values for confusion matrix 150 | """ 151 | mask = (label_gt >= 0) & (label_gt < num_classes) 152 | hist = np.bincount(num_classes * label_gt[mask].astype(int) + label_pred[mask], 153 | minlength=num_classes**2).reshape(num_classes, num_classes) 154 | return hist 155 | confusion_matrix = np.zeros((num_classes, num_classes)) 156 | for lt, lp in zip(label_gts, label_preds): 157 | confusion_matrix += __fast_hist(lt.flatten(), lp.flatten()) 158 | return confusion_matrix 159 | 160 | 161 | def get_mIoU(num_classes, label_gts, label_preds): 162 | confusion_matrix = get_confuse_matrix(num_classes, label_gts, label_preds) 163 | score_dict = cm2score(confusion_matrix) 164 | return score_dict['miou'] 165 | -------------------------------------------------------------------------------- /STADE-CDNet/misc/pyutils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import random 4 | import glob 5 | 6 | 7 | def seed_random(seed=2020): 8 | # 加入以下随机种子,数据输入,随机扩充等保持一致 9 | random.seed(seed) 10 | os.environ['PYTHONHASHSEED'] = str(seed) 11 | np.random.seed(seed) 12 | 13 | 14 | def mkdir(path): 15 | """create a single empty directory if it didn't exist 16 | 17 | Parameters: 18 | path (str) -- a single directory path 19 | """ 20 | if not os.path.exists(path): 21 | os.makedirs(path) 22 | 23 | 24 | def get_paths(image_folder_path, suffix='*.png'): 25 | """从文件夹中返回指定格式的文件 26 | :param image_folder_path: str 27 | :param suffix: str 28 | :return: list 29 | """ 30 | paths = sorted(glob.glob(os.path.join(image_folder_path, suffix))) 31 | return paths 32 | 33 | 34 | def get_paths_from_list(image_folder_path, list): 35 | """从image folder中找到list中的文件,返回path list""" 36 | out = [] 37 | for item in list: 38 | path = os.path.join(image_folder_path,item) 39 | out.append(path) 40 | return sorted(out) 41 | 42 | 43 | -------------------------------------------------------------------------------- /STADE-CDNet/misc/torchutils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.optim import lr_scheduler 3 | from torch.utils.data import Subset 4 | import torch.nn.functional as F 5 | import numpy as np 6 | import math 7 | import random 8 | import os 9 | from torch.nn import MaxPool1d,AvgPool1d 10 | from torch import Tensor 11 | from typing import Iterable, Set, Tuple 12 | 13 | 14 | __all__ = ['cls_accuracy'] 15 | 16 | 17 | 18 | def visualize_imgs(*imgs): 19 | """ 20 | 可视化图像,ndarray格式的图像 21 | :param imgs: ndarray:H*W*C, C=1/3 22 | :return: 23 | """ 24 | import matplotlib.pyplot as plt 25 | nums = len(imgs) 26 | if nums > 1: 27 | fig, axs = plt.subplots(1, nums) 28 | for i, image in enumerate(imgs): 29 | axs[i].imshow(image, cmap='jet') 30 | elif nums == 1: 31 | fig, ax = plt.subplots(1, nums) 32 | for i, image in enumerate(imgs): 33 | ax.imshow(image, cmap='jet') 34 | plt.show() 35 | plt.show() 36 | 37 | def minmax(tensor): 38 | assert tensor.ndim >= 2 39 | shape = tensor.shape 40 | tensor = tensor.view([*shape[:-2], shape[-1]*shape[-2]]) 41 | min_, _ = tensor.min(-1, keepdim=True) 42 | max_, _ = tensor.max(-1, keepdim=True) 43 | return min_, max_ 44 | 45 | def norm_tensor(tensor,min_=None,max_=None, mode='minmax'): 46 | """ 47 | 输入:N*C*H*W / C*H*W / H*W 48 | 输出:在H*W维度的归一化的与原始等大的图 49 | """ 50 | assert tensor.ndim >= 2 51 | shape = tensor.shape 52 | tensor = tensor.view([*shape[:-2], shape[-1]*shape[-2]]) 53 | if mode == 'minmax': 54 | if min_ is None: 55 | min_, _ = tensor.min(-1, keepdim=True) 56 | if max_ is None: 57 | max_, _ = tensor.max(-1, keepdim=True) 58 | tensor = (tensor - min_) / (max_ - min_ + 0.00000000001) 59 | elif mode == 'thres': 60 | N = tensor.shape[-1] 61 | thres_a = 0.001 62 | top_k = round(thres_a*N) 63 | max_ = tensor.topk(top_k, dim=-1, largest=True)[0][..., -1] 64 | max_ = max_.unsqueeze(-1) 65 | min_ = tensor.topk(top_k, dim=-1, largest=False)[0][..., -1] 66 | min_ = min_.unsqueeze(-1) 67 | tensor = (tensor - min_) / (max_ - min_ + 0.00000000001) 68 | 69 | elif mode == 'std': 70 | mean, std = torch.std_mean(tensor, [-1], keepdim=True) 71 | tensor = (tensor - mean)/std 72 | min_, _ = tensor.min(-1, keepdim=True) 73 | max_, _ = tensor.max(-1, keepdim=True) 74 | tensor = (tensor - min_) / (max_ - min_ + 0.00000000001) 75 | elif mode == 'exp': 76 | tai = 1 77 | tensor = torch.nn.functional.softmax(tensor/tai, dim=-1, ) 78 | min_, _ = tensor.min(-1, keepdim=True) 79 | max_, _ = tensor.max(-1, keepdim=True) 80 | tensor = (tensor - min_) / (max_ - min_ + 0.00000000001) 81 | else: 82 | raise NotImplementedError 83 | tensor = torch.clamp(tensor, 0, 1) 84 | return tensor.view(shape) 85 | 86 | # if tensor.ndim == 4: 87 | # B, C, H, W = tensor.shape 88 | # tensor = tensor.view([B, C, -1]) 89 | # min_, _ = tensor.min(-1, keepdim=True) 90 | # max_, _ = tensor.max(-1, keepdim=True) 91 | # tensor = (tensor - min_) / (max_ - min_ + 0.00000000001) 92 | # return tensor.view(B, C, H, W) 93 | # elif tensor.ndim == 3: 94 | # C, H, W = tensor.shape 95 | # tensor = tensor.view([C, -1]) 96 | # min_, _ = tensor.min(-1, keepdim=True) 97 | # max_, _ = tensor.max(-1, keepdim=True) 98 | # tensor = (tensor - min_) / (max_ - min_ + 0.00000000001) 99 | # return tensor.view(C, H, W) 100 | # elif tensor.ndim == 2: 101 | # H, W = tensor.shape 102 | # tensor = tensor.view([-1]) 103 | # min_, _ = tensor.min(-1, keepdim=True) 104 | # max_, _ = tensor.max(-1, keepdim=True) 105 | # tensor = (tensor - min_) / (max_ - min_ + 0.00000000001) 106 | # return tensor.view(H, W) 107 | # else: 108 | # raise NotImplementedError 109 | 110 | def visulize_features(features, normalize=False): 111 | """ 112 | 可视化特征图,各维度make grid到一起 113 | """ 114 | from torchvision.utils import make_grid 115 | assert features.ndim == 4 116 | b,c,h,w = features.shape 117 | features = features.view((b*c, 1, h, w)) 118 | if normalize: 119 | features = norm_tensor(features) 120 | grid = make_grid(features) 121 | visualize_tensors(grid) 122 | 123 | def visualize_tensors(*tensors): 124 | """ 125 | 可视化tensor,支持单通道特征或3通道图像 126 | :param tensors: tensor: C*H*W, C=1/3 127 | :return: 128 | """ 129 | import matplotlib.pyplot as plt 130 | # from misc.torchutils import tensor2np 131 | images = [] 132 | for tensor in tensors: 133 | assert tensor.ndim == 3 or tensor.ndim==2 134 | if tensor.ndim ==3: 135 | assert tensor.shape[0] == 1 or tensor.shape[0] == 3 136 | images.append(tensor2np(tensor)) 137 | nums = len(images) 138 | if nums>1: 139 | fig, axs = plt.subplots(1, nums) 140 | for i, image in enumerate(images): 141 | axs[i].imshow(image, cmap='jet') 142 | plt.show() 143 | elif nums == 1: 144 | fig, ax = plt.subplots(1, nums) 145 | for i, image in enumerate(images): 146 | ax.imshow(image, cmap='jet') 147 | plt.show() 148 | 149 | 150 | def np_to_tensor(image): 151 | """ 152 | input: nd.array: H*W*C/H*W 153 | """ 154 | if isinstance(image, torch.Tensor): 155 | return image 156 | elif isinstance(image, np.ndarray): 157 | if image.ndim == 3: 158 | if image.shape[2]==3: 159 | image = np.transpose(image,[2,0,1]) 160 | elif image.ndim == 2: 161 | image = np.newaxis(image, 0) 162 | image = torch.from_numpy(image) 163 | return image.unsqueeze(0) 164 | 165 | 166 | def seed_torch(seed=2019): 167 | 168 | # 加入以下随机种子,数据输入,随机扩充等保持一致 169 | random.seed(seed) 170 | os.environ['PYTHONHASHSEED'] = str(seed) 171 | np.random.seed(seed) 172 | torch.manual_seed(seed) 173 | torch.cuda.manual_seed(seed) 174 | # 加入所有随机种子后,模型更新后,中间结果还是不一样, 175 | # 发现这一的现象:前两轮,的结果还是一样;随着模型更新结果会变; 176 | # torch.backends.cudnn.benchmark = False 177 | # torch.backends.cudnn.deterministic = True 178 | 179 | def simplex(t: Tensor, axis=1) -> bool: 180 | _sum = t.sum(axis).type(torch.float32) 181 | _ones = torch.ones_like(_sum, dtype=torch.float32) 182 | return torch.allclose(_sum, _ones) 183 | 184 | 185 | # Assert utils 186 | def uniq(a: Tensor) -> Set: 187 | return set(torch.unique(a.cpu()).numpy()) 188 | 189 | def sset(a: Tensor, sub: Iterable) -> bool: 190 | return uniq(a).issubset(sub) 191 | 192 | def eq(a: Tensor, b) -> bool: 193 | return torch.eq(a, b).all() 194 | 195 | def one_hot(t: Tensor, axis=1) -> bool: 196 | return simplex(t, axis) and sset(t, [0, 1]) 197 | 198 | 199 | def class2one_hot(seg: Tensor, C: int) -> Tensor: 200 | if len(seg.shape) == 2: # Only w, h, used by the dataloader 201 | seg = seg.unsqueeze(dim=0) 202 | assert sset(seg, list(range(C))) 203 | 204 | b, w, h = seg.shape # type: Tuple[int, int, int] 205 | 206 | res = torch.stack([seg == c for c in range(C)], dim=1).type(torch.int32) 207 | assert res.shape == (b, C, w, h) 208 | assert one_hot(res) 209 | 210 | return res 211 | 212 | class ChannelMaxPool(MaxPool1d): 213 | def forward(self, input): 214 | n, c, w, h = input.size() 215 | input = input.view(n,c,w*h).permute(0,2,1) 216 | pooled = F.max_pool1d(input, self.kernel_size, self.stride, 217 | self.padding, self.dilation, self.ceil_mode, 218 | self.return_indices) 219 | _, _, c = pooled.size() 220 | pooled = pooled.permute(0,2,1) 221 | return pooled.view(n,c,w,h) 222 | 223 | class ChannelAvePool(AvgPool1d): 224 | def forward(self, input): 225 | n, c, w, h = input.size() 226 | input = input.view(n,c,w*h).permute(0,2,1) 227 | pooled = F.avg_pool1d(input, self.kernel_size, self.stride, 228 | self.padding) 229 | _, _, c = pooled.size() 230 | pooled = pooled.permute(0,2,1) 231 | return pooled.view(n,c,w,h) 232 | 233 | def cross_entropy(input, target, weight=None, reduction='mean',ignore_index=255): 234 | """ 235 | logSoftmax_with_loss 236 | :param input: torch.Tensor, N*C*H*W 237 | :param target: torch.Tensor, N*1*H*W,/ N*H*W 238 | :param weight: torch.Tensor, C 239 | :return: torch.Tensor [0] 240 | """ 241 | target = target.long() 242 | if target.dim() == 4: 243 | target = torch.squeeze(target, dim=1) 244 | if input.shape[-1] != target.shape[-1]: 245 | input = F.interpolate(input, size=target.shape[1:], mode='bilinear',align_corners=True) 246 | 247 | return F.cross_entropy(input=input, target=target, weight=weight, 248 | ignore_index=ignore_index, reduction=reduction) 249 | 250 | def balanced_cross_entropy(input, target, weight=None,ignore_index=255): 251 | """ 252 | 类别均衡的交叉熵损失,暂时只支持2类 253 | TODO: 扩展到多类C>2 254 | """ 255 | if target.dim() == 4: 256 | target = torch.squeeze(target, dim=1) 257 | if input.shape[-1] != target.shape[-1]: 258 | input = F.interpolate(input, size=target.shape[1:], mode='bilinear',align_corners=True) 259 | 260 | # print('target.sum',target.sum()) 261 | pos = (target==1).float() 262 | neg = (target==0).float() 263 | pos_num = torch.sum(pos) + 0.0000001 264 | neg_num = torch.sum(neg) + 0.0000001 265 | # print(pos_num) 266 | # print(neg_num) 267 | target_pos = target.float() 268 | target_pos[target_pos!=1] = ignore_index # 忽略不为正样本的区域 269 | target_neg = target.float() 270 | target_neg[target_neg!=0] = ignore_index # 忽略不为负样本的区域 271 | 272 | # print('target.sum',target.sum()) 273 | 274 | loss_pos = cross_entropy(input, target_pos,weight=weight,reduction='sum',ignore_index=ignore_index) 275 | loss_neg = cross_entropy(input, target_neg,weight=weight,reduction='sum',ignore_index=ignore_index) 276 | # print(loss_neg, loss_pos) 277 | loss = 0.5 * loss_pos / pos_num + 0.5 * loss_neg / neg_num 278 | # loss = (loss_pos + loss_neg)/ (pos_num+neg_num) 279 | return loss 280 | 281 | def get_scheduler(optimizer, opt): 282 | """Return a learning rate scheduler 283 | """ 284 | if opt.lr_policy == 'linear': 285 | def lambda_rule(epoch): 286 | lr_l = 1.0 - max(0, epoch + opt.epoch_count - opt.niter) / float(opt.niter_decay + 1) 287 | return lr_l 288 | scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule) 289 | elif opt.lr_policy == 'poly': 290 | max_step = opt.niter+opt.niter_decay 291 | power = 0.9 292 | def lambda_rule(epoch): 293 | current_step = epoch + opt.epoch_count 294 | lr_l = (1.0 - current_step / (max_step+1)) ** float(power) 295 | return lr_l 296 | scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule) 297 | elif opt.lr_policy == 'step': 298 | scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1) 299 | else: 300 | return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy) 301 | return scheduler 302 | 303 | 304 | def mul_cls_acc(preds, targets, topk=(1,)): 305 | """计算multi-label分类的top-k准确率topk-acc,topk-error=1-topk-acc; 306 | 首先计算每张图的的平均准确率,再计算所有图的平均准确率 307 | :param pred: N * C 308 | :param target: N * C 309 | :param topk: 310 | :return: 311 | """ 312 | with torch.no_grad(): 313 | maxk = max(topk) 314 | bs, C = targets.shape 315 | _, pred = preds.topk(maxk, 1, True, True) 316 | pred += 1 # pred 为类别\in [1,C] 317 | # print('pred: ', pred) 318 | # print('targets: ', targets) 319 | correct = torch.zeros([bs, maxk]).long() # 记录预测正确label数量 320 | if preds.device != torch.device(type='cpu'): 321 | correct = correct.cuda() 322 | for i in range(C): 323 | label = i + 1 324 | target = targets[:, i] * label 325 | # print('target.view: ', target.view(-1, 1).expand_as(pred)) 326 | # print('pred: ', pred) 327 | correct = correct + pred.eq(target.view(-1, 1).expand_as(pred)).long() 328 | # print('correct: ', pred.eq(target.view(-1, 1).expand_as(pred)).long()) 329 | n = (targets == 1).long().sum(1) # N*1, 每张图中含有目标的数量 330 | # print(n) 331 | res = [] 332 | for k in topk: 333 | acc_k = correct[:, :k].sum(1).float() / n.float() # 每张图的平均正确率,预测正确目标数/总目标数 334 | # print(correct[:, :k].sum(1).float()) 335 | acc_k = acc_k.sum()/bs 336 | res.append(acc_k) 337 | # print(acc_k) 338 | return res 339 | 340 | 341 | def cls_accuracy(output, target, topk=(1,)): 342 | """ 343 | Computes the accuracy over the k top predictions for the specified values of k 344 | https://github.com/pytorch/examples/blob/ee964a2eeb41e1712fe719b83645c79bcbd0ba1a/imagenet/main.py#L407 345 | """ 346 | 347 | with torch.no_grad(): 348 | maxk = max(topk) 349 | batch_size = target.size(0) 350 | 351 | _, pred = output.topk(maxk, 1, True, True) 352 | pred = pred.t() 353 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 354 | 355 | res = [] 356 | for k in topk: 357 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 358 | res.append(correct_k.mul_(100.0 / batch_size)) 359 | return res 360 | 361 | class PolyOptimizer(torch.optim.SGD): 362 | 363 | def __init__(self, params, lr, weight_decay, max_step, init_step=0, momentum=0.9): 364 | super().__init__(params, lr, weight_decay) 365 | 366 | self.global_step = init_step 367 | print(self.global_step) 368 | self.max_step = max_step 369 | self.momentum = momentum 370 | 371 | self.__initial_lr = [group['lr'] for group in self.param_groups] 372 | 373 | 374 | def step(self, closure=None): 375 | 376 | if self.global_step < self.max_step: 377 | lr_mult = (1 - self.global_step / self.max_step) ** self.momentum 378 | 379 | for i in range(len(self.param_groups)): 380 | self.param_groups[i]['lr'] = self.__initial_lr[i] * lr_mult 381 | 382 | super().step(closure) 383 | 384 | self.global_step += 1 385 | 386 | 387 | class PolyAdamOptimizer(torch.optim.Adam): 388 | def __init__(self, params, lr, betas, max_step, momentum=0.9): 389 | super().__init__(params, lr, betas) 390 | 391 | self.global_step = 0 392 | self.max_step = max_step 393 | self.momentum = momentum 394 | 395 | self.__initial_lr = [group['lr'] for group in self.param_groups] 396 | 397 | 398 | def step(self, closure=None): 399 | 400 | if self.global_step < self.max_step: 401 | lr_mult = (1 - self.global_step / self.max_step) ** self.momentum 402 | 403 | for i in range(len(self.param_groups)): 404 | self.param_groups[i]['lr'] = self.__initial_lr[i] * lr_mult 405 | 406 | super().step(closure) 407 | self.global_step += 1 408 | # 409 | # from ranger import RangerQH,Ranger 410 | # # https://github.com/lessw2020/Ranger-Deep-Learning-Optimizer/blob/master/ranger/rangerqh.py 411 | # 412 | # class PolyRangerOptimizer(RangerQH): 413 | # 414 | # def __init__(self, params, lr, betas, max_step, momentum=0.9): 415 | # super().__init__(params, lr, betas) 416 | # 417 | # self.global_step = 0 418 | # self.max_step = max_step 419 | # self.momentum = momentum 420 | # 421 | # self.__initial_lr = [group['lr'] for group in self.param_groups] 422 | # 423 | # 424 | # def step(self, closure=None): 425 | # 426 | # if self.global_step < self.max_step: 427 | # lr_mult = (1 - self.global_step / self.max_step) ** self.momentum 428 | # 429 | # for i in range(len(self.param_groups)): 430 | # self.param_groups[i]['lr'] = self.__initial_lr[i] * lr_mult 431 | # 432 | # super().step(closure) 433 | # self.global_step += 1 434 | 435 | class SGDROptimizer(torch.optim.SGD): 436 | 437 | def __init__(self, params, steps_per_epoch, lr=0, weight_decay=0, epoch_start=1, restart_mult=2): 438 | super().__init__(params, lr, weight_decay) 439 | 440 | self.global_step = 0 441 | self.local_step = 0 442 | self.total_restart = 0 443 | 444 | self.max_step = steps_per_epoch * epoch_start 445 | self.restart_mult = restart_mult 446 | 447 | self.__initial_lr = [group['lr'] for group in self.param_groups] 448 | 449 | 450 | def step(self, closure=None): 451 | 452 | if self.local_step >= self.max_step: 453 | self.local_step = 0 454 | self.max_step *= self.restart_mult 455 | self.total_restart += 1 456 | 457 | lr_mult = (1 + math.cos(math.pi * self.local_step / self.max_step))/2 / (self.total_restart + 1) 458 | 459 | for i in range(len(self.param_groups)): 460 | self.param_groups[i]['lr'] = self.__initial_lr[i] * lr_mult 461 | 462 | super().step(closure) 463 | 464 | self.local_step += 1 465 | self.global_step += 1 466 | 467 | 468 | def split_dataset(dataset, n_splits): 469 | 470 | return [Subset(dataset, np.arange(i, len(dataset), n_splits)) for i in range(n_splits)] 471 | 472 | 473 | def gap2d(x, keepdims=False): 474 | out = torch.mean(x.view(x.size(0), x.size(1), -1), -1) 475 | if keepdims: 476 | out = out.view(out.size(0), out.size(1), 1, 1) 477 | 478 | return out 479 | 480 | 481 | def decode_seg(label_mask, toTensor=False): 482 | """ 483 | :param label_mask: mask (np.ndarray): (M, N)/ tensor: N*C*H*W 484 | :return: color label: (M, N, 3), 485 | """ 486 | if not isinstance(label_mask, np.ndarray): 487 | if isinstance(label_mask, torch.Tensor): # get the data from a variable 488 | image_tensor = label_mask.data 489 | else: 490 | return label_mask 491 | label_mask = image_tensor[0][0].cpu().numpy() 492 | 493 | rgb = np.zeros((label_mask.shape[0], label_mask.shape[1], 3),dtype=np.float) 494 | r = label_mask % 6 495 | g = (label_mask % 36) // 6 496 | b = label_mask // 36 497 | # 归一化到[0-1] 498 | rgb[:, :, 0] = r / 6 499 | rgb[:, :, 1] = g / 6 500 | rgb[:, :, 2] = b / 6 501 | if toTensor: 502 | rgb = torch.from_numpy(rgb.transpose([2,0,1])).unsqueeze(0) 503 | 504 | return rgb 505 | 506 | 507 | def tensor2im(input_image, imtype=np.uint8, normalize=True): 508 | """"Converts a Tensor array into a numpy image array. 509 | Parameters: 510 | input_image (tensor) -- the input image tensor array 511 | imtype (type) -- the desired type of the converted numpy array 512 | """ 513 | if not isinstance(input_image, np.ndarray): 514 | if isinstance(input_image, torch.Tensor): # get the data from a variable 515 | image_tensor = input_image.data 516 | else: 517 | return input_image 518 | image_numpy = image_tensor[0].cpu().float().numpy() # convert it into a numpy array 519 | # if image_numpy.shape[0] == 1: # grayscale to RGB 520 | # image_numpy = np.tile(image_numpy, (3, 1, 1)) 521 | if image_numpy.shape[0] == 3: # if RGB 522 | image_numpy = np.transpose(image_numpy, (1, 2, 0)) 523 | if normalize: 524 | image_numpy = (image_numpy + 1) / 2.0 * 255.0 # post-processing: tranpose and scaling 525 | else: # if it is a numpy array, do nothing 526 | image_numpy = input_image 527 | return image_numpy.astype(imtype) 528 | 529 | 530 | def tensor2np(input_image, if_normalize=True): 531 | """ 532 | :param input_image: C*H*W / H*W 533 | :return: ndarray, H*W*C / H*W 534 | """ 535 | if isinstance(input_image, torch.Tensor): # get the data from a variable 536 | image_tensor = input_image.data 537 | image_numpy = image_tensor.cpu().float().numpy() # convert it into a numpy array 538 | 539 | else: 540 | image_numpy = input_image 541 | if image_numpy.ndim == 2: 542 | return image_numpy 543 | elif image_numpy.ndim == 3: 544 | C, H, W = image_numpy.shape 545 | image_numpy = np.transpose(image_numpy, (1, 2, 0)) 546 | # 如果输入为灰度图C==1,则输出array,ndim==2; 547 | if C == 1: 548 | image_numpy = image_numpy[:, :, 0] 549 | if if_normalize and C == 3: 550 | image_numpy = (image_numpy + 1) / 2.0 * 255.0 # post-processing: tranpose and scaling 551 | # add to prevent extreme noises in visual images 552 | image_numpy[image_numpy<0]=0 553 | image_numpy[image_numpy>255]=255 554 | image_numpy = image_numpy.astype(np.uint8) 555 | return image_numpy 556 | 557 | 558 | import ntpath 559 | from misc.imutils import save_image 560 | def save_visuals(visuals, img_dir, name, save_one=True, iter='0'): 561 | """ 562 | """ 563 | # save images to the disk 564 | for label, image in visuals.items(): 565 | N = image.shape[0] 566 | if save_one: 567 | N = 1 568 | # 保存各个bz的数据 569 | for j in range(N): 570 | name_ = ntpath.basename(name[j]) 571 | name_ = name_.split(".")[0] 572 | # print(name_) 573 | image_numpy = tensor2np(image[j], if_normalize=True).astype(np.uint8) 574 | # print(image_numpy) 575 | img_path = os.path.join(img_dir, iter+'_%s_%s.png' % (name_, label)) 576 | save_image(image_numpy, img_path) -------------------------------------------------------------------------------- /STADE-CDNet/models/ChangeFormerBaseNetworks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import math 5 | import torch 6 | 7 | from torch import nn 8 | from torch.nn import init 9 | from torch.nn import functional as F 10 | from torch.autograd import Function 11 | 12 | from math import sqrt 13 | 14 | import random 15 | 16 | class ConvBlock(torch.nn.Module): 17 | def __init__(self, input_size, output_size, kernel_size=3, stride=1, padding=1, bias=True, activation='prelu', norm=None): 18 | super(ConvBlock, self).__init__() 19 | self.conv = torch.nn.Conv2d(input_size, output_size, kernel_size, stride, padding, bias=bias) 20 | 21 | self.norm = norm 22 | if self.norm =='batch': 23 | self.bn = torch.nn.BatchNorm2d(output_size) 24 | elif self.norm == 'instance': 25 | self.bn = torch.nn.InstanceNorm2d(output_size) 26 | 27 | self.activation = activation 28 | if self.activation == 'relu': 29 | self.act = torch.nn.ReLU(True) 30 | elif self.activation == 'prelu': 31 | self.act = torch.nn.PReLU() 32 | elif self.activation == 'lrelu': 33 | self.act = torch.nn.LeakyReLU(0.2, True) 34 | elif self.activation == 'tanh': 35 | self.act = torch.nn.Tanh() 36 | elif self.activation == 'sigmoid': 37 | self.act = torch.nn.Sigmoid() 38 | 39 | def forward(self, x): 40 | if self.norm is not None: 41 | out = self.bn(self.conv(x)) 42 | else: 43 | out = self.conv(x) 44 | 45 | if self.activation != 'no': 46 | return self.act(out) 47 | else: 48 | return out 49 | 50 | class DeconvBlock(torch.nn.Module): 51 | def __init__(self, input_size, output_size, kernel_size=4, stride=2, padding=1, bias=True, activation='prelu', norm=None): 52 | super(DeconvBlock, self).__init__() 53 | self.deconv = torch.nn.ConvTranspose2d(input_size, output_size, kernel_size, stride, padding, bias=bias) 54 | 55 | self.norm = norm 56 | if self.norm == 'batch': 57 | self.bn = torch.nn.BatchNorm2d(output_size) 58 | elif self.norm == 'instance': 59 | self.bn = torch.nn.InstanceNorm2d(output_size) 60 | 61 | self.activation = activation 62 | if self.activation == 'relu': 63 | self.act = torch.nn.ReLU(True) 64 | elif self.activation == 'prelu': 65 | self.act = torch.nn.PReLU() 66 | elif self.activation == 'lrelu': 67 | self.act = torch.nn.LeakyReLU(0.2, True) 68 | elif self.activation == 'tanh': 69 | self.act = torch.nn.Tanh() 70 | elif self.activation == 'sigmoid': 71 | self.act = torch.nn.Sigmoid() 72 | 73 | def forward(self, x): 74 | if self.norm is not None: 75 | out = self.bn(self.deconv(x)) 76 | else: 77 | out = self.deconv(x) 78 | 79 | if self.activation is not None: 80 | return self.act(out) 81 | else: 82 | return out 83 | 84 | 85 | class ConvLayer(nn.Module): 86 | def __init__(self, in_channels, out_channels, kernel_size, stride, padding): 87 | super(ConvLayer, self).__init__() 88 | # reflection_padding = kernel_size // 2 89 | # self.reflection_pad = nn.ReflectionPad2d(reflection_padding) 90 | self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding) 91 | 92 | def forward(self, x): 93 | # out = self.reflection_pad(x) 94 | out = self.conv2d(x) 95 | return out 96 | 97 | 98 | class UpsampleConvLayer(torch.nn.Module): 99 | def __init__(self, in_channels, out_channels, kernel_size, stride): 100 | super(UpsampleConvLayer, self).__init__() 101 | self.conv2d = nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride=stride, padding=1) 102 | 103 | def forward(self, x): 104 | out = self.conv2d(x) 105 | return out 106 | 107 | 108 | class ResidualBlock(torch.nn.Module): 109 | def __init__(self, channels): 110 | super(ResidualBlock, self).__init__() 111 | self.conv1 = ConvLayer(channels, channels, kernel_size=3, stride=1, padding=1) 112 | self.conv2 = ConvLayer(channels, channels, kernel_size=3, stride=1, padding=1) 113 | self.relu = nn.ReLU() 114 | 115 | def forward(self, x): 116 | residual = x 117 | out = self.relu(self.conv1(x)) 118 | out = self.conv2(out) * 0.1 119 | out = torch.add(out, residual) 120 | return out 121 | 122 | 123 | 124 | def init_linear(linear): 125 | init.xavier_normal(linear.weight) 126 | linear.bias.data.zero_() 127 | 128 | 129 | def init_conv(conv, glu=True): 130 | init.kaiming_normal(conv.weight) 131 | if conv.bias is not None: 132 | conv.bias.data.zero_() 133 | 134 | 135 | class EqualLR: 136 | def __init__(self, name): 137 | self.name = name 138 | 139 | def compute_weight(self, module): 140 | weight = getattr(module, self.name + '_orig') 141 | fan_in = weight.data.size(1) * weight.data[0][0].numel() 142 | 143 | return weight * sqrt(2 / fan_in) 144 | 145 | @staticmethod 146 | def apply(module, name): 147 | fn = EqualLR(name) 148 | 149 | weight = getattr(module, name) 150 | del module._parameters[name] 151 | module.register_parameter(name + '_orig', nn.Parameter(weight.data)) 152 | module.register_forward_pre_hook(fn) 153 | 154 | return fn 155 | 156 | def __call__(self, module, input): 157 | weight = self.compute_weight(module) 158 | setattr(module, self.name, weight) 159 | 160 | 161 | def equal_lr(module, name='weight'): 162 | EqualLR.apply(module, name) 163 | 164 | return module -------------------------------------------------------------------------------- /STADE-CDNet/models/DTCDSCN.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | from torchvision.models import ResNet 5 | import torch.nn.functional as F 6 | from functools import partial 7 | 8 | 9 | nonlinearity = partial(F.relu,inplace=True) 10 | 11 | class SELayer(nn.Module): 12 | def __init__(self, channel, reduction=16): 13 | super(SELayer, self).__init__() 14 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 15 | self.fc = nn.Sequential( 16 | nn.Linear(channel, channel // reduction, bias=False), 17 | nn.ReLU(inplace=True), 18 | nn.Linear(channel // reduction, channel, bias=False), 19 | nn.Sigmoid() 20 | ) 21 | 22 | def forward(self, x): 23 | b, c, _, _ = x.size() 24 | y = self.avg_pool(x).view(b, c) 25 | y = self.fc(y).view(b, c, 1, 1) 26 | return x * y.expand_as(x) 27 | 28 | class Dblock_more_dilate(nn.Module): 29 | def __init__(self, channel): 30 | super(Dblock_more_dilate, self).__init__() 31 | self.dilate1 = nn.Conv2d(channel, channel, kernel_size=3, dilation=1, padding=1) 32 | self.dilate2 = nn.Conv2d(channel, channel, kernel_size=3, dilation=2, padding=2) 33 | self.dilate3 = nn.Conv2d(channel, channel, kernel_size=3, dilation=4, padding=4) 34 | self.dilate4 = nn.Conv2d(channel, channel, kernel_size=3, dilation=8, padding=8) 35 | self.dilate5 = nn.Conv2d(channel, channel, kernel_size=3, dilation=16, padding=16) 36 | for m in self.modules(): 37 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): 38 | if m.bias is not None: 39 | m.bias.data.zero_() 40 | 41 | def forward(self, x): 42 | dilate1_out = nonlinearity(self.dilate1(x)) 43 | dilate2_out = nonlinearity(self.dilate2(dilate1_out)) 44 | dilate3_out = nonlinearity(self.dilate3(dilate2_out)) 45 | dilate4_out = nonlinearity(self.dilate4(dilate3_out)) 46 | dilate5_out = nonlinearity(self.dilate5(dilate4_out)) 47 | out = x + dilate1_out + dilate2_out + dilate3_out + dilate4_out + dilate5_out 48 | return out 49 | class Dblock(nn.Module): 50 | def __init__(self, channel): 51 | super(Dblock, self).__init__() 52 | self.dilate1 = nn.Conv2d(channel, channel, kernel_size=3, dilation=1, padding=1) 53 | self.dilate2 = nn.Conv2d(channel, channel, kernel_size=3, dilation=2, padding=2) 54 | self.dilate3 = nn.Conv2d(channel, channel, kernel_size=3, dilation=4, padding=4) 55 | self.dilate4 = nn.Conv2d(channel, channel, kernel_size=3, dilation=8, padding=8) 56 | # self.dilate5 = nn.Conv2d(channel, channel, kernel_size=3, dilation=16, padding=16) 57 | for m in self.modules(): 58 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): 59 | if m.bias is not None: 60 | m.bias.data.zero_() 61 | 62 | def forward(self, x): 63 | dilate1_out = nonlinearity(self.dilate1(x)) 64 | dilate2_out = nonlinearity(self.dilate2(dilate1_out)) 65 | dilate3_out = nonlinearity(self.dilate3(dilate2_out)) 66 | dilate4_out = nonlinearity(self.dilate4(dilate3_out)) 67 | # dilate5_out = nonlinearity(self.dilate5(dilate4_out)) 68 | out = x + dilate1_out + dilate2_out + dilate3_out + dilate4_out # + dilate5_out 69 | return out 70 | 71 | def conv3x3(in_planes, out_planes, stride=1): 72 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) 73 | 74 | class SEBasicBlock(nn.Module): 75 | expansion = 1 76 | 77 | def __init__(self, inplanes, planes, stride=1, downsample=None, reduction=16): 78 | super(SEBasicBlock, self).__init__() 79 | self.conv1 = conv3x3(inplanes, planes, stride) 80 | self.bn1 = nn.BatchNorm2d(planes) 81 | self.relu = nn.ReLU(inplace=True) 82 | self.conv2 = conv3x3(planes, planes, 1) 83 | self.bn2 = nn.BatchNorm2d(planes) 84 | self.se = SELayer(planes, reduction) 85 | self.downsample = downsample 86 | self.stride = stride 87 | 88 | def forward(self, x): 89 | residual = x 90 | out = self.conv1(x) 91 | out = self.bn1(out) 92 | out = self.relu(out) 93 | 94 | out = self.conv2(out) 95 | out = self.bn2(out) 96 | out = self.se(out) 97 | 98 | if self.downsample is not None: 99 | residual = self.downsample(x) 100 | 101 | out += residual 102 | out = self.relu(out) 103 | 104 | return out 105 | 106 | class DecoderBlock(nn.Module): 107 | def __init__(self, in_channels, n_filters): 108 | super(DecoderBlock,self).__init__() 109 | 110 | self.conv1 = nn.Conv2d(in_channels, in_channels // 4, 1) 111 | self.norm1 = nn.BatchNorm2d(in_channels // 4) 112 | self.relu1 = nonlinearity 113 | self.scse = SCSEBlock(in_channels // 4) 114 | 115 | self.deconv2 = nn.ConvTranspose2d(in_channels // 4, in_channels // 4, 3, stride=2, padding=1, output_padding=1) 116 | self.norm2 = nn.BatchNorm2d(in_channels // 4) 117 | self.relu2 = nonlinearity 118 | 119 | self.conv3 = nn.Conv2d(in_channels // 4, n_filters, 1) 120 | self.norm3 = nn.BatchNorm2d(n_filters) 121 | self.relu3 = nonlinearity 122 | 123 | def forward(self, x): 124 | x = self.conv1(x) 125 | x = self.norm1(x) 126 | x = self.relu1(x) 127 | y = self.scse(x) 128 | x = x + y 129 | x = self.deconv2(x) 130 | x = self.norm2(x) 131 | x = self.relu2(x) 132 | x = self.conv3(x) 133 | x = self.norm3(x) 134 | x = self.relu3(x) 135 | return x 136 | 137 | class SCSEBlock(nn.Module): 138 | def __init__(self, channel, reduction=16): 139 | super(SCSEBlock, self).__init__() 140 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 141 | 142 | '''self.channel_excitation = nn.Sequential(nn.(channel, int(channel//reduction)), 143 | nn.ReLU(inplace=True), 144 | nn.Linear(int(channel//reduction), channel), 145 | nn.Sigmoid())''' 146 | self.channel_excitation = nn.Sequential(nn.Conv2d(channel, int(channel//reduction), kernel_size=1, 147 | stride=1, padding=0, bias=False), 148 | nn.ReLU(inplace=True), 149 | nn.Conv2d(int(channel // reduction), channel,kernel_size=1, 150 | stride=1, padding=0, bias=False), 151 | nn.Sigmoid()) 152 | 153 | self.spatial_se = nn.Sequential(nn.Conv2d(channel, 1, kernel_size=1, 154 | stride=1, padding=0, bias=False), 155 | nn.Sigmoid()) 156 | 157 | def forward(self, x): 158 | bahs, chs, _, _ = x.size() 159 | 160 | # Returns a new tensor with the same data as the self tensor but of a different size. 161 | chn_se = self.avg_pool(x) 162 | chn_se = self.channel_excitation(chn_se) 163 | chn_se = torch.mul(x, chn_se) 164 | spa_se = self.spatial_se(x) 165 | spa_se = torch.mul(x, spa_se) 166 | return torch.add(chn_se, 1, spa_se) 167 | 168 | class CDNet_model(nn.Module): 169 | def __init__(self, in_channels=3, block=SEBasicBlock, layers=[3, 4, 6, 3], num_classes=2): 170 | super(CDNet_model, self).__init__() 171 | 172 | filters = [64, 128, 256, 512] 173 | self.inplanes = 64 174 | self.firstconv = nn.Conv2d(in_channels, 64, kernel_size=7, stride=2, padding=3, 175 | bias=False) 176 | self.firstbn = nn.BatchNorm2d(64) 177 | self.firstrelu = nn.ReLU(inplace=True) 178 | self.firstmaxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 179 | self.encoder1 = self._make_layer(block, 64, layers[0]) 180 | self.encoder2 = self._make_layer(block, 128, layers[1], stride=2) 181 | self.encoder3 = self._make_layer(block, 256, layers[2], stride=2) 182 | self.encoder4 = self._make_layer(block, 512, layers[3], stride=2) 183 | 184 | self.decoder4 = DecoderBlock(filters[3], filters[2]) 185 | self.decoder3 = DecoderBlock(filters[2], filters[1]) 186 | self.decoder2 = DecoderBlock(filters[1], filters[0]) 187 | self.decoder1 = DecoderBlock(filters[0], filters[0]) 188 | 189 | self.dblock_master = Dblock(512) 190 | self.dblock = Dblock(512) 191 | 192 | self.decoder4_master = DecoderBlock(filters[3], filters[2]) 193 | self.decoder3_master = DecoderBlock(filters[2], filters[1]) 194 | self.decoder2_master = DecoderBlock(filters[1], filters[0]) 195 | self.decoder1_master = DecoderBlock(filters[0], filters[0]) 196 | 197 | self.finaldeconv1_master = nn.ConvTranspose2d(filters[0], 32, 4, 2, 1) 198 | self.finalrelu1_master = nonlinearity 199 | self.finalconv2_master = nn.Conv2d(32, 32, 3, padding=1) 200 | self.finalrelu2_master = nonlinearity 201 | self.finalconv3_master = nn.Conv2d(32, num_classes, 3, padding=1) 202 | 203 | self.finaldeconv1 = nn.ConvTranspose2d(filters[0], 32, 4, 2, 1) 204 | self.finalrelu1 = nonlinearity 205 | self.finalconv2 = nn.Conv2d(32, 32, 3, padding=1) 206 | self.finalrelu2 = nonlinearity 207 | self.finalconv3 = nn.Conv2d(32, num_classes, 3, padding=1) 208 | 209 | for m in self.modules(): 210 | if isinstance(m, nn.Conv2d): 211 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 212 | m.weight.data.normal_(0, math.sqrt(2. / n)) 213 | elif isinstance(m, nn.BatchNorm2d): 214 | m.weight.data.fill_(1) 215 | m.bias.data.zero_() 216 | 217 | def _make_layer(self, block, planes, blocks, stride=1): 218 | downsample = None 219 | if stride != 1 or self.inplanes != planes * block.expansion: 220 | downsample = nn.Sequential( 221 | nn.Conv2d(self.inplanes, planes * block.expansion, 222 | kernel_size=1, stride=stride, bias=False), 223 | nn.BatchNorm2d(planes * block.expansion), 224 | ) 225 | 226 | layers = [] 227 | layers.append(block(self.inplanes, planes, stride, downsample)) 228 | self.inplanes = planes * block.expansion 229 | for i in range(1, blocks): 230 | layers.append(block(self.inplanes, planes)) 231 | 232 | return nn.Sequential(*layers) 233 | 234 | def forward(self, x, y): 235 | # Encoder_1 236 | x = self.firstconv(x) 237 | x = self.firstbn(x) 238 | x = self.firstrelu(x) 239 | x = self.firstmaxpool(x) 240 | 241 | e1_x = self.encoder1(x) 242 | e2_x = self.encoder2(e1_x) 243 | e3_x = self.encoder3(e2_x) 244 | e4_x = self.encoder4(e3_x) 245 | 246 | # # Center_1 247 | # e4_x_center = self.dblock(e4_x) 248 | 249 | # # Decoder_1 250 | # d4_x = self.decoder4(e4_x_center) + e3_x 251 | # d3_x = self.decoder3(d4_x) + e2_x 252 | # d2_x = self.decoder2(d3_x) + e1_x 253 | # d1_x = self.decoder1(d2_x) 254 | 255 | # out1 = self.finaldeconv1(d1_x) 256 | # out1 = self.finalrelu1(out1) 257 | # out1 = self.finalconv2(out1) 258 | # out1 = self.finalrelu2(out1) 259 | # out1 = self.finalconv3(out1) 260 | 261 | # Encoder_2 262 | y = self.firstconv(y) 263 | y = self.firstbn(y) 264 | y = self.firstrelu(y) 265 | y = self.firstmaxpool(y) 266 | 267 | e1_y = self.encoder1(y) 268 | e2_y = self.encoder2(e1_y) 269 | e3_y = self.encoder3(e2_y) 270 | e4_y = self.encoder4(e3_y) 271 | 272 | # # Center_2 273 | # e4_y_center = self.dblock(e4_y) 274 | 275 | # # Decoder_2 276 | # d4_y = self.decoder4(e4_y_center) + e3_y 277 | # d3_y = self.decoder3(d4_y) + e2_y 278 | # d2_y = self.decoder2(d3_y) + e1_y 279 | # d1_y = self.decoder1(d2_y) 280 | # out2 = self.finaldeconv1(d1_y) 281 | # out2 = self.finalrelu1(out2) 282 | # out2 = self.finalconv2(out2) 283 | # out2 = self.finalrelu2(out2) 284 | # out2 = self.finalconv3(out2) 285 | 286 | # center_master 287 | e4 = self.dblock_master(e4_x - e4_y) 288 | # decoder_master 289 | d4 = self.decoder4_master(e4) + e3_x - e3_y 290 | d3 = self.decoder3_master(d4) + e2_x - e2_y 291 | d2 = self.decoder2_master(d3) + e1_x - e1_y 292 | d1 = self.decoder1_master(d2) 293 | 294 | out = self.finaldeconv1_master(d1) 295 | out = self.finalrelu1_master(out) 296 | out = self.finalconv2_master(out) 297 | out = self.finalrelu2_master(out) 298 | out = self.finalconv3_master(out) 299 | 300 | output = [] 301 | output.append(out) 302 | 303 | return output 304 | 305 | 306 | 307 | def CDNet34(in_channels, **kwargs): 308 | 309 | model = CDNet_model(in_channels, SEBasicBlock, [3, 4, 6, 3], **kwargs) 310 | 311 | return model -------------------------------------------------------------------------------- /STADE-CDNet/models/SiamUnet_conc.py: -------------------------------------------------------------------------------- 1 | # Rodrigo Caye Daudt 2 | # https://rcdaudt.github.io/ 3 | # Daudt, R. C., Le Saux, B., & Boulch, A. "Fully convolutional siamese networks for change detection". In 2018 25th IEEE International Conference on Image Processing (ICIP) (pp. 4063-4067). IEEE. 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from torch.nn.modules.padding import ReplicationPad2d 9 | 10 | class SiamUnet_conc(nn.Module): 11 | """SiamUnet_conc segmentation network.""" 12 | 13 | def __init__(self, input_nbr, label_nbr): 14 | super(SiamUnet_conc, self).__init__() 15 | 16 | self.input_nbr = input_nbr 17 | 18 | self.conv11 = nn.Conv2d(input_nbr, 16, kernel_size=3, padding=1) 19 | self.bn11 = nn.BatchNorm2d(16) 20 | self.do11 = nn.Dropout2d(p=0.2) 21 | self.conv12 = nn.Conv2d(16, 16, kernel_size=3, padding=1) 22 | self.bn12 = nn.BatchNorm2d(16) 23 | self.do12 = nn.Dropout2d(p=0.2) 24 | 25 | self.conv21 = nn.Conv2d(16, 32, kernel_size=3, padding=1) 26 | self.bn21 = nn.BatchNorm2d(32) 27 | self.do21 = nn.Dropout2d(p=0.2) 28 | self.conv22 = nn.Conv2d(32, 32, kernel_size=3, padding=1) 29 | self.bn22 = nn.BatchNorm2d(32) 30 | self.do22 = nn.Dropout2d(p=0.2) 31 | 32 | self.conv31 = nn.Conv2d(32, 64, kernel_size=3, padding=1) 33 | self.bn31 = nn.BatchNorm2d(64) 34 | self.do31 = nn.Dropout2d(p=0.2) 35 | self.conv32 = nn.Conv2d(64, 64, kernel_size=3, padding=1) 36 | self.bn32 = nn.BatchNorm2d(64) 37 | self.do32 = nn.Dropout2d(p=0.2) 38 | self.conv33 = nn.Conv2d(64, 64, kernel_size=3, padding=1) 39 | self.bn33 = nn.BatchNorm2d(64) 40 | self.do33 = nn.Dropout2d(p=0.2) 41 | 42 | self.conv41 = nn.Conv2d(64, 128, kernel_size=3, padding=1) 43 | self.bn41 = nn.BatchNorm2d(128) 44 | self.do41 = nn.Dropout2d(p=0.2) 45 | self.conv42 = nn.Conv2d(128, 128, kernel_size=3, padding=1) 46 | self.bn42 = nn.BatchNorm2d(128) 47 | self.do42 = nn.Dropout2d(p=0.2) 48 | self.conv43 = nn.Conv2d(128, 128, kernel_size=3, padding=1) 49 | self.bn43 = nn.BatchNorm2d(128) 50 | self.do43 = nn.Dropout2d(p=0.2) 51 | 52 | self.upconv4 = nn.ConvTranspose2d(128, 128, kernel_size=3, padding=1, stride=2, output_padding=1) 53 | 54 | self.conv43d = nn.ConvTranspose2d(384, 128, kernel_size=3, padding=1) 55 | self.bn43d = nn.BatchNorm2d(128) 56 | self.do43d = nn.Dropout2d(p=0.2) 57 | self.conv42d = nn.ConvTranspose2d(128, 128, kernel_size=3, padding=1) 58 | self.bn42d = nn.BatchNorm2d(128) 59 | self.do42d = nn.Dropout2d(p=0.2) 60 | self.conv41d = nn.ConvTranspose2d(128, 64, kernel_size=3, padding=1) 61 | self.bn41d = nn.BatchNorm2d(64) 62 | self.do41d = nn.Dropout2d(p=0.2) 63 | 64 | self.upconv3 = nn.ConvTranspose2d(64, 64, kernel_size=3, padding=1, stride=2, output_padding=1) 65 | 66 | self.conv33d = nn.ConvTranspose2d(192, 64, kernel_size=3, padding=1) 67 | self.bn33d = nn.BatchNorm2d(64) 68 | self.do33d = nn.Dropout2d(p=0.2) 69 | self.conv32d = nn.ConvTranspose2d(64, 64, kernel_size=3, padding=1) 70 | self.bn32d = nn.BatchNorm2d(64) 71 | self.do32d = nn.Dropout2d(p=0.2) 72 | self.conv31d = nn.ConvTranspose2d(64, 32, kernel_size=3, padding=1) 73 | self.bn31d = nn.BatchNorm2d(32) 74 | self.do31d = nn.Dropout2d(p=0.2) 75 | 76 | self.upconv2 = nn.ConvTranspose2d(32, 32, kernel_size=3, padding=1, stride=2, output_padding=1) 77 | 78 | self.conv22d = nn.ConvTranspose2d(96, 32, kernel_size=3, padding=1) 79 | self.bn22d = nn.BatchNorm2d(32) 80 | self.do22d = nn.Dropout2d(p=0.2) 81 | self.conv21d = nn.ConvTranspose2d(32, 16, kernel_size=3, padding=1) 82 | self.bn21d = nn.BatchNorm2d(16) 83 | self.do21d = nn.Dropout2d(p=0.2) 84 | 85 | self.upconv1 = nn.ConvTranspose2d(16, 16, kernel_size=3, padding=1, stride=2, output_padding=1) 86 | 87 | self.conv12d = nn.ConvTranspose2d(48, 16, kernel_size=3, padding=1) 88 | self.bn12d = nn.BatchNorm2d(16) 89 | self.do12d = nn.Dropout2d(p=0.2) 90 | self.conv11d = nn.ConvTranspose2d(16, label_nbr, kernel_size=3, padding=1) 91 | 92 | self.sm = nn.LogSoftmax(dim=1) 93 | 94 | def forward(self, x1, x2): 95 | 96 | """Forward method.""" 97 | # Stage 1 98 | x11 = self.do11(F.relu(self.bn11(self.conv11(x1)))) 99 | x12_1 = self.do12(F.relu(self.bn12(self.conv12(x11)))) 100 | x1p = F.max_pool2d(x12_1, kernel_size=2, stride=2) 101 | 102 | 103 | # Stage 2 104 | x21 = self.do21(F.relu(self.bn21(self.conv21(x1p)))) 105 | x22_1 = self.do22(F.relu(self.bn22(self.conv22(x21)))) 106 | x2p = F.max_pool2d(x22_1, kernel_size=2, stride=2) 107 | 108 | # Stage 3 109 | x31 = self.do31(F.relu(self.bn31(self.conv31(x2p)))) 110 | x32 = self.do32(F.relu(self.bn32(self.conv32(x31)))) 111 | x33_1 = self.do33(F.relu(self.bn33(self.conv33(x32)))) 112 | x3p = F.max_pool2d(x33_1, kernel_size=2, stride=2) 113 | 114 | # Stage 4 115 | x41 = self.do41(F.relu(self.bn41(self.conv41(x3p)))) 116 | x42 = self.do42(F.relu(self.bn42(self.conv42(x41)))) 117 | x43_1 = self.do43(F.relu(self.bn43(self.conv43(x42)))) 118 | x4p = F.max_pool2d(x43_1, kernel_size=2, stride=2) 119 | 120 | 121 | #################################################### 122 | # Stage 1 123 | x11 = self.do11(F.relu(self.bn11(self.conv11(x2)))) 124 | x12_2 = self.do12(F.relu(self.bn12(self.conv12(x11)))) 125 | x1p = F.max_pool2d(x12_2, kernel_size=2, stride=2) 126 | 127 | # Stage 2 128 | x21 = self.do21(F.relu(self.bn21(self.conv21(x1p)))) 129 | x22_2 = self.do22(F.relu(self.bn22(self.conv22(x21)))) 130 | x2p = F.max_pool2d(x22_2, kernel_size=2, stride=2) 131 | 132 | # Stage 3 133 | x31 = self.do31(F.relu(self.bn31(self.conv31(x2p)))) 134 | x32 = self.do32(F.relu(self.bn32(self.conv32(x31)))) 135 | x33_2 = self.do33(F.relu(self.bn33(self.conv33(x32)))) 136 | x3p = F.max_pool2d(x33_2, kernel_size=2, stride=2) 137 | 138 | # Stage 4 139 | x41 = self.do41(F.relu(self.bn41(self.conv41(x3p)))) 140 | x42 = self.do42(F.relu(self.bn42(self.conv42(x41)))) 141 | x43_2 = self.do43(F.relu(self.bn43(self.conv43(x42)))) 142 | x4p = F.max_pool2d(x43_2, kernel_size=2, stride=2) 143 | 144 | 145 | #################################################### 146 | # Stage 4d 147 | x4d = self.upconv4(x4p) 148 | pad4 = ReplicationPad2d((0, x43_1.size(3) - x4d.size(3), 0, x43_1.size(2) - x4d.size(2))) 149 | x4d = torch.cat((pad4(x4d), x43_1, x43_2), 1) 150 | x43d = self.do43d(F.relu(self.bn43d(self.conv43d(x4d)))) 151 | x42d = self.do42d(F.relu(self.bn42d(self.conv42d(x43d)))) 152 | x41d = self.do41d(F.relu(self.bn41d(self.conv41d(x42d)))) 153 | 154 | # Stage 3d 155 | x3d = self.upconv3(x41d) 156 | pad3 = ReplicationPad2d((0, x33_1.size(3) - x3d.size(3), 0, x33_1.size(2) - x3d.size(2))) 157 | x3d = torch.cat((pad3(x3d), x33_1, x33_2), 1) 158 | x33d = self.do33d(F.relu(self.bn33d(self.conv33d(x3d)))) 159 | x32d = self.do32d(F.relu(self.bn32d(self.conv32d(x33d)))) 160 | x31d = self.do31d(F.relu(self.bn31d(self.conv31d(x32d)))) 161 | 162 | # Stage 2d 163 | x2d = self.upconv2(x31d) 164 | pad2 = ReplicationPad2d((0, x22_1.size(3) - x2d.size(3), 0, x22_1.size(2) - x2d.size(2))) 165 | x2d = torch.cat((pad2(x2d), x22_1, x22_2), 1) 166 | x22d = self.do22d(F.relu(self.bn22d(self.conv22d(x2d)))) 167 | x21d = self.do21d(F.relu(self.bn21d(self.conv21d(x22d)))) 168 | 169 | # Stage 1d 170 | x1d = self.upconv1(x21d) 171 | pad1 = ReplicationPad2d((0, x12_1.size(3) - x1d.size(3), 0, x12_1.size(2) - x1d.size(2))) 172 | x1d = torch.cat((pad1(x1d), x12_1, x12_2), 1) 173 | x12d = self.do12d(F.relu(self.bn12d(self.conv12d(x1d)))) 174 | x11d = self.conv11d(x12d) 175 | 176 | #Softmax layer is embedded in the loss layer 177 | #out = self.sm(x11d) 178 | output = [] 179 | output.append(x11d) 180 | 181 | return output 182 | 183 | 184 | -------------------------------------------------------------------------------- /STADE-CDNet/models/SiamUnet_diff.py: -------------------------------------------------------------------------------- 1 | # Rodrigo Caye Daudt 2 | # https://rcdaudt.github.io/ 3 | # Daudt, R. C., Le Saux, B., & Boulch, A. "Fully convolutional siamese networks for change detection". In 2018 25th IEEE International Conference on Image Processing (ICIP) (pp. 4063-4067). IEEE. 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from torch.nn.modules.padding import ReplicationPad2d 9 | 10 | class SiamUnet_diff(nn.Module): 11 | """SiamUnet_diff segmentation network.""" 12 | 13 | def __init__(self, input_nbr, label_nbr): 14 | super(SiamUnet_diff, self).__init__() 15 | 16 | self.input_nbr = input_nbr 17 | 18 | self.conv11 = nn.Conv2d(input_nbr, 16, kernel_size=3, padding=1) 19 | self.bn11 = nn.BatchNorm2d(16) 20 | self.do11 = nn.Dropout2d(p=0.2) 21 | self.conv12 = nn.Conv2d(16, 16, kernel_size=3, padding=1) 22 | self.bn12 = nn.BatchNorm2d(16) 23 | self.do12 = nn.Dropout2d(p=0.2) 24 | 25 | self.conv21 = nn.Conv2d(16, 32, kernel_size=3, padding=1) 26 | self.bn21 = nn.BatchNorm2d(32) 27 | self.do21 = nn.Dropout2d(p=0.2) 28 | self.conv22 = nn.Conv2d(32, 32, kernel_size=3, padding=1) 29 | self.bn22 = nn.BatchNorm2d(32) 30 | self.do22 = nn.Dropout2d(p=0.2) 31 | 32 | self.conv31 = nn.Conv2d(32, 64, kernel_size=3, padding=1) 33 | self.bn31 = nn.BatchNorm2d(64) 34 | self.do31 = nn.Dropout2d(p=0.2) 35 | self.conv32 = nn.Conv2d(64, 64, kernel_size=3, padding=1) 36 | self.bn32 = nn.BatchNorm2d(64) 37 | self.do32 = nn.Dropout2d(p=0.2) 38 | self.conv33 = nn.Conv2d(64, 64, kernel_size=3, padding=1) 39 | self.bn33 = nn.BatchNorm2d(64) 40 | self.do33 = nn.Dropout2d(p=0.2) 41 | 42 | self.conv41 = nn.Conv2d(64, 128, kernel_size=3, padding=1) 43 | self.bn41 = nn.BatchNorm2d(128) 44 | self.do41 = nn.Dropout2d(p=0.2) 45 | self.conv42 = nn.Conv2d(128, 128, kernel_size=3, padding=1) 46 | self.bn42 = nn.BatchNorm2d(128) 47 | self.do42 = nn.Dropout2d(p=0.2) 48 | self.conv43 = nn.Conv2d(128, 128, kernel_size=3, padding=1) 49 | self.bn43 = nn.BatchNorm2d(128) 50 | self.do43 = nn.Dropout2d(p=0.2) 51 | 52 | self.upconv4 = nn.ConvTranspose2d(128, 128, kernel_size=3, padding=1, stride=2, output_padding=1) 53 | 54 | self.conv43d = nn.ConvTranspose2d(256, 128, kernel_size=3, padding=1) 55 | self.bn43d = nn.BatchNorm2d(128) 56 | self.do43d = nn.Dropout2d(p=0.2) 57 | self.conv42d = nn.ConvTranspose2d(128, 128, kernel_size=3, padding=1) 58 | self.bn42d = nn.BatchNorm2d(128) 59 | self.do42d = nn.Dropout2d(p=0.2) 60 | self.conv41d = nn.ConvTranspose2d(128, 64, kernel_size=3, padding=1) 61 | self.bn41d = nn.BatchNorm2d(64) 62 | self.do41d = nn.Dropout2d(p=0.2) 63 | 64 | self.upconv3 = nn.ConvTranspose2d(64, 64, kernel_size=3, padding=1, stride=2, output_padding=1) 65 | 66 | self.conv33d = nn.ConvTranspose2d(128, 64, kernel_size=3, padding=1) 67 | self.bn33d = nn.BatchNorm2d(64) 68 | self.do33d = nn.Dropout2d(p=0.2) 69 | self.conv32d = nn.ConvTranspose2d(64, 64, kernel_size=3, padding=1) 70 | self.bn32d = nn.BatchNorm2d(64) 71 | self.do32d = nn.Dropout2d(p=0.2) 72 | self.conv31d = nn.ConvTranspose2d(64, 32, kernel_size=3, padding=1) 73 | self.bn31d = nn.BatchNorm2d(32) 74 | self.do31d = nn.Dropout2d(p=0.2) 75 | 76 | self.upconv2 = nn.ConvTranspose2d(32, 32, kernel_size=3, padding=1, stride=2, output_padding=1) 77 | 78 | self.conv22d = nn.ConvTranspose2d(64, 32, kernel_size=3, padding=1) 79 | self.bn22d = nn.BatchNorm2d(32) 80 | self.do22d = nn.Dropout2d(p=0.2) 81 | self.conv21d = nn.ConvTranspose2d(32, 16, kernel_size=3, padding=1) 82 | self.bn21d = nn.BatchNorm2d(16) 83 | self.do21d = nn.Dropout2d(p=0.2) 84 | 85 | self.upconv1 = nn.ConvTranspose2d(16, 16, kernel_size=3, padding=1, stride=2, output_padding=1) 86 | 87 | self.conv12d = nn.ConvTranspose2d(32, 16, kernel_size=3, padding=1) 88 | self.bn12d = nn.BatchNorm2d(16) 89 | self.do12d = nn.Dropout2d(p=0.2) 90 | self.conv11d = nn.ConvTranspose2d(16, label_nbr, kernel_size=3, padding=1) 91 | 92 | self.sm = nn.LogSoftmax(dim=1) 93 | 94 | def forward(self, x1, x2): 95 | 96 | 97 | """Forward method.""" 98 | # Stage 1 99 | x11 = self.do11(F.relu(self.bn11(self.conv11(x1)))) 100 | x12_1 = self.do12(F.relu(self.bn12(self.conv12(x11)))) 101 | x1p = F.max_pool2d(x12_1, kernel_size=2, stride=2) 102 | 103 | 104 | # Stage 2 105 | x21 = self.do21(F.relu(self.bn21(self.conv21(x1p)))) 106 | x22_1 = self.do22(F.relu(self.bn22(self.conv22(x21)))) 107 | x2p = F.max_pool2d(x22_1, kernel_size=2, stride=2) 108 | 109 | # Stage 3 110 | x31 = self.do31(F.relu(self.bn31(self.conv31(x2p)))) 111 | x32 = self.do32(F.relu(self.bn32(self.conv32(x31)))) 112 | x33_1 = self.do33(F.relu(self.bn33(self.conv33(x32)))) 113 | x3p = F.max_pool2d(x33_1, kernel_size=2, stride=2) 114 | 115 | # Stage 4 116 | x41 = self.do41(F.relu(self.bn41(self.conv41(x3p)))) 117 | x42 = self.do42(F.relu(self.bn42(self.conv42(x41)))) 118 | x43_1 = self.do43(F.relu(self.bn43(self.conv43(x42)))) 119 | x4p = F.max_pool2d(x43_1, kernel_size=2, stride=2) 120 | 121 | #################################################### 122 | # Stage 1 123 | x11 = self.do11(F.relu(self.bn11(self.conv11(x2)))) 124 | x12_2 = self.do12(F.relu(self.bn12(self.conv12(x11)))) 125 | x1p = F.max_pool2d(x12_2, kernel_size=2, stride=2) 126 | 127 | 128 | # Stage 2 129 | x21 = self.do21(F.relu(self.bn21(self.conv21(x1p)))) 130 | x22_2 = self.do22(F.relu(self.bn22(self.conv22(x21)))) 131 | x2p = F.max_pool2d(x22_2, kernel_size=2, stride=2) 132 | 133 | # Stage 3 134 | x31 = self.do31(F.relu(self.bn31(self.conv31(x2p)))) 135 | x32 = self.do32(F.relu(self.bn32(self.conv32(x31)))) 136 | x33_2 = self.do33(F.relu(self.bn33(self.conv33(x32)))) 137 | x3p = F.max_pool2d(x33_2, kernel_size=2, stride=2) 138 | 139 | # Stage 4 140 | x41 = self.do41(F.relu(self.bn41(self.conv41(x3p)))) 141 | x42 = self.do42(F.relu(self.bn42(self.conv42(x41)))) 142 | x43_2 = self.do43(F.relu(self.bn43(self.conv43(x42)))) 143 | x4p = F.max_pool2d(x43_2, kernel_size=2, stride=2) 144 | 145 | 146 | 147 | # Stage 4d 148 | x4d = self.upconv4(x4p) 149 | pad4 = ReplicationPad2d((0, x43_1.size(3) - x4d.size(3), 0, x43_1.size(2) - x4d.size(2))) 150 | x4d = torch.cat((pad4(x4d), torch.abs(x43_1 - x43_2)), 1) 151 | x43d = self.do43d(F.relu(self.bn43d(self.conv43d(x4d)))) 152 | x42d = self.do42d(F.relu(self.bn42d(self.conv42d(x43d)))) 153 | x41d = self.do41d(F.relu(self.bn41d(self.conv41d(x42d)))) 154 | 155 | # Stage 3d 156 | x3d = self.upconv3(x41d) 157 | pad3 = ReplicationPad2d((0, x33_1.size(3) - x3d.size(3), 0, x33_1.size(2) - x3d.size(2))) 158 | x3d = torch.cat((pad3(x3d), torch.abs(x33_1 - x33_2)), 1) 159 | x33d = self.do33d(F.relu(self.bn33d(self.conv33d(x3d)))) 160 | x32d = self.do32d(F.relu(self.bn32d(self.conv32d(x33d)))) 161 | x31d = self.do31d(F.relu(self.bn31d(self.conv31d(x32d)))) 162 | 163 | # Stage 2d 164 | x2d = self.upconv2(x31d) 165 | pad2 = ReplicationPad2d((0, x22_1.size(3) - x2d.size(3), 0, x22_1.size(2) - x2d.size(2))) 166 | x2d = torch.cat((pad2(x2d), torch.abs(x22_1 - x22_2)), 1) 167 | x22d = self.do22d(F.relu(self.bn22d(self.conv22d(x2d)))) 168 | x21d = self.do21d(F.relu(self.bn21d(self.conv21d(x22d)))) 169 | 170 | # Stage 1d 171 | x1d = self.upconv1(x21d) 172 | pad1 = ReplicationPad2d((0, x12_1.size(3) - x1d.size(3), 0, x12_1.size(2) - x1d.size(2))) 173 | x1d = torch.cat((pad1(x1d), torch.abs(x12_1 - x12_2)), 1) 174 | x12d = self.do12d(F.relu(self.bn12d(self.conv12d(x1d)))) 175 | x11d = self.conv11d(x12d) 176 | #out = self.sm(x11d) 177 | 178 | output = [] 179 | output.append(x11d) 180 | 181 | return output 182 | -------------------------------------------------------------------------------- /STADE-CDNet/models/Unet.py: -------------------------------------------------------------------------------- 1 | # Rodrigo Caye Daudt 2 | # https://rcdaudt.github.io/ 3 | # Daudt, R. C., Le Saux, B., & Boulch, A. "Fully convolutional siamese networks for change detection". In 2018 25th IEEE International Conference on Image Processing (ICIP) (pp. 4063-4067). IEEE. 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from torch.nn.modules.padding import ReplicationPad2d 9 | 10 | class Unet(nn.Module): 11 | """EF segmentation network.""" 12 | 13 | def __init__(self, input_nbr, label_nbr): 14 | super(Unet, self).__init__() 15 | 16 | self.conv11 = nn.Conv2d(2*input_nbr, 16, kernel_size=3, padding=1) 17 | self.bn11 = nn.BatchNorm2d(16) 18 | self.do11 = nn.Dropout2d(p=0.2) 19 | self.conv12 = nn.Conv2d(16, 16, kernel_size=3, padding=1) 20 | self.bn12 = nn.BatchNorm2d(16) 21 | self.do12 = nn.Dropout2d(p=0.2) 22 | 23 | self.conv21 = nn.Conv2d(16, 32, kernel_size=3, padding=1) 24 | self.bn21 = nn.BatchNorm2d(32) 25 | self.do21 = nn.Dropout2d(p=0.2) 26 | self.conv22 = nn.Conv2d(32, 32, kernel_size=3, padding=1) 27 | self.bn22 = nn.BatchNorm2d(32) 28 | self.do22 = nn.Dropout2d(p=0.2) 29 | 30 | self.conv31 = nn.Conv2d(32, 64, kernel_size=3, padding=1) 31 | self.bn31 = nn.BatchNorm2d(64) 32 | self.do31 = nn.Dropout2d(p=0.2) 33 | self.conv32 = nn.Conv2d(64, 64, kernel_size=3, padding=1) 34 | self.bn32 = nn.BatchNorm2d(64) 35 | self.do32 = nn.Dropout2d(p=0.2) 36 | self.conv33 = nn.Conv2d(64, 64, kernel_size=3, padding=1) 37 | self.bn33 = nn.BatchNorm2d(64) 38 | self.do33 = nn.Dropout2d(p=0.2) 39 | 40 | self.conv41 = nn.Conv2d(64, 128, kernel_size=3, padding=1) 41 | self.bn41 = nn.BatchNorm2d(128) 42 | self.do41 = nn.Dropout2d(p=0.2) 43 | self.conv42 = nn.Conv2d(128, 128, kernel_size=3, padding=1) 44 | self.bn42 = nn.BatchNorm2d(128) 45 | self.do42 = nn.Dropout2d(p=0.2) 46 | self.conv43 = nn.Conv2d(128, 128, kernel_size=3, padding=1) 47 | self.bn43 = nn.BatchNorm2d(128) 48 | self.do43 = nn.Dropout2d(p=0.2) 49 | 50 | 51 | self.upconv4 = nn.ConvTranspose2d(128, 128, kernel_size=3, padding=1, stride=2, output_padding=1) 52 | 53 | self.conv43d = nn.ConvTranspose2d(256, 128, kernel_size=3, padding=1) 54 | self.bn43d = nn.BatchNorm2d(128) 55 | self.do43d = nn.Dropout2d(p=0.2) 56 | self.conv42d = nn.ConvTranspose2d(128, 128, kernel_size=3, padding=1) 57 | self.bn42d = nn.BatchNorm2d(128) 58 | self.do42d = nn.Dropout2d(p=0.2) 59 | self.conv41d = nn.ConvTranspose2d(128, 64, kernel_size=3, padding=1) 60 | self.bn41d = nn.BatchNorm2d(64) 61 | self.do41d = nn.Dropout2d(p=0.2) 62 | 63 | self.upconv3 = nn.ConvTranspose2d(64, 64, kernel_size=3, padding=1, stride=2, output_padding=1) 64 | 65 | self.conv33d = nn.ConvTranspose2d(128, 64, kernel_size=3, padding=1) 66 | self.bn33d = nn.BatchNorm2d(64) 67 | self.do33d = nn.Dropout2d(p=0.2) 68 | self.conv32d = nn.ConvTranspose2d(64, 64, kernel_size=3, padding=1) 69 | self.bn32d = nn.BatchNorm2d(64) 70 | self.do32d = nn.Dropout2d(p=0.2) 71 | self.conv31d = nn.ConvTranspose2d(64, 32, kernel_size=3, padding=1) 72 | self.bn31d = nn.BatchNorm2d(32) 73 | self.do31d = nn.Dropout2d(p=0.2) 74 | 75 | self.upconv2 = nn.ConvTranspose2d(32, 32, kernel_size=3, padding=1, stride=2, output_padding=1) 76 | 77 | self.conv22d = nn.ConvTranspose2d(64, 32, kernel_size=3, padding=1) 78 | self.bn22d = nn.BatchNorm2d(32) 79 | self.do22d = nn.Dropout2d(p=0.2) 80 | self.conv21d = nn.ConvTranspose2d(32, 16, kernel_size=3, padding=1) 81 | self.bn21d = nn.BatchNorm2d(16) 82 | self.do21d = nn.Dropout2d(p=0.2) 83 | 84 | self.upconv1 = nn.ConvTranspose2d(16, 16, kernel_size=3, padding=1, stride=2, output_padding=1) 85 | 86 | self.conv12d = nn.ConvTranspose2d(32, 16, kernel_size=3, padding=1) 87 | self.bn12d = nn.BatchNorm2d(16) 88 | self.do12d = nn.Dropout2d(p=0.2) 89 | self.conv11d = nn.ConvTranspose2d(16, label_nbr, kernel_size=3, padding=1) 90 | 91 | self.sm = nn.LogSoftmax(dim=1) 92 | 93 | def forward(self, x1, x2): 94 | 95 | x = torch.cat((x1, x2), 1) 96 | 97 | """Forward method.""" 98 | # Stage 1 99 | x11 = self.do11(F.relu(self.bn11(self.conv11(x)))) 100 | x12 = self.do12(F.relu(self.bn12(self.conv12(x11)))) 101 | x1p = F.max_pool2d(x12, kernel_size=2, stride=2) 102 | 103 | # Stage 2 104 | x21 = self.do21(F.relu(self.bn21(self.conv21(x1p)))) 105 | x22 = self.do22(F.relu(self.bn22(self.conv22(x21)))) 106 | x2p = F.max_pool2d(x22, kernel_size=2, stride=2) 107 | 108 | # Stage 3 109 | x31 = self.do31(F.relu(self.bn31(self.conv31(x2p)))) 110 | x32 = self.do32(F.relu(self.bn32(self.conv32(x31)))) 111 | x33 = self.do33(F.relu(self.bn33(self.conv33(x32)))) 112 | x3p = F.max_pool2d(x33, kernel_size=2, stride=2) 113 | 114 | # Stage 4 115 | x41 = self.do41(F.relu(self.bn41(self.conv41(x3p)))) 116 | x42 = self.do42(F.relu(self.bn42(self.conv42(x41)))) 117 | x43 = self.do43(F.relu(self.bn43(self.conv43(x42)))) 118 | x4p = F.max_pool2d(x43, kernel_size=2, stride=2) 119 | 120 | 121 | # Stage 4d 122 | x4d = self.upconv4(x4p) 123 | pad4 = ReplicationPad2d((0, x43.size(3) - x4d.size(3), 0, x43.size(2) - x4d.size(2))) 124 | x4d = torch.cat((pad4(x4d), x43), 1) 125 | x43d = self.do43d(F.relu(self.bn43d(self.conv43d(x4d)))) 126 | x42d = self.do42d(F.relu(self.bn42d(self.conv42d(x43d)))) 127 | x41d = self.do41d(F.relu(self.bn41d(self.conv41d(x42d)))) 128 | 129 | # Stage 3d 130 | x3d = self.upconv3(x41d) 131 | pad3 = ReplicationPad2d((0, x33.size(3) - x3d.size(3), 0, x33.size(2) - x3d.size(2))) 132 | x3d = torch.cat((pad3(x3d), x33), 1) 133 | x33d = self.do33d(F.relu(self.bn33d(self.conv33d(x3d)))) 134 | x32d = self.do32d(F.relu(self.bn32d(self.conv32d(x33d)))) 135 | x31d = self.do31d(F.relu(self.bn31d(self.conv31d(x32d)))) 136 | 137 | # Stage 2d 138 | x2d = self.upconv2(x31d) 139 | pad2 = ReplicationPad2d((0, x22.size(3) - x2d.size(3), 0, x22.size(2) - x2d.size(2))) 140 | x2d = torch.cat((pad2(x2d), x22), 1) 141 | x22d = self.do22d(F.relu(self.bn22d(self.conv22d(x2d)))) 142 | x21d = self.do21d(F.relu(self.bn21d(self.conv21d(x22d)))) 143 | 144 | # Stage 1d 145 | x1d = self.upconv1(x21d) 146 | pad1 = ReplicationPad2d((0, x12.size(3) - x1d.size(3), 0, x12.size(2) - x1d.size(2))) 147 | x1d = torch.cat((pad1(x1d), x12), 1) 148 | x12d = self.do12d(F.relu(self.bn12d(self.conv12d(x1d)))) 149 | x11d = self.conv11d(x12d) 150 | 151 | output = [] 152 | output.append(x11d) 153 | 154 | return output 155 | 156 | 157 | -------------------------------------------------------------------------------- /STADE-CDNet/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .resnet import * 2 | -------------------------------------------------------------------------------- /STADE-CDNet/models/basic_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | 5 | from misc.imutils import save_image 6 | from models.networks import * 7 | 8 | 9 | class CDEvaluator(): 10 | 11 | def __init__(self, args): 12 | 13 | self.n_class = args.n_class 14 | # define G 15 | self.net_G = define_G(args=args, gpu_ids=args.gpu_ids) 16 | 17 | self.device = torch.device("cuda:%s" % args.gpu_ids[0] 18 | if torch.cuda.is_available() and len(args.gpu_ids)>0 19 | else "cpu") 20 | 21 | print(self.device) 22 | 23 | self.checkpoint_dir = '/where' 24 | 25 | self.pred_dir = args.output_folder 26 | os.makedirs(self.pred_dir, exist_ok=True) 27 | 28 | def load_checkpoint(self, checkpoint_name='best_ckpt.pt'): 29 | 30 | if os.path.exists(os.path.join(self.checkpoint_dir, checkpoint_name)): 31 | # load the entire checkpoint 32 | checkpoint = torch.load(os.path.join(self.checkpoint_dir, checkpoint_name), 33 | map_location=self.device) 34 | 35 | self.net_G.load_state_dict(checkpoint['model_G_state_dict']) 36 | self.net_G.to(self.device) 37 | # update some other states 38 | self.best_val_acc = checkpoint['best_val_acc'] 39 | self.best_epoch_id = checkpoint['best_epoch_id'] 40 | 41 | else: 42 | raise FileNotFoundError('no such checkpoint %s' % checkpoint_name) 43 | return self.net_G 44 | 45 | 46 | def _visualize_pred(self): 47 | pred = torch.argmax(self.G_pred, dim=1, keepdim=True) 48 | pred_vis = pred * 255 49 | return pred_vis 50 | 51 | def _forward_pass(self, batch): 52 | self.batch = batch 53 | img_in1 = batch['A'].to(self.device) 54 | img_in2 = batch['B'].to(self.device) 55 | self.shape_h = img_in1.shape[-2] 56 | self.shape_w = img_in1.shape[-1] 57 | self.G_pred = self.net_G(img_in1, img_in2)[-1] 58 | return self._visualize_pred() 59 | 60 | def eval(self): 61 | self.net_G.eval() 62 | 63 | def _save_predictions(self): 64 | """ 65 | 保存模型输出结果,二分类图像 66 | """ 67 | 68 | preds = self._visualize_pred() 69 | name = self.batch['name'] 70 | for i, pred in enumerate(preds): 71 | file_name = os.path.join( 72 | self.pred_dir, name[i].replace('.jpg', '.png')) 73 | pred = pred[0].cpu().numpy() 74 | save_image(pred, file_name) 75 | 76 | -------------------------------------------------------------------------------- /STADE-CDNet/models/evaluator.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import matplotlib.pyplot as plt 4 | 5 | from models.networks import * 6 | from misc.metric_tool import ConfuseMatrixMeter 7 | from misc.logger_tool import Logger 8 | from utils import de_norm 9 | import utils 10 | 11 | 12 | # Decide which device we want to run on 13 | # torch.cuda.current_device() 14 | 15 | # device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 16 | 17 | 18 | class CDEvaluator(): 19 | 20 | def __init__(self, args, dataloader): 21 | 22 | self.dataloader = dataloader 23 | 24 | self.n_class = args.n_class 25 | # define G 26 | self.net_G = define_G(args=args, gpu_ids=args.gpu_ids) 27 | self.device = torch.device("cuda:%s" % args.gpu_ids[0] if torch.cuda.is_available() and len(args.gpu_ids)>0 28 | else "cpu") 29 | print(self.device) 30 | 31 | # define some other vars to record the training states 32 | self.running_metric = ConfuseMatrixMeter(n_class=self.n_class) 33 | 34 | # define logger file 35 | logger_path = os.path.join(args.checkpoint_dir, 'log_test.txt') 36 | self.logger = Logger(logger_path) 37 | self.logger.write_dict_str(args.__dict__) 38 | 39 | 40 | # training log 41 | self.epoch_acc = 0 42 | self.best_val_acc = 0.0 43 | self.best_epoch_id = 0 44 | 45 | self.steps_per_epoch = len(dataloader) 46 | 47 | self.G_pred = None 48 | self.pred_vis = None 49 | self.batch = None 50 | self.is_training = False 51 | self.batch_id = 0 52 | self.epoch_id = 0 53 | self.checkpoint_dir = args.checkpoint_dir 54 | self.vis_dir = args.vis_dir 55 | 56 | # check and create model dir 57 | if os.path.exists(self.checkpoint_dir) is False: 58 | os.mkdir(self.checkpoint_dir) 59 | if os.path.exists(self.vis_dir) is False: 60 | os.mkdir(self.vis_dir) 61 | 62 | 63 | def _load_checkpoint(self, checkpoint_name='best_ckpt.pt'): 64 | 65 | if os.path.exists(os.path.join(self.checkpoint_dir, checkpoint_name)): 66 | self.logger.write('loading last checkpoint...\n') 67 | # load the entire checkpoint 68 | checkpoint = torch.load(os.path.join(self.checkpoint_dir, checkpoint_name), map_location=self.device) 69 | 70 | self.net_G.load_state_dict(checkpoint['model_G_state_dict']) 71 | 72 | self.net_G.to(self.device) 73 | 74 | # update some other states 75 | self.best_val_acc = checkpoint['best_val_acc'] 76 | self.best_epoch_id = checkpoint['best_epoch_id'] 77 | 78 | self.logger.write('Eval Historical_best_acc = %.4f (at epoch %d)\n' % 79 | (self.best_val_acc, self.best_epoch_id)) 80 | self.logger.write('\n') 81 | 82 | else: 83 | raise FileNotFoundError('no such checkpoint %s' % checkpoint_name) 84 | 85 | 86 | def _visualize_pred(self): 87 | pred = torch.argmax(self.G_pred, dim=1, keepdim=True) 88 | pred_vis = pred * 255 89 | return pred_vis 90 | 91 | 92 | def _update_metric(self): 93 | """ 94 | update metric 95 | """ 96 | target = self.batch['L'].to(self.device).detach() 97 | G_pred = self.G_pred.detach() 98 | G_pred = torch.argmax(G_pred, dim=1) 99 | 100 | current_score = self.running_metric.update_cm(pr=G_pred.cpu().numpy(), gt=target.cpu().numpy()) 101 | return current_score 102 | 103 | def _collect_running_batch_states(self): 104 | 105 | running_acc = self._update_metric() 106 | 107 | m = len(self.dataloader) 108 | 109 | if np.mod(self.batch_id, 1) == 1: 110 | message = 'Is_training: %s. [%d,%d], running_mf1: %.5f\n' %\ 111 | (self.is_training, self.batch_id, m, running_acc) 112 | self.logger.write(message) 113 | 114 | if np.mod(self.batch_id, 1) == 1: 115 | vis_input = utils.make_numpy_grid(de_norm(self.batch['A'])) 116 | vis_input2 = utils.make_numpy_grid(de_norm(self.batch['B'])) 117 | 118 | vis_pred = utils.make_numpy_grid(self._visualize_pred()) 119 | 120 | vis_gt = utils.make_numpy_grid(self.batch['L']) 121 | vis = np.concatenate([vis_input, vis_input2, vis_pred, vis_gt], axis=0) 122 | vis = np.clip(vis, a_min=0.0, a_max=1.0) 123 | file_name = os.path.join( 124 | self.vis_dir, 'eval_' + str(self.batch_id)+'.jpg') 125 | plt.imsave(file_name, vis) 126 | 127 | 128 | def _collect_epoch_states(self): 129 | 130 | scores_dict = self.running_metric.get_scores() 131 | 132 | np.save(os.path.join(self.checkpoint_dir, 'scores_dict.npy'), scores_dict) 133 | 134 | self.epoch_acc = scores_dict['mf1'] 135 | 136 | with open(os.path.join(self.checkpoint_dir, '%s.txt' % (self.epoch_acc)), 137 | mode='a') as file: 138 | pass 139 | 140 | message = '' 141 | for k, v in scores_dict.items(): 142 | message += '%s: %.5f ' % (k, v) 143 | self.logger.write('%s\n' % message) # save the message 144 | 145 | self.logger.write('\n') 146 | 147 | def _clear_cache(self): 148 | self.running_metric.clear() 149 | 150 | def _forward_pass(self, batch): 151 | self.batch = batch 152 | img_in1 = batch['A'].to(self.device) 153 | img_in2 = batch['B'].to(self.device) 154 | self.G_pred = self.net_G(img_in1, img_in2)[-1] 155 | 156 | def eval_models(self,checkpoint_name='best_ckpt.pt'): 157 | 158 | self._load_checkpoint(checkpoint_name) 159 | 160 | ################## Eval ################## 161 | ########################################## 162 | self.logger.write('Begin evaluation...\n') 163 | self._clear_cache() 164 | self.is_training = False 165 | self.net_G.eval() 166 | 167 | # Iterate over data. 168 | for self.batch_id, batch in enumerate(self.dataloader, 0): 169 | with torch.no_grad(): 170 | self._forward_pass(batch) 171 | self._collect_running_batch_states() 172 | self._collect_epoch_states() 173 | -------------------------------------------------------------------------------- /STADE-CDNet/models/help_funcs.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from einops import rearrange 4 | from torch import nn 5 | 6 | 7 | class TwoLayerConv2d(nn.Sequential): 8 | def __init__(self, in_channels, out_channels, kernel_size=3): 9 | super().__init__(nn.Conv2d(in_channels, in_channels, kernel_size=kernel_size, 10 | padding=kernel_size // 2, stride=1, bias=False), 11 | nn.BatchNorm2d(in_channels), 12 | nn.ReLU(), 13 | nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, 14 | padding=kernel_size // 2, stride=1) 15 | ) 16 | 17 | 18 | class Residual(nn.Module): 19 | def __init__(self, fn): 20 | super().__init__() 21 | self.fn = fn 22 | def forward(self, x, **kwargs): 23 | return self.fn(x, **kwargs) + x 24 | 25 | 26 | class Residual2(nn.Module): 27 | def __init__(self, fn): 28 | super().__init__() 29 | self.fn = fn 30 | def forward(self, x, x2, **kwargs): 31 | return self.fn(x, x2, **kwargs) + x 32 | 33 | 34 | class PreNorm(nn.Module): 35 | def __init__(self, dim, fn): 36 | super().__init__() 37 | self.norm = nn.LayerNorm(dim) 38 | self.fn = fn 39 | def forward(self, x, **kwargs): 40 | return self.fn(self.norm(x), **kwargs) 41 | 42 | 43 | class PreNorm2(nn.Module): 44 | def __init__(self, dim, fn): 45 | super().__init__() 46 | self.norm = nn.LayerNorm(dim) 47 | self.fn = fn 48 | def forward(self, x, x2, **kwargs): 49 | return self.fn(self.norm(x), self.norm(x2), **kwargs) 50 | 51 | 52 | class FeedForward(nn.Module): 53 | def __init__(self, dim, hidden_dim, dropout = 0.): 54 | super().__init__() 55 | self.net = nn.Sequential( 56 | nn.Linear(dim, hidden_dim), 57 | nn.GELU(), 58 | nn.Dropout(dropout), 59 | nn.Linear(hidden_dim, dim), 60 | nn.Dropout(dropout) 61 | ) 62 | def forward(self, x): 63 | return self.net(x) 64 | 65 | 66 | class Cross_Attention(nn.Module): 67 | def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0., softmax=True): 68 | super().__init__() 69 | inner_dim = dim_head * heads 70 | self.heads = heads 71 | self.scale = dim ** -0.5 72 | 73 | self.softmax = softmax 74 | self.to_q = nn.Linear(dim, inner_dim, bias=False) 75 | self.to_k = nn.Linear(dim, inner_dim, bias=False) 76 | self.to_v = nn.Linear(dim, inner_dim, bias=False) 77 | 78 | self.to_out = nn.Sequential( 79 | nn.Linear(inner_dim, dim), 80 | nn.Dropout(dropout) 81 | ) 82 | 83 | def forward(self, x, m, mask = None): 84 | 85 | b, n, _, h = *x.shape, self.heads 86 | q = self.to_q(x) 87 | k = self.to_k(m) 88 | v = self.to_v(m) 89 | 90 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), [q,k,v]) 91 | 92 | dots = torch.einsum('bhid,bhjd->bhij', q, k) * self.scale 93 | mask_value = -torch.finfo(dots.dtype).max 94 | 95 | if mask is not None: 96 | mask = F.pad(mask.flatten(1), (1, 0), value = True) 97 | assert mask.shape[-1] == dots.shape[-1], 'mask has incorrect dimensions' 98 | mask = mask[:, None, :] * mask[:, :, None] 99 | dots.masked_fill_(~mask, mask_value) 100 | del mask 101 | 102 | if self.softmax: 103 | attn = dots.softmax(dim=-1) 104 | else: 105 | attn = dots 106 | # attn = dots 107 | # vis_tmp(dots) 108 | 109 | out = torch.einsum('bhij,bhjd->bhid', attn, v) 110 | out = rearrange(out, 'b h n d -> b n (h d)') 111 | out = self.to_out(out) 112 | # vis_tmp2(out) 113 | 114 | return out 115 | 116 | 117 | class Attention(nn.Module): 118 | def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): 119 | super().__init__() 120 | inner_dim = dim_head * heads 121 | self.heads = heads 122 | self.scale = dim ** -0.5 123 | 124 | self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) 125 | self.to_out = nn.Sequential( 126 | nn.Linear(inner_dim, dim), 127 | nn.Dropout(dropout) 128 | ) 129 | 130 | def forward(self, x, mask = None): 131 | b, n, _, h = *x.shape, self.heads 132 | qkv = self.to_qkv(x).chunk(3, dim = -1) 133 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv) 134 | 135 | dots = torch.einsum('bhid,bhjd->bhij', q, k) * self.scale 136 | mask_value = -torch.finfo(dots.dtype).max 137 | 138 | if mask is not None: 139 | mask = F.pad(mask.flatten(1), (1, 0), value = True) 140 | assert mask.shape[-1] == dots.shape[-1], 'mask has incorrect dimensions' 141 | mask = mask[:, None, :] * mask[:, :, None] 142 | dots.masked_fill_(~mask, mask_value) 143 | del mask 144 | 145 | attn = dots.softmax(dim=-1) 146 | 147 | 148 | out = torch.einsum('bhij,bhjd->bhid', attn, v) 149 | out = rearrange(out, 'b h n d -> b n (h d)') 150 | out = self.to_out(out) 151 | return out 152 | 153 | 154 | class Transformer(nn.Module): 155 | def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout): 156 | super().__init__() 157 | self.layers = nn.ModuleList([]) 158 | for _ in range(depth): 159 | self.layers.append(nn.ModuleList([ 160 | Residual(PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout))), 161 | Residual(PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))) 162 | ])) 163 | def forward(self, x, mask = None): 164 | for attn, ff in self.layers: 165 | x = attn(x, mask = mask) 166 | x = ff(x) 167 | return x 168 | 169 | 170 | class TransformerDecoder(nn.Module): 171 | def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout, softmax=True): 172 | super().__init__() 173 | self.layers = nn.ModuleList([]) 174 | for _ in range(depth): 175 | self.layers.append(nn.ModuleList([ 176 | Residual2(PreNorm2(dim, Cross_Attention(dim, heads = heads, 177 | dim_head = dim_head, dropout = dropout, 178 | softmax=softmax))), 179 | Residual(PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))) 180 | ])) 181 | def forward(self, x, m, mask = None): 182 | """target(query), memory""" 183 | for attn, ff in self.layers: 184 | x = attn(x, m, mask = mask) 185 | x = ff(x) 186 | return x 187 | 188 | from scipy.io import savemat 189 | def save_to_mat(x1, x2, fx1, fx2, cp, file_name): 190 | #Save to mat files 191 | x1_np = x1.detach().cpu().numpy() 192 | x2_np = x2.detach().cpu().numpy() 193 | 194 | fx1_0_np = fx1[0].detach().cpu().numpy() 195 | fx2_0_np = fx2[0].detach().cpu().numpy() 196 | fx1_1_np = fx1[1].detach().cpu().numpy() 197 | fx2_1_np = fx2[1].detach().cpu().numpy() 198 | fx1_2_np = fx1[2].detach().cpu().numpy() 199 | fx2_2_np = fx2[2].detach().cpu().numpy() 200 | fx1_3_np = fx1[3].detach().cpu().numpy() 201 | fx2_3_np = fx2[3].detach().cpu().numpy() 202 | fx1_4_np = fx1[4].detach().cpu().numpy() 203 | fx2_4_np = fx2[4].detach().cpu().numpy() 204 | 205 | cp_np = cp[-1].detach().cpu().numpy() 206 | 207 | mdic = {'x1': x1_np, 'x2': x2_np, 208 | 'fx1_0': fx1_0_np, 'fx1_1': fx1_1_np, 'fx1_2': fx1_2_np, 'fx1_3': fx1_3_np, 'fx1_4': fx1_4_np, 209 | 'fx2_0': fx2_0_np, 'fx2_1': fx2_1_np, 'fx2_2': fx2_2_np, 'fx2_3': fx2_3_np, 'fx2_4': fx2_4_np, 210 | "final_pred": cp_np} 211 | 212 | savemat("/media/lidan/ssd2/ChangeFormer/vis/mat/"+file_name+".mat", mdic) 213 | 214 | 215 | 216 | -------------------------------------------------------------------------------- /STADE-CDNet/models/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import numpy as np 4 | import torch.nn as nn 5 | 6 | def cross_entropy(input, target, weight=None, reduction='mean',ignore_index=255): 7 | """ 8 | logSoftmax_with_loss 9 | :param input: torch.Tensor, N*C*H*W 10 | :param target: torch.Tensor, N*1*H*W,/ N*H*W 11 | :param weight: torch.Tensor, C 12 | :return: torch.Tensor [0] 13 | """ 14 | target = target.long() 15 | if target.dim() == 4: 16 | target = torch.squeeze(target, dim=1) 17 | if input.shape[-1] != target.shape[-1]: 18 | input = F.interpolate(input, size=target.shape[1:], mode='bilinear',align_corners=True) 19 | 20 | return F.cross_entropy(input=input, target=target, weight=weight, 21 | ignore_index=ignore_index, reduction=reduction) 22 | 23 | #Focal Loss 24 | def get_alpha(supervised_loader): 25 | # get number of classes 26 | num_labels = 0 27 | for batch in supervised_loader: 28 | label_batch = batch['L'] 29 | label_batch.data[label_batch.data==255] = 0 # pixels of ignore class added to background 30 | l_unique = torch.unique(label_batch.data) 31 | list_unique = [element.item() for element in l_unique.flatten()] 32 | num_labels = max(max(list_unique),num_labels) 33 | num_classes = num_labels + 1 34 | # count class occurrences 35 | alpha = [0 for i in range(num_classes)] 36 | for batch in supervised_loader: 37 | label_batch = batch['L'] 38 | label_batch.data[label_batch.data==255] = 0 # pixels of ignore class added to background 39 | l_unique = torch.unique(label_batch.data) 40 | list_unique = [element.item() for element in l_unique.flatten()] 41 | l_unique_count = torch.stack([(label_batch.data==x_u).sum() for x_u in l_unique]) # tensor([65920, 36480]) 42 | list_count = [count.item() for count in l_unique_count.flatten()] 43 | for index in list_unique: 44 | alpha[index] += list_count[list_unique.index(index)] 45 | return alpha 46 | 47 | # for FocalLoss 48 | def softmax_helper(x): 49 | # copy from: https://github.com/MIC-DKFZ/nnUNet/blob/master/nnunet/utilities/nd_softmax.py 50 | rpt = [1 for _ in range(len(x.size()))] 51 | rpt[1] = x.size(1) 52 | x_max = x.max(1, keepdim=True)[0].repeat(*rpt) 53 | e_x = torch.exp(x - x_max) 54 | return e_x / e_x.sum(1, keepdim=True).repeat(*rpt) 55 | 56 | class FocalLoss(nn.Module): 57 | """ 58 | copy from: https://github.com/Hsuxu/Loss_ToolBox-PyTorch/blob/master/FocalLoss/FocalLoss.py 59 | This is a implementation of Focal Loss with smooth label cross entropy supported which is proposed in 60 | 'Focal Loss for Dense Object Detection. (https://arxiv.org/abs/1708.02002)' 61 | Focal_Loss= -1*alpha*(1-pt)*log(pt) 62 | :param num_class: 63 | :param alpha: (tensor) 3D or 4D the scalar factor for this criterion 64 | :param gamma: (float,double) gamma > 0 reduces the relative loss for well-classified examples (p>0.5) putting more 65 | focus on hard misclassified example 66 | :param smooth: (float,double) smooth value when cross entropy 67 | :param balance_index: (int) balance class index, should be specific when alpha is float 68 | :param size_average: (bool, optional) By default, the losses are averaged over each loss element in the batch. 69 | """ 70 | 71 | def __init__(self, apply_nonlin=None, alpha=None, gamma=1, balance_index=0, smooth=1e-5, size_average=True): 72 | super(FocalLoss, self).__init__() 73 | self.apply_nonlin = apply_nonlin 74 | self.alpha = alpha 75 | self.gamma = gamma 76 | self.balance_index = balance_index 77 | self.smooth = smooth 78 | self.size_average = size_average 79 | 80 | if self.smooth is not None: 81 | if self.smooth < 0 or self.smooth > 1.0: 82 | raise ValueError('smooth value should be in [0,1]') 83 | 84 | def forward(self, logit, target): 85 | if self.apply_nonlin is not None: 86 | logit = self.apply_nonlin(logit) 87 | num_class = logit.shape[1] 88 | 89 | if logit.dim() > 2: 90 | # N,C,d1,d2 -> N,C,m (m=d1*d2*...) 91 | logit = logit.view(logit.size(0), logit.size(1), -1) 92 | logit = logit.permute(0, 2, 1).contiguous() 93 | logit = logit.view(-1, logit.size(-1)) 94 | target = torch.squeeze(target, 1) 95 | target = target.view(-1, 1) 96 | 97 | alpha = self.alpha 98 | 99 | if alpha is None: 100 | alpha = torch.ones(num_class, 1) 101 | elif isinstance(alpha, (list, np.ndarray)): 102 | assert len(alpha) == num_class 103 | alpha = torch.FloatTensor(alpha).view(num_class, 1) 104 | alpha = alpha / alpha.sum() 105 | alpha = 1/alpha # inverse of class frequency 106 | elif isinstance(alpha, float): 107 | alpha = torch.ones(num_class, 1) 108 | alpha = alpha * (1 - self.alpha) 109 | alpha[self.balance_index] = self.alpha 110 | 111 | else: 112 | raise TypeError('Not support alpha type') 113 | 114 | if alpha.device != logit.device: 115 | alpha = alpha.to(logit.device) 116 | 117 | idx = target.cpu().long() 118 | 119 | one_hot_key = torch.FloatTensor(target.size(0), num_class).zero_() 120 | 121 | # to resolve error in idx in scatter_ 122 | idx[idx==225]=0 123 | 124 | one_hot_key = one_hot_key.scatter_(1, idx, 1) 125 | if one_hot_key.device != logit.device: 126 | one_hot_key = one_hot_key.to(logit.device) 127 | 128 | if self.smooth: 129 | one_hot_key = torch.clamp( 130 | one_hot_key, self.smooth/(num_class-1), 1.0 - self.smooth) 131 | pt = (one_hot_key * logit).sum(1) + self.smooth 132 | logpt = pt.log() 133 | 134 | gamma = self.gamma 135 | 136 | alpha = alpha[idx] 137 | alpha = torch.squeeze(alpha) 138 | loss = -1 * alpha * torch.pow((1 - pt), gamma) * logpt 139 | 140 | if self.size_average: 141 | loss = loss.mean() 142 | else: 143 | loss = loss.sum() 144 | return loss 145 | 146 | 147 | #miou loss 148 | from torch.autograd import Variable 149 | def to_one_hot_var(tensor, nClasses, requires_grad=False): 150 | 151 | n, h, w = torch.squeeze(tensor, dim=1).size() 152 | one_hot = tensor.new(n, nClasses, h, w).fill_(0) 153 | one_hot = one_hot.scatter_(1, tensor.type(torch.int64).view(n, 1, h, w), 1) 154 | return Variable(one_hot, requires_grad=requires_grad) 155 | 156 | class mIoULoss(nn.Module): 157 | def __init__(self, weight=None, size_average=True, n_classes=2): 158 | super(mIoULoss, self).__init__() 159 | self.classes = n_classes 160 | self.weights = Variable(weight) 161 | 162 | def forward(self, inputs, target, is_target_variable=False): 163 | # inputs => N x Classes x H x W 164 | # target => N x H x W 165 | # target_oneHot => N x Classes x H x W 166 | 167 | N = inputs.size()[0] 168 | if is_target_variable: 169 | target_oneHot = to_one_hot_var(target.data, self.classes).float() 170 | else: 171 | target_oneHot = to_one_hot_var(target, self.classes).float() 172 | 173 | # predicted probabilities for each pixel along channel 174 | inputs = F.softmax(inputs, dim=1) 175 | 176 | # Numerator Product 177 | inter = inputs * target_oneHot 178 | ## Sum over all pixels N x C x H x W => N x C 179 | inter = inter.view(N, self.classes, -1).sum(2) 180 | 181 | # Denominator 182 | union = inputs + target_oneHot - (inputs * target_oneHot) 183 | ## Sum over all pixels N x C x H x W => N x C 184 | union = union.view(N, self.classes, -1).sum(2) 185 | 186 | loss = (self.weights * inter) / (union + 1e-8) 187 | 188 | ## Return average loss over classes and batch 189 | return -torch.mean(loss) 190 | 191 | #Minimax iou 192 | class mmIoULoss(nn.Module): 193 | def __init__(self, n_classes=2): 194 | super(mmIoULoss, self).__init__() 195 | self.classes = n_classes 196 | 197 | def forward(self, inputs, target, is_target_variable=False): 198 | # inputs => N x Classes x H x W 199 | # target => N x H x W 200 | # target_oneHot => N x Classes x H x W 201 | 202 | N = inputs.size()[0] 203 | if is_target_variable: 204 | target_oneHot = to_one_hot_var(target.data, self.classes).float() 205 | else: 206 | target_oneHot = to_one_hot_var(target, self.classes).float() 207 | 208 | # predicted probabilities for each pixel along channel 209 | inputs = F.softmax(inputs, dim=1) 210 | 211 | # Numerator Product 212 | inter = inputs * target_oneHot 213 | ## Sum over all pixels N x C x H x W => N x C 214 | inter = inter.view(N, self.classes, -1).sum(2) 215 | 216 | # Denominator 217 | union = inputs + target_oneHot - (inputs * target_oneHot) 218 | ## Sum over all pixels N x C x H x W => N x C 219 | union = union.view(N, self.classes, -1).sum(2) 220 | 221 | iou = inter/ (union + 1e-8) 222 | 223 | #minimum iou of two classes 224 | min_iou = torch.min(iou) 225 | 226 | #loss 227 | loss = -min_iou-torch.mean(iou) 228 | return loss 229 | -------------------------------------------------------------------------------- /STADE-CDNet/models/networks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import init 4 | import torch.nn.functional as F 5 | from torch.optim import lr_scheduler 6 | import numpy as np 7 | import functools 8 | from einops import rearrange 9 | import cv2 10 | import models 11 | from models.help_funcs import Transformer, TransformerDecoder, TwoLayerConv2d 12 | from models.ChangeFormer import ChangeFormerV1, ChangeFormerV2, ChangeFormerV3, ChangeFormerV4, ChangeFormerV5, ChangeFormerV6 13 | from models.SiamUnet_diff import SiamUnet_diff 14 | from models.SiamUnet_conc import SiamUnet_conc 15 | from models.Unet import Unet 16 | from models.DTCDSCN import CDNet34 17 | device = "cuda" if torch.cuda.is_available() else "cpu" 18 | 19 | ############################################################################### 20 | # Helper Functions 21 | ############################################################################### 22 | 23 | def get_scheduler(optimizer, args): 24 | """Return a learning rate scheduler 25 | 26 | Parameters: 27 | optimizer -- the optimizer of the network 28 | args (option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions.  29 | opt.lr_policy is the name of learning rate policy: linear | step | plateau | cosine 30 | 31 | For 'linear', we keep the same learning rate for the first epochs 32 | and linearly decay the rate to zero over the next epochs. 33 | For other schedulers (step, plateau, and cosine), we use the default PyTorch schedulers. 34 | See https://pytorch.org/docs/stable/optim.html for more details. 35 | """ 36 | if args.lr_policy == 'linear': 37 | def lambda_rule(epoch): 38 | lr_l = 1- epoch/ float(args.max_epochs + 1) 39 | return lr_l 40 | scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule) 41 | elif args.lr_policy == 'step': 42 | step_size = args.max_epochs//3 43 | # args.lr_decay_iters 44 | scheduler = lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=0.1) 45 | else: 46 | return NotImplementedError('learning rate policy [%s] is not implemented', args.lr_policy) 47 | return scheduler 48 | 49 | 50 | class Identity(nn.Module): 51 | def forward(self, x): 52 | return x 53 | 54 | 55 | def get_norm_layer(norm_type='instance'): 56 | """Return a normalization layer 57 | 58 | Parameters: 59 | norm_type (str) -- the name of the normalization layer: batch | instance | none 60 | 61 | For BatchNorm, we use learnable affine parameters and track running statistics (mean/stddev). 62 | For InstanceNorm, we do not use learnable affine parameters. We do not track running statistics. 63 | """ 64 | if norm_type == 'batch': 65 | norm_layer = functools.partial(nn.BatchNorm2d, affine=True, track_running_stats=True) 66 | elif norm_type == 'instance': 67 | norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False) 68 | elif norm_type == 'none': 69 | norm_layer = lambda x: Identity() 70 | else: 71 | raise NotImplementedError('normalization layer [%s] is not found' % norm_type) 72 | return norm_layer 73 | 74 | 75 | def init_weights(net, init_type='normal', init_gain=0.02): 76 | """Initialize network weights. 77 | 78 | Parameters: 79 | net (network) -- network to be initialized 80 | init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal 81 | init_gain (float) -- scaling factor for normal, xavier and orthogonal. 82 | 83 | We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might 84 | work better for some applications. Feel free to try yourself. 85 | """ 86 | def init_func(m): # define the initialization function 87 | classname = m.__class__.__name__ 88 | if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): 89 | if init_type == 'normal': 90 | init.normal_(m.weight.data, 0.0, init_gain) 91 | elif init_type == 'xavier': 92 | init.xavier_normal_(m.weight.data, gain=init_gain) 93 | elif init_type == 'kaiming': 94 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 95 | elif init_type == 'orthogonal': 96 | init.orthogonal_(m.weight.data, gain=init_gain) 97 | else: 98 | raise NotImplementedError('initialization method [%s] is not implemented' % init_type) 99 | if hasattr(m, 'bias') and m.bias is not None: 100 | init.constant_(m.bias.data, 0.0) 101 | elif classname.find('BatchNorm2d') != -1: # BatchNorm Layer's weight is not a matrix; only normal distribution applies. 102 | init.normal_(m.weight.data, 1.0, init_gain) 103 | init.constant_(m.bias.data, 0.0) 104 | 105 | print('initialize network with %s' % init_type) 106 | net.apply(init_func) # apply the initialization function 107 | 108 | 109 | def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[]): 110 | """Initialize a network: 1. register CPU/GPU device (with multi-GPU support); 2. initialize the network weights 111 | Parameters: 112 | net (network) -- the network to be initialized 113 | init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal 114 | gain (float) -- scaling factor for normal, xavier and orthogonal. 115 | gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2 116 | 117 | Return an initialized network. 118 | """ 119 | if len(gpu_ids) > 0: 120 | assert(torch.cuda.is_available()) 121 | net.to(gpu_ids[0]) 122 | if len(gpu_ids) > 1: 123 | net = torch.nn.DataParallel(net, gpu_ids) # multi-GPUs 124 | init_weights(net, init_type, init_gain=init_gain) 125 | return net 126 | 127 | 128 | def define_G(args, init_type='normal', init_gain=0.02, gpu_ids=[]): 129 | if args.net_G == 'base_resnet18': 130 | net = ResNet(input_nc=3, output_nc=2, output_sigmoid=False) 131 | 132 | elif args.net_G == 'base_transformer_pos_s4': 133 | net = BASE_Transformer(input_nc=3, output_nc=2, token_len=4, resnet_stages_num=4, 134 | with_pos='learned') 135 | 136 | elif args.net_G == 'base_transformer_pos_s4_dd8': 137 | net = BASE_Transformer(input_nc=3, output_nc=2, token_len=4, resnet_stages_num=4, 138 | with_pos='learned', enc_depth=1, dec_depth=8) 139 | 140 | elif args.net_G == 'base_transformer_pos_s4_dd8_dedim8': 141 | net = BASE_Transformer(input_nc=3, output_nc=2, token_len=4, resnet_stages_num=4, 142 | with_pos='learned', enc_depth=1, dec_depth=8, decoder_dim_head=8) 143 | 144 | elif args.net_G == 'ChangeFormerV1': 145 | net = ChangeFormerV1() #ChangeFormer with Transformer Encoder and Convolutional Decoder 146 | 147 | elif args.net_G == 'ChangeFormerV2': 148 | net = ChangeFormerV2() #ChangeFormer with Transformer Encoder and Convolutional Decoder 149 | 150 | elif args.net_G == 'ChangeFormerV3': 151 | net = ChangeFormerV3() #ChangeFormer with Transformer Encoder and Convolutional Decoder (Fuse) 152 | 153 | elif args.net_G == 'ChangeFormerV4': 154 | net = ChangeFormerV4() #ChangeFormer with Transformer Encoder and Convolutional Decoder (Fuse) 155 | 156 | elif args.net_G == 'ChangeFormerV5': 157 | net = ChangeFormerV5(embed_dim=args.embed_dim) #ChangeFormer with Transformer Encoder and Convolutional Decoder (Fuse) 158 | 159 | elif args.net_G == 'ChangeFormerV6': 160 | net = ChangeFormerV6(embed_dim=args.embed_dim) #ChangeFormer with Transformer Encoder and Convolutional Decoder (Fuse) 161 | 162 | elif args.net_G == "SiamUnet_diff": 163 | #Implementation of ``Fully convolutional siamese networks for change detection'' 164 | #Code copied from: https://github.com/rcdaudt/fully_convolutional_change_detection 165 | net = SiamUnet_diff(input_nbr=3, label_nbr=2) 166 | 167 | elif args.net_G == "SiamUnet_conc": 168 | #Implementation of ``Fully convolutional siamese networks for change detection'' 169 | #Code copied from: https://github.com/rcdaudt/fully_convolutional_change_detection 170 | net = SiamUnet_conc(input_nbr=3, label_nbr=2) 171 | 172 | elif args.net_G == "Unet": 173 | #Usually abbreviated as FC-EF = Image Level Concatenation 174 | #Implementation of ``Fully convolutional siamese networks for change detection'' 175 | #Code copied from: https://github.com/rcdaudt/fully_convolutional_change_detection 176 | net = Unet(input_nbr=3, label_nbr=2) 177 | 178 | elif args.net_G == "DTCDSCN": 179 | #The implementation of the paper"Building Change Detection for Remote Sensing Images Using a Dual Task Constrained Deep Siamese Convolutional Network Model " 180 | #Code copied from: https://github.com/fitzpchao/DTCDSCN 181 | net = CDNet34(in_channels=3) 182 | 183 | else: 184 | raise NotImplementedError('Generator model name [%s] is not recognized' % args.net_G) 185 | return init_net(net, init_type, init_gain, gpu_ids) 186 | 187 | 188 | ############################################################################### 189 | # main Functions 190 | ############################################################################### 191 | 192 | 193 | class ResNet(torch.nn.Module): 194 | def __init__(self, input_nc, output_nc, 195 | resnet_stages_num=5, backbone='resnet18', 196 | output_sigmoid=False, if_upsample_2x=True): 197 | """ 198 | In the constructor we instantiate two nn.Linear modules and assign them as 199 | member variables. 200 | """ 201 | super(ResNet, self).__init__() 202 | expand = 1 203 | if backbone == 'resnet18': 204 | self.resnet = models.resnet18(pretrained=True, 205 | replace_stride_with_dilation=[False,True,True]) 206 | elif backbone == 'resnet34': 207 | self.resnet = models.resnet34(pretrained=True, 208 | replace_stride_with_dilation=[False,True,True]) 209 | elif backbone == 'resnet50': 210 | self.resnet = models.resnet50(pretrained=True, 211 | replace_stride_with_dilation=[False,True,True]) 212 | expand = 4 213 | else: 214 | raise NotImplementedError 215 | self.relu = nn.ReLU() 216 | self.upsamplex2 = nn.Upsample(scale_factor=2) 217 | self.upsamplex4 = nn.Upsample(scale_factor=4, mode='bilinear') 218 | 219 | self.classifier = TwoLayerConv2d(in_channels=32, out_channels=output_nc) 220 | 221 | self.resnet_stages_num = resnet_stages_num 222 | 223 | self.if_upsample_2x = if_upsample_2x 224 | if self.resnet_stages_num == 5: 225 | layers = 512 * expand 226 | elif self.resnet_stages_num == 4: 227 | layers = 256 * expand 228 | elif self.resnet_stages_num == 3: 229 | layers = 128 * expand 230 | else: 231 | raise NotImplementedError 232 | self.conv_pred = nn.Conv2d(layers, 32, kernel_size=3, padding=1) 233 | 234 | self.output_sigmoid = output_sigmoid 235 | self.sigmoid = nn.Sigmoid() 236 | 237 | def forward(self, x1, x2): 238 | x1 = self.forward_single(x1) 239 | x2 = self.forward_single(x2) 240 | x = torch.abs(x1 - x2) 241 | if not self.if_upsample_2x: 242 | x = self.upsamplex2(x) 243 | x = self.upsamplex4(x) 244 | x = self.classifier(x) 245 | 246 | if self.output_sigmoid: 247 | x = self.sigmoid(x) 248 | return x 249 | 250 | def forward_single(self, x): 251 | # resnet layers 252 | x = self.resnet.conv1(x) 253 | x = self.resnet.bn1(x) 254 | x = self.resnet.relu(x) 255 | x = self.resnet.maxpool(x) 256 | 257 | x_4 = self.resnet.layer1(x) # 1/4, in=64, out=64 258 | x_8 = self.resnet.layer2(x_4) # 1/8, in=64, out=128 259 | 260 | if self.resnet_stages_num > 3: 261 | x_8 = self.resnet.layer3(x_8) # 1/8, in=128, out=256 262 | 263 | if self.resnet_stages_num == 5: 264 | x_8 = self.resnet.layer4(x_8) # 1/32, in=256, out=512 265 | elif self.resnet_stages_num > 5: 266 | raise NotImplementedError 267 | 268 | if self.if_upsample_2x: 269 | x = self.upsamplex2(x_8) 270 | else: 271 | x = x_8 272 | # output layers 273 | x = self.conv_pred(x) 274 | return x 275 | 276 | class Rnn(nn.Module): 277 | def __init__(self, in_dim, hidden_dim, n_layer, n_class): 278 | super(Rnn, self).__init__() 279 | self.n_layer = n_layer 280 | self.hidden_dim = hidden_dim 281 | self.lstm = nn.LSTM(in_dim, hidden_dim, n_layer, 282 | batch_first=True) 283 | self.classifier = nn.Linear(hidden_dim, n_class) 284 | 285 | def forward(self, x): 286 | # h0 = Variable(torch.zeros(self.n_layer, x.size(1), 287 | # self.hidden_dim)).cuda() 288 | # c0 = Variable(torch.zeros(self.n_layer, x.size(1), 289 | # self.hidden_dim)).cuda() 290 | x = x.to(device) 291 | out, _ = self.lstm(x) 292 | 293 | return out 294 | 295 | class BASE_Transformer(ResNet): 296 | """ 297 | Resnet of 8 downsampling + BIT + bitemporal feature Differencing + a small CNN 298 | """ 299 | def __init__(self, input_nc, output_nc, with_pos, resnet_stages_num=5, 300 | token_len=4, token_trans=True, 301 | enc_depth=1, dec_depth=1, 302 | dim_head=64, decoder_dim_head=64, 303 | tokenizer=True, if_upsample_2x=True, 304 | pool_mode='max', pool_size=2, 305 | backbone='resnet18', 306 | decoder_softmax=True, with_decoder_pos=None, 307 | with_decoder=True): 308 | super(BASE_Transformer, self).__init__(input_nc, output_nc,backbone=backbone, 309 | resnet_stages_num=resnet_stages_num, 310 | if_upsample_2x=if_upsample_2x, 311 | ) 312 | self.token_len = token_len 313 | self.conv_a = nn.Conv2d(32, self.token_len, kernel_size=1, 314 | padding=0, bias=False) 315 | self.tokenizer = tokenizer 316 | if not self.tokenizer: 317 | # if not use tokenzier,then downsample the feature map into a certain size 318 | self.pooling_size = pool_size 319 | self.pool_mode = pool_mode 320 | self.token_len = self.pooling_size * self.pooling_size 321 | 322 | self.token_trans = token_trans 323 | self.with_decoder = with_decoder 324 | dim = 32 325 | mlp_dim = 2*dim 326 | 327 | self.with_pos = with_pos 328 | if with_pos == 'learned': 329 | self.pos_embedding = nn.Parameter(torch.randn(1, self.token_len*2, 32)) 330 | decoder_pos_size = 256//4 331 | self.with_decoder_pos = with_decoder_pos 332 | if self.with_decoder_pos == 'learned': 333 | self.pos_embedding_decoder =nn.Parameter(torch.randn(1, 32, 334 | decoder_pos_size, 335 | decoder_pos_size)) 336 | self.enc_depth = enc_depth 337 | self.dec_depth = dec_depth 338 | self.dim_head = dim_head 339 | self.decoder_dim_head = decoder_dim_head 340 | self.transformer = Transformer(dim=dim, depth=self.enc_depth, heads=8, 341 | dim_head=self.dim_head, 342 | mlp_dim=mlp_dim, dropout=0) 343 | self.transformer_decoder = TransformerDecoder(dim=dim, depth=self.dec_depth, 344 | heads=8, dim_head=self.decoder_dim_head, mlp_dim=mlp_dim, dropout=0, 345 | softmax=decoder_softmax) 346 | 347 | def _forward_semantic_tokens(self, x): 348 | b, c, h, w = x.shape 349 | spatial_attention = self.conv_a(x) 350 | spatial_attention = spatial_attention.view([b, self.token_len, -1]).contiguous() 351 | spatial_attention = torch.softmax(spatial_attention, dim=-1) 352 | x = x.view([b, c, -1]).contiguous() 353 | tokens = torch.einsum('bln,bcn->blc', spatial_attention, x) 354 | 355 | return tokens 356 | 357 | def _forward_reshape_tokens(self, x): 358 | # b,c,h,w = x.shape 359 | if self.pool_mode == 'max': 360 | x = F.adaptive_max_pool2d(x, [self.pooling_size, self.pooling_size]) 361 | elif self.pool_mode == 'ave': 362 | x = F.adaptive_avg_pool2d(x, [self.pooling_size, self.pooling_size]) 363 | else: 364 | x = x 365 | tokens = rearrange(x, 'b c h w -> b (h w) c') 366 | return tokens 367 | 368 | def _forward_transformer(self, x): 369 | if self.with_pos: 370 | x += self.pos_embedding 371 | x = self.transformer(x) 372 | return x 373 | 374 | def _forward_transformer_decoder(self, x, m): 375 | b, c, h, w = x.shape 376 | if self.with_decoder_pos == 'fix': 377 | x = x + self.pos_embedding_decoder 378 | elif self.with_decoder_pos == 'learned': 379 | x = x + self.pos_embedding_decoder 380 | x = rearrange(x, 'b c h w -> b (h w) c') 381 | x = self.transformer_decoder(x, m) 382 | x = rearrange(x, 'b (h w) c -> b c h w', h=h) 383 | return x 384 | 385 | def _forward_simple_decoder(self, x, m): 386 | b, c, h, w = x.shape 387 | b, l, c = m.shape 388 | m = m.expand([h,w,b,l,c]) 389 | m = rearrange(m, 'h w b l c -> l b c h w') 390 | m = m.sum(0) 391 | x = x + m 392 | return x 393 | def zl_difference(self, x): 394 | # # #zl_CDDM############### 395 | 396 | ##1 step 397 | avg_pooling = torch.nn.AdaptiveMaxPool2d((256, 256)) 398 | _c1_21 = avg_pooling(x) 399 | _c1_21 = rearrange(_c1_21, 'b c h w -> b c (h w)') 400 | ##1.1 step 401 | conv = nn.Conv1d(_c1_21.shape[1], x.shape[2], kernel_size=3, padding=1, groups=1, stride=1, bias=False).cuda() 402 | _c1_22=conv(_c1_21) 403 | nn.ReLU(), 404 | ##1.2 step 405 | conv1_LZ = nn.Conv1d(x.shape[2], x.shape[1], kernel_size=3, padding=1, groups=1, stride=1, bias=False).cuda() 406 | _c1_24=conv1_LZ(_c1_22) 407 | _c1_24=_c1_24.unsqueeze(0) 408 | _c1_24=np.transpose(_c1_24,(1,2,3,0)).cuda() 409 | nn.Sigmoid(), 410 | _c1_24=_c1_24.squeeze(3) 411 | _c1_24 = x.reshape(x.shape[0],x.shape[1],x.shape[2],x.shape[3],) 412 | #M_1J 413 | _c1_25_M_1J=_c1_24*x 414 | #2 step 415 | #2.1 step 416 | conv3_LZ = nn.Conv3d( _c1_25_M_1J.shape[0],_c1_25_M_1J.shape[0], kernel_size=3, padding=1, groups=1, stride=1, bias=False).cuda() 417 | _c1_26_M_1J=conv3_LZ(_c1_25_M_1J) 418 | nn.Sigmoid(), 419 | #2.2 step 420 | conv3_LZ = nn.Conv3d(_c1_26_M_1J.shape[0], _c1_25_M_1J.shape[0], kernel_size=3, padding=1, groups=1, stride=1, bias=False).cuda() 421 | _c1_27_M_1J=conv3_LZ(_c1_26_M_1J) 422 | #M_2J 423 | _c1_28_M_2J=_c1_27_M_1J*_c1_25_M_1J 424 | #result 425 | _c1_29_M_2J =_c1_28_M_2J.squeeze(0) 426 | out=_c1_29_M_2J 427 | # ###CDDM############### 428 | 429 | return out 430 | def forward(self, x1, x2): 431 | # forward backbone resnet 432 | x_n0=abs(x1-x2) 433 | x1 = self.forward_single(x1) 434 | x2 = self.forward_single(x2) 435 | 436 | # forward tokenzier 437 | if self.tokenizer: 438 | token1 = self._forward_semantic_tokens(x1) 439 | token2 = self._forward_semantic_tokens(x2) 440 | else: 441 | token1 = self._forward_reshape_tokens(x1) 442 | token2 = self._forward_reshape_tokens(x2) 443 | # forward transformer encoder 444 | if self.token_trans: 445 | self.tokens_ = torch.cat([token1, token2], dim=1) 446 | self.tokens = self._forward_transformer(self.tokens_) 447 | token1, token2 = self.tokens.chunk(2, dim=1) 448 | # forward transformer decoder 449 | if self.with_decoder: 450 | x1 = self._forward_transformer_decoder(x1, token1) 451 | x2 = self._forward_transformer_decoder(x2, token2) 452 | else: 453 | x1 = self._forward_simple_decoder(x1, token1) 454 | x2 = self._forward_simple_decoder(x2, token2) 455 | x101 = abs(x1-x2) 456 | if not self.if_upsample_2x: 457 | x101 = self.upsamplex2(x101) 458 | x101 = self.upsamplex4(x101) 459 | # forward small cnn 460 | x101 = self.classifier(x101) 461 | 462 | 463 | x11=x1 464 | if not self.if_upsample_2x: 465 | x11 = self.upsamplex2(x11) 466 | x11 = self.upsamplex4(x11) 467 | # forward small cnn 468 | x110 = self.classifier(x11) 469 | 470 | x22=x2 471 | if not self.if_upsample_2x: 472 | x22 = self.upsamplex2(x22) 473 | x22 = self.upsamplex4(x22) 474 | # forward small cnn 475 | x22 = self.classifier(x22) 476 | 477 | #TMM 478 | model = Rnn(x_size, y_size, 2, 2, help="x_size and y_size are the pixel size your set, my pixel size[32, 32]") 479 | model = model.cuda() 480 | #Maxpooling 481 | maxPool = nn.MaxPool2d(kernel_size=2, stride=None, padding=0, dilation=1,return_indices=False, ceil_mode=False) 482 | ###fusion 483 | 484 | 485 | x_add=x1+x2 486 | x_maxpool=maxPool( x_add) 487 | x_before = self.upsamplex2(x_maxpool) 488 | x_after = rearrange(x_before, 'b c h w -> b (h w) c') 489 | x_TMM_0 = model(x_after) 490 | x_TMM_1 = x_TMM_0 .reshape(x101.shape[0],x101.shape[1],x101.shape[2],x101.shape[3],) 491 | 492 | #CDDM 493 | x111=self.zl_difference(x110) 494 | 495 | x222=self.zl_difference(x22) 496 | 497 | 498 | x_CDDM=torch.abs( x111- x222 ) 499 | if not self.if_upsample_2x: 500 | x_CDDM = self.upsamplex2(x_CDDM) 501 | 502 | # forward small cnn 503 | 504 | x=x_CDDM*x101+x101+x_TMM_1 505 | 506 | if self.output_sigmoid: 507 | x = self.sigmoid(x) 508 | outputs = [] 509 | outputs.append(x) 510 | return outputs 511 | 512 | 513 | -------------------------------------------------------------------------------- /STADE-CDNet/models/pixel_shuffel_up.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn.functional as F 4 | from torch import nn 5 | 6 | def icnr(x, scale=2, init=nn.init.kaiming_normal_): 7 | """ 8 | Checkerboard artifact free sub-pixel convolution 9 | https://arxiv.org/abs/1707.02937 10 | """ 11 | ni,nf,h,w = x.shape 12 | ni2 = int(ni/(scale**2)) 13 | k = init(torch.zeros([ni2,nf,h,w])).transpose(0, 1) 14 | k = k.contiguous().view(ni2, nf, -1) 15 | k = k.repeat(1, 1, scale**2) 16 | k = k.contiguous().view([nf,ni,h,w]).transpose(0, 1) 17 | x.data.copy_(k) 18 | 19 | 20 | class PixelShuffle(nn.Module): 21 | """ 22 | Real-Time Single Image and Video Super-Resolution 23 | https://arxiv.org/abs/1609.05158 24 | """ 25 | def __init__(self, n_channels, scale): 26 | super(PixelShuffle, self).__init__() 27 | self.conv = nn.Conv2d(n_channels, n_channels*(scale**2), kernel_size=1) 28 | icnr(self.conv.weight) 29 | self.shuf = nn.PixelShuffle(scale) 30 | self.relu = nn.ReLU(inplace=True) 31 | 32 | def forward(self,x): 33 | x = self.shuf(self.relu(self.conv(x))) 34 | return x 35 | 36 | 37 | def upsample(in_channels, out_channels, upscale, kernel_size=3): 38 | # A series of x 2 upsamling until we get to the upscale we want 39 | layers = [] 40 | conv1x1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False) 41 | nn.init.kaiming_normal_(conv1x1.weight.data, nonlinearity='relu') 42 | layers.append(conv1x1) 43 | for i in range(int(math.log(upscale, 2))): 44 | layers.append(PixelShuffle(out_channels, scale=2)) 45 | return nn.Sequential(*layers) 46 | 47 | 48 | class PS_UP(nn.Module): 49 | def __init__(self, upscale, conv_in_ch, num_classes): 50 | super(PS_UP, self).__init__() 51 | self.upsample = upsample(conv_in_ch, num_classes, upscale=upscale) 52 | 53 | def forward(self, x): 54 | x = self.upsample(x) 55 | return x -------------------------------------------------------------------------------- /STADE-CDNet/models/resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.utils.model_zoo import load_url as load_state_dict_from_url 4 | 5 | # from torchvision.models.utils import load_state_dict_from_url 6 | 7 | 8 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 9 | 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', 10 | 'wide_resnet50_2', 'wide_resnet101_2'] 11 | 12 | 13 | model_urls = { 14 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 15 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 16 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 17 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 18 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 19 | 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', 20 | 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth', 21 | 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth', 22 | 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth', 23 | } 24 | 25 | 26 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 27 | """3x3 convolution with padding""" 28 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 29 | padding=dilation, groups=groups, bias=False, dilation=dilation) 30 | 31 | 32 | def conv1x1(in_planes, out_planes, stride=1): 33 | """1x1 convolution""" 34 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 35 | 36 | 37 | class BasicBlock(nn.Module): 38 | expansion = 1 39 | 40 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 41 | base_width=64, dilation=1, norm_layer=None): 42 | super(BasicBlock, self).__init__() 43 | if norm_layer is None: 44 | norm_layer = nn.BatchNorm2d 45 | if groups != 1 or base_width != 64: 46 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 47 | if dilation > 1: 48 | dilation = 1 49 | # raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 50 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 51 | self.conv1 = conv3x3(inplanes, planes, stride) 52 | self.bn1 = norm_layer(planes) 53 | self.relu = nn.ReLU(inplace=True) 54 | self.conv2 = conv3x3(planes, planes) 55 | self.bn2 = norm_layer(planes) 56 | self.downsample = downsample 57 | self.stride = stride 58 | 59 | def forward(self, x): 60 | identity = x 61 | 62 | out = self.conv1(x) 63 | out = self.bn1(out) 64 | out = self.relu(out) 65 | 66 | out = self.conv2(out) 67 | out = self.bn2(out) 68 | 69 | if self.downsample is not None: 70 | identity = self.downsample(x) 71 | 72 | out += identity 73 | out = self.relu(out) 74 | 75 | return out 76 | 77 | 78 | class Bottleneck(nn.Module): 79 | # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2) 80 | # while original implementation places the stride at the first 1x1 convolution(self.conv1) 81 | # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385. 82 | # This variant is also known as ResNet V1.5 and improves accuracy according to 83 | # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. 84 | 85 | expansion = 4 86 | 87 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 88 | base_width=64, dilation=1, norm_layer=None): 89 | super(Bottleneck, self).__init__() 90 | if norm_layer is None: 91 | norm_layer = nn.BatchNorm2d 92 | width = int(planes * (base_width / 64.)) * groups 93 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 94 | self.conv1 = conv1x1(inplanes, width) 95 | self.bn1 = norm_layer(width) 96 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 97 | self.bn2 = norm_layer(width) 98 | self.conv3 = conv1x1(width, planes * self.expansion) 99 | self.bn3 = norm_layer(planes * self.expansion) 100 | self.relu = nn.ReLU(inplace=True) 101 | self.downsample = downsample 102 | self.stride = stride 103 | 104 | def forward(self, x): 105 | identity = x 106 | 107 | out = self.conv1(x) 108 | out = self.bn1(out) 109 | out = self.relu(out) 110 | 111 | out = self.conv2(out) 112 | out = self.bn2(out) 113 | out = self.relu(out) 114 | 115 | out = self.conv3(out) 116 | out = self.bn3(out) 117 | 118 | if self.downsample is not None: 119 | identity = self.downsample(x) 120 | 121 | out += identity 122 | out = self.relu(out) 123 | 124 | return out 125 | 126 | 127 | class ResNet(nn.Module): 128 | 129 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, 130 | groups=1, width_per_group=64, replace_stride_with_dilation=None, 131 | norm_layer=None, strides=None): 132 | super(ResNet, self).__init__() 133 | if norm_layer is None: 134 | norm_layer = nn.BatchNorm2d 135 | self._norm_layer = norm_layer 136 | 137 | self.strides = strides 138 | if self.strides is None: 139 | self.strides = [2, 2, 2, 2, 2] 140 | 141 | self.inplanes = 64 142 | self.dilation = 1 143 | if replace_stride_with_dilation is None: 144 | # each element in the tuple indicates if we should replace 145 | # the 2x2 stride with a dilated convolution instead 146 | replace_stride_with_dilation = [False, False, False] 147 | if len(replace_stride_with_dilation) != 3: 148 | raise ValueError("replace_stride_with_dilation should be None " 149 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 150 | self.groups = groups 151 | self.base_width = width_per_group 152 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=self.strides[0], padding=3, 153 | bias=False) 154 | self.bn1 = norm_layer(self.inplanes) 155 | self.relu = nn.ReLU(inplace=True) 156 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=self.strides[1], padding=1) 157 | self.layer1 = self._make_layer(block, 64, layers[0]) 158 | self.layer2 = self._make_layer(block, 128, layers[1], stride=self.strides[2], 159 | dilate=replace_stride_with_dilation[0]) 160 | self.layer3 = self._make_layer(block, 256, layers[2], stride=self.strides[3], 161 | dilate=replace_stride_with_dilation[1]) 162 | self.layer4 = self._make_layer(block, 512, layers[3], stride=self.strides[4], 163 | dilate=replace_stride_with_dilation[2]) 164 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 165 | self.fc = nn.Linear(512 * block.expansion, num_classes) 166 | 167 | for m in self.modules(): 168 | if isinstance(m, nn.Conv2d): 169 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 170 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 171 | nn.init.constant_(m.weight, 1) 172 | nn.init.constant_(m.bias, 0) 173 | 174 | # Zero-initialize the last BN in each residual branch, 175 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 176 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 177 | if zero_init_residual: 178 | for m in self.modules(): 179 | if isinstance(m, Bottleneck): 180 | nn.init.constant_(m.bn3.weight, 0) 181 | elif isinstance(m, BasicBlock): 182 | nn.init.constant_(m.bn2.weight, 0) 183 | 184 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 185 | norm_layer = self._norm_layer 186 | downsample = None 187 | previous_dilation = self.dilation 188 | if dilate: 189 | self.dilation *= stride 190 | stride = 1 191 | if stride != 1 or self.inplanes != planes * block.expansion: 192 | downsample = nn.Sequential( 193 | conv1x1(self.inplanes, planes * block.expansion, stride), 194 | norm_layer(planes * block.expansion), 195 | ) 196 | 197 | layers = [] 198 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 199 | self.base_width, previous_dilation, norm_layer)) 200 | self.inplanes = planes * block.expansion 201 | for _ in range(1, blocks): 202 | layers.append(block(self.inplanes, planes, groups=self.groups, 203 | base_width=self.base_width, dilation=self.dilation, 204 | norm_layer=norm_layer)) 205 | 206 | return nn.Sequential(*layers) 207 | 208 | def _forward_impl(self, x): 209 | # See note [TorchScript super()] 210 | x = self.conv1(x) 211 | x = self.bn1(x) 212 | x = self.relu(x) 213 | x = self.maxpool(x) 214 | 215 | x = self.layer1(x) 216 | x = self.layer2(x) 217 | x = self.layer3(x) 218 | x = self.layer4(x) 219 | 220 | x = self.avgpool(x) 221 | x = torch.flatten(x, 1) 222 | x = self.fc(x) 223 | 224 | return x 225 | 226 | def forward(self, x): 227 | return self._forward_impl(x) 228 | 229 | 230 | def _resnet(arch, block, layers, pretrained, progress, **kwargs): 231 | model = ResNet(block, layers, **kwargs) 232 | if pretrained: 233 | state_dict = load_state_dict_from_url(model_urls[arch], 234 | progress=progress) 235 | model.load_state_dict(state_dict) 236 | return model 237 | 238 | 239 | def resnet18(pretrained=False, progress=True, **kwargs): 240 | r"""ResNet-18 model from 241 | `"Deep Residual Learning for Image Recognition" `_ 242 | 243 | Args: 244 | pretrained (bool): If True, returns a model pre-trained on ImageNet 245 | progress (bool): If True, displays a progress bar of the download to stderr 246 | """ 247 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, 248 | **kwargs) 249 | 250 | 251 | def resnet34(pretrained=False, progress=True, **kwargs): 252 | r"""ResNet-34 model from 253 | `"Deep Residual Learning for Image Recognition" `_ 254 | 255 | Args: 256 | pretrained (bool): If True, returns a model pre-trained on ImageNet 257 | progress (bool): If True, displays a progress bar of the download to stderr 258 | """ 259 | return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, 260 | **kwargs) 261 | 262 | 263 | def resnet50(pretrained=False, progress=True, **kwargs): 264 | r"""ResNet-50 model from 265 | `"Deep Residual Learning for Image Recognition" `_ 266 | 267 | Args: 268 | pretrained (bool): If True, returns a model pre-trained on ImageNet 269 | progress (bool): If True, displays a progress bar of the download to stderr 270 | """ 271 | return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, 272 | **kwargs) 273 | 274 | 275 | def resnet101(pretrained=False, progress=True, **kwargs): 276 | r"""ResNet-101 model from 277 | `"Deep Residual Learning for Image Recognition" `_ 278 | 279 | Args: 280 | pretrained (bool): If True, returns a model pre-trained on ImageNet 281 | progress (bool): If True, displays a progress bar of the download to stderr 282 | """ 283 | return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress, 284 | **kwargs) 285 | 286 | 287 | def resnet152(pretrained=False, progress=True, **kwargs): 288 | r"""ResNet-152 model from 289 | `"Deep Residual Learning for Image Recognition" `_ 290 | 291 | Args: 292 | pretrained (bool): If True, returns a model pre-trained on ImageNet 293 | progress (bool): If True, displays a progress bar of the download to stderr 294 | """ 295 | return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress, 296 | **kwargs) 297 | 298 | 299 | def resnext50_32x4d(pretrained=False, progress=True, **kwargs): 300 | r"""ResNeXt-50 32x4d model from 301 | `"Aggregated Residual Transformation for Deep Neural Networks" `_ 302 | 303 | Args: 304 | pretrained (bool): If True, returns a model pre-trained on ImageNet 305 | progress (bool): If True, displays a progress bar of the download to stderr 306 | """ 307 | kwargs['groups'] = 32 308 | kwargs['width_per_group'] = 4 309 | return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3], 310 | pretrained, progress, **kwargs) 311 | 312 | 313 | def resnext101_32x8d(pretrained=False, progress=True, **kwargs): 314 | r"""ResNeXt-101 32x8d model from 315 | `"Aggregated Residual Transformation for Deep Neural Networks" `_ 316 | 317 | Args: 318 | pretrained (bool): If True, returns a model pre-trained on ImageNet 319 | progress (bool): If True, displays a progress bar of the download to stderr 320 | """ 321 | kwargs['groups'] = 32 322 | kwargs['width_per_group'] = 8 323 | return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3], 324 | pretrained, progress, **kwargs) 325 | 326 | 327 | def wide_resnet50_2(pretrained=False, progress=True, **kwargs): 328 | r"""Wide ResNet-50-2 model from 329 | `"Wide Residual Networks" `_ 330 | 331 | The model is the same as ResNet except for the bottleneck number of channels 332 | which is twice larger in every block. The number of channels in outer 1x1 333 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 334 | channels, and in Wide ResNet-50-2 has 2048-1024-2048. 335 | 336 | Args: 337 | pretrained (bool): If True, returns a model pre-trained on ImageNet 338 | progress (bool): If True, displays a progress bar of the download to stderr 339 | """ 340 | kwargs['width_per_group'] = 64 * 2 341 | return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3], 342 | pretrained, progress, **kwargs) 343 | 344 | 345 | def wide_resnet101_2(pretrained=False, progress=True, **kwargs): 346 | r"""Wide ResNet-101-2 model from 347 | `"Wide Residual Networks" `_ 348 | 349 | The model is the same as ResNet except for the bottleneck number of channels 350 | which is twice larger in every block. The number of channels in outer 1x1 351 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 352 | channels, and in Wide ResNet-50-2 has 2048-1024-2048. 353 | 354 | Args: 355 | pretrained (bool): If True, returns a model pre-trained on ImageNet 356 | progress (bool): If True, displays a progress bar of the download to stderr 357 | """ 358 | kwargs['width_per_group'] = 64 * 2 359 | return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3], 360 | pretrained, progress, **kwargs) 361 | -------------------------------------------------------------------------------- /STADE-CDNet/models/trainer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | import os 4 | 5 | import utils 6 | from models.networks import * 7 | 8 | import torch 9 | import torch.optim as optim 10 | import numpy as np 11 | from misc.metric_tool import ConfuseMatrixMeter 12 | from models.losses import cross_entropy 13 | import models.losses as losses 14 | from models.losses import get_alpha, softmax_helper, FocalLoss, mIoULoss, mmIoULoss 15 | 16 | from misc.logger_tool import Logger, Timer 17 | 18 | from utils import de_norm 19 | 20 | from tqdm import tqdm 21 | 22 | class CDTrainer(): 23 | 24 | def __init__(self, args, dataloaders): 25 | self.args = args 26 | self.dataloaders = dataloaders 27 | 28 | self.n_class = args.n_class 29 | # define G 30 | self.net_G = define_G(args=args, gpu_ids=args.gpu_ids) 31 | 32 | 33 | self.device = torch.device("cuda:%s" % args.gpu_ids[0] if torch.cuda.is_available() and len(args.gpu_ids)>0 34 | else "cpu") 35 | print(self.device) 36 | 37 | # Learning rate and Beta1 for Adam optimizers 38 | self.lr = args.lr 39 | 40 | # define optimizers 41 | if args.optimizer == "sgd": 42 | self.optimizer_G = optim.SGD(self.net_G.parameters(), lr=self.lr, 43 | momentum=0.9, 44 | weight_decay=5e-4) 45 | elif args.optimizer == "adam": 46 | self.optimizer_G = optim.Adam(self.net_G.parameters(), lr=self.lr, 47 | weight_decay=0) 48 | elif args.optimizer == "adamw": 49 | self.optimizer_G = optim.AdamW(self.net_G.parameters(), lr=self.lr, 50 | betas=("x", "y"), weight_decay=0.05)# x, y must be set. 51 | 52 | # self.optimizer_G = optim.Adam(self.net_G.parameters(), lr=self.lr) 53 | 54 | # define lr schedulers 55 | self.exp_lr_scheduler_G = get_scheduler(self.optimizer_G, args) 56 | 57 | self.running_metric = ConfuseMatrixMeter(n_class=2) 58 | 59 | # define logger file 60 | logger_path = os.path.join(args.checkpoint_dir, 'log.txt') 61 | self.logger = Logger(logger_path) 62 | self.logger.write_dict_str(args.__dict__) 63 | # define timer 64 | self.timer = Timer() 65 | self.batch_size = args.batch_size 66 | 67 | # training log 68 | self.epoch_acc = 0 69 | self.best_val_acc = 0.0 70 | self.best_epoch_id = 0 71 | self.epoch_to_start = 0 72 | self.max_num_epochs = args.max_epochs 73 | 74 | self.global_step = 0 75 | self.steps_per_epoch = len(dataloaders['train']) 76 | self.total_steps = (self.max_num_epochs - self.epoch_to_start)*self.steps_per_epoch 77 | 78 | self.G_pred = None 79 | self.pred_vis = None 80 | self.batch = None 81 | self.G_loss = None 82 | self.is_training = False 83 | self.batch_id = 0 84 | self.epoch_id = 0 85 | self.checkpoint_dir = args.checkpoint_dir 86 | self.vis_dir = args.vis_dir 87 | 88 | self.shuffle_AB = args.shuffle_AB 89 | 90 | # define the loss functions 91 | self.multi_scale_train = args.multi_scale_train 92 | self.multi_scale_infer = args.multi_scale_infer 93 | self.weights = tuple(args.multi_pred_weights) 94 | if args.loss == 'ce': 95 | self._pxl_loss = cross_entropy 96 | elif args.loss == 'bce': 97 | self._pxl_loss = losses.binary_ce 98 | elif args.loss == 'fl': 99 | print('\n Calculating alpha in Focal-Loss (FL) ...') 100 | alpha = get_alpha(dataloaders['train']) # calculare class occurences 101 | print(f"alpha-0 (no-change)={alpha[0]}, alpha-1 (change)={alpha[1]}") 102 | self._pxl_loss = FocalLoss(apply_nonlin = softmax_helper, alpha = alpha, gamma = 2, smooth = 1e-5) 103 | elif args.loss == "miou": 104 | print('\n Calculating Class occurances in training set...') 105 | alpha = np.asarray(get_alpha(dataloaders['train'])) # calculare class occurences 106 | alpha = alpha/np.sum(alpha) 107 | # weights = torch.tensor([1.0, 1.0]).cuda() 108 | weights = 1-torch.from_numpy(alpha).cuda() 109 | print(f"Weights = {weights}") 110 | self._pxl_loss = mIoULoss(weight=weights, size_average=True, n_classes=args.n_class).cuda() 111 | elif args.loss == "mmiou": 112 | self._pxl_loss = mmIoULoss(n_classes=args.n_class).cuda() 113 | else: 114 | raise NotImplemented(args.loss) 115 | 116 | self.VAL_ACC = np.array([], np.float32) 117 | if os.path.exists(os.path.join(self.checkpoint_dir, 'val_acc.npy')): 118 | self.VAL_ACC = np.load(os.path.join(self.checkpoint_dir, 'val_acc.npy')) 119 | self.TRAIN_ACC = np.array([], np.float32) 120 | if os.path.exists(os.path.join(self.checkpoint_dir, 'train_acc.npy')): 121 | self.TRAIN_ACC = np.load(os.path.join(self.checkpoint_dir, 'train_acc.npy')) 122 | 123 | # check and create model dir 124 | if os.path.exists(self.checkpoint_dir) is False: 125 | os.mkdir(self.checkpoint_dir) 126 | if os.path.exists(self.vis_dir) is False: 127 | os.mkdir(self.vis_dir) 128 | 129 | 130 | def _load_checkpoint(self, ckpt_name='last_ckpt.pt'): 131 | print("\n") 132 | if os.path.exists(os.path.join(self.checkpoint_dir, ckpt_name)): 133 | self.logger.write('loading last checkpoint...\n') 134 | # load the entire checkpoint 135 | checkpoint = torch.load(os.path.join(self.checkpoint_dir, ckpt_name), 136 | map_location=self.device) 137 | # update net_G states 138 | self.net_G.load_state_dict(checkpoint['model_G_state_dict'],strict=False)# 139 | 140 | # self.optimizer_G.load_state_dict(checkpoint['optimizer_G_state_dict']) 141 | self.exp_lr_scheduler_G.load_state_dict( 142 | checkpoint['exp_lr_scheduler_G_state_dict']) 143 | 144 | self.net_G.to(self.device) 145 | 146 | # update some other states 147 | self.epoch_to_start = checkpoint['epoch_id'] + 1 148 | self.best_val_acc = checkpoint['best_val_acc'] 149 | self.best_epoch_id = checkpoint['best_epoch_id'] 150 | 151 | self.total_steps = (self.max_num_epochs - self.epoch_to_start)*self.steps_per_epoch 152 | 153 | self.logger.write('Epoch_to_start = %d, Historical_best_acc = %.4f (at epoch %d)\n' % 154 | (self.epoch_to_start, self.best_val_acc, self.best_epoch_id)) 155 | self.logger.write('\n') 156 | elif self.args.pretrain is not None: 157 | print("Initializing backbone weights from: " + self.args.pretrain) 158 | self.net_G.load_state_dict(torch.load(self.args.pretrain), strict=False) 159 | self.net_G.to(self.device) 160 | self.net_G.eval() 161 | else: 162 | print('training from scratch...') 163 | print("\n") 164 | 165 | def _timer_update(self): 166 | self.global_step = (self.epoch_id-self.epoch_to_start) * self.steps_per_epoch + self.batch_id 167 | 168 | self.timer.update_progress((self.global_step + 1) / self.total_steps) 169 | est = self.timer.estimated_remaining() 170 | imps = (self.global_step + 1) * self.batch_size / self.timer.get_stage_elapsed() 171 | return imps, est 172 | 173 | def _visualize_pred(self): 174 | pred = torch.argmax(self.G_final_pred, dim=1, keepdim=True) 175 | pred_vis = pred * 255 176 | return pred_vis 177 | 178 | def _save_checkpoint(self, ckpt_name): 179 | torch.save({ 180 | 'epoch_id': self.epoch_id, 181 | 'best_val_acc': self.best_val_acc, 182 | 'best_epoch_id': self.best_epoch_id, 183 | 'model_G_state_dict': self.net_G.state_dict(), 184 | 'optimizer_G_state_dict': self.optimizer_G.state_dict(), 185 | 'exp_lr_scheduler_G_state_dict': self.exp_lr_scheduler_G.state_dict(), 186 | }, os.path.join(self.checkpoint_dir, ckpt_name)) 187 | 188 | def _update_lr_schedulers(self): 189 | self.exp_lr_scheduler_G.step() 190 | 191 | def _update_metric(self): 192 | """ 193 | update metric 194 | """ 195 | target = self.batch['L'].to(self.device).detach() 196 | G_pred = self.G_final_pred.detach() 197 | 198 | G_pred = torch.argmax(G_pred, dim=1) 199 | 200 | current_score = self.running_metric.update_cm(pr=G_pred.cpu().numpy(), gt=target.cpu().numpy()) 201 | return current_score 202 | 203 | def _collect_running_batch_states(self): 204 | 205 | running_acc = self._update_metric() 206 | 207 | m = len(self.dataloaders['train']) 208 | if self.is_training is False: 209 | m = len(self.dataloaders['val']) 210 | 211 | imps, est = self._timer_update() 212 | if np.mod(self.batch_id, 200) == 1: 213 | message = 'Is_training: %s. [%d,%d][%d,%d], imps: %.2f, est: %.2fh, G_loss: %.5f, running_mf1: %.5f\n' %\ 214 | (self.is_training, self.epoch_id, self.max_num_epochs-1, self.batch_id, m, 215 | imps*self.batch_size, est, 216 | self.G_loss.item(), running_acc) 217 | self.logger.write(message) 218 | 219 | 220 | if np.mod(self.batch_id, 200) == 1: 221 | vis_input = utils.make_numpy_grid(de_norm(self.batch['A'])) 222 | vis_input2 = utils.make_numpy_grid(de_norm(self.batch['B'])) 223 | 224 | vis_pred = utils.make_numpy_grid(self._visualize_pred()) 225 | 226 | vis_gt = utils.make_numpy_grid(self.batch['L']) 227 | vis = np.concatenate([vis_input, vis_input2, vis_pred, vis_gt], axis=0) 228 | vis = np.clip(vis, a_min=0.0, a_max=1.0) 229 | file_name = os.path.join( 230 | '/where', 'istrain_'+str(self.is_training)+'_'+ 231 | str(self.epoch_id)+'_'+str(self.batch_id)+'.jpg') 232 | plt.imsave(file_name, vis) 233 | 234 | def _collect_epoch_states(self): 235 | scores = self.running_metric.get_scores() 236 | self.epoch_acc = scores['mf1'] 237 | self.logger.write('Is_training: %s. Epoch %d / %d, epoch_mF1= %.5f\n' % 238 | (self.is_training, self.epoch_id, self.max_num_epochs-1, self.epoch_acc)) 239 | message = '' 240 | for k, v in scores.items(): 241 | message += '%s: %.5f ' % (k, v) 242 | self.logger.write(message+'\n') 243 | self.logger.write('\n') 244 | 245 | def _update_checkpoints(self): 246 | 247 | # save current model 248 | self._save_checkpoint(ckpt_name='last_ckpt.pt') 249 | self.logger.write('Lastest model updated. Epoch_acc=%.4f, Historical_best_acc=%.4f (at epoch %d)\n' 250 | % (self.epoch_acc, self.best_val_acc, self.best_epoch_id)) 251 | self.logger.write('\n') 252 | 253 | # update the best model (based on eval acc) 254 | if self.epoch_acc > self.best_val_acc: 255 | self.best_val_acc = self.epoch_acc 256 | self.best_epoch_id = self.epoch_id 257 | self._save_checkpoint(ckpt_name='best_ckpt.pt') 258 | self.logger.write('*' * 10 + 'Best model updated!\n') 259 | self.logger.write('\n') 260 | 261 | def _update_training_acc_curve(self): 262 | # update train acc curve 263 | self.TRAIN_ACC = np.append(self.TRAIN_ACC, [self.epoch_acc]) 264 | np.save(os.path.join(self.checkpoint_dir, 'train_acc.npy'), self.TRAIN_ACC) 265 | 266 | def _update_val_acc_curve(self): 267 | # update val acc curve 268 | self.VAL_ACC = np.append(self.VAL_ACC, [self.epoch_acc]) 269 | np.save(os.path.join(self.checkpoint_dir, 'val_acc.npy'), self.VAL_ACC) 270 | 271 | def _clear_cache(self): 272 | self.running_metric.clear() 273 | 274 | 275 | def _forward_pass(self, batch): 276 | self.batch = batch 277 | img_in1 = batch['A'].to(self.device) 278 | img_in2 = batch['B'].to(self.device) 279 | self.G_pred = self.net_G(img_in1, img_in2) 280 | 281 | if self.multi_scale_infer == "True": 282 | self.G_final_pred = torch.zeros(self.G_pred[-1].size()).to(self.device) 283 | for pred in self.G_pred: 284 | if pred.size(2) != self.G_pred[-1].size(2): 285 | self.G_final_pred = self.G_final_pred + F.interpolate(pred, size=self.G_pred[-1].size(2), mode="nearest") 286 | else: 287 | self.G_final_pred = self.G_final_pred + pred 288 | self.G_final_pred = self.G_final_pred/len(self.G_pred) 289 | else: 290 | self.G_final_pred = self.G_pred[-1] 291 | 292 | 293 | def _backward_G(self): 294 | gt = self.batch['L'].to(self.device).float() 295 | if self.multi_scale_train == "True": 296 | i = 0 297 | temp_loss = 0.0 298 | for pred in self.G_pred: 299 | if pred.size(2) != gt.size(2): 300 | temp_loss = temp_loss + self.weights[i]*self._pxl_loss(pred, F.interpolate(gt, size=pred.size(2), mode="nearest")) 301 | else: 302 | temp_loss = temp_loss + self.weights[i]*self._pxl_loss(pred, gt) 303 | i+=1 304 | self.G_loss = temp_loss 305 | else: 306 | self.G_loss = self._pxl_loss(self.G_pred[-1], gt) 307 | 308 | self.G_loss.backward() 309 | 310 | 311 | def train_models(self): 312 | 313 | self._load_checkpoint() 314 | 315 | # loop over the dataset multiple times 316 | for self.epoch_id in range(self.epoch_to_start, self.max_num_epochs): 317 | 318 | ################## train ################# 319 | ########################################## 320 | self._clear_cache() 321 | self.is_training = True 322 | self.net_G.train() # Set model to training mode 323 | # Iterate over data. 324 | total = len(self.dataloaders['train']) 325 | l=self.optimizer_G.param_groups[0]['lr'] 326 | #l=0.00001 327 | self.logger.write('lr: %0.07f\n \n' % l)#%0.7f 328 | for self.batch_id, batch in tqdm(enumerate(self.dataloaders['train'], 1), total=total): 329 | self._forward_pass(batch) 330 | # update G 331 | self.optimizer_G.zero_grad() 332 | self._backward_G() 333 | self.optimizer_G.step() 334 | self._collect_running_batch_states() 335 | self._timer_update() 336 | 337 | self._collect_epoch_states() 338 | self._update_training_acc_curve() 339 | self._update_lr_schedulers() 340 | 341 | 342 | ################## Eval ################## 343 | ########################################## 344 | self.logger.write('Begin evaluation...\n') 345 | self._clear_cache() 346 | self.is_training = False 347 | self.net_G.eval() 348 | 349 | # Iterate over data. 350 | for self.batch_id, batch in enumerate(self.dataloaders['val'], 0): 351 | with torch.no_grad(): 352 | self._forward_pass(batch) 353 | self._collect_running_batch_states() 354 | self._collect_epoch_states() 355 | 356 | ########### Update_Checkpoints ########### 357 | ########################################## 358 | self._update_val_acc_curve() 359 | self._update_checkpoints() 360 | 361 | -------------------------------------------------------------------------------- /STADE-CDNet/samples_DSIFN/A/0_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LiLisaZhi/STADE-CDNet/946b0bc29ff630ce30c72ed16cfb18bf5b6151e3/STADE-CDNet/samples_DSIFN/A/0_2.png -------------------------------------------------------------------------------- /STADE-CDNet/samples_DSIFN/A/1_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LiLisaZhi/STADE-CDNet/946b0bc29ff630ce30c72ed16cfb18bf5b6151e3/STADE-CDNet/samples_DSIFN/A/1_1.png -------------------------------------------------------------------------------- /STADE-CDNet/samples_DSIFN/A/2_4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LiLisaZhi/STADE-CDNet/946b0bc29ff630ce30c72ed16cfb18bf5b6151e3/STADE-CDNet/samples_DSIFN/A/2_4.png -------------------------------------------------------------------------------- /STADE-CDNet/samples_DSIFN/A/3_4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LiLisaZhi/STADE-CDNet/946b0bc29ff630ce30c72ed16cfb18bf5b6151e3/STADE-CDNet/samples_DSIFN/A/3_4.png -------------------------------------------------------------------------------- /STADE-CDNet/samples_DSIFN/A/4_4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LiLisaZhi/STADE-CDNet/946b0bc29ff630ce30c72ed16cfb18bf5b6151e3/STADE-CDNet/samples_DSIFN/A/4_4.png -------------------------------------------------------------------------------- /STADE-CDNet/samples_DSIFN/A/5_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LiLisaZhi/STADE-CDNet/946b0bc29ff630ce30c72ed16cfb18bf5b6151e3/STADE-CDNet/samples_DSIFN/A/5_3.png -------------------------------------------------------------------------------- /STADE-CDNet/samples_DSIFN/A/6_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LiLisaZhi/STADE-CDNet/946b0bc29ff630ce30c72ed16cfb18bf5b6151e3/STADE-CDNet/samples_DSIFN/A/6_3.png -------------------------------------------------------------------------------- /STADE-CDNet/samples_DSIFN/A/7_4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LiLisaZhi/STADE-CDNet/946b0bc29ff630ce30c72ed16cfb18bf5b6151e3/STADE-CDNet/samples_DSIFN/A/7_4.png -------------------------------------------------------------------------------- /STADE-CDNet/samples_DSIFN/A/8_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LiLisaZhi/STADE-CDNet/946b0bc29ff630ce30c72ed16cfb18bf5b6151e3/STADE-CDNet/samples_DSIFN/A/8_3.png -------------------------------------------------------------------------------- /STADE-CDNet/samples_DSIFN/A/9_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LiLisaZhi/STADE-CDNet/946b0bc29ff630ce30c72ed16cfb18bf5b6151e3/STADE-CDNet/samples_DSIFN/A/9_3.png -------------------------------------------------------------------------------- /STADE-CDNet/samples_DSIFN/B/0_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LiLisaZhi/STADE-CDNet/946b0bc29ff630ce30c72ed16cfb18bf5b6151e3/STADE-CDNet/samples_DSIFN/B/0_2.png -------------------------------------------------------------------------------- /STADE-CDNet/samples_DSIFN/B/1_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LiLisaZhi/STADE-CDNet/946b0bc29ff630ce30c72ed16cfb18bf5b6151e3/STADE-CDNet/samples_DSIFN/B/1_1.png -------------------------------------------------------------------------------- /STADE-CDNet/samples_DSIFN/B/2_4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LiLisaZhi/STADE-CDNet/946b0bc29ff630ce30c72ed16cfb18bf5b6151e3/STADE-CDNet/samples_DSIFN/B/2_4.png -------------------------------------------------------------------------------- /STADE-CDNet/samples_DSIFN/B/3_4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LiLisaZhi/STADE-CDNet/946b0bc29ff630ce30c72ed16cfb18bf5b6151e3/STADE-CDNet/samples_DSIFN/B/3_4.png -------------------------------------------------------------------------------- /STADE-CDNet/samples_DSIFN/B/4_4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LiLisaZhi/STADE-CDNet/946b0bc29ff630ce30c72ed16cfb18bf5b6151e3/STADE-CDNet/samples_DSIFN/B/4_4.png -------------------------------------------------------------------------------- /STADE-CDNet/samples_DSIFN/B/5_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LiLisaZhi/STADE-CDNet/946b0bc29ff630ce30c72ed16cfb18bf5b6151e3/STADE-CDNet/samples_DSIFN/B/5_3.png -------------------------------------------------------------------------------- /STADE-CDNet/samples_DSIFN/B/6_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LiLisaZhi/STADE-CDNet/946b0bc29ff630ce30c72ed16cfb18bf5b6151e3/STADE-CDNet/samples_DSIFN/B/6_3.png -------------------------------------------------------------------------------- /STADE-CDNet/samples_DSIFN/B/7_4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LiLisaZhi/STADE-CDNet/946b0bc29ff630ce30c72ed16cfb18bf5b6151e3/STADE-CDNet/samples_DSIFN/B/7_4.png -------------------------------------------------------------------------------- /STADE-CDNet/samples_DSIFN/B/8_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LiLisaZhi/STADE-CDNet/946b0bc29ff630ce30c72ed16cfb18bf5b6151e3/STADE-CDNet/samples_DSIFN/B/8_3.png -------------------------------------------------------------------------------- /STADE-CDNet/samples_DSIFN/B/9_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LiLisaZhi/STADE-CDNet/946b0bc29ff630ce30c72ed16cfb18bf5b6151e3/STADE-CDNet/samples_DSIFN/B/9_3.png -------------------------------------------------------------------------------- /STADE-CDNet/samples_DSIFN/label/0_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LiLisaZhi/STADE-CDNet/946b0bc29ff630ce30c72ed16cfb18bf5b6151e3/STADE-CDNet/samples_DSIFN/label/0_2.png -------------------------------------------------------------------------------- /STADE-CDNet/samples_DSIFN/label/1_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LiLisaZhi/STADE-CDNet/946b0bc29ff630ce30c72ed16cfb18bf5b6151e3/STADE-CDNet/samples_DSIFN/label/1_1.png -------------------------------------------------------------------------------- /STADE-CDNet/samples_DSIFN/label/2_4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LiLisaZhi/STADE-CDNet/946b0bc29ff630ce30c72ed16cfb18bf5b6151e3/STADE-CDNet/samples_DSIFN/label/2_4.png -------------------------------------------------------------------------------- /STADE-CDNet/samples_DSIFN/label/3_4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LiLisaZhi/STADE-CDNet/946b0bc29ff630ce30c72ed16cfb18bf5b6151e3/STADE-CDNet/samples_DSIFN/label/3_4.png -------------------------------------------------------------------------------- /STADE-CDNet/samples_DSIFN/label/4_4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LiLisaZhi/STADE-CDNet/946b0bc29ff630ce30c72ed16cfb18bf5b6151e3/STADE-CDNet/samples_DSIFN/label/4_4.png -------------------------------------------------------------------------------- /STADE-CDNet/samples_DSIFN/label/5_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LiLisaZhi/STADE-CDNet/946b0bc29ff630ce30c72ed16cfb18bf5b6151e3/STADE-CDNet/samples_DSIFN/label/5_3.png -------------------------------------------------------------------------------- /STADE-CDNet/samples_DSIFN/label/6_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LiLisaZhi/STADE-CDNet/946b0bc29ff630ce30c72ed16cfb18bf5b6151e3/STADE-CDNet/samples_DSIFN/label/6_3.png -------------------------------------------------------------------------------- /STADE-CDNet/samples_DSIFN/label/7_4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LiLisaZhi/STADE-CDNet/946b0bc29ff630ce30c72ed16cfb18bf5b6151e3/STADE-CDNet/samples_DSIFN/label/7_4.png -------------------------------------------------------------------------------- /STADE-CDNet/samples_DSIFN/label/8_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LiLisaZhi/STADE-CDNet/946b0bc29ff630ce30c72ed16cfb18bf5b6151e3/STADE-CDNet/samples_DSIFN/label/8_3.png -------------------------------------------------------------------------------- /STADE-CDNet/samples_DSIFN/label/9_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LiLisaZhi/STADE-CDNet/946b0bc29ff630ce30c72ed16cfb18bf5b6151e3/STADE-CDNet/samples_DSIFN/label/9_3.png -------------------------------------------------------------------------------- /STADE-CDNet/samples_DSIFN/list/demo.txt: -------------------------------------------------------------------------------- 1 | 9_3.png 2 | 8_3.png 3 | 7_4.png 4 | 6_3.png 5 | 5_3.png 6 | 4_4.png 7 | 3_4.png 8 | 2_4.png 9 | 1_1.png 10 | 0_2.png 11 | -------------------------------------------------------------------------------- /STADE-CDNet/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch.utils.data import DataLoader 4 | from torchvision import utils 5 | 6 | import data_config 7 | from datasets.CD_dataset import CDDataset 8 | 9 | 10 | def get_loader(data_name, img_size=256, batch_size="your need", split='test', 11 | is_train=False, dataset='CDDataset'): 12 | dataConfig = data_config.DataConfig().get_data_config(data_name) 13 | root_dir = dataConfig.root_dir 14 | label_transform = dataConfig.label_transform 15 | 16 | if dataset == 'CDDataset': 17 | data_set = CDDataset(root_dir=root_dir, split=split, 18 | img_size=img_size, is_train=is_train, 19 | label_transform=label_transform) 20 | else: 21 | raise NotImplementedError( 22 | 'Wrong dataset name %s (choose one from [CDDataset])' 23 | % dataset) 24 | 25 | shuffle = is_train 26 | dataloader = DataLoader(data_set, batch_size=batch_size, 27 | shuffle=shuffle, num_workers=4) 28 | 29 | return dataloader 30 | 31 | 32 | def get_loaders(args): 33 | 34 | data_name = args.data_name 35 | dataConfig = data_config.DataConfig().get_data_config(data_name) 36 | root_dir = dataConfig.root_dir 37 | label_transform = dataConfig.label_transform 38 | split = args.split 39 | split_val = 'val' 40 | if hasattr(args, 'split_val'): 41 | split_val = args.split_val 42 | if args.dataset == 'CDDataset': 43 | training_set = CDDataset(root_dir=root_dir, split=split, 44 | img_size=args.img_size,is_train=True, 45 | label_transform=label_transform) 46 | val_set = CDDataset(root_dir=root_dir, split=split_val, 47 | img_size=args.img_size,is_train=False, 48 | label_transform=label_transform) 49 | else: 50 | raise NotImplementedError( 51 | 'Wrong dataset name %s (choose one from [CDDataset,])' 52 | % args.dataset) 53 | 54 | datasets = {'train': training_set, 'val': val_set} 55 | dataloaders = {x: DataLoader(datasets[x], batch_size="your need", 56 | shuffle=True, num_workers=args.num_workers) 57 | for x in ['train', 'val']} 58 | 59 | return dataloaders 60 | 61 | 62 | def make_numpy_grid(tensor_data, pad_value=0,padding=0): 63 | tensor_data = tensor_data.detach() 64 | vis = utils.make_grid(tensor_data, pad_value=pad_value,padding=padding) 65 | vis = np.array(vis.cpu()).transpose((1,2,0)) 66 | if vis.shape[2] == 1: 67 | vis = np.stack([vis, vis, vis], axis=-1) 68 | return vis 69 | 70 | 71 | def de_norm(tensor_data): 72 | return tensor_data * 0.5 + 0.5 73 | 74 | 75 | def get_device(args): 76 | # set gpu ids 77 | str_ids = args.gpu_ids.split(',') 78 | args.gpu_ids = [] 79 | for str_id in str_ids: 80 | id = int(str_id) 81 | if id >= 0: 82 | args.gpu_ids.append(id) 83 | if len(args.gpu_ids) > 0: 84 | torch.cuda.set_device(args.gpu_ids[0]) -------------------------------------------------------------------------------- /image/1 (2).png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LiLisaZhi/STADE-CDNet/946b0bc29ff630ce30c72ed16cfb18bf5b6151e3/image/1 (2).png -------------------------------------------------------------------------------- /image/11.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LiLisaZhi/STADE-CDNet/946b0bc29ff630ce30c72ed16cfb18bf5b6151e3/image/11.png -------------------------------------------------------------------------------- /image/16.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LiLisaZhi/STADE-CDNet/946b0bc29ff630ce30c72ed16cfb18bf5b6151e3/image/16.png -------------------------------------------------------------------------------- /image/22.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LiLisaZhi/STADE-CDNet/946b0bc29ff630ce30c72ed16cfb18bf5b6151e3/image/22.png -------------------------------------------------------------------------------- /image/33.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LiLisaZhi/STADE-CDNet/946b0bc29ff630ce30c72ed16cfb18bf5b6151e3/image/33.png -------------------------------------------------------------------------------- /image/4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LiLisaZhi/STADE-CDNet/946b0bc29ff630ce30c72ed16cfb18bf5b6151e3/image/4.png -------------------------------------------------------------------------------- /image/44.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LiLisaZhi/STADE-CDNet/946b0bc29ff630ce30c72ed16cfb18bf5b6151e3/image/44.png -------------------------------------------------------------------------------- /image/5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LiLisaZhi/STADE-CDNet/946b0bc29ff630ce30c72ed16cfb18bf5b6151e3/image/5.png -------------------------------------------------------------------------------- /image/55.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LiLisaZhi/STADE-CDNet/946b0bc29ff630ce30c72ed16cfb18bf5b6151e3/image/55.jpg -------------------------------------------------------------------------------- /image/6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LiLisaZhi/STADE-CDNet/946b0bc29ff630ce30c72ed16cfb18bf5b6151e3/image/6.png -------------------------------------------------------------------------------- /image/66.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LiLisaZhi/STADE-CDNet/946b0bc29ff630ce30c72ed16cfb18bf5b6151e3/image/66.png -------------------------------------------------------------------------------- /image/7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LiLisaZhi/STADE-CDNet/946b0bc29ff630ce30c72ed16cfb18bf5b6151e3/image/7.png -------------------------------------------------------------------------------- /image/77.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LiLisaZhi/STADE-CDNet/946b0bc29ff630ce30c72ed16cfb18bf5b6151e3/image/77.png --------------------------------------------------------------------------------