├── .gitignore
├── README.md
├── STADE-CDNet
├── .gitignore
├── data_config.py
├── data_preparation
│ ├── dsifn_cd_256.m
│ ├── find_mean_std.py
│ └── levir_cd_256.m
├── datasets
│ ├── CD_dataset.py
│ └── data_utils.py
├── eval_zl.py
├── main_zl.py
├── misc
│ ├── imutils.py
│ ├── logger_tool.py
│ ├── metric_tool.py
│ ├── pyutils.py
│ └── torchutils.py
├── models
│ ├── ChangeFormer.py
│ ├── ChangeFormerBaseNetworks.py
│ ├── DTCDSCN.py
│ ├── SiamUnet_conc.py
│ ├── SiamUnet_diff.py
│ ├── Unet.py
│ ├── __init__.py
│ ├── basic_model.py
│ ├── evaluator.py
│ ├── help_funcs.py
│ ├── losses.py
│ ├── networks.py
│ ├── pixel_shuffel_up.py
│ ├── resnet.py
│ └── trainer.py
├── samples_DSIFN
│ ├── A
│ │ ├── 0_2.png
│ │ ├── 1_1.png
│ │ ├── 2_4.png
│ │ ├── 3_4.png
│ │ ├── 4_4.png
│ │ ├── 5_3.png
│ │ ├── 6_3.png
│ │ ├── 7_4.png
│ │ ├── 8_3.png
│ │ └── 9_3.png
│ ├── B
│ │ ├── 0_2.png
│ │ ├── 1_1.png
│ │ ├── 2_4.png
│ │ ├── 3_4.png
│ │ ├── 4_4.png
│ │ ├── 5_3.png
│ │ ├── 6_3.png
│ │ ├── 7_4.png
│ │ ├── 8_3.png
│ │ └── 9_3.png
│ ├── label
│ │ ├── 0_2.png
│ │ ├── 1_1.png
│ │ ├── 2_4.png
│ │ ├── 3_4.png
│ │ ├── 4_4.png
│ │ ├── 5_3.png
│ │ ├── 6_3.png
│ │ ├── 7_4.png
│ │ ├── 8_3.png
│ │ └── 9_3.png
│ └── list
│ │ └── demo.txt
└── utils.py
└── image
├── 1 (2).png
├── 11.png
├── 16.png
├── 22.png
├── 33.png
├── 4.png
├── 44.png
├── 5.png
├── 55.jpg
├── 6.png
├── 66.png
├── 7.png
└── 77.png
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/#use-with-ide
110 | .pdm.toml
111 |
112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113 | __pypackages__/
114 |
115 | # Celery stuff
116 | celerybeat-schedule
117 | celerybeat.pid
118 |
119 | # SageMath parsed files
120 | *.sage.py
121 |
122 | # Environments
123 | .env
124 | .venv
125 | env/
126 | venv/
127 | ENV/
128 | env.bak/
129 | venv.bak/
130 |
131 | # Spyder project settings
132 | .spyderproject
133 | .spyproject
134 |
135 | # Rope project settings
136 | .ropeproject
137 |
138 | # mkdocs documentation
139 | /site
140 |
141 | # mypy
142 | .mypy_cache/
143 | .dmypy.json
144 | dmypy.json
145 |
146 | # Pyre type checker
147 | .pyre/
148 |
149 | # pytype static type analyzer
150 | .pytype/
151 |
152 | # Cython debug symbols
153 | cython_debug/
154 |
155 | # PyCharm
156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158 | # and can be added to the global gitignore or merged into this file. For a more nuclear
159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160 | #.idea/
161 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # STADE-CDNet: Spatial-temporal attention with difference enhancement-based Network for remote sensing image change detection
2 | ##
Requirements
3 | 

4 |
5 | Python 3.8.0
6 | pytorch 1.10.1
7 | torchvision 0.11.2
8 | einops 0.3.2
9 |
10 | 

