├── utils ├── bimap.py ├── check.py └── split_whu.py ├── test_swin.py ├── pytorch_iou └── __init__.py ├── my_scheduler.py ├── deal_evaluation.py ├── data ├── dataset_swin_whu.py ├── dataset_swin_GZ.py ├── dataset_swin_levir.py └── dataset_swin_sysu.py ├── README.md ├── pytorch_ssim └── __init__.py ├── train_swin_gz.py ├── train_swin_whu.py ├── train_swin_sysu.py ├── train_swin_levir.py ├── YTYAttention.py └── swin_ynet.py /utils/bimap.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import os 4 | import shutil 5 | 6 | # if the masks are incorrect, then use cv2.threshold to process the images 7 | refile = '../label' 8 | outPath = '../label_deal' 9 | 10 | if os.path.exists(outPath): 11 | shutil.rmtree(outPath) 12 | os.mkdir(outPath) 13 | name = os.listdir(refile) 14 | 15 | for i in range(len(name)): 16 | label_file = os.path.join(refile, name[i]) 17 | a = cv2.imread(label_file, 0) 18 | b = 2. * np.mean(a) 19 | b, photo = cv2.threshold(a, 6, 255, cv2.THRESH_BINARY) 20 | cv2.imwrite(outPath + '/' + name[i], photo) 21 | -------------------------------------------------------------------------------- /utils/check.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import os 3 | 4 | Files_path = r"../try" 5 | labels_num = len(os.listdir(Files_path)) 6 | print(labels_num) 7 | import numpy as np 8 | 9 | # check if the masks are correct 10 | for i in range(labels_num): 11 | image_dir = os.path.join(Files_path, str(os.listdir(Files_path)[i])) 12 | image_path = os.path.join(image_dir) 13 | img = Image.open(image_path) 14 | print(np.array(img).shape) 15 | print(np.array(img).shape) 16 | # if the mask is abnormal, the item will be printed 17 | if np.array(img).sum() % 5 != 0: 18 | print(np.array(img).sum()) 19 | -------------------------------------------------------------------------------- /test_swin.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import shutil 4 | from torch.utils.data import DataLoader 5 | from data.dataset_swin_levir import MyTestData 6 | from swin_ynet import Encoder 7 | import torchvision 8 | 9 | model = Encoder().cuda() 10 | model.load_state_dict(torch.load('levir_swin.pth')) 11 | 12 | test_loader = DataLoader(MyTestData(), shuffle=False, batch_size=1) 13 | 14 | outPath = 'levir_swin' 15 | if os.path.exists(outPath): 16 | shutil.rmtree(outPath) 17 | os.mkdir(outPath) 18 | 19 | with torch.no_grad(): 20 | model = model.eval() 21 | for i, (im1, im2, label_name) in enumerate(test_loader): 22 | im1 = im1.cuda() 23 | im2 = im2.cuda() 24 | label_name = label_name[0] 25 | 26 | outputs = model(im1, im2) 27 | outputs = outputs[0][0] 28 | a = outputs[0].unsqueeze(0) 29 | 30 | torchvision.utils.save_image(a, outPath + '/%s' % label_name) 31 | -------------------------------------------------------------------------------- /utils/split_whu.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import os 3 | import shutil 4 | 5 | Files_path = r"A" # original whu dataset file path, change to B/label to process each item 6 | labels_num = len(os.listdir(Files_path)) 7 | print(labels_num) 8 | 9 | outpath = 'A_224' 10 | if os.path.exists(outpath): 11 | shutil.rmtree(outpath) 12 | os.mkdir(outpath) 13 | 14 | 15 | def split(img, img_name): 16 | print(img_name) 17 | size = img.shape 18 | index = 0 19 | for i in range(size[0] // 224): 20 | for j in range(size[1] // 224): 21 | crop_img = img[i * 224:(i + 1) * 224, j * 224:(j + 1) * 224] 22 | cv2.imwrite(outpath + '/' + str(index) + '.png', crop_img) 23 | index = index + 1 24 | 25 | 26 | for i in range(labels_num): 27 | image_dir = os.path.join(Files_path, str(os.listdir(Files_path)[i])) 28 | image_path = os.path.join(image_dir) 29 | img = cv2.imread(image_path) 30 | split(img, str(os.listdir(Files_path)[i])) 31 | -------------------------------------------------------------------------------- /pytorch_iou/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch.autograd import Variable 4 | import numpy as np 5 | 6 | def _iou(pred, target, size_average = True): 7 | #print(pred.shape,target.shape) 8 | pred = pred[:,1,:,:] 9 | #print(pred.shape,'\n','------------------------','\n',target.shape) 10 | b = pred.shape[0] 11 | IoU = 0.0 12 | for i in range(0,b): 13 | #compute the IoU of the foreground 14 | Iand1 = torch.sum(target[i,:,:]*pred[i,:,:]) 15 | Ior1 = torch.sum(target[i,:,:]) + torch.sum(pred[i,:,:])-Iand1 16 | IoU1 = Iand1/Ior1 17 | 18 | #IoU loss is (1-IoU1) 19 | IoU = IoU + (1-IoU1) 20 | 21 | return IoU/b 22 | 23 | class IOU(torch.nn.Module): 24 | def __init__(self, size_average = True): 25 | super(IOU, self).__init__() 26 | self.size_average = size_average 27 | 28 | def forward(self, pred, target): 29 | 30 | return _iou(pred, target, self.size_average) 31 | -------------------------------------------------------------------------------- /my_scheduler.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | 4 | class LR_Scheduler(object): 5 | """Learning Rate Scheduler 6 | 7 | Step mode: ``lr = baselr * 0.1 ^ {floor(epoch-1 / lr_step)}`` 8 | 9 | Cosine mode: ``lr = baselr * 0.5 * (1 + cos(iter/maxiter))`` 10 | 11 | Poly mode: ``lr = baselr * (1 - iter/maxiter) ^ 0.9`` 12 | 13 | Args: 14 | args: :attr:`args.lr_scheduler` lr scheduler mode (`cos`, `poly`), 15 | :attr:`args.lr` base learning rate, :attr:`args.epochs` number of epochs, 16 | :attr:`args.lr_step` 17 | 18 | iters_per_epoch: number of iterations per epoch 19 | """ 20 | 21 | def __init__(self, mode, base_lr, num_epochs, iters_per_epoch=0, 22 | lr_step=0, warmup_epochs=0): 23 | self.mode = mode 24 | print('Using {} LR Scheduler!'.format(self.mode)) 25 | self.lr = base_lr 26 | if mode == 'step': 27 | assert lr_step 28 | self.lr_step = lr_step 29 | self.iters_per_epoch = iters_per_epoch 30 | self.N = num_epochs * iters_per_epoch 31 | self.epoch = -1 32 | self.warmup_iters = warmup_epochs * iters_per_epoch 33 | 34 | def __call__(self, optimizer, i, epoch, best_pred=0.0): 35 | T = epoch * self.iters_per_epoch + i 36 | if self.mode == 'cos': 37 | lr = 0.5 * self.lr * (1 + math.cos(1.0 * (T - self.warmup_iters) / (self.N - self.warmup_iters) * math.pi)) 38 | elif self.mode == 'poly': 39 | lr = self.lr * pow((1 - 1.0 * T / self.N), 0.9) 40 | elif self.mode == 'step': 41 | lr = self.lr * (0.1 ** (epoch // self.lr_step)) 42 | elif self.mode == 'linear': 43 | if T < self.warmup_iters: 44 | lr = self.lr 45 | else: 46 | lr = self.lr * 1.0 * (2 - T / self.warmup_iters) 47 | else: 48 | raise NotImplemented 49 | 50 | if self.warmup_iters > 0 and T < self.warmup_iters: 51 | lr = lr * 1.0 * T / self.warmup_iters 52 | if epoch > self.epoch: 53 | self.epoch = epoch 54 | assert lr >= 0 55 | self._adjust_learning_rate(optimizer, lr) 56 | 57 | def _adjust_learning_rate(self, optimizer, lr): 58 | if len(optimizer.param_groups) == 1: 59 | optimizer.param_groups[0]['lr'] = lr 60 | else: 61 | optimizer.param_groups[0]['lr'] = lr * 0.1 62 | for i in range(1, len(optimizer.param_groups)): 63 | optimizer.param_groups[i]['lr'] = lr 64 | -------------------------------------------------------------------------------- /deal_evaluation.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import os 4 | import shutil 5 | 6 | refile = 'levir_swin' # file obtained by test_swin.py 7 | outPath = 'levir_swin_deal' # temp file 8 | if os.path.exists(outPath): 9 | shutil.rmtree(outPath) 10 | os.mkdir(outPath) 11 | name = os.listdir(refile) 12 | 13 | for i in range(len(name)): 14 | label_file = os.path.join(refile, name[i]) 15 | a = cv2.imread(label_file, 0) 16 | 17 | b = 2. * np.mean(a) 18 | b, photo = cv2.threshold(a, 20, 255, cv2.THRESH_BINARY) 19 | cv2.imwrite(outPath + '/' + name[i], photo) 20 | 21 | Files_path = outPath 22 | labels_num = len(os.listdir(Files_path)) 23 | print(labels_num) 24 | 25 | outPath = 'try' # temp file2 26 | if os.path.exists(outPath): 27 | shutil.rmtree(outPath) 28 | os.mkdir(outPath) 29 | 30 | for i in range(labels_num): 31 | image_dir = os.path.join(Files_path, str(os.listdir(Files_path)[i])) 32 | image_path = os.path.join(image_dir) 33 | img = cv2.imread(image_path) 34 | img = img[:, :, 0] 35 | lb0 = cv2.merge([img * 255.]) 36 | cv2.imwrite(outPath + '/' + str(os.listdir(Files_path)[i]), lb0) 37 | 38 | gt_path = 'xxx/label' # corresponding gt path 39 | pred_path = outPath 40 | labels_num = len(os.listdir(gt_path)) 41 | 42 | print(labels_num) 43 | 44 | 45 | def P_R_IoU(gt, pred): 46 | predict_precision = 0 47 | predict_recall = 0 48 | tp = 0 49 | tn = 0 50 | for k in range(len(gt)): 51 | 52 | img1 = gt[k] 53 | img2 = pred[k] 54 | for i in range(224): 55 | for j in range(224): 56 | if not (int(img1[i, j]) - int(img2[i, j])) and img2[i, j] == 255: 57 | tp += 1 # TP value 58 | if img1[i, j] == img2[i, j] and img2[i, j] == 0: 59 | tn += 1 # TN value 60 | 61 | predict_precision += np.sum(np.reshape(img2, (img2.size,))) / 255 62 | predict_recall += np.sum(np.reshape(img1, (img1.size,))) / 255 63 | print(k) 64 | predict_iou = predict_precision + predict_recall - tp 65 | return tp / predict_precision, tp / predict_recall, tp / predict_iou, (tn + tp) / (predict_iou + tn) 66 | 67 | 68 | def f1_score(precision, recall): 69 | return 2 * precision * recall / (precision + recall) 70 | 71 | 72 | def get_average(list_): 73 | sum_ = 0 74 | for _ in list_: 75 | sum_ += _ 76 | return sum_ / len(list_) 77 | 78 | 79 | gt_list = [] 80 | pred_list = [] 81 | 82 | for i in range(labels_num): 83 | gt = os.path.join(gt_path, str(os.listdir(gt_path)[i])) 84 | pred = os.path.join(pred_path, str(os.listdir(pred_path)[i])) 85 | gt_path1 = os.path.join(gt) 86 | pred_path1 = os.path.join(pred) 87 | gt1 = cv2.imread(gt_path1, flags=0) 88 | pred1 = cv2.imread(pred_path1, flags=0) 89 | gt_list.append(gt1) 90 | pred_list.append(pred1) 91 | 92 | precision_res, recall_res, iou_res, OA = P_R_IoU(gt_list, pred_list) 93 | 94 | f1_score = f1_score(precision_res, recall_res) 95 | print('precision:', precision_res, 'recall:', recall_res, '\n', 'f1:', f1_score, 'IoU:', iou_res, 'OA:', OA) 96 | -------------------------------------------------------------------------------- /data/dataset_swin_whu.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | from torchvision import transforms 3 | from PIL import Image 4 | import os 5 | 6 | 7 | class MyData(Dataset): 8 | def __init__(self): 9 | super(MyData, self).__init__() 10 | self.train_im_path1 = 'data/whu224_train/A' 11 | self.train_im_path2 = 'data/whu224_train/B' 12 | self.train_lb_path = 'data/whu224_train/label' 13 | self.train_im_num = 8874 14 | self.train_imgs1 = os.listdir(self.train_im_path1) 15 | self.train_imgs2 = os.listdir(self.train_im_path2) 16 | self.train_labels = os.listdir(self.train_lb_path) 17 | 18 | def __len__(self): 19 | return self.train_im_num 20 | 21 | def __getitem__(self, index): 22 | img_file1 = os.path.join(self.train_im_path1, self.train_imgs1[index]) 23 | img1 = Image.open(img_file1) 24 | img_file2 = os.path.join(self.train_im_path2, self.train_imgs2[index]) 25 | img2 = Image.open(img_file2) 26 | label_file = os.path.join(self.train_lb_path, self.train_labels[index]) 27 | labels = Image.open(label_file) 28 | 29 | im1, im2, lb0, lb1, lb2, lb3 = self.transform(img1, img2, labels) 30 | lb0 = 1. - lb0[0] 31 | lb1 = 1. - lb1[0] 32 | lb2 = 1. - lb2[0] 33 | lb3 = 1. - lb3[0] 34 | 35 | return im1, im2, lb0, lb1, lb2, lb3 # ,lb4,lb5,lb6,lb7#,edge0,edge1,edge2,edge3 36 | 37 | def transform(self, img1, img2, label): 38 | transform_img = transforms.Compose([ 39 | transforms.ToTensor(), 40 | ]) 41 | transform_img_4 = transforms.Compose([transforms.Resize((56, 56), Image.NEAREST), 42 | transforms.ToTensor(), 43 | ]) 44 | transform_img_8 = transforms.Compose([transforms.Resize((28, 28), Image.NEAREST), 45 | transforms.ToTensor(), 46 | ]) 47 | transform_img_16 = transforms.Compose([transforms.Resize((14, 14), Image.NEAREST), 48 | transforms.ToTensor(), 49 | ]) 50 | 51 | transform_img_2 = transforms.Compose([ 52 | transforms.ToTensor(), 53 | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), 54 | ]) 55 | im1 = transform_img_2(img1) 56 | im2 = transform_img_2(img2) 57 | label0 = transform_img(label) 58 | label_4 = transform_img_4(label) 59 | label_8 = transform_img_8(label) 60 | label_16 = transform_img_16(label) 61 | return im1, im2, label0, label_4, label_8, label_16 62 | 63 | 64 | class MyTestData(Dataset): 65 | def __init__(self): 66 | super(MyTestData, self).__init__() 67 | self.train_im_path1 = 'data/WHU-CD/whu224_test/A' 68 | self.train_im_path2 = 'data/WHU-CD/whu224_test/B' 69 | self.train_lb_path = 'data/WHU-CD/whu224_test/label' 70 | self.train_im_num = 986 71 | self.train_imgs1 = os.listdir(self.train_im_path1) 72 | self.train_imgs2 = os.listdir(self.train_im_path2) 73 | self.train_labels = os.listdir(self.train_lb_path) 74 | 75 | def __len__(self): 76 | return self.train_im_num 77 | 78 | def __getitem__(self, index): 79 | img_file1 = os.path.join(self.train_im_path1, self.train_imgs1[index]) 80 | img1 = Image.open(img_file1) 81 | img_file2 = os.path.join(self.train_im_path2, self.train_imgs2[index]) 82 | img2 = Image.open(img_file2) 83 | label_file = str(self.train_labels[index][:-4]) + '.png' 84 | 85 | im1, im2 = self.transform(img1, img2) 86 | 87 | return im1, im2, label_file 88 | 89 | def transform(self, img1, img2): 90 | transform_img_2 = transforms.Compose([ 91 | transforms.ToTensor(), 92 | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), 93 | ]) 94 | im1 = transform_img_2(img1) 95 | im2 = transform_img_2(img2) 96 | return im1, im2 97 | -------------------------------------------------------------------------------- /data/dataset_swin_GZ.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | from torchvision import transforms 3 | from PIL import Image 4 | import os 5 | 6 | 7 | class MyData(Dataset): 8 | def __init__(self): 9 | super(MyData, self).__init__() 10 | self.train_im_path1 = 'data/CD_Data_GZ/GZ_train_224/A' 11 | self.train_im_path2 = 'data/CD_Data_GZ/GZ_train_224/B' 12 | self.train_lb_path = 'data/CD_Data_GZ/GZ_train_224/label' 13 | self.train_im_num = 3743 14 | self.train_imgs1 = os.listdir(self.train_im_path1) 15 | self.train_imgs2 = os.listdir(self.train_im_path2) 16 | self.train_labels = os.listdir(self.train_lb_path) 17 | 18 | def __len__(self): 19 | return self.train_im_num 20 | 21 | def __getitem__(self, index): 22 | img_file1 = os.path.join(self.train_im_path1, self.train_imgs1[index]) 23 | img1 = Image.open(img_file1) 24 | img_file2 = os.path.join(self.train_im_path2, self.train_imgs2[index]) 25 | img2 = Image.open(img_file2) 26 | label_file = os.path.join(self.train_lb_path, self.train_labels[index]) 27 | labels = Image.open(label_file) 28 | 29 | im1, im2, lb0, lb1, lb2, lb3 = self.transform(img1, img2, labels) 30 | lb0 = 1. - lb0[0] 31 | lb1 = 1. - lb1[0] 32 | lb2 = 1. - lb2[0] 33 | lb3 = 1. - lb3[0] 34 | 35 | return im1, im2, lb0, lb1, lb2, lb3 # ,lb4,lb5,lb6,lb7#,edge0,edge1,edge2,edge3 36 | 37 | def transform(self, img1, img2, label): 38 | transform_img = transforms.Compose([ 39 | transforms.ToTensor(), 40 | ]) 41 | transform_img_4 = transforms.Compose([transforms.Resize((56, 56), Image.NEAREST), 42 | transforms.ToTensor(), 43 | ]) 44 | transform_img_8 = transforms.Compose([transforms.Resize((28, 28), Image.NEAREST), 45 | transforms.ToTensor(), 46 | ]) 47 | transform_img_16 = transforms.Compose([transforms.Resize((14, 14), Image.NEAREST), 48 | transforms.ToTensor(), 49 | ]) 50 | 51 | transform_img_2 = transforms.Compose([ 52 | transforms.ToTensor(), 53 | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), 54 | ]) 55 | im1 = transform_img_2(img1) 56 | im2 = transform_img_2(img2) 57 | label0 = transform_img(label) 58 | label_4 = transform_img_4(label) 59 | label_8 = transform_img_8(label) 60 | label_16 = transform_img_16(label) 61 | return im1, im2, label0, label_4, label_8, label_16 62 | 63 | 64 | class MyTestData(Dataset): 65 | def __init__(self): 66 | super(MyTestData, self).__init__() 67 | self.train_im_path1 = 'data/CD_Data_GZ/GZ_test_224/A' 68 | self.train_im_path2 = 'data/CD_Data_GZ/GZ_test_224/B' 69 | self.train_lb_path = 'data/CD_Data_GZ/GZ_test_224/label' 70 | self.train_im_num = 415 71 | self.train_imgs1 = os.listdir(self.train_im_path1) 72 | self.train_imgs2 = os.listdir(self.train_im_path2) 73 | self.train_labels = os.listdir(self.train_lb_path) 74 | 75 | def __len__(self): 76 | return self.train_im_num 77 | 78 | def __getitem__(self, index): 79 | img_file1 = os.path.join(self.train_im_path1, self.train_imgs1[index]) 80 | img1 = Image.open(img_file1) 81 | img_file2 = os.path.join(self.train_im_path2, self.train_imgs2[index]) 82 | img2 = Image.open(img_file2) 83 | label_file = str(self.train_labels[index][:-4]) + '.png' 84 | 85 | im1, im2 = self.transform(img1, img2) 86 | 87 | return im1, im2, label_file 88 | 89 | def transform(self, img1, img2): 90 | transform_img_2 = transforms.Compose([ 91 | transforms.ToTensor(), 92 | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), 93 | ]) 94 | im1 = transform_img_2(img1) 95 | im2 = transform_img_2(img2) 96 | return im1, im2 97 | -------------------------------------------------------------------------------- /data/dataset_swin_levir.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | from torchvision import transforms 3 | from PIL import Image 4 | import os 5 | 6 | 7 | class MyData(Dataset): 8 | def __init__(self): 9 | super(MyData, self).__init__() 10 | self.train_im_path1 = 'data/LEVIR-CD/levir224/train_224/A' 11 | self.train_im_path2 = 'data/LEVIR-CD/levir224/train_224/B' 12 | self.train_lb_path = 'data/LEVIR-CD/levir224/train_224/label' 13 | self.train_im_num = 12725 14 | self.train_imgs1 = os.listdir(self.train_im_path1) 15 | self.train_imgs2 = os.listdir(self.train_im_path2) 16 | self.train_labels = os.listdir(self.train_lb_path) 17 | 18 | def __len__(self): 19 | return self.train_im_num 20 | 21 | def __getitem__(self, index): 22 | img_file1 = os.path.join(self.train_im_path1, self.train_imgs1[index]) 23 | img1 = Image.open(img_file1) 24 | img_file2 = os.path.join(self.train_im_path2, self.train_imgs2[index]) 25 | img2 = Image.open(img_file2) 26 | label_file = os.path.join(self.train_lb_path, self.train_labels[index]) 27 | labels = Image.open(label_file) 28 | 29 | im1, im2, lb0, lb1, lb2, lb3 = self.transform(img1, img2, labels) 30 | lb0 = 1. - lb0[0] 31 | lb1 = 1. - lb1[0] 32 | lb2 = 1. - lb2[0] 33 | lb3 = 1. - lb3[0] 34 | return im1, im2, lb0, lb1, lb2, lb3 # ,lb4,lb5,lb6,lb7#,edge0,edge1,edge2,edge3 35 | 36 | def transform(self, img1, img2, label): 37 | transform_img = transforms.Compose([ 38 | transforms.ToTensor(), 39 | ]) 40 | transform_img_4 = transforms.Compose([transforms.Resize((56, 56), Image.NEAREST), 41 | transforms.ToTensor(), 42 | ]) 43 | transform_img_8 = transforms.Compose([transforms.Resize((28, 28), Image.NEAREST), 44 | transforms.ToTensor(), 45 | ]) 46 | transform_img_16 = transforms.Compose([transforms.Resize((14, 14), Image.NEAREST), 47 | transforms.ToTensor(), 48 | ]) 49 | 50 | transform_img_2 = transforms.Compose([ 51 | transforms.ToTensor(), 52 | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), 53 | ]) 54 | im1 = transform_img_2(img1) 55 | im2 = transform_img_2(img2) 56 | label0 = transform_img(label) 57 | label_4 = transform_img_4(label) 58 | label_8 = transform_img_8(label) 59 | label_16 = transform_img_16(label) 60 | return im1, im2, label0, label_4, label_8, label_16 61 | 62 | 63 | class MyTestData(Dataset): 64 | def __init__(self): 65 | super(MyTestData, self).__init__() 66 | self.train_im_path1 = 'data/LEVIR-CD/levir224/test_224/A' 67 | self.train_im_path2 = 'data/LEVIR-CD/levir224/test_224/B' 68 | self.train_lb_path = 'data/LEVIR-CD/levir224/test_224/label' 69 | self.train_im_num = 3200 70 | self.train_imgs1 = os.listdir(self.train_im_path1) 71 | self.train_imgs2 = os.listdir(self.train_im_path2) 72 | self.train_labels = os.listdir(self.train_lb_path) 73 | 74 | def __len__(self): 75 | return self.train_im_num 76 | 77 | def __getitem__(self, index): 78 | img_file1 = os.path.join(self.train_im_path1, self.train_imgs1[index]) 79 | img1 = Image.open(img_file1) 80 | img_file2 = os.path.join(self.train_im_path2, self.train_imgs2[index]) 81 | img2 = Image.open(img_file2) 82 | label_file = str(self.train_labels[index][:-4]) + '.png' 83 | 84 | im1, im2 = self.transform(img1, img2) 85 | 86 | return im1, im2, label_file 87 | 88 | def transform(self, img1, img2): 89 | transform_img_2 = transforms.Compose([ 90 | transforms.ToTensor(), 91 | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), 92 | ]) 93 | im1 = transform_img_2(img1) 94 | im2 = transform_img_2(img2) 95 | return im1, im2 96 | -------------------------------------------------------------------------------- /data/dataset_swin_sysu.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | from torchvision import transforms 3 | from PIL import Image 4 | import os 5 | 6 | 7 | class MyData(Dataset): 8 | def __init__(self): 9 | super(MyData, self).__init__() 10 | self.train_im_path1 = 'data/SYSU256_train/A' 11 | self.train_im_path2 = 'data/SYSU256_train/B' 12 | self.train_lb_path = 'data/SYSU256_train/label' 13 | self.train_im_num = 16000 14 | self.train_imgs1 = os.listdir(self.train_im_path1) 15 | self.train_imgs2 = os.listdir(self.train_im_path2) 16 | self.train_labels = os.listdir(self.train_lb_path) 17 | 18 | def __len__(self): 19 | return self.train_im_num 20 | 21 | def __getitem__(self, index): 22 | img_file1 = os.path.join(self.train_im_path1, self.train_imgs1[index]) 23 | img1 = Image.open(img_file1) 24 | img_file2 = os.path.join(self.train_im_path2, self.train_imgs2[index]) 25 | img2 = Image.open(img_file2) 26 | label_file = os.path.join(self.train_lb_path, self.train_labels[index]) 27 | labels = Image.open(label_file) 28 | 29 | im1, im2, lb0, lb1, lb2, lb3 = self.transform(img1, img2, labels) 30 | lb0 = 1. - lb0[0] 31 | lb1 = 1. - lb1[0] 32 | lb2 = 1. - lb2[0] 33 | lb3 = 1. - lb3[0] 34 | 35 | return im1, im2, lb0, lb1, lb2, lb3 # ,lb4,lb5,lb6,lb7#,edge0,edge1,edge2,edge3 36 | 37 | def transform(self, img1, img2, label): 38 | transform_img = transforms.Compose([transforms.Resize((224, 224), Image.NEAREST), 39 | transforms.ToTensor(), 40 | ]) 41 | transform_img_4 = transforms.Compose([transforms.Resize((56, 56), Image.NEAREST), 42 | transforms.ToTensor(), 43 | ]) 44 | transform_img_8 = transforms.Compose([transforms.Resize((28, 28), Image.NEAREST), 45 | transforms.ToTensor(), 46 | ]) 47 | transform_img_16 = transforms.Compose([transforms.Resize((14, 14), Image.NEAREST), 48 | transforms.ToTensor(), 49 | ]) 50 | 51 | transform_img_2 = transforms.Compose([ 52 | transforms.Resize((224, 224)), 53 | transforms.ToTensor(), 54 | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), 55 | ]) 56 | im1 = transform_img_2(img1) 57 | im2 = transform_img_2(img2) 58 | label0 = transform_img(label) 59 | label_4 = transform_img_4(label) 60 | label_8 = transform_img_8(label) 61 | label_16 = transform_img_16(label) 62 | return im1, im2, label0, label_4, label_8, label_16 63 | 64 | 65 | class MyTestData(Dataset): 66 | def __init__(self): 67 | super(MyTestData, self).__init__() 68 | self.train_im_path1 = 'data/sysu_test256/A256' 69 | self.train_im_path2 = 'data/sysu_test256/B256' 70 | self.train_lb_path = 'data/sysu_test256/label256' 71 | self.train_im_num = 4000 72 | self.train_imgs1 = os.listdir(self.train_im_path1) 73 | self.train_imgs2 = os.listdir(self.train_im_path2) 74 | self.train_labels = os.listdir(self.train_lb_path) 75 | 76 | def __len__(self): 77 | return self.train_im_num 78 | 79 | def __getitem__(self, index): 80 | img_file1 = os.path.join(self.train_im_path1, self.train_imgs1[index]) 81 | img1 = Image.open(img_file1) 82 | img_file2 = os.path.join(self.train_im_path2, self.train_imgs2[index]) 83 | img2 = Image.open(img_file2) 84 | label_file = str(self.train_labels[index][:-4]) + '.png' 85 | im1, im2 = self.transform(img1, img2) 86 | return im1, im2, label_file 87 | 88 | def transform(self, img1, img2): 89 | transform_img_2 = transforms.Compose([ 90 | transforms.Resize((224, 224)), 91 | transforms.ToTensor(), 92 | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), 93 | ]) 94 | im1 = transform_img_2(img1) 95 | im2 = transform_img_2(img2) 96 | return im1, im2 97 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Fully Transformer Network for Change Detection of Remote Sensing Images 2 | **** 3 | 4 | Paper Links: [Fully Transformer Network for Change Detection of Remote Sensing Images 5 | ](https://openaccess.thecvf.com/content/ACCV2022/html/Yan_Fully_Transformer_Network_for_Change_Detection_of_Remote_Sensing_Images_ACCV_2022_paper.html) 6 | 7 | by [Tianyu Yan](), [Zifu Wan](), [Pingping Zhang*](https://scholar.google.com/citations?user=MfbIbuEAAAAJ&hl=zh-CN). 8 | 9 | ## Introduction 10 | **** 11 | Recently, change detection (CD) of remote sensing images have achieved great progress with the advances of deep learning. However, current methods generally deliver incomplete CD regions and irregular CD boundaries due to the limited representation ability of the extracted visual features. To relieve these issues, in this work we propose a novel learning framework named Fully Transformer Network (FTN) for remote sensing image CD, which improves the feature extraction from a global view and combines multi-level visual features in a pyramid manner. More specifically, the proposed framework first utilizes the advantages of Transformers in long-range dependency modeling. It can help to learn more discriminative global-level features and obtain complete CD regions. Then, we introduce a pyramid structure to aggregate multi-level visual features from Transformers for feature enhancement. The pyramid structure grafted with a Progressive Attention Module (PAM) can improve the feature representation ability with additional interdependencies through channel attentions. Finally, to better train the framework, we utilize the deeply-supervised learning with multiple boundaryaware loss functions. Extensive experiments demonstrate that our proposed method achieves a new state-of-the-art performance on four public CD benchmarks. 12 | 13 | ## Update 14 | **** 15 | 16 | * 03/17/2023: The code has been updated. 17 | 18 | ## Requirements 19 | **** 20 | * python 3.5+ 21 | * PyTorch 1.1+ 22 | * torchvision 23 | * Numpy 24 | * tqdm 25 | * OpenCV 26 | 27 | ## Preperations 28 | **** 29 | 30 | For using the codes, please download the public change detection datasets 31 | (more details are provided in the paper) : 32 | * LEVIR-CD 33 | * WHU-CD 34 | * SYSU-CD 35 | * Google-CD 36 | 37 | The processed datasets can be downloaded at this [link](https://drive.google.com/drive/folders/1Knqdxb6g8_7NFKqeHgnp-iUEMRemGCNh?usp=share_link). 38 | 39 | Then, run the following codes with your GPUs, and you can get the same results in the above paper. 40 | 41 | 42 | ## Usage 43 | **** 44 | 45 | ### 1. Download pre-trained Swin Transformer models 46 | * [Get models in this link](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22kto1k.pth): SwinB pre-trained on ImageNet22K 47 | 48 | 49 | ### 2. Prepare data 50 | 51 | * Please use *utils/split.py* to split the images to 224*224 first. 52 | * Use *utils/check.py* to check if the labels are binary form. Info will be printed if your label form is incorrect. 53 | * Use *utils/bimap.py* if the labels are not binary. 54 | * You may need to move the aforementioned files to corresponding places. 55 | 56 | ### 3. Train/Test 57 | 58 | - For training, run: 59 | 60 | ```bash 61 | python train_(name of the dataset).py 62 | ``` 63 | 64 | [//]: # (- If you want to use the SSIM and IOU loss function with CrossEntropy loss funtion together, you just need to remove comments in train.py (below the CrossEntropy loss) and add the loss operation in the loss calculation place.)) 65 | [//]: # (- Especially, when you calculate the IOU loss, you need to convert the images (convert 0->1, 1->0). Because the image pixels values are mostly 0, and it will influence the IOU loss calculation (Based on IOU loss characteristic).) 66 | 67 | - For prediction, run: 68 | ```bash 69 | python test_swin.py 70 | ``` 71 | 72 | - For evaluation, run: 73 | ```bash 74 | python deal_evaluation.py 75 | ``` 76 | 77 | ## Reference 78 | **** 79 | 80 | * [Swin Transformer](https://github.com/microsoft/Swin-Transformer) 81 | 82 | ## Contact 83 | **** 84 | 85 | If you have any problems. Please concat 86 | 87 | QQ: 1580329199 88 | 89 | Email: tianyuyan2001@gmail.com or wanzifu2000@gmail.com 90 | 91 | ## Citation 92 | **** 93 | 94 | If you find our work helpful to your research, please cite with: 95 | 96 | ```bibtex 97 | @InProceedings{Yan_2022_ACCV, 98 | author = {Yan, Tianyu and Wan, Zifu and Zhang, Pingping}, 99 | title = {Fully Transformer Network for Change Detection of Remote Sensing Images}, 100 | booktitle = {Proceedings of the Asian Conference on Computer Vision (ACCV)}, 101 | month = {December}, 102 | year = {2022}, 103 | pages = {1691-1708} 104 | } 105 | ``` 106 | -------------------------------------------------------------------------------- /pytorch_ssim/__init__.py: -------------------------------------------------------------------------------- 1 | # https://github.com/Po-Hsun-Su/pytorch-ssim/blob/master/pytorch_ssim/__init__.py 2 | import torch 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | import numpy as np 6 | from math import exp 7 | 8 | def gaussian(window_size, sigma): 9 | gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)]) 10 | return gauss/gauss.sum() 11 | 12 | def create_window(window_size, channel): 13 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 14 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 15 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) 16 | return window 17 | 18 | def _ssim(img1, img2, window, window_size, channel, size_average = True): 19 | mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel) 20 | mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel) 21 | 22 | mu1_sq = mu1.pow(2) 23 | mu2_sq = mu2.pow(2) 24 | mu1_mu2 = mu1*mu2 25 | 26 | sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq 27 | sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq 28 | sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2 29 | 30 | C1 = 0.01**2 31 | C2 = 0.03**2 32 | 33 | ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2)) 34 | 35 | if size_average: 36 | return ssim_map.mean() 37 | else: 38 | return ssim_map.mean(1).mean(1).mean(1) 39 | 40 | class SSIM(torch.nn.Module): 41 | def __init__(self, window_size = 11, size_average = True): 42 | super(SSIM, self).__init__() 43 | self.window_size = window_size 44 | self.size_average = size_average 45 | self.channel = 1 46 | self.window = create_window(window_size, self.channel) 47 | 48 | def forward(self, img1, img2): 49 | img1 = img1[:,1,:,:].unsqueeze(1) 50 | img2 = img2.unsqueeze(1) 51 | (_, channel, _, _) = img1.size() 52 | 53 | if channel == self.channel and self.window.data.type() == img1.data.type(): 54 | window = self.window 55 | else: 56 | window = create_window(self.window_size, channel) 57 | 58 | if img1.is_cuda: 59 | window = window.cuda(img1.get_device()) 60 | window = window.type_as(img1) 61 | 62 | self.window = window 63 | self.channel = channel 64 | 65 | 66 | return _ssim(img1, img2, window, self.window_size, channel, self.size_average) 67 | 68 | def _logssim(img1, img2, window, window_size, channel, size_average = True): 69 | mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel) 70 | mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel) 71 | 72 | mu1_sq = mu1.pow(2) 73 | mu2_sq = mu2.pow(2) 74 | mu1_mu2 = mu1*mu2 75 | 76 | sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq 77 | sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq 78 | sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2 79 | 80 | C1 = 0.01**2 81 | C2 = 0.03**2 82 | 83 | ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2)) 84 | ssim_map = (ssim_map - torch.min(ssim_map))/(torch.max(ssim_map)-torch.min(ssim_map)) 85 | ssim_map = -torch.log(ssim_map + 1e-8) 86 | 87 | if size_average: 88 | return ssim_map.mean() 89 | else: 90 | return ssim_map.mean(1).mean(1).mean(1) 91 | 92 | class LOGSSIM(torch.nn.Module): 93 | def __init__(self, window_size = 11, size_average = True): 94 | super(LOGSSIM, self).__init__() 95 | self.window_size = window_size 96 | self.size_average = size_average 97 | self.channel = 1 98 | self.window = create_window(window_size, self.channel) 99 | 100 | def forward(self, img1, img2): 101 | (_, channel, _, _) = img1.size() 102 | 103 | if channel == self.channel and self.window.data.type() == img1.data.type(): 104 | window = self.window 105 | else: 106 | window = create_window(self.window_size, channel) 107 | 108 | if img1.is_cuda: 109 | window = window.cuda(img1.get_device()) 110 | window = window.type_as(img1) 111 | 112 | self.window = window 113 | self.channel = channel 114 | 115 | 116 | return _logssim(img1, img2, window, self.window_size, channel, self.size_average) 117 | 118 | 119 | def ssim(img1, img2, window_size = 11, size_average = True): 120 | (_, channel, _, _) = img1.size() 121 | window = create_window(window_size, channel) 122 | 123 | if img1.is_cuda: 124 | window = window.cuda(img1.get_device()) 125 | window = window.type_as(img1) 126 | 127 | return _ssim(img1, img2, window, window_size, channel, size_average) 128 | -------------------------------------------------------------------------------- /train_swin_gz.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import tqdm 4 | from my_scheduler import LR_Scheduler 5 | from swin_ynet import Encoder 6 | from data.dataset_swin_GZ import MyData 7 | import torch.optim as optim 8 | from torch.utils.data import DataLoader 9 | import warnings 10 | 11 | warnings.filterwarnings("ignore") 12 | model = Encoder().cuda() 13 | 14 | import pytorch_iou 15 | import pytorch_ssim 16 | 17 | deal = nn.Softmax(dim=1) 18 | 19 | 20 | def all_loss(pred, gt): 21 | ce_loss = nn.CrossEntropyLoss() 22 | ssim_loss = pytorch_ssim.SSIM(window_size=11, size_average=True).cuda() 23 | iou_loss = pytorch_iou.IOU().cuda() 24 | ce_out = ce_loss(pred, gt.long()) 25 | ssim_out = 1 - ssim_loss(deal(pred), gt) 26 | iou_out = iou_loss(deal(pred), gt) 27 | loss = ce_out + ssim_out + iou_out 28 | return loss 29 | 30 | 31 | model = model.train() 32 | ce_loss = nn.CrossEntropyLoss() 33 | ssim_loss = pytorch_ssim.SSIM(window_size=7, size_average=True).cuda() 34 | iou_loss = pytorch_iou.IOU().cuda() 35 | LR = 0.01 36 | LR_VGG = 0.00001 37 | EPOCH = 80 38 | scheduler = LR_Scheduler('cos', LR, EPOCH, 3743 // 10 + 1) 39 | 40 | optimizer = optim.SGD(model.parameters(), lr=LR, momentum=0.9, weight_decay=0.0005, nesterov=False) 41 | 42 | 43 | def make_optimizer(LR, model): 44 | params = [] 45 | for key, value in model.named_parameters(): 46 | if not value.requires_grad: 47 | continue 48 | if "encoder1" in key: 49 | lr = LR * 0.1 50 | else: 51 | lr = LR 52 | params += [{"params": [value], "lr": lr}] 53 | optimizer = getattr(torch.optim, "SGD")(params, momentum=0.9, weight_decay=0.0005, nesterov=False) 54 | return optimizer 55 | 56 | 57 | train_loader = DataLoader(MyData(), 58 | shuffle=True, 59 | batch_size=10, 60 | pin_memory=True, 61 | num_workers=16, 62 | ) 63 | 64 | losses0 = 0 65 | losses1 = 0 66 | losses2 = 0 67 | losses3 = 0 68 | losses4 = 0 69 | losses5 = 0 70 | losses6 = 0 71 | losses7 = 0 72 | losses8 = 0 73 | losses9 = 0 74 | losses10 = 0 75 | losses11 = 0 76 | 77 | print(len(train_loader)) 78 | 79 | 80 | def adjust_learning_rate(optimizer, epoch, start_lr): 81 | if epoch % 20 == 0: # epoch != 0 and 82 | for param_group in optimizer.param_groups: 83 | param_group["lr"] = param_group["lr"] * 0.1 84 | print(param_group["lr"]) 85 | 86 | 87 | loss_least = 100000 88 | for epoch_num in range(EPOCH): 89 | print(epoch_num) 90 | adjust_learning_rate(optimizer, epoch_num, LR) 91 | print('LR is:', optimizer.state_dict()['param_groups'][0]['lr']) 92 | show_dict = {'epoch': epoch_num} 93 | 94 | loss_all = 0 95 | for i_batch, (im1, im2, label0, label1, label2, label3) in enumerate( 96 | tqdm.tqdm(train_loader, ncols=60, postfix=show_dict)): # ,edge0,edge1,edge2,edge3 97 | im1 = im1.cuda() 98 | im2 = im2.cuda() 99 | label0 = label0.cuda() 100 | label1 = label1.cuda() 101 | label2 = label2.cuda() 102 | label3 = label3.cuda() 103 | 104 | outputs = model(im1, im2) 105 | 106 | loss0 = ce_loss(outputs[0], label0.long()) 107 | loss1 = ce_loss(outputs[1], label1.long()) 108 | loss2 = ce_loss(outputs[2], label2.long()) 109 | loss3 = ce_loss(outputs[3], label3.long()) 110 | 111 | loss4 = 1. - ssim_loss(deal(outputs[0]), label0) 112 | loss5 = 1. - ssim_loss(deal(outputs[1]), label1) 113 | loss6 = 1. - ssim_loss(deal(outputs[2]), label2) 114 | loss7 = 1. - ssim_loss(deal(outputs[3]), label3) 115 | 116 | loss8 = iou_loss(deal(outputs[0]), label0) 117 | loss9 = iou_loss(deal(outputs[1]), label1) 118 | loss10 = iou_loss(deal(outputs[2]), label2) 119 | loss11 = iou_loss(deal(outputs[3]), label3) 120 | 121 | loss = loss0 + loss1 + loss2 + loss3 + loss4 + loss5 + loss6 + loss7 + loss8 + loss9 + loss10 + loss11 122 | loss_all += loss 123 | 124 | losses0 += loss0 125 | losses1 += loss1 126 | losses2 += loss2 127 | losses3 += loss3 128 | losses4 += loss4 129 | losses5 += loss5 130 | losses6 += loss6 131 | losses7 += loss7 132 | losses8 += loss8 133 | losses9 += loss9 134 | losses10 += loss10 135 | losses11 += loss11 136 | 137 | optimizer.zero_grad() 138 | loss.backward() 139 | optimizer.step() 140 | if i_batch % 100 == 0: 141 | print(i_batch, '|', 'losses0: {:.3f}'.format(losses0.data), '|', 'losses1: {:.3f}'.format(losses1.data), 142 | '|', 'losses2: {:.3f}'.format(losses2.data), '|', 'losses3: {:.3f}'.format(losses3.data), '|', 143 | 'losses4: {:.3f}'.format(losses4.data), '|', 'losses5: {:.3f}'.format(losses5.data), '|', 144 | 'losses6: {:.3f}'.format(losses6.data), '|', 'losses7: {:.3f}'.format(losses7.data), 145 | 'losses8: {:.3f}'.format(losses8.data), '|', 'losses9: {:.3f}'.format(losses9.data), '|', 146 | 'losses10: {:.3f}'.format(losses10.data), '|', 'losses11: {:.3f}'.format(losses11.data)) 147 | 148 | losses0 = 0 149 | losses1 = 0 150 | losses2 = 0 151 | losses3 = 0 152 | losses4 = 0 153 | losses5 = 0 154 | losses6 = 0 155 | losses7 = 0 156 | losses8 = 0 157 | losses9 = 0 158 | losses10 = 0 159 | losses11 = 0 160 | 161 | if loss_all <= loss_least: 162 | loss_least = loss_all 163 | torch.save(model.state_dict(), 'new_try3.pth') 164 | print('\n', 'epoch:', epoch_num, 'epoch loss:', loss_all) 165 | -------------------------------------------------------------------------------- /train_swin_whu.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import tqdm 4 | from my_scheduler import LR_Scheduler 5 | from swin_ynet import Encoder 6 | from data.dataset_swin_whu import MyData 7 | import torch.optim as optim 8 | from torch.utils.data import DataLoader 9 | import warnings 10 | 11 | warnings.filterwarnings("ignore") 12 | model = Encoder().cuda() 13 | 14 | import pytorch_iou 15 | import pytorch_ssim 16 | 17 | deal = nn.Softmax(dim=1) 18 | 19 | 20 | def all_loss(pred, gt): 21 | ce_loss = nn.CrossEntropyLoss() 22 | ssim_loss = pytorch_ssim.SSIM(window_size=11, size_average=True).cuda() 23 | iou_loss = pytorch_iou.IOU().cuda() 24 | ce_out = ce_loss(pred, gt.long()) 25 | ssim_out = 1 - ssim_loss(deal(pred), gt) 26 | iou_out = iou_loss(deal(pred), gt) 27 | loss = ce_out + ssim_out + iou_out 28 | return loss 29 | 30 | 31 | model = model.train() 32 | ce_loss = nn.CrossEntropyLoss() 33 | ssim_loss = pytorch_ssim.SSIM(window_size=7, size_average=True).cuda() 34 | iou_loss = pytorch_iou.IOU().cuda() 35 | LR = 0.01 36 | LR_VGG = 0.00001 37 | EPOCH = 80 38 | scheduler = LR_Scheduler('cos', LR, EPOCH, 8874 // 10 + 1) 39 | 40 | optimizer = optim.SGD(model.parameters(), lr=LR, momentum=0.9, weight_decay=0.0005, nesterov=False) 41 | 42 | 43 | def make_optimizer(LR, model): 44 | params = [] 45 | for key, value in model.named_parameters(): 46 | if not value.requires_grad: 47 | continue 48 | if "encoder1" in key: 49 | lr = LR * 0.1 50 | else: 51 | lr = LR 52 | params += [{"params": [value], "lr": lr}] 53 | optimizer = getattr(torch.optim, "SGD")(params, momentum=0.9, weight_decay=0.0005, nesterov=False) 54 | return optimizer 55 | 56 | 57 | train_loader = DataLoader(MyData(), 58 | shuffle=True, 59 | batch_size=10, 60 | pin_memory=True, 61 | num_workers=16, 62 | ) 63 | 64 | losses0 = 0 65 | losses1 = 0 66 | losses2 = 0 67 | losses3 = 0 68 | losses4 = 0 69 | losses5 = 0 70 | losses6 = 0 71 | losses7 = 0 72 | losses8 = 0 73 | losses9 = 0 74 | losses10 = 0 75 | losses11 = 0 76 | 77 | print(len(train_loader)) 78 | 79 | def adjust_learning_rate(optimizer, epoch, start_lr): 80 | if epoch % 20 == 0: # epoch != 0 and 81 | for param_group in optimizer.param_groups: 82 | param_group["lr"] = param_group["lr"] * 0.1 83 | print(param_group["lr"]) 84 | 85 | 86 | loss_least = 100000 87 | for epoch_num in range(EPOCH): 88 | print(epoch_num) 89 | adjust_learning_rate(optimizer, epoch_num, LR) 90 | print('LR is:', optimizer.state_dict()['param_groups'][0]['lr']) 91 | show_dict = {'epoch': epoch_num} 92 | 93 | loss_all = 0 94 | for i_batch, (im1, im2, label0, label1, label2, label3) in enumerate( 95 | tqdm.tqdm(train_loader, ncols=60, postfix=show_dict)): # ,edge0,edge1,edge2,edge3 96 | im1 = im1.cuda() 97 | im2 = im2.cuda() 98 | label0 = label0.cuda() 99 | label1 = label1.cuda() 100 | label2 = label2.cuda() 101 | label3 = label3.cuda() 102 | 103 | outputs = model(im1, im2) 104 | 105 | loss0 = ce_loss(outputs[0], label0.long()) 106 | loss1 = ce_loss(outputs[1], label1.long()) 107 | loss2 = ce_loss(outputs[2], label2.long()) 108 | loss3 = ce_loss(outputs[3], label3.long()) 109 | 110 | loss4 = 1. - ssim_loss(deal(outputs[0]), label0) 111 | loss5 = 1. - ssim_loss(deal(outputs[1]), label1) 112 | loss6 = 1. - ssim_loss(deal(outputs[2]), label2) 113 | loss7 = 1. - ssim_loss(deal(outputs[3]), label3) 114 | 115 | loss8 = iou_loss(deal(outputs[0]), label0) 116 | loss9 = iou_loss(deal(outputs[1]), label1) 117 | loss10 = iou_loss(deal(outputs[2]), label2) 118 | loss11 = iou_loss(deal(outputs[3]), label3) 119 | 120 | loss = loss0 + loss1 + loss2 + loss3 + loss4 + loss5 + loss6 + loss7 + loss8 + loss9 + loss10 + loss11 121 | loss_all += loss 122 | 123 | losses0 += loss0 124 | losses1 += loss1 125 | losses2 += loss2 126 | losses3 += loss3 127 | losses4 += loss4 128 | losses5 += loss5 129 | losses6 += loss6 130 | losses7 += loss7 131 | losses8 += loss8 132 | losses9 += loss9 133 | losses10 += loss10 134 | losses11 += loss11 135 | 136 | optimizer.zero_grad() 137 | 138 | loss.backward() 139 | optimizer.step() 140 | if i_batch % 100 == 0: 141 | print(i_batch, '|', 'losses0: {:.3f}'.format(losses0.data), '|', 'losses1: {:.3f}'.format(losses1.data), 142 | '|', 'losses2: {:.3f}'.format(losses2.data), '|', 'losses3: {:.3f}'.format(losses3.data), '|', 143 | 'losses4: {:.3f}'.format(losses4.data), '|', 'losses5: {:.3f}'.format(losses5.data), '|', 144 | 'losses6: {:.3f}'.format(losses6.data), '|', 'losses7: {:.3f}'.format(losses7.data), 145 | 'losses8: {:.3f}'.format(losses8.data), '|', 'losses9: {:.3f}'.format(losses9.data), '|', 146 | 'losses10: {:.3f}'.format(losses10.data), '|', 'losses11: {:.3f}'.format(losses11.data)) 147 | 148 | losses0 = 0 149 | losses1 = 0 150 | losses2 = 0 151 | losses3 = 0 152 | losses4 = 0 153 | losses5 = 0 154 | losses6 = 0 155 | losses7 = 0 156 | losses8 = 0 157 | losses9 = 0 158 | losses10 = 0 159 | losses11 = 0 160 | 161 | if loss_all <= loss_least: 162 | loss_least = loss_all 163 | torch.save(model.state_dict(), 'whu_swin.pth') 164 | print('\n', 'epoch:', epoch_num, 'epoch loss:', loss_all) 165 | 166 | -------------------------------------------------------------------------------- /train_swin_sysu.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import tqdm 4 | from my_scheduler import LR_Scheduler 5 | from swin_ynet import Encoder 6 | from data.dataset_swin_sysu import MyData 7 | import torch.optim as optim 8 | from torch.utils.data import DataLoader 9 | import warnings 10 | 11 | warnings.filterwarnings("ignore") 12 | model = Encoder().cuda() 13 | 14 | import pytorch_iou 15 | import pytorch_ssim 16 | 17 | deal = nn.Softmax(dim=1) 18 | 19 | 20 | def all_loss(pred, gt): 21 | ce_loss = nn.CrossEntropyLoss() 22 | ssim_loss = pytorch_ssim.SSIM(window_size=11, size_average=True).cuda() 23 | iou_loss = pytorch_iou.IOU().cuda() 24 | ce_out = ce_loss(pred, gt.long()) 25 | ssim_out = 1 - ssim_loss(deal(pred), gt) 26 | iou_out = iou_loss(deal(pred), gt) 27 | loss = ce_out + ssim_out + iou_out 28 | return loss 29 | 30 | 31 | model = model.train() 32 | ce_loss = nn.CrossEntropyLoss() 33 | ssim_loss = pytorch_ssim.SSIM(window_size=7, size_average=True).cuda() 34 | iou_loss = pytorch_iou.IOU().cuda() 35 | LR = 0.01 36 | LR_VGG = 0.00001 37 | EPOCH = 80 38 | scheduler = LR_Scheduler('cos', LR, EPOCH, 16000 // 10) 39 | 40 | optimizer = optim.SGD(model.parameters(), lr=LR, momentum=0.9, weight_decay=0.0005, nesterov=False) 41 | 42 | 43 | def make_optimizer(LR, model): 44 | params = [] 45 | for key, value in model.named_parameters(): 46 | # print(key) 47 | if not value.requires_grad: 48 | continue 49 | if "encoder1" in key: 50 | lr = LR * 0.1 51 | else: 52 | lr = LR 53 | params += [{"params": [value], "lr": lr}] 54 | optimizer = getattr(torch.optim, "SGD")(params, momentum=0.9, weight_decay=0.0005, nesterov=False) 55 | return optimizer 56 | 57 | 58 | train_loader = DataLoader(MyData(), 59 | shuffle=True, 60 | batch_size=10, 61 | pin_memory=True, 62 | num_workers=16, 63 | ) 64 | 65 | losses0 = 0 66 | losses1 = 0 67 | losses2 = 0 68 | losses3 = 0 69 | losses4 = 0 70 | losses5 = 0 71 | losses6 = 0 72 | losses7 = 0 73 | losses8 = 0 74 | losses9 = 0 75 | losses10 = 0 76 | losses11 = 0 77 | 78 | print(len(train_loader)) 79 | 80 | 81 | def adjust_learning_rate(optimizer, epoch, start_lr): 82 | if epoch % 20 == 0: # epoch != 0 and 83 | for param_group in optimizer.param_groups: 84 | param_group["lr"] = param_group["lr"] * 0.1 85 | print(param_group["lr"]) 86 | 87 | 88 | loss_least = 100000 89 | for epoch_num in range(EPOCH): 90 | print(epoch_num) 91 | adjust_learning_rate(optimizer, epoch_num, LR) 92 | print('LR is:', optimizer.state_dict()['param_groups'][0]['lr']) 93 | show_dict = {'epoch': epoch_num} 94 | 95 | loss_all = 0 96 | for i_batch, (im1, im2, label0, label1, label2, label3) in enumerate( 97 | tqdm.tqdm(train_loader, ncols=60, postfix=show_dict)): # ,edge0,edge1,edge2,edge3 98 | im1 = im1.cuda() 99 | im2 = im2.cuda() 100 | label0 = label0.cuda() 101 | label1 = label1.cuda() 102 | label2 = label2.cuda() 103 | label3 = label3.cuda() 104 | 105 | outputs = model(im1, im2) 106 | 107 | loss0 = ce_loss(outputs[0], label0.long()) 108 | loss1 = ce_loss(outputs[1], label1.long()) 109 | loss2 = ce_loss(outputs[2], label2.long()) 110 | loss3 = ce_loss(outputs[3], label3.long()) 111 | 112 | loss4 = 1. - ssim_loss(deal(outputs[0]), label0) 113 | loss5 = 1. - ssim_loss(deal(outputs[1]), label1) 114 | loss6 = 1. - ssim_loss(deal(outputs[2]), label2) 115 | loss7 = 1. - ssim_loss(deal(outputs[3]), label3) 116 | 117 | loss8 = iou_loss(deal(outputs[0]), label0) 118 | loss9 = iou_loss(deal(outputs[1]), label1) 119 | loss10 = iou_loss(deal(outputs[2]), label2) 120 | loss11 = iou_loss(deal(outputs[3]), label3) 121 | 122 | loss = loss0 + loss1 + loss2 + loss3 + loss4 + loss5 + loss6 + loss7 + loss8 + loss9 + loss10 + loss11 123 | loss_all += loss 124 | 125 | losses0 += loss0 126 | losses1 += loss1 127 | losses2 += loss2 128 | losses3 += loss3 129 | losses4 += loss4 130 | losses5 += loss5 131 | losses6 += loss6 132 | losses7 += loss7 133 | losses8 += loss8 134 | losses9 += loss9 135 | losses10 += loss10 136 | losses11 += loss11 137 | 138 | optimizer.zero_grad() 139 | loss.backward() 140 | optimizer.step() 141 | if i_batch % 100 == 0: 142 | print(i_batch, '|', 'losses0: {:.3f}'.format(losses0.data), '|', 'losses1: {:.3f}'.format(losses1.data), 143 | '|', 'losses2: {:.3f}'.format(losses2.data), '|', 'losses3: {:.3f}'.format(losses3.data), '|', 144 | 'losses4: {:.3f}'.format(losses4.data), '|', 'losses5: {:.3f}'.format(losses5.data), '|', 145 | 'losses6: {:.3f}'.format(losses6.data), '|', 'losses7: {:.3f}'.format(losses7.data), 146 | 'losses8: {:.3f}'.format(losses8.data), '|', 'losses9: {:.3f}'.format(losses9.data), '|', 147 | 'losses10: {:.3f}'.format(losses10.data), '|', 'losses11: {:.3f}'.format(losses11.data)) 148 | 149 | losses0 = 0 150 | losses1 = 0 151 | losses2 = 0 152 | losses3 = 0 153 | losses4 = 0 154 | losses5 = 0 155 | losses6 = 0 156 | losses7 = 0 157 | losses8 = 0 158 | losses9 = 0 159 | losses10 = 0 160 | losses11 = 0 161 | 162 | if loss_all <= loss_least: 163 | loss_least = loss_all 164 | torch.save(model.state_dict(), 'sysu_swin.pth') 165 | print('\n', 'epoch:', epoch_num, 'epoch loss:', loss_all) 166 | -------------------------------------------------------------------------------- /train_swin_levir.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import tqdm 4 | from my_scheduler import LR_Scheduler 5 | from swin_ynet import Encoder 6 | from data.dataset_swin_levir import MyData 7 | import torch.optim as optim 8 | from torch.utils.data import DataLoader 9 | import warnings 10 | 11 | warnings.filterwarnings("ignore") 12 | model = Encoder().cuda() 13 | 14 | import pytorch_iou 15 | import pytorch_ssim 16 | 17 | deal = nn.Softmax(dim=1) 18 | 19 | 20 | def all_loss(pred, gt): 21 | ce_loss = nn.CrossEntropyLoss() 22 | ssim_loss = pytorch_ssim.SSIM(window_size=11, size_average=True).cuda() 23 | iou_loss = pytorch_iou.IOU().cuda() 24 | ce_out = ce_loss(pred, gt.long()) 25 | ssim_out = 1 - ssim_loss(deal(pred), gt) 26 | iou_out = iou_loss(deal(pred), gt) 27 | loss = ce_out + ssim_out + iou_out 28 | return loss 29 | 30 | 31 | model = model.train() 32 | ce_loss = nn.CrossEntropyLoss() 33 | ssim_loss = pytorch_ssim.SSIM(window_size=7, size_average=True).cuda() 34 | iou_loss = pytorch_iou.IOU().cuda() 35 | LR = 0.01 36 | LR_VGG = 0.00001 37 | EPOCH = 80 38 | scheduler = LR_Scheduler('cos', LR, EPOCH, 12725 // 10 + 1) 39 | 40 | optimizer = optim.SGD(model.parameters(), lr=LR, momentum=0.9, weight_decay=0.0005, nesterov=False) 41 | 42 | 43 | def make_optimizer(LR, model): 44 | params = [] 45 | for key, value in model.named_parameters(): 46 | if not value.requires_grad: 47 | continue 48 | if "encoder1" in key: 49 | lr = LR * 0.1 50 | else: 51 | lr = LR 52 | params += [{"params": [value], "lr": lr}] 53 | optimizer = getattr(torch.optim, "SGD")(params, momentum=0.9, weight_decay=0.0005, nesterov=False) 54 | return optimizer 55 | 56 | 57 | train_loader = DataLoader(MyData(), 58 | shuffle=True, 59 | batch_size=10, 60 | pin_memory=True, 61 | num_workers=16, 62 | ) 63 | 64 | losses0 = 0 65 | losses1 = 0 66 | losses2 = 0 67 | losses3 = 0 68 | losses4 = 0 69 | losses5 = 0 70 | losses6 = 0 71 | losses7 = 0 72 | losses8 = 0 73 | losses9 = 0 74 | losses10 = 0 75 | losses11 = 0 76 | 77 | print(len(train_loader)) 78 | 79 | 80 | def adjust_learning_rate(optimizer, epoch, start_lr): 81 | if epoch % 20 == 0: # epoch != 0 and 82 | for param_group in optimizer.param_groups: 83 | param_group["lr"] = param_group["lr"] * 0.1 84 | print(param_group["lr"]) 85 | 86 | 87 | loss_least = 100000 88 | for epoch_num in range(EPOCH): 89 | print(epoch_num) 90 | adjust_learning_rate(optimizer, epoch_num, LR) 91 | print('LR is:', optimizer.state_dict()['param_groups'][0]['lr']) 92 | show_dict = {'epoch': epoch_num} 93 | 94 | loss_all = 0 95 | for i_batch, (im1, im2, label0, label1, label2, label3) in enumerate( 96 | tqdm.tqdm(train_loader, ncols=60, postfix=show_dict)): # ,edge0,edge1,edge2,edge3 97 | im1 = im1.cuda() 98 | im2 = im2.cuda() 99 | label0 = label0.cuda() 100 | label1 = label1.cuda() 101 | label2 = label2.cuda() 102 | label3 = label3.cuda() 103 | 104 | outputs = model(im1, im2) 105 | 106 | loss0 = ce_loss(outputs[0], label0.long()) 107 | loss1 = ce_loss(outputs[1], label1.long()) 108 | loss2 = ce_loss(outputs[2], label2.long()) 109 | loss3 = ce_loss(outputs[3], label3.long()) 110 | 111 | loss4 = 1. - ssim_loss(deal(outputs[0]), label0) 112 | loss5 = 1. - ssim_loss(deal(outputs[1]), label1) 113 | loss6 = 1. - ssim_loss(deal(outputs[2]), label2) 114 | loss7 = 1. - ssim_loss(deal(outputs[3]), label3) 115 | 116 | loss8 = iou_loss(deal(outputs[0]), label0) 117 | loss9 = iou_loss(deal(outputs[1]), label1) 118 | loss10 = iou_loss(deal(outputs[2]), label2) 119 | loss11 = iou_loss(deal(outputs[3]), label3) 120 | 121 | loss = loss0 + loss1 + loss2 + loss3 + loss4 + loss5 + loss6 + loss7 + loss8 + loss9 + loss10 + loss11 122 | loss_all += loss 123 | 124 | losses0 += loss0 125 | losses1 += loss1 126 | losses2 += loss2 127 | losses3 += loss3 128 | losses4 += loss4 129 | losses5 += loss5 130 | losses6 += loss6 131 | losses7 += loss7 132 | losses8 += loss8 133 | losses9 += loss9 134 | losses10 += loss10 135 | losses11 += loss11 136 | 137 | optimizer.zero_grad() 138 | # scheduler(optimizer,i_batch,epoch_num) 139 | loss.backward() 140 | optimizer.step() 141 | if i_batch % 100 == 0: 142 | print(i_batch, '|', 'losses0: {:.3f}'.format(losses0.data), '|', 'losses1: {:.3f}'.format(losses1.data), 143 | '|', 'losses2: {:.3f}'.format(losses2.data), '|', 'losses3: {:.3f}'.format(losses3.data), '|', 144 | 'losses4: {:.3f}'.format(losses4.data), '|', 'losses5: {:.3f}'.format(losses5.data), '|', 145 | 'losses6: {:.3f}'.format(losses6.data), '|', 'losses7: {:.3f}'.format(losses7.data), 146 | 'losses8: {:.3f}'.format(losses8.data), '|', 'losses9: {:.3f}'.format(losses9.data), '|', 147 | 'losses10: {:.3f}'.format(losses10.data), '|', 'losses11: {:.3f}'.format(losses11.data)) 148 | 149 | losses0 = 0 150 | losses1 = 0 151 | losses2 = 0 152 | losses3 = 0 153 | losses4 = 0 154 | losses5 = 0 155 | losses6 = 0 156 | losses7 = 0 157 | losses8 = 0 158 | losses9 = 0 159 | losses10 = 0 160 | losses11 = 0 161 | 162 | if loss_all <= loss_least: 163 | loss_least = loss_all 164 | torch.save(model.state_dict(), 'levir_swin.pth') 165 | print('\n', 'epoch:', epoch_num, 'epoch loss:', loss_all) 166 | -------------------------------------------------------------------------------- /YTYAttention.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import nn 4 | from torch.nn import init 5 | import math 6 | 7 | import math 8 | 9 | 10 | def autopad(k, p=None): # kernel, padding 11 | # Pad to 'same' 12 | if p is None: 13 | p = k // 2 if isinstance(k, int) else [x // 2 for x in k] # auto-pad 14 | return p 15 | 16 | 17 | class Conv(nn.Module): 18 | # Standard convolution 19 | def __init__(self, c1, c2, k=3, s=1, p=None, g=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups 20 | super().__init__() 21 | self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=False) 22 | self.bn = nn.BatchNorm2d(c2) 23 | self.act = nn.SiLU() if act is True else (act if isinstance(act, nn.Module) else nn.Identity()) 24 | 25 | def forward(self, x): 26 | return self.act(self.bn(self.conv(x))) 27 | 28 | def fuseforward(self, x): 29 | return self.act(self.conv(x)) 30 | 31 | 32 | class CrossConv(nn.Module): 33 | # Cross Convolution Downsample 34 | def __init__(self, c1, c2, k=3, s=1, g=1, e=1.0, shortcut=False): 35 | # ch_in, ch_out, kernel, stride, groups, expansion, shortcut 36 | super().__init__() 37 | c_ = int(c2 * e) # hidden channels 38 | self.cv1 = Conv(c1, c_) # , (1, k), (1, s)) 39 | self.cv2 = Conv(c_, c2) # , (k, 1), (s, 1), g=g) 40 | self.add = shortcut and c1 == c2 41 | 42 | def forward(self, x): 43 | return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x)) 44 | 45 | 46 | class ppattention(nn.Module): 47 | def __init__(self, in_planes, ratio=16): 48 | super().__init__() 49 | self.conv = CrossConv(2 * in_planes, in_planes) 50 | self.avg_pool = nn.AdaptiveAvgPool2d(1) # 输出最后两维1*1 51 | self.max_pool = nn.AdaptiveMaxPool2d(1) 52 | # b,h,w,c --- n*c #n,,c,h,w ---- n*c 53 | self.fc = nn.Sequential(nn.Conv2d(in_planes, in_planes // ratio, 1, bias=False), 54 | nn.SiLU(), 55 | nn.Conv2d(in_planes // ratio, in_planes, 1, bias=False)) 56 | self.sigmoid = nn.Sigmoid() 57 | self.bnnorm = nn.BatchNorm2d(in_planes) 58 | 59 | def forward(self, x): 60 | x = self.conv(x) # 2-->1,CROSS CONV 61 | res = x 62 | avg_out = self.fc(self.avg_pool(x)) 63 | max_out = self.fc(self.max_pool(x)) 64 | out = avg_out + max_out 65 | attn = self.sigmoid(out) 66 | result = x * attn + res 67 | return result 68 | 69 | 70 | class ppattention_wan(nn.Module): 71 | def __init__(self, in_planes, ratio=16): 72 | super().__init__() 73 | self.conv = CrossConv(2 * in_planes, in_planes) 74 | self.avg_pool = nn.AdaptiveAvgPool2d(1) # 输出最后两维1*1 75 | self.max_pool = nn.AdaptiveMaxPool2d(1) 76 | # b,h,w,c --- n*c #n,,c,h,w ---- n*c 77 | self.fc = nn.Sequential(nn.Conv2d(in_planes, in_planes // ratio, 1, bias=False), 78 | nn.SiLU(), 79 | nn.Conv2d(in_planes // ratio, in_planes, 1, bias=False)) 80 | self.fc2 = nn.Conv2d(1, 1, 1, bias=False) 81 | self.sigmoid = nn.Sigmoid() 82 | self.bnnorm = nn.BatchNorm2d(in_planes) 83 | 84 | def forward(self, x): 85 | x = self.conv(x) # 2-->1,CROSS CONV 86 | res = x 87 | avg_out = self.fc(self.avg_pool(x)) 88 | # max_out = self.fc(self.max_pool(x)) 89 | # out = avg_out + max_out 90 | attn = self.sigmoid(avg_out) 91 | x_channel_summation = torch.sum(x, dim=1, keepdim=True) 92 | attn_channel_summation = self.sigmoid(self.fc2(x_channel_summation)) 93 | result = x * attn + res + attn_channel_summation * res 94 | return result 95 | 96 | 97 | class DFE(nn.Module): 98 | def __init__(self, in_planes): 99 | super().__init__() 100 | self.fc = nn.Sequential(nn.Conv2d(in_planes, in_planes // 2, 1, bias=False), 101 | nn.BatchNorm2d(in_planes // 2), 102 | nn.SiLU()) 103 | 104 | def forward(self, x): 105 | result = self.fc(x) 106 | return result 107 | 108 | 109 | class YTYAttention(nn.Module): 110 | 111 | def __init__(self, channel=512, reduction=16, im_channel=49): 112 | super().__init__() 113 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 114 | self.fc = nn.Sequential( 115 | nn.Linear(channel, channel // reduction, bias=False), 116 | nn.SiLU(), 117 | nn.Linear(channel // reduction, channel, bias=False), 118 | ) 119 | self.sigmoid = nn.Sigmoid() 120 | self.fc1 = nn.Linear(im_channel, 1, bias=False) 121 | self.fc2 = nn.Linear(im_channel, 1, bias=False) 122 | 123 | def init_weights(self): 124 | for m in self.modules(): 125 | if isinstance(m, nn.Conv2d): 126 | init.kaiming_normal_(m.weight, mode='fan_out') 127 | if m.bias is not None: 128 | init.constant_(m.bias, 0) 129 | elif isinstance(m, nn.BatchNorm2d): 130 | init.constant_(m.weight, 1) 131 | init.constant_(m.bias, 0) 132 | elif isinstance(m, nn.Linear): 133 | init.normal_(m.weight, std=0.001) 134 | if m.bias is not None: 135 | init.constant_(m.bias, 0) 136 | 137 | def forward(self, im1, im2): 138 | # b, c, _, _ = x.size() 139 | img = torch.cat([im1, im2], dim=2) 140 | origin = img 141 | im1 = im1.transpose(1, 2) 142 | im2 = im2.transpose(1, 2) 143 | im1 = self.fc1(im1) # 1,512,1 144 | im2 = self.fc2(im2) # 1,512,1 145 | 146 | im = torch.cat([im1, im2], dim=2) # 1,512,2 147 | im = torch.transpose(im, 1, 2) # 1,2,512 148 | im = self.fc(im) # 1,2,512 149 | im1 = im[:, 0, :].unsqueeze(1) 150 | im2 = im[:, 1, :].unsqueeze(1) 151 | im = torch.cat([im1, im2], dim=2) 152 | im = im.transpose(1, 2) 153 | im = self.sigmoid(im) 154 | img = img.transpose(1, 2) 155 | res = img * im.expand_as(img) 156 | res = res.transpose(1, 2) 157 | return res 158 | 159 | 160 | class TYAttention(nn.Module): 161 | def __init__(self, in_channel=1024, im_channel=49, gamma=2, b=1): 162 | super().__init__() 163 | self.fc1 = nn.Linear(im_channel, 1, bias=False) 164 | self.fc2 = nn.Linear(im_channel, 1, bias=False) 165 | self.SiLU = nn.SiLU() 166 | self.bn = nn.BatchNorm1d(2) 167 | self.gamma = gamma 168 | self.b = b 169 | self.sigmoid = nn.Sigmoid() 170 | self.in_channel = in_channel 171 | t = int(abs((math.log(self.in_channel, 2) + self.b) / self.gamma)) 172 | k = t if t % 2 else t + 1 173 | self.conv = nn.Conv1d(2, 2, kernel_size=k, padding=int(k / 2), bias=False) 174 | self.conv2 = nn.Conv1d(2, 2, kernel_size=k, padding=int(k / 2), bias=False) 175 | 176 | def init_weights(self): 177 | for m in self.modules(): 178 | if isinstance(m, nn.Conv1d): 179 | init.kaiming_normal_(m.weight, mode='fan_out') 180 | if m.bias is not None: 181 | init.constant_(m.bias, 0) 182 | elif isinstance(m, nn.BatchNorm2d): 183 | init.constant_(m.weight, 1) 184 | init.constant_(m.bias, 0) 185 | elif isinstance(m, nn.Linear): 186 | init.normal_(m.weight, std=0.001) 187 | if m.bias is not None: 188 | init.constant_(m.bias, 0) 189 | 190 | def forward(self, im1, im2): 191 | img = torch.cat([im1, im1], dim=2) # B,49,2048 192 | img = img.transpose(1, 2) 193 | origin = img 194 | im1 = im1.transpose(-1, -2) # B,1024,49 195 | im2 = im2.transpose(-1, -2) 196 | 197 | im1 = self.fc1(im1) # B,1024,1 198 | im2 = self.fc2(im2) 199 | im = torch.cat([im1, im2], dim=2) # B,1024,2 200 | im = self.conv(im.transpose(-1, -2)) # B,2,1024 201 | im = self.SiLU(im) 202 | # im = self.bn(im) 203 | im = self.conv2(im) 204 | im1 = im[:, 0, :].unsqueeze(1) 205 | im2 = im[:, 1, :].unsqueeze(1) 206 | im = torch.cat([im1, im2], dim=2) # B,1,2048 207 | im = im.transpose(1, 2) 208 | im = self.sigmoid(im) 209 | res = img * im.expand_as(img) + origin 210 | return res.transpose(1, 2) 211 | 212 | 213 | class ChannelAttention(nn.Module): 214 | def __init__(self, in_planes, ratio=16): 215 | super(ChannelAttention, self).__init__() 216 | self.conv = nn.Conv2d(3 * in_planes, in_planes, 1, 1, 0) 217 | self.bn = nn.BatchNorm2d(in_planes) 218 | self.SiLU = nn.SiLU() 219 | self.avg_pool = nn.AdaptiveAvgPool2d(1) # 输出最后两维1*1 220 | self.max_pool = nn.AdaptiveMaxPool2d(1) 221 | 222 | self.fc = nn.Sequential(nn.Conv2d(in_planes, in_planes // ratio, 1, bias=False), 223 | nn.SiLU(), 224 | nn.Conv2d(in_planes // ratio, in_planes, 1, bias=False)) 225 | self.sigmoid = nn.Sigmoid() 226 | self.bnnorm = nn.BatchNorm2d(in_planes) 227 | 228 | def forward(self, x): 229 | x = self.conv(x) # 2-1 230 | x = self.bn(x) 231 | x = self.SiLU(x) 232 | res = x 233 | avg_out = self.fc(self.avg_pool(x)) 234 | max_out = self.fc(self.max_pool(x)) 235 | out = avg_out + max_out 236 | attn = self.sigmoid(out) 237 | out = x * attn + res 238 | return out 239 | 240 | 241 | class ChannelAttention_1(nn.Module): 242 | def __init__(self, in_planes, ratio=16): 243 | super(ChannelAttention_1, self).__init__() 244 | self.conv = nn.Conv2d(2 * in_planes, in_planes, 3, 1, 1) 245 | self.bn = nn.BatchNorm2d(in_planes) 246 | self.SiLU = nn.SiLU() 247 | self.avg_pool = nn.AdaptiveAvgPool2d(1) # 输出最后两维1*1 248 | self.max_pool = nn.AdaptiveMaxPool2d(1) 249 | 250 | self.fc = nn.Sequential(nn.Conv2d(in_planes, in_planes // 16, 1, bias=False), 251 | nn.SiLU(), 252 | nn.Conv2d(in_planes // 16, in_planes, 1, bias=False)) 253 | self.sigmoid = nn.Sigmoid() 254 | 255 | def forward(self, x): 256 | x = self.conv(x) 257 | x = self.bn(x) 258 | x = self.SiLU(x) 259 | res = x 260 | avg_out = self.fc(self.avg_pool(x)) 261 | max_out = self.fc(self.max_pool(x)) 262 | out = avg_out + max_out 263 | result = x * self.sigmoid(out) + res 264 | return result 265 | 266 | 267 | class SpatialAttention(nn.Module): 268 | def __init__(self, kernel_size=7): 269 | super(SpatialAttention, self).__init__() 270 | 271 | self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=kernel_size // 2, bias=False) 272 | self.sigmoid = nn.Sigmoid() 273 | 274 | def forward(self, x): 275 | avg_out = torch.mean(x, dim=1, keepdim=True) 276 | max_out, _ = torch.max(x, dim=1, keepdim=True) 277 | x = torch.cat([avg_out, max_out], dim=1) 278 | x = self.conv1(x) 279 | return self.sigmoid(x) 280 | 281 | 282 | class CBAM(nn.Module): 283 | def __init__(self, in_planes): 284 | super(CBAM, self).__init__() 285 | self.channelattention = ChannelAttention_1(in_planes) 286 | self.spatialattention = SpatialAttention() 287 | 288 | def transpose(self, x): 289 | B, HW, C = x.size() 290 | H = int(math.sqrt(HW)) 291 | x = x.transpose(1, 2) 292 | x = x.view(B, C, H, H) 293 | return x 294 | 295 | def transpose_verse(self, x): 296 | B, C, H, W = x.size() 297 | HW = H * W 298 | x = x.view(B, C, HW) 299 | x = x.transpose(1, 2) 300 | return x 301 | 302 | def forward(self, x1, x2): 303 | x1 = self.transpose(x1) 304 | x2 = self.transpose(x2) 305 | x = torch.cat([x1, x2], dim=1) 306 | 307 | channel_attn = self.channelattention(x) 308 | x1 = x1 * channel_attn 309 | x2 = x2 * channel_attn 310 | x1 = self.transpose_verse(x1) 311 | x2 = self.transpose_verse(x2) 312 | return x1, x2 313 | 314 | 315 | class MixConv2d(nn.Module): 316 | # Mixed Depthwise Conv https://arxiv.org/abs/1907.09595 317 | def __init__(self, c1, c2, k=(1, 3, 5, 7), s=1, equal_ch=True): 318 | super().__init__() 319 | groups = len(k) 320 | if equal_ch: # equal c_ per group 321 | i = torch.linspace(0, groups - 1E-6, c2).floor() # c2 indices 322 | c_ = [(i == g).sum() for g in range(groups)] # intermediate channels 323 | else: # equal weight.numel() per group 324 | b = [c2] + [0] * groups 325 | a = np.eye(groups + 1, groups, k=-1) 326 | a -= np.roll(a, 1, axis=1) 327 | a *= np.array(k) ** 2 328 | a[0] = 1 329 | c_ = np.linalg.lstsq(a, b, rcond=None)[0].round() # solve for equal weight indices, ax = b 330 | 331 | self.m = nn.ModuleList([nn.Conv2d(c1, int(c_[g]), k[g], s, k[g] // 2, bias=False) for g in range(groups)]) 332 | self.bn = nn.BatchNorm2d(c2) 333 | self.act = nn.SiLU() 334 | 335 | def forward(self, x): 336 | return self.act(self.bn(torch.cat([m(x) for m in self.m], 1))) 337 | -------------------------------------------------------------------------------- /swin_ynet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.utils.checkpoint as checkpoint 3 | from einops import rearrange 4 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_ 5 | import copy 6 | from YTYAttention import * 7 | 8 | 9 | class Mlp(nn.Module): 10 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 11 | super().__init__() 12 | out_features = out_features or in_features 13 | hidden_features = hidden_features or in_features 14 | self.fc1 = nn.Linear(in_features, hidden_features) 15 | self.act = act_layer() 16 | self.fc2 = nn.Linear(hidden_features, out_features) 17 | self.drop = nn.Dropout(drop) 18 | 19 | def forward(self, x): 20 | x = self.fc1(x) 21 | x = self.act(x) 22 | x = self.drop(x) 23 | x = self.fc2(x) 24 | x = self.drop(x) 25 | return x 26 | 27 | 28 | def window_partition(x, window_size): 29 | """ 30 | Args: 31 | x: (B, H, W, C) 32 | window_size (int): window size 33 | 34 | Returns: 35 | windows: (num_windows*B, window_size, window_size, C) 36 | """ 37 | B, H, W, C = x.shape 38 | x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) # 如果换成384输入,这里无法被整除 39 | windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) 40 | return windows 41 | 42 | 43 | def window_reverse(windows, window_size, H, W): 44 | """ 45 | Args: 46 | windows: (num_windows*B, window_size, window_size, C) 47 | window_size (int): Window size 48 | H (int): Height of image 49 | W (int): Width of image 50 | 51 | Returns: 52 | x: (B, H, W, C) 53 | """ 54 | B = int(windows.shape[0] / (H * W / window_size / window_size)) 55 | x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) 56 | x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) 57 | return x 58 | 59 | 60 | class WindowAttention(nn.Module): 61 | r""" Window based multi-head self attention (W-MSA) module with relative position bias. 62 | It supports both of shifted and non-shifted window. 63 | 64 | Args: 65 | dim (int): Number of input channels. 66 | window_size (tuple[int]): The height and width of the window. 67 | num_heads (int): Number of attention heads. 68 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 69 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set 70 | attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 71 | proj_drop (float, optional): Dropout ratio of output. Default: 0.0 72 | """ 73 | 74 | def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): 75 | 76 | super().__init__() 77 | self.dim = dim 78 | self.window_size = window_size # Wh, Ww 79 | self.num_heads = num_heads 80 | head_dim = dim // num_heads 81 | self.scale = qk_scale or head_dim ** -0.5 82 | 83 | # define a parameter table of relative position bias 84 | self.relative_position_bias_table = nn.Parameter( 85 | torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH 86 | 87 | # get pair-wise relative position index for each token inside the window 88 | coords_h = torch.arange(self.window_size[0]) 89 | coords_w = torch.arange(self.window_size[1]) 90 | coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww 91 | coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww 92 | relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww 93 | relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 94 | relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 95 | relative_coords[:, :, 1] += self.window_size[1] - 1 96 | relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 97 | relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww 98 | self.register_buffer("relative_position_index", relative_position_index) 99 | 100 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 101 | self.attn_drop = nn.Dropout(attn_drop) 102 | self.proj = nn.Linear(dim, dim) 103 | self.proj_drop = nn.Dropout(proj_drop) 104 | 105 | trunc_normal_(self.relative_position_bias_table, std=.02) 106 | self.softmax = nn.Softmax(dim=-1) 107 | 108 | def forward(self, x, mask=None): 109 | """ 110 | Args: 111 | x: input features with shape of (num_windows*B, N, C) 112 | mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None 113 | """ 114 | B_, N, C = x.shape 115 | qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 116 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) 117 | 118 | q = q * self.scale 119 | attn = (q @ k.transpose(-2, -1)) 120 | 121 | relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( 122 | self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH 123 | relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww 124 | attn = attn + relative_position_bias.unsqueeze(0) 125 | 126 | if mask is not None: 127 | nW = mask.shape[0] 128 | attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) 129 | attn = attn.view(-1, self.num_heads, N, N) 130 | attn = self.softmax(attn) 131 | else: 132 | attn = self.softmax(attn) 133 | 134 | attn = self.attn_drop(attn) 135 | 136 | x = (attn @ v).transpose(1, 2).reshape(B_, N, C) 137 | x = self.proj(x) 138 | x = self.proj_drop(x) 139 | return x 140 | 141 | def extra_repr(self) -> str: 142 | return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' 143 | 144 | def flops(self, N): 145 | # calculate flops for 1 window with token length of N 146 | flops = 0 147 | # qkv = self.qkv(x) 148 | flops += N * self.dim * 3 * self.dim 149 | # attn = (q @ k.transpose(-2, -1)) 150 | flops += self.num_heads * N * (self.dim // self.num_heads) * N 151 | # x = (attn @ v) 152 | flops += self.num_heads * N * N * (self.dim // self.num_heads) 153 | # x = self.proj(x) 154 | flops += N * self.dim * self.dim 155 | return flops 156 | 157 | 158 | class SwinTransformerBlock(nn.Module): 159 | r""" Swin Transformer Block. 160 | 161 | Args: 162 | dim (int): Number of input channels. 163 | input_resolution (tuple[int]): Input resulotion. 164 | num_heads (int): Number of attention heads. 165 | window_size (int): Window size. 166 | shift_size (int): Shift size for SW-MSA. 167 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 168 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 169 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. 170 | drop (float, optional): Dropout rate. Default: 0.0 171 | attn_drop (float, optional): Attention dropout rate. Default: 0.0 172 | drop_path (float, optional): Stochastic depth rate. Default: 0.0 173 | act_layer (nn.Module, optional): Activation layer. Default: nn.GELU 174 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 175 | """ 176 | 177 | def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, 178 | mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., 179 | act_layer=nn.GELU, norm_layer=nn.LayerNorm): 180 | super().__init__() 181 | self.dim = dim 182 | self.input_resolution = input_resolution 183 | self.num_heads = num_heads 184 | self.window_size = window_size 185 | self.shift_size = shift_size 186 | self.mlp_ratio = mlp_ratio 187 | if min(self.input_resolution) <= self.window_size: 188 | # if window size is larger than input resolution, we don't partition windows 189 | self.shift_size = 0 190 | self.window_size = min(self.input_resolution) 191 | assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" 192 | 193 | self.norm1 = norm_layer(dim) 194 | self.attn = WindowAttention( 195 | dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, 196 | qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) 197 | 198 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 199 | self.norm2 = norm_layer(dim) 200 | mlp_hidden_dim = int(dim * mlp_ratio) 201 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 202 | 203 | if self.shift_size > 0: 204 | # calculate attention mask for SW-MSA 205 | H, W = self.input_resolution 206 | img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 207 | h_slices = (slice(0, -self.window_size), 208 | slice(-self.window_size, -self.shift_size), 209 | slice(-self.shift_size, None)) 210 | w_slices = (slice(0, -self.window_size), 211 | slice(-self.window_size, -self.shift_size), 212 | slice(-self.shift_size, None)) 213 | cnt = 0 214 | for h in h_slices: 215 | for w in w_slices: 216 | img_mask[:, h, w, :] = cnt 217 | cnt += 1 218 | 219 | mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 220 | mask_windows = mask_windows.view(-1, self.window_size * self.window_size) 221 | attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) 222 | attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) 223 | else: 224 | attn_mask = None 225 | 226 | self.register_buffer("attn_mask", attn_mask) 227 | 228 | def forward(self, x): 229 | H, W = self.input_resolution 230 | B, L, C = x.shape 231 | assert L == H * W, "input feature has wrong size" 232 | 233 | shortcut = x 234 | x = self.norm1(x) 235 | x = x.view(B, H, W, C) 236 | 237 | # cyclic shift 238 | if self.shift_size > 0: 239 | shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) 240 | else: 241 | shifted_x = x 242 | 243 | # partition windows 244 | x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C 245 | x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C 246 | 247 | # W-MSA/SW-MSA 248 | attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C 249 | 250 | # merge windows 251 | attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) 252 | shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C 253 | 254 | # reverse cyclic shift 255 | if self.shift_size > 0: 256 | x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) 257 | else: 258 | x = shifted_x 259 | x = x.view(B, H * W, C) 260 | 261 | # FFN 262 | x = shortcut + self.drop_path(x) 263 | x = x + self.drop_path(self.mlp(self.norm2(x))) 264 | 265 | return x 266 | 267 | def extra_repr(self) -> str: 268 | return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ 269 | f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" 270 | 271 | def flops(self): 272 | flops = 0 273 | H, W = self.input_resolution 274 | # norm1 275 | flops += self.dim * H * W 276 | # W-MSA/SW-MSA 277 | nW = H * W / self.window_size / self.window_size 278 | flops += nW * self.attn.flops(self.window_size * self.window_size) 279 | # mlp 280 | flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio 281 | # norm2 282 | flops += self.dim * H * W 283 | return flops 284 | 285 | 286 | class PatchEmbed(nn.Module): 287 | r""" Image to Patch Embedding 288 | 289 | Args: 290 | img_size (int): Image size. Default: 224. 291 | patch_size (int): Patch token size. Default: 4. 292 | in_chans (int): Number of input image channels. Default: 3. 293 | embed_dim (int): Number of linear projection output channels. Default: 96. 294 | norm_layer (nn.Module, optional): Normalization layer. Default: None 295 | """ 296 | 297 | def __init__(self, img_size=256, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): 298 | super().__init__() 299 | img_size = to_2tuple(img_size) # 变成(224, 224) 300 | patch_size = to_2tuple(patch_size) 301 | patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] # 整数除法 302 | self.img_size = img_size 303 | self.patch_size = patch_size 304 | self.patches_resolution = patches_resolution 305 | self.num_patches = patches_resolution[0] * patches_resolution[1] 306 | 307 | self.in_chans = in_chans 308 | self.embed_dim = embed_dim # dim=96 Swin-T Swin-S的配置 309 | 310 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) 311 | if norm_layer is not None: 312 | self.norm = norm_layer(embed_dim) 313 | else: 314 | self.norm = None 315 | 316 | def forward(self, x): 317 | B, C, H, W = x.shape 318 | # FIXME look at relaxing size constraints 319 | assert H == self.img_size[0] and W == self.img_size[1], \ 320 | f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." 321 | x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C 322 | if self.norm is not None: 323 | x = self.norm(x) 324 | return x 325 | 326 | def flops(self): 327 | Ho, Wo = self.patches_resolution 328 | flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) 329 | if self.norm is not None: 330 | flops += Ho * Wo * self.embed_dim 331 | return flops 332 | 333 | 334 | class PatchMerging(nn.Module): 335 | r""" Patch Merging Layer. 336 | 337 | Args: 338 | input_resolution (tuple[int]): Resolution of input feature. 339 | dim (int): Number of input channels. 340 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 341 | """ 342 | 343 | def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): 344 | super().__init__() 345 | self.input_resolution = input_resolution 346 | self.dim = dim 347 | self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) 348 | self.norm = norm_layer(4 * dim) 349 | 350 | def forward(self, x): 351 | """ 352 | x: B, H*W, C 353 | """ 354 | H, W = self.input_resolution 355 | B, L, C = x.shape 356 | assert L == H * W, "input feature has wrong size" 357 | assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." 358 | 359 | x = x.view(B, H, W, C) 360 | 361 | x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C 362 | x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C 363 | x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C 364 | x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C 365 | x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C 366 | x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C 367 | 368 | x = self.norm(x) 369 | x = self.reduction(x) 370 | 371 | return x 372 | 373 | def extra_repr(self) -> str: 374 | return f"input_resolution={self.input_resolution}, dim={self.dim}" 375 | 376 | def flops(self): 377 | H, W = self.input_resolution 378 | flops = H * W * self.dim 379 | flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim 380 | return flops 381 | 382 | 383 | class PatchExpand(nn.Module): 384 | def __init__(self, input_resolution, dim, dim_scale=2, norm_layer=nn.LayerNorm): 385 | super().__init__() 386 | self.input_resolution = input_resolution 387 | self.dim = dim 388 | self.expand = nn.Linear(dim, 2 * dim, bias=False) if dim_scale == 2 else nn.Identity() 389 | self.norm = norm_layer(dim // dim_scale) 390 | 391 | def forward(self, x): 392 | """ 393 | x: B, H*W, C 394 | """ 395 | H, W = self.input_resolution 396 | x = self.expand(x) 397 | B, L, C = x.shape 398 | assert L == H * W, "input feature has wrong size" 399 | 400 | x = x.view(B, H, W, C) 401 | x = rearrange(x, 'b h w (p1 p2 c)-> b (h p1) (w p2) c', p1=2, p2=2, c=C // 4) 402 | x = x.view(B, -1, C // 4) 403 | x = self.norm(x) 404 | 405 | return x 406 | 407 | 408 | class FinalPatchExpand_X4(nn.Module): 409 | def __init__(self, input_resolution, dim, dim_scale=4, norm_layer=nn.LayerNorm, patchsize=4): 410 | super().__init__() 411 | self.input_resolution = input_resolution 412 | self.dim = dim 413 | self.dim_scale = dim_scale 414 | self.expand = nn.Linear(dim, patchsize * patchsize * dim, bias=False) 415 | self.output_dim = dim 416 | self.norm = norm_layer(self.output_dim) 417 | 418 | def forward(self, x): 419 | """ 420 | x: B, H*W, C 421 | """ 422 | H, W = self.input_resolution 423 | x = self.expand(x) 424 | B, L, C = x.shape 425 | assert L == H * W, "input feature has wrong size" 426 | 427 | x = x.view(B, H, W, C) 428 | x = rearrange(x, 'b h w (p1 p2 c)-> b (h p1) (w p2) c', p1=self.dim_scale, p2=self.dim_scale, 429 | c=C // (self.dim_scale ** 2)) 430 | x = x.view(B, -1, self.output_dim) 431 | x = self.norm(x) 432 | 433 | return x 434 | 435 | 436 | class FinalPatchExpand_X4_1(nn.Module): 437 | def __init__(self, input_resolution, dim, dim_scale=1, norm_layer=nn.LayerNorm, patchsize=1): 438 | super().__init__() 439 | self.input_resolution = input_resolution 440 | self.dim = dim 441 | self.dim_scale = dim_scale 442 | # self.expand = nn.Linear(dim, patchsize*patchsize*dim, bias=False) 443 | self.output_dim = dim 444 | self.norm = norm_layer(self.output_dim) 445 | 446 | def forward(self, x): 447 | """ 448 | x: B, H*W, C 449 | """ 450 | H, W = self.input_resolution 451 | # x = self.expand(x) 452 | B, L, C = x.shape 453 | assert L == H * W, "input feature has wrong size" 454 | 455 | x = x.view(B, H, W, C) 456 | x = rearrange(x, 'b h w (p1 p2 c)-> b (h p1) (w p2) c', p1=self.dim_scale, p2=self.dim_scale, 457 | c=C // (self.dim_scale ** 2)) 458 | x = x.view(B, -1, self.output_dim) 459 | x = self.norm(x) 460 | 461 | return x 462 | 463 | 464 | class BasicLayer(nn.Module): 465 | """ A basic Swin Transformer layer for one stage. 466 | 467 | Args: 468 | dim (int): Number of input channels. 469 | input_resolution (tuple[int]): Input resolution. 470 | depth (int): Number of blocks. 471 | num_heads (int): Number of attention heads. 472 | window_size (int): Local window size. 473 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 474 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 475 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. 476 | drop (float, optional): Dropout rate. Default: 0.0 477 | attn_drop (float, optional): Attention dropout rate. Default: 0.0 478 | drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 479 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 480 | downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None 481 | use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. 482 | """ 483 | 484 | def __init__(self, dim, input_resolution, depth, num_heads, window_size, 485 | mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., 486 | drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False): 487 | 488 | super().__init__() 489 | self.dim = dim 490 | self.input_resolution = input_resolution 491 | self.depth = depth 492 | self.use_checkpoint = use_checkpoint 493 | 494 | # build blocks 495 | self.blocks = nn.ModuleList([ 496 | SwinTransformerBlock(dim=dim, input_resolution=input_resolution, 497 | num_heads=num_heads, window_size=window_size, 498 | shift_size=0 if (i % 2 == 0) else window_size // 2, # 相邻两个trans blocks, 一个shift,一个不用 499 | mlp_ratio=mlp_ratio, 500 | qkv_bias=qkv_bias, qk_scale=qk_scale, 501 | drop=drop, attn_drop=attn_drop, 502 | drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, 503 | norm_layer=norm_layer) 504 | for i in range(depth)]) 505 | 506 | # patch merging layer 507 | if downsample is not None: 508 | self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) 509 | else: 510 | self.downsample = None 511 | 512 | def forward(self, x): 513 | for blk in self.blocks: 514 | if self.use_checkpoint: 515 | x = checkpoint.checkpoint(blk, x) 516 | else: 517 | x = blk(x) 518 | if self.downsample is not None: 519 | x = self.downsample(x) 520 | return x 521 | 522 | def extra_repr(self) -> str: 523 | return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" 524 | 525 | def flops(self): 526 | flops = 0 527 | for blk in self.blocks: 528 | flops += blk.flops() 529 | if self.downsample is not None: 530 | flops += self.downsample.flops() 531 | return flops 532 | 533 | 534 | class BasicLayer_up(nn.Module): 535 | """ A basic Swin Transformer layer for one stage. 536 | 537 | Args: 538 | dim (int): Number of input channels. 539 | input_resolution (tuple[int]): Input resolution. 540 | depth (int): Number of blocks. 541 | num_heads (int): Number of attention heads. 542 | window_size (int): Local window size. 543 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 544 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 545 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. 546 | drop (float, optional): Dropout rate. Default: 0.0 547 | attn_drop (float, optional): Attention dropout rate. Default: 0.0 548 | drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 549 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 550 | downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None 551 | use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. 552 | """ 553 | 554 | def __init__(self, dim, input_resolution, depth, num_heads, window_size, 555 | mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., 556 | drop_path=0., norm_layer=nn.LayerNorm, upsample=None, use_checkpoint=False): 557 | 558 | super().__init__() 559 | self.dim = dim 560 | self.input_resolution = input_resolution 561 | self.depth = depth 562 | self.use_checkpoint = use_checkpoint 563 | 564 | # build blocks 565 | self.blocks = nn.ModuleList([ 566 | SwinTransformerBlock(dim=dim, input_resolution=input_resolution, 567 | num_heads=num_heads, window_size=window_size, 568 | shift_size=0 if (i % 2 == 0) else window_size // 2, 569 | mlp_ratio=mlp_ratio, 570 | qkv_bias=qkv_bias, qk_scale=qk_scale, 571 | drop=drop, attn_drop=attn_drop, 572 | drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, 573 | norm_layer=norm_layer) 574 | for i in range(depth)]) 575 | 576 | # patch merging layer 577 | if upsample is not None: 578 | self.upsample = PatchExpand(input_resolution, dim=dim, dim_scale=2, norm_layer=norm_layer) 579 | else: 580 | self.upsample = None 581 | 582 | def forward(self, x): 583 | for blk in self.blocks: 584 | if self.use_checkpoint: 585 | x = checkpoint.checkpoint(blk, x) 586 | else: 587 | x = blk(x) 588 | if self.upsample is not None: 589 | x = self.upsample(x) 590 | return x 591 | 592 | 593 | class SwinTransEncoder(nn.Module): 594 | r""" Swin Transformer 595 | A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - 596 | https://arxiv.org/pdf/2103.14030 597 | 598 | Args: 599 | img_size (int | tuple(int)): Input image size. Default 224 600 | patch_size (int | tuple(int)): Patch size. Default: 4 601 | in_chans (int): Number of input image channels. Default: 3 602 | num_classes (int): Number of classes for classification head. Default: 1000 603 | embed_dim (int): Patch embedding dimension. Default: 96 604 | depths (tuple(int)): Depth of each Swin Transformer layer. 605 | num_heads (tuple(int)): Number of attention heads in different layers. 606 | window_size (int): Window size. Default: 7 607 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 608 | qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True 609 | qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None 610 | drop_rate (float): Dropout rate. Default: 0 611 | attn_drop_rate (float): Attention dropout rate. Default: 0 612 | drop_path_rate (float): Stochastic depth rate. Default: 0.1 613 | norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. 614 | ape (bool): If True, add absolute position embedding to the patch embedding. Default: False 615 | patch_norm (bool): If True, add normalization after patch embedding. Default: True 616 | use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False 617 | """ 618 | 619 | def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=2, 620 | embed_dim=192, depths=[2, 2, 18, 2], depths_decoder=[4, 4, 4, 4], num_heads=[6, 12, 24, 48], 621 | window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None, 622 | drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, 623 | norm_layer=nn.LayerNorm, ape=False, patch_norm=True, 624 | use_checkpoint=False, final_upsample="expand_first", **kwargs): 625 | super().__init__() 626 | 627 | print( 628 | "SwinTransformerSys expand initial----depths:{};depths_decoder:{};drop_path_rate:{};num_classes:{}".format( 629 | depths, 630 | depths_decoder, drop_path_rate, num_classes)) 631 | 632 | self.num_classes = num_classes 633 | self.num_layers = len(depths) 634 | self.embed_dim = embed_dim 635 | self.ape = ape 636 | self.patch_norm = patch_norm 637 | self.num_features = int(embed_dim * 2 ** (self.num_layers - 1)) 638 | self.num_features_up = int(embed_dim * 2) 639 | self.mlp_ratio = mlp_ratio 640 | self.final_upsample = final_upsample 641 | 642 | # split image into non-overlapping patches 643 | self.patch_embed = PatchEmbed( 644 | img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, 645 | norm_layer=norm_layer if self.patch_norm else None) 646 | num_patches = self.patch_embed.num_patches 647 | patches_resolution = self.patch_embed.patches_resolution 648 | self.patches_resolution = patches_resolution 649 | 650 | # absolute position embedding 651 | if self.ape: 652 | self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) 653 | trunc_normal_(self.absolute_pos_embed, std=.02) 654 | 655 | self.pos_drop = nn.Dropout(p=drop_rate) 656 | 657 | self.deal = nn.ModuleList() 658 | self.deal.append(nn.LayerNorm(128)) 659 | self.deal.append(nn.LayerNorm(256)) 660 | self.deal.append(nn.LayerNorm(512)) 661 | self.deal.append(nn.LayerNorm(1024)) 662 | 663 | # stochastic depth 664 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule 665 | 666 | # build encoder and bottleneck layers 667 | self.layers = nn.ModuleList() 668 | for i_layer in range(self.num_layers): 669 | layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer), 670 | input_resolution=(patches_resolution[0] // (2 ** i_layer), 671 | patches_resolution[1] // (2 ** i_layer)), 672 | depth=depths[i_layer], 673 | num_heads=num_heads[i_layer], 674 | window_size=window_size, 675 | mlp_ratio=self.mlp_ratio, 676 | qkv_bias=qkv_bias, qk_scale=qk_scale, 677 | drop=drop_rate, attn_drop=attn_drop_rate, 678 | drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], 679 | norm_layer=norm_layer, 680 | downsample=PatchMerging if (i_layer < self.num_layers - 1) else None, 681 | use_checkpoint=use_checkpoint) 682 | self.layers.append(layer) 683 | 684 | self.norm = norm_layer(self.num_features) 685 | self.fusion = ChannelAttention_1(1024) 686 | 687 | # Encoder and Bottleneck 688 | def transpose(self, x): 689 | B, HW, C = x.size() 690 | H = int(math.sqrt(HW)) 691 | x = x.transpose(1, 2) 692 | x = x.view(B, C, H, H) 693 | return x 694 | 695 | def transpose_verse(self, x): 696 | B, C, H, W = x.size() 697 | x = x.view(B, C, -1) 698 | x = x.transpose(1, 2) 699 | return x 700 | 701 | def forward(self, x1, x2): 702 | # input [1,3,224,224] 703 | x1 = self.patch_embed(x1) # [1,56x56,96(embedding dim)] 704 | x2 = self.patch_embed(x2) # [1,56x56,96(embedding dim)] 705 | if self.ape: 706 | x1 = x1 + self.absolute_pos_embed 707 | x2 = x2 + self.absolute_pos_embed 708 | x1 = self.pos_drop(x1) 709 | x2 = self.pos_drop(x2) 710 | x1_downsample = [] 711 | x2_downsample = [] 712 | 713 | for inx, layer in enumerate(self.layers): 714 | if inx != 3: 715 | x1_downsample.append(self.deal[inx](x1)) # self.deal[inx](x)) #self.deal[inx] 716 | x2_downsample.append(self.deal[inx](x2)) 717 | x1 = layer(x1) # ??norm( mlp(layer_norm(self-attention(x))) 718 | x2 = layer(x2) 719 | else: 720 | x1_downsample.append(self.deal[inx](x1)) # self.deal[inx](x)) #self.deal[inx] 721 | x2_downsample.append(self.deal[inx](x2)) 722 | x_mid = self.transpose_verse(self.fusion(self.transpose(torch.cat([x1, x2], dim=2)))) 723 | x_mid = layer(x_mid) 724 | # x_mid = torch.cat([x1,x2],dim=2) 725 | # x1 = self.norm(x1) 726 | # x2 = self.norm(x2) 727 | # x_mid = self.transpose_verse(self.fusion(self.transpose(torch.cat([x1,x2],dim=2)))) 728 | x_mid = self.norm(x_mid) # B L C --1 49 768] 729 | 730 | return x_mid, x1_downsample, x2_downsample 731 | 732 | 733 | # TODO finish this part to transfer trans 2 cnn 734 | class PatchUnembed(nn.Module): 735 | def __init__(self, input_resolution, dim, dim_scale=4, final_dim=64, norm_layer=nn.LayerNorm): 736 | super().__init__() 737 | self.input_resolution = input_resolution 738 | self.dim = dim 739 | self.dim_scale = dim_scale 740 | 741 | self.final_dim = final_dim 742 | 743 | self.expand = nn.Linear(dim, 16 * final_dim, bias=False) 744 | self.output_dim = dim // self.dim_scale ** 2 745 | 746 | self.norm = norm_layer(self.final_dim) 747 | 748 | # self.output = nn.Conv2d(in_channels=self.output_dim,out_channels=self.final_dim,kernel_size=1,bias=False) 749 | 750 | def forward(self, x): 751 | """ 752 | x: B, H*W, C 753 | -> B, C/16, H*4, W*4 754 | """ 755 | H, W = self.input_resolution 756 | B, L, C = x.shape 757 | assert L == H * W, "input feature has wrong size" 758 | 759 | x = self.expand(x) 760 | C = 16 * self.final_dim 761 | x = x.view(B, H, W, C) 762 | x = rearrange(x, 'b h w (p1 p2 c)-> b (h p1) (w p2) c', p1=self.dim_scale, p2=self.dim_scale, 763 | c=C // (self.dim_scale ** 2)) 764 | x = x.view(B, -1, self.final_dim) 765 | x = self.norm(x) 766 | 767 | x = x.view(B, 4 * H, 4 * W, -1) 768 | x = x.permute(0, 3, 1, 2) # B,C,H,W 769 | # x = self.output(x) 770 | 771 | return x 772 | 773 | 774 | class SwinTransDecoder(nn.Module): 775 | def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=2, 776 | embed_dim=128, depths=[4, 4, 4, 4], depths_decoder=[2, 2, 2, 2], num_heads=[4, 8, 16, 32], 777 | window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None, 778 | drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, 779 | norm_layer=nn.LayerNorm, ape=False, patch_norm=True, 780 | use_checkpoint=False, final_upsample="expand_first", **kwargs): 781 | super().__init__() 782 | 783 | self.num_classes = num_classes 784 | self.num_layers = len(depths) 785 | self.embed_dim = embed_dim 786 | self.ape = ape 787 | self.patch_norm = patch_norm 788 | self.num_features = int(embed_dim * 2 ** (self.num_layers - 1)) 789 | self.num_features_up = int(embed_dim * 2) 790 | self.mlp_ratio = mlp_ratio 791 | self.final_upsample = final_upsample 792 | self.patches_resolution = [img_size // patch_size, img_size // patch_size] 793 | self.patch_size = patch_size 794 | 795 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] 796 | 797 | self.concat_linear0 = nn.Linear(1024, 1024) 798 | 799 | self.norm = nn.ModuleList() 800 | self.norm.append(nn.LayerNorm(512)) 801 | self.norm.append(nn.LayerNorm(256)) 802 | self.norm.append(nn.LayerNorm(128)) 803 | self.norm.append(nn.LayerNorm(128)) 804 | 805 | self.layers_up = nn.ModuleList() 806 | self.concat_back_dim = nn.ModuleList() 807 | for i_layer in range(self.num_layers): 808 | concat_linear = nn.Linear(int(embed_dim * 2 ** (self.num_layers - 1 - i_layer)), 809 | int(embed_dim * 2 ** ( 810 | self.num_layers - 1 - i_layer))) if i_layer > 0 else nn.Identity() 811 | if i_layer == 0: 812 | layer_up = PatchExpand( 813 | input_resolution=(self.patches_resolution[0] // (2 ** (self.num_layers - 1 - i_layer)), 814 | self.patches_resolution[1] // (2 ** (self.num_layers - 1 - i_layer))), 815 | dim=int(embed_dim * 2 ** (self.num_layers - 1 - i_layer)), dim_scale=2, norm_layer=norm_layer) 816 | else: 817 | layer_up = BasicLayer_up(dim=int(embed_dim * 2 ** (self.num_layers - 1 - i_layer)), 818 | input_resolution=( 819 | self.patches_resolution[0] // (2 ** (self.num_layers - 1 - i_layer)), 820 | self.patches_resolution[1] // (2 ** (self.num_layers - 1 - i_layer))), 821 | depth=depths[(self.num_layers - 1 - i_layer)], 822 | num_heads=num_heads[(self.num_layers - 1 - i_layer)], 823 | window_size=window_size, 824 | mlp_ratio=self.mlp_ratio, 825 | qkv_bias=qkv_bias, qk_scale=qk_scale, 826 | drop=drop_rate, attn_drop=attn_drop_rate, 827 | drop_path=dpr[sum(depths[:(self.num_layers - 1 - i_layer)]):sum( 828 | depths[:(self.num_layers - 1 - i_layer) + 1])], 829 | norm_layer=norm_layer, 830 | upsample=PatchExpand if (i_layer < self.num_layers - 1) else None, 831 | use_checkpoint=use_checkpoint) 832 | self.layers_up.append(layer_up) 833 | self.concat_back_dim.append(concat_linear) 834 | 835 | self.norm_up = norm_layer(self.embed_dim) 836 | if self.final_upsample == "expand_first": 837 | # print("---final upsample expand_first---") 838 | self.up = FinalPatchExpand_X4(input_resolution=(img_size // patch_size, img_size // patch_size), 839 | dim_scale=4, dim=embed_dim, patchsize=patch_size) 840 | self.up_0 = FinalPatchExpand_X4_1(input_resolution=(56, 56), dim_scale=1, dim=128, patchsize=1) 841 | self.up_1 = FinalPatchExpand_X4_1(input_resolution=(28, 28), dim_scale=1, dim=256, patchsize=1) 842 | self.up_2 = FinalPatchExpand_X4_1(input_resolution=(14, 14), dim_scale=1, dim=512, patchsize=1) 843 | self.output = nn.Conv2d(in_channels=embed_dim, out_channels=self.num_classes, kernel_size=1, bias=False) 844 | self.output_0 = nn.Conv2d(in_channels=128, out_channels=self.num_classes, kernel_size=1, bias=False) 845 | self.output_1 = nn.Conv2d(in_channels=256, out_channels=self.num_classes, kernel_size=1, bias=False) 846 | self.output_2 = nn.Conv2d(in_channels=512, out_channels=self.num_classes, kernel_size=1, bias=False) 847 | 848 | self.ppattention = nn.ModuleList() 849 | self.ppattention.append(ppattention_wan(1024)) 850 | self.ppattention.append(ppattention_wan(512)) 851 | self.ppattention.append(ppattention_wan(256)) 852 | self.ppattention.append(ppattention_wan(128)) 853 | 854 | self.DFE = nn.ModuleList() 855 | self.DFE.append(DFE(1024)) 856 | self.DFE.append(DFE(512)) 857 | self.DFE.append(DFE(256)) 858 | self.DFE.append(DFE(128)) 859 | 860 | self.norm_bn = nn.ModuleList() 861 | self.norm_bn.append(nn.BatchNorm2d(2048)) 862 | self.norm_bn.append(nn.BatchNorm2d(1024)) 863 | self.norm_bn.append(nn.BatchNorm2d(512)) 864 | self.norm_bn.append(nn.BatchNorm2d(256)) 865 | 866 | self.channelattention = ChannelAttention(1024) 867 | self.avgpool = nn.AvgPool2d(kernel_size=3, stride=1, padding=1) 868 | 869 | def transpose(self, x): 870 | B, HW, C = x.size() 871 | H = int(math.sqrt(HW)) 872 | x = x.transpose(1, 2) 873 | x = x.view(B, C, H, H) 874 | return x 875 | 876 | def transpose_verse(self, x): 877 | B, C, H, W = x.size() 878 | x = x.view(B, C, -1) 879 | x = x.transpose(1, 2) 880 | return x 881 | 882 | def forward_up_features(self, x_mid, x_downsample1, x_downsample2): # 1/4,1/8,1/16,1/32, 1/32 883 | x_upsample = [] 884 | for inx, layer_up in enumerate(self.layers_up): 885 | if inx == 0: 886 | t1 = self.transpose(x_downsample1[3]) # B,C,H,W mlp 887 | t2 = self.transpose(x_downsample2[3]) 888 | diff_feature = t1 - t2 # 1024 889 | add_feature = t1 + t2 890 | 891 | x_mid = self.transpose(x_mid) 892 | x_mid = torch.cat([x_mid, x_mid], dim=1) 893 | # hidden = torch.cat([x_mid,t1,t2],dim=1) 894 | hidden = torch.cat([diff_feature, add_feature], dim=1) 895 | hidden = x_mid + hidden 896 | # hidden = self.norm_bn[inx](hidden) 897 | x = self.transpose_verse(self.ppattention[inx](hidden)) # B,HW,C 898 | 899 | x = self.concat_linear0(x) # C,C 1024 900 | y1 = layer_up(x) # C/2 UP SAMPLE 512 901 | y2 = y1 # B,HW,C/2 512 902 | x = torch.cat([y1, y2], dim=2) # C B,HW,C 1024 903 | x_upsample.append(self.norm[0](y1)) # B,HW,C/2 904 | else: 905 | t1 = self.transpose(x_downsample1[3 - inx]) # 512 906 | t2 = self.transpose(x_downsample2[3 - inx]) 907 | diff_feature = t1 - t2 908 | add_feature = t1 + t2 909 | 910 | x = self.transpose(x) # 1024 911 | # hidden = torch.cat([x,t1,t2],dim=1) 912 | hidden = torch.cat([diff_feature, add_feature], dim=1) 913 | hidden = x + hidden 914 | # hidden = self.norm_bn[inx](hidden) 915 | x = self.ppattention[inx](hidden) 916 | x = self.transpose_verse(x) # B,HW,C 917 | x = self.concat_back_dim[inx](x) ###### 918 | 919 | y1 = layer_up(x) # layer up 初始层有norm,up norm,norm,up 920 | y2 = y1 921 | x = torch.cat([y1, y2], dim=2) # C1024 922 | norm = self.norm[inx] 923 | x_upsample.append((norm(y1))) 924 | 925 | x = self.norm_up(y1) # B L C 最终预测结果 926 | 927 | return x, x_upsample 928 | 929 | def up_x4(self, x, pz): 930 | H, W = self.patches_resolution 931 | B, L, C = x.shape 932 | assert L == H * W, "input features has wrong size" 933 | 934 | if self.final_upsample == "expand_first": 935 | x = self.up(x) 936 | x = x.view(B, pz * H, pz * W, -1) 937 | x = x.permute(0, 3, 1, 2) # B,C,H,W 938 | x = self.output(x) 939 | 940 | return x 941 | 942 | def up_x4_1(self, x, pz): 943 | H, W = self.patches_resolution 944 | B, L, C = x.shape 945 | assert L == H * W, "input features has wrong size" 946 | 947 | if self.final_upsample == "expand_first": 948 | x = self.up_0(x) 949 | x = x.view(B, pz * H, pz * W, -1) 950 | x = x.permute(0, 3, 1, 2) # B,C,H,W 951 | x = self.output_0(x) 952 | 953 | return x 954 | 955 | def up_x8(self, x, pz): 956 | H, W = (28, 28) 957 | B, L, C = x.shape 958 | assert L == H * W, "input features has wrong size" 959 | 960 | if self.final_upsample == "expand_first": 961 | # x = self.up(x,patchsize=pz) 962 | x = self.up_1(x) 963 | x = x.view(B, pz * H, pz * W, -1) 964 | x = x.permute(0, 3, 1, 2) # B,C,H,W 965 | x = self.output_1(x) 966 | 967 | return x 968 | 969 | def up_x16(self, x, pz): 970 | H, W = (14, 14) 971 | B, L, C = x.shape 972 | assert L == H * W, "input features has wrong size" 973 | 974 | if self.final_upsample == "expand_first": 975 | # x = self.up(x,patchsize=pz) 976 | x = self.up_2(x) 977 | x = x.view(B, pz * H, pz * W, -1) 978 | x = x.permute(0, 3, 1, 2) # B,C,H,W 979 | x = self.output_2(x) 980 | 981 | return x 982 | 983 | def forward(self, x, x_down1, x_down2): 984 | x, x_upsample = self.forward_up_features(x, x_down1, x_down2) 985 | 986 | x_p = self.up_x4(x, self.patch_size) 987 | x_pre2 = self.up_x4_1(x_upsample[2], 1) 988 | x_pre3 = self.up_x8(x_upsample[1], 1) 989 | x_pre4 = self.up_x16(x_upsample[0], 1) 990 | 991 | return x_p, x_pre2, x_pre3, x_pre4 992 | 993 | 994 | class encoder1(nn.Module): 995 | def __init__(self): 996 | super(encoder1, self).__init__() 997 | self.encoder1 = SwinTransEncoder(img_size=224, patch_size=4, in_chans=3, num_classes=2, embed_dim=128, 998 | depths=[2, 2, 18, 2], depths_decoder=[4, 4, 4, 4], num_heads=[4, 8, 16, 32], 999 | window_size=7) 1000 | self.pretrained_path = 'swin_pretrain_224.pth' 1001 | self.load_from() 1002 | 1003 | def load_from(self): 1004 | pretrained_path = self.pretrained_path 1005 | if pretrained_path is not None: 1006 | print("pretrained_path:{}".format(pretrained_path)) 1007 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 1008 | pretrained_dict = torch.load(pretrained_path, map_location=device) 1009 | if "model" not in pretrained_dict: 1010 | print("---start load pretrained modle by splitting---") 1011 | pretrained_dict = {k[17:]: v for k, v in pretrained_dict.items()} 1012 | for k in list(pretrained_dict.keys()): 1013 | if "output" in k: 1014 | print("delete key:{}".format(k)) 1015 | del pretrained_dict[k] 1016 | msg = self.encoder1.load_state_dict(pretrained_dict, strict=False) 1017 | return 1018 | pretrained_dict = pretrained_dict['model'] 1019 | print("---start load pretrained modle of swin encoder---") 1020 | 1021 | model_dict = self.encoder1.state_dict() 1022 | full_dict = copy.deepcopy(pretrained_dict) 1023 | for k, v in pretrained_dict.items(): 1024 | if "layers." in k: 1025 | current_layer_num = 3 - int(k[7:8]) 1026 | current_k = "layers_up." + str(current_layer_num) + k[8:] 1027 | full_dict.update({current_k: v}) 1028 | for k in list(full_dict.keys()): 1029 | if k in model_dict: 1030 | # print(1) 1031 | if full_dict[k].shape != model_dict[k].shape: 1032 | print("delete:{};shape pretrain:{};shape model:{}".format(k, v.shape, model_dict[k].shape)) 1033 | del full_dict[k] 1034 | 1035 | msg = self.encoder1.load_state_dict(full_dict, strict=False) 1036 | else: 1037 | print("none pretrain") 1038 | 1039 | def forward(self, img1, img2): 1040 | x, y1, y2 = self.encoder1(img1, img2) 1041 | return x, y1, y2 1042 | 1043 | 1044 | class Encoder(nn.Module): 1045 | def __init__(self): 1046 | super(Encoder, self).__init__() 1047 | 1048 | self.encoder1 = encoder1() 1049 | self.decoder = SwinTransDecoder() 1050 | 1051 | def forward(self, img1, img2): 1052 | x, x_downsample1, x_downsample2 = self.encoder1(img1, img2) 1053 | x_p, x_2, x_3, x_4 = self.decoder(x, x_downsample1, x_downsample2) 1054 | 1055 | return x_p, x_2, x_3, x_4 1056 | --------------------------------------------------------------------------------