├── requirements.txt ├── README.md ├── test_sam.py ├── datasets.py ├── metric.py ├── test.py ├── main.py └── models.py /requirements.txt: -------------------------------------------------------------------------------- 1 | albumentations 2 | einops 3 | timm 4 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CDRL-SA 2 | 3 | ## Update Note 4 | * (24.08.19) The paper has been accepted into Remote Sensing Letters(RSL). 5 | 6 | > [**Unsupervised Change Detection Based on Image Reconstruction Loss with Segment Anything**](https://doi.org/10.1080/2150704X.2024.2388851) 7 | > 8 | > [Hyeoncheol Noh](https://scholar.google.co.kr/citations?user=XTmafQgAAAAJ&hl), [Jingi Ju](https://scholar.google.co.kr/citations?user=hlJYrqAAAAAJ&hl), [Yuhyun Kim](), [Minwoo Kim](https://scholar.google.com/citations?user=c7el4JwAAAAJ), [Dong-Geol Choi](https://scholar.google.co.kr/citations?user=1498JWoAAAAJ&hl) 9 | > 10 | > *[Remote Sensing Letters](https://doi.org/10.1080/2150704X.2024.2388851)* 11 | ## Getting Started 12 | 13 | Dataset download link : 14 | * [LEVIR-CD](https://drive.google.com/file/d/18RGfTqPo1atw_IMm6xPOnND-Vl4ok_o3/view?usp=sharing) (Our multi-class LEVIR dataset is included in that link.) 15 | * [LEVIR-CD_A2B_B2A](https://drive.google.com/file/d/1-LERpM7GOxviKna47bbO_mLQON3Q0YcA/view?usp=sharing) 16 | * [CLCD-CD](https://drive.google.com/file/d/1F4RfWSvoghmIrir_2YlBYfgrJt-flzY8/view?usp=sharing) 17 | * [CLCD-CD_A2B_B2A](https://drive.google.com/file/d/1Q9COBNxg7r5PhgNzY60GTugotbS8AzUg/view?usp=sharing) 18 | 19 | SAM download link : 20 | * [SAM weight](https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth) 21 | 22 | ```angular2html 23 | CDRL-SA 24 | └──datasets 25 | ├── LEVIR-CD 26 | ├── val 27 | ├── test 28 | └── train 29 | ├── A 30 | ├── B 31 | └── label 32 | ├── LEVIR-CD_A2B_B2A 33 | └── train 34 | ├── A 35 | └── B 36 | ├── CLCD-CD 37 | └── CLCD-CD_A2B_B2A 38 | └──pretrain_weight 39 | └── sam_vit_h_4b8939.pth 40 | 41 | ``` 42 | 43 | ## Train 44 | ```bash 45 | python main.py --root_path ./datasets/ --dataset_name LEVIR-CD --save_name levir 46 | ``` 47 | 48 | ## CDRL Difference Map Generate 49 | ```bash 50 | python test.py --root_path ./datasets/ --dataset_name LEVIR-CD --save_name levir 51 | ``` 52 | 53 | ## CDRL-SA Refine Map Generate 54 | ```bash 55 | python test_sam.py --root_path ./datasets/ --dataset_name LEVIR-CD --save_name levir 56 | ``` 57 | -------------------------------------------------------------------------------- /test_sam.py: -------------------------------------------------------------------------------- 1 | from segment_anything import sam_model_registry, SamAutomaticMaskGenerator 2 | import cv2 3 | import numpy as np 4 | import torch 5 | from glob import glob 6 | from tqdm import tqdm 7 | import argparse 8 | import os 9 | 10 | 11 | def make_mask(A_sam_masks, B_sam_masks, bool_diff_map, mask_thre, overlap_thre): 12 | w,h = bool_diff_map.shape 13 | refine_mask = torch.tensor(np.zeros((w,h))).cuda() 14 | thre_size = w * h * mask_thre 15 | overlap_size = w * h * overlap_thre 16 | bool_diff_map = torch.tensor(bool_diff_map).cuda() 17 | 18 | for B_sam_mask_dict in B_sam_masks: 19 | B_sam_mask = torch.tensor(B_sam_mask_dict['segmentation']).cuda() 20 | overlap_trigger = 0 21 | _, b_sam_mask_count = torch.unique(B_sam_mask,return_counts=True) 22 | if b_sam_mask_count[-1] > thre_size: 23 | continue 24 | for A_sam_mask_dict in A_sam_masks: 25 | A_sam_mask = torch.tensor(A_sam_mask_dict['segmentation']).cuda() 26 | _, a_sam_mask_count = torch.unique(A_sam_mask,return_counts=True) 27 | if a_sam_mask_count[-1] > thre_size: 28 | continue 29 | overlap_ab = B_sam_mask * A_sam_mask 30 | _, overlap_ab_count = torch.unique(overlap_ab,return_counts=True) 31 | if len(overlap_ab_count) == 2: 32 | if overlap_ab_count[-1] > (b_sam_mask_count[-1] * 0.95) and overlap_ab_count[-1] > (b_sam_mask_count[-1] * 1.05): 33 | if overlap_ab_count[-1] > (a_sam_mask_count[-1] * 0.95) and overlap_ab_count[-1] > (a_sam_mask_count[-1] * 1.05): 34 | overlap_trigger = 1 35 | break 36 | if overlap_trigger == 0: 37 | overlap_diff = bool_diff_map * B_sam_mask 38 | _, overlap_diff_count = torch.unique(overlap_diff,return_counts=True) 39 | if len(overlap_diff_count) == 2 and (b_sam_mask_count[-1] * overlap_thre) <= overlap_diff_count[-1]: 40 | refine_mask[B_sam_mask==True] = 255 41 | return refine_mask 42 | 43 | 44 | def main(): 45 | parser = argparse.ArgumentParser(description='Make test sam image') 46 | parser.add_argument('--root_path', default='./datasets/', type=str) 47 | parser.add_argument('--dataset_name', default='LEVIR-CD', type=str) 48 | parser.add_argument('--save_name', default='levir', type=str) 49 | parser.add_argument('--mode', default='test', type=str) 50 | parser.add_argument('--mask_thre', default=0.05, type=int) 51 | parser.add_argument('--overlap_thre', default=0.1, type=int) 52 | args = parser.parse_args() 53 | 54 | save_dir = './sam_refine_mask/' + args.save_name + '/' 55 | if not os.path.exists(save_dir): 56 | os.makedirs(save_dir) 57 | 58 | A_img_dir = args.root_path + args.dataset_name + '/' + args.mode + '/A/' 59 | B_img_dir = args.root_path + args.dataset_name + '/' + args.mode + '/B/' 60 | 61 | # CDRL output directory 62 | diff_map_path = './pixel_img_morpho/' + args.save_name + '/' 63 | 64 | A_img_paths = glob(A_img_dir + '*') 65 | A_img_paths.sort() 66 | 67 | sam = sam_model_registry["default"](checkpoint="./pretrain_weight/sam_vit_h_4b8939.pth") 68 | sam.to(device='cuda') 69 | mask_generator = SamAutomaticMaskGenerator(sam) 70 | 71 | for A_img_path in tqdm(A_img_paths[:2]): 72 | img_name = A_img_path.split('/')[-1] 73 | A_img = cv2.imread(A_img_path) 74 | B_img = cv2.imread(B_img_dir + img_name) 75 | ori_w, ori_h, _ = B_img.shape 76 | 77 | diff_map = cv2.imread(diff_map_path + img_name, 0) 78 | diff_map = cv2.resize(diff_map,(ori_w,ori_h)) 79 | bool_diff_map = np.where(diff_map>0,True,False) 80 | 81 | A_masks = mask_generator.generate(A_img) 82 | B_masks = mask_generator.generate(B_img) 83 | 84 | refine_mask = make_mask(A_masks, B_masks, bool_diff_map, args.mask_thre, args.overlap_thre) 85 | refine_mask = refine_mask.cpu().numpy() 86 | cv2.imwrite(save_dir + img_name, refine_mask) 87 | 88 | if __name__ == '__main__': 89 | main() 90 | -------------------------------------------------------------------------------- /datasets.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | import cv2 4 | from torch.utils.data import Dataset 5 | import numpy as np 6 | import random 7 | 8 | 9 | class CutSwap(object): 10 | 11 | def __init__(self, n_holes, length): 12 | self.n_holes = n_holes 13 | self.length = length 14 | 15 | def __call__(self, img, img2): 16 | """ 17 | Args: 18 | img (Tensor): Tensor image of size (C, H, W). 19 | Returns: 20 | Tensor: Image with n_holes of dimension length x length cut out of it. 21 | """ 22 | h = img.size(1) 23 | w = img.size(2) 24 | img_ = img.clone().detach() 25 | local = [] 26 | for n in range(self.n_holes): 27 | y = np.random.randint(h) 28 | x = np.random.randint(w) 29 | 30 | y1 = np.clip(y - self.length // 2, 0, h) 31 | y2 = np.clip(y + self.length // 2, 0, h) 32 | x1 = np.clip(x - self.length // 2, 0, w) 33 | x2 = np.clip(x + self.length // 2, 0, w) 34 | 35 | img[:,y1: y2, x1: x2] = img2[:,y1: y2, x1: x2] 36 | img2[:,y1: y2, x1: x2] = img_[:,y1: y2, x1: x2] 37 | local.append([y1, y2, x1, x2]) 38 | 39 | return img,img2,local 40 | 41 | 42 | 43 | class CDRL_Dataset_CutSwap(Dataset): 44 | def __init__(self, root_path=None, dataset=None, train_val=None, transforms_A=None, transforms_B=None, n_holes=None): 45 | self.transforms_A = transforms_A 46 | self.transforms_B = transforms_B 47 | self.train_val = train_val 48 | self.files = [] 49 | self.n_holes = n_holes 50 | 51 | for data in dataset.split(','): 52 | if data!='': 53 | self.total_path = os.path.join(root_path, data, train_val) 54 | self.files += sorted(glob.glob(self.total_path + "/A/*.*")) +\ 55 | sorted(glob.glob(self.total_path + "/B/*.*")) 56 | 57 | def __len__(self): 58 | return len(self.files) 59 | 60 | def __getitem__(self, index): 61 | img_name = self.files[index % len(self.files)].split('/')[-1] 62 | img_A = cv2.imread(self.files[index % len(self.files)], cv2.IMREAD_COLOR) 63 | img_ori = img_A.copy() 64 | A2BB2A_path = self.files[index % len(self.files)].split('/'+self.train_val+'/')[0]+'_A2B_B2A/' 65 | if '/A/' in self.files[index % len(self.files)]: 66 | img_B = cv2.imread(A2BB2A_path+self.train_val+ '/A/'+img_name, cv2.IMREAD_COLOR) 67 | elif '/B/' in self.files[index % len(self.files)]: 68 | img_B = cv2.imread(A2BB2A_path+self.train_val+ '/B/'+img_name, cv2.IMREAD_COLOR) 69 | 70 | transformed_A = self.transforms_A(image=img_A) 71 | transformed_B = self.transforms_B(image=img_B) 72 | 73 | img_A = transformed_A["image"] 74 | img_B = transformed_B["image"] 75 | 76 | cutmix_ = CutSwap(n_holes=self.n_holes, length=64) 77 | img_A_cutmix = img_A.clone().detach() 78 | img_B_cutmix = img_B.clone().detach() 79 | img_A_cutmix,img_B_cutmix, local = cutmix_(img_A_cutmix,img_B_cutmix) 80 | 81 | return {"A":img_A , "B": img_B, "A_cutmix": img_A_cutmix,"B_cutmix": img_B_cutmix, "local":local} 82 | 83 | 84 | 85 | class CDRL_Dataset_test(Dataset): 86 | def __init__(self, root_path=None, dataset=None, transforms=None): 87 | self.total_path = os.path.join(root_path, dataset, 'test') 88 | self.transforms = transforms 89 | self.files = sorted(glob.glob(self.total_path + "/A/*.*")) 90 | 91 | def __getitem__(self, index): 92 | name = self.files[index % len(self.files)].split('/')[-1] 93 | 94 | img_A = cv2.imread(self.files[index % len(self.files)], cv2.IMREAD_COLOR) 95 | img_B = cv2.imread(self.files[index % len(self.files)].replace('/A/','/B/'), cv2.IMREAD_COLOR) 96 | 97 | transformed_A = self.transforms(image=img_A) 98 | transformed_B = self.transforms(image=img_B) 99 | 100 | img_A = transformed_A["image"] 101 | img_B = transformed_B["image"] 102 | 103 | return {"A": img_A, "B": img_B, 'NAME': name} 104 | 105 | def __len__(self): 106 | return len(self.files) 107 | 108 | -------------------------------------------------------------------------------- /metric.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | ################### metrics ################### 5 | class AverageMeter(object): 6 | """Computes and stores the average and current value""" 7 | def __init__(self): 8 | self.initialized = False 9 | self.val = None 10 | self.avg = None 11 | self.sum = None 12 | self.count = None 13 | 14 | def initialize(self, val, weight): 15 | self.val = val 16 | self.avg = val 17 | self.sum = val * weight 18 | self.count = weight 19 | self.initialized = True 20 | 21 | def update(self, val, weight=1): 22 | if not self.initialized: 23 | self.initialize(val, weight) 24 | else: 25 | self.add(val, weight) 26 | 27 | def add(self, val, weight): 28 | self.val = val 29 | self.sum += val * weight 30 | self.count += weight 31 | self.avg = self.sum / self.count 32 | 33 | def value(self): 34 | return self.val 35 | 36 | def average(self): 37 | return self.avg 38 | 39 | def get_scores(self): 40 | scores_dict = cm2score(self.sum) 41 | return scores_dict 42 | 43 | def clear(self): 44 | self.initialized = False 45 | 46 | 47 | ################### cm metrics ################### 48 | class ConfuseMatrixMeter(AverageMeter): 49 | """Computes and stores the average and current value""" 50 | def __init__(self, n_class): 51 | super(ConfuseMatrixMeter, self).__init__() 52 | self.n_class = n_class 53 | 54 | def update_cm(self, pr, gt, weight=1): 55 | """获得当前混淆矩阵,并计算当前F1得分,并更新混淆矩阵""" 56 | val = get_confuse_matrix(num_classes=self.n_class, label_gts=gt, label_preds=pr) 57 | self.update(val, weight) 58 | 59 | return val 60 | 61 | def get_scores(self,val_sum): 62 | scores_dict = cm2score(val_sum) 63 | return scores_dict 64 | 65 | 66 | 67 | def harmonic_mean(xs): 68 | harmonic_mean = len(xs) / sum((x+1e-6)**-1 for x in xs) 69 | return harmonic_mean 70 | 71 | 72 | def cm2F1(confusion_matrix): 73 | # print(confusion_matrix.shape) 74 | hist = confusion_matrix 75 | n_class = hist.shape[0] 76 | tp = np.diag(hist) 77 | sum_a1 = hist.sum(axis=1) 78 | sum_a0 = hist.sum(axis=0) 79 | # ---------------------------------------------------------------------- # 80 | # 1. Accuracy & Class Accuracy 81 | # ---------------------------------------------------------------------- # 82 | acc = tp.sum() / (hist.sum() + np.finfo(np.float32).eps) 83 | 84 | # recall 85 | recall = tp / (sum_a1 + np.finfo(np.float32).eps) 86 | # acc_cls = np.nanmean(recall) 87 | 88 | # precision 89 | precision = tp / (sum_a0 + np.finfo(np.float32).eps) 90 | 91 | # F1 score 92 | F1 = 2 * recall * precision / (recall + precision + np.finfo(np.float32).eps) 93 | mean_F1 = np.nanmean(F1) 94 | return mean_F1 95 | 96 | 97 | def cm2score(confusion_matrix): 98 | hist = confusion_matrix 99 | n_class = hist.shape[0] 100 | tp = np.diag(hist) 101 | sum_a1 = hist.sum(axis=1) 102 | sum_a0 = hist.sum(axis=0) 103 | # ---------------------------------------------------------------------- # 104 | # 1. Accuracy & Class Accuracy 105 | # ---------------------------------------------------------------------- # 106 | acc = tp.sum() / (hist.sum() + np.finfo(np.float32).eps) 107 | 108 | # recall 109 | recall = tp / (sum_a1 + np.finfo(np.float32).eps) 110 | # acc_cls = np.nanmean(recall) 111 | 112 | # precision 113 | precision = tp / (sum_a0 + np.finfo(np.float32).eps) 114 | 115 | # F1 score 116 | F1 = 2*recall * precision / (recall + precision + np.finfo(np.float32).eps) 117 | mean_F1 = np.nanmean(F1) 118 | # ---------------------------------------------------------------------- # 119 | # 2. Frequency weighted Accuracy & Mean IoU 120 | # ---------------------------------------------------------------------- # 121 | 122 | iu = tp / (sum_a1 + hist.sum(axis=0) - tp + np.finfo(np.float32).eps) 123 | mean_iu = np.nanmean(iu) 124 | 125 | freq = sum_a1 / (hist.sum() + np.finfo(np.float32).eps) 126 | fwavacc = (freq[freq > 0] * iu[freq > 0]).sum() 127 | 128 | cls_iou = dict(zip(['iou_'+str(i) for i in range(n_class)], iu)) 129 | 130 | cls_precision = dict(zip(['precision_'+str(i) for i in range(n_class)], precision)) 131 | cls_recall = dict(zip(['recall_'+str(i) for i in range(n_class)], recall)) 132 | cls_F1 = dict(zip(['F1_'+str(i) for i in range(n_class)], F1)) 133 | 134 | score_dict = {'acc': acc, 'miou': mean_iu, 'mf1':mean_F1} 135 | score_dict.update(cls_iou) 136 | score_dict.update(cls_F1) 137 | score_dict.update(cls_precision) 138 | score_dict.update(cls_recall) 139 | return score_dict 140 | 141 | 142 | def get_confuse_matrix(num_classes, label_gts, label_preds): 143 | """计算一组预测的混淆矩阵""" 144 | def __fast_hist(label_gt, label_pred): 145 | """ 146 | Collect values for Confusion Matrix 147 | For reference, please see: https://en.wikipedia.org/wiki/Confusion_matrix 148 | :param label_gt: ground-truth 149 | :param label_pred: prediction 150 | :return: values for confusion matrix 151 | """ 152 | mask = (label_gt >= 0) & (label_gt < num_classes) 153 | hist = np.bincount(num_classes * label_gt[mask].astype(int) + label_pred[mask], 154 | minlength=num_classes**2).reshape(num_classes, num_classes) 155 | return hist 156 | confusion_matrix = np.zeros((num_classes, num_classes)) 157 | for lt, lp in zip(label_gts, label_preds): 158 | confusion_matrix += __fast_hist(lt.flatten(), lp.flatten()) 159 | return confusion_matrix 160 | 161 | 162 | def get_mIoU(num_classes, label_gts, label_preds): 163 | confusion_matrix = get_confuse_matrix(num_classes, label_gts, label_preds) 164 | score_dict = cm2score(confusion_matrix) 165 | return score_dict['miou'] 166 | 167 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import numpy as np 4 | import time 5 | import torch 6 | import torch.nn as nn 7 | from torchvision.utils import save_image 8 | from torch.utils.data import DataLoader 9 | from torch.autograd import Variable 10 | from models import * 11 | from datasets import * 12 | import albumentations as A 13 | from albumentations.pytorch.transforms import ToTensorV2 14 | from torchvision import transforms 15 | from torchvision.transforms.functional import to_pil_image 16 | from models import SwinTransformerSys 17 | 18 | import cv2 19 | import glob 20 | 21 | from metric import * 22 | from tqdm import tqdm 23 | 24 | parser = argparse.ArgumentParser() 25 | parser.add_argument("--root_path", type=str, default="./datasets/", help="root path") 26 | parser.add_argument("--dataset_name", type=str, default="LEVIR-CD", help="name of the dataset") 27 | parser.add_argument("--save_name", type=str, default="levir", help="name of the dataset") 28 | parser.add_argument("--n_cpu", type=int, default=4, help="number of cpu threads to use during batch generation") 29 | parser.add_argument("--img_height", type=int, default=256, help="size of image height") 30 | parser.add_argument("--img_width", type=int, default=256, help="size of image width") 31 | opt = parser.parse_args() 32 | print(opt) 33 | 34 | os.makedirs('pixel_img/'+opt.save_name, exist_ok=True) 35 | os.makedirs('gener_img/'+opt.save_name, exist_ok=True) 36 | 37 | cuda = True if torch.cuda.is_available() else False 38 | 39 | criterion_GAN = torch.nn.MSELoss() 40 | criterion_pixelwise = torch.nn.L1Loss() 41 | 42 | lambda_pixel = 100 43 | 44 | patch = (1, opt.img_height // 2 ** 4, opt.img_width // 2 ** 4) 45 | 46 | generator = SwinTransformerSys(img_size=256, 47 | patch_size=4, 48 | in_chans=6, 49 | num_classes=3, 50 | embed_dim=96, 51 | depths=[2, 2, 6, 2], 52 | num_heads=[3, 6, 12, 24], 53 | window_size=8, 54 | mlp_ratio=4., 55 | qkv_bias=True, 56 | qk_scale=None, 57 | drop_rate=0.0, 58 | drop_path_rate=0.1, 59 | ape=False, 60 | patch_norm=True, 61 | use_checkpoint=False) 62 | discriminator = Discriminator() 63 | 64 | if cuda: 65 | generator = generator.cuda() 66 | discriminator = discriminator.cuda() 67 | criterion_GAN.cuda() 68 | criterion_pixelwise.cuda() 69 | 70 | 71 | 72 | generator.load_state_dict(torch.load("saved_models/"+opt.save_name+"/generator_9.pth")) 73 | 74 | transforms_ = A.Compose([ 75 | A.Resize(opt.img_height, opt.img_width), 76 | A.Normalize(), 77 | ToTensorV2() 78 | ]) 79 | 80 | val_dataloader = DataLoader( 81 | CDRL_Dataset_test(opt.root_path, dataset=opt.dataset_name, transforms=transforms_), 82 | batch_size=1, 83 | shuffle=False, 84 | num_workers=opt.n_cpu, 85 | ) 86 | 87 | Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor 88 | 89 | 90 | def pixel_visual(gener_output_, A_ori_, name): 91 | gener_output = gener_output_.cpu().clone().detach().squeeze() 92 | A_ori = A_ori_.cpu().clone().detach().squeeze() 93 | 94 | pixel_loss = to_pil_image(torch.abs(gener_output-A_ori)) 95 | trans = transforms.Compose([ 96 | transforms.Grayscale(), 97 | transforms.ToTensor()]) 98 | pixel_loss = trans(pixel_loss) 99 | 100 | thre_num= 0.7 101 | threshold = nn.Threshold(thre_num, 0.) 102 | pixel_loss = threshold(pixel_loss) 103 | save_image(pixel_loss, 'pixel_img/'+opt.save_name+'/'+str(name[0])) 104 | save_image(gener_output.flip(-3), 'gener_img/'+opt.save_name+'/'+str(name[0]), normalize=True) 105 | 106 | 107 | prev_time = time.time() 108 | 109 | loss_G_total = 0 110 | 111 | generator.eval() 112 | 113 | 114 | with torch.no_grad(): 115 | for i, batch in enumerate(val_dataloader): 116 | 117 | img_A = Variable(batch["A"].type(Tensor)) 118 | img_B = Variable(batch["B"].type(Tensor)) 119 | name = batch["NAME"] 120 | 121 | valid = Variable(Tensor(np.ones((img_A.size(0), *patch))), requires_grad=False) 122 | 123 | img_A = img_A.cuda() 124 | img_B = img_B.cuda() 125 | img_AB = torch.cat([img_A,img_B], dim=1) 126 | gener_output = generator(img_AB) 127 | 128 | pixel_visual(gener_output, img_A, name) 129 | 130 | 131 | loss_pixel = criterion_pixelwise(gener_output, img_B) 132 | 133 | loss_G = lambda_pixel * loss_pixel 134 | 135 | loss_G_total += loss_G 136 | 137 | print('----------------------------total------------------------------') 138 | print('loss_G_total : ', round((loss_G_total/len(val_dataloader)).item(),4)) 139 | 140 | 141 | 142 | paths = glob.glob('./pixel_img/'+opt.save_name+'/*') 143 | 144 | if not os.path.isdir('./pixel_img_morpho'): 145 | os.mkdir('pixel_img_morpho') 146 | if not os.path.isdir('./pixel_img_morpho/'+opt.save_name): 147 | os.mkdir('pixel_img_morpho/'+opt.save_name) 148 | 149 | for path in paths: 150 | 151 | img = cv2.imread(path) 152 | 153 | k = cv2.getStructuringElement(cv2.MORPH_RECT, (7,7)) 154 | img = cv2.dilate(img, k) 155 | k = cv2.getStructuringElement(cv2.MORPH_RECT, (7,7)) 156 | img = cv2.erode(img, k) 157 | k = cv2.getStructuringElement(cv2.MORPH_RECT, (5,5)) 158 | img = cv2.erode(img, k) 159 | k = cv2.getStructuringElement(cv2.MORPH_RECT, (7,7)) 160 | img = cv2.erode(img, k) 161 | k = cv2.getStructuringElement(cv2.MORPH_RECT, (7,7)) 162 | img = cv2.erode(img, k) 163 | 164 | 165 | img_name = path.split('/')[-1] 166 | cv2.imwrite('./pixel_img_morpho/'+opt.save_name+'/'+img_name, img) 167 | 168 | 169 | 170 | con = ConfuseMatrixMeter(2) 171 | pred_path = glob.glob('./pixel_img_morpho/'+opt.save_name+'/*') 172 | 173 | scores_dict = 0. 174 | c = 0 175 | 176 | for img_path in tqdm(pred_path): 177 | gt = cv2.imread(opt.root_path + opt.dataset_name + '/test/label/' + img_path.split('/')[-1].replace('jpg','png'),0) 178 | gt = cv2.resize(gt,(256,256)) 179 | gt = np.expand_dims(gt,axis=0) 180 | 181 | pr = np.expand_dims(cv2.imread(img_path,0),axis=0) 182 | 183 | gt[gt>0] = 1 184 | pr[pr>0] = 1 185 | gt = gt.astype(int) 186 | pr = pr.astype(int) 187 | 188 | scores_dict += con.update_cm(gt, pr) 189 | 190 | 191 | 192 | scores_dict = (scores_dict/len(pred_path)).astype(int) 193 | scores_dict = con.get_scores(scores_dict) 194 | 195 | [print(a,' : ', scores_dict[a]) for a in scores_dict] -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import numpy as np 4 | import time 5 | import datetime 6 | import sys 7 | 8 | from torchvision.utils import save_image 9 | from torch.utils.data import DataLoader 10 | from torch.autograd import Variable 11 | 12 | from models import * 13 | from datasets import * 14 | 15 | import torch 16 | 17 | import albumentations as A 18 | from albumentations.pytorch.transforms import ToTensorV2 19 | from models import SwinTransformerSys 20 | 21 | parser = argparse.ArgumentParser() 22 | parser.add_argument("--epoch", type=int, default=0, help="epoch to start training from") 23 | parser.add_argument("--n_epochs", type=int, default=10, help="number of epochs of training") 24 | parser.add_argument("--root_path", type=str, default="./datasets/", help="root path") 25 | parser.add_argument("--dataset_name", type=str, default="LEVIR-CD", help="name of the dataset") 26 | parser.add_argument("--save_name", type=str, default="levir", help="name of the dataset") 27 | parser.add_argument("--batch_size", type=int, default=1, help="size of the batches") 28 | parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate") 29 | parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient") 30 | parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient") 31 | parser.add_argument("--n_cpu", type=int, default=4, help="number of cpu threads to use during batch generation") 32 | parser.add_argument("--img_height", type=int, default=256, help="size of image height") 33 | parser.add_argument("--img_width", type=int, default=256, help="size of image width") 34 | parser.add_argument("--sample_interval", type=int, default=1000, help="interval between sampling of images from generators") 35 | parser.add_argument("--n_holes", type=int, default=2, help="size of the batches") 36 | 37 | opt = parser.parse_args() 38 | print(opt) 39 | 40 | os.makedirs("images/%s" % opt.save_name, exist_ok=True) 41 | os.makedirs("saved_models/%s" % opt.save_name, exist_ok=True) 42 | 43 | cuda = True if torch.cuda.is_available() else False 44 | 45 | criterion_GAN = torch.nn.MSELoss() 46 | criterion_pixelwise = torch.nn.L1Loss() 47 | 48 | lambda_pixel = 100 49 | 50 | patch = (1, opt.img_height // 2 ** 4, opt.img_width // 2 ** 4) 51 | 52 | generator = SwinTransformerSys(img_size=256, 53 | patch_size=4, 54 | in_chans=6, 55 | num_classes=3, 56 | embed_dim=96, 57 | depths=[2, 2, 6, 2], 58 | num_heads=[3, 6, 12, 24], 59 | window_size=8, 60 | mlp_ratio=4., 61 | qkv_bias=True, 62 | qk_scale=None, 63 | drop_rate=0.0, 64 | drop_path_rate=0.1, 65 | ape=False, 66 | patch_norm=True, 67 | use_checkpoint=False) 68 | 69 | discriminator = Discriminator() 70 | 71 | if cuda: 72 | generator = generator.cuda() 73 | discriminator = discriminator.cuda() 74 | criterion_GAN.cuda() 75 | criterion_pixelwise.cuda() 76 | 77 | 78 | generator.apply(weights_init_normal) 79 | discriminator.apply(weights_init_normal) 80 | 81 | optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2)) 82 | optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2)) 83 | 84 | 85 | transforms_aug = A.Compose([ 86 | A.Resize(opt.img_height, opt.img_width), 87 | # A.ColorJitter(p=0.5), 88 | A.Normalize(), 89 | ToTensorV2() 90 | ]) 91 | 92 | transforms_ori = A.Compose([ 93 | A.Resize(opt.img_height, opt.img_width), 94 | A.Normalize(), 95 | ToTensorV2() 96 | ]) 97 | 98 | 99 | dataloader = DataLoader( 100 | CDRL_Dataset_CutSwap(root_path=opt.root_path, dataset=opt.dataset_name, train_val='train', 101 | transforms_A=transforms_aug, transforms_B=transforms_ori, n_holes=opt.n_holes), 102 | batch_size=opt.batch_size, 103 | shuffle=True, 104 | num_workers=opt.n_cpu, 105 | ) 106 | 107 | val_dataloader = DataLoader( 108 | CDRL_Dataset_CutSwap(root_path=opt.root_path, dataset=opt.dataset_name, train_val='train', 109 | transforms_A=transforms_aug, transforms_B=transforms_ori, n_holes=opt.n_holes), 110 | batch_size=10, 111 | shuffle=False, 112 | num_workers=1, 113 | ) 114 | 115 | 116 | Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor 117 | 118 | 119 | def sample_images(batches_done): 120 | imgs = next(iter(val_dataloader)) 121 | img_A = Variable(imgs["A"].type(Tensor)).cuda() 122 | img_B = Variable(imgs["B"].type(Tensor)).cuda() 123 | img_A_cutmix = Variable(imgs["A_cutmix"].type(Tensor)).cuda() 124 | img_B_cutmix = Variable(imgs["B_cutmix"].type(Tensor)).cuda() 125 | 126 | img_AB = torch.cat([img_A_cutmix,img_B], dim=1) 127 | img_B_fake = generator(img_AB) 128 | img_A = img_A[:, [2,1,0],:,:] 129 | img_A_cutmix = img_A_cutmix[:, [2,1,0],:,:] 130 | img_B_cutmix = img_B_cutmix[:, [2,1,0],:,:] 131 | img_B_fake = img_B_fake[:, [2,1,0],:,:] 132 | img_B = img_B[:, [2,1,0],:,:] 133 | img_sample = torch.cat((img_A.data, img_A_cutmix.data, img_B_fake.data, img_B.data,img_B_cutmix.data), -2) 134 | save_image(img_sample, "images/%s/%s.png" % (opt.save_name, batches_done), nrow=5, normalize=True) 135 | 136 | 137 | prev_time = time.time() 138 | 139 | for epoch in range(opt.epoch, opt.n_epochs): 140 | for i, batch in enumerate(dataloader): 141 | 142 | img_A = Variable(batch["A"].type(Tensor)) 143 | img_B = Variable(batch["B"].type(Tensor)) 144 | img_A_cutmix = Variable(batch["A_cutmix"].type(Tensor)) 145 | img_B_cutmix = Variable(batch["B_cutmix"].type(Tensor)) 146 | local = batch["local"] 147 | 148 | 149 | valid = Variable(Tensor(np.ones((img_A.size(0), *patch))), requires_grad=False) 150 | fake = Variable(Tensor(np.zeros((img_A.size(0), *patch))), requires_grad=False) 151 | 152 | # Generator 153 | optimizer_G.zero_grad() 154 | 155 | img_A = img_A.cuda() 156 | img_B = img_B.cuda() 157 | img_A_cutmix = img_A_cutmix.cuda() 158 | img_B_cutmix = img_B_cutmix.cuda() 159 | img_AB = torch.cat([img_A_cutmix,img_B_cutmix], dim=1) 160 | 161 | gener_output = generator(img_AB) 162 | 163 | gener_output_pred = discriminator(gener_output, img_A) 164 | 165 | loss_GAN = criterion_GAN(gener_output_pred, valid) 166 | 167 | loss_pixel = criterion_pixelwise(gener_output, img_A) 168 | loss_pixel_cutmix = 0 169 | for lo in local: 170 | y1, y2, x1, x2 = lo 171 | loss_pixel_cutmix += criterion_pixelwise( 172 | gener_output[:,:,y1.item(): y2.item(), x1.item(): x2.item()], 173 | img_A[:,:,y1.item(): y2.item(), x1.item(): x2.item()]) 174 | loss_pixel_cutmix = loss_pixel_cutmix/len(local) 175 | 176 | loss_pixel = (loss_pixel+loss_pixel_cutmix)/2 177 | loss_G = loss_GAN + lambda_pixel * loss_pixel 178 | 179 | loss_G.backward() 180 | optimizer_G.step() 181 | 182 | 183 | # Discriminator 184 | optimizer_D.zero_grad() 185 | 186 | pred_real = discriminator(img_B, img_A) 187 | loss_real = criterion_GAN(pred_real, valid) 188 | 189 | B_pred_fake = discriminator(gener_output.detach(), img_A) 190 | loss_fake = criterion_GAN(B_pred_fake, fake) 191 | 192 | loss_D = 0.5 * (loss_real + loss_fake) 193 | 194 | loss_D.backward() 195 | optimizer_D.step() 196 | 197 | 198 | batches_done = epoch * len(dataloader) + i 199 | batches_left = opt.n_epochs * len(dataloader) - batches_done 200 | time_left = datetime.timedelta(seconds=batches_left * (time.time() - prev_time)) 201 | prev_time = time.time() 202 | 203 | sys.stdout.write( 204 | "\r[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f, pixel: %f, adv: %f] ETA: %s" 205 | % ( 206 | epoch, 207 | opt.n_epochs, 208 | i, 209 | len(dataloader), 210 | loss_D.item(), 211 | loss_G.item(), 212 | loss_pixel.item(), 213 | loss_GAN.item(), 214 | time_left, 215 | ) 216 | ) 217 | 218 | if batches_done % opt.sample_interval == 0: 219 | sample_images(batches_done) 220 | 221 | torch.save(generator.state_dict(), "saved_models/%s/generator_%d.pth" % (opt.save_name, epoch)) 222 | torch.save(discriminator.state_dict(), "saved_models/%s/discriminator_%d.pth" % (opt.save_name, epoch)) 223 | 224 | 225 | 226 | 227 | 228 | 229 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.utils.checkpoint as checkpoint 4 | from einops import rearrange 5 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_ 6 | import numpy as np 7 | 8 | def weights_init_normal(m): 9 | classname = m.__class__.__name__ 10 | if classname.find("BasicConv") != -1: 11 | nn.init.normal_(m.conv.weight.data, 0.0, 0.02) 12 | nn.init.normal_(m.bn.weight.data, 1.0, 0.02) 13 | nn.init.constant_(m.bn.bias.data, 0.0) 14 | elif classname.find("Conv") != -1: 15 | nn.init.normal_(m.weight.data, 0.0, 0.02) 16 | elif classname.find("BatchNorm2d") != -1: 17 | nn.init.normal_(m.weight.data, 1.0, 0.02) 18 | nn.init.constant_(m.bias.data, 0.0) 19 | 20 | class Discriminator(nn.Module): 21 | def __init__(self, in_channels=3): 22 | super(Discriminator, self).__init__() 23 | 24 | def discriminator_block(in_filters, out_filters, normalization=True): 25 | """Returns downsampling layers of each discriminator block""" 26 | layers = [nn.Conv2d(in_filters, out_filters, 4, stride=2, padding=1)] 27 | if normalization: 28 | layers.append(nn.InstanceNorm2d(out_filters)) 29 | layers.append(nn.LeakyReLU(0.2, inplace=True)) 30 | return layers 31 | 32 | self.model = nn.Sequential( 33 | *discriminator_block(in_channels * 2, 64, normalization=False), 34 | *discriminator_block(64, 128), 35 | *discriminator_block(128, 256), 36 | *discriminator_block(256, 512), 37 | nn.ZeroPad2d((1, 0, 1, 0)), 38 | nn.Conv2d(512, 1, 4, padding=1, bias=False) 39 | ) 40 | 41 | def forward(self, img_A, img_B): 42 | # Concatenate image and condition image by channels to produce input 43 | img_input = torch.cat((img_A, img_B), 1) 44 | return self.model(img_input) 45 | 46 | 47 | 48 | class Mlp(nn.Module): 49 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 50 | super().__init__() 51 | out_features = out_features or in_features 52 | hidden_features = hidden_features or in_features 53 | self.fc1 = nn.Linear(in_features, hidden_features) 54 | self.act = act_layer() 55 | self.fc2 = nn.Linear(hidden_features, out_features) 56 | self.drop = nn.Dropout(drop) 57 | 58 | def forward(self, x): 59 | x = self.fc1(x) 60 | x = self.act(x) 61 | x = self.drop(x) 62 | x = self.fc2(x) 63 | x = self.drop(x) 64 | return x 65 | 66 | 67 | def window_partition(x, window_size): 68 | """ 69 | Args: 70 | x: (B, H, W, C) 71 | window_size (int): window size 72 | 73 | Returns: 74 | windows: (num_windows*B, window_size, window_size, C) 75 | """ 76 | B, H, W, C = x.shape 77 | x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) 78 | windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) 79 | return windows 80 | 81 | 82 | def window_reverse(windows, window_size, H, W): 83 | """ 84 | Args: 85 | windows: (num_windows*B, window_size, window_size, C) 86 | window_size (int): Window size 87 | H (int): Height of image 88 | W (int): Width of image 89 | 90 | Returns: 91 | x: (B, H, W, C) 92 | """ 93 | B = int(windows.shape[0] / (H * W / window_size / window_size)) 94 | x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) 95 | x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) 96 | return x 97 | 98 | 99 | class WindowAttention(nn.Module): 100 | r""" Window based multi-head self attention (W-MSA) module with relative position bias. 101 | It supports both of shifted and non-shifted window. 102 | 103 | Args: 104 | dim (int): Number of input channels. 105 | window_size (tuple[int]): The height and width of the window. 106 | num_heads (int): Number of attention heads. 107 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 108 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set 109 | attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 110 | proj_drop (float, optional): Dropout ratio of output. Default: 0.0 111 | """ 112 | 113 | def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): 114 | 115 | super().__init__() 116 | self.dim = dim 117 | self.window_size = window_size # Wh, Ww 118 | self.num_heads = num_heads 119 | head_dim = dim // num_heads 120 | self.scale = qk_scale or head_dim ** -0.5 121 | 122 | # define a parameter table of relative position bias 123 | self.relative_position_bias_table = nn.Parameter( 124 | torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH 125 | 126 | # get pair-wise relative position index for each token inside the window 127 | coords_h = torch.arange(self.window_size[0]) 128 | coords_w = torch.arange(self.window_size[1]) 129 | coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww 130 | coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww 131 | relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww 132 | relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 133 | relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 134 | relative_coords[:, :, 1] += self.window_size[1] - 1 135 | relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 136 | relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww 137 | self.register_buffer("relative_position_index", relative_position_index) 138 | 139 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 140 | self.attn_drop = nn.Dropout(attn_drop) 141 | self.proj = nn.Linear(dim, dim) 142 | self.proj_drop = nn.Dropout(proj_drop) 143 | 144 | trunc_normal_(self.relative_position_bias_table, std=.02) 145 | self.softmax = nn.Softmax(dim=-1) 146 | 147 | def forward(self, x, mask=None): 148 | """ 149 | Args: 150 | x: input features with shape of (num_windows*B, N, C) 151 | mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None 152 | """ 153 | B_, N, C = x.shape 154 | qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 155 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) 156 | 157 | q = q * self.scale 158 | attn = (q @ k.transpose(-2, -1)) 159 | 160 | relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( 161 | self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH 162 | relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww 163 | attn = attn + relative_position_bias.unsqueeze(0) 164 | 165 | if mask is not None: 166 | nW = mask.shape[0] 167 | attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) 168 | attn = attn.view(-1, self.num_heads, N, N) 169 | attn = self.softmax(attn) 170 | else: 171 | attn = self.softmax(attn) 172 | 173 | attn = self.attn_drop(attn) 174 | 175 | x = (attn @ v).transpose(1, 2).reshape(B_, N, C) 176 | x = self.proj(x) 177 | x = self.proj_drop(x) 178 | return x 179 | 180 | def extra_repr(self) -> str: 181 | return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' 182 | 183 | def flops(self, N): 184 | # calculate flops for 1 window with token length of N 185 | flops = 0 186 | # qkv = self.qkv(x) 187 | flops += N * self.dim * 3 * self.dim 188 | # attn = (q @ k.transpose(-2, -1)) 189 | flops += self.num_heads * N * (self.dim // self.num_heads) * N 190 | # x = (attn @ v) 191 | flops += self.num_heads * N * N * (self.dim // self.num_heads) 192 | # x = self.proj(x) 193 | flops += N * self.dim * self.dim 194 | return flops 195 | 196 | 197 | class SwinTransformerBlock(nn.Module): 198 | r""" Swin Transformer Block. 199 | 200 | Args: 201 | dim (int): Number of input channels. 202 | input_resolution (tuple[int]): Input resulotion. 203 | num_heads (int): Number of attention heads. 204 | window_size (int): Window size. 205 | shift_size (int): Shift size for SW-MSA. 206 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 207 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 208 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. 209 | drop (float, optional): Dropout rate. Default: 0.0 210 | attn_drop (float, optional): Attention dropout rate. Default: 0.0 211 | drop_path (float, optional): Stochastic depth rate. Default: 0.0 212 | act_layer (nn.Module, optional): Activation layer. Default: nn.GELU 213 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 214 | """ 215 | 216 | def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, 217 | mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., 218 | act_layer=nn.GELU, norm_layer=nn.LayerNorm): 219 | super().__init__() 220 | self.dim = dim 221 | self.input_resolution = input_resolution 222 | self.num_heads = num_heads 223 | self.window_size = window_size 224 | self.shift_size = shift_size 225 | self.mlp_ratio = mlp_ratio 226 | if min(self.input_resolution) <= self.window_size: 227 | # if window size is larger than input resolution, we don't partition windows 228 | self.shift_size = 0 229 | self.window_size = min(self.input_resolution) 230 | assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" 231 | 232 | self.norm1 = norm_layer(dim) 233 | self.attn = WindowAttention( 234 | dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, 235 | qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) 236 | 237 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 238 | self.norm2 = norm_layer(dim) 239 | mlp_hidden_dim = int(dim * mlp_ratio) 240 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 241 | 242 | if self.shift_size > 0: 243 | # calculate attention mask for SW-MSA 244 | H, W = self.input_resolution 245 | img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 246 | h_slices = (slice(0, -self.window_size), 247 | slice(-self.window_size, -self.shift_size), 248 | slice(-self.shift_size, None)) 249 | w_slices = (slice(0, -self.window_size), 250 | slice(-self.window_size, -self.shift_size), 251 | slice(-self.shift_size, None)) 252 | cnt = 0 253 | for h in h_slices: 254 | for w in w_slices: 255 | img_mask[:, h, w, :] = cnt 256 | cnt += 1 257 | 258 | mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 259 | mask_windows = mask_windows.view(-1, self.window_size * self.window_size) 260 | attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) 261 | attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) 262 | else: 263 | attn_mask = None 264 | 265 | self.register_buffer("attn_mask", attn_mask) 266 | 267 | def forward(self, x): 268 | H, W = self.input_resolution 269 | B, L, C = x.shape 270 | assert L == H * W, "input feature has wrong size" 271 | 272 | shortcut = x 273 | x = self.norm1(x) 274 | x = x.view(B, H, W, C) 275 | 276 | # cyclic shift 277 | if self.shift_size > 0: 278 | shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) 279 | else: 280 | shifted_x = x 281 | 282 | # partition windows 283 | x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C 284 | x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C 285 | 286 | # W-MSA/SW-MSA 287 | attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C 288 | 289 | # merge windows 290 | attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) 291 | shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C 292 | 293 | # reverse cyclic shift 294 | if self.shift_size > 0: 295 | x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) 296 | else: 297 | x = shifted_x 298 | x = x.view(B, H * W, C) 299 | 300 | # FFN 301 | x = shortcut + self.drop_path(x) 302 | x = x + self.drop_path(self.mlp(self.norm2(x))) 303 | 304 | return x 305 | 306 | def extra_repr(self) -> str: 307 | return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ 308 | f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" 309 | 310 | def flops(self): 311 | flops = 0 312 | H, W = self.input_resolution 313 | # norm1 314 | flops += self.dim * H * W 315 | # W-MSA/SW-MSA 316 | nW = H * W / self.window_size / self.window_size 317 | flops += nW * self.attn.flops(self.window_size * self.window_size) 318 | # mlp 319 | flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio 320 | # norm2 321 | flops += self.dim * H * W 322 | return flops 323 | 324 | 325 | class PatchMerging(nn.Module): 326 | r""" Patch Merging Layer. 327 | 328 | Args: 329 | input_resolution (tuple[int]): Resolution of input feature. 330 | dim (int): Number of input channels. 331 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 332 | """ 333 | 334 | def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): 335 | super().__init__() 336 | self.input_resolution = input_resolution 337 | self.dim = dim 338 | self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) 339 | self.norm = norm_layer(4 * dim) 340 | 341 | def forward(self, x): 342 | """ 343 | x: B, H*W, C 344 | """ 345 | H, W = self.input_resolution 346 | B, L, C = x.shape 347 | assert L == H * W, "input feature has wrong size" 348 | assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." 349 | 350 | x = x.view(B, H, W, C) 351 | 352 | x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C 353 | x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C 354 | x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C 355 | x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C 356 | x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C 357 | x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C 358 | 359 | x = self.norm(x) 360 | x = self.reduction(x) 361 | 362 | return x 363 | 364 | def extra_repr(self) -> str: 365 | return f"input_resolution={self.input_resolution}, dim={self.dim}" 366 | 367 | def flops(self): 368 | H, W = self.input_resolution 369 | flops = H * W * self.dim 370 | flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim 371 | return flops 372 | 373 | class PatchExpand(nn.Module): 374 | def __init__(self, input_resolution, dim, dim_scale=2, norm_layer=nn.LayerNorm): 375 | super().__init__() 376 | self.input_resolution = input_resolution 377 | self.dim = dim 378 | self.expand = nn.Linear(dim, 2*dim, bias=False) if dim_scale==2 else nn.Identity() 379 | self.norm = norm_layer(dim // dim_scale) 380 | 381 | def forward(self, x): 382 | """ 383 | x: B, H*W, C 384 | """ 385 | H, W = self.input_resolution 386 | x = self.expand(x) 387 | B, L, C = x.shape 388 | assert L == H * W, "input feature has wrong size" 389 | 390 | x = x.view(B, H, W, C) 391 | x = rearrange(x, 'b h w (p1 p2 c)-> b (h p1) (w p2) c', p1=2, p2=2, c=C//4) 392 | x = x.view(B,-1,C//4) 393 | x= self.norm(x) 394 | 395 | return x 396 | 397 | class FinalPatchExpand_X4(nn.Module): 398 | def __init__(self, input_resolution, dim, dim_scale=4, norm_layer=nn.LayerNorm): 399 | super().__init__() 400 | self.input_resolution = input_resolution 401 | self.dim = dim 402 | self.dim_scale = dim_scale 403 | self.expand = nn.Linear(dim, 16*dim, bias=False) 404 | self.output_dim = dim 405 | self.norm = norm_layer(self.output_dim) 406 | 407 | def forward(self, x): 408 | """ 409 | x: B, H*W, C 410 | """ 411 | H, W = self.input_resolution 412 | x = self.expand(x) 413 | B, L, C = x.shape 414 | assert L == H * W, "input feature has wrong size" 415 | 416 | x = x.view(B, H, W, C) 417 | x = rearrange(x, 'b h w (p1 p2 c)-> b (h p1) (w p2) c', p1=self.dim_scale, p2=self.dim_scale, c=C//(self.dim_scale**2)) 418 | x = x.view(B,-1,self.output_dim) 419 | x= self.norm(x) 420 | 421 | return x 422 | 423 | class BasicLayer(nn.Module): 424 | """ A basic Swin Transformer layer for one stage. 425 | 426 | Args: 427 | dim (int): Number of input channels. 428 | input_resolution (tuple[int]): Input resolution. 429 | depth (int): Number of blocks. 430 | num_heads (int): Number of attention heads. 431 | window_size (int): Local window size. 432 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 433 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 434 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. 435 | drop (float, optional): Dropout rate. Default: 0.0 436 | attn_drop (float, optional): Attention dropout rate. Default: 0.0 437 | drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 438 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 439 | downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None 440 | use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. 441 | """ 442 | 443 | def __init__(self, dim, input_resolution, depth, num_heads, window_size, 444 | mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., 445 | drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False): 446 | 447 | super().__init__() 448 | self.dim = dim 449 | self.input_resolution = input_resolution 450 | self.depth = depth 451 | self.use_checkpoint = use_checkpoint 452 | 453 | # build blocks 454 | self.blocks = nn.ModuleList([ 455 | SwinTransformerBlock(dim=dim, input_resolution=input_resolution, 456 | num_heads=num_heads, window_size=window_size, 457 | shift_size=0 if (i % 2 == 0) else window_size // 2, 458 | mlp_ratio=mlp_ratio, 459 | qkv_bias=qkv_bias, qk_scale=qk_scale, 460 | drop=drop, attn_drop=attn_drop, 461 | drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, 462 | norm_layer=norm_layer) 463 | for i in range(depth)]) 464 | 465 | # patch merging layer 466 | if downsample is not None: 467 | self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) 468 | else: 469 | self.downsample = None 470 | 471 | def forward(self, x): 472 | for blk in self.blocks: 473 | if self.use_checkpoint: 474 | x = checkpoint.checkpoint(blk, x) 475 | else: 476 | x = blk(x) 477 | if self.downsample is not None: 478 | x = self.downsample(x) 479 | return x 480 | 481 | def extra_repr(self) -> str: 482 | return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" 483 | 484 | def flops(self): 485 | flops = 0 486 | for blk in self.blocks: 487 | flops += blk.flops() 488 | if self.downsample is not None: 489 | flops += self.downsample.flops() 490 | return flops 491 | 492 | class BasicLayer_up(nn.Module): 493 | """ A basic Swin Transformer layer for one stage. 494 | 495 | Args: 496 | dim (int): Number of input channels. 497 | input_resolution (tuple[int]): Input resolution. 498 | depth (int): Number of blocks. 499 | num_heads (int): Number of attention heads. 500 | window_size (int): Local window size. 501 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 502 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 503 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. 504 | drop (float, optional): Dropout rate. Default: 0.0 505 | attn_drop (float, optional): Attention dropout rate. Default: 0.0 506 | drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 507 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 508 | downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None 509 | use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. 510 | """ 511 | 512 | def __init__(self, dim, input_resolution, depth, num_heads, window_size, 513 | mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., 514 | drop_path=0., norm_layer=nn.LayerNorm, upsample=None, use_checkpoint=False): 515 | 516 | super().__init__() 517 | self.dim = dim 518 | self.input_resolution = input_resolution 519 | self.depth = depth 520 | self.use_checkpoint = use_checkpoint 521 | 522 | # build blocks 523 | self.blocks = nn.ModuleList([ 524 | SwinTransformerBlock(dim=dim, input_resolution=input_resolution, 525 | num_heads=num_heads, window_size=window_size, 526 | shift_size=0 if (i % 2 == 0) else window_size // 2, 527 | mlp_ratio=mlp_ratio, 528 | qkv_bias=qkv_bias, qk_scale=qk_scale, 529 | drop=drop, attn_drop=attn_drop, 530 | drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, 531 | norm_layer=norm_layer) 532 | for i in range(depth)]) 533 | 534 | # patch merging layer 535 | if upsample is not None: 536 | self.upsample = PatchExpand(input_resolution, dim=dim, dim_scale=2, norm_layer=norm_layer) 537 | else: 538 | self.upsample = None 539 | 540 | def forward(self, x): 541 | for blk in self.blocks: 542 | if self.use_checkpoint: 543 | x = checkpoint.checkpoint(blk, x) 544 | else: 545 | x = blk(x) 546 | if self.upsample is not None: 547 | x = self.upsample(x) 548 | return x 549 | 550 | class PatchEmbed(nn.Module): 551 | r""" Image to Patch Embedding 552 | 553 | Args: 554 | img_size (int): Image size. Default: 224. 555 | patch_size (int): Patch token size. Default: 4. 556 | in_chans (int): Number of input image channels. Default: 3. 557 | embed_dim (int): Number of linear projection output channels. Default: 96. 558 | norm_layer (nn.Module, optional): Normalization layer. Default: None 559 | """ 560 | 561 | def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): 562 | super().__init__() 563 | img_size = to_2tuple(img_size) 564 | patch_size = to_2tuple(patch_size) 565 | patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] 566 | self.img_size = img_size 567 | self.patch_size = patch_size 568 | self.patches_resolution = patches_resolution 569 | self.num_patches = patches_resolution[0] * patches_resolution[1] 570 | 571 | self.in_chans = in_chans 572 | self.embed_dim = embed_dim 573 | 574 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) 575 | if norm_layer is not None: 576 | self.norm = norm_layer(embed_dim) 577 | else: 578 | self.norm = None 579 | 580 | def forward(self, x): 581 | B, C, H, W = x.shape 582 | # FIXME look at relaxing size constraints 583 | assert H == self.img_size[0] and W == self.img_size[1], \ 584 | f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." 585 | x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C 586 | if self.norm is not None: 587 | x = self.norm(x) 588 | 589 | return x 590 | 591 | def flops(self): 592 | Ho, Wo = self.patches_resolution 593 | flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) 594 | if self.norm is not None: 595 | flops += Ho * Wo * self.embed_dim 596 | return flops 597 | 598 | 599 | class SwinTransformerSys(nn.Module): 600 | r""" Swin Transformer 601 | A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - 602 | https://arxiv.org/pdf/2103.14030 603 | 604 | Args: 605 | img_size (int | tuple(int)): Input image size. Default 224 606 | patch_size (int | tuple(int)): Patch size. Default: 4 607 | in_chans (int): Number of input image channels. Default: 3 608 | num_classes (int): Number of classes for classification head. Default: 1000 609 | embed_dim (int): Patch embedding dimension. Default: 96 610 | depths (tuple(int)): Depth of each Swin Transformer layer. 611 | num_heads (tuple(int)): Number of attention heads in different layers. 612 | window_size (int): Window size. Default: 7 613 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 614 | qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True 615 | qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None 616 | drop_rate (float): Dropout rate. Default: 0 617 | attn_drop_rate (float): Attention dropout rate. Default: 0 618 | drop_path_rate (float): Stochastic depth rate. Default: 0.1 619 | norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. 620 | ape (bool): If True, add absolute position embedding to the patch embedding. Default: False 621 | patch_norm (bool): If True, add normalization after patch embedding. Default: True 622 | use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False 623 | """ 624 | 625 | def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, 626 | embed_dim=96, depths=[2, 2, 2, 2], depths_decoder=[1, 2, 2, 2], num_heads=[3, 6, 12, 24], 627 | window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None, 628 | drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, 629 | norm_layer=nn.LayerNorm, ape=False, patch_norm=True, 630 | use_checkpoint=False, final_upsample="expand_first", **kwargs): 631 | super().__init__() 632 | 633 | print("SwinTransformerSys expand initial----depths:{};depths_decoder:{};drop_path_rate:{};num_classes:{}".format(depths, 634 | depths_decoder,drop_path_rate,num_classes)) 635 | 636 | self.num_classes = num_classes 637 | self.num_layers = len(depths) 638 | self.embed_dim = embed_dim 639 | self.ape = ape 640 | self.patch_norm = patch_norm 641 | self.num_features = int(embed_dim * 2 ** (self.num_layers - 1)) 642 | self.num_features_up = int(embed_dim * 2) 643 | self.mlp_ratio = mlp_ratio 644 | self.final_upsample = final_upsample 645 | 646 | # split image into non-overlapping patches 647 | self.patch_embed = PatchEmbed( 648 | img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, 649 | norm_layer=norm_layer if self.patch_norm else None) 650 | num_patches = self.patch_embed.num_patches 651 | patches_resolution = self.patch_embed.patches_resolution 652 | self.patches_resolution = patches_resolution 653 | 654 | # absolute position embedding 655 | if self.ape: 656 | self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) 657 | trunc_normal_(self.absolute_pos_embed, std=.02) 658 | 659 | self.pos_drop = nn.Dropout(p=drop_rate) 660 | 661 | # stochastic depth 662 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule 663 | 664 | # build encoder and bottleneck layers 665 | self.layers = nn.ModuleList() 666 | for i_layer in range(self.num_layers): 667 | layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer), 668 | input_resolution=(patches_resolution[0] // (2 ** i_layer), 669 | patches_resolution[1] // (2 ** i_layer)), 670 | depth=depths[i_layer], 671 | num_heads=num_heads[i_layer], 672 | window_size=window_size, 673 | mlp_ratio=self.mlp_ratio, 674 | qkv_bias=qkv_bias, qk_scale=qk_scale, 675 | drop=drop_rate, attn_drop=attn_drop_rate, 676 | drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], 677 | norm_layer=norm_layer, 678 | downsample=PatchMerging if (i_layer < self.num_layers - 1) else None, 679 | use_checkpoint=use_checkpoint) 680 | self.layers.append(layer) 681 | 682 | # build decoder layers 683 | self.layers_up = nn.ModuleList() 684 | self.concat_back_dim = nn.ModuleList() 685 | for i_layer in range(self.num_layers): 686 | concat_linear = nn.Linear(2*int(embed_dim*2**(self.num_layers-1-i_layer)), 687 | int(embed_dim*2**(self.num_layers-1-i_layer))) if i_layer > 0 else nn.Identity() 688 | if i_layer ==0 : 689 | layer_up = PatchExpand(input_resolution=(patches_resolution[0] // (2 ** (self.num_layers-1-i_layer)), 690 | patches_resolution[1] // (2 ** (self.num_layers-1-i_layer))), dim=int(embed_dim * 2 ** (self.num_layers-1-i_layer)), dim_scale=2, norm_layer=norm_layer) 691 | else: 692 | layer_up = BasicLayer_up(dim=int(embed_dim * 2 ** (self.num_layers-1-i_layer)), 693 | input_resolution=(patches_resolution[0] // (2 ** (self.num_layers-1-i_layer)), 694 | patches_resolution[1] // (2 ** (self.num_layers-1-i_layer))), 695 | depth=depths[(self.num_layers-1-i_layer)], 696 | num_heads=num_heads[(self.num_layers-1-i_layer)], 697 | window_size=window_size, 698 | mlp_ratio=self.mlp_ratio, 699 | qkv_bias=qkv_bias, qk_scale=qk_scale, 700 | drop=drop_rate, attn_drop=attn_drop_rate, 701 | drop_path=dpr[sum(depths[:(self.num_layers-1-i_layer)]):sum(depths[:(self.num_layers-1-i_layer) + 1])], 702 | norm_layer=norm_layer, 703 | upsample=PatchExpand if (i_layer < self.num_layers - 1) else None, 704 | use_checkpoint=use_checkpoint) 705 | self.layers_up.append(layer_up) 706 | self.concat_back_dim.append(concat_linear) 707 | 708 | self.norm = norm_layer(self.num_features) 709 | self.norm_up= norm_layer(self.embed_dim) 710 | 711 | if self.final_upsample == "expand_first": 712 | print("---final upsample expand_first---") 713 | self.up = FinalPatchExpand_X4(input_resolution=(img_size//patch_size,img_size//patch_size),dim_scale=4,dim=embed_dim) 714 | self.output = nn.Conv2d(in_channels=embed_dim,out_channels=self.num_classes,kernel_size=1,bias=False) 715 | 716 | self.apply(self._init_weights) 717 | 718 | def _init_weights(self, m): 719 | if isinstance(m, nn.Linear): 720 | trunc_normal_(m.weight, std=.02) 721 | if isinstance(m, nn.Linear) and m.bias is not None: 722 | nn.init.constant_(m.bias, 0) 723 | elif isinstance(m, nn.LayerNorm): 724 | nn.init.constant_(m.bias, 0) 725 | nn.init.constant_(m.weight, 1.0) 726 | 727 | @torch.jit.ignore 728 | def no_weight_decay(self): 729 | return {'absolute_pos_embed'} 730 | 731 | @torch.jit.ignore 732 | def no_weight_decay_keywords(self): 733 | return {'relative_position_bias_table'} 734 | 735 | #Encoder and Bottleneck 736 | def forward_features(self, x): 737 | x = self.patch_embed(x) 738 | if self.ape: 739 | x = x + self.absolute_pos_embed 740 | x = self.pos_drop(x) 741 | x_downsample = [] 742 | 743 | for layer in self.layers: 744 | x_downsample.append(x) 745 | x = layer(x) 746 | 747 | x = self.norm(x) # B L C 748 | 749 | return x, x_downsample 750 | 751 | #Dencoder and Skip connection 752 | def forward_up_features(self, x, x_downsample): 753 | for inx, layer_up in enumerate(self.layers_up): 754 | if inx == 0: 755 | x = layer_up(x) 756 | else: 757 | x = torch.cat([x,x_downsample[3-inx]],-1) 758 | x = self.concat_back_dim[inx](x) 759 | x = layer_up(x) 760 | 761 | x = self.norm_up(x) # B L C 762 | 763 | return x 764 | 765 | def up_x4(self, x): 766 | H, W = self.patches_resolution 767 | B, L, C = x.shape 768 | assert L == H*W, "input features has wrong size" 769 | 770 | if self.final_upsample=="expand_first": 771 | x = self.up(x) 772 | x = x.view(B,4*H,4*W,-1) 773 | x = x.permute(0,3,1,2) #B,C,H,W 774 | x = self.output(x) 775 | 776 | return x 777 | 778 | def forward(self, x): 779 | x, x_downsample = self.forward_features(x) 780 | x = self.forward_up_features(x,x_downsample) 781 | x = self.up_x4(x) 782 | 783 | return x 784 | 785 | def flops(self): 786 | flops = 0 787 | flops += self.patch_embed.flops() 788 | for i, layer in enumerate(self.layers): 789 | flops += layer.flops() 790 | flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2 ** self.num_layers) 791 | flops += self.num_features * self.num_classes 792 | return flops 793 | 794 | 795 | 796 | 797 | --------------------------------------------------------------------------------