├── README.md ├── config.py ├── images ├── a.png └── r.png ├── networks ├── SwinNet │ ├── README.md │ ├── SwinNet_test.py │ ├── SwinNet_train.py │ ├── cpts │ │ └── RGBD-SwinNet.log │ ├── data.py │ ├── imgs │ │ ├── RGBD.png │ │ ├── RGBT.png │ │ └── main.png │ ├── logger.py │ ├── lr_scheduler.py │ ├── main.py │ ├── models │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-38.pyc │ │ │ ├── build.cpython-38.pyc │ │ │ └── swinNet.cpython-38.pyc │ │ ├── back │ │ │ └── Swin_Transformer.py │ │ ├── build.py │ │ └── swinNet.py │ ├── optimizer.py │ ├── options.py │ ├── test_gray_to_rgb.py │ └── utils.py ├── Wavenet.py ├── __init__.py ├── __pycache__ │ ├── Wavenet.cpython-36.pyc │ ├── __init__.cpython-36.pyc │ ├── __init__.cpython-38.pyc │ └── wavemlp.cpython-36.pyc └── wavemlp.py ├── pytorch_iou └── __init__.py ├── rgbt_dataset_KD.py ├── test.py ├── train_wave_KD.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # WaveNet 2 | This project provides the code and results for 'WaveNet: Wavelet Network With Knowledge Distillation for RGB-T Salient Object Detection', IEEE TIP, 2023. [IEEE link](https://ieeexplore.ieee.org/document/10127616)
3 | 4 | - Old codebase: https://github.com/nowander/WaveNet/tree/default 5 | # Requirements 6 | Python 3.7+, Pytorch 1.5.0+, Cuda 10.2+, TensorboardX 2.1, opencv-python, pytorch_wavelets, timm.
7 | 8 | # Architecture and Details 9 | 10 | drawing 11 | 12 | # Results 13 | drawing 14 | 15 | 16 | # Preparation 17 | - Download the RGB-T raw data from [LSNet](https://github.com/zyrant/LSNet).
18 | - Options: Download the pre-trained wavemlp-s from [wavemlp](https://github.com/huawei-noah/Efficient-AI-Backbones/tree/master/wavemlp_pytorch).
19 | - We have two ways of training knowledge distillation: 20 | 1. Load the SwinNet model, please refer to the specific configuration of [SwinNet](https://github.com/liuzywen/SwinNet). 21 | 2. Directly load the prediction maps of SwinNet [baidu](https://pan.baidu.com/s/18qwaTwTZ39XtWlP3JaeSOQ) pin: py5y.
22 | We use prediction maps of SwinNet as the default setting. 23 | 24 | 25 | 26 | # Training & Testing 27 | Modify the `train_root` `train_root` `save_path` path in `config.py` according to your own data path. 28 | - Train the WaveNet: 29 | 30 | `python train.py` 31 | 32 | Modify the `test_path` path in `config.py` according to your own data path. 33 | 34 | - Test the WaveNet: 35 | 36 | `python test.py` 37 | 38 | # Evaluate tools 39 | - You can select one of the toolboxes to get the metrics 40 | [CODToolbox](https://github.com/DengPingFan/CODToolbox) / [PySODMetrics](https://github.com/lartpang/PySODMetrics) 41 | 42 | # Saliency Maps 43 | - RGB-T [baidu](https://pan.baidu.com/s/1L6gF3hT8ML3uN_p2va4n8w) pin: gl01
44 | 45 | 46 | # Pretraining Models 47 | - RGB-T [baidu](https://pan.baidu.com/s/1PGwu3uVRWyFS1erOBr7KAg) pin: v5pb
48 | 49 | # Citation 50 | @ARTICLE{10127616, 51 | author={Zhou, Wujie and Sun, Fan and Jiang, Qiuping and Cong, Runmin and Hwang, Jenq-Neng}, 52 | journal={IEEE Transactions on Image Processing}, 53 | title={WaveNet: Wavelet Network With Knowledge Distillation for RGB-T Salient Object Detection}, 54 | year={2023}, 55 | volume={32}, 56 | number={}, 57 | pages={3027-3039}, 58 | doi={10.1109/TIP.2023.3275538}} 59 | 60 | # Acknowledgement 61 | 62 | The implementation of this project is based on the codebases below.
63 | - [BBS-Net](https://github.com/zyjwuyan/BBS-Net)
64 | - [LSNet](https://github.com/zyrant/LSNet)
65 | - [Wavemlp](https://github.com/huawei-noah/Efficient-AI-Backbones/tree/master/wavemlp_pytorch)
66 | - Evaluate tools [CODToolbox](https://github.com/DengPingFan/CODToolbox) / [PySODMetrics](https://github.com/lartpang/PySODMetrics)
67 | 68 | If you find this project helpful, Please also cite the codebases above. Besides, we also thank [zyrant](https://github.com/zyrant). 69 | 70 | # Contact 71 | Please drop me an email for any problems or discussion: https://wujiezhou.github.io/ (wujiezhou@163.com). 72 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | parser = argparse.ArgumentParser() 3 | # change dir according to your own dataset dirs 4 | # train/val 5 | parser.add_argument('--epoch', type=int, default=60, help='epoch number') 6 | parser.add_argument('--lr', type=float, default=3e-5, help='learning rate') 7 | parser.add_argument('--batchsize', type=int, default=6, help='training batch size') 8 | parser.add_argument('--trainsize', type=int, default=384, help='training dataset size') 9 | parser.add_argument('--clip', type=float, default=0.5, help='gradient clipping margin') 10 | parser.add_argument('--decay_rate', type=float, default=0.1, help='decay rate of learning rate')#0.1 11 | parser.add_argument('--decay_epoch', type=int, default=30, help='every n epochs decay learning rate') 12 | parser.add_argument('--load', type=str, default=None, help='train from checkpoints') 13 | parser.add_argument('--gpu_id', type=str, default='0', help='select gpu id') 14 | parser.add_argument('--lr_train_root', type=str, default='', help='the train images root') 15 | parser.add_argument('--lr_val_root', type=str, default='', help='the val images root') 16 | parser.add_argument('--save_path', type=str, default='', help='the path to save models and logs') 17 | # test(predict) 18 | parser.add_argument('--testsize', type=int, default=384, help='testing size') 19 | parser.add_argument('--test_path',type=str,default='/media/hjk/shuju/轻量级/data set/RGBT_test/',help='test dataset path') 20 | opt = parser.parse_args() 21 | -------------------------------------------------------------------------------- /images/a.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nowander/WaveNet/7e1343bcc5b583f1aaa8d581d0ea3dbf101d9011/images/a.png -------------------------------------------------------------------------------- /images/r.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nowander/WaveNet/7e1343bcc5b583f1aaa8d581d0ea3dbf101d9011/images/r.png -------------------------------------------------------------------------------- /networks/SwinNet/README.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nowander/WaveNet/7e1343bcc5b583f1aaa8d581d0ea3dbf101d9011/networks/SwinNet/README.md -------------------------------------------------------------------------------- /networks/SwinNet/SwinNet_test.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | @author: caigentan@AnHui University 4 | @software: PyCharm 5 | @file: SwinNet_test.py 6 | @time: 2021/5/31 09:34 7 | """ 8 | 9 | import torch 10 | import torch.nn.functional as F 11 | import sys 12 | sys.path.append('./models') 13 | import numpy as np 14 | import os, argparse 15 | import cv2 16 | from models.swinNet import SwinNet 17 | from data import test_dataset 18 | import time 19 | 20 | parser = argparse.ArgumentParser() 21 | parser.add_argument('--testsize', type=int, default=384, help='testing size') 22 | parser.add_argument('--gpu_id', type=str, default='1', help='select gpu id') 23 | parser.add_argument('--test_path',type=str,default='./datasets/RGB-D/test/',help='test dataset path') 24 | # parser.add_argument('--test_path',type=str,default='./datasets/RGB-T/Test/',help='test dataset path') 25 | opt = parser.parse_args() 26 | 27 | dataset_path = opt.test_path 28 | 29 | #set device for test 30 | if opt.gpu_id=='0': 31 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 32 | print('USE GPU 0') 33 | elif opt.gpu_id=='1': 34 | os.environ["CUDA_VISIBLE_DEVICES"] = "1" 35 | print('USE GPU 1') 36 | 37 | #load the model 38 | model = SwinNet() 39 | #Large epoch size may not generalize well. You can choose a good model to load according to the log file and pth files saved in ('./BBSNet_cpts/') when training. 40 | model.load_state_dict(torch.load("./cpts/SwinNet.pth")) 41 | 42 | model.cuda() 43 | model.eval() 44 | fps = 0 45 | 46 | # RGBT-Test 47 | test_datasets = ['VT1000','VT821','VT5000'] 48 | for dataset in test_datasets: 49 | time_s = time.time() 50 | sal_save_path = './test_maps/SwinTransNet_RGBT_speed/' + dataset + '/' 51 | edge_save_path = './test_maps/SwinTransNet_RGBT2_best/Edge/' + dataset + '/' 52 | if not os.path.exists(sal_save_path): 53 | os.makedirs(sal_save_path) 54 | # os.makedirs(edge_save_path) 55 | image_root = dataset_path + dataset + '/RGB/' 56 | gt_root = dataset_path + dataset + '/GT/' 57 | depth_root = dataset_path + dataset + '/T/' 58 | test_loader = test_dataset(image_root, gt_root, depth_root, opt.testsize) 59 | nums = test_loader.size 60 | for i in range(test_loader.size): 61 | image, gt, depth, name, image_for_post = test_loader.load_data() 62 | gt = np.asarray(gt, np.float32) 63 | gt /= (gt.max() + 1e-8) 64 | image = image.cuda() 65 | # depth = depth = depth.repeat(1,3,1,1).cuda() 66 | depth = depth.cuda() 67 | 68 | 69 | # print(depth.shape) 70 | res,edge = model(image,depth) 71 | res = F.upsample(res, size=gt.shape, mode='bilinear', align_corners=False) 72 | # edge = F.upsample(edge, size=gt.shape, mode='bilinear', align_corners=False) 73 | res = res.sigmoid().data.cpu().numpy().squeeze() 74 | # edge = edge.sigmoid().data.cpu().numpy().squeeze() 75 | res = (res - res.min()) / (res.max() - res.min() + 1e-8) 76 | # edge = (edge - edge.min()) / (edge.max() - edge.min() + 1e-8) 77 | print('save img to: ', sal_save_path + name) 78 | cv2.imwrite(sal_save_path + name, res * 255) 79 | time_e = time.time() 80 | fps += (nums / (time_e - time_s)) 81 | print("FPS:%f" % (nums / (time_e - time_s))) 82 | print('Test Done!') 83 | print("Total FPS %f" % fps) # this result include I/O cost 84 | 85 | # test 86 | 87 | test_datasets = ['SIP','SSD','RedWeb','NJU2K','NLPR','STERE','DES','LFSD',] 88 | # test_datasets = ['DUT-RGBD'] 89 | fps = 0 90 | for dataset in test_datasets: 91 | time_s = time.time() 92 | save_path = './test_maps/' + dataset + '/' 93 | edge_save_path = './test_maps/' + dataset + '/edge/' 94 | if not os.path.exists(save_path): 95 | os.makedirs(save_path) 96 | if not os.path.exists(edge_save_path): 97 | os.makedirs(edge_save_path) 98 | image_root = dataset_path + dataset + '/RGB/' 99 | gt_root = dataset_path + dataset + '/GT/' 100 | depth_root = dataset_path + dataset + '/depth/' 101 | test_loader = test_dataset(image_root, gt_root, depth_root, opt.testsize) 102 | nums = test_loader.size 103 | for i in range(test_loader.size): 104 | image, gt, depth, name, image_for_post = test_loader.load_data() 105 | gt = np.asarray(gt, np.float32) 106 | gt /= (gt.max() + 1e-8) 107 | image = image.cuda() 108 | depth = depth.repeat(1,3,1,1).cuda() 109 | res, edge = model(image,depth) 110 | res = F.upsample(res, size=gt.shape, mode='bilinear', align_corners=False) 111 | edge = F.upsample(edge, size=gt.shape, mode='bilinear', align_corners=False) 112 | res = res.sigmoid().data.cpu().numpy().squeeze() 113 | edge = edge.sigmoid().data.cpu().numpy().squeeze() 114 | res = (res - res.min()) / (res.max() - res.min() + 1e-8) 115 | edge = (edge - edge.min()) / (edge.max() - edge.min() + 1e-8) 116 | print('save img to: ',save_path+name) 117 | cv2.imwrite(save_path + name, res*255) 118 | cv2.imwrite(edge_save_path + name, edge * 255) 119 | time_e = time.time() 120 | fps += (nums / (time_e - time_s)) 121 | print("FPS:%f" % (nums / (time_e - time_s))) 122 | print('Test Done!') 123 | print("Total FPS %f" % fps) 124 | 125 | -------------------------------------------------------------------------------- /networks/SwinNet/SwinNet_train.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | @author: caigentan@AnHui University 4 | @software: PyCharm 5 | @file: SwinNet.py 6 | @time: 2021/5/6 16:12 7 | """ 8 | import os 9 | import torch 10 | import torch.nn.functional as F 11 | import sys 12 | 13 | sys.path.append('./models') 14 | import numpy as np 15 | from datetime import datetime 16 | from models.swinNet import SwinTransformer,SwinNet 17 | from torchvision.utils import make_grid 18 | from data import get_loader, test_dataset 19 | from utils import clip_gradient, adjust_lr 20 | from tensorboardX import SummaryWriter 21 | import logging 22 | import torch.backends.cudnn as cudnn 23 | from options import opt 24 | import yaml 25 | 26 | 27 | # set the device for training 28 | if opt.gpu_id == '0': 29 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 30 | print('USE GPU 0') 31 | elif opt.gpu_id == '1': 32 | os.environ["CUDA_VISIBLE_DEVICES"] = "1" 33 | print('USE GPU 1') 34 | cudnn.benchmark = True 35 | 36 | image_root = opt.rgb_root 37 | gt_root = opt.gt_root 38 | depth_root = opt.depth_root 39 | edge_root = opt.edge_root 40 | 41 | val_image_root = opt.val_rgb_root 42 | val_gt_root = opt.val_gt_root 43 | val_depth_root = opt.val_depth_root 44 | save_path = opt.save_path 45 | 46 | logging.basicConfig(filename=save_path + 'RGBD-SwinNet.log', 47 | format='[%(asctime)s-%(filename)s-%(levelname)s:%(message)s]', level=logging.INFO, filemode='a', 48 | datefmt='%Y-%m-%d %I:%M:%S %p') 49 | 50 | 51 | model = SwinNet() 52 | 53 | num_parms = 0 54 | if (opt.load is not None): 55 | model.load_pre(opt.load) 56 | print('load model from ', opt.load) 57 | 58 | model.cuda() 59 | for p in model.parameters(): 60 | num_parms += p.numel() 61 | logging.info("Total Parameters (For Reference): {}".format(num_parms)) 62 | print("Total Parameters (For Reference): {}".format(num_parms)) 63 | 64 | params = model.parameters() 65 | optimizer = torch.optim.Adam(params, opt.lr) 66 | 67 | # set the path 68 | if not os.path.exists(save_path): 69 | os.makedirs(save_path) 70 | 71 | # load data 72 | print('load data...') 73 | train_loader = get_loader(image_root, gt_root,depth_root, edge_root, batchsize=opt.batchsize, trainsize=opt.trainsize) 74 | test_loader = test_dataset(val_image_root, val_gt_root, val_depth_root, opt.trainsize) 75 | total_step = len(train_loader) 76 | 77 | logging.info("Config") 78 | logging.info( 79 | 'epoch:{};lr:{};batchsize:{};trainsize:{};clip:{};decay_rate:{};load:{};save_path:{};decay_epoch:{}'.format( 80 | opt.epoch, opt.lr, opt.batchsize, opt.trainsize, opt.clip, opt.decay_rate, opt.load, save_path, 81 | opt.decay_epoch)) 82 | 83 | # set loss function 84 | CE = torch.nn.BCEWithLogitsLoss() 85 | ECE = torch.nn.BCELoss() 86 | step = 0 87 | writer = SummaryWriter(save_path + 'summary') 88 | best_mae = 1 89 | best_epoch = 0 90 | 91 | 92 | # train function 93 | def train(train_loader, model, optimizer, epoch, save_path): 94 | global step 95 | model.train() 96 | loss_all = 0 97 | epoch_step = 0 98 | 99 | try: 100 | for i, (images, gts, depth,edge) in enumerate(train_loader, start=1): 101 | optimizer.zero_grad() 102 | 103 | images = images.cuda() 104 | gts = gts.cuda() 105 | # RGB-T data 106 | # depth = depth.cuda() 107 | # RGB-D data 108 | depth = depth.repeat(1,3,1,1).cuda() 109 | edge = edge.cuda() 110 | s, e = model(images,depth) 111 | 112 | sal_loss = CE(s, gts) 113 | edge_loss = CE(e, edge) 114 | loss = sal_loss + edge_loss 115 | loss.backward() 116 | 117 | clip_gradient(optimizer, opt.clip) 118 | optimizer.step() 119 | step += 1 120 | epoch_step += 1 121 | loss_all += loss.data 122 | memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0) 123 | if i % 100 == 0 or i == total_step or i == 1: 124 | print('{} Epoch [{:03d}/{:03d}], Step [{:04d}/{:04d}], LR:{:.7f}||sal_loss:{:4f} ||edge_loss:{:4f}'. 125 | format(datetime.now(), epoch, opt.epoch, i, total_step, 126 | optimizer.state_dict()['param_groups'][0]['lr'], sal_loss.data, edge_loss.data)) 127 | logging.info( 128 | '#TRAIN#:Epoch [{:03d}/{:03d}], Step [{:04d}/{:04d}], LR:{:.7f}, sal_loss:{:4f} ||edge_loss:{:4f}, mem_use:{:.0f}MB'. 129 | format(epoch, opt.epoch, i, total_step, optimizer.state_dict()['param_groups'][0]['lr'], loss.data,edge_loss.data,memory_used)) 130 | writer.add_scalar('Loss', loss.data, global_step=step) 131 | grid_image = make_grid(images[0].clone().cpu().data, 1, normalize=True) 132 | writer.add_image('RGB', grid_image, step) 133 | grid_image = make_grid(gts[0].clone().cpu().data, 1, normalize=True) 134 | writer.add_image('Ground_truth', grid_image, step) 135 | res = s[0].clone() 136 | res = res.sigmoid().data.cpu().numpy().squeeze() 137 | res = (res - res.min()) / (res.max() - res.min() + 1e-8) 138 | writer.add_image('res', torch.tensor(res), step, dataformats='HW') 139 | loss_all /= epoch_step 140 | logging.info('#TRAIN#:Epoch [{:03d}/{:03d}],Loss_AVG: {:.4f}'.format(epoch, opt.epoch, loss_all)) 141 | writer.add_scalar('Loss-epoch', loss_all, global_step=epoch) 142 | if (epoch) % 5 == 0: 143 | torch.save(model.state_dict(), save_path + 'SwinNet_epoch_{}.pth'.format(epoch)) 144 | except KeyboardInterrupt: 145 | print('Keyboard Interrupt: save model and exit.') 146 | if not os.path.exists(save_path): 147 | os.makedirs(save_path) 148 | torch.save(model.state_dict(), save_path + 'SwinNet_epoch_{}.pth'.format(epoch + 1)) 149 | print('save checkpoints successfully!') 150 | raise 151 | 152 | # test function 153 | def test(test_loader, model, epoch, save_path): 154 | global best_mae, best_epoch 155 | model.eval() 156 | with torch.no_grad(): 157 | mae_sum = 0 158 | for i in range(test_loader.size): 159 | image, gt, depth, name, img_for_post = test_loader.load_data() 160 | 161 | gt = np.asarray(gt, np.float32) 162 | gt /= (gt.max() + 1e-8) 163 | image = image.cuda() 164 | # print(depth.shape) 165 | # depth = depth.cuda() 166 | depth = depth.repeat(1,3,1,1).cuda() 167 | res,e = model(image,depth) 168 | res = F.upsample(res, size=gt.shape, mode='bilinear', align_corners=False) 169 | res = res.sigmoid().data.cpu().numpy().squeeze() 170 | res = (res - res.min()) / (res.max() - res.min() + 1e-8) 171 | mae_sum += np.sum(np.abs(res - gt)) * 1.0 / (gt.shape[0] * gt.shape[1]) 172 | mae = mae_sum / test_loader.size 173 | writer.add_scalar('MAE', torch.tensor(mae), global_step=epoch) 174 | print('Epoch: {} MAE: {} #### bestMAE: {} bestEpoch: {}'.format(epoch, mae, best_mae, best_epoch)) 175 | if epoch == 1: 176 | best_mae = mae 177 | else: 178 | if mae < best_mae: 179 | best_mae = mae 180 | best_epoch = epoch 181 | torch.save(model.state_dict(), save_path + 'SwinNet_epoch_best.pth') 182 | print('best epoch:{}'.format(epoch)) 183 | logging.info('#TEST#:Epoch:{} MAE:{} bestEpoch:{} bestMAE:{}'.format(epoch, mae, best_epoch, best_mae)) 184 | 185 | 186 | if __name__ == '__main__': 187 | print("Start train...") 188 | for epoch in range(1, opt.epoch): 189 | cur_lr = adjust_lr(optimizer, opt.lr, epoch, opt.decay_rate, opt.decay_epoch) 190 | writer.add_scalar('learning_rate', cur_lr, global_step=epoch) 191 | train(train_loader, model, optimizer, epoch, save_path) 192 | test(test_loader, model, epoch, save_path) 193 | -------------------------------------------------------------------------------- /networks/SwinNet/cpts/RGBD-SwinNet.log: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nowander/WaveNet/7e1343bcc5b583f1aaa8d581d0ea3dbf101d9011/networks/SwinNet/cpts/RGBD-SwinNet.log -------------------------------------------------------------------------------- /networks/SwinNet/data.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image 3 | import torch.utils.data as data 4 | import torchvision.transforms as transforms 5 | import random 6 | import numpy as np 7 | from PIL import ImageEnhance 8 | 9 | # Reference from BBSNet, Thanks!!! 10 | 11 | # several data augumentation strategies 12 | def cv_random_flip(img, label, depth,edge): 13 | flip_flag = random.randint(0, 1) 14 | # flip_flag2= random.randint(0,1) 15 | # left right flip 16 | if flip_flag == 1: 17 | img = img.transpose(Image.FLIP_LEFT_RIGHT) 18 | label = label.transpose(Image.FLIP_LEFT_RIGHT) 19 | depth = depth.transpose(Image.FLIP_LEFT_RIGHT) 20 | edge = edge.transpose(Image.FLIP_LEFT_RIGHT) 21 | # top bottom flip 22 | # if flip_flag2==1: 23 | # img = img.transpose(Image.FLIP_TOP_BOTTOM) 24 | # label = label.transpose(Image.FLIP_TOP_BOTTOM) 25 | # depth = depth.transpose(Image.FLIP_TOP_BOTTOM) 26 | return img, label, depth, edge 27 | 28 | 29 | def randomCrop(image, label, depth, edge): 30 | border = 30 31 | image_width = image.size[0] 32 | image_height = image.size[1] 33 | crop_win_width = np.random.randint(image_width - border, image_width) 34 | crop_win_height = np.random.randint(image_height - border, image_height) 35 | random_region = ( 36 | (image_width - crop_win_width) >> 1, (image_height - crop_win_height) >> 1, (image_width + crop_win_width) >> 1, 37 | (image_height + crop_win_height) >> 1) 38 | return image.crop(random_region), label.crop(random_region), depth.crop(random_region), edge.crop(random_region) 39 | 40 | 41 | def randomRotation(image, label, depth, edge): 42 | mode = Image.BICUBIC 43 | if random.random() > 0.8: 44 | random_angle = np.random.randint(-15, 15) 45 | image = image.rotate(random_angle, mode) 46 | label = label.rotate(random_angle, mode) 47 | depth = depth.rotate(random_angle, mode) 48 | edge = edge.rotate(random_angle, mode) 49 | return image, label, depth, edge 50 | 51 | 52 | def colorEnhance(image): 53 | bright_intensity = random.randint(5, 15) / 10.0 54 | image = ImageEnhance.Brightness(image).enhance(bright_intensity) 55 | contrast_intensity = random.randint(5, 15) / 10.0 56 | image = ImageEnhance.Contrast(image).enhance(contrast_intensity) 57 | color_intensity = random.randint(0, 20) / 10.0 58 | image = ImageEnhance.Color(image).enhance(color_intensity) 59 | sharp_intensity = random.randint(0, 30) / 10.0 60 | image = ImageEnhance.Sharpness(image).enhance(sharp_intensity) 61 | return image 62 | 63 | 64 | def randomGaussian(image, mean=0.1, sigma=0.35): 65 | def gaussianNoisy(im, mean=mean, sigma=sigma): 66 | for _i in range(len(im)): 67 | im[_i] += random.gauss(mean, sigma) 68 | return im 69 | 70 | img = np.asarray(image) 71 | width, height = img.shape 72 | img = gaussianNoisy(img[:].flatten(), mean, sigma) 73 | img = img.reshape([width, height]) 74 | return Image.fromarray(np.uint8(img)) 75 | 76 | 77 | def randomPeper(img): 78 | img = np.array(img) 79 | noiseNum = int(0.0015 * img.shape[0] * img.shape[1]) 80 | for i in range(noiseNum): 81 | 82 | randX = random.randint(0, img.shape[0] - 1) 83 | 84 | randY = random.randint(0, img.shape[1] - 1) 85 | 86 | if random.randint(0, 1) == 0: 87 | 88 | img[randX, randY] = 0 89 | 90 | else: 91 | 92 | img[randX, randY] = 255 93 | return Image.fromarray(img) 94 | 95 | 96 | # dataset for training 97 | # The current loader is not using the normalized depth maps for training and test. If you use the normalized depth maps 98 | # (e.g., 0 represents background and 1 represents foreground.), the performance will be further improved. 99 | class SalObjDataset(data.Dataset): 100 | def __init__(self, image_root, gt_root, depth_root, edge_root, trainsize): 101 | self.trainsize = trainsize 102 | self.images = [image_root + f for f in os.listdir(image_root) if f.endswith('.jpg')] 103 | 104 | self.gts = [gt_root + f for f in os.listdir(gt_root) if f.endswith('.jpg') 105 | or f.endswith('.png')] 106 | 107 | self.depths = [depth_root + f for f in os.listdir(depth_root) if f.endswith('.bmp') 108 | or f.endswith('.png') or f.endswith('.jpg')] 109 | self.edges = [edge_root + f for f in os.listdir(edge_root) if f.endswith('.bmp') 110 | or f.endswith('.png') or f.endswith('.jpg')] 111 | self.images = sorted(self.images) 112 | self.gts = sorted(self.gts) 113 | self.depths = sorted(self.depths) 114 | self.edges = sorted(self.edges) 115 | self.filter_files() 116 | self.size = len(self.images) 117 | self.img_transform = transforms.Compose([ 118 | transforms.Resize((self.trainsize, self.trainsize)), 119 | transforms.ToTensor(), 120 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) 121 | self.gt_transform = transforms.Compose([ 122 | transforms.Resize((self.trainsize, self.trainsize)), 123 | transforms.ToTensor()]) 124 | self.depths_transform = transforms.Compose( 125 | [transforms.Resize((self.trainsize, self.trainsize)), transforms.ToTensor()]) 126 | self.edges_transform = transforms.Compose( 127 | [transforms.Resize((self.trainsize, self.trainsize)), transforms.ToTensor()] 128 | ) 129 | 130 | def __getitem__(self, index): 131 | image = self.rgb_loader(self.images[index]) 132 | gt = self.binary_loader(self.gts[index]) 133 | depth = self.binary_loader(self.depths[index]) # RGBD 134 | # depth = self.rgb_loader(self.depths[index]) # RGBT 135 | edge = self.binary_loader(self.edges[index]) 136 | 137 | image, gt, depth, edge = cv_random_flip(image, gt, depth, edge) 138 | image, gt, depth, edge = randomCrop(image, gt, depth, edge) 139 | image, gt, depth, edge = randomRotation(image, gt, depth, edge) 140 | image = colorEnhance(image) 141 | # gt=randomGaussian(gt) 142 | gt = randomPeper(gt) 143 | image = self.img_transform(image) 144 | gt = self.gt_transform(gt) 145 | depth = self.depths_transform(depth) 146 | edge = self.edges_transform(edge) 147 | 148 | return image, gt, depth, edge 149 | 150 | def filter_files(self): 151 | assert len(self.images) == len(self.gts) and len(self.gts) == len(self.images) 152 | images = [] 153 | gts = [] 154 | depths = [] 155 | edges = [] 156 | for img_path, gt_path, depth_path, edge_path in zip(self.images, self.gts, self.depths, self.edges): 157 | img = Image.open(img_path) 158 | gt = Image.open(gt_path) 159 | depth = Image.open(depth_path) 160 | edge = Image.open(edge_path) 161 | if img.size == gt.size and gt.size == depth.size and edge.size == img.size: 162 | images.append(img_path) 163 | gts.append(gt_path) 164 | depths.append(depth_path) 165 | edges.append(edge_path) 166 | self.images = images 167 | self.gts = gts 168 | self.depths = depths 169 | self.edges = edges 170 | 171 | def rgb_loader(self, path): 172 | with open(path, 'rb') as f: 173 | img = Image.open(f) 174 | return img.convert('RGB') 175 | 176 | def binary_loader(self, path): 177 | with open(path, 'rb') as f: 178 | img = Image.open(f) 179 | return img.convert('L') 180 | 181 | def resize(self, img, gt, depth, edge): 182 | assert img.size == gt.size and gt.size == depth.size and edge.size == img.size 183 | w, h = img.size 184 | if h < self.trainsize or w < self.trainsize: 185 | h = max(h, self.trainsize) 186 | w = max(w, self.trainsize) 187 | return img.resize((w, h), Image.BILINEAR), gt.resize((w, h), Image.NEAREST), \ 188 | depth.resize((w, h),Image.NEAREST), edge.resize((w,h), Image.NEAREST) 189 | else: 190 | return img, gt, depth, edge 191 | 192 | def __len__(self): 193 | return self.size 194 | 195 | 196 | # dataloader for training 197 | def get_loader(image_root, gt_root, depth_root, edge_root, batchsize, trainsize, shuffle=True, num_workers=0, pin_memory=True): 198 | dataset = SalObjDataset(image_root, gt_root, depth_root, edge_root, trainsize) 199 | # print(image_root) 200 | # print(gt_root) 201 | # print(depth_root) 202 | data_loader = data.DataLoader(dataset=dataset, 203 | batch_size=batchsize, 204 | shuffle=shuffle, 205 | num_workers=num_workers, 206 | pin_memory=pin_memory) 207 | return data_loader 208 | 209 | 210 | # test dataset and loader 211 | class test_dataset: 212 | def __init__(self, image_root, gt_root, depth_root, testsize): 213 | self.testsize = testsize 214 | self.images = [image_root + f for f in os.listdir(image_root) if f.endswith('.jpg')] 215 | self.gts = [gt_root + f for f in os.listdir(gt_root) if f.endswith('.jpg') 216 | or f.endswith('.png')] 217 | self.depths = [depth_root + f for f in os.listdir(depth_root) if f.endswith('.bmp') 218 | or f.endswith('.png')or f.endswith('.jpg')] 219 | 220 | self.images = sorted(self.images) 221 | self.gts = sorted(self.gts) 222 | self.depths = sorted(self.depths) 223 | self.transform = transforms.Compose([ 224 | transforms.Resize((self.testsize, self.testsize)), 225 | transforms.ToTensor(), 226 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) 227 | self.gt_transform = transforms.ToTensor() 228 | self.depths_transform = transforms.Compose( 229 | [transforms.Resize((self.testsize, self.testsize)), transforms.ToTensor()]) 230 | self.size = len(self.images) 231 | self.index = 0 232 | 233 | def load_data(self): 234 | image = self.rgb_loader(self.images[self.index]) 235 | image = self.transform(image).unsqueeze(0) 236 | gt = self.binary_loader(self.gts[self.index]) 237 | depth = self.binary_loader(self.depths[self.index]) # RGBD 238 | # depth = self.rgb_loader(self.depths[self.index]) # RGBT 239 | depth = self.transform(depth).unsqueeze(0) 240 | name = self.images[self.index].split('/')[-1] 241 | image_for_post = self.rgb_loader(self.images[self.index]) 242 | image_for_post = image_for_post.resize(gt.size) 243 | if name.endswith('.jpg'): 244 | name = name.split('.jpg')[0] + '.png' 245 | self.index += 1 246 | self.index = self.index % self.size 247 | return image, gt, depth, name, np.array(image_for_post) 248 | 249 | def rgb_loader(self, path): 250 | with open(path, 'rb') as f: 251 | img = Image.open(f) 252 | return img.convert('RGB') 253 | 254 | def binary_loader(self, path): 255 | with open(path, 'rb') as f: 256 | img = Image.open(f) 257 | return img.convert('L') 258 | 259 | def __len__(self): 260 | return self.size 261 | 262 | -------------------------------------------------------------------------------- /networks/SwinNet/imgs/RGBD.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nowander/WaveNet/7e1343bcc5b583f1aaa8d581d0ea3dbf101d9011/networks/SwinNet/imgs/RGBD.png -------------------------------------------------------------------------------- /networks/SwinNet/imgs/RGBT.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nowander/WaveNet/7e1343bcc5b583f1aaa8d581d0ea3dbf101d9011/networks/SwinNet/imgs/RGBT.png -------------------------------------------------------------------------------- /networks/SwinNet/imgs/main.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nowander/WaveNet/7e1343bcc5b583f1aaa8d581d0ea3dbf101d9011/networks/SwinNet/imgs/main.png -------------------------------------------------------------------------------- /networks/SwinNet/logger.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | @author: caigentan@AnHui University 4 | @software: PyCharm 5 | @file: logger.py 6 | @time: 2021/5/6 16:00 7 | """ 8 | import os 9 | import sys 10 | import logging 11 | import functools 12 | from termcolor import colored 13 | 14 | 15 | @functools.lru_cache() 16 | def create_logger(output_dir, dist_rank=0, name=''): 17 | # create logger 18 | logger = logging.getLogger(name) 19 | logger.setLevel(logging.DEBUG) 20 | logger.propagate = False 21 | 22 | # create formatter 23 | fmt = '[%(asctime)s %(name)s] (%(filename)s %(lineno)d): %(levelname)s %(message)s' 24 | color_fmt = colored('[%(asctime)s %(name)s]', 'green') + \ 25 | colored('(%(filename)s %(lineno)d)', 'yellow') + ': %(levelname)s %(message)s' 26 | 27 | # create console handlers for master process 28 | if dist_rank == 0: 29 | console_handler = logging.StreamHandler(sys.stdout) 30 | console_handler.setLevel(logging.DEBUG) 31 | console_handler.setFormatter( 32 | logging.Formatter(fmt=color_fmt, datefmt='%Y-%m-%d %H:%M:%S')) 33 | logger.addHandler(console_handler) 34 | 35 | # create file handlers 36 | file_handler = logging.FileHandler(os.path.join(output_dir, f'log_rank{dist_rank}.txt'), mode='a') 37 | file_handler.setLevel(logging.DEBUG) 38 | file_handler.setFormatter(logging.Formatter(fmt=fmt, datefmt='%Y-%m-%d %H:%M:%S')) 39 | logger.addHandler(file_handler) 40 | 41 | return logger -------------------------------------------------------------------------------- /networks/SwinNet/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | @author: caigentan@AnHui University 4 | @software: PyCharm 5 | @file: lr_scheduler.py 6 | @time: 2021/5/6 16:04 7 | """ 8 | 9 | import torch 10 | from timm.scheduler.cosine_lr import CosineLRScheduler 11 | from timm.scheduler.step_lr import StepLRScheduler 12 | from timm.scheduler.scheduler import Scheduler 13 | 14 | 15 | def build_scheduler(config, optimizer, n_iter_per_epoch): 16 | num_steps = int(config.TRAIN.EPOCHS * n_iter_per_epoch) 17 | warmup_steps = int(config.TRAIN.WARMUP_EPOCHS * n_iter_per_epoch) 18 | decay_steps = int(config.TRAIN.LR_SCHEDULER.DECAY_EPOCHS * n_iter_per_epoch) 19 | 20 | lr_scheduler = None 21 | if config.TRAIN.LR_SCHEDULER.NAME == 'cosine': 22 | lr_scheduler = CosineLRScheduler( 23 | optimizer, 24 | t_initial=num_steps, 25 | t_mul=1., 26 | lr_min=config.TRAIN.MIN_LR, 27 | warmup_lr_init=config.TRAIN.WARMUP_LR, 28 | warmup_t=warmup_steps, 29 | cycle_limit=1, 30 | t_in_epochs=False, 31 | ) 32 | elif config.TRAIN.LR_SCHEDULER.NAME == 'linear': 33 | lr_scheduler = LinearLRScheduler( 34 | optimizer, 35 | t_initial=num_steps, 36 | lr_min_rate=0.01, 37 | warmup_lr_init=config.TRAIN.WARMUP_LR, 38 | warmup_t=warmup_steps, 39 | t_in_epochs=False, 40 | ) 41 | elif config.TRAIN.LR_SCHEDULER.NAME == 'step': 42 | lr_scheduler = StepLRScheduler( 43 | optimizer, 44 | decay_t=decay_steps, 45 | decay_rate=config.TRAIN.LR_SCHEDULER.DECAY_RATE, 46 | warmup_lr_init=config.TRAIN.WARMUP_LR, 47 | warmup_t=warmup_steps, 48 | t_in_epochs=False, 49 | ) 50 | 51 | return lr_scheduler 52 | 53 | 54 | class LinearLRScheduler(Scheduler): 55 | def __init__(self, 56 | optimizer: torch.optim.Optimizer, 57 | t_initial: int, 58 | lr_min_rate: float, 59 | warmup_t=0, 60 | warmup_lr_init=0., 61 | t_in_epochs=True, 62 | noise_range_t=None, 63 | noise_pct=0.67, 64 | noise_std=1.0, 65 | noise_seed=42, 66 | initialize=True, 67 | ) -> None: 68 | super().__init__( 69 | optimizer, param_group_field="lr", 70 | noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed, 71 | initialize=initialize) 72 | 73 | self.t_initial = t_initial 74 | self.lr_min_rate = lr_min_rate 75 | self.warmup_t = warmup_t 76 | self.warmup_lr_init = warmup_lr_init 77 | self.t_in_epochs = t_in_epochs 78 | if self.warmup_t: 79 | self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values] 80 | super().update_groups(self.warmup_lr_init) 81 | else: 82 | self.warmup_steps = [1 for _ in self.base_values] 83 | 84 | def _get_lr(self, t): 85 | if t < self.warmup_t: 86 | lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps] 87 | else: 88 | t = t - self.warmup_t 89 | total_t = self.t_initial - self.warmup_t 90 | lrs = [v - ((v - v * self.lr_min_rate) * (t / total_t)) for v in self.base_values] 91 | return lrs 92 | 93 | def get_epoch_values(self, epoch: int): 94 | if self.t_in_epochs: 95 | return self._get_lr(epoch) 96 | else: 97 | return None 98 | 99 | def get_update_values(self, num_updates: int): 100 | if not self.t_in_epochs: 101 | return self._get_lr(num_updates) 102 | else: 103 | return None -------------------------------------------------------------------------------- /networks/SwinNet/main.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | @author: caigentan@AnHui University 4 | @software: PyCharm 5 | @file: main.py 6 | @time: 2021/4/6 16:10 7 | """ 8 | -------------------------------------------------------------------------------- /networks/SwinNet/models/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | @author: caigentan@AnHui University 4 | @software: PyCharm 5 | @file: __init__.py 6 | @time: 2021/5/6 15:58 7 | """ 8 | from .build import build_model -------------------------------------------------------------------------------- /networks/SwinNet/models/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nowander/WaveNet/7e1343bcc5b583f1aaa8d581d0ea3dbf101d9011/networks/SwinNet/models/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /networks/SwinNet/models/__pycache__/build.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nowander/WaveNet/7e1343bcc5b583f1aaa8d581d0ea3dbf101d9011/networks/SwinNet/models/__pycache__/build.cpython-38.pyc -------------------------------------------------------------------------------- /networks/SwinNet/models/__pycache__/swinNet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nowander/WaveNet/7e1343bcc5b583f1aaa8d581d0ea3dbf101d9011/networks/SwinNet/models/__pycache__/swinNet.cpython-38.pyc -------------------------------------------------------------------------------- /networks/SwinNet/models/back/Swin_Transformer.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | @author: caigentan@AnHui University 4 | @software: PyCharm 5 | @file: build.py 6 | @time: 2021-05-03 11:31 7 | """ 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.utils.checkpoint as checkpoint 12 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_ 13 | import numpy as np 14 | import math 15 | 16 | class Mlp(nn.Module): 17 | def __init__(self, in_features, hidden_features=None,out_features=None, act_layer=nn.GELU, drop=0.): 18 | super().__init__() 19 | out_features = out_features or in_features 20 | hidden_features = hidden_features or in_features 21 | self.fc1 = nn.Linear(in_features, hidden_features) 22 | self.act = act_layer() 23 | self.fc2 = nn.Linear(hidden_features, out_features) 24 | self.drop = nn.Dropout(drop) 25 | 26 | def forward(self,x): 27 | x = self.fc1(x) 28 | x = self.act(x) 29 | x = self.drop(x) 30 | x = self.fc2(x) 31 | x = self.drop(x) 32 | return x 33 | 34 | def window_partition(x, window_size): 35 | B, H, W, C = x.shape 36 | x = x.view(B, H//window_size, window_size, W//window_size, window_size, C) 37 | windows = x.permute(0,1,3,2,4,5).contiguous().view(-1, window_size, window_size, C) 38 | return windows 39 | 40 | def window_reverse(windows, window_size, H, W): 41 | B = int(windows.shape[0] / (H * W / window_size / window_size)) 42 | x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) 43 | x = x.permute(0,1,3,2,4,5).contiguous().view*(B, H, W, -1) 44 | return x 45 | 46 | class WindowAttention(nn.Module): 47 | def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): 48 | super().__init__() 49 | self.dim = dim 50 | self.window_size = window_size 51 | self.num_heads = num_heads 52 | head_dim = dim // self.num_heads 53 | self.scale = qk_scale or head_dim ** -0.5 54 | 55 | self.relative_position_bias_table = nn.Parameter( 56 | torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1, num_heads)) 57 | ) 58 | 59 | coords_h = torch.arrange(self.window_size[0]) 60 | coords_w = torch.arrange(self.window_size[1]) 61 | 62 | coords = torch.stack(torch.meshgrid([coords_h, coords_w])) 63 | coords_flatten = torch.flatten(coords,1) 64 | relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] 65 | relative_coords = relative_coords.permute(1,2,0).contiguous() 66 | relative_coords[:, :, 0] += self.window_size[0] - 1 67 | relative_coords[:, :, 1] += self.window_size[1] - 1 68 | relative_coords[:, :, 0] *= 2 * self.window_size[1] -1 69 | self.relative_position_index = relative_coords.sum(-1) 70 | self.register_buffer("relative_position_index", self.relative_position_index) 71 | 72 | self.qkv = nn.Linear(dim, dim*3, bias=qkv_bias) 73 | self.attn_drop = nn.Dropout(attn_drop) 74 | self.proj = nn.Linear(dim, dim) 75 | self.proj_drop = nn.Dropout(proj_drop) 76 | 77 | trunc_normal_(self.relative_position_bias_table, std=.02) 78 | self.softmax = nn.Softmax(dim=-1) 79 | 80 | def forward(self, x, mask=None): 81 | B_, N, C = x.shape 82 | qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 83 | q, k, v = qkv[0], qkv[1], qkv[2] 84 | 85 | q = q * self.scale 86 | attn = (q @ k.transpose(-2, -1)) 87 | 88 | relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( 89 | self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) 90 | attn = attn + relative_position_bias.unsqueeze(0) 91 | 92 | if mask is not None: 93 | nW = mask.shape[0] 94 | attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) 95 | attn = attn.view(-1, self.num_heads, N, N) 96 | attn = self.softmax(attn) 97 | 98 | else: 99 | attn = self.softmax(attn) 100 | 101 | attn = self.attn_drop(attn) 102 | 103 | x = (attn @ v).transpose(1, 2).reshape(B_, N, C) 104 | x = self.proj(x) 105 | x = self.proj_drop(x) 106 | return x 107 | 108 | def extra_repr(self) -> str: 109 | return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' 110 | 111 | def flops(self, N): 112 | flops = 0 113 | flops += N*self.dim*3*self.dim 114 | flops += self.num_heads * N * (self.dim // self.num_heads) * N 115 | flops += self.num_heads * N * N * (self.dim // self.num_heads) 116 | flops += N * self.dim + self.dim 117 | return flops 118 | 119 | class SwinTransformerBlock(nn.Module): 120 | def __init__(self, dim, input_resolution,num_heads,window_size=7, shift_size=0, 121 | mlp_ratio=4, qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., 122 | act_layer=nn.GELU, norm_layer = nn.LayerNorm): 123 | super().__init__() 124 | self.dim = dim 125 | self.input_resolution = input_resolution 126 | self.num_heads = num_heads 127 | self.window_size = window_size 128 | self.shift_size = shift_size 129 | self.mlp_ratio = mlp_ratio 130 | if min(self.input_resolution) < self.window_size: 131 | self.shift_size = 0 132 | self.window_size = min(self.input_resolution) 133 | 134 | assert 0 <= self.shift_size < self.window_size,"shift_size must in 0-window_size" 135 | 136 | self.norm1 = norm_layer(dim) 137 | self.attn = WindowAttention( 138 | dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, 139 | qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop 140 | ) 141 | self.drop_path= DropPath(drop_path) if drop_path > 0. else nn.Identity() 142 | self.norm2 = norm_layer(dim) 143 | mlp_hidden_dim = int(dim * mlp_ratio) 144 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim,act_layer=act_layer, drop=drop) 145 | 146 | if self.shift_size > 0: 147 | H, W = self.input_resolution 148 | img_mask = torch.zeros((1, H, W, -1)) 149 | h_slices = (slice(0, -self.window_size), 150 | slice(-self.window_size, -self.shift_size), 151 | slice(-self.shift_size,None)) 152 | w_slices = (slice(0, -self.window_size), 153 | slice(-self.window_size, -self.shift_size), 154 | slice(-self.shift_size, None)) 155 | 156 | cnt = 0 157 | for h in h_slices: 158 | for w in w_slices: 159 | img_mask[:, h, w, :] = cnt 160 | cnt += 1 161 | 162 | mask_windows = window_partition(img_mask, self.window_size) 163 | mask_windows = mask_windows.view(-1, self.window_size * self.window_size) 164 | attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) 165 | self.attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) 166 | else: 167 | self.attn_mask = None 168 | 169 | self.register_buffer("attn_mask", self.attn_mask) 170 | def forward(self, x): 171 | H, W = self.input_resolution 172 | B, L, C = x.shape 173 | assert L == H * W, "input feature has wrong size" 174 | 175 | shortcut = x 176 | x = self.norm1(x) 177 | x = x.view(B, H, W, C) 178 | 179 | if self.shift_size > 0: 180 | shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) 181 | else: 182 | shifted_x = x 183 | 184 | # position window 185 | x_windows = window_partition(shifted_x, self.window_size) 186 | x_windows = x_windows.view(-1, self.window_size * self.window_size, C) 187 | 188 | attn_windows = self.attn(x_windows, mask=self.attn_mask) 189 | attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) 190 | shifted_x = window_reverse(attn_windows, self.window_size, H, W) 191 | 192 | if self.shift_size > 0: 193 | x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) 194 | else: 195 | x = shifted_x 196 | x = x.view(B, H*W, C) 197 | 198 | #FFN 199 | x = shortcut + self.drop_path(x) 200 | x = x = self.drop_path(self.mlp(self.norm2(x))) 201 | return x 202 | 203 | def extra_repr(self) -> str: 204 | return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ 205 | f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" 206 | 207 | def flops(self): 208 | flops = 0 209 | H, W = self.input_resolution 210 | 211 | # norm1 212 | flops += self.dim * H * W 213 | nW = H * W / self.window_size / self.window_size 214 | flops += nW * self.attn.flops(self.window_size * self.window_size) 215 | flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio 216 | flops += self.dim * H * W 217 | return flops 218 | 219 | class PatchMerging(nn.Module): 220 | def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): 221 | super(self).__init__() 222 | self.input_resolution = input_resolution 223 | self.dim = dim 224 | self.reduction = nn.Linear(4 * dim, 2* dim, bias = False) 225 | self.norm =norm_layer(4 * dim) 226 | 227 | def forward(self, x): 228 | H, W = self.input_resolution 229 | B, L, C = x.shape 230 | assert L == H * W, "input feature has wrong size" 231 | assert H % 2 == 0 and W % 2 == 0,f"x size ({H}*{W}) are nor even(偶数)" 232 | 233 | x = x.view(B, H, W, C) 234 | 235 | x0 = x[:, 0::2, 0::2, :] 236 | x1 = x[:, 1::2, 0::2, :] 237 | x2 = x[:, 0::2, 1::2, :] 238 | x3 = x[:, 1::2, 1::2, :] 239 | x = torch.cat([x0,x1,x2,x3], -1) 240 | x = x.view(B, -1, 4 * C) 241 | 242 | x = self.norm(x) 243 | x =self.reduction(x) 244 | 245 | return x 246 | 247 | def extra_repr(self) -> str: 248 | return f"input_resolution={self.input_resolution}, dim={self.dim}" 249 | 250 | def flops(self): 251 | H,W = self.input_resolution 252 | flops = H * W * self.dim 253 | flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim 254 | return flops 255 | 256 | class BasicLayer(nn.Module): 257 | def __init__(self, dim, input_resolution, depth, num_heads, window_size, 258 | mlp_ratio=4, qkv_bias = True, qk_scale=None, drop=0., attn_drop=0., 259 | drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False): 260 | super().__init__() 261 | self.dim = dim 262 | self.input_resolution = input_resolution 263 | self.depth = depth 264 | self.use_checkpoint=use_checkpoint 265 | 266 | self.blocks = nn.ModuleList([ 267 | SwinTransformerBlock(dim=dim, input_resolution=input_resolution, 268 | num_heads=num_heads, window_size=window_size, 269 | shift_size=0 if (i % 2 == 0) else window_size // 2, 270 | mlp_ratio=mlp_ratio, 271 | qkv_bias=qkv_bias,qk_scale=qk_scale, 272 | drop=drop, attn_drop=attn_drop, 273 | drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, 274 | norm_layer=norm_layer) 275 | for i in range(depth)]) 276 | 277 | if downsample is not None: 278 | self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) 279 | else: 280 | self.downsample = None 281 | 282 | def forward(self, x): 283 | for blk in self.blocks: 284 | if self.use_checkpoint: 285 | x = checkpoint.checkpoint(blk, x) 286 | else: 287 | x = blk(x) 288 | if self.downsample is not None: 289 | x = self.downsample(x) 290 | return x 291 | 292 | def extra_repr(self) -> str: 293 | return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" 294 | 295 | def flops(self): 296 | flops = 0 297 | for blk in self.blocks: 298 | flops += blk.flops() 299 | if self.downsample is not None: 300 | flops += self.downsample.flops() 301 | return flops 302 | 303 | class PatchEmbed(nn.Module): 304 | def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): 305 | super(PatchEmbed, self).__init__() 306 | img_size = to_2tuple(img_size) 307 | patch_size = to_2tuple(patch_size) 308 | patchs_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] 309 | self.img_size = img_size 310 | self.patch_size = patch_size 311 | self.patchs_resolution = patchs_resolution 312 | self.num_patches = patchs_resolution[0] * patchs_resolution[1] 313 | 314 | self.in_chans = in_chans 315 | self.embed_dim = embed_dim 316 | 317 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) 318 | if norm_layer is not None: 319 | self.norm = norm_layer(embed_dim) 320 | else: 321 | self.norm = None 322 | 323 | def forward(self, x): 324 | B, C, H, W = x.shape 325 | assert H == self.img_size[0] and W == self.img_size[1], \ 326 | f"Imput image size ({H}*{W}) doesn't match model({self.img_size[0]}*{self.img_size[1]})" 327 | x = self.proj(x).flatten(2).transpose(1, 2) 328 | if self.norm is not None: 329 | x = self.norm(x) 330 | return x 331 | 332 | def flops(self): 333 | H0, W0 = self.patchs_resolution 334 | flops = H0 * W0 * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) 335 | if self.norm is not None: 336 | flops += H0 * W0 * self.embed_dim 337 | return flops 338 | 339 | class SwinTransformer(nn.Module): 340 | def __init__(self, img_size=224, patch_size=4, in_chans=3, num_class=1000, 341 | embed_dim = 96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], 342 | window_size=7, mlp_ratio = 4, qkv_bias = True, qk_scale=None, 343 | drop_rate = 0., attn_drop_rate=0., drop_path_rate=0.1, 344 | norm_layer=nn.LayerNorm, ape = False, patch_norm=True, 345 | use_checkpoint=False, **kwargs): 346 | super(SwinTransformer, self).__init__() 347 | self.num_class = num_class 348 | self.num_layers = len(depths) 349 | self.embed_dim = embed_dim 350 | self.ape = ape 351 | self.patch_norm = patch_norm 352 | self.num_features = int(embed_dim * 2 ** (self.num_layers - 1)) 353 | self.mlp_ratio = mlp_ratio 354 | 355 | #split image into non-overlapping patches 356 | self.patch_embed = PatchEmbed( 357 | img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, 358 | norm_layer=norm_layer if self.patch_norm else None) 359 | num_patches = self.patch_embed.num_patches 360 | patches_resolution = self.patch_embed.patchs_resolution 361 | self.patches_resloution = patches_resolution 362 | 363 | # absolute position embedding 364 | if self.ape: 365 | self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) 366 | trunc_normal_(self.absolute_pos_embed, std=.02) 367 | self.pos_drop = nn.Dropout(p=drop_rate) 368 | 369 | # stochastic depth 随机深度 370 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] 371 | 372 | # build layer 373 | self.layers = nn.ModuleList() 374 | for i_layer in range(self.num_layers): 375 | layer = BasicLayer(dim = int(embed_dim * 2 ** i_layer), 376 | input_resolution = (patches_resolution[0] // (2 ** i_layer), 377 | patches_resolution[1] // (2 ** i_layer)), 378 | depth = depths[i_layer], 379 | num_heads = num_heads[i_layer], 380 | window_size = window_size, 381 | mlp_ratio = mlp_ratio, 382 | qkv_bias = qkv_bias,qk_scale=qk_scale, 383 | drop_path = dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], 384 | norm_layer = norm_layer, 385 | downsample = PatchMerging if (i_layer < self.num_layers -1) else None, 386 | use_checkpoint = use_checkpoint) 387 | self.layers.appens(layer) 388 | self.norm = norm_layer(self.num_features) 389 | self.avgpool = nn.AdaptiveAvgPool1d(1) 390 | self.apply(self._init_weights) 391 | 392 | def _int_weights(self,m): 393 | if isinstance(m, nn.Linear): 394 | trunc_normal_(m.weight, std=.02) 395 | if isinstance(m, nn.Linear) and m.bias is not None: 396 | nn.init.constant_(m.bias, 0) 397 | elif isinstance(m, nn.LayerNorm): 398 | nn.init.constant_(m.bias, 0) 399 | nn.init.constant_(m.weight, 1.0) 400 | 401 | @torch.jit.ignore 402 | def no_weight_decay(self): 403 | return {'absolute_pos_embed'} 404 | 405 | @torch.jit.ignore 406 | def no_weight_decay_keywords(self): 407 | return {'relative_position_bias_table'} 408 | 409 | def forward_features(self, x): 410 | x = self.patch_embed(x) 411 | if self.ape: 412 | x += self.absolute_pos_embed 413 | x = self.pos_drop(x) 414 | 415 | for layer in self.layers: 416 | x = layer(x) 417 | x = self.norm(x) 418 | x = self.avgpool(x.transpose(1, 2)) 419 | x = torch.flatten(x, 1) 420 | return x 421 | 422 | def forward(self, x): 423 | x = self.forward_features(x) 424 | return x 425 | 426 | def flops(self): 427 | flops = 0 428 | flops += self.patch_embed.flops() 429 | for i, layer in enumerate(self.layers): 430 | flops += layer.flops() 431 | flops += self.num_features * self.patches_resloution[0] * self.patches_resloution[1] // (2 ** self.num_layers) 432 | flops += self.num_features * self.num_class 433 | return flops 434 | 435 | if __name__ == '__main__': 436 | a = np.random.random((1,3,224,224)) 437 | b = torch.Tensor(a).cuda() 438 | swintransformer = SwinTransformer().cuda() 439 | x = swintransformer(b) 440 | print(x.shape) 441 | 442 | 443 | -------------------------------------------------------------------------------- /networks/SwinNet/models/build.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | @author: caigentan@AnHui University 4 | @software: PyCharm 5 | @file: build.py 6 | @time: 2021/5/3 00:31 7 | """ 8 | 9 | # from models.back.Swin_Transformer import SwinTransformer 10 | 11 | from .swinNet import SwinTransformer 12 | 13 | 14 | def build_model(config): 15 | model_type = config.MODEL.TYPE 16 | if model_type == 'swin': 17 | model = SwinTransformer(img_size=config.DATA.IMG_SIZE, 18 | patch_size=config.MODEL.SWIN.PATCH_SIZE, 19 | in_chans=config.MODEL.SWIN.IN_CHANS, 20 | num_classes=config.MODEL.NUM_CLASSES, 21 | embed_dim=config.MODEL.SWIN.EMBED_DIM, 22 | depths=config.MODEL.SWIN.DEPTHS, 23 | num_heads=config.MODEL.SWIN.NUM_HEADS, 24 | window_size=config.MODEL.SWIN.WINDOW_SIZE, 25 | mlp_ratio=config.MODEL.SWIN.MLP_RATIO, 26 | qkv_bias=config.MODEL.SWIN.QKV_BIAS, 27 | qk_scale=config.MODEL.SWIN.QK_SCALE, 28 | drop_rate=config.MODEL.DROP_RATE, 29 | drop_path_rate=config.MODEL.DROP_PATH_RATE, 30 | ape=config.MODEL.SWIN.APE, 31 | patch_norm=config.MODEL.SWIN.PATCH_NORM, 32 | use_checkpoint=config.TRAIN.USE_CHECKPOINT) 33 | else: 34 | raise NotImplementedError(f"Unkown model: {model_type}") 35 | 36 | return model -------------------------------------------------------------------------------- /networks/SwinNet/models/swinNet.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | @author: caigentan@AnHui University 4 | @software: PyCharm 5 | @file: SwinTransformer.py 6 | @time: 2021/5/6 6:13 7 | """ 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.utils.checkpoint as checkpoint 12 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_ 13 | import numpy as np 14 | import torch.nn.functional as F 15 | import os 16 | 17 | def conv3x3(in_planes, out_planes, stride=1, has_bias=False): 18 | "3x3 convolution with padding" 19 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 20 | padding=1, bias=has_bias) 21 | 22 | def conv3x3_bn_relu(in_planes, out_planes, stride=1): 23 | return nn.Sequential( 24 | conv3x3(in_planes, out_planes, stride), 25 | nn.BatchNorm2d(out_planes), 26 | nn.ReLU(inplace=True), 27 | ) 28 | 29 | 30 | class Mlp(nn.Module): 31 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 32 | super().__init__() 33 | out_features = out_features or in_features 34 | hidden_features = hidden_features or in_features 35 | self.fc1 = nn.Linear(in_features, hidden_features) 36 | self.act = act_layer() 37 | self.fc2 = nn.Linear(hidden_features, out_features) 38 | self.drop = nn.Dropout(drop) 39 | 40 | def forward(self, x): 41 | x = self.fc1(x) 42 | x = self.act(x) 43 | x = self.drop(x) 44 | x = self.fc2(x) 45 | x = self.drop(x) 46 | return x 47 | 48 | 49 | def window_partition(x, window_size): 50 | """ 51 | Args: 52 | x: (B, H, W, C) 53 | window_size (int): window size 54 | 55 | Returns: 56 | windows: (num_windows*B, window_size, window_size, C) 堆叠到一起形成一个长条 57 | """ 58 | B, H, W, C = x.shape 59 | x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) 60 | windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) 61 | return windows 62 | 63 | 64 | def window_reverse(windows, window_size, H, W): 65 | """ 66 | Args: 67 | windows: (num_windows*B, window_size, window_size, C) 68 | window_size (int): Window size 69 | H (int): Height of image 70 | W (int): Width of image 71 | 72 | Returns: 73 | x: (B, H, W, C) 74 | """ 75 | B = int(windows.shape[0] / (H * W / window_size / window_size)) 76 | x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) 77 | x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) 78 | return x 79 | 80 | 81 | class WindowAttention(nn.Module): 82 | r""" Window based multi-head self attention (W-MSA) module with relative position bias. 83 | It supports both of shifted and non-shifted window. 84 | 85 | Args: 86 | dim (int): Number of input channels. 87 | window_size (tuple[int]): The height and width of the window. 88 | num_heads (int): Number of attention heads. 89 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 90 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set 91 | attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 92 | proj_drop (float, optional): Dropout ratio of output. Default: 0.0 93 | """ 94 | 95 | def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): 96 | 97 | super().__init__() 98 | self.dim = dim 99 | self.window_size = window_size # Wh, Ww 100 | self.num_heads = num_heads 101 | head_dim = dim // num_heads 102 | self.scale = qk_scale or head_dim ** -0.5 103 | 104 | # define a parameter table of relative position bias 105 | self.relative_position_bias_table = nn.Parameter( 106 | torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH 107 | 108 | # get pair-wise relative position index for each token inside the window 109 | coords_h = torch.arange(self.window_size[0]) 110 | coords_w = torch.arange(self.window_size[1]) 111 | coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww 112 | coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww 113 | relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww 114 | relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 115 | relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 116 | relative_coords[:, :, 1] += self.window_size[1] - 1 117 | relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 118 | relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww 119 | self.register_buffer("relative_position_index", relative_position_index) 120 | 121 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 122 | self.attn_drop = nn.Dropout(attn_drop) 123 | self.proj = nn.Linear(dim, dim) 124 | self.proj_drop = nn.Dropout(proj_drop) 125 | 126 | trunc_normal_(self.relative_position_bias_table, std=.02) 127 | self.softmax = nn.Softmax(dim=-1) 128 | 129 | def forward(self, x, mask=None): 130 | """ 131 | Args: 132 | x: input features with shape of (num_windows*B, N, C) 133 | mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None 134 | """ 135 | B_, N, C = x.shape 136 | qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 137 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) 138 | 139 | q = q * self.scale 140 | attn = (q @ k.transpose(-2, -1)) 141 | 142 | relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( 143 | self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH 144 | relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww 145 | attn = attn + relative_position_bias.unsqueeze(0) 146 | 147 | if mask is not None: 148 | nW = mask.shape[0] 149 | attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) 150 | attn = attn.view(-1, self.num_heads, N, N) 151 | attn = self.softmax(attn) 152 | else: 153 | attn = self.softmax(attn) 154 | 155 | attn = self.attn_drop(attn) 156 | 157 | x = (attn @ v).transpose(1, 2).reshape(B_, N, C) 158 | x = self.proj(x) 159 | x = self.proj_drop(x) 160 | return x 161 | 162 | def extra_repr(self) -> str: 163 | return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' 164 | 165 | def flops(self, N): 166 | # calculate flops for 1 window with token length of N 167 | flops = 0 168 | # qkv = self.qkv(x) 169 | flops += N * self.dim * 3 * self.dim 170 | # attn = (q @ k.transpose(-2, -1)) 171 | flops += self.num_heads * N * (self.dim // self.num_heads) * N 172 | # x = (attn @ v) 173 | flops += self.num_heads * N * N * (self.dim // self.num_heads) 174 | # x = self.proj(x) 175 | flops += N * self.dim * self.dim 176 | return flops 177 | 178 | 179 | class SwinTransformerBlock(nn.Module): 180 | r""" Swin Transformer Block. 181 | 182 | Args: 183 | dim (int): Number of input channels. 184 | input_resolution (tuple[int]): Input resulotion. 185 | num_heads (int): Number of attention heads. 186 | window_size (int): Window size. 187 | shift_size (int): Shift size for SW-MSA. 188 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 189 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 190 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. 191 | drop (float, optional): Dropout rate. Default: 0.0 192 | attn_drop (float, optional): Attention dropout rate. Default: 0.0 193 | drop_path (float, optional): Stochastic depth rate. Default: 0.0 194 | act_layer (nn.Module, optional): Activation layer. Default: nn.GELU 195 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 196 | """ 197 | 198 | def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, 199 | mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., 200 | act_layer=nn.GELU, norm_layer=nn.LayerNorm): 201 | super().__init__() 202 | self.dim = dim 203 | self.input_resolution = input_resolution 204 | self.num_heads = num_heads 205 | self.window_size = window_size 206 | self.shift_size = shift_size 207 | self.mlp_ratio = mlp_ratio 208 | if min(self.input_resolution) <= self.window_size: 209 | # if window size is larger than input resolution, we don't partition windows 210 | self.shift_size = 0 211 | self.window_size = min(self.input_resolution) 212 | assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" 213 | 214 | self.norm1 = norm_layer(dim) 215 | 216 | self.attn = WindowAttention( 217 | dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, 218 | qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) 219 | 220 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 221 | self.norm2 = norm_layer(dim) 222 | mlp_hidden_dim = int(dim * mlp_ratio) 223 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 224 | 225 | if self.shift_size > 0: 226 | # calculate attention mask for SW-MSA 227 | H, W = self.input_resolution 228 | img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1---Important!!! 229 | h_slices = (slice(0, -self.window_size), 230 | slice(-self.window_size, -self.shift_size), 231 | slice(-self.shift_size, None)) 232 | w_slices = (slice(0, -self.window_size), 233 | slice(-self.window_size, -self.shift_size), 234 | slice(-self.shift_size, None)) 235 | cnt = 0 236 | for h in h_slices: 237 | for w in w_slices: 238 | img_mask[:, h, w, :] = cnt 239 | cnt += 1 240 | 241 | mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 242 | mask_windows = mask_windows.view(-1, self.window_size * self.window_size) 243 | attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) 244 | attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) 245 | else: 246 | attn_mask = None 247 | 248 | self.register_buffer("attn_mask", attn_mask) 249 | 250 | def forward(self, x): 251 | # x: B,C,H,W 252 | H, W = self.input_resolution 253 | B, L, C = x.shape 254 | assert L == H * W, "input feature has wrong size" 255 | 256 | shortcut = x 257 | x = self.norm1(x) 258 | x = x.view(B, H, W, C) 259 | 260 | # cyclic shift 261 | if self.shift_size > 0: 262 | shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) 263 | else: 264 | shifted_x = x 265 | 266 | # partition windows 267 | x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C 268 | x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C 269 | 270 | # W-MSA/SW-MSA 271 | attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C 272 | 273 | # merge windows 274 | attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) 275 | shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C 276 | 277 | # reverse cyclic shift 278 | if self.shift_size > 0: 279 | x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) 280 | else: 281 | x = shifted_x 282 | x = x.view(B, H * W, C) 283 | 284 | # FFN 285 | x = shortcut + self.drop_path(x) 286 | x = x + self.drop_path(self.mlp(self.norm2(x))) 287 | # print('FFN',x.shape) 288 | return x 289 | 290 | def extra_repr(self) -> str: 291 | return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ 292 | f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" 293 | 294 | def flops(self): 295 | flops = 0 296 | H, W = self.input_resolution 297 | # norm1 298 | flops += self.dim * H * W 299 | # W-MSA/SW-MSA 300 | nW = H * W / self.window_size / self.window_size 301 | flops += nW * self.attn.flops(self.window_size * self.window_size) 302 | # mlp 303 | flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio 304 | # norm2 305 | flops += self.dim * H * W 306 | return flops 307 | 308 | 309 | class PatchMerging(nn.Module): 310 | r""" Patch Merging Layer. 311 | 312 | Args: 313 | input_resolution (tuple[int]): Resolution of input feature. 314 | dim (int): Number of input channels. 315 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 316 | """ 317 | 318 | def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): 319 | super().__init__() 320 | self.input_resolution = input_resolution 321 | self.dim = dim 322 | self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) 323 | self.norm = norm_layer(4 * dim) 324 | 325 | def forward(self, x): 326 | """ 327 | x: B, H*W, C 328 | """ 329 | H, W = self.input_resolution 330 | B, L, C = x.shape 331 | assert L == H * W, "input feature has wrong size" 332 | assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." 333 | 334 | x = x.view(B, H, W, C) 335 | 336 | x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C 337 | x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C 338 | x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C 339 | x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C 340 | x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C 341 | x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C 342 | 343 | x = self.norm(x) 344 | x = self.reduction(x) 345 | 346 | return x 347 | 348 | def extra_repr(self) -> str: 349 | return f"input_resolution={self.input_resolution}, dim={self.dim}" 350 | 351 | def flops(self): 352 | H, W = self.input_resolution 353 | flops = H * W * self.dim 354 | flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim 355 | return flops 356 | 357 | 358 | class BasicLayer(nn.Module): 359 | """ A basic Swin Transformer layer for one stage. 360 | 361 | Args: 362 | dim (int): Number of input channels. 363 | input_resolution (tuple[int]): Input resolution. 364 | depth (int): Number of blocks. 365 | num_heads (int): Number of attention heads. 366 | window_size (int): Local window size. 367 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 368 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 369 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. 370 | drop (float, optional): Dropout rate. Default: 0.0 371 | attn_drop (float, optional): Attention dropout rate. Default: 0.0 372 | drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 373 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 374 | downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None 375 | use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. 376 | """ 377 | 378 | def __init__(self, dim, input_resolution, depth, num_heads, window_size, 379 | mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., 380 | drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False): 381 | 382 | super().__init__() 383 | self.dim = dim 384 | self.input_resolution = input_resolution 385 | self.depth = depth 386 | self.use_checkpoint = use_checkpoint 387 | 388 | # build blocks 389 | self.blocks = nn.ModuleList([ 390 | SwinTransformerBlock(dim=dim, input_resolution=input_resolution, 391 | num_heads=num_heads, window_size=window_size, 392 | shift_size=0 if (i % 2 == 0) else window_size // 2, 393 | mlp_ratio=mlp_ratio, 394 | qkv_bias=qkv_bias, qk_scale=qk_scale, 395 | drop=drop, attn_drop=attn_drop, 396 | drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, 397 | norm_layer=norm_layer) 398 | for i in range(depth)]) 399 | 400 | # patch merging layer 401 | if downsample is not None: 402 | self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) 403 | else: 404 | self.downsample = None 405 | 406 | def forward(self, x): 407 | for blk in self.blocks: 408 | if self.use_checkpoint: 409 | x = checkpoint.checkpoint(blk, x) 410 | else: 411 | x = blk(x) 412 | if self.downsample is not None: 413 | x = self.downsample(x) 414 | return x 415 | 416 | def extra_repr(self) -> str: 417 | return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" 418 | 419 | def flops(self): 420 | flops = 0 421 | for blk in self.blocks: 422 | flops += blk.flops() 423 | if self.downsample is not None: 424 | flops += self.downsample.flops() 425 | return flops 426 | 427 | 428 | class PatchEmbed(nn.Module): 429 | r""" Image to Patch Embedding 430 | First step !!!!!!主要作用在于将要输入到SwinTransNet的特征图下采样4倍并将通道变成初始的embed_dim 431 | Args: 432 | img_size (int): Image size. Default: 224. 433 | patch_size (int): Patch token size. Default: 4. 434 | in_chans (int): Number of input image channels. Default: 3. 435 | embed_dim (int): Number of linear projection output channels. Default: 96. 436 | patches_resolution: 以每个patch大小为单位的分辨率 437 | num_patches: patch 的数量 438 | norm_layer (nn.Module, optional): Normalization layer. Default: None 439 | """ 440 | 441 | def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): 442 | super().__init__() 443 | img_size = to_2tuple(img_size) 444 | patch_size = to_2tuple(patch_size) 445 | patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] 446 | self.img_size = img_size 447 | self.patch_size = patch_size 448 | self.patches_resolution = patches_resolution 449 | self.num_patches = patches_resolution[0] * patches_resolution[1] 450 | 451 | self.in_chans = in_chans # define in_chans == 3 452 | self.embed_dim = embed_dim # Swin-B.embed_dim ==128,(T is 96) 453 | 454 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) # dim 3->128 455 | if norm_layer is not None: 456 | self.norm = norm_layer(embed_dim) 457 | else: 458 | self.norm = None 459 | 460 | def forward(self, x): 461 | B, C, H, W = x.shape 462 | # FIXME look at relaxing size constraints,尺寸固定,下有断言 463 | assert H == self.img_size[0] and W == self.img_size[1], \ 464 | f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." 465 | x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C 466 | if self.norm is not None: 467 | x = self.norm(x) 468 | return x 469 | 470 | def flops(self): 471 | Ho, Wo = self.patches_resolution 472 | flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) 473 | if self.norm is not None: 474 | flops += Ho * Wo * self.embed_dim 475 | return flops 476 | 477 | 478 | class SwinTransformer(nn.Module): 479 | r""" Swin Transformer 480 | A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - 481 | https://arxiv.org/pdf/2103.14030 482 | 483 | Args: 484 | img_size (int | tuple(int)): Input image size. Default 224 485 | patch_size (int | tuple(int)): Patch size. Default: 4 486 | in_chans (int): Number of input image channels. Default: 3 487 | num_classes (int): Number of classes for classification head. Default: 1000 488 | embed_dim (int): Patch embedding dimension. Default: 96 489 | depths (tuple(int)): Depth of each Swin Transformer layer. 490 | num_heads (tuple(int)): Number of attention heads in different layers. 491 | window_size (int): Window size. Default: 7 492 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 493 | qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True 494 | qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None 495 | drop_rate (float): Dropout rate. Default: 0 496 | attn_drop_rate (float): Attention dropout rate. Default: 0 497 | drop_path_rate (float): Stochastic depth rate. Default: 0.1 498 | norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. 499 | ape (bool): If True, add absolute position embedding to the patch embedding. Default: False 500 | patch_norm (bool): If True, add normalization after patch embedding. Default: True 501 | use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False 502 | """ 503 | 504 | def __init__(self, img_size=384, patch_size=4, in_chans=3, num_classes=1000, 505 | embed_dim=128, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], 506 | window_size=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, 507 | drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, 508 | norm_layer=nn.LayerNorm, ape=False, patch_norm=True, 509 | use_checkpoint=False, **kwargs): 510 | super().__init__() 511 | 512 | self.num_classes = num_classes 513 | self.num_layers = len(depths) 514 | self.embed_dim = embed_dim 515 | self.ape = ape 516 | self.patch_norm = patch_norm 517 | self.num_features = int(embed_dim * 2 ** (self.num_layers - 1)) # TODO 518 | self.mlp_ratio = mlp_ratio 519 | 520 | # split image into non-overlapping patches 521 | self.patch_embed = PatchEmbed( 522 | img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, 523 | norm_layer=norm_layer if self.patch_norm else None) 524 | num_patches = self.patch_embed.num_patches 525 | patches_resolution = self.patch_embed.patches_resolution 526 | self.patches_resolution = patches_resolution 527 | 528 | # absolute position embedding 529 | if self.ape: 530 | self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) 531 | trunc_normal_(self.absolute_pos_embed, std=.02) 532 | 533 | self.pos_drop = nn.Dropout(p=drop_rate) 534 | 535 | # stochastic depth 536 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule 537 | 538 | # build layers 539 | self.layers = nn.ModuleList() 540 | for i_layer in range(self.num_layers): 541 | layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer), 542 | input_resolution=(patches_resolution[0] // (2 ** i_layer), 543 | patches_resolution[1] // (2 ** i_layer)), 544 | depth=depths[i_layer], 545 | num_heads=num_heads[i_layer], 546 | window_size=window_size, 547 | mlp_ratio=self.mlp_ratio, 548 | qkv_bias=qkv_bias, qk_scale=qk_scale, 549 | drop=drop_rate, attn_drop=attn_drop_rate, 550 | drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], 551 | norm_layer=norm_layer, 552 | downsample=PatchMerging if (i_layer < self.num_layers - 1) else None, 553 | use_checkpoint=use_checkpoint) 554 | self.layers.append(layer) 555 | 556 | self.norm = norm_layer(self.num_features) 557 | # self.avgpool = nn.AdaptiveAvgPool1d(1) 558 | # self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() 559 | # self.swin_decoder = Swin_decoder() 560 | self.apply(self._init_weights) 561 | 562 | def _init_weights(self, m): 563 | if isinstance(m, nn.Linear): 564 | trunc_normal_(m.weight, std=.02) 565 | if isinstance(m, nn.Linear) and m.bias is not None: 566 | nn.init.constant_(m.bias, 0) 567 | elif isinstance(m, nn.LayerNorm): 568 | nn.init.constant_(m.bias, 0) 569 | nn.init.constant_(m.weight, 1.0) 570 | 571 | @torch.jit.ignore 572 | def no_weight_decay(self): 573 | return {'absolute_pos_embed'} 574 | 575 | @torch.jit.ignore 576 | def no_weight_decay_keywords(self): 577 | return {'relative_position_bias_table'} 578 | 579 | def forward_features(self, x): 580 | layer_features = [] 581 | x = self.patch_embed(x) 582 | B,L,C = x.shape 583 | layer_features.append(x.view(B, int(np.sqrt(L)), int(np.sqrt(L)),-1).permute(0, 3, 1, 2).contiguous()) 584 | # layer_features.append(x) 585 | if self.ape: 586 | x = x + self.absolute_pos_embed 587 | x = self.pos_drop(x) 588 | 589 | for layer in self.layers: 590 | x = layer(x) 591 | B, L, C = x.shape 592 | # print('x:', x.shape) 593 | xl = x.view(B, int(np.sqrt(L)), int(np.sqrt(L)),-1).permute(0, 3, 1, 2).contiguous() 594 | # print('xl',xl.shape) 595 | layer_features.append(xl) 596 | x = self.norm(x) # B L C 597 | B, L, C = x.shape 598 | # x = self.avgpool(x.ranspose(1, 2)) # B C 1 599 | x = x.view(B, int(np.sqrt(L)), int(np.sqrt(L)),-1).permute(0, 3, 1, 2).contiguous() 600 | # x = torch.flatten(x, 1) 601 | layer_features[-1] = x 602 | # self.swin_decoder(layer_features) 603 | 604 | 605 | return layer_features 606 | 607 | def forward(self, x): 608 | outs = self.forward_features(x) 609 | # print("len outs",len(outs)) 610 | # sal, edge = self.swin_decoder(outs) 611 | return outs 612 | 613 | def flops(self): 614 | flops = 0 615 | flops += self.patch_embed.flops() 616 | for i, layer in enumerate(self.layers): 617 | flops += layer.flops() 618 | flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2 ** self.num_layers) 619 | flops += self.num_features * self.num_classes 620 | return flops 621 | 622 | class SwinNet(nn.Module): 623 | def __init__(self): 624 | super(SwinNet, self).__init__() 625 | 626 | self.rgb_swin = SwinTransformer(embed_dim=128, depths=[2,2,18,2], num_heads=[4,8,16,32]) 627 | self.depth_swin = SwinTransformer(embed_dim=128, depths=[2,2,18,2], num_heads=[4,8,16,32]) 628 | self.fuse_enhance1 = fuse_enhance(1024) 629 | self.fuse_enhance2 = fuse_enhance(512) 630 | self.fuse_enhance3 = fuse_enhance(256) 631 | self.fuse_enhance4 = fuse_enhance(128) 632 | self.up2 = nn.UpsamplingBilinear2d(scale_factor = 2) 633 | self.up4 = nn.UpsamplingBilinear2d(scale_factor = 4) 634 | self.conv2048_1024 = conv3x3_bn_relu(2048, 1024) 635 | self.conv1024_512 = conv3x3_bn_relu(1024, 512) 636 | self.conv512_256 = conv3x3_bn_relu(512, 256) 637 | self.conv256_32 = conv3x3_bn_relu(256, 32) 638 | self.conv64_1 = conv3x3(64, 1) 639 | 640 | self.edge_layer = Edge_Module() 641 | self.edge_feature = conv3x3_bn_relu(1, 32) 642 | self.fuse_edge_sal = conv3x3(32, 1) 643 | self.up_edge = nn.Sequential( 644 | nn.UpsamplingBilinear2d(scale_factor = 4), 645 | conv3x3(32, 1) 646 | ) 647 | 648 | self.relu = nn.ReLU(True) 649 | def forward(self,x ,d): 650 | rgb_list = self.rgb_swin(x) 651 | depth_list = self.depth_swin(d) 652 | 653 | r4 = rgb_list[0] 654 | r3 = rgb_list[1] 655 | r2 = rgb_list[2] 656 | r1 = rgb_list[3] 657 | d4 = depth_list[0] 658 | d3 = depth_list[1] 659 | d2 = depth_list[2] 660 | d1 = depth_list[3] 661 | 662 | fr1, fd1 = self.fuse_enhance1(r1, d1) 663 | fr2, fd2 = self.fuse_enhance2(r2, d2) 664 | fr3, fd3 = self.fuse_enhance3(r3, d3) 665 | fr4, fd4 = self.fuse_enhance4(r4, d4) 666 | 667 | mul_fea1 = fr1 * fd1 668 | add_fea1 = fr1 + fd1 669 | fuse_fea1 = torch.cat((mul_fea1, add_fea1), dim=1) 670 | fuse_fea1 = self.up2(fuse_fea1) 671 | fuse_fea1 = self.conv2048_1024(fuse_fea1) 672 | 673 | mul_fea2 = fr2 * fd2 674 | add_fea2 = fr2 + fd2 675 | fuse_fea2 = torch.cat((mul_fea2, add_fea2), dim=1) 676 | 677 | fuse_fea2 = fuse_fea2 + fuse_fea1 678 | fuse_fea2 = self.up2(fuse_fea2) 679 | fuse_fea2 = self.conv1024_512(fuse_fea2) 680 | 681 | mul_fea3 = fr3 * fd3 682 | add_fea3 = fr3 + fd3 683 | fuse_fea3 = torch.cat((mul_fea3, add_fea3), dim=1) 684 | fuse_fea3 = fuse_fea3 + fuse_fea2 685 | fuse_fea3 = self.up2(fuse_fea3) 686 | fuse_fea3 = self.conv512_256(fuse_fea3) 687 | 688 | mul_fea4 = fr4 * fd4 689 | add_fea4 = fr4 + fd4 690 | fuse_fea4 = torch.cat((mul_fea4, add_fea4), dim=1) 691 | fuse_fea4 = fuse_fea4 + fuse_fea3 692 | 693 | edge_map = self.edge_layer(d4, d3, d2) 694 | edge_feature = self.edge_feature(edge_map) 695 | end_sal = self.conv256_32(fuse_fea4) 696 | up_edge = self.up_edge(end_sal) 697 | out = self.relu(torch.cat((end_sal, edge_feature), dim=1)) 698 | out = self.up4(out) 699 | sal_out = self.conv64_1(out) 700 | 701 | return sal_out, up_edge 702 | 703 | def load_pre(self, pre_model): 704 | self.rgb_swin.load_state_dict(torch.load(pre_model)['model'],strict=False) 705 | print(f"RGB SwinTransformer loading pre_model ${pre_model}") 706 | self.depth_swin.load_state_dict(torch.load(pre_model)['model'], strict=False) 707 | print(f"Depth SwinTransformer loading pre_model ${pre_model}") 708 | 709 | 710 | class fuse_enhance(nn.Module): 711 | def __init__(self, infeature): 712 | super(fuse_enhance, self).__init__() 713 | self.depth_channel_attention = ChannelAttention(infeature) 714 | self.rgb_channel_attention = ChannelAttention(infeature) 715 | self.rd_spatial_attention = SpatialAttention() 716 | self.rgb_spatial_attention = SpatialAttention() 717 | self.depth_spatial_attention = SpatialAttention() 718 | 719 | def forward(self,r,d): 720 | assert r.shape == d.shape,"rgb and depth should have same size" 721 | mul_fuse = r * d 722 | sa = self.rd_spatial_attention(mul_fuse) 723 | r_f = r * sa 724 | d_f = d * sa 725 | r_ca = self.rgb_channel_attention(r_f) 726 | d_ca = self.depth_channel_attention(d_f) 727 | 728 | r_out = r * r_ca 729 | d_out = d * d_ca 730 | return r_out, d_out 731 | 732 | class CALayer(nn.Module): 733 | def __init__(self, channel, reduction=16): 734 | super(CALayer, self).__init__() 735 | # global average pooling: feature --> point 736 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 737 | # feature channel downscale and upscale --> channel weight 738 | self.conv_du = nn.Sequential( 739 | nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=True), 740 | nn.ReLU(inplace=True), 741 | nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=True), 742 | nn.Sigmoid() 743 | ) 744 | 745 | def forward(self, x): 746 | y = self.avg_pool(x) 747 | y = self.conv_du(y) 748 | return x * y 749 | 750 | ## Residual Channel Attention Block (RCAB) 751 | class RCAB(nn.Module): 752 | def __init__( 753 | self, n_feat, kernel_size=3, reduction=16, 754 | bias=True, bn=False, act=nn.ReLU(True), res_scale=1): 755 | 756 | super(RCAB, self).__init__() 757 | modules_body = [] 758 | for i in range(2): 759 | modules_body.append(self.default_conv(n_feat, n_feat, kernel_size, bias=bias)) 760 | if bn: modules_body.append(nn.BatchNorm2d(n_feat)) 761 | if i == 0: modules_body.append(act) 762 | modules_body.append(CALayer(n_feat, reduction)) 763 | self.body = nn.Sequential(*modules_body) 764 | self.res_scale = res_scale 765 | 766 | def default_conv(self, in_channels, out_channels, kernel_size, bias=True): 767 | return nn.Conv2d(in_channels, out_channels, kernel_size,padding=(kernel_size // 2), bias=bias) 768 | 769 | def forward(self, x): 770 | res = self.body(x) 771 | #res = self.body(x).mul(self.res_scale) 772 | res += x 773 | return res 774 | 775 | class Edge_Module(nn.Module): 776 | def __init__(self, in_fea=[128, 256, 512], mid_fea=32): 777 | super(Edge_Module, self).__init__() 778 | self.relu = nn.ReLU(inplace=True) 779 | self.conv2 = nn.Conv2d(in_fea[0], mid_fea, 1) 780 | self.conv4 = nn.Conv2d(in_fea[1], mid_fea, 1) 781 | self.conv5 = nn.Conv2d(in_fea[2], mid_fea, 1) 782 | self.conv5_2 = nn.Conv2d(mid_fea, mid_fea, 3, padding=1) 783 | self.conv5_4 = nn.Conv2d(mid_fea, mid_fea, 3, padding=1) 784 | self.conv5_5 = nn.Conv2d(mid_fea, mid_fea, 3, padding=1) 785 | self.up2 = nn.UpsamplingBilinear2d(scale_factor=2) 786 | self.classifer = nn.Conv2d(mid_fea * 3, 1, kernel_size=3, padding=1) 787 | self.rcab = RCAB(mid_fea * 3) 788 | 789 | def forward(self, x2, x4, x5): 790 | _, _, h, w = x2.size() 791 | edge2_fea = self.relu(self.conv2(x2)) 792 | edge2 = self.relu(self.conv5_2(edge2_fea)) 793 | edge4_fea = self.relu(self.conv4(x4)) 794 | edge4 = self.relu(self.conv5_4(edge4_fea)) 795 | edge5_fea = self.relu(self.conv5(x5)) 796 | edge5 = self.relu(self.conv5_5(edge5_fea)) 797 | 798 | edge4 = F.interpolate(edge4, size=(h, w), mode='bilinear', align_corners=True) 799 | edge5 = F.interpolate(edge5, size=(h, w), mode='bilinear', align_corners=True) 800 | 801 | edge = torch.cat([edge2, edge4, edge5], dim=1) 802 | edge = self.rcab(edge) 803 | edge = self.classifer(edge) 804 | return edge 805 | 806 | class ChannelAttention(nn.Module): 807 | def __init__(self, in_planes, ratio=16): 808 | super(ChannelAttention, self).__init__() 809 | 810 | self.max_pool = nn.AdaptiveMaxPool2d(1) 811 | 812 | self.fc1 = nn.Conv2d(in_planes, in_planes // 16, 1, bias=False) 813 | self.relu1 = nn.ReLU() 814 | self.fc2 = nn.Conv2d(in_planes // 16, in_planes, 1, bias=False) 815 | 816 | self.sigmoid = nn.Sigmoid() 817 | 818 | def forward(self, x): 819 | max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x)))) 820 | out = max_out 821 | return self.sigmoid(out) 822 | 823 | 824 | class SpatialAttention(nn.Module): 825 | def __init__(self, kernel_size=7): 826 | super(SpatialAttention, self).__init__() 827 | 828 | assert kernel_size in (3, 7), 'kernel size must be 3 or 7' 829 | padding = 3 if kernel_size == 7 else 1 830 | 831 | self.conv1 = nn.Conv2d(1, 1, kernel_size, padding=padding, bias=False) 832 | self.sigmoid = nn.Sigmoid() 833 | 834 | def forward(self, x): 835 | max_out, _ = torch.max(x, dim=1, keepdim=True) 836 | x = max_out 837 | x = self.conv1(x) 838 | return self.sigmoid(x) 839 | 840 | 841 | if __name__ == '__main__': 842 | os.environ["CUDA_VISIBLE_DEVICES"] = "1" 843 | pre_path = '../Pre_train/swin_base_patch4_window7_224.pth' 844 | a = torch.randn(1,3,224,224) 845 | b = torch.randn(1,3,224,224) 846 | # net = SwinTransformer(embed_dim=128, depths=[2, 2, 18, 2], num_heads=[4, 8, 16, 32]) 847 | net = SwinNet() 848 | out = net(a,b) 849 | # print(out.shape) 850 | for i in out: 851 | print(i.shape) 852 | # model = SwinNet() 853 | # out = model(a,b) 854 | # for i in out: 855 | # print(i.shape) 856 | # model.load_state_dict(torch.load(r"D:\tanyacheng\Experiments\SOD\Transformer_Saliency\Swin\Saliency\Swin-Transformer-Saliency_v19\SwinTransNet_RGBD_cpts\SwinTransNet_epoch_best.pth"), strict=True) 857 | -------------------------------------------------------------------------------- /networks/SwinNet/optimizer.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | @author: caigentan@AnHui University 4 | @software: PyCharm 5 | @file: optimizer.py 6 | @time: 2021/5/6 15:58 7 | """ 8 | 9 | from torch import optim as optim 10 | 11 | 12 | def build_optimizer(config, model): 13 | """ 14 | Build optimizer, set weight decay of normalization to 0 by default. 15 | """ 16 | skip = {} 17 | skip_keywords = {} 18 | if hasattr(model, 'no_weight_decay'): 19 | skip = model.no_weight_decay() 20 | if hasattr(model, 'no_weight_decay_keywords'): 21 | skip_keywords = model.no_weight_decay_keywords() 22 | parameters = set_weight_decay(model, skip, skip_keywords) 23 | 24 | opt_lower = config.TRAIN.OPTIMIZER.NAME.lower() 25 | optimizer = None 26 | if opt_lower == 'sgd': 27 | optimizer = optim.SGD(parameters, momentum=config.TRAIN.OPTIMIZER.MOMENTUM, nesterov=True, 28 | lr=config.TRAIN.BASE_LR, weight_decay=config.TRAIN.WEIGHT_DECAY) 29 | elif opt_lower == 'adamw': 30 | optimizer = optim.AdamW(parameters, eps=config.TRAIN.OPTIMIZER.EPS, betas=config.TRAIN.OPTIMIZER.BETAS, 31 | lr=config.TRAIN.BASE_LR, weight_decay=config.TRAIN.WEIGHT_DECAY) 32 | 33 | return optimizer 34 | 35 | 36 | def set_weight_decay(model, skip_list=(), skip_keywords=()): 37 | has_decay = [] 38 | no_decay = [] 39 | 40 | for name, param in model.named_parameters(): 41 | if not param.requires_grad: 42 | continue # frozen weights 43 | if len(param.shape) == 1 or name.endswith(".bias") or (name in skip_list) or \ 44 | check_keywords_in_name(name, skip_keywords): 45 | no_decay.append(param) 46 | # print(f"{name} has no weight decay") 47 | else: 48 | has_decay.append(param) 49 | return [{'params': has_decay}, 50 | {'params': no_decay, 'weight_decay': 0.}] 51 | 52 | 53 | def check_keywords_in_name(name, keywords=()): 54 | isin = False 55 | for keyword in keywords: 56 | if keyword in name: 57 | isin = True 58 | return isin -------------------------------------------------------------------------------- /networks/SwinNet/options.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | @author: caigentan@AnHui University 4 | @software: PyCharm 5 | @file: options.py 6 | @time: 2021/5/16 14:52 7 | """ 8 | 9 | import argparse 10 | 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument('--epoch', type=int, default=300, help='epoch number') 13 | parser.add_argument('--lr', type=float, default=5e-5, help='learning rate') 14 | parser.add_argument('--batchsize', type=int, default=3, help='training batch size') 15 | parser.add_argument('--trainsize', type=int, default=384, help='training dataset size') 16 | parser.add_argument('--clip', type=float, default=0.5, help='gradient clipping margin') 17 | parser.add_argument('--decay_rate', type=float, default=0.1, help='decay rate of learning rate') 18 | parser.add_argument('--decay_epoch', type=int, default=100, help='every n epochs decay learning rate') 19 | # parser.add_argument('--load', type=str, default='./pre_train/swin_base_patch4_window7_224.pth', help='train from checkpoints') 20 | parser.add_argument('--load', type=str, default='./Pre_train/swin_base_patch4_window12_384_22k.pth', help='train from checkpoints') 21 | parser.add_argument('--gpu_id', type=str, default='1', help='train use gpu') 22 | 23 | # RGB-D Datasets 24 | parser.add_argument('--rgb_root', type=str, default='./datasets/train/RGB/', help='the training RGB images root') 25 | parser.add_argument('--depth_root', type=str, default='./datasets/train/Depth/', help='the training Depth images root') 26 | parser.add_argument('--gt_root', type=str, default='./datasets/train/GT/', help='the training GT images root') 27 | parser.add_argument('--edge_root', type=str, default='./datasets/train/Edge/', help='the training Edge images root') 28 | parser.add_argument('--val_rgb_root', type=str, default='./datasets/RGB-D/validation/RGB/', help='the validation RGB images root') 29 | parser.add_argument('--val_depth_root', type=str, default='./datasets/RGB-D/validation/Depth/', help='the validation Depth images root') 30 | parser.add_argument('--val_gt_root', type=str, default='./datasets/RGB-D/validation/GT/', help='the test validation GT images root') 31 | 32 | # RGB-T Datasets 33 | """ 34 | parser.add_argument('--rgb_root', type=str, default='./datasets/RGB-T/train/RGB/', help='the training RGB images root') 35 | parser.add_argument('--depth_root', type=str, default='./datasets/RGB-T/train/T/', help='the training Thermal images root') 36 | parser.add_argument('--gt_root', type=str, default='./datasets/RGB-T/train/GT/', help='the training GT images root') 37 | parser.add_argument('--edge_root', type=str, default='./datasets/RGB-T/train/Edge/', help='the training Edge images root') 38 | 39 | parser.add_argument('--val_rgb_root', type=str, default='./datasets/RGB-T/validation/RGB/', help='the validation RGB images root') 40 | parser.add_argument('--val_depth_root', type=str, default='./datasets/RGB-T/validation/T/', help='the validation Thermal images root') 41 | parser.add_argument('--val_gt_root', type=str, default='./datasets/RGB-T/validation/GT/', help='the validation GT images root') 42 | """ 43 | parser.add_argument('--save_path', type=str, default='./cpts/', help='the path to save models and logs') 44 | 45 | opt = parser.parse_args() -------------------------------------------------------------------------------- /networks/SwinNet/test_gray_to_rgb.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import cv2 3 | # # def rgb_loader(path): 4 | # # with open(path, 'rb') as f: 5 | # # img = Image.open(f) 6 | # # return img.convert('RGB') 7 | root = '/home/sunfan/Desktop/colorization-master/colorization-master/imgs/ansel_adams3.jpg' 8 | save_path = '/home/sunfan/Desktop/save1.png' 9 | img_gray = cv2.imread(root, flags = 0) 10 | img2 = cv2.cvtColor(img_gray, cv2.COLOR_GRAY2BGR) 11 | cv2.imwrite(save_path, img2) 12 | cv2.imshow('test', img2) 13 | # # 14 | # # 15 | # # cv2.waitKey(0) 16 | # # 17 | # # cv2.destroyAllWindows() 18 | # 19 | # # img = Image.open(root) 20 | # # rgb = img.convert('RGB') 21 | # 22 | # # cv2.imwrite(save_path,img2) 23 | # 24 | # 25 | # 26 | # import numpy as np 27 | # import cv2 28 | # root = '/home/sunfan/Desktop/colorization-master/colorization-master/imgs/ansel_adams3.jpg' 29 | # root = '/home/sunfan/Downloads/newdata/train/DUT_NJUNLPR/depth/1_02-02-40.png' 30 | # 31 | # save_path = '/home/sunfan/Desktop/c.png' 32 | # src_gray = cv2.imread(root,flags = 0) 33 | # print(src_gray.shape) 34 | # src = cv2.cvtColor(src_gray, cv2.COLOR_GRAY2BGR) 35 | # print(src.shape) 36 | # # RGB在opencv中存储为BGR的顺序,数据结构为一个3D的numpy.array,索引的顺序是行,列,通道: 37 | # B = src[:,:,0] 38 | # G = src[:,:,1] 39 | # R = src[:,:,2] 40 | # # 灰度g=p*R+q*G+t*B(其中p=0.2989,q=0.5870,t=0.1140),于是B=(g-p*R-q*G)/t。于是我们只要保留R和G两个颜色分量,再加上灰度图g,就可以回复原来的RGB图像。 41 | # g = src_gray[:] 42 | # 43 | # p = 1; q = 1; t = 1 44 | # B_new = (g-p*R-q*G)/t 45 | # G_new1 = (g-p*R+q*G)/t 46 | # R_new = R 47 | # 48 | # # B_new = np.uint8(B_new) 49 | # src_new = np.zeros((src.shape)).astype("uint8") 50 | # src_new[:,:,0] = B_new 51 | # src_new[:,:,1] = G_new1 52 | # src_new[:,:,2] = R_new 53 | # # 显示图像 54 | # cv2.imshow("input", src_gray) 55 | # cv2.imshow("output", src) 56 | # cv2.imshow("result", src_new) 57 | # cv2.imwrite(save_path, src_new) 58 | # cv2.waitKey(0) 59 | # cv2.destroyAllWindows() 60 | # 61 | 62 | # Show images 63 | -------------------------------------------------------------------------------- /networks/SwinNet/utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | @author: caigentan@AnHui University 4 | @software: PyCharm 5 | @file: utils.py 6 | @time: 2021/5/6 16:16 7 | """ 8 | def clip_gradient(optimizer, grad_clip): 9 | for group in optimizer.param_groups: 10 | for param in group['params']: 11 | if param.grad is not None: 12 | param.grad.data.clamp_(-grad_clip, grad_clip) 13 | 14 | 15 | def adjust_lr(optimizer, init_lr, epoch, decay_rate=0.1, decay_epoch=30): 16 | decay = decay_rate ** (epoch // decay_epoch) 17 | for param_group in optimizer.param_groups: 18 | param_group['lr'] = decay*init_lr 19 | lr=param_group['lr'] 20 | return lr 21 | -------------------------------------------------------------------------------- /networks/Wavenet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from networks.wavemlp import WaveMLP_S 4 | from timm.models.layers import DropPath 5 | from pytorch_wavelets import DTCWTForward, DTCWTInverse 6 | from torch.nn.functional import kl_div 7 | 8 | class BasicConv2d(nn.Module): 9 | def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1): 10 | super(BasicConv2d, self).__init__() 11 | self.conv = nn.Conv2d(in_planes, out_planes, 12 | kernel_size=kernel_size, stride=stride, 13 | padding=padding, dilation=dilation, bias=False) 14 | self.bn = nn.BatchNorm2d(out_planes) 15 | self.relu = nn.ReLU(inplace=True) 16 | 17 | def forward(self, x): 18 | x = self.conv(x) 19 | x = self.bn(x) 20 | x = self.relu(x) 21 | return x 22 | 23 | class TransBasicConv2d(nn.Module): 24 | def __init__(self, in_planes, out_planes, kernel_size=2, stride=2, padding=0, dilation=1,output_padding=0, bias=False): 25 | super(TransBasicConv2d, self).__init__() 26 | self.Deconv = nn.ConvTranspose2d(in_planes, out_planes, 27 | kernel_size=kernel_size, stride=stride, 28 | padding=padding,output_padding= output_padding, dilation=dilation, bias=bias) 29 | self.bn = nn.BatchNorm2d(out_planes) 30 | self.relu = nn.ReLU(inplace=True) 31 | self.inch = in_planes 32 | def forward(self, x): 33 | 34 | x = self.Deconv(x) 35 | x = self.bn(x) 36 | x = self.relu(x) 37 | return x 38 | 39 | 40 | class Mlp(nn.Module): 41 | def __init__(self, in_features=64, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 42 | super().__init__() 43 | out_features = out_features or in_features 44 | hidden_features = hidden_features or in_features 45 | self.fc1 = nn.Linear(in_features, hidden_features) 46 | self.act = act_layer() 47 | self.fc2 = nn.Linear(hidden_features, out_features) 48 | self.drop = nn.Dropout(drop) 49 | 50 | def forward(self, x): 51 | # print('x',x.shape) 52 | x = self.fc1(x) 53 | # print('fc',x.shape) 54 | x = self.act(x) 55 | x = self.drop(x) 56 | x = self.fc2(x) 57 | x = self.drop(x) 58 | return x 59 | 60 | class Mlp_wave(nn.Module): 61 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 62 | super().__init__() 63 | 64 | out_features = out_features or in_features 65 | hidden_features = hidden_features or in_features 66 | self.act = act_layer() 67 | self.drop = nn.Dropout(drop) 68 | self.fc1 = nn.Conv2d(in_features, hidden_features, 1, 1) 69 | self.fc2 = nn.Conv2d(hidden_features, out_features, 1, 1) 70 | 71 | def forward(self, x): 72 | x = self.fc1(x) 73 | x = self.act(x) 74 | x = self.drop(x) 75 | x = self.fc2(x) 76 | x = self.drop(x) 77 | return x 78 | 79 | 80 | class BAB_Decoder(nn.Module): 81 | def __init__(self, channel_1=1024, channel_2=512, channel_3=256, dilation_1=3, dilation_2=2): 82 | super(BAB_Decoder, self).__init__() 83 | 84 | self.conv1 = BasicConv2d(channel_1, channel_2, 3, padding=1) 85 | self.conv1_Dila = BasicConv2d(channel_2, channel_2, 3, padding=dilation_1, dilation=dilation_1) 86 | self.conv2 = BasicConv2d(channel_2, channel_2, 3, padding=1) 87 | self.conv2_Dila = BasicConv2d(channel_2, channel_2, 3, padding=dilation_2, dilation=dilation_2) 88 | self.conv3 = BasicConv2d(channel_2, channel_2, 3, padding=1) 89 | self.conv_fuse = BasicConv2d(channel_2*3, channel_3, 3, padding=1) 90 | 91 | def forward(self, x): 92 | x1 = self.conv1(x) 93 | x1_dila = self.conv1_Dila(x1) 94 | 95 | x2 = self.conv2(x1) 96 | x2_dila = self.conv2_Dila(x2) 97 | 98 | x3 = self.conv3(x2) 99 | 100 | x_fuse = self.conv_fuse(torch.cat((x1_dila, x2_dila, x3), 1)) 101 | 102 | return x_fuse 103 | 104 | class FFT(nn.Module): 105 | def __init__(self,inchannel,outchannel): 106 | super().__init__() 107 | self.DWT = DTCWTForward(J=3, biort='near_sym_b', qshift='qshift_b') 108 | # self.DWT =DTCWTForward(J=3, include_scale=True) 109 | self.IWT = DTCWTInverse(biort='near_sym_b', qshift='qshift_b') 110 | self.conv1 = BasicConv2d(outchannel, outchannel) 111 | self.conv2 = BasicConv2d(inchannel, outchannel) 112 | self.conv3 = BasicConv2d(outchannel, outchannel) 113 | self.change = TransBasicConv2d(outchannel, outchannel) 114 | 115 | def forward(self, x, y): 116 | y = self.conv2(y) 117 | Xl, Xh = self.DWT(x) 118 | Yl, Yh = self.DWT(y) 119 | x_y = self.conv1(Xl) + self.conv1(Yl) 120 | 121 | x_m = self.IWT((x_y, Xh)) 122 | y_m = self.IWT((x_y, Yh)) 123 | 124 | out = self.conv3(x_m + y_m) 125 | return out 126 | 127 | class PATM_BAB(nn.Module): 128 | def __init__(self, channel_1=1024, channel_2=512, channel_3=256, dilation_1=3, dilation_2=2): 129 | super().__init__() 130 | self.conv1 = BasicConv2d(channel_1, channel_2, 3, padding=1) 131 | self.conv1_Dila = BasicConv2d(channel_2, channel_2, 3, padding=dilation_1, dilation=dilation_1) 132 | self.conv2 = BasicConv2d(channel_2, channel_2, 3, padding=1) 133 | self.conv2_Dila = BasicConv2d(channel_2, channel_2, 3, padding=dilation_2, dilation=dilation_2) 134 | self.conv3 = BasicConv2d(channel_2, channel_2, 3, padding=1) 135 | self.conv_fuse = BasicConv2d(channel_2 *2, channel_3, 3, padding=1) 136 | self.drop = nn.Dropout(0.5) 137 | self.conv_last=TransBasicConv2d(channel_3, channel_3, kernel_size=2, stride=2, 138 | padding=0, dilation=1, bias=False) 139 | def forward(self, x): 140 | x1 = self.conv1(x) 141 | x1_dila = self.conv1_Dila(x1) 142 | 143 | x2 = self.conv2(x1) 144 | x2_dila = self.conv2_Dila(x2) 145 | 146 | x3 = self.conv3(x2) 147 | x1_dila = torch.cat([x1_dila * torch.cos(x1_dila), x1_dila * torch.sin(x1_dila)], dim=1) 148 | x2_dila = torch.cat([x2_dila * torch.cos(x2_dila), x2_dila * torch.sin(x2_dila)], dim=1) 149 | x3 = torch.cat([x3 * torch.cos(x3), x3 * torch.sin(x3)], dim=1) 150 | # print('x1_dila + x2_dila+x3',x1_dila.shape) 151 | x_fuse = self.conv_fuse(x1_dila + x2_dila +x3) 152 | # x_fuse = self.conv_fuse(torch.cat((x1_dila, x2_dila, x3), 1)) 153 | # print('x_f',x_fuse.shape) 154 | x_fuse= self.drop(x_fuse) 155 | # print() 156 | x_fuse = self.conv_last(x_fuse) 157 | return x_fuse 158 | 159 | class DWT(nn.Module): 160 | def __init__(self, inchannel,outchannel): 161 | super(DWT, self).__init__() 162 | self.DWT = DTCWTForward(J=3, biort='near_sym_b', qshift='qshift_b') 163 | self.IWT = DTCWTInverse(biort='near_sym_b', qshift='qshift_b') 164 | self.conv1 = BasicConv2d(outchannel,outchannel) 165 | self.conv2 = BasicConv2d(inchannel, outchannel) 166 | self.conv3 = BasicConv2d(outchannel, outchannel) 167 | self.change = TransBasicConv2d(outchannel,outchannel) 168 | def forward(self, x, y): 169 | # print('x',x.shape) 170 | y = self.change(self.conv2(y)) 171 | # print('y',y.shape) 172 | Xl, Xh = self.DWT(x) 173 | Yl, Yh = self.DWT(y) 174 | # print('Xl',Xl.shape) 175 | x_y = self.conv1(Xl)+self.conv1(Yl) 176 | # print('x_y',x_y.shape) 177 | # print('Xh',Xh.shape) 178 | # print('Yh',Yh.shape) 179 | x_m = self.IWT((x_y,Xh)) 180 | y_m = self.IWT((x_y,Yh)) 181 | # print('x_m',x_m.shape) 182 | # print('y_m',y_m.shape) 183 | out = self.conv3(x_m + y_m) 184 | return out 185 | class Edge_Aware(nn.Module): 186 | def __init__(self, ): 187 | super(Edge_Aware, self).__init__() 188 | self.conv1 = TransBasicConv2d(512, 64,kernel_size=4,stride=8,padding=0,dilation=2,output_padding=1) 189 | self.conv2 = TransBasicConv2d(320, 64,kernel_size=2,stride=4,padding=0,dilation=2,output_padding=1) 190 | self.conv3 = TransBasicConv2d(128, 64,kernel_size=2,stride=2,padding=1,dilation=2,output_padding=1) 191 | self.pos_embed = BasicConv2d(64, 64 ) 192 | self.pos_embed3 = BasicConv2d(64, 64) 193 | self.conv31 = nn.Conv2d(64,1, kernel_size=1) 194 | self.conv512_64 = TransBasicConv2d(512,64) 195 | self.conv320_64 = TransBasicConv2d(320, 64) 196 | self.conv128_64 = TransBasicConv2d(128, 64) 197 | self.up = nn.Upsample(56) 198 | self.up2 = nn.Upsample(384) 199 | self.norm1 = nn.LayerNorm(64) 200 | self.norm2 = nn.BatchNorm2d(64) 201 | self.drop_path = DropPath(0.3) 202 | self.maxpool =nn.AdaptiveMaxPool2d(1) 203 | # self.qkv = nn.Linear(64, 64 * 3, bias=False) 204 | self.num_heads = 8 205 | self.mlp1 = Mlp(in_features=64, out_features=64) 206 | self.mlp2 = Mlp(in_features=64, out_features=64) 207 | self.mlp3 = Mlp(in_features=64, out_features=64) 208 | def forward(self, x, y, z, v): 209 | 210 | 211 | # v = self.conv1(v) 212 | # z = self.conv2(z) 213 | # y = self.conv3(y) 214 | # print('v',v) 215 | v = self.up(self.conv512_64(v)) 216 | z = self.up(self.conv320_64(z)) 217 | y = self.up(self.conv128_64(y)) 218 | x = self.up(x) 219 | 220 | x_max = self.maxpool(x) 221 | # print('x_max',x_max.shape) 222 | b,_,_,_ = x_max.shape 223 | x_max = x_max.reshape(b, -1) 224 | x_y = self.mlp1(x_max) 225 | # print('s',x_y.shape) 226 | x_z = self.mlp2(x_max) 227 | x_v = self.mlp3(x_max) 228 | 229 | x_y = x_y.reshape(b,64,1,1) 230 | x_z = x_z.reshape(b, 64, 1, 1) 231 | x_v = x_v.reshape(b, 64, 1, 1) 232 | x_y = torch.mul(x_y, y) 233 | x_z = torch.mul(x_z, z) 234 | x_v = torch.mul(x_v, v) 235 | 236 | 237 | # x_mix_1 = torch.cat((x_y,x_z,x_v),dim=1) 238 | x_mix_1 = x_y+ x_z+ x_v 239 | # print('sd',x_mix_1.shape) 240 | x_mix_1 = self.norm2(x_mix_1) 241 | # print('x_mix_1',x_mix_1.shape) 242 | x_mix_1= self.pos_embed3(x_mix_1) 243 | x_mix = self.drop_path(x_mix_1) 244 | x_mix = x_mix_1 + self. pos_embed3(x_mix) 245 | x_mix = self.up2(self.conv31(x_mix)) 246 | return x_mix 247 | 248 | class Mutual_info_reg(nn.Module): 249 | def __init__(self, input_channels=64, channels=64, latent_size=6): 250 | super(Mutual_info_reg, self).__init__() 251 | self.soft = torch.nn.Softmax(dim=1) 252 | def forward(self, rgb_feat, depth_feat): 253 | 254 | # print('rgb_feat',rgb_feat.shape) 255 | # print('depth_feat', depth_feat.shape) 256 | rgb_feat = self.soft(rgb_feat) 257 | depth_feat = self.soft(depth_feat) 258 | # 259 | # print('rgb_feat',rgb_feat.shape) 260 | # print('depth_feat', depth_feat.shape) 261 | return kl_div(rgb_feat.log(), depth_feat) 262 | 263 | class WaveNet(nn.Module): 264 | def __init__(self, channel=32): 265 | super(WaveNet, self).__init__() 266 | 267 | 268 | self.encoderR = WaveMLP_S() 269 | # Lateral layers 270 | self.lateral_conv0 = BasicConv2d(64, 64, 3, stride=1, padding=1) 271 | self.lateral_conv1 = BasicConv2d(128, 64, 3, stride=1, padding=1) 272 | self.lateral_conv2 = BasicConv2d(320, 128, 3, stride=1, padding=1) 273 | self.lateral_conv3 = BasicConv2d(512, 320, 3, stride=1, padding=1) 274 | 275 | 276 | self.FFT1 = FFT(64,64) 277 | self.FFT2 = FFT(128,128) 278 | self.FFT3 = FFT(320,320) 279 | self.FFT4 = FFT(512,512) 280 | 281 | 282 | 283 | self.conv512_64 = BasicConv2d(512, 64) 284 | self.conv320_64 = BasicConv2d(320, 64) 285 | self.conv128_64 = BasicConv2d(128, 64) 286 | self.sigmoid = nn.Sigmoid() 287 | self.S4 = nn.Conv2d(512, 1, 3, stride=1, padding=1) 288 | self.S3 = nn.Conv2d(320, 1, 3, stride=1, padding=1) 289 | self.S2 = nn.Conv2d(128, 1, 3, stride=1, padding=1) 290 | self.S1 = nn.Conv2d(64, 1, 3, stride=1, padding=1) 291 | self.up1 = nn.Upsample(384) 292 | self.up2 = nn.Upsample(384) 293 | self.up3 = nn.Upsample(384) 294 | self.up_loss = nn.Upsample(92) 295 | # Mutual_info_reg1 296 | self.mi_level1 = Mutual_info_reg(64, 64, 6) 297 | self.mi_level2 = Mutual_info_reg(64, 64, 6) 298 | self.mi_level3 = Mutual_info_reg(64, 64, 6) 299 | self.mi_level4 = Mutual_info_reg(64, 64, 6) 300 | 301 | self.edge = Edge_Aware() 302 | self.PATM4 = PATM_BAB(512, 512, 512, 3, 2) 303 | self.PATM3 = PATM_BAB(832, 512, 320, 3, 2) 304 | self.PATM2 = PATM_BAB(448, 256, 128, 5, 3) 305 | self.PATM1 = PATM_BAB(192, 128, 64, 5, 3) 306 | 307 | def forward(self, x_rgb,x_thermal): 308 | x0,x1,x2,x3 = self.encoderR(x_rgb) 309 | y0, y1, y2, y3 = self.encoderR(x_thermal) 310 | 311 | x2_ACCoM = self.FFT1(x0, y0) 312 | x3_ACCoM = self.FFT2(x1, y1) 313 | x4_ACCoM = self.FFT3(x2, y2) 314 | x5_ACCoM = self.FFT4(x3, y3) 315 | 316 | edge = self.edge(x2_ACCoM, x3_ACCoM, x4_ACCoM, x5_ACCoM) 317 | 318 | mer_cros4 = self.PATM4(x5_ACCoM) 319 | m4 = torch.cat((mer_cros4,x4_ACCoM),dim=1) 320 | mer_cros3 = self.PATM3(m4) 321 | m3 = torch.cat((mer_cros3, x3_ACCoM), dim=1) 322 | mer_cros2 = self.PATM2(m3) 323 | m2 = torch.cat((mer_cros2, x2_ACCoM), dim=1) 324 | mer_cros1 = self.PATM1(m2) 325 | 326 | s1 = self.up1(self.S1(mer_cros1)) 327 | s2 = self.up2(self.S2(mer_cros2)) 328 | s3 = self.up3(self.S3(mer_cros3)) 329 | s4 = self.up3(self.S4(mer_cros4)) 330 | 331 | x_loss0 = x0 332 | y_loss0 = y0 333 | x_loss1 = self.up_loss(self.conv128_64(x1)) 334 | y_loss1 = self.up_loss(self.conv128_64(y1)) 335 | x_loss2 = self.up_loss(self.conv320_64(x2)) 336 | y_loss2 = self.up_loss(self.conv320_64(y2)) 337 | x_loss3 = self.up_loss(self.conv512_64(x3)) 338 | y_loss3 = self.up_loss(self.conv512_64(y3)) 339 | 340 | lat_loss0 = self.mi_level1(x_loss0, y_loss0) 341 | lat_loss1 = self.mi_level2(x_loss1, y_loss1) 342 | lat_loss2 = self.mi_level3(x_loss2, y_loss2) 343 | lat_loss3 = self.mi_level4(x_loss3, y_loss3) 344 | lat_loss = lat_loss0 + lat_loss1 + lat_loss2 + lat_loss3 345 | return s1, s2, s3, s4, self.sigmoid(s1), self.sigmoid(s2), self.sigmoid(s3), self.sigmoid(s4),edge,lat_loss 346 | 347 | if __name__=='__main__': 348 | image = torch.randn(1, 3, 384, 384).cuda(0) 349 | ndsm = torch.randn(1, 64, 56, 56) 350 | ndsm1 = torch.randn(1, 128, 28, 28) 351 | ndsm2 = torch.randn(1, 320, 14, 14) 352 | ndsm3 = torch.randn(1, 512, 7, 7) 353 | 354 | net = WaveNet().cuda() 355 | 356 | -------------------------------------------------------------------------------- /networks/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nowander/WaveNet/7e1343bcc5b583f1aaa8d581d0ea3dbf101d9011/networks/__init__.py -------------------------------------------------------------------------------- /networks/__pycache__/Wavenet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nowander/WaveNet/7e1343bcc5b583f1aaa8d581d0ea3dbf101d9011/networks/__pycache__/Wavenet.cpython-36.pyc -------------------------------------------------------------------------------- /networks/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nowander/WaveNet/7e1343bcc5b583f1aaa8d581d0ea3dbf101d9011/networks/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /networks/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nowander/WaveNet/7e1343bcc5b583f1aaa8d581d0ea3dbf101d9011/networks/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /networks/__pycache__/wavemlp.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nowander/WaveNet/7e1343bcc5b583f1aaa8d581d0ea3dbf101d9011/networks/__pycache__/wavemlp.cpython-36.pyc -------------------------------------------------------------------------------- /networks/wavemlp.py: -------------------------------------------------------------------------------- 1 | # 2022.02.14-Changed for main script for wavemlp model 2 | # Huawei Technologies Co., Ltd. 3 | """ Vision Transformer (ViT) in PyTorch 4 | 5 | A PyTorch implement of Vision Transformers as described in 6 | 'An Image Is Worth 16 x 16 Words: Transformers for Image Recognition at Scale' - https://arxiv.org/abs/2010.11929 7 | 8 | The official jax code is released and available at https://github.com/google-research/vision_transformer 9 | 10 | Status/TODO: 11 | * Models updated to be compatible with official impl. Args added to support backward compat for old PyTorch weights. 12 | * Weights ported from official jax impl for 384x384 base and small models, 16x16 and 32x32 patches. 13 | * Trained (supervised on ImageNet-1k) my custom 'small' patch model to 77.9, 'base' to 79.4 top-1 with this code. 14 | * Hopefully find time and GPUs for SSL or unsupervised pretraining on OpenImages w/ ImageNet fine-tune in future. 15 | 16 | Acknowledgments: 17 | * The paper authors for releasing code and weights, thanks! 18 | * I fixed my class token impl based on Phil Wang's https://github.com/lucidrains/vit-pytorch ... check it out 19 | for some einops/einsum fun 20 | * Simple transformer style inspired by Andrej Karpathy's https://github.com/karpathy/minGPT 21 | * Bert reference code checks against Huggingface Transformers and Tensorflow Bert 22 | 23 | Hacked together by / Copyright 2020 Ross Wightman 24 | """ 25 | 26 | import os 27 | import torch 28 | import torch.nn as nn 29 | 30 | from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 31 | from timm.models.layers import DropPath, trunc_normal_ 32 | from timm.models.registry import register_model 33 | from timm.models.layers.helpers import to_2tuple 34 | 35 | import math 36 | from torch import Tensor 37 | from torch.nn import init 38 | from torch.nn.modules.utils import _pair 39 | import torch.nn.functional as F 40 | 41 | 42 | def _cfg(url='', **kwargs): 43 | return { 44 | 'url': url, 45 | 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, 46 | 'crop_pct': .96, 'interpolation': 'bicubic', 47 | 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 'classifier': 'head', 48 | **kwargs 49 | } 50 | 51 | default_cfgs = { 52 | 'wave_T': _cfg(crop_pct=0.9), 53 | 'wave_S': _cfg(crop_pct=0.9), 54 | 'wave_M': _cfg(crop_pct=0.9), 55 | 'wave_B': _cfg(crop_pct=0.875), 56 | } 57 | 58 | class Mlp(nn.Module): 59 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 60 | super().__init__() 61 | 62 | out_features = out_features or in_features 63 | hidden_features = hidden_features or in_features 64 | self.act = act_layer() 65 | self.drop = nn.Dropout(drop) 66 | self.fc1 = nn.Conv2d(in_features, hidden_features, 1, 1) 67 | self.fc2 = nn.Conv2d(hidden_features, out_features, 1, 1) 68 | 69 | def forward(self, x): 70 | x = self.fc1(x) 71 | x = self.act(x) 72 | x = self.drop(x) 73 | x = self.fc2(x) 74 | x = self.drop(x) 75 | return x 76 | 77 | 78 | class PATM(nn.Module): 79 | def __init__(self, dim, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.,mode='fc'): 80 | super().__init__() 81 | 82 | 83 | self.fc_h = nn.Conv2d(dim, dim, 1, 1,bias=qkv_bias) 84 | self.fc_w = nn.Conv2d(dim, dim, 1, 1,bias=qkv_bias) 85 | self.fc_c = nn.Conv2d(dim, dim, 1, 1,bias=qkv_bias) 86 | 87 | self.tfc_h = nn.Conv2d(2*dim, dim, (1,7), stride=1, padding=(0,7//2), groups=dim, bias=False) 88 | self.tfc_w = nn.Conv2d(2*dim, dim, (7,1), stride=1, padding=(7//2,0), groups=dim, bias=False) 89 | self.reweight = Mlp(dim, dim // 4, dim * 3) 90 | self.proj = nn.Conv2d(dim, dim, 1, 1,bias=True) 91 | self.proj_drop = nn.Dropout(proj_drop) 92 | self.mode=mode 93 | 94 | if mode=='fc': 95 | self.theta_h_conv=nn.Sequential(nn.Conv2d(dim, dim, 1, 1,bias=True),nn.BatchNorm2d(dim),nn.ReLU()) 96 | self.theta_w_conv=nn.Sequential(nn.Conv2d(dim, dim, 1, 1,bias=True),nn.BatchNorm2d(dim),nn.ReLU()) 97 | else: 98 | self.theta_h_conv=nn.Sequential(nn.Conv2d(dim, dim, 3, stride=1, padding=1, groups=dim, bias=False),nn.BatchNorm2d(dim),nn.ReLU()) 99 | self.theta_w_conv=nn.Sequential(nn.Conv2d(dim, dim, 3, stride=1, padding=1, groups=dim, bias=False),nn.BatchNorm2d(dim),nn.ReLU()) 100 | 101 | 102 | 103 | def forward(self, x): 104 | 105 | B, C, H, W = x.shape 106 | theta_h=self.theta_h_conv(x) 107 | theta_w=self.theta_w_conv(x) 108 | 109 | x_h=self.fc_h(x) 110 | x_w=self.fc_w(x) 111 | x_h=torch.cat([x_h*torch.cos(theta_h),x_h*torch.sin(theta_h)],dim=1) 112 | x_w=torch.cat([x_w*torch.cos(theta_w),x_w*torch.sin(theta_w)],dim=1) 113 | 114 | # x_1=self.fc_h(x) 115 | # x_2=self.fc_w(x) 116 | # x_h=torch.cat([x_1*torch.cos(theta_h),x_2*torch.sin(theta_h)],dim=1) 117 | # x_w=torch.cat([x_1*torch.cos(theta_w),x_2*torch.sin(theta_w)],dim=1) 118 | 119 | h = self.tfc_h(x_h) 120 | w = self.tfc_w(x_w) 121 | c = self.fc_c(x) 122 | a = F.adaptive_avg_pool2d(h + w + c,output_size=1) 123 | a = self.reweight(a).reshape(B, C, 3).permute(2, 0, 1).softmax(dim=0).unsqueeze(-1).unsqueeze(-1) 124 | x = h * a[0] + w * a[1] + c * a[2] 125 | x = self.proj(x) 126 | x = self.proj_drop(x) 127 | return x 128 | 129 | class WaveBlock(nn.Module): 130 | 131 | def __init__(self, dim, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 132 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.BatchNorm2d, mode='fc'): 133 | super().__init__() 134 | self.norm1 = norm_layer(dim) 135 | self.attn = PATM(dim, qkv_bias=qkv_bias, qk_scale=None, attn_drop=attn_drop,mode=mode) 136 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 137 | self.norm2 = norm_layer(dim) 138 | mlp_hidden_dim = int(dim * mlp_ratio) 139 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer) 140 | 141 | def forward(self, x): 142 | x = x + self.drop_path(self.attn(self.norm1(x))) 143 | x = x + self.drop_path(self.mlp(self.norm2(x))) 144 | return x 145 | 146 | 147 | class PatchEmbedOverlapping(nn.Module): 148 | def __init__(self, patch_size=16, stride=16, padding=0, in_chans=3, embed_dim=768, norm_layer=nn.BatchNorm2d, groups=1,use_norm=True): 149 | super().__init__() 150 | patch_size = to_2tuple(patch_size) 151 | stride = to_2tuple(stride) 152 | padding = to_2tuple(padding) 153 | self.patch_size = patch_size 154 | 155 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride, padding=padding, groups=groups) 156 | self.norm = norm_layer(embed_dim) if use_norm==True else nn.Identity() 157 | 158 | def forward(self, x): 159 | x = self.proj(x) 160 | x = self.norm(x) 161 | return x 162 | 163 | 164 | class Downsample(nn.Module): 165 | def __init__(self, in_embed_dim, out_embed_dim, patch_size,norm_layer=nn.BatchNorm2d,use_norm=True): 166 | super().__init__() 167 | assert patch_size == 2, patch_size 168 | self.proj = nn.Conv2d(in_embed_dim, out_embed_dim, kernel_size=(3, 3), stride=(2, 2), padding=1) 169 | self.norm = norm_layer(out_embed_dim) if use_norm==True else nn.Identity() 170 | def forward(self, x): 171 | x = self.proj(x) 172 | x = self.norm(x) 173 | return x 174 | 175 | 176 | def basic_blocks(dim, index, layers, mlp_ratio=3., qkv_bias=False, qk_scale=None, attn_drop=0., 177 | drop_path_rate=0.,norm_layer=nn.BatchNorm2d,mode='fc', **kwargs): 178 | blocks = [] 179 | for block_idx in range(layers[index]): 180 | block_dpr = drop_path_rate * (block_idx + sum(layers[:index])) / (sum(layers) - 1) 181 | blocks.append(WaveBlock(dim, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 182 | attn_drop=attn_drop, drop_path=block_dpr, norm_layer=norm_layer,mode=mode)) 183 | blocks = nn.Sequential(*blocks) 184 | return blocks 185 | 186 | class WaveNet(nn.Module): 187 | def __init__(self, layers, img_size=224, patch_size=4, in_chans=3, num_classes=1000, 188 | embed_dims=None, transitions=None, mlp_ratios=None, 189 | qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., 190 | norm_layer=nn.BatchNorm2d, fork_feat=False,mode='fc',ds_use_norm=True,args=None): 191 | 192 | super().__init__() 193 | 194 | if not fork_feat: 195 | self.num_classes = num_classes 196 | self.fork_feat = fork_feat 197 | 198 | self.patch_embed = PatchEmbedOverlapping(patch_size=7, stride=4, padding=2, in_chans=3, embed_dim=embed_dims[0],norm_layer=norm_layer,use_norm=ds_use_norm) 199 | 200 | network = [] 201 | for i in range(len(layers)): 202 | stage = basic_blocks(embed_dims[i], i, layers, mlp_ratio=mlp_ratios[i], qkv_bias=qkv_bias, 203 | qk_scale=qk_scale, attn_drop=attn_drop_rate, drop_path_rate=drop_path_rate, 204 | norm_layer=norm_layer,mode=mode) 205 | network.append(stage) 206 | if i >= len(layers) - 1: 207 | break 208 | # transitions = [True, True, True, True] 209 | if transitions[i] or embed_dims[i] != embed_dims[i+1]: 210 | patch_size = 2 if transitions[i] else 1 211 | network.append(Downsample(embed_dims[i], embed_dims[i+1], patch_size,norm_layer=norm_layer,use_norm=ds_use_norm)) 212 | 213 | self.network = nn.ModuleList(network) 214 | 215 | if self.fork_feat: 216 | # add a norm layer for each output 217 | self.out_indices = [0, 2, 4, 6] 218 | for i_emb, i_layer in enumerate(self.out_indices): 219 | if i_emb == 0 and os.environ.get('FORK_LAST3', None): 220 | layer = nn.Identity() 221 | else: 222 | layer = norm_layer(embed_dims[i_emb]) 223 | layer_name = f'norm{i_layer}' 224 | self.add_module(layer_name, layer) 225 | else: 226 | self.norm = norm_layer(embed_dims[-1]) 227 | self.head = nn.Linear(embed_dims[-1], num_classes) if num_classes > 0 else nn.Identity() 228 | self.apply(self.cls_init_weights) 229 | 230 | def cls_init_weights(self, m): 231 | if isinstance(m, nn.Linear): 232 | trunc_normal_(m.weight, std=.02) 233 | if isinstance(m, nn.Linear) and m.bias is not None: 234 | nn.init.constant_(m.bias, 0) 235 | elif isinstance(m, nn.LayerNorm) or isinstance(m, nn.BatchNorm2d): 236 | nn.init.constant_(m.bias, 0) 237 | nn.init.constant_(m.weight, 1.0) 238 | 239 | def init_weights(self, pretrained=None): 240 | """ mmseg or mmdet `init_weight` """ 241 | 242 | if isinstance(pretrained, str): 243 | pass; 244 | # logger = get_root_logger() 245 | # load_checkpoint(self, pretrained, map_location='cpu', strict=False, logger=logger) 246 | 247 | def get_classifier(self): 248 | return self.head 249 | 250 | def reset_classifier(self, num_classes, global_pool=''): 251 | self.num_classes = num_classes 252 | self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() 253 | 254 | def forward_embeddings(self, x): 255 | x = self.patch_embed(x) 256 | return x 257 | 258 | def forward_tokens(self, x): 259 | outs = [] 260 | for idx, block in enumerate(self.network): 261 | x = block(x) 262 | if self.fork_feat and idx in self.out_indices: 263 | norm_layer = getattr(self, f'norm{idx}') 264 | x_out = norm_layer(x) 265 | outs.append(x_out) 266 | # print('x_out',x_out.shape) 267 | # return outs 268 | # print('self.fork_feat',self.fork_feat) 269 | if self.fork_feat: 270 | return outs 271 | return x 272 | 273 | def forward(self, x): 274 | # print('a',x.shape) 275 | x = self.forward_embeddings(x) 276 | # print('b', x.shape) 277 | x = self.forward_tokens(x) 278 | # print('c',len(x)) 279 | # if self.fork_feat: 280 | # return x 281 | # x = self.norm(x) 282 | return x 283 | # cls_out = self.head(F.adaptive_avg_pool2d(x,output_size=1).flatten(1))#x.mean(1) 284 | # return cls_out 285 | 286 | def MyNorm(dim): 287 | return nn.GroupNorm(1, dim) 288 | 289 | @register_model 290 | def WaveMLP_T_dw(pretrained=False, **kwargs): 291 | transitions = [True, True, True, True] 292 | layers = [2, 2, 4, 2] 293 | mlp_ratios = [4, 4, 4, 4] 294 | embed_dims = [64, 128, 320, 512] 295 | model = WaveNet(layers, embed_dims=embed_dims, patch_size=7, transitions=transitions, 296 | fork_feat=True,mlp_ratios=mlp_ratios,mode='depthwise', **kwargs) 297 | # print('model',model) 298 | model.default_cfg = default_cfgs['wave_T'] 299 | return model 300 | 301 | @register_model 302 | def WaveMLP_T(pretrained=False, **kwargs): 303 | transitions = [True, True, True, True] 304 | layers = [2, 2, 4, 2] 305 | mlp_ratios = [4, 4, 4, 4] 306 | embed_dims = [64, 128, 320, 512] 307 | # model = WaveNet(layers, embed_dims=embed_dims, patch_size=7, transitions=transitions, 308 | # mlp_ratios=mlp_ratios, **kwargs) 309 | model = WaveNet(layers, embed_dims=embed_dims, patch_size=7, transitions=transitions, 310 | fork_feat=True, mlp_ratios=mlp_ratios, **kwargs) 311 | model.default_cfg = default_cfgs['wave_T'] 312 | return model 313 | 314 | @register_model 315 | def WaveMLP_S(pretrained=False, **kwargs): 316 | transitions = [True, True, True, True] 317 | layers = [2, 3, 10, 3] 318 | mlp_ratios = [4, 4, 4, 4] 319 | embed_dims = [64, 128, 320, 512] 320 | model = WaveNet(layers, embed_dims=embed_dims, patch_size=7, transitions=transitions, 321 | fork_feat=True, mlp_ratios=mlp_ratios,norm_layer=MyNorm, **kwargs) 322 | model.default_cfg = default_cfgs['wave_S'] 323 | return model 324 | 325 | @register_model 326 | def WaveMLP_M(pretrained=False, **kwargs): 327 | transitions = [True, True, True, True] 328 | layers = [3, 4, 18, 3] 329 | mlp_ratios = [8, 8, 4, 4] 330 | embed_dims = [64, 128, 320, 512] 331 | model = WaveNet(layers, embed_dims=embed_dims, patch_size=7, transitions=transitions, 332 | fork_feat=True, mlp_ratios=mlp_ratios,norm_layer=MyNorm,ds_use_norm=False, **kwargs) 333 | model.default_cfg = default_cfgs['wave_M'] 334 | return model 335 | 336 | @register_model 337 | def WaveMLP_B(pretrained=False, **kwargs): 338 | transitions = [True, True, True, True] 339 | layers = [2, 2, 18, 2] 340 | mlp_ratios = [4, 4, 4, 4] 341 | embed_dims = [96, 192, 384, 768] 342 | model = WaveNet(layers, embed_dims=embed_dims, patch_size=7, transitions=transitions, 343 | fork_feat=True, mlp_ratios=mlp_ratios,norm_layer=MyNorm,ds_use_norm=False, **kwargs) 344 | model.default_cfg = default_cfgs['wave_B'] 345 | return model 346 | if __name__ == '__main__': 347 | image = torch.randn(1, 3, 224, 224) 348 | ndsm = torch.randn(1, 3, 224, 224) 349 | ndsm1 = torch.randn(1, 3, 224, 224) 350 | net = WaveMLP_B() 351 | out = net(image) 352 | # print(out.shape) 353 | for i in out: 354 | print(i.shape) 355 | 356 | -------------------------------------------------------------------------------- /pytorch_iou/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | from torch.autograd import Variable 6 | import numpy as np 7 | 8 | def _iou(pred, target, size_average = True): 9 | 10 | b = pred.shape[0] 11 | IoU = 0.0 12 | for i in range(0,b): 13 | #compute the IoU of the foreground 14 | Iand1 = torch.sum(target[i,:,:,:]*pred[i,:,:,:]) 15 | Ior1 = torch.sum(target[i,:,:,:]) + torch.sum(pred[i,:,:,:])-Iand1 16 | IoU1 = Iand1/Ior1 17 | 18 | #IoU loss is (1-IoU1) 19 | IoU = IoU + (1-IoU1) 20 | 21 | return IoU/b 22 | 23 | class IOU(torch.nn.Module): 24 | def __init__(self, size_average = True): 25 | super(IOU, self).__init__() 26 | self.size_average = size_average 27 | 28 | def forward(self, pred, target): 29 | 30 | return _iou(pred, target, self.size_average) 31 | -------------------------------------------------------------------------------- /rgbt_dataset_KD.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image 3 | import torch.utils.data as data 4 | import torchvision.transforms as transforms 5 | import random 6 | import numpy as np 7 | from PIL import ImageEnhance 8 | 9 | 10 | # several data augumentation strategies 11 | def cv_random_flip(img, label, depth,bound,gt2): 12 | flip_flag = random.randint(0, 1) 13 | # flip_flag2= random.randint(0,1) 14 | # left right flip 15 | if flip_flag == 1: 16 | img = img.transpose(Image.FLIP_LEFT_RIGHT) 17 | label = label.transpose(Image.FLIP_LEFT_RIGHT) 18 | depth = depth.transpose(Image.FLIP_LEFT_RIGHT) 19 | bound = bound.transpose(Image.FLIP_LEFT_RIGHT) 20 | gt2 = gt2.transpose(Image.FLIP_LEFT_RIGHT) 21 | # top bottom flip 22 | # if flip_flag2==1: 23 | # img = img.transpose(Image.FLIP_TOP_BOTTOM) 24 | # label = label.transpose(Image.FLIP_TOP_BOTTOM) 25 | # depth = depth.transpose(Image.FLIP_TOP_BOTTOM) 26 | return img, label, depth,bound,gt2 27 | 28 | 29 | def randomCrop(image, label, depth,bound,gt2): 30 | border = 30 31 | image_width = image.size[0] 32 | image_height = image.size[1] 33 | crop_win_width = np.random.randint(image_width - border, image_width) 34 | crop_win_height = np.random.randint(image_height - border, image_height) 35 | random_region = ( 36 | (image_width - crop_win_width) >> 1, (image_height - crop_win_height) >> 1, (image_width + crop_win_width) >> 1, 37 | (image_height + crop_win_height) >> 1) 38 | return image.crop(random_region), label.crop(random_region), depth.crop(random_region),bound.crop(random_region),gt2.crop(random_region) 39 | 40 | 41 | def randomRotation(image, label, depth,bound,gt2): 42 | mode = Image.BICUBIC 43 | if random.random() > 0.8: 44 | random_angle = np.random.randint(-15, 15) 45 | image = image.rotate(random_angle, mode) 46 | label = label.rotate(random_angle, mode) 47 | depth = depth.rotate(random_angle, mode) 48 | bound = bound.rotate(random_angle, mode) 49 | gt2 = gt2.rotate(random_angle, mode) 50 | return image, label, depth,bound,gt2 51 | 52 | 53 | def colorEnhance(image): 54 | bright_intensity = random.randint(5, 15) / 10.0 55 | image = ImageEnhance.Brightness(image).enhance(bright_intensity) 56 | contrast_intensity = random.randint(5, 15) / 10.0 57 | image = ImageEnhance.Contrast(image).enhance(contrast_intensity) 58 | color_intensity = random.randint(0, 20) / 10.0 59 | image = ImageEnhance.Color(image).enhance(color_intensity) 60 | sharp_intensity = random.randint(0, 30) / 10.0 61 | image = ImageEnhance.Sharpness(image).enhance(sharp_intensity) 62 | return image 63 | 64 | 65 | def randomGaussian(image, mean=0.1, sigma=0.35): 66 | def gaussianNoisy(im, mean=mean, sigma=sigma): 67 | for _i in range(len(im)): 68 | im[_i] += random.gauss(mean, sigma) 69 | return im 70 | 71 | img = np.asarray(image) 72 | width, height = img.shape 73 | img = gaussianNoisy(img[:].flatten(), mean, sigma) 74 | img = img.reshape([width, height]) 75 | return Image.fromarray(np.uint8(img)) 76 | 77 | 78 | def randomPeper(img): 79 | img = np.array(img) 80 | noiseNum = int(0.0015 * img.shape[0] * img.shape[1]) 81 | for i in range(noiseNum): 82 | 83 | randX = random.randint(0, img.shape[0] - 1) 84 | 85 | randY = random.randint(0, img.shape[1] - 1) 86 | 87 | if random.randint(0, 1) == 0: 88 | 89 | img[randX, randY] = 0 90 | 91 | else: 92 | 93 | img[randX, randY] = 255 94 | return Image.fromarray(img) 95 | 96 | 97 | # dataset for training 98 | # The current loader is not using the normalized depth maps for training and test. If you use the normalized depth maps 99 | # (e.g., 0 represents background and 1 represents foreground.), the performance will be further improved. 100 | class SalObjDataset(data.Dataset): 101 | def __init__(self, image_root, gt_root, depth_root, bound_root,gt2_root,trainsize): 102 | self.trainsize = trainsize 103 | self.images = [image_root + f for f in os.listdir(image_root) if f.endswith('.jpg')] 104 | self.gts = [gt_root + f for f in os.listdir(gt_root) if f.endswith('.jpg') 105 | or f.endswith('.png')] 106 | self.depths = [depth_root + f for f in os.listdir(depth_root) if f.endswith('.bmp') 107 | or f.endswith('.png') or f.endswith('.jpg')] 108 | self.bound = [bound_root + f for f in os.listdir(bound_root) if f.endswith('.jpg') 109 | or f.endswith('.png')] 110 | self.images = sorted(self.images) 111 | self.gts2 = [gt2_root + f for f in os.listdir(gt_root) if f.endswith('.jpg') 112 | or f.endswith('.png')] 113 | # print(len(self.images)) 114 | self.gts = sorted(self.gts) 115 | self.gts2 = sorted(self.gts2) 116 | # print(len(self.gts)) 117 | self.depths = sorted(self.depths) 118 | self.bound = sorted(self.bound) 119 | # print(len(self.depths)) 120 | self.filter_files() 121 | self.size = len(self.images) 122 | self.img_transform = transforms.Compose([ 123 | transforms.Resize((self.trainsize, self.trainsize)), 124 | transforms.ToTensor(), 125 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) 126 | self.gt_transform = transforms.Compose([ 127 | transforms.Resize((self.trainsize, self.trainsize)), 128 | transforms.ToTensor()]) 129 | self.bound_transform = transforms.Compose([ 130 | transforms.Resize((self.trainsize, self.trainsize)), 131 | transforms.ToTensor()]) 132 | self.depths_transform = transforms.Compose([ 133 | transforms.Resize((self.trainsize, self.trainsize)), 134 | transforms.ToTensor(), 135 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 136 | ]) 137 | self.gt2_transform = transforms.Compose([ 138 | transforms.Resize((self.trainsize, self.trainsize)), 139 | transforms.ToTensor()]) 140 | 141 | def __getitem__(self, index): 142 | image = self.rgb_loader(self.images[index]) 143 | gt = self.binary_loader(self.gts[index]) 144 | gt2 = self.binary_loader(self.gts2[index]) 145 | # depth = self.binary_loader(self.depths[index]) 146 | depth = self.rgb_loader(self.depths[index]) 147 | bound = self.binary_loader(self.bound[index]) 148 | image, gt, depth,bound,gt2 = cv_random_flip(image, gt, depth,bound,gt2) 149 | image, gt, depth,bound,gt2 = randomCrop(image, gt, depth,bound,gt2) 150 | image, gt, depth,bound,gt2 = randomRotation(image, gt, depth,bound,gt2) 151 | image = colorEnhance(image) 152 | # gt=randomGaussian(gt) 153 | gt = randomPeper(gt) 154 | bound = randomPeper(bound) 155 | gt2 = randomPeper(gt2) 156 | # image, gt, depth = self.resize(image,gt, depth) 157 | image = self.img_transform(image) 158 | gt = self.gt_transform(gt) 159 | bound = self.bound_transform(bound) 160 | depth = self.depths_transform(depth) 161 | gt2 = self.gt2_transform(gt2) 162 | return image, gt, depth,bound,gt2 163 | 164 | def filter_files(self): 165 | # print('len(self.images)',len(self.images)) 166 | # print('len(self.depths)', len(self.depths)) 167 | # print('len(self.images)', len(self.images)) 168 | assert len(self.images) == len(self.depths) and len(self.gts) == len(self.images) 169 | images = [] 170 | gts = [] 171 | depths = [] 172 | bounds = [] 173 | gts2 = [] 174 | for img_path, gt_path, depth_path,bound_path,gt2_path in zip(self.images, self.gts, self.depths,self.bound,self.gts2): 175 | img = Image.open(img_path) 176 | gt = Image.open(gt_path) 177 | depth = Image.open(depth_path) 178 | bound = Image.open(bound_path) 179 | # if img.size == depth.size: 180 | # and gt.size == depth.size 181 | images.append(img_path) 182 | gts.append(gt_path) 183 | depths.append(depth_path) 184 | bounds.append(bound_path) 185 | gts2.append(gt2_path) 186 | self.images = images 187 | self.gts = gts 188 | self.depths = depths 189 | self.bound = bounds 190 | self.gts2 = gts2 191 | 192 | def rgb_loader(self, path): 193 | with open(path, 'rb') as f: 194 | img = Image.open(f) 195 | return img.convert('RGB') 196 | 197 | def binary_loader(self, path): 198 | with open(path, 'rb') as f: 199 | img = Image.open(f) 200 | return img.convert('L') 201 | 202 | def resize(self, img, gt, depth,gt2): 203 | assert img.size == gt.size and gt.size == depth.size 204 | h = self.trainsize 205 | w = self.trainsize 206 | return img.resize((w, h), Image.BILINEAR), gt.resize((w, h), Image.NEAREST), depth.resize((w, h),Image.NEAREST)#, gt2.resize((w, h), Image.NEAREST) 207 | 208 | 209 | 210 | def __len__(self): 211 | return self.size 212 | 213 | 214 | # dataloader for training 215 | def get_loader(image_root, gt_root, depth_root, bound_root,gt2_root, batchsize, trainsize, shuffle=True, num_workers=4, pin_memory=False): 216 | dataset = SalObjDataset(image_root, gt_root, depth_root, bound_root,gt2_root,trainsize) 217 | data_loader = data.DataLoader(dataset=dataset, 218 | batch_size=batchsize, 219 | shuffle=shuffle, 220 | num_workers=num_workers, 221 | pin_memory=pin_memory, 222 | drop_last=True 223 | ) 224 | return data_loader 225 | 226 | 227 | # test dataset and loader 228 | class test_dataset: 229 | def __init__(self, image_root, gt_root, depth_root, testsize): 230 | self.testsize = testsize 231 | self.images = [image_root + f for f in os.listdir(image_root) if f.endswith('.jpg')] 232 | self.gts = [gt_root + f for f in os.listdir(gt_root) if f.endswith('.jpg') 233 | or f.endswith('.png')] 234 | self.depths = [depth_root + f for f in os.listdir(depth_root) if f.endswith('.bmp') 235 | or f.endswith('.png')or f.endswith('.jpg')] 236 | self.images = sorted(self.images) 237 | self.gts = sorted(self.gts) 238 | self.depths = sorted(self.depths) 239 | self.transform = transforms.Compose([ 240 | transforms.Resize((self.testsize, self.testsize)), 241 | transforms.ToTensor(), 242 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) 243 | # self.gt_transform = transforms.ToTensor() 244 | self.gt_transform = transforms.Compose([ 245 | transforms.Resize((self.testsize, self.testsize)), 246 | transforms.ToTensor()]) 247 | self.depths_transform = transforms.Compose([ 248 | transforms.Resize((self.testsize, self.testsize)), 249 | transforms.ToTensor(), 250 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) 251 | self.size = len(self.images) 252 | self.index = 0 253 | 254 | def load_data(self): 255 | image = self.rgb_loader(self.images[self.index]) 256 | gt = self.binary_loader(self.gts[self.index]) 257 | # depth = self.binary_loader(self.depths[self.index]) 258 | depth = self.rgb_loader(self.depths[self.index]) 259 | # image, gt, depth = self.resize(image, gt, depth) 260 | image = self.transform(image).unsqueeze(0) 261 | gt = self.gt_transform(gt).unsqueeze(0) 262 | depth = self.depths_transform(depth).unsqueeze(0) 263 | name = self.images[self.index].split('/')[-1] 264 | if name.endswith('.jpg'): 265 | name = name.split('.jpg')[0] + '.png' 266 | self.index += 1 267 | self.index = self.index % self.size 268 | return image, gt, depth, name 269 | 270 | def rgb_loader(self, path): 271 | with open(path, 'rb') as f: 272 | img = Image.open(f) 273 | return img.convert('RGB') 274 | 275 | def binary_loader(self, path): 276 | with open(path, 'rb') as f: 277 | img = Image.open(f) 278 | return img.convert('L') 279 | 280 | def resize(self, img, gt, depth): 281 | # assert img.size == gt.size and gt.size == depth.size 282 | h = self.testsize 283 | w = self.testsize 284 | return img.resize((w, h), Image.BILINEAR), gt.resize((w, h), Image.NEAREST), depth.resize((w, h), 285 | Image.NEAREST) 286 | 287 | 288 | def __len__(self): 289 | return self.size 290 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import os 4 | import cv2 5 | 6 | from networks.Wavenet import WaveNet 7 | 8 | import matplotlib.pyplot as plt 9 | from config import opt 10 | from rgbt_dataset_KD import test_dataset 11 | from datetime import datetime 12 | 13 | dataset_path = opt.test_path 14 | 15 | #set device for test 16 | os.environ["CUDA_VISIBLE_DEVICES"] = opt.gpu_id 17 | print('USE GPU:', opt.gpu_id) 18 | 19 | #load the model 20 | model = WaveNet() 21 | print('NOW USING:WaveNet') 22 | model.load_state_dict(torch.load('')) 23 | 24 | model.cuda() 25 | model.eval() 26 | 27 | #test 28 | 29 | test_mae = [] 30 | test_datasets = ['VT800','VT1000', 'VT5000'] 31 | 32 | for dataset in test_datasets: 33 | 34 | mae_sum = 0 35 | save_path = 'show/' + dataset + '/' 36 | if not os.path.exists(save_path): 37 | os.makedirs(save_path) 38 | image_root = dataset_path + dataset + '/RGB/' 39 | gt_root = dataset_path + dataset + '/GT/' 40 | depth_root = dataset_path + dataset + '/T/' 41 | 42 | test_loader = test_dataset(image_root, gt_root,depth_root, opt.testsize) 43 | prec_time = datetime.now() 44 | for i in range(test_loader.size): 45 | image, gt, depth, name = test_loader.load_data() 46 | gt = gt.cuda() 47 | image = image.cuda() 48 | depth = depth.cuda() 49 | 50 | res = model(image, depth) 51 | res = torch.sigmoid(res[0]) 52 | res = (res - res.min()) / (res.max() - res.min() + 1e-8) 53 | mae_train = torch.sum((torch.abs(res - gt)) * 1.0 / (torch.numel(gt))) 54 | mae_sum = mae_train.item() + mae_sum 55 | predict = res.data.cpu().numpy().squeeze() 56 | print('save img to: ', save_path + name, ) 57 | predict = cv2.resize(predict,(224,224),interpolation=cv2.INTER_LINEAR) 58 | plt.imsave(save_path + name, arr=predict, cmap='gray') 59 | 60 | cur_time = datetime.now() 61 | test_mae.append(mae_sum / len(test_loader)) 62 | 63 | print('Test_mae:', test_mae) 64 | print('Test Done!') -------------------------------------------------------------------------------- /train_wave_KD.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn.functional as F 4 | 5 | from datetime import datetime 6 | from torchvision.utils import make_grid 7 | 8 | from networks.Wavenet import WaveNet 9 | from rgbt_dataset_KD import get_loader, test_dataset 10 | from utils import clip_gradient, adjust_lr 11 | from tensorboardX import SummaryWriter 12 | import logging 13 | import torch.backends.cudnn as cudnn 14 | from config import opt 15 | from torch.cuda import amp 16 | import pytorch_iou 17 | 18 | # set the device for training 19 | cudnn.benchmark = True 20 | 21 | os.environ["CUDA_VISIBLE_DEVICES"] = opt.gpu_id 22 | print('USE GPU:', opt.gpu_id) 23 | 24 | # build the model 25 | 26 | 27 | model = WaveNet() 28 | model.load_pre() 29 | 30 | model.cuda() 31 | params = model.parameters() 32 | optimizer = torch.optim.Adam(params, opt.lr) 33 | 34 | # set the path 35 | train_dataset_path = opt.lr_train_root 36 | image_root = train_dataset_path + '/RGB/' 37 | depth_root = train_dataset_path + '/T/' 38 | gt_root = train_dataset_path + '/GT/' 39 | bound_root = train_dataset_path + '/bound/' 40 | gt2_root = train_dataset_path + '/GT_SwinNet/' 41 | val_dataset_path = opt.lr_val_root 42 | val_image_root = val_dataset_path + '/RGB/' 43 | val_depth_root = val_dataset_path + '/T/' 44 | val_gt_root = val_dataset_path + '/GT/' 45 | save_path = opt.save_path 46 | 47 | if not os.path.exists(save_path): 48 | os.makedirs(save_path) 49 | 50 | # load data 51 | print('load data...') 52 | 53 | train_loader = get_loader(image_root, gt_root,depth_root,bound_root,gt2_root, batchsize=opt.batchsize, trainsize=opt.trainsize) 54 | # print(len(train_loader)) 55 | test_loader = test_dataset(val_image_root, val_gt_root,val_depth_root, opt.trainsize) 56 | total_step = len(train_loader) 57 | 58 | logging.basicConfig(filename=save_path + 'log.log', format='[%(asctime)s-%(filename)s-%(levelname)s:%(message)s]', 59 | level=logging.INFO, filemode='a', datefmt='%Y-%m-%d %I:%M:%S %p') 60 | logging.info(save_path + "Train") 61 | logging.info("Config") 62 | logging.info( 63 | 'epoch:{};lr:{};batchsize:{};trainsize:{};clip:{};decay_rate:{};load:{};save_path:{};decay_epoch:{}'.format( 64 | opt.epoch, opt.lr, opt.batchsize, opt.trainsize, opt.clip, opt.decay_rate, opt.load, save_path, 65 | opt.decay_epoch)) 66 | 67 | # set loss function 68 | def linear_annealing(init, fin, step, annealing_steps): 69 | """Linear annealing of a parameter.""" 70 | if annealing_steps == 0: 71 | return fin 72 | assert fin > init 73 | delta = fin - init 74 | annealed = min(init + delta * step / annealing_steps, fin) 75 | return annealed 76 | 77 | 78 | 79 | def joint_loss(pred, mask): 80 | weit = 1 + 5*torch.abs(F.avg_pool2d(mask, kernel_size=31, stride=1, padding=15) - mask) 81 | wbce = F.binary_cross_entropy_with_logits(pred, mask, reduce='none') 82 | wbce = (weit*wbce).sum(dim=(2, 3)) / weit.sum(dim=(2, 3)) 83 | 84 | pred = torch.sigmoid(pred) 85 | inter = ((pred * mask)*weit).sum(dim=(2, 3)) 86 | union = ((pred + mask)*weit).sum(dim=(2, 3)) 87 | wiou = 1 - (inter + 1)/(union - inter+1) 88 | return (wbce + wiou).mean() 89 | 90 | 91 | CE = torch.nn.BCEWithLogitsLoss() 92 | IOU = pytorch_iou.IOU(size_average = True) 93 | 94 | 95 | 96 | # 超参数 97 | step = 0 98 | writer = SummaryWriter(save_path + 'summary') 99 | best_mae = 1 100 | best_epoch = 0 101 | Sacler = amp.GradScaler() 102 | 103 | # train function 104 | length = 821 105 | 106 | def train(train_loader, model, optimizer, epoch, save_path): 107 | global step 108 | model.train() 109 | # model2.train() 110 | loss_all = 0 111 | epoch_step = 0 112 | try: 113 | for i, (images, gts,depths,bound,gts2) in enumerate(train_loader, start=1): 114 | optimizer.zero_grad() 115 | # ima = images 116 | # dep = depths 117 | images = images.cuda() 118 | # print(images.shape) 119 | depths = depths.cuda() 120 | gts = gts.cuda() 121 | bound = bound.cuda() 122 | gts2 = gts2.cuda() 123 | _,_,w_t,h_t = gts.size() 124 | 125 | s1, s2, s3, s4, s1_sig, s2_sig, s3_sig, s4_sig, edge, latent_loss =model(images, depths) 126 | ''' 127 | We directly use the predicts from the teacher model, but you can also use this type to load teacher model. 128 | 129 | t1,t2, t3, t4, t1_sig, t2_sig, t3_sig, t4_sig, tedge, tlatent_loss = model_T(images, depths) 130 | t1, t2 = model_T(images, depths) 131 | s1, s2, s3,s4,s1_sig, s2_sig, s3_sig,s4_sig= model(images, depths) # , self.sigmoid(s5) 132 | target,_ = model2(ima, dep) 133 | t1 = t1.cuda() 134 | m = nn.Sigmoid() 135 | t1 = m(t1) 136 | t1 = t1.detach() 137 | 138 | t2 = t2.cuda() 139 | m = nn.Sigmoid() 140 | t2 = m(t2) 141 | t2 = t2.detach() 142 | print('latent_loss',latent_loss) 143 | ''' 144 | loss1 = CE(s1, gts) + IOU(s1_sig, gts) 145 | loss2 = CE(s2, gts) + IOU(s2_sig, gts) 146 | loss3 = CE(s3, gts) + IOU(s3_sig, gts) 147 | loss4 = CE(s4, gts) + IOU(s4_sig, gts) 148 | loss5 = CE(edge, bound) 149 | anneal_reg = linear_annealing(0, 1, epoch, opt.epoch) 150 | loss6 = 0.1 * anneal_reg * latent_loss 151 | loss7 = CE(s1_sig, gts2) + IOU(s1_sig, gts2) 152 | 153 | loss = loss1 + loss2 + loss3 + loss4 + loss5 + loss6 + loss7 154 | loss.backward() 155 | 156 | 157 | clip_gradient(optimizer, opt.clip) 158 | 159 | optimizer.step() 160 | step += 1 161 | epoch_step = epoch_step +1 162 | loss_all =loss_all + loss.item() 163 | if i % 100 == 0 or i == total_step or i == 1: 164 | print('{} Epoch [{:03d}/{:03d}],W*H [{:03d}*{:03d}], Step [{:04d}/{:04d}], Loss: {:.4f}'. 165 | format(datetime.now(), epoch+1, opt.epoch, w_t, h_t, i, total_step, loss.item())) 166 | logging.info('#TRAIN#:Epoch [{:03d}/{:03d}], Step [{:04d}/{:04d}], Loss: {:.4f}'. 167 | format(epoch+1, opt.epoch, i, total_step, loss.item())) 168 | writer.add_scalar('Loss/total_loss', loss, global_step=step) 169 | grid_image = make_grid(images[0].clone().cpu().data,1,normalize=True) 170 | writer.add_image('train/RGB',grid_image, step) 171 | grid_image = make_grid(depths[0].clone().cpu().data, 1, normalize=True) 172 | writer.add_image('train/Ti', grid_image, step) 173 | 174 | loss_all /= epoch_step 175 | logging.info('#TRAIN#:Epoch [{:03d}/{:03d}], Loss_AVG: {:.4f}'.format(epoch+1, opt.epoch, loss_all)) 176 | writer.add_scalar('Loss-epoch', loss_all, global_step=epoch) 177 | if (epoch+1) % 10 == 0 or (epoch+1) == opt.epoch: 178 | torch.save(model.state_dict(), save_path + 'Epoch_{}_test.pth'.format(epoch+1)) 179 | except KeyboardInterrupt: 180 | print('Keyboard Interrupt: save model and exit.') 181 | if not os.path.exists(save_path): 182 | os.makedirs(save_path) 183 | torch.save(model.state_dict(), save_path + 'Epoch_{}_test.pth'.format(epoch + 1)) 184 | print('save checkpoints successfully!') 185 | raise 186 | 187 | # 188 | 189 | # test function 190 | def test(test_loader, model, epoch, save_path): 191 | global best_mae, best_epoch 192 | model.eval() 193 | with torch.no_grad(): 194 | mae_sum = 0 195 | for i in range(test_loader.size): 196 | image, gt, depth, name = test_loader.load_data() 197 | gt = gt.cuda() 198 | image = image.cuda() 199 | depth = depth.cuda() 200 | 201 | res = model(image, depth) 202 | res = torch.sigmoid(res[0]) 203 | 204 | res = (res-res.min())/(res.max()-res.min()+1e-8) 205 | mae_train =torch.sum(torch.abs(res-gt))*1.0/(torch.numel(gt)) 206 | mae_sum = mae_train.item()+mae_sum 207 | 208 | mae = mae_sum / test_loader.size # mae / length:.8f 209 | 210 | print('Epoch: {} MAE: {} #### bestMAE: {} bestEpoch: {}'.format(epoch, mae, best_mae, best_epoch)) 211 | if epoch == 1: 212 | best_mae = mae 213 | else: 214 | if mae < best_mae: 215 | best_mae = mae 216 | best_epoch = epoch 217 | torch.save(model.state_dict(), save_path + 'Best_mae_test.pth') 218 | print('best epoch:{}'.format(epoch)) 219 | logging.info('#TEST#:Epoch:{} MAE:{} bestEpoch:{} bestMAE:{}'.format(epoch, mae, best_epoch, best_mae)) 220 | 221 | 222 | 223 | if __name__ == '__main__': 224 | print("Start train...") 225 | start_time = datetime.now() 226 | for epoch in range(opt.epoch): 227 | cur_lr = adjust_lr(optimizer, opt.lr, epoch, opt.decay_rate, opt.decay_epoch) 228 | writer.add_scalar('learning_rate', cur_lr, global_step=epoch) 229 | train(train_loader, model, optimizer, epoch, save_path) 230 | test(test_loader, model, epoch, save_path) 231 | finish_time = datetime.now() 232 | h, remainder = divmod((finish_time - start_time).seconds, 3600) 233 | m, s = divmod(remainder, 60) 234 | time = 'Time: {:.0f}:{:.0f}:{:.0f}'.format(h, m, s) 235 | print(time) 236 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | def clip_gradient(optimizer, grad_clip): 2 | for group in optimizer.param_groups: 3 | for param in group['params']: 4 | if param.grad is not None: 5 | param.grad.data.clamp_(-grad_clip, grad_clip) 6 | 7 | 8 | def adjust_lr(optimizer1, init_lr, epoch, decay_rate=0.1, decay_epoch=30): 9 | decay = decay_rate ** (epoch // decay_epoch) 10 | for param_group in optimizer1.param_groups: 11 | param_group['lr'] = decay*init_lr 12 | lr=param_group['lr'] 13 | return lr 14 | 15 | --------------------------------------------------------------------------------