11 | ##
Installation
12 | Clone this repo:
13 | ```python
14 | git clone https://github.com/LiLisaZhi/STADE-CDNet.git cd STADE-CDNet
15 |
16 | ```
17 |
18 |
19 | ##
Dataset Preparation
20 |
21 |
22 | ```
23 | """
24 | Change detection data set with pixel-level binary labels;
25 | ├─A
26 | ├─B
27 | ├─label
28 | └─list
29 | """
30 | ```
31 | `A`:image of pro-image;
32 | `B`:image of post-image;
33 | `label`:label maps;
34 | `list`:contains train.txt, val.txt and test.txt, each file records the image names (XXX.png) in the change detection dataset.
35 |
36 | ##
Links to download processed datsets
37 | - LEVIR-CD:[`click here to download`](https://justchenhao.github.io/LEVIR/)
38 | - DSIFN-CD: [`click here to download`](https://github.com/GeoZcx/A-deeply-supervised-image-fusion-network-for-change-detection-in-remote-sensing-images/tree/master/dataset)
39 | ##
References
40 | Appreciate the work from the following repositories:
41 | ```
42 | https://github.com/justchenhao/BIT_CD
43 | ```
44 | 







45 | ```
46 | https://github.com/wgcban/ChangeFormer
47 | ```
48 |
49 | (The code implementation of our STADE-CDNet method references these code repoistories)
50 | ##
Contact
51 |
lisa_zhi@foxmail.com
52 |
53 |
--------------------------------------------------------------------------------
/STADE-CDNet/.gitignore:
--------------------------------------------------------------------------------
1 | *.pyc
2 | *.DS_Store
3 | ./sha256
4 | ./sha256.pub
5 |
--------------------------------------------------------------------------------
/STADE-CDNet/data_config.py:
--------------------------------------------------------------------------------
1 |
2 | class DataConfig:
3 | data_name = ""
4 | root_dir = ""
5 | label_transform = ""
6 | def get_data_config(self, data_name):
7 | self.data_name = data_name
8 | if data_name == 'LEVIR':
9 | self.label_transform = "norm"
10 | self.root_dir = '/where'
11 | elif data_name == 'DSIFN':
12 | self.label_transform = "norm"
13 | self.root_dir = '/where'
14 | elif data_name == 'WHU':
15 | self.label_transform = "norm"
16 | self.root_dir = '/where'
17 | elif data_name == 'CDD':
18 | self.label_transform = "norm"
19 | self.root_dir = '/where'
20 | elif data_name == 'TYPO':
21 | self.label_transform = "norm"
22 | self.root_dir = '/where'
23 | elif data_name == 'quick_start_LEVIR':
24 | self.root_dir = '/where'
25 | elif data_name == 'quick_start_DSIFN':
26 | self.root_dir = '/where'
27 | else:
28 | raise TypeError('%s has not defined' % data_name)
29 | return self
30 |
31 |
32 |
33 |
34 | if __name__ == '__main__':
35 | data = DataConfig().get_data_config(data_name='LEVIR')
36 | print(data.data_name)
37 | print(data.root_dir)
38 | print(data.label_transform)
39 |
40 |
--------------------------------------------------------------------------------
/STADE-CDNet/data_preparation/dsifn_cd_256.m:
--------------------------------------------------------------------------------
1 | %Dataset preparation code for DSFIN dataset (MATLAB)
2 | %Download DSFIN dataset here: https://github.com/GeoZcx/A-deeply-supervised-image-fusion-network-for-change-detection-in-remote-sensing-images/tree/master/dataset
3 | %This code generate 256x256 image partches required for the train/val/test
4 | %Please create folders according to following format.
5 | %DSIFN_256
6 | %------(train)
7 | % |---> A
8 | % |---> B
9 | % |---> label
10 | %------(val)
11 | % |---> A
12 | % |---> B
13 | % |---> label
14 | %------(test)
15 | % |---> A
16 | % |---> B
17 | % |---> label
18 | %Then run this code
19 | %Then copy all images in train-A, val-A, test-A to a folder name A
20 | %Then copy all images in train-B, val-B, test-B to a folder name B
21 | %Then copy all images in train-label, val-label, test-label to a folder name label
22 |
23 |
24 |
25 |
26 | clear all;
27 | close all;
28 | clc;
29 |
30 | %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
31 | %Train-A
32 | imgs_name = struct2cell(dir('DSIFN/download/Archive/train/t1/*.jpg'));
33 | for i=1:1:length(imgs_name)
34 | img_file_name = imgs_name{1,i};
35 | temp = imread(strcat('DSIFN/download/Archive/train/t1/', img_file_name));
36 | c=1;
37 | for j=1:2
38 | for k=1:2
39 | patch = temp((j-1)*256+1: j*256, (k-1)*256+1: k*256, :);
40 | imwrite(patch, strcat('DSIFN_256/train/A/', img_file_name(1:end-4), '_', num2str(c), '.png'));
41 | c=c+1;
42 | end
43 | end
44 |
45 | end
46 |
47 | %Train-B
48 | imgs_name = struct2cell(dir('DSIFN/download/Archive/train/t2/*.jpg'));
49 | for i=1:1:length(imgs_name)
50 | img_file_name = imgs_name{1,i};
51 | temp = imread(strcat('DSIFN/download/Archive/train/t2/', img_file_name));
52 | c=1;
53 | for j=1:2
54 | for k=1:2
55 | patch = temp((j-1)*256+1: j*256, (k-1)*256+1: k*256, :);
56 | imwrite(patch, strcat('DSIFN_256/train/B/', img_file_name(1:end-4), '_', num2str(c), '.png'));
57 | c=c+1;
58 | end
59 | end
60 |
61 | end
62 |
63 | %Train-label
64 | imgs_name = struct2cell(dir('DSIFN/download/Archive/train/mask/*.png'));
65 | for i=1:1:length(imgs_name)
66 | img_file_name = imgs_name{1,i};
67 | temp = imread(strcat('DSIFN/download/Archive/train/mask/',img_file_name));
68 | c=1;
69 | for j=1:2
70 | for k=1:2
71 | patch = temp((j-1)*256+1: j*256, (k-1)*256+1: k*256, :);
72 | imwrite(patch, strcat('DSIFN_256/train/label/', img_file_name(1:end-4), '_', num2str(c), '.png'));
73 | c=c+1;
74 | end
75 | end
76 |
77 | end
78 |
79 | %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
80 | %test-A
81 | imgs_name = struct2cell(dir('DSIFN/download/Archive/test/t1/*.jpg'));
82 | for i=1:1:length(imgs_name)
83 | img_file_name = imgs_name{1,i};
84 | temp = imread(strcat('DSIFN/download/Archive/test/t1/', img_file_name));
85 | c=1;
86 | for j=1:2
87 | for k=1:2
88 | patch = temp((j-1)*256+1: j*256, (k-1)*256+1: k*256, :);
89 | imwrite(patch, strcat('DSIFN_256/test/A/', img_file_name(1:end-4), '_', num2str(c), '.png'));
90 | c=c+1;
91 | end
92 | end
93 |
94 | end
95 |
96 | %test-B
97 | imgs_name = struct2cell(dir('DSIFN/download/Archive/test/t2/*.jpg'));
98 | for i=1:1:length(imgs_name)
99 | img_file_name = imgs_name{1,i};
100 | temp = imread(strcat('DSIFN/download/Archive/test/t2/', img_file_name));
101 | c=1;
102 | for j=1:2
103 | for k=1:2
104 | patch = temp((j-1)*256+1: j*256, (k-1)*256+1: k*256, :);
105 | imwrite(patch, strcat('DSIFN_256/test/B/', img_file_name(1:end-4), '_', num2str(c), '.png'));
106 | c=c+1;
107 | end
108 | end
109 |
110 | end
111 |
112 | %test-label
113 | imgs_name = struct2cell(dir('DSIFN/download/Archive/test/mask/*.png'));
114 | for i=1:1:length(imgs_name)
115 | img_file_name = imgs_name{1,i};
116 | temp = imread(strcat('DSIFN/download/Archive/test/mask/',img_file_name));
117 | c=1;
118 | for j=1:2
119 | for k=1:2
120 | patch = temp((j-1)*256+1: j*256, (k-1)*256+1: k*256, :);
121 | imwrite(patch, strcat('DSIFN_256/test/label/', img_file_name(1:end-4), '_', num2str(c), '.png'));
122 | c=c+1;
123 | end
124 | end
125 |
126 | end
127 |
128 |
129 | %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
130 | %val-A
131 | imgs_name = struct2cell(dir('DSIFN/download/Archive/val/t1/*.jpg'));
132 | for i=1:1:length(imgs_name)
133 | img_file_name = imgs_name{1,i};
134 | temp = imread(strcat('DSIFN/download/Archive/val/t1/', img_file_name));
135 | c=1;
136 | for j=1:2
137 | for k=1:2
138 | patch = temp((j-1)*256+1: j*256, (k-1)*256+1: k*256, :);
139 | imwrite(patch, strcat('DSIFN_256/val/A/', img_file_name(1:end-4), '_', num2str(c), '.png'));
140 | c=c+1;
141 | end
142 | end
143 |
144 | end
145 |
146 | %val-B
147 | imgs_name = struct2cell(dir('DSIFN/download/Archive/val/t2/*.jpg'));
148 | for i=1:1:length(imgs_name)
149 | img_file_name = imgs_name{1,i};
150 | temp = imread(strcat('DSIFN/download/Archive/val/t2/', img_file_name));
151 | c=1;
152 | for j=1:2
153 | for k=1:2
154 | patch = temp((j-1)*256+1: j*256, (k-1)*256+1: k*256, :);
155 | imwrite(patch, strcat('DSIFN_256/val/B/', img_file_name(1:end-4), '_', num2str(c), '.png'));
156 | c=c+1;
157 | end
158 | end
159 |
160 | end
161 |
162 | %val-label
163 | imgs_name = struct2cell(dir('DSIFN/download/Archive/val/mask/*.png'));
164 | for i=1:1:length(imgs_name)
165 | img_file_name = imgs_name{1,i};
166 | temp = imread(strcat('DSIFN/download/Archive/val/mask/',img_file_name));
167 | c=1;
168 | for j=1:2
169 | for k=1:2
170 | patch = temp((j-1)*256+1: j*256, (k-1)*256+1: k*256, :);
171 | imwrite(patch, strcat('DSIFN_256/val/label/', img_file_name(1:end-4), '_', num2str(c), '.png'));
172 | c=c+1;
173 | end
174 | end
175 |
176 | end
177 |
--------------------------------------------------------------------------------
/STADE-CDNet/data_preparation/find_mean_std.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | from PIL import Image
4 |
5 |
6 | if __name__ == '__main__':
7 | filepath = r"/where" # Dataset directory
8 | pathDir = os.listdir(filepath) # Images in dataset directory
9 | num = len(pathDir) # Here (512512) is the size of each image
10 |
11 | print("Computing mean...")
12 | data_mean = np.zeros(3)
13 | for idx in range(len(pathDir)):
14 | filename = pathDir[idx]
15 | img = Image.open(os.path.join(filepath, filename))
16 | img = np.array(img) / 255.0
17 | print(img.shape)
18 | data_mean += np.mean(img) # Take all the data of the first dimension in the three-dimensional matrix
19 | # As the use of gray images, so calculate a channel on it
20 | data_mean = data_mean / num
21 |
22 | print("Computing var...")
23 | data_std = 0.
24 | for idx in range(len(pathDir)):
25 | filename = pathDir[idx]
26 | img = Image.open(os.path.join(filepath, filename)).convert('L').resize((256, 256))
27 | img = np.array(img) / 255.0
28 | data_std += np.std(img)
29 |
30 | data_std = data_std / num
31 | print("mean:{}".format(data_mean))
32 | print("std:{}".format(data_std))
--------------------------------------------------------------------------------
/STADE-CDNet/data_preparation/levir_cd_256.m:
--------------------------------------------------------------------------------
1 | %Dataset preparation code for DSFIN dataset (MATLAB)
2 | %Download LEVIR dataset here: https://www.dropbox.com/s/h9jl2ygznsaeg5d/LEVIR-CD-256.zip
3 | %This code generate 256x256 image partches required for the train/val/test
4 | %Please create folders according to following format.
5 | %DSIFN-CD-256
6 | %------(train)
7 | % |---> A
8 | % |---> B
9 | % |---> label
10 | %------(val)
11 | % |---> A
12 | % |---> B
13 | % |---> label
14 | %------(test)
15 | % |---> A
16 | % |---> B
17 | % |---> label
18 | %Then run this code
19 | %Then copy all images in train-A, val-A, test-A to a folder name A
20 | %Then copy all images in train-B, val-B, test-B to a folder name B
21 | %Then copy all images in train-label, val-label, test-label to a folder name label
22 |
23 | clear all;
24 | close all;
25 | clc;
26 |
27 | %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
28 | %Train-A
29 | imgs_name = struct2cell(dir('LEVIR-CD/train/A/*.png'));
30 | for i=1:1:length(imgs_name)
31 | img_file_name = imgs_name{1,i};
32 | temp = imread(strcat('LEVIR-CD/train/A/',img_file_name));
33 | c=1;
34 | for j=1:4
35 | for k=1:4
36 | patch = temp((j-1)*256+1: j*256, (k-1)*256+1: k*256, :);
37 | imwrite(patch, strcat('LEVIR-CD256/train/A/', img_file_name(1:end-4), '_', num2str(c), '.png'));
38 | c=c+1;
39 | end
40 | end
41 |
42 | end
43 |
44 | %Train-B
45 | imgs_name = struct2cell(dir('LEVIR-CD/train/B/*.png'));
46 | for i=1:1:length(imgs_name)
47 | img_file_name = imgs_name{1,i};
48 | temp = imread(strcat('LEVIR-CD/train/B/',img_file_name));
49 | c=1;
50 | for j=1:4
51 | for k=1:4
52 | patch = temp((j-1)*256+1: j*256, (k-1)*256+1: k*256, :);
53 | imwrite(patch, strcat('LEVIR-CD256/train/B/', img_file_name(1:end-4), '_', num2str(c), '.png'));
54 | c=c+1;
55 | end
56 | end
57 |
58 | end
59 |
60 | %Train-label
61 | imgs_name = struct2cell(dir('LEVIR-CD/train/label/*.png'));
62 | for i=1:1:length(imgs_name)
63 | img_file_name = imgs_name{1,i};
64 | temp = imread(strcat('LEVIR-CD/train/label/',img_file_name));
65 | c=1;
66 | for j=1:4
67 | for k=1:4
68 | patch = temp((j-1)*256+1: j*256, (k-1)*256+1: k*256, :);
69 | imwrite(patch, strcat('LEVIR-CD256/train/label/', img_file_name(1:end-4), '_', num2str(c), '.png'));
70 | c=c+1;
71 | end
72 | end
73 |
74 | end
75 |
76 | %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
77 | %Test-A
78 | imgs_name = struct2cell(dir('LEVIR-CD/test/A/*.png'));
79 | for i=1:1:length(imgs_name)
80 | img_file_name = imgs_name{1,i};
81 | temp = imread(strcat('LEVIR-CD/test/A/',img_file_name));
82 | c=1;
83 | for j=1:4
84 | for k=1:4
85 | patch = temp((j-1)*256+1: j*256, (k-1)*256+1: k*256, :);
86 | imwrite(patch, strcat('LEVIR-CD256/test/A/', img_file_name(1:end-4), '_', num2str(c), '.png'));
87 | c=c+1;
88 | end
89 | end
90 |
91 | end
92 |
93 | %Test-B
94 | imgs_name = struct2cell(dir('LEVIR-CD/test/B/*.png'));
95 | for i=1:1:length(imgs_name)
96 | img_file_name = imgs_name{1,i};
97 | temp = imread(strcat('LEVIR-CD/test/B/',img_file_name));
98 | c=1;
99 | for j=1:4
100 | for k=1:4
101 | patch = temp((j-1)*256+1: j*256, (k-1)*256+1: k*256, :);
102 | imwrite(patch, strcat('LEVIR-CD256/test/B/', img_file_name(1:end-4), '_', num2str(c), '.png'));
103 | c=c+1;
104 | end
105 | end
106 |
107 | end
108 |
109 | %Test-label
110 | imgs_name = struct2cell(dir('LEVIR-CD/test/label/*.png'));
111 | for i=1:1:length(imgs_name)
112 | img_file_name = imgs_name{1,i};
113 | temp = imread(strcat('LEVIR-CD/test/label/',img_file_name));
114 | c=1;
115 | for j=1:4
116 | for k=1:4
117 | patch = temp((j-1)*256+1: j*256, (k-1)*256+1: k*256, :);
118 | imwrite(patch, strcat('LEVIR-CD256/test/label/', img_file_name(1:end-4), '_', num2str(c), '.png'));
119 | c=c+1;
120 | end
121 | end
122 |
123 | end
124 |
125 |
126 | %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
127 | %val-A
128 | imgs_name = struct2cell(dir('LEVIR-CD/val/A/*.png'));
129 | for i=1:1:length(imgs_name)
130 | img_file_name = imgs_name{1,i};
131 | temp = imread(strcat('LEVIR-CD/val/A/',img_file_name));
132 | c=1;
133 | for j=1:4
134 | for k=1:4
135 | patch = temp((j-1)*256+1: j*256, (k-1)*256+1: k*256, :);
136 | imwrite(patch, strcat('LEVIR-CD256/val/A/', img_file_name(1:end-4), '_', num2str(c), '.png'));
137 | c=c+1;
138 | end
139 | end
140 |
141 | end
142 |
143 | %val-B
144 | imgs_name = struct2cell(dir('LEVIR-CD/val/B/*.png'));
145 | for i=1:1:length(imgs_name)
146 | img_file_name = imgs_name{1,i};
147 | temp = imread(strcat('LEVIR-CD/val/B/',img_file_name));
148 | c=1;
149 | for j=1:4
150 | for k=1:4
151 | patch = temp((j-1)*256+1: j*256, (k-1)*256+1: k*256, :);
152 | imwrite(patch, strcat('LEVIR-CD256/val/B/', img_file_name(1:end-4), '_', num2str(c), '.png'));
153 | c=c+1;
154 | end
155 | end
156 |
157 | end
158 |
159 | %val-label
160 | imgs_name = struct2cell(dir('LEVIR-CD/val/label/*.png'));
161 | for i=1:1:length(imgs_name)
162 | img_file_name = imgs_name{1,i};
163 | temp = imread(strcat('LEVIR-CD/val/label/',img_file_name));
164 | c=1;
165 | for j=1:4
166 | for k=1:4
167 | patch = temp((j-1)*256+1: j*256, (k-1)*256+1: k*256, :);
168 | imwrite(patch, strcat('LEVIR-CD256/val/label/', img_file_name(1:end-4), '_', num2str(c), '.png'));
169 | c=c+1;
170 | end
171 | end
172 |
173 | end
174 |
--------------------------------------------------------------------------------
/STADE-CDNet/datasets/CD_dataset.py:
--------------------------------------------------------------------------------
1 | """
2 | change detection data
3 | """
4 |
5 | import os
6 | from PIL import Image
7 | import numpy as np
8 |
9 | from torch.utils import data
10 |
11 | from datasets.data_utils import CDDataAugmentation
12 |
13 |
14 | """
15 | CD data set with pixel-level labels;
16 | ├─image
17 | ├─image_post
18 | ├─label
19 | └─list
20 | """
21 | IMG_FOLDER_NAME = "A"
22 | IMG_POST_FOLDER_NAME = 'B'
23 | LIST_FOLDER_NAME = 'list'
24 | ANNOT_FOLDER_NAME = "label"
25 |
26 | IGNORE = 255
27 |
28 | label_suffix='.png' # jpg for gan dataset, others : png
29 |
30 | def load_img_name_list(dataset_path):
31 | img_name_list = np.loadtxt(dataset_path, dtype=np.str)
32 | if img_name_list.ndim == 2:
33 | return img_name_list[:, 0]
34 | return img_name_list
35 |
36 |
37 | def load_image_label_list_from_npy(npy_path, img_name_list):
38 | cls_labels_dict = np.load(npy_path, allow_pickle=True).item()
39 | return [cls_labels_dict[img_name] for img_name in img_name_list]
40 |
41 |
42 | def get_img_post_path(root_dir,img_name):
43 | return os.path.join(root_dir, IMG_POST_FOLDER_NAME, img_name)
44 |
45 |
46 | def get_img_path(root_dir, img_name):
47 | return os.path.join(root_dir, IMG_FOLDER_NAME, img_name)
48 |
49 |
50 | def get_label_path(root_dir, img_name):
51 | return os.path.join(root_dir, ANNOT_FOLDER_NAME, img_name.replace('.jpg', label_suffix))
52 |
53 |
54 | class ImageDataset(data.Dataset):
55 | """VOCdataloder"""
56 | def __init__(self, root_dir, split='train', img_size=256, is_train=True,to_tensor=True):
57 | super(ImageDataset, self).__init__()
58 | self.root_dir = root_dir
59 | self.img_size = img_size
60 | self.split = split #train | train_aug | val
61 | # self.list_path = self.root_dir + '/' + LIST_FOLDER_NAME + '/' + self.list + '.txt'
62 | self.list_path = os.path.join(self.root_dir, LIST_FOLDER_NAME, self.split+'.txt')
63 | self.img_name_list = load_img_name_list(self.list_path)
64 |
65 | self.A_size = len(self.img_name_list) # get the size of dataset A
66 | self.to_tensor = to_tensor
67 | if is_train:
68 | self.augm = CDDataAugmentation(
69 | img_size=self.img_size,
70 | with_random_hflip=True,
71 | with_random_vflip=True,
72 | with_scale_random_crop=True,
73 | with_random_blur=True,
74 | random_color_tf=True
75 | )
76 | else:
77 | self.augm = CDDataAugmentation(
78 | img_size=self.img_size
79 | )
80 | def __getitem__(self, index):
81 | name = self.img_name_list[index]
82 | A_path = get_img_path(self.root_dir, self.img_name_list[index % self.A_size])
83 | B_path = get_img_post_path(self.root_dir, self.img_name_list[index % self.A_size])
84 |
85 | img = np.asarray(Image.open(A_path).convert('RGB'))
86 | img_B = np.asarray(Image.open(B_path).convert('RGB'))
87 |
88 | [img, img_B], _ = self.augm.transform([img, img_B],[], to_tensor=self.to_tensor)
89 |
90 | return {'A': img, 'B': img_B, 'name': name}
91 |
92 | def __len__(self):
93 | """Return the total number of images in the dataset."""
94 | return self.A_size
95 |
96 |
97 | class CDDataset(ImageDataset):
98 |
99 | def __init__(self, root_dir, img_size, split='train', is_train=True, label_transform=None,
100 | to_tensor=True):
101 | super(CDDataset, self).__init__(root_dir, img_size=img_size, split=split, is_train=is_train,
102 | to_tensor=to_tensor)
103 | self.label_transform = label_transform
104 |
105 | def __getitem__(self, index):
106 | name = self.img_name_list[index]
107 | A_path = get_img_path(self.root_dir, self.img_name_list[index % self.A_size])
108 | B_path = get_img_post_path(self.root_dir, self.img_name_list[index % self.A_size])
109 | img = np.asarray(Image.open(A_path).convert('RGB'))
110 | img_B = np.asarray(Image.open(B_path).convert('RGB'))
111 | L_path = get_label_path(self.root_dir, self.img_name_list[index % self.A_size])
112 |
113 | label = np.array(Image.open(L_path), dtype=np.uint8)
114 | # if you are getting error because of dim mismatch ad [:,:,0] at the end
115 |
116 | # 二分类中,前景标注为255
117 | if self.label_transform == 'norm':
118 | label = label // 255
119 |
120 | [img, img_B], [label] = self.augm.transform([img, img_B], [label], to_tensor=self.to_tensor)
121 | # print(label.max())
122 |
123 | return {'name': name, 'A': img, 'B': img_B, 'L': label}
124 |
125 |
--------------------------------------------------------------------------------
/STADE-CDNet/datasets/data_utils.py:
--------------------------------------------------------------------------------
1 | import random
2 | import numpy as np
3 |
4 | from PIL import Image
5 | from PIL import ImageFilter
6 |
7 | import torchvision.transforms.functional as TF
8 | from torchvision import transforms
9 | import torch
10 |
11 |
12 | def to_tensor_and_norm(imgs, labels):
13 | # to tensor
14 | imgs = [TF.to_tensor(img) for img in imgs]
15 | labels = [torch.from_numpy(np.array(img, np.uint8)).unsqueeze(dim=0)
16 | for img in labels]
17 |
18 | imgs = [TF.normalize(img, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
19 | for img in imgs]
20 | return imgs, labels
21 |
22 |
23 | class CDDataAugmentation:
24 |
25 | def __init__(
26 | self,
27 | img_size,
28 | with_random_hflip=False,
29 | with_random_vflip=False,
30 | with_random_rot=False,
31 | with_random_crop=False,
32 | with_scale_random_crop=False,
33 | with_random_blur=False,
34 | random_color_tf=False
35 | ):
36 | self.img_size = img_size
37 | if self.img_size is None:
38 | self.img_size_dynamic = True
39 | else:
40 | self.img_size_dynamic = False
41 | self.with_random_hflip = with_random_hflip
42 | self.with_random_vflip = with_random_vflip
43 | self.with_random_rot = with_random_rot
44 | self.with_random_crop = with_random_crop
45 | self.with_scale_random_crop = with_scale_random_crop
46 | self.with_random_blur = with_random_blur
47 | self.random_color_tf=random_color_tf
48 | def transform(self, imgs, labels, to_tensor=True):
49 | """
50 | :param imgs: [ndarray,]
51 | :param labels: [ndarray,]
52 | :return: [ndarray,],[ndarray,]
53 | """
54 | # resize image and covert to tensor
55 | imgs = [TF.to_pil_image(img) for img in imgs]
56 | if self.img_size is None:
57 | self.img_size = None
58 |
59 | if not self.img_size_dynamic:
60 | if imgs[0].size != (self.img_size, self.img_size):
61 | imgs = [TF.resize(img, [self.img_size, self.img_size], interpolation=3)
62 | for img in imgs]
63 | else:
64 | self.img_size = imgs[0].size[0]
65 |
66 | labels = [TF.to_pil_image(img) for img in labels]
67 | if len(labels) != 0:
68 | if labels[0].size != (self.img_size, self.img_size):
69 | labels = [TF.resize(img, [self.img_size, self.img_size], interpolation=0)
70 | for img in labels]
71 |
72 | random_base = 0.5
73 | if self.with_random_hflip and random.random() > 0.5:
74 | imgs = [TF.hflip(img) for img in imgs]
75 | labels = [TF.hflip(img) for img in labels]
76 |
77 | if self.with_random_vflip and random.random() > 0.5:
78 | imgs = [TF.vflip(img) for img in imgs]
79 | labels = [TF.vflip(img) for img in labels]
80 |
81 | if self.with_random_rot and random.random() > random_base:
82 | angles = [90, 180, 270]
83 | index = random.randint(0, 2)
84 | angle = angles[index]
85 | imgs = [TF.rotate(img, angle) for img in imgs]
86 | labels = [TF.rotate(img, angle) for img in labels]
87 |
88 | if self.with_random_crop and random.random() > 0:
89 | i, j, h, w = transforms.RandomResizedCrop(size=self.img_size). \
90 | get_params(img=imgs[0], scale=(0.8, 1.2), ratio=(1, 1))
91 |
92 | imgs = [TF.resized_crop(img, i, j, h, w,
93 | size=(self.img_size, self.img_size),
94 | interpolation=Image.CUBIC)
95 | for img in imgs]
96 |
97 | labels = [TF.resized_crop(img, i, j, h, w,
98 | size=(self.img_size, self.img_size),
99 | interpolation=Image.NEAREST)
100 | for img in labels]
101 |
102 | if self.with_scale_random_crop:
103 | # rescale
104 | scale_range = [1, 1.2]
105 | target_scale = scale_range[0] + random.random() * (scale_range[1] - scale_range[0])
106 |
107 | imgs = [pil_rescale(img, target_scale, order=3) for img in imgs]
108 | labels = [pil_rescale(img, target_scale, order=0) for img in labels]
109 | # crop
110 | imgsize = imgs[0].size # h, w
111 | box = get_random_crop_box(imgsize=imgsize, cropsize=self.img_size)
112 | imgs = [pil_crop(img, box, cropsize=self.img_size, default_value=0)
113 | for img in imgs]
114 | labels = [pil_crop(img, box, cropsize=self.img_size, default_value=255)
115 | for img in labels]
116 |
117 | if self.with_random_blur and random.random() > 0:
118 | radius = random.random()
119 | imgs = [img.filter(ImageFilter.GaussianBlur(radius=radius))
120 | for img in imgs]
121 |
122 | if self.random_color_tf:
123 | color_jitter = transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.3)
124 | imgs_tf = []
125 | for img in imgs:
126 | tf = transforms.ColorJitter(
127 | color_jitter.brightness,
128 | color_jitter.contrast,
129 | color_jitter.saturation,
130 | color_jitter.hue)
131 | imgs_tf.append(tf(img))
132 | imgs = imgs_tf
133 |
134 | if to_tensor:
135 | # to tensor
136 | imgs = [TF.to_tensor(img) for img in imgs]
137 | labels = [torch.from_numpy(np.array(img, np.uint8)).unsqueeze(dim=0)
138 | for img in labels]
139 |
140 | imgs = [TF.normalize(img, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
141 | for img in imgs]
142 |
143 | return imgs, labels
144 |
145 |
146 | def pil_crop(image, box, cropsize, default_value):
147 | assert isinstance(image, Image.Image)
148 | img = np.array(image)
149 |
150 | if len(img.shape) == 3:
151 | cont = np.ones((cropsize, cropsize, img.shape[2]), img.dtype)*default_value
152 | else:
153 | cont = np.ones((cropsize, cropsize), img.dtype)*default_value
154 | cont[box[0]:box[1], box[2]:box[3]] = img[box[4]:box[5], box[6]:box[7]]
155 |
156 | return Image.fromarray(cont)
157 |
158 |
159 | def get_random_crop_box(imgsize, cropsize):
160 | h, w = imgsize
161 | ch = min(cropsize, h)
162 | cw = min(cropsize, w)
163 |
164 | w_space = w - cropsize
165 | h_space = h - cropsize
166 |
167 | if w_space > 0:
168 | cont_left = 0
169 | img_left = random.randrange(w_space + 1)
170 | else:
171 | cont_left = random.randrange(-w_space + 1)
172 | img_left = 0
173 |
174 | if h_space > 0:
175 | cont_top = 0
176 | img_top = random.randrange(h_space + 1)
177 | else:
178 | cont_top = random.randrange(-h_space + 1)
179 | img_top = 0
180 |
181 | return cont_top, cont_top+ch, cont_left, cont_left+cw, img_top, img_top+ch, img_left, img_left+cw
182 |
183 |
184 | def pil_rescale(img, scale, order):
185 | assert isinstance(img, Image.Image)
186 | height, width = img.size
187 | target_size = (int(np.round(height*scale)), int(np.round(width*scale)))
188 | return pil_resize(img, target_size, order)
189 |
190 |
191 | def pil_resize(img, size, order):
192 | assert isinstance(img, Image.Image)
193 | if size[0] == img.size[0] and size[1] == img.size[1]:
194 | return img
195 | if order == 3:
196 | resample = Image.BICUBIC
197 | elif order == 0:
198 | resample = Image.NEAREST
199 | return img.resize(size[::-1], resample)
200 |
--------------------------------------------------------------------------------
/STADE-CDNet/eval_zl.py:
--------------------------------------------------------------------------------
1 | from argparse import ArgumentParser
2 | import torch
3 | from models.evaluator import *
4 |
5 | print(torch.cuda.is_available())
6 |
7 |
8 | """
9 | eval the CD model
10 | """
11 |
12 | def main():
13 | # ------------
14 | # args
15 | # ------------
16 | parser = ArgumentParser()
17 | parser.add_argument('--gpu_ids', type=str, default="your need", help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU')
18 | parser.add_argument('--project_name', default='test', type=str)
19 | parser.add_argument('--print_models', default=False, type=bool, help='print models')
20 | parser.add_argument('--checkpoints_root', default='checkpoints', type=str)
21 | parser.add_argument('--vis_root', default='vis', type=str)
22 |
23 | # data
24 | parser.add_argument('--num_workers', default="your need", type=int)
25 | parser.add_argument('--dataset', default='CDDataset', type=str)
26 | parser.add_argument('--data_name', default='LEVIR', type=str)
27 |
28 | parser.add_argument('--batch_size', default=1, type=int)
29 | parser.add_argument('--split', default="test", type=str)
30 |
31 | parser.add_argument('--img_size', default="your data need", type=int)
32 |
33 | # model
34 | parser.add_argument('--n_class', default=2, type=int)
35 | parser.add_argument('--embed_dim', default="your need", type=int)
36 | parser.add_argument('--net_G', default='base_transformer_pos_s4_dd8_dedim8', type=str,
37 | help='base_resnet18 | base_transformer_pos_s4_dd8 | base_transformer_pos_s4_dd8_dedim8|')
38 |
39 | parser.add_argument('--checkpoint_name', default='best_ckpt.pt', type=str)
40 |
41 | args = parser.parse_args()
42 | utils.get_device(args)
43 | print(args.gpu_ids)
44 |
45 | # checkpoints dir
46 | args.checkpoint_dir = os.path.join(args.checkpoints_root, args.project_name)
47 | os.makedirs(args.checkpoint_dir, exist_ok=True)
48 | # visualize dir
49 | args.vis_dir = os.path.join(args.vis_root, args.project_name)
50 | os.makedirs(args.vis_dir, exist_ok=True)
51 |
52 | dataloader = utils.get_loader(args.data_name, img_size=args.img_size,
53 | batch_size=args.batch_size, is_train=False,
54 | split=args.split)
55 | model = CDEvaluator(args=args, dataloader=dataloader)
56 |
57 | model.eval_models(checkpoint_name=args.checkpoint_name)
58 |
59 |
60 | if __name__ == '__main__':
61 | main()
62 |
63 |
--------------------------------------------------------------------------------
/STADE-CDNet/main_zl.py:
--------------------------------------------------------------------------------
1 | from argparse import ArgumentParser
2 | import torch
3 | from models.trainer import *
4 |
5 | print(torch.cuda.is_available())
6 |
7 | """
8 | the main function for training the CD networks
9 | """
10 |
11 |
12 | def train(args):
13 | dataloaders = utils.get_loaders(args)
14 | model = CDTrainer(args=args, dataloaders=dataloaders)
15 | model.train_models()
16 |
17 |
18 | def test(args):
19 | from models.evaluator import CDEvaluator
20 | dataloader = utils.get_loader(args.data_name, img_size=args.img_size,
21 | batch_size=args.batch_size, is_train=False,
22 | split='test')
23 | model = CDEvaluator(args=args, dataloader=dataloader)
24 |
25 | model.eval_models()
26 |
27 |
28 | if __name__ == '__main__':
29 | # ------------
30 | # args
31 | # ------------
32 | parser = ArgumentParser()
33 | parser.add_argument('--gpu_ids', type=str, default='your need', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU')
34 | parser.add_argument('--project_name', default='STADE-CD', type=str)
35 | parser.add_argument('--checkpoint_root', default='where', type=str)
36 | parser.add_argument('--vis_root', default='where', type=str)
37 | parser.add_argument('--output_folder', default='samples_LEVIR/predict_STADE-CD', type=str)
38 | # data
39 | parser.add_argument('--num_workers', default=2, type=int)
40 | parser.add_argument('--dataset', default='CDDataset', type=str)
41 | parser.add_argument('--data_name', default='DSIFN', type=str)
42 |
43 | parser.add_argument('--batch_size', default="your need", type=int,help='The parameters I set = 8') #
44 | parser.add_argument('--split', default="train", type=str)
45 | parser.add_argument('--split_val', default="val", type=str)
46 |
47 | parser.add_argument('--img_size', default=256, type=int)
48 | parser.add_argument('--shuffle_AB', default=False, type=str)
49 |
50 | # model
51 | parser.add_argument('--n_class', default=2, type=int)
52 | parser.add_argument('--embed_dim', default=64, type=int)
53 | parser.add_argument('--pretrain', default=None, type=str)
54 | parser.add_argument('--multi_scale_train', default=False, type=str)
55 | parser.add_argument('--multi_scale_infer', default=False, type=str)
56 | parser.add_argument('--multi_pred_weights', nargs = '+', type = float, default = "your need",)
57 |
58 | parser.add_argument('--net_G', default='base_transformer_pos_s4_dd8', type=str,
59 | help='base_resnet18 | base_transformer_pos_s4 | '
60 | 'base_transformer_pos_s4_dd8 | '
61 | 'base_transformer_pos_s4_dd8_dedim8|ChangeFormerV5|SiamUnet_diff')
62 | parser.add_argument('--loss', default='ce', type=str)
63 |
64 | # optimizer
65 | parser.add_argument('--optimizer', default='adamw', type=str)
66 | parser.add_argument('--lr', default="your need", type=float,help='The parameters I set = 0.00009567')
67 | parser.add_argument('--max_epochs', default=406, type=int)
68 | parser.add_argument('--lr_policy', default='linear', type=str,
69 | help='linear | step')
70 | parser.add_argument('--lr_decay_iters', default=100, type=int)
71 |
72 | args = parser.parse_args()
73 | utils.get_device(args)
74 | print(args.gpu_ids)
75 |
76 | # checkpoints dir
77 | args.checkpoint_dir = os.path.join(args.checkpoint_root, args.project_name)
78 | os.makedirs(args.checkpoint_dir, exist_ok=True)
79 | # visualize dir
80 | args.vis_dir = os.path.join(args.vis_root, args.project_name)
81 | os.makedirs(args.vis_dir, exist_ok=True)
82 |
83 | train(args)
84 |
85 | test(args)
86 |
--------------------------------------------------------------------------------
/STADE-CDNet/misc/imutils.py:
--------------------------------------------------------------------------------
1 | import random
2 | import numpy as np
3 | import cv2
4 | from PIL import Image
5 | from PIL import ImageFilter
6 | import PIL
7 | import tifffile
8 |
9 |
10 | def cv_rotate(image, angle, borderValue):
11 | """
12 | rot angle, fill with borderValue
13 | """
14 | # grab the dimensions of the image and then determine the
15 | # center
16 | (h, w) = image.shape[:2]
17 | (cX, cY) = (w // 2, h // 2)
18 |
19 | # grab the rotation matrix (applying the negative of the
20 | # angle to rotate clockwise), then grab the sine and cosine
21 | # (i.e., the rotation components of the matrix)
22 | # -angle位置参数为角度参数负值表示顺时针旋转; 1.0位置参数scale是调整尺寸比例(图像缩放参数),建议0.75
23 | M = cv2.getRotationMatrix2D((cX, cY), -angle, 1.0)
24 | cos = np.abs(M[0, 0])
25 | sin = np.abs(M[0, 1])
26 |
27 | # compute the new bounding dimensions of the image
28 | nW = int((h * sin) + (w * cos))
29 | nH = int((h * cos) + (w * sin))
30 |
31 | # adjust the rotation matrix to take into account translation
32 | M[0, 2] += (nW / 2) - cX
33 | M[1, 2] += (nH / 2) - cY
34 | if isinstance(borderValue, int):
35 | values = (borderValue, borderValue, borderValue)
36 | else:
37 | values = borderValue
38 | # perform the actual rotation and return the image
39 | return cv2.warpAffine(image, M, (nW, nH), borderValue=values)
40 |
41 |
42 | def pil_resize(img, size, order):
43 | if size[0] == img.shape[0] and size[1] == img.shape[1]:
44 | return img
45 |
46 | if order == 3:
47 | resample = Image.BICUBIC
48 | elif order == 0:
49 | resample = Image.NEAREST
50 |
51 | return np.asarray(Image.fromarray(img).resize(size[::-1], resample))
52 |
53 |
54 | def pil_rescale(img, scale, order):
55 | height, width = img.shape[:2]
56 | target_size = (int(np.round(height*scale)), int(np.round(width*scale)))
57 | return pil_resize(img, target_size, order)
58 |
59 |
60 | def pil_rotate(img, degree, default_value):
61 | if isinstance(default_value, tuple):
62 | values = (default_value[0], default_value[1], default_value[2], 0)
63 | else:
64 | values = (default_value, default_value, default_value,0)
65 | img = Image.fromarray(img)
66 | if img.mode =='RGB':
67 | # set img padding == default_value
68 | img2 = img.convert('RGBA')
69 | rot = img2.rotate(degree, expand=1)
70 | fff = Image.new('RGBA', rot.size, values) # 灰色
71 | out = Image.composite(rot, fff, rot)
72 | img = out.convert(img.mode)
73 |
74 | else:
75 | # set label padding == default_value
76 | img2 = img.convert('RGBA')
77 | rot = img2.rotate(degree, expand=1)
78 | # a white image same size as rotated image
79 | fff = Image.new('RGBA', rot.size, values)
80 | # create a composite image using the alpha layer of rot as a mask
81 | out = Image.composite(rot, fff, rot)
82 | img = out.convert(img.mode)
83 |
84 | return np.asarray(img)
85 |
86 |
87 | def random_resize_long_image_list(img_list, min_long, max_long):
88 | target_long = random.randint(min_long, max_long)
89 | h, w = img_list[0].shape[:2]
90 | if w < h:
91 | scale = target_long / h
92 | else:
93 | scale = target_long / w
94 | out = []
95 | for img in img_list:
96 | out.append(pil_rescale(img, scale, 3) )
97 | return out
98 |
99 |
100 | def random_resize_long(img, min_long, max_long):
101 | target_long = random.randint(min_long, max_long)
102 | h, w = img.shape[:2]
103 |
104 | if w < h:
105 | scale = target_long / h
106 | else:
107 | scale = target_long / w
108 |
109 | return pil_rescale(img, scale, 3)
110 |
111 |
112 | def random_scale_list(img_list, scale_range, order):
113 | """
114 | 输入:图像列表
115 | """
116 | target_scale = scale_range[0] + random.random() * (scale_range[1] - scale_range[0])
117 |
118 | if isinstance(img_list, tuple):
119 | assert img_list.__len__() == 2
120 | img1 = []
121 | img2 = []
122 | for img in img_list[0]:
123 | img1.append(pil_rescale(img, target_scale, order[0]))
124 | for img in img_list[1]:
125 | img2.append(pil_rescale(img, target_scale, order[1]))
126 | return (img1, img2)
127 | else:
128 | out = []
129 | for img in img_list:
130 | out.append(pil_rescale(img, target_scale, order))
131 | return out
132 |
133 |
134 | def random_scale(img, scale_range, order):
135 |
136 | target_scale = scale_range[0] + random.random() * (scale_range[1] - scale_range[0])
137 |
138 | if isinstance(img, tuple):
139 | return (pil_rescale(img[0], target_scale, order[0]), pil_rescale(img[1], target_scale, order[1]))
140 | else:
141 | return pil_rescale(img, target_scale, order)
142 |
143 |
144 | def random_rotate_list(img_list, max_degree, default_values):
145 | degree = random.random() * max_degree
146 | if isinstance(img_list, tuple):
147 | assert img_list.__len__() == 2
148 | img1 = []
149 | img2 = []
150 | for img in img_list[0]:
151 | assert isinstance(img, np.ndarray)
152 | img1.append((pil_rotate(img, degree, default_values[0])))
153 | for img in img_list[1]:
154 | img2.append((pil_rotate(img, degree, default_values[1])))
155 | return (img1, img2)
156 | else:
157 | out = []
158 | for img in img_list:
159 | out.append(pil_rotate(img, degree, default_values))
160 | return out
161 |
162 |
163 | def random_rotate(img, max_degree, default_values):
164 | degree = random.random() * max_degree
165 | if isinstance(img, tuple):
166 | return (pil_rotate(img[0], degree, default_values[0]),
167 | pil_rotate(img[1], degree, default_values[1]))
168 | else:
169 | return pil_rotate(img, degree, default_values)
170 |
171 |
172 | def random_lr_flip_list(img_list):
173 |
174 | if bool(random.getrandbits(1)):
175 | if isinstance(img_list, tuple):
176 | assert img_list.__len__()==2
177 | img1=list((np.fliplr(m) for m in img_list[0]))
178 | img2=list((np.fliplr(m) for m in img_list[1]))
179 |
180 | return (img1, img2)
181 | else:
182 | return list([np.fliplr(m) for m in img_list])
183 | else:
184 | return img_list
185 |
186 |
187 | def random_lr_flip(img):
188 |
189 | if bool(random.getrandbits(1)):
190 | if isinstance(img, tuple):
191 | return tuple([np.fliplr(m) for m in img])
192 | else:
193 | return np.fliplr(img)
194 | else:
195 | return img
196 |
197 |
198 | def get_random_crop_box(imgsize, cropsize):
199 | h, w = imgsize
200 |
201 | ch = min(cropsize, h)
202 | cw = min(cropsize, w)
203 |
204 | w_space = w - cropsize
205 | h_space = h - cropsize
206 |
207 | if w_space > 0:
208 | cont_left = 0
209 | img_left = random.randrange(w_space + 1)
210 | else:
211 | cont_left = random.randrange(-w_space + 1)
212 | img_left = 0
213 |
214 | if h_space > 0:
215 | cont_top = 0
216 | img_top = random.randrange(h_space + 1)
217 | else:
218 | cont_top = random.randrange(-h_space + 1)
219 | img_top = 0
220 |
221 | return cont_top, cont_top+ch, cont_left, cont_left+cw, img_top, img_top+ch, img_left, img_left+cw
222 |
223 |
224 | def random_crop_list(images_list, cropsize, default_values):
225 |
226 | if isinstance(images_list, tuple):
227 | imgsize = images_list[0][0].shape[:2]
228 | elif isinstance(images_list, list):
229 | imgsize = images_list[0].shape[:2]
230 | else:
231 | raise RuntimeError('do not support the type of image_list')
232 | if isinstance(default_values, int): default_values = (default_values,)
233 |
234 | box = get_random_crop_box(imgsize, cropsize)
235 | if isinstance(images_list, tuple):
236 | assert images_list.__len__()==2
237 | img1 = []
238 | img2 = []
239 | for img in images_list[0]:
240 | f = default_values[0]
241 | if len(img.shape) == 3:
242 | cont = np.ones((cropsize, cropsize, img.shape[2]), img.dtype)*f
243 | else:
244 | cont = np.ones((cropsize, cropsize), img.dtype)*f
245 | cont[box[0]:box[1], box[2]:box[3]] = img[box[4]:box[5], box[6]:box[7]]
246 | img1.append(cont)
247 | for img in images_list[1]:
248 | f = default_values[1]
249 | if len(img.shape) == 3:
250 | cont = np.ones((cropsize, cropsize, img.shape[2]), img.dtype)*f
251 | else:
252 | cont = np.ones((cropsize, cropsize), img.dtype)*f
253 | cont[box[0]:box[1], box[2]:box[3]] = img[box[4]:box[5], box[6]:box[7]]
254 | img2.append(cont)
255 | return (img1, img2)
256 | else:
257 | out = []
258 | for img in images_list:
259 | f = default_values
260 | if len(img.shape) == 3:
261 | cont = np.ones((cropsize, cropsize, img.shape[2]), img.dtype) * f
262 | else:
263 | cont = np.ones((cropsize, cropsize), img.dtype) * f
264 | cont[box[0]:box[1], box[2]:box[3]] = img[box[4]:box[5], box[6]:box[7]]
265 | out.append(cont)
266 | return out
267 |
268 |
269 | def random_crop(images, cropsize, default_values):
270 |
271 | if isinstance(images, np.ndarray): images = (images,)
272 | if isinstance(default_values, int): default_values = (default_values,)
273 |
274 | imgsize = images[0].shape[:2]
275 | box = get_random_crop_box(imgsize, cropsize)
276 |
277 | new_images = []
278 | for img, f in zip(images, default_values):
279 |
280 | if len(img.shape) == 3:
281 | cont = np.ones((cropsize, cropsize, img.shape[2]), img.dtype)*f
282 | else:
283 | cont = np.ones((cropsize, cropsize), img.dtype)*f
284 | cont[box[0]:box[1], box[2]:box[3]] = img[box[4]:box[5], box[6]:box[7]]
285 | new_images.append(cont)
286 |
287 | if len(new_images) == 1:
288 | new_images = new_images[0]
289 |
290 | return new_images
291 |
292 |
293 | def top_left_crop(img, cropsize, default_value):
294 |
295 | h, w = img.shape[:2]
296 |
297 | ch = min(cropsize, h)
298 | cw = min(cropsize, w)
299 |
300 | if len(img.shape) == 2:
301 | container = np.ones((cropsize, cropsize), img.dtype)*default_value
302 | else:
303 | container = np.ones((cropsize, cropsize, img.shape[2]), img.dtype)*default_value
304 |
305 | container[:ch, :cw] = img[:ch, :cw]
306 |
307 | return container
308 |
309 |
310 | def center_crop(img, cropsize, default_value=0):
311 |
312 | h, w = img.shape[:2]
313 |
314 | ch = min(cropsize, h)
315 | cw = min(cropsize, w)
316 |
317 | sh = h - cropsize
318 | sw = w - cropsize
319 |
320 | if sw > 0:
321 | cont_left = 0
322 | img_left = int(round(sw / 2))
323 | else:
324 | cont_left = int(round(-sw / 2))
325 | img_left = 0
326 |
327 | if sh > 0:
328 | cont_top = 0
329 | img_top = int(round(sh / 2))
330 | else:
331 | cont_top = int(round(-sh / 2))
332 | img_top = 0
333 |
334 | if len(img.shape) == 2:
335 | container = np.ones((cropsize, cropsize), img.dtype)*default_value
336 | else:
337 | container = np.ones((cropsize, cropsize, img.shape[2]), img.dtype)*default_value
338 |
339 | container[cont_top:cont_top+ch, cont_left:cont_left+cw] = \
340 | img[img_top:img_top+ch, img_left:img_left+cw]
341 |
342 | return container
343 |
344 |
345 | def HWC_to_CHW(img):
346 | return np.transpose(img, (2, 0, 1))
347 |
348 |
349 | def pil_blur(img, radius):
350 | return np.array(Image.fromarray(img).filter(ImageFilter.GaussianBlur(radius=radius)))
351 |
352 |
353 | def random_blur(img):
354 | radius = random.random()
355 | # print('add blur: ', radius)
356 | if isinstance(img, list):
357 | out = []
358 | for im in img:
359 | out.append(pil_blur(im, radius))
360 | return out
361 | elif isinstance(img, np.ndarray):
362 | return pil_blur(img, radius)
363 | else:
364 | print(img)
365 | raise RuntimeError("do not support the input image type!")
366 |
367 |
368 | def save_image(image_numpy, image_path):
369 | """Save a numpy image to the disk
370 | Parameters:
371 | image_numpy (numpy array) -- input numpy array
372 | image_path (str) -- the path of the image
373 | """
374 | image_pil = Image.fromarray(np.array(image_numpy,dtype=np.uint8))
375 | image_pil.save(image_path)
376 |
377 |
378 | def im2arr(img_path, mode=1, dtype=np.uint8):
379 | """
380 | :param img_path:
381 | :param mode:
382 | :return: numpy.ndarray, shape: H*W*C
383 | """
384 | if mode==1:
385 | img = PIL.Image.open(img_path)
386 | arr = np.asarray(img, dtype=dtype)
387 | else:
388 | arr = tifffile.imread(img_path)
389 | if arr.ndim == 3:
390 | a, b, c = arr.shape
391 | if a < b and a < c: # 当arr为C*H*W时,需要交换通道顺序
392 | arr = arr.transpose([1,2,0])
393 | # print('shape: ', arr.shape, 'dytpe: ',arr.dtype)
394 | return arr
395 |
396 |
397 |
398 |
399 |
400 |
401 |
402 |
--------------------------------------------------------------------------------
/STADE-CDNet/misc/logger_tool.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import time
3 |
4 |
5 | class Logger(object):
6 | def __init__(self, outfile):
7 | self.terminal = sys.stdout
8 | self.log_path = outfile
9 | now = time.strftime("%c")
10 | self.write('================ (%s) ================\n' % now)
11 |
12 | def write(self, message):
13 | self.terminal.write(message)
14 | with open(self.log_path, mode='a') as f:
15 | f.write(message)
16 |
17 | def write_dict(self, dict):
18 | message = ''
19 | for k, v in dict.items():
20 | message += '%s: %.7f ' % (k, v)
21 | self.write(message)
22 |
23 | def write_dict_str(self, dict):
24 | message = ''
25 | for k, v in dict.items():
26 | message += '%s: %s ' % (k, v)
27 | self.write(message)
28 |
29 | def flush(self):
30 | self.terminal.flush()
31 |
32 |
33 | class Timer:
34 | def __init__(self, starting_msg = None):
35 | self.start = time.time()
36 | self.stage_start = self.start
37 |
38 | if starting_msg is not None:
39 | print(starting_msg, time.ctime(time.time()))
40 |
41 | def __enter__(self):
42 | return self
43 |
44 | def __exit__(self, exc_type, exc_val, exc_tb):
45 | return
46 |
47 | def update_progress(self, progress):
48 | self.elapsed = time.time() - self.start
49 | self.est_total = self.elapsed / progress
50 | self.est_remaining = self.est_total - self.elapsed
51 | self.est_finish = int(self.start + self.est_total)
52 |
53 |
54 | def str_estimated_complete(self):
55 | return str(time.ctime(self.est_finish))
56 |
57 | def str_estimated_remaining(self):
58 | return str(self.est_remaining/3600) + 'h'
59 |
60 | def estimated_remaining(self):
61 | return self.est_remaining/3600
62 |
63 | def get_stage_elapsed(self):
64 | return time.time() - self.stage_start
65 |
66 | def reset_stage(self):
67 | self.stage_start = time.time()
68 |
69 | def lapse(self):
70 | out = time.time() - self.stage_start
71 | self.stage_start = time.time()
72 | return out
73 |
74 |
--------------------------------------------------------------------------------
/STADE-CDNet/misc/metric_tool.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | ################### metrics ###################
5 | class AverageMeter(object):
6 | """Computes and stores the average and current value"""
7 | def __init__(self):
8 | self.initialized = False
9 | self.val = None
10 | self.avg = None
11 | self.sum = None
12 | self.count = None
13 |
14 | def initialize(self, val, weight):
15 | self.val = val
16 | self.avg = val
17 | self.sum = val * weight
18 | self.count = weight
19 | self.initialized = True
20 |
21 | def update(self, val, weight=1):
22 | if not self.initialized:
23 | self.initialize(val, weight)
24 | else:
25 | self.add(val, weight)
26 |
27 | def add(self, val, weight):
28 | self.val = val
29 | self.sum += val * weight
30 | self.count += weight
31 | self.avg = self.sum / self.count
32 |
33 | def value(self):
34 | return self.val
35 |
36 | def average(self):
37 | return self.avg
38 |
39 | def get_scores(self):
40 | scores_dict = cm2score(self.sum)
41 | return scores_dict
42 |
43 | def clear(self):
44 | self.initialized = False
45 |
46 |
47 | ################### cm metrics ###################
48 | class ConfuseMatrixMeter(AverageMeter):
49 | """Computes and stores the average and current value"""
50 | def __init__(self, n_class):
51 | super(ConfuseMatrixMeter, self).__init__()
52 | self.n_class = n_class
53 |
54 | def update_cm(self, pr, gt, weight=1):
55 | """获得当前混淆矩阵,并计算当前F1得分,并更新混淆矩阵"""
56 | val = get_confuse_matrix(num_classes=self.n_class, label_gts=gt, label_preds=pr)
57 | self.update(val, weight)
58 | current_score = cm2F1(val)
59 | return current_score
60 |
61 | def get_scores(self):
62 | scores_dict = cm2score(self.sum)
63 | return scores_dict
64 |
65 |
66 |
67 | def harmonic_mean(xs):
68 | harmonic_mean = len(xs) / sum((x+1e-6)**-1 for x in xs)
69 | return harmonic_mean
70 |
71 |
72 | def cm2F1(confusion_matrix):
73 | hist = confusion_matrix
74 | n_class = hist.shape[0]
75 | tp = np.diag(hist)
76 | sum_a1 = hist.sum(axis=1)
77 | sum_a0 = hist.sum(axis=0)
78 | # ---------------------------------------------------------------------- #
79 | # 1. Accuracy & Class Accuracy
80 | # ---------------------------------------------------------------------- #
81 | acc = tp.sum() / (hist.sum() + np.finfo(np.float32).eps)
82 |
83 | # recall
84 | recall = tp / (sum_a1 + np.finfo(np.float32).eps)
85 | # acc_cls = np.nanmean(recall)
86 |
87 | # precision
88 | precision = tp / (sum_a0 + np.finfo(np.float32).eps)
89 |
90 | # F1 score
91 | F1 = 2 * recall * precision / (recall + precision + np.finfo(np.float32).eps)
92 | mean_F1 = np.nanmean(F1)
93 | return mean_F1
94 |
95 |
96 | def cm2score(confusion_matrix):
97 | hist = confusion_matrix
98 | n_class = hist.shape[0]
99 | tp = np.diag(hist)
100 | sum_a1 = hist.sum(axis=1)
101 | sum_a0 = hist.sum(axis=0)
102 | # ---------------------------------------------------------------------- #
103 | # 1. Accuracy & Class Accuracy
104 | # ---------------------------------------------------------------------- #
105 | acc = tp.sum() / (hist.sum() + np.finfo(np.float32).eps)
106 |
107 | # recall
108 | recall = tp / (sum_a1 + np.finfo(np.float32).eps)
109 | # acc_cls = np.nanmean(recall)
110 |
111 | # precision
112 | precision = tp / (sum_a0 + np.finfo(np.float32).eps)
113 |
114 | # F1 score
115 | F1 = 2*recall * precision / (recall + precision + np.finfo(np.float32).eps)
116 | mean_F1 = np.nanmean(F1)
117 | # ---------------------------------------------------------------------- #
118 | # 2. Frequency weighted Accuracy & Mean IoU
119 | # ---------------------------------------------------------------------- #
120 | iu = tp / (sum_a1 + hist.sum(axis=0) - tp + np.finfo(np.float32).eps)
121 | mean_iu = np.nanmean(iu)
122 |
123 | freq = sum_a1 / (hist.sum() + np.finfo(np.float32).eps)
124 | fwavacc = (freq[freq > 0] * iu[freq > 0]).sum()
125 |
126 | #
127 | cls_iou = dict(zip(['iou_'+str(i) for i in range(n_class)], iu))
128 |
129 | cls_precision = dict(zip(['precision_'+str(i) for i in range(n_class)], precision))
130 | cls_recall = dict(zip(['recall_'+str(i) for i in range(n_class)], recall))
131 | cls_F1 = dict(zip(['F1_'+str(i) for i in range(n_class)], F1))
132 |
133 | score_dict = {'acc': acc, 'miou': mean_iu, 'mf1':mean_F1}
134 | score_dict.update(cls_iou)
135 | score_dict.update(cls_F1)
136 | score_dict.update(cls_precision)
137 | score_dict.update(cls_recall)
138 | return score_dict
139 |
140 |
141 | def get_confuse_matrix(num_classes, label_gts, label_preds):
142 | """计算一组预测的混淆矩阵"""
143 | def __fast_hist(label_gt, label_pred):
144 | """
145 | Collect values for Confusion Matrix
146 | For reference, please see: https://en.wikipedia.org/wiki/Confusion_matrix
147 | :param label_gt: ground-truth
148 | :param label_pred: prediction
149 | :return: values for confusion matrix
150 | """
151 | mask = (label_gt >= 0) & (label_gt < num_classes)
152 | hist = np.bincount(num_classes * label_gt[mask].astype(int) + label_pred[mask],
153 | minlength=num_classes**2).reshape(num_classes, num_classes)
154 | return hist
155 | confusion_matrix = np.zeros((num_classes, num_classes))
156 | for lt, lp in zip(label_gts, label_preds):
157 | confusion_matrix += __fast_hist(lt.flatten(), lp.flatten())
158 | return confusion_matrix
159 |
160 |
161 | def get_mIoU(num_classes, label_gts, label_preds):
162 | confusion_matrix = get_confuse_matrix(num_classes, label_gts, label_preds)
163 | score_dict = cm2score(confusion_matrix)
164 | return score_dict['miou']
165 |
--------------------------------------------------------------------------------
/STADE-CDNet/misc/pyutils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 | import random
4 | import glob
5 |
6 |
7 | def seed_random(seed=2020):
8 | # 加入以下随机种子,数据输入,随机扩充等保持一致
9 | random.seed(seed)
10 | os.environ['PYTHONHASHSEED'] = str(seed)
11 | np.random.seed(seed)
12 |
13 |
14 | def mkdir(path):
15 | """create a single empty directory if it didn't exist
16 |
17 | Parameters:
18 | path (str) -- a single directory path
19 | """
20 | if not os.path.exists(path):
21 | os.makedirs(path)
22 |
23 |
24 | def get_paths(image_folder_path, suffix='*.png'):
25 | """从文件夹中返回指定格式的文件
26 | :param image_folder_path: str
27 | :param suffix: str
28 | :return: list
29 | """
30 | paths = sorted(glob.glob(os.path.join(image_folder_path, suffix)))
31 | return paths
32 |
33 |
34 | def get_paths_from_list(image_folder_path, list):
35 | """从image folder中找到list中的文件,返回path list"""
36 | out = []
37 | for item in list:
38 | path = os.path.join(image_folder_path,item)
39 | out.append(path)
40 | return sorted(out)
41 |
42 |
43 |
--------------------------------------------------------------------------------
/STADE-CDNet/misc/torchutils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.optim import lr_scheduler
3 | from torch.utils.data import Subset
4 | import torch.nn.functional as F
5 | import numpy as np
6 | import math
7 | import random
8 | import os
9 | from torch.nn import MaxPool1d,AvgPool1d
10 | from torch import Tensor
11 | from typing import Iterable, Set, Tuple
12 |
13 |
14 | __all__ = ['cls_accuracy']
15 |
16 |
17 |
18 | def visualize_imgs(*imgs):
19 | """
20 | 可视化图像,ndarray格式的图像
21 | :param imgs: ndarray:H*W*C, C=1/3
22 | :return:
23 | """
24 | import matplotlib.pyplot as plt
25 | nums = len(imgs)
26 | if nums > 1:
27 | fig, axs = plt.subplots(1, nums)
28 | for i, image in enumerate(imgs):
29 | axs[i].imshow(image, cmap='jet')
30 | elif nums == 1:
31 | fig, ax = plt.subplots(1, nums)
32 | for i, image in enumerate(imgs):
33 | ax.imshow(image, cmap='jet')
34 | plt.show()
35 | plt.show()
36 |
37 | def minmax(tensor):
38 | assert tensor.ndim >= 2
39 | shape = tensor.shape
40 | tensor = tensor.view([*shape[:-2], shape[-1]*shape[-2]])
41 | min_, _ = tensor.min(-1, keepdim=True)
42 | max_, _ = tensor.max(-1, keepdim=True)
43 | return min_, max_
44 |
45 | def norm_tensor(tensor,min_=None,max_=None, mode='minmax'):
46 | """
47 | 输入:N*C*H*W / C*H*W / H*W
48 | 输出:在H*W维度的归一化的与原始等大的图
49 | """
50 | assert tensor.ndim >= 2
51 | shape = tensor.shape
52 | tensor = tensor.view([*shape[:-2], shape[-1]*shape[-2]])
53 | if mode == 'minmax':
54 | if min_ is None:
55 | min_, _ = tensor.min(-1, keepdim=True)
56 | if max_ is None:
57 | max_, _ = tensor.max(-1, keepdim=True)
58 | tensor = (tensor - min_) / (max_ - min_ + 0.00000000001)
59 | elif mode == 'thres':
60 | N = tensor.shape[-1]
61 | thres_a = 0.001
62 | top_k = round(thres_a*N)
63 | max_ = tensor.topk(top_k, dim=-1, largest=True)[0][..., -1]
64 | max_ = max_.unsqueeze(-1)
65 | min_ = tensor.topk(top_k, dim=-1, largest=False)[0][..., -1]
66 | min_ = min_.unsqueeze(-1)
67 | tensor = (tensor - min_) / (max_ - min_ + 0.00000000001)
68 |
69 | elif mode == 'std':
70 | mean, std = torch.std_mean(tensor, [-1], keepdim=True)
71 | tensor = (tensor - mean)/std
72 | min_, _ = tensor.min(-1, keepdim=True)
73 | max_, _ = tensor.max(-1, keepdim=True)
74 | tensor = (tensor - min_) / (max_ - min_ + 0.00000000001)
75 | elif mode == 'exp':
76 | tai = 1
77 | tensor = torch.nn.functional.softmax(tensor/tai, dim=-1, )
78 | min_, _ = tensor.min(-1, keepdim=True)
79 | max_, _ = tensor.max(-1, keepdim=True)
80 | tensor = (tensor - min_) / (max_ - min_ + 0.00000000001)
81 | else:
82 | raise NotImplementedError
83 | tensor = torch.clamp(tensor, 0, 1)
84 | return tensor.view(shape)
85 |
86 | # if tensor.ndim == 4:
87 | # B, C, H, W = tensor.shape
88 | # tensor = tensor.view([B, C, -1])
89 | # min_, _ = tensor.min(-1, keepdim=True)
90 | # max_, _ = tensor.max(-1, keepdim=True)
91 | # tensor = (tensor - min_) / (max_ - min_ + 0.00000000001)
92 | # return tensor.view(B, C, H, W)
93 | # elif tensor.ndim == 3:
94 | # C, H, W = tensor.shape
95 | # tensor = tensor.view([C, -1])
96 | # min_, _ = tensor.min(-1, keepdim=True)
97 | # max_, _ = tensor.max(-1, keepdim=True)
98 | # tensor = (tensor - min_) / (max_ - min_ + 0.00000000001)
99 | # return tensor.view(C, H, W)
100 | # elif tensor.ndim == 2:
101 | # H, W = tensor.shape
102 | # tensor = tensor.view([-1])
103 | # min_, _ = tensor.min(-1, keepdim=True)
104 | # max_, _ = tensor.max(-1, keepdim=True)
105 | # tensor = (tensor - min_) / (max_ - min_ + 0.00000000001)
106 | # return tensor.view(H, W)
107 | # else:
108 | # raise NotImplementedError
109 |
110 | def visulize_features(features, normalize=False):
111 | """
112 | 可视化特征图,各维度make grid到一起
113 | """
114 | from torchvision.utils import make_grid
115 | assert features.ndim == 4
116 | b,c,h,w = features.shape
117 | features = features.view((b*c, 1, h, w))
118 | if normalize:
119 | features = norm_tensor(features)
120 | grid = make_grid(features)
121 | visualize_tensors(grid)
122 |
123 | def visualize_tensors(*tensors):
124 | """
125 | 可视化tensor,支持单通道特征或3通道图像
126 | :param tensors: tensor: C*H*W, C=1/3
127 | :return:
128 | """
129 | import matplotlib.pyplot as plt
130 | # from misc.torchutils import tensor2np
131 | images = []
132 | for tensor in tensors:
133 | assert tensor.ndim == 3 or tensor.ndim==2
134 | if tensor.ndim ==3:
135 | assert tensor.shape[0] == 1 or tensor.shape[0] == 3
136 | images.append(tensor2np(tensor))
137 | nums = len(images)
138 | if nums>1:
139 | fig, axs = plt.subplots(1, nums)
140 | for i, image in enumerate(images):
141 | axs[i].imshow(image, cmap='jet')
142 | plt.show()
143 | elif nums == 1:
144 | fig, ax = plt.subplots(1, nums)
145 | for i, image in enumerate(images):
146 | ax.imshow(image, cmap='jet')
147 | plt.show()
148 |
149 |
150 | def np_to_tensor(image):
151 | """
152 | input: nd.array: H*W*C/H*W
153 | """
154 | if isinstance(image, torch.Tensor):
155 | return image
156 | elif isinstance(image, np.ndarray):
157 | if image.ndim == 3:
158 | if image.shape[2]==3:
159 | image = np.transpose(image,[2,0,1])
160 | elif image.ndim == 2:
161 | image = np.newaxis(image, 0)
162 | image = torch.from_numpy(image)
163 | return image.unsqueeze(0)
164 |
165 |
166 | def seed_torch(seed=2019):
167 |
168 | # 加入以下随机种子,数据输入,随机扩充等保持一致
169 | random.seed(seed)
170 | os.environ['PYTHONHASHSEED'] = str(seed)
171 | np.random.seed(seed)
172 | torch.manual_seed(seed)
173 | torch.cuda.manual_seed(seed)
174 | # 加入所有随机种子后,模型更新后,中间结果还是不一样,
175 | # 发现这一的现象:前两轮,的结果还是一样;随着模型更新结果会变;
176 | # torch.backends.cudnn.benchmark = False
177 | # torch.backends.cudnn.deterministic = True
178 |
179 | def simplex(t: Tensor, axis=1) -> bool:
180 | _sum = t.sum(axis).type(torch.float32)
181 | _ones = torch.ones_like(_sum, dtype=torch.float32)
182 | return torch.allclose(_sum, _ones)
183 |
184 |
185 | # Assert utils
186 | def uniq(a: Tensor) -> Set:
187 | return set(torch.unique(a.cpu()).numpy())
188 |
189 | def sset(a: Tensor, sub: Iterable) -> bool:
190 | return uniq(a).issubset(sub)
191 |
192 | def eq(a: Tensor, b) -> bool:
193 | return torch.eq(a, b).all()
194 |
195 | def one_hot(t: Tensor, axis=1) -> bool:
196 | return simplex(t, axis) and sset(t, [0, 1])
197 |
198 |
199 | def class2one_hot(seg: Tensor, C: int) -> Tensor:
200 | if len(seg.shape) == 2: # Only w, h, used by the dataloader
201 | seg = seg.unsqueeze(dim=0)
202 | assert sset(seg, list(range(C)))
203 |
204 | b, w, h = seg.shape # type: Tuple[int, int, int]
205 |
206 | res = torch.stack([seg == c for c in range(C)], dim=1).type(torch.int32)
207 | assert res.shape == (b, C, w, h)
208 | assert one_hot(res)
209 |
210 | return res
211 |
212 | class ChannelMaxPool(MaxPool1d):
213 | def forward(self, input):
214 | n, c, w, h = input.size()
215 | input = input.view(n,c,w*h).permute(0,2,1)
216 | pooled = F.max_pool1d(input, self.kernel_size, self.stride,
217 | self.padding, self.dilation, self.ceil_mode,
218 | self.return_indices)
219 | _, _, c = pooled.size()
220 | pooled = pooled.permute(0,2,1)
221 | return pooled.view(n,c,w,h)
222 |
223 | class ChannelAvePool(AvgPool1d):
224 | def forward(self, input):
225 | n, c, w, h = input.size()
226 | input = input.view(n,c,w*h).permute(0,2,1)
227 | pooled = F.avg_pool1d(input, self.kernel_size, self.stride,
228 | self.padding)
229 | _, _, c = pooled.size()
230 | pooled = pooled.permute(0,2,1)
231 | return pooled.view(n,c,w,h)
232 |
233 | def cross_entropy(input, target, weight=None, reduction='mean',ignore_index=255):
234 | """
235 | logSoftmax_with_loss
236 | :param input: torch.Tensor, N*C*H*W
237 | :param target: torch.Tensor, N*1*H*W,/ N*H*W
238 | :param weight: torch.Tensor, C
239 | :return: torch.Tensor [0]
240 | """
241 | target = target.long()
242 | if target.dim() == 4:
243 | target = torch.squeeze(target, dim=1)
244 | if input.shape[-1] != target.shape[-1]:
245 | input = F.interpolate(input, size=target.shape[1:], mode='bilinear',align_corners=True)
246 |
247 | return F.cross_entropy(input=input, target=target, weight=weight,
248 | ignore_index=ignore_index, reduction=reduction)
249 |
250 | def balanced_cross_entropy(input, target, weight=None,ignore_index=255):
251 | """
252 | 类别均衡的交叉熵损失,暂时只支持2类
253 | TODO: 扩展到多类C>2
254 | """
255 | if target.dim() == 4:
256 | target = torch.squeeze(target, dim=1)
257 | if input.shape[-1] != target.shape[-1]:
258 | input = F.interpolate(input, size=target.shape[1:], mode='bilinear',align_corners=True)
259 |
260 | # print('target.sum',target.sum())
261 | pos = (target==1).float()
262 | neg = (target==0).float()
263 | pos_num = torch.sum(pos) + 0.0000001
264 | neg_num = torch.sum(neg) + 0.0000001
265 | # print(pos_num)
266 | # print(neg_num)
267 | target_pos = target.float()
268 | target_pos[target_pos!=1] = ignore_index # 忽略不为正样本的区域
269 | target_neg = target.float()
270 | target_neg[target_neg!=0] = ignore_index # 忽略不为负样本的区域
271 |
272 | # print('target.sum',target.sum())
273 |
274 | loss_pos = cross_entropy(input, target_pos,weight=weight,reduction='sum',ignore_index=ignore_index)
275 | loss_neg = cross_entropy(input, target_neg,weight=weight,reduction='sum',ignore_index=ignore_index)
276 | # print(loss_neg, loss_pos)
277 | loss = 0.5 * loss_pos / pos_num + 0.5 * loss_neg / neg_num
278 | # loss = (loss_pos + loss_neg)/ (pos_num+neg_num)
279 | return loss
280 |
281 | def get_scheduler(optimizer, opt):
282 | """Return a learning rate scheduler
283 | """
284 | if opt.lr_policy == 'linear':
285 | def lambda_rule(epoch):
286 | lr_l = 1.0 - max(0, epoch + opt.epoch_count - opt.niter) / float(opt.niter_decay + 1)
287 | return lr_l
288 | scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
289 | elif opt.lr_policy == 'poly':
290 | max_step = opt.niter+opt.niter_decay
291 | power = 0.9
292 | def lambda_rule(epoch):
293 | current_step = epoch + opt.epoch_count
294 | lr_l = (1.0 - current_step / (max_step+1)) ** float(power)
295 | return lr_l
296 | scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
297 | elif opt.lr_policy == 'step':
298 | scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1)
299 | else:
300 | return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy)
301 | return scheduler
302 |
303 |
304 | def mul_cls_acc(preds, targets, topk=(1,)):
305 | """计算multi-label分类的top-k准确率topk-acc,topk-error=1-topk-acc;
306 | 首先计算每张图的的平均准确率,再计算所有图的平均准确率
307 | :param pred: N * C
308 | :param target: N * C
309 | :param topk:
310 | :return:
311 | """
312 | with torch.no_grad():
313 | maxk = max(topk)
314 | bs, C = targets.shape
315 | _, pred = preds.topk(maxk, 1, True, True)
316 | pred += 1 # pred 为类别\in [1,C]
317 | # print('pred: ', pred)
318 | # print('targets: ', targets)
319 | correct = torch.zeros([bs, maxk]).long() # 记录预测正确label数量
320 | if preds.device != torch.device(type='cpu'):
321 | correct = correct.cuda()
322 | for i in range(C):
323 | label = i + 1
324 | target = targets[:, i] * label
325 | # print('target.view: ', target.view(-1, 1).expand_as(pred))
326 | # print('pred: ', pred)
327 | correct = correct + pred.eq(target.view(-1, 1).expand_as(pred)).long()
328 | # print('correct: ', pred.eq(target.view(-1, 1).expand_as(pred)).long())
329 | n = (targets == 1).long().sum(1) # N*1, 每张图中含有目标的数量
330 | # print(n)
331 | res = []
332 | for k in topk:
333 | acc_k = correct[:, :k].sum(1).float() / n.float() # 每张图的平均正确率,预测正确目标数/总目标数
334 | # print(correct[:, :k].sum(1).float())
335 | acc_k = acc_k.sum()/bs
336 | res.append(acc_k)
337 | # print(acc_k)
338 | return res
339 |
340 |
341 | def cls_accuracy(output, target, topk=(1,)):
342 | """
343 | Computes the accuracy over the k top predictions for the specified values of k
344 | https://github.com/pytorch/examples/blob/ee964a2eeb41e1712fe719b83645c79bcbd0ba1a/imagenet/main.py#L407
345 | """
346 |
347 | with torch.no_grad():
348 | maxk = max(topk)
349 | batch_size = target.size(0)
350 |
351 | _, pred = output.topk(maxk, 1, True, True)
352 | pred = pred.t()
353 | correct = pred.eq(target.view(1, -1).expand_as(pred))
354 |
355 | res = []
356 | for k in topk:
357 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
358 | res.append(correct_k.mul_(100.0 / batch_size))
359 | return res
360 |
361 | class PolyOptimizer(torch.optim.SGD):
362 |
363 | def __init__(self, params, lr, weight_decay, max_step, init_step=0, momentum=0.9):
364 | super().__init__(params, lr, weight_decay)
365 |
366 | self.global_step = init_step
367 | print(self.global_step)
368 | self.max_step = max_step
369 | self.momentum = momentum
370 |
371 | self.__initial_lr = [group['lr'] for group in self.param_groups]
372 |
373 |
374 | def step(self, closure=None):
375 |
376 | if self.global_step < self.max_step:
377 | lr_mult = (1 - self.global_step / self.max_step) ** self.momentum
378 |
379 | for i in range(len(self.param_groups)):
380 | self.param_groups[i]['lr'] = self.__initial_lr[i] * lr_mult
381 |
382 | super().step(closure)
383 |
384 | self.global_step += 1
385 |
386 |
387 | class PolyAdamOptimizer(torch.optim.Adam):
388 | def __init__(self, params, lr, betas, max_step, momentum=0.9):
389 | super().__init__(params, lr, betas)
390 |
391 | self.global_step = 0
392 | self.max_step = max_step
393 | self.momentum = momentum
394 |
395 | self.__initial_lr = [group['lr'] for group in self.param_groups]
396 |
397 |
398 | def step(self, closure=None):
399 |
400 | if self.global_step < self.max_step:
401 | lr_mult = (1 - self.global_step / self.max_step) ** self.momentum
402 |
403 | for i in range(len(self.param_groups)):
404 | self.param_groups[i]['lr'] = self.__initial_lr[i] * lr_mult
405 |
406 | super().step(closure)
407 | self.global_step += 1
408 | #
409 | # from ranger import RangerQH,Ranger
410 | # # https://github.com/lessw2020/Ranger-Deep-Learning-Optimizer/blob/master/ranger/rangerqh.py
411 | #
412 | # class PolyRangerOptimizer(RangerQH):
413 | #
414 | # def __init__(self, params, lr, betas, max_step, momentum=0.9):
415 | # super().__init__(params, lr, betas)
416 | #
417 | # self.global_step = 0
418 | # self.max_step = max_step
419 | # self.momentum = momentum
420 | #
421 | # self.__initial_lr = [group['lr'] for group in self.param_groups]
422 | #
423 | #
424 | # def step(self, closure=None):
425 | #
426 | # if self.global_step < self.max_step:
427 | # lr_mult = (1 - self.global_step / self.max_step) ** self.momentum
428 | #
429 | # for i in range(len(self.param_groups)):
430 | # self.param_groups[i]['lr'] = self.__initial_lr[i] * lr_mult
431 | #
432 | # super().step(closure)
433 | # self.global_step += 1
434 |
435 | class SGDROptimizer(torch.optim.SGD):
436 |
437 | def __init__(self, params, steps_per_epoch, lr=0, weight_decay=0, epoch_start=1, restart_mult=2):
438 | super().__init__(params, lr, weight_decay)
439 |
440 | self.global_step = 0
441 | self.local_step = 0
442 | self.total_restart = 0
443 |
444 | self.max_step = steps_per_epoch * epoch_start
445 | self.restart_mult = restart_mult
446 |
447 | self.__initial_lr = [group['lr'] for group in self.param_groups]
448 |
449 |
450 | def step(self, closure=None):
451 |
452 | if self.local_step >= self.max_step:
453 | self.local_step = 0
454 | self.max_step *= self.restart_mult
455 | self.total_restart += 1
456 |
457 | lr_mult = (1 + math.cos(math.pi * self.local_step / self.max_step))/2 / (self.total_restart + 1)
458 |
459 | for i in range(len(self.param_groups)):
460 | self.param_groups[i]['lr'] = self.__initial_lr[i] * lr_mult
461 |
462 | super().step(closure)
463 |
464 | self.local_step += 1
465 | self.global_step += 1
466 |
467 |
468 | def split_dataset(dataset, n_splits):
469 |
470 | return [Subset(dataset, np.arange(i, len(dataset), n_splits)) for i in range(n_splits)]
471 |
472 |
473 | def gap2d(x, keepdims=False):
474 | out = torch.mean(x.view(x.size(0), x.size(1), -1), -1)
475 | if keepdims:
476 | out = out.view(out.size(0), out.size(1), 1, 1)
477 |
478 | return out
479 |
480 |
481 | def decode_seg(label_mask, toTensor=False):
482 | """
483 | :param label_mask: mask (np.ndarray): (M, N)/ tensor: N*C*H*W
484 | :return: color label: (M, N, 3),
485 | """
486 | if not isinstance(label_mask, np.ndarray):
487 | if isinstance(label_mask, torch.Tensor): # get the data from a variable
488 | image_tensor = label_mask.data
489 | else:
490 | return label_mask
491 | label_mask = image_tensor[0][0].cpu().numpy()
492 |
493 | rgb = np.zeros((label_mask.shape[0], label_mask.shape[1], 3),dtype=np.float)
494 | r = label_mask % 6
495 | g = (label_mask % 36) // 6
496 | b = label_mask // 36
497 | # 归一化到[0-1]
498 | rgb[:, :, 0] = r / 6
499 | rgb[:, :, 1] = g / 6
500 | rgb[:, :, 2] = b / 6
501 | if toTensor:
502 | rgb = torch.from_numpy(rgb.transpose([2,0,1])).unsqueeze(0)
503 |
504 | return rgb
505 |
506 |
507 | def tensor2im(input_image, imtype=np.uint8, normalize=True):
508 | """"Converts a Tensor array into a numpy image array.
509 | Parameters:
510 | input_image (tensor) -- the input image tensor array
511 | imtype (type) -- the desired type of the converted numpy array
512 | """
513 | if not isinstance(input_image, np.ndarray):
514 | if isinstance(input_image, torch.Tensor): # get the data from a variable
515 | image_tensor = input_image.data
516 | else:
517 | return input_image
518 | image_numpy = image_tensor[0].cpu().float().numpy() # convert it into a numpy array
519 | # if image_numpy.shape[0] == 1: # grayscale to RGB
520 | # image_numpy = np.tile(image_numpy, (3, 1, 1))
521 | if image_numpy.shape[0] == 3: # if RGB
522 | image_numpy = np.transpose(image_numpy, (1, 2, 0))
523 | if normalize:
524 | image_numpy = (image_numpy + 1) / 2.0 * 255.0 # post-processing: tranpose and scaling
525 | else: # if it is a numpy array, do nothing
526 | image_numpy = input_image
527 | return image_numpy.astype(imtype)
528 |
529 |
530 | def tensor2np(input_image, if_normalize=True):
531 | """
532 | :param input_image: C*H*W / H*W
533 | :return: ndarray, H*W*C / H*W
534 | """
535 | if isinstance(input_image, torch.Tensor): # get the data from a variable
536 | image_tensor = input_image.data
537 | image_numpy = image_tensor.cpu().float().numpy() # convert it into a numpy array
538 |
539 | else:
540 | image_numpy = input_image
541 | if image_numpy.ndim == 2:
542 | return image_numpy
543 | elif image_numpy.ndim == 3:
544 | C, H, W = image_numpy.shape
545 | image_numpy = np.transpose(image_numpy, (1, 2, 0))
546 | # 如果输入为灰度图C==1,则输出array,ndim==2;
547 | if C == 1:
548 | image_numpy = image_numpy[:, :, 0]
549 | if if_normalize and C == 3:
550 | image_numpy = (image_numpy + 1) / 2.0 * 255.0 # post-processing: tranpose and scaling
551 | # add to prevent extreme noises in visual images
552 | image_numpy[image_numpy<0]=0
553 | image_numpy[image_numpy>255]=255
554 | image_numpy = image_numpy.astype(np.uint8)
555 | return image_numpy
556 |
557 |
558 | import ntpath
559 | from misc.imutils import save_image
560 | def save_visuals(visuals, img_dir, name, save_one=True, iter='0'):
561 | """
562 | """
563 | # save images to the disk
564 | for label, image in visuals.items():
565 | N = image.shape[0]
566 | if save_one:
567 | N = 1
568 | # 保存各个bz的数据
569 | for j in range(N):
570 | name_ = ntpath.basename(name[j])
571 | name_ = name_.split(".")[0]
572 | # print(name_)
573 | image_numpy = tensor2np(image[j], if_normalize=True).astype(np.uint8)
574 | # print(image_numpy)
575 | img_path = os.path.join(img_dir, iter+'_%s_%s.png' % (name_, label))
576 | save_image(image_numpy, img_path)
--------------------------------------------------------------------------------
/STADE-CDNet/models/ChangeFormerBaseNetworks.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import math
5 | import torch
6 |
7 | from torch import nn
8 | from torch.nn import init
9 | from torch.nn import functional as F
10 | from torch.autograd import Function
11 |
12 | from math import sqrt
13 |
14 | import random
15 |
16 | class ConvBlock(torch.nn.Module):
17 | def __init__(self, input_size, output_size, kernel_size=3, stride=1, padding=1, bias=True, activation='prelu', norm=None):
18 | super(ConvBlock, self).__init__()
19 | self.conv = torch.nn.Conv2d(input_size, output_size, kernel_size, stride, padding, bias=bias)
20 |
21 | self.norm = norm
22 | if self.norm =='batch':
23 | self.bn = torch.nn.BatchNorm2d(output_size)
24 | elif self.norm == 'instance':
25 | self.bn = torch.nn.InstanceNorm2d(output_size)
26 |
27 | self.activation = activation
28 | if self.activation == 'relu':
29 | self.act = torch.nn.ReLU(True)
30 | elif self.activation == 'prelu':
31 | self.act = torch.nn.PReLU()
32 | elif self.activation == 'lrelu':
33 | self.act = torch.nn.LeakyReLU(0.2, True)
34 | elif self.activation == 'tanh':
35 | self.act = torch.nn.Tanh()
36 | elif self.activation == 'sigmoid':
37 | self.act = torch.nn.Sigmoid()
38 |
39 | def forward(self, x):
40 | if self.norm is not None:
41 | out = self.bn(self.conv(x))
42 | else:
43 | out = self.conv(x)
44 |
45 | if self.activation != 'no':
46 | return self.act(out)
47 | else:
48 | return out
49 |
50 | class DeconvBlock(torch.nn.Module):
51 | def __init__(self, input_size, output_size, kernel_size=4, stride=2, padding=1, bias=True, activation='prelu', norm=None):
52 | super(DeconvBlock, self).__init__()
53 | self.deconv = torch.nn.ConvTranspose2d(input_size, output_size, kernel_size, stride, padding, bias=bias)
54 |
55 | self.norm = norm
56 | if self.norm == 'batch':
57 | self.bn = torch.nn.BatchNorm2d(output_size)
58 | elif self.norm == 'instance':
59 | self.bn = torch.nn.InstanceNorm2d(output_size)
60 |
61 | self.activation = activation
62 | if self.activation == 'relu':
63 | self.act = torch.nn.ReLU(True)
64 | elif self.activation == 'prelu':
65 | self.act = torch.nn.PReLU()
66 | elif self.activation == 'lrelu':
67 | self.act = torch.nn.LeakyReLU(0.2, True)
68 | elif self.activation == 'tanh':
69 | self.act = torch.nn.Tanh()
70 | elif self.activation == 'sigmoid':
71 | self.act = torch.nn.Sigmoid()
72 |
73 | def forward(self, x):
74 | if self.norm is not None:
75 | out = self.bn(self.deconv(x))
76 | else:
77 | out = self.deconv(x)
78 |
79 | if self.activation is not None:
80 | return self.act(out)
81 | else:
82 | return out
83 |
84 |
85 | class ConvLayer(nn.Module):
86 | def __init__(self, in_channels, out_channels, kernel_size, stride, padding):
87 | super(ConvLayer, self).__init__()
88 | # reflection_padding = kernel_size // 2
89 | # self.reflection_pad = nn.ReflectionPad2d(reflection_padding)
90 | self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
91 |
92 | def forward(self, x):
93 | # out = self.reflection_pad(x)
94 | out = self.conv2d(x)
95 | return out
96 |
97 |
98 | class UpsampleConvLayer(torch.nn.Module):
99 | def __init__(self, in_channels, out_channels, kernel_size, stride):
100 | super(UpsampleConvLayer, self).__init__()
101 | self.conv2d = nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride=stride, padding=1)
102 |
103 | def forward(self, x):
104 | out = self.conv2d(x)
105 | return out
106 |
107 |
108 | class ResidualBlock(torch.nn.Module):
109 | def __init__(self, channels):
110 | super(ResidualBlock, self).__init__()
111 | self.conv1 = ConvLayer(channels, channels, kernel_size=3, stride=1, padding=1)
112 | self.conv2 = ConvLayer(channels, channels, kernel_size=3, stride=1, padding=1)
113 | self.relu = nn.ReLU()
114 |
115 | def forward(self, x):
116 | residual = x
117 | out = self.relu(self.conv1(x))
118 | out = self.conv2(out) * 0.1
119 | out = torch.add(out, residual)
120 | return out
121 |
122 |
123 |
124 | def init_linear(linear):
125 | init.xavier_normal(linear.weight)
126 | linear.bias.data.zero_()
127 |
128 |
129 | def init_conv(conv, glu=True):
130 | init.kaiming_normal(conv.weight)
131 | if conv.bias is not None:
132 | conv.bias.data.zero_()
133 |
134 |
135 | class EqualLR:
136 | def __init__(self, name):
137 | self.name = name
138 |
139 | def compute_weight(self, module):
140 | weight = getattr(module, self.name + '_orig')
141 | fan_in = weight.data.size(1) * weight.data[0][0].numel()
142 |
143 | return weight * sqrt(2 / fan_in)
144 |
145 | @staticmethod
146 | def apply(module, name):
147 | fn = EqualLR(name)
148 |
149 | weight = getattr(module, name)
150 | del module._parameters[name]
151 | module.register_parameter(name + '_orig', nn.Parameter(weight.data))
152 | module.register_forward_pre_hook(fn)
153 |
154 | return fn
155 |
156 | def __call__(self, module, input):
157 | weight = self.compute_weight(module)
158 | setattr(module, self.name, weight)
159 |
160 |
161 | def equal_lr(module, name='weight'):
162 | EqualLR.apply(module, name)
163 |
164 | return module
--------------------------------------------------------------------------------
/STADE-CDNet/models/DTCDSCN.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn as nn
4 | from torchvision.models import ResNet
5 | import torch.nn.functional as F
6 | from functools import partial
7 |
8 |
9 | nonlinearity = partial(F.relu,inplace=True)
10 |
11 | class SELayer(nn.Module):
12 | def __init__(self, channel, reduction=16):
13 | super(SELayer, self).__init__()
14 | self.avg_pool = nn.AdaptiveAvgPool2d(1)
15 | self.fc = nn.Sequential(
16 | nn.Linear(channel, channel // reduction, bias=False),
17 | nn.ReLU(inplace=True),
18 | nn.Linear(channel // reduction, channel, bias=False),
19 | nn.Sigmoid()
20 | )
21 |
22 | def forward(self, x):
23 | b, c, _, _ = x.size()
24 | y = self.avg_pool(x).view(b, c)
25 | y = self.fc(y).view(b, c, 1, 1)
26 | return x * y.expand_as(x)
27 |
28 | class Dblock_more_dilate(nn.Module):
29 | def __init__(self, channel):
30 | super(Dblock_more_dilate, self).__init__()
31 | self.dilate1 = nn.Conv2d(channel, channel, kernel_size=3, dilation=1, padding=1)
32 | self.dilate2 = nn.Conv2d(channel, channel, kernel_size=3, dilation=2, padding=2)
33 | self.dilate3 = nn.Conv2d(channel, channel, kernel_size=3, dilation=4, padding=4)
34 | self.dilate4 = nn.Conv2d(channel, channel, kernel_size=3, dilation=8, padding=8)
35 | self.dilate5 = nn.Conv2d(channel, channel, kernel_size=3, dilation=16, padding=16)
36 | for m in self.modules():
37 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
38 | if m.bias is not None:
39 | m.bias.data.zero_()
40 |
41 | def forward(self, x):
42 | dilate1_out = nonlinearity(self.dilate1(x))
43 | dilate2_out = nonlinearity(self.dilate2(dilate1_out))
44 | dilate3_out = nonlinearity(self.dilate3(dilate2_out))
45 | dilate4_out = nonlinearity(self.dilate4(dilate3_out))
46 | dilate5_out = nonlinearity(self.dilate5(dilate4_out))
47 | out = x + dilate1_out + dilate2_out + dilate3_out + dilate4_out + dilate5_out
48 | return out
49 | class Dblock(nn.Module):
50 | def __init__(self, channel):
51 | super(Dblock, self).__init__()
52 | self.dilate1 = nn.Conv2d(channel, channel, kernel_size=3, dilation=1, padding=1)
53 | self.dilate2 = nn.Conv2d(channel, channel, kernel_size=3, dilation=2, padding=2)
54 | self.dilate3 = nn.Conv2d(channel, channel, kernel_size=3, dilation=4, padding=4)
55 | self.dilate4 = nn.Conv2d(channel, channel, kernel_size=3, dilation=8, padding=8)
56 | # self.dilate5 = nn.Conv2d(channel, channel, kernel_size=3, dilation=16, padding=16)
57 | for m in self.modules():
58 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
59 | if m.bias is not None:
60 | m.bias.data.zero_()
61 |
62 | def forward(self, x):
63 | dilate1_out = nonlinearity(self.dilate1(x))
64 | dilate2_out = nonlinearity(self.dilate2(dilate1_out))
65 | dilate3_out = nonlinearity(self.dilate3(dilate2_out))
66 | dilate4_out = nonlinearity(self.dilate4(dilate3_out))
67 | # dilate5_out = nonlinearity(self.dilate5(dilate4_out))
68 | out = x + dilate1_out + dilate2_out + dilate3_out + dilate4_out # + dilate5_out
69 | return out
70 |
71 | def conv3x3(in_planes, out_planes, stride=1):
72 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
73 |
74 | class SEBasicBlock(nn.Module):
75 | expansion = 1
76 |
77 | def __init__(self, inplanes, planes, stride=1, downsample=None, reduction=16):
78 | super(SEBasicBlock, self).__init__()
79 | self.conv1 = conv3x3(inplanes, planes, stride)
80 | self.bn1 = nn.BatchNorm2d(planes)
81 | self.relu = nn.ReLU(inplace=True)
82 | self.conv2 = conv3x3(planes, planes, 1)
83 | self.bn2 = nn.BatchNorm2d(planes)
84 | self.se = SELayer(planes, reduction)
85 | self.downsample = downsample
86 | self.stride = stride
87 |
88 | def forward(self, x):
89 | residual = x
90 | out = self.conv1(x)
91 | out = self.bn1(out)
92 | out = self.relu(out)
93 |
94 | out = self.conv2(out)
95 | out = self.bn2(out)
96 | out = self.se(out)
97 |
98 | if self.downsample is not None:
99 | residual = self.downsample(x)
100 |
101 | out += residual
102 | out = self.relu(out)
103 |
104 | return out
105 |
106 | class DecoderBlock(nn.Module):
107 | def __init__(self, in_channels, n_filters):
108 | super(DecoderBlock,self).__init__()
109 |
110 | self.conv1 = nn.Conv2d(in_channels, in_channels // 4, 1)
111 | self.norm1 = nn.BatchNorm2d(in_channels // 4)
112 | self.relu1 = nonlinearity
113 | self.scse = SCSEBlock(in_channels // 4)
114 |
115 | self.deconv2 = nn.ConvTranspose2d(in_channels // 4, in_channels // 4, 3, stride=2, padding=1, output_padding=1)
116 | self.norm2 = nn.BatchNorm2d(in_channels // 4)
117 | self.relu2 = nonlinearity
118 |
119 | self.conv3 = nn.Conv2d(in_channels // 4, n_filters, 1)
120 | self.norm3 = nn.BatchNorm2d(n_filters)
121 | self.relu3 = nonlinearity
122 |
123 | def forward(self, x):
124 | x = self.conv1(x)
125 | x = self.norm1(x)
126 | x = self.relu1(x)
127 | y = self.scse(x)
128 | x = x + y
129 | x = self.deconv2(x)
130 | x = self.norm2(x)
131 | x = self.relu2(x)
132 | x = self.conv3(x)
133 | x = self.norm3(x)
134 | x = self.relu3(x)
135 | return x
136 |
137 | class SCSEBlock(nn.Module):
138 | def __init__(self, channel, reduction=16):
139 | super(SCSEBlock, self).__init__()
140 | self.avg_pool = nn.AdaptiveAvgPool2d(1)
141 |
142 | '''self.channel_excitation = nn.Sequential(nn.(channel, int(channel//reduction)),
143 | nn.ReLU(inplace=True),
144 | nn.Linear(int(channel//reduction), channel),
145 | nn.Sigmoid())'''
146 | self.channel_excitation = nn.Sequential(nn.Conv2d(channel, int(channel//reduction), kernel_size=1,
147 | stride=1, padding=0, bias=False),
148 | nn.ReLU(inplace=True),
149 | nn.Conv2d(int(channel // reduction), channel,kernel_size=1,
150 | stride=1, padding=0, bias=False),
151 | nn.Sigmoid())
152 |
153 | self.spatial_se = nn.Sequential(nn.Conv2d(channel, 1, kernel_size=1,
154 | stride=1, padding=0, bias=False),
155 | nn.Sigmoid())
156 |
157 | def forward(self, x):
158 | bahs, chs, _, _ = x.size()
159 |
160 | # Returns a new tensor with the same data as the self tensor but of a different size.
161 | chn_se = self.avg_pool(x)
162 | chn_se = self.channel_excitation(chn_se)
163 | chn_se = torch.mul(x, chn_se)
164 | spa_se = self.spatial_se(x)
165 | spa_se = torch.mul(x, spa_se)
166 | return torch.add(chn_se, 1, spa_se)
167 |
168 | class CDNet_model(nn.Module):
169 | def __init__(self, in_channels=3, block=SEBasicBlock, layers=[3, 4, 6, 3], num_classes=2):
170 | super(CDNet_model, self).__init__()
171 |
172 | filters = [64, 128, 256, 512]
173 | self.inplanes = 64
174 | self.firstconv = nn.Conv2d(in_channels, 64, kernel_size=7, stride=2, padding=3,
175 | bias=False)
176 | self.firstbn = nn.BatchNorm2d(64)
177 | self.firstrelu = nn.ReLU(inplace=True)
178 | self.firstmaxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
179 | self.encoder1 = self._make_layer(block, 64, layers[0])
180 | self.encoder2 = self._make_layer(block, 128, layers[1], stride=2)
181 | self.encoder3 = self._make_layer(block, 256, layers[2], stride=2)
182 | self.encoder4 = self._make_layer(block, 512, layers[3], stride=2)
183 |
184 | self.decoder4 = DecoderBlock(filters[3], filters[2])
185 | self.decoder3 = DecoderBlock(filters[2], filters[1])
186 | self.decoder2 = DecoderBlock(filters[1], filters[0])
187 | self.decoder1 = DecoderBlock(filters[0], filters[0])
188 |
189 | self.dblock_master = Dblock(512)
190 | self.dblock = Dblock(512)
191 |
192 | self.decoder4_master = DecoderBlock(filters[3], filters[2])
193 | self.decoder3_master = DecoderBlock(filters[2], filters[1])
194 | self.decoder2_master = DecoderBlock(filters[1], filters[0])
195 | self.decoder1_master = DecoderBlock(filters[0], filters[0])
196 |
197 | self.finaldeconv1_master = nn.ConvTranspose2d(filters[0], 32, 4, 2, 1)
198 | self.finalrelu1_master = nonlinearity
199 | self.finalconv2_master = nn.Conv2d(32, 32, 3, padding=1)
200 | self.finalrelu2_master = nonlinearity
201 | self.finalconv3_master = nn.Conv2d(32, num_classes, 3, padding=1)
202 |
203 | self.finaldeconv1 = nn.ConvTranspose2d(filters[0], 32, 4, 2, 1)
204 | self.finalrelu1 = nonlinearity
205 | self.finalconv2 = nn.Conv2d(32, 32, 3, padding=1)
206 | self.finalrelu2 = nonlinearity
207 | self.finalconv3 = nn.Conv2d(32, num_classes, 3, padding=1)
208 |
209 | for m in self.modules():
210 | if isinstance(m, nn.Conv2d):
211 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
212 | m.weight.data.normal_(0, math.sqrt(2. / n))
213 | elif isinstance(m, nn.BatchNorm2d):
214 | m.weight.data.fill_(1)
215 | m.bias.data.zero_()
216 |
217 | def _make_layer(self, block, planes, blocks, stride=1):
218 | downsample = None
219 | if stride != 1 or self.inplanes != planes * block.expansion:
220 | downsample = nn.Sequential(
221 | nn.Conv2d(self.inplanes, planes * block.expansion,
222 | kernel_size=1, stride=stride, bias=False),
223 | nn.BatchNorm2d(planes * block.expansion),
224 | )
225 |
226 | layers = []
227 | layers.append(block(self.inplanes, planes, stride, downsample))
228 | self.inplanes = planes * block.expansion
229 | for i in range(1, blocks):
230 | layers.append(block(self.inplanes, planes))
231 |
232 | return nn.Sequential(*layers)
233 |
234 | def forward(self, x, y):
235 | # Encoder_1
236 | x = self.firstconv(x)
237 | x = self.firstbn(x)
238 | x = self.firstrelu(x)
239 | x = self.firstmaxpool(x)
240 |
241 | e1_x = self.encoder1(x)
242 | e2_x = self.encoder2(e1_x)
243 | e3_x = self.encoder3(e2_x)
244 | e4_x = self.encoder4(e3_x)
245 |
246 | # # Center_1
247 | # e4_x_center = self.dblock(e4_x)
248 |
249 | # # Decoder_1
250 | # d4_x = self.decoder4(e4_x_center) + e3_x
251 | # d3_x = self.decoder3(d4_x) + e2_x
252 | # d2_x = self.decoder2(d3_x) + e1_x
253 | # d1_x = self.decoder1(d2_x)
254 |
255 | # out1 = self.finaldeconv1(d1_x)
256 | # out1 = self.finalrelu1(out1)
257 | # out1 = self.finalconv2(out1)
258 | # out1 = self.finalrelu2(out1)
259 | # out1 = self.finalconv3(out1)
260 |
261 | # Encoder_2
262 | y = self.firstconv(y)
263 | y = self.firstbn(y)
264 | y = self.firstrelu(y)
265 | y = self.firstmaxpool(y)
266 |
267 | e1_y = self.encoder1(y)
268 | e2_y = self.encoder2(e1_y)
269 | e3_y = self.encoder3(e2_y)
270 | e4_y = self.encoder4(e3_y)
271 |
272 | # # Center_2
273 | # e4_y_center = self.dblock(e4_y)
274 |
275 | # # Decoder_2
276 | # d4_y = self.decoder4(e4_y_center) + e3_y
277 | # d3_y = self.decoder3(d4_y) + e2_y
278 | # d2_y = self.decoder2(d3_y) + e1_y
279 | # d1_y = self.decoder1(d2_y)
280 | # out2 = self.finaldeconv1(d1_y)
281 | # out2 = self.finalrelu1(out2)
282 | # out2 = self.finalconv2(out2)
283 | # out2 = self.finalrelu2(out2)
284 | # out2 = self.finalconv3(out2)
285 |
286 | # center_master
287 | e4 = self.dblock_master(e4_x - e4_y)
288 | # decoder_master
289 | d4 = self.decoder4_master(e4) + e3_x - e3_y
290 | d3 = self.decoder3_master(d4) + e2_x - e2_y
291 | d2 = self.decoder2_master(d3) + e1_x - e1_y
292 | d1 = self.decoder1_master(d2)
293 |
294 | out = self.finaldeconv1_master(d1)
295 | out = self.finalrelu1_master(out)
296 | out = self.finalconv2_master(out)
297 | out = self.finalrelu2_master(out)
298 | out = self.finalconv3_master(out)
299 |
300 | output = []
301 | output.append(out)
302 |
303 | return output
304 |
305 |
306 |
307 | def CDNet34(in_channels, **kwargs):
308 |
309 | model = CDNet_model(in_channels, SEBasicBlock, [3, 4, 6, 3], **kwargs)
310 |
311 | return model
--------------------------------------------------------------------------------
/STADE-CDNet/models/SiamUnet_conc.py:
--------------------------------------------------------------------------------
1 | # Rodrigo Caye Daudt
2 | # https://rcdaudt.github.io/
3 | # Daudt, R. C., Le Saux, B., & Boulch, A. "Fully convolutional siamese networks for change detection". In 2018 25th IEEE International Conference on Image Processing (ICIP) (pp. 4063-4067). IEEE.
4 |
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 | from torch.nn.modules.padding import ReplicationPad2d
9 |
10 | class SiamUnet_conc(nn.Module):
11 | """SiamUnet_conc segmentation network."""
12 |
13 | def __init__(self, input_nbr, label_nbr):
14 | super(SiamUnet_conc, self).__init__()
15 |
16 | self.input_nbr = input_nbr
17 |
18 | self.conv11 = nn.Conv2d(input_nbr, 16, kernel_size=3, padding=1)
19 | self.bn11 = nn.BatchNorm2d(16)
20 | self.do11 = nn.Dropout2d(p=0.2)
21 | self.conv12 = nn.Conv2d(16, 16, kernel_size=3, padding=1)
22 | self.bn12 = nn.BatchNorm2d(16)
23 | self.do12 = nn.Dropout2d(p=0.2)
24 |
25 | self.conv21 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
26 | self.bn21 = nn.BatchNorm2d(32)
27 | self.do21 = nn.Dropout2d(p=0.2)
28 | self.conv22 = nn.Conv2d(32, 32, kernel_size=3, padding=1)
29 | self.bn22 = nn.BatchNorm2d(32)
30 | self.do22 = nn.Dropout2d(p=0.2)
31 |
32 | self.conv31 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
33 | self.bn31 = nn.BatchNorm2d(64)
34 | self.do31 = nn.Dropout2d(p=0.2)
35 | self.conv32 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
36 | self.bn32 = nn.BatchNorm2d(64)
37 | self.do32 = nn.Dropout2d(p=0.2)
38 | self.conv33 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
39 | self.bn33 = nn.BatchNorm2d(64)
40 | self.do33 = nn.Dropout2d(p=0.2)
41 |
42 | self.conv41 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
43 | self.bn41 = nn.BatchNorm2d(128)
44 | self.do41 = nn.Dropout2d(p=0.2)
45 | self.conv42 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
46 | self.bn42 = nn.BatchNorm2d(128)
47 | self.do42 = nn.Dropout2d(p=0.2)
48 | self.conv43 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
49 | self.bn43 = nn.BatchNorm2d(128)
50 | self.do43 = nn.Dropout2d(p=0.2)
51 |
52 | self.upconv4 = nn.ConvTranspose2d(128, 128, kernel_size=3, padding=1, stride=2, output_padding=1)
53 |
54 | self.conv43d = nn.ConvTranspose2d(384, 128, kernel_size=3, padding=1)
55 | self.bn43d = nn.BatchNorm2d(128)
56 | self.do43d = nn.Dropout2d(p=0.2)
57 | self.conv42d = nn.ConvTranspose2d(128, 128, kernel_size=3, padding=1)
58 | self.bn42d = nn.BatchNorm2d(128)
59 | self.do42d = nn.Dropout2d(p=0.2)
60 | self.conv41d = nn.ConvTranspose2d(128, 64, kernel_size=3, padding=1)
61 | self.bn41d = nn.BatchNorm2d(64)
62 | self.do41d = nn.Dropout2d(p=0.2)
63 |
64 | self.upconv3 = nn.ConvTranspose2d(64, 64, kernel_size=3, padding=1, stride=2, output_padding=1)
65 |
66 | self.conv33d = nn.ConvTranspose2d(192, 64, kernel_size=3, padding=1)
67 | self.bn33d = nn.BatchNorm2d(64)
68 | self.do33d = nn.Dropout2d(p=0.2)
69 | self.conv32d = nn.ConvTranspose2d(64, 64, kernel_size=3, padding=1)
70 | self.bn32d = nn.BatchNorm2d(64)
71 | self.do32d = nn.Dropout2d(p=0.2)
72 | self.conv31d = nn.ConvTranspose2d(64, 32, kernel_size=3, padding=1)
73 | self.bn31d = nn.BatchNorm2d(32)
74 | self.do31d = nn.Dropout2d(p=0.2)
75 |
76 | self.upconv2 = nn.ConvTranspose2d(32, 32, kernel_size=3, padding=1, stride=2, output_padding=1)
77 |
78 | self.conv22d = nn.ConvTranspose2d(96, 32, kernel_size=3, padding=1)
79 | self.bn22d = nn.BatchNorm2d(32)
80 | self.do22d = nn.Dropout2d(p=0.2)
81 | self.conv21d = nn.ConvTranspose2d(32, 16, kernel_size=3, padding=1)
82 | self.bn21d = nn.BatchNorm2d(16)
83 | self.do21d = nn.Dropout2d(p=0.2)
84 |
85 | self.upconv1 = nn.ConvTranspose2d(16, 16, kernel_size=3, padding=1, stride=2, output_padding=1)
86 |
87 | self.conv12d = nn.ConvTranspose2d(48, 16, kernel_size=3, padding=1)
88 | self.bn12d = nn.BatchNorm2d(16)
89 | self.do12d = nn.Dropout2d(p=0.2)
90 | self.conv11d = nn.ConvTranspose2d(16, label_nbr, kernel_size=3, padding=1)
91 |
92 | self.sm = nn.LogSoftmax(dim=1)
93 |
94 | def forward(self, x1, x2):
95 |
96 | """Forward method."""
97 | # Stage 1
98 | x11 = self.do11(F.relu(self.bn11(self.conv11(x1))))
99 | x12_1 = self.do12(F.relu(self.bn12(self.conv12(x11))))
100 | x1p = F.max_pool2d(x12_1, kernel_size=2, stride=2)
101 |
102 |
103 | # Stage 2
104 | x21 = self.do21(F.relu(self.bn21(self.conv21(x1p))))
105 | x22_1 = self.do22(F.relu(self.bn22(self.conv22(x21))))
106 | x2p = F.max_pool2d(x22_1, kernel_size=2, stride=2)
107 |
108 | # Stage 3
109 | x31 = self.do31(F.relu(self.bn31(self.conv31(x2p))))
110 | x32 = self.do32(F.relu(self.bn32(self.conv32(x31))))
111 | x33_1 = self.do33(F.relu(self.bn33(self.conv33(x32))))
112 | x3p = F.max_pool2d(x33_1, kernel_size=2, stride=2)
113 |
114 | # Stage 4
115 | x41 = self.do41(F.relu(self.bn41(self.conv41(x3p))))
116 | x42 = self.do42(F.relu(self.bn42(self.conv42(x41))))
117 | x43_1 = self.do43(F.relu(self.bn43(self.conv43(x42))))
118 | x4p = F.max_pool2d(x43_1, kernel_size=2, stride=2)
119 |
120 |
121 | ####################################################
122 | # Stage 1
123 | x11 = self.do11(F.relu(self.bn11(self.conv11(x2))))
124 | x12_2 = self.do12(F.relu(self.bn12(self.conv12(x11))))
125 | x1p = F.max_pool2d(x12_2, kernel_size=2, stride=2)
126 |
127 | # Stage 2
128 | x21 = self.do21(F.relu(self.bn21(self.conv21(x1p))))
129 | x22_2 = self.do22(F.relu(self.bn22(self.conv22(x21))))
130 | x2p = F.max_pool2d(x22_2, kernel_size=2, stride=2)
131 |
132 | # Stage 3
133 | x31 = self.do31(F.relu(self.bn31(self.conv31(x2p))))
134 | x32 = self.do32(F.relu(self.bn32(self.conv32(x31))))
135 | x33_2 = self.do33(F.relu(self.bn33(self.conv33(x32))))
136 | x3p = F.max_pool2d(x33_2, kernel_size=2, stride=2)
137 |
138 | # Stage 4
139 | x41 = self.do41(F.relu(self.bn41(self.conv41(x3p))))
140 | x42 = self.do42(F.relu(self.bn42(self.conv42(x41))))
141 | x43_2 = self.do43(F.relu(self.bn43(self.conv43(x42))))
142 | x4p = F.max_pool2d(x43_2, kernel_size=2, stride=2)
143 |
144 |
145 | ####################################################
146 | # Stage 4d
147 | x4d = self.upconv4(x4p)
148 | pad4 = ReplicationPad2d((0, x43_1.size(3) - x4d.size(3), 0, x43_1.size(2) - x4d.size(2)))
149 | x4d = torch.cat((pad4(x4d), x43_1, x43_2), 1)
150 | x43d = self.do43d(F.relu(self.bn43d(self.conv43d(x4d))))
151 | x42d = self.do42d(F.relu(self.bn42d(self.conv42d(x43d))))
152 | x41d = self.do41d(F.relu(self.bn41d(self.conv41d(x42d))))
153 |
154 | # Stage 3d
155 | x3d = self.upconv3(x41d)
156 | pad3 = ReplicationPad2d((0, x33_1.size(3) - x3d.size(3), 0, x33_1.size(2) - x3d.size(2)))
157 | x3d = torch.cat((pad3(x3d), x33_1, x33_2), 1)
158 | x33d = self.do33d(F.relu(self.bn33d(self.conv33d(x3d))))
159 | x32d = self.do32d(F.relu(self.bn32d(self.conv32d(x33d))))
160 | x31d = self.do31d(F.relu(self.bn31d(self.conv31d(x32d))))
161 |
162 | # Stage 2d
163 | x2d = self.upconv2(x31d)
164 | pad2 = ReplicationPad2d((0, x22_1.size(3) - x2d.size(3), 0, x22_1.size(2) - x2d.size(2)))
165 | x2d = torch.cat((pad2(x2d), x22_1, x22_2), 1)
166 | x22d = self.do22d(F.relu(self.bn22d(self.conv22d(x2d))))
167 | x21d = self.do21d(F.relu(self.bn21d(self.conv21d(x22d))))
168 |
169 | # Stage 1d
170 | x1d = self.upconv1(x21d)
171 | pad1 = ReplicationPad2d((0, x12_1.size(3) - x1d.size(3), 0, x12_1.size(2) - x1d.size(2)))
172 | x1d = torch.cat((pad1(x1d), x12_1, x12_2), 1)
173 | x12d = self.do12d(F.relu(self.bn12d(self.conv12d(x1d))))
174 | x11d = self.conv11d(x12d)
175 |
176 | #Softmax layer is embedded in the loss layer
177 | #out = self.sm(x11d)
178 | output = []
179 | output.append(x11d)
180 |
181 | return output
182 |
183 |
184 |
--------------------------------------------------------------------------------
/STADE-CDNet/models/SiamUnet_diff.py:
--------------------------------------------------------------------------------
1 | # Rodrigo Caye Daudt
2 | # https://rcdaudt.github.io/
3 | # Daudt, R. C., Le Saux, B., & Boulch, A. "Fully convolutional siamese networks for change detection". In 2018 25th IEEE International Conference on Image Processing (ICIP) (pp. 4063-4067). IEEE.
4 |
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 | from torch.nn.modules.padding import ReplicationPad2d
9 |
10 | class SiamUnet_diff(nn.Module):
11 | """SiamUnet_diff segmentation network."""
12 |
13 | def __init__(self, input_nbr, label_nbr):
14 | super(SiamUnet_diff, self).__init__()
15 |
16 | self.input_nbr = input_nbr
17 |
18 | self.conv11 = nn.Conv2d(input_nbr, 16, kernel_size=3, padding=1)
19 | self.bn11 = nn.BatchNorm2d(16)
20 | self.do11 = nn.Dropout2d(p=0.2)
21 | self.conv12 = nn.Conv2d(16, 16, kernel_size=3, padding=1)
22 | self.bn12 = nn.BatchNorm2d(16)
23 | self.do12 = nn.Dropout2d(p=0.2)
24 |
25 | self.conv21 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
26 | self.bn21 = nn.BatchNorm2d(32)
27 | self.do21 = nn.Dropout2d(p=0.2)
28 | self.conv22 = nn.Conv2d(32, 32, kernel_size=3, padding=1)
29 | self.bn22 = nn.BatchNorm2d(32)
30 | self.do22 = nn.Dropout2d(p=0.2)
31 |
32 | self.conv31 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
33 | self.bn31 = nn.BatchNorm2d(64)
34 | self.do31 = nn.Dropout2d(p=0.2)
35 | self.conv32 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
36 | self.bn32 = nn.BatchNorm2d(64)
37 | self.do32 = nn.Dropout2d(p=0.2)
38 | self.conv33 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
39 | self.bn33 = nn.BatchNorm2d(64)
40 | self.do33 = nn.Dropout2d(p=0.2)
41 |
42 | self.conv41 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
43 | self.bn41 = nn.BatchNorm2d(128)
44 | self.do41 = nn.Dropout2d(p=0.2)
45 | self.conv42 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
46 | self.bn42 = nn.BatchNorm2d(128)
47 | self.do42 = nn.Dropout2d(p=0.2)
48 | self.conv43 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
49 | self.bn43 = nn.BatchNorm2d(128)
50 | self.do43 = nn.Dropout2d(p=0.2)
51 |
52 | self.upconv4 = nn.ConvTranspose2d(128, 128, kernel_size=3, padding=1, stride=2, output_padding=1)
53 |
54 | self.conv43d = nn.ConvTranspose2d(256, 128, kernel_size=3, padding=1)
55 | self.bn43d = nn.BatchNorm2d(128)
56 | self.do43d = nn.Dropout2d(p=0.2)
57 | self.conv42d = nn.ConvTranspose2d(128, 128, kernel_size=3, padding=1)
58 | self.bn42d = nn.BatchNorm2d(128)
59 | self.do42d = nn.Dropout2d(p=0.2)
60 | self.conv41d = nn.ConvTranspose2d(128, 64, kernel_size=3, padding=1)
61 | self.bn41d = nn.BatchNorm2d(64)
62 | self.do41d = nn.Dropout2d(p=0.2)
63 |
64 | self.upconv3 = nn.ConvTranspose2d(64, 64, kernel_size=3, padding=1, stride=2, output_padding=1)
65 |
66 | self.conv33d = nn.ConvTranspose2d(128, 64, kernel_size=3, padding=1)
67 | self.bn33d = nn.BatchNorm2d(64)
68 | self.do33d = nn.Dropout2d(p=0.2)
69 | self.conv32d = nn.ConvTranspose2d(64, 64, kernel_size=3, padding=1)
70 | self.bn32d = nn.BatchNorm2d(64)
71 | self.do32d = nn.Dropout2d(p=0.2)
72 | self.conv31d = nn.ConvTranspose2d(64, 32, kernel_size=3, padding=1)
73 | self.bn31d = nn.BatchNorm2d(32)
74 | self.do31d = nn.Dropout2d(p=0.2)
75 |
76 | self.upconv2 = nn.ConvTranspose2d(32, 32, kernel_size=3, padding=1, stride=2, output_padding=1)
77 |
78 | self.conv22d = nn.ConvTranspose2d(64, 32, kernel_size=3, padding=1)
79 | self.bn22d = nn.BatchNorm2d(32)
80 | self.do22d = nn.Dropout2d(p=0.2)
81 | self.conv21d = nn.ConvTranspose2d(32, 16, kernel_size=3, padding=1)
82 | self.bn21d = nn.BatchNorm2d(16)
83 | self.do21d = nn.Dropout2d(p=0.2)
84 |
85 | self.upconv1 = nn.ConvTranspose2d(16, 16, kernel_size=3, padding=1, stride=2, output_padding=1)
86 |
87 | self.conv12d = nn.ConvTranspose2d(32, 16, kernel_size=3, padding=1)
88 | self.bn12d = nn.BatchNorm2d(16)
89 | self.do12d = nn.Dropout2d(p=0.2)
90 | self.conv11d = nn.ConvTranspose2d(16, label_nbr, kernel_size=3, padding=1)
91 |
92 | self.sm = nn.LogSoftmax(dim=1)
93 |
94 | def forward(self, x1, x2):
95 |
96 |
97 | """Forward method."""
98 | # Stage 1
99 | x11 = self.do11(F.relu(self.bn11(self.conv11(x1))))
100 | x12_1 = self.do12(F.relu(self.bn12(self.conv12(x11))))
101 | x1p = F.max_pool2d(x12_1, kernel_size=2, stride=2)
102 |
103 |
104 | # Stage 2
105 | x21 = self.do21(F.relu(self.bn21(self.conv21(x1p))))
106 | x22_1 = self.do22(F.relu(self.bn22(self.conv22(x21))))
107 | x2p = F.max_pool2d(x22_1, kernel_size=2, stride=2)
108 |
109 | # Stage 3
110 | x31 = self.do31(F.relu(self.bn31(self.conv31(x2p))))
111 | x32 = self.do32(F.relu(self.bn32(self.conv32(x31))))
112 | x33_1 = self.do33(F.relu(self.bn33(self.conv33(x32))))
113 | x3p = F.max_pool2d(x33_1, kernel_size=2, stride=2)
114 |
115 | # Stage 4
116 | x41 = self.do41(F.relu(self.bn41(self.conv41(x3p))))
117 | x42 = self.do42(F.relu(self.bn42(self.conv42(x41))))
118 | x43_1 = self.do43(F.relu(self.bn43(self.conv43(x42))))
119 | x4p = F.max_pool2d(x43_1, kernel_size=2, stride=2)
120 |
121 | ####################################################
122 | # Stage 1
123 | x11 = self.do11(F.relu(self.bn11(self.conv11(x2))))
124 | x12_2 = self.do12(F.relu(self.bn12(self.conv12(x11))))
125 | x1p = F.max_pool2d(x12_2, kernel_size=2, stride=2)
126 |
127 |
128 | # Stage 2
129 | x21 = self.do21(F.relu(self.bn21(self.conv21(x1p))))
130 | x22_2 = self.do22(F.relu(self.bn22(self.conv22(x21))))
131 | x2p = F.max_pool2d(x22_2, kernel_size=2, stride=2)
132 |
133 | # Stage 3
134 | x31 = self.do31(F.relu(self.bn31(self.conv31(x2p))))
135 | x32 = self.do32(F.relu(self.bn32(self.conv32(x31))))
136 | x33_2 = self.do33(F.relu(self.bn33(self.conv33(x32))))
137 | x3p = F.max_pool2d(x33_2, kernel_size=2, stride=2)
138 |
139 | # Stage 4
140 | x41 = self.do41(F.relu(self.bn41(self.conv41(x3p))))
141 | x42 = self.do42(F.relu(self.bn42(self.conv42(x41))))
142 | x43_2 = self.do43(F.relu(self.bn43(self.conv43(x42))))
143 | x4p = F.max_pool2d(x43_2, kernel_size=2, stride=2)
144 |
145 |
146 |
147 | # Stage 4d
148 | x4d = self.upconv4(x4p)
149 | pad4 = ReplicationPad2d((0, x43_1.size(3) - x4d.size(3), 0, x43_1.size(2) - x4d.size(2)))
150 | x4d = torch.cat((pad4(x4d), torch.abs(x43_1 - x43_2)), 1)
151 | x43d = self.do43d(F.relu(self.bn43d(self.conv43d(x4d))))
152 | x42d = self.do42d(F.relu(self.bn42d(self.conv42d(x43d))))
153 | x41d = self.do41d(F.relu(self.bn41d(self.conv41d(x42d))))
154 |
155 | # Stage 3d
156 | x3d = self.upconv3(x41d)
157 | pad3 = ReplicationPad2d((0, x33_1.size(3) - x3d.size(3), 0, x33_1.size(2) - x3d.size(2)))
158 | x3d = torch.cat((pad3(x3d), torch.abs(x33_1 - x33_2)), 1)
159 | x33d = self.do33d(F.relu(self.bn33d(self.conv33d(x3d))))
160 | x32d = self.do32d(F.relu(self.bn32d(self.conv32d(x33d))))
161 | x31d = self.do31d(F.relu(self.bn31d(self.conv31d(x32d))))
162 |
163 | # Stage 2d
164 | x2d = self.upconv2(x31d)
165 | pad2 = ReplicationPad2d((0, x22_1.size(3) - x2d.size(3), 0, x22_1.size(2) - x2d.size(2)))
166 | x2d = torch.cat((pad2(x2d), torch.abs(x22_1 - x22_2)), 1)
167 | x22d = self.do22d(F.relu(self.bn22d(self.conv22d(x2d))))
168 | x21d = self.do21d(F.relu(self.bn21d(self.conv21d(x22d))))
169 |
170 | # Stage 1d
171 | x1d = self.upconv1(x21d)
172 | pad1 = ReplicationPad2d((0, x12_1.size(3) - x1d.size(3), 0, x12_1.size(2) - x1d.size(2)))
173 | x1d = torch.cat((pad1(x1d), torch.abs(x12_1 - x12_2)), 1)
174 | x12d = self.do12d(F.relu(self.bn12d(self.conv12d(x1d))))
175 | x11d = self.conv11d(x12d)
176 | #out = self.sm(x11d)
177 |
178 | output = []
179 | output.append(x11d)
180 |
181 | return output
182 |
--------------------------------------------------------------------------------
/STADE-CDNet/models/Unet.py:
--------------------------------------------------------------------------------
1 | # Rodrigo Caye Daudt
2 | # https://rcdaudt.github.io/
3 | # Daudt, R. C., Le Saux, B., & Boulch, A. "Fully convolutional siamese networks for change detection". In 2018 25th IEEE International Conference on Image Processing (ICIP) (pp. 4063-4067). IEEE.
4 |
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 | from torch.nn.modules.padding import ReplicationPad2d
9 |
10 | class Unet(nn.Module):
11 | """EF segmentation network."""
12 |
13 | def __init__(self, input_nbr, label_nbr):
14 | super(Unet, self).__init__()
15 |
16 | self.conv11 = nn.Conv2d(2*input_nbr, 16, kernel_size=3, padding=1)
17 | self.bn11 = nn.BatchNorm2d(16)
18 | self.do11 = nn.Dropout2d(p=0.2)
19 | self.conv12 = nn.Conv2d(16, 16, kernel_size=3, padding=1)
20 | self.bn12 = nn.BatchNorm2d(16)
21 | self.do12 = nn.Dropout2d(p=0.2)
22 |
23 | self.conv21 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
24 | self.bn21 = nn.BatchNorm2d(32)
25 | self.do21 = nn.Dropout2d(p=0.2)
26 | self.conv22 = nn.Conv2d(32, 32, kernel_size=3, padding=1)
27 | self.bn22 = nn.BatchNorm2d(32)
28 | self.do22 = nn.Dropout2d(p=0.2)
29 |
30 | self.conv31 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
31 | self.bn31 = nn.BatchNorm2d(64)
32 | self.do31 = nn.Dropout2d(p=0.2)
33 | self.conv32 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
34 | self.bn32 = nn.BatchNorm2d(64)
35 | self.do32 = nn.Dropout2d(p=0.2)
36 | self.conv33 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
37 | self.bn33 = nn.BatchNorm2d(64)
38 | self.do33 = nn.Dropout2d(p=0.2)
39 |
40 | self.conv41 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
41 | self.bn41 = nn.BatchNorm2d(128)
42 | self.do41 = nn.Dropout2d(p=0.2)
43 | self.conv42 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
44 | self.bn42 = nn.BatchNorm2d(128)
45 | self.do42 = nn.Dropout2d(p=0.2)
46 | self.conv43 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
47 | self.bn43 = nn.BatchNorm2d(128)
48 | self.do43 = nn.Dropout2d(p=0.2)
49 |
50 |
51 | self.upconv4 = nn.ConvTranspose2d(128, 128, kernel_size=3, padding=1, stride=2, output_padding=1)
52 |
53 | self.conv43d = nn.ConvTranspose2d(256, 128, kernel_size=3, padding=1)
54 | self.bn43d = nn.BatchNorm2d(128)
55 | self.do43d = nn.Dropout2d(p=0.2)
56 | self.conv42d = nn.ConvTranspose2d(128, 128, kernel_size=3, padding=1)
57 | self.bn42d = nn.BatchNorm2d(128)
58 | self.do42d = nn.Dropout2d(p=0.2)
59 | self.conv41d = nn.ConvTranspose2d(128, 64, kernel_size=3, padding=1)
60 | self.bn41d = nn.BatchNorm2d(64)
61 | self.do41d = nn.Dropout2d(p=0.2)
62 |
63 | self.upconv3 = nn.ConvTranspose2d(64, 64, kernel_size=3, padding=1, stride=2, output_padding=1)
64 |
65 | self.conv33d = nn.ConvTranspose2d(128, 64, kernel_size=3, padding=1)
66 | self.bn33d = nn.BatchNorm2d(64)
67 | self.do33d = nn.Dropout2d(p=0.2)
68 | self.conv32d = nn.ConvTranspose2d(64, 64, kernel_size=3, padding=1)
69 | self.bn32d = nn.BatchNorm2d(64)
70 | self.do32d = nn.Dropout2d(p=0.2)
71 | self.conv31d = nn.ConvTranspose2d(64, 32, kernel_size=3, padding=1)
72 | self.bn31d = nn.BatchNorm2d(32)
73 | self.do31d = nn.Dropout2d(p=0.2)
74 |
75 | self.upconv2 = nn.ConvTranspose2d(32, 32, kernel_size=3, padding=1, stride=2, output_padding=1)
76 |
77 | self.conv22d = nn.ConvTranspose2d(64, 32, kernel_size=3, padding=1)
78 | self.bn22d = nn.BatchNorm2d(32)
79 | self.do22d = nn.Dropout2d(p=0.2)
80 | self.conv21d = nn.ConvTranspose2d(32, 16, kernel_size=3, padding=1)
81 | self.bn21d = nn.BatchNorm2d(16)
82 | self.do21d = nn.Dropout2d(p=0.2)
83 |
84 | self.upconv1 = nn.ConvTranspose2d(16, 16, kernel_size=3, padding=1, stride=2, output_padding=1)
85 |
86 | self.conv12d = nn.ConvTranspose2d(32, 16, kernel_size=3, padding=1)
87 | self.bn12d = nn.BatchNorm2d(16)
88 | self.do12d = nn.Dropout2d(p=0.2)
89 | self.conv11d = nn.ConvTranspose2d(16, label_nbr, kernel_size=3, padding=1)
90 |
91 | self.sm = nn.LogSoftmax(dim=1)
92 |
93 | def forward(self, x1, x2):
94 |
95 | x = torch.cat((x1, x2), 1)
96 |
97 | """Forward method."""
98 | # Stage 1
99 | x11 = self.do11(F.relu(self.bn11(self.conv11(x))))
100 | x12 = self.do12(F.relu(self.bn12(self.conv12(x11))))
101 | x1p = F.max_pool2d(x12, kernel_size=2, stride=2)
102 |
103 | # Stage 2
104 | x21 = self.do21(F.relu(self.bn21(self.conv21(x1p))))
105 | x22 = self.do22(F.relu(self.bn22(self.conv22(x21))))
106 | x2p = F.max_pool2d(x22, kernel_size=2, stride=2)
107 |
108 | # Stage 3
109 | x31 = self.do31(F.relu(self.bn31(self.conv31(x2p))))
110 | x32 = self.do32(F.relu(self.bn32(self.conv32(x31))))
111 | x33 = self.do33(F.relu(self.bn33(self.conv33(x32))))
112 | x3p = F.max_pool2d(x33, kernel_size=2, stride=2)
113 |
114 | # Stage 4
115 | x41 = self.do41(F.relu(self.bn41(self.conv41(x3p))))
116 | x42 = self.do42(F.relu(self.bn42(self.conv42(x41))))
117 | x43 = self.do43(F.relu(self.bn43(self.conv43(x42))))
118 | x4p = F.max_pool2d(x43, kernel_size=2, stride=2)
119 |
120 |
121 | # Stage 4d
122 | x4d = self.upconv4(x4p)
123 | pad4 = ReplicationPad2d((0, x43.size(3) - x4d.size(3), 0, x43.size(2) - x4d.size(2)))
124 | x4d = torch.cat((pad4(x4d), x43), 1)
125 | x43d = self.do43d(F.relu(self.bn43d(self.conv43d(x4d))))
126 | x42d = self.do42d(F.relu(self.bn42d(self.conv42d(x43d))))
127 | x41d = self.do41d(F.relu(self.bn41d(self.conv41d(x42d))))
128 |
129 | # Stage 3d
130 | x3d = self.upconv3(x41d)
131 | pad3 = ReplicationPad2d((0, x33.size(3) - x3d.size(3), 0, x33.size(2) - x3d.size(2)))
132 | x3d = torch.cat((pad3(x3d), x33), 1)
133 | x33d = self.do33d(F.relu(self.bn33d(self.conv33d(x3d))))
134 | x32d = self.do32d(F.relu(self.bn32d(self.conv32d(x33d))))
135 | x31d = self.do31d(F.relu(self.bn31d(self.conv31d(x32d))))
136 |
137 | # Stage 2d
138 | x2d = self.upconv2(x31d)
139 | pad2 = ReplicationPad2d((0, x22.size(3) - x2d.size(3), 0, x22.size(2) - x2d.size(2)))
140 | x2d = torch.cat((pad2(x2d), x22), 1)
141 | x22d = self.do22d(F.relu(self.bn22d(self.conv22d(x2d))))
142 | x21d = self.do21d(F.relu(self.bn21d(self.conv21d(x22d))))
143 |
144 | # Stage 1d
145 | x1d = self.upconv1(x21d)
146 | pad1 = ReplicationPad2d((0, x12.size(3) - x1d.size(3), 0, x12.size(2) - x1d.size(2)))
147 | x1d = torch.cat((pad1(x1d), x12), 1)
148 | x12d = self.do12d(F.relu(self.bn12d(self.conv12d(x1d))))
149 | x11d = self.conv11d(x12d)
150 |
151 | output = []
152 | output.append(x11d)
153 |
154 | return output
155 |
156 |
157 |
--------------------------------------------------------------------------------
/STADE-CDNet/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .resnet import *
2 |
--------------------------------------------------------------------------------
/STADE-CDNet/models/basic_model.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import torch
4 |
5 | from misc.imutils import save_image
6 | from models.networks import *
7 |
8 |
9 | class CDEvaluator():
10 |
11 | def __init__(self, args):
12 |
13 | self.n_class = args.n_class
14 | # define G
15 | self.net_G = define_G(args=args, gpu_ids=args.gpu_ids)
16 |
17 | self.device = torch.device("cuda:%s" % args.gpu_ids[0]
18 | if torch.cuda.is_available() and len(args.gpu_ids)>0
19 | else "cpu")
20 |
21 | print(self.device)
22 |
23 | self.checkpoint_dir = '/where'
24 |
25 | self.pred_dir = args.output_folder
26 | os.makedirs(self.pred_dir, exist_ok=True)
27 |
28 | def load_checkpoint(self, checkpoint_name='best_ckpt.pt'):
29 |
30 | if os.path.exists(os.path.join(self.checkpoint_dir, checkpoint_name)):
31 | # load the entire checkpoint
32 | checkpoint = torch.load(os.path.join(self.checkpoint_dir, checkpoint_name),
33 | map_location=self.device)
34 |
35 | self.net_G.load_state_dict(checkpoint['model_G_state_dict'])
36 | self.net_G.to(self.device)
37 | # update some other states
38 | self.best_val_acc = checkpoint['best_val_acc']
39 | self.best_epoch_id = checkpoint['best_epoch_id']
40 |
41 | else:
42 | raise FileNotFoundError('no such checkpoint %s' % checkpoint_name)
43 | return self.net_G
44 |
45 |
46 | def _visualize_pred(self):
47 | pred = torch.argmax(self.G_pred, dim=1, keepdim=True)
48 | pred_vis = pred * 255
49 | return pred_vis
50 |
51 | def _forward_pass(self, batch):
52 | self.batch = batch
53 | img_in1 = batch['A'].to(self.device)
54 | img_in2 = batch['B'].to(self.device)
55 | self.shape_h = img_in1.shape[-2]
56 | self.shape_w = img_in1.shape[-1]
57 | self.G_pred = self.net_G(img_in1, img_in2)[-1]
58 | return self._visualize_pred()
59 |
60 | def eval(self):
61 | self.net_G.eval()
62 |
63 | def _save_predictions(self):
64 | """
65 | 保存模型输出结果,二分类图像
66 | """
67 |
68 | preds = self._visualize_pred()
69 | name = self.batch['name']
70 | for i, pred in enumerate(preds):
71 | file_name = os.path.join(
72 | self.pred_dir, name[i].replace('.jpg', '.png'))
73 | pred = pred[0].cpu().numpy()
74 | save_image(pred, file_name)
75 |
76 |
--------------------------------------------------------------------------------
/STADE-CDNet/models/evaluator.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import matplotlib.pyplot as plt
4 |
5 | from models.networks import *
6 | from misc.metric_tool import ConfuseMatrixMeter
7 | from misc.logger_tool import Logger
8 | from utils import de_norm
9 | import utils
10 |
11 |
12 | # Decide which device we want to run on
13 | # torch.cuda.current_device()
14 |
15 | # device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
16 |
17 |
18 | class CDEvaluator():
19 |
20 | def __init__(self, args, dataloader):
21 |
22 | self.dataloader = dataloader
23 |
24 | self.n_class = args.n_class
25 | # define G
26 | self.net_G = define_G(args=args, gpu_ids=args.gpu_ids)
27 | self.device = torch.device("cuda:%s" % args.gpu_ids[0] if torch.cuda.is_available() and len(args.gpu_ids)>0
28 | else "cpu")
29 | print(self.device)
30 |
31 | # define some other vars to record the training states
32 | self.running_metric = ConfuseMatrixMeter(n_class=self.n_class)
33 |
34 | # define logger file
35 | logger_path = os.path.join(args.checkpoint_dir, 'log_test.txt')
36 | self.logger = Logger(logger_path)
37 | self.logger.write_dict_str(args.__dict__)
38 |
39 |
40 | # training log
41 | self.epoch_acc = 0
42 | self.best_val_acc = 0.0
43 | self.best_epoch_id = 0
44 |
45 | self.steps_per_epoch = len(dataloader)
46 |
47 | self.G_pred = None
48 | self.pred_vis = None
49 | self.batch = None
50 | self.is_training = False
51 | self.batch_id = 0
52 | self.epoch_id = 0
53 | self.checkpoint_dir = args.checkpoint_dir
54 | self.vis_dir = args.vis_dir
55 |
56 | # check and create model dir
57 | if os.path.exists(self.checkpoint_dir) is False:
58 | os.mkdir(self.checkpoint_dir)
59 | if os.path.exists(self.vis_dir) is False:
60 | os.mkdir(self.vis_dir)
61 |
62 |
63 | def _load_checkpoint(self, checkpoint_name='best_ckpt.pt'):
64 |
65 | if os.path.exists(os.path.join(self.checkpoint_dir, checkpoint_name)):
66 | self.logger.write('loading last checkpoint...\n')
67 | # load the entire checkpoint
68 | checkpoint = torch.load(os.path.join(self.checkpoint_dir, checkpoint_name), map_location=self.device)
69 |
70 | self.net_G.load_state_dict(checkpoint['model_G_state_dict'])
71 |
72 | self.net_G.to(self.device)
73 |
74 | # update some other states
75 | self.best_val_acc = checkpoint['best_val_acc']
76 | self.best_epoch_id = checkpoint['best_epoch_id']
77 |
78 | self.logger.write('Eval Historical_best_acc = %.4f (at epoch %d)\n' %
79 | (self.best_val_acc, self.best_epoch_id))
80 | self.logger.write('\n')
81 |
82 | else:
83 | raise FileNotFoundError('no such checkpoint %s' % checkpoint_name)
84 |
85 |
86 | def _visualize_pred(self):
87 | pred = torch.argmax(self.G_pred, dim=1, keepdim=True)
88 | pred_vis = pred * 255
89 | return pred_vis
90 |
91 |
92 | def _update_metric(self):
93 | """
94 | update metric
95 | """
96 | target = self.batch['L'].to(self.device).detach()
97 | G_pred = self.G_pred.detach()
98 | G_pred = torch.argmax(G_pred, dim=1)
99 |
100 | current_score = self.running_metric.update_cm(pr=G_pred.cpu().numpy(), gt=target.cpu().numpy())
101 | return current_score
102 |
103 | def _collect_running_batch_states(self):
104 |
105 | running_acc = self._update_metric()
106 |
107 | m = len(self.dataloader)
108 |
109 | if np.mod(self.batch_id, 1) == 1:
110 | message = 'Is_training: %s. [%d,%d], running_mf1: %.5f\n' %\
111 | (self.is_training, self.batch_id, m, running_acc)
112 | self.logger.write(message)
113 |
114 | if np.mod(self.batch_id, 1) == 1:
115 | vis_input = utils.make_numpy_grid(de_norm(self.batch['A']))
116 | vis_input2 = utils.make_numpy_grid(de_norm(self.batch['B']))
117 |
118 | vis_pred = utils.make_numpy_grid(self._visualize_pred())
119 |
120 | vis_gt = utils.make_numpy_grid(self.batch['L'])
121 | vis = np.concatenate([vis_input, vis_input2, vis_pred, vis_gt], axis=0)
122 | vis = np.clip(vis, a_min=0.0, a_max=1.0)
123 | file_name = os.path.join(
124 | self.vis_dir, 'eval_' + str(self.batch_id)+'.jpg')
125 | plt.imsave(file_name, vis)
126 |
127 |
128 | def _collect_epoch_states(self):
129 |
130 | scores_dict = self.running_metric.get_scores()
131 |
132 | np.save(os.path.join(self.checkpoint_dir, 'scores_dict.npy'), scores_dict)
133 |
134 | self.epoch_acc = scores_dict['mf1']
135 |
136 | with open(os.path.join(self.checkpoint_dir, '%s.txt' % (self.epoch_acc)),
137 | mode='a') as file:
138 | pass
139 |
140 | message = ''
141 | for k, v in scores_dict.items():
142 | message += '%s: %.5f ' % (k, v)
143 | self.logger.write('%s\n' % message) # save the message
144 |
145 | self.logger.write('\n')
146 |
147 | def _clear_cache(self):
148 | self.running_metric.clear()
149 |
150 | def _forward_pass(self, batch):
151 | self.batch = batch
152 | img_in1 = batch['A'].to(self.device)
153 | img_in2 = batch['B'].to(self.device)
154 | self.G_pred = self.net_G(img_in1, img_in2)[-1]
155 |
156 | def eval_models(self,checkpoint_name='best_ckpt.pt'):
157 |
158 | self._load_checkpoint(checkpoint_name)
159 |
160 | ################## Eval ##################
161 | ##########################################
162 | self.logger.write('Begin evaluation...\n')
163 | self._clear_cache()
164 | self.is_training = False
165 | self.net_G.eval()
166 |
167 | # Iterate over data.
168 | for self.batch_id, batch in enumerate(self.dataloader, 0):
169 | with torch.no_grad():
170 | self._forward_pass(batch)
171 | self._collect_running_batch_states()
172 | self._collect_epoch_states()
173 |
--------------------------------------------------------------------------------
/STADE-CDNet/models/help_funcs.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from einops import rearrange
4 | from torch import nn
5 |
6 |
7 | class TwoLayerConv2d(nn.Sequential):
8 | def __init__(self, in_channels, out_channels, kernel_size=3):
9 | super().__init__(nn.Conv2d(in_channels, in_channels, kernel_size=kernel_size,
10 | padding=kernel_size // 2, stride=1, bias=False),
11 | nn.BatchNorm2d(in_channels),
12 | nn.ReLU(),
13 | nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size,
14 | padding=kernel_size // 2, stride=1)
15 | )
16 |
17 |
18 | class Residual(nn.Module):
19 | def __init__(self, fn):
20 | super().__init__()
21 | self.fn = fn
22 | def forward(self, x, **kwargs):
23 | return self.fn(x, **kwargs) + x
24 |
25 |
26 | class Residual2(nn.Module):
27 | def __init__(self, fn):
28 | super().__init__()
29 | self.fn = fn
30 | def forward(self, x, x2, **kwargs):
31 | return self.fn(x, x2, **kwargs) + x
32 |
33 |
34 | class PreNorm(nn.Module):
35 | def __init__(self, dim, fn):
36 | super().__init__()
37 | self.norm = nn.LayerNorm(dim)
38 | self.fn = fn
39 | def forward(self, x, **kwargs):
40 | return self.fn(self.norm(x), **kwargs)
41 |
42 |
43 | class PreNorm2(nn.Module):
44 | def __init__(self, dim, fn):
45 | super().__init__()
46 | self.norm = nn.LayerNorm(dim)
47 | self.fn = fn
48 | def forward(self, x, x2, **kwargs):
49 | return self.fn(self.norm(x), self.norm(x2), **kwargs)
50 |
51 |
52 | class FeedForward(nn.Module):
53 | def __init__(self, dim, hidden_dim, dropout = 0.):
54 | super().__init__()
55 | self.net = nn.Sequential(
56 | nn.Linear(dim, hidden_dim),
57 | nn.GELU(),
58 | nn.Dropout(dropout),
59 | nn.Linear(hidden_dim, dim),
60 | nn.Dropout(dropout)
61 | )
62 | def forward(self, x):
63 | return self.net(x)
64 |
65 |
66 | class Cross_Attention(nn.Module):
67 | def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0., softmax=True):
68 | super().__init__()
69 | inner_dim = dim_head * heads
70 | self.heads = heads
71 | self.scale = dim ** -0.5
72 |
73 | self.softmax = softmax
74 | self.to_q = nn.Linear(dim, inner_dim, bias=False)
75 | self.to_k = nn.Linear(dim, inner_dim, bias=False)
76 | self.to_v = nn.Linear(dim, inner_dim, bias=False)
77 |
78 | self.to_out = nn.Sequential(
79 | nn.Linear(inner_dim, dim),
80 | nn.Dropout(dropout)
81 | )
82 |
83 | def forward(self, x, m, mask = None):
84 |
85 | b, n, _, h = *x.shape, self.heads
86 | q = self.to_q(x)
87 | k = self.to_k(m)
88 | v = self.to_v(m)
89 |
90 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), [q,k,v])
91 |
92 | dots = torch.einsum('bhid,bhjd->bhij', q, k) * self.scale
93 | mask_value = -torch.finfo(dots.dtype).max
94 |
95 | if mask is not None:
96 | mask = F.pad(mask.flatten(1), (1, 0), value = True)
97 | assert mask.shape[-1] == dots.shape[-1], 'mask has incorrect dimensions'
98 | mask = mask[:, None, :] * mask[:, :, None]
99 | dots.masked_fill_(~mask, mask_value)
100 | del mask
101 |
102 | if self.softmax:
103 | attn = dots.softmax(dim=-1)
104 | else:
105 | attn = dots
106 | # attn = dots
107 | # vis_tmp(dots)
108 |
109 | out = torch.einsum('bhij,bhjd->bhid', attn, v)
110 | out = rearrange(out, 'b h n d -> b n (h d)')
111 | out = self.to_out(out)
112 | # vis_tmp2(out)
113 |
114 | return out
115 |
116 |
117 | class Attention(nn.Module):
118 | def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
119 | super().__init__()
120 | inner_dim = dim_head * heads
121 | self.heads = heads
122 | self.scale = dim ** -0.5
123 |
124 | self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
125 | self.to_out = nn.Sequential(
126 | nn.Linear(inner_dim, dim),
127 | nn.Dropout(dropout)
128 | )
129 |
130 | def forward(self, x, mask = None):
131 | b, n, _, h = *x.shape, self.heads
132 | qkv = self.to_qkv(x).chunk(3, dim = -1)
133 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
134 |
135 | dots = torch.einsum('bhid,bhjd->bhij', q, k) * self.scale
136 | mask_value = -torch.finfo(dots.dtype).max
137 |
138 | if mask is not None:
139 | mask = F.pad(mask.flatten(1), (1, 0), value = True)
140 | assert mask.shape[-1] == dots.shape[-1], 'mask has incorrect dimensions'
141 | mask = mask[:, None, :] * mask[:, :, None]
142 | dots.masked_fill_(~mask, mask_value)
143 | del mask
144 |
145 | attn = dots.softmax(dim=-1)
146 |
147 |
148 | out = torch.einsum('bhij,bhjd->bhid', attn, v)
149 | out = rearrange(out, 'b h n d -> b n (h d)')
150 | out = self.to_out(out)
151 | return out
152 |
153 |
154 | class Transformer(nn.Module):
155 | def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout):
156 | super().__init__()
157 | self.layers = nn.ModuleList([])
158 | for _ in range(depth):
159 | self.layers.append(nn.ModuleList([
160 | Residual(PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout))),
161 | Residual(PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)))
162 | ]))
163 | def forward(self, x, mask = None):
164 | for attn, ff in self.layers:
165 | x = attn(x, mask = mask)
166 | x = ff(x)
167 | return x
168 |
169 |
170 | class TransformerDecoder(nn.Module):
171 | def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout, softmax=True):
172 | super().__init__()
173 | self.layers = nn.ModuleList([])
174 | for _ in range(depth):
175 | self.layers.append(nn.ModuleList([
176 | Residual2(PreNorm2(dim, Cross_Attention(dim, heads = heads,
177 | dim_head = dim_head, dropout = dropout,
178 | softmax=softmax))),
179 | Residual(PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)))
180 | ]))
181 | def forward(self, x, m, mask = None):
182 | """target(query), memory"""
183 | for attn, ff in self.layers:
184 | x = attn(x, m, mask = mask)
185 | x = ff(x)
186 | return x
187 |
188 | from scipy.io import savemat
189 | def save_to_mat(x1, x2, fx1, fx2, cp, file_name):
190 | #Save to mat files
191 | x1_np = x1.detach().cpu().numpy()
192 | x2_np = x2.detach().cpu().numpy()
193 |
194 | fx1_0_np = fx1[0].detach().cpu().numpy()
195 | fx2_0_np = fx2[0].detach().cpu().numpy()
196 | fx1_1_np = fx1[1].detach().cpu().numpy()
197 | fx2_1_np = fx2[1].detach().cpu().numpy()
198 | fx1_2_np = fx1[2].detach().cpu().numpy()
199 | fx2_2_np = fx2[2].detach().cpu().numpy()
200 | fx1_3_np = fx1[3].detach().cpu().numpy()
201 | fx2_3_np = fx2[3].detach().cpu().numpy()
202 | fx1_4_np = fx1[4].detach().cpu().numpy()
203 | fx2_4_np = fx2[4].detach().cpu().numpy()
204 |
205 | cp_np = cp[-1].detach().cpu().numpy()
206 |
207 | mdic = {'x1': x1_np, 'x2': x2_np,
208 | 'fx1_0': fx1_0_np, 'fx1_1': fx1_1_np, 'fx1_2': fx1_2_np, 'fx1_3': fx1_3_np, 'fx1_4': fx1_4_np,
209 | 'fx2_0': fx2_0_np, 'fx2_1': fx2_1_np, 'fx2_2': fx2_2_np, 'fx2_3': fx2_3_np, 'fx2_4': fx2_4_np,
210 | "final_pred": cp_np}
211 |
212 | savemat("/media/lidan/ssd2/ChangeFormer/vis/mat/"+file_name+".mat", mdic)
213 |
214 |
215 |
216 |
--------------------------------------------------------------------------------
/STADE-CDNet/models/losses.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | import numpy as np
4 | import torch.nn as nn
5 |
6 | def cross_entropy(input, target, weight=None, reduction='mean',ignore_index=255):
7 | """
8 | logSoftmax_with_loss
9 | :param input: torch.Tensor, N*C*H*W
10 | :param target: torch.Tensor, N*1*H*W,/ N*H*W
11 | :param weight: torch.Tensor, C
12 | :return: torch.Tensor [0]
13 | """
14 | target = target.long()
15 | if target.dim() == 4:
16 | target = torch.squeeze(target, dim=1)
17 | if input.shape[-1] != target.shape[-1]:
18 | input = F.interpolate(input, size=target.shape[1:], mode='bilinear',align_corners=True)
19 |
20 | return F.cross_entropy(input=input, target=target, weight=weight,
21 | ignore_index=ignore_index, reduction=reduction)
22 |
23 | #Focal Loss
24 | def get_alpha(supervised_loader):
25 | # get number of classes
26 | num_labels = 0
27 | for batch in supervised_loader:
28 | label_batch = batch['L']
29 | label_batch.data[label_batch.data==255] = 0 # pixels of ignore class added to background
30 | l_unique = torch.unique(label_batch.data)
31 | list_unique = [element.item() for element in l_unique.flatten()]
32 | num_labels = max(max(list_unique),num_labels)
33 | num_classes = num_labels + 1
34 | # count class occurrences
35 | alpha = [0 for i in range(num_classes)]
36 | for batch in supervised_loader:
37 | label_batch = batch['L']
38 | label_batch.data[label_batch.data==255] = 0 # pixels of ignore class added to background
39 | l_unique = torch.unique(label_batch.data)
40 | list_unique = [element.item() for element in l_unique.flatten()]
41 | l_unique_count = torch.stack([(label_batch.data==x_u).sum() for x_u in l_unique]) # tensor([65920, 36480])
42 | list_count = [count.item() for count in l_unique_count.flatten()]
43 | for index in list_unique:
44 | alpha[index] += list_count[list_unique.index(index)]
45 | return alpha
46 |
47 | # for FocalLoss
48 | def softmax_helper(x):
49 | # copy from: https://github.com/MIC-DKFZ/nnUNet/blob/master/nnunet/utilities/nd_softmax.py
50 | rpt = [1 for _ in range(len(x.size()))]
51 | rpt[1] = x.size(1)
52 | x_max = x.max(1, keepdim=True)[0].repeat(*rpt)
53 | e_x = torch.exp(x - x_max)
54 | return e_x / e_x.sum(1, keepdim=True).repeat(*rpt)
55 |
56 | class FocalLoss(nn.Module):
57 | """
58 | copy from: https://github.com/Hsuxu/Loss_ToolBox-PyTorch/blob/master/FocalLoss/FocalLoss.py
59 | This is a implementation of Focal Loss with smooth label cross entropy supported which is proposed in
60 | 'Focal Loss for Dense Object Detection. (https://arxiv.org/abs/1708.02002)'
61 | Focal_Loss= -1*alpha*(1-pt)*log(pt)
62 | :param num_class:
63 | :param alpha: (tensor) 3D or 4D the scalar factor for this criterion
64 | :param gamma: (float,double) gamma > 0 reduces the relative loss for well-classified examples (p>0.5) putting more
65 | focus on hard misclassified example
66 | :param smooth: (float,double) smooth value when cross entropy
67 | :param balance_index: (int) balance class index, should be specific when alpha is float
68 | :param size_average: (bool, optional) By default, the losses are averaged over each loss element in the batch.
69 | """
70 |
71 | def __init__(self, apply_nonlin=None, alpha=None, gamma=1, balance_index=0, smooth=1e-5, size_average=True):
72 | super(FocalLoss, self).__init__()
73 | self.apply_nonlin = apply_nonlin
74 | self.alpha = alpha
75 | self.gamma = gamma
76 | self.balance_index = balance_index
77 | self.smooth = smooth
78 | self.size_average = size_average
79 |
80 | if self.smooth is not None:
81 | if self.smooth < 0 or self.smooth > 1.0:
82 | raise ValueError('smooth value should be in [0,1]')
83 |
84 | def forward(self, logit, target):
85 | if self.apply_nonlin is not None:
86 | logit = self.apply_nonlin(logit)
87 | num_class = logit.shape[1]
88 |
89 | if logit.dim() > 2:
90 | # N,C,d1,d2 -> N,C,m (m=d1*d2*...)
91 | logit = logit.view(logit.size(0), logit.size(1), -1)
92 | logit = logit.permute(0, 2, 1).contiguous()
93 | logit = logit.view(-1, logit.size(-1))
94 | target = torch.squeeze(target, 1)
95 | target = target.view(-1, 1)
96 |
97 | alpha = self.alpha
98 |
99 | if alpha is None:
100 | alpha = torch.ones(num_class, 1)
101 | elif isinstance(alpha, (list, np.ndarray)):
102 | assert len(alpha) == num_class
103 | alpha = torch.FloatTensor(alpha).view(num_class, 1)
104 | alpha = alpha / alpha.sum()
105 | alpha = 1/alpha # inverse of class frequency
106 | elif isinstance(alpha, float):
107 | alpha = torch.ones(num_class, 1)
108 | alpha = alpha * (1 - self.alpha)
109 | alpha[self.balance_index] = self.alpha
110 |
111 | else:
112 | raise TypeError('Not support alpha type')
113 |
114 | if alpha.device != logit.device:
115 | alpha = alpha.to(logit.device)
116 |
117 | idx = target.cpu().long()
118 |
119 | one_hot_key = torch.FloatTensor(target.size(0), num_class).zero_()
120 |
121 | # to resolve error in idx in scatter_
122 | idx[idx==225]=0
123 |
124 | one_hot_key = one_hot_key.scatter_(1, idx, 1)
125 | if one_hot_key.device != logit.device:
126 | one_hot_key = one_hot_key.to(logit.device)
127 |
128 | if self.smooth:
129 | one_hot_key = torch.clamp(
130 | one_hot_key, self.smooth/(num_class-1), 1.0 - self.smooth)
131 | pt = (one_hot_key * logit).sum(1) + self.smooth
132 | logpt = pt.log()
133 |
134 | gamma = self.gamma
135 |
136 | alpha = alpha[idx]
137 | alpha = torch.squeeze(alpha)
138 | loss = -1 * alpha * torch.pow((1 - pt), gamma) * logpt
139 |
140 | if self.size_average:
141 | loss = loss.mean()
142 | else:
143 | loss = loss.sum()
144 | return loss
145 |
146 |
147 | #miou loss
148 | from torch.autograd import Variable
149 | def to_one_hot_var(tensor, nClasses, requires_grad=False):
150 |
151 | n, h, w = torch.squeeze(tensor, dim=1).size()
152 | one_hot = tensor.new(n, nClasses, h, w).fill_(0)
153 | one_hot = one_hot.scatter_(1, tensor.type(torch.int64).view(n, 1, h, w), 1)
154 | return Variable(one_hot, requires_grad=requires_grad)
155 |
156 | class mIoULoss(nn.Module):
157 | def __init__(self, weight=None, size_average=True, n_classes=2):
158 | super(mIoULoss, self).__init__()
159 | self.classes = n_classes
160 | self.weights = Variable(weight)
161 |
162 | def forward(self, inputs, target, is_target_variable=False):
163 | # inputs => N x Classes x H x W
164 | # target => N x H x W
165 | # target_oneHot => N x Classes x H x W
166 |
167 | N = inputs.size()[0]
168 | if is_target_variable:
169 | target_oneHot = to_one_hot_var(target.data, self.classes).float()
170 | else:
171 | target_oneHot = to_one_hot_var(target, self.classes).float()
172 |
173 | # predicted probabilities for each pixel along channel
174 | inputs = F.softmax(inputs, dim=1)
175 |
176 | # Numerator Product
177 | inter = inputs * target_oneHot
178 | ## Sum over all pixels N x C x H x W => N x C
179 | inter = inter.view(N, self.classes, -1).sum(2)
180 |
181 | # Denominator
182 | union = inputs + target_oneHot - (inputs * target_oneHot)
183 | ## Sum over all pixels N x C x H x W => N x C
184 | union = union.view(N, self.classes, -1).sum(2)
185 |
186 | loss = (self.weights * inter) / (union + 1e-8)
187 |
188 | ## Return average loss over classes and batch
189 | return -torch.mean(loss)
190 |
191 | #Minimax iou
192 | class mmIoULoss(nn.Module):
193 | def __init__(self, n_classes=2):
194 | super(mmIoULoss, self).__init__()
195 | self.classes = n_classes
196 |
197 | def forward(self, inputs, target, is_target_variable=False):
198 | # inputs => N x Classes x H x W
199 | # target => N x H x W
200 | # target_oneHot => N x Classes x H x W
201 |
202 | N = inputs.size()[0]
203 | if is_target_variable:
204 | target_oneHot = to_one_hot_var(target.data, self.classes).float()
205 | else:
206 | target_oneHot = to_one_hot_var(target, self.classes).float()
207 |
208 | # predicted probabilities for each pixel along channel
209 | inputs = F.softmax(inputs, dim=1)
210 |
211 | # Numerator Product
212 | inter = inputs * target_oneHot
213 | ## Sum over all pixels N x C x H x W => N x C
214 | inter = inter.view(N, self.classes, -1).sum(2)
215 |
216 | # Denominator
217 | union = inputs + target_oneHot - (inputs * target_oneHot)
218 | ## Sum over all pixels N x C x H x W => N x C
219 | union = union.view(N, self.classes, -1).sum(2)
220 |
221 | iou = inter/ (union + 1e-8)
222 |
223 | #minimum iou of two classes
224 | min_iou = torch.min(iou)
225 |
226 | #loss
227 | loss = -min_iou-torch.mean(iou)
228 | return loss
229 |
--------------------------------------------------------------------------------
/STADE-CDNet/models/networks.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.nn import init
4 | import torch.nn.functional as F
5 | from torch.optim import lr_scheduler
6 | import numpy as np
7 | import functools
8 | from einops import rearrange
9 | import cv2
10 | import models
11 | from models.help_funcs import Transformer, TransformerDecoder, TwoLayerConv2d
12 | from models.ChangeFormer import ChangeFormerV1, ChangeFormerV2, ChangeFormerV3, ChangeFormerV4, ChangeFormerV5, ChangeFormerV6
13 | from models.SiamUnet_diff import SiamUnet_diff
14 | from models.SiamUnet_conc import SiamUnet_conc
15 | from models.Unet import Unet
16 | from models.DTCDSCN import CDNet34
17 | device = "cuda" if torch.cuda.is_available() else "cpu"
18 |
19 | ###############################################################################
20 | # Helper Functions
21 | ###############################################################################
22 |
23 | def get_scheduler(optimizer, args):
24 | """Return a learning rate scheduler
25 |
26 | Parameters:
27 | optimizer -- the optimizer of the network
28 | args (option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions.
29 | opt.lr_policy is the name of learning rate policy: linear | step | plateau | cosine
30 |
31 | For 'linear', we keep the same learning rate for the first epochs
32 | and linearly decay the rate to zero over the next epochs.
33 | For other schedulers (step, plateau, and cosine), we use the default PyTorch schedulers.
34 | See https://pytorch.org/docs/stable/optim.html for more details.
35 | """
36 | if args.lr_policy == 'linear':
37 | def lambda_rule(epoch):
38 | lr_l = 1- epoch/ float(args.max_epochs + 1)
39 | return lr_l
40 | scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
41 | elif args.lr_policy == 'step':
42 | step_size = args.max_epochs//3
43 | # args.lr_decay_iters
44 | scheduler = lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=0.1)
45 | else:
46 | return NotImplementedError('learning rate policy [%s] is not implemented', args.lr_policy)
47 | return scheduler
48 |
49 |
50 | class Identity(nn.Module):
51 | def forward(self, x):
52 | return x
53 |
54 |
55 | def get_norm_layer(norm_type='instance'):
56 | """Return a normalization layer
57 |
58 | Parameters:
59 | norm_type (str) -- the name of the normalization layer: batch | instance | none
60 |
61 | For BatchNorm, we use learnable affine parameters and track running statistics (mean/stddev).
62 | For InstanceNorm, we do not use learnable affine parameters. We do not track running statistics.
63 | """
64 | if norm_type == 'batch':
65 | norm_layer = functools.partial(nn.BatchNorm2d, affine=True, track_running_stats=True)
66 | elif norm_type == 'instance':
67 | norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False)
68 | elif norm_type == 'none':
69 | norm_layer = lambda x: Identity()
70 | else:
71 | raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
72 | return norm_layer
73 |
74 |
75 | def init_weights(net, init_type='normal', init_gain=0.02):
76 | """Initialize network weights.
77 |
78 | Parameters:
79 | net (network) -- network to be initialized
80 | init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
81 | init_gain (float) -- scaling factor for normal, xavier and orthogonal.
82 |
83 | We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might
84 | work better for some applications. Feel free to try yourself.
85 | """
86 | def init_func(m): # define the initialization function
87 | classname = m.__class__.__name__
88 | if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
89 | if init_type == 'normal':
90 | init.normal_(m.weight.data, 0.0, init_gain)
91 | elif init_type == 'xavier':
92 | init.xavier_normal_(m.weight.data, gain=init_gain)
93 | elif init_type == 'kaiming':
94 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
95 | elif init_type == 'orthogonal':
96 | init.orthogonal_(m.weight.data, gain=init_gain)
97 | else:
98 | raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
99 | if hasattr(m, 'bias') and m.bias is not None:
100 | init.constant_(m.bias.data, 0.0)
101 | elif classname.find('BatchNorm2d') != -1: # BatchNorm Layer's weight is not a matrix; only normal distribution applies.
102 | init.normal_(m.weight.data, 1.0, init_gain)
103 | init.constant_(m.bias.data, 0.0)
104 |
105 | print('initialize network with %s' % init_type)
106 | net.apply(init_func) # apply the initialization function
107 |
108 |
109 | def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[]):
110 | """Initialize a network: 1. register CPU/GPU device (with multi-GPU support); 2. initialize the network weights
111 | Parameters:
112 | net (network) -- the network to be initialized
113 | init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
114 | gain (float) -- scaling factor for normal, xavier and orthogonal.
115 | gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2
116 |
117 | Return an initialized network.
118 | """
119 | if len(gpu_ids) > 0:
120 | assert(torch.cuda.is_available())
121 | net.to(gpu_ids[0])
122 | if len(gpu_ids) > 1:
123 | net = torch.nn.DataParallel(net, gpu_ids) # multi-GPUs
124 | init_weights(net, init_type, init_gain=init_gain)
125 | return net
126 |
127 |
128 | def define_G(args, init_type='normal', init_gain=0.02, gpu_ids=[]):
129 | if args.net_G == 'base_resnet18':
130 | net = ResNet(input_nc=3, output_nc=2, output_sigmoid=False)
131 |
132 | elif args.net_G == 'base_transformer_pos_s4':
133 | net = BASE_Transformer(input_nc=3, output_nc=2, token_len=4, resnet_stages_num=4,
134 | with_pos='learned')
135 |
136 | elif args.net_G == 'base_transformer_pos_s4_dd8':
137 | net = BASE_Transformer(input_nc=3, output_nc=2, token_len=4, resnet_stages_num=4,
138 | with_pos='learned', enc_depth=1, dec_depth=8)
139 |
140 | elif args.net_G == 'base_transformer_pos_s4_dd8_dedim8':
141 | net = BASE_Transformer(input_nc=3, output_nc=2, token_len=4, resnet_stages_num=4,
142 | with_pos='learned', enc_depth=1, dec_depth=8, decoder_dim_head=8)
143 |
144 | elif args.net_G == 'ChangeFormerV1':
145 | net = ChangeFormerV1() #ChangeFormer with Transformer Encoder and Convolutional Decoder
146 |
147 | elif args.net_G == 'ChangeFormerV2':
148 | net = ChangeFormerV2() #ChangeFormer with Transformer Encoder and Convolutional Decoder
149 |
150 | elif args.net_G == 'ChangeFormerV3':
151 | net = ChangeFormerV3() #ChangeFormer with Transformer Encoder and Convolutional Decoder (Fuse)
152 |
153 | elif args.net_G == 'ChangeFormerV4':
154 | net = ChangeFormerV4() #ChangeFormer with Transformer Encoder and Convolutional Decoder (Fuse)
155 |
156 | elif args.net_G == 'ChangeFormerV5':
157 | net = ChangeFormerV5(embed_dim=args.embed_dim) #ChangeFormer with Transformer Encoder and Convolutional Decoder (Fuse)
158 |
159 | elif args.net_G == 'ChangeFormerV6':
160 | net = ChangeFormerV6(embed_dim=args.embed_dim) #ChangeFormer with Transformer Encoder and Convolutional Decoder (Fuse)
161 |
162 | elif args.net_G == "SiamUnet_diff":
163 | #Implementation of ``Fully convolutional siamese networks for change detection''
164 | #Code copied from: https://github.com/rcdaudt/fully_convolutional_change_detection
165 | net = SiamUnet_diff(input_nbr=3, label_nbr=2)
166 |
167 | elif args.net_G == "SiamUnet_conc":
168 | #Implementation of ``Fully convolutional siamese networks for change detection''
169 | #Code copied from: https://github.com/rcdaudt/fully_convolutional_change_detection
170 | net = SiamUnet_conc(input_nbr=3, label_nbr=2)
171 |
172 | elif args.net_G == "Unet":
173 | #Usually abbreviated as FC-EF = Image Level Concatenation
174 | #Implementation of ``Fully convolutional siamese networks for change detection''
175 | #Code copied from: https://github.com/rcdaudt/fully_convolutional_change_detection
176 | net = Unet(input_nbr=3, label_nbr=2)
177 |
178 | elif args.net_G == "DTCDSCN":
179 | #The implementation of the paper"Building Change Detection for Remote Sensing Images Using a Dual Task Constrained Deep Siamese Convolutional Network Model "
180 | #Code copied from: https://github.com/fitzpchao/DTCDSCN
181 | net = CDNet34(in_channels=3)
182 |
183 | else:
184 | raise NotImplementedError('Generator model name [%s] is not recognized' % args.net_G)
185 | return init_net(net, init_type, init_gain, gpu_ids)
186 |
187 |
188 | ###############################################################################
189 | # main Functions
190 | ###############################################################################
191 |
192 |
193 | class ResNet(torch.nn.Module):
194 | def __init__(self, input_nc, output_nc,
195 | resnet_stages_num=5, backbone='resnet18',
196 | output_sigmoid=False, if_upsample_2x=True):
197 | """
198 | In the constructor we instantiate two nn.Linear modules and assign them as
199 | member variables.
200 | """
201 | super(ResNet, self).__init__()
202 | expand = 1
203 | if backbone == 'resnet18':
204 | self.resnet = models.resnet18(pretrained=True,
205 | replace_stride_with_dilation=[False,True,True])
206 | elif backbone == 'resnet34':
207 | self.resnet = models.resnet34(pretrained=True,
208 | replace_stride_with_dilation=[False,True,True])
209 | elif backbone == 'resnet50':
210 | self.resnet = models.resnet50(pretrained=True,
211 | replace_stride_with_dilation=[False,True,True])
212 | expand = 4
213 | else:
214 | raise NotImplementedError
215 | self.relu = nn.ReLU()
216 | self.upsamplex2 = nn.Upsample(scale_factor=2)
217 | self.upsamplex4 = nn.Upsample(scale_factor=4, mode='bilinear')
218 |
219 | self.classifier = TwoLayerConv2d(in_channels=32, out_channels=output_nc)
220 |
221 | self.resnet_stages_num = resnet_stages_num
222 |
223 | self.if_upsample_2x = if_upsample_2x
224 | if self.resnet_stages_num == 5:
225 | layers = 512 * expand
226 | elif self.resnet_stages_num == 4:
227 | layers = 256 * expand
228 | elif self.resnet_stages_num == 3:
229 | layers = 128 * expand
230 | else:
231 | raise NotImplementedError
232 | self.conv_pred = nn.Conv2d(layers, 32, kernel_size=3, padding=1)
233 |
234 | self.output_sigmoid = output_sigmoid
235 | self.sigmoid = nn.Sigmoid()
236 |
237 | def forward(self, x1, x2):
238 | x1 = self.forward_single(x1)
239 | x2 = self.forward_single(x2)
240 | x = torch.abs(x1 - x2)
241 | if not self.if_upsample_2x:
242 | x = self.upsamplex2(x)
243 | x = self.upsamplex4(x)
244 | x = self.classifier(x)
245 |
246 | if self.output_sigmoid:
247 | x = self.sigmoid(x)
248 | return x
249 |
250 | def forward_single(self, x):
251 | # resnet layers
252 | x = self.resnet.conv1(x)
253 | x = self.resnet.bn1(x)
254 | x = self.resnet.relu(x)
255 | x = self.resnet.maxpool(x)
256 |
257 | x_4 = self.resnet.layer1(x) # 1/4, in=64, out=64
258 | x_8 = self.resnet.layer2(x_4) # 1/8, in=64, out=128
259 |
260 | if self.resnet_stages_num > 3:
261 | x_8 = self.resnet.layer3(x_8) # 1/8, in=128, out=256
262 |
263 | if self.resnet_stages_num == 5:
264 | x_8 = self.resnet.layer4(x_8) # 1/32, in=256, out=512
265 | elif self.resnet_stages_num > 5:
266 | raise NotImplementedError
267 |
268 | if self.if_upsample_2x:
269 | x = self.upsamplex2(x_8)
270 | else:
271 | x = x_8
272 | # output layers
273 | x = self.conv_pred(x)
274 | return x
275 |
276 | class Rnn(nn.Module):
277 | def __init__(self, in_dim, hidden_dim, n_layer, n_class):
278 | super(Rnn, self).__init__()
279 | self.n_layer = n_layer
280 | self.hidden_dim = hidden_dim
281 | self.lstm = nn.LSTM(in_dim, hidden_dim, n_layer,
282 | batch_first=True)
283 | self.classifier = nn.Linear(hidden_dim, n_class)
284 |
285 | def forward(self, x):
286 | # h0 = Variable(torch.zeros(self.n_layer, x.size(1),
287 | # self.hidden_dim)).cuda()
288 | # c0 = Variable(torch.zeros(self.n_layer, x.size(1),
289 | # self.hidden_dim)).cuda()
290 | x = x.to(device)
291 | out, _ = self.lstm(x)
292 |
293 | return out
294 |
295 | class BASE_Transformer(ResNet):
296 | """
297 | Resnet of 8 downsampling + BIT + bitemporal feature Differencing + a small CNN
298 | """
299 | def __init__(self, input_nc, output_nc, with_pos, resnet_stages_num=5,
300 | token_len=4, token_trans=True,
301 | enc_depth=1, dec_depth=1,
302 | dim_head=64, decoder_dim_head=64,
303 | tokenizer=True, if_upsample_2x=True,
304 | pool_mode='max', pool_size=2,
305 | backbone='resnet18',
306 | decoder_softmax=True, with_decoder_pos=None,
307 | with_decoder=True):
308 | super(BASE_Transformer, self).__init__(input_nc, output_nc,backbone=backbone,
309 | resnet_stages_num=resnet_stages_num,
310 | if_upsample_2x=if_upsample_2x,
311 | )
312 | self.token_len = token_len
313 | self.conv_a = nn.Conv2d(32, self.token_len, kernel_size=1,
314 | padding=0, bias=False)
315 | self.tokenizer = tokenizer
316 | if not self.tokenizer:
317 | # if not use tokenzier,then downsample the feature map into a certain size
318 | self.pooling_size = pool_size
319 | self.pool_mode = pool_mode
320 | self.token_len = self.pooling_size * self.pooling_size
321 |
322 | self.token_trans = token_trans
323 | self.with_decoder = with_decoder
324 | dim = 32
325 | mlp_dim = 2*dim
326 |
327 | self.with_pos = with_pos
328 | if with_pos == 'learned':
329 | self.pos_embedding = nn.Parameter(torch.randn(1, self.token_len*2, 32))
330 | decoder_pos_size = 256//4
331 | self.with_decoder_pos = with_decoder_pos
332 | if self.with_decoder_pos == 'learned':
333 | self.pos_embedding_decoder =nn.Parameter(torch.randn(1, 32,
334 | decoder_pos_size,
335 | decoder_pos_size))
336 | self.enc_depth = enc_depth
337 | self.dec_depth = dec_depth
338 | self.dim_head = dim_head
339 | self.decoder_dim_head = decoder_dim_head
340 | self.transformer = Transformer(dim=dim, depth=self.enc_depth, heads=8,
341 | dim_head=self.dim_head,
342 | mlp_dim=mlp_dim, dropout=0)
343 | self.transformer_decoder = TransformerDecoder(dim=dim, depth=self.dec_depth,
344 | heads=8, dim_head=self.decoder_dim_head, mlp_dim=mlp_dim, dropout=0,
345 | softmax=decoder_softmax)
346 |
347 | def _forward_semantic_tokens(self, x):
348 | b, c, h, w = x.shape
349 | spatial_attention = self.conv_a(x)
350 | spatial_attention = spatial_attention.view([b, self.token_len, -1]).contiguous()
351 | spatial_attention = torch.softmax(spatial_attention, dim=-1)
352 | x = x.view([b, c, -1]).contiguous()
353 | tokens = torch.einsum('bln,bcn->blc', spatial_attention, x)
354 |
355 | return tokens
356 |
357 | def _forward_reshape_tokens(self, x):
358 | # b,c,h,w = x.shape
359 | if self.pool_mode == 'max':
360 | x = F.adaptive_max_pool2d(x, [self.pooling_size, self.pooling_size])
361 | elif self.pool_mode == 'ave':
362 | x = F.adaptive_avg_pool2d(x, [self.pooling_size, self.pooling_size])
363 | else:
364 | x = x
365 | tokens = rearrange(x, 'b c h w -> b (h w) c')
366 | return tokens
367 |
368 | def _forward_transformer(self, x):
369 | if self.with_pos:
370 | x += self.pos_embedding
371 | x = self.transformer(x)
372 | return x
373 |
374 | def _forward_transformer_decoder(self, x, m):
375 | b, c, h, w = x.shape
376 | if self.with_decoder_pos == 'fix':
377 | x = x + self.pos_embedding_decoder
378 | elif self.with_decoder_pos == 'learned':
379 | x = x + self.pos_embedding_decoder
380 | x = rearrange(x, 'b c h w -> b (h w) c')
381 | x = self.transformer_decoder(x, m)
382 | x = rearrange(x, 'b (h w) c -> b c h w', h=h)
383 | return x
384 |
385 | def _forward_simple_decoder(self, x, m):
386 | b, c, h, w = x.shape
387 | b, l, c = m.shape
388 | m = m.expand([h,w,b,l,c])
389 | m = rearrange(m, 'h w b l c -> l b c h w')
390 | m = m.sum(0)
391 | x = x + m
392 | return x
393 | def zl_difference(self, x):
394 | # # #zl_CDDM###############
395 |
396 | ##1 step
397 | avg_pooling = torch.nn.AdaptiveMaxPool2d((256, 256))
398 | _c1_21 = avg_pooling(x)
399 | _c1_21 = rearrange(_c1_21, 'b c h w -> b c (h w)')
400 | ##1.1 step
401 | conv = nn.Conv1d(_c1_21.shape[1], x.shape[2], kernel_size=3, padding=1, groups=1, stride=1, bias=False).cuda()
402 | _c1_22=conv(_c1_21)
403 | nn.ReLU(),
404 | ##1.2 step
405 | conv1_LZ = nn.Conv1d(x.shape[2], x.shape[1], kernel_size=3, padding=1, groups=1, stride=1, bias=False).cuda()
406 | _c1_24=conv1_LZ(_c1_22)
407 | _c1_24=_c1_24.unsqueeze(0)
408 | _c1_24=np.transpose(_c1_24,(1,2,3,0)).cuda()
409 | nn.Sigmoid(),
410 | _c1_24=_c1_24.squeeze(3)
411 | _c1_24 = x.reshape(x.shape[0],x.shape[1],x.shape[2],x.shape[3],)
412 | #M_1J
413 | _c1_25_M_1J=_c1_24*x
414 | #2 step
415 | #2.1 step
416 | conv3_LZ = nn.Conv3d( _c1_25_M_1J.shape[0],_c1_25_M_1J.shape[0], kernel_size=3, padding=1, groups=1, stride=1, bias=False).cuda()
417 | _c1_26_M_1J=conv3_LZ(_c1_25_M_1J)
418 | nn.Sigmoid(),
419 | #2.2 step
420 | conv3_LZ = nn.Conv3d(_c1_26_M_1J.shape[0], _c1_25_M_1J.shape[0], kernel_size=3, padding=1, groups=1, stride=1, bias=False).cuda()
421 | _c1_27_M_1J=conv3_LZ(_c1_26_M_1J)
422 | #M_2J
423 | _c1_28_M_2J=_c1_27_M_1J*_c1_25_M_1J
424 | #result
425 | _c1_29_M_2J =_c1_28_M_2J.squeeze(0)
426 | out=_c1_29_M_2J
427 | # ###CDDM###############
428 |
429 | return out
430 | def forward(self, x1, x2):
431 | # forward backbone resnet
432 | x_n0=abs(x1-x2)
433 | x1 = self.forward_single(x1)
434 | x2 = self.forward_single(x2)
435 |
436 | # forward tokenzier
437 | if self.tokenizer:
438 | token1 = self._forward_semantic_tokens(x1)
439 | token2 = self._forward_semantic_tokens(x2)
440 | else:
441 | token1 = self._forward_reshape_tokens(x1)
442 | token2 = self._forward_reshape_tokens(x2)
443 | # forward transformer encoder
444 | if self.token_trans:
445 | self.tokens_ = torch.cat([token1, token2], dim=1)
446 | self.tokens = self._forward_transformer(self.tokens_)
447 | token1, token2 = self.tokens.chunk(2, dim=1)
448 | # forward transformer decoder
449 | if self.with_decoder:
450 | x1 = self._forward_transformer_decoder(x1, token1)
451 | x2 = self._forward_transformer_decoder(x2, token2)
452 | else:
453 | x1 = self._forward_simple_decoder(x1, token1)
454 | x2 = self._forward_simple_decoder(x2, token2)
455 | x101 = abs(x1-x2)
456 | if not self.if_upsample_2x:
457 | x101 = self.upsamplex2(x101)
458 | x101 = self.upsamplex4(x101)
459 | # forward small cnn
460 | x101 = self.classifier(x101)
461 |
462 |
463 | x11=x1
464 | if not self.if_upsample_2x:
465 | x11 = self.upsamplex2(x11)
466 | x11 = self.upsamplex4(x11)
467 | # forward small cnn
468 | x110 = self.classifier(x11)
469 |
470 | x22=x2
471 | if not self.if_upsample_2x:
472 | x22 = self.upsamplex2(x22)
473 | x22 = self.upsamplex4(x22)
474 | # forward small cnn
475 | x22 = self.classifier(x22)
476 |
477 | #TMM
478 | model = Rnn(x_size, y_size, 2, 2, help="x_size and y_size are the pixel size your set, my pixel size[32, 32]")
479 | model = model.cuda()
480 | #Maxpooling
481 | maxPool = nn.MaxPool2d(kernel_size=2, stride=None, padding=0, dilation=1,return_indices=False, ceil_mode=False)
482 | ###fusion
483 |
484 |
485 | x_add=x1+x2
486 | x_maxpool=maxPool( x_add)
487 | x_before = self.upsamplex2(x_maxpool)
488 | x_after = rearrange(x_before, 'b c h w -> b (h w) c')
489 | x_TMM_0 = model(x_after)
490 | x_TMM_1 = x_TMM_0 .reshape(x101.shape[0],x101.shape[1],x101.shape[2],x101.shape[3],)
491 |
492 | #CDDM
493 | x111=self.zl_difference(x110)
494 |
495 | x222=self.zl_difference(x22)
496 |
497 |
498 | x_CDDM=torch.abs( x111- x222 )
499 | if not self.if_upsample_2x:
500 | x_CDDM = self.upsamplex2(x_CDDM)
501 |
502 | # forward small cnn
503 |
504 | x=x_CDDM*x101+x101+x_TMM_1
505 |
506 | if self.output_sigmoid:
507 | x = self.sigmoid(x)
508 | outputs = []
509 | outputs.append(x)
510 | return outputs
511 |
512 |
513 |
--------------------------------------------------------------------------------
/STADE-CDNet/models/pixel_shuffel_up.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn.functional as F
4 | from torch import nn
5 |
6 | def icnr(x, scale=2, init=nn.init.kaiming_normal_):
7 | """
8 | Checkerboard artifact free sub-pixel convolution
9 | https://arxiv.org/abs/1707.02937
10 | """
11 | ni,nf,h,w = x.shape
12 | ni2 = int(ni/(scale**2))
13 | k = init(torch.zeros([ni2,nf,h,w])).transpose(0, 1)
14 | k = k.contiguous().view(ni2, nf, -1)
15 | k = k.repeat(1, 1, scale**2)
16 | k = k.contiguous().view([nf,ni,h,w]).transpose(0, 1)
17 | x.data.copy_(k)
18 |
19 |
20 | class PixelShuffle(nn.Module):
21 | """
22 | Real-Time Single Image and Video Super-Resolution
23 | https://arxiv.org/abs/1609.05158
24 | """
25 | def __init__(self, n_channels, scale):
26 | super(PixelShuffle, self).__init__()
27 | self.conv = nn.Conv2d(n_channels, n_channels*(scale**2), kernel_size=1)
28 | icnr(self.conv.weight)
29 | self.shuf = nn.PixelShuffle(scale)
30 | self.relu = nn.ReLU(inplace=True)
31 |
32 | def forward(self,x):
33 | x = self.shuf(self.relu(self.conv(x)))
34 | return x
35 |
36 |
37 | def upsample(in_channels, out_channels, upscale, kernel_size=3):
38 | # A series of x 2 upsamling until we get to the upscale we want
39 | layers = []
40 | conv1x1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
41 | nn.init.kaiming_normal_(conv1x1.weight.data, nonlinearity='relu')
42 | layers.append(conv1x1)
43 | for i in range(int(math.log(upscale, 2))):
44 | layers.append(PixelShuffle(out_channels, scale=2))
45 | return nn.Sequential(*layers)
46 |
47 |
48 | class PS_UP(nn.Module):
49 | def __init__(self, upscale, conv_in_ch, num_classes):
50 | super(PS_UP, self).__init__()
51 | self.upsample = upsample(conv_in_ch, num_classes, upscale=upscale)
52 |
53 | def forward(self, x):
54 | x = self.upsample(x)
55 | return x
--------------------------------------------------------------------------------
/STADE-CDNet/models/resnet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.utils.model_zoo import load_url as load_state_dict_from_url
4 |
5 | # from torchvision.models.utils import load_state_dict_from_url
6 |
7 |
8 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
9 | 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d',
10 | 'wide_resnet50_2', 'wide_resnet101_2']
11 |
12 |
13 | model_urls = {
14 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
15 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
16 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
17 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
18 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
19 | 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
20 | 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
21 | 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
22 | 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',
23 | }
24 |
25 |
26 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
27 | """3x3 convolution with padding"""
28 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
29 | padding=dilation, groups=groups, bias=False, dilation=dilation)
30 |
31 |
32 | def conv1x1(in_planes, out_planes, stride=1):
33 | """1x1 convolution"""
34 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
35 |
36 |
37 | class BasicBlock(nn.Module):
38 | expansion = 1
39 |
40 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
41 | base_width=64, dilation=1, norm_layer=None):
42 | super(BasicBlock, self).__init__()
43 | if norm_layer is None:
44 | norm_layer = nn.BatchNorm2d
45 | if groups != 1 or base_width != 64:
46 | raise ValueError('BasicBlock only supports groups=1 and base_width=64')
47 | if dilation > 1:
48 | dilation = 1
49 | # raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
50 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1
51 | self.conv1 = conv3x3(inplanes, planes, stride)
52 | self.bn1 = norm_layer(planes)
53 | self.relu = nn.ReLU(inplace=True)
54 | self.conv2 = conv3x3(planes, planes)
55 | self.bn2 = norm_layer(planes)
56 | self.downsample = downsample
57 | self.stride = stride
58 |
59 | def forward(self, x):
60 | identity = x
61 |
62 | out = self.conv1(x)
63 | out = self.bn1(out)
64 | out = self.relu(out)
65 |
66 | out = self.conv2(out)
67 | out = self.bn2(out)
68 |
69 | if self.downsample is not None:
70 | identity = self.downsample(x)
71 |
72 | out += identity
73 | out = self.relu(out)
74 |
75 | return out
76 |
77 |
78 | class Bottleneck(nn.Module):
79 | # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
80 | # while original implementation places the stride at the first 1x1 convolution(self.conv1)
81 | # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
82 | # This variant is also known as ResNet V1.5 and improves accuracy according to
83 | # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
84 |
85 | expansion = 4
86 |
87 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
88 | base_width=64, dilation=1, norm_layer=None):
89 | super(Bottleneck, self).__init__()
90 | if norm_layer is None:
91 | norm_layer = nn.BatchNorm2d
92 | width = int(planes * (base_width / 64.)) * groups
93 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1
94 | self.conv1 = conv1x1(inplanes, width)
95 | self.bn1 = norm_layer(width)
96 | self.conv2 = conv3x3(width, width, stride, groups, dilation)
97 | self.bn2 = norm_layer(width)
98 | self.conv3 = conv1x1(width, planes * self.expansion)
99 | self.bn3 = norm_layer(planes * self.expansion)
100 | self.relu = nn.ReLU(inplace=True)
101 | self.downsample = downsample
102 | self.stride = stride
103 |
104 | def forward(self, x):
105 | identity = x
106 |
107 | out = self.conv1(x)
108 | out = self.bn1(out)
109 | out = self.relu(out)
110 |
111 | out = self.conv2(out)
112 | out = self.bn2(out)
113 | out = self.relu(out)
114 |
115 | out = self.conv3(out)
116 | out = self.bn3(out)
117 |
118 | if self.downsample is not None:
119 | identity = self.downsample(x)
120 |
121 | out += identity
122 | out = self.relu(out)
123 |
124 | return out
125 |
126 |
127 | class ResNet(nn.Module):
128 |
129 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
130 | groups=1, width_per_group=64, replace_stride_with_dilation=None,
131 | norm_layer=None, strides=None):
132 | super(ResNet, self).__init__()
133 | if norm_layer is None:
134 | norm_layer = nn.BatchNorm2d
135 | self._norm_layer = norm_layer
136 |
137 | self.strides = strides
138 | if self.strides is None:
139 | self.strides = [2, 2, 2, 2, 2]
140 |
141 | self.inplanes = 64
142 | self.dilation = 1
143 | if replace_stride_with_dilation is None:
144 | # each element in the tuple indicates if we should replace
145 | # the 2x2 stride with a dilated convolution instead
146 | replace_stride_with_dilation = [False, False, False]
147 | if len(replace_stride_with_dilation) != 3:
148 | raise ValueError("replace_stride_with_dilation should be None "
149 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
150 | self.groups = groups
151 | self.base_width = width_per_group
152 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=self.strides[0], padding=3,
153 | bias=False)
154 | self.bn1 = norm_layer(self.inplanes)
155 | self.relu = nn.ReLU(inplace=True)
156 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=self.strides[1], padding=1)
157 | self.layer1 = self._make_layer(block, 64, layers[0])
158 | self.layer2 = self._make_layer(block, 128, layers[1], stride=self.strides[2],
159 | dilate=replace_stride_with_dilation[0])
160 | self.layer3 = self._make_layer(block, 256, layers[2], stride=self.strides[3],
161 | dilate=replace_stride_with_dilation[1])
162 | self.layer4 = self._make_layer(block, 512, layers[3], stride=self.strides[4],
163 | dilate=replace_stride_with_dilation[2])
164 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
165 | self.fc = nn.Linear(512 * block.expansion, num_classes)
166 |
167 | for m in self.modules():
168 | if isinstance(m, nn.Conv2d):
169 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
170 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
171 | nn.init.constant_(m.weight, 1)
172 | nn.init.constant_(m.bias, 0)
173 |
174 | # Zero-initialize the last BN in each residual branch,
175 | # so that the residual branch starts with zeros, and each residual block behaves like an identity.
176 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
177 | if zero_init_residual:
178 | for m in self.modules():
179 | if isinstance(m, Bottleneck):
180 | nn.init.constant_(m.bn3.weight, 0)
181 | elif isinstance(m, BasicBlock):
182 | nn.init.constant_(m.bn2.weight, 0)
183 |
184 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
185 | norm_layer = self._norm_layer
186 | downsample = None
187 | previous_dilation = self.dilation
188 | if dilate:
189 | self.dilation *= stride
190 | stride = 1
191 | if stride != 1 or self.inplanes != planes * block.expansion:
192 | downsample = nn.Sequential(
193 | conv1x1(self.inplanes, planes * block.expansion, stride),
194 | norm_layer(planes * block.expansion),
195 | )
196 |
197 | layers = []
198 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
199 | self.base_width, previous_dilation, norm_layer))
200 | self.inplanes = planes * block.expansion
201 | for _ in range(1, blocks):
202 | layers.append(block(self.inplanes, planes, groups=self.groups,
203 | base_width=self.base_width, dilation=self.dilation,
204 | norm_layer=norm_layer))
205 |
206 | return nn.Sequential(*layers)
207 |
208 | def _forward_impl(self, x):
209 | # See note [TorchScript super()]
210 | x = self.conv1(x)
211 | x = self.bn1(x)
212 | x = self.relu(x)
213 | x = self.maxpool(x)
214 |
215 | x = self.layer1(x)
216 | x = self.layer2(x)
217 | x = self.layer3(x)
218 | x = self.layer4(x)
219 |
220 | x = self.avgpool(x)
221 | x = torch.flatten(x, 1)
222 | x = self.fc(x)
223 |
224 | return x
225 |
226 | def forward(self, x):
227 | return self._forward_impl(x)
228 |
229 |
230 | def _resnet(arch, block, layers, pretrained, progress, **kwargs):
231 | model = ResNet(block, layers, **kwargs)
232 | if pretrained:
233 | state_dict = load_state_dict_from_url(model_urls[arch],
234 | progress=progress)
235 | model.load_state_dict(state_dict)
236 | return model
237 |
238 |
239 | def resnet18(pretrained=False, progress=True, **kwargs):
240 | r"""ResNet-18 model from
241 | `"Deep Residual Learning for Image Recognition" `_
242 |
243 | Args:
244 | pretrained (bool): If True, returns a model pre-trained on ImageNet
245 | progress (bool): If True, displays a progress bar of the download to stderr
246 | """
247 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress,
248 | **kwargs)
249 |
250 |
251 | def resnet34(pretrained=False, progress=True, **kwargs):
252 | r"""ResNet-34 model from
253 | `"Deep Residual Learning for Image Recognition" `_
254 |
255 | Args:
256 | pretrained (bool): If True, returns a model pre-trained on ImageNet
257 | progress (bool): If True, displays a progress bar of the download to stderr
258 | """
259 | return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress,
260 | **kwargs)
261 |
262 |
263 | def resnet50(pretrained=False, progress=True, **kwargs):
264 | r"""ResNet-50 model from
265 | `"Deep Residual Learning for Image Recognition" `_
266 |
267 | Args:
268 | pretrained (bool): If True, returns a model pre-trained on ImageNet
269 | progress (bool): If True, displays a progress bar of the download to stderr
270 | """
271 | return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress,
272 | **kwargs)
273 |
274 |
275 | def resnet101(pretrained=False, progress=True, **kwargs):
276 | r"""ResNet-101 model from
277 | `"Deep Residual Learning for Image Recognition" `_
278 |
279 | Args:
280 | pretrained (bool): If True, returns a model pre-trained on ImageNet
281 | progress (bool): If True, displays a progress bar of the download to stderr
282 | """
283 | return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress,
284 | **kwargs)
285 |
286 |
287 | def resnet152(pretrained=False, progress=True, **kwargs):
288 | r"""ResNet-152 model from
289 | `"Deep Residual Learning for Image Recognition" `_
290 |
291 | Args:
292 | pretrained (bool): If True, returns a model pre-trained on ImageNet
293 | progress (bool): If True, displays a progress bar of the download to stderr
294 | """
295 | return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress,
296 | **kwargs)
297 |
298 |
299 | def resnext50_32x4d(pretrained=False, progress=True, **kwargs):
300 | r"""ResNeXt-50 32x4d model from
301 | `"Aggregated Residual Transformation for Deep Neural Networks" `_
302 |
303 | Args:
304 | pretrained (bool): If True, returns a model pre-trained on ImageNet
305 | progress (bool): If True, displays a progress bar of the download to stderr
306 | """
307 | kwargs['groups'] = 32
308 | kwargs['width_per_group'] = 4
309 | return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3],
310 | pretrained, progress, **kwargs)
311 |
312 |
313 | def resnext101_32x8d(pretrained=False, progress=True, **kwargs):
314 | r"""ResNeXt-101 32x8d model from
315 | `"Aggregated Residual Transformation for Deep Neural Networks" `_
316 |
317 | Args:
318 | pretrained (bool): If True, returns a model pre-trained on ImageNet
319 | progress (bool): If True, displays a progress bar of the download to stderr
320 | """
321 | kwargs['groups'] = 32
322 | kwargs['width_per_group'] = 8
323 | return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3],
324 | pretrained, progress, **kwargs)
325 |
326 |
327 | def wide_resnet50_2(pretrained=False, progress=True, **kwargs):
328 | r"""Wide ResNet-50-2 model from
329 | `"Wide Residual Networks" `_
330 |
331 | The model is the same as ResNet except for the bottleneck number of channels
332 | which is twice larger in every block. The number of channels in outer 1x1
333 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
334 | channels, and in Wide ResNet-50-2 has 2048-1024-2048.
335 |
336 | Args:
337 | pretrained (bool): If True, returns a model pre-trained on ImageNet
338 | progress (bool): If True, displays a progress bar of the download to stderr
339 | """
340 | kwargs['width_per_group'] = 64 * 2
341 | return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3],
342 | pretrained, progress, **kwargs)
343 |
344 |
345 | def wide_resnet101_2(pretrained=False, progress=True, **kwargs):
346 | r"""Wide ResNet-101-2 model from
347 | `"Wide Residual Networks" `_
348 |
349 | The model is the same as ResNet except for the bottleneck number of channels
350 | which is twice larger in every block. The number of channels in outer 1x1
351 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
352 | channels, and in Wide ResNet-50-2 has 2048-1024-2048.
353 |
354 | Args:
355 | pretrained (bool): If True, returns a model pre-trained on ImageNet
356 | progress (bool): If True, displays a progress bar of the download to stderr
357 | """
358 | kwargs['width_per_group'] = 64 * 2
359 | return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3],
360 | pretrained, progress, **kwargs)
361 |
--------------------------------------------------------------------------------
/STADE-CDNet/models/trainer.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import matplotlib.pyplot as plt
3 | import os
4 |
5 | import utils
6 | from models.networks import *
7 |
8 | import torch
9 | import torch.optim as optim
10 | import numpy as np
11 | from misc.metric_tool import ConfuseMatrixMeter
12 | from models.losses import cross_entropy
13 | import models.losses as losses
14 | from models.losses import get_alpha, softmax_helper, FocalLoss, mIoULoss, mmIoULoss
15 |
16 | from misc.logger_tool import Logger, Timer
17 |
18 | from utils import de_norm
19 |
20 | from tqdm import tqdm
21 |
22 | class CDTrainer():
23 |
24 | def __init__(self, args, dataloaders):
25 | self.args = args
26 | self.dataloaders = dataloaders
27 |
28 | self.n_class = args.n_class
29 | # define G
30 | self.net_G = define_G(args=args, gpu_ids=args.gpu_ids)
31 |
32 |
33 | self.device = torch.device("cuda:%s" % args.gpu_ids[0] if torch.cuda.is_available() and len(args.gpu_ids)>0
34 | else "cpu")
35 | print(self.device)
36 |
37 | # Learning rate and Beta1 for Adam optimizers
38 | self.lr = args.lr
39 |
40 | # define optimizers
41 | if args.optimizer == "sgd":
42 | self.optimizer_G = optim.SGD(self.net_G.parameters(), lr=self.lr,
43 | momentum=0.9,
44 | weight_decay=5e-4)
45 | elif args.optimizer == "adam":
46 | self.optimizer_G = optim.Adam(self.net_G.parameters(), lr=self.lr,
47 | weight_decay=0)
48 | elif args.optimizer == "adamw":
49 | self.optimizer_G = optim.AdamW(self.net_G.parameters(), lr=self.lr,
50 | betas=("x", "y"), weight_decay=0.05)# x, y must be set.
51 |
52 | # self.optimizer_G = optim.Adam(self.net_G.parameters(), lr=self.lr)
53 |
54 | # define lr schedulers
55 | self.exp_lr_scheduler_G = get_scheduler(self.optimizer_G, args)
56 |
57 | self.running_metric = ConfuseMatrixMeter(n_class=2)
58 |
59 | # define logger file
60 | logger_path = os.path.join(args.checkpoint_dir, 'log.txt')
61 | self.logger = Logger(logger_path)
62 | self.logger.write_dict_str(args.__dict__)
63 | # define timer
64 | self.timer = Timer()
65 | self.batch_size = args.batch_size
66 |
67 | # training log
68 | self.epoch_acc = 0
69 | self.best_val_acc = 0.0
70 | self.best_epoch_id = 0
71 | self.epoch_to_start = 0
72 | self.max_num_epochs = args.max_epochs
73 |
74 | self.global_step = 0
75 | self.steps_per_epoch = len(dataloaders['train'])
76 | self.total_steps = (self.max_num_epochs - self.epoch_to_start)*self.steps_per_epoch
77 |
78 | self.G_pred = None
79 | self.pred_vis = None
80 | self.batch = None
81 | self.G_loss = None
82 | self.is_training = False
83 | self.batch_id = 0
84 | self.epoch_id = 0
85 | self.checkpoint_dir = args.checkpoint_dir
86 | self.vis_dir = args.vis_dir
87 |
88 | self.shuffle_AB = args.shuffle_AB
89 |
90 | # define the loss functions
91 | self.multi_scale_train = args.multi_scale_train
92 | self.multi_scale_infer = args.multi_scale_infer
93 | self.weights = tuple(args.multi_pred_weights)
94 | if args.loss == 'ce':
95 | self._pxl_loss = cross_entropy
96 | elif args.loss == 'bce':
97 | self._pxl_loss = losses.binary_ce
98 | elif args.loss == 'fl':
99 | print('\n Calculating alpha in Focal-Loss (FL) ...')
100 | alpha = get_alpha(dataloaders['train']) # calculare class occurences
101 | print(f"alpha-0 (no-change)={alpha[0]}, alpha-1 (change)={alpha[1]}")
102 | self._pxl_loss = FocalLoss(apply_nonlin = softmax_helper, alpha = alpha, gamma = 2, smooth = 1e-5)
103 | elif args.loss == "miou":
104 | print('\n Calculating Class occurances in training set...')
105 | alpha = np.asarray(get_alpha(dataloaders['train'])) # calculare class occurences
106 | alpha = alpha/np.sum(alpha)
107 | # weights = torch.tensor([1.0, 1.0]).cuda()
108 | weights = 1-torch.from_numpy(alpha).cuda()
109 | print(f"Weights = {weights}")
110 | self._pxl_loss = mIoULoss(weight=weights, size_average=True, n_classes=args.n_class).cuda()
111 | elif args.loss == "mmiou":
112 | self._pxl_loss = mmIoULoss(n_classes=args.n_class).cuda()
113 | else:
114 | raise NotImplemented(args.loss)
115 |
116 | self.VAL_ACC = np.array([], np.float32)
117 | if os.path.exists(os.path.join(self.checkpoint_dir, 'val_acc.npy')):
118 | self.VAL_ACC = np.load(os.path.join(self.checkpoint_dir, 'val_acc.npy'))
119 | self.TRAIN_ACC = np.array([], np.float32)
120 | if os.path.exists(os.path.join(self.checkpoint_dir, 'train_acc.npy')):
121 | self.TRAIN_ACC = np.load(os.path.join(self.checkpoint_dir, 'train_acc.npy'))
122 |
123 | # check and create model dir
124 | if os.path.exists(self.checkpoint_dir) is False:
125 | os.mkdir(self.checkpoint_dir)
126 | if os.path.exists(self.vis_dir) is False:
127 | os.mkdir(self.vis_dir)
128 |
129 |
130 | def _load_checkpoint(self, ckpt_name='last_ckpt.pt'):
131 | print("\n")
132 | if os.path.exists(os.path.join(self.checkpoint_dir, ckpt_name)):
133 | self.logger.write('loading last checkpoint...\n')
134 | # load the entire checkpoint
135 | checkpoint = torch.load(os.path.join(self.checkpoint_dir, ckpt_name),
136 | map_location=self.device)
137 | # update net_G states
138 | self.net_G.load_state_dict(checkpoint['model_G_state_dict'],strict=False)#
139 |
140 | # self.optimizer_G.load_state_dict(checkpoint['optimizer_G_state_dict'])
141 | self.exp_lr_scheduler_G.load_state_dict(
142 | checkpoint['exp_lr_scheduler_G_state_dict'])
143 |
144 | self.net_G.to(self.device)
145 |
146 | # update some other states
147 | self.epoch_to_start = checkpoint['epoch_id'] + 1
148 | self.best_val_acc = checkpoint['best_val_acc']
149 | self.best_epoch_id = checkpoint['best_epoch_id']
150 |
151 | self.total_steps = (self.max_num_epochs - self.epoch_to_start)*self.steps_per_epoch
152 |
153 | self.logger.write('Epoch_to_start = %d, Historical_best_acc = %.4f (at epoch %d)\n' %
154 | (self.epoch_to_start, self.best_val_acc, self.best_epoch_id))
155 | self.logger.write('\n')
156 | elif self.args.pretrain is not None:
157 | print("Initializing backbone weights from: " + self.args.pretrain)
158 | self.net_G.load_state_dict(torch.load(self.args.pretrain), strict=False)
159 | self.net_G.to(self.device)
160 | self.net_G.eval()
161 | else:
162 | print('training from scratch...')
163 | print("\n")
164 |
165 | def _timer_update(self):
166 | self.global_step = (self.epoch_id-self.epoch_to_start) * self.steps_per_epoch + self.batch_id
167 |
168 | self.timer.update_progress((self.global_step + 1) / self.total_steps)
169 | est = self.timer.estimated_remaining()
170 | imps = (self.global_step + 1) * self.batch_size / self.timer.get_stage_elapsed()
171 | return imps, est
172 |
173 | def _visualize_pred(self):
174 | pred = torch.argmax(self.G_final_pred, dim=1, keepdim=True)
175 | pred_vis = pred * 255
176 | return pred_vis
177 |
178 | def _save_checkpoint(self, ckpt_name):
179 | torch.save({
180 | 'epoch_id': self.epoch_id,
181 | 'best_val_acc': self.best_val_acc,
182 | 'best_epoch_id': self.best_epoch_id,
183 | 'model_G_state_dict': self.net_G.state_dict(),
184 | 'optimizer_G_state_dict': self.optimizer_G.state_dict(),
185 | 'exp_lr_scheduler_G_state_dict': self.exp_lr_scheduler_G.state_dict(),
186 | }, os.path.join(self.checkpoint_dir, ckpt_name))
187 |
188 | def _update_lr_schedulers(self):
189 | self.exp_lr_scheduler_G.step()
190 |
191 | def _update_metric(self):
192 | """
193 | update metric
194 | """
195 | target = self.batch['L'].to(self.device).detach()
196 | G_pred = self.G_final_pred.detach()
197 |
198 | G_pred = torch.argmax(G_pred, dim=1)
199 |
200 | current_score = self.running_metric.update_cm(pr=G_pred.cpu().numpy(), gt=target.cpu().numpy())
201 | return current_score
202 |
203 | def _collect_running_batch_states(self):
204 |
205 | running_acc = self._update_metric()
206 |
207 | m = len(self.dataloaders['train'])
208 | if self.is_training is False:
209 | m = len(self.dataloaders['val'])
210 |
211 | imps, est = self._timer_update()
212 | if np.mod(self.batch_id, 200) == 1:
213 | message = 'Is_training: %s. [%d,%d][%d,%d], imps: %.2f, est: %.2fh, G_loss: %.5f, running_mf1: %.5f\n' %\
214 | (self.is_training, self.epoch_id, self.max_num_epochs-1, self.batch_id, m,
215 | imps*self.batch_size, est,
216 | self.G_loss.item(), running_acc)
217 | self.logger.write(message)
218 |
219 |
220 | if np.mod(self.batch_id, 200) == 1:
221 | vis_input = utils.make_numpy_grid(de_norm(self.batch['A']))
222 | vis_input2 = utils.make_numpy_grid(de_norm(self.batch['B']))
223 |
224 | vis_pred = utils.make_numpy_grid(self._visualize_pred())
225 |
226 | vis_gt = utils.make_numpy_grid(self.batch['L'])
227 | vis = np.concatenate([vis_input, vis_input2, vis_pred, vis_gt], axis=0)
228 | vis = np.clip(vis, a_min=0.0, a_max=1.0)
229 | file_name = os.path.join(
230 | '/where', 'istrain_'+str(self.is_training)+'_'+
231 | str(self.epoch_id)+'_'+str(self.batch_id)+'.jpg')
232 | plt.imsave(file_name, vis)
233 |
234 | def _collect_epoch_states(self):
235 | scores = self.running_metric.get_scores()
236 | self.epoch_acc = scores['mf1']
237 | self.logger.write('Is_training: %s. Epoch %d / %d, epoch_mF1= %.5f\n' %
238 | (self.is_training, self.epoch_id, self.max_num_epochs-1, self.epoch_acc))
239 | message = ''
240 | for k, v in scores.items():
241 | message += '%s: %.5f ' % (k, v)
242 | self.logger.write(message+'\n')
243 | self.logger.write('\n')
244 |
245 | def _update_checkpoints(self):
246 |
247 | # save current model
248 | self._save_checkpoint(ckpt_name='last_ckpt.pt')
249 | self.logger.write('Lastest model updated. Epoch_acc=%.4f, Historical_best_acc=%.4f (at epoch %d)\n'
250 | % (self.epoch_acc, self.best_val_acc, self.best_epoch_id))
251 | self.logger.write('\n')
252 |
253 | # update the best model (based on eval acc)
254 | if self.epoch_acc > self.best_val_acc:
255 | self.best_val_acc = self.epoch_acc
256 | self.best_epoch_id = self.epoch_id
257 | self._save_checkpoint(ckpt_name='best_ckpt.pt')
258 | self.logger.write('*' * 10 + 'Best model updated!\n')
259 | self.logger.write('\n')
260 |
261 | def _update_training_acc_curve(self):
262 | # update train acc curve
263 | self.TRAIN_ACC = np.append(self.TRAIN_ACC, [self.epoch_acc])
264 | np.save(os.path.join(self.checkpoint_dir, 'train_acc.npy'), self.TRAIN_ACC)
265 |
266 | def _update_val_acc_curve(self):
267 | # update val acc curve
268 | self.VAL_ACC = np.append(self.VAL_ACC, [self.epoch_acc])
269 | np.save(os.path.join(self.checkpoint_dir, 'val_acc.npy'), self.VAL_ACC)
270 |
271 | def _clear_cache(self):
272 | self.running_metric.clear()
273 |
274 |
275 | def _forward_pass(self, batch):
276 | self.batch = batch
277 | img_in1 = batch['A'].to(self.device)
278 | img_in2 = batch['B'].to(self.device)
279 | self.G_pred = self.net_G(img_in1, img_in2)
280 |
281 | if self.multi_scale_infer == "True":
282 | self.G_final_pred = torch.zeros(self.G_pred[-1].size()).to(self.device)
283 | for pred in self.G_pred:
284 | if pred.size(2) != self.G_pred[-1].size(2):
285 | self.G_final_pred = self.G_final_pred + F.interpolate(pred, size=self.G_pred[-1].size(2), mode="nearest")
286 | else:
287 | self.G_final_pred = self.G_final_pred + pred
288 | self.G_final_pred = self.G_final_pred/len(self.G_pred)
289 | else:
290 | self.G_final_pred = self.G_pred[-1]
291 |
292 |
293 | def _backward_G(self):
294 | gt = self.batch['L'].to(self.device).float()
295 | if self.multi_scale_train == "True":
296 | i = 0
297 | temp_loss = 0.0
298 | for pred in self.G_pred:
299 | if pred.size(2) != gt.size(2):
300 | temp_loss = temp_loss + self.weights[i]*self._pxl_loss(pred, F.interpolate(gt, size=pred.size(2), mode="nearest"))
301 | else:
302 | temp_loss = temp_loss + self.weights[i]*self._pxl_loss(pred, gt)
303 | i+=1
304 | self.G_loss = temp_loss
305 | else:
306 | self.G_loss = self._pxl_loss(self.G_pred[-1], gt)
307 |
308 | self.G_loss.backward()
309 |
310 |
311 | def train_models(self):
312 |
313 | self._load_checkpoint()
314 |
315 | # loop over the dataset multiple times
316 | for self.epoch_id in range(self.epoch_to_start, self.max_num_epochs):
317 |
318 | ################## train #################
319 | ##########################################
320 | self._clear_cache()
321 | self.is_training = True
322 | self.net_G.train() # Set model to training mode
323 | # Iterate over data.
324 | total = len(self.dataloaders['train'])
325 | l=self.optimizer_G.param_groups[0]['lr']
326 | #l=0.00001
327 | self.logger.write('lr: %0.07f\n \n' % l)#%0.7f
328 | for self.batch_id, batch in tqdm(enumerate(self.dataloaders['train'], 1), total=total):
329 | self._forward_pass(batch)
330 | # update G
331 | self.optimizer_G.zero_grad()
332 | self._backward_G()
333 | self.optimizer_G.step()
334 | self._collect_running_batch_states()
335 | self._timer_update()
336 |
337 | self._collect_epoch_states()
338 | self._update_training_acc_curve()
339 | self._update_lr_schedulers()
340 |
341 |
342 | ################## Eval ##################
343 | ##########################################
344 | self.logger.write('Begin evaluation...\n')
345 | self._clear_cache()
346 | self.is_training = False
347 | self.net_G.eval()
348 |
349 | # Iterate over data.
350 | for self.batch_id, batch in enumerate(self.dataloaders['val'], 0):
351 | with torch.no_grad():
352 | self._forward_pass(batch)
353 | self._collect_running_batch_states()
354 | self._collect_epoch_states()
355 |
356 | ########### Update_Checkpoints ###########
357 | ##########################################
358 | self._update_val_acc_curve()
359 | self._update_checkpoints()
360 |
361 |
--------------------------------------------------------------------------------
/STADE-CDNet/samples_DSIFN/A/0_2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LiLisaZhi/STADE-CDNet/946b0bc29ff630ce30c72ed16cfb18bf5b6151e3/STADE-CDNet/samples_DSIFN/A/0_2.png
--------------------------------------------------------------------------------
/STADE-CDNet/samples_DSIFN/A/1_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LiLisaZhi/STADE-CDNet/946b0bc29ff630ce30c72ed16cfb18bf5b6151e3/STADE-CDNet/samples_DSIFN/A/1_1.png
--------------------------------------------------------------------------------
/STADE-CDNet/samples_DSIFN/A/2_4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LiLisaZhi/STADE-CDNet/946b0bc29ff630ce30c72ed16cfb18bf5b6151e3/STADE-CDNet/samples_DSIFN/A/2_4.png
--------------------------------------------------------------------------------
/STADE-CDNet/samples_DSIFN/A/3_4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LiLisaZhi/STADE-CDNet/946b0bc29ff630ce30c72ed16cfb18bf5b6151e3/STADE-CDNet/samples_DSIFN/A/3_4.png
--------------------------------------------------------------------------------
/STADE-CDNet/samples_DSIFN/A/4_4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LiLisaZhi/STADE-CDNet/946b0bc29ff630ce30c72ed16cfb18bf5b6151e3/STADE-CDNet/samples_DSIFN/A/4_4.png
--------------------------------------------------------------------------------
/STADE-CDNet/samples_DSIFN/A/5_3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LiLisaZhi/STADE-CDNet/946b0bc29ff630ce30c72ed16cfb18bf5b6151e3/STADE-CDNet/samples_DSIFN/A/5_3.png
--------------------------------------------------------------------------------
/STADE-CDNet/samples_DSIFN/A/6_3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LiLisaZhi/STADE-CDNet/946b0bc29ff630ce30c72ed16cfb18bf5b6151e3/STADE-CDNet/samples_DSIFN/A/6_3.png
--------------------------------------------------------------------------------
/STADE-CDNet/samples_DSIFN/A/7_4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LiLisaZhi/STADE-CDNet/946b0bc29ff630ce30c72ed16cfb18bf5b6151e3/STADE-CDNet/samples_DSIFN/A/7_4.png
--------------------------------------------------------------------------------
/STADE-CDNet/samples_DSIFN/A/8_3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LiLisaZhi/STADE-CDNet/946b0bc29ff630ce30c72ed16cfb18bf5b6151e3/STADE-CDNet/samples_DSIFN/A/8_3.png
--------------------------------------------------------------------------------
/STADE-CDNet/samples_DSIFN/A/9_3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LiLisaZhi/STADE-CDNet/946b0bc29ff630ce30c72ed16cfb18bf5b6151e3/STADE-CDNet/samples_DSIFN/A/9_3.png
--------------------------------------------------------------------------------
/STADE-CDNet/samples_DSIFN/B/0_2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LiLisaZhi/STADE-CDNet/946b0bc29ff630ce30c72ed16cfb18bf5b6151e3/STADE-CDNet/samples_DSIFN/B/0_2.png
--------------------------------------------------------------------------------
/STADE-CDNet/samples_DSIFN/B/1_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LiLisaZhi/STADE-CDNet/946b0bc29ff630ce30c72ed16cfb18bf5b6151e3/STADE-CDNet/samples_DSIFN/B/1_1.png
--------------------------------------------------------------------------------
/STADE-CDNet/samples_DSIFN/B/2_4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LiLisaZhi/STADE-CDNet/946b0bc29ff630ce30c72ed16cfb18bf5b6151e3/STADE-CDNet/samples_DSIFN/B/2_4.png
--------------------------------------------------------------------------------
/STADE-CDNet/samples_DSIFN/B/3_4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LiLisaZhi/STADE-CDNet/946b0bc29ff630ce30c72ed16cfb18bf5b6151e3/STADE-CDNet/samples_DSIFN/B/3_4.png
--------------------------------------------------------------------------------
/STADE-CDNet/samples_DSIFN/B/4_4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LiLisaZhi/STADE-CDNet/946b0bc29ff630ce30c72ed16cfb18bf5b6151e3/STADE-CDNet/samples_DSIFN/B/4_4.png
--------------------------------------------------------------------------------
/STADE-CDNet/samples_DSIFN/B/5_3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LiLisaZhi/STADE-CDNet/946b0bc29ff630ce30c72ed16cfb18bf5b6151e3/STADE-CDNet/samples_DSIFN/B/5_3.png
--------------------------------------------------------------------------------
/STADE-CDNet/samples_DSIFN/B/6_3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LiLisaZhi/STADE-CDNet/946b0bc29ff630ce30c72ed16cfb18bf5b6151e3/STADE-CDNet/samples_DSIFN/B/6_3.png
--------------------------------------------------------------------------------
/STADE-CDNet/samples_DSIFN/B/7_4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LiLisaZhi/STADE-CDNet/946b0bc29ff630ce30c72ed16cfb18bf5b6151e3/STADE-CDNet/samples_DSIFN/B/7_4.png
--------------------------------------------------------------------------------
/STADE-CDNet/samples_DSIFN/B/8_3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LiLisaZhi/STADE-CDNet/946b0bc29ff630ce30c72ed16cfb18bf5b6151e3/STADE-CDNet/samples_DSIFN/B/8_3.png
--------------------------------------------------------------------------------
/STADE-CDNet/samples_DSIFN/B/9_3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LiLisaZhi/STADE-CDNet/946b0bc29ff630ce30c72ed16cfb18bf5b6151e3/STADE-CDNet/samples_DSIFN/B/9_3.png
--------------------------------------------------------------------------------
/STADE-CDNet/samples_DSIFN/label/0_2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LiLisaZhi/STADE-CDNet/946b0bc29ff630ce30c72ed16cfb18bf5b6151e3/STADE-CDNet/samples_DSIFN/label/0_2.png
--------------------------------------------------------------------------------
/STADE-CDNet/samples_DSIFN/label/1_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LiLisaZhi/STADE-CDNet/946b0bc29ff630ce30c72ed16cfb18bf5b6151e3/STADE-CDNet/samples_DSIFN/label/1_1.png
--------------------------------------------------------------------------------
/STADE-CDNet/samples_DSIFN/label/2_4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LiLisaZhi/STADE-CDNet/946b0bc29ff630ce30c72ed16cfb18bf5b6151e3/STADE-CDNet/samples_DSIFN/label/2_4.png
--------------------------------------------------------------------------------
/STADE-CDNet/samples_DSIFN/label/3_4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LiLisaZhi/STADE-CDNet/946b0bc29ff630ce30c72ed16cfb18bf5b6151e3/STADE-CDNet/samples_DSIFN/label/3_4.png
--------------------------------------------------------------------------------
/STADE-CDNet/samples_DSIFN/label/4_4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LiLisaZhi/STADE-CDNet/946b0bc29ff630ce30c72ed16cfb18bf5b6151e3/STADE-CDNet/samples_DSIFN/label/4_4.png
--------------------------------------------------------------------------------
/STADE-CDNet/samples_DSIFN/label/5_3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LiLisaZhi/STADE-CDNet/946b0bc29ff630ce30c72ed16cfb18bf5b6151e3/STADE-CDNet/samples_DSIFN/label/5_3.png
--------------------------------------------------------------------------------
/STADE-CDNet/samples_DSIFN/label/6_3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LiLisaZhi/STADE-CDNet/946b0bc29ff630ce30c72ed16cfb18bf5b6151e3/STADE-CDNet/samples_DSIFN/label/6_3.png
--------------------------------------------------------------------------------
/STADE-CDNet/samples_DSIFN/label/7_4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LiLisaZhi/STADE-CDNet/946b0bc29ff630ce30c72ed16cfb18bf5b6151e3/STADE-CDNet/samples_DSIFN/label/7_4.png
--------------------------------------------------------------------------------
/STADE-CDNet/samples_DSIFN/label/8_3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LiLisaZhi/STADE-CDNet/946b0bc29ff630ce30c72ed16cfb18bf5b6151e3/STADE-CDNet/samples_DSIFN/label/8_3.png
--------------------------------------------------------------------------------
/STADE-CDNet/samples_DSIFN/label/9_3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LiLisaZhi/STADE-CDNet/946b0bc29ff630ce30c72ed16cfb18bf5b6151e3/STADE-CDNet/samples_DSIFN/label/9_3.png
--------------------------------------------------------------------------------
/STADE-CDNet/samples_DSIFN/list/demo.txt:
--------------------------------------------------------------------------------
1 | 9_3.png
2 | 8_3.png
3 | 7_4.png
4 | 6_3.png
5 | 5_3.png
6 | 4_4.png
7 | 3_4.png
8 | 2_4.png
9 | 1_1.png
10 | 0_2.png
11 |
--------------------------------------------------------------------------------
/STADE-CDNet/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from torch.utils.data import DataLoader
4 | from torchvision import utils
5 |
6 | import data_config
7 | from datasets.CD_dataset import CDDataset
8 |
9 |
10 | def get_loader(data_name, img_size=256, batch_size="your need", split='test',
11 | is_train=False, dataset='CDDataset'):
12 | dataConfig = data_config.DataConfig().get_data_config(data_name)
13 | root_dir = dataConfig.root_dir
14 | label_transform = dataConfig.label_transform
15 |
16 | if dataset == 'CDDataset':
17 | data_set = CDDataset(root_dir=root_dir, split=split,
18 | img_size=img_size, is_train=is_train,
19 | label_transform=label_transform)
20 | else:
21 | raise NotImplementedError(
22 | 'Wrong dataset name %s (choose one from [CDDataset])'
23 | % dataset)
24 |
25 | shuffle = is_train
26 | dataloader = DataLoader(data_set, batch_size=batch_size,
27 | shuffle=shuffle, num_workers=4)
28 |
29 | return dataloader
30 |
31 |
32 | def get_loaders(args):
33 |
34 | data_name = args.data_name
35 | dataConfig = data_config.DataConfig().get_data_config(data_name)
36 | root_dir = dataConfig.root_dir
37 | label_transform = dataConfig.label_transform
38 | split = args.split
39 | split_val = 'val'
40 | if hasattr(args, 'split_val'):
41 | split_val = args.split_val
42 | if args.dataset == 'CDDataset':
43 | training_set = CDDataset(root_dir=root_dir, split=split,
44 | img_size=args.img_size,is_train=True,
45 | label_transform=label_transform)
46 | val_set = CDDataset(root_dir=root_dir, split=split_val,
47 | img_size=args.img_size,is_train=False,
48 | label_transform=label_transform)
49 | else:
50 | raise NotImplementedError(
51 | 'Wrong dataset name %s (choose one from [CDDataset,])'
52 | % args.dataset)
53 |
54 | datasets = {'train': training_set, 'val': val_set}
55 | dataloaders = {x: DataLoader(datasets[x], batch_size="your need",
56 | shuffle=True, num_workers=args.num_workers)
57 | for x in ['train', 'val']}
58 |
59 | return dataloaders
60 |
61 |
62 | def make_numpy_grid(tensor_data, pad_value=0,padding=0):
63 | tensor_data = tensor_data.detach()
64 | vis = utils.make_grid(tensor_data, pad_value=pad_value,padding=padding)
65 | vis = np.array(vis.cpu()).transpose((1,2,0))
66 | if vis.shape[2] == 1:
67 | vis = np.stack([vis, vis, vis], axis=-1)
68 | return vis
69 |
70 |
71 | def de_norm(tensor_data):
72 | return tensor_data * 0.5 + 0.5
73 |
74 |
75 | def get_device(args):
76 | # set gpu ids
77 | str_ids = args.gpu_ids.split(',')
78 | args.gpu_ids = []
79 | for str_id in str_ids:
80 | id = int(str_id)
81 | if id >= 0:
82 | args.gpu_ids.append(id)
83 | if len(args.gpu_ids) > 0:
84 | torch.cuda.set_device(args.gpu_ids[0])
--------------------------------------------------------------------------------
/image/1 (2).png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LiLisaZhi/STADE-CDNet/946b0bc29ff630ce30c72ed16cfb18bf5b6151e3/image/1 (2).png
--------------------------------------------------------------------------------
/image/11.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LiLisaZhi/STADE-CDNet/946b0bc29ff630ce30c72ed16cfb18bf5b6151e3/image/11.png
--------------------------------------------------------------------------------
/image/16.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LiLisaZhi/STADE-CDNet/946b0bc29ff630ce30c72ed16cfb18bf5b6151e3/image/16.png
--------------------------------------------------------------------------------
/image/22.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LiLisaZhi/STADE-CDNet/946b0bc29ff630ce30c72ed16cfb18bf5b6151e3/image/22.png
--------------------------------------------------------------------------------
/image/33.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LiLisaZhi/STADE-CDNet/946b0bc29ff630ce30c72ed16cfb18bf5b6151e3/image/33.png
--------------------------------------------------------------------------------
/image/4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LiLisaZhi/STADE-CDNet/946b0bc29ff630ce30c72ed16cfb18bf5b6151e3/image/4.png
--------------------------------------------------------------------------------
/image/44.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LiLisaZhi/STADE-CDNet/946b0bc29ff630ce30c72ed16cfb18bf5b6151e3/image/44.png
--------------------------------------------------------------------------------
/image/5.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LiLisaZhi/STADE-CDNet/946b0bc29ff630ce30c72ed16cfb18bf5b6151e3/image/5.png
--------------------------------------------------------------------------------
/image/55.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LiLisaZhi/STADE-CDNet/946b0bc29ff630ce30c72ed16cfb18bf5b6151e3/image/55.jpg
--------------------------------------------------------------------------------
/image/6.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LiLisaZhi/STADE-CDNet/946b0bc29ff630ce30c72ed16cfb18bf5b6151e3/image/6.png
--------------------------------------------------------------------------------
/image/66.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LiLisaZhi/STADE-CDNet/946b0bc29ff630ce30c72ed16cfb18bf5b6151e3/image/66.png
--------------------------------------------------------------------------------
/image/7.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LiLisaZhi/STADE-CDNet/946b0bc29ff630ce30c72ed16cfb18bf5b6151e3/image/7.png
--------------------------------------------------------------------------------
/image/77.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LiLisaZhi/STADE-CDNet/946b0bc29ff630ce30c72ed16cfb18bf5b6151e3/image/77.png
--------------------------------------------------------------------------------