├── compare
├── Changer.py
├── __pycache__
│ ├── A2Net.cpython-38.pyc
│ ├── FC_EF.cpython-38.pyc
│ ├── IFNet.cpython-38.pyc
│ ├── DMINet.cpython-38.pyc
│ ├── DTCDSCN.cpython-38.pyc
│ ├── SNUNet.cpython-38.pyc
│ ├── TFI_GR.cpython-38.pyc
│ ├── resnet.cpython-38.pyc
│ ├── MobileNet.cpython-38.pyc
│ ├── NestedUNet.cpython-38.pyc
│ ├── resnet_tfi.cpython-38.pyc
│ ├── ChangeFormer.cpython-38.pyc
│ ├── FC_Siam_conc.cpython-38.pyc
│ └── FC_Siam_diff.cpython-38.pyc
├── resbase.py
├── MobileNet.py
├── NestedUNet.py
├── SNUNet.py
├── FC_EF.py
├── IFNet.py
├── DASNet.py
├── FC_Siam_conc.py
├── FC_Siam_diff.py
├── TFI_GR.py
├── resnet_tfi.py
├── DMINet.py
├── A2Net.py
└── DTCDSCN.py
├── gitignore
├── .idea
├── .gitignore
├── misc.xml
├── vcs.xml
├── inspectionProfiles
│ └── profiles_settings.xml
├── modules.xml
└── SEIFNet-main.iml
├── models
├── .idea
│ ├── .gitignore
│ ├── misc.xml
│ ├── vcs.xml
│ ├── inspectionProfiles
│ │ └── profiles_settings.xml
│ ├── modules.xml
│ └── models.iml
├── __pycache__
│ ├── ASFF.cpython-38.pyc
│ ├── CBAM.cpython-38.pyc
│ ├── DEFM.cpython-38.pyc
│ ├── TAM.cpython-38.pyc
│ ├── Models.cpython-38.pyc
│ ├── losses.cpython-38.pyc
│ ├── resnet.cpython-38.pyc
│ ├── FT_loss.cpython-38.pyc
│ ├── evaluator.cpython-38.pyc
│ ├── networks.cpython-38.pyc
│ ├── trainer.cpython-38.pyc
│ ├── Focal_loss.cpython-38.pyc
│ ├── siamunet_dif.cpython-38.pyc
│ └── swin_transformer.cpython-38.pyc
├── CBAM.py
└── evaluator.py
├── __pycache__
├── utils_.cpython-38.pyc
├── data_config.cpython-38.pyc
└── main_train.cpython-38.pyc
├── utils
├── __pycache__
│ ├── helpers.cpython-38.pyc
│ ├── losses.cpython-38.pyc
│ ├── metrics.cpython-38.pyc
│ ├── parser.cpython-38.pyc
│ ├── transforms.cpython-38.pyc
│ └── dataloaders.cpython-38.pyc
├── losses.py
├── parser.py
├── dataloaders.py
├── helpers.py
├── metrics.py
└── transforms.py
├── misc
├── __pycache__
│ ├── logger_tool.cpython-38.pyc
│ └── metric_tool.cpython-38.pyc
├── pyutils.py
├── logger_tool.py
└── metric_tool.py
├── datasets
├── __pycache__
│ ├── CD_dataset.cpython-38.pyc
│ └── data_utils.cpython-38.pyc
├── CD_dataset.py
└── data_utils.py
├── data_config.py
├── README.md
├── eval.py
├── utils_.py
├── eval_cd.py
├── main_train.py
├── checkpoints
└── LEVIR-CD_SEIFNet_ce_Adamw_0.0001_200
│ └── log.txt
└── train.py
/compare/Changer.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/gitignore:
--------------------------------------------------------------------------------
1 | .idea
2 | __pycache__
3 | checkpoints
--------------------------------------------------------------------------------
/.idea/.gitignore:
--------------------------------------------------------------------------------
1 | # Default ignored files
2 | /shelf/
3 | /workspace.xml
4 |
--------------------------------------------------------------------------------
/models/.idea/.gitignore:
--------------------------------------------------------------------------------
1 | # Default ignored files
2 | /shelf/
3 | /workspace.xml
4 |
--------------------------------------------------------------------------------
/__pycache__/utils_.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lixinghua5540/SEIFNet/HEAD/__pycache__/utils_.cpython-38.pyc
--------------------------------------------------------------------------------
/__pycache__/data_config.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lixinghua5540/SEIFNet/HEAD/__pycache__/data_config.cpython-38.pyc
--------------------------------------------------------------------------------
/__pycache__/main_train.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lixinghua5540/SEIFNet/HEAD/__pycache__/main_train.cpython-38.pyc
--------------------------------------------------------------------------------
/models/__pycache__/ASFF.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lixinghua5540/SEIFNet/HEAD/models/__pycache__/ASFF.cpython-38.pyc
--------------------------------------------------------------------------------
/models/__pycache__/CBAM.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lixinghua5540/SEIFNet/HEAD/models/__pycache__/CBAM.cpython-38.pyc
--------------------------------------------------------------------------------
/models/__pycache__/DEFM.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lixinghua5540/SEIFNet/HEAD/models/__pycache__/DEFM.cpython-38.pyc
--------------------------------------------------------------------------------
/models/__pycache__/TAM.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lixinghua5540/SEIFNet/HEAD/models/__pycache__/TAM.cpython-38.pyc
--------------------------------------------------------------------------------
/compare/__pycache__/A2Net.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lixinghua5540/SEIFNet/HEAD/compare/__pycache__/A2Net.cpython-38.pyc
--------------------------------------------------------------------------------
/compare/__pycache__/FC_EF.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lixinghua5540/SEIFNet/HEAD/compare/__pycache__/FC_EF.cpython-38.pyc
--------------------------------------------------------------------------------
/compare/__pycache__/IFNet.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lixinghua5540/SEIFNet/HEAD/compare/__pycache__/IFNet.cpython-38.pyc
--------------------------------------------------------------------------------
/models/__pycache__/Models.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lixinghua5540/SEIFNet/HEAD/models/__pycache__/Models.cpython-38.pyc
--------------------------------------------------------------------------------
/models/__pycache__/losses.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lixinghua5540/SEIFNet/HEAD/models/__pycache__/losses.cpython-38.pyc
--------------------------------------------------------------------------------
/models/__pycache__/resnet.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lixinghua5540/SEIFNet/HEAD/models/__pycache__/resnet.cpython-38.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/helpers.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lixinghua5540/SEIFNet/HEAD/utils/__pycache__/helpers.cpython-38.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/losses.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lixinghua5540/SEIFNet/HEAD/utils/__pycache__/losses.cpython-38.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/metrics.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lixinghua5540/SEIFNet/HEAD/utils/__pycache__/metrics.cpython-38.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/parser.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lixinghua5540/SEIFNet/HEAD/utils/__pycache__/parser.cpython-38.pyc
--------------------------------------------------------------------------------
/compare/__pycache__/DMINet.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lixinghua5540/SEIFNet/HEAD/compare/__pycache__/DMINet.cpython-38.pyc
--------------------------------------------------------------------------------
/compare/__pycache__/DTCDSCN.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lixinghua5540/SEIFNet/HEAD/compare/__pycache__/DTCDSCN.cpython-38.pyc
--------------------------------------------------------------------------------
/compare/__pycache__/SNUNet.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lixinghua5540/SEIFNet/HEAD/compare/__pycache__/SNUNet.cpython-38.pyc
--------------------------------------------------------------------------------
/compare/__pycache__/TFI_GR.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lixinghua5540/SEIFNet/HEAD/compare/__pycache__/TFI_GR.cpython-38.pyc
--------------------------------------------------------------------------------
/compare/__pycache__/resnet.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lixinghua5540/SEIFNet/HEAD/compare/__pycache__/resnet.cpython-38.pyc
--------------------------------------------------------------------------------
/misc/__pycache__/logger_tool.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lixinghua5540/SEIFNet/HEAD/misc/__pycache__/logger_tool.cpython-38.pyc
--------------------------------------------------------------------------------
/misc/__pycache__/metric_tool.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lixinghua5540/SEIFNet/HEAD/misc/__pycache__/metric_tool.cpython-38.pyc
--------------------------------------------------------------------------------
/models/__pycache__/FT_loss.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lixinghua5540/SEIFNet/HEAD/models/__pycache__/FT_loss.cpython-38.pyc
--------------------------------------------------------------------------------
/models/__pycache__/evaluator.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lixinghua5540/SEIFNet/HEAD/models/__pycache__/evaluator.cpython-38.pyc
--------------------------------------------------------------------------------
/models/__pycache__/networks.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lixinghua5540/SEIFNet/HEAD/models/__pycache__/networks.cpython-38.pyc
--------------------------------------------------------------------------------
/models/__pycache__/trainer.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lixinghua5540/SEIFNet/HEAD/models/__pycache__/trainer.cpython-38.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/transforms.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lixinghua5540/SEIFNet/HEAD/utils/__pycache__/transforms.cpython-38.pyc
--------------------------------------------------------------------------------
/compare/__pycache__/MobileNet.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lixinghua5540/SEIFNet/HEAD/compare/__pycache__/MobileNet.cpython-38.pyc
--------------------------------------------------------------------------------
/compare/__pycache__/NestedUNet.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lixinghua5540/SEIFNet/HEAD/compare/__pycache__/NestedUNet.cpython-38.pyc
--------------------------------------------------------------------------------
/compare/__pycache__/resnet_tfi.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lixinghua5540/SEIFNet/HEAD/compare/__pycache__/resnet_tfi.cpython-38.pyc
--------------------------------------------------------------------------------
/models/__pycache__/Focal_loss.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lixinghua5540/SEIFNet/HEAD/models/__pycache__/Focal_loss.cpython-38.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/dataloaders.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lixinghua5540/SEIFNet/HEAD/utils/__pycache__/dataloaders.cpython-38.pyc
--------------------------------------------------------------------------------
/compare/__pycache__/ChangeFormer.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lixinghua5540/SEIFNet/HEAD/compare/__pycache__/ChangeFormer.cpython-38.pyc
--------------------------------------------------------------------------------
/compare/__pycache__/FC_Siam_conc.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lixinghua5540/SEIFNet/HEAD/compare/__pycache__/FC_Siam_conc.cpython-38.pyc
--------------------------------------------------------------------------------
/compare/__pycache__/FC_Siam_diff.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lixinghua5540/SEIFNet/HEAD/compare/__pycache__/FC_Siam_diff.cpython-38.pyc
--------------------------------------------------------------------------------
/datasets/__pycache__/CD_dataset.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lixinghua5540/SEIFNet/HEAD/datasets/__pycache__/CD_dataset.cpython-38.pyc
--------------------------------------------------------------------------------
/datasets/__pycache__/data_utils.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lixinghua5540/SEIFNet/HEAD/datasets/__pycache__/data_utils.cpython-38.pyc
--------------------------------------------------------------------------------
/models/__pycache__/siamunet_dif.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lixinghua5540/SEIFNet/HEAD/models/__pycache__/siamunet_dif.cpython-38.pyc
--------------------------------------------------------------------------------
/models/__pycache__/swin_transformer.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lixinghua5540/SEIFNet/HEAD/models/__pycache__/swin_transformer.cpython-38.pyc
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/models/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/models/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/models/.idea/inspectionProfiles/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/models/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/models/.idea/models.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
--------------------------------------------------------------------------------
/.idea/SEIFNet-main.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
--------------------------------------------------------------------------------
/utils/losses.py:
--------------------------------------------------------------------------------
1 | from utils.parser import get_parser_with_args
2 | from utils.metrics import FocalLoss, dice_loss
3 |
4 | parser, metadata = get_parser_with_args()
5 | opt = parser.parse_args()
6 |
7 | def hybrid_loss(predictions, target):
8 | """Calculating the loss"""
9 | loss = 0
10 |
11 | # gamma=0, alpha=None --> CE
12 | focal = FocalLoss(gamma=0, alpha=None)
13 |
14 | for prediction in predictions:
15 |
16 | bce = focal(prediction, target)
17 | dice = dice_loss(prediction, target)
18 | loss += bce + dice
19 |
20 | return loss
21 |
22 |
--------------------------------------------------------------------------------
/misc/pyutils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 | import random
4 | import glob
5 |
6 |
7 | def seed_random(seed=2020):
8 | # 加入以下随机种子,数据输入,随机扩充等保持一致
9 | random.seed(seed)
10 | os.environ['PYTHONHASHSEED'] = str(seed)
11 | np.random.seed(seed)
12 |
13 |
14 | def mkdir(path):
15 | """create a single empty directory if it didn't exist
16 |
17 | Parameters:
18 | path (str) -- a single directory path
19 | """
20 | if not os.path.exists(path):
21 | os.makedirs(path)
22 |
23 |
24 | def get_paths(image_folder_path, suffix='*.png'):
25 | """从文件夹中返回指定格式的文件
26 | :param image_folder_path: str
27 | :param suffix: str
28 | :return: list
29 | """
30 | paths = sorted(glob.glob(os.path.join(image_folder_path, suffix)))
31 | return paths
32 |
33 |
34 | def get_paths_from_list(image_folder_path, list):
35 | """从image folder中找到list中的文件,返回path list"""
36 | out = []
37 | for item in list:
38 | path = os.path.join(image_folder_path,item)
39 | out.append(path)
40 | return sorted(out)
41 |
42 |
43 |
--------------------------------------------------------------------------------
/utils/parser.py:
--------------------------------------------------------------------------------
1 | import argparse as ag
2 | import json
3 |
4 | def get_parser_with_args(metadata_json='metadata.json'):
5 | parser = ag.ArgumentParser(description='Training change detection network')
6 |
7 | with open(metadata_json, 'r') as fin:
8 | metadata = json.load(fin)
9 | parser.set_defaults(**metadata)
10 |
11 | #project save
12 | parser.add_argument('--project name', default='LEVIR+_ViTAE_BIT_bce_100', type=str)
13 | parser.add_argument('--path', default='checkpoints', type=str, help='path of saved model')
14 | # parser.add_argument('--checkpoint_root', default='checkpoints', type=str)
15 |
16 | #network
17 | parser.add_argument('--backbone', default='vitae', type=str, choices=['resnet','swin','vitae'], help='type of model')
18 |
19 | parser.add_argument('--dataset', default='levir+', type=str, choices=['cdd','levir','levir+'], help='type of dataset')
20 |
21 | parser.add_argument('--mode', default='rsp_100', type=str, choices=['imp','rsp_40', 'rsp_100', 'rsp_120' , 'rsp_300', 'rsp_300_sgd', 'seco'], help='type of pretrn')
22 |
23 |
24 | return parser, metadata
25 |
26 |
--------------------------------------------------------------------------------
/data_config.py:
--------------------------------------------------------------------------------
1 |
2 | class DataConfig:
3 | data_name = ""
4 | root_dir = ""
5 | label_transform = "norm"
6 | def get_data_config(self, data_name):
7 | self.data_name = data_name
8 | if data_name == 'LEVIR':
9 | self.root_dir = 'E:/CDDataset/LEVIR'
10 | elif data_name == 'DSIFN':
11 | self.label_transform = "norm"
12 | self.root_dir = 'E:/CDDataset/DSIFN_256'
13 | elif data_name == 'SYSU-CD':
14 | self.label_transform = "norm"
15 | self.root_dir = 'E:/CDDataset/SYSU-CD'
16 | elif data_name == 'LEVIR+':
17 | self.label_transform = "norm"
18 | self.root_dir = 'E:/CDDataset/LEVIR-CD+_256'
19 | elif data_name == 'BBCD':
20 | self.label_transform = "norm"
21 | self.root_dir = 'E:/CDDataset/Big_Building_ChangeDetection'
22 | elif data_name == 'GZ_CD':
23 | self.label_transform = "norm"
24 | self.root_dir = 'E:/CDDataset/GZ'
25 | elif data_name == 'WHU-CD':
26 | self.label_transform = "norm"
27 | self.root_dir = 'E:/CDDataset/WHU-CUT'
28 | elif data_name == 'test':
29 | self.label_transform = "norm"
30 | self.root_dir = 'E:/CDDataset/att_test_whu'
31 | elif data_name == 'quick_start':
32 | self.root_dir = './samples/'
33 | else:
34 | raise TypeError('%s has not defined' % data_name)
35 | return self
36 |
37 |
38 | if __name__ == '__main__':
39 | data = DataConfig().get_data_config(data_name='LEVIR')
40 | print(data.data_name)
41 | print(data.root_dir)
42 | print(data.label_transform)
43 |
44 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # SEIFNet
2 |
3 | Y. Huang, X. Li, Z. Du, and H. Shen, “Spatiotemporal Enhancement and Interlevel Fusion Network for Remote Sensing Images Change Detection,” IEEE Transactions on Geoscience and Remote Sensing, vol. 61, 2024.
4 |
5 | Abstract:
6 | Remote sensing (RS) images change detection (CD) plays a crucial role in monitoring surface dynamic. However, current deep learning (DL)-based CD methods still suffer from pseudo changes and scale variations due to inadequate exploration of temporal differences and under-utilization of multiscale features. Based on the aforementioned considerations, a spatiotemporal enhancement and interlevel fusion network (SEIFNet) is proposed to improve the ability of feature representation for changing objects. Firstly, the multilevel feature maps are acquired from Siamese hierarchical backbone. To highlight the disparity in the same location at different times, the spatiotemporal difference enhancement modules (ST-DEM) are introduced to capture global and local information from bitemporal feature maps at each level. Coordinate attention and cascaded convolutions are adopted in subtraction and connection branches, respectively. Then, an adaptive context fusion module (ACFM) is designed to integrate interlevel features under the guidance of different semantic information, constituting a progressive decoder. Additionally, a plain refinement module and a concise summation-based prediction head are employed to enhance the boundary details and internal integrity of CD results. The experimental results validate the superiority of our lightweight network over 8 state-of-the-art (SOTA) methods on LEVIR-CD, SYSU-CD and WHU-CD datasets, both in accuracy and efficiency. Also, the effects of different types of backbones and differential enhancement modules are discussed in the ablation experiments in details.
7 |
8 | 
9 |
--------------------------------------------------------------------------------
/eval.py:
--------------------------------------------------------------------------------
1 | import torch.utils.data
2 | from utils.parser import get_parser_with_args
3 | from utils.helpers import get_test_loaders
4 | from tqdm import tqdm
5 | from sklearn.metrics import confusion_matrix
6 |
7 | # The Evaluation Methods in our paper are slightly different from this file.
8 | # In our paper, we use the evaluation methods in train.py. specifically, batch size is considered.
9 | # And the evaluation methods in this file usually produce higher numerical indicators.
10 |
11 | parser, metadata = get_parser_with_args()
12 | opt = parser.parse_args()
13 |
14 | dev = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
15 |
16 | if opt.dataset == 'cdd':
17 | opt.dataset_dir = '../Dataset/cdd_dataset/'
18 | elif opt.dataset == 'levir':
19 | opt.dataset_dir = '../Dataset/levir_patch_dataset/'
20 |
21 | test_loader = get_test_loaders(opt)
22 |
23 | #path = 'weights/sunet-32.pt' # the path of the model
24 | model = torch.load(opt.path, map_location='cpu')
25 |
26 | model.to(dev)
27 |
28 | c_matrix = {'tn': 0, 'fp': 0, 'fn': 0, 'tp': 0}
29 | model.eval()
30 |
31 | with torch.no_grad():
32 | tbar = tqdm(test_loader)
33 | for batch_img1, batch_img2, labels, fname in tbar:
34 |
35 | batch_img1 = batch_img1.float().to(dev)
36 | batch_img2 = batch_img2.float().to(dev)
37 | labels = labels.long().to(dev)
38 |
39 | cd_preds = model(batch_img1, batch_img2)
40 | #cd_preds = cd_preds[-1] # BIT输出不是tuple
41 | _, cd_preds = torch.max(cd_preds, 1)
42 |
43 | #print(cd_preds.shape, labels.shape, fname)
44 |
45 | tn, fp, fn, tp = confusion_matrix(labels.data.cpu().numpy().flatten(),
46 | cd_preds.data.cpu().numpy().flatten(),labels=[0,1]).ravel()
47 |
48 | c_matrix['tn'] += tn
49 | c_matrix['fp'] += fp
50 | c_matrix['fn'] += fn
51 | c_matrix['tp'] += tp
52 |
53 | tn, fp, fn, tp = c_matrix['tn'], c_matrix['fp'], c_matrix['fn'], c_matrix['tp']
54 | P = tp / (tp + fp)
55 | R = tp / (tp + fn)
56 | F1 = 2 * P * R / (R + P)
57 |
58 | print('Precision: {}\nRecall: {}\nF1-Score: {}'.format(P, R, F1))
59 |
--------------------------------------------------------------------------------
/misc/logger_tool.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import time
3 |
4 |
5 | class Logger(object):
6 | def __init__(self, outfile):
7 | self.terminal = sys.stdout
8 | self.log_path = outfile
9 | now = time.strftime("%c")
10 | self.write('================ (%s) ================\n' % now)
11 |
12 | def write(self, message):
13 | self.terminal.write(message)
14 | with open(self.log_path, mode='a') as f:
15 | f.write(message)
16 |
17 | def write_dict(self, dict):
18 | message = ''
19 | for k, v in dict.items():
20 | message += '%s: %.7f ' % (k, v)
21 | self.write(message)
22 |
23 | def write_dict_str(self, dict):
24 | message = ''
25 | for k, v in dict.items():
26 | message += '%s: %s ' % (k, v)
27 | self.write(message)
28 |
29 | def flush(self):
30 | self.terminal.flush()
31 |
32 |
33 | class Timer:
34 | def __init__(self, starting_msg = None):
35 | self.start = time.time()
36 | self.stage_start = self.start
37 |
38 | if starting_msg is not None:
39 | print(starting_msg, time.ctime(time.time()))
40 |
41 | def __enter__(self):
42 | return self
43 |
44 | def __exit__(self, exc_type, exc_val, exc_tb):
45 | return
46 |
47 | def update_progress(self, progress):
48 | self.elapsed = time.time() - self.start
49 | self.est_total = self.elapsed / progress
50 | self.est_remaining = self.est_total - self.elapsed
51 | self.est_finish = int(self.start + self.est_total)
52 |
53 |
54 | def str_estimated_complete(self):
55 | return str(time.ctime(self.est_finish))
56 |
57 | def str_estimated_remaining(self):
58 | return str(self.est_remaining/3600) + 'h'
59 |
60 | def estimated_remaining(self):
61 | return self.est_remaining/3600
62 |
63 | def get_stage_elapsed(self):
64 | return time.time() - self.stage_start
65 |
66 | def reset_stage(self):
67 | self.stage_start = time.time()
68 |
69 | def lapse(self):
70 | out = time.time() - self.stage_start
71 | self.stage_start = time.time()
72 | return out
73 |
74 |
--------------------------------------------------------------------------------
/compare/resbase.py:
--------------------------------------------------------------------------------
1 | ###########################################################################
2 | # Created by: Hang Zhang
3 | # Email: zhang.hang@rutgers.edu
4 | # Copyright (c) 2017
5 | ###########################################################################
6 |
7 | import torch.nn as nn
8 |
9 |
10 | import resnet
11 |
12 | up_kwargs = {'mode': 'bilinear', 'align_corners': True}
13 |
14 | __all__ = ['BaseNet']
15 |
16 | class BaseNet(nn.Module):
17 | def __init__(self, nclass, backbone, dilated=True, norm_layer=None,root='./pretrain_models',
18 | multi_grid=False, multi_dilation=None):
19 | super(BaseNet, self).__init__()
20 |
21 | # copying modules from pretrained models
22 | if backbone == 'resnet34':
23 | self.pretrained = resnet.resnet34(pretrained=False, dilated=dilated,
24 | norm_layer=norm_layer, root=root,
25 | multi_grid=multi_grid, multi_dilation=multi_dilation)
26 | elif backbone == 'resnet50':
27 | self.pretrained = resnet.resnet50(pretrained=True, dilated=dilated,
28 | norm_layer=norm_layer, root=root,
29 | multi_grid=multi_grid, multi_dilation=multi_dilation)
30 | elif backbone == 'resnet101':
31 | self.pretrained = resnet.resnet101(pretrained=True, dilated=dilated,
32 | norm_layer=norm_layer, root=root,
33 | multi_grid=multi_grid,multi_dilation=multi_dilation)
34 | elif backbone == 'resnet152':
35 | self.pretrained = resnet.resnet152(pretrained=False, dilated=dilated,
36 | norm_layer=norm_layer, root=root,
37 | multi_grid=multi_grid, multi_dilation=multi_dilation)
38 | else:
39 | raise RuntimeError('unknown backbone: {}'.format(backbone))
40 | # bilinear upsample options
41 | self._up_kwargs = up_kwargs
42 |
43 | def base_forward(self, x):
44 | x = self.pretrained.conv1(x)
45 | x = self.pretrained.bn1(x)
46 | x = self.pretrained.relu(x)
47 | x = self.pretrained.maxpool(x)
48 | c1 = self.pretrained.layer1(x)
49 | c2 = self.pretrained.layer2(c1)
50 | c3 = self.pretrained.layer3(c2)
51 | c4 = self.pretrained.layer4(c3)
52 | return c1, c2, c3, c4
53 |
--------------------------------------------------------------------------------
/utils_.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from torch.utils.data import DataLoader
4 | from torchvision import utils
5 |
6 | import data_config
7 | from datasets.CD_dataset import CDDataset
8 | import argparse
9 |
10 | def get_loader(data_name, img_size=256, batch_size=8, split='test',
11 | is_train=False, dataset='CDDataset'):
12 | dataConfig = data_config.DataConfig().get_data_config(data_name)
13 | root_dir = dataConfig.root_dir
14 | label_transform = dataConfig.label_transform
15 |
16 | if dataset == 'CDDataset':
17 | data_set = CDDataset(root_dir=root_dir, split=split,
18 | img_size=img_size, is_train=is_train,
19 | label_transform=label_transform)
20 | else:
21 | raise NotImplementedError(
22 | 'Wrong dataset name %s (choose one from [CDDataset])'
23 | % dataset)
24 |
25 | shuffle = is_train
26 | dataloader = DataLoader(data_set, batch_size=batch_size,
27 | shuffle=shuffle, num_workers=4)
28 |
29 |
30 | return dataloader
31 |
32 |
33 | def get_loaders(args):
34 |
35 | data_name = args.data_name
36 | dataConfig = data_config.DataConfig().get_data_config(data_name)
37 | root_dir = dataConfig.root_dir
38 | label_transform = dataConfig.label_transform
39 | split = args.split
40 | split_val = 'val'
41 | if hasattr(args, 'split_val'):
42 | split_val = args.split_val
43 | if args.dataset == 'CDDataset':
44 | training_set = CDDataset(root_dir=root_dir, split=split,
45 | img_size=args.img_size,is_train=True,
46 | label_transform=label_transform)
47 | val_set = CDDataset(root_dir=root_dir, split=split_val,
48 | img_size=args.img_size,is_train=False,
49 | label_transform=label_transform)
50 | else:
51 | raise NotImplementedError(
52 | 'Wrong dataset name %s (choose one from [CDDataset,])'
53 | % args.dataset)
54 |
55 | datasets = {'train': training_set, 'val': val_set}
56 | dataloaders = {x: DataLoader(datasets[x], batch_size=args.batch_size,
57 | shuffle=True, num_workers=args.num_workers)
58 | for x in ['train', 'val']}
59 |
60 | return dataloaders
61 |
62 |
63 | def make_numpy_grid(tensor_data, pad_value=0,padding=0):
64 | tensor_data = tensor_data.detach()
65 | vis = utils.make_grid(tensor_data, pad_value=pad_value,padding=padding)
66 | vis = np.array(vis.cpu()).transpose((1,2,0))
67 | if vis.shape[2] == 1:
68 | vis = np.stack([vis, vis, vis], axis=-1)
69 | return vis
70 |
71 |
72 | def de_norm(tensor_data):
73 | return tensor_data * 0.5 + 0.5
74 |
75 |
76 | def get_device(args):
77 | # set gpu ids
78 | str_ids = args.gpu_ids.split(',')
79 | args.gpu_ids = []
80 | for str_id in str_ids:
81 | id = int(str_id)
82 | if id >= 0:
83 | args.gpu_ids.append(id)
84 | if len(args.gpu_ids) > 0:
85 | torch.cuda.set_device(args.gpu_ids[0])
86 |
87 | def str2bool(v):
88 | if v.lower() in ['true', 1]:
89 | return True
90 | elif v.lower() in ['false', 0]:
91 | return False
92 | else:
93 | raise argparse.ArgumentTypeError('Boolean value expected.')
--------------------------------------------------------------------------------
/eval_cd.py:
--------------------------------------------------------------------------------
1 | from argparse import ArgumentParser
2 | import torch
3 | from models.evaluator import *
4 | from utils_ import str2bool
5 | print(torch.cuda.is_available())
6 |
7 |
8 | """
9 | eval the CD model
10 | """
11 |
12 | def main():
13 | # ------------
14 | # args
15 | # ------------
16 | parser = ArgumentParser()
17 | parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU')
18 | parser.add_argument('--project_name', default='LEVIR_SEIFNet_ce_Adamw_0.0001_150', type=str)
19 | parser.add_argument('--print_models', default=False, type=bool, help='print models')
20 | parser.add_argument('--checkpoints_root', default='checkpoints', type=str)
21 | parser.add_argument('--vis_root', default='vis', type=str)
22 |
23 | # data
24 | parser.add_argument('--num_workers', default=8, type=int)
25 | parser.add_argument('--dataset', default='CDDataset', type=str)
26 | parser.add_argument('--data_name', default='LEVIR', type=str)
27 |
28 | parser.add_argument('--batch_size', default=1, type=int)
29 | parser.add_argument('--split', default="test", type=str)
30 |
31 | parser.add_argument('--img_size', default=256, type=int)
32 |
33 | # model
34 | parser.add_argument('--n_class', default=2, type=int)
35 | parser.add_argument('--embed_dim', default=64, type=int)
36 | parser.add_argument('--net_G', default='SEIFNet', type=str,
37 | help='base_resnet18 | base_transformer_pos_s4_dd8 | base_transformer_pos_s4_dd8_dedim8|'
38 | 'vitae_transformer|SEIFNet')
39 | parser.add_argument('--backbone', default='L-Backbone', type=str, choices=['resnet', 'swin', 'vitae'],
40 | help='type of model')
41 | parser.add_argument('--mode', default='None', type=str,
42 | choices=['imp', 'rsp_40', 'rsp_100', 'rsp_120', 'rsp_300', 'rsp_300_sgd', 'seco'],
43 | help='type of pretrn')
44 | parser.add_argument('--deep_supervision', default=False, type=str2bool) # UNet++和A2net时为True,不需要时为False
45 | # parser.add_argument('--loss1', default='Focal', type=str,help='ce|BL_ce|Focal|Focal_Dice|Focal_Dice_BL|Focal_BL|BL_Focal|Focal_BF_IOU')
46 | # parser.add_argument('--loss2', default='BL_Focal', type=str,
47 | # help='ce|BL_ce|Focal|Focal_Dice|Focal_Dice_BL|Focal_BL|BL_Focal|Focal_BF_IOU')
48 | parser.add_argument('--loss_SD', default=True, type=str2bool) # 只有CD_Net才为True
49 |
50 | parser.add_argument('--checkpoint_name', default='best_ckpt.pt', type=str)
51 |
52 | args = parser.parse_args()
53 | utils_.get_device(args)
54 | print(args.gpu_ids)
55 |
56 | # checkpoints dir
57 | args.checkpoint_dir = os.path.join(args.checkpoints_root, args.project_name)
58 | os.makedirs(args.checkpoint_dir, exist_ok=True)
59 | # visualize dir
60 | args.vis_dir = os.path.join(args.vis_root, args.project_name)
61 | os.makedirs(args.vis_dir, exist_ok=True)
62 |
63 | dataloader = utils_.get_loader(args.data_name, img_size=args.img_size,
64 | batch_size=args.batch_size, is_train=False,
65 | split=args.split)
66 | model = CDEvaluator(args=args, dataloader=dataloader)
67 |
68 | model.eval_models(args=args,checkpoint_name=args.checkpoint_name)
69 |
70 |
71 | if __name__ == '__main__':
72 | main()
73 |
74 |
--------------------------------------------------------------------------------
/main_train.py:
--------------------------------------------------------------------------------
1 | from argparse import ArgumentParser
2 | import torch
3 | from models.trainer import *
4 | from utils_ import str2bool
5 | print(torch.cuda.is_available())
6 |
7 | """
8 | the main function for training the CD networks
9 | """
10 |
11 |
12 | def train(args):
13 | dataloaders = utils_.get_loaders(args)
14 | model = CDTrainer(args=args, dataloaders=dataloaders)
15 | model.train_models(args=args)
16 | # model.train_models()
17 |
18 |
19 | def test(args):
20 | from models.evaluator import CDEvaluator
21 | dataloader = utils_.get_loader(args.data_name, img_size=args.img_size,
22 | batch_size=args.batch_size, is_train=False,
23 | split='test')
24 | model = CDEvaluator(args=args, dataloader=dataloader)
25 |
26 | model.eval_models(args)
27 |
28 |
29 | if __name__ == '__main__':
30 | # ------------
31 | # args
32 | # ------------
33 | parser = ArgumentParser()
34 | parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU')
35 | parser.add_argument('--project_name', default='LEVIR-CD_SEIFNet_ce_Adamw_0.0001_200', type=str)
36 | #SYSU_res18_coMDE2_AFF_Bcedice_l0_Adamw_0.0001_200
37 | #LEVIR_transformer_CoDEM_AFF_ce_Adamw_0.0001_200_2
38 | parser.add_argument('--checkpoint_root', default='checkpoints', type=str)
39 |
40 | # data
41 | parser.add_argument('--num_workers', default=4, type=int)
42 | parser.add_argument('--dataset', default='CDDataset', type=str)
43 | parser.add_argument('--data_name', default='LEVIR', type=str,help='ChangeDetection|LEVIR|DSFIN|SYSU-CD|LEVIR+|BBCD|WHU-CD')
44 |
45 | parser.add_argument('--batch_size', default=2, type=int)
46 | parser.add_argument('--split', default="train", type=str)
47 | parser.add_argument('--split_val', default="val", type=str)
48 |
49 | parser.add_argument('--img_size', default=256, type=int)
50 |
51 | # model
52 | parser.add_argument('--n_class', default=2, type=int)
53 | parser.add_argument('--embed_dim', default=64, type=int)
54 | parser.add_argument('--net_G', default='SEIFNet', type=str,
55 | help='FC_EF | FC_Siam_conc | '
56 | 'FC_Siam_diff | UNet++|SNUNet|'
57 | 'DTCDSCN|IFNet|'
58 | 'base_transformer_pos_s4_dd8_dedim8|'
59 | 'ChangeFormer|'
60 | 'A2Net|DMINet|TFI-GR|'
61 | 'SEIFNet')
62 | parser.add_argument('--backbone', default='L-Backbone', type=str, choices=['resnet', 'swin', 'vitae','L-Backbone-cross','L-Backbone','BiFormer'],
63 | help='type of model')
64 | parser.add_argument('--mode', default='None', type=str,
65 | choices=['imp','res18 ','rsp_40', 'rsp_100', 'rsp_120', 'rsp_300', 'rsp_300_sgd', 'seco','None'],
66 | help='type of pretrn')
67 | parser.add_argument('--deep_supervision', default=False, type=str2bool)#UNet++和A2net时为True,不需要时为False
68 |
69 | parser.add_argument('--loss_SD', default=False,type=str2bool) #IFNet DMINet才为True
70 | # optimizer
71 | parser.add_argument('--optimizer', default='adamw', type=str)
72 | parser.add_argument('--lr', default=0.0001, type=float)
73 | parser.add_argument('--max_epochs', default=200, type=int) #150
74 | parser.add_argument('--lr_policy', default='linear', type=str,
75 | help='linear | step')
76 | parser.add_argument('--lr_decay_iters', default=200, type=int)
77 |
78 | args = parser.parse_args()
79 | utils_.get_device(args)
80 | print(args.gpu_ids)
81 |
82 | # checkpoints dir
83 | args.checkpoint_dir = os.path.join(args.checkpoint_root, args.project_name)
84 | os.makedirs(args.checkpoint_dir, exist_ok=True)
85 | # visualize dir
86 | args.vis_dir = os.path.join('vis', args.project_name)
87 | os.makedirs(args.vis_dir, exist_ok=True)
88 |
89 | train(args)
90 |
91 | test(args)
92 |
--------------------------------------------------------------------------------
/checkpoints/LEVIR-CD_SEIFNet_ce_Adamw_0.0001_200/log.txt:
--------------------------------------------------------------------------------
1 | ================ (Tue Jan 23 17:10:54 2024) ================
2 | gpu_ids: [0] project_name: LEVIR-CD_SEIFNet_ce_Adamw_0.0001_200 checkpoint_root: checkpoints num_workers: 4 dataset: CDDataset data_name: LEVIR batch_size: 8 split: train split_val: val img_size: 256 n_class: 2 embed_dim: 64 net_G: SEIFNet backbone: L-Backbone mode: None deep_supervision: False loss_SD: False optimizer: adamw lr: 0.0001 max_epochs: 200 lr_policy: linear lr_decay_iters: 200 checkpoint_dir: checkpoints\LEVIR-CD_SEIFNet_ce_Adamw_0.0001_200 vis_dir: vis\LEVIR-CD_SEIFNet_ce_Adamw_0.0001_200 lr: 0.0001000
3 | ================ (Tue Jan 23 17:15:15 2024) ================
4 | gpu_ids: [0] project_name: LEVIR-CD_SEIFNet_ce_Adamw_0.0001_200 checkpoint_root: checkpoints num_workers: 4 dataset: CDDataset data_name: LEVIR batch_size: 8 split: train split_val: val img_size: 256 n_class: 2 embed_dim: 64 net_G: SEIFNet backbone: L-Backbone mode: None deep_supervision: False loss_SD: False optimizer: adamw lr: 0.0001 max_epochs: 200 lr_policy: linear lr_decay_iters: 200 checkpoint_dir: checkpoints\LEVIR-CD_SEIFNet_ce_Adamw_0.0001_200 vis_dir: vis\LEVIR-CD_SEIFNet_ce_Adamw_0.0001_200 lr: 0.0001000
5 | ================ (Tue Jan 23 17:16:10 2024) ================
6 | gpu_ids: [0] project_name: LEVIR-CD_SEIFNet_ce_Adamw_0.0001_200 checkpoint_root: checkpoints num_workers: 4 dataset: CDDataset data_name: LEVIR batch_size: 8 split: train split_val: val img_size: 256 n_class: 2 embed_dim: 64 net_G: SEIFNet backbone: L-Backbone mode: None deep_supervision: False loss_SD: False optimizer: adamw lr: 0.0001 max_epochs: 200 lr_policy: linear lr_decay_iters: 200 checkpoint_dir: checkpoints\LEVIR-CD_SEIFNet_ce_Adamw_0.0001_200 vis_dir: vis\LEVIR-CD_SEIFNet_ce_Adamw_0.0001_200 lr: 0.0001000
7 | ================ (Tue Jan 23 17:16:37 2024) ================
8 | gpu_ids: [0] project_name: LEVIR-CD_SEIFNet_ce_Adamw_0.0001_200 checkpoint_root: checkpoints num_workers: 4 dataset: CDDataset data_name: LEVIR batch_size: 2 split: train split_val: val img_size: 256 n_class: 2 embed_dim: 64 net_G: SEIFNet backbone: L-Backbone mode: None deep_supervision: False loss_SD: False optimizer: adamw lr: 0.0001 max_epochs: 200 lr_policy: linear lr_decay_iters: 200 checkpoint_dir: checkpoints\LEVIR-CD_SEIFNet_ce_Adamw_0.0001_200 vis_dir: vis\LEVIR-CD_SEIFNet_ce_Adamw_0.0001_200 lr: 0.0001000
9 | ================ (Tue Jan 23 17:17:16 2024) ================
10 | gpu_ids: [0] project_name: LEVIR-CD_SEIFNet_ce_Adamw_0.0001_200 checkpoint_root: checkpoints num_workers: 4 dataset: CDDataset data_name: LEVIR batch_size: 2 split: train split_val: val img_size: 256 n_class: 2 embed_dim: 64 net_G: SEIFNet backbone: L-Backbone mode: None deep_supervision: False loss_SD: False optimizer: adamw lr: 0.0001 max_epochs: 200 lr_policy: linear lr_decay_iters: 200 checkpoint_dir: checkpoints\LEVIR-CD_SEIFNet_ce_Adamw_0.0001_200 vis_dir: vis\LEVIR-CD_SEIFNet_ce_Adamw_0.0001_200 lr: 0.0001000
11 | ================ (Tue Jan 23 17:21:21 2024) ================
12 | gpu_ids: [0] project_name: LEVIR-CD_SEIFNet_ce_Adamw_0.0001_200 checkpoint_root: checkpoints num_workers: 4 dataset: CDDataset data_name: LEVIR batch_size: 2 split: train split_val: val img_size: 256 n_class: 2 embed_dim: 64 net_G: SEIFNet backbone: L-Backbone mode: None deep_supervision: False loss_SD: False optimizer: adamw lr: 0.0001 max_epochs: 200 lr_policy: linear lr_decay_iters: 200 checkpoint_dir: checkpoints\LEVIR-CD_SEIFNet_ce_Adamw_0.0001_200 vis_dir: vis\LEVIR-CD_SEIFNet_ce_Adamw_0.0001_200 lr: 0.0001000
13 | ================ (Tue Jan 23 17:33:05 2024) ================
14 | gpu_ids: [0] project_name: LEVIR-CD_SEIFNet_ce_Adamw_0.0001_200 checkpoint_root: checkpoints num_workers: 4 dataset: CDDataset data_name: LEVIR batch_size: 2 split: train split_val: val img_size: 256 n_class: 2 embed_dim: 64 net_G: SEIFNet backbone: L-Backbone mode: None deep_supervision: False loss_SD: False optimizer: adamw lr: 0.0001 max_epochs: 200 lr_policy: linear lr_decay_iters: 200 checkpoint_dir: checkpoints\LEVIR-CD_SEIFNet_ce_Adamw_0.0001_200 vis_dir: vis\LEVIR-CD_SEIFNet_ce_Adamw_0.0001_200 lr: 0.0001000
15 | Is_training: True. [0,199][1,3560], imps: 0.83, est: 955.75h, G_loss: 0.60956, running_mf1: 0.47574
16 |
--------------------------------------------------------------------------------
/compare/MobileNet.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | # from torchvision.models.utils import load_state_dict_from_url
3 | from torch.hub import load_state_dict_from_url
4 | model_urls = {
5 | 'mobilenet_v2': 'https://download.pytorch.org/models/mobilenet_v2-b0353104.pth',
6 | }
7 |
8 |
9 | class ConvBNReLU(nn.Sequential):
10 | def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1, dilation=1):
11 | padding = (kernel_size - 1) // 2
12 | if dilation != 1:
13 | padding = dilation
14 | super(ConvBNReLU, self).__init__(
15 | nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, dilation=dilation,
16 | bias=False),
17 | nn.BatchNorm2d(out_planes),
18 | nn.ReLU6(inplace=True)
19 | )
20 |
21 |
22 | class InvertedResidual(nn.Module):
23 | def __init__(self, inp, oup, stride, expand_ratio, dilation=1):
24 | super(InvertedResidual, self).__init__()
25 | self.stride = stride
26 | assert stride in [1, 2]
27 |
28 | hidden_dim = int(round(inp * expand_ratio))
29 | self.use_res_connect = self.stride == 1 and inp == oup
30 |
31 | layers = []
32 | if expand_ratio != 1:
33 | # pw
34 | layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1))
35 | layers.extend([
36 | # dw
37 | ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim, dilation=dilation),
38 | # pw-linear
39 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
40 | nn.BatchNorm2d(oup),
41 | ])
42 | self.conv = nn.Sequential(*layers)
43 |
44 | def forward(self, x):
45 | if self.use_res_connect:
46 | return x + self.conv(x)
47 | else:
48 | return self.conv(x)
49 |
50 |
51 | class MobileNetV2(nn.Module):
52 | def __init__(self, pretrained=None, num_classes=1000, width_mult=1.0):
53 | super(MobileNetV2, self).__init__()
54 | block = InvertedResidual
55 | input_channel = 32
56 | last_channel = 1280
57 | inverted_residual_setting = [
58 | # t, c, n, s, d
59 | [1, 16, 1, 1, 1],
60 | [6, 24, 2, 2, 1],
61 | [6, 32, 3, 2, 1],
62 | [6, 64, 4, 2, 1],
63 | [6, 96, 3, 1, 1],
64 | [6, 160, 3, 2, 1],
65 | [6, 320, 1, 1, 1],
66 | ]
67 |
68 | # building first layer
69 | input_channel = int(input_channel * width_mult)
70 | self.last_channel = int(last_channel * max(1.0, width_mult))
71 | features = [ConvBNReLU(3, input_channel, stride=2)]
72 | # building inverted residual blocks
73 | for t, c, n, s, d in inverted_residual_setting:
74 | output_channel = int(c * width_mult)
75 | for i in range(n):
76 | stride = s if i == 0 else 1
77 | dilation = d if i == 0 else 1
78 | features.append(block(input_channel, output_channel, stride, expand_ratio=t, dilation=d))
79 | input_channel = output_channel
80 | # building last several layers
81 | features.append(ConvBNReLU(input_channel, self.last_channel, kernel_size=1))
82 | # make it nn.Sequential
83 | self.features = nn.Sequential(*features)
84 |
85 | # weight initialization
86 | for m in self.modules():
87 | if isinstance(m, nn.Conv2d):
88 | nn.init.kaiming_normal_(m.weight, mode='fan_out')
89 | if m.bias is not None:
90 | nn.init.zeros_(m.bias)
91 | elif isinstance(m, nn.BatchNorm2d):
92 | nn.init.ones_(m.weight)
93 | nn.init.zeros_(m.bias)
94 | elif isinstance(m, nn.Linear):
95 | nn.init.normal_(m.weight, 0, 0.01)
96 | nn.init.zeros_(m.bias)
97 |
98 | def forward(self, x):
99 | res = []
100 | for idx, m in enumerate(self.features):
101 | x = m(x)
102 | if idx in [1, 3, 6, 13, 17]:
103 | res.append(x)
104 | return res
105 |
106 |
107 | def mobilenet_v2(pretrained=True, progress=True, **kwargs):
108 | model = MobileNetV2(**kwargs)
109 | if pretrained:
110 | state_dict = load_state_dict_from_url(model_urls['mobilenet_v2'],
111 | progress=progress)
112 | print("loading imagenet pretrained mobilenetv2")
113 | model.load_state_dict(state_dict, strict=False)
114 | print("loaded imagenet pretrained mobilenetv2")
115 | return model
116 |
--------------------------------------------------------------------------------
/utils/dataloaders.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch.utils.data as data
3 | from PIL import Image
4 | from utils import transforms as tr
5 |
6 |
7 | '''
8 | Load all training and validation data paths
9 | '''
10 | def full_path_loader(data_dir):
11 | train_data = [i for i in os.listdir(data_dir + 'train/A/') if not
12 | i.startswith('.')]
13 | train_data.sort()
14 |
15 | valid_data = [i for i in os.listdir(data_dir + 'val/A/') if not
16 | i.startswith('.')]
17 | valid_data.sort()
18 |
19 | train_label_paths = []
20 | val_label_paths = []
21 | for img in train_data:
22 | train_label_paths.append(data_dir + 'train/label/' + img)
23 | for img in valid_data:
24 | val_label_paths.append(data_dir + 'val/label/' + img)
25 |
26 |
27 | train_data_path = []
28 | val_data_path = []
29 |
30 | for img in train_data:
31 | train_data_path.append([data_dir + 'train/', img])
32 | for img in valid_data:
33 | val_data_path.append([data_dir + 'val/', img])
34 |
35 | train_dataset = {}
36 | val_dataset = {}
37 | for cp in range(len(train_data)):
38 | train_dataset[cp] = {'image': train_data_path[cp],
39 | 'label': train_label_paths[cp]}
40 | for cp in range(len(valid_data)):
41 | val_dataset[cp] = {'image': val_data_path[cp],
42 | 'label': val_label_paths[cp]}
43 |
44 |
45 | return train_dataset, val_dataset
46 |
47 | '''
48 | Load all testing data paths
49 | '''
50 | def full_test_loader(data_dir):
51 |
52 | test_data = [i for i in os.listdir(data_dir + 'test/A/') if not
53 | i.startswith('.')]
54 | test_data.sort()
55 |
56 | test_label_paths = []
57 | for img in test_data:
58 | test_label_paths.append(data_dir + 'test/label/' + img)
59 |
60 | test_data_path = []
61 | for img in test_data:
62 | test_data_path.append([data_dir + 'test/', img])
63 |
64 | test_dataset = {}
65 | for cp in range(len(test_data)):
66 | test_dataset[cp] = {'image': test_data_path[cp],
67 | 'label': test_label_paths[cp]}
68 |
69 | return test_dataset
70 |
71 | def cdd_loader(img_path, label_path, aug):
72 | dir = img_path[0]
73 | name = img_path[1]
74 |
75 | img1 = Image.open(dir + 'A/' + name)
76 | img2 = Image.open(dir + 'B/' + name)
77 | label = Image.open(label_path)
78 | sample = {'image': (img1, img2), 'label': label}
79 |
80 | if aug:
81 | sample = tr.train_transforms(sample)
82 | else:
83 | sample = tr.test_transforms(sample)
84 |
85 | return sample['image'][0], sample['image'][1], sample['label'], name
86 |
87 |
88 | class CDDloader(data.Dataset):
89 |
90 | def __init__(self, full_load, flag = 'trn', aug=False):
91 |
92 | self.full_load = full_load
93 | self.loader = cdd_loader
94 | self.aug = aug
95 |
96 | print('load {} cdd {} pairs'.format(len(self.full_load), flag))
97 |
98 | def __getitem__(self, index):
99 |
100 | img_path, label_path = self.full_load[index]['image'], self.full_load[index]['label']
101 |
102 | return self.loader(img_path,
103 | label_path,
104 | self.aug)
105 |
106 | def __len__(self):
107 | return len(self.full_load)
108 |
109 | class LEVIRloader(data.Dataset):
110 |
111 | def __init__(self, full_load, flag = 'trn', aug=False):
112 |
113 | self.full_load = full_load
114 | self.loader = cdd_loader
115 | self.aug = aug
116 |
117 | print('load {} levir {} pairs'.format(len(self.full_load), flag))
118 |
119 | def __getitem__(self, index):
120 |
121 | img_path, label_path = self.full_load[index]['image'], self.full_load[index]['label']
122 |
123 | return self.loader(img_path,
124 | label_path,
125 | self.aug)
126 |
127 | def __len__(self):
128 | return len(self.full_load)
129 |
130 | class LEVIRplusloader(data.Dataset):
131 |
132 | def __init__(self, full_load, flag = 'trn', aug=False):
133 |
134 | self.full_load = full_load
135 | self.loader = cdd_loader
136 | self.aug = aug
137 |
138 | print('load {} levir {} pairs'.format(len(self.full_load), flag))
139 |
140 | def __getitem__(self, index):
141 |
142 | img_path, label_path = self.full_load[index]['image'], self.full_load[index]['label']
143 |
144 | return self.loader(img_path,
145 | label_path,
146 | self.aug)
147 |
148 | def __len__(self):
149 | return len(self.full_load)
--------------------------------------------------------------------------------
/datasets/CD_dataset.py:
--------------------------------------------------------------------------------
1 | """
2 | 变化检测数据集
3 | """
4 |
5 | import os
6 | from PIL import Image
7 | from PIL import Image,ImageFile
8 | ImageFile.LOAD_TRUNCATED_IMAGES = True
9 |
10 | import numpy as np
11 |
12 | from torch.utils import data
13 |
14 | from datasets.data_utils import CDDataAugmentation
15 |
16 |
17 | """
18 | CD data set with pixel-level labels;
19 | ├─image
20 | ├─image_post
21 | ├─label
22 | └─list
23 | """
24 | IMG_FOLDER_NAME = "A"
25 | IMG_POST_FOLDER_NAME = 'B'
26 | # IMG_SOBEL_FOLDER_NAME = 'diff'
27 | LIST_FOLDER_NAME = 'list'
28 | ANNOT_FOLDER_NAME = "label"
29 |
30 | IGNORE = 255
31 |
32 | label_suffix='.png' # jpg for gan dataset, others : png
33 |
34 | def load_img_name_list(dataset_path):
35 | img_name_list = np.loadtxt(dataset_path, dtype=np.str)
36 | if img_name_list.ndim == 2:
37 | return img_name_list[:, 0]
38 | return img_name_list
39 |
40 |
41 | def load_image_label_list_from_npy(npy_path, img_name_list):
42 | cls_labels_dict = np.load(npy_path, allow_pickle=True).item()
43 | return [cls_labels_dict[img_name] for img_name in img_name_list]
44 |
45 |
46 | def get_img_post_path(root_dir,split,img_name):
47 | return os.path.join(root_dir,split, IMG_POST_FOLDER_NAME, img_name)
48 |
49 | def get_sobel_path(root_dir,split,img_name):
50 | return os.path.join(root_dir,split, IMG_SOBEL_FOLDER_NAME, img_name)
51 |
52 |
53 | def get_img_path(root_dir,split, img_name):
54 | return os.path.join(root_dir, split,IMG_FOLDER_NAME, img_name)
55 |
56 |
57 | def get_label_path(root_dir,split, img_name):
58 | return os.path.join(root_dir, split,ANNOT_FOLDER_NAME, img_name.replace('.png', label_suffix))
59 |
60 |
61 | class ImageDataset(data.Dataset):
62 | """VOCdataloder"""
63 | def __init__(self, root_dir, split='train', img_size=256, is_train=True,to_tensor=True):
64 | super(ImageDataset, self).__init__()
65 | self.root_dir = root_dir
66 | self.img_size = img_size
67 | self.split = split # train | train_aug++ | val
68 | # self.list_path = self.root_dir + '/' + LIST_FOLDER_NAME + '/' + self.list + '.txt'
69 | self.list_path = os.path.join(self.root_dir, LIST_FOLDER_NAME, self.split+'.txt')
70 | self.img_name_list = load_img_name_list(self.list_path)
71 |
72 | self.A_size = len(self.img_name_list) # get the size of dataset A
73 | self.to_tensor = to_tensor
74 | if is_train:
75 | self.augm = CDDataAugmentation(
76 | img_size=self.img_size,
77 | with_random_hflip=True,
78 | with_random_vflip=True,
79 | with_scale_random_crop=True,
80 | with_random_blur=True,
81 | )
82 | else:
83 | self.augm = CDDataAugmentation(
84 | img_size=self.img_size
85 | )
86 | def __getitem__(self, index):
87 | name = self.img_name_list[index]
88 | A_path = get_img_path(self.root_dir,self.split, self.img_name_list[index % self.A_size])#得到时相1影像的文件路径
89 | B_path = get_img_post_path(self.root_dir, self.split,self.img_name_list[index % self.A_size])
90 |
91 | img = np.asarray(Image.open(A_path).convert('RGB'))#时相1
92 | img_B = np.asarray(Image.open(B_path).convert('RGB'))#时相2
93 |
94 | [img, img_B], _ = self.augm.transform([img, img_B],[], to_tensor=self.to_tensor)
95 |
96 | return {'A': img, 'B': img_B, 'name': name}
97 |
98 | def __len__(self):
99 | """Return the total number of images in the dataset."""
100 | return self.A_size
101 |
102 |
103 | class CDDataset(ImageDataset):
104 |
105 | def __init__(self, root_dir, img_size, split='train', is_train=True, label_transform=None,
106 | to_tensor=True):
107 | super(CDDataset, self).__init__(root_dir, img_size=img_size, split=split, is_train=is_train,
108 | to_tensor=to_tensor)
109 | self.label_transform = label_transform
110 |
111 | def __getitem__(self, index):
112 | name = self.img_name_list[index]
113 | A_path = get_img_path(self.root_dir,self.split,self.img_name_list[index % self.A_size])
114 | B_path = get_img_post_path(self.root_dir,self.split, self.img_name_list[index % self.A_size])
115 | img = np.asarray(Image.open(A_path).convert('RGB'))
116 | img_B = np.asarray(Image.open(B_path).convert('RGB'))
117 | #sobel差分特征图
118 | # sobel_path = get_sobel_path(self.root_dir,self.split, self.img_name_list[index % self.A_size])
119 | # sobel = np.asarray(Image.open(sobel_path).convert('RGB'))
120 |
121 | L_path = get_label_path(self.root_dir, self.split,self.img_name_list[index % self.A_size])
122 |
123 | label = np.array(Image.open(L_path), dtype=np.uint8)
124 |
125 | # 二分类中,前景标注为255
126 | if self.label_transform == 'norm':
127 | label = label // 255
128 | #
129 | # [img, img_B,sobel], [label] = self.augm.transform([img, img_B,sobel], [label], to_tensor=self.to_tensor)
130 | [img, img_B], [label] = self.augm.transform([img, img_B], [label], to_tensor=self.to_tensor)
131 | # print(label.max())
132 | # return {'name': name, 'A': img, 'B': img_B,'L':label, 'S': sobel}
133 | return {'name': name, 'A': img, 'B': img_B, 'L': label}
134 |
135 |
--------------------------------------------------------------------------------
/compare/NestedUNet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 | class VGGBlock(nn.Module):
5 | def __init__(self, in_channels, middle_channels, out_channels):
6 | super().__init__()
7 | self.relu = nn.ReLU(inplace=True)
8 | self.conv1 = nn.Conv2d(in_channels, middle_channels, 3, padding=1)
9 | self.bn1 = nn.BatchNorm2d(middle_channels)
10 | self.conv2 = nn.Conv2d(middle_channels, out_channels, 3, padding=1)
11 | self.bn2 = nn.BatchNorm2d(out_channels)
12 |
13 | def forward(self, x):
14 | out = self.conv1(x)
15 | out = self.bn1(out)
16 | out = self.relu(out)
17 |
18 | out = self.conv2(out)
19 | out = self.bn2(out)
20 | out = self.relu(out)
21 |
22 | return out
23 |
24 | class standard_unit(nn.Module):
25 | def __init__(self,in_channels,middle_channels,out_channels):
26 | super().__init__()
27 | self.relu = nn.ReLU(inplace=True)
28 | self.conv1 = nn.Conv2d(in_channels, middle_channels, 3, padding=1)
29 | self.bn1 = nn.BatchNorm2d(middle_channels)
30 | self.conv2 = nn.Conv2d(middle_channels, out_channels, 3, padding=1)
31 | self.bn2 = nn.BatchNorm2d(out_channels)
32 | # self.downsample = downsample
33 |
34 | def forward(self, x):
35 | out = self.conv1(x)
36 | identity=out
37 | # print(identity.size())#(8,32,256,256)
38 | out = self.bn1(out)
39 | out = self.relu(out)
40 |
41 | out = self.conv2(out)
42 | out = self.bn2(out)
43 | output = self.relu(out+identity)
44 | # print(output.size())
45 | # out = self.relu(output)
46 |
47 |
48 |
49 | # out=torch.add(out,out0)
50 |
51 | return output
52 |
53 | class NestedUNet(nn.Module):
54 | def __init__(self, num_classes, input_channels=3, deep_supervision=True, **kwargs):
55 | super().__init__()
56 |
57 | nb_filter = [32, 64, 128, 256, 512]
58 |
59 | self.deep_supervision = deep_supervision
60 |
61 | self.pool = nn.MaxPool2d(2, 2)
62 | self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
63 |
64 | self.conv0_0 = VGGBlock(input_channels, nb_filter[0], nb_filter[0])
65 | self.conv1_0 = VGGBlock(nb_filter[0], nb_filter[1], nb_filter[1])
66 | self.conv2_0 = VGGBlock(nb_filter[1], nb_filter[2], nb_filter[2])
67 | self.conv3_0 = VGGBlock(nb_filter[2], nb_filter[3], nb_filter[3])
68 | self.conv4_0 = VGGBlock(nb_filter[3], nb_filter[4], nb_filter[4])
69 |
70 | self.conv0_1 = VGGBlock(nb_filter[0]+nb_filter[1], nb_filter[0], nb_filter[0])
71 | self.conv1_1 = VGGBlock(nb_filter[1]+nb_filter[2], nb_filter[1], nb_filter[1])
72 | self.conv2_1 = VGGBlock(nb_filter[2]+nb_filter[3], nb_filter[2], nb_filter[2])
73 | self.conv3_1 = VGGBlock(nb_filter[3]+nb_filter[4], nb_filter[3], nb_filter[3])
74 |
75 | self.conv0_2 = VGGBlock(nb_filter[0]*2+nb_filter[1], nb_filter[0], nb_filter[0])
76 | self.conv1_2 = VGGBlock(nb_filter[1]*2+nb_filter[2], nb_filter[1], nb_filter[1])
77 | self.conv2_2 = VGGBlock(nb_filter[2]*2+nb_filter[3], nb_filter[2], nb_filter[2])
78 |
79 | self.conv0_3 = VGGBlock(nb_filter[0]*3+nb_filter[1], nb_filter[0], nb_filter[0])
80 | self.conv1_3 = VGGBlock(nb_filter[1]*3+nb_filter[2], nb_filter[1], nb_filter[1])
81 |
82 | self.conv0_4 = standard_unit(nb_filter[0]*4+nb_filter[1], nb_filter[0], nb_filter[0])
83 |
84 | if self.deep_supervision:
85 | self.final1 = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)
86 | self.final2 = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)
87 | self.final3 = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)
88 | self.final4 = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)
89 | else:
90 | self.final = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)
91 |
92 |
93 | def forward(self, input1,input2):
94 | input = torch.cat((input1, input2), dim=1)
95 | x0_0 = self.conv0_0(input)
96 |
97 | x1_0 = self.conv1_0(self.pool(x0_0))
98 | x0_1 = self.conv0_1(torch.cat([x0_0, self.up(x1_0)], 1))
99 |
100 |
101 | x2_0 = self.conv2_0(self.pool(x1_0))
102 | x1_1 = self.conv1_1(torch.cat([x1_0, self.up(x2_0)], 1))
103 | x0_2 = self.conv0_2(torch.cat([x0_0, x0_1, self.up(x1_1)], 1))
104 |
105 | x3_0 = self.conv3_0(self.pool(x2_0))
106 | x2_1 = self.conv2_1(torch.cat([x2_0, self.up(x3_0)], 1))
107 | x1_2 = self.conv1_2(torch.cat([x1_0, x1_1, self.up(x2_1)], 1))
108 | x0_3 = self.conv0_3(torch.cat([x0_0, x0_1, x0_2, self.up(x1_2)], 1))
109 |
110 | x4_0 = self.conv4_0(self.pool(x3_0))
111 | x3_1 = self.conv3_1(torch.cat([x3_0, self.up(x4_0)], 1))
112 | x2_2 = self.conv2_2(torch.cat([x2_0, x2_1, self.up(x3_1)], 1))
113 | x1_3 = self.conv1_3(torch.cat([x1_0, x1_1, x1_2, self.up(x2_2)], 1))
114 | x0_4 = self.conv0_4(torch.cat([x0_0, x0_1, x0_2, x0_3, self.up(x1_3)], 1))
115 |
116 | if self.deep_supervision:
117 | output1 = self.final1(x0_1)
118 |
119 | output2 = self.final2(x0_2)
120 | output3 = self.final3(x0_3)
121 | output4 = self.final4(x0_4)
122 | return [output1, output2, output3, output4]
123 |
124 | else:
125 | output = self.final(x0_4)
126 | return output
--------------------------------------------------------------------------------
/models/CBAM.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import math
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 | class BasicConv(nn.Module):
7 | def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1, relu=True, bn=True, bias=False):
8 | super(BasicConv, self).__init__()
9 | self.out_channels = out_planes
10 | self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias)
11 | self.bn = nn.BatchNorm2d(out_planes,eps=1e-5, momentum=0.01, affine=True) if bn else None
12 | self.relu = nn.ReLU() if relu else None
13 |
14 | def forward(self, x):
15 | x = self.conv(x)
16 | if self.bn is not None:
17 | x = self.bn(x)
18 | if self.relu is not None:
19 | x = self.relu(x)
20 | return x
21 |
22 | class Flatten(nn.Module):
23 | def forward(self, x):
24 | return x.view(x.size(0), -1)
25 |
26 | class ChannelGate(nn.Module):
27 | def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max']):
28 | super(ChannelGate, self).__init__()
29 | self.gate_channels = gate_channels
30 | self.mlp = nn.Sequential(
31 | Flatten(),
32 | nn.Linear(gate_channels, gate_channels // reduction_ratio),
33 | nn.ReLU(),
34 | nn.Linear(gate_channels // reduction_ratio, gate_channels)
35 | )
36 | self.pool_types = pool_types
37 | def forward(self, x):
38 | channel_att_sum = None
39 | for pool_type in self.pool_types:
40 | if pool_type=='avg':
41 | avg_pool = F.avg_pool2d( x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
42 | channel_att_raw = self.mlp( avg_pool )
43 | elif pool_type=='max':
44 | max_pool = F.max_pool2d( x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
45 | channel_att_raw = self.mlp( max_pool )
46 | elif pool_type=='lp':
47 | lp_pool = F.lp_pool2d( x, 2, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
48 | channel_att_raw = self.mlp( lp_pool )
49 | elif pool_type=='lse':
50 | # LSE pool only
51 | lse_pool = logsumexp_2d(x)
52 | channel_att_raw = self.mlp( lse_pool )
53 |
54 | if channel_att_sum is None:
55 | channel_att_sum = channel_att_raw
56 | else:
57 | channel_att_sum = channel_att_sum + channel_att_raw
58 |
59 | scale = F.sigmoid( channel_att_sum ).unsqueeze(2).unsqueeze(3).expand_as(x)
60 | return x * scale #返回通道注意力图(特征图*权重)
61 |
62 | def logsumexp_2d(tensor):
63 | tensor_flatten = tensor.view(tensor.size(0), tensor.size(1), -1)
64 | s, _ = torch.max(tensor_flatten, dim=2, keepdim=True)
65 | outputs = s + (tensor_flatten - s).exp().sum(dim=2, keepdim=True).log()
66 | return outputs
67 |
68 | class ChannelPool(nn.Module):
69 | def forward(self, x):
70 | return torch.cat( (torch.max(x,1)[0].unsqueeze(1), torch.mean(x,1).unsqueeze(1)), dim=1 )
71 |
72 | class SpatialGate(nn.Module):
73 | def __init__(self):
74 | super(SpatialGate, self).__init__()
75 | kernel_size = 7
76 | self.compress = ChannelPool()
77 | self.spatial = BasicConv(2, 1, kernel_size, stride=1, padding=(kernel_size-1) // 2, relu=False)
78 | def forward(self, x):
79 | x_compress = self.compress(x)
80 | x_out = self.spatial(x_compress)
81 | scale = F.sigmoid(x_out) # broadcasting
82 | return x * scale
83 |
84 | # class CBAM(nn.Module):
85 | # def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max'], no_spatial=False):
86 | # super(CBAM, self).__init__()
87 | # self.ChannelGate = ChannelGate(gate_channels, reduction_ratio, pool_types)
88 | # self.no_spatial=no_spatial
89 | # if not no_spatial:
90 | # self.SpatialGate = SpatialGate()
91 | # def forward(self, x):
92 | # x_out = self.ChannelGate(x)
93 | # if not self.no_spatial:
94 | # x_out = self.SpatialGate(x_out)
95 | # return x_out
96 | #
97 | class ChannelAttentionModule(nn.Module):
98 | def __init__(self, channel, ratio=16):
99 | super(ChannelAttentionModule, self).__init__()
100 | # 使用自适应池化缩减map的大小,保持通道不变
101 | self.avg_pool = nn.AdaptiveAvgPool2d(1)
102 | self.max_pool = nn.AdaptiveMaxPool2d(1)
103 |
104 | self.shared_MLP = nn.Sequential(
105 | nn.Conv2d(channel, channel // ratio, 1, bias=False),
106 | nn.ReLU(),
107 | nn.Conv2d(channel // ratio, channel, 1, bias=False)
108 | )
109 | self.sigmoid = nn.Sigmoid()
110 |
111 | def forward(self, x):
112 | avgout = self.shared_MLP(self.avg_pool(x))
113 | maxout = self.shared_MLP(self.max_pool(x))
114 | return self.sigmoid(avgout + maxout)
115 |
116 |
117 | class SpatialAttentionModule(nn.Module):
118 | def __init__(self):
119 | super(SpatialAttentionModule, self).__init__()
120 | self.conv2d = nn.Conv2d(in_channels=2, out_channels=1, kernel_size=7, stride=1, padding=3)
121 | self.sigmoid = nn.Sigmoid()
122 |
123 | def forward(self, x):
124 | # map尺寸不变,缩减通道
125 | avgout = torch.mean(x, dim=1, keepdim=True)
126 | maxout, _ = torch.max(x, dim=1, keepdim=True)
127 | out = torch.cat([avgout, maxout], dim=1)
128 | out = self.sigmoid(self.conv2d(out))
129 | return out
130 |
131 |
132 | class CBAM(nn.Module):
133 | def __init__(self, channel):
134 | super(CBAM, self).__init__()
135 | self.channel_attention = ChannelAttentionModule(channel)
136 | self.spatial_attention = SpatialAttentionModule()
137 |
138 | def forward(self, x):
139 | out = self.channel_attention(x) * x
140 | out = self.spatial_attention(out) * out
141 | return out
--------------------------------------------------------------------------------
/misc/metric_tool.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | ################### metrics ###################
5 | class AverageMeter(object):
6 | """Computes and stores the average and current value"""
7 | def __init__(self):
8 | self.initialized = False
9 | self.val = None
10 | self.avg = None
11 | self.sum = None
12 | self.count = None
13 |
14 | def initialize(self, val, weight):
15 | self.val = val
16 | self.avg = val
17 | self.sum = val * weight
18 | self.count = weight
19 | self.initialized = True
20 |
21 | def update(self, val, weight=1):
22 | if not self.initialized:
23 | self.initialize(val, weight)
24 | else:
25 | self.add(val, weight)
26 |
27 | def add(self, val, weight):
28 | self.val = val
29 | self.sum += val * weight
30 | self.count += weight
31 | self.avg = self.sum / self.count
32 |
33 | def value(self):
34 | return self.val
35 |
36 | def average(self):
37 | return self.avg
38 |
39 | def get_scores(self):
40 | scores_dict = cm2score(self.sum)
41 | return scores_dict
42 |
43 | def clear(self):
44 | self.initialized = False
45 |
46 |
47 | ################### cm metrics ###################
48 | class ConfuseMatrixMeter(AverageMeter):
49 | """Computes and stores the average and current value"""
50 | def __init__(self, n_class):
51 | super(ConfuseMatrixMeter, self).__init__()
52 | self.n_class = n_class
53 |
54 | def update_cm(self, pr, gt, weight=1):
55 | """获得当前混淆矩阵,并计算当前F1得分,并更新混淆矩阵"""
56 | val = get_confuse_matrix(num_classes=self.n_class, label_gts=gt, label_preds=pr)
57 | self.update(val, weight)
58 | current_score = cm2F1(val)
59 | return current_score
60 |
61 | def get_scores(self):
62 | scores_dict = cm2score(self.sum)
63 | return scores_dict
64 |
65 |
66 |
67 | def harmonic_mean(xs):
68 | harmonic_mean = len(xs) / sum((x+1e-6)**-1 for x in xs)
69 | return harmonic_mean
70 |
71 |
72 | def cm2F1(confusion_matrix):
73 | hist = confusion_matrix
74 | n_class = hist.shape[0]
75 | tp = np.diag(hist)
76 | sum_a1 = hist.sum(axis=1)
77 | sum_a0 = hist.sum(axis=0)
78 | # ---------------------------------------------------------------------- #
79 | # 1. Accuracy & Class Accuracy
80 | # ---------------------------------------------------------------------- #
81 | acc = tp.sum() / (hist.sum() + np.finfo(np.float32).eps)
82 |
83 | # recall
84 | recall = tp / (sum_a1 + np.finfo(np.float32).eps)
85 | # acc_cls = np.nanmean(recall)
86 |
87 | # precision
88 | precision = tp / (sum_a0 + np.finfo(np.float32).eps)
89 |
90 | # F1 score
91 | F1 = 2 * recall * precision / (recall + precision + np.finfo(np.float32).eps)
92 | mean_F1 = np.nanmean(F1)#求均值的时候忽略nan值
93 | return mean_F1
94 |
95 |
96 | def cm2score(confusion_matrix):
97 | hist = confusion_matrix
98 | n_class = hist.shape[0]
99 | tp = np.diag(hist)
100 | sum_a1 = hist.sum(axis=1)
101 | sum_a0 = hist.sum(axis=0)
102 | # ---------------------------------------------------------------------- #
103 | # 1. Accuracy & Class Accuracy
104 | # ---------------------------------------------------------------------- #
105 | acc = tp.sum() / (hist.sum() + np.finfo(np.float32).eps)
106 |
107 | # recall
108 | recall = tp / (sum_a1 + np.finfo(np.float32).eps)
109 | # acc_cls = np.nanmean(recall)
110 |
111 | # precision
112 | precision = tp / (sum_a0 + np.finfo(np.float32).eps)
113 |
114 | # F1 score
115 | F1 = 2*recall * precision / (recall + precision + np.finfo(np.float32).eps)
116 | mean_F1 = np.nanmean(F1)
117 | # ---------------------------------------------------------------------- #
118 | # 2. Frequency weighted Accuracy & Mean IoU
119 | # ---------------------------------------------------------------------- #
120 | iu = tp / (sum_a1 + hist.sum(axis=0) - tp + np.finfo(np.float32).eps)
121 | mean_iu = np.nanmean(iu)
122 |
123 | freq = sum_a1 / (hist.sum() + np.finfo(np.float32).eps)
124 | fwavacc = (freq[freq > 0] * iu[freq > 0]).sum()
125 |
126 | #
127 | cls_iou = dict(zip(['iou_'+str(i) for i in range(n_class)], iu))
128 |
129 | cls_precision = dict(zip(['precision_'+str(i) for i in range(n_class)], precision))
130 | cls_recall = dict(zip(['recall_'+str(i) for i in range(n_class)], recall))
131 | cls_F1 = dict(zip(['F1_'+str(i) for i in range(n_class)], F1))
132 |
133 | score_dict = {'acc': acc, 'miou': mean_iu, 'mf1':mean_F1}
134 | score_dict.update(cls_iou)
135 | score_dict.update(cls_F1)
136 | score_dict.update(cls_precision)
137 | score_dict.update(cls_recall)
138 | return score_dict
139 |
140 |
141 | def get_confuse_matrix(num_classes, label_gts, label_preds):
142 | """计算一组预测的混淆矩阵"""
143 | def __fast_hist(label_gt, label_pred):
144 | """
145 | Collect values for Confusion Matrix
146 | For reference, please see: https://en.wikipedia.org/wiki/Confusion_matrix
147 | :param label_gt: ground-truth
148 | :param label_pred: prediction
149 | :return: values for confusion matrix
150 | """
151 | mask = (label_gt >= 0) & (label_gt < num_classes)
152 | hist = np.bincount(num_classes * label_gt[mask].astype(int) + label_pred[mask],
153 | minlength=num_classes**2).reshape(num_classes, num_classes)
154 | return hist
155 | confusion_matrix = np.zeros((num_classes, num_classes))
156 | for lt, lp in zip(label_gts, label_preds):
157 | confusion_matrix += __fast_hist(lt.flatten(), lp.flatten())
158 | return confusion_matrix
159 |
160 |
161 | def get_mIoU(num_classes, label_gts, label_preds):
162 | confusion_matrix = get_confuse_matrix(num_classes, label_gts, label_preds)
163 | score_dict = cm2score(confusion_matrix)
164 | return score_dict['miou']
165 |
--------------------------------------------------------------------------------
/utils/helpers.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import torch
3 | import torch.utils.data
4 | import torch.nn as nn
5 | import numpy as np
6 | from utils.dataloaders import (full_path_loader, full_test_loader, CDDloader, LEVIRloader,LEVIRplusloader)
7 | from utils.metrics import jaccard_loss, dice_loss
8 | from utils.losses import hybrid_loss
9 | from models.Models import Siam_NestedUNet_Conc, SNUNet_ECAM
10 | from models.siamunet_dif import SiamUnet_diff
11 | from models.networks import BASE_Transformer
12 |
13 |
14 | logging.basicConfig(level=logging.INFO)
15 |
16 | def initialize_metrics():
17 | """Generates a dictionary of metrics with metrics as keys
18 | and empty lists as values
19 |
20 | Returns
21 | -------
22 | dict
23 | a dictionary of metrics
24 |
25 | """
26 | metrics = {
27 | 'cd_losses': [],
28 | 'cd_corrects': [],
29 | 'cd_precisions': [],
30 | 'cd_recalls': [],
31 | 'cd_f1scores': [],
32 | 'learning_rate': [],
33 | }
34 |
35 | return metrics
36 |
37 |
38 | def get_mean_metrics(metric_dict):
39 | """takes a dictionary of lists for metrics and returns dict of mean values
40 |
41 | Parameters
42 | ----------
43 | metric_dict : dict
44 | A dictionary of metrics
45 |
46 | Returns
47 | -------
48 | dict
49 | dict of floats that reflect mean metric value
50 |
51 | """
52 | return {k: np.mean(v) for k, v in metric_dict.items()}
53 |
54 |
55 | def set_metrics(metric_dict, cd_loss, cd_corrects, cd_report, lr):
56 | """Updates metric dict with batch metrics
57 |
58 | Parameters
59 | ----------
60 | metric_dict : dict
61 | dict of metrics
62 | cd_loss : dict(?)
63 | loss value
64 | cd_corrects : dict(?)
65 | number of correct results (to generate accuracy
66 | cd_report : list
67 | precision, recall, f1 values
68 |
69 | Returns
70 | -------
71 | dict
72 | dict of updated metrics
73 |
74 |
75 | """
76 | metric_dict['cd_losses'].append(cd_loss.item())
77 | metric_dict['cd_corrects'].append(cd_corrects.item())
78 | metric_dict['cd_precisions'].append(cd_report[0])
79 | metric_dict['cd_recalls'].append(cd_report[1])
80 | metric_dict['cd_f1scores'].append(cd_report[2])
81 | metric_dict['learning_rate'].append(lr)
82 |
83 | return metric_dict
84 |
85 | def set_test_metrics(metric_dict, cd_corrects, cd_report):
86 |
87 | metric_dict['cd_corrects'].append(cd_corrects.item())
88 | metric_dict['cd_precisions'].append(cd_report[0])
89 | metric_dict['cd_recalls'].append(cd_report[1])
90 | metric_dict['cd_f1scores'].append(cd_report[2])
91 |
92 | return metric_dict
93 |
94 |
95 | def get_loaders(opt):
96 |
97 |
98 | logging.info('STARTING Dataset Creation')
99 |
100 | train_full_load, val_full_load = full_path_loader(opt.dataset_dir)
101 |
102 | if opt.dataset == 'cdd':
103 |
104 | train_dataset = CDDloader(train_full_load, flag = 'trn', aug=opt.augmentation)
105 | val_dataset = CDDloader(val_full_load, flag='val', aug=False)
106 |
107 | elif opt.dataset == 'levir':
108 | train_dataset = LEVIRloader(train_full_load, flag = 'trn', aug=opt.augmentation)
109 | val_dataset = LEVIRloader(val_full_load, flag='val', aug=False)
110 | elif opt.dataset == 'levir+':
111 | train_dataset = LEVIRplusloader(train_full_load, flag='trn', aug=opt.augmentation)
112 | val_dataset = LEVIRplusloader(val_full_load, flag='val', aug=False)
113 |
114 | logging.info('STARTING Dataloading')
115 |
116 | train_loader = torch.utils.data.DataLoader(train_dataset,
117 | batch_size=opt.batch_size,
118 | shuffle=True,
119 | num_workers=opt.num_workers,
120 | pin_memory=True)
121 | val_loader = torch.utils.data.DataLoader(val_dataset,
122 | batch_size=opt.batch_size,
123 | shuffle=False,
124 | num_workers=opt.num_workers,
125 | pin_memory=True)
126 | return train_loader, val_loader
127 |
128 | def get_test_loaders(opt, batch_size=None):
129 |
130 | if not batch_size:
131 | batch_size = opt.batch_size
132 |
133 | logging.info('STARTING Dataset Creation')
134 |
135 | test_full_load = full_test_loader(opt.dataset_dir)
136 |
137 | if opt.dataset == 'cdd':
138 |
139 | test_dataset = CDDloader(test_full_load, flag = 'tes', aug=False)
140 |
141 | elif opt.dataset == 'levir':
142 |
143 | test_dataset = LEVIRloader(test_full_load, flag = 'tes', aug=False)
144 |
145 | logging.info('STARTING Dataloading')
146 |
147 |
148 | test_loader = torch.utils.data.DataLoader(test_dataset,
149 | batch_size=batch_size,
150 | shuffle=False,
151 | num_workers=opt.num_workers)
152 | return test_loader
153 |
154 |
155 | def get_criterion(opt):
156 | """get the user selected loss function
157 |
158 | Parameters
159 | ----------
160 | opt : dict
161 | Dictionary of options/flags
162 |
163 | Returns
164 | -------
165 | method
166 | loss function
167 |
168 | """
169 | if opt.loss_function == 'hybrid':
170 | criterion = hybrid_loss
171 | if opt.loss_function == 'bce':
172 | criterion = nn.CrossEntropyLoss()
173 | if opt.loss_function == 'dice':
174 | criterion = dice_loss
175 | if opt.loss_function == 'jaccard':
176 | criterion = jaccard_loss
177 |
178 | return criterion
179 |
180 |
181 | def load_model(opt, device):
182 | """Load the model
183 |
184 | Parameters
185 | ----------
186 | opt : dict
187 | User specified flags/options
188 | device : string
189 | device on which to train model
190 |
191 | """
192 | # device_ids = list(range(opt.num_gpus))
193 | #model = SNUNet_ECAM(opt, opt.num_channel, 2).to(device)
194 | # model = nn.DataParallel(model, device_ids=device_ids)
195 |
196 | model = BASE_Transformer(opt, input_nc=3, output_nc=2, token_len=4, resnet_stages_num=4,
197 | with_pos='learned', enc_depth=1, dec_depth=8).to(device)
198 |
199 | return model
200 |
--------------------------------------------------------------------------------
/compare/SNUNet.py:
--------------------------------------------------------------------------------
1 | # Kaiyu Li
2 | # https://github.com/likyoo
3 | #
4 |
5 | import torch.nn as nn
6 | import torch
7 |
8 | class conv_block_nested(nn.Module):
9 | def __init__(self, in_ch, mid_ch, out_ch):
10 | super(conv_block_nested, self).__init__()
11 | self.activation = nn.ReLU(inplace=True)
12 | self.conv1 = nn.Conv2d(in_ch, mid_ch, kernel_size=3, padding=1, bias=True)
13 | self.bn1 = nn.BatchNorm2d(mid_ch)
14 | self.conv2 = nn.Conv2d(mid_ch, out_ch, kernel_size=3, padding=1, bias=True)
15 | self.bn2 = nn.BatchNorm2d(out_ch)
16 |
17 | def forward(self, x):
18 | x = self.conv1(x)
19 | identity = x
20 | x = self.bn1(x)
21 | x = self.activation(x)
22 |
23 | x = self.conv2(x)
24 | x = self.bn2(x)
25 | output = self.activation(x + identity)
26 | return output
27 |
28 |
29 | class up(nn.Module):
30 | def __init__(self, in_ch, bilinear=False):
31 | super(up, self).__init__()
32 |
33 | if bilinear:
34 | self.up = nn.Upsample(scale_factor=2,
35 | mode='bilinear',
36 | align_corners=True)
37 | else:
38 | self.up = nn.ConvTranspose2d(in_ch, in_ch, 2, stride=2)
39 |
40 | def forward(self, x):
41 |
42 | x = self.up(x)
43 | return x
44 |
45 |
46 | class ChannelAttention(nn.Module):
47 | def __init__(self, in_channels, ratio = 16):
48 | super(ChannelAttention, self).__init__()
49 | self.avg_pool = nn.AdaptiveAvgPool2d(1)
50 | self.max_pool = nn.AdaptiveMaxPool2d(1)
51 | self.fc1 = nn.Conv2d(in_channels,in_channels//ratio,1,bias=False)
52 | self.relu1 = nn.ReLU()
53 | self.fc2 = nn.Conv2d(in_channels//ratio, in_channels,1,bias=False)
54 | self.sigmod = nn.Sigmoid()
55 | def forward(self,x):
56 | avg_out = self.fc2(self.relu1(self.fc1(self.avg_pool(x))))
57 | max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x))))
58 | out = avg_out + max_out
59 | return self.sigmod(out)
60 |
61 |
62 |
63 | class SNUNet_ECAM(nn.Module):
64 | # SNUNet-CD with ECAM
65 | def __init__(self, in_ch=3, out_ch=2):
66 | super(SNUNet_ECAM, self).__init__()
67 | torch.nn.Module.dump_patches = True
68 | n1 = 32 # the initial number of channels of feature map
69 | filters = [n1, n1 * 2, n1 * 4, n1 * 8, n1 * 16]
70 |
71 | self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
72 |
73 | self.conv0_0 = conv_block_nested(in_ch, filters[0], filters[0])
74 | self.conv1_0 = conv_block_nested(filters[0], filters[1], filters[1])
75 | self.Up1_0 = up(filters[1])
76 | self.conv2_0 = conv_block_nested(filters[1], filters[2], filters[2])
77 | self.Up2_0 = up(filters[2])
78 | self.conv3_0 = conv_block_nested(filters[2], filters[3], filters[3])
79 | self.Up3_0 = up(filters[3])
80 | self.conv4_0 = conv_block_nested(filters[3], filters[4], filters[4])
81 | self.Up4_0 = up(filters[4])
82 |
83 | self.conv0_1 = conv_block_nested(filters[0] * 2 + filters[1], filters[0], filters[0])
84 | self.conv1_1 = conv_block_nested(filters[1] * 2 + filters[2], filters[1], filters[1])
85 | self.Up1_1 = up(filters[1])
86 | self.conv2_1 = conv_block_nested(filters[2] * 2 + filters[3], filters[2], filters[2])
87 | self.Up2_1 = up(filters[2])
88 | self.conv3_1 = conv_block_nested(filters[3] * 2 + filters[4], filters[3], filters[3])
89 | self.Up3_1 = up(filters[3])
90 |
91 | self.conv0_2 = conv_block_nested(filters[0] * 3 + filters[1], filters[0], filters[0])
92 | self.conv1_2 = conv_block_nested(filters[1] * 3 + filters[2], filters[1], filters[1])
93 | self.Up1_2 = up(filters[1])
94 | self.conv2_2 = conv_block_nested(filters[2] * 3 + filters[3], filters[2], filters[2])
95 | self.Up2_2 = up(filters[2])
96 |
97 | self.conv0_3 = conv_block_nested(filters[0] * 4 + filters[1], filters[0], filters[0])
98 | self.conv1_3 = conv_block_nested(filters[1] * 4 + filters[2], filters[1], filters[1])
99 | self.Up1_3 = up(filters[1])
100 |
101 | self.conv0_4 = conv_block_nested(filters[0] * 5 + filters[1], filters[0], filters[0])
102 |
103 | self.ca = ChannelAttention(filters[0] * 4, ratio=16)
104 | self.ca1 = ChannelAttention(filters[0], ratio=16 // 4)
105 |
106 | self.conv_final = nn.Conv2d(filters[0] * 4, out_ch, kernel_size=1)
107 | # self.conv_final = nn.Conv2d(filters[0], out_ch, kernel_size=1)
108 |
109 | for m in self.modules():
110 | if isinstance(m, nn.Conv2d):
111 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
112 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
113 | nn.init.constant_(m.weight, 1)
114 | nn.init.constant_(m.bias, 0)
115 |
116 |
117 | def forward(self, xA, xB):
118 | '''xA'''
119 | x0_0A = self.conv0_0(xA)
120 | x1_0A = self.conv1_0(self.pool(x0_0A))
121 | x2_0A = self.conv2_0(self.pool(x1_0A))
122 | x3_0A = self.conv3_0(self.pool(x2_0A))
123 | # x4_0A = self.conv4_0(self.pool(x3_0A))
124 | '''xB'''
125 | x0_0B = self.conv0_0(xB)
126 | x1_0B = self.conv1_0(self.pool(x0_0B))
127 | x2_0B = self.conv2_0(self.pool(x1_0B))
128 | x3_0B = self.conv3_0(self.pool(x2_0B))
129 | x4_0B = self.conv4_0(self.pool(x3_0B))
130 |
131 | x0_1 = self.conv0_1(torch.cat([x0_0A, x0_0B, self.Up1_0(x1_0B)], 1))
132 | x1_1 = self.conv1_1(torch.cat([x1_0A, x1_0B, self.Up2_0(x2_0B)], 1))
133 | x0_2 = self.conv0_2(torch.cat([x0_0A, x0_0B, x0_1, self.Up1_1(x1_1)], 1))
134 |
135 |
136 | x2_1 = self.conv2_1(torch.cat([x2_0A, x2_0B, self.Up3_0(x3_0B)], 1))
137 | x1_2 = self.conv1_2(torch.cat([x1_0A, x1_0B, x1_1, self.Up2_1(x2_1)], 1))
138 | x0_3 = self.conv0_3(torch.cat([x0_0A, x0_0B, x0_1, x0_2, self.Up1_2(x1_2)], 1))
139 |
140 | x3_1 = self.conv3_1(torch.cat([x3_0A, x3_0B, self.Up4_0(x4_0B)], 1))
141 | x2_2 = self.conv2_2(torch.cat([x2_0A, x2_0B, x2_1, self.Up3_1(x3_1)], 1))
142 | x1_3 = self.conv1_3(torch.cat([x1_0A, x1_0B, x1_1, x1_2, self.Up2_2(x2_2)], 1))
143 | x0_4 = self.conv0_4(torch.cat([x0_0A, x0_0B, x0_1, x0_2, x0_3, self.Up1_3(x1_3)], 1))
144 |
145 | out = torch.cat([x0_1, x0_2, x0_3, x0_4], 1)
146 |
147 | intra = torch.sum(torch.stack((x0_1, x0_2, x0_3, x0_4)), dim=0)
148 | ca1 = self.ca1(intra)
149 | out = self.ca(out) * (out + ca1.repeat(1, 4, 1, 1))
150 | out = self.conv_final(out)
151 |
152 | # return (out, )
153 | return out
154 |
155 |
--------------------------------------------------------------------------------
/compare/FC_EF.py:
--------------------------------------------------------------------------------
1 | # Rodrigo Caye Daudt
2 | # https://rcdaudt.github.io/
3 | # Daudt, R. C., Le Saux, B., & Boulch, A. "Fully convolutional siamese networks for change detection". In 2018 25th IEEE International Conference on Image Processing (ICIP) (pp. 4063-4067). IEEE.
4 |
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 | from torch.nn.modules.padding import ReplicationPad2d
9 |
10 |
11 |
12 | class Unet(nn.Module):
13 | """EF segmentation network."""
14 |
15 | def __init__(self, input_nbr, label_nbr):
16 | super(Unet, self).__init__()
17 |
18 | self.conv11 = nn.Conv2d(2 * input_nbr, 16, kernel_size=3, padding=1)
19 | self.bn11 = nn.BatchNorm2d(16)
20 | self.do11 = nn.Dropout2d(p=0.2)
21 | self.conv12 = nn.Conv2d(16, 16, kernel_size=3, padding=1)
22 | self.bn12 = nn.BatchNorm2d(16)
23 | self.do12 = nn.Dropout2d(p=0.2)
24 |
25 | self.conv21 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
26 | self.bn21 = nn.BatchNorm2d(32)
27 | self.do21 = nn.Dropout2d(p=0.2)
28 | self.conv22 = nn.Conv2d(32, 32, kernel_size=3, padding=1)
29 | self.bn22 = nn.BatchNorm2d(32)
30 | self.do22 = nn.Dropout2d(p=0.2)
31 |
32 | self.conv31 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
33 | self.bn31 = nn.BatchNorm2d(64)
34 | self.do31 = nn.Dropout2d(p=0.2)
35 | self.conv32 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
36 | self.bn32 = nn.BatchNorm2d(64)
37 | self.do32 = nn.Dropout2d(p=0.2)
38 | self.conv33 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
39 | self.bn33 = nn.BatchNorm2d(64)
40 | self.do33 = nn.Dropout2d(p=0.2)
41 |
42 | self.conv41 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
43 | self.bn41 = nn.BatchNorm2d(128)
44 | self.do41 = nn.Dropout2d(p=0.2)
45 | self.conv42 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
46 | self.bn42 = nn.BatchNorm2d(128)
47 | self.do42 = nn.Dropout2d(p=0.2)
48 | self.conv43 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
49 | self.bn43 = nn.BatchNorm2d(128)
50 | self.do43 = nn.Dropout2d(p=0.2)
51 |
52 | self.upconv4 = nn.ConvTranspose2d(128, 128, kernel_size=3, padding=1, stride=2, output_padding=1)
53 |
54 | self.conv43d = nn.ConvTranspose2d(256, 128, kernel_size=3, padding=1)
55 | self.bn43d = nn.BatchNorm2d(128)
56 | self.do43d = nn.Dropout2d(p=0.2)
57 | self.conv42d = nn.ConvTranspose2d(128, 128, kernel_size=3, padding=1)
58 | self.bn42d = nn.BatchNorm2d(128)
59 | self.do42d = nn.Dropout2d(p=0.2)
60 | self.conv41d = nn.ConvTranspose2d(128, 64, kernel_size=3, padding=1)
61 | self.bn41d = nn.BatchNorm2d(64)
62 | self.do41d = nn.Dropout2d(p=0.2)
63 |
64 | self.upconv3 = nn.ConvTranspose2d(64, 64, kernel_size=3, padding=1, stride=2, output_padding=1)
65 |
66 | self.conv33d = nn.ConvTranspose2d(128, 64, kernel_size=3, padding=1)
67 | self.bn33d = nn.BatchNorm2d(64)
68 | self.do33d = nn.Dropout2d(p=0.2)
69 | self.conv32d = nn.ConvTranspose2d(64, 64, kernel_size=3, padding=1)
70 | self.bn32d = nn.BatchNorm2d(64)
71 | self.do32d = nn.Dropout2d(p=0.2)
72 | self.conv31d = nn.ConvTranspose2d(64, 32, kernel_size=3, padding=1)
73 | self.bn31d = nn.BatchNorm2d(32)
74 | self.do31d = nn.Dropout2d(p=0.2)
75 |
76 | self.upconv2 = nn.ConvTranspose2d(32, 32, kernel_size=3, padding=1, stride=2, output_padding=1)
77 |
78 | self.conv22d = nn.ConvTranspose2d(64, 32, kernel_size=3, padding=1)
79 | self.bn22d = nn.BatchNorm2d(32)
80 | self.do22d = nn.Dropout2d(p=0.2)
81 | self.conv21d = nn.ConvTranspose2d(32, 16, kernel_size=3, padding=1)
82 | self.bn21d = nn.BatchNorm2d(16)
83 | self.do21d = nn.Dropout2d(p=0.2)
84 |
85 | self.upconv1 = nn.ConvTranspose2d(16, 16, kernel_size=3, padding=1, stride=2, output_padding=1)
86 |
87 | self.conv12d = nn.ConvTranspose2d(32, 16, kernel_size=3, padding=1)
88 | self.bn12d = nn.BatchNorm2d(16)
89 | self.do12d = nn.Dropout2d(p=0.2)
90 | self.conv11d = nn.ConvTranspose2d(16, label_nbr, kernel_size=3, padding=1)
91 |
92 | self.sm = nn.LogSoftmax(dim=1)
93 |
94 | def forward(self, x1, x2):
95 | x = torch.cat((x1, x2), 1)
96 |
97 | """Forward method."""
98 | # Stage 1
99 | x11 = self.do11(F.relu(self.bn11(self.conv11(x))))
100 | x12 = self.do12(F.relu(self.bn12(self.conv12(x11))))
101 | x1p = F.max_pool2d(x12, kernel_size=2, stride=2)
102 |
103 | # Stage 2
104 | x21 = self.do21(F.relu(self.bn21(self.conv21(x1p))))
105 | x22 = self.do22(F.relu(self.bn22(self.conv22(x21))))
106 | x2p = F.max_pool2d(x22, kernel_size=2, stride=2)
107 |
108 | # Stage 3
109 | x31 = self.do31(F.relu(self.bn31(self.conv31(x2p))))
110 | x32 = self.do32(F.relu(self.bn32(self.conv32(x31))))
111 | x33 = self.do33(F.relu(self.bn33(self.conv33(x32))))
112 | x3p = F.max_pool2d(x33, kernel_size=2, stride=2)
113 |
114 | # Stage 4
115 | x41 = self.do41(F.relu(self.bn41(self.conv41(x3p))))
116 | x42 = self.do42(F.relu(self.bn42(self.conv42(x41))))
117 | x43 = self.do43(F.relu(self.bn43(self.conv43(x42))))
118 | x4p = F.max_pool2d(x43, kernel_size=2, stride=2)
119 |
120 | # Stage 4d
121 | x4d = self.upconv4(x4p)
122 | pad4 = ReplicationPad2d((0, x43.size(3) - x4d.size(3), 0, x43.size(2) - x4d.size(2)))
123 | x4d = torch.cat((pad4(x4d), x43), 1)
124 | x43d = self.do43d(F.relu(self.bn43d(self.conv43d(x4d))))
125 | x42d = self.do42d(F.relu(self.bn42d(self.conv42d(x43d))))
126 | x41d = self.do41d(F.relu(self.bn41d(self.conv41d(x42d))))
127 |
128 | # Stage 3d
129 | x3d = self.upconv3(x41d)
130 | pad3 = ReplicationPad2d((0, x33.size(3) - x3d.size(3), 0, x33.size(2) - x3d.size(2)))
131 | x3d = torch.cat((pad3(x3d), x33), 1)
132 | x33d = self.do33d(F.relu(self.bn33d(self.conv33d(x3d))))
133 | x32d = self.do32d(F.relu(self.bn32d(self.conv32d(x33d))))
134 | x31d = self.do31d(F.relu(self.bn31d(self.conv31d(x32d))))
135 |
136 | # Stage 2d
137 | x2d = self.upconv2(x31d)
138 | pad2 = ReplicationPad2d((0, x22.size(3) - x2d.size(3), 0, x22.size(2) - x2d.size(2)))
139 | x2d = torch.cat((pad2(x2d), x22), 1)
140 | x22d = self.do22d(F.relu(self.bn22d(self.conv22d(x2d))))
141 | x21d = self.do21d(F.relu(self.bn21d(self.conv21d(x22d))))
142 |
143 | # Stage 1d
144 | x1d = self.upconv1(x21d)
145 | pad1 = ReplicationPad2d((0, x12.size(3) - x1d.size(3), 0, x12.size(2) - x1d.size(2)))
146 | x1d = torch.cat((pad1(x1d), x12), 1)
147 | x12d = self.do12d(F.relu(self.bn12d(self.conv12d(x1d))))
148 | x11d = self.conv11d(x12d)
149 |
150 | # output = []
151 | # output.append(x11d)
152 | # output = output[-1]
153 |
154 | return x11d
155 |
156 |
157 |
--------------------------------------------------------------------------------
/datasets/data_utils.py:
--------------------------------------------------------------------------------
1 | import random
2 | import numpy as np
3 |
4 | from PIL import Image
5 | from PIL import ImageFilter
6 |
7 | import torchvision.transforms.functional as TF
8 | from torchvision import transforms
9 | import torch
10 |
11 | """
12 | 这是数据格式转换和数据增强的代码
13 | """
14 |
15 |
16 | def to_tensor_and_norm(imgs, labels):
17 | # to tensor
18 | imgs = [TF.to_tensor(img) for img in imgs]
19 | labels = [torch.from_numpy(np.array(img, np.uint8)).unsqueeze(dim=0)
20 | for img in labels]
21 |
22 | imgs = [TF.normalize(img, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
23 | for img in imgs]
24 | return imgs, labels
25 |
26 |
27 | class CDDataAugmentation:
28 |
29 | def __init__(
30 | self,
31 | img_size,
32 | with_random_hflip=False,
33 | with_random_vflip=False,
34 | with_random_rot=False,
35 | with_random_crop=False,
36 | with_scale_random_crop=False,
37 | with_random_blur=False,
38 | ):
39 | self.img_size = img_size
40 | if self.img_size is None:
41 | self.img_size_dynamic = True
42 | else:
43 | self.img_size_dynamic = False
44 | self.with_random_hflip = with_random_hflip
45 | self.with_random_vflip = with_random_vflip
46 | self.with_random_rot = with_random_rot
47 | self.with_random_crop = with_random_crop
48 | self.with_scale_random_crop = with_scale_random_crop
49 | self.with_random_blur = with_random_blur
50 | def transform(self, imgs, labels, to_tensor=True):
51 | """
52 | :param imgs: [ndarray,]
53 | :param labels: [ndarray,]
54 | :return: [ndarray,],[ndarray,]
55 | """
56 | # resize image and covert to tensor
57 | imgs = [TF.to_pil_image(img) for img in imgs]
58 | if self.img_size is None:
59 | self.img_size = None
60 |
61 | if not self.img_size_dynamic:
62 | if imgs[0].size != (self.img_size, self.img_size):
63 | imgs = [TF.resize(img, [self.img_size, self.img_size], interpolation=3)
64 | for img in imgs]
65 | else:
66 | self.img_size = imgs[0].size[0]
67 |
68 | labels = [TF.to_pil_image(img) for img in labels]
69 | if len(labels) != 0:
70 | if labels[0].size != (self.img_size, self.img_size):
71 | labels = [TF.resize(img, [self.img_size, self.img_size], interpolation=0)
72 | for img in labels]
73 |
74 | random_base = 0.5
75 | if self.with_random_hflip and random.random() > 0.5:
76 | imgs = [TF.hflip(img) for img in imgs]
77 | labels = [TF.hflip(img) for img in labels]
78 |
79 | if self.with_random_vflip and random.random() > 0.5:
80 | imgs = [TF.vflip(img) for img in imgs]
81 | labels = [TF.vflip(img) for img in labels]
82 |
83 | if self.with_random_rot and random.random() > random_base:
84 | angles = [90, 180, 270]
85 | index = random.randint(0, 2)
86 | angle = angles[index]
87 | imgs = [TF.rotate(img, angle) for img in imgs]
88 | labels = [TF.rotate(img, angle) for img in labels]
89 |
90 | if self.with_random_crop and random.random() > 0:
91 | i, j, h, w = transforms.RandomResizedCrop(size=self.img_size). \
92 | get_params(img=imgs[0], scale=(0.8, 1.0), ratio=(1, 1))
93 |
94 | imgs = [TF.resized_crop(img, i, j, h, w,
95 | size=(self.img_size, self.img_size),
96 | interpolation=Image.CUBIC)
97 | for img in imgs]
98 |
99 | labels = [TF.resized_crop(img, i, j, h, w,
100 | size=(self.img_size, self.img_size),
101 | interpolation=Image.NEAREST)
102 | for img in labels]
103 |
104 | if self.with_scale_random_crop:
105 | # rescale
106 | scale_range = [1, 1.2]
107 | target_scale = scale_range[0] + random.random() * (scale_range[1] - scale_range[0])
108 |
109 | imgs = [pil_rescale(img, target_scale, order=3) for img in imgs]
110 | labels = [pil_rescale(img, target_scale, order=0) for img in labels]
111 | # crop
112 | imgsize = imgs[0].size # h, w
113 | box = get_random_crop_box(imgsize=imgsize, cropsize=self.img_size)
114 | imgs = [pil_crop(img, box, cropsize=self.img_size, default_value=0)
115 | for img in imgs]
116 | labels = [pil_crop(img, box, cropsize=self.img_size, default_value=255)
117 | for img in labels]
118 |
119 | if self.with_random_blur and random.random() > 0:
120 | radius = random.random()
121 | imgs = [img.filter(ImageFilter.GaussianBlur(radius=radius))
122 | for img in imgs]
123 |
124 | if to_tensor:
125 | # to tensor
126 | imgs = [TF.to_tensor(img) for img in imgs]
127 | labels = [torch.from_numpy(np.array(img, np.uint8)).unsqueeze(dim=0)
128 | for img in labels]
129 |
130 | imgs = [TF.normalize(img, mean=[0.5, 0.5, 0.5],std=[0.5, 0.5, 0.5])
131 | for img in imgs]
132 |
133 | return imgs, labels
134 |
135 |
136 | def pil_crop(image, box, cropsize, default_value):
137 | assert isinstance(image, Image.Image)
138 | img = np.array(image)
139 |
140 | if len(img.shape) == 3:
141 | cont = np.ones((cropsize, cropsize, img.shape[2]), img.dtype)*default_value
142 | else:
143 | cont = np.ones((cropsize, cropsize), img.dtype)*default_value
144 | cont[box[0]:box[1], box[2]:box[3]] = img[box[4]:box[5], box[6]:box[7]]
145 |
146 | return Image.fromarray(cont)
147 |
148 |
149 | def get_random_crop_box(imgsize, cropsize):
150 | h, w = imgsize
151 | ch = min(cropsize, h)
152 | cw = min(cropsize, w)
153 |
154 | w_space = w - cropsize
155 | h_space = h - cropsize
156 |
157 | if w_space > 0:
158 | cont_left = 0
159 | img_left = random.randrange(w_space + 1)
160 | else:
161 | cont_left = random.randrange(-w_space + 1)
162 | img_left = 0
163 |
164 | if h_space > 0:
165 | cont_top = 0
166 | img_top = random.randrange(h_space + 1)
167 | else:
168 | cont_top = random.randrange(-h_space + 1)
169 | img_top = 0
170 |
171 | return cont_top, cont_top+ch, cont_left, cont_left+cw, img_top, img_top+ch, img_left, img_left+cw
172 |
173 |
174 | def pil_rescale(img, scale, order):
175 | assert isinstance(img, Image.Image)
176 | height, width = img.size
177 | target_size = (int(np.round(height*scale)), int(np.round(width*scale)))
178 | return pil_resize(img, target_size, order)
179 |
180 |
181 | def pil_resize(img, size, order):
182 | assert isinstance(img, Image.Image)
183 | if size[0] == img.size[0] and size[1] == img.size[1]:
184 | return img
185 | if order == 3:
186 | resample = Image.BICUBIC
187 | elif order == 0:
188 | resample = Image.NEAREST
189 | return img.resize(size[::-1], resample)
190 |
--------------------------------------------------------------------------------
/utils/metrics.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.utils.data
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from torch.autograd import Variable
6 |
7 |
8 |
9 | class FocalLoss(nn.Module):
10 | def __init__(self, gamma=0, alpha=None, size_average=True):
11 | super(FocalLoss, self).__init__()
12 | self.gamma = gamma
13 | self.alpha = alpha
14 | if isinstance(alpha, (float, int)):
15 | self.alpha = torch.Tensor([alpha, 1-alpha])
16 | if isinstance(alpha, list):
17 | self.alpha = torch.Tensor(alpha)
18 | self.size_average = size_average
19 |
20 | def forward(self, input, target):
21 | if input.dim() > 2:
22 | # N,C,H,W => N,C,H*W
23 | input = input.view(input.size(0), input.size(1), -1)
24 |
25 | # N,C,H*W => N,H*W,C
26 | input = input.transpose(1, 2)
27 |
28 | # N,H*W,C => N*H*W,C
29 | input = input.contiguous().view(-1, input.size(2))
30 |
31 | target = target.view(-1, 1)
32 | logpt = F.log_softmax(input,dim=-1)
33 | logpt = logpt.gather(1, target)
34 | logpt = logpt.view(-1)
35 | pt = Variable(logpt.data.exp())
36 |
37 | if self.alpha is not None:
38 | if self.alpha.type() != input.data.type():
39 | self.alpha = self.alpha.type_as(input.data)
40 | at = self.alpha.gather(0, target.data.view(-1))
41 | logpt = logpt * Variable(at)
42 |
43 | loss = -1 * (1-pt)**self.gamma * logpt
44 |
45 | if self.size_average:
46 | return loss.mean()
47 | else:
48 | return loss.sum()
49 |
50 | def dice_loss(logits, true, eps=1e-7):
51 | """Computes the Sørensen–Dice loss.
52 | Note that PyTorch optimizers minimize a loss. In this
53 | case, we would like to maximize the dice loss so we
54 | return the negated dice loss.
55 | Args:
56 | true: a tensor of shape [B, 1, H, W].
57 | logits: a tensor of shape [B, C, H, W]. Corresponds to
58 | the raw output or logits of the model.
59 | eps: added to the denominator for numerical stability.
60 | Returns:
61 | dice_loss: the Sørensen–Dice loss.
62 | """
63 | num_classes = logits.shape[1]
64 | if num_classes == 1:
65 | true_1_hot = torch.eye(num_classes + 1)[true.squeeze(1)]
66 | true_1_hot = true_1_hot.permute(0, 3, 1, 2).float()
67 | true_1_hot_f = true_1_hot[:, 0:1, :, :]
68 | true_1_hot_s = true_1_hot[:, 1:2, :, :]
69 | true_1_hot = torch.cat([true_1_hot_s, true_1_hot_f], dim=1)
70 | pos_prob = torch.sigmoid(logits)
71 | neg_prob = 1 - pos_prob
72 | probas = torch.cat([pos_prob, neg_prob], dim=1)
73 | else:
74 | true_1_hot = torch.eye(num_classes)[true.squeeze(1)]
75 | true_1_hot = true_1_hot.permute(0, 3, 1, 2).float()
76 | probas = F.softmax(logits, dim=1)
77 | true_1_hot = true_1_hot.type(logits.type())
78 | dims = (0,) + tuple(range(2, true.ndimension()))
79 | intersection = torch.sum(probas * true_1_hot, dims)
80 | cardinality = torch.sum(probas + true_1_hot, dims)
81 | dice_loss = (2. * intersection / (cardinality + eps)).mean()
82 | return (1 - dice_loss)
83 |
84 |
85 | def jaccard_loss(logits, true, eps=1e-7):
86 | """Computes the Jaccard loss, a.k.a the IoU loss.
87 | Note that PyTorch optimizers minimize a loss. In this
88 | case, we would like to maximize the jaccard loss so we
89 | return the negated jaccard loss.
90 | Args:
91 | true: a tensor of shape [B, H, W] or [B, 1, H, W].
92 | logits: a tensor of shape [B, C, H, W]. Corresponds to
93 | the raw output or logits of the model.
94 | eps: added to the denominator for numerical stability.
95 | Returns:
96 | jacc_loss: the Jaccard loss.
97 | """
98 | num_classes = logits.shape[1]
99 | if num_classes == 1:
100 | true_1_hot = torch.eye(num_classes + 1)[true.squeeze(1)]
101 | true_1_hot = true_1_hot.permute(0, 3, 1, 2).float()
102 | true_1_hot_f = true_1_hot[:, 0:1, :, :]
103 | true_1_hot_s = true_1_hot[:, 1:2, :, :]
104 | true_1_hot = torch.cat([true_1_hot_s, true_1_hot_f], dim=1)
105 | pos_prob = torch.sigmoid(logits)
106 | neg_prob = 1 - pos_prob
107 | probas = torch.cat([pos_prob, neg_prob], dim=1)
108 | else:
109 | true_1_hot = torch.eye(num_classes)[true.squeeze(1)]
110 | true_1_hot = true_1_hot.permute(0, 3, 1, 2).float()
111 | probas = F.softmax(logits, dim=1)
112 | true_1_hot = true_1_hot.type(logits.type())
113 | dims = (0,) + tuple(range(2, true.ndimension()))
114 | intersection = torch.sum(probas * true_1_hot, dims)
115 | cardinality = torch.sum(probas + true_1_hot, dims)
116 | union = cardinality - intersection
117 | jacc_loss = (intersection / (union + eps)).mean()
118 | return (1 - jacc_loss)
119 |
120 |
121 | class TverskyLoss(nn.Module):
122 | def __init__(self, alpha=0.5, beta=0.5, eps=1e-7, size_average=True):
123 | super(TverskyLoss, self).__init__()
124 | self.alpha = alpha
125 | self.beta = beta
126 | self.size_average = size_average
127 | self.eps = eps
128 |
129 | def forward(self, logits, true):
130 | """Computes the Tversky loss [1].
131 | Args:
132 | true: a tensor of shape [B, H, W] or [B, 1, H, W].
133 | logits: a tensor of shape [B, C, H, W]. Corresponds to
134 | the raw output or logits of the model.
135 | alpha: controls the penalty for false positives.
136 | beta: controls the penalty for false negatives.
137 | eps: added to the denominator for numerical stability.
138 | Returns:
139 | tversky_loss: the Tversky loss.
140 | Notes:
141 | alpha = beta = 0.5 => dice coeff
142 | alpha = beta = 1 => tanimoto coeff
143 | alpha + beta = 1 => F beta coeff
144 | References:
145 | [1]: https://arxiv.org/abs/1706.05721
146 | """
147 | num_classes = logits.shape[1]
148 | if num_classes == 1:
149 | true_1_hot = torch.eye(num_classes + 1)[true.squeeze(1)]
150 | true_1_hot = true_1_hot.permute(0, 3, 1, 2).float()
151 | true_1_hot_f = true_1_hot[:, 0:1, :, :]
152 | true_1_hot_s = true_1_hot[:, 1:2, :, :]
153 | true_1_hot = torch.cat([true_1_hot_s, true_1_hot_f], dim=1)
154 | pos_prob = torch.sigmoid(logits)
155 | neg_prob = 1 - pos_prob
156 | probas = torch.cat([pos_prob, neg_prob], dim=1)
157 | else:
158 | true_1_hot = torch.eye(num_classes)[true.squeeze(1)]
159 | true_1_hot = true_1_hot.permute(0, 3, 1, 2).float()
160 | probas = F.softmax(logits, dim=1)
161 |
162 | true_1_hot = true_1_hot.type(logits.type())
163 | dims = (0,) + tuple(range(2, true.ndimension()))
164 | intersection = torch.sum(probas * true_1_hot, dims)
165 | fps = torch.sum(probas * (1 - true_1_hot), dims)
166 | fns = torch.sum((1 - probas) * true_1_hot, dims)
167 | num = intersection
168 | denom = intersection + (self.alpha * fps) + (self.beta * fns)
169 | tversky_loss = (num / (denom + self.eps)).mean()
170 | return (1 - tversky_loss)
171 |
--------------------------------------------------------------------------------
/compare/IFNet.py:
--------------------------------------------------------------------------------
1 | # credits: https://github.com/GeoZcx/A-deeply-supervised-image-fusion-network-for-change-detection-in-remote-sensing-images
2 |
3 | import torch
4 | import torch.nn as nn
5 | from torchvision.models import vgg16
6 | import numpy as np
7 |
8 |
9 | class vgg16_base(nn.Module):
10 | def __init__(self):
11 | super(vgg16_base, self).__init__()
12 | features = list(vgg16(pretrained=True).features)[:30]
13 | self.features = nn.ModuleList(features).eval()
14 |
15 | def forward(self, x):
16 | results = []
17 | for ii, model in enumerate(self.features):
18 | x = model(x)
19 | if ii in {3, 8, 15, 22, 29}:
20 | results.append(x)
21 | return results
22 |
23 |
24 | class ChannelAttention(nn.Module):
25 | def __init__(self, in_channels, ratio=8):
26 | super(ChannelAttention, self).__init__()
27 | self.avg_pool = nn.AdaptiveAvgPool2d(1)
28 | self.max_pool = nn.AdaptiveMaxPool2d(1)
29 | self.fc1 = nn.Conv2d(in_channels, in_channels // ratio, 1, bias=False)
30 | self.relu1 = nn.ReLU()
31 | self.fc2 = nn.Conv2d(in_channels // ratio, in_channels, 1, bias=False)
32 | self.sigmod = nn.Sigmoid()
33 |
34 | def forward(self, x):
35 | avg_out = self.fc2(self.relu1(self.fc1(self.avg_pool(x))))
36 | max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x))))
37 | out = avg_out + max_out
38 | return self.sigmod(out)
39 |
40 |
41 | class SpatialAttention(nn.Module):
42 | def __init__(self):
43 | super(SpatialAttention, self).__init__()
44 | self.conv1 = nn.Conv2d(2, 1, 7, padding=3, bias=False)
45 | self.sigmoid = nn.Sigmoid()
46 |
47 | def forward(self, x):
48 | avg_out = torch.mean(x, dim=1, keepdim=True)
49 | max_out = torch.max(x, dim=1, keepdim=True, out=None)[0]
50 |
51 | x = torch.cat([avg_out, max_out], dim=1)
52 | x = self.conv1(x)
53 | return self.sigmoid(x)
54 |
55 |
56 | def conv2d_bn(in_channels, out_channels):
57 | return nn.Sequential(
58 | nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1),
59 | nn.PReLU(),
60 | nn.BatchNorm2d(out_channels),
61 | nn.Dropout(p=0.6),
62 | )
63 |
64 |
65 | class DSIFN(nn.Module):
66 | def __init__(self):
67 | super().__init__()
68 | self.t1_base = vgg16_base()
69 | self.t2_base = vgg16_base()
70 | self.sa1 = SpatialAttention()
71 | self.sa2 = SpatialAttention()
72 | self.sa3 = SpatialAttention()
73 | self.sa4 = SpatialAttention()
74 | self.sa5 = SpatialAttention()
75 |
76 | self.sigmoid = nn.Sigmoid()
77 |
78 | # branch1
79 | self.ca1 = ChannelAttention(in_channels=1024)
80 | self.bn_ca1 = nn.BatchNorm2d(1024)
81 | self.o1_conv1 = conv2d_bn(1024, 512)
82 | self.o1_conv2 = conv2d_bn(512, 512)
83 | self.bn_sa1 = nn.BatchNorm2d(512)
84 | self.o1_conv3 = nn.Conv2d(512, 2, 1)
85 | self.trans_conv1 = nn.ConvTranspose2d(512, 512, kernel_size=2, stride=2)
86 |
87 | # branch 2
88 | self.ca2 = ChannelAttention(in_channels=1536)
89 | self.bn_ca2 = nn.BatchNorm2d(1536)
90 | self.o2_conv1 = conv2d_bn(1536, 512)
91 | self.o2_conv2 = conv2d_bn(512, 256)
92 | self.o2_conv3 = conv2d_bn(256, 256)
93 | self.bn_sa2 = nn.BatchNorm2d(256)
94 | self.o2_conv4 = nn.Conv2d(256, 2, 1)
95 | self.trans_conv2 = nn.ConvTranspose2d(256, 256, kernel_size=2, stride=2)
96 |
97 | # branch 3
98 | self.ca3 = ChannelAttention(in_channels=768)
99 | self.o3_conv1 = conv2d_bn(768, 256)
100 | self.o3_conv2 = conv2d_bn(256, 128)
101 | self.o3_conv3 = conv2d_bn(128, 128)
102 | self.bn_sa3 = nn.BatchNorm2d(128)
103 | self.o3_conv4 = nn.Conv2d(128, 2, 1)
104 | self.trans_conv3 = nn.ConvTranspose2d(128, 128, kernel_size=2, stride=2)
105 |
106 | # branch 4
107 | self.ca4 = ChannelAttention(in_channels=384)
108 | self.o4_conv1 = conv2d_bn(384, 128)
109 | self.o4_conv2 = conv2d_bn(128, 64)
110 | self.o4_conv3 = conv2d_bn(64, 64)
111 | self.bn_sa4 = nn.BatchNorm2d(64)
112 | self.o4_conv4 = nn.Conv2d(64, 2, 1)
113 | self.trans_conv4 = nn.ConvTranspose2d(64, 64, kernel_size=2, stride=2)
114 |
115 | # branch 5
116 | self.ca5 = ChannelAttention(in_channels=192)
117 | self.o5_conv1 = conv2d_bn(192, 64)
118 | self.o5_conv2 = conv2d_bn(64, 32)
119 | self.o5_conv3 = conv2d_bn(32, 16)
120 | self.bn_sa5 = nn.BatchNorm2d(16)
121 | self.o5_conv4 = nn.Conv2d(16, 2, 1)
122 |
123 | def forward(self, t1_input, t2_input):
124 | t1_list = self.t1_base(t1_input)
125 | t2_list = self.t2_base(t2_input)
126 |
127 | t1_f_l3, t1_f_l8, t1_f_l15, t1_f_l22, t1_f_l29 = t1_list[0], t1_list[1], t1_list[2], t1_list[3], t1_list[4]
128 | t2_f_l3, t2_f_l8, t2_f_l15, t2_f_l22, t2_f_l29, = t2_list[0], t2_list[1], t2_list[2], t2_list[3], t2_list[4]
129 |
130 | x = torch.cat((t1_f_l29, t2_f_l29), dim=1)
131 | # optional to use channel attention module in the first combined feature
132 | # 在第一个深度特征叠加层之后可以选择使用或者不使用通道注意力模块
133 | # x = self.ca1(x) * x
134 | x = self.o1_conv1(x)
135 | x = self.o1_conv2(x)
136 | x = self.sa1(x) * x
137 | x = self.bn_sa1(x)
138 |
139 | # branch_1_out = self.sigmoid(self.o1_conv3(x))
140 | branch_1_out = self.o1_conv3(x)
141 |
142 | x = self.trans_conv1(x)
143 | x = torch.cat((x, t1_f_l22, t2_f_l22), dim=1)
144 | x = self.ca2(x) * x
145 | # According to the amount of the training data, appropriately reduce the use of conv layers to prevent overfitting
146 | # 根据训练数据的大小,适当减少conv层的使用来防止过拟合
147 | x = self.o2_conv1(x)
148 | x = self.o2_conv2(x)
149 | x = self.o2_conv3(x)
150 | x = self.sa2(x) * x
151 | x = self.bn_sa2(x)
152 |
153 | # branch_2_out = self.sigmoid(self.o2_conv4(x))
154 | branch_2_out = self.o2_conv4(x)
155 |
156 | x = self.trans_conv2(x)
157 | x = torch.cat((x, t1_f_l15, t2_f_l15), dim=1)
158 | x = self.ca3(x) * x
159 | x = self.o3_conv1(x)
160 | x = self.o3_conv2(x)
161 | x = self.o3_conv3(x)
162 | x = self.sa3(x) * x
163 | x = self.bn_sa3(x)
164 |
165 | # branch_3_out = self.sigmoid(self.o3_conv4(x))
166 | branch_3_out = self.o3_conv4(x)
167 |
168 | x = self.trans_conv3(x)
169 | x = torch.cat((x, t1_f_l8, t2_f_l8), dim=1)
170 | x = self.ca4(x) * x
171 | x = self.o4_conv1(x)
172 | x = self.o4_conv2(x)
173 | x = self.o4_conv3(x)
174 | x = self.sa4(x) * x
175 | x = self.bn_sa4(x)
176 |
177 | # branch_4_out = self.sigmoid(self.o4_conv4(x))
178 | branch_4_out = self.o4_conv4(x)
179 |
180 | x = self.trans_conv4(x)
181 | x = torch.cat((x, t1_f_l3, t2_f_l3), dim=1)
182 | x = self.ca5(x) * x
183 | x = self.o5_conv1(x)
184 | x = self.o5_conv2(x)
185 | x = self.o5_conv3(x)
186 | x = self.sa5(x) * x
187 | x = self.bn_sa5(x)
188 |
189 | # branch_5_out = self.sigmoid(self.o5_conv4(x))
190 | branch_5_out = self.o5_conv4(x)
191 |
192 | return [branch_5_out, branch_4_out, branch_3_out, branch_2_out, branch_1_out]
--------------------------------------------------------------------------------
/models/evaluator.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import matplotlib.pyplot as plt
4 |
5 | from models.networks import *
6 | from misc.metric_tool import ConfuseMatrixMeter
7 | from misc.logger_tool import Logger
8 | from utils_ import de_norm
9 | import utils_
10 | from tqdm import tqdm
11 |
12 | # Decide which device we want to run on
13 | # torch.cuda.current_device()
14 |
15 | # device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
16 |
17 |
18 | class CDEvaluator():
19 |
20 | def __init__(self, args, dataloader):
21 |
22 | self.dataloader = dataloader
23 |
24 | self.n_class = args.n_class
25 | # define G
26 | self.net_G = define_G(args=args, gpu_ids=args.gpu_ids)
27 | self.device = torch.device("cuda:%s" % args.gpu_ids[0] if torch.cuda.is_available() and len(args.gpu_ids)>0
28 | else "cpu")
29 | print(self.device)
30 |
31 | # define some other vars to record the training states
32 | self.running_metric = ConfuseMatrixMeter(n_class=self.n_class)
33 |
34 | # define logger file
35 | logger_path = os.path.join(args.checkpoint_dir, 'log_test.txt')
36 | self.logger = Logger(logger_path)
37 | self.logger.write_dict_str(args.__dict__)
38 |
39 |
40 | # training log
41 | self.epoch_acc = 0
42 | self.best_val_acc = 0.0
43 | self.best_epoch_id = 0
44 |
45 | self.steps_per_epoch = len(dataloader)
46 |
47 | self.G_pred = None
48 | self.pred_vis = None
49 | self.batch = None
50 | self.is_training = False
51 | self.batch_id = 0
52 | self.epoch_id = 0
53 | self.checkpoint_dir = args.checkpoint_dir
54 | self.vis_dir = args.vis_dir
55 |
56 | # check and create model dir
57 | if os.path.exists(self.checkpoint_dir) is False:
58 | os.mkdir(self.checkpoint_dir)
59 | if os.path.exists(self.vis_dir) is False:
60 | os.mkdir(self.vis_dir)
61 |
62 |
63 | def _load_checkpoint(self, checkpoint_name='best_ckpt.pt'):
64 |
65 | if os.path.exists(os.path.join(self.checkpoint_dir, checkpoint_name)):
66 | self.logger.write('loading last checkpoint...\n')
67 | # load the entire checkpoint
68 | checkpoint = torch.load(os.path.join(self.checkpoint_dir, checkpoint_name), map_location=self.device)
69 |
70 | self.net_G.load_state_dict(checkpoint['model_G_state_dict'])
71 |
72 | self.net_G.to(self.device)
73 |
74 | # update some other states
75 | self.best_val_acc = checkpoint['best_val_acc']
76 | self.best_epoch_id = checkpoint['best_epoch_id']
77 |
78 | self.logger.write('Eval Historical_best_acc = %.4f (at epoch %d)\n' %
79 | (self.best_val_acc, self.best_epoch_id))
80 | self.logger.write('\n')
81 |
82 | else:
83 | raise FileNotFoundError('no such checkpoint %s' % checkpoint_name)
84 |
85 |
86 | def _visualize_pred(self,args):
87 | # pred = torch.argmax(self.G_pred, dim=1, keepdim=True)
88 | if args.deep_supervision== True:
89 | pred =torch.argmax(self.G_pred[-1], dim=1, keepdim=True)
90 | else:
91 | pred = torch.argmax(self.G_pred, dim=1, keepdim=True)
92 | pred_vis = pred * 255
93 | return pred_vis
94 |
95 |
96 | def _update_metric(self,args):
97 | """
98 | update metric
99 | """
100 | target = self.batch['L'].to(self.device).detach()
101 | # G_pred = self.G_pred.detach()
102 | if args.deep_supervision == True:
103 | G_pred = self.G_pred[-1].detach()
104 | else:
105 | G_pred = self.G_pred.detach()
106 | G_pred = torch.argmax(G_pred, dim=1)
107 |
108 | current_score = self.running_metric.update_cm(pr=G_pred.cpu().numpy(), gt=target.cpu().numpy())
109 | return current_score
110 |
111 | def _collect_running_batch_states(self,args):
112 |
113 | running_acc = self._update_metric(args)#变化的精确度
114 |
115 | m = len(self.dataloader)
116 | #print('m:',m)
117 | #print('batch_id:',self.batch_id)
118 |
119 | if np.mod(self.batch_id, 100) == 1:#取模运算,同正为正,同负为负
120 | message = 'Is_training: %s. [%d,%d], running_mf1: %.5f\n' %\
121 | (self.is_training, self.batch_id, m, running_acc)
122 | self.logger.write(message)
123 |
124 | if np.mod(self.batch_id, 100) == 1:
125 | vis_input = utils_.make_numpy_grid(de_norm(self.batch['A']))
126 | vis_input2 = utils_.make_numpy_grid(de_norm(self.batch['B']))
127 |
128 | vis_pred = utils_.make_numpy_grid(self._visualize_pred(args))
129 |
130 | vis_gt = utils_.make_numpy_grid(self.batch['L'])
131 | vis = np.concatenate([vis_input, vis_input2, vis_pred, vis_gt], axis=0)
132 | vis = np.clip(vis, a_min=0.0, a_max=1.0)#限制数组中的值,也就是说clip这个函数将数组中的元素限制在a_min, a_max之间,大于a_max的就使得它等于 a_max,小于a_min,的就使得它等于a_min
133 | file_name = os.path.join(
134 | self.vis_dir, 'eval_' + str(self.batch_id)+'.jpg')
135 | plt.imsave(file_name, vis)
136 |
137 |
138 |
139 | def _collect_epoch_states(self):
140 |
141 | scores_dict = self.running_metric.get_scores()
142 |
143 | np.save(os.path.join(self.checkpoint_dir, 'scores_dict.npy'), scores_dict)
144 |
145 | self.epoch_acc = scores_dict['mf1']
146 | print(self.epoch_acc )
147 |
148 | with open(os.path.join(self.checkpoint_dir, '%s.txt' % (self.epoch_acc)),
149 | mode='a') as file:
150 | pass
151 |
152 | message = ''
153 | for k, v in scores_dict.items():
154 | message += '%s: %.5f ' % (k, v)
155 | self.logger.write('%s\n' % message) # save the message
156 |
157 | self.logger.write('\n')
158 |
159 | def _clear_cache(self):
160 | self.running_metric.clear()
161 |
162 | def _forward_pass(self,batch,args):
163 | self.batch = batch
164 | img_in1 = batch['A'].to(self.device)
165 | img_in2 = batch['B'].to(self.device)
166 | # sobel = batch['S'].to(self.device)
167 | # self.G_pred1, self.G_pred2, self.G_pred3 = self.net_G(img_in1, img_in2)
168 | # self.G_pred = self.G_pred1 + self.G_pred2 + self.G_pred3
169 | if args.loss_SD== True :
170 | # self.G_pred1, self.G_pred2, self.G_pred3 = self.net_G(img_in1, img_in2,sobel)
171 | # self.G_pred = self.G_pred1 + self.G_pred2 + self.G_pred3
172 | #cd_net
173 | # self.G_pred0, self.G_pred1, self.G_pred2, self.G_pred3, self.G_pred4 = self.net_G(img_in1, img_in2)
174 | # self.G_pred = self.G_pred0
175 |
176 | #DMINet
177 | self.G_pred0, self.G_pred1, self.G_pred2, self.G_pred3 = self.net_G(img_in1, img_in2)
178 | self.G_pred = self.G_pred0 + self.G_pred1
179 |
180 | else :
181 | # self.G_pred = self.net_G(img_in1, img_in2,sobel)
182 | self.G_pred = self.net_G(img_in1, img_in2)
183 |
184 | def eval_models(self,args,checkpoint_name='best_ckpt.pt'):
185 |
186 | self._load_checkpoint(checkpoint_name)
187 |
188 | ################## Eval ##################
189 | ##########################################
190 | self.logger.write('Begin evaluation...\n')
191 | self._clear_cache()#不返回任何值,清空running_metric
192 | self.is_training = False
193 | self.net_G.eval()
194 |
195 | #Iterate over data.遍历数据(原始)
196 | for self.batch_id, batch in enumerate(self.dataloader, 0):
197 | #name = batch['name']
198 | #print('process: %s' % name)
199 | with torch.no_grad():
200 | self._forward_pass(batch,args)
201 | self._collect_running_batch_states(args)
202 | self._collect_epoch_states()
203 |
204 |
205 |
--------------------------------------------------------------------------------
/compare/DASNet.py:
--------------------------------------------------------------------------------
1 | ###########################################################################
2 | # Created by: CASIA IVA
3 | # Email: jliu@nlpr.ia.ac.cn
4 | # Copyright (c) 2018
5 | ###########################################################################
6 | from __future__ import division
7 | import os
8 | import numpy as np
9 | import torch
10 | import torch.nn as nn
11 | from torch.nn.functional import upsample, normalize
12 | from torch.nn import Module, Sequential, Conv2d, ReLU,AdaptiveMaxPool2d, AdaptiveAvgPool2d, \
13 | NLLLoss, BCELoss, CrossEntropyLoss, AvgPool2d, MaxPool2d, Parameter, Linear, Sigmoid, Softmax, Dropout, Embedding
14 | # from attention import PAM_Module
15 | # from attention import CAM_Module
16 | from resbase import BaseNet
17 | import torch.nn.functional as F
18 |
19 | __all__ = ['DANet']
20 |
21 | class PAM_Module(Module):
22 | """ Position attention module"""
23 | #Ref from SAGAN
24 | def __init__(self, in_dim):
25 | super(PAM_Module, self).__init__()
26 | self.chanel_in = in_dim
27 |
28 | self.query_conv = Conv2d(in_channels=in_dim, out_channels=in_dim//8, kernel_size=1)
29 | self.key_conv = Conv2d(in_channels=in_dim, out_channels=in_dim//8, kernel_size=1)
30 | self.value_conv = Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1)
31 | self.gamma = Parameter(torch.zeros(1))
32 |
33 | self.softmax = Softmax(dim=-1)
34 |
35 | def forward(self, x):
36 | """
37 | inputs :
38 | x : input feature maps( B X C X H X W)
39 | returns :
40 | out : attention value + input feature
41 | attention: B X (HxW) X (HxW)
42 | """
43 | m_batchsize, C, height, width = x.size()
44 | proj_query = self.query_conv(x).view(m_batchsize, -1, width*height).permute(0, 2, 1)
45 | proj_key = self.key_conv(x).view(m_batchsize, -1, width*height)
46 | energy = torch.bmm(proj_query, proj_key)
47 | attention = self.softmax(energy)
48 | proj_value = self.value_conv(x).view(m_batchsize, -1, width*height)
49 |
50 | out = torch.bmm(proj_value, attention.permute(0, 2, 1))
51 | out = out.view(m_batchsize, C, height, width)
52 |
53 | out = self.gamma*out + x
54 | return out
55 |
56 |
57 | class CAM_Module(Module):
58 | """ Channel attention module"""
59 | def __init__(self, in_dim):
60 | super(CAM_Module, self).__init__()
61 | self.chanel_in = in_dim
62 |
63 |
64 | self.gamma = Parameter(torch.zeros(1))
65 | self.softmax = Softmax(dim=-1)
66 | def forward(self,x):
67 | """
68 | inputs :
69 | x : input feature maps( B X C X H X W)
70 | returns :
71 | out : attention value + input feature
72 | attention: B X C X C
73 | """
74 | m_batchsize, C, height, width = x.size()
75 | proj_query = x.view(m_batchsize, C, -1)
76 | proj_key = x.view(m_batchsize, C, -1).permute(0, 2, 1)
77 | energy = torch.bmm(proj_query, proj_key)
78 | energy_new = torch.max(energy, -1, keepdim=True)[0].expand_as(energy)-energy
79 | attention = self.softmax(energy_new)
80 | proj_value = x.view(m_batchsize, C, -1)
81 |
82 | out = torch.bmm(attention, proj_value)
83 | out = out.view(m_batchsize, C, height, width)
84 |
85 | out = self.gamma*out + x
86 | return out
87 |
88 | class DANet(BaseNet):
89 | r"""Fully Convolutional Networks for Semantic Segmentation
90 |
91 | Parameters
92 | ----------
93 | nclass : int
94 | Number of categories for the training dataset.
95 | backbone : string
96 | Pre-trained dilated backbone network type (default:'resnet50'; 'resnet50',
97 | 'resnet101' or 'resnet152').
98 | norm_layer : object
99 | Normalization layer used in backbone network (default: :class:`mxnet.gluon.nn.BatchNorm`;
100 |
101 |
102 | Reference:
103 |
104 | Long, Jonathan, Evan Shelhamer, and Trevor Darrell. "Fully convolutional networks
105 | for semantic segmentation." *CVPR*, 2015
106 |
107 | """
108 |
109 | def __init__(self, nclass, backbone, norm_layer=nn.BatchNorm2d, **kwargs):
110 | super(DANet, self).__init__(nclass, backbone, norm_layer=norm_layer, **kwargs)
111 | self.head = DANetHead(2048, nclass, norm_layer)
112 |
113 | def forward(self, x):
114 |
115 | _, _, c3, c4 = self.base_forward(x)
116 |
117 | x = self.head(c4)
118 | x = list(x)
119 |
120 | return x[0],x[1],x[2]
121 |
122 |
123 | class DANetHead(nn.Module):
124 | def __init__(self, in_channels, out_channels, norm_layer):
125 | super(DANetHead, self).__init__()
126 | inter_channels = in_channels // 4
127 | self.conv5a = nn.Sequential(nn.Conv2d(in_channels, inter_channels, 3, padding=1, bias=False),
128 | norm_layer(inter_channels),
129 | nn.ReLU())
130 |
131 | self.conv5c = nn.Sequential(nn.Conv2d(in_channels, inter_channels, 3, padding=1, bias=False),
132 | norm_layer(inter_channels),
133 | nn.ReLU())
134 |
135 | self.sa = PAM_Module(inter_channels)
136 | self.sc = CAM_Module(inter_channels)
137 | self.conv51 = nn.Sequential(nn.Conv2d(inter_channels, inter_channels, 3, padding=1, bias=False),
138 | norm_layer(inter_channels),
139 | nn.ReLU())
140 | self.conv52 = nn.Sequential(nn.Conv2d(inter_channels, inter_channels, 3, padding=1, bias=False),
141 | norm_layer(inter_channels),
142 | nn.ReLU())
143 |
144 | self.conv6 = nn.Sequential(nn.Dropout2d(0.1, False), nn.Conv2d(512, out_channels, 1))
145 | self.conv7 = nn.Sequential(nn.Dropout2d(0.1, False), nn.Conv2d(512, out_channels, 1))
146 |
147 | self.conv8 = nn.Sequential(nn.Dropout2d(0.1, False), nn.Conv2d(512, out_channels, 1))
148 |
149 | def forward(self, x):
150 | feat1 = self.conv5a(x)
151 | sa_feat = self.sa(feat1)
152 | sa_conv = self.conv51(sa_feat)
153 | sa_output = self.conv6(sa_conv)
154 | feat2 = self.conv5c(x)
155 | sc_feat = self.sc(feat2)
156 | sc_conv = self.conv52(sc_feat)
157 | sc_output = self.conv7(sc_conv)
158 |
159 | feat_sum = sa_conv + sc_conv
160 |
161 | sasc_output = self.conv8(feat_sum)
162 |
163 | return sa_output,sc_output,sasc_output
164 |
165 | def cnn():
166 | model = DANet(512, backbone='resnet50')
167 | return model
168 |
169 | class SiameseNet(nn.Module):
170 | def __init__(self,norm_flag = 'l2'):
171 | super(SiameseNet, self).__init__()
172 | self.CNN = cnn()
173 | if norm_flag == 'l2':
174 | self.norm = F.normalize
175 | if norm_flag == 'exp':
176 | self.norm = nn.Softmax2d()
177 | '''''''''
178 | def forward(self,t0,t1):
179 | out_t0_embedding = self.CNN(t0)
180 | out_t1_embedding = self.CNN(t1)
181 | #out_t0_conv5_norm,out_t1_conv5_norm = self.norm(out_t0_conv5),self.norm(out_t1_conv5)
182 | #out_t0_fc7_norm,out_t1_fc7_norm = self.norm(out_t0_fc7),self.norm(out_t1_fc7)
183 | out_t0_embedding_norm,out_t1_embedding_norm = self.norm(out_t0_embedding),self.norm(out_t1_embedding)
184 | return [out_t0_embedding_norm,out_t1_embedding_norm]
185 | '''''''''
186 |
187 | def forward(self,t0,t1):
188 |
189 | out_t0_conv5,out_t0_fc7,out_t0_embedding = self.CNN(t0)
190 | out_t1_conv5,out_t1_fc7,out_t1_embedding = self.CNN(t1)
191 | out_t0_conv5_norm,out_t1_conv5_norm = self.norm(out_t0_conv5,2,dim=1),self.norm(out_t1_conv5,2,dim=1)
192 | out_t0_fc7_norm,out_t1_fc7_norm = self.norm(out_t0_fc7,2,dim=1),self.norm(out_t1_fc7,2,dim=1)
193 | out_t0_embedding_norm,out_t1_embedding_norm = self.norm(out_t0_embedding,2,dim=1),self.norm(out_t1_embedding,2,dim=1)
194 | return [out_t0_conv5_norm,out_t1_conv5_norm],[out_t0_fc7_norm,out_t1_fc7_norm],[out_t0_embedding_norm,out_t1_embedding_norm]
195 |
196 |
197 | if __name__ == '__main__':
198 | m = SiameseNet()
199 | print('gg')
--------------------------------------------------------------------------------
/compare/FC_Siam_conc.py:
--------------------------------------------------------------------------------
1 | # Rodrigo Caye Daudt
2 | # https://rcdaudt.github.io/
3 | # Daudt, R. C., Le Saux, B., & Boulch, A. "Fully convolutional siamese networks for change detection". In 2018 25th IEEE International Conference on Image Processing (ICIP) (pp. 4063-4067). IEEE.
4 |
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 | from torch.nn.modules.padding import ReplicationPad2d
9 |
10 | class SiamUnet_conc(nn.Module):
11 | """SiamUnet_conc segmentation network."""
12 |
13 | def __init__(self, input_nbr, label_nbr):
14 | super(SiamUnet_conc, self).__init__()
15 |
16 | self.input_nbr = input_nbr
17 |
18 | self.conv11 = nn.Conv2d(input_nbr, 16, kernel_size=3, padding=1)
19 | self.bn11 = nn.BatchNorm2d(16)
20 | self.do11 = nn.Dropout2d(p=0.2)
21 | self.conv12 = nn.Conv2d(16, 16, kernel_size=3, padding=1)
22 | self.bn12 = nn.BatchNorm2d(16)
23 | self.do12 = nn.Dropout2d(p=0.2)
24 |
25 | self.conv21 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
26 | self.bn21 = nn.BatchNorm2d(32)
27 | self.do21 = nn.Dropout2d(p=0.2)
28 | self.conv22 = nn.Conv2d(32, 32, kernel_size=3, padding=1)
29 | self.bn22 = nn.BatchNorm2d(32)
30 | self.do22 = nn.Dropout2d(p=0.2)
31 |
32 | self.conv31 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
33 | self.bn31 = nn.BatchNorm2d(64)
34 | self.do31 = nn.Dropout2d(p=0.2)
35 | self.conv32 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
36 | self.bn32 = nn.BatchNorm2d(64)
37 | self.do32 = nn.Dropout2d(p=0.2)
38 | self.conv33 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
39 | self.bn33 = nn.BatchNorm2d(64)
40 | self.do33 = nn.Dropout2d(p=0.2)
41 |
42 | self.conv41 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
43 | self.bn41 = nn.BatchNorm2d(128)
44 | self.do41 = nn.Dropout2d(p=0.2)
45 | self.conv42 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
46 | self.bn42 = nn.BatchNorm2d(128)
47 | self.do42 = nn.Dropout2d(p=0.2)
48 | self.conv43 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
49 | self.bn43 = nn.BatchNorm2d(128)
50 | self.do43 = nn.Dropout2d(p=0.2)
51 |
52 | self.upconv4 = nn.ConvTranspose2d(128, 128, kernel_size=3, padding=1, stride=2, output_padding=1)
53 |
54 | self.conv43d = nn.ConvTranspose2d(384, 128, kernel_size=3, padding=1)
55 | self.bn43d = nn.BatchNorm2d(128)
56 | self.do43d = nn.Dropout2d(p=0.2)
57 | self.conv42d = nn.ConvTranspose2d(128, 128, kernel_size=3, padding=1)
58 | self.bn42d = nn.BatchNorm2d(128)
59 | self.do42d = nn.Dropout2d(p=0.2)
60 | self.conv41d = nn.ConvTranspose2d(128, 64, kernel_size=3, padding=1)
61 | self.bn41d = nn.BatchNorm2d(64)
62 | self.do41d = nn.Dropout2d(p=0.2)
63 |
64 | self.upconv3 = nn.ConvTranspose2d(64, 64, kernel_size=3, padding=1, stride=2, output_padding=1)
65 |
66 | self.conv33d = nn.ConvTranspose2d(192, 64, kernel_size=3, padding=1)
67 | self.bn33d = nn.BatchNorm2d(64)
68 | self.do33d = nn.Dropout2d(p=0.2)
69 | self.conv32d = nn.ConvTranspose2d(64, 64, kernel_size=3, padding=1)
70 | self.bn32d = nn.BatchNorm2d(64)
71 | self.do32d = nn.Dropout2d(p=0.2)
72 | self.conv31d = nn.ConvTranspose2d(64, 32, kernel_size=3, padding=1)
73 | self.bn31d = nn.BatchNorm2d(32)
74 | self.do31d = nn.Dropout2d(p=0.2)
75 |
76 | self.upconv2 = nn.ConvTranspose2d(32, 32, kernel_size=3, padding=1, stride=2, output_padding=1)
77 |
78 | self.conv22d = nn.ConvTranspose2d(96, 32, kernel_size=3, padding=1)
79 | self.bn22d = nn.BatchNorm2d(32)
80 | self.do22d = nn.Dropout2d(p=0.2)
81 | self.conv21d = nn.ConvTranspose2d(32, 16, kernel_size=3, padding=1)
82 | self.bn21d = nn.BatchNorm2d(16)
83 | self.do21d = nn.Dropout2d(p=0.2)
84 |
85 | self.upconv1 = nn.ConvTranspose2d(16, 16, kernel_size=3, padding=1, stride=2, output_padding=1)
86 |
87 | self.conv12d = nn.ConvTranspose2d(48, 16, kernel_size=3, padding=1)
88 | self.bn12d = nn.BatchNorm2d(16)
89 | self.do12d = nn.Dropout2d(p=0.2)
90 | self.conv11d = nn.ConvTranspose2d(16, label_nbr, kernel_size=3, padding=1)
91 |
92 | self.sm = nn.LogSoftmax(dim=1)
93 |
94 | def forward(self, x1, x2):
95 |
96 | """Forward method."""
97 | # Stage 1
98 | x11 = self.do11(F.relu(self.bn11(self.conv11(x1))))
99 | x12_1 = self.do12(F.relu(self.bn12(self.conv12(x11))))
100 | x1p = F.max_pool2d(x12_1, kernel_size=2, stride=2)
101 |
102 |
103 | # Stage 2
104 | x21 = self.do21(F.relu(self.bn21(self.conv21(x1p))))
105 | x22_1 = self.do22(F.relu(self.bn22(self.conv22(x21))))
106 | x2p = F.max_pool2d(x22_1, kernel_size=2, stride=2)
107 |
108 | # Stage 3
109 | x31 = self.do31(F.relu(self.bn31(self.conv31(x2p))))
110 | x32 = self.do32(F.relu(self.bn32(self.conv32(x31))))
111 | x33_1 = self.do33(F.relu(self.bn33(self.conv33(x32))))
112 | x3p = F.max_pool2d(x33_1, kernel_size=2, stride=2)
113 |
114 | # Stage 4
115 | x41 = self.do41(F.relu(self.bn41(self.conv41(x3p))))
116 | x42 = self.do42(F.relu(self.bn42(self.conv42(x41))))
117 | x43_1 = self.do43(F.relu(self.bn43(self.conv43(x42))))
118 | x4p = F.max_pool2d(x43_1, kernel_size=2, stride=2)
119 |
120 |
121 | ####################################################
122 | # Stage 1
123 | x11 = self.do11(F.relu(self.bn11(self.conv11(x2))))
124 | x12_2 = self.do12(F.relu(self.bn12(self.conv12(x11))))
125 | x1p = F.max_pool2d(x12_2, kernel_size=2, stride=2)
126 |
127 | # Stage 2
128 | x21 = self.do21(F.relu(self.bn21(self.conv21(x1p))))
129 | x22_2 = self.do22(F.relu(self.bn22(self.conv22(x21))))
130 | x2p = F.max_pool2d(x22_2, kernel_size=2, stride=2)
131 |
132 | # Stage 3
133 | x31 = self.do31(F.relu(self.bn31(self.conv31(x2p))))
134 | x32 = self.do32(F.relu(self.bn32(self.conv32(x31))))
135 | x33_2 = self.do33(F.relu(self.bn33(self.conv33(x32))))
136 | x3p = F.max_pool2d(x33_2, kernel_size=2, stride=2)
137 |
138 | # Stage 4
139 | x41 = self.do41(F.relu(self.bn41(self.conv41(x3p))))
140 | x42 = self.do42(F.relu(self.bn42(self.conv42(x41))))
141 | x43_2 = self.do43(F.relu(self.bn43(self.conv43(x42))))
142 | x4p = F.max_pool2d(x43_2, kernel_size=2, stride=2)
143 |
144 |
145 | ####################################################
146 | # Stage 4d
147 | x4d = self.upconv4(x4p)
148 | pad4 = ReplicationPad2d((0, x43_1.size(3) - x4d.size(3), 0, x43_1.size(2) - x4d.size(2)))
149 | x4d = torch.cat((pad4(x4d), x43_1, x43_2), 1)
150 | x43d = self.do43d(F.relu(self.bn43d(self.conv43d(x4d))))
151 | x42d = self.do42d(F.relu(self.bn42d(self.conv42d(x43d))))
152 | x41d = self.do41d(F.relu(self.bn41d(self.conv41d(x42d))))
153 |
154 | # Stage 3d
155 | x3d = self.upconv3(x41d)
156 | pad3 = ReplicationPad2d((0, x33_1.size(3) - x3d.size(3), 0, x33_1.size(2) - x3d.size(2)))
157 | x3d = torch.cat((pad3(x3d), x33_1, x33_2), 1)
158 | x33d = self.do33d(F.relu(self.bn33d(self.conv33d(x3d))))
159 | x32d = self.do32d(F.relu(self.bn32d(self.conv32d(x33d))))
160 | x31d = self.do31d(F.relu(self.bn31d(self.conv31d(x32d))))
161 |
162 | # Stage 2d
163 | x2d = self.upconv2(x31d)
164 | pad2 = ReplicationPad2d((0, x22_1.size(3) - x2d.size(3), 0, x22_1.size(2) - x2d.size(2)))
165 | x2d = torch.cat((pad2(x2d), x22_1, x22_2), 1)
166 | x22d = self.do22d(F.relu(self.bn22d(self.conv22d(x2d))))
167 | x21d = self.do21d(F.relu(self.bn21d(self.conv21d(x22d))))
168 |
169 | # Stage 1d
170 | x1d = self.upconv1(x21d)
171 | pad1 = ReplicationPad2d((0, x12_1.size(3) - x1d.size(3), 0, x12_1.size(2) - x1d.size(2)))
172 | x1d = torch.cat((pad1(x1d), x12_1, x12_2), 1)
173 | x12d = self.do12d(F.relu(self.bn12d(self.conv12d(x1d))))
174 | x11d = self.conv11d(x12d)
175 |
176 | #Softmax layer is embedded in the loss layer
177 | #out = self.sm(x11d)
178 | output = []
179 | output.append(x11d)
180 | output = output[-1]
181 |
182 | return output
183 |
--------------------------------------------------------------------------------
/compare/FC_Siam_diff.py:
--------------------------------------------------------------------------------
1 | # Rodrigo Caye Daudt
2 | # https://rcdaudt.github.io/
3 | # Daudt, R. C., Le Saux, B., & Boulch, A. "Fully convolutional siamese networks for change detection". In 2018 25th IEEE International Conference on Image Processing (ICIP) (pp. 4063-4067). IEEE.
4 |
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 | from torch.nn.modules.padding import ReplicationPad2d
9 |
10 | class SiamUnet_diff(nn.Module):
11 | """SiamUnet_diff segmentation network."""
12 |
13 | def __init__(self, input_nbr, label_nbr):
14 | super(SiamUnet_diff, self).__init__()
15 |
16 | self.input_nbr = input_nbr
17 |
18 | self.conv11 = nn.Conv2d(input_nbr, 16, kernel_size=3, padding=1)
19 | self.bn11 = nn.BatchNorm2d(16)
20 | self.do11 = nn.Dropout2d(p=0.01) #原始p=0.2
21 | self.conv12 = nn.Conv2d(16, 16, kernel_size=3, padding=1)
22 | self.bn12 = nn.BatchNorm2d(16)
23 | self.do12 = nn.Dropout2d(p=0.01)
24 |
25 | self.conv21 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
26 | self.bn21 = nn.BatchNorm2d(32)
27 | self.do21 = nn.Dropout2d(p=0.01)
28 | self.conv22 = nn.Conv2d(32, 32, kernel_size=3, padding=1)
29 | self.bn22 = nn.BatchNorm2d(32)
30 | self.do22 = nn.Dropout2d(p=0.01)
31 |
32 | self.conv31 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
33 | self.bn31 = nn.BatchNorm2d(64)
34 | self.do31 = nn.Dropout2d(p=0.01)
35 | self.conv32 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
36 | self.bn32 = nn.BatchNorm2d(64)
37 | self.do32 = nn.Dropout2d(p=0.01)
38 | self.conv33 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
39 | self.bn33 = nn.BatchNorm2d(64)
40 | self.do33 = nn.Dropout2d(p=0.01)
41 |
42 | self.conv41 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
43 | self.bn41 = nn.BatchNorm2d(128)
44 | self.do41 = nn.Dropout2d(p=0.01)
45 | self.conv42 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
46 | self.bn42 = nn.BatchNorm2d(128)
47 | self.do42 = nn.Dropout2d(p=0.01)
48 | self.conv43 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
49 | self.bn43 = nn.BatchNorm2d(128)
50 | self.do43 = nn.Dropout2d(p=0.01)
51 |
52 | self.upconv4 = nn.ConvTranspose2d(128, 128, kernel_size=3, padding=1, stride=2, output_padding=1)
53 |
54 | self.conv43d = nn.ConvTranspose2d(256, 128, kernel_size=3, padding=1)
55 | self.bn43d = nn.BatchNorm2d(128)
56 | self.do43d = nn.Dropout2d(p=0.01)
57 | self.conv42d = nn.ConvTranspose2d(128, 128, kernel_size=3, padding=1)
58 | self.bn42d = nn.BatchNorm2d(128)
59 | self.do42d = nn.Dropout2d(p=0.01)
60 | self.conv41d = nn.ConvTranspose2d(128, 64, kernel_size=3, padding=1)
61 | self.bn41d = nn.BatchNorm2d(64)
62 | self.do41d = nn.Dropout2d(p=0.01)
63 |
64 | self.upconv3 = nn.ConvTranspose2d(64, 64, kernel_size=3, padding=1, stride=2, output_padding=1)
65 |
66 | self.conv33d = nn.ConvTranspose2d(128, 64, kernel_size=3, padding=1)
67 | self.bn33d = nn.BatchNorm2d(64)
68 | self.do33d = nn.Dropout2d(p=0.01)
69 | self.conv32d = nn.ConvTranspose2d(64, 64, kernel_size=3, padding=1)
70 | self.bn32d = nn.BatchNorm2d(64)
71 | self.do32d = nn.Dropout2d(p=0.01)
72 | self.conv31d = nn.ConvTranspose2d(64, 32, kernel_size=3, padding=1)
73 | self.bn31d = nn.BatchNorm2d(32)
74 | self.do31d = nn.Dropout2d(p=0.01)
75 |
76 | self.upconv2 = nn.ConvTranspose2d(32, 32, kernel_size=3, padding=1, stride=2, output_padding=1)
77 |
78 | self.conv22d = nn.ConvTranspose2d(64, 32, kernel_size=3, padding=1)
79 | self.bn22d = nn.BatchNorm2d(32)
80 | self.do22d = nn.Dropout2d(p=0.01)
81 | self.conv21d = nn.ConvTranspose2d(32, 16, kernel_size=3, padding=1)
82 | self.bn21d = nn.BatchNorm2d(16)
83 | self.do21d = nn.Dropout2d(p=0.01)
84 |
85 | self.upconv1 = nn.ConvTranspose2d(16, 16, kernel_size=3, padding=1, stride=2, output_padding=1)
86 |
87 | self.conv12d = nn.ConvTranspose2d(32, 16, kernel_size=3, padding=1)
88 | self.bn12d = nn.BatchNorm2d(16)
89 | self.do12d = nn.Dropout2d(p=0.01)
90 | self.conv11d = nn.ConvTranspose2d(16, label_nbr, kernel_size=3, padding=1)
91 |
92 | self.sm = nn.LogSoftmax(dim=1)
93 |
94 | def forward(self, x1, x2):
95 |
96 | LeakyReLU =nn.LeakyReLU(negative_slope=5e-2)
97 | """Forward method."""
98 | # Stage 1
99 | x11 = self.do11(F.relu(self.bn11(self.conv11(x1)))) #原始为F.relu
100 | x12_1 = self.do12(F.relu(self.bn12(self.conv12(x11))))
101 | x1p = F.max_pool2d(x12_1, kernel_size=2, stride=2)
102 |
103 |
104 | # Stage 2
105 | x21 = self.do21(F.relu(self.bn21(self.conv21(x1p))))
106 | x22_1 = self.do22(F.relu(self.bn22(self.conv22(x21))))
107 | x2p = F.max_pool2d(x22_1, kernel_size=2, stride=2)
108 |
109 | # Stage 3
110 | x31 = self.do31(F.relu(self.bn31(self.conv31(x2p))))
111 | x32 = self.do32(F.relu(self.bn32(self.conv32(x31))))
112 | x33_1 = self.do33(F.relu(self.bn33(self.conv33(x32))))
113 | x3p = F.max_pool2d(x33_1, kernel_size=2, stride=2)
114 |
115 | # Stage 4
116 | x41 = self.do41(F.relu(self.bn41(self.conv41(x3p))))
117 | x42 = self.do42(F.relu(self.bn42(self.conv42(x41))))
118 | x43_1 = self.do43(F.relu(self.bn43(self.conv43(x42))))
119 | x4p = F.max_pool2d(x43_1, kernel_size=2, stride=2)
120 |
121 | ####################################################
122 | # Stage 1
123 | x11 = self.do11(F.relu(self.bn11(self.conv11(x2))))
124 | x12_2 = self.do12(F.relu(self.bn12(self.conv12(x11))))
125 | x1p = F.max_pool2d(x12_2, kernel_size=2, stride=2)
126 |
127 |
128 | # Stage 2
129 | x21 = self.do21(F.relu(self.bn21(self.conv21(x1p))))
130 | x22_2 = self.do22(F.relu(self.bn22(self.conv22(x21))))
131 | x2p = F.max_pool2d(x22_2, kernel_size=2, stride=2)
132 |
133 | # Stage 3
134 | x31 = self.do31(F.relu(self.bn31(self.conv31(x2p))))
135 | x32 = self.do32(F.relu(self.bn32(self.conv32(x31))))
136 | x33_2 = self.do33(F.relu(self.bn33(self.conv33(x32))))
137 | x3p = F.max_pool2d(x33_2, kernel_size=2, stride=2)
138 |
139 | # Stage 4
140 | x41 = self.do41(F.relu(self.bn41(self.conv41(x3p))))
141 | x42 = self.do42(F.relu(self.bn42(self.conv42(x41))))
142 | x43_2 = self.do43(F.relu(self.bn43(self.conv43(x42))))
143 | x4p = F.max_pool2d(x43_2, kernel_size=2, stride=2)
144 |
145 |
146 |
147 | # Stage 4d
148 | x4d = self.upconv4(x4p)
149 | pad4 = ReplicationPad2d((0, x43_1.size(3) - x4d.size(3), 0, x43_1.size(2) - x4d.size(2)))
150 | x4d = torch.cat((pad4(x4d), torch.abs(x43_1 - x43_2)), 1)
151 | x43d = self.do43d(F.relu(self.bn43d(self.conv43d(x4d))))
152 | x42d = self.do42d(F.relu(self.bn42d(self.conv42d(x43d))))
153 | x41d = self.do41d(F.relu(self.bn41d(self.conv41d(x42d))))
154 |
155 | # Stage 3d
156 | x3d = self.upconv3(x41d)
157 | pad3 = ReplicationPad2d((0, x33_1.size(3) - x3d.size(3), 0, x33_1.size(2) - x3d.size(2)))
158 | x3d = torch.cat((pad3(x3d), torch.abs(x33_1 - x33_2)), 1)
159 | x33d = self.do33d(F.relu(self.bn33d(self.conv33d(x3d))))
160 | x32d = self.do32d(F.relu(self.bn32d(self.conv32d(x33d))))
161 | x31d = self.do31d(F.relu(self.bn31d(self.conv31d(x32d))))
162 |
163 | # Stage 2d
164 | x2d = self.upconv2(x31d)
165 | pad2 = ReplicationPad2d((0, x22_1.size(3) - x2d.size(3), 0, x22_1.size(2) - x2d.size(2)))
166 | x2d = torch.cat((pad2(x2d), torch.abs(x22_1 - x22_2)), 1)
167 | x22d = self.do22d(F.relu(self.bn22d(self.conv22d(x2d))))
168 | x21d = self.do21d(F.relu(self.bn21d(self.conv21d(x22d))))
169 |
170 | # Stage 1d
171 | x1d = self.upconv1(x21d)
172 | pad1 = ReplicationPad2d((0, x12_1.size(3) - x1d.size(3), 0, x12_1.size(2) - x1d.size(2)))
173 | x1d = torch.cat((pad1(x1d), torch.abs(x12_1 - x12_2)), 1)
174 | x12d = self.do12d(F.relu(self.bn12d(self.conv12d(x1d))))
175 | x11d = self.conv11d(x12d)
176 | #out = self.sm(x11d)
177 |
178 | output = []
179 | output.append(x11d)
180 | output = output[-1]
181 |
182 | return output
--------------------------------------------------------------------------------
/utils/transforms.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import random
3 | import numpy as np
4 |
5 | from PIL import Image, ImageOps, ImageFilter
6 | import torchvision.transforms as transforms
7 |
8 | class Normalize(object):
9 | """Normalize a tensor image with mean and standard deviation.
10 | Args:
11 | mean (tuple): means for each channel.
12 | std (tuple): standard deviations for each channel.
13 | """
14 | def __init__(self, mean=(0., 0., 0.), std=(1., 1., 1.)):
15 | self.mean = mean
16 | self.std = std
17 |
18 | def __call__(self, sample):
19 | img = sample['image']
20 | mask = sample['label']
21 | img = np.array(img).astype(np.float32)
22 | mask = np.array(mask).astype(np.float32)
23 | img /= 255.0
24 | img -= self.mean
25 | img /= self.std
26 |
27 | return {'image': img,
28 | 'label': mask}
29 |
30 |
31 | class ToTensor(object):
32 | """Convert ndarrays in sample to Tensors."""
33 |
34 | def __call__(self, sample):
35 | # swap color axis because
36 | # numpy image: H x W x C
37 | # torch image: C X H X W
38 | img1 = sample['image'][0]
39 | img2 = sample['image'][1]
40 | mask = sample['label']
41 | img1 = np.array(img1).astype(np.float32).transpose((2, 0, 1))
42 | img2 = np.array(img2).astype(np.float32).transpose((2, 0, 1))
43 | mask = np.array(mask).astype(np.float32) / 255.0
44 |
45 | img1 = torch.from_numpy(img1).float()
46 | img2 = torch.from_numpy(img2).float()
47 | mask = torch.from_numpy(mask).float()
48 |
49 | return {'image': (img1, img2),
50 | 'label': mask}
51 |
52 |
53 | class RandomHorizontalFlip(object):
54 | def __call__(self, sample):
55 | img1 = sample['image'][0]
56 | img2 = sample['image'][1]
57 | mask = sample['label']
58 | if random.random() < 0.5:
59 | img1 = img1.transpose(Image.FLIP_LEFT_RIGHT)
60 | img2 = img2.transpose(Image.FLIP_LEFT_RIGHT)
61 | mask = mask.transpose(Image.FLIP_LEFT_RIGHT)
62 |
63 | return {'image': (img1, img2),
64 | 'label': mask}
65 |
66 | class RandomVerticalFlip(object):
67 | def __call__(self, sample):
68 | img1 = sample['image'][0]
69 | img2 = sample['image'][1]
70 | mask = sample['label']
71 | if random.random() < 0.5:
72 | img1 = img1.transpose(Image.FLIP_TOP_BOTTOM)
73 | img2 = img2.transpose(Image.FLIP_TOP_BOTTOM)
74 | mask = mask.transpose(Image.FLIP_TOP_BOTTOM)
75 |
76 | return {'image': (img1, img2),
77 | 'label': mask}
78 |
79 | class RandomFixRotate(object):
80 | def __init__(self):
81 | self.degree = [Image.ROTATE_90, Image.ROTATE_180, Image.ROTATE_270]
82 |
83 | def __call__(self, sample):
84 | img1 = sample['image'][0]
85 | img2 = sample['image'][1]
86 | mask = sample['label']
87 | if random.random() < 0.75:
88 | rotate_degree = random.choice(self.degree)
89 | img1 = img1.transpose(rotate_degree)
90 | img2 = img2.transpose(rotate_degree)
91 | mask = mask.transpose(rotate_degree)
92 |
93 | return {'image': (img1, img2),
94 | 'label': mask}
95 |
96 |
97 | class RandomRotate(object):
98 | def __init__(self, degree):
99 | self.degree = degree
100 |
101 | def __call__(self, sample):
102 | img1 = sample['image'][0]
103 | img2 = sample['image'][1]
104 | mask = sample['label']
105 | rotate_degree = random.uniform(-1*self.degree, self.degree)
106 | img1 = img1.rotate(rotate_degree, Image.BILINEAR)
107 | img2 = img2.rotate(rotate_degree, Image.BILINEAR)
108 | mask = mask.rotate(rotate_degree, Image.NEAREST)
109 |
110 | return {'image': (img1, img2),
111 | 'label': mask}
112 |
113 |
114 | class RandomGaussianBlur(object):
115 | def __call__(self, sample):
116 | img1 = sample['image'][0]
117 | img2 = sample['image'][1]
118 | mask = sample['label']
119 | if random.random() < 0.5:
120 | img1 = img1.filter(ImageFilter.GaussianBlur(
121 | radius=random.random()))
122 | img2 = img2.filter(ImageFilter.GaussianBlur(
123 | radius=random.random()))
124 |
125 | return {'image': (img1, img2),
126 | 'label': mask}
127 |
128 |
129 | class RandomScaleCrop(object):
130 | def __init__(self, base_size, crop_size, fill=0):
131 | self.base_size = base_size
132 | self.crop_size = crop_size
133 | self.fill = fill
134 |
135 | def __call__(self, sample):
136 | img = sample['image']
137 | mask = sample['label']
138 | # random scale (short edge)
139 | short_size = random.randint(int(self.base_size * 0.5), int(self.base_size * 2.0))
140 | w, h = img.size
141 | if h > w:
142 | ow = short_size
143 | oh = int(1.0 * h * ow / w)
144 | else:
145 | oh = short_size
146 | ow = int(1.0 * w * oh / h)
147 | img = img.resize((ow, oh), Image.BILINEAR)
148 | mask = mask.resize((ow, oh), Image.NEAREST)
149 | # pad crop
150 | if short_size < self.crop_size:
151 | padh = self.crop_size - oh if oh < self.crop_size else 0
152 | padw = self.crop_size - ow if ow < self.crop_size else 0
153 | img = ImageOps.expand(img, border=(0, 0, padw, padh), fill=0)
154 | mask = ImageOps.expand(mask, border=(0, 0, padw, padh), fill=self.fill)
155 | # random crop crop_size
156 | w, h = img.size
157 | x1 = random.randint(0, w - self.crop_size)
158 | y1 = random.randint(0, h - self.crop_size)
159 | img = img.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size))
160 | mask = mask.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size))
161 |
162 | return {'image': img,
163 | 'label': mask}
164 |
165 |
166 | class FixScaleCrop(object):
167 | def __init__(self, crop_size):
168 | self.crop_size = crop_size
169 |
170 | def __call__(self, sample):
171 | img = sample['image']
172 | mask = sample['label']
173 | w, h = img.size
174 | if w > h:
175 | oh = self.crop_size
176 | ow = int(1.0 * w * oh / h)
177 | else:
178 | ow = self.crop_size
179 | oh = int(1.0 * h * ow / w)
180 | img = img.resize((ow, oh), Image.BILINEAR)
181 | mask = mask.resize((ow, oh), Image.NEAREST)
182 | # center crop
183 | w, h = img.size
184 | x1 = int(round((w - self.crop_size) / 2.))
185 | y1 = int(round((h - self.crop_size) / 2.))
186 | img = img.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size))
187 | mask = mask.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size))
188 |
189 | return {'image': img,
190 | 'label': mask}
191 |
192 | class FixedResize(object):
193 | def __init__(self, size):
194 | self.size = (size, size) # size: (h, w)
195 |
196 | def __call__(self, sample):
197 | img1 = sample['image'][0]
198 | img2 = sample['image'][1]
199 | mask = sample['label']
200 |
201 | assert img1.size == mask.size and img2.size == mask.size
202 |
203 | img1 = img1.resize(self.size, Image.BILINEAR)
204 | img2 = img2.resize(self.size, Image.BILINEAR)
205 | mask = mask.resize(self.size, Image.NEAREST)
206 |
207 | return {'image': (img1, img2),
208 | 'label': mask}
209 |
210 |
211 | '''
212 | We don't use Normalize here, because it will bring negative effects.
213 | the mask of ground truth is converted to [0,1] in ToTensor() function.
214 | '''
215 | train_transforms = transforms.Compose([
216 | RandomHorizontalFlip(),
217 | RandomVerticalFlip(),
218 | RandomFixRotate(),
219 | # RandomScaleCrop(base_size=self.args.base_size, crop_size=self.args.crop_size),
220 | # RandomGaussianBlur(),
221 | # Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
222 | ToTensor()])
223 |
224 | test_transforms = transforms.Compose([
225 | # RandomHorizontalFlip(),
226 | # RandomVerticalFlip(),
227 | # RandomFixRotate(),
228 | # RandomScaleCrop(base_size=self.args.base_size, crop_size=self.args.crop_size),
229 | # RandomGaussianBlur(),
230 | # Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
231 | ToTensor()])
--------------------------------------------------------------------------------
/compare/TFI_GR.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from .resnet_tfi import resnet18
5 |
6 |
7 | class TemporalFeatureInteractionModule(nn.Module):
8 | def __init__(self, in_d, out_d):
9 | super(TemporalFeatureInteractionModule, self).__init__()
10 | self.in_d = in_d
11 | self.out_d = out_d
12 | self.conv_sub = nn.Sequential(
13 | nn.Conv2d(self.in_d, self.in_d, kernel_size=3, stride=1, padding=1),
14 | nn.BatchNorm2d(self.in_d),
15 | nn.ReLU(inplace=True)
16 | )
17 | self.conv_diff_enh1 = nn.Sequential(
18 | nn.Conv2d(self.in_d, self.in_d, kernel_size=3, stride=1, padding=1),
19 | nn.BatchNorm2d(self.in_d),
20 | nn.ReLU(inplace=True)
21 | )
22 | self.conv_diff_enh2 = nn.Sequential(
23 | nn.Conv2d(self.in_d, self.in_d, kernel_size=3, stride=1, padding=1),
24 | nn.BatchNorm2d(self.in_d),
25 | nn.ReLU(inplace=True)
26 | )
27 | self.conv_cat = nn.Sequential(
28 | nn.Conv2d(self.in_d * 2, self.in_d, kernel_size=3, stride=1, padding=1),
29 | nn.BatchNorm2d(self.in_d),
30 | nn.ReLU(inplace=True)
31 | )
32 | self.conv_dr = nn.Sequential(
33 | nn.Conv2d(self.in_d, self.out_d, kernel_size=1, bias=True),
34 | nn.BatchNorm2d(self.out_d),
35 | nn.ReLU(inplace=True)
36 | )
37 |
38 | def forward(self, x1, x2):
39 | # difference enhance
40 | x_sub = self.conv_sub(torch.abs(x1 - x2))
41 | x1 = self.conv_diff_enh1(x1.mul(x_sub) + x1)
42 | x2 = self.conv_diff_enh2(x2.mul(x_sub) + x2)
43 | # fusion
44 | x_f = torch.cat([x1, x2], dim=1)
45 | x_f = self.conv_cat(x_f)
46 | x = x_sub + x_f
47 | x = self.conv_dr(x)
48 | return x
49 |
50 |
51 | class ChannelAttention(nn.Module):
52 | def __init__(self, in_planes, ratio=16):
53 | super(ChannelAttention, self).__init__()
54 | self.avg_pool = nn.AdaptiveAvgPool2d(1)
55 | self.max_pool = nn.AdaptiveMaxPool2d(1)
56 |
57 | self.fc1 = nn.Conv2d(in_planes, in_planes // ratio, 1, bias=False)
58 | self.relu1 = nn.ReLU()
59 | self.fc2 = nn.Conv2d(in_planes // ratio, in_planes, 1, bias=False)
60 | self.sigmoid = nn.Sigmoid()
61 |
62 | def forward(self, x):
63 | avg_out = self.fc2(self.relu1(self.fc1(self.avg_pool(x))))
64 | max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x))))
65 | out = avg_out + max_out
66 | return self.sigmoid(out)
67 |
68 |
69 | class ChangeInformationExtractionModule(nn.Module):
70 | def __init__(self, in_d, out_d):
71 | super(ChangeInformationExtractionModule, self).__init__()
72 | self.in_d = in_d
73 | self.out_d = out_d
74 | self.ca = ChannelAttention(self.in_d * 4, ratio=16)
75 | self.conv_dr = nn.Sequential(
76 | nn.Conv2d(self.in_d * 4, self.in_d, kernel_size=3, stride=1, padding=1, bias=False),
77 | nn.BatchNorm2d(self.in_d),
78 | nn.ReLU(inplace=True)
79 | )
80 | self.pools_sizes = [2, 4, 8]
81 | self.conv_pool1 = nn.Sequential(
82 | nn.AvgPool2d(kernel_size=self.pools_sizes[0], stride=self.pools_sizes[0]),
83 | nn.Conv2d(self.in_d, self.in_d, kernel_size=3, stride=1, padding=1, bias=False)
84 | )
85 | self.conv_pool2 = nn.Sequential(
86 | nn.AvgPool2d(kernel_size=self.pools_sizes[1], stride=self.pools_sizes[1]),
87 | nn.Conv2d(self.in_d, self.in_d, kernel_size=3, stride=1, padding=1, bias=False)
88 | )
89 | self.conv_pool3 = nn.Sequential(
90 | nn.AvgPool2d(kernel_size=self.pools_sizes[2], stride=self.pools_sizes[2]),
91 | nn.Conv2d(self.in_d, self.in_d, kernel_size=3, stride=1, padding=1, bias=False)
92 | )
93 |
94 | def forward(self, d5, d4, d3, d2):
95 | # upsampling
96 | d5 = F.interpolate(d5, d2.size()[2:], mode='bilinear', align_corners=True)
97 | d4 = F.interpolate(d4, d2.size()[2:], mode='bilinear', align_corners=True)
98 | d3 = F.interpolate(d3, d2.size()[2:], mode='bilinear', align_corners=True)
99 | # fusion
100 | x = torch.cat([d5, d4, d3, d2], dim=1)
101 | x_ca = self.ca(x)
102 | x = x * x_ca
103 | x = self.conv_dr(x)
104 |
105 | # feature = x[0:1, 0:64, 0:64, 0:64]
106 | # vis.visulize_features(feature)
107 |
108 | # pooling
109 | d2 = x
110 | d3 = self.conv_pool1(x)
111 | d4 = self.conv_pool2(x)
112 | d5 = self.conv_pool3(x)
113 |
114 | return d5, d4, d3, d2
115 |
116 |
117 | class GuidedRefinementModule(nn.Module):
118 | def __init__(self, in_d, out_d):
119 | super(GuidedRefinementModule, self).__init__()
120 | self.in_d = in_d
121 | self.out_d = out_d
122 | self.conv_d5 = nn.Sequential(
123 | nn.Conv2d(self.in_d, self.out_d, kernel_size=3, stride=1, padding=1),
124 | nn.BatchNorm2d(self.out_d),
125 | nn.ReLU(inplace=True)
126 | )
127 | self.conv_d4 = nn.Sequential(
128 | nn.Conv2d(self.in_d, self.out_d, kernel_size=3, stride=1, padding=1),
129 | nn.BatchNorm2d(self.out_d),
130 | nn.ReLU(inplace=True)
131 | )
132 | self.conv_d3 = nn.Sequential(
133 | nn.Conv2d(self.in_d, self.out_d, kernel_size=3, stride=1, padding=1),
134 | nn.BatchNorm2d(self.out_d),
135 | nn.ReLU(inplace=True)
136 | )
137 | self.conv_d2 = nn.Sequential(
138 | nn.Conv2d(self.in_d, self.out_d, kernel_size=3, stride=1, padding=1),
139 | nn.BatchNorm2d(self.out_d),
140 | nn.ReLU(inplace=True)
141 | )
142 |
143 | def forward(self, d5, d4, d3, d2, d5_p, d4_p, d3_p, d2_p):
144 | # feature refinement
145 | d5 = self.conv_d5(d5_p + d5)
146 | d4 = self.conv_d4(d4_p + d4)
147 | d3 = self.conv_d3(d3_p + d3)
148 | d2 = self.conv_d2(d2_p + d2)
149 |
150 | return d5, d4, d3, d2
151 |
152 |
153 | class Decoder(nn.Module):
154 | def __init__(self, in_d, out_d):
155 | super(Decoder, self).__init__()
156 | self.in_d = in_d
157 | self.out_d = out_d
158 | self.conv_sum1 = nn.Sequential(
159 | nn.Conv2d(self.in_d, self.in_d, kernel_size=3, stride=1, padding=1),
160 | nn.BatchNorm2d(self.in_d),
161 | nn.ReLU(inplace=True)
162 | )
163 | self.conv_sum2 = nn.Sequential(
164 | nn.Conv2d(self.in_d, self.in_d, kernel_size=3, stride=1, padding=1),
165 | nn.BatchNorm2d(self.in_d),
166 | nn.ReLU(inplace=True)
167 | )
168 | self.conv_sum3 = nn.Sequential(
169 | nn.Conv2d(self.in_d, self.in_d, kernel_size=3, stride=1, padding=1),
170 | nn.BatchNorm2d(self.in_d),
171 | nn.ReLU(inplace=True)
172 | )
173 | self.cls = nn.Conv2d(self.in_d, self.out_d, kernel_size=1, bias=False)
174 |
175 | def forward(self, d5, d4, d3, d2):
176 |
177 | d5 = F.interpolate(d5, d4.size()[2:], mode='bilinear', align_corners=True)
178 | d4 = self.conv_sum1(d4 + d5)
179 | d4 = F.interpolate(d4, d3.size()[2:], mode='bilinear', align_corners=True)
180 | d3 = self.conv_sum1(d3 + d4)
181 | d3 = F.interpolate(d3, d2.size()[2:], mode='bilinear', align_corners=True)
182 | d2 = self.conv_sum1(d2 + d3)
183 |
184 | mask = self.cls(d2)
185 |
186 | return mask
187 |
188 |
189 | class TFI_GR(nn.Module):
190 | def __init__(self, input_nc, output_nc):
191 | super(TFI_GR, self).__init__()
192 | self.backbone = resnet18(pretrained=True)
193 | self.mid_d = 64
194 | self.TFIM5 = TemporalFeatureInteractionModule(512, self.mid_d)
195 | self.TFIM4 = TemporalFeatureInteractionModule(256, self.mid_d)
196 | self.TFIM3 = TemporalFeatureInteractionModule(128, self.mid_d)
197 | self.TFIM2 = TemporalFeatureInteractionModule(64, self.mid_d)
198 |
199 | self.CIEM1 = ChangeInformationExtractionModule(self.mid_d, output_nc)
200 | self.GRM1 = GuidedRefinementModule(self.mid_d, self.mid_d)
201 |
202 | self.CIEM2 = ChangeInformationExtractionModule(self.mid_d, output_nc)
203 | self.GRM2 = GuidedRefinementModule(self.mid_d, self.mid_d)
204 |
205 | self.decoder = Decoder(self.mid_d, output_nc)
206 |
207 | def forward(self, x1, x2):
208 | # forward backbone resnet
209 | x1_1, x1_2, x1_3, x1_4, x1_5 = self.backbone.base_forward(x1)
210 | x2_1, x2_2, x2_3, x2_4, x2_5 = self.backbone.base_forward(x2)
211 | # feature difference
212 | d5 = self.TFIM5(x1_5, x2_5) # 1/32
213 | d4 = self.TFIM4(x1_4, x2_4) # 1/16
214 | d3 = self.TFIM3(x1_3, x2_3) # 1/8
215 | d2 = self.TFIM2(x1_2, x2_2) # 1/4
216 |
217 | # change information guided refinement 1
218 | d5_p, d4_p, d3_p, d2_p = self.CIEM1(d5, d4, d3, d2)
219 | d5, d4, d3, d2 = self.GRM1(d5, d4, d3, d2, d5_p, d4_p, d3_p, d2_p)
220 |
221 | # change information guided refinement 2
222 | d5_p, d4_p, d3_p, d2_p = self.CIEM2(d5, d4, d3, d2)
223 | d5, d4, d3, d2 = self.GRM2(d5, d4, d3, d2, d5_p, d4_p, d3_p, d2_p)
224 |
225 | # decoder
226 | mask = self.decoder(d5, d4, d3, d2)
227 | mask = F.interpolate(mask, x1.size()[2:], mode='bilinear', align_corners=True)
228 | # mask = torch.sigmoid(mask)
229 |
230 | return mask
231 |
--------------------------------------------------------------------------------
/compare/resnet_tfi.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | # from torchvision.models.utils import load_state_dict_from_url
4 | from torch.hub import load_state_dict_from_url
5 |
6 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
7 | 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d']
8 |
9 | model_urls = {
10 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
11 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
12 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
13 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
14 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
15 | 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
16 | 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
17 | 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
18 | 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',
19 | }
20 |
21 |
22 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
23 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
24 | padding=dilation, groups=groups, bias=False, dilation=dilation)
25 |
26 |
27 | def conv1x1(in_planes, out_planes, stride=1):
28 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
29 |
30 |
31 | class BasicBlock(nn.Module):
32 | expansion = 1
33 |
34 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
35 | base_width=64, dilation=1, norm_layer=None):
36 | super(BasicBlock, self).__init__()
37 | if norm_layer is None:
38 | norm_layer = nn.BatchNorm2d
39 | if groups != 1 or base_width != 64:
40 | raise ValueError('BasicBlock only supports groups=1 and base_width=64')
41 |
42 | self.conv1 = conv3x3(inplanes, planes, stride, dilation=dilation)
43 | self.bn1 = norm_layer(planes)
44 | self.relu = nn.ReLU(inplace=True)
45 | self.conv2 = conv3x3(planes, planes)
46 | self.bn2 = norm_layer(planes)
47 | self.downsample = downsample
48 | self.stride = stride
49 |
50 | def forward(self, x):
51 | identity = x
52 |
53 | out = self.conv1(x)
54 | out = self.bn1(out)
55 | out = self.relu(out)
56 |
57 | out = self.conv2(out)
58 | out = self.bn2(out)
59 |
60 | if self.downsample is not None:
61 | identity = self.downsample(x)
62 |
63 | out += identity
64 | out = self.relu(out)
65 |
66 | return out
67 |
68 |
69 | class Bottleneck(nn.Module):
70 | expansion = 4
71 |
72 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
73 | base_width=64, dilation=1, norm_layer=None):
74 | super(Bottleneck, self).__init__()
75 | if norm_layer is None:
76 | norm_layer = nn.BatchNorm2d
77 | width = int(planes * (base_width / 64.)) * groups
78 |
79 | self.conv1 = conv1x1(inplanes, width)
80 | self.bn1 = norm_layer(width)
81 | self.conv2 = conv3x3(width, width, stride, groups, dilation)
82 | self.bn2 = norm_layer(width)
83 | self.conv3 = conv1x1(width, planes * self.expansion)
84 | self.bn3 = norm_layer(planes * self.expansion)
85 | self.relu = nn.ReLU(inplace=True)
86 | self.downsample = downsample
87 | self.stride = stride
88 |
89 | def forward(self, x):
90 | identity = x
91 |
92 | out = self.conv1(x)
93 | out = self.bn1(out)
94 | out = self.relu(out)
95 |
96 | out = self.conv2(out)
97 | out = self.bn2(out)
98 | out = self.relu(out)
99 |
100 | out = self.conv3(out)
101 | out = self.bn3(out)
102 |
103 | if self.downsample is not None:
104 | identity = self.downsample(x)
105 |
106 | out += identity
107 | out = self.relu(out)
108 |
109 | return out
110 |
111 |
112 | class ResNet(nn.Module):
113 |
114 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, groups=1,
115 | width_per_group=64, replace_stride_with_dilation=None, norm_layer=None):
116 | super(ResNet, self).__init__()
117 |
118 | self.channels = [64, 64 * block.expansion, 128 * block.expansion,
119 | 256 * block.expansion, 512 * block.expansion]
120 |
121 | if norm_layer is None:
122 | norm_layer = nn.BatchNorm2d
123 | self._norm_layer = norm_layer
124 |
125 | self.inplanes = 64
126 | self.dilation = 1
127 | if replace_stride_with_dilation is None:
128 | replace_stride_with_dilation = [False, False, False]
129 | if len(replace_stride_with_dilation) != 3:
130 | raise ValueError('replace_stride_with_dilation should be None '
131 | 'or a 3-element tuple, got {}'.format(replace_stride_with_dilation))
132 | self.groups = groups
133 | self.base_width = width_per_group
134 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
135 | bias=False)
136 | self.bn1 = norm_layer(self.inplanes)
137 | self.relu = nn.ReLU(inplace=True)
138 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
139 | self.layer1 = self._make_layer(block, 64, layers[0])
140 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
141 | dilate=replace_stride_with_dilation[0])
142 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
143 | dilate=replace_stride_with_dilation[1])
144 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
145 | dilate=replace_stride_with_dilation[2])
146 |
147 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
148 | self.fc = nn.Linear(512 * block.expansion, num_classes)
149 |
150 | for m in self.modules():
151 | if isinstance(m, nn.Conv2d):
152 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
153 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
154 | nn.init.constant_(m.weight, 1)
155 | nn.init.constant_(m.bias, 0)
156 |
157 | if zero_init_residual:
158 | for m in self.modules():
159 | if isinstance(m, Bottleneck):
160 | nn.init.constant_(m.bn3.weight, 0)
161 | elif isinstance(m, BasicBlock):
162 | nn.init.constant_(m.bn2.weight, 0)
163 |
164 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
165 | norm_layer = self._norm_layer
166 | downsample = None
167 | previous_dilation = self.dilation
168 | if dilate:
169 | self.dilation *= stride
170 | stride = 1
171 | if stride != 1 or self.inplanes != planes * block.expansion:
172 | downsample = nn.Sequential(
173 | conv1x1(self.inplanes, planes * block.expansion, stride),
174 | norm_layer(planes * block.expansion),
175 | )
176 |
177 | layers = list()
178 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
179 | self.base_width, previous_dilation, norm_layer))
180 | self.inplanes = planes * block.expansion
181 | for _ in range(1, blocks):
182 | layers.append(block(self.inplanes, planes, groups=self.groups,
183 | base_width=self.base_width, dilation=self.dilation,
184 | norm_layer=norm_layer))
185 |
186 | return nn.Sequential(*layers)
187 |
188 | def base_forward(self, x):
189 | x = self.conv1(x)
190 | x = self.bn1(x)
191 | c0 = self.relu(x)
192 | c1 = self.maxpool(c0)
193 |
194 | c1 = self.layer1(c1)
195 | c2 = self.layer2(c1)
196 | c3 = self.layer3(c2)
197 | c4 = self.layer4(c3)
198 |
199 | return c0, c1, c2, c3, c4
200 |
201 | def forward(self, x):
202 | x = self.base_forward(x)[-1]
203 | x = self.avgpool(x)
204 | x = torch.flatten(x, 1)
205 | x = self.fc(x)
206 |
207 | return x
208 |
209 |
210 | def _resnet(arch, block, layers, pretrained, **kwargs):
211 | model = ResNet(block, layers, **kwargs)
212 | if pretrained:
213 | state_dict = load_state_dict_from_url(model_urls[arch],
214 | progress=True)
215 | model.load_state_dict(state_dict)
216 | return model
217 |
218 |
219 | def resnet18(pretrained=False, **kwargs):
220 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained,
221 | replace_stride_with_dilation=[False, False, False], **kwargs)
222 |
223 |
224 | def resnet34(pretrained=False, **kwargs):
225 | return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained,
226 | replace_stride_with_dilation=[False, True, True], **kwargs)
227 |
228 |
229 | def resnet50(pretrained=False, **kwargs):
230 | return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained,
231 | replace_stride_with_dilation=[False, True, True], **kwargs)
232 |
233 |
234 | def resnet101(pretrained=False, **kwargs):
235 | return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained,
236 | replace_stride_with_dilation=[False, True, True], **kwargs)
237 |
238 |
239 | def resnet152(pretrained=False, **kwargs):
240 | return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained,
241 | replace_stride_with_dilation=[False, True, True], **kwargs)
242 |
243 |
244 | def resnext50_32x4d(pretrained=False, **kwargs):
245 | kwargs['groups'] = 32
246 | kwargs['width_per_group'] = 4
247 | return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3], pretrained,
248 | replace_stride_with_dilation=[False, True, True], **kwargs)
249 |
250 |
251 | def resnext101_32x8d(pretrained=False, **kwargs):
252 | kwargs['groups'] = 32
253 | kwargs['width_per_group'] = 8
254 | return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3], pretrained,
255 | replace_stride_with_dilation=[False, True, True], **kwargs)
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | from cProfile import label
2 | import datetime
3 | import torch
4 | from sklearn.metrics import precision_recall_fscore_support as prfs
5 | from utils.parser import get_parser_with_args
6 | from utils.helpers import (get_loaders, get_criterion,
7 | load_model, initialize_metrics, get_mean_metrics,
8 | set_metrics)
9 | import os
10 | import logging
11 | import json
12 | from tensorboardX import SummaryWriter
13 | from tqdm import tqdm
14 | import random
15 | import numpy as np
16 | from torch.optim import lr_scheduler
17 |
18 | def get_scheduler(optimizer, opt, lr_policy):
19 | """Return a learning rate scheduler
20 | Parameters:
21 | optimizer -- the optimizer of the network
22 | args (option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions.
23 | opt.lr_policy is the name of learning rate policy: linear | step | plateau | cosine
24 | For 'linear', we keep the same learning rate for the first epochs
25 | and linearly decay the rate to zero over the next epochs.
26 | For other schedulers (step, plateau, and cosine), we use the default PyTorch schedulers.
27 | See https://pytorch.org/docs/stable/optim.html for more details.
28 | """
29 | if lr_policy == 'linear':
30 | def lambda_rule(epoch):
31 | lr_l = 1.0 - epoch / float(opt.epochs + 1)
32 | return lr_l
33 | scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
34 | elif lr_policy == 'step':
35 | step_size = opt.epochs//3
36 | # args.lr_decay_iters
37 | scheduler = lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=0.1)
38 | else:
39 | return NotImplementedError('learning rate policy [%s] is not implemented', lr_policy)
40 | return scheduler
41 |
42 |
43 | def seed_torch(seed):
44 | random.seed(seed)
45 | os.environ['PYTHONHASHSEED'] = str(seed)
46 | np.random.seed(seed)
47 | torch.manual_seed(seed)
48 | torch.cuda.manual_seed(seed)
49 | # torch.cuda.manual_seed_all(seed) # if you are using multi-GPU.
50 | torch.backends.cudnn.benchmark = False
51 | torch.backends.cudnn.deterministic = True
52 |
53 |
54 | if __name__ == '__main__':
55 | """
56 | Initialize Parser and
57 | define arguments
58 | """
59 | parser, metadata = get_parser_with_args()
60 | opt = parser.parse_args()
61 |
62 | opt.epochs = 100
63 | opt.batch_size = 16
64 | opt.loss_function = "bce"
65 |
66 | #checkpoints dir
67 | opt.checkpoint_dir= os.path.join(opt.path,opt.project_name)
68 | os.makedirs(opt.checkpoint_dir,exist_ok=True)
69 | # save_path = '.tmp' + '/' + opt.dataset + '_' + opt.backbone + '_' + opt.mode
70 |
71 | """
72 | Initialize experiments log
73 | """
74 | logging.basicConfig(level=logging.INFO)
75 | writer = SummaryWriter(
76 | '.tmp/log/' + opt.dataset + '_' + opt.backbone + '_' + opt.mode + '_' + f'{datetime.datetime.now().strftime("%Y%m%d-%H%M%S")}')
77 |
78 | """
79 | Set up environment: define paths, download data, and set device
80 | """
81 | # os.environ["CUDA_VISIBLE_DEVICES"] = "0"
82 | dev = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
83 | logging.info('GPU AVAILABLE? ' + str(torch.cuda.is_available()))
84 |
85 | seed_torch(seed=777)
86 |
87 | if opt.dataset == 'cdd':
88 | opt.dataset_dir = '../CDDataset/ChangeDetection/'
89 | elif opt.dataset == 'levir':
90 | opt.dataset_dir = '../CDDataset/LEVIR/'
91 | elif opt.dataset_dir == 'levir+':
92 | opt.dataset_dir = '../CDDataset/LEVIR-CD+_256'
93 |
94 | train_loader, val_loader = get_loaders(opt)
95 |
96 | """
97 | Load Model then define other aspects of the model
98 | """
99 | logging.info('LOADING Model')
100 | model = load_model(opt, dev)
101 |
102 | criterion = get_criterion(opt)
103 | if opt.backbone == 'resnet':
104 | optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.99,
105 | weight_decay=0.0005) # Be careful when you adjust learning rate, you can refer to the linear scaling rule
106 | # scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=8, gamma=0.5)
107 | scheduler = get_scheduler(optimizer, opt, 'linear')
108 | else:
109 | optimizer = torch.optim.AdamW(model.parameters(), lr=0.00006, betas=(0.9, 0.999), weight_decay=0.01)
110 | scheduler = get_scheduler(optimizer, opt, 'linear')
111 |
112 | """
113 | Set starting values
114 | """
115 | best_metrics = {'cd_f1scores': -1, 'cd_recalls': -1, 'cd_precisions': -1}
116 | logging.info('STARTING training')
117 | total_step = -1
118 |
119 | for epoch in range(opt.epochs):
120 | train_metrics = initialize_metrics()
121 | val_metrics = initialize_metrics()
122 |
123 | """
124 | Begin Training
125 | """
126 | model.train()
127 | logging.info('SET model mode to train!')
128 | batch_iter = 0
129 | tbar = tqdm(train_loader)
130 | for batch_img1, batch_img2, labels, fname in tbar:
131 | tbar.set_description(
132 | "epoch {} info ".format(epoch) + str(batch_iter) + " - " + str(batch_iter + opt.batch_size))
133 | batch_iter = batch_iter + opt.batch_size
134 | total_step += 1
135 | # Set variables for training
136 | batch_img1 = batch_img1.float().to(dev)
137 | batch_img2 = batch_img2.float().to(dev)
138 | labels = labels.long().to(dev)
139 |
140 | # Zero the gradient
141 | optimizer.zero_grad()
142 |
143 | # Get model predictions, calculate loss, backprop
144 | cd_preds = model(batch_img1, batch_img2)
145 |
146 | cd_loss = criterion(cd_preds, labels)
147 | loss = cd_loss
148 | loss.backward()
149 | optimizer.step()
150 |
151 | # cd_preds = cd_preds[-1] # BIT输出不是tuple
152 | _, cd_preds = torch.max(cd_preds, 1)
153 |
154 | # Calculate and log other batch metrics
155 | cd_corrects = (100 *
156 | (cd_preds.squeeze().byte() == labels.squeeze().byte()).sum() /
157 | (labels.size()[0] * (opt.patch_size ** 2)))
158 |
159 | cd_train_report = prfs(labels.data.cpu().numpy().flatten(),
160 | cd_preds.data.cpu().numpy().flatten(),
161 | average='binary',
162 | pos_label=1, zero_division=0)
163 |
164 | train_metrics = set_metrics(train_metrics,
165 | cd_loss,
166 | cd_corrects,
167 | cd_train_report,
168 | scheduler.get_last_lr())
169 |
170 | # log the batch mean metrics
171 | mean_train_metrics = get_mean_metrics(train_metrics)
172 |
173 | for k, v in mean_train_metrics.items():
174 | writer.add_scalars(str(k), {'train': v}, total_step)
175 |
176 | # clear batch variables from memory
177 | del batch_img1, batch_img2, labels
178 |
179 | scheduler.step()
180 | logging.info("EPOCH {} TRAIN METRICS".format(epoch) + str(mean_train_metrics))
181 |
182 | """
183 | Begin Validation
184 | """
185 | model.eval()
186 | with torch.no_grad():
187 | for batch_img1, batch_img2, labels, fname in val_loader:
188 | # Set variables for training
189 | batch_img1 = batch_img1.float().to(dev)
190 | batch_img2 = batch_img2.float().to(dev)
191 | labels = labels.long().to(dev)
192 |
193 | # Get predictions and calculate loss
194 | cd_preds = model(batch_img1, batch_img2)
195 |
196 | cd_loss = criterion(cd_preds, labels)
197 |
198 | # cd_preds = cd_preds[-1] # BIT输出不是tuple
199 | _, cd_preds = torch.max(cd_preds, 1)
200 |
201 | # Calculate and log other batch metrics
202 | cd_corrects = (100 *
203 | (cd_preds.squeeze().byte() == labels.squeeze().byte()).sum() /
204 | (labels.size()[0] * (opt.patch_size ** 2)))
205 |
206 | cd_val_report = prfs(labels.data.cpu().numpy().flatten(),
207 | cd_preds.data.cpu().numpy().flatten(),
208 | average='binary',
209 | pos_label=1, zero_division=0)
210 |
211 | val_metrics = set_metrics(val_metrics,
212 | cd_loss,
213 | cd_corrects,
214 | cd_val_report,
215 | scheduler.get_last_lr())
216 |
217 | # log the batch mean metrics
218 | mean_val_metrics = get_mean_metrics(val_metrics)
219 |
220 | for k, v in mean_train_metrics.items():
221 | writer.add_scalars(str(k), {'val': v}, total_step)
222 |
223 | # clear batch variables from memory
224 | del batch_img1, batch_img2, labels
225 |
226 | logging.info("EPOCH {} VALIDATION METRICS".format(epoch) + str(mean_val_metrics))
227 |
228 | """
229 | Store the weights of good epochs based on validation results
230 | """
231 | if ((mean_val_metrics['cd_precisions'] > best_metrics['cd_precisions'])
232 | or
233 | (mean_val_metrics['cd_recalls'] > best_metrics['cd_recalls'])
234 | or
235 | (mean_val_metrics['cd_f1scores'] > best_metrics['cd_f1scores'])):
236 |
237 | # Insert training and epoch information to metadata dictionary
238 | logging.info('updata the model')
239 | metadata['validation_metrics'] = mean_val_metrics
240 |
241 | # Save model and log
242 |
243 | if not os.path.exists(save_path):
244 | os.makedirs(save_path)
245 | # with open('./tmp/metadata_epoch_' + str(epoch) + '.json', 'w') as fout:
246 | # json.dump(metadata, fout)
247 |
248 | torch.save(model, save_path + '/checkpoint_epoch_' + str(epoch) + '.pth')
249 |
250 | # comet.log_asset(upload_metadata_file_path)
251 | best_metrics = mean_val_metrics
252 |
253 | print('An epoch finished.')
254 | writer.close() # close tensor board
255 | print('Done!')
256 |
257 |
258 |
259 |
--------------------------------------------------------------------------------
/compare/DMINet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from .resnet import resnet18
4 | import torch.nn.functional as F
5 | import math
6 | import matplotlib.pyplot as plt
7 | import numpy as np
8 | import cv2
9 |
10 |
11 | def init_weights(m):
12 | """
13 | Initialize weights of layers using Kaiming Normal (He et al.) as argument of "Apply" function of
14 | "nn.Module"
15 | :param m: Layer to initialize
16 | :return: None
17 | """
18 | if isinstance(m, nn.Conv2d):
19 | '''
20 | fan_in, _ = nn.init._calculate_fan_in_and_fan_out(m.weight)
21 | trunc_normal_(m.weight, std=math.sqrt(1.0/fan_in)/.87962566103423978)
22 | if m.bias is not None:
23 | nn.init.zeros_(m.bias)
24 | '''
25 | nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu')
26 | if m.bias is not None:
27 | fan_in, _ = nn.init._calculate_fan_in_and_fan_out(m.weight)
28 | bound = 1 / math.sqrt(fan_in)
29 | nn.init.uniform_(m.bias, -bound, bound)
30 |
31 | elif isinstance(m, nn.BatchNorm2d):
32 | nn.init.constant_(m.weight, 1)
33 | nn.init.constant_(m.bias, 0)
34 |
35 |
36 | class Conv(nn.Module):
37 | def __init__(self, inp_dim, out_dim, kernel_size=3, stride=1, bn=False, relu=True, bias=True):
38 | super(Conv, self).__init__()
39 | self.inp_dim = inp_dim
40 | self.conv = nn.Conv2d(inp_dim, out_dim, kernel_size, stride, padding=(kernel_size - 1) // 2, bias=bias)
41 | self.relu = None
42 | self.bn = None
43 | if relu:
44 | self.relu = nn.ReLU(inplace=True)
45 | if bn:
46 | self.bn = nn.BatchNorm2d(out_dim)
47 |
48 | def forward(self, x):
49 | assert x.size()[1] == self.inp_dim, "{} {}".format(x.size()[1], self.inp_dim)
50 | # print("++",x.size()[1],self.inp_dim,x.size()[1],self.inp_dim)
51 | x = self.conv(x)
52 | if self.bn is not None:
53 | x = self.bn(x)
54 | if self.relu is not None:
55 | x = self.relu(x)
56 | return x
57 |
58 |
59 | class decode(nn.Module):
60 | def __init__(self, in_channel_left, in_channel_down, out_channel, norm_layer=nn.BatchNorm2d):
61 | super(decode, self).__init__()
62 | self.conv_d1 = nn.Conv2d(in_channel_down, out_channel, kernel_size=3, stride=1, padding=1)
63 | self.conv_l = nn.Conv2d(in_channel_left, out_channel, kernel_size=3, stride=1, padding=1)
64 | self.conv3 = nn.Conv2d(out_channel * 2, out_channel, kernel_size=3, stride=1, padding=1)
65 | self.bn3 = norm_layer(out_channel)
66 |
67 | def forward(self, left, down):
68 | down_mask = self.conv_d1(down)
69 | left_mask = self.conv_l(left)
70 | if down.size()[2:] != left.size()[2:]:
71 | down_ = F.interpolate(down, size=left.size()[2:], mode='bilinear')
72 | z1 = F.relu(left_mask * down_, inplace=True)
73 | else:
74 | z1 = F.relu(left_mask * down, inplace=True)
75 |
76 | if down_mask.size()[2:] != left.size()[2:]:
77 | down_mask = F.interpolate(down_mask, size=left.size()[2:], mode='bilinear')
78 |
79 | z2 = F.relu(down_mask * left, inplace=True)
80 |
81 | out = torch.cat((z1, z2), dim=1)
82 | return F.relu(self.bn3(self.conv3(out)), inplace=True)
83 |
84 |
85 | class BasicConv2d(nn.Module):
86 | def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1):
87 | super(BasicConv2d, self).__init__()
88 |
89 | self.conv = nn.Conv2d(in_planes, out_planes,
90 | kernel_size=kernel_size, stride=stride,
91 | padding=padding, dilation=dilation, bias=False)
92 | self.bn = nn.BatchNorm2d(out_planes)
93 | self.relu = nn.ReLU(inplace=True)
94 |
95 | def forward(self, x):
96 | x = self.conv(x)
97 | x = self.bn(x)
98 | return x
99 |
100 |
101 | class CrossAtt(nn.Module):
102 | def __init__(self, in_channels, out_channels):
103 | super().__init__()
104 | self.in_channels = in_channels
105 |
106 | self.query1 = nn.Conv2d(in_channels, in_channels // 8, kernel_size=1, stride=1)
107 | self.key1 = nn.Conv2d(in_channels, in_channels // 4, kernel_size=1, stride=1)
108 | self.value1 = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1)
109 |
110 | self.query2 = nn.Conv2d(in_channels, in_channels // 8, kernel_size=1, stride=1)
111 | self.key2 = nn.Conv2d(in_channels, in_channels // 4, kernel_size=1, stride=1)
112 | self.value2 = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1)
113 |
114 | self.gamma = nn.Parameter(torch.zeros(1))
115 | self.softmax = nn.Softmax(dim=-1)
116 |
117 | self.conv_cat = nn.Sequential(nn.Conv2d(in_channels * 2, out_channels, 3, padding=1, bias=False),
118 | nn.BatchNorm2d(out_channels),
119 | nn.ReLU()) # conv_f
120 |
121 | def forward(self, input1, input2):
122 | batch_size, channels, height, width = input1.shape
123 | q1 = self.query1(input1)
124 | k1 = self.key1(input1).view(batch_size, -1, height * width)
125 | v1 = self.value1(input1).view(batch_size, -1, height * width)
126 |
127 | q2 = self.query2(input2)
128 | k2 = self.key2(input2).view(batch_size, -1, height * width)
129 | v2 = self.value2(input2).view(batch_size, -1, height * width)
130 |
131 | q = torch.cat([q1, q2], 1).view(batch_size, -1, height * width).permute(0, 2, 1)
132 | attn_matrix1 = torch.bmm(q, k1)
133 | attn_matrix1 = self.softmax(attn_matrix1)
134 | out1 = torch.bmm(v1, attn_matrix1.permute(0, 2, 1))
135 | out1 = out1.view(*input1.shape)
136 | out1 = self.gamma * out1 + input1
137 |
138 | attn_matrix2 = torch.bmm(q, k2)
139 | attn_matrix2 = self.softmax(attn_matrix2)
140 | out2 = torch.bmm(v2, attn_matrix2.permute(0, 2, 1))
141 | out2 = out2.view(*input2.shape)
142 | out2 = self.gamma * out2 + input2
143 |
144 | feat_sum = self.conv_cat(torch.cat([out1, out2], 1))
145 | return feat_sum, out1, out2
146 |
147 | #
148 | # def draw_features(width=8, height=8, x, savename):
149 | # # tic=time.time()
150 | # fig = plt.figure(figsize=(60, 60))
151 | # fig.subplots_adjust(left=0.05, right=0.95, bottom=0.05, top=0.95, wspace=0.05, hspace=0.05)
152 | # for i in range(width * height):
153 | # plt.subplot(height, width, i + 1)
154 | # plt.axis('off')
155 | # img = x[0, i, :, :]
156 | # pmin = np.min(img)
157 | # pmax = np.max(img)
158 | # img = ((img - pmin) / (pmax - pmin + 0.000001)) * 255 # float在[0,1]之间,转换成0-255
159 | # img = img.astype(np.uint8) # 转成unit8
160 | # img = cv2.applyColorMap(img, cv2.COLORMAP_JET) # 生成heat map
161 | # img = img[:, :, ::-1] # 注意cv2(BGR)和matplotlib(RGB)通道是相反的
162 | # plt.imshow(img)
163 | # fig.savefig(savename, dpi=100)
164 | # fig.clf()
165 | # plt.close()
166 | # # print("time:{}".format(time.time()-tic))
167 | #
168 |
169 | class DMINet(nn.Module):
170 | def __init__(self, num_classes=2, drop_rate=0.2, normal_init=True, pretrained=False, show_Feature_Maps=False):
171 | super(DMINet, self).__init__()
172 |
173 | self.show_Feature_Maps = show_Feature_Maps
174 |
175 | self.resnet = resnet18()
176 | self.resnet.load_state_dict(torch.load('./pretrain_model/resnet18-5c106cde.pth'))
177 | self.resnet.layer4 = nn.Identity()
178 |
179 | self.cross2 = CrossAtt(256, 256)
180 | self.cross3 = CrossAtt(128, 128)
181 | self.cross4 = CrossAtt(64, 64)
182 |
183 | self.Translayer2_1 = BasicConv2d(256, 128, 1)
184 | self.fam32_1 = decode(128, 128, 128) # AlignBlock(128) # decode(128,128,128)
185 | self.Translayer3_1 = BasicConv2d(128, 64, 1)
186 | self.fam43_1 = decode(64, 64, 64) # AlignBlock(64) # decode(64,64,64)
187 |
188 | self.Translayer2_2 = BasicConv2d(256, 128, 1)
189 | self.fam32_2 = decode(128, 128, 128)
190 | self.Translayer3_2 = BasicConv2d(128, 64, 1)
191 | self.fam43_2 = decode(64, 64, 64)
192 |
193 | self.upsamplex4 = nn.Upsample(scale_factor=4, mode='bilinear')
194 | self.upsamplex8 = nn.Upsample(scale_factor=8, mode='bilinear')
195 |
196 | self.final = nn.Sequential(
197 | Conv(64, 32, 3, bn=True, relu=True),
198 | Conv(32, num_classes, 3, bn=False, relu=False)
199 | )
200 | self.final2 = nn.Sequential(
201 | Conv(64, 32, 3, bn=True, relu=True),
202 | Conv(32, num_classes, 3, bn=False, relu=False)
203 | )
204 |
205 | self.final_2 = nn.Sequential(
206 | Conv(128, 32, 3, bn=True, relu=True),
207 | Conv(32, num_classes, 3, bn=False, relu=False)
208 | )
209 | self.final2_2 = nn.Sequential(
210 | Conv(128, 32, 3, bn=True, relu=True),
211 | Conv(32, num_classes, 3, bn=False, relu=False)
212 | )
213 | if normal_init:
214 | self.init_weights()
215 |
216 | def forward(self, imgs1, imgs2, labels=None):
217 |
218 | c0 = self.resnet.conv1(imgs1)
219 | c0 = self.resnet.bn1(c0)
220 | c0 = self.resnet.relu(c0)
221 | c1 = self.resnet.maxpool(c0)
222 | c1 = self.resnet.layer1(c1)
223 | c2 = self.resnet.layer2(c1)
224 | c3 = self.resnet.layer3(c2)
225 |
226 | c0_img2 = self.resnet.conv1(imgs2)
227 | c0_img2 = self.resnet.bn1(c0_img2)
228 | c0_img2 = self.resnet.relu(c0_img2)
229 | c1_img2 = self.resnet.maxpool(c0_img2)
230 | c1_img2 = self.resnet.layer1(c1_img2)
231 | c2_img2 = self.resnet.layer2(c1_img2)
232 | c3_img2 = self.resnet.layer3(c2_img2)
233 |
234 | cross_result2, cur1_2, cur2_2 = self.cross2(c3, c3_img2)
235 | cross_result3, cur1_3, cur2_3 = self.cross3(c2, c2_img2)
236 | cross_result4, cur1_4, cur2_4 = self.cross4(c1, c1_img2)
237 |
238 | out3 = self.fam32_1(cross_result3, self.Translayer2_1(cross_result2))
239 | out4 = self.fam43_1(cross_result4, self.Translayer3_1(out3))
240 |
241 | out3_2 = self.fam32_2(torch.abs(cur1_3 - cur2_3), self.Translayer2_2(torch.abs(cur1_2 - cur2_2)))
242 | out4_2 = self.fam43_2(torch.abs(cur1_4 - cur2_4), self.Translayer3_2(out3_2))
243 |
244 | out4_up = self.upsamplex4(out4)
245 | out4_2_up = self.upsamplex4(out4_2)
246 | out_1 = self.final(out4_up)
247 | out_2 = self.final2(out4_2_up)
248 |
249 | out_1_2 = self.final_2(self.upsamplex8(out3))
250 | out_2_2 = self.final2_2(self.upsamplex8(out3_2))
251 |
252 | # if self.show_Feature_Maps:
253 | # savepath = r'temp'
254 | # draw_features(8, 8, (F.interpolate(c1, scale_factor=4, mode='bilinear')).cpu().detach().numpy(),
255 | # "{}/c1_img1.png".format(savepath))
256 | # draw_features(8, 16, (F.interpolate(c2, scale_factor=8, mode='bilinear')).cpu().detach().numpy(),
257 | # "{}/c2_img1.png".format(savepath))
258 | # draw_features(16, 16, (F.interpolate(c3, scale_factor=8, mode='bilinear')).cpu().detach().numpy(),
259 | # "{}/c3_img1.png".format(savepath))
260 | # draw_features(8, 8, (F.interpolate(c1_img2, scale_factor=4, mode='bilinear')).cpu().detach().numpy(),
261 | # "{}/c1_img2.png".format(savepath))
262 | # draw_features(8, 16, (F.interpolate(c2_img2, scale_factor=8, mode='bilinear')).cpu().detach().numpy(),
263 | # "{}/c2_img2.png".format(savepath))
264 | # draw_features(16, 16, (F.interpolate(c3_img2, scale_factor=8, mode='bilinear')).cpu().detach().numpy(),
265 | # "{}/c3_img2.png".format(savepath))
266 | # # You can show more.
267 |
268 | return out_1, out_2, out_1_2, out_2_2
269 |
270 | def init_weights(self):
271 | self.cross2.apply(init_weights)
272 | self.cross3.apply(init_weights)
273 | self.cross4.apply(init_weights)
274 |
275 | self.fam32_1.apply(init_weights)
276 | self.Translayer2_1.apply(init_weights)
277 | self.fam43_1.apply(init_weights)
278 | self.Translayer3_1.apply(init_weights)
279 |
280 | self.fam32_2.apply(init_weights)
281 | self.Translayer2_2.apply(init_weights)
282 | self.fam43_2.apply(init_weights)
283 | self.Translayer3_2.apply(init_weights)
284 |
285 | self.final.apply(init_weights)
286 | self.final2.apply(init_weights)
287 | self.final_2.apply(init_weights)
288 | self.final2_2.apply(init_weights)
289 |
--------------------------------------------------------------------------------
/compare/A2Net.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from compare.MobileNet import mobilenet_v2
5 |
6 | class NeighborFeatureAggregation(nn.Module):
7 | def __init__(self, in_d=None, out_d=64):
8 | super(NeighborFeatureAggregation, self).__init__()
9 | if in_d is None:
10 | in_d = [16, 24, 32, 96, 320]
11 | self.in_d = in_d
12 | self.mid_d = out_d // 2
13 | self.out_d = out_d
14 | # scale 2
15 | self.conv_scale2_c2 = nn.Sequential(
16 | nn.Conv2d(self.in_d[1], self.mid_d, kernel_size=3, stride=1, padding=1),
17 | nn.BatchNorm2d(self.mid_d),
18 | nn.ReLU(inplace=True)
19 | )
20 | self.conv_scale2_c3 = nn.Sequential(
21 | nn.Conv2d(self.in_d[2], self.mid_d, kernel_size=3, stride=1, padding=1),
22 | nn.BatchNorm2d(self.mid_d),
23 | nn.ReLU(inplace=True)
24 | )
25 | self.conv_aggregation_s2 = FeatureFusionModule(self.mid_d * 2, self.in_d[1], self.out_d)
26 | # scale 3
27 | self.conv_scale3_c2 = nn.Sequential(
28 | nn.MaxPool2d(kernel_size=2, stride=2),
29 | nn.Conv2d(self.in_d[1], self.mid_d, kernel_size=3, stride=1, padding=1),
30 | nn.BatchNorm2d(self.mid_d),
31 | nn.ReLU(inplace=True)
32 | )
33 | self.conv_scale3_c3 = nn.Sequential(
34 | nn.Conv2d(self.in_d[2], self.mid_d, kernel_size=3, stride=1, padding=1),
35 | nn.BatchNorm2d(self.mid_d),
36 | nn.ReLU(inplace=True)
37 | )
38 | self.conv_scale3_c4 = nn.Sequential(
39 | nn.Conv2d(self.in_d[3], self.mid_d, kernel_size=3, stride=1, padding=1),
40 | nn.BatchNorm2d(self.mid_d),
41 | nn.ReLU(inplace=True)
42 | )
43 | self.conv_aggregation_s3 = FeatureFusionModule(self.mid_d * 3, self.in_d[2], self.out_d)
44 | # scale 4
45 | self.conv_scale4_c3 = nn.Sequential(
46 | nn.MaxPool2d(kernel_size=2, stride=2),
47 | nn.Conv2d(self.in_d[2], self.mid_d, kernel_size=3, stride=1, padding=1),
48 | nn.BatchNorm2d(self.mid_d),
49 | nn.ReLU(inplace=True)
50 | )
51 | self.conv_scale4_c4 = nn.Sequential(
52 | nn.Conv2d(self.in_d[3], self.mid_d, kernel_size=3, stride=1, padding=1),
53 | nn.BatchNorm2d(self.mid_d),
54 | nn.ReLU(inplace=True)
55 | )
56 | self.conv_scale4_c5 = nn.Sequential(
57 | nn.Conv2d(self.in_d[4], self.mid_d, kernel_size=3, stride=1, padding=1),
58 | nn.BatchNorm2d(self.mid_d),
59 | nn.ReLU(inplace=True)
60 | )
61 | self.conv_aggregation_s4 = FeatureFusionModule(self.mid_d * 3, self.in_d[3], self.out_d)
62 | # scale 5
63 | self.conv_scale5_c4 = nn.Sequential(
64 | nn.MaxPool2d(kernel_size=2, stride=2),
65 | nn.Conv2d(self.in_d[3], self.mid_d, kernel_size=3, stride=1, padding=1),
66 | nn.BatchNorm2d(self.mid_d),
67 | nn.ReLU(inplace=True)
68 | )
69 | self.conv_scale5_c5 = nn.Sequential(
70 | nn.Conv2d(self.in_d[4], self.mid_d, kernel_size=3, stride=1, padding=1),
71 | nn.BatchNorm2d(self.mid_d),
72 | nn.ReLU(inplace=True)
73 | )
74 | self.conv_aggregation_s5 = FeatureFusionModule(self.mid_d * 2, self.in_d[4], self.out_d)
75 |
76 | def forward(self, c2, c3, c4, c5):
77 | # scale 2
78 | c2_s2 = self.conv_scale2_c2(c2)
79 |
80 | c3_s2 = self.conv_scale2_c3(c3)
81 | c3_s2 = F.interpolate(c3_s2, scale_factor=(2, 2), mode='bilinear')
82 |
83 | s2 = self.conv_aggregation_s2(torch.cat([c2_s2, c3_s2], dim=1), c2)
84 | # scale 3
85 | c2_s3 = self.conv_scale3_c2(c2)
86 |
87 | c3_s3 = self.conv_scale3_c3(c3)
88 |
89 | c4_s3 = self.conv_scale3_c4(c4)
90 | c4_s3 = F.interpolate(c4_s3, scale_factor=(2, 2), mode='bilinear')
91 |
92 | s3 = self.conv_aggregation_s3(torch.cat([c2_s3, c3_s3, c4_s3], dim=1), c3)
93 | # scale 4
94 | c3_s4 = self.conv_scale4_c3(c3)
95 |
96 | c4_s4 = self.conv_scale4_c4(c4)
97 |
98 | c5_s4 = self.conv_scale4_c5(c5)
99 | c5_s4 = F.interpolate(c5_s4, scale_factor=(2, 2), mode='bilinear')
100 |
101 | s4 = self.conv_aggregation_s4(torch.cat([c3_s4, c4_s4, c5_s4], dim=1), c4)
102 | # scale 5
103 | c4_s5 = self.conv_scale5_c4(c4)
104 |
105 | c5_s5 = self.conv_scale5_c5(c5)
106 |
107 | s5 = self.conv_aggregation_s5(torch.cat([c4_s5, c5_s5], dim=1), c5)
108 |
109 | return s2, s3, s4, s5
110 |
111 |
112 | class FeatureFusionModule(nn.Module):
113 | def __init__(self, fuse_d, id_d, out_d):
114 | super(FeatureFusionModule, self).__init__()
115 | self.fuse_d = fuse_d
116 | self.id_d = id_d
117 | self.out_d = out_d
118 | self.conv_fuse = nn.Sequential(
119 | nn.Conv2d(self.fuse_d, self.out_d, kernel_size=3, stride=1, padding=1),
120 | nn.BatchNorm2d(self.out_d),
121 | nn.ReLU(inplace=True),
122 | nn.Conv2d(self.out_d, self.out_d, kernel_size=3, stride=1, padding=1),
123 | nn.BatchNorm2d(self.out_d)
124 | )
125 | self.conv_identity = nn.Conv2d(self.id_d, self.out_d, kernel_size=1)
126 | self.relu = nn.ReLU(inplace=True)
127 |
128 | def forward(self, c_fuse, c):
129 | c_fuse = self.conv_fuse(c_fuse)
130 | c_out = self.relu(c_fuse + self.conv_identity(c))
131 |
132 | return c_out
133 |
134 |
135 | class TemporalFeatureFusionModule(nn.Module):
136 | def __init__(self, in_d, out_d):
137 | super(TemporalFeatureFusionModule, self).__init__()
138 | self.in_d = in_d
139 | self.out_d = out_d
140 | self.relu = nn.ReLU(inplace=True)
141 | # branch 1
142 | self.conv_branch1 = nn.Sequential(
143 | nn.Conv2d(self.in_d, self.in_d, kernel_size=3, stride=1, padding=7, dilation=7),
144 | nn.BatchNorm2d(self.in_d)
145 | )
146 | # branch 2
147 | self.conv_branch2 = nn.Conv2d(self.in_d, self.in_d, kernel_size=1)
148 | self.conv_branch2_f = nn.Sequential(
149 | nn.Conv2d(self.in_d, self.in_d, kernel_size=3, stride=1, padding=5, dilation=5),
150 | nn.BatchNorm2d(self.in_d)
151 | )
152 | # branch 3
153 | self.conv_branch3 = nn.Conv2d(self.in_d, self.in_d, kernel_size=1)
154 | self.conv_branch3_f = nn.Sequential(
155 | nn.Conv2d(self.in_d, self.in_d, kernel_size=3, stride=1, padding=3, dilation=3),
156 | nn.BatchNorm2d(self.in_d)
157 | )
158 | # branch 4
159 | self.conv_branch4 = nn.Conv2d(self.in_d, self.in_d, kernel_size=1)
160 | self.conv_branch4_f = nn.Sequential(
161 | nn.Conv2d(self.in_d, self.out_d, kernel_size=3, stride=1, padding=1, dilation=1),
162 | nn.BatchNorm2d(self.out_d)
163 | )
164 | self.conv_branch5 = nn.Conv2d(self.in_d, self.out_d, kernel_size=1)
165 |
166 | def forward(self, x1, x2):
167 | # temporal fusion
168 | x = torch.abs(x1 - x2)
169 | # branch 1
170 | x_branch1 = self.conv_branch1(x)
171 | # branch 2
172 | x_branch2 = self.relu(self.conv_branch2(x) + x_branch1)
173 | x_branch2 = self.conv_branch2_f(x_branch2)
174 | # branch 3
175 | x_branch3 = self.relu(self.conv_branch3(x) + x_branch2)
176 | x_branch3 = self.conv_branch3_f(x_branch3)
177 | # branch 4
178 | x_branch4 = self.relu(self.conv_branch4(x) + x_branch3)
179 | x_branch4 = self.conv_branch4_f(x_branch4)
180 | x_out = self.relu(self.conv_branch5(x) + x_branch4)
181 |
182 | return x_out
183 |
184 |
185 | class TemporalFusionModule(nn.Module):
186 | def __init__(self, in_d=32, out_d=32):
187 | super(TemporalFusionModule, self).__init__()
188 | self.in_d = in_d
189 | self.out_d = out_d
190 | # fusion
191 | self.tffm_x2 = TemporalFeatureFusionModule(self.in_d, self.out_d)
192 | self.tffm_x3 = TemporalFeatureFusionModule(self.in_d, self.out_d)
193 | self.tffm_x4 = TemporalFeatureFusionModule(self.in_d, self.out_d)
194 | self.tffm_x5 = TemporalFeatureFusionModule(self.in_d, self.out_d)
195 |
196 | def forward(self, x1_2, x1_3, x1_4, x1_5, x2_2, x2_3, x2_4, x2_5):
197 | # temporal fusion
198 | c2 = self.tffm_x2(x1_2, x2_2)
199 | c3 = self.tffm_x3(x1_3, x2_3)
200 | c4 = self.tffm_x4(x1_4, x2_4)
201 | c5 = self.tffm_x5(x1_5, x2_5)
202 |
203 | return c2, c3, c4, c5
204 |
205 | class SupervisedAttentionModule(nn.Module):
206 | def __init__(self, mid_d):
207 | super(SupervisedAttentionModule, self).__init__()
208 | self.mid_d = mid_d
209 | # fusion
210 | self.cls = nn.Conv2d(self.mid_d, 2, kernel_size=1)
211 | self.conv_context = nn.Sequential(
212 | nn.Conv2d(4, self.mid_d, kernel_size=1),
213 | nn.BatchNorm2d(self.mid_d),
214 | nn.ReLU(inplace=True)
215 | )
216 | self.conv2 = nn.Sequential(
217 | nn.Conv2d(self.mid_d, self.mid_d, kernel_size=3, stride=1, padding=1),
218 | nn.BatchNorm2d(self.mid_d),
219 | nn.ReLU(inplace=True)
220 | )
221 |
222 | def forward(self, x):
223 | mask = self.cls(x)
224 | mask_f = torch.sigmoid(mask)
225 | mask_b = 1 - mask_f
226 | context = torch.cat([mask_f, mask_b], dim=1)
227 | context = self.conv_context(context)
228 | x = x.mul(context)
229 | x_out = self.conv2(x)
230 |
231 | return x_out, mask
232 |
233 |
234 | class Decoder(nn.Module):
235 | def __init__(self, mid_d=320):
236 | super(Decoder, self).__init__()
237 | self.mid_d = mid_d #64
238 | # fusion
239 | self.sam_p5 = SupervisedAttentionModule(self.mid_d)
240 | self.sam_p4 = SupervisedAttentionModule(self.mid_d)
241 | self.sam_p3 = SupervisedAttentionModule(self.mid_d)
242 | self.conv_p4 = nn.Sequential(
243 | nn.Conv2d(self.mid_d, self.mid_d, kernel_size=3, stride=1, padding=1),
244 | nn.BatchNorm2d(self.mid_d),
245 | nn.ReLU(inplace=True)
246 | )
247 | self.conv_p3 = nn.Sequential(
248 | nn.Conv2d(self.mid_d, self.mid_d, kernel_size=3, stride=1, padding=1),
249 | nn.BatchNorm2d(self.mid_d),
250 | nn.ReLU(inplace=True)
251 | )
252 | self.conv_p2 = nn.Sequential(
253 | nn.Conv2d(self.mid_d, self.mid_d, kernel_size=3, stride=1, padding=1),
254 | nn.BatchNorm2d(self.mid_d),
255 | nn.ReLU(inplace=True)
256 | )
257 | self.cls = nn.Conv2d(self.mid_d, 2, kernel_size=1)
258 |
259 | def forward(self, d2, d3, d4, d5):
260 | # high-level
261 | p5, mask_p5 = self.sam_p5(d5)
262 | p4 = self.conv_p4(d4 + F.interpolate(p5, scale_factor=(2, 2), mode='bilinear'))
263 |
264 | p4, mask_p4 = self.sam_p4(p4)
265 | p3 = self.conv_p3(d3 + F.interpolate(p4, scale_factor=(2, 2), mode='bilinear'))
266 |
267 | p3, mask_p3 = self.sam_p3(p3)
268 | p2 = self.conv_p2(d2 + F.interpolate(p3, scale_factor=(2, 2), mode='bilinear'))
269 | mask_p2 = self.cls(p2)
270 |
271 | return p2, p3, p4, p5, mask_p2, mask_p3, mask_p4, mask_p5
272 |
273 | class A2Net(nn.Module):
274 | def __init__(self, input_nc=3, output_nc=1):
275 | super(A2Net, self).__init__()
276 | self.backbone = mobilenet_v2(pretrained=True)
277 | channles = [16, 24, 32, 96, 320]
278 | self.en_d = 32
279 | self.mid_d = self.en_d * 2
280 | self.swa = NeighborFeatureAggregation(channles, self.mid_d)
281 | self.tfm = TemporalFusionModule(self.mid_d, self.en_d * 2)
282 | self.decoder = Decoder(self.en_d * 2)
283 |
284 | def forward(self, x1, x2):
285 | # forward backbone resnet
286 | x1_1, x1_2, x1_3, x1_4, x1_5 = self.backbone(x1)
287 | x2_1, x2_2, x2_3, x2_4, x2_5 = self.backbone(x2)
288 | # aggregation
289 | x1_2, x1_3, x1_4, x1_5 = self.swa(x1_2, x1_3, x1_4, x1_5)
290 | x2_2, x2_3, x2_4, x2_5 = self.swa(x2_2, x2_3, x2_4, x2_5)
291 | # temporal fusion
292 | c2, c3, c4, c5 = self.tfm(x1_2, x1_3, x1_4, x1_5, x2_2, x2_3, x2_4, x2_5)
293 | # fpn
294 | p2, p3, p4, p5, mask_p2, mask_p3, mask_p4, mask_p5 = self.decoder(c2, c3, c4, c5)
295 |
296 | # change map
297 | mask_p2 = F.interpolate(mask_p2, scale_factor=(4, 4), mode='bilinear')
298 | # mask_p2 = torch.sigmoid(mask_p2)
299 | mask_p3 = F.interpolate(mask_p3, scale_factor=(8, 8), mode='bilinear')
300 | # mask_p3 = torch.sigmoid(mask_p3)
301 | mask_p4 = F.interpolate(mask_p4, scale_factor=(16, 16), mode='bilinear')
302 | # mask_p4 = torch.sigmoid(mask_p4)
303 | mask_p5 = F.interpolate(mask_p5, scale_factor=(32, 32), mode='bilinear')
304 | # mask_p5 = torch.sigmoid(mask_p5)
305 |
306 | return [mask_p5, mask_p4, mask_p3, mask_p2]
307 |
--------------------------------------------------------------------------------
/compare/DTCDSCN.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn as nn
4 | from torchvision.models import ResNet
5 | import torch.nn.functional as F
6 | from functools import partial
7 |
8 |
9 | nonlinearity = partial(F.relu,inplace=True)
10 |
11 | class SELayer(nn.Module):
12 | def __init__(self, channel, reduction=16):
13 | super(SELayer, self).__init__()
14 | self.avg_pool = nn.AdaptiveAvgPool2d(1)
15 | self.fc = nn.Sequential(
16 | nn.Linear(channel, channel // reduction, bias=False),
17 | nn.ReLU(inplace=True),
18 | nn.Linear(channel // reduction, channel, bias=False),
19 | nn.Sigmoid()
20 | )
21 |
22 | def forward(self, x):
23 | b, c, _, _ = x.size()
24 | y = self.avg_pool(x).view(b, c)
25 | y = self.fc(y).view(b, c, 1, 1)
26 | return x * y.expand_as(x)
27 |
28 | class Dblock_more_dilate(nn.Module):
29 | def __init__(self, channel):
30 | super(Dblock_more_dilate, self).__init__()
31 | self.dilate1 = nn.Conv2d(channel, channel, kernel_size=3, dilation=1, padding=1)
32 | self.dilate2 = nn.Conv2d(channel, channel, kernel_size=3, dilation=2, padding=2)
33 | self.dilate3 = nn.Conv2d(channel, channel, kernel_size=3, dilation=4, padding=4)
34 | self.dilate4 = nn.Conv2d(channel, channel, kernel_size=3, dilation=8, padding=8)
35 | self.dilate5 = nn.Conv2d(channel, channel, kernel_size=3, dilation=16, padding=16)
36 | for m in self.modules():
37 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
38 | if m.bias is not None:
39 | m.bias.data.zero_()
40 |
41 | def forward(self, x):
42 | dilate1_out = nonlinearity(self.dilate1(x))
43 | dilate2_out = nonlinearity(self.dilate2(dilate1_out))
44 | dilate3_out = nonlinearity(self.dilate3(dilate2_out))
45 | dilate4_out = nonlinearity(self.dilate4(dilate3_out))
46 | dilate5_out = nonlinearity(self.dilate5(dilate4_out))
47 | out = x + dilate1_out + dilate2_out + dilate3_out + dilate4_out + dilate5_out
48 | return out
49 | class Dblock(nn.Module):
50 | def __init__(self, channel):
51 | super(Dblock, self).__init__()
52 | self.dilate1 = nn.Conv2d(channel, channel, kernel_size=3, dilation=1, padding=1)
53 | self.dilate2 = nn.Conv2d(channel, channel, kernel_size=3, dilation=2, padding=2)
54 | self.dilate3 = nn.Conv2d(channel, channel, kernel_size=3, dilation=4, padding=4)
55 | self.dilate4 = nn.Conv2d(channel, channel, kernel_size=3, dilation=8, padding=8)
56 | # self.dilate5 = nn.Conv2d(channel, channel, kernel_size=3, dilation=16, padding=16)
57 | for m in self.modules():
58 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
59 | if m.bias is not None:
60 | m.bias.data.zero_()
61 |
62 | def forward(self, x):
63 | dilate1_out = nonlinearity(self.dilate1(x))
64 | dilate2_out = nonlinearity(self.dilate2(dilate1_out))
65 | dilate3_out = nonlinearity(self.dilate3(dilate2_out))
66 | dilate4_out = nonlinearity(self.dilate4(dilate3_out))
67 | # dilate5_out = nonlinearity(self.dilate5(dilate4_out))
68 | out = x + dilate1_out + dilate2_out + dilate3_out + dilate4_out # + dilate5_out
69 | return out
70 |
71 | def conv3x3(in_planes, out_planes, stride=1):
72 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
73 |
74 | class SEBasicBlock(nn.Module):
75 | expansion = 1
76 |
77 | def __init__(self, inplanes, planes, stride=1, downsample=None, reduction=16):
78 | super(SEBasicBlock, self).__init__()
79 | self.conv1 = conv3x3(inplanes, planes, stride)
80 | self.bn1 = nn.BatchNorm2d(planes)
81 | self.relu = nn.ReLU(inplace=True)
82 | self.conv2 = conv3x3(planes, planes, 1)
83 | self.bn2 = nn.BatchNorm2d(planes)
84 | self.se = SELayer(planes, reduction)
85 | self.downsample = downsample
86 | self.stride = stride
87 |
88 | def forward(self, x):
89 | residual = x
90 | out = self.conv1(x)
91 | out = self.bn1(out)
92 | out = self.relu(out)
93 |
94 | out = self.conv2(out)
95 | out = self.bn2(out)
96 | out = self.se(out)
97 |
98 | if self.downsample is not None:
99 | residual = self.downsample(x)
100 |
101 | out += residual
102 | out = self.relu(out)
103 |
104 | return out
105 |
106 | class DecoderBlock(nn.Module):
107 | def __init__(self, in_channels, n_filters):
108 | super(DecoderBlock,self).__init__()
109 |
110 | self.conv1 = nn.Conv2d(in_channels, in_channels // 4, 1)
111 | self.norm1 = nn.BatchNorm2d(in_channels // 4)
112 | self.relu1 = nonlinearity
113 | self.scse = SCSEBlock(in_channels // 4)
114 |
115 | self.deconv2 = nn.ConvTranspose2d(in_channels // 4, in_channels // 4, 3, stride=2, padding=1, output_padding=1)
116 | self.norm2 = nn.BatchNorm2d(in_channels // 4)
117 | self.relu2 = nonlinearity
118 |
119 | self.conv3 = nn.Conv2d(in_channels // 4, n_filters, 1)
120 | self.norm3 = nn.BatchNorm2d(n_filters)
121 | self.relu3 = nonlinearity
122 |
123 | def forward(self, x):
124 | x = self.conv1(x)
125 | x = self.norm1(x)
126 | x = self.relu1(x)
127 | y = self.scse(x)
128 | x = x + y
129 | x = self.deconv2(x)
130 | x = self.norm2(x)
131 | x = self.relu2(x)
132 | x = self.conv3(x)
133 | x = self.norm3(x)
134 | x = self.relu3(x)
135 | return x
136 |
137 | class SCSEBlock(nn.Module):
138 | def __init__(self, channel, reduction=16):
139 | super(SCSEBlock, self).__init__()
140 | self.avg_pool = nn.AdaptiveAvgPool2d(1)
141 |
142 | '''self.channel_excitation = nn.Sequential(nn.(channel, int(channel//reduction)),
143 | nn.ReLU(inplace=True),
144 | nn.Linear(int(channel//reduction), channel),
145 | nn.Sigmoid())'''
146 | self.channel_excitation = nn.Sequential(nn.Conv2d(channel, int(channel//reduction), kernel_size=1,
147 | stride=1, padding=0, bias=False),
148 | nn.ReLU(inplace=True),
149 | nn.Conv2d(int(channel // reduction), channel,kernel_size=1,
150 | stride=1, padding=0, bias=False),
151 | nn.Sigmoid())
152 |
153 | self.spatial_se = nn.Sequential(nn.Conv2d(channel, 1, kernel_size=1,
154 | stride=1, padding=0, bias=False),
155 | nn.Sigmoid())
156 |
157 | def forward(self, x):
158 | bahs, chs, _, _ = x.size()
159 |
160 | # Returns a new tensor with the same data as the self tensor but of a different size.
161 | chn_se = self.avg_pool(x)
162 | chn_se = self.channel_excitation(chn_se)
163 | chn_se = torch.mul(x, chn_se)
164 | spa_se = self.spatial_se(x)
165 | spa_se = torch.mul(x, spa_se)
166 | return torch.add(chn_se, 1, spa_se)
167 |
168 | class CDNet_model(nn.Module):
169 | def __init__(self, in_channels=3, block=SEBasicBlock, layers=[3, 4, 6, 3], num_classes=2):
170 | super(CDNet_model, self).__init__()
171 |
172 | filters = [64, 128, 256, 512]
173 | self.inplanes = 64
174 | self.firstconv = nn.Conv2d(in_channels, 64, kernel_size=7, stride=2, padding=3,
175 | bias=False)
176 | self.firstbn = nn.BatchNorm2d(64)
177 | self.firstrelu = nn.ReLU(inplace=True)
178 | self.firstmaxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
179 | self.encoder1 = self._make_layer(block, 64, layers[0])
180 | self.encoder2 = self._make_layer(block, 128, layers[1], stride=2)
181 | self.encoder3 = self._make_layer(block, 256, layers[2], stride=2)
182 | self.encoder4 = self._make_layer(block, 512, layers[3], stride=2)
183 |
184 | self.decoder4 = DecoderBlock(filters[3], filters[2])
185 | self.decoder3 = DecoderBlock(filters[2], filters[1])
186 | self.decoder2 = DecoderBlock(filters[1], filters[0])
187 | self.decoder1 = DecoderBlock(filters[0], filters[0])
188 |
189 | self.dblock_master = Dblock(512)
190 | self.dblock = Dblock(512)
191 |
192 | self.decoder4_master = DecoderBlock(filters[3], filters[2])
193 | self.decoder3_master = DecoderBlock(filters[2], filters[1])
194 | self.decoder2_master = DecoderBlock(filters[1], filters[0])
195 | self.decoder1_master = DecoderBlock(filters[0], filters[0])
196 |
197 | self.finaldeconv1_master = nn.ConvTranspose2d(filters[0], 32, 4, 2, 1)
198 | self.finalrelu1_master = nonlinearity
199 | self.finalconv2_master = nn.Conv2d(32, 32, 3, padding=1)
200 | self.finalrelu2_master = nonlinearity
201 | self.finalconv3_master = nn.Conv2d(32, num_classes, 3, padding=1)
202 |
203 | self.finaldeconv1 = nn.ConvTranspose2d(filters[0], 32, 4, 2, 1)
204 | self.finalrelu1 = nonlinearity
205 | self.finalconv2 = nn.Conv2d(32, 32, 3, padding=1)
206 | self.finalrelu2 = nonlinearity
207 | self.finalconv3 = nn.Conv2d(32, num_classes, 3, padding=1)
208 |
209 | for m in self.modules():
210 | if isinstance(m, nn.Conv2d):
211 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
212 | m.weight.data.normal_(0, math.sqrt(2. / n))
213 | elif isinstance(m, nn.BatchNorm2d):
214 | m.weight.data.fill_(1)
215 | m.bias.data.zero_()
216 |
217 | def _make_layer(self, block, planes, blocks, stride=1):
218 | downsample = None
219 | if stride != 1 or self.inplanes != planes * block.expansion:
220 | downsample = nn.Sequential(
221 | nn.Conv2d(self.inplanes, planes * block.expansion,
222 | kernel_size=1, stride=stride, bias=False),
223 | nn.BatchNorm2d(planes * block.expansion),
224 | )
225 |
226 | layers = []
227 | layers.append(block(self.inplanes, planes, stride, downsample))
228 | self.inplanes = planes * block.expansion
229 | for i in range(1, blocks):
230 | layers.append(block(self.inplanes, planes))
231 |
232 | return nn.Sequential(*layers)
233 |
234 | def forward(self, x, y):
235 | # Encoder_1
236 | x = self.firstconv(x)
237 | x = self.firstbn(x)
238 | x = self.firstrelu(x)
239 | x = self.firstmaxpool(x)
240 |
241 | e1_x = self.encoder1(x)
242 | e2_x = self.encoder2(e1_x)
243 | e3_x = self.encoder3(e2_x)
244 | e4_x = self.encoder4(e3_x)
245 |
246 | # # Center_1
247 | # e4_x_center = self.dblock(e4_x)
248 |
249 | # # Decoder_1
250 | # d4_x = self.decoder4(e4_x_center) + e3_x
251 | # d3_x = self.decoder3(d4_x) + e2_x
252 | # d2_x = self.decoder2(d3_x) + e1_x
253 | # d1_x = self.decoder1(d2_x)
254 |
255 | # out1 = self.finaldeconv1(d1_x)
256 | # out1 = self.finalrelu1(out1)
257 | # out1 = self.finalconv2(out1)
258 | # out1 = self.finalrelu2(out1)
259 | # out1 = self.finalconv3(out1)
260 |
261 | # Encoder_2
262 | y = self.firstconv(y)
263 | y = self.firstbn(y)
264 | y = self.firstrelu(y)
265 | y = self.firstmaxpool(y)
266 |
267 | e1_y = self.encoder1(y)
268 | e2_y = self.encoder2(e1_y)
269 | e3_y = self.encoder3(e2_y)
270 | e4_y = self.encoder4(e3_y)
271 |
272 | # # Center_2
273 | # e4_y_center = self.dblock(e4_y)
274 |
275 | # # Decoder_2
276 | # d4_y = self.decoder4(e4_y_center) + e3_y
277 | # d3_y = self.decoder3(d4_y) + e2_y
278 | # d2_y = self.decoder2(d3_y) + e1_y
279 | # d1_y = self.decoder1(d2_y)
280 | # out2 = self.finaldeconv1(d1_y)
281 | # out2 = self.finalrelu1(out2)
282 | # out2 = self.finalconv2(out2)
283 | # out2 = self.finalrelu2(out2)
284 | # out2 = self.finalconv3(out2)
285 |
286 | # center_master
287 | e4 = self.dblock_master(e4_x - e4_y)
288 | # decoder_master
289 | d4 = self.decoder4_master(e4) + e3_x - e3_y
290 | d3 = self.decoder3_master(d4) + e2_x - e2_y
291 | d2 = self.decoder2_master(d3) + e1_x - e1_y
292 | d1 = self.decoder1_master(d2)
293 |
294 | out = self.finaldeconv1_master(d1)
295 | out = self.finalrelu1_master(out)
296 | out = self.finalconv2_master(out)
297 | out = self.finalrelu2_master(out)
298 | output = self.finalconv3_master(out)
299 |
300 | # output = []
301 | # output.append(out)
302 |
303 | return output
304 |
305 |
306 |
307 | def CDNet34(in_channels, **kwargs):
308 |
309 | model = CDNet_model(in_channels, SEBasicBlock, [3, 4, 6, 3], **kwargs)
310 |
311 | return model
--------------------------------------------------------------------------------