├── compare ├── Changer.py ├── __pycache__ │ ├── A2Net.cpython-38.pyc │ ├── FC_EF.cpython-38.pyc │ ├── IFNet.cpython-38.pyc │ ├── DMINet.cpython-38.pyc │ ├── DTCDSCN.cpython-38.pyc │ ├── SNUNet.cpython-38.pyc │ ├── TFI_GR.cpython-38.pyc │ ├── resnet.cpython-38.pyc │ ├── MobileNet.cpython-38.pyc │ ├── NestedUNet.cpython-38.pyc │ ├── resnet_tfi.cpython-38.pyc │ ├── ChangeFormer.cpython-38.pyc │ ├── FC_Siam_conc.cpython-38.pyc │ └── FC_Siam_diff.cpython-38.pyc ├── resbase.py ├── MobileNet.py ├── NestedUNet.py ├── SNUNet.py ├── FC_EF.py ├── IFNet.py ├── DASNet.py ├── FC_Siam_conc.py ├── FC_Siam_diff.py ├── TFI_GR.py ├── resnet_tfi.py ├── DMINet.py ├── A2Net.py └── DTCDSCN.py ├── gitignore ├── .idea ├── .gitignore ├── misc.xml ├── vcs.xml ├── inspectionProfiles │ └── profiles_settings.xml ├── modules.xml └── SEIFNet-main.iml ├── models ├── .idea │ ├── .gitignore │ ├── misc.xml │ ├── vcs.xml │ ├── inspectionProfiles │ │ └── profiles_settings.xml │ ├── modules.xml │ └── models.iml ├── __pycache__ │ ├── ASFF.cpython-38.pyc │ ├── CBAM.cpython-38.pyc │ ├── DEFM.cpython-38.pyc │ ├── TAM.cpython-38.pyc │ ├── Models.cpython-38.pyc │ ├── losses.cpython-38.pyc │ ├── resnet.cpython-38.pyc │ ├── FT_loss.cpython-38.pyc │ ├── evaluator.cpython-38.pyc │ ├── networks.cpython-38.pyc │ ├── trainer.cpython-38.pyc │ ├── Focal_loss.cpython-38.pyc │ ├── siamunet_dif.cpython-38.pyc │ └── swin_transformer.cpython-38.pyc ├── CBAM.py └── evaluator.py ├── __pycache__ ├── utils_.cpython-38.pyc ├── data_config.cpython-38.pyc └── main_train.cpython-38.pyc ├── utils ├── __pycache__ │ ├── helpers.cpython-38.pyc │ ├── losses.cpython-38.pyc │ ├── metrics.cpython-38.pyc │ ├── parser.cpython-38.pyc │ ├── transforms.cpython-38.pyc │ └── dataloaders.cpython-38.pyc ├── losses.py ├── parser.py ├── dataloaders.py ├── helpers.py ├── metrics.py └── transforms.py ├── misc ├── __pycache__ │ ├── logger_tool.cpython-38.pyc │ └── metric_tool.cpython-38.pyc ├── pyutils.py ├── logger_tool.py └── metric_tool.py ├── datasets ├── __pycache__ │ ├── CD_dataset.cpython-38.pyc │ └── data_utils.cpython-38.pyc ├── CD_dataset.py └── data_utils.py ├── data_config.py ├── README.md ├── eval.py ├── utils_.py ├── eval_cd.py ├── main_train.py ├── checkpoints └── LEVIR-CD_SEIFNet_ce_Adamw_0.0001_200 │ └── log.txt └── train.py /compare/Changer.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | __pycache__ 3 | checkpoints -------------------------------------------------------------------------------- /.idea/.gitignore: -------------------------------------------------------------------------------- 1 | # Default ignored files 2 | /shelf/ 3 | /workspace.xml 4 | -------------------------------------------------------------------------------- /models/.idea/.gitignore: -------------------------------------------------------------------------------- 1 | # Default ignored files 2 | /shelf/ 3 | /workspace.xml 4 | -------------------------------------------------------------------------------- /__pycache__/utils_.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lixinghua5540/SEIFNet/HEAD/__pycache__/utils_.cpython-38.pyc -------------------------------------------------------------------------------- /__pycache__/data_config.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lixinghua5540/SEIFNet/HEAD/__pycache__/data_config.cpython-38.pyc -------------------------------------------------------------------------------- /__pycache__/main_train.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lixinghua5540/SEIFNet/HEAD/__pycache__/main_train.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/ASFF.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lixinghua5540/SEIFNet/HEAD/models/__pycache__/ASFF.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/CBAM.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lixinghua5540/SEIFNet/HEAD/models/__pycache__/CBAM.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/DEFM.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lixinghua5540/SEIFNet/HEAD/models/__pycache__/DEFM.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/TAM.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lixinghua5540/SEIFNet/HEAD/models/__pycache__/TAM.cpython-38.pyc -------------------------------------------------------------------------------- /compare/__pycache__/A2Net.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lixinghua5540/SEIFNet/HEAD/compare/__pycache__/A2Net.cpython-38.pyc -------------------------------------------------------------------------------- /compare/__pycache__/FC_EF.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lixinghua5540/SEIFNet/HEAD/compare/__pycache__/FC_EF.cpython-38.pyc -------------------------------------------------------------------------------- /compare/__pycache__/IFNet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lixinghua5540/SEIFNet/HEAD/compare/__pycache__/IFNet.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/Models.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lixinghua5540/SEIFNet/HEAD/models/__pycache__/Models.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/losses.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lixinghua5540/SEIFNet/HEAD/models/__pycache__/losses.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/resnet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lixinghua5540/SEIFNet/HEAD/models/__pycache__/resnet.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/helpers.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lixinghua5540/SEIFNet/HEAD/utils/__pycache__/helpers.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/losses.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lixinghua5540/SEIFNet/HEAD/utils/__pycache__/losses.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/metrics.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lixinghua5540/SEIFNet/HEAD/utils/__pycache__/metrics.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/parser.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lixinghua5540/SEIFNet/HEAD/utils/__pycache__/parser.cpython-38.pyc -------------------------------------------------------------------------------- /compare/__pycache__/DMINet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lixinghua5540/SEIFNet/HEAD/compare/__pycache__/DMINet.cpython-38.pyc -------------------------------------------------------------------------------- /compare/__pycache__/DTCDSCN.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lixinghua5540/SEIFNet/HEAD/compare/__pycache__/DTCDSCN.cpython-38.pyc -------------------------------------------------------------------------------- /compare/__pycache__/SNUNet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lixinghua5540/SEIFNet/HEAD/compare/__pycache__/SNUNet.cpython-38.pyc -------------------------------------------------------------------------------- /compare/__pycache__/TFI_GR.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lixinghua5540/SEIFNet/HEAD/compare/__pycache__/TFI_GR.cpython-38.pyc -------------------------------------------------------------------------------- /compare/__pycache__/resnet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lixinghua5540/SEIFNet/HEAD/compare/__pycache__/resnet.cpython-38.pyc -------------------------------------------------------------------------------- /misc/__pycache__/logger_tool.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lixinghua5540/SEIFNet/HEAD/misc/__pycache__/logger_tool.cpython-38.pyc -------------------------------------------------------------------------------- /misc/__pycache__/metric_tool.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lixinghua5540/SEIFNet/HEAD/misc/__pycache__/metric_tool.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/FT_loss.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lixinghua5540/SEIFNet/HEAD/models/__pycache__/FT_loss.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/evaluator.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lixinghua5540/SEIFNet/HEAD/models/__pycache__/evaluator.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/networks.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lixinghua5540/SEIFNet/HEAD/models/__pycache__/networks.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/trainer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lixinghua5540/SEIFNet/HEAD/models/__pycache__/trainer.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/transforms.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lixinghua5540/SEIFNet/HEAD/utils/__pycache__/transforms.cpython-38.pyc -------------------------------------------------------------------------------- /compare/__pycache__/MobileNet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lixinghua5540/SEIFNet/HEAD/compare/__pycache__/MobileNet.cpython-38.pyc -------------------------------------------------------------------------------- /compare/__pycache__/NestedUNet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lixinghua5540/SEIFNet/HEAD/compare/__pycache__/NestedUNet.cpython-38.pyc -------------------------------------------------------------------------------- /compare/__pycache__/resnet_tfi.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lixinghua5540/SEIFNet/HEAD/compare/__pycache__/resnet_tfi.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/Focal_loss.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lixinghua5540/SEIFNet/HEAD/models/__pycache__/Focal_loss.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/dataloaders.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lixinghua5540/SEIFNet/HEAD/utils/__pycache__/dataloaders.cpython-38.pyc -------------------------------------------------------------------------------- /compare/__pycache__/ChangeFormer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lixinghua5540/SEIFNet/HEAD/compare/__pycache__/ChangeFormer.cpython-38.pyc -------------------------------------------------------------------------------- /compare/__pycache__/FC_Siam_conc.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lixinghua5540/SEIFNet/HEAD/compare/__pycache__/FC_Siam_conc.cpython-38.pyc -------------------------------------------------------------------------------- /compare/__pycache__/FC_Siam_diff.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lixinghua5540/SEIFNet/HEAD/compare/__pycache__/FC_Siam_diff.cpython-38.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/CD_dataset.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lixinghua5540/SEIFNet/HEAD/datasets/__pycache__/CD_dataset.cpython-38.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/data_utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lixinghua5540/SEIFNet/HEAD/datasets/__pycache__/data_utils.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/siamunet_dif.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lixinghua5540/SEIFNet/HEAD/models/__pycache__/siamunet_dif.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/swin_transformer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lixinghua5540/SEIFNet/HEAD/models/__pycache__/swin_transformer.cpython-38.pyc -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /models/.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /models/.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /models/.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /models/.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /models/.idea/models.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 12 | -------------------------------------------------------------------------------- /.idea/SEIFNet-main.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 12 | -------------------------------------------------------------------------------- /utils/losses.py: -------------------------------------------------------------------------------- 1 | from utils.parser import get_parser_with_args 2 | from utils.metrics import FocalLoss, dice_loss 3 | 4 | parser, metadata = get_parser_with_args() 5 | opt = parser.parse_args() 6 | 7 | def hybrid_loss(predictions, target): 8 | """Calculating the loss""" 9 | loss = 0 10 | 11 | # gamma=0, alpha=None --> CE 12 | focal = FocalLoss(gamma=0, alpha=None) 13 | 14 | for prediction in predictions: 15 | 16 | bce = focal(prediction, target) 17 | dice = dice_loss(prediction, target) 18 | loss += bce + dice 19 | 20 | return loss 21 | 22 | -------------------------------------------------------------------------------- /misc/pyutils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import random 4 | import glob 5 | 6 | 7 | def seed_random(seed=2020): 8 | # 加入以下随机种子,数据输入,随机扩充等保持一致 9 | random.seed(seed) 10 | os.environ['PYTHONHASHSEED'] = str(seed) 11 | np.random.seed(seed) 12 | 13 | 14 | def mkdir(path): 15 | """create a single empty directory if it didn't exist 16 | 17 | Parameters: 18 | path (str) -- a single directory path 19 | """ 20 | if not os.path.exists(path): 21 | os.makedirs(path) 22 | 23 | 24 | def get_paths(image_folder_path, suffix='*.png'): 25 | """从文件夹中返回指定格式的文件 26 | :param image_folder_path: str 27 | :param suffix: str 28 | :return: list 29 | """ 30 | paths = sorted(glob.glob(os.path.join(image_folder_path, suffix))) 31 | return paths 32 | 33 | 34 | def get_paths_from_list(image_folder_path, list): 35 | """从image folder中找到list中的文件,返回path list""" 36 | out = [] 37 | for item in list: 38 | path = os.path.join(image_folder_path,item) 39 | out.append(path) 40 | return sorted(out) 41 | 42 | 43 | -------------------------------------------------------------------------------- /utils/parser.py: -------------------------------------------------------------------------------- 1 | import argparse as ag 2 | import json 3 | 4 | def get_parser_with_args(metadata_json='metadata.json'): 5 | parser = ag.ArgumentParser(description='Training change detection network') 6 | 7 | with open(metadata_json, 'r') as fin: 8 | metadata = json.load(fin) 9 | parser.set_defaults(**metadata) 10 | 11 | #project save 12 | parser.add_argument('--project name', default='LEVIR+_ViTAE_BIT_bce_100', type=str) 13 | parser.add_argument('--path', default='checkpoints', type=str, help='path of saved model') 14 | # parser.add_argument('--checkpoint_root', default='checkpoints', type=str) 15 | 16 | #network 17 | parser.add_argument('--backbone', default='vitae', type=str, choices=['resnet','swin','vitae'], help='type of model') 18 | 19 | parser.add_argument('--dataset', default='levir+', type=str, choices=['cdd','levir','levir+'], help='type of dataset') 20 | 21 | parser.add_argument('--mode', default='rsp_100', type=str, choices=['imp','rsp_40', 'rsp_100', 'rsp_120' , 'rsp_300', 'rsp_300_sgd', 'seco'], help='type of pretrn') 22 | 23 | 24 | return parser, metadata 25 | 26 | -------------------------------------------------------------------------------- /data_config.py: -------------------------------------------------------------------------------- 1 | 2 | class DataConfig: 3 | data_name = "" 4 | root_dir = "" 5 | label_transform = "norm" 6 | def get_data_config(self, data_name): 7 | self.data_name = data_name 8 | if data_name == 'LEVIR': 9 | self.root_dir = 'E:/CDDataset/LEVIR' 10 | elif data_name == 'DSIFN': 11 | self.label_transform = "norm" 12 | self.root_dir = 'E:/CDDataset/DSIFN_256' 13 | elif data_name == 'SYSU-CD': 14 | self.label_transform = "norm" 15 | self.root_dir = 'E:/CDDataset/SYSU-CD' 16 | elif data_name == 'LEVIR+': 17 | self.label_transform = "norm" 18 | self.root_dir = 'E:/CDDataset/LEVIR-CD+_256' 19 | elif data_name == 'BBCD': 20 | self.label_transform = "norm" 21 | self.root_dir = 'E:/CDDataset/Big_Building_ChangeDetection' 22 | elif data_name == 'GZ_CD': 23 | self.label_transform = "norm" 24 | self.root_dir = 'E:/CDDataset/GZ' 25 | elif data_name == 'WHU-CD': 26 | self.label_transform = "norm" 27 | self.root_dir = 'E:/CDDataset/WHU-CUT' 28 | elif data_name == 'test': 29 | self.label_transform = "norm" 30 | self.root_dir = 'E:/CDDataset/att_test_whu' 31 | elif data_name == 'quick_start': 32 | self.root_dir = './samples/' 33 | else: 34 | raise TypeError('%s has not defined' % data_name) 35 | return self 36 | 37 | 38 | if __name__ == '__main__': 39 | data = DataConfig().get_data_config(data_name='LEVIR') 40 | print(data.data_name) 41 | print(data.root_dir) 42 | print(data.label_transform) 43 | 44 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SEIFNet 2 | 3 | Y. Huang, X. Li, Z. Du, and H. Shen, “Spatiotemporal Enhancement and Interlevel Fusion Network for Remote Sensing Images Change Detection,” IEEE Transactions on Geoscience and Remote Sensing, vol. 61, 2024. 4 | 5 | Abstract: 6 | Remote sensing (RS) images change detection (CD) plays a crucial role in monitoring surface dynamic. However, current deep learning (DL)-based CD methods still suffer from pseudo changes and scale variations due to inadequate exploration of temporal differences and under-utilization of multiscale features. Based on the aforementioned considerations, a spatiotemporal enhancement and interlevel fusion network (SEIFNet) is proposed to improve the ability of feature representation for changing objects. Firstly, the multilevel feature maps are acquired from Siamese hierarchical backbone. To highlight the disparity in the same location at different times, the spatiotemporal difference enhancement modules (ST-DEM) are introduced to capture global and local information from bitemporal feature maps at each level. Coordinate attention and cascaded convolutions are adopted in subtraction and connection branches, respectively. Then, an adaptive context fusion module (ACFM) is designed to integrate interlevel features under the guidance of different semantic information, constituting a progressive decoder. Additionally, a plain refinement module and a concise summation-based prediction head are employed to enhance the boundary details and internal integrity of CD results. The experimental results validate the superiority of our lightweight network over 8 state-of-the-art (SOTA) methods on LEVIR-CD, SYSU-CD and WHU-CD datasets, both in accuracy and efficiency. Also, the effects of different types of backbones and differential enhancement modules are discussed in the ablation experiments in details. 7 | 8 | ![Fig1](https://github.com/lixinghua5540/SEIFNet/assets/75232301/3149f35a-4cca-4111-b03f-17492bf82cef) 9 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data 2 | from utils.parser import get_parser_with_args 3 | from utils.helpers import get_test_loaders 4 | from tqdm import tqdm 5 | from sklearn.metrics import confusion_matrix 6 | 7 | # The Evaluation Methods in our paper are slightly different from this file. 8 | # In our paper, we use the evaluation methods in train.py. specifically, batch size is considered. 9 | # And the evaluation methods in this file usually produce higher numerical indicators. 10 | 11 | parser, metadata = get_parser_with_args() 12 | opt = parser.parse_args() 13 | 14 | dev = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 15 | 16 | if opt.dataset == 'cdd': 17 | opt.dataset_dir = '../Dataset/cdd_dataset/' 18 | elif opt.dataset == 'levir': 19 | opt.dataset_dir = '../Dataset/levir_patch_dataset/' 20 | 21 | test_loader = get_test_loaders(opt) 22 | 23 | #path = 'weights/sunet-32.pt' # the path of the model 24 | model = torch.load(opt.path, map_location='cpu') 25 | 26 | model.to(dev) 27 | 28 | c_matrix = {'tn': 0, 'fp': 0, 'fn': 0, 'tp': 0} 29 | model.eval() 30 | 31 | with torch.no_grad(): 32 | tbar = tqdm(test_loader) 33 | for batch_img1, batch_img2, labels, fname in tbar: 34 | 35 | batch_img1 = batch_img1.float().to(dev) 36 | batch_img2 = batch_img2.float().to(dev) 37 | labels = labels.long().to(dev) 38 | 39 | cd_preds = model(batch_img1, batch_img2) 40 | #cd_preds = cd_preds[-1] # BIT输出不是tuple 41 | _, cd_preds = torch.max(cd_preds, 1) 42 | 43 | #print(cd_preds.shape, labels.shape, fname) 44 | 45 | tn, fp, fn, tp = confusion_matrix(labels.data.cpu().numpy().flatten(), 46 | cd_preds.data.cpu().numpy().flatten(),labels=[0,1]).ravel() 47 | 48 | c_matrix['tn'] += tn 49 | c_matrix['fp'] += fp 50 | c_matrix['fn'] += fn 51 | c_matrix['tp'] += tp 52 | 53 | tn, fp, fn, tp = c_matrix['tn'], c_matrix['fp'], c_matrix['fn'], c_matrix['tp'] 54 | P = tp / (tp + fp) 55 | R = tp / (tp + fn) 56 | F1 = 2 * P * R / (R + P) 57 | 58 | print('Precision: {}\nRecall: {}\nF1-Score: {}'.format(P, R, F1)) 59 | -------------------------------------------------------------------------------- /misc/logger_tool.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import time 3 | 4 | 5 | class Logger(object): 6 | def __init__(self, outfile): 7 | self.terminal = sys.stdout 8 | self.log_path = outfile 9 | now = time.strftime("%c") 10 | self.write('================ (%s) ================\n' % now) 11 | 12 | def write(self, message): 13 | self.terminal.write(message) 14 | with open(self.log_path, mode='a') as f: 15 | f.write(message) 16 | 17 | def write_dict(self, dict): 18 | message = '' 19 | for k, v in dict.items(): 20 | message += '%s: %.7f ' % (k, v) 21 | self.write(message) 22 | 23 | def write_dict_str(self, dict): 24 | message = '' 25 | for k, v in dict.items(): 26 | message += '%s: %s ' % (k, v) 27 | self.write(message) 28 | 29 | def flush(self): 30 | self.terminal.flush() 31 | 32 | 33 | class Timer: 34 | def __init__(self, starting_msg = None): 35 | self.start = time.time() 36 | self.stage_start = self.start 37 | 38 | if starting_msg is not None: 39 | print(starting_msg, time.ctime(time.time())) 40 | 41 | def __enter__(self): 42 | return self 43 | 44 | def __exit__(self, exc_type, exc_val, exc_tb): 45 | return 46 | 47 | def update_progress(self, progress): 48 | self.elapsed = time.time() - self.start 49 | self.est_total = self.elapsed / progress 50 | self.est_remaining = self.est_total - self.elapsed 51 | self.est_finish = int(self.start + self.est_total) 52 | 53 | 54 | def str_estimated_complete(self): 55 | return str(time.ctime(self.est_finish)) 56 | 57 | def str_estimated_remaining(self): 58 | return str(self.est_remaining/3600) + 'h' 59 | 60 | def estimated_remaining(self): 61 | return self.est_remaining/3600 62 | 63 | def get_stage_elapsed(self): 64 | return time.time() - self.stage_start 65 | 66 | def reset_stage(self): 67 | self.stage_start = time.time() 68 | 69 | def lapse(self): 70 | out = time.time() - self.stage_start 71 | self.stage_start = time.time() 72 | return out 73 | 74 | -------------------------------------------------------------------------------- /compare/resbase.py: -------------------------------------------------------------------------------- 1 | ########################################################################### 2 | # Created by: Hang Zhang 3 | # Email: zhang.hang@rutgers.edu 4 | # Copyright (c) 2017 5 | ########################################################################### 6 | 7 | import torch.nn as nn 8 | 9 | 10 | import resnet 11 | 12 | up_kwargs = {'mode': 'bilinear', 'align_corners': True} 13 | 14 | __all__ = ['BaseNet'] 15 | 16 | class BaseNet(nn.Module): 17 | def __init__(self, nclass, backbone, dilated=True, norm_layer=None,root='./pretrain_models', 18 | multi_grid=False, multi_dilation=None): 19 | super(BaseNet, self).__init__() 20 | 21 | # copying modules from pretrained models 22 | if backbone == 'resnet34': 23 | self.pretrained = resnet.resnet34(pretrained=False, dilated=dilated, 24 | norm_layer=norm_layer, root=root, 25 | multi_grid=multi_grid, multi_dilation=multi_dilation) 26 | elif backbone == 'resnet50': 27 | self.pretrained = resnet.resnet50(pretrained=True, dilated=dilated, 28 | norm_layer=norm_layer, root=root, 29 | multi_grid=multi_grid, multi_dilation=multi_dilation) 30 | elif backbone == 'resnet101': 31 | self.pretrained = resnet.resnet101(pretrained=True, dilated=dilated, 32 | norm_layer=norm_layer, root=root, 33 | multi_grid=multi_grid,multi_dilation=multi_dilation) 34 | elif backbone == 'resnet152': 35 | self.pretrained = resnet.resnet152(pretrained=False, dilated=dilated, 36 | norm_layer=norm_layer, root=root, 37 | multi_grid=multi_grid, multi_dilation=multi_dilation) 38 | else: 39 | raise RuntimeError('unknown backbone: {}'.format(backbone)) 40 | # bilinear upsample options 41 | self._up_kwargs = up_kwargs 42 | 43 | def base_forward(self, x): 44 | x = self.pretrained.conv1(x) 45 | x = self.pretrained.bn1(x) 46 | x = self.pretrained.relu(x) 47 | x = self.pretrained.maxpool(x) 48 | c1 = self.pretrained.layer1(x) 49 | c2 = self.pretrained.layer2(c1) 50 | c3 = self.pretrained.layer3(c2) 51 | c4 = self.pretrained.layer4(c3) 52 | return c1, c2, c3, c4 53 | -------------------------------------------------------------------------------- /utils_.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch.utils.data import DataLoader 4 | from torchvision import utils 5 | 6 | import data_config 7 | from datasets.CD_dataset import CDDataset 8 | import argparse 9 | 10 | def get_loader(data_name, img_size=256, batch_size=8, split='test', 11 | is_train=False, dataset='CDDataset'): 12 | dataConfig = data_config.DataConfig().get_data_config(data_name) 13 | root_dir = dataConfig.root_dir 14 | label_transform = dataConfig.label_transform 15 | 16 | if dataset == 'CDDataset': 17 | data_set = CDDataset(root_dir=root_dir, split=split, 18 | img_size=img_size, is_train=is_train, 19 | label_transform=label_transform) 20 | else: 21 | raise NotImplementedError( 22 | 'Wrong dataset name %s (choose one from [CDDataset])' 23 | % dataset) 24 | 25 | shuffle = is_train 26 | dataloader = DataLoader(data_set, batch_size=batch_size, 27 | shuffle=shuffle, num_workers=4) 28 | 29 | 30 | return dataloader 31 | 32 | 33 | def get_loaders(args): 34 | 35 | data_name = args.data_name 36 | dataConfig = data_config.DataConfig().get_data_config(data_name) 37 | root_dir = dataConfig.root_dir 38 | label_transform = dataConfig.label_transform 39 | split = args.split 40 | split_val = 'val' 41 | if hasattr(args, 'split_val'): 42 | split_val = args.split_val 43 | if args.dataset == 'CDDataset': 44 | training_set = CDDataset(root_dir=root_dir, split=split, 45 | img_size=args.img_size,is_train=True, 46 | label_transform=label_transform) 47 | val_set = CDDataset(root_dir=root_dir, split=split_val, 48 | img_size=args.img_size,is_train=False, 49 | label_transform=label_transform) 50 | else: 51 | raise NotImplementedError( 52 | 'Wrong dataset name %s (choose one from [CDDataset,])' 53 | % args.dataset) 54 | 55 | datasets = {'train': training_set, 'val': val_set} 56 | dataloaders = {x: DataLoader(datasets[x], batch_size=args.batch_size, 57 | shuffle=True, num_workers=args.num_workers) 58 | for x in ['train', 'val']} 59 | 60 | return dataloaders 61 | 62 | 63 | def make_numpy_grid(tensor_data, pad_value=0,padding=0): 64 | tensor_data = tensor_data.detach() 65 | vis = utils.make_grid(tensor_data, pad_value=pad_value,padding=padding) 66 | vis = np.array(vis.cpu()).transpose((1,2,0)) 67 | if vis.shape[2] == 1: 68 | vis = np.stack([vis, vis, vis], axis=-1) 69 | return vis 70 | 71 | 72 | def de_norm(tensor_data): 73 | return tensor_data * 0.5 + 0.5 74 | 75 | 76 | def get_device(args): 77 | # set gpu ids 78 | str_ids = args.gpu_ids.split(',') 79 | args.gpu_ids = [] 80 | for str_id in str_ids: 81 | id = int(str_id) 82 | if id >= 0: 83 | args.gpu_ids.append(id) 84 | if len(args.gpu_ids) > 0: 85 | torch.cuda.set_device(args.gpu_ids[0]) 86 | 87 | def str2bool(v): 88 | if v.lower() in ['true', 1]: 89 | return True 90 | elif v.lower() in ['false', 0]: 91 | return False 92 | else: 93 | raise argparse.ArgumentTypeError('Boolean value expected.') -------------------------------------------------------------------------------- /eval_cd.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | import torch 3 | from models.evaluator import * 4 | from utils_ import str2bool 5 | print(torch.cuda.is_available()) 6 | 7 | 8 | """ 9 | eval the CD model 10 | """ 11 | 12 | def main(): 13 | # ------------ 14 | # args 15 | # ------------ 16 | parser = ArgumentParser() 17 | parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU') 18 | parser.add_argument('--project_name', default='LEVIR_SEIFNet_ce_Adamw_0.0001_150', type=str) 19 | parser.add_argument('--print_models', default=False, type=bool, help='print models') 20 | parser.add_argument('--checkpoints_root', default='checkpoints', type=str) 21 | parser.add_argument('--vis_root', default='vis', type=str) 22 | 23 | # data 24 | parser.add_argument('--num_workers', default=8, type=int) 25 | parser.add_argument('--dataset', default='CDDataset', type=str) 26 | parser.add_argument('--data_name', default='LEVIR', type=str) 27 | 28 | parser.add_argument('--batch_size', default=1, type=int) 29 | parser.add_argument('--split', default="test", type=str) 30 | 31 | parser.add_argument('--img_size', default=256, type=int) 32 | 33 | # model 34 | parser.add_argument('--n_class', default=2, type=int) 35 | parser.add_argument('--embed_dim', default=64, type=int) 36 | parser.add_argument('--net_G', default='SEIFNet', type=str, 37 | help='base_resnet18 | base_transformer_pos_s4_dd8 | base_transformer_pos_s4_dd8_dedim8|' 38 | 'vitae_transformer|SEIFNet') 39 | parser.add_argument('--backbone', default='L-Backbone', type=str, choices=['resnet', 'swin', 'vitae'], 40 | help='type of model') 41 | parser.add_argument('--mode', default='None', type=str, 42 | choices=['imp', 'rsp_40', 'rsp_100', 'rsp_120', 'rsp_300', 'rsp_300_sgd', 'seco'], 43 | help='type of pretrn') 44 | parser.add_argument('--deep_supervision', default=False, type=str2bool) # UNet++和A2net时为True,不需要时为False 45 | # parser.add_argument('--loss1', default='Focal', type=str,help='ce|BL_ce|Focal|Focal_Dice|Focal_Dice_BL|Focal_BL|BL_Focal|Focal_BF_IOU') 46 | # parser.add_argument('--loss2', default='BL_Focal', type=str, 47 | # help='ce|BL_ce|Focal|Focal_Dice|Focal_Dice_BL|Focal_BL|BL_Focal|Focal_BF_IOU') 48 | parser.add_argument('--loss_SD', default=True, type=str2bool) # 只有CD_Net才为True 49 | 50 | parser.add_argument('--checkpoint_name', default='best_ckpt.pt', type=str) 51 | 52 | args = parser.parse_args() 53 | utils_.get_device(args) 54 | print(args.gpu_ids) 55 | 56 | # checkpoints dir 57 | args.checkpoint_dir = os.path.join(args.checkpoints_root, args.project_name) 58 | os.makedirs(args.checkpoint_dir, exist_ok=True) 59 | # visualize dir 60 | args.vis_dir = os.path.join(args.vis_root, args.project_name) 61 | os.makedirs(args.vis_dir, exist_ok=True) 62 | 63 | dataloader = utils_.get_loader(args.data_name, img_size=args.img_size, 64 | batch_size=args.batch_size, is_train=False, 65 | split=args.split) 66 | model = CDEvaluator(args=args, dataloader=dataloader) 67 | 68 | model.eval_models(args=args,checkpoint_name=args.checkpoint_name) 69 | 70 | 71 | if __name__ == '__main__': 72 | main() 73 | 74 | -------------------------------------------------------------------------------- /main_train.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | import torch 3 | from models.trainer import * 4 | from utils_ import str2bool 5 | print(torch.cuda.is_available()) 6 | 7 | """ 8 | the main function for training the CD networks 9 | """ 10 | 11 | 12 | def train(args): 13 | dataloaders = utils_.get_loaders(args) 14 | model = CDTrainer(args=args, dataloaders=dataloaders) 15 | model.train_models(args=args) 16 | # model.train_models() 17 | 18 | 19 | def test(args): 20 | from models.evaluator import CDEvaluator 21 | dataloader = utils_.get_loader(args.data_name, img_size=args.img_size, 22 | batch_size=args.batch_size, is_train=False, 23 | split='test') 24 | model = CDEvaluator(args=args, dataloader=dataloader) 25 | 26 | model.eval_models(args) 27 | 28 | 29 | if __name__ == '__main__': 30 | # ------------ 31 | # args 32 | # ------------ 33 | parser = ArgumentParser() 34 | parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU') 35 | parser.add_argument('--project_name', default='LEVIR-CD_SEIFNet_ce_Adamw_0.0001_200', type=str) 36 | #SYSU_res18_coMDE2_AFF_Bcedice_l0_Adamw_0.0001_200 37 | #LEVIR_transformer_CoDEM_AFF_ce_Adamw_0.0001_200_2 38 | parser.add_argument('--checkpoint_root', default='checkpoints', type=str) 39 | 40 | # data 41 | parser.add_argument('--num_workers', default=4, type=int) 42 | parser.add_argument('--dataset', default='CDDataset', type=str) 43 | parser.add_argument('--data_name', default='LEVIR', type=str,help='ChangeDetection|LEVIR|DSFIN|SYSU-CD|LEVIR+|BBCD|WHU-CD') 44 | 45 | parser.add_argument('--batch_size', default=2, type=int) 46 | parser.add_argument('--split', default="train", type=str) 47 | parser.add_argument('--split_val', default="val", type=str) 48 | 49 | parser.add_argument('--img_size', default=256, type=int) 50 | 51 | # model 52 | parser.add_argument('--n_class', default=2, type=int) 53 | parser.add_argument('--embed_dim', default=64, type=int) 54 | parser.add_argument('--net_G', default='SEIFNet', type=str, 55 | help='FC_EF | FC_Siam_conc | ' 56 | 'FC_Siam_diff | UNet++|SNUNet|' 57 | 'DTCDSCN|IFNet|' 58 | 'base_transformer_pos_s4_dd8_dedim8|' 59 | 'ChangeFormer|' 60 | 'A2Net|DMINet|TFI-GR|' 61 | 'SEIFNet') 62 | parser.add_argument('--backbone', default='L-Backbone', type=str, choices=['resnet', 'swin', 'vitae','L-Backbone-cross','L-Backbone','BiFormer'], 63 | help='type of model') 64 | parser.add_argument('--mode', default='None', type=str, 65 | choices=['imp','res18 ','rsp_40', 'rsp_100', 'rsp_120', 'rsp_300', 'rsp_300_sgd', 'seco','None'], 66 | help='type of pretrn') 67 | parser.add_argument('--deep_supervision', default=False, type=str2bool)#UNet++和A2net时为True,不需要时为False 68 | 69 | parser.add_argument('--loss_SD', default=False,type=str2bool) #IFNet DMINet才为True 70 | # optimizer 71 | parser.add_argument('--optimizer', default='adamw', type=str) 72 | parser.add_argument('--lr', default=0.0001, type=float) 73 | parser.add_argument('--max_epochs', default=200, type=int) #150 74 | parser.add_argument('--lr_policy', default='linear', type=str, 75 | help='linear | step') 76 | parser.add_argument('--lr_decay_iters', default=200, type=int) 77 | 78 | args = parser.parse_args() 79 | utils_.get_device(args) 80 | print(args.gpu_ids) 81 | 82 | # checkpoints dir 83 | args.checkpoint_dir = os.path.join(args.checkpoint_root, args.project_name) 84 | os.makedirs(args.checkpoint_dir, exist_ok=True) 85 | # visualize dir 86 | args.vis_dir = os.path.join('vis', args.project_name) 87 | os.makedirs(args.vis_dir, exist_ok=True) 88 | 89 | train(args) 90 | 91 | test(args) 92 | -------------------------------------------------------------------------------- /checkpoints/LEVIR-CD_SEIFNet_ce_Adamw_0.0001_200/log.txt: -------------------------------------------------------------------------------- 1 | ================ (Tue Jan 23 17:10:54 2024) ================ 2 | gpu_ids: [0] project_name: LEVIR-CD_SEIFNet_ce_Adamw_0.0001_200 checkpoint_root: checkpoints num_workers: 4 dataset: CDDataset data_name: LEVIR batch_size: 8 split: train split_val: val img_size: 256 n_class: 2 embed_dim: 64 net_G: SEIFNet backbone: L-Backbone mode: None deep_supervision: False loss_SD: False optimizer: adamw lr: 0.0001 max_epochs: 200 lr_policy: linear lr_decay_iters: 200 checkpoint_dir: checkpoints\LEVIR-CD_SEIFNet_ce_Adamw_0.0001_200 vis_dir: vis\LEVIR-CD_SEIFNet_ce_Adamw_0.0001_200 lr: 0.0001000 3 | ================ (Tue Jan 23 17:15:15 2024) ================ 4 | gpu_ids: [0] project_name: LEVIR-CD_SEIFNet_ce_Adamw_0.0001_200 checkpoint_root: checkpoints num_workers: 4 dataset: CDDataset data_name: LEVIR batch_size: 8 split: train split_val: val img_size: 256 n_class: 2 embed_dim: 64 net_G: SEIFNet backbone: L-Backbone mode: None deep_supervision: False loss_SD: False optimizer: adamw lr: 0.0001 max_epochs: 200 lr_policy: linear lr_decay_iters: 200 checkpoint_dir: checkpoints\LEVIR-CD_SEIFNet_ce_Adamw_0.0001_200 vis_dir: vis\LEVIR-CD_SEIFNet_ce_Adamw_0.0001_200 lr: 0.0001000 5 | ================ (Tue Jan 23 17:16:10 2024) ================ 6 | gpu_ids: [0] project_name: LEVIR-CD_SEIFNet_ce_Adamw_0.0001_200 checkpoint_root: checkpoints num_workers: 4 dataset: CDDataset data_name: LEVIR batch_size: 8 split: train split_val: val img_size: 256 n_class: 2 embed_dim: 64 net_G: SEIFNet backbone: L-Backbone mode: None deep_supervision: False loss_SD: False optimizer: adamw lr: 0.0001 max_epochs: 200 lr_policy: linear lr_decay_iters: 200 checkpoint_dir: checkpoints\LEVIR-CD_SEIFNet_ce_Adamw_0.0001_200 vis_dir: vis\LEVIR-CD_SEIFNet_ce_Adamw_0.0001_200 lr: 0.0001000 7 | ================ (Tue Jan 23 17:16:37 2024) ================ 8 | gpu_ids: [0] project_name: LEVIR-CD_SEIFNet_ce_Adamw_0.0001_200 checkpoint_root: checkpoints num_workers: 4 dataset: CDDataset data_name: LEVIR batch_size: 2 split: train split_val: val img_size: 256 n_class: 2 embed_dim: 64 net_G: SEIFNet backbone: L-Backbone mode: None deep_supervision: False loss_SD: False optimizer: adamw lr: 0.0001 max_epochs: 200 lr_policy: linear lr_decay_iters: 200 checkpoint_dir: checkpoints\LEVIR-CD_SEIFNet_ce_Adamw_0.0001_200 vis_dir: vis\LEVIR-CD_SEIFNet_ce_Adamw_0.0001_200 lr: 0.0001000 9 | ================ (Tue Jan 23 17:17:16 2024) ================ 10 | gpu_ids: [0] project_name: LEVIR-CD_SEIFNet_ce_Adamw_0.0001_200 checkpoint_root: checkpoints num_workers: 4 dataset: CDDataset data_name: LEVIR batch_size: 2 split: train split_val: val img_size: 256 n_class: 2 embed_dim: 64 net_G: SEIFNet backbone: L-Backbone mode: None deep_supervision: False loss_SD: False optimizer: adamw lr: 0.0001 max_epochs: 200 lr_policy: linear lr_decay_iters: 200 checkpoint_dir: checkpoints\LEVIR-CD_SEIFNet_ce_Adamw_0.0001_200 vis_dir: vis\LEVIR-CD_SEIFNet_ce_Adamw_0.0001_200 lr: 0.0001000 11 | ================ (Tue Jan 23 17:21:21 2024) ================ 12 | gpu_ids: [0] project_name: LEVIR-CD_SEIFNet_ce_Adamw_0.0001_200 checkpoint_root: checkpoints num_workers: 4 dataset: CDDataset data_name: LEVIR batch_size: 2 split: train split_val: val img_size: 256 n_class: 2 embed_dim: 64 net_G: SEIFNet backbone: L-Backbone mode: None deep_supervision: False loss_SD: False optimizer: adamw lr: 0.0001 max_epochs: 200 lr_policy: linear lr_decay_iters: 200 checkpoint_dir: checkpoints\LEVIR-CD_SEIFNet_ce_Adamw_0.0001_200 vis_dir: vis\LEVIR-CD_SEIFNet_ce_Adamw_0.0001_200 lr: 0.0001000 13 | ================ (Tue Jan 23 17:33:05 2024) ================ 14 | gpu_ids: [0] project_name: LEVIR-CD_SEIFNet_ce_Adamw_0.0001_200 checkpoint_root: checkpoints num_workers: 4 dataset: CDDataset data_name: LEVIR batch_size: 2 split: train split_val: val img_size: 256 n_class: 2 embed_dim: 64 net_G: SEIFNet backbone: L-Backbone mode: None deep_supervision: False loss_SD: False optimizer: adamw lr: 0.0001 max_epochs: 200 lr_policy: linear lr_decay_iters: 200 checkpoint_dir: checkpoints\LEVIR-CD_SEIFNet_ce_Adamw_0.0001_200 vis_dir: vis\LEVIR-CD_SEIFNet_ce_Adamw_0.0001_200 lr: 0.0001000 15 | Is_training: True. [0,199][1,3560], imps: 0.83, est: 955.75h, G_loss: 0.60956, running_mf1: 0.47574 16 | -------------------------------------------------------------------------------- /compare/MobileNet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | # from torchvision.models.utils import load_state_dict_from_url 3 | from torch.hub import load_state_dict_from_url 4 | model_urls = { 5 | 'mobilenet_v2': 'https://download.pytorch.org/models/mobilenet_v2-b0353104.pth', 6 | } 7 | 8 | 9 | class ConvBNReLU(nn.Sequential): 10 | def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1, dilation=1): 11 | padding = (kernel_size - 1) // 2 12 | if dilation != 1: 13 | padding = dilation 14 | super(ConvBNReLU, self).__init__( 15 | nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, dilation=dilation, 16 | bias=False), 17 | nn.BatchNorm2d(out_planes), 18 | nn.ReLU6(inplace=True) 19 | ) 20 | 21 | 22 | class InvertedResidual(nn.Module): 23 | def __init__(self, inp, oup, stride, expand_ratio, dilation=1): 24 | super(InvertedResidual, self).__init__() 25 | self.stride = stride 26 | assert stride in [1, 2] 27 | 28 | hidden_dim = int(round(inp * expand_ratio)) 29 | self.use_res_connect = self.stride == 1 and inp == oup 30 | 31 | layers = [] 32 | if expand_ratio != 1: 33 | # pw 34 | layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1)) 35 | layers.extend([ 36 | # dw 37 | ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim, dilation=dilation), 38 | # pw-linear 39 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), 40 | nn.BatchNorm2d(oup), 41 | ]) 42 | self.conv = nn.Sequential(*layers) 43 | 44 | def forward(self, x): 45 | if self.use_res_connect: 46 | return x + self.conv(x) 47 | else: 48 | return self.conv(x) 49 | 50 | 51 | class MobileNetV2(nn.Module): 52 | def __init__(self, pretrained=None, num_classes=1000, width_mult=1.0): 53 | super(MobileNetV2, self).__init__() 54 | block = InvertedResidual 55 | input_channel = 32 56 | last_channel = 1280 57 | inverted_residual_setting = [ 58 | # t, c, n, s, d 59 | [1, 16, 1, 1, 1], 60 | [6, 24, 2, 2, 1], 61 | [6, 32, 3, 2, 1], 62 | [6, 64, 4, 2, 1], 63 | [6, 96, 3, 1, 1], 64 | [6, 160, 3, 2, 1], 65 | [6, 320, 1, 1, 1], 66 | ] 67 | 68 | # building first layer 69 | input_channel = int(input_channel * width_mult) 70 | self.last_channel = int(last_channel * max(1.0, width_mult)) 71 | features = [ConvBNReLU(3, input_channel, stride=2)] 72 | # building inverted residual blocks 73 | for t, c, n, s, d in inverted_residual_setting: 74 | output_channel = int(c * width_mult) 75 | for i in range(n): 76 | stride = s if i == 0 else 1 77 | dilation = d if i == 0 else 1 78 | features.append(block(input_channel, output_channel, stride, expand_ratio=t, dilation=d)) 79 | input_channel = output_channel 80 | # building last several layers 81 | features.append(ConvBNReLU(input_channel, self.last_channel, kernel_size=1)) 82 | # make it nn.Sequential 83 | self.features = nn.Sequential(*features) 84 | 85 | # weight initialization 86 | for m in self.modules(): 87 | if isinstance(m, nn.Conv2d): 88 | nn.init.kaiming_normal_(m.weight, mode='fan_out') 89 | if m.bias is not None: 90 | nn.init.zeros_(m.bias) 91 | elif isinstance(m, nn.BatchNorm2d): 92 | nn.init.ones_(m.weight) 93 | nn.init.zeros_(m.bias) 94 | elif isinstance(m, nn.Linear): 95 | nn.init.normal_(m.weight, 0, 0.01) 96 | nn.init.zeros_(m.bias) 97 | 98 | def forward(self, x): 99 | res = [] 100 | for idx, m in enumerate(self.features): 101 | x = m(x) 102 | if idx in [1, 3, 6, 13, 17]: 103 | res.append(x) 104 | return res 105 | 106 | 107 | def mobilenet_v2(pretrained=True, progress=True, **kwargs): 108 | model = MobileNetV2(**kwargs) 109 | if pretrained: 110 | state_dict = load_state_dict_from_url(model_urls['mobilenet_v2'], 111 | progress=progress) 112 | print("loading imagenet pretrained mobilenetv2") 113 | model.load_state_dict(state_dict, strict=False) 114 | print("loaded imagenet pretrained mobilenetv2") 115 | return model 116 | -------------------------------------------------------------------------------- /utils/dataloaders.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch.utils.data as data 3 | from PIL import Image 4 | from utils import transforms as tr 5 | 6 | 7 | ''' 8 | Load all training and validation data paths 9 | ''' 10 | def full_path_loader(data_dir): 11 | train_data = [i for i in os.listdir(data_dir + 'train/A/') if not 12 | i.startswith('.')] 13 | train_data.sort() 14 | 15 | valid_data = [i for i in os.listdir(data_dir + 'val/A/') if not 16 | i.startswith('.')] 17 | valid_data.sort() 18 | 19 | train_label_paths = [] 20 | val_label_paths = [] 21 | for img in train_data: 22 | train_label_paths.append(data_dir + 'train/label/' + img) 23 | for img in valid_data: 24 | val_label_paths.append(data_dir + 'val/label/' + img) 25 | 26 | 27 | train_data_path = [] 28 | val_data_path = [] 29 | 30 | for img in train_data: 31 | train_data_path.append([data_dir + 'train/', img]) 32 | for img in valid_data: 33 | val_data_path.append([data_dir + 'val/', img]) 34 | 35 | train_dataset = {} 36 | val_dataset = {} 37 | for cp in range(len(train_data)): 38 | train_dataset[cp] = {'image': train_data_path[cp], 39 | 'label': train_label_paths[cp]} 40 | for cp in range(len(valid_data)): 41 | val_dataset[cp] = {'image': val_data_path[cp], 42 | 'label': val_label_paths[cp]} 43 | 44 | 45 | return train_dataset, val_dataset 46 | 47 | ''' 48 | Load all testing data paths 49 | ''' 50 | def full_test_loader(data_dir): 51 | 52 | test_data = [i for i in os.listdir(data_dir + 'test/A/') if not 53 | i.startswith('.')] 54 | test_data.sort() 55 | 56 | test_label_paths = [] 57 | for img in test_data: 58 | test_label_paths.append(data_dir + 'test/label/' + img) 59 | 60 | test_data_path = [] 61 | for img in test_data: 62 | test_data_path.append([data_dir + 'test/', img]) 63 | 64 | test_dataset = {} 65 | for cp in range(len(test_data)): 66 | test_dataset[cp] = {'image': test_data_path[cp], 67 | 'label': test_label_paths[cp]} 68 | 69 | return test_dataset 70 | 71 | def cdd_loader(img_path, label_path, aug): 72 | dir = img_path[0] 73 | name = img_path[1] 74 | 75 | img1 = Image.open(dir + 'A/' + name) 76 | img2 = Image.open(dir + 'B/' + name) 77 | label = Image.open(label_path) 78 | sample = {'image': (img1, img2), 'label': label} 79 | 80 | if aug: 81 | sample = tr.train_transforms(sample) 82 | else: 83 | sample = tr.test_transforms(sample) 84 | 85 | return sample['image'][0], sample['image'][1], sample['label'], name 86 | 87 | 88 | class CDDloader(data.Dataset): 89 | 90 | def __init__(self, full_load, flag = 'trn', aug=False): 91 | 92 | self.full_load = full_load 93 | self.loader = cdd_loader 94 | self.aug = aug 95 | 96 | print('load {} cdd {} pairs'.format(len(self.full_load), flag)) 97 | 98 | def __getitem__(self, index): 99 | 100 | img_path, label_path = self.full_load[index]['image'], self.full_load[index]['label'] 101 | 102 | return self.loader(img_path, 103 | label_path, 104 | self.aug) 105 | 106 | def __len__(self): 107 | return len(self.full_load) 108 | 109 | class LEVIRloader(data.Dataset): 110 | 111 | def __init__(self, full_load, flag = 'trn', aug=False): 112 | 113 | self.full_load = full_load 114 | self.loader = cdd_loader 115 | self.aug = aug 116 | 117 | print('load {} levir {} pairs'.format(len(self.full_load), flag)) 118 | 119 | def __getitem__(self, index): 120 | 121 | img_path, label_path = self.full_load[index]['image'], self.full_load[index]['label'] 122 | 123 | return self.loader(img_path, 124 | label_path, 125 | self.aug) 126 | 127 | def __len__(self): 128 | return len(self.full_load) 129 | 130 | class LEVIRplusloader(data.Dataset): 131 | 132 | def __init__(self, full_load, flag = 'trn', aug=False): 133 | 134 | self.full_load = full_load 135 | self.loader = cdd_loader 136 | self.aug = aug 137 | 138 | print('load {} levir {} pairs'.format(len(self.full_load), flag)) 139 | 140 | def __getitem__(self, index): 141 | 142 | img_path, label_path = self.full_load[index]['image'], self.full_load[index]['label'] 143 | 144 | return self.loader(img_path, 145 | label_path, 146 | self.aug) 147 | 148 | def __len__(self): 149 | return len(self.full_load) -------------------------------------------------------------------------------- /datasets/CD_dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | 变化检测数据集 3 | """ 4 | 5 | import os 6 | from PIL import Image 7 | from PIL import Image,ImageFile 8 | ImageFile.LOAD_TRUNCATED_IMAGES = True 9 | 10 | import numpy as np 11 | 12 | from torch.utils import data 13 | 14 | from datasets.data_utils import CDDataAugmentation 15 | 16 | 17 | """ 18 | CD data set with pixel-level labels; 19 | ├─image 20 | ├─image_post 21 | ├─label 22 | └─list 23 | """ 24 | IMG_FOLDER_NAME = "A" 25 | IMG_POST_FOLDER_NAME = 'B' 26 | # IMG_SOBEL_FOLDER_NAME = 'diff' 27 | LIST_FOLDER_NAME = 'list' 28 | ANNOT_FOLDER_NAME = "label" 29 | 30 | IGNORE = 255 31 | 32 | label_suffix='.png' # jpg for gan dataset, others : png 33 | 34 | def load_img_name_list(dataset_path): 35 | img_name_list = np.loadtxt(dataset_path, dtype=np.str) 36 | if img_name_list.ndim == 2: 37 | return img_name_list[:, 0] 38 | return img_name_list 39 | 40 | 41 | def load_image_label_list_from_npy(npy_path, img_name_list): 42 | cls_labels_dict = np.load(npy_path, allow_pickle=True).item() 43 | return [cls_labels_dict[img_name] for img_name in img_name_list] 44 | 45 | 46 | def get_img_post_path(root_dir,split,img_name): 47 | return os.path.join(root_dir,split, IMG_POST_FOLDER_NAME, img_name) 48 | 49 | def get_sobel_path(root_dir,split,img_name): 50 | return os.path.join(root_dir,split, IMG_SOBEL_FOLDER_NAME, img_name) 51 | 52 | 53 | def get_img_path(root_dir,split, img_name): 54 | return os.path.join(root_dir, split,IMG_FOLDER_NAME, img_name) 55 | 56 | 57 | def get_label_path(root_dir,split, img_name): 58 | return os.path.join(root_dir, split,ANNOT_FOLDER_NAME, img_name.replace('.png', label_suffix)) 59 | 60 | 61 | class ImageDataset(data.Dataset): 62 | """VOCdataloder""" 63 | def __init__(self, root_dir, split='train', img_size=256, is_train=True,to_tensor=True): 64 | super(ImageDataset, self).__init__() 65 | self.root_dir = root_dir 66 | self.img_size = img_size 67 | self.split = split # train | train_aug++ | val 68 | # self.list_path = self.root_dir + '/' + LIST_FOLDER_NAME + '/' + self.list + '.txt' 69 | self.list_path = os.path.join(self.root_dir, LIST_FOLDER_NAME, self.split+'.txt') 70 | self.img_name_list = load_img_name_list(self.list_path) 71 | 72 | self.A_size = len(self.img_name_list) # get the size of dataset A 73 | self.to_tensor = to_tensor 74 | if is_train: 75 | self.augm = CDDataAugmentation( 76 | img_size=self.img_size, 77 | with_random_hflip=True, 78 | with_random_vflip=True, 79 | with_scale_random_crop=True, 80 | with_random_blur=True, 81 | ) 82 | else: 83 | self.augm = CDDataAugmentation( 84 | img_size=self.img_size 85 | ) 86 | def __getitem__(self, index): 87 | name = self.img_name_list[index] 88 | A_path = get_img_path(self.root_dir,self.split, self.img_name_list[index % self.A_size])#得到时相1影像的文件路径 89 | B_path = get_img_post_path(self.root_dir, self.split,self.img_name_list[index % self.A_size]) 90 | 91 | img = np.asarray(Image.open(A_path).convert('RGB'))#时相1 92 | img_B = np.asarray(Image.open(B_path).convert('RGB'))#时相2 93 | 94 | [img, img_B], _ = self.augm.transform([img, img_B],[], to_tensor=self.to_tensor) 95 | 96 | return {'A': img, 'B': img_B, 'name': name} 97 | 98 | def __len__(self): 99 | """Return the total number of images in the dataset.""" 100 | return self.A_size 101 | 102 | 103 | class CDDataset(ImageDataset): 104 | 105 | def __init__(self, root_dir, img_size, split='train', is_train=True, label_transform=None, 106 | to_tensor=True): 107 | super(CDDataset, self).__init__(root_dir, img_size=img_size, split=split, is_train=is_train, 108 | to_tensor=to_tensor) 109 | self.label_transform = label_transform 110 | 111 | def __getitem__(self, index): 112 | name = self.img_name_list[index] 113 | A_path = get_img_path(self.root_dir,self.split,self.img_name_list[index % self.A_size]) 114 | B_path = get_img_post_path(self.root_dir,self.split, self.img_name_list[index % self.A_size]) 115 | img = np.asarray(Image.open(A_path).convert('RGB')) 116 | img_B = np.asarray(Image.open(B_path).convert('RGB')) 117 | #sobel差分特征图 118 | # sobel_path = get_sobel_path(self.root_dir,self.split, self.img_name_list[index % self.A_size]) 119 | # sobel = np.asarray(Image.open(sobel_path).convert('RGB')) 120 | 121 | L_path = get_label_path(self.root_dir, self.split,self.img_name_list[index % self.A_size]) 122 | 123 | label = np.array(Image.open(L_path), dtype=np.uint8) 124 | 125 | # 二分类中,前景标注为255 126 | if self.label_transform == 'norm': 127 | label = label // 255 128 | # 129 | # [img, img_B,sobel], [label] = self.augm.transform([img, img_B,sobel], [label], to_tensor=self.to_tensor) 130 | [img, img_B], [label] = self.augm.transform([img, img_B], [label], to_tensor=self.to_tensor) 131 | # print(label.max()) 132 | # return {'name': name, 'A': img, 'B': img_B,'L':label, 'S': sobel} 133 | return {'name': name, 'A': img, 'B': img_B, 'L': label} 134 | 135 | -------------------------------------------------------------------------------- /compare/NestedUNet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | class VGGBlock(nn.Module): 5 | def __init__(self, in_channels, middle_channels, out_channels): 6 | super().__init__() 7 | self.relu = nn.ReLU(inplace=True) 8 | self.conv1 = nn.Conv2d(in_channels, middle_channels, 3, padding=1) 9 | self.bn1 = nn.BatchNorm2d(middle_channels) 10 | self.conv2 = nn.Conv2d(middle_channels, out_channels, 3, padding=1) 11 | self.bn2 = nn.BatchNorm2d(out_channels) 12 | 13 | def forward(self, x): 14 | out = self.conv1(x) 15 | out = self.bn1(out) 16 | out = self.relu(out) 17 | 18 | out = self.conv2(out) 19 | out = self.bn2(out) 20 | out = self.relu(out) 21 | 22 | return out 23 | 24 | class standard_unit(nn.Module): 25 | def __init__(self,in_channels,middle_channels,out_channels): 26 | super().__init__() 27 | self.relu = nn.ReLU(inplace=True) 28 | self.conv1 = nn.Conv2d(in_channels, middle_channels, 3, padding=1) 29 | self.bn1 = nn.BatchNorm2d(middle_channels) 30 | self.conv2 = nn.Conv2d(middle_channels, out_channels, 3, padding=1) 31 | self.bn2 = nn.BatchNorm2d(out_channels) 32 | # self.downsample = downsample 33 | 34 | def forward(self, x): 35 | out = self.conv1(x) 36 | identity=out 37 | # print(identity.size())#(8,32,256,256) 38 | out = self.bn1(out) 39 | out = self.relu(out) 40 | 41 | out = self.conv2(out) 42 | out = self.bn2(out) 43 | output = self.relu(out+identity) 44 | # print(output.size()) 45 | # out = self.relu(output) 46 | 47 | 48 | 49 | # out=torch.add(out,out0) 50 | 51 | return output 52 | 53 | class NestedUNet(nn.Module): 54 | def __init__(self, num_classes, input_channels=3, deep_supervision=True, **kwargs): 55 | super().__init__() 56 | 57 | nb_filter = [32, 64, 128, 256, 512] 58 | 59 | self.deep_supervision = deep_supervision 60 | 61 | self.pool = nn.MaxPool2d(2, 2) 62 | self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 63 | 64 | self.conv0_0 = VGGBlock(input_channels, nb_filter[0], nb_filter[0]) 65 | self.conv1_0 = VGGBlock(nb_filter[0], nb_filter[1], nb_filter[1]) 66 | self.conv2_0 = VGGBlock(nb_filter[1], nb_filter[2], nb_filter[2]) 67 | self.conv3_0 = VGGBlock(nb_filter[2], nb_filter[3], nb_filter[3]) 68 | self.conv4_0 = VGGBlock(nb_filter[3], nb_filter[4], nb_filter[4]) 69 | 70 | self.conv0_1 = VGGBlock(nb_filter[0]+nb_filter[1], nb_filter[0], nb_filter[0]) 71 | self.conv1_1 = VGGBlock(nb_filter[1]+nb_filter[2], nb_filter[1], nb_filter[1]) 72 | self.conv2_1 = VGGBlock(nb_filter[2]+nb_filter[3], nb_filter[2], nb_filter[2]) 73 | self.conv3_1 = VGGBlock(nb_filter[3]+nb_filter[4], nb_filter[3], nb_filter[3]) 74 | 75 | self.conv0_2 = VGGBlock(nb_filter[0]*2+nb_filter[1], nb_filter[0], nb_filter[0]) 76 | self.conv1_2 = VGGBlock(nb_filter[1]*2+nb_filter[2], nb_filter[1], nb_filter[1]) 77 | self.conv2_2 = VGGBlock(nb_filter[2]*2+nb_filter[3], nb_filter[2], nb_filter[2]) 78 | 79 | self.conv0_3 = VGGBlock(nb_filter[0]*3+nb_filter[1], nb_filter[0], nb_filter[0]) 80 | self.conv1_3 = VGGBlock(nb_filter[1]*3+nb_filter[2], nb_filter[1], nb_filter[1]) 81 | 82 | self.conv0_4 = standard_unit(nb_filter[0]*4+nb_filter[1], nb_filter[0], nb_filter[0]) 83 | 84 | if self.deep_supervision: 85 | self.final1 = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1) 86 | self.final2 = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1) 87 | self.final3 = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1) 88 | self.final4 = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1) 89 | else: 90 | self.final = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1) 91 | 92 | 93 | def forward(self, input1,input2): 94 | input = torch.cat((input1, input2), dim=1) 95 | x0_0 = self.conv0_0(input) 96 | 97 | x1_0 = self.conv1_0(self.pool(x0_0)) 98 | x0_1 = self.conv0_1(torch.cat([x0_0, self.up(x1_0)], 1)) 99 | 100 | 101 | x2_0 = self.conv2_0(self.pool(x1_0)) 102 | x1_1 = self.conv1_1(torch.cat([x1_0, self.up(x2_0)], 1)) 103 | x0_2 = self.conv0_2(torch.cat([x0_0, x0_1, self.up(x1_1)], 1)) 104 | 105 | x3_0 = self.conv3_0(self.pool(x2_0)) 106 | x2_1 = self.conv2_1(torch.cat([x2_0, self.up(x3_0)], 1)) 107 | x1_2 = self.conv1_2(torch.cat([x1_0, x1_1, self.up(x2_1)], 1)) 108 | x0_3 = self.conv0_3(torch.cat([x0_0, x0_1, x0_2, self.up(x1_2)], 1)) 109 | 110 | x4_0 = self.conv4_0(self.pool(x3_0)) 111 | x3_1 = self.conv3_1(torch.cat([x3_0, self.up(x4_0)], 1)) 112 | x2_2 = self.conv2_2(torch.cat([x2_0, x2_1, self.up(x3_1)], 1)) 113 | x1_3 = self.conv1_3(torch.cat([x1_0, x1_1, x1_2, self.up(x2_2)], 1)) 114 | x0_4 = self.conv0_4(torch.cat([x0_0, x0_1, x0_2, x0_3, self.up(x1_3)], 1)) 115 | 116 | if self.deep_supervision: 117 | output1 = self.final1(x0_1) 118 | 119 | output2 = self.final2(x0_2) 120 | output3 = self.final3(x0_3) 121 | output4 = self.final4(x0_4) 122 | return [output1, output2, output3, output4] 123 | 124 | else: 125 | output = self.final(x0_4) 126 | return output -------------------------------------------------------------------------------- /models/CBAM.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | class BasicConv(nn.Module): 7 | def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1, relu=True, bn=True, bias=False): 8 | super(BasicConv, self).__init__() 9 | self.out_channels = out_planes 10 | self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias) 11 | self.bn = nn.BatchNorm2d(out_planes,eps=1e-5, momentum=0.01, affine=True) if bn else None 12 | self.relu = nn.ReLU() if relu else None 13 | 14 | def forward(self, x): 15 | x = self.conv(x) 16 | if self.bn is not None: 17 | x = self.bn(x) 18 | if self.relu is not None: 19 | x = self.relu(x) 20 | return x 21 | 22 | class Flatten(nn.Module): 23 | def forward(self, x): 24 | return x.view(x.size(0), -1) 25 | 26 | class ChannelGate(nn.Module): 27 | def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max']): 28 | super(ChannelGate, self).__init__() 29 | self.gate_channels = gate_channels 30 | self.mlp = nn.Sequential( 31 | Flatten(), 32 | nn.Linear(gate_channels, gate_channels // reduction_ratio), 33 | nn.ReLU(), 34 | nn.Linear(gate_channels // reduction_ratio, gate_channels) 35 | ) 36 | self.pool_types = pool_types 37 | def forward(self, x): 38 | channel_att_sum = None 39 | for pool_type in self.pool_types: 40 | if pool_type=='avg': 41 | avg_pool = F.avg_pool2d( x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3))) 42 | channel_att_raw = self.mlp( avg_pool ) 43 | elif pool_type=='max': 44 | max_pool = F.max_pool2d( x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3))) 45 | channel_att_raw = self.mlp( max_pool ) 46 | elif pool_type=='lp': 47 | lp_pool = F.lp_pool2d( x, 2, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3))) 48 | channel_att_raw = self.mlp( lp_pool ) 49 | elif pool_type=='lse': 50 | # LSE pool only 51 | lse_pool = logsumexp_2d(x) 52 | channel_att_raw = self.mlp( lse_pool ) 53 | 54 | if channel_att_sum is None: 55 | channel_att_sum = channel_att_raw 56 | else: 57 | channel_att_sum = channel_att_sum + channel_att_raw 58 | 59 | scale = F.sigmoid( channel_att_sum ).unsqueeze(2).unsqueeze(3).expand_as(x) 60 | return x * scale #返回通道注意力图(特征图*权重) 61 | 62 | def logsumexp_2d(tensor): 63 | tensor_flatten = tensor.view(tensor.size(0), tensor.size(1), -1) 64 | s, _ = torch.max(tensor_flatten, dim=2, keepdim=True) 65 | outputs = s + (tensor_flatten - s).exp().sum(dim=2, keepdim=True).log() 66 | return outputs 67 | 68 | class ChannelPool(nn.Module): 69 | def forward(self, x): 70 | return torch.cat( (torch.max(x,1)[0].unsqueeze(1), torch.mean(x,1).unsqueeze(1)), dim=1 ) 71 | 72 | class SpatialGate(nn.Module): 73 | def __init__(self): 74 | super(SpatialGate, self).__init__() 75 | kernel_size = 7 76 | self.compress = ChannelPool() 77 | self.spatial = BasicConv(2, 1, kernel_size, stride=1, padding=(kernel_size-1) // 2, relu=False) 78 | def forward(self, x): 79 | x_compress = self.compress(x) 80 | x_out = self.spatial(x_compress) 81 | scale = F.sigmoid(x_out) # broadcasting 82 | return x * scale 83 | 84 | # class CBAM(nn.Module): 85 | # def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max'], no_spatial=False): 86 | # super(CBAM, self).__init__() 87 | # self.ChannelGate = ChannelGate(gate_channels, reduction_ratio, pool_types) 88 | # self.no_spatial=no_spatial 89 | # if not no_spatial: 90 | # self.SpatialGate = SpatialGate() 91 | # def forward(self, x): 92 | # x_out = self.ChannelGate(x) 93 | # if not self.no_spatial: 94 | # x_out = self.SpatialGate(x_out) 95 | # return x_out 96 | # 97 | class ChannelAttentionModule(nn.Module): 98 | def __init__(self, channel, ratio=16): 99 | super(ChannelAttentionModule, self).__init__() 100 | # 使用自适应池化缩减map的大小,保持通道不变 101 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 102 | self.max_pool = nn.AdaptiveMaxPool2d(1) 103 | 104 | self.shared_MLP = nn.Sequential( 105 | nn.Conv2d(channel, channel // ratio, 1, bias=False), 106 | nn.ReLU(), 107 | nn.Conv2d(channel // ratio, channel, 1, bias=False) 108 | ) 109 | self.sigmoid = nn.Sigmoid() 110 | 111 | def forward(self, x): 112 | avgout = self.shared_MLP(self.avg_pool(x)) 113 | maxout = self.shared_MLP(self.max_pool(x)) 114 | return self.sigmoid(avgout + maxout) 115 | 116 | 117 | class SpatialAttentionModule(nn.Module): 118 | def __init__(self): 119 | super(SpatialAttentionModule, self).__init__() 120 | self.conv2d = nn.Conv2d(in_channels=2, out_channels=1, kernel_size=7, stride=1, padding=3) 121 | self.sigmoid = nn.Sigmoid() 122 | 123 | def forward(self, x): 124 | # map尺寸不变,缩减通道 125 | avgout = torch.mean(x, dim=1, keepdim=True) 126 | maxout, _ = torch.max(x, dim=1, keepdim=True) 127 | out = torch.cat([avgout, maxout], dim=1) 128 | out = self.sigmoid(self.conv2d(out)) 129 | return out 130 | 131 | 132 | class CBAM(nn.Module): 133 | def __init__(self, channel): 134 | super(CBAM, self).__init__() 135 | self.channel_attention = ChannelAttentionModule(channel) 136 | self.spatial_attention = SpatialAttentionModule() 137 | 138 | def forward(self, x): 139 | out = self.channel_attention(x) * x 140 | out = self.spatial_attention(out) * out 141 | return out -------------------------------------------------------------------------------- /misc/metric_tool.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | ################### metrics ################### 5 | class AverageMeter(object): 6 | """Computes and stores the average and current value""" 7 | def __init__(self): 8 | self.initialized = False 9 | self.val = None 10 | self.avg = None 11 | self.sum = None 12 | self.count = None 13 | 14 | def initialize(self, val, weight): 15 | self.val = val 16 | self.avg = val 17 | self.sum = val * weight 18 | self.count = weight 19 | self.initialized = True 20 | 21 | def update(self, val, weight=1): 22 | if not self.initialized: 23 | self.initialize(val, weight) 24 | else: 25 | self.add(val, weight) 26 | 27 | def add(self, val, weight): 28 | self.val = val 29 | self.sum += val * weight 30 | self.count += weight 31 | self.avg = self.sum / self.count 32 | 33 | def value(self): 34 | return self.val 35 | 36 | def average(self): 37 | return self.avg 38 | 39 | def get_scores(self): 40 | scores_dict = cm2score(self.sum) 41 | return scores_dict 42 | 43 | def clear(self): 44 | self.initialized = False 45 | 46 | 47 | ################### cm metrics ################### 48 | class ConfuseMatrixMeter(AverageMeter): 49 | """Computes and stores the average and current value""" 50 | def __init__(self, n_class): 51 | super(ConfuseMatrixMeter, self).__init__() 52 | self.n_class = n_class 53 | 54 | def update_cm(self, pr, gt, weight=1): 55 | """获得当前混淆矩阵,并计算当前F1得分,并更新混淆矩阵""" 56 | val = get_confuse_matrix(num_classes=self.n_class, label_gts=gt, label_preds=pr) 57 | self.update(val, weight) 58 | current_score = cm2F1(val) 59 | return current_score 60 | 61 | def get_scores(self): 62 | scores_dict = cm2score(self.sum) 63 | return scores_dict 64 | 65 | 66 | 67 | def harmonic_mean(xs): 68 | harmonic_mean = len(xs) / sum((x+1e-6)**-1 for x in xs) 69 | return harmonic_mean 70 | 71 | 72 | def cm2F1(confusion_matrix): 73 | hist = confusion_matrix 74 | n_class = hist.shape[0] 75 | tp = np.diag(hist) 76 | sum_a1 = hist.sum(axis=1) 77 | sum_a0 = hist.sum(axis=0) 78 | # ---------------------------------------------------------------------- # 79 | # 1. Accuracy & Class Accuracy 80 | # ---------------------------------------------------------------------- # 81 | acc = tp.sum() / (hist.sum() + np.finfo(np.float32).eps) 82 | 83 | # recall 84 | recall = tp / (sum_a1 + np.finfo(np.float32).eps) 85 | # acc_cls = np.nanmean(recall) 86 | 87 | # precision 88 | precision = tp / (sum_a0 + np.finfo(np.float32).eps) 89 | 90 | # F1 score 91 | F1 = 2 * recall * precision / (recall + precision + np.finfo(np.float32).eps) 92 | mean_F1 = np.nanmean(F1)#求均值的时候忽略nan值 93 | return mean_F1 94 | 95 | 96 | def cm2score(confusion_matrix): 97 | hist = confusion_matrix 98 | n_class = hist.shape[0] 99 | tp = np.diag(hist) 100 | sum_a1 = hist.sum(axis=1) 101 | sum_a0 = hist.sum(axis=0) 102 | # ---------------------------------------------------------------------- # 103 | # 1. Accuracy & Class Accuracy 104 | # ---------------------------------------------------------------------- # 105 | acc = tp.sum() / (hist.sum() + np.finfo(np.float32).eps) 106 | 107 | # recall 108 | recall = tp / (sum_a1 + np.finfo(np.float32).eps) 109 | # acc_cls = np.nanmean(recall) 110 | 111 | # precision 112 | precision = tp / (sum_a0 + np.finfo(np.float32).eps) 113 | 114 | # F1 score 115 | F1 = 2*recall * precision / (recall + precision + np.finfo(np.float32).eps) 116 | mean_F1 = np.nanmean(F1) 117 | # ---------------------------------------------------------------------- # 118 | # 2. Frequency weighted Accuracy & Mean IoU 119 | # ---------------------------------------------------------------------- # 120 | iu = tp / (sum_a1 + hist.sum(axis=0) - tp + np.finfo(np.float32).eps) 121 | mean_iu = np.nanmean(iu) 122 | 123 | freq = sum_a1 / (hist.sum() + np.finfo(np.float32).eps) 124 | fwavacc = (freq[freq > 0] * iu[freq > 0]).sum() 125 | 126 | # 127 | cls_iou = dict(zip(['iou_'+str(i) for i in range(n_class)], iu)) 128 | 129 | cls_precision = dict(zip(['precision_'+str(i) for i in range(n_class)], precision)) 130 | cls_recall = dict(zip(['recall_'+str(i) for i in range(n_class)], recall)) 131 | cls_F1 = dict(zip(['F1_'+str(i) for i in range(n_class)], F1)) 132 | 133 | score_dict = {'acc': acc, 'miou': mean_iu, 'mf1':mean_F1} 134 | score_dict.update(cls_iou) 135 | score_dict.update(cls_F1) 136 | score_dict.update(cls_precision) 137 | score_dict.update(cls_recall) 138 | return score_dict 139 | 140 | 141 | def get_confuse_matrix(num_classes, label_gts, label_preds): 142 | """计算一组预测的混淆矩阵""" 143 | def __fast_hist(label_gt, label_pred): 144 | """ 145 | Collect values for Confusion Matrix 146 | For reference, please see: https://en.wikipedia.org/wiki/Confusion_matrix 147 | :param label_gt: ground-truth 148 | :param label_pred: prediction 149 | :return: values for confusion matrix 150 | """ 151 | mask = (label_gt >= 0) & (label_gt < num_classes) 152 | hist = np.bincount(num_classes * label_gt[mask].astype(int) + label_pred[mask], 153 | minlength=num_classes**2).reshape(num_classes, num_classes) 154 | return hist 155 | confusion_matrix = np.zeros((num_classes, num_classes)) 156 | for lt, lp in zip(label_gts, label_preds): 157 | confusion_matrix += __fast_hist(lt.flatten(), lp.flatten()) 158 | return confusion_matrix 159 | 160 | 161 | def get_mIoU(num_classes, label_gts, label_preds): 162 | confusion_matrix = get_confuse_matrix(num_classes, label_gts, label_preds) 163 | score_dict = cm2score(confusion_matrix) 164 | return score_dict['miou'] 165 | -------------------------------------------------------------------------------- /utils/helpers.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import torch 3 | import torch.utils.data 4 | import torch.nn as nn 5 | import numpy as np 6 | from utils.dataloaders import (full_path_loader, full_test_loader, CDDloader, LEVIRloader,LEVIRplusloader) 7 | from utils.metrics import jaccard_loss, dice_loss 8 | from utils.losses import hybrid_loss 9 | from models.Models import Siam_NestedUNet_Conc, SNUNet_ECAM 10 | from models.siamunet_dif import SiamUnet_diff 11 | from models.networks import BASE_Transformer 12 | 13 | 14 | logging.basicConfig(level=logging.INFO) 15 | 16 | def initialize_metrics(): 17 | """Generates a dictionary of metrics with metrics as keys 18 | and empty lists as values 19 | 20 | Returns 21 | ------- 22 | dict 23 | a dictionary of metrics 24 | 25 | """ 26 | metrics = { 27 | 'cd_losses': [], 28 | 'cd_corrects': [], 29 | 'cd_precisions': [], 30 | 'cd_recalls': [], 31 | 'cd_f1scores': [], 32 | 'learning_rate': [], 33 | } 34 | 35 | return metrics 36 | 37 | 38 | def get_mean_metrics(metric_dict): 39 | """takes a dictionary of lists for metrics and returns dict of mean values 40 | 41 | Parameters 42 | ---------- 43 | metric_dict : dict 44 | A dictionary of metrics 45 | 46 | Returns 47 | ------- 48 | dict 49 | dict of floats that reflect mean metric value 50 | 51 | """ 52 | return {k: np.mean(v) for k, v in metric_dict.items()} 53 | 54 | 55 | def set_metrics(metric_dict, cd_loss, cd_corrects, cd_report, lr): 56 | """Updates metric dict with batch metrics 57 | 58 | Parameters 59 | ---------- 60 | metric_dict : dict 61 | dict of metrics 62 | cd_loss : dict(?) 63 | loss value 64 | cd_corrects : dict(?) 65 | number of correct results (to generate accuracy 66 | cd_report : list 67 | precision, recall, f1 values 68 | 69 | Returns 70 | ------- 71 | dict 72 | dict of updated metrics 73 | 74 | 75 | """ 76 | metric_dict['cd_losses'].append(cd_loss.item()) 77 | metric_dict['cd_corrects'].append(cd_corrects.item()) 78 | metric_dict['cd_precisions'].append(cd_report[0]) 79 | metric_dict['cd_recalls'].append(cd_report[1]) 80 | metric_dict['cd_f1scores'].append(cd_report[2]) 81 | metric_dict['learning_rate'].append(lr) 82 | 83 | return metric_dict 84 | 85 | def set_test_metrics(metric_dict, cd_corrects, cd_report): 86 | 87 | metric_dict['cd_corrects'].append(cd_corrects.item()) 88 | metric_dict['cd_precisions'].append(cd_report[0]) 89 | metric_dict['cd_recalls'].append(cd_report[1]) 90 | metric_dict['cd_f1scores'].append(cd_report[2]) 91 | 92 | return metric_dict 93 | 94 | 95 | def get_loaders(opt): 96 | 97 | 98 | logging.info('STARTING Dataset Creation') 99 | 100 | train_full_load, val_full_load = full_path_loader(opt.dataset_dir) 101 | 102 | if opt.dataset == 'cdd': 103 | 104 | train_dataset = CDDloader(train_full_load, flag = 'trn', aug=opt.augmentation) 105 | val_dataset = CDDloader(val_full_load, flag='val', aug=False) 106 | 107 | elif opt.dataset == 'levir': 108 | train_dataset = LEVIRloader(train_full_load, flag = 'trn', aug=opt.augmentation) 109 | val_dataset = LEVIRloader(val_full_load, flag='val', aug=False) 110 | elif opt.dataset == 'levir+': 111 | train_dataset = LEVIRplusloader(train_full_load, flag='trn', aug=opt.augmentation) 112 | val_dataset = LEVIRplusloader(val_full_load, flag='val', aug=False) 113 | 114 | logging.info('STARTING Dataloading') 115 | 116 | train_loader = torch.utils.data.DataLoader(train_dataset, 117 | batch_size=opt.batch_size, 118 | shuffle=True, 119 | num_workers=opt.num_workers, 120 | pin_memory=True) 121 | val_loader = torch.utils.data.DataLoader(val_dataset, 122 | batch_size=opt.batch_size, 123 | shuffle=False, 124 | num_workers=opt.num_workers, 125 | pin_memory=True) 126 | return train_loader, val_loader 127 | 128 | def get_test_loaders(opt, batch_size=None): 129 | 130 | if not batch_size: 131 | batch_size = opt.batch_size 132 | 133 | logging.info('STARTING Dataset Creation') 134 | 135 | test_full_load = full_test_loader(opt.dataset_dir) 136 | 137 | if opt.dataset == 'cdd': 138 | 139 | test_dataset = CDDloader(test_full_load, flag = 'tes', aug=False) 140 | 141 | elif opt.dataset == 'levir': 142 | 143 | test_dataset = LEVIRloader(test_full_load, flag = 'tes', aug=False) 144 | 145 | logging.info('STARTING Dataloading') 146 | 147 | 148 | test_loader = torch.utils.data.DataLoader(test_dataset, 149 | batch_size=batch_size, 150 | shuffle=False, 151 | num_workers=opt.num_workers) 152 | return test_loader 153 | 154 | 155 | def get_criterion(opt): 156 | """get the user selected loss function 157 | 158 | Parameters 159 | ---------- 160 | opt : dict 161 | Dictionary of options/flags 162 | 163 | Returns 164 | ------- 165 | method 166 | loss function 167 | 168 | """ 169 | if opt.loss_function == 'hybrid': 170 | criterion = hybrid_loss 171 | if opt.loss_function == 'bce': 172 | criterion = nn.CrossEntropyLoss() 173 | if opt.loss_function == 'dice': 174 | criterion = dice_loss 175 | if opt.loss_function == 'jaccard': 176 | criterion = jaccard_loss 177 | 178 | return criterion 179 | 180 | 181 | def load_model(opt, device): 182 | """Load the model 183 | 184 | Parameters 185 | ---------- 186 | opt : dict 187 | User specified flags/options 188 | device : string 189 | device on which to train model 190 | 191 | """ 192 | # device_ids = list(range(opt.num_gpus)) 193 | #model = SNUNet_ECAM(opt, opt.num_channel, 2).to(device) 194 | # model = nn.DataParallel(model, device_ids=device_ids) 195 | 196 | model = BASE_Transformer(opt, input_nc=3, output_nc=2, token_len=4, resnet_stages_num=4, 197 | with_pos='learned', enc_depth=1, dec_depth=8).to(device) 198 | 199 | return model 200 | -------------------------------------------------------------------------------- /compare/SNUNet.py: -------------------------------------------------------------------------------- 1 | # Kaiyu Li 2 | # https://github.com/likyoo 3 | # 4 | 5 | import torch.nn as nn 6 | import torch 7 | 8 | class conv_block_nested(nn.Module): 9 | def __init__(self, in_ch, mid_ch, out_ch): 10 | super(conv_block_nested, self).__init__() 11 | self.activation = nn.ReLU(inplace=True) 12 | self.conv1 = nn.Conv2d(in_ch, mid_ch, kernel_size=3, padding=1, bias=True) 13 | self.bn1 = nn.BatchNorm2d(mid_ch) 14 | self.conv2 = nn.Conv2d(mid_ch, out_ch, kernel_size=3, padding=1, bias=True) 15 | self.bn2 = nn.BatchNorm2d(out_ch) 16 | 17 | def forward(self, x): 18 | x = self.conv1(x) 19 | identity = x 20 | x = self.bn1(x) 21 | x = self.activation(x) 22 | 23 | x = self.conv2(x) 24 | x = self.bn2(x) 25 | output = self.activation(x + identity) 26 | return output 27 | 28 | 29 | class up(nn.Module): 30 | def __init__(self, in_ch, bilinear=False): 31 | super(up, self).__init__() 32 | 33 | if bilinear: 34 | self.up = nn.Upsample(scale_factor=2, 35 | mode='bilinear', 36 | align_corners=True) 37 | else: 38 | self.up = nn.ConvTranspose2d(in_ch, in_ch, 2, stride=2) 39 | 40 | def forward(self, x): 41 | 42 | x = self.up(x) 43 | return x 44 | 45 | 46 | class ChannelAttention(nn.Module): 47 | def __init__(self, in_channels, ratio = 16): 48 | super(ChannelAttention, self).__init__() 49 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 50 | self.max_pool = nn.AdaptiveMaxPool2d(1) 51 | self.fc1 = nn.Conv2d(in_channels,in_channels//ratio,1,bias=False) 52 | self.relu1 = nn.ReLU() 53 | self.fc2 = nn.Conv2d(in_channels//ratio, in_channels,1,bias=False) 54 | self.sigmod = nn.Sigmoid() 55 | def forward(self,x): 56 | avg_out = self.fc2(self.relu1(self.fc1(self.avg_pool(x)))) 57 | max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x)))) 58 | out = avg_out + max_out 59 | return self.sigmod(out) 60 | 61 | 62 | 63 | class SNUNet_ECAM(nn.Module): 64 | # SNUNet-CD with ECAM 65 | def __init__(self, in_ch=3, out_ch=2): 66 | super(SNUNet_ECAM, self).__init__() 67 | torch.nn.Module.dump_patches = True 68 | n1 = 32 # the initial number of channels of feature map 69 | filters = [n1, n1 * 2, n1 * 4, n1 * 8, n1 * 16] 70 | 71 | self.pool = nn.MaxPool2d(kernel_size=2, stride=2) 72 | 73 | self.conv0_0 = conv_block_nested(in_ch, filters[0], filters[0]) 74 | self.conv1_0 = conv_block_nested(filters[0], filters[1], filters[1]) 75 | self.Up1_0 = up(filters[1]) 76 | self.conv2_0 = conv_block_nested(filters[1], filters[2], filters[2]) 77 | self.Up2_0 = up(filters[2]) 78 | self.conv3_0 = conv_block_nested(filters[2], filters[3], filters[3]) 79 | self.Up3_0 = up(filters[3]) 80 | self.conv4_0 = conv_block_nested(filters[3], filters[4], filters[4]) 81 | self.Up4_0 = up(filters[4]) 82 | 83 | self.conv0_1 = conv_block_nested(filters[0] * 2 + filters[1], filters[0], filters[0]) 84 | self.conv1_1 = conv_block_nested(filters[1] * 2 + filters[2], filters[1], filters[1]) 85 | self.Up1_1 = up(filters[1]) 86 | self.conv2_1 = conv_block_nested(filters[2] * 2 + filters[3], filters[2], filters[2]) 87 | self.Up2_1 = up(filters[2]) 88 | self.conv3_1 = conv_block_nested(filters[3] * 2 + filters[4], filters[3], filters[3]) 89 | self.Up3_1 = up(filters[3]) 90 | 91 | self.conv0_2 = conv_block_nested(filters[0] * 3 + filters[1], filters[0], filters[0]) 92 | self.conv1_2 = conv_block_nested(filters[1] * 3 + filters[2], filters[1], filters[1]) 93 | self.Up1_2 = up(filters[1]) 94 | self.conv2_2 = conv_block_nested(filters[2] * 3 + filters[3], filters[2], filters[2]) 95 | self.Up2_2 = up(filters[2]) 96 | 97 | self.conv0_3 = conv_block_nested(filters[0] * 4 + filters[1], filters[0], filters[0]) 98 | self.conv1_3 = conv_block_nested(filters[1] * 4 + filters[2], filters[1], filters[1]) 99 | self.Up1_3 = up(filters[1]) 100 | 101 | self.conv0_4 = conv_block_nested(filters[0] * 5 + filters[1], filters[0], filters[0]) 102 | 103 | self.ca = ChannelAttention(filters[0] * 4, ratio=16) 104 | self.ca1 = ChannelAttention(filters[0], ratio=16 // 4) 105 | 106 | self.conv_final = nn.Conv2d(filters[0] * 4, out_ch, kernel_size=1) 107 | # self.conv_final = nn.Conv2d(filters[0], out_ch, kernel_size=1) 108 | 109 | for m in self.modules(): 110 | if isinstance(m, nn.Conv2d): 111 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 112 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 113 | nn.init.constant_(m.weight, 1) 114 | nn.init.constant_(m.bias, 0) 115 | 116 | 117 | def forward(self, xA, xB): 118 | '''xA''' 119 | x0_0A = self.conv0_0(xA) 120 | x1_0A = self.conv1_0(self.pool(x0_0A)) 121 | x2_0A = self.conv2_0(self.pool(x1_0A)) 122 | x3_0A = self.conv3_0(self.pool(x2_0A)) 123 | # x4_0A = self.conv4_0(self.pool(x3_0A)) 124 | '''xB''' 125 | x0_0B = self.conv0_0(xB) 126 | x1_0B = self.conv1_0(self.pool(x0_0B)) 127 | x2_0B = self.conv2_0(self.pool(x1_0B)) 128 | x3_0B = self.conv3_0(self.pool(x2_0B)) 129 | x4_0B = self.conv4_0(self.pool(x3_0B)) 130 | 131 | x0_1 = self.conv0_1(torch.cat([x0_0A, x0_0B, self.Up1_0(x1_0B)], 1)) 132 | x1_1 = self.conv1_1(torch.cat([x1_0A, x1_0B, self.Up2_0(x2_0B)], 1)) 133 | x0_2 = self.conv0_2(torch.cat([x0_0A, x0_0B, x0_1, self.Up1_1(x1_1)], 1)) 134 | 135 | 136 | x2_1 = self.conv2_1(torch.cat([x2_0A, x2_0B, self.Up3_0(x3_0B)], 1)) 137 | x1_2 = self.conv1_2(torch.cat([x1_0A, x1_0B, x1_1, self.Up2_1(x2_1)], 1)) 138 | x0_3 = self.conv0_3(torch.cat([x0_0A, x0_0B, x0_1, x0_2, self.Up1_2(x1_2)], 1)) 139 | 140 | x3_1 = self.conv3_1(torch.cat([x3_0A, x3_0B, self.Up4_0(x4_0B)], 1)) 141 | x2_2 = self.conv2_2(torch.cat([x2_0A, x2_0B, x2_1, self.Up3_1(x3_1)], 1)) 142 | x1_3 = self.conv1_3(torch.cat([x1_0A, x1_0B, x1_1, x1_2, self.Up2_2(x2_2)], 1)) 143 | x0_4 = self.conv0_4(torch.cat([x0_0A, x0_0B, x0_1, x0_2, x0_3, self.Up1_3(x1_3)], 1)) 144 | 145 | out = torch.cat([x0_1, x0_2, x0_3, x0_4], 1) 146 | 147 | intra = torch.sum(torch.stack((x0_1, x0_2, x0_3, x0_4)), dim=0) 148 | ca1 = self.ca1(intra) 149 | out = self.ca(out) * (out + ca1.repeat(1, 4, 1, 1)) 150 | out = self.conv_final(out) 151 | 152 | # return (out, ) 153 | return out 154 | 155 | -------------------------------------------------------------------------------- /compare/FC_EF.py: -------------------------------------------------------------------------------- 1 | # Rodrigo Caye Daudt 2 | # https://rcdaudt.github.io/ 3 | # Daudt, R. C., Le Saux, B., & Boulch, A. "Fully convolutional siamese networks for change detection". In 2018 25th IEEE International Conference on Image Processing (ICIP) (pp. 4063-4067). IEEE. 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from torch.nn.modules.padding import ReplicationPad2d 9 | 10 | 11 | 12 | class Unet(nn.Module): 13 | """EF segmentation network.""" 14 | 15 | def __init__(self, input_nbr, label_nbr): 16 | super(Unet, self).__init__() 17 | 18 | self.conv11 = nn.Conv2d(2 * input_nbr, 16, kernel_size=3, padding=1) 19 | self.bn11 = nn.BatchNorm2d(16) 20 | self.do11 = nn.Dropout2d(p=0.2) 21 | self.conv12 = nn.Conv2d(16, 16, kernel_size=3, padding=1) 22 | self.bn12 = nn.BatchNorm2d(16) 23 | self.do12 = nn.Dropout2d(p=0.2) 24 | 25 | self.conv21 = nn.Conv2d(16, 32, kernel_size=3, padding=1) 26 | self.bn21 = nn.BatchNorm2d(32) 27 | self.do21 = nn.Dropout2d(p=0.2) 28 | self.conv22 = nn.Conv2d(32, 32, kernel_size=3, padding=1) 29 | self.bn22 = nn.BatchNorm2d(32) 30 | self.do22 = nn.Dropout2d(p=0.2) 31 | 32 | self.conv31 = nn.Conv2d(32, 64, kernel_size=3, padding=1) 33 | self.bn31 = nn.BatchNorm2d(64) 34 | self.do31 = nn.Dropout2d(p=0.2) 35 | self.conv32 = nn.Conv2d(64, 64, kernel_size=3, padding=1) 36 | self.bn32 = nn.BatchNorm2d(64) 37 | self.do32 = nn.Dropout2d(p=0.2) 38 | self.conv33 = nn.Conv2d(64, 64, kernel_size=3, padding=1) 39 | self.bn33 = nn.BatchNorm2d(64) 40 | self.do33 = nn.Dropout2d(p=0.2) 41 | 42 | self.conv41 = nn.Conv2d(64, 128, kernel_size=3, padding=1) 43 | self.bn41 = nn.BatchNorm2d(128) 44 | self.do41 = nn.Dropout2d(p=0.2) 45 | self.conv42 = nn.Conv2d(128, 128, kernel_size=3, padding=1) 46 | self.bn42 = nn.BatchNorm2d(128) 47 | self.do42 = nn.Dropout2d(p=0.2) 48 | self.conv43 = nn.Conv2d(128, 128, kernel_size=3, padding=1) 49 | self.bn43 = nn.BatchNorm2d(128) 50 | self.do43 = nn.Dropout2d(p=0.2) 51 | 52 | self.upconv4 = nn.ConvTranspose2d(128, 128, kernel_size=3, padding=1, stride=2, output_padding=1) 53 | 54 | self.conv43d = nn.ConvTranspose2d(256, 128, kernel_size=3, padding=1) 55 | self.bn43d = nn.BatchNorm2d(128) 56 | self.do43d = nn.Dropout2d(p=0.2) 57 | self.conv42d = nn.ConvTranspose2d(128, 128, kernel_size=3, padding=1) 58 | self.bn42d = nn.BatchNorm2d(128) 59 | self.do42d = nn.Dropout2d(p=0.2) 60 | self.conv41d = nn.ConvTranspose2d(128, 64, kernel_size=3, padding=1) 61 | self.bn41d = nn.BatchNorm2d(64) 62 | self.do41d = nn.Dropout2d(p=0.2) 63 | 64 | self.upconv3 = nn.ConvTranspose2d(64, 64, kernel_size=3, padding=1, stride=2, output_padding=1) 65 | 66 | self.conv33d = nn.ConvTranspose2d(128, 64, kernel_size=3, padding=1) 67 | self.bn33d = nn.BatchNorm2d(64) 68 | self.do33d = nn.Dropout2d(p=0.2) 69 | self.conv32d = nn.ConvTranspose2d(64, 64, kernel_size=3, padding=1) 70 | self.bn32d = nn.BatchNorm2d(64) 71 | self.do32d = nn.Dropout2d(p=0.2) 72 | self.conv31d = nn.ConvTranspose2d(64, 32, kernel_size=3, padding=1) 73 | self.bn31d = nn.BatchNorm2d(32) 74 | self.do31d = nn.Dropout2d(p=0.2) 75 | 76 | self.upconv2 = nn.ConvTranspose2d(32, 32, kernel_size=3, padding=1, stride=2, output_padding=1) 77 | 78 | self.conv22d = nn.ConvTranspose2d(64, 32, kernel_size=3, padding=1) 79 | self.bn22d = nn.BatchNorm2d(32) 80 | self.do22d = nn.Dropout2d(p=0.2) 81 | self.conv21d = nn.ConvTranspose2d(32, 16, kernel_size=3, padding=1) 82 | self.bn21d = nn.BatchNorm2d(16) 83 | self.do21d = nn.Dropout2d(p=0.2) 84 | 85 | self.upconv1 = nn.ConvTranspose2d(16, 16, kernel_size=3, padding=1, stride=2, output_padding=1) 86 | 87 | self.conv12d = nn.ConvTranspose2d(32, 16, kernel_size=3, padding=1) 88 | self.bn12d = nn.BatchNorm2d(16) 89 | self.do12d = nn.Dropout2d(p=0.2) 90 | self.conv11d = nn.ConvTranspose2d(16, label_nbr, kernel_size=3, padding=1) 91 | 92 | self.sm = nn.LogSoftmax(dim=1) 93 | 94 | def forward(self, x1, x2): 95 | x = torch.cat((x1, x2), 1) 96 | 97 | """Forward method.""" 98 | # Stage 1 99 | x11 = self.do11(F.relu(self.bn11(self.conv11(x)))) 100 | x12 = self.do12(F.relu(self.bn12(self.conv12(x11)))) 101 | x1p = F.max_pool2d(x12, kernel_size=2, stride=2) 102 | 103 | # Stage 2 104 | x21 = self.do21(F.relu(self.bn21(self.conv21(x1p)))) 105 | x22 = self.do22(F.relu(self.bn22(self.conv22(x21)))) 106 | x2p = F.max_pool2d(x22, kernel_size=2, stride=2) 107 | 108 | # Stage 3 109 | x31 = self.do31(F.relu(self.bn31(self.conv31(x2p)))) 110 | x32 = self.do32(F.relu(self.bn32(self.conv32(x31)))) 111 | x33 = self.do33(F.relu(self.bn33(self.conv33(x32)))) 112 | x3p = F.max_pool2d(x33, kernel_size=2, stride=2) 113 | 114 | # Stage 4 115 | x41 = self.do41(F.relu(self.bn41(self.conv41(x3p)))) 116 | x42 = self.do42(F.relu(self.bn42(self.conv42(x41)))) 117 | x43 = self.do43(F.relu(self.bn43(self.conv43(x42)))) 118 | x4p = F.max_pool2d(x43, kernel_size=2, stride=2) 119 | 120 | # Stage 4d 121 | x4d = self.upconv4(x4p) 122 | pad4 = ReplicationPad2d((0, x43.size(3) - x4d.size(3), 0, x43.size(2) - x4d.size(2))) 123 | x4d = torch.cat((pad4(x4d), x43), 1) 124 | x43d = self.do43d(F.relu(self.bn43d(self.conv43d(x4d)))) 125 | x42d = self.do42d(F.relu(self.bn42d(self.conv42d(x43d)))) 126 | x41d = self.do41d(F.relu(self.bn41d(self.conv41d(x42d)))) 127 | 128 | # Stage 3d 129 | x3d = self.upconv3(x41d) 130 | pad3 = ReplicationPad2d((0, x33.size(3) - x3d.size(3), 0, x33.size(2) - x3d.size(2))) 131 | x3d = torch.cat((pad3(x3d), x33), 1) 132 | x33d = self.do33d(F.relu(self.bn33d(self.conv33d(x3d)))) 133 | x32d = self.do32d(F.relu(self.bn32d(self.conv32d(x33d)))) 134 | x31d = self.do31d(F.relu(self.bn31d(self.conv31d(x32d)))) 135 | 136 | # Stage 2d 137 | x2d = self.upconv2(x31d) 138 | pad2 = ReplicationPad2d((0, x22.size(3) - x2d.size(3), 0, x22.size(2) - x2d.size(2))) 139 | x2d = torch.cat((pad2(x2d), x22), 1) 140 | x22d = self.do22d(F.relu(self.bn22d(self.conv22d(x2d)))) 141 | x21d = self.do21d(F.relu(self.bn21d(self.conv21d(x22d)))) 142 | 143 | # Stage 1d 144 | x1d = self.upconv1(x21d) 145 | pad1 = ReplicationPad2d((0, x12.size(3) - x1d.size(3), 0, x12.size(2) - x1d.size(2))) 146 | x1d = torch.cat((pad1(x1d), x12), 1) 147 | x12d = self.do12d(F.relu(self.bn12d(self.conv12d(x1d)))) 148 | x11d = self.conv11d(x12d) 149 | 150 | # output = [] 151 | # output.append(x11d) 152 | # output = output[-1] 153 | 154 | return x11d 155 | 156 | 157 | -------------------------------------------------------------------------------- /datasets/data_utils.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | 4 | from PIL import Image 5 | from PIL import ImageFilter 6 | 7 | import torchvision.transforms.functional as TF 8 | from torchvision import transforms 9 | import torch 10 | 11 | """ 12 | 这是数据格式转换和数据增强的代码 13 | """ 14 | 15 | 16 | def to_tensor_and_norm(imgs, labels): 17 | # to tensor 18 | imgs = [TF.to_tensor(img) for img in imgs] 19 | labels = [torch.from_numpy(np.array(img, np.uint8)).unsqueeze(dim=0) 20 | for img in labels] 21 | 22 | imgs = [TF.normalize(img, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 23 | for img in imgs] 24 | return imgs, labels 25 | 26 | 27 | class CDDataAugmentation: 28 | 29 | def __init__( 30 | self, 31 | img_size, 32 | with_random_hflip=False, 33 | with_random_vflip=False, 34 | with_random_rot=False, 35 | with_random_crop=False, 36 | with_scale_random_crop=False, 37 | with_random_blur=False, 38 | ): 39 | self.img_size = img_size 40 | if self.img_size is None: 41 | self.img_size_dynamic = True 42 | else: 43 | self.img_size_dynamic = False 44 | self.with_random_hflip = with_random_hflip 45 | self.with_random_vflip = with_random_vflip 46 | self.with_random_rot = with_random_rot 47 | self.with_random_crop = with_random_crop 48 | self.with_scale_random_crop = with_scale_random_crop 49 | self.with_random_blur = with_random_blur 50 | def transform(self, imgs, labels, to_tensor=True): 51 | """ 52 | :param imgs: [ndarray,] 53 | :param labels: [ndarray,] 54 | :return: [ndarray,],[ndarray,] 55 | """ 56 | # resize image and covert to tensor 57 | imgs = [TF.to_pil_image(img) for img in imgs] 58 | if self.img_size is None: 59 | self.img_size = None 60 | 61 | if not self.img_size_dynamic: 62 | if imgs[0].size != (self.img_size, self.img_size): 63 | imgs = [TF.resize(img, [self.img_size, self.img_size], interpolation=3) 64 | for img in imgs] 65 | else: 66 | self.img_size = imgs[0].size[0] 67 | 68 | labels = [TF.to_pil_image(img) for img in labels] 69 | if len(labels) != 0: 70 | if labels[0].size != (self.img_size, self.img_size): 71 | labels = [TF.resize(img, [self.img_size, self.img_size], interpolation=0) 72 | for img in labels] 73 | 74 | random_base = 0.5 75 | if self.with_random_hflip and random.random() > 0.5: 76 | imgs = [TF.hflip(img) for img in imgs] 77 | labels = [TF.hflip(img) for img in labels] 78 | 79 | if self.with_random_vflip and random.random() > 0.5: 80 | imgs = [TF.vflip(img) for img in imgs] 81 | labels = [TF.vflip(img) for img in labels] 82 | 83 | if self.with_random_rot and random.random() > random_base: 84 | angles = [90, 180, 270] 85 | index = random.randint(0, 2) 86 | angle = angles[index] 87 | imgs = [TF.rotate(img, angle) for img in imgs] 88 | labels = [TF.rotate(img, angle) for img in labels] 89 | 90 | if self.with_random_crop and random.random() > 0: 91 | i, j, h, w = transforms.RandomResizedCrop(size=self.img_size). \ 92 | get_params(img=imgs[0], scale=(0.8, 1.0), ratio=(1, 1)) 93 | 94 | imgs = [TF.resized_crop(img, i, j, h, w, 95 | size=(self.img_size, self.img_size), 96 | interpolation=Image.CUBIC) 97 | for img in imgs] 98 | 99 | labels = [TF.resized_crop(img, i, j, h, w, 100 | size=(self.img_size, self.img_size), 101 | interpolation=Image.NEAREST) 102 | for img in labels] 103 | 104 | if self.with_scale_random_crop: 105 | # rescale 106 | scale_range = [1, 1.2] 107 | target_scale = scale_range[0] + random.random() * (scale_range[1] - scale_range[0]) 108 | 109 | imgs = [pil_rescale(img, target_scale, order=3) for img in imgs] 110 | labels = [pil_rescale(img, target_scale, order=0) for img in labels] 111 | # crop 112 | imgsize = imgs[0].size # h, w 113 | box = get_random_crop_box(imgsize=imgsize, cropsize=self.img_size) 114 | imgs = [pil_crop(img, box, cropsize=self.img_size, default_value=0) 115 | for img in imgs] 116 | labels = [pil_crop(img, box, cropsize=self.img_size, default_value=255) 117 | for img in labels] 118 | 119 | if self.with_random_blur and random.random() > 0: 120 | radius = random.random() 121 | imgs = [img.filter(ImageFilter.GaussianBlur(radius=radius)) 122 | for img in imgs] 123 | 124 | if to_tensor: 125 | # to tensor 126 | imgs = [TF.to_tensor(img) for img in imgs] 127 | labels = [torch.from_numpy(np.array(img, np.uint8)).unsqueeze(dim=0) 128 | for img in labels] 129 | 130 | imgs = [TF.normalize(img, mean=[0.5, 0.5, 0.5],std=[0.5, 0.5, 0.5]) 131 | for img in imgs] 132 | 133 | return imgs, labels 134 | 135 | 136 | def pil_crop(image, box, cropsize, default_value): 137 | assert isinstance(image, Image.Image) 138 | img = np.array(image) 139 | 140 | if len(img.shape) == 3: 141 | cont = np.ones((cropsize, cropsize, img.shape[2]), img.dtype)*default_value 142 | else: 143 | cont = np.ones((cropsize, cropsize), img.dtype)*default_value 144 | cont[box[0]:box[1], box[2]:box[3]] = img[box[4]:box[5], box[6]:box[7]] 145 | 146 | return Image.fromarray(cont) 147 | 148 | 149 | def get_random_crop_box(imgsize, cropsize): 150 | h, w = imgsize 151 | ch = min(cropsize, h) 152 | cw = min(cropsize, w) 153 | 154 | w_space = w - cropsize 155 | h_space = h - cropsize 156 | 157 | if w_space > 0: 158 | cont_left = 0 159 | img_left = random.randrange(w_space + 1) 160 | else: 161 | cont_left = random.randrange(-w_space + 1) 162 | img_left = 0 163 | 164 | if h_space > 0: 165 | cont_top = 0 166 | img_top = random.randrange(h_space + 1) 167 | else: 168 | cont_top = random.randrange(-h_space + 1) 169 | img_top = 0 170 | 171 | return cont_top, cont_top+ch, cont_left, cont_left+cw, img_top, img_top+ch, img_left, img_left+cw 172 | 173 | 174 | def pil_rescale(img, scale, order): 175 | assert isinstance(img, Image.Image) 176 | height, width = img.size 177 | target_size = (int(np.round(height*scale)), int(np.round(width*scale))) 178 | return pil_resize(img, target_size, order) 179 | 180 | 181 | def pil_resize(img, size, order): 182 | assert isinstance(img, Image.Image) 183 | if size[0] == img.size[0] and size[1] == img.size[1]: 184 | return img 185 | if order == 3: 186 | resample = Image.BICUBIC 187 | elif order == 0: 188 | resample = Image.NEAREST 189 | return img.resize(size[::-1], resample) 190 | -------------------------------------------------------------------------------- /utils/metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.utils.data 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch.autograd import Variable 6 | 7 | 8 | 9 | class FocalLoss(nn.Module): 10 | def __init__(self, gamma=0, alpha=None, size_average=True): 11 | super(FocalLoss, self).__init__() 12 | self.gamma = gamma 13 | self.alpha = alpha 14 | if isinstance(alpha, (float, int)): 15 | self.alpha = torch.Tensor([alpha, 1-alpha]) 16 | if isinstance(alpha, list): 17 | self.alpha = torch.Tensor(alpha) 18 | self.size_average = size_average 19 | 20 | def forward(self, input, target): 21 | if input.dim() > 2: 22 | # N,C,H,W => N,C,H*W 23 | input = input.view(input.size(0), input.size(1), -1) 24 | 25 | # N,C,H*W => N,H*W,C 26 | input = input.transpose(1, 2) 27 | 28 | # N,H*W,C => N*H*W,C 29 | input = input.contiguous().view(-1, input.size(2)) 30 | 31 | target = target.view(-1, 1) 32 | logpt = F.log_softmax(input,dim=-1) 33 | logpt = logpt.gather(1, target) 34 | logpt = logpt.view(-1) 35 | pt = Variable(logpt.data.exp()) 36 | 37 | if self.alpha is not None: 38 | if self.alpha.type() != input.data.type(): 39 | self.alpha = self.alpha.type_as(input.data) 40 | at = self.alpha.gather(0, target.data.view(-1)) 41 | logpt = logpt * Variable(at) 42 | 43 | loss = -1 * (1-pt)**self.gamma * logpt 44 | 45 | if self.size_average: 46 | return loss.mean() 47 | else: 48 | return loss.sum() 49 | 50 | def dice_loss(logits, true, eps=1e-7): 51 | """Computes the Sørensen–Dice loss. 52 | Note that PyTorch optimizers minimize a loss. In this 53 | case, we would like to maximize the dice loss so we 54 | return the negated dice loss. 55 | Args: 56 | true: a tensor of shape [B, 1, H, W]. 57 | logits: a tensor of shape [B, C, H, W]. Corresponds to 58 | the raw output or logits of the model. 59 | eps: added to the denominator for numerical stability. 60 | Returns: 61 | dice_loss: the Sørensen–Dice loss. 62 | """ 63 | num_classes = logits.shape[1] 64 | if num_classes == 1: 65 | true_1_hot = torch.eye(num_classes + 1)[true.squeeze(1)] 66 | true_1_hot = true_1_hot.permute(0, 3, 1, 2).float() 67 | true_1_hot_f = true_1_hot[:, 0:1, :, :] 68 | true_1_hot_s = true_1_hot[:, 1:2, :, :] 69 | true_1_hot = torch.cat([true_1_hot_s, true_1_hot_f], dim=1) 70 | pos_prob = torch.sigmoid(logits) 71 | neg_prob = 1 - pos_prob 72 | probas = torch.cat([pos_prob, neg_prob], dim=1) 73 | else: 74 | true_1_hot = torch.eye(num_classes)[true.squeeze(1)] 75 | true_1_hot = true_1_hot.permute(0, 3, 1, 2).float() 76 | probas = F.softmax(logits, dim=1) 77 | true_1_hot = true_1_hot.type(logits.type()) 78 | dims = (0,) + tuple(range(2, true.ndimension())) 79 | intersection = torch.sum(probas * true_1_hot, dims) 80 | cardinality = torch.sum(probas + true_1_hot, dims) 81 | dice_loss = (2. * intersection / (cardinality + eps)).mean() 82 | return (1 - dice_loss) 83 | 84 | 85 | def jaccard_loss(logits, true, eps=1e-7): 86 | """Computes the Jaccard loss, a.k.a the IoU loss. 87 | Note that PyTorch optimizers minimize a loss. In this 88 | case, we would like to maximize the jaccard loss so we 89 | return the negated jaccard loss. 90 | Args: 91 | true: a tensor of shape [B, H, W] or [B, 1, H, W]. 92 | logits: a tensor of shape [B, C, H, W]. Corresponds to 93 | the raw output or logits of the model. 94 | eps: added to the denominator for numerical stability. 95 | Returns: 96 | jacc_loss: the Jaccard loss. 97 | """ 98 | num_classes = logits.shape[1] 99 | if num_classes == 1: 100 | true_1_hot = torch.eye(num_classes + 1)[true.squeeze(1)] 101 | true_1_hot = true_1_hot.permute(0, 3, 1, 2).float() 102 | true_1_hot_f = true_1_hot[:, 0:1, :, :] 103 | true_1_hot_s = true_1_hot[:, 1:2, :, :] 104 | true_1_hot = torch.cat([true_1_hot_s, true_1_hot_f], dim=1) 105 | pos_prob = torch.sigmoid(logits) 106 | neg_prob = 1 - pos_prob 107 | probas = torch.cat([pos_prob, neg_prob], dim=1) 108 | else: 109 | true_1_hot = torch.eye(num_classes)[true.squeeze(1)] 110 | true_1_hot = true_1_hot.permute(0, 3, 1, 2).float() 111 | probas = F.softmax(logits, dim=1) 112 | true_1_hot = true_1_hot.type(logits.type()) 113 | dims = (0,) + tuple(range(2, true.ndimension())) 114 | intersection = torch.sum(probas * true_1_hot, dims) 115 | cardinality = torch.sum(probas + true_1_hot, dims) 116 | union = cardinality - intersection 117 | jacc_loss = (intersection / (union + eps)).mean() 118 | return (1 - jacc_loss) 119 | 120 | 121 | class TverskyLoss(nn.Module): 122 | def __init__(self, alpha=0.5, beta=0.5, eps=1e-7, size_average=True): 123 | super(TverskyLoss, self).__init__() 124 | self.alpha = alpha 125 | self.beta = beta 126 | self.size_average = size_average 127 | self.eps = eps 128 | 129 | def forward(self, logits, true): 130 | """Computes the Tversky loss [1]. 131 | Args: 132 | true: a tensor of shape [B, H, W] or [B, 1, H, W]. 133 | logits: a tensor of shape [B, C, H, W]. Corresponds to 134 | the raw output or logits of the model. 135 | alpha: controls the penalty for false positives. 136 | beta: controls the penalty for false negatives. 137 | eps: added to the denominator for numerical stability. 138 | Returns: 139 | tversky_loss: the Tversky loss. 140 | Notes: 141 | alpha = beta = 0.5 => dice coeff 142 | alpha = beta = 1 => tanimoto coeff 143 | alpha + beta = 1 => F beta coeff 144 | References: 145 | [1]: https://arxiv.org/abs/1706.05721 146 | """ 147 | num_classes = logits.shape[1] 148 | if num_classes == 1: 149 | true_1_hot = torch.eye(num_classes + 1)[true.squeeze(1)] 150 | true_1_hot = true_1_hot.permute(0, 3, 1, 2).float() 151 | true_1_hot_f = true_1_hot[:, 0:1, :, :] 152 | true_1_hot_s = true_1_hot[:, 1:2, :, :] 153 | true_1_hot = torch.cat([true_1_hot_s, true_1_hot_f], dim=1) 154 | pos_prob = torch.sigmoid(logits) 155 | neg_prob = 1 - pos_prob 156 | probas = torch.cat([pos_prob, neg_prob], dim=1) 157 | else: 158 | true_1_hot = torch.eye(num_classes)[true.squeeze(1)] 159 | true_1_hot = true_1_hot.permute(0, 3, 1, 2).float() 160 | probas = F.softmax(logits, dim=1) 161 | 162 | true_1_hot = true_1_hot.type(logits.type()) 163 | dims = (0,) + tuple(range(2, true.ndimension())) 164 | intersection = torch.sum(probas * true_1_hot, dims) 165 | fps = torch.sum(probas * (1 - true_1_hot), dims) 166 | fns = torch.sum((1 - probas) * true_1_hot, dims) 167 | num = intersection 168 | denom = intersection + (self.alpha * fps) + (self.beta * fns) 169 | tversky_loss = (num / (denom + self.eps)).mean() 170 | return (1 - tversky_loss) 171 | -------------------------------------------------------------------------------- /compare/IFNet.py: -------------------------------------------------------------------------------- 1 | # credits: https://github.com/GeoZcx/A-deeply-supervised-image-fusion-network-for-change-detection-in-remote-sensing-images 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torchvision.models import vgg16 6 | import numpy as np 7 | 8 | 9 | class vgg16_base(nn.Module): 10 | def __init__(self): 11 | super(vgg16_base, self).__init__() 12 | features = list(vgg16(pretrained=True).features)[:30] 13 | self.features = nn.ModuleList(features).eval() 14 | 15 | def forward(self, x): 16 | results = [] 17 | for ii, model in enumerate(self.features): 18 | x = model(x) 19 | if ii in {3, 8, 15, 22, 29}: 20 | results.append(x) 21 | return results 22 | 23 | 24 | class ChannelAttention(nn.Module): 25 | def __init__(self, in_channels, ratio=8): 26 | super(ChannelAttention, self).__init__() 27 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 28 | self.max_pool = nn.AdaptiveMaxPool2d(1) 29 | self.fc1 = nn.Conv2d(in_channels, in_channels // ratio, 1, bias=False) 30 | self.relu1 = nn.ReLU() 31 | self.fc2 = nn.Conv2d(in_channels // ratio, in_channels, 1, bias=False) 32 | self.sigmod = nn.Sigmoid() 33 | 34 | def forward(self, x): 35 | avg_out = self.fc2(self.relu1(self.fc1(self.avg_pool(x)))) 36 | max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x)))) 37 | out = avg_out + max_out 38 | return self.sigmod(out) 39 | 40 | 41 | class SpatialAttention(nn.Module): 42 | def __init__(self): 43 | super(SpatialAttention, self).__init__() 44 | self.conv1 = nn.Conv2d(2, 1, 7, padding=3, bias=False) 45 | self.sigmoid = nn.Sigmoid() 46 | 47 | def forward(self, x): 48 | avg_out = torch.mean(x, dim=1, keepdim=True) 49 | max_out = torch.max(x, dim=1, keepdim=True, out=None)[0] 50 | 51 | x = torch.cat([avg_out, max_out], dim=1) 52 | x = self.conv1(x) 53 | return self.sigmoid(x) 54 | 55 | 56 | def conv2d_bn(in_channels, out_channels): 57 | return nn.Sequential( 58 | nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1), 59 | nn.PReLU(), 60 | nn.BatchNorm2d(out_channels), 61 | nn.Dropout(p=0.6), 62 | ) 63 | 64 | 65 | class DSIFN(nn.Module): 66 | def __init__(self): 67 | super().__init__() 68 | self.t1_base = vgg16_base() 69 | self.t2_base = vgg16_base() 70 | self.sa1 = SpatialAttention() 71 | self.sa2 = SpatialAttention() 72 | self.sa3 = SpatialAttention() 73 | self.sa4 = SpatialAttention() 74 | self.sa5 = SpatialAttention() 75 | 76 | self.sigmoid = nn.Sigmoid() 77 | 78 | # branch1 79 | self.ca1 = ChannelAttention(in_channels=1024) 80 | self.bn_ca1 = nn.BatchNorm2d(1024) 81 | self.o1_conv1 = conv2d_bn(1024, 512) 82 | self.o1_conv2 = conv2d_bn(512, 512) 83 | self.bn_sa1 = nn.BatchNorm2d(512) 84 | self.o1_conv3 = nn.Conv2d(512, 2, 1) 85 | self.trans_conv1 = nn.ConvTranspose2d(512, 512, kernel_size=2, stride=2) 86 | 87 | # branch 2 88 | self.ca2 = ChannelAttention(in_channels=1536) 89 | self.bn_ca2 = nn.BatchNorm2d(1536) 90 | self.o2_conv1 = conv2d_bn(1536, 512) 91 | self.o2_conv2 = conv2d_bn(512, 256) 92 | self.o2_conv3 = conv2d_bn(256, 256) 93 | self.bn_sa2 = nn.BatchNorm2d(256) 94 | self.o2_conv4 = nn.Conv2d(256, 2, 1) 95 | self.trans_conv2 = nn.ConvTranspose2d(256, 256, kernel_size=2, stride=2) 96 | 97 | # branch 3 98 | self.ca3 = ChannelAttention(in_channels=768) 99 | self.o3_conv1 = conv2d_bn(768, 256) 100 | self.o3_conv2 = conv2d_bn(256, 128) 101 | self.o3_conv3 = conv2d_bn(128, 128) 102 | self.bn_sa3 = nn.BatchNorm2d(128) 103 | self.o3_conv4 = nn.Conv2d(128, 2, 1) 104 | self.trans_conv3 = nn.ConvTranspose2d(128, 128, kernel_size=2, stride=2) 105 | 106 | # branch 4 107 | self.ca4 = ChannelAttention(in_channels=384) 108 | self.o4_conv1 = conv2d_bn(384, 128) 109 | self.o4_conv2 = conv2d_bn(128, 64) 110 | self.o4_conv3 = conv2d_bn(64, 64) 111 | self.bn_sa4 = nn.BatchNorm2d(64) 112 | self.o4_conv4 = nn.Conv2d(64, 2, 1) 113 | self.trans_conv4 = nn.ConvTranspose2d(64, 64, kernel_size=2, stride=2) 114 | 115 | # branch 5 116 | self.ca5 = ChannelAttention(in_channels=192) 117 | self.o5_conv1 = conv2d_bn(192, 64) 118 | self.o5_conv2 = conv2d_bn(64, 32) 119 | self.o5_conv3 = conv2d_bn(32, 16) 120 | self.bn_sa5 = nn.BatchNorm2d(16) 121 | self.o5_conv4 = nn.Conv2d(16, 2, 1) 122 | 123 | def forward(self, t1_input, t2_input): 124 | t1_list = self.t1_base(t1_input) 125 | t2_list = self.t2_base(t2_input) 126 | 127 | t1_f_l3, t1_f_l8, t1_f_l15, t1_f_l22, t1_f_l29 = t1_list[0], t1_list[1], t1_list[2], t1_list[3], t1_list[4] 128 | t2_f_l3, t2_f_l8, t2_f_l15, t2_f_l22, t2_f_l29, = t2_list[0], t2_list[1], t2_list[2], t2_list[3], t2_list[4] 129 | 130 | x = torch.cat((t1_f_l29, t2_f_l29), dim=1) 131 | # optional to use channel attention module in the first combined feature 132 | # 在第一个深度特征叠加层之后可以选择使用或者不使用通道注意力模块 133 | # x = self.ca1(x) * x 134 | x = self.o1_conv1(x) 135 | x = self.o1_conv2(x) 136 | x = self.sa1(x) * x 137 | x = self.bn_sa1(x) 138 | 139 | # branch_1_out = self.sigmoid(self.o1_conv3(x)) 140 | branch_1_out = self.o1_conv3(x) 141 | 142 | x = self.trans_conv1(x) 143 | x = torch.cat((x, t1_f_l22, t2_f_l22), dim=1) 144 | x = self.ca2(x) * x 145 | # According to the amount of the training data, appropriately reduce the use of conv layers to prevent overfitting 146 | # 根据训练数据的大小,适当减少conv层的使用来防止过拟合 147 | x = self.o2_conv1(x) 148 | x = self.o2_conv2(x) 149 | x = self.o2_conv3(x) 150 | x = self.sa2(x) * x 151 | x = self.bn_sa2(x) 152 | 153 | # branch_2_out = self.sigmoid(self.o2_conv4(x)) 154 | branch_2_out = self.o2_conv4(x) 155 | 156 | x = self.trans_conv2(x) 157 | x = torch.cat((x, t1_f_l15, t2_f_l15), dim=1) 158 | x = self.ca3(x) * x 159 | x = self.o3_conv1(x) 160 | x = self.o3_conv2(x) 161 | x = self.o3_conv3(x) 162 | x = self.sa3(x) * x 163 | x = self.bn_sa3(x) 164 | 165 | # branch_3_out = self.sigmoid(self.o3_conv4(x)) 166 | branch_3_out = self.o3_conv4(x) 167 | 168 | x = self.trans_conv3(x) 169 | x = torch.cat((x, t1_f_l8, t2_f_l8), dim=1) 170 | x = self.ca4(x) * x 171 | x = self.o4_conv1(x) 172 | x = self.o4_conv2(x) 173 | x = self.o4_conv3(x) 174 | x = self.sa4(x) * x 175 | x = self.bn_sa4(x) 176 | 177 | # branch_4_out = self.sigmoid(self.o4_conv4(x)) 178 | branch_4_out = self.o4_conv4(x) 179 | 180 | x = self.trans_conv4(x) 181 | x = torch.cat((x, t1_f_l3, t2_f_l3), dim=1) 182 | x = self.ca5(x) * x 183 | x = self.o5_conv1(x) 184 | x = self.o5_conv2(x) 185 | x = self.o5_conv3(x) 186 | x = self.sa5(x) * x 187 | x = self.bn_sa5(x) 188 | 189 | # branch_5_out = self.sigmoid(self.o5_conv4(x)) 190 | branch_5_out = self.o5_conv4(x) 191 | 192 | return [branch_5_out, branch_4_out, branch_3_out, branch_2_out, branch_1_out] -------------------------------------------------------------------------------- /models/evaluator.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import matplotlib.pyplot as plt 4 | 5 | from models.networks import * 6 | from misc.metric_tool import ConfuseMatrixMeter 7 | from misc.logger_tool import Logger 8 | from utils_ import de_norm 9 | import utils_ 10 | from tqdm import tqdm 11 | 12 | # Decide which device we want to run on 13 | # torch.cuda.current_device() 14 | 15 | # device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 16 | 17 | 18 | class CDEvaluator(): 19 | 20 | def __init__(self, args, dataloader): 21 | 22 | self.dataloader = dataloader 23 | 24 | self.n_class = args.n_class 25 | # define G 26 | self.net_G = define_G(args=args, gpu_ids=args.gpu_ids) 27 | self.device = torch.device("cuda:%s" % args.gpu_ids[0] if torch.cuda.is_available() and len(args.gpu_ids)>0 28 | else "cpu") 29 | print(self.device) 30 | 31 | # define some other vars to record the training states 32 | self.running_metric = ConfuseMatrixMeter(n_class=self.n_class) 33 | 34 | # define logger file 35 | logger_path = os.path.join(args.checkpoint_dir, 'log_test.txt') 36 | self.logger = Logger(logger_path) 37 | self.logger.write_dict_str(args.__dict__) 38 | 39 | 40 | # training log 41 | self.epoch_acc = 0 42 | self.best_val_acc = 0.0 43 | self.best_epoch_id = 0 44 | 45 | self.steps_per_epoch = len(dataloader) 46 | 47 | self.G_pred = None 48 | self.pred_vis = None 49 | self.batch = None 50 | self.is_training = False 51 | self.batch_id = 0 52 | self.epoch_id = 0 53 | self.checkpoint_dir = args.checkpoint_dir 54 | self.vis_dir = args.vis_dir 55 | 56 | # check and create model dir 57 | if os.path.exists(self.checkpoint_dir) is False: 58 | os.mkdir(self.checkpoint_dir) 59 | if os.path.exists(self.vis_dir) is False: 60 | os.mkdir(self.vis_dir) 61 | 62 | 63 | def _load_checkpoint(self, checkpoint_name='best_ckpt.pt'): 64 | 65 | if os.path.exists(os.path.join(self.checkpoint_dir, checkpoint_name)): 66 | self.logger.write('loading last checkpoint...\n') 67 | # load the entire checkpoint 68 | checkpoint = torch.load(os.path.join(self.checkpoint_dir, checkpoint_name), map_location=self.device) 69 | 70 | self.net_G.load_state_dict(checkpoint['model_G_state_dict']) 71 | 72 | self.net_G.to(self.device) 73 | 74 | # update some other states 75 | self.best_val_acc = checkpoint['best_val_acc'] 76 | self.best_epoch_id = checkpoint['best_epoch_id'] 77 | 78 | self.logger.write('Eval Historical_best_acc = %.4f (at epoch %d)\n' % 79 | (self.best_val_acc, self.best_epoch_id)) 80 | self.logger.write('\n') 81 | 82 | else: 83 | raise FileNotFoundError('no such checkpoint %s' % checkpoint_name) 84 | 85 | 86 | def _visualize_pred(self,args): 87 | # pred = torch.argmax(self.G_pred, dim=1, keepdim=True) 88 | if args.deep_supervision== True: 89 | pred =torch.argmax(self.G_pred[-1], dim=1, keepdim=True) 90 | else: 91 | pred = torch.argmax(self.G_pred, dim=1, keepdim=True) 92 | pred_vis = pred * 255 93 | return pred_vis 94 | 95 | 96 | def _update_metric(self,args): 97 | """ 98 | update metric 99 | """ 100 | target = self.batch['L'].to(self.device).detach() 101 | # G_pred = self.G_pred.detach() 102 | if args.deep_supervision == True: 103 | G_pred = self.G_pred[-1].detach() 104 | else: 105 | G_pred = self.G_pred.detach() 106 | G_pred = torch.argmax(G_pred, dim=1) 107 | 108 | current_score = self.running_metric.update_cm(pr=G_pred.cpu().numpy(), gt=target.cpu().numpy()) 109 | return current_score 110 | 111 | def _collect_running_batch_states(self,args): 112 | 113 | running_acc = self._update_metric(args)#变化的精确度 114 | 115 | m = len(self.dataloader) 116 | #print('m:',m) 117 | #print('batch_id:',self.batch_id) 118 | 119 | if np.mod(self.batch_id, 100) == 1:#取模运算,同正为正,同负为负 120 | message = 'Is_training: %s. [%d,%d], running_mf1: %.5f\n' %\ 121 | (self.is_training, self.batch_id, m, running_acc) 122 | self.logger.write(message) 123 | 124 | if np.mod(self.batch_id, 100) == 1: 125 | vis_input = utils_.make_numpy_grid(de_norm(self.batch['A'])) 126 | vis_input2 = utils_.make_numpy_grid(de_norm(self.batch['B'])) 127 | 128 | vis_pred = utils_.make_numpy_grid(self._visualize_pred(args)) 129 | 130 | vis_gt = utils_.make_numpy_grid(self.batch['L']) 131 | vis = np.concatenate([vis_input, vis_input2, vis_pred, vis_gt], axis=0) 132 | vis = np.clip(vis, a_min=0.0, a_max=1.0)#限制数组中的值,也就是说clip这个函数将数组中的元素限制在a_min, a_max之间,大于a_max的就使得它等于 a_max,小于a_min,的就使得它等于a_min 133 | file_name = os.path.join( 134 | self.vis_dir, 'eval_' + str(self.batch_id)+'.jpg') 135 | plt.imsave(file_name, vis) 136 | 137 | 138 | 139 | def _collect_epoch_states(self): 140 | 141 | scores_dict = self.running_metric.get_scores() 142 | 143 | np.save(os.path.join(self.checkpoint_dir, 'scores_dict.npy'), scores_dict) 144 | 145 | self.epoch_acc = scores_dict['mf1'] 146 | print(self.epoch_acc ) 147 | 148 | with open(os.path.join(self.checkpoint_dir, '%s.txt' % (self.epoch_acc)), 149 | mode='a') as file: 150 | pass 151 | 152 | message = '' 153 | for k, v in scores_dict.items(): 154 | message += '%s: %.5f ' % (k, v) 155 | self.logger.write('%s\n' % message) # save the message 156 | 157 | self.logger.write('\n') 158 | 159 | def _clear_cache(self): 160 | self.running_metric.clear() 161 | 162 | def _forward_pass(self,batch,args): 163 | self.batch = batch 164 | img_in1 = batch['A'].to(self.device) 165 | img_in2 = batch['B'].to(self.device) 166 | # sobel = batch['S'].to(self.device) 167 | # self.G_pred1, self.G_pred2, self.G_pred3 = self.net_G(img_in1, img_in2) 168 | # self.G_pred = self.G_pred1 + self.G_pred2 + self.G_pred3 169 | if args.loss_SD== True : 170 | # self.G_pred1, self.G_pred2, self.G_pred3 = self.net_G(img_in1, img_in2,sobel) 171 | # self.G_pred = self.G_pred1 + self.G_pred2 + self.G_pred3 172 | #cd_net 173 | # self.G_pred0, self.G_pred1, self.G_pred2, self.G_pred3, self.G_pred4 = self.net_G(img_in1, img_in2) 174 | # self.G_pred = self.G_pred0 175 | 176 | #DMINet 177 | self.G_pred0, self.G_pred1, self.G_pred2, self.G_pred3 = self.net_G(img_in1, img_in2) 178 | self.G_pred = self.G_pred0 + self.G_pred1 179 | 180 | else : 181 | # self.G_pred = self.net_G(img_in1, img_in2,sobel) 182 | self.G_pred = self.net_G(img_in1, img_in2) 183 | 184 | def eval_models(self,args,checkpoint_name='best_ckpt.pt'): 185 | 186 | self._load_checkpoint(checkpoint_name) 187 | 188 | ################## Eval ################## 189 | ########################################## 190 | self.logger.write('Begin evaluation...\n') 191 | self._clear_cache()#不返回任何值,清空running_metric 192 | self.is_training = False 193 | self.net_G.eval() 194 | 195 | #Iterate over data.遍历数据(原始) 196 | for self.batch_id, batch in enumerate(self.dataloader, 0): 197 | #name = batch['name'] 198 | #print('process: %s' % name) 199 | with torch.no_grad(): 200 | self._forward_pass(batch,args) 201 | self._collect_running_batch_states(args) 202 | self._collect_epoch_states() 203 | 204 | 205 | -------------------------------------------------------------------------------- /compare/DASNet.py: -------------------------------------------------------------------------------- 1 | ########################################################################### 2 | # Created by: CASIA IVA 3 | # Email: jliu@nlpr.ia.ac.cn 4 | # Copyright (c) 2018 5 | ########################################################################### 6 | from __future__ import division 7 | import os 8 | import numpy as np 9 | import torch 10 | import torch.nn as nn 11 | from torch.nn.functional import upsample, normalize 12 | from torch.nn import Module, Sequential, Conv2d, ReLU,AdaptiveMaxPool2d, AdaptiveAvgPool2d, \ 13 | NLLLoss, BCELoss, CrossEntropyLoss, AvgPool2d, MaxPool2d, Parameter, Linear, Sigmoid, Softmax, Dropout, Embedding 14 | # from attention import PAM_Module 15 | # from attention import CAM_Module 16 | from resbase import BaseNet 17 | import torch.nn.functional as F 18 | 19 | __all__ = ['DANet'] 20 | 21 | class PAM_Module(Module): 22 | """ Position attention module""" 23 | #Ref from SAGAN 24 | def __init__(self, in_dim): 25 | super(PAM_Module, self).__init__() 26 | self.chanel_in = in_dim 27 | 28 | self.query_conv = Conv2d(in_channels=in_dim, out_channels=in_dim//8, kernel_size=1) 29 | self.key_conv = Conv2d(in_channels=in_dim, out_channels=in_dim//8, kernel_size=1) 30 | self.value_conv = Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1) 31 | self.gamma = Parameter(torch.zeros(1)) 32 | 33 | self.softmax = Softmax(dim=-1) 34 | 35 | def forward(self, x): 36 | """ 37 | inputs : 38 | x : input feature maps( B X C X H X W) 39 | returns : 40 | out : attention value + input feature 41 | attention: B X (HxW) X (HxW) 42 | """ 43 | m_batchsize, C, height, width = x.size() 44 | proj_query = self.query_conv(x).view(m_batchsize, -1, width*height).permute(0, 2, 1) 45 | proj_key = self.key_conv(x).view(m_batchsize, -1, width*height) 46 | energy = torch.bmm(proj_query, proj_key) 47 | attention = self.softmax(energy) 48 | proj_value = self.value_conv(x).view(m_batchsize, -1, width*height) 49 | 50 | out = torch.bmm(proj_value, attention.permute(0, 2, 1)) 51 | out = out.view(m_batchsize, C, height, width) 52 | 53 | out = self.gamma*out + x 54 | return out 55 | 56 | 57 | class CAM_Module(Module): 58 | """ Channel attention module""" 59 | def __init__(self, in_dim): 60 | super(CAM_Module, self).__init__() 61 | self.chanel_in = in_dim 62 | 63 | 64 | self.gamma = Parameter(torch.zeros(1)) 65 | self.softmax = Softmax(dim=-1) 66 | def forward(self,x): 67 | """ 68 | inputs : 69 | x : input feature maps( B X C X H X W) 70 | returns : 71 | out : attention value + input feature 72 | attention: B X C X C 73 | """ 74 | m_batchsize, C, height, width = x.size() 75 | proj_query = x.view(m_batchsize, C, -1) 76 | proj_key = x.view(m_batchsize, C, -1).permute(0, 2, 1) 77 | energy = torch.bmm(proj_query, proj_key) 78 | energy_new = torch.max(energy, -1, keepdim=True)[0].expand_as(energy)-energy 79 | attention = self.softmax(energy_new) 80 | proj_value = x.view(m_batchsize, C, -1) 81 | 82 | out = torch.bmm(attention, proj_value) 83 | out = out.view(m_batchsize, C, height, width) 84 | 85 | out = self.gamma*out + x 86 | return out 87 | 88 | class DANet(BaseNet): 89 | r"""Fully Convolutional Networks for Semantic Segmentation 90 | 91 | Parameters 92 | ---------- 93 | nclass : int 94 | Number of categories for the training dataset. 95 | backbone : string 96 | Pre-trained dilated backbone network type (default:'resnet50'; 'resnet50', 97 | 'resnet101' or 'resnet152'). 98 | norm_layer : object 99 | Normalization layer used in backbone network (default: :class:`mxnet.gluon.nn.BatchNorm`; 100 | 101 | 102 | Reference: 103 | 104 | Long, Jonathan, Evan Shelhamer, and Trevor Darrell. "Fully convolutional networks 105 | for semantic segmentation." *CVPR*, 2015 106 | 107 | """ 108 | 109 | def __init__(self, nclass, backbone, norm_layer=nn.BatchNorm2d, **kwargs): 110 | super(DANet, self).__init__(nclass, backbone, norm_layer=norm_layer, **kwargs) 111 | self.head = DANetHead(2048, nclass, norm_layer) 112 | 113 | def forward(self, x): 114 | 115 | _, _, c3, c4 = self.base_forward(x) 116 | 117 | x = self.head(c4) 118 | x = list(x) 119 | 120 | return x[0],x[1],x[2] 121 | 122 | 123 | class DANetHead(nn.Module): 124 | def __init__(self, in_channels, out_channels, norm_layer): 125 | super(DANetHead, self).__init__() 126 | inter_channels = in_channels // 4 127 | self.conv5a = nn.Sequential(nn.Conv2d(in_channels, inter_channels, 3, padding=1, bias=False), 128 | norm_layer(inter_channels), 129 | nn.ReLU()) 130 | 131 | self.conv5c = nn.Sequential(nn.Conv2d(in_channels, inter_channels, 3, padding=1, bias=False), 132 | norm_layer(inter_channels), 133 | nn.ReLU()) 134 | 135 | self.sa = PAM_Module(inter_channels) 136 | self.sc = CAM_Module(inter_channels) 137 | self.conv51 = nn.Sequential(nn.Conv2d(inter_channels, inter_channels, 3, padding=1, bias=False), 138 | norm_layer(inter_channels), 139 | nn.ReLU()) 140 | self.conv52 = nn.Sequential(nn.Conv2d(inter_channels, inter_channels, 3, padding=1, bias=False), 141 | norm_layer(inter_channels), 142 | nn.ReLU()) 143 | 144 | self.conv6 = nn.Sequential(nn.Dropout2d(0.1, False), nn.Conv2d(512, out_channels, 1)) 145 | self.conv7 = nn.Sequential(nn.Dropout2d(0.1, False), nn.Conv2d(512, out_channels, 1)) 146 | 147 | self.conv8 = nn.Sequential(nn.Dropout2d(0.1, False), nn.Conv2d(512, out_channels, 1)) 148 | 149 | def forward(self, x): 150 | feat1 = self.conv5a(x) 151 | sa_feat = self.sa(feat1) 152 | sa_conv = self.conv51(sa_feat) 153 | sa_output = self.conv6(sa_conv) 154 | feat2 = self.conv5c(x) 155 | sc_feat = self.sc(feat2) 156 | sc_conv = self.conv52(sc_feat) 157 | sc_output = self.conv7(sc_conv) 158 | 159 | feat_sum = sa_conv + sc_conv 160 | 161 | sasc_output = self.conv8(feat_sum) 162 | 163 | return sa_output,sc_output,sasc_output 164 | 165 | def cnn(): 166 | model = DANet(512, backbone='resnet50') 167 | return model 168 | 169 | class SiameseNet(nn.Module): 170 | def __init__(self,norm_flag = 'l2'): 171 | super(SiameseNet, self).__init__() 172 | self.CNN = cnn() 173 | if norm_flag == 'l2': 174 | self.norm = F.normalize 175 | if norm_flag == 'exp': 176 | self.norm = nn.Softmax2d() 177 | ''''''''' 178 | def forward(self,t0,t1): 179 | out_t0_embedding = self.CNN(t0) 180 | out_t1_embedding = self.CNN(t1) 181 | #out_t0_conv5_norm,out_t1_conv5_norm = self.norm(out_t0_conv5),self.norm(out_t1_conv5) 182 | #out_t0_fc7_norm,out_t1_fc7_norm = self.norm(out_t0_fc7),self.norm(out_t1_fc7) 183 | out_t0_embedding_norm,out_t1_embedding_norm = self.norm(out_t0_embedding),self.norm(out_t1_embedding) 184 | return [out_t0_embedding_norm,out_t1_embedding_norm] 185 | ''''''''' 186 | 187 | def forward(self,t0,t1): 188 | 189 | out_t0_conv5,out_t0_fc7,out_t0_embedding = self.CNN(t0) 190 | out_t1_conv5,out_t1_fc7,out_t1_embedding = self.CNN(t1) 191 | out_t0_conv5_norm,out_t1_conv5_norm = self.norm(out_t0_conv5,2,dim=1),self.norm(out_t1_conv5,2,dim=1) 192 | out_t0_fc7_norm,out_t1_fc7_norm = self.norm(out_t0_fc7,2,dim=1),self.norm(out_t1_fc7,2,dim=1) 193 | out_t0_embedding_norm,out_t1_embedding_norm = self.norm(out_t0_embedding,2,dim=1),self.norm(out_t1_embedding,2,dim=1) 194 | return [out_t0_conv5_norm,out_t1_conv5_norm],[out_t0_fc7_norm,out_t1_fc7_norm],[out_t0_embedding_norm,out_t1_embedding_norm] 195 | 196 | 197 | if __name__ == '__main__': 198 | m = SiameseNet() 199 | print('gg') -------------------------------------------------------------------------------- /compare/FC_Siam_conc.py: -------------------------------------------------------------------------------- 1 | # Rodrigo Caye Daudt 2 | # https://rcdaudt.github.io/ 3 | # Daudt, R. C., Le Saux, B., & Boulch, A. "Fully convolutional siamese networks for change detection". In 2018 25th IEEE International Conference on Image Processing (ICIP) (pp. 4063-4067). IEEE. 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from torch.nn.modules.padding import ReplicationPad2d 9 | 10 | class SiamUnet_conc(nn.Module): 11 | """SiamUnet_conc segmentation network.""" 12 | 13 | def __init__(self, input_nbr, label_nbr): 14 | super(SiamUnet_conc, self).__init__() 15 | 16 | self.input_nbr = input_nbr 17 | 18 | self.conv11 = nn.Conv2d(input_nbr, 16, kernel_size=3, padding=1) 19 | self.bn11 = nn.BatchNorm2d(16) 20 | self.do11 = nn.Dropout2d(p=0.2) 21 | self.conv12 = nn.Conv2d(16, 16, kernel_size=3, padding=1) 22 | self.bn12 = nn.BatchNorm2d(16) 23 | self.do12 = nn.Dropout2d(p=0.2) 24 | 25 | self.conv21 = nn.Conv2d(16, 32, kernel_size=3, padding=1) 26 | self.bn21 = nn.BatchNorm2d(32) 27 | self.do21 = nn.Dropout2d(p=0.2) 28 | self.conv22 = nn.Conv2d(32, 32, kernel_size=3, padding=1) 29 | self.bn22 = nn.BatchNorm2d(32) 30 | self.do22 = nn.Dropout2d(p=0.2) 31 | 32 | self.conv31 = nn.Conv2d(32, 64, kernel_size=3, padding=1) 33 | self.bn31 = nn.BatchNorm2d(64) 34 | self.do31 = nn.Dropout2d(p=0.2) 35 | self.conv32 = nn.Conv2d(64, 64, kernel_size=3, padding=1) 36 | self.bn32 = nn.BatchNorm2d(64) 37 | self.do32 = nn.Dropout2d(p=0.2) 38 | self.conv33 = nn.Conv2d(64, 64, kernel_size=3, padding=1) 39 | self.bn33 = nn.BatchNorm2d(64) 40 | self.do33 = nn.Dropout2d(p=0.2) 41 | 42 | self.conv41 = nn.Conv2d(64, 128, kernel_size=3, padding=1) 43 | self.bn41 = nn.BatchNorm2d(128) 44 | self.do41 = nn.Dropout2d(p=0.2) 45 | self.conv42 = nn.Conv2d(128, 128, kernel_size=3, padding=1) 46 | self.bn42 = nn.BatchNorm2d(128) 47 | self.do42 = nn.Dropout2d(p=0.2) 48 | self.conv43 = nn.Conv2d(128, 128, kernel_size=3, padding=1) 49 | self.bn43 = nn.BatchNorm2d(128) 50 | self.do43 = nn.Dropout2d(p=0.2) 51 | 52 | self.upconv4 = nn.ConvTranspose2d(128, 128, kernel_size=3, padding=1, stride=2, output_padding=1) 53 | 54 | self.conv43d = nn.ConvTranspose2d(384, 128, kernel_size=3, padding=1) 55 | self.bn43d = nn.BatchNorm2d(128) 56 | self.do43d = nn.Dropout2d(p=0.2) 57 | self.conv42d = nn.ConvTranspose2d(128, 128, kernel_size=3, padding=1) 58 | self.bn42d = nn.BatchNorm2d(128) 59 | self.do42d = nn.Dropout2d(p=0.2) 60 | self.conv41d = nn.ConvTranspose2d(128, 64, kernel_size=3, padding=1) 61 | self.bn41d = nn.BatchNorm2d(64) 62 | self.do41d = nn.Dropout2d(p=0.2) 63 | 64 | self.upconv3 = nn.ConvTranspose2d(64, 64, kernel_size=3, padding=1, stride=2, output_padding=1) 65 | 66 | self.conv33d = nn.ConvTranspose2d(192, 64, kernel_size=3, padding=1) 67 | self.bn33d = nn.BatchNorm2d(64) 68 | self.do33d = nn.Dropout2d(p=0.2) 69 | self.conv32d = nn.ConvTranspose2d(64, 64, kernel_size=3, padding=1) 70 | self.bn32d = nn.BatchNorm2d(64) 71 | self.do32d = nn.Dropout2d(p=0.2) 72 | self.conv31d = nn.ConvTranspose2d(64, 32, kernel_size=3, padding=1) 73 | self.bn31d = nn.BatchNorm2d(32) 74 | self.do31d = nn.Dropout2d(p=0.2) 75 | 76 | self.upconv2 = nn.ConvTranspose2d(32, 32, kernel_size=3, padding=1, stride=2, output_padding=1) 77 | 78 | self.conv22d = nn.ConvTranspose2d(96, 32, kernel_size=3, padding=1) 79 | self.bn22d = nn.BatchNorm2d(32) 80 | self.do22d = nn.Dropout2d(p=0.2) 81 | self.conv21d = nn.ConvTranspose2d(32, 16, kernel_size=3, padding=1) 82 | self.bn21d = nn.BatchNorm2d(16) 83 | self.do21d = nn.Dropout2d(p=0.2) 84 | 85 | self.upconv1 = nn.ConvTranspose2d(16, 16, kernel_size=3, padding=1, stride=2, output_padding=1) 86 | 87 | self.conv12d = nn.ConvTranspose2d(48, 16, kernel_size=3, padding=1) 88 | self.bn12d = nn.BatchNorm2d(16) 89 | self.do12d = nn.Dropout2d(p=0.2) 90 | self.conv11d = nn.ConvTranspose2d(16, label_nbr, kernel_size=3, padding=1) 91 | 92 | self.sm = nn.LogSoftmax(dim=1) 93 | 94 | def forward(self, x1, x2): 95 | 96 | """Forward method.""" 97 | # Stage 1 98 | x11 = self.do11(F.relu(self.bn11(self.conv11(x1)))) 99 | x12_1 = self.do12(F.relu(self.bn12(self.conv12(x11)))) 100 | x1p = F.max_pool2d(x12_1, kernel_size=2, stride=2) 101 | 102 | 103 | # Stage 2 104 | x21 = self.do21(F.relu(self.bn21(self.conv21(x1p)))) 105 | x22_1 = self.do22(F.relu(self.bn22(self.conv22(x21)))) 106 | x2p = F.max_pool2d(x22_1, kernel_size=2, stride=2) 107 | 108 | # Stage 3 109 | x31 = self.do31(F.relu(self.bn31(self.conv31(x2p)))) 110 | x32 = self.do32(F.relu(self.bn32(self.conv32(x31)))) 111 | x33_1 = self.do33(F.relu(self.bn33(self.conv33(x32)))) 112 | x3p = F.max_pool2d(x33_1, kernel_size=2, stride=2) 113 | 114 | # Stage 4 115 | x41 = self.do41(F.relu(self.bn41(self.conv41(x3p)))) 116 | x42 = self.do42(F.relu(self.bn42(self.conv42(x41)))) 117 | x43_1 = self.do43(F.relu(self.bn43(self.conv43(x42)))) 118 | x4p = F.max_pool2d(x43_1, kernel_size=2, stride=2) 119 | 120 | 121 | #################################################### 122 | # Stage 1 123 | x11 = self.do11(F.relu(self.bn11(self.conv11(x2)))) 124 | x12_2 = self.do12(F.relu(self.bn12(self.conv12(x11)))) 125 | x1p = F.max_pool2d(x12_2, kernel_size=2, stride=2) 126 | 127 | # Stage 2 128 | x21 = self.do21(F.relu(self.bn21(self.conv21(x1p)))) 129 | x22_2 = self.do22(F.relu(self.bn22(self.conv22(x21)))) 130 | x2p = F.max_pool2d(x22_2, kernel_size=2, stride=2) 131 | 132 | # Stage 3 133 | x31 = self.do31(F.relu(self.bn31(self.conv31(x2p)))) 134 | x32 = self.do32(F.relu(self.bn32(self.conv32(x31)))) 135 | x33_2 = self.do33(F.relu(self.bn33(self.conv33(x32)))) 136 | x3p = F.max_pool2d(x33_2, kernel_size=2, stride=2) 137 | 138 | # Stage 4 139 | x41 = self.do41(F.relu(self.bn41(self.conv41(x3p)))) 140 | x42 = self.do42(F.relu(self.bn42(self.conv42(x41)))) 141 | x43_2 = self.do43(F.relu(self.bn43(self.conv43(x42)))) 142 | x4p = F.max_pool2d(x43_2, kernel_size=2, stride=2) 143 | 144 | 145 | #################################################### 146 | # Stage 4d 147 | x4d = self.upconv4(x4p) 148 | pad4 = ReplicationPad2d((0, x43_1.size(3) - x4d.size(3), 0, x43_1.size(2) - x4d.size(2))) 149 | x4d = torch.cat((pad4(x4d), x43_1, x43_2), 1) 150 | x43d = self.do43d(F.relu(self.bn43d(self.conv43d(x4d)))) 151 | x42d = self.do42d(F.relu(self.bn42d(self.conv42d(x43d)))) 152 | x41d = self.do41d(F.relu(self.bn41d(self.conv41d(x42d)))) 153 | 154 | # Stage 3d 155 | x3d = self.upconv3(x41d) 156 | pad3 = ReplicationPad2d((0, x33_1.size(3) - x3d.size(3), 0, x33_1.size(2) - x3d.size(2))) 157 | x3d = torch.cat((pad3(x3d), x33_1, x33_2), 1) 158 | x33d = self.do33d(F.relu(self.bn33d(self.conv33d(x3d)))) 159 | x32d = self.do32d(F.relu(self.bn32d(self.conv32d(x33d)))) 160 | x31d = self.do31d(F.relu(self.bn31d(self.conv31d(x32d)))) 161 | 162 | # Stage 2d 163 | x2d = self.upconv2(x31d) 164 | pad2 = ReplicationPad2d((0, x22_1.size(3) - x2d.size(3), 0, x22_1.size(2) - x2d.size(2))) 165 | x2d = torch.cat((pad2(x2d), x22_1, x22_2), 1) 166 | x22d = self.do22d(F.relu(self.bn22d(self.conv22d(x2d)))) 167 | x21d = self.do21d(F.relu(self.bn21d(self.conv21d(x22d)))) 168 | 169 | # Stage 1d 170 | x1d = self.upconv1(x21d) 171 | pad1 = ReplicationPad2d((0, x12_1.size(3) - x1d.size(3), 0, x12_1.size(2) - x1d.size(2))) 172 | x1d = torch.cat((pad1(x1d), x12_1, x12_2), 1) 173 | x12d = self.do12d(F.relu(self.bn12d(self.conv12d(x1d)))) 174 | x11d = self.conv11d(x12d) 175 | 176 | #Softmax layer is embedded in the loss layer 177 | #out = self.sm(x11d) 178 | output = [] 179 | output.append(x11d) 180 | output = output[-1] 181 | 182 | return output 183 | -------------------------------------------------------------------------------- /compare/FC_Siam_diff.py: -------------------------------------------------------------------------------- 1 | # Rodrigo Caye Daudt 2 | # https://rcdaudt.github.io/ 3 | # Daudt, R. C., Le Saux, B., & Boulch, A. "Fully convolutional siamese networks for change detection". In 2018 25th IEEE International Conference on Image Processing (ICIP) (pp. 4063-4067). IEEE. 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from torch.nn.modules.padding import ReplicationPad2d 9 | 10 | class SiamUnet_diff(nn.Module): 11 | """SiamUnet_diff segmentation network.""" 12 | 13 | def __init__(self, input_nbr, label_nbr): 14 | super(SiamUnet_diff, self).__init__() 15 | 16 | self.input_nbr = input_nbr 17 | 18 | self.conv11 = nn.Conv2d(input_nbr, 16, kernel_size=3, padding=1) 19 | self.bn11 = nn.BatchNorm2d(16) 20 | self.do11 = nn.Dropout2d(p=0.01) #原始p=0.2 21 | self.conv12 = nn.Conv2d(16, 16, kernel_size=3, padding=1) 22 | self.bn12 = nn.BatchNorm2d(16) 23 | self.do12 = nn.Dropout2d(p=0.01) 24 | 25 | self.conv21 = nn.Conv2d(16, 32, kernel_size=3, padding=1) 26 | self.bn21 = nn.BatchNorm2d(32) 27 | self.do21 = nn.Dropout2d(p=0.01) 28 | self.conv22 = nn.Conv2d(32, 32, kernel_size=3, padding=1) 29 | self.bn22 = nn.BatchNorm2d(32) 30 | self.do22 = nn.Dropout2d(p=0.01) 31 | 32 | self.conv31 = nn.Conv2d(32, 64, kernel_size=3, padding=1) 33 | self.bn31 = nn.BatchNorm2d(64) 34 | self.do31 = nn.Dropout2d(p=0.01) 35 | self.conv32 = nn.Conv2d(64, 64, kernel_size=3, padding=1) 36 | self.bn32 = nn.BatchNorm2d(64) 37 | self.do32 = nn.Dropout2d(p=0.01) 38 | self.conv33 = nn.Conv2d(64, 64, kernel_size=3, padding=1) 39 | self.bn33 = nn.BatchNorm2d(64) 40 | self.do33 = nn.Dropout2d(p=0.01) 41 | 42 | self.conv41 = nn.Conv2d(64, 128, kernel_size=3, padding=1) 43 | self.bn41 = nn.BatchNorm2d(128) 44 | self.do41 = nn.Dropout2d(p=0.01) 45 | self.conv42 = nn.Conv2d(128, 128, kernel_size=3, padding=1) 46 | self.bn42 = nn.BatchNorm2d(128) 47 | self.do42 = nn.Dropout2d(p=0.01) 48 | self.conv43 = nn.Conv2d(128, 128, kernel_size=3, padding=1) 49 | self.bn43 = nn.BatchNorm2d(128) 50 | self.do43 = nn.Dropout2d(p=0.01) 51 | 52 | self.upconv4 = nn.ConvTranspose2d(128, 128, kernel_size=3, padding=1, stride=2, output_padding=1) 53 | 54 | self.conv43d = nn.ConvTranspose2d(256, 128, kernel_size=3, padding=1) 55 | self.bn43d = nn.BatchNorm2d(128) 56 | self.do43d = nn.Dropout2d(p=0.01) 57 | self.conv42d = nn.ConvTranspose2d(128, 128, kernel_size=3, padding=1) 58 | self.bn42d = nn.BatchNorm2d(128) 59 | self.do42d = nn.Dropout2d(p=0.01) 60 | self.conv41d = nn.ConvTranspose2d(128, 64, kernel_size=3, padding=1) 61 | self.bn41d = nn.BatchNorm2d(64) 62 | self.do41d = nn.Dropout2d(p=0.01) 63 | 64 | self.upconv3 = nn.ConvTranspose2d(64, 64, kernel_size=3, padding=1, stride=2, output_padding=1) 65 | 66 | self.conv33d = nn.ConvTranspose2d(128, 64, kernel_size=3, padding=1) 67 | self.bn33d = nn.BatchNorm2d(64) 68 | self.do33d = nn.Dropout2d(p=0.01) 69 | self.conv32d = nn.ConvTranspose2d(64, 64, kernel_size=3, padding=1) 70 | self.bn32d = nn.BatchNorm2d(64) 71 | self.do32d = nn.Dropout2d(p=0.01) 72 | self.conv31d = nn.ConvTranspose2d(64, 32, kernel_size=3, padding=1) 73 | self.bn31d = nn.BatchNorm2d(32) 74 | self.do31d = nn.Dropout2d(p=0.01) 75 | 76 | self.upconv2 = nn.ConvTranspose2d(32, 32, kernel_size=3, padding=1, stride=2, output_padding=1) 77 | 78 | self.conv22d = nn.ConvTranspose2d(64, 32, kernel_size=3, padding=1) 79 | self.bn22d = nn.BatchNorm2d(32) 80 | self.do22d = nn.Dropout2d(p=0.01) 81 | self.conv21d = nn.ConvTranspose2d(32, 16, kernel_size=3, padding=1) 82 | self.bn21d = nn.BatchNorm2d(16) 83 | self.do21d = nn.Dropout2d(p=0.01) 84 | 85 | self.upconv1 = nn.ConvTranspose2d(16, 16, kernel_size=3, padding=1, stride=2, output_padding=1) 86 | 87 | self.conv12d = nn.ConvTranspose2d(32, 16, kernel_size=3, padding=1) 88 | self.bn12d = nn.BatchNorm2d(16) 89 | self.do12d = nn.Dropout2d(p=0.01) 90 | self.conv11d = nn.ConvTranspose2d(16, label_nbr, kernel_size=3, padding=1) 91 | 92 | self.sm = nn.LogSoftmax(dim=1) 93 | 94 | def forward(self, x1, x2): 95 | 96 | LeakyReLU =nn.LeakyReLU(negative_slope=5e-2) 97 | """Forward method.""" 98 | # Stage 1 99 | x11 = self.do11(F.relu(self.bn11(self.conv11(x1)))) #原始为F.relu 100 | x12_1 = self.do12(F.relu(self.bn12(self.conv12(x11)))) 101 | x1p = F.max_pool2d(x12_1, kernel_size=2, stride=2) 102 | 103 | 104 | # Stage 2 105 | x21 = self.do21(F.relu(self.bn21(self.conv21(x1p)))) 106 | x22_1 = self.do22(F.relu(self.bn22(self.conv22(x21)))) 107 | x2p = F.max_pool2d(x22_1, kernel_size=2, stride=2) 108 | 109 | # Stage 3 110 | x31 = self.do31(F.relu(self.bn31(self.conv31(x2p)))) 111 | x32 = self.do32(F.relu(self.bn32(self.conv32(x31)))) 112 | x33_1 = self.do33(F.relu(self.bn33(self.conv33(x32)))) 113 | x3p = F.max_pool2d(x33_1, kernel_size=2, stride=2) 114 | 115 | # Stage 4 116 | x41 = self.do41(F.relu(self.bn41(self.conv41(x3p)))) 117 | x42 = self.do42(F.relu(self.bn42(self.conv42(x41)))) 118 | x43_1 = self.do43(F.relu(self.bn43(self.conv43(x42)))) 119 | x4p = F.max_pool2d(x43_1, kernel_size=2, stride=2) 120 | 121 | #################################################### 122 | # Stage 1 123 | x11 = self.do11(F.relu(self.bn11(self.conv11(x2)))) 124 | x12_2 = self.do12(F.relu(self.bn12(self.conv12(x11)))) 125 | x1p = F.max_pool2d(x12_2, kernel_size=2, stride=2) 126 | 127 | 128 | # Stage 2 129 | x21 = self.do21(F.relu(self.bn21(self.conv21(x1p)))) 130 | x22_2 = self.do22(F.relu(self.bn22(self.conv22(x21)))) 131 | x2p = F.max_pool2d(x22_2, kernel_size=2, stride=2) 132 | 133 | # Stage 3 134 | x31 = self.do31(F.relu(self.bn31(self.conv31(x2p)))) 135 | x32 = self.do32(F.relu(self.bn32(self.conv32(x31)))) 136 | x33_2 = self.do33(F.relu(self.bn33(self.conv33(x32)))) 137 | x3p = F.max_pool2d(x33_2, kernel_size=2, stride=2) 138 | 139 | # Stage 4 140 | x41 = self.do41(F.relu(self.bn41(self.conv41(x3p)))) 141 | x42 = self.do42(F.relu(self.bn42(self.conv42(x41)))) 142 | x43_2 = self.do43(F.relu(self.bn43(self.conv43(x42)))) 143 | x4p = F.max_pool2d(x43_2, kernel_size=2, stride=2) 144 | 145 | 146 | 147 | # Stage 4d 148 | x4d = self.upconv4(x4p) 149 | pad4 = ReplicationPad2d((0, x43_1.size(3) - x4d.size(3), 0, x43_1.size(2) - x4d.size(2))) 150 | x4d = torch.cat((pad4(x4d), torch.abs(x43_1 - x43_2)), 1) 151 | x43d = self.do43d(F.relu(self.bn43d(self.conv43d(x4d)))) 152 | x42d = self.do42d(F.relu(self.bn42d(self.conv42d(x43d)))) 153 | x41d = self.do41d(F.relu(self.bn41d(self.conv41d(x42d)))) 154 | 155 | # Stage 3d 156 | x3d = self.upconv3(x41d) 157 | pad3 = ReplicationPad2d((0, x33_1.size(3) - x3d.size(3), 0, x33_1.size(2) - x3d.size(2))) 158 | x3d = torch.cat((pad3(x3d), torch.abs(x33_1 - x33_2)), 1) 159 | x33d = self.do33d(F.relu(self.bn33d(self.conv33d(x3d)))) 160 | x32d = self.do32d(F.relu(self.bn32d(self.conv32d(x33d)))) 161 | x31d = self.do31d(F.relu(self.bn31d(self.conv31d(x32d)))) 162 | 163 | # Stage 2d 164 | x2d = self.upconv2(x31d) 165 | pad2 = ReplicationPad2d((0, x22_1.size(3) - x2d.size(3), 0, x22_1.size(2) - x2d.size(2))) 166 | x2d = torch.cat((pad2(x2d), torch.abs(x22_1 - x22_2)), 1) 167 | x22d = self.do22d(F.relu(self.bn22d(self.conv22d(x2d)))) 168 | x21d = self.do21d(F.relu(self.bn21d(self.conv21d(x22d)))) 169 | 170 | # Stage 1d 171 | x1d = self.upconv1(x21d) 172 | pad1 = ReplicationPad2d((0, x12_1.size(3) - x1d.size(3), 0, x12_1.size(2) - x1d.size(2))) 173 | x1d = torch.cat((pad1(x1d), torch.abs(x12_1 - x12_2)), 1) 174 | x12d = self.do12d(F.relu(self.bn12d(self.conv12d(x1d)))) 175 | x11d = self.conv11d(x12d) 176 | #out = self.sm(x11d) 177 | 178 | output = [] 179 | output.append(x11d) 180 | output = output[-1] 181 | 182 | return output -------------------------------------------------------------------------------- /utils/transforms.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import random 3 | import numpy as np 4 | 5 | from PIL import Image, ImageOps, ImageFilter 6 | import torchvision.transforms as transforms 7 | 8 | class Normalize(object): 9 | """Normalize a tensor image with mean and standard deviation. 10 | Args: 11 | mean (tuple): means for each channel. 12 | std (tuple): standard deviations for each channel. 13 | """ 14 | def __init__(self, mean=(0., 0., 0.), std=(1., 1., 1.)): 15 | self.mean = mean 16 | self.std = std 17 | 18 | def __call__(self, sample): 19 | img = sample['image'] 20 | mask = sample['label'] 21 | img = np.array(img).astype(np.float32) 22 | mask = np.array(mask).astype(np.float32) 23 | img /= 255.0 24 | img -= self.mean 25 | img /= self.std 26 | 27 | return {'image': img, 28 | 'label': mask} 29 | 30 | 31 | class ToTensor(object): 32 | """Convert ndarrays in sample to Tensors.""" 33 | 34 | def __call__(self, sample): 35 | # swap color axis because 36 | # numpy image: H x W x C 37 | # torch image: C X H X W 38 | img1 = sample['image'][0] 39 | img2 = sample['image'][1] 40 | mask = sample['label'] 41 | img1 = np.array(img1).astype(np.float32).transpose((2, 0, 1)) 42 | img2 = np.array(img2).astype(np.float32).transpose((2, 0, 1)) 43 | mask = np.array(mask).astype(np.float32) / 255.0 44 | 45 | img1 = torch.from_numpy(img1).float() 46 | img2 = torch.from_numpy(img2).float() 47 | mask = torch.from_numpy(mask).float() 48 | 49 | return {'image': (img1, img2), 50 | 'label': mask} 51 | 52 | 53 | class RandomHorizontalFlip(object): 54 | def __call__(self, sample): 55 | img1 = sample['image'][0] 56 | img2 = sample['image'][1] 57 | mask = sample['label'] 58 | if random.random() < 0.5: 59 | img1 = img1.transpose(Image.FLIP_LEFT_RIGHT) 60 | img2 = img2.transpose(Image.FLIP_LEFT_RIGHT) 61 | mask = mask.transpose(Image.FLIP_LEFT_RIGHT) 62 | 63 | return {'image': (img1, img2), 64 | 'label': mask} 65 | 66 | class RandomVerticalFlip(object): 67 | def __call__(self, sample): 68 | img1 = sample['image'][0] 69 | img2 = sample['image'][1] 70 | mask = sample['label'] 71 | if random.random() < 0.5: 72 | img1 = img1.transpose(Image.FLIP_TOP_BOTTOM) 73 | img2 = img2.transpose(Image.FLIP_TOP_BOTTOM) 74 | mask = mask.transpose(Image.FLIP_TOP_BOTTOM) 75 | 76 | return {'image': (img1, img2), 77 | 'label': mask} 78 | 79 | class RandomFixRotate(object): 80 | def __init__(self): 81 | self.degree = [Image.ROTATE_90, Image.ROTATE_180, Image.ROTATE_270] 82 | 83 | def __call__(self, sample): 84 | img1 = sample['image'][0] 85 | img2 = sample['image'][1] 86 | mask = sample['label'] 87 | if random.random() < 0.75: 88 | rotate_degree = random.choice(self.degree) 89 | img1 = img1.transpose(rotate_degree) 90 | img2 = img2.transpose(rotate_degree) 91 | mask = mask.transpose(rotate_degree) 92 | 93 | return {'image': (img1, img2), 94 | 'label': mask} 95 | 96 | 97 | class RandomRotate(object): 98 | def __init__(self, degree): 99 | self.degree = degree 100 | 101 | def __call__(self, sample): 102 | img1 = sample['image'][0] 103 | img2 = sample['image'][1] 104 | mask = sample['label'] 105 | rotate_degree = random.uniform(-1*self.degree, self.degree) 106 | img1 = img1.rotate(rotate_degree, Image.BILINEAR) 107 | img2 = img2.rotate(rotate_degree, Image.BILINEAR) 108 | mask = mask.rotate(rotate_degree, Image.NEAREST) 109 | 110 | return {'image': (img1, img2), 111 | 'label': mask} 112 | 113 | 114 | class RandomGaussianBlur(object): 115 | def __call__(self, sample): 116 | img1 = sample['image'][0] 117 | img2 = sample['image'][1] 118 | mask = sample['label'] 119 | if random.random() < 0.5: 120 | img1 = img1.filter(ImageFilter.GaussianBlur( 121 | radius=random.random())) 122 | img2 = img2.filter(ImageFilter.GaussianBlur( 123 | radius=random.random())) 124 | 125 | return {'image': (img1, img2), 126 | 'label': mask} 127 | 128 | 129 | class RandomScaleCrop(object): 130 | def __init__(self, base_size, crop_size, fill=0): 131 | self.base_size = base_size 132 | self.crop_size = crop_size 133 | self.fill = fill 134 | 135 | def __call__(self, sample): 136 | img = sample['image'] 137 | mask = sample['label'] 138 | # random scale (short edge) 139 | short_size = random.randint(int(self.base_size * 0.5), int(self.base_size * 2.0)) 140 | w, h = img.size 141 | if h > w: 142 | ow = short_size 143 | oh = int(1.0 * h * ow / w) 144 | else: 145 | oh = short_size 146 | ow = int(1.0 * w * oh / h) 147 | img = img.resize((ow, oh), Image.BILINEAR) 148 | mask = mask.resize((ow, oh), Image.NEAREST) 149 | # pad crop 150 | if short_size < self.crop_size: 151 | padh = self.crop_size - oh if oh < self.crop_size else 0 152 | padw = self.crop_size - ow if ow < self.crop_size else 0 153 | img = ImageOps.expand(img, border=(0, 0, padw, padh), fill=0) 154 | mask = ImageOps.expand(mask, border=(0, 0, padw, padh), fill=self.fill) 155 | # random crop crop_size 156 | w, h = img.size 157 | x1 = random.randint(0, w - self.crop_size) 158 | y1 = random.randint(0, h - self.crop_size) 159 | img = img.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size)) 160 | mask = mask.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size)) 161 | 162 | return {'image': img, 163 | 'label': mask} 164 | 165 | 166 | class FixScaleCrop(object): 167 | def __init__(self, crop_size): 168 | self.crop_size = crop_size 169 | 170 | def __call__(self, sample): 171 | img = sample['image'] 172 | mask = sample['label'] 173 | w, h = img.size 174 | if w > h: 175 | oh = self.crop_size 176 | ow = int(1.0 * w * oh / h) 177 | else: 178 | ow = self.crop_size 179 | oh = int(1.0 * h * ow / w) 180 | img = img.resize((ow, oh), Image.BILINEAR) 181 | mask = mask.resize((ow, oh), Image.NEAREST) 182 | # center crop 183 | w, h = img.size 184 | x1 = int(round((w - self.crop_size) / 2.)) 185 | y1 = int(round((h - self.crop_size) / 2.)) 186 | img = img.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size)) 187 | mask = mask.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size)) 188 | 189 | return {'image': img, 190 | 'label': mask} 191 | 192 | class FixedResize(object): 193 | def __init__(self, size): 194 | self.size = (size, size) # size: (h, w) 195 | 196 | def __call__(self, sample): 197 | img1 = sample['image'][0] 198 | img2 = sample['image'][1] 199 | mask = sample['label'] 200 | 201 | assert img1.size == mask.size and img2.size == mask.size 202 | 203 | img1 = img1.resize(self.size, Image.BILINEAR) 204 | img2 = img2.resize(self.size, Image.BILINEAR) 205 | mask = mask.resize(self.size, Image.NEAREST) 206 | 207 | return {'image': (img1, img2), 208 | 'label': mask} 209 | 210 | 211 | ''' 212 | We don't use Normalize here, because it will bring negative effects. 213 | the mask of ground truth is converted to [0,1] in ToTensor() function. 214 | ''' 215 | train_transforms = transforms.Compose([ 216 | RandomHorizontalFlip(), 217 | RandomVerticalFlip(), 218 | RandomFixRotate(), 219 | # RandomScaleCrop(base_size=self.args.base_size, crop_size=self.args.crop_size), 220 | # RandomGaussianBlur(), 221 | # Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), 222 | ToTensor()]) 223 | 224 | test_transforms = transforms.Compose([ 225 | # RandomHorizontalFlip(), 226 | # RandomVerticalFlip(), 227 | # RandomFixRotate(), 228 | # RandomScaleCrop(base_size=self.args.base_size, crop_size=self.args.crop_size), 229 | # RandomGaussianBlur(), 230 | # Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), 231 | ToTensor()]) -------------------------------------------------------------------------------- /compare/TFI_GR.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from .resnet_tfi import resnet18 5 | 6 | 7 | class TemporalFeatureInteractionModule(nn.Module): 8 | def __init__(self, in_d, out_d): 9 | super(TemporalFeatureInteractionModule, self).__init__() 10 | self.in_d = in_d 11 | self.out_d = out_d 12 | self.conv_sub = nn.Sequential( 13 | nn.Conv2d(self.in_d, self.in_d, kernel_size=3, stride=1, padding=1), 14 | nn.BatchNorm2d(self.in_d), 15 | nn.ReLU(inplace=True) 16 | ) 17 | self.conv_diff_enh1 = nn.Sequential( 18 | nn.Conv2d(self.in_d, self.in_d, kernel_size=3, stride=1, padding=1), 19 | nn.BatchNorm2d(self.in_d), 20 | nn.ReLU(inplace=True) 21 | ) 22 | self.conv_diff_enh2 = nn.Sequential( 23 | nn.Conv2d(self.in_d, self.in_d, kernel_size=3, stride=1, padding=1), 24 | nn.BatchNorm2d(self.in_d), 25 | nn.ReLU(inplace=True) 26 | ) 27 | self.conv_cat = nn.Sequential( 28 | nn.Conv2d(self.in_d * 2, self.in_d, kernel_size=3, stride=1, padding=1), 29 | nn.BatchNorm2d(self.in_d), 30 | nn.ReLU(inplace=True) 31 | ) 32 | self.conv_dr = nn.Sequential( 33 | nn.Conv2d(self.in_d, self.out_d, kernel_size=1, bias=True), 34 | nn.BatchNorm2d(self.out_d), 35 | nn.ReLU(inplace=True) 36 | ) 37 | 38 | def forward(self, x1, x2): 39 | # difference enhance 40 | x_sub = self.conv_sub(torch.abs(x1 - x2)) 41 | x1 = self.conv_diff_enh1(x1.mul(x_sub) + x1) 42 | x2 = self.conv_diff_enh2(x2.mul(x_sub) + x2) 43 | # fusion 44 | x_f = torch.cat([x1, x2], dim=1) 45 | x_f = self.conv_cat(x_f) 46 | x = x_sub + x_f 47 | x = self.conv_dr(x) 48 | return x 49 | 50 | 51 | class ChannelAttention(nn.Module): 52 | def __init__(self, in_planes, ratio=16): 53 | super(ChannelAttention, self).__init__() 54 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 55 | self.max_pool = nn.AdaptiveMaxPool2d(1) 56 | 57 | self.fc1 = nn.Conv2d(in_planes, in_planes // ratio, 1, bias=False) 58 | self.relu1 = nn.ReLU() 59 | self.fc2 = nn.Conv2d(in_planes // ratio, in_planes, 1, bias=False) 60 | self.sigmoid = nn.Sigmoid() 61 | 62 | def forward(self, x): 63 | avg_out = self.fc2(self.relu1(self.fc1(self.avg_pool(x)))) 64 | max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x)))) 65 | out = avg_out + max_out 66 | return self.sigmoid(out) 67 | 68 | 69 | class ChangeInformationExtractionModule(nn.Module): 70 | def __init__(self, in_d, out_d): 71 | super(ChangeInformationExtractionModule, self).__init__() 72 | self.in_d = in_d 73 | self.out_d = out_d 74 | self.ca = ChannelAttention(self.in_d * 4, ratio=16) 75 | self.conv_dr = nn.Sequential( 76 | nn.Conv2d(self.in_d * 4, self.in_d, kernel_size=3, stride=1, padding=1, bias=False), 77 | nn.BatchNorm2d(self.in_d), 78 | nn.ReLU(inplace=True) 79 | ) 80 | self.pools_sizes = [2, 4, 8] 81 | self.conv_pool1 = nn.Sequential( 82 | nn.AvgPool2d(kernel_size=self.pools_sizes[0], stride=self.pools_sizes[0]), 83 | nn.Conv2d(self.in_d, self.in_d, kernel_size=3, stride=1, padding=1, bias=False) 84 | ) 85 | self.conv_pool2 = nn.Sequential( 86 | nn.AvgPool2d(kernel_size=self.pools_sizes[1], stride=self.pools_sizes[1]), 87 | nn.Conv2d(self.in_d, self.in_d, kernel_size=3, stride=1, padding=1, bias=False) 88 | ) 89 | self.conv_pool3 = nn.Sequential( 90 | nn.AvgPool2d(kernel_size=self.pools_sizes[2], stride=self.pools_sizes[2]), 91 | nn.Conv2d(self.in_d, self.in_d, kernel_size=3, stride=1, padding=1, bias=False) 92 | ) 93 | 94 | def forward(self, d5, d4, d3, d2): 95 | # upsampling 96 | d5 = F.interpolate(d5, d2.size()[2:], mode='bilinear', align_corners=True) 97 | d4 = F.interpolate(d4, d2.size()[2:], mode='bilinear', align_corners=True) 98 | d3 = F.interpolate(d3, d2.size()[2:], mode='bilinear', align_corners=True) 99 | # fusion 100 | x = torch.cat([d5, d4, d3, d2], dim=1) 101 | x_ca = self.ca(x) 102 | x = x * x_ca 103 | x = self.conv_dr(x) 104 | 105 | # feature = x[0:1, 0:64, 0:64, 0:64] 106 | # vis.visulize_features(feature) 107 | 108 | # pooling 109 | d2 = x 110 | d3 = self.conv_pool1(x) 111 | d4 = self.conv_pool2(x) 112 | d5 = self.conv_pool3(x) 113 | 114 | return d5, d4, d3, d2 115 | 116 | 117 | class GuidedRefinementModule(nn.Module): 118 | def __init__(self, in_d, out_d): 119 | super(GuidedRefinementModule, self).__init__() 120 | self.in_d = in_d 121 | self.out_d = out_d 122 | self.conv_d5 = nn.Sequential( 123 | nn.Conv2d(self.in_d, self.out_d, kernel_size=3, stride=1, padding=1), 124 | nn.BatchNorm2d(self.out_d), 125 | nn.ReLU(inplace=True) 126 | ) 127 | self.conv_d4 = nn.Sequential( 128 | nn.Conv2d(self.in_d, self.out_d, kernel_size=3, stride=1, padding=1), 129 | nn.BatchNorm2d(self.out_d), 130 | nn.ReLU(inplace=True) 131 | ) 132 | self.conv_d3 = nn.Sequential( 133 | nn.Conv2d(self.in_d, self.out_d, kernel_size=3, stride=1, padding=1), 134 | nn.BatchNorm2d(self.out_d), 135 | nn.ReLU(inplace=True) 136 | ) 137 | self.conv_d2 = nn.Sequential( 138 | nn.Conv2d(self.in_d, self.out_d, kernel_size=3, stride=1, padding=1), 139 | nn.BatchNorm2d(self.out_d), 140 | nn.ReLU(inplace=True) 141 | ) 142 | 143 | def forward(self, d5, d4, d3, d2, d5_p, d4_p, d3_p, d2_p): 144 | # feature refinement 145 | d5 = self.conv_d5(d5_p + d5) 146 | d4 = self.conv_d4(d4_p + d4) 147 | d3 = self.conv_d3(d3_p + d3) 148 | d2 = self.conv_d2(d2_p + d2) 149 | 150 | return d5, d4, d3, d2 151 | 152 | 153 | class Decoder(nn.Module): 154 | def __init__(self, in_d, out_d): 155 | super(Decoder, self).__init__() 156 | self.in_d = in_d 157 | self.out_d = out_d 158 | self.conv_sum1 = nn.Sequential( 159 | nn.Conv2d(self.in_d, self.in_d, kernel_size=3, stride=1, padding=1), 160 | nn.BatchNorm2d(self.in_d), 161 | nn.ReLU(inplace=True) 162 | ) 163 | self.conv_sum2 = nn.Sequential( 164 | nn.Conv2d(self.in_d, self.in_d, kernel_size=3, stride=1, padding=1), 165 | nn.BatchNorm2d(self.in_d), 166 | nn.ReLU(inplace=True) 167 | ) 168 | self.conv_sum3 = nn.Sequential( 169 | nn.Conv2d(self.in_d, self.in_d, kernel_size=3, stride=1, padding=1), 170 | nn.BatchNorm2d(self.in_d), 171 | nn.ReLU(inplace=True) 172 | ) 173 | self.cls = nn.Conv2d(self.in_d, self.out_d, kernel_size=1, bias=False) 174 | 175 | def forward(self, d5, d4, d3, d2): 176 | 177 | d5 = F.interpolate(d5, d4.size()[2:], mode='bilinear', align_corners=True) 178 | d4 = self.conv_sum1(d4 + d5) 179 | d4 = F.interpolate(d4, d3.size()[2:], mode='bilinear', align_corners=True) 180 | d3 = self.conv_sum1(d3 + d4) 181 | d3 = F.interpolate(d3, d2.size()[2:], mode='bilinear', align_corners=True) 182 | d2 = self.conv_sum1(d2 + d3) 183 | 184 | mask = self.cls(d2) 185 | 186 | return mask 187 | 188 | 189 | class TFI_GR(nn.Module): 190 | def __init__(self, input_nc, output_nc): 191 | super(TFI_GR, self).__init__() 192 | self.backbone = resnet18(pretrained=True) 193 | self.mid_d = 64 194 | self.TFIM5 = TemporalFeatureInteractionModule(512, self.mid_d) 195 | self.TFIM4 = TemporalFeatureInteractionModule(256, self.mid_d) 196 | self.TFIM3 = TemporalFeatureInteractionModule(128, self.mid_d) 197 | self.TFIM2 = TemporalFeatureInteractionModule(64, self.mid_d) 198 | 199 | self.CIEM1 = ChangeInformationExtractionModule(self.mid_d, output_nc) 200 | self.GRM1 = GuidedRefinementModule(self.mid_d, self.mid_d) 201 | 202 | self.CIEM2 = ChangeInformationExtractionModule(self.mid_d, output_nc) 203 | self.GRM2 = GuidedRefinementModule(self.mid_d, self.mid_d) 204 | 205 | self.decoder = Decoder(self.mid_d, output_nc) 206 | 207 | def forward(self, x1, x2): 208 | # forward backbone resnet 209 | x1_1, x1_2, x1_3, x1_4, x1_5 = self.backbone.base_forward(x1) 210 | x2_1, x2_2, x2_3, x2_4, x2_5 = self.backbone.base_forward(x2) 211 | # feature difference 212 | d5 = self.TFIM5(x1_5, x2_5) # 1/32 213 | d4 = self.TFIM4(x1_4, x2_4) # 1/16 214 | d3 = self.TFIM3(x1_3, x2_3) # 1/8 215 | d2 = self.TFIM2(x1_2, x2_2) # 1/4 216 | 217 | # change information guided refinement 1 218 | d5_p, d4_p, d3_p, d2_p = self.CIEM1(d5, d4, d3, d2) 219 | d5, d4, d3, d2 = self.GRM1(d5, d4, d3, d2, d5_p, d4_p, d3_p, d2_p) 220 | 221 | # change information guided refinement 2 222 | d5_p, d4_p, d3_p, d2_p = self.CIEM2(d5, d4, d3, d2) 223 | d5, d4, d3, d2 = self.GRM2(d5, d4, d3, d2, d5_p, d4_p, d3_p, d2_p) 224 | 225 | # decoder 226 | mask = self.decoder(d5, d4, d3, d2) 227 | mask = F.interpolate(mask, x1.size()[2:], mode='bilinear', align_corners=True) 228 | # mask = torch.sigmoid(mask) 229 | 230 | return mask 231 | -------------------------------------------------------------------------------- /compare/resnet_tfi.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | # from torchvision.models.utils import load_state_dict_from_url 4 | from torch.hub import load_state_dict_from_url 5 | 6 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 7 | 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d'] 8 | 9 | model_urls = { 10 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 11 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 12 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 13 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 14 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 15 | 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', 16 | 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth', 17 | 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth', 18 | 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth', 19 | } 20 | 21 | 22 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 23 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 24 | padding=dilation, groups=groups, bias=False, dilation=dilation) 25 | 26 | 27 | def conv1x1(in_planes, out_planes, stride=1): 28 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 29 | 30 | 31 | class BasicBlock(nn.Module): 32 | expansion = 1 33 | 34 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 35 | base_width=64, dilation=1, norm_layer=None): 36 | super(BasicBlock, self).__init__() 37 | if norm_layer is None: 38 | norm_layer = nn.BatchNorm2d 39 | if groups != 1 or base_width != 64: 40 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 41 | 42 | self.conv1 = conv3x3(inplanes, planes, stride, dilation=dilation) 43 | self.bn1 = norm_layer(planes) 44 | self.relu = nn.ReLU(inplace=True) 45 | self.conv2 = conv3x3(planes, planes) 46 | self.bn2 = norm_layer(planes) 47 | self.downsample = downsample 48 | self.stride = stride 49 | 50 | def forward(self, x): 51 | identity = x 52 | 53 | out = self.conv1(x) 54 | out = self.bn1(out) 55 | out = self.relu(out) 56 | 57 | out = self.conv2(out) 58 | out = self.bn2(out) 59 | 60 | if self.downsample is not None: 61 | identity = self.downsample(x) 62 | 63 | out += identity 64 | out = self.relu(out) 65 | 66 | return out 67 | 68 | 69 | class Bottleneck(nn.Module): 70 | expansion = 4 71 | 72 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 73 | base_width=64, dilation=1, norm_layer=None): 74 | super(Bottleneck, self).__init__() 75 | if norm_layer is None: 76 | norm_layer = nn.BatchNorm2d 77 | width = int(planes * (base_width / 64.)) * groups 78 | 79 | self.conv1 = conv1x1(inplanes, width) 80 | self.bn1 = norm_layer(width) 81 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 82 | self.bn2 = norm_layer(width) 83 | self.conv3 = conv1x1(width, planes * self.expansion) 84 | self.bn3 = norm_layer(planes * self.expansion) 85 | self.relu = nn.ReLU(inplace=True) 86 | self.downsample = downsample 87 | self.stride = stride 88 | 89 | def forward(self, x): 90 | identity = x 91 | 92 | out = self.conv1(x) 93 | out = self.bn1(out) 94 | out = self.relu(out) 95 | 96 | out = self.conv2(out) 97 | out = self.bn2(out) 98 | out = self.relu(out) 99 | 100 | out = self.conv3(out) 101 | out = self.bn3(out) 102 | 103 | if self.downsample is not None: 104 | identity = self.downsample(x) 105 | 106 | out += identity 107 | out = self.relu(out) 108 | 109 | return out 110 | 111 | 112 | class ResNet(nn.Module): 113 | 114 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, groups=1, 115 | width_per_group=64, replace_stride_with_dilation=None, norm_layer=None): 116 | super(ResNet, self).__init__() 117 | 118 | self.channels = [64, 64 * block.expansion, 128 * block.expansion, 119 | 256 * block.expansion, 512 * block.expansion] 120 | 121 | if norm_layer is None: 122 | norm_layer = nn.BatchNorm2d 123 | self._norm_layer = norm_layer 124 | 125 | self.inplanes = 64 126 | self.dilation = 1 127 | if replace_stride_with_dilation is None: 128 | replace_stride_with_dilation = [False, False, False] 129 | if len(replace_stride_with_dilation) != 3: 130 | raise ValueError('replace_stride_with_dilation should be None ' 131 | 'or a 3-element tuple, got {}'.format(replace_stride_with_dilation)) 132 | self.groups = groups 133 | self.base_width = width_per_group 134 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, 135 | bias=False) 136 | self.bn1 = norm_layer(self.inplanes) 137 | self.relu = nn.ReLU(inplace=True) 138 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 139 | self.layer1 = self._make_layer(block, 64, layers[0]) 140 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, 141 | dilate=replace_stride_with_dilation[0]) 142 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, 143 | dilate=replace_stride_with_dilation[1]) 144 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, 145 | dilate=replace_stride_with_dilation[2]) 146 | 147 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 148 | self.fc = nn.Linear(512 * block.expansion, num_classes) 149 | 150 | for m in self.modules(): 151 | if isinstance(m, nn.Conv2d): 152 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 153 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 154 | nn.init.constant_(m.weight, 1) 155 | nn.init.constant_(m.bias, 0) 156 | 157 | if zero_init_residual: 158 | for m in self.modules(): 159 | if isinstance(m, Bottleneck): 160 | nn.init.constant_(m.bn3.weight, 0) 161 | elif isinstance(m, BasicBlock): 162 | nn.init.constant_(m.bn2.weight, 0) 163 | 164 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 165 | norm_layer = self._norm_layer 166 | downsample = None 167 | previous_dilation = self.dilation 168 | if dilate: 169 | self.dilation *= stride 170 | stride = 1 171 | if stride != 1 or self.inplanes != planes * block.expansion: 172 | downsample = nn.Sequential( 173 | conv1x1(self.inplanes, planes * block.expansion, stride), 174 | norm_layer(planes * block.expansion), 175 | ) 176 | 177 | layers = list() 178 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 179 | self.base_width, previous_dilation, norm_layer)) 180 | self.inplanes = planes * block.expansion 181 | for _ in range(1, blocks): 182 | layers.append(block(self.inplanes, planes, groups=self.groups, 183 | base_width=self.base_width, dilation=self.dilation, 184 | norm_layer=norm_layer)) 185 | 186 | return nn.Sequential(*layers) 187 | 188 | def base_forward(self, x): 189 | x = self.conv1(x) 190 | x = self.bn1(x) 191 | c0 = self.relu(x) 192 | c1 = self.maxpool(c0) 193 | 194 | c1 = self.layer1(c1) 195 | c2 = self.layer2(c1) 196 | c3 = self.layer3(c2) 197 | c4 = self.layer4(c3) 198 | 199 | return c0, c1, c2, c3, c4 200 | 201 | def forward(self, x): 202 | x = self.base_forward(x)[-1] 203 | x = self.avgpool(x) 204 | x = torch.flatten(x, 1) 205 | x = self.fc(x) 206 | 207 | return x 208 | 209 | 210 | def _resnet(arch, block, layers, pretrained, **kwargs): 211 | model = ResNet(block, layers, **kwargs) 212 | if pretrained: 213 | state_dict = load_state_dict_from_url(model_urls[arch], 214 | progress=True) 215 | model.load_state_dict(state_dict) 216 | return model 217 | 218 | 219 | def resnet18(pretrained=False, **kwargs): 220 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, 221 | replace_stride_with_dilation=[False, False, False], **kwargs) 222 | 223 | 224 | def resnet34(pretrained=False, **kwargs): 225 | return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, 226 | replace_stride_with_dilation=[False, True, True], **kwargs) 227 | 228 | 229 | def resnet50(pretrained=False, **kwargs): 230 | return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, 231 | replace_stride_with_dilation=[False, True, True], **kwargs) 232 | 233 | 234 | def resnet101(pretrained=False, **kwargs): 235 | return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, 236 | replace_stride_with_dilation=[False, True, True], **kwargs) 237 | 238 | 239 | def resnet152(pretrained=False, **kwargs): 240 | return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, 241 | replace_stride_with_dilation=[False, True, True], **kwargs) 242 | 243 | 244 | def resnext50_32x4d(pretrained=False, **kwargs): 245 | kwargs['groups'] = 32 246 | kwargs['width_per_group'] = 4 247 | return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3], pretrained, 248 | replace_stride_with_dilation=[False, True, True], **kwargs) 249 | 250 | 251 | def resnext101_32x8d(pretrained=False, **kwargs): 252 | kwargs['groups'] = 32 253 | kwargs['width_per_group'] = 8 254 | return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3], pretrained, 255 | replace_stride_with_dilation=[False, True, True], **kwargs) -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from cProfile import label 2 | import datetime 3 | import torch 4 | from sklearn.metrics import precision_recall_fscore_support as prfs 5 | from utils.parser import get_parser_with_args 6 | from utils.helpers import (get_loaders, get_criterion, 7 | load_model, initialize_metrics, get_mean_metrics, 8 | set_metrics) 9 | import os 10 | import logging 11 | import json 12 | from tensorboardX import SummaryWriter 13 | from tqdm import tqdm 14 | import random 15 | import numpy as np 16 | from torch.optim import lr_scheduler 17 | 18 | def get_scheduler(optimizer, opt, lr_policy): 19 | """Return a learning rate scheduler 20 | Parameters: 21 | optimizer -- the optimizer of the network 22 | args (option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions.  23 | opt.lr_policy is the name of learning rate policy: linear | step | plateau | cosine 24 | For 'linear', we keep the same learning rate for the first epochs 25 | and linearly decay the rate to zero over the next epochs. 26 | For other schedulers (step, plateau, and cosine), we use the default PyTorch schedulers. 27 | See https://pytorch.org/docs/stable/optim.html for more details. 28 | """ 29 | if lr_policy == 'linear': 30 | def lambda_rule(epoch): 31 | lr_l = 1.0 - epoch / float(opt.epochs + 1) 32 | return lr_l 33 | scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule) 34 | elif lr_policy == 'step': 35 | step_size = opt.epochs//3 36 | # args.lr_decay_iters 37 | scheduler = lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=0.1) 38 | else: 39 | return NotImplementedError('learning rate policy [%s] is not implemented', lr_policy) 40 | return scheduler 41 | 42 | 43 | def seed_torch(seed): 44 | random.seed(seed) 45 | os.environ['PYTHONHASHSEED'] = str(seed) 46 | np.random.seed(seed) 47 | torch.manual_seed(seed) 48 | torch.cuda.manual_seed(seed) 49 | # torch.cuda.manual_seed_all(seed) # if you are using multi-GPU. 50 | torch.backends.cudnn.benchmark = False 51 | torch.backends.cudnn.deterministic = True 52 | 53 | 54 | if __name__ == '__main__': 55 | """ 56 | Initialize Parser and 57 | define arguments 58 | """ 59 | parser, metadata = get_parser_with_args() 60 | opt = parser.parse_args() 61 | 62 | opt.epochs = 100 63 | opt.batch_size = 16 64 | opt.loss_function = "bce" 65 | 66 | #checkpoints dir 67 | opt.checkpoint_dir= os.path.join(opt.path,opt.project_name) 68 | os.makedirs(opt.checkpoint_dir,exist_ok=True) 69 | # save_path = '.tmp' + '/' + opt.dataset + '_' + opt.backbone + '_' + opt.mode 70 | 71 | """ 72 | Initialize experiments log 73 | """ 74 | logging.basicConfig(level=logging.INFO) 75 | writer = SummaryWriter( 76 | '.tmp/log/' + opt.dataset + '_' + opt.backbone + '_' + opt.mode + '_' + f'{datetime.datetime.now().strftime("%Y%m%d-%H%M%S")}') 77 | 78 | """ 79 | Set up environment: define paths, download data, and set device 80 | """ 81 | # os.environ["CUDA_VISIBLE_DEVICES"] = "0" 82 | dev = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 83 | logging.info('GPU AVAILABLE? ' + str(torch.cuda.is_available())) 84 | 85 | seed_torch(seed=777) 86 | 87 | if opt.dataset == 'cdd': 88 | opt.dataset_dir = '../CDDataset/ChangeDetection/' 89 | elif opt.dataset == 'levir': 90 | opt.dataset_dir = '../CDDataset/LEVIR/' 91 | elif opt.dataset_dir == 'levir+': 92 | opt.dataset_dir = '../CDDataset/LEVIR-CD+_256' 93 | 94 | train_loader, val_loader = get_loaders(opt) 95 | 96 | """ 97 | Load Model then define other aspects of the model 98 | """ 99 | logging.info('LOADING Model') 100 | model = load_model(opt, dev) 101 | 102 | criterion = get_criterion(opt) 103 | if opt.backbone == 'resnet': 104 | optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.99, 105 | weight_decay=0.0005) # Be careful when you adjust learning rate, you can refer to the linear scaling rule 106 | # scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=8, gamma=0.5) 107 | scheduler = get_scheduler(optimizer, opt, 'linear') 108 | else: 109 | optimizer = torch.optim.AdamW(model.parameters(), lr=0.00006, betas=(0.9, 0.999), weight_decay=0.01) 110 | scheduler = get_scheduler(optimizer, opt, 'linear') 111 | 112 | """ 113 | Set starting values 114 | """ 115 | best_metrics = {'cd_f1scores': -1, 'cd_recalls': -1, 'cd_precisions': -1} 116 | logging.info('STARTING training') 117 | total_step = -1 118 | 119 | for epoch in range(opt.epochs): 120 | train_metrics = initialize_metrics() 121 | val_metrics = initialize_metrics() 122 | 123 | """ 124 | Begin Training 125 | """ 126 | model.train() 127 | logging.info('SET model mode to train!') 128 | batch_iter = 0 129 | tbar = tqdm(train_loader) 130 | for batch_img1, batch_img2, labels, fname in tbar: 131 | tbar.set_description( 132 | "epoch {} info ".format(epoch) + str(batch_iter) + " - " + str(batch_iter + opt.batch_size)) 133 | batch_iter = batch_iter + opt.batch_size 134 | total_step += 1 135 | # Set variables for training 136 | batch_img1 = batch_img1.float().to(dev) 137 | batch_img2 = batch_img2.float().to(dev) 138 | labels = labels.long().to(dev) 139 | 140 | # Zero the gradient 141 | optimizer.zero_grad() 142 | 143 | # Get model predictions, calculate loss, backprop 144 | cd_preds = model(batch_img1, batch_img2) 145 | 146 | cd_loss = criterion(cd_preds, labels) 147 | loss = cd_loss 148 | loss.backward() 149 | optimizer.step() 150 | 151 | # cd_preds = cd_preds[-1] # BIT输出不是tuple 152 | _, cd_preds = torch.max(cd_preds, 1) 153 | 154 | # Calculate and log other batch metrics 155 | cd_corrects = (100 * 156 | (cd_preds.squeeze().byte() == labels.squeeze().byte()).sum() / 157 | (labels.size()[0] * (opt.patch_size ** 2))) 158 | 159 | cd_train_report = prfs(labels.data.cpu().numpy().flatten(), 160 | cd_preds.data.cpu().numpy().flatten(), 161 | average='binary', 162 | pos_label=1, zero_division=0) 163 | 164 | train_metrics = set_metrics(train_metrics, 165 | cd_loss, 166 | cd_corrects, 167 | cd_train_report, 168 | scheduler.get_last_lr()) 169 | 170 | # log the batch mean metrics 171 | mean_train_metrics = get_mean_metrics(train_metrics) 172 | 173 | for k, v in mean_train_metrics.items(): 174 | writer.add_scalars(str(k), {'train': v}, total_step) 175 | 176 | # clear batch variables from memory 177 | del batch_img1, batch_img2, labels 178 | 179 | scheduler.step() 180 | logging.info("EPOCH {} TRAIN METRICS".format(epoch) + str(mean_train_metrics)) 181 | 182 | """ 183 | Begin Validation 184 | """ 185 | model.eval() 186 | with torch.no_grad(): 187 | for batch_img1, batch_img2, labels, fname in val_loader: 188 | # Set variables for training 189 | batch_img1 = batch_img1.float().to(dev) 190 | batch_img2 = batch_img2.float().to(dev) 191 | labels = labels.long().to(dev) 192 | 193 | # Get predictions and calculate loss 194 | cd_preds = model(batch_img1, batch_img2) 195 | 196 | cd_loss = criterion(cd_preds, labels) 197 | 198 | # cd_preds = cd_preds[-1] # BIT输出不是tuple 199 | _, cd_preds = torch.max(cd_preds, 1) 200 | 201 | # Calculate and log other batch metrics 202 | cd_corrects = (100 * 203 | (cd_preds.squeeze().byte() == labels.squeeze().byte()).sum() / 204 | (labels.size()[0] * (opt.patch_size ** 2))) 205 | 206 | cd_val_report = prfs(labels.data.cpu().numpy().flatten(), 207 | cd_preds.data.cpu().numpy().flatten(), 208 | average='binary', 209 | pos_label=1, zero_division=0) 210 | 211 | val_metrics = set_metrics(val_metrics, 212 | cd_loss, 213 | cd_corrects, 214 | cd_val_report, 215 | scheduler.get_last_lr()) 216 | 217 | # log the batch mean metrics 218 | mean_val_metrics = get_mean_metrics(val_metrics) 219 | 220 | for k, v in mean_train_metrics.items(): 221 | writer.add_scalars(str(k), {'val': v}, total_step) 222 | 223 | # clear batch variables from memory 224 | del batch_img1, batch_img2, labels 225 | 226 | logging.info("EPOCH {} VALIDATION METRICS".format(epoch) + str(mean_val_metrics)) 227 | 228 | """ 229 | Store the weights of good epochs based on validation results 230 | """ 231 | if ((mean_val_metrics['cd_precisions'] > best_metrics['cd_precisions']) 232 | or 233 | (mean_val_metrics['cd_recalls'] > best_metrics['cd_recalls']) 234 | or 235 | (mean_val_metrics['cd_f1scores'] > best_metrics['cd_f1scores'])): 236 | 237 | # Insert training and epoch information to metadata dictionary 238 | logging.info('updata the model') 239 | metadata['validation_metrics'] = mean_val_metrics 240 | 241 | # Save model and log 242 | 243 | if not os.path.exists(save_path): 244 | os.makedirs(save_path) 245 | # with open('./tmp/metadata_epoch_' + str(epoch) + '.json', 'w') as fout: 246 | # json.dump(metadata, fout) 247 | 248 | torch.save(model, save_path + '/checkpoint_epoch_' + str(epoch) + '.pth') 249 | 250 | # comet.log_asset(upload_metadata_file_path) 251 | best_metrics = mean_val_metrics 252 | 253 | print('An epoch finished.') 254 | writer.close() # close tensor board 255 | print('Done!') 256 | 257 | 258 | 259 | -------------------------------------------------------------------------------- /compare/DMINet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from .resnet import resnet18 4 | import torch.nn.functional as F 5 | import math 6 | import matplotlib.pyplot as plt 7 | import numpy as np 8 | import cv2 9 | 10 | 11 | def init_weights(m): 12 | """ 13 | Initialize weights of layers using Kaiming Normal (He et al.) as argument of "Apply" function of 14 | "nn.Module" 15 | :param m: Layer to initialize 16 | :return: None 17 | """ 18 | if isinstance(m, nn.Conv2d): 19 | ''' 20 | fan_in, _ = nn.init._calculate_fan_in_and_fan_out(m.weight) 21 | trunc_normal_(m.weight, std=math.sqrt(1.0/fan_in)/.87962566103423978) 22 | if m.bias is not None: 23 | nn.init.zeros_(m.bias) 24 | ''' 25 | nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu') 26 | if m.bias is not None: 27 | fan_in, _ = nn.init._calculate_fan_in_and_fan_out(m.weight) 28 | bound = 1 / math.sqrt(fan_in) 29 | nn.init.uniform_(m.bias, -bound, bound) 30 | 31 | elif isinstance(m, nn.BatchNorm2d): 32 | nn.init.constant_(m.weight, 1) 33 | nn.init.constant_(m.bias, 0) 34 | 35 | 36 | class Conv(nn.Module): 37 | def __init__(self, inp_dim, out_dim, kernel_size=3, stride=1, bn=False, relu=True, bias=True): 38 | super(Conv, self).__init__() 39 | self.inp_dim = inp_dim 40 | self.conv = nn.Conv2d(inp_dim, out_dim, kernel_size, stride, padding=(kernel_size - 1) // 2, bias=bias) 41 | self.relu = None 42 | self.bn = None 43 | if relu: 44 | self.relu = nn.ReLU(inplace=True) 45 | if bn: 46 | self.bn = nn.BatchNorm2d(out_dim) 47 | 48 | def forward(self, x): 49 | assert x.size()[1] == self.inp_dim, "{} {}".format(x.size()[1], self.inp_dim) 50 | # print("++",x.size()[1],self.inp_dim,x.size()[1],self.inp_dim) 51 | x = self.conv(x) 52 | if self.bn is not None: 53 | x = self.bn(x) 54 | if self.relu is not None: 55 | x = self.relu(x) 56 | return x 57 | 58 | 59 | class decode(nn.Module): 60 | def __init__(self, in_channel_left, in_channel_down, out_channel, norm_layer=nn.BatchNorm2d): 61 | super(decode, self).__init__() 62 | self.conv_d1 = nn.Conv2d(in_channel_down, out_channel, kernel_size=3, stride=1, padding=1) 63 | self.conv_l = nn.Conv2d(in_channel_left, out_channel, kernel_size=3, stride=1, padding=1) 64 | self.conv3 = nn.Conv2d(out_channel * 2, out_channel, kernel_size=3, stride=1, padding=1) 65 | self.bn3 = norm_layer(out_channel) 66 | 67 | def forward(self, left, down): 68 | down_mask = self.conv_d1(down) 69 | left_mask = self.conv_l(left) 70 | if down.size()[2:] != left.size()[2:]: 71 | down_ = F.interpolate(down, size=left.size()[2:], mode='bilinear') 72 | z1 = F.relu(left_mask * down_, inplace=True) 73 | else: 74 | z1 = F.relu(left_mask * down, inplace=True) 75 | 76 | if down_mask.size()[2:] != left.size()[2:]: 77 | down_mask = F.interpolate(down_mask, size=left.size()[2:], mode='bilinear') 78 | 79 | z2 = F.relu(down_mask * left, inplace=True) 80 | 81 | out = torch.cat((z1, z2), dim=1) 82 | return F.relu(self.bn3(self.conv3(out)), inplace=True) 83 | 84 | 85 | class BasicConv2d(nn.Module): 86 | def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1): 87 | super(BasicConv2d, self).__init__() 88 | 89 | self.conv = nn.Conv2d(in_planes, out_planes, 90 | kernel_size=kernel_size, stride=stride, 91 | padding=padding, dilation=dilation, bias=False) 92 | self.bn = nn.BatchNorm2d(out_planes) 93 | self.relu = nn.ReLU(inplace=True) 94 | 95 | def forward(self, x): 96 | x = self.conv(x) 97 | x = self.bn(x) 98 | return x 99 | 100 | 101 | class CrossAtt(nn.Module): 102 | def __init__(self, in_channels, out_channels): 103 | super().__init__() 104 | self.in_channels = in_channels 105 | 106 | self.query1 = nn.Conv2d(in_channels, in_channels // 8, kernel_size=1, stride=1) 107 | self.key1 = nn.Conv2d(in_channels, in_channels // 4, kernel_size=1, stride=1) 108 | self.value1 = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1) 109 | 110 | self.query2 = nn.Conv2d(in_channels, in_channels // 8, kernel_size=1, stride=1) 111 | self.key2 = nn.Conv2d(in_channels, in_channels // 4, kernel_size=1, stride=1) 112 | self.value2 = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1) 113 | 114 | self.gamma = nn.Parameter(torch.zeros(1)) 115 | self.softmax = nn.Softmax(dim=-1) 116 | 117 | self.conv_cat = nn.Sequential(nn.Conv2d(in_channels * 2, out_channels, 3, padding=1, bias=False), 118 | nn.BatchNorm2d(out_channels), 119 | nn.ReLU()) # conv_f 120 | 121 | def forward(self, input1, input2): 122 | batch_size, channels, height, width = input1.shape 123 | q1 = self.query1(input1) 124 | k1 = self.key1(input1).view(batch_size, -1, height * width) 125 | v1 = self.value1(input1).view(batch_size, -1, height * width) 126 | 127 | q2 = self.query2(input2) 128 | k2 = self.key2(input2).view(batch_size, -1, height * width) 129 | v2 = self.value2(input2).view(batch_size, -1, height * width) 130 | 131 | q = torch.cat([q1, q2], 1).view(batch_size, -1, height * width).permute(0, 2, 1) 132 | attn_matrix1 = torch.bmm(q, k1) 133 | attn_matrix1 = self.softmax(attn_matrix1) 134 | out1 = torch.bmm(v1, attn_matrix1.permute(0, 2, 1)) 135 | out1 = out1.view(*input1.shape) 136 | out1 = self.gamma * out1 + input1 137 | 138 | attn_matrix2 = torch.bmm(q, k2) 139 | attn_matrix2 = self.softmax(attn_matrix2) 140 | out2 = torch.bmm(v2, attn_matrix2.permute(0, 2, 1)) 141 | out2 = out2.view(*input2.shape) 142 | out2 = self.gamma * out2 + input2 143 | 144 | feat_sum = self.conv_cat(torch.cat([out1, out2], 1)) 145 | return feat_sum, out1, out2 146 | 147 | # 148 | # def draw_features(width=8, height=8, x, savename): 149 | # # tic=time.time() 150 | # fig = plt.figure(figsize=(60, 60)) 151 | # fig.subplots_adjust(left=0.05, right=0.95, bottom=0.05, top=0.95, wspace=0.05, hspace=0.05) 152 | # for i in range(width * height): 153 | # plt.subplot(height, width, i + 1) 154 | # plt.axis('off') 155 | # img = x[0, i, :, :] 156 | # pmin = np.min(img) 157 | # pmax = np.max(img) 158 | # img = ((img - pmin) / (pmax - pmin + 0.000001)) * 255 # float在[0,1]之间,转换成0-255 159 | # img = img.astype(np.uint8) # 转成unit8 160 | # img = cv2.applyColorMap(img, cv2.COLORMAP_JET) # 生成heat map 161 | # img = img[:, :, ::-1] # 注意cv2(BGR)和matplotlib(RGB)通道是相反的 162 | # plt.imshow(img) 163 | # fig.savefig(savename, dpi=100) 164 | # fig.clf() 165 | # plt.close() 166 | # # print("time:{}".format(time.time()-tic)) 167 | # 168 | 169 | class DMINet(nn.Module): 170 | def __init__(self, num_classes=2, drop_rate=0.2, normal_init=True, pretrained=False, show_Feature_Maps=False): 171 | super(DMINet, self).__init__() 172 | 173 | self.show_Feature_Maps = show_Feature_Maps 174 | 175 | self.resnet = resnet18() 176 | self.resnet.load_state_dict(torch.load('./pretrain_model/resnet18-5c106cde.pth')) 177 | self.resnet.layer4 = nn.Identity() 178 | 179 | self.cross2 = CrossAtt(256, 256) 180 | self.cross3 = CrossAtt(128, 128) 181 | self.cross4 = CrossAtt(64, 64) 182 | 183 | self.Translayer2_1 = BasicConv2d(256, 128, 1) 184 | self.fam32_1 = decode(128, 128, 128) # AlignBlock(128) # decode(128,128,128) 185 | self.Translayer3_1 = BasicConv2d(128, 64, 1) 186 | self.fam43_1 = decode(64, 64, 64) # AlignBlock(64) # decode(64,64,64) 187 | 188 | self.Translayer2_2 = BasicConv2d(256, 128, 1) 189 | self.fam32_2 = decode(128, 128, 128) 190 | self.Translayer3_2 = BasicConv2d(128, 64, 1) 191 | self.fam43_2 = decode(64, 64, 64) 192 | 193 | self.upsamplex4 = nn.Upsample(scale_factor=4, mode='bilinear') 194 | self.upsamplex8 = nn.Upsample(scale_factor=8, mode='bilinear') 195 | 196 | self.final = nn.Sequential( 197 | Conv(64, 32, 3, bn=True, relu=True), 198 | Conv(32, num_classes, 3, bn=False, relu=False) 199 | ) 200 | self.final2 = nn.Sequential( 201 | Conv(64, 32, 3, bn=True, relu=True), 202 | Conv(32, num_classes, 3, bn=False, relu=False) 203 | ) 204 | 205 | self.final_2 = nn.Sequential( 206 | Conv(128, 32, 3, bn=True, relu=True), 207 | Conv(32, num_classes, 3, bn=False, relu=False) 208 | ) 209 | self.final2_2 = nn.Sequential( 210 | Conv(128, 32, 3, bn=True, relu=True), 211 | Conv(32, num_classes, 3, bn=False, relu=False) 212 | ) 213 | if normal_init: 214 | self.init_weights() 215 | 216 | def forward(self, imgs1, imgs2, labels=None): 217 | 218 | c0 = self.resnet.conv1(imgs1) 219 | c0 = self.resnet.bn1(c0) 220 | c0 = self.resnet.relu(c0) 221 | c1 = self.resnet.maxpool(c0) 222 | c1 = self.resnet.layer1(c1) 223 | c2 = self.resnet.layer2(c1) 224 | c3 = self.resnet.layer3(c2) 225 | 226 | c0_img2 = self.resnet.conv1(imgs2) 227 | c0_img2 = self.resnet.bn1(c0_img2) 228 | c0_img2 = self.resnet.relu(c0_img2) 229 | c1_img2 = self.resnet.maxpool(c0_img2) 230 | c1_img2 = self.resnet.layer1(c1_img2) 231 | c2_img2 = self.resnet.layer2(c1_img2) 232 | c3_img2 = self.resnet.layer3(c2_img2) 233 | 234 | cross_result2, cur1_2, cur2_2 = self.cross2(c3, c3_img2) 235 | cross_result3, cur1_3, cur2_3 = self.cross3(c2, c2_img2) 236 | cross_result4, cur1_4, cur2_4 = self.cross4(c1, c1_img2) 237 | 238 | out3 = self.fam32_1(cross_result3, self.Translayer2_1(cross_result2)) 239 | out4 = self.fam43_1(cross_result4, self.Translayer3_1(out3)) 240 | 241 | out3_2 = self.fam32_2(torch.abs(cur1_3 - cur2_3), self.Translayer2_2(torch.abs(cur1_2 - cur2_2))) 242 | out4_2 = self.fam43_2(torch.abs(cur1_4 - cur2_4), self.Translayer3_2(out3_2)) 243 | 244 | out4_up = self.upsamplex4(out4) 245 | out4_2_up = self.upsamplex4(out4_2) 246 | out_1 = self.final(out4_up) 247 | out_2 = self.final2(out4_2_up) 248 | 249 | out_1_2 = self.final_2(self.upsamplex8(out3)) 250 | out_2_2 = self.final2_2(self.upsamplex8(out3_2)) 251 | 252 | # if self.show_Feature_Maps: 253 | # savepath = r'temp' 254 | # draw_features(8, 8, (F.interpolate(c1, scale_factor=4, mode='bilinear')).cpu().detach().numpy(), 255 | # "{}/c1_img1.png".format(savepath)) 256 | # draw_features(8, 16, (F.interpolate(c2, scale_factor=8, mode='bilinear')).cpu().detach().numpy(), 257 | # "{}/c2_img1.png".format(savepath)) 258 | # draw_features(16, 16, (F.interpolate(c3, scale_factor=8, mode='bilinear')).cpu().detach().numpy(), 259 | # "{}/c3_img1.png".format(savepath)) 260 | # draw_features(8, 8, (F.interpolate(c1_img2, scale_factor=4, mode='bilinear')).cpu().detach().numpy(), 261 | # "{}/c1_img2.png".format(savepath)) 262 | # draw_features(8, 16, (F.interpolate(c2_img2, scale_factor=8, mode='bilinear')).cpu().detach().numpy(), 263 | # "{}/c2_img2.png".format(savepath)) 264 | # draw_features(16, 16, (F.interpolate(c3_img2, scale_factor=8, mode='bilinear')).cpu().detach().numpy(), 265 | # "{}/c3_img2.png".format(savepath)) 266 | # # You can show more. 267 | 268 | return out_1, out_2, out_1_2, out_2_2 269 | 270 | def init_weights(self): 271 | self.cross2.apply(init_weights) 272 | self.cross3.apply(init_weights) 273 | self.cross4.apply(init_weights) 274 | 275 | self.fam32_1.apply(init_weights) 276 | self.Translayer2_1.apply(init_weights) 277 | self.fam43_1.apply(init_weights) 278 | self.Translayer3_1.apply(init_weights) 279 | 280 | self.fam32_2.apply(init_weights) 281 | self.Translayer2_2.apply(init_weights) 282 | self.fam43_2.apply(init_weights) 283 | self.Translayer3_2.apply(init_weights) 284 | 285 | self.final.apply(init_weights) 286 | self.final2.apply(init_weights) 287 | self.final_2.apply(init_weights) 288 | self.final2_2.apply(init_weights) 289 | -------------------------------------------------------------------------------- /compare/A2Net.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from compare.MobileNet import mobilenet_v2 5 | 6 | class NeighborFeatureAggregation(nn.Module): 7 | def __init__(self, in_d=None, out_d=64): 8 | super(NeighborFeatureAggregation, self).__init__() 9 | if in_d is None: 10 | in_d = [16, 24, 32, 96, 320] 11 | self.in_d = in_d 12 | self.mid_d = out_d // 2 13 | self.out_d = out_d 14 | # scale 2 15 | self.conv_scale2_c2 = nn.Sequential( 16 | nn.Conv2d(self.in_d[1], self.mid_d, kernel_size=3, stride=1, padding=1), 17 | nn.BatchNorm2d(self.mid_d), 18 | nn.ReLU(inplace=True) 19 | ) 20 | self.conv_scale2_c3 = nn.Sequential( 21 | nn.Conv2d(self.in_d[2], self.mid_d, kernel_size=3, stride=1, padding=1), 22 | nn.BatchNorm2d(self.mid_d), 23 | nn.ReLU(inplace=True) 24 | ) 25 | self.conv_aggregation_s2 = FeatureFusionModule(self.mid_d * 2, self.in_d[1], self.out_d) 26 | # scale 3 27 | self.conv_scale3_c2 = nn.Sequential( 28 | nn.MaxPool2d(kernel_size=2, stride=2), 29 | nn.Conv2d(self.in_d[1], self.mid_d, kernel_size=3, stride=1, padding=1), 30 | nn.BatchNorm2d(self.mid_d), 31 | nn.ReLU(inplace=True) 32 | ) 33 | self.conv_scale3_c3 = nn.Sequential( 34 | nn.Conv2d(self.in_d[2], self.mid_d, kernel_size=3, stride=1, padding=1), 35 | nn.BatchNorm2d(self.mid_d), 36 | nn.ReLU(inplace=True) 37 | ) 38 | self.conv_scale3_c4 = nn.Sequential( 39 | nn.Conv2d(self.in_d[3], self.mid_d, kernel_size=3, stride=1, padding=1), 40 | nn.BatchNorm2d(self.mid_d), 41 | nn.ReLU(inplace=True) 42 | ) 43 | self.conv_aggregation_s3 = FeatureFusionModule(self.mid_d * 3, self.in_d[2], self.out_d) 44 | # scale 4 45 | self.conv_scale4_c3 = nn.Sequential( 46 | nn.MaxPool2d(kernel_size=2, stride=2), 47 | nn.Conv2d(self.in_d[2], self.mid_d, kernel_size=3, stride=1, padding=1), 48 | nn.BatchNorm2d(self.mid_d), 49 | nn.ReLU(inplace=True) 50 | ) 51 | self.conv_scale4_c4 = nn.Sequential( 52 | nn.Conv2d(self.in_d[3], self.mid_d, kernel_size=3, stride=1, padding=1), 53 | nn.BatchNorm2d(self.mid_d), 54 | nn.ReLU(inplace=True) 55 | ) 56 | self.conv_scale4_c5 = nn.Sequential( 57 | nn.Conv2d(self.in_d[4], self.mid_d, kernel_size=3, stride=1, padding=1), 58 | nn.BatchNorm2d(self.mid_d), 59 | nn.ReLU(inplace=True) 60 | ) 61 | self.conv_aggregation_s4 = FeatureFusionModule(self.mid_d * 3, self.in_d[3], self.out_d) 62 | # scale 5 63 | self.conv_scale5_c4 = nn.Sequential( 64 | nn.MaxPool2d(kernel_size=2, stride=2), 65 | nn.Conv2d(self.in_d[3], self.mid_d, kernel_size=3, stride=1, padding=1), 66 | nn.BatchNorm2d(self.mid_d), 67 | nn.ReLU(inplace=True) 68 | ) 69 | self.conv_scale5_c5 = nn.Sequential( 70 | nn.Conv2d(self.in_d[4], self.mid_d, kernel_size=3, stride=1, padding=1), 71 | nn.BatchNorm2d(self.mid_d), 72 | nn.ReLU(inplace=True) 73 | ) 74 | self.conv_aggregation_s5 = FeatureFusionModule(self.mid_d * 2, self.in_d[4], self.out_d) 75 | 76 | def forward(self, c2, c3, c4, c5): 77 | # scale 2 78 | c2_s2 = self.conv_scale2_c2(c2) 79 | 80 | c3_s2 = self.conv_scale2_c3(c3) 81 | c3_s2 = F.interpolate(c3_s2, scale_factor=(2, 2), mode='bilinear') 82 | 83 | s2 = self.conv_aggregation_s2(torch.cat([c2_s2, c3_s2], dim=1), c2) 84 | # scale 3 85 | c2_s3 = self.conv_scale3_c2(c2) 86 | 87 | c3_s3 = self.conv_scale3_c3(c3) 88 | 89 | c4_s3 = self.conv_scale3_c4(c4) 90 | c4_s3 = F.interpolate(c4_s3, scale_factor=(2, 2), mode='bilinear') 91 | 92 | s3 = self.conv_aggregation_s3(torch.cat([c2_s3, c3_s3, c4_s3], dim=1), c3) 93 | # scale 4 94 | c3_s4 = self.conv_scale4_c3(c3) 95 | 96 | c4_s4 = self.conv_scale4_c4(c4) 97 | 98 | c5_s4 = self.conv_scale4_c5(c5) 99 | c5_s4 = F.interpolate(c5_s4, scale_factor=(2, 2), mode='bilinear') 100 | 101 | s4 = self.conv_aggregation_s4(torch.cat([c3_s4, c4_s4, c5_s4], dim=1), c4) 102 | # scale 5 103 | c4_s5 = self.conv_scale5_c4(c4) 104 | 105 | c5_s5 = self.conv_scale5_c5(c5) 106 | 107 | s5 = self.conv_aggregation_s5(torch.cat([c4_s5, c5_s5], dim=1), c5) 108 | 109 | return s2, s3, s4, s5 110 | 111 | 112 | class FeatureFusionModule(nn.Module): 113 | def __init__(self, fuse_d, id_d, out_d): 114 | super(FeatureFusionModule, self).__init__() 115 | self.fuse_d = fuse_d 116 | self.id_d = id_d 117 | self.out_d = out_d 118 | self.conv_fuse = nn.Sequential( 119 | nn.Conv2d(self.fuse_d, self.out_d, kernel_size=3, stride=1, padding=1), 120 | nn.BatchNorm2d(self.out_d), 121 | nn.ReLU(inplace=True), 122 | nn.Conv2d(self.out_d, self.out_d, kernel_size=3, stride=1, padding=1), 123 | nn.BatchNorm2d(self.out_d) 124 | ) 125 | self.conv_identity = nn.Conv2d(self.id_d, self.out_d, kernel_size=1) 126 | self.relu = nn.ReLU(inplace=True) 127 | 128 | def forward(self, c_fuse, c): 129 | c_fuse = self.conv_fuse(c_fuse) 130 | c_out = self.relu(c_fuse + self.conv_identity(c)) 131 | 132 | return c_out 133 | 134 | 135 | class TemporalFeatureFusionModule(nn.Module): 136 | def __init__(self, in_d, out_d): 137 | super(TemporalFeatureFusionModule, self).__init__() 138 | self.in_d = in_d 139 | self.out_d = out_d 140 | self.relu = nn.ReLU(inplace=True) 141 | # branch 1 142 | self.conv_branch1 = nn.Sequential( 143 | nn.Conv2d(self.in_d, self.in_d, kernel_size=3, stride=1, padding=7, dilation=7), 144 | nn.BatchNorm2d(self.in_d) 145 | ) 146 | # branch 2 147 | self.conv_branch2 = nn.Conv2d(self.in_d, self.in_d, kernel_size=1) 148 | self.conv_branch2_f = nn.Sequential( 149 | nn.Conv2d(self.in_d, self.in_d, kernel_size=3, stride=1, padding=5, dilation=5), 150 | nn.BatchNorm2d(self.in_d) 151 | ) 152 | # branch 3 153 | self.conv_branch3 = nn.Conv2d(self.in_d, self.in_d, kernel_size=1) 154 | self.conv_branch3_f = nn.Sequential( 155 | nn.Conv2d(self.in_d, self.in_d, kernel_size=3, stride=1, padding=3, dilation=3), 156 | nn.BatchNorm2d(self.in_d) 157 | ) 158 | # branch 4 159 | self.conv_branch4 = nn.Conv2d(self.in_d, self.in_d, kernel_size=1) 160 | self.conv_branch4_f = nn.Sequential( 161 | nn.Conv2d(self.in_d, self.out_d, kernel_size=3, stride=1, padding=1, dilation=1), 162 | nn.BatchNorm2d(self.out_d) 163 | ) 164 | self.conv_branch5 = nn.Conv2d(self.in_d, self.out_d, kernel_size=1) 165 | 166 | def forward(self, x1, x2): 167 | # temporal fusion 168 | x = torch.abs(x1 - x2) 169 | # branch 1 170 | x_branch1 = self.conv_branch1(x) 171 | # branch 2 172 | x_branch2 = self.relu(self.conv_branch2(x) + x_branch1) 173 | x_branch2 = self.conv_branch2_f(x_branch2) 174 | # branch 3 175 | x_branch3 = self.relu(self.conv_branch3(x) + x_branch2) 176 | x_branch3 = self.conv_branch3_f(x_branch3) 177 | # branch 4 178 | x_branch4 = self.relu(self.conv_branch4(x) + x_branch3) 179 | x_branch4 = self.conv_branch4_f(x_branch4) 180 | x_out = self.relu(self.conv_branch5(x) + x_branch4) 181 | 182 | return x_out 183 | 184 | 185 | class TemporalFusionModule(nn.Module): 186 | def __init__(self, in_d=32, out_d=32): 187 | super(TemporalFusionModule, self).__init__() 188 | self.in_d = in_d 189 | self.out_d = out_d 190 | # fusion 191 | self.tffm_x2 = TemporalFeatureFusionModule(self.in_d, self.out_d) 192 | self.tffm_x3 = TemporalFeatureFusionModule(self.in_d, self.out_d) 193 | self.tffm_x4 = TemporalFeatureFusionModule(self.in_d, self.out_d) 194 | self.tffm_x5 = TemporalFeatureFusionModule(self.in_d, self.out_d) 195 | 196 | def forward(self, x1_2, x1_3, x1_4, x1_5, x2_2, x2_3, x2_4, x2_5): 197 | # temporal fusion 198 | c2 = self.tffm_x2(x1_2, x2_2) 199 | c3 = self.tffm_x3(x1_3, x2_3) 200 | c4 = self.tffm_x4(x1_4, x2_4) 201 | c5 = self.tffm_x5(x1_5, x2_5) 202 | 203 | return c2, c3, c4, c5 204 | 205 | class SupervisedAttentionModule(nn.Module): 206 | def __init__(self, mid_d): 207 | super(SupervisedAttentionModule, self).__init__() 208 | self.mid_d = mid_d 209 | # fusion 210 | self.cls = nn.Conv2d(self.mid_d, 2, kernel_size=1) 211 | self.conv_context = nn.Sequential( 212 | nn.Conv2d(4, self.mid_d, kernel_size=1), 213 | nn.BatchNorm2d(self.mid_d), 214 | nn.ReLU(inplace=True) 215 | ) 216 | self.conv2 = nn.Sequential( 217 | nn.Conv2d(self.mid_d, self.mid_d, kernel_size=3, stride=1, padding=1), 218 | nn.BatchNorm2d(self.mid_d), 219 | nn.ReLU(inplace=True) 220 | ) 221 | 222 | def forward(self, x): 223 | mask = self.cls(x) 224 | mask_f = torch.sigmoid(mask) 225 | mask_b = 1 - mask_f 226 | context = torch.cat([mask_f, mask_b], dim=1) 227 | context = self.conv_context(context) 228 | x = x.mul(context) 229 | x_out = self.conv2(x) 230 | 231 | return x_out, mask 232 | 233 | 234 | class Decoder(nn.Module): 235 | def __init__(self, mid_d=320): 236 | super(Decoder, self).__init__() 237 | self.mid_d = mid_d #64 238 | # fusion 239 | self.sam_p5 = SupervisedAttentionModule(self.mid_d) 240 | self.sam_p4 = SupervisedAttentionModule(self.mid_d) 241 | self.sam_p3 = SupervisedAttentionModule(self.mid_d) 242 | self.conv_p4 = nn.Sequential( 243 | nn.Conv2d(self.mid_d, self.mid_d, kernel_size=3, stride=1, padding=1), 244 | nn.BatchNorm2d(self.mid_d), 245 | nn.ReLU(inplace=True) 246 | ) 247 | self.conv_p3 = nn.Sequential( 248 | nn.Conv2d(self.mid_d, self.mid_d, kernel_size=3, stride=1, padding=1), 249 | nn.BatchNorm2d(self.mid_d), 250 | nn.ReLU(inplace=True) 251 | ) 252 | self.conv_p2 = nn.Sequential( 253 | nn.Conv2d(self.mid_d, self.mid_d, kernel_size=3, stride=1, padding=1), 254 | nn.BatchNorm2d(self.mid_d), 255 | nn.ReLU(inplace=True) 256 | ) 257 | self.cls = nn.Conv2d(self.mid_d, 2, kernel_size=1) 258 | 259 | def forward(self, d2, d3, d4, d5): 260 | # high-level 261 | p5, mask_p5 = self.sam_p5(d5) 262 | p4 = self.conv_p4(d4 + F.interpolate(p5, scale_factor=(2, 2), mode='bilinear')) 263 | 264 | p4, mask_p4 = self.sam_p4(p4) 265 | p3 = self.conv_p3(d3 + F.interpolate(p4, scale_factor=(2, 2), mode='bilinear')) 266 | 267 | p3, mask_p3 = self.sam_p3(p3) 268 | p2 = self.conv_p2(d2 + F.interpolate(p3, scale_factor=(2, 2), mode='bilinear')) 269 | mask_p2 = self.cls(p2) 270 | 271 | return p2, p3, p4, p5, mask_p2, mask_p3, mask_p4, mask_p5 272 | 273 | class A2Net(nn.Module): 274 | def __init__(self, input_nc=3, output_nc=1): 275 | super(A2Net, self).__init__() 276 | self.backbone = mobilenet_v2(pretrained=True) 277 | channles = [16, 24, 32, 96, 320] 278 | self.en_d = 32 279 | self.mid_d = self.en_d * 2 280 | self.swa = NeighborFeatureAggregation(channles, self.mid_d) 281 | self.tfm = TemporalFusionModule(self.mid_d, self.en_d * 2) 282 | self.decoder = Decoder(self.en_d * 2) 283 | 284 | def forward(self, x1, x2): 285 | # forward backbone resnet 286 | x1_1, x1_2, x1_3, x1_4, x1_5 = self.backbone(x1) 287 | x2_1, x2_2, x2_3, x2_4, x2_5 = self.backbone(x2) 288 | # aggregation 289 | x1_2, x1_3, x1_4, x1_5 = self.swa(x1_2, x1_3, x1_4, x1_5) 290 | x2_2, x2_3, x2_4, x2_5 = self.swa(x2_2, x2_3, x2_4, x2_5) 291 | # temporal fusion 292 | c2, c3, c4, c5 = self.tfm(x1_2, x1_3, x1_4, x1_5, x2_2, x2_3, x2_4, x2_5) 293 | # fpn 294 | p2, p3, p4, p5, mask_p2, mask_p3, mask_p4, mask_p5 = self.decoder(c2, c3, c4, c5) 295 | 296 | # change map 297 | mask_p2 = F.interpolate(mask_p2, scale_factor=(4, 4), mode='bilinear') 298 | # mask_p2 = torch.sigmoid(mask_p2) 299 | mask_p3 = F.interpolate(mask_p3, scale_factor=(8, 8), mode='bilinear') 300 | # mask_p3 = torch.sigmoid(mask_p3) 301 | mask_p4 = F.interpolate(mask_p4, scale_factor=(16, 16), mode='bilinear') 302 | # mask_p4 = torch.sigmoid(mask_p4) 303 | mask_p5 = F.interpolate(mask_p5, scale_factor=(32, 32), mode='bilinear') 304 | # mask_p5 = torch.sigmoid(mask_p5) 305 | 306 | return [mask_p5, mask_p4, mask_p3, mask_p2] 307 | -------------------------------------------------------------------------------- /compare/DTCDSCN.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | from torchvision.models import ResNet 5 | import torch.nn.functional as F 6 | from functools import partial 7 | 8 | 9 | nonlinearity = partial(F.relu,inplace=True) 10 | 11 | class SELayer(nn.Module): 12 | def __init__(self, channel, reduction=16): 13 | super(SELayer, self).__init__() 14 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 15 | self.fc = nn.Sequential( 16 | nn.Linear(channel, channel // reduction, bias=False), 17 | nn.ReLU(inplace=True), 18 | nn.Linear(channel // reduction, channel, bias=False), 19 | nn.Sigmoid() 20 | ) 21 | 22 | def forward(self, x): 23 | b, c, _, _ = x.size() 24 | y = self.avg_pool(x).view(b, c) 25 | y = self.fc(y).view(b, c, 1, 1) 26 | return x * y.expand_as(x) 27 | 28 | class Dblock_more_dilate(nn.Module): 29 | def __init__(self, channel): 30 | super(Dblock_more_dilate, self).__init__() 31 | self.dilate1 = nn.Conv2d(channel, channel, kernel_size=3, dilation=1, padding=1) 32 | self.dilate2 = nn.Conv2d(channel, channel, kernel_size=3, dilation=2, padding=2) 33 | self.dilate3 = nn.Conv2d(channel, channel, kernel_size=3, dilation=4, padding=4) 34 | self.dilate4 = nn.Conv2d(channel, channel, kernel_size=3, dilation=8, padding=8) 35 | self.dilate5 = nn.Conv2d(channel, channel, kernel_size=3, dilation=16, padding=16) 36 | for m in self.modules(): 37 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): 38 | if m.bias is not None: 39 | m.bias.data.zero_() 40 | 41 | def forward(self, x): 42 | dilate1_out = nonlinearity(self.dilate1(x)) 43 | dilate2_out = nonlinearity(self.dilate2(dilate1_out)) 44 | dilate3_out = nonlinearity(self.dilate3(dilate2_out)) 45 | dilate4_out = nonlinearity(self.dilate4(dilate3_out)) 46 | dilate5_out = nonlinearity(self.dilate5(dilate4_out)) 47 | out = x + dilate1_out + dilate2_out + dilate3_out + dilate4_out + dilate5_out 48 | return out 49 | class Dblock(nn.Module): 50 | def __init__(self, channel): 51 | super(Dblock, self).__init__() 52 | self.dilate1 = nn.Conv2d(channel, channel, kernel_size=3, dilation=1, padding=1) 53 | self.dilate2 = nn.Conv2d(channel, channel, kernel_size=3, dilation=2, padding=2) 54 | self.dilate3 = nn.Conv2d(channel, channel, kernel_size=3, dilation=4, padding=4) 55 | self.dilate4 = nn.Conv2d(channel, channel, kernel_size=3, dilation=8, padding=8) 56 | # self.dilate5 = nn.Conv2d(channel, channel, kernel_size=3, dilation=16, padding=16) 57 | for m in self.modules(): 58 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): 59 | if m.bias is not None: 60 | m.bias.data.zero_() 61 | 62 | def forward(self, x): 63 | dilate1_out = nonlinearity(self.dilate1(x)) 64 | dilate2_out = nonlinearity(self.dilate2(dilate1_out)) 65 | dilate3_out = nonlinearity(self.dilate3(dilate2_out)) 66 | dilate4_out = nonlinearity(self.dilate4(dilate3_out)) 67 | # dilate5_out = nonlinearity(self.dilate5(dilate4_out)) 68 | out = x + dilate1_out + dilate2_out + dilate3_out + dilate4_out # + dilate5_out 69 | return out 70 | 71 | def conv3x3(in_planes, out_planes, stride=1): 72 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) 73 | 74 | class SEBasicBlock(nn.Module): 75 | expansion = 1 76 | 77 | def __init__(self, inplanes, planes, stride=1, downsample=None, reduction=16): 78 | super(SEBasicBlock, self).__init__() 79 | self.conv1 = conv3x3(inplanes, planes, stride) 80 | self.bn1 = nn.BatchNorm2d(planes) 81 | self.relu = nn.ReLU(inplace=True) 82 | self.conv2 = conv3x3(planes, planes, 1) 83 | self.bn2 = nn.BatchNorm2d(planes) 84 | self.se = SELayer(planes, reduction) 85 | self.downsample = downsample 86 | self.stride = stride 87 | 88 | def forward(self, x): 89 | residual = x 90 | out = self.conv1(x) 91 | out = self.bn1(out) 92 | out = self.relu(out) 93 | 94 | out = self.conv2(out) 95 | out = self.bn2(out) 96 | out = self.se(out) 97 | 98 | if self.downsample is not None: 99 | residual = self.downsample(x) 100 | 101 | out += residual 102 | out = self.relu(out) 103 | 104 | return out 105 | 106 | class DecoderBlock(nn.Module): 107 | def __init__(self, in_channels, n_filters): 108 | super(DecoderBlock,self).__init__() 109 | 110 | self.conv1 = nn.Conv2d(in_channels, in_channels // 4, 1) 111 | self.norm1 = nn.BatchNorm2d(in_channels // 4) 112 | self.relu1 = nonlinearity 113 | self.scse = SCSEBlock(in_channels // 4) 114 | 115 | self.deconv2 = nn.ConvTranspose2d(in_channels // 4, in_channels // 4, 3, stride=2, padding=1, output_padding=1) 116 | self.norm2 = nn.BatchNorm2d(in_channels // 4) 117 | self.relu2 = nonlinearity 118 | 119 | self.conv3 = nn.Conv2d(in_channels // 4, n_filters, 1) 120 | self.norm3 = nn.BatchNorm2d(n_filters) 121 | self.relu3 = nonlinearity 122 | 123 | def forward(self, x): 124 | x = self.conv1(x) 125 | x = self.norm1(x) 126 | x = self.relu1(x) 127 | y = self.scse(x) 128 | x = x + y 129 | x = self.deconv2(x) 130 | x = self.norm2(x) 131 | x = self.relu2(x) 132 | x = self.conv3(x) 133 | x = self.norm3(x) 134 | x = self.relu3(x) 135 | return x 136 | 137 | class SCSEBlock(nn.Module): 138 | def __init__(self, channel, reduction=16): 139 | super(SCSEBlock, self).__init__() 140 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 141 | 142 | '''self.channel_excitation = nn.Sequential(nn.(channel, int(channel//reduction)), 143 | nn.ReLU(inplace=True), 144 | nn.Linear(int(channel//reduction), channel), 145 | nn.Sigmoid())''' 146 | self.channel_excitation = nn.Sequential(nn.Conv2d(channel, int(channel//reduction), kernel_size=1, 147 | stride=1, padding=0, bias=False), 148 | nn.ReLU(inplace=True), 149 | nn.Conv2d(int(channel // reduction), channel,kernel_size=1, 150 | stride=1, padding=0, bias=False), 151 | nn.Sigmoid()) 152 | 153 | self.spatial_se = nn.Sequential(nn.Conv2d(channel, 1, kernel_size=1, 154 | stride=1, padding=0, bias=False), 155 | nn.Sigmoid()) 156 | 157 | def forward(self, x): 158 | bahs, chs, _, _ = x.size() 159 | 160 | # Returns a new tensor with the same data as the self tensor but of a different size. 161 | chn_se = self.avg_pool(x) 162 | chn_se = self.channel_excitation(chn_se) 163 | chn_se = torch.mul(x, chn_se) 164 | spa_se = self.spatial_se(x) 165 | spa_se = torch.mul(x, spa_se) 166 | return torch.add(chn_se, 1, spa_se) 167 | 168 | class CDNet_model(nn.Module): 169 | def __init__(self, in_channels=3, block=SEBasicBlock, layers=[3, 4, 6, 3], num_classes=2): 170 | super(CDNet_model, self).__init__() 171 | 172 | filters = [64, 128, 256, 512] 173 | self.inplanes = 64 174 | self.firstconv = nn.Conv2d(in_channels, 64, kernel_size=7, stride=2, padding=3, 175 | bias=False) 176 | self.firstbn = nn.BatchNorm2d(64) 177 | self.firstrelu = nn.ReLU(inplace=True) 178 | self.firstmaxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 179 | self.encoder1 = self._make_layer(block, 64, layers[0]) 180 | self.encoder2 = self._make_layer(block, 128, layers[1], stride=2) 181 | self.encoder3 = self._make_layer(block, 256, layers[2], stride=2) 182 | self.encoder4 = self._make_layer(block, 512, layers[3], stride=2) 183 | 184 | self.decoder4 = DecoderBlock(filters[3], filters[2]) 185 | self.decoder3 = DecoderBlock(filters[2], filters[1]) 186 | self.decoder2 = DecoderBlock(filters[1], filters[0]) 187 | self.decoder1 = DecoderBlock(filters[0], filters[0]) 188 | 189 | self.dblock_master = Dblock(512) 190 | self.dblock = Dblock(512) 191 | 192 | self.decoder4_master = DecoderBlock(filters[3], filters[2]) 193 | self.decoder3_master = DecoderBlock(filters[2], filters[1]) 194 | self.decoder2_master = DecoderBlock(filters[1], filters[0]) 195 | self.decoder1_master = DecoderBlock(filters[0], filters[0]) 196 | 197 | self.finaldeconv1_master = nn.ConvTranspose2d(filters[0], 32, 4, 2, 1) 198 | self.finalrelu1_master = nonlinearity 199 | self.finalconv2_master = nn.Conv2d(32, 32, 3, padding=1) 200 | self.finalrelu2_master = nonlinearity 201 | self.finalconv3_master = nn.Conv2d(32, num_classes, 3, padding=1) 202 | 203 | self.finaldeconv1 = nn.ConvTranspose2d(filters[0], 32, 4, 2, 1) 204 | self.finalrelu1 = nonlinearity 205 | self.finalconv2 = nn.Conv2d(32, 32, 3, padding=1) 206 | self.finalrelu2 = nonlinearity 207 | self.finalconv3 = nn.Conv2d(32, num_classes, 3, padding=1) 208 | 209 | for m in self.modules(): 210 | if isinstance(m, nn.Conv2d): 211 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 212 | m.weight.data.normal_(0, math.sqrt(2. / n)) 213 | elif isinstance(m, nn.BatchNorm2d): 214 | m.weight.data.fill_(1) 215 | m.bias.data.zero_() 216 | 217 | def _make_layer(self, block, planes, blocks, stride=1): 218 | downsample = None 219 | if stride != 1 or self.inplanes != planes * block.expansion: 220 | downsample = nn.Sequential( 221 | nn.Conv2d(self.inplanes, planes * block.expansion, 222 | kernel_size=1, stride=stride, bias=False), 223 | nn.BatchNorm2d(planes * block.expansion), 224 | ) 225 | 226 | layers = [] 227 | layers.append(block(self.inplanes, planes, stride, downsample)) 228 | self.inplanes = planes * block.expansion 229 | for i in range(1, blocks): 230 | layers.append(block(self.inplanes, planes)) 231 | 232 | return nn.Sequential(*layers) 233 | 234 | def forward(self, x, y): 235 | # Encoder_1 236 | x = self.firstconv(x) 237 | x = self.firstbn(x) 238 | x = self.firstrelu(x) 239 | x = self.firstmaxpool(x) 240 | 241 | e1_x = self.encoder1(x) 242 | e2_x = self.encoder2(e1_x) 243 | e3_x = self.encoder3(e2_x) 244 | e4_x = self.encoder4(e3_x) 245 | 246 | # # Center_1 247 | # e4_x_center = self.dblock(e4_x) 248 | 249 | # # Decoder_1 250 | # d4_x = self.decoder4(e4_x_center) + e3_x 251 | # d3_x = self.decoder3(d4_x) + e2_x 252 | # d2_x = self.decoder2(d3_x) + e1_x 253 | # d1_x = self.decoder1(d2_x) 254 | 255 | # out1 = self.finaldeconv1(d1_x) 256 | # out1 = self.finalrelu1(out1) 257 | # out1 = self.finalconv2(out1) 258 | # out1 = self.finalrelu2(out1) 259 | # out1 = self.finalconv3(out1) 260 | 261 | # Encoder_2 262 | y = self.firstconv(y) 263 | y = self.firstbn(y) 264 | y = self.firstrelu(y) 265 | y = self.firstmaxpool(y) 266 | 267 | e1_y = self.encoder1(y) 268 | e2_y = self.encoder2(e1_y) 269 | e3_y = self.encoder3(e2_y) 270 | e4_y = self.encoder4(e3_y) 271 | 272 | # # Center_2 273 | # e4_y_center = self.dblock(e4_y) 274 | 275 | # # Decoder_2 276 | # d4_y = self.decoder4(e4_y_center) + e3_y 277 | # d3_y = self.decoder3(d4_y) + e2_y 278 | # d2_y = self.decoder2(d3_y) + e1_y 279 | # d1_y = self.decoder1(d2_y) 280 | # out2 = self.finaldeconv1(d1_y) 281 | # out2 = self.finalrelu1(out2) 282 | # out2 = self.finalconv2(out2) 283 | # out2 = self.finalrelu2(out2) 284 | # out2 = self.finalconv3(out2) 285 | 286 | # center_master 287 | e4 = self.dblock_master(e4_x - e4_y) 288 | # decoder_master 289 | d4 = self.decoder4_master(e4) + e3_x - e3_y 290 | d3 = self.decoder3_master(d4) + e2_x - e2_y 291 | d2 = self.decoder2_master(d3) + e1_x - e1_y 292 | d1 = self.decoder1_master(d2) 293 | 294 | out = self.finaldeconv1_master(d1) 295 | out = self.finalrelu1_master(out) 296 | out = self.finalconv2_master(out) 297 | out = self.finalrelu2_master(out) 298 | output = self.finalconv3_master(out) 299 | 300 | # output = [] 301 | # output.append(out) 302 | 303 | return output 304 | 305 | 306 | 307 | def CDNet34(in_channels, **kwargs): 308 | 309 | model = CDNet_model(in_channels, SEBasicBlock, [3, 4, 6, 3], **kwargs) 310 | 311 | return model --------------------------------------------------------------------------------