├── README.md
├── config.py
├── images
├── a.png
└── r.png
├── networks
├── SwinNet
│ ├── README.md
│ ├── SwinNet_test.py
│ ├── SwinNet_train.py
│ ├── cpts
│ │ └── RGBD-SwinNet.log
│ ├── data.py
│ ├── imgs
│ │ ├── RGBD.png
│ │ ├── RGBT.png
│ │ └── main.png
│ ├── logger.py
│ ├── lr_scheduler.py
│ ├── main.py
│ ├── models
│ │ ├── __init__.py
│ │ ├── __pycache__
│ │ │ ├── __init__.cpython-38.pyc
│ │ │ ├── build.cpython-38.pyc
│ │ │ └── swinNet.cpython-38.pyc
│ │ ├── back
│ │ │ └── Swin_Transformer.py
│ │ ├── build.py
│ │ └── swinNet.py
│ ├── optimizer.py
│ ├── options.py
│ ├── test_gray_to_rgb.py
│ └── utils.py
├── Wavenet.py
├── __init__.py
├── __pycache__
│ ├── Wavenet.cpython-36.pyc
│ ├── __init__.cpython-36.pyc
│ ├── __init__.cpython-38.pyc
│ └── wavemlp.cpython-36.pyc
└── wavemlp.py
├── pytorch_iou
└── __init__.py
├── rgbt_dataset_KD.py
├── test.py
├── train_wave_KD.py
└── utils.py
/README.md:
--------------------------------------------------------------------------------
1 | # WaveNet
2 | This project provides the code and results for 'WaveNet: Wavelet Network With Knowledge Distillation for RGB-T Salient Object Detection', IEEE TIP, 2023. [IEEE link](https://ieeexplore.ieee.org/document/10127616)
3 |
4 | - Old codebase: https://github.com/nowander/WaveNet/tree/default
5 | # Requirements
6 | Python 3.7+, Pytorch 1.5.0+, Cuda 10.2+, TensorboardX 2.1, opencv-python, pytorch_wavelets, timm.
7 |
8 | # Architecture and Details
9 |
10 |
11 |
12 | # Results
13 |
14 |
15 |
16 | # Preparation
17 | - Download the RGB-T raw data from [LSNet](https://github.com/zyrant/LSNet).
18 | - Options: Download the pre-trained wavemlp-s from [wavemlp](https://github.com/huawei-noah/Efficient-AI-Backbones/tree/master/wavemlp_pytorch).
19 | - We have two ways of training knowledge distillation:
20 | 1. Load the SwinNet model, please refer to the specific configuration of [SwinNet](https://github.com/liuzywen/SwinNet).
21 | 2. Directly load the prediction maps of SwinNet [baidu](https://pan.baidu.com/s/18qwaTwTZ39XtWlP3JaeSOQ) pin: py5y.
22 | We use prediction maps of SwinNet as the default setting.
23 |
24 |
25 |
26 | # Training & Testing
27 | Modify the `train_root` `train_root` `save_path` path in `config.py` according to your own data path.
28 | - Train the WaveNet:
29 |
30 | `python train.py`
31 |
32 | Modify the `test_path` path in `config.py` according to your own data path.
33 |
34 | - Test the WaveNet:
35 |
36 | `python test.py`
37 |
38 | # Evaluate tools
39 | - You can select one of the toolboxes to get the metrics
40 | [CODToolbox](https://github.com/DengPingFan/CODToolbox) / [PySODMetrics](https://github.com/lartpang/PySODMetrics)
41 |
42 | # Saliency Maps
43 | - RGB-T [baidu](https://pan.baidu.com/s/1L6gF3hT8ML3uN_p2va4n8w) pin: gl01
44 |
45 |
46 | # Pretraining Models
47 | - RGB-T [baidu](https://pan.baidu.com/s/1PGwu3uVRWyFS1erOBr7KAg) pin: v5pb
48 |
49 | # Citation
50 | @ARTICLE{10127616,
51 | author={Zhou, Wujie and Sun, Fan and Jiang, Qiuping and Cong, Runmin and Hwang, Jenq-Neng},
52 | journal={IEEE Transactions on Image Processing},
53 | title={WaveNet: Wavelet Network With Knowledge Distillation for RGB-T Salient Object Detection},
54 | year={2023},
55 | volume={32},
56 | number={},
57 | pages={3027-3039},
58 | doi={10.1109/TIP.2023.3275538}}
59 |
60 | # Acknowledgement
61 |
62 | The implementation of this project is based on the codebases below.
63 | - [BBS-Net](https://github.com/zyjwuyan/BBS-Net)
64 | - [LSNet](https://github.com/zyrant/LSNet)
65 | - [Wavemlp](https://github.com/huawei-noah/Efficient-AI-Backbones/tree/master/wavemlp_pytorch)
66 | - Evaluate tools [CODToolbox](https://github.com/DengPingFan/CODToolbox) / [PySODMetrics](https://github.com/lartpang/PySODMetrics)
67 |
68 | If you find this project helpful, Please also cite the codebases above. Besides, we also thank [zyrant](https://github.com/zyrant).
69 |
70 | # Contact
71 | Please drop me an email for any problems or discussion: https://wujiezhou.github.io/ (wujiezhou@163.com).
72 |
--------------------------------------------------------------------------------
/config.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | parser = argparse.ArgumentParser()
3 | # change dir according to your own dataset dirs
4 | # train/val
5 | parser.add_argument('--epoch', type=int, default=60, help='epoch number')
6 | parser.add_argument('--lr', type=float, default=3e-5, help='learning rate')
7 | parser.add_argument('--batchsize', type=int, default=6, help='training batch size')
8 | parser.add_argument('--trainsize', type=int, default=384, help='training dataset size')
9 | parser.add_argument('--clip', type=float, default=0.5, help='gradient clipping margin')
10 | parser.add_argument('--decay_rate', type=float, default=0.1, help='decay rate of learning rate')#0.1
11 | parser.add_argument('--decay_epoch', type=int, default=30, help='every n epochs decay learning rate')
12 | parser.add_argument('--load', type=str, default=None, help='train from checkpoints')
13 | parser.add_argument('--gpu_id', type=str, default='0', help='select gpu id')
14 | parser.add_argument('--lr_train_root', type=str, default='', help='the train images root')
15 | parser.add_argument('--lr_val_root', type=str, default='', help='the val images root')
16 | parser.add_argument('--save_path', type=str, default='', help='the path to save models and logs')
17 | # test(predict)
18 | parser.add_argument('--testsize', type=int, default=384, help='testing size')
19 | parser.add_argument('--test_path',type=str,default='/media/hjk/shuju/轻量级/data set/RGBT_test/',help='test dataset path')
20 | opt = parser.parse_args()
21 |
--------------------------------------------------------------------------------
/images/a.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nowander/WaveNet/7e1343bcc5b583f1aaa8d581d0ea3dbf101d9011/images/a.png
--------------------------------------------------------------------------------
/images/r.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nowander/WaveNet/7e1343bcc5b583f1aaa8d581d0ea3dbf101d9011/images/r.png
--------------------------------------------------------------------------------
/networks/SwinNet/README.md:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nowander/WaveNet/7e1343bcc5b583f1aaa8d581d0ea3dbf101d9011/networks/SwinNet/README.md
--------------------------------------------------------------------------------
/networks/SwinNet/SwinNet_test.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | @author: caigentan@AnHui University
4 | @software: PyCharm
5 | @file: SwinNet_test.py
6 | @time: 2021/5/31 09:34
7 | """
8 |
9 | import torch
10 | import torch.nn.functional as F
11 | import sys
12 | sys.path.append('./models')
13 | import numpy as np
14 | import os, argparse
15 | import cv2
16 | from models.swinNet import SwinNet
17 | from data import test_dataset
18 | import time
19 |
20 | parser = argparse.ArgumentParser()
21 | parser.add_argument('--testsize', type=int, default=384, help='testing size')
22 | parser.add_argument('--gpu_id', type=str, default='1', help='select gpu id')
23 | parser.add_argument('--test_path',type=str,default='./datasets/RGB-D/test/',help='test dataset path')
24 | # parser.add_argument('--test_path',type=str,default='./datasets/RGB-T/Test/',help='test dataset path')
25 | opt = parser.parse_args()
26 |
27 | dataset_path = opt.test_path
28 |
29 | #set device for test
30 | if opt.gpu_id=='0':
31 | os.environ["CUDA_VISIBLE_DEVICES"] = "0"
32 | print('USE GPU 0')
33 | elif opt.gpu_id=='1':
34 | os.environ["CUDA_VISIBLE_DEVICES"] = "1"
35 | print('USE GPU 1')
36 |
37 | #load the model
38 | model = SwinNet()
39 | #Large epoch size may not generalize well. You can choose a good model to load according to the log file and pth files saved in ('./BBSNet_cpts/') when training.
40 | model.load_state_dict(torch.load("./cpts/SwinNet.pth"))
41 |
42 | model.cuda()
43 | model.eval()
44 | fps = 0
45 |
46 | # RGBT-Test
47 | test_datasets = ['VT1000','VT821','VT5000']
48 | for dataset in test_datasets:
49 | time_s = time.time()
50 | sal_save_path = './test_maps/SwinTransNet_RGBT_speed/' + dataset + '/'
51 | edge_save_path = './test_maps/SwinTransNet_RGBT2_best/Edge/' + dataset + '/'
52 | if not os.path.exists(sal_save_path):
53 | os.makedirs(sal_save_path)
54 | # os.makedirs(edge_save_path)
55 | image_root = dataset_path + dataset + '/RGB/'
56 | gt_root = dataset_path + dataset + '/GT/'
57 | depth_root = dataset_path + dataset + '/T/'
58 | test_loader = test_dataset(image_root, gt_root, depth_root, opt.testsize)
59 | nums = test_loader.size
60 | for i in range(test_loader.size):
61 | image, gt, depth, name, image_for_post = test_loader.load_data()
62 | gt = np.asarray(gt, np.float32)
63 | gt /= (gt.max() + 1e-8)
64 | image = image.cuda()
65 | # depth = depth = depth.repeat(1,3,1,1).cuda()
66 | depth = depth.cuda()
67 |
68 |
69 | # print(depth.shape)
70 | res,edge = model(image,depth)
71 | res = F.upsample(res, size=gt.shape, mode='bilinear', align_corners=False)
72 | # edge = F.upsample(edge, size=gt.shape, mode='bilinear', align_corners=False)
73 | res = res.sigmoid().data.cpu().numpy().squeeze()
74 | # edge = edge.sigmoid().data.cpu().numpy().squeeze()
75 | res = (res - res.min()) / (res.max() - res.min() + 1e-8)
76 | # edge = (edge - edge.min()) / (edge.max() - edge.min() + 1e-8)
77 | print('save img to: ', sal_save_path + name)
78 | cv2.imwrite(sal_save_path + name, res * 255)
79 | time_e = time.time()
80 | fps += (nums / (time_e - time_s))
81 | print("FPS:%f" % (nums / (time_e - time_s)))
82 | print('Test Done!')
83 | print("Total FPS %f" % fps) # this result include I/O cost
84 |
85 | # test
86 |
87 | test_datasets = ['SIP','SSD','RedWeb','NJU2K','NLPR','STERE','DES','LFSD',]
88 | # test_datasets = ['DUT-RGBD']
89 | fps = 0
90 | for dataset in test_datasets:
91 | time_s = time.time()
92 | save_path = './test_maps/' + dataset + '/'
93 | edge_save_path = './test_maps/' + dataset + '/edge/'
94 | if not os.path.exists(save_path):
95 | os.makedirs(save_path)
96 | if not os.path.exists(edge_save_path):
97 | os.makedirs(edge_save_path)
98 | image_root = dataset_path + dataset + '/RGB/'
99 | gt_root = dataset_path + dataset + '/GT/'
100 | depth_root = dataset_path + dataset + '/depth/'
101 | test_loader = test_dataset(image_root, gt_root, depth_root, opt.testsize)
102 | nums = test_loader.size
103 | for i in range(test_loader.size):
104 | image, gt, depth, name, image_for_post = test_loader.load_data()
105 | gt = np.asarray(gt, np.float32)
106 | gt /= (gt.max() + 1e-8)
107 | image = image.cuda()
108 | depth = depth.repeat(1,3,1,1).cuda()
109 | res, edge = model(image,depth)
110 | res = F.upsample(res, size=gt.shape, mode='bilinear', align_corners=False)
111 | edge = F.upsample(edge, size=gt.shape, mode='bilinear', align_corners=False)
112 | res = res.sigmoid().data.cpu().numpy().squeeze()
113 | edge = edge.sigmoid().data.cpu().numpy().squeeze()
114 | res = (res - res.min()) / (res.max() - res.min() + 1e-8)
115 | edge = (edge - edge.min()) / (edge.max() - edge.min() + 1e-8)
116 | print('save img to: ',save_path+name)
117 | cv2.imwrite(save_path + name, res*255)
118 | cv2.imwrite(edge_save_path + name, edge * 255)
119 | time_e = time.time()
120 | fps += (nums / (time_e - time_s))
121 | print("FPS:%f" % (nums / (time_e - time_s)))
122 | print('Test Done!')
123 | print("Total FPS %f" % fps)
124 |
125 |
--------------------------------------------------------------------------------
/networks/SwinNet/SwinNet_train.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | @author: caigentan@AnHui University
4 | @software: PyCharm
5 | @file: SwinNet.py
6 | @time: 2021/5/6 16:12
7 | """
8 | import os
9 | import torch
10 | import torch.nn.functional as F
11 | import sys
12 |
13 | sys.path.append('./models')
14 | import numpy as np
15 | from datetime import datetime
16 | from models.swinNet import SwinTransformer,SwinNet
17 | from torchvision.utils import make_grid
18 | from data import get_loader, test_dataset
19 | from utils import clip_gradient, adjust_lr
20 | from tensorboardX import SummaryWriter
21 | import logging
22 | import torch.backends.cudnn as cudnn
23 | from options import opt
24 | import yaml
25 |
26 |
27 | # set the device for training
28 | if opt.gpu_id == '0':
29 | os.environ["CUDA_VISIBLE_DEVICES"] = "0"
30 | print('USE GPU 0')
31 | elif opt.gpu_id == '1':
32 | os.environ["CUDA_VISIBLE_DEVICES"] = "1"
33 | print('USE GPU 1')
34 | cudnn.benchmark = True
35 |
36 | image_root = opt.rgb_root
37 | gt_root = opt.gt_root
38 | depth_root = opt.depth_root
39 | edge_root = opt.edge_root
40 |
41 | val_image_root = opt.val_rgb_root
42 | val_gt_root = opt.val_gt_root
43 | val_depth_root = opt.val_depth_root
44 | save_path = opt.save_path
45 |
46 | logging.basicConfig(filename=save_path + 'RGBD-SwinNet.log',
47 | format='[%(asctime)s-%(filename)s-%(levelname)s:%(message)s]', level=logging.INFO, filemode='a',
48 | datefmt='%Y-%m-%d %I:%M:%S %p')
49 |
50 |
51 | model = SwinNet()
52 |
53 | num_parms = 0
54 | if (opt.load is not None):
55 | model.load_pre(opt.load)
56 | print('load model from ', opt.load)
57 |
58 | model.cuda()
59 | for p in model.parameters():
60 | num_parms += p.numel()
61 | logging.info("Total Parameters (For Reference): {}".format(num_parms))
62 | print("Total Parameters (For Reference): {}".format(num_parms))
63 |
64 | params = model.parameters()
65 | optimizer = torch.optim.Adam(params, opt.lr)
66 |
67 | # set the path
68 | if not os.path.exists(save_path):
69 | os.makedirs(save_path)
70 |
71 | # load data
72 | print('load data...')
73 | train_loader = get_loader(image_root, gt_root,depth_root, edge_root, batchsize=opt.batchsize, trainsize=opt.trainsize)
74 | test_loader = test_dataset(val_image_root, val_gt_root, val_depth_root, opt.trainsize)
75 | total_step = len(train_loader)
76 |
77 | logging.info("Config")
78 | logging.info(
79 | 'epoch:{};lr:{};batchsize:{};trainsize:{};clip:{};decay_rate:{};load:{};save_path:{};decay_epoch:{}'.format(
80 | opt.epoch, opt.lr, opt.batchsize, opt.trainsize, opt.clip, opt.decay_rate, opt.load, save_path,
81 | opt.decay_epoch))
82 |
83 | # set loss function
84 | CE = torch.nn.BCEWithLogitsLoss()
85 | ECE = torch.nn.BCELoss()
86 | step = 0
87 | writer = SummaryWriter(save_path + 'summary')
88 | best_mae = 1
89 | best_epoch = 0
90 |
91 |
92 | # train function
93 | def train(train_loader, model, optimizer, epoch, save_path):
94 | global step
95 | model.train()
96 | loss_all = 0
97 | epoch_step = 0
98 |
99 | try:
100 | for i, (images, gts, depth,edge) in enumerate(train_loader, start=1):
101 | optimizer.zero_grad()
102 |
103 | images = images.cuda()
104 | gts = gts.cuda()
105 | # RGB-T data
106 | # depth = depth.cuda()
107 | # RGB-D data
108 | depth = depth.repeat(1,3,1,1).cuda()
109 | edge = edge.cuda()
110 | s, e = model(images,depth)
111 |
112 | sal_loss = CE(s, gts)
113 | edge_loss = CE(e, edge)
114 | loss = sal_loss + edge_loss
115 | loss.backward()
116 |
117 | clip_gradient(optimizer, opt.clip)
118 | optimizer.step()
119 | step += 1
120 | epoch_step += 1
121 | loss_all += loss.data
122 | memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0)
123 | if i % 100 == 0 or i == total_step or i == 1:
124 | print('{} Epoch [{:03d}/{:03d}], Step [{:04d}/{:04d}], LR:{:.7f}||sal_loss:{:4f} ||edge_loss:{:4f}'.
125 | format(datetime.now(), epoch, opt.epoch, i, total_step,
126 | optimizer.state_dict()['param_groups'][0]['lr'], sal_loss.data, edge_loss.data))
127 | logging.info(
128 | '#TRAIN#:Epoch [{:03d}/{:03d}], Step [{:04d}/{:04d}], LR:{:.7f}, sal_loss:{:4f} ||edge_loss:{:4f}, mem_use:{:.0f}MB'.
129 | format(epoch, opt.epoch, i, total_step, optimizer.state_dict()['param_groups'][0]['lr'], loss.data,edge_loss.data,memory_used))
130 | writer.add_scalar('Loss', loss.data, global_step=step)
131 | grid_image = make_grid(images[0].clone().cpu().data, 1, normalize=True)
132 | writer.add_image('RGB', grid_image, step)
133 | grid_image = make_grid(gts[0].clone().cpu().data, 1, normalize=True)
134 | writer.add_image('Ground_truth', grid_image, step)
135 | res = s[0].clone()
136 | res = res.sigmoid().data.cpu().numpy().squeeze()
137 | res = (res - res.min()) / (res.max() - res.min() + 1e-8)
138 | writer.add_image('res', torch.tensor(res), step, dataformats='HW')
139 | loss_all /= epoch_step
140 | logging.info('#TRAIN#:Epoch [{:03d}/{:03d}],Loss_AVG: {:.4f}'.format(epoch, opt.epoch, loss_all))
141 | writer.add_scalar('Loss-epoch', loss_all, global_step=epoch)
142 | if (epoch) % 5 == 0:
143 | torch.save(model.state_dict(), save_path + 'SwinNet_epoch_{}.pth'.format(epoch))
144 | except KeyboardInterrupt:
145 | print('Keyboard Interrupt: save model and exit.')
146 | if not os.path.exists(save_path):
147 | os.makedirs(save_path)
148 | torch.save(model.state_dict(), save_path + 'SwinNet_epoch_{}.pth'.format(epoch + 1))
149 | print('save checkpoints successfully!')
150 | raise
151 |
152 | # test function
153 | def test(test_loader, model, epoch, save_path):
154 | global best_mae, best_epoch
155 | model.eval()
156 | with torch.no_grad():
157 | mae_sum = 0
158 | for i in range(test_loader.size):
159 | image, gt, depth, name, img_for_post = test_loader.load_data()
160 |
161 | gt = np.asarray(gt, np.float32)
162 | gt /= (gt.max() + 1e-8)
163 | image = image.cuda()
164 | # print(depth.shape)
165 | # depth = depth.cuda()
166 | depth = depth.repeat(1,3,1,1).cuda()
167 | res,e = model(image,depth)
168 | res = F.upsample(res, size=gt.shape, mode='bilinear', align_corners=False)
169 | res = res.sigmoid().data.cpu().numpy().squeeze()
170 | res = (res - res.min()) / (res.max() - res.min() + 1e-8)
171 | mae_sum += np.sum(np.abs(res - gt)) * 1.0 / (gt.shape[0] * gt.shape[1])
172 | mae = mae_sum / test_loader.size
173 | writer.add_scalar('MAE', torch.tensor(mae), global_step=epoch)
174 | print('Epoch: {} MAE: {} #### bestMAE: {} bestEpoch: {}'.format(epoch, mae, best_mae, best_epoch))
175 | if epoch == 1:
176 | best_mae = mae
177 | else:
178 | if mae < best_mae:
179 | best_mae = mae
180 | best_epoch = epoch
181 | torch.save(model.state_dict(), save_path + 'SwinNet_epoch_best.pth')
182 | print('best epoch:{}'.format(epoch))
183 | logging.info('#TEST#:Epoch:{} MAE:{} bestEpoch:{} bestMAE:{}'.format(epoch, mae, best_epoch, best_mae))
184 |
185 |
186 | if __name__ == '__main__':
187 | print("Start train...")
188 | for epoch in range(1, opt.epoch):
189 | cur_lr = adjust_lr(optimizer, opt.lr, epoch, opt.decay_rate, opt.decay_epoch)
190 | writer.add_scalar('learning_rate', cur_lr, global_step=epoch)
191 | train(train_loader, model, optimizer, epoch, save_path)
192 | test(test_loader, model, epoch, save_path)
193 |
--------------------------------------------------------------------------------
/networks/SwinNet/cpts/RGBD-SwinNet.log:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nowander/WaveNet/7e1343bcc5b583f1aaa8d581d0ea3dbf101d9011/networks/SwinNet/cpts/RGBD-SwinNet.log
--------------------------------------------------------------------------------
/networks/SwinNet/data.py:
--------------------------------------------------------------------------------
1 | import os
2 | from PIL import Image
3 | import torch.utils.data as data
4 | import torchvision.transforms as transforms
5 | import random
6 | import numpy as np
7 | from PIL import ImageEnhance
8 |
9 | # Reference from BBSNet, Thanks!!!
10 |
11 | # several data augumentation strategies
12 | def cv_random_flip(img, label, depth,edge):
13 | flip_flag = random.randint(0, 1)
14 | # flip_flag2= random.randint(0,1)
15 | # left right flip
16 | if flip_flag == 1:
17 | img = img.transpose(Image.FLIP_LEFT_RIGHT)
18 | label = label.transpose(Image.FLIP_LEFT_RIGHT)
19 | depth = depth.transpose(Image.FLIP_LEFT_RIGHT)
20 | edge = edge.transpose(Image.FLIP_LEFT_RIGHT)
21 | # top bottom flip
22 | # if flip_flag2==1:
23 | # img = img.transpose(Image.FLIP_TOP_BOTTOM)
24 | # label = label.transpose(Image.FLIP_TOP_BOTTOM)
25 | # depth = depth.transpose(Image.FLIP_TOP_BOTTOM)
26 | return img, label, depth, edge
27 |
28 |
29 | def randomCrop(image, label, depth, edge):
30 | border = 30
31 | image_width = image.size[0]
32 | image_height = image.size[1]
33 | crop_win_width = np.random.randint(image_width - border, image_width)
34 | crop_win_height = np.random.randint(image_height - border, image_height)
35 | random_region = (
36 | (image_width - crop_win_width) >> 1, (image_height - crop_win_height) >> 1, (image_width + crop_win_width) >> 1,
37 | (image_height + crop_win_height) >> 1)
38 | return image.crop(random_region), label.crop(random_region), depth.crop(random_region), edge.crop(random_region)
39 |
40 |
41 | def randomRotation(image, label, depth, edge):
42 | mode = Image.BICUBIC
43 | if random.random() > 0.8:
44 | random_angle = np.random.randint(-15, 15)
45 | image = image.rotate(random_angle, mode)
46 | label = label.rotate(random_angle, mode)
47 | depth = depth.rotate(random_angle, mode)
48 | edge = edge.rotate(random_angle, mode)
49 | return image, label, depth, edge
50 |
51 |
52 | def colorEnhance(image):
53 | bright_intensity = random.randint(5, 15) / 10.0
54 | image = ImageEnhance.Brightness(image).enhance(bright_intensity)
55 | contrast_intensity = random.randint(5, 15) / 10.0
56 | image = ImageEnhance.Contrast(image).enhance(contrast_intensity)
57 | color_intensity = random.randint(0, 20) / 10.0
58 | image = ImageEnhance.Color(image).enhance(color_intensity)
59 | sharp_intensity = random.randint(0, 30) / 10.0
60 | image = ImageEnhance.Sharpness(image).enhance(sharp_intensity)
61 | return image
62 |
63 |
64 | def randomGaussian(image, mean=0.1, sigma=0.35):
65 | def gaussianNoisy(im, mean=mean, sigma=sigma):
66 | for _i in range(len(im)):
67 | im[_i] += random.gauss(mean, sigma)
68 | return im
69 |
70 | img = np.asarray(image)
71 | width, height = img.shape
72 | img = gaussianNoisy(img[:].flatten(), mean, sigma)
73 | img = img.reshape([width, height])
74 | return Image.fromarray(np.uint8(img))
75 |
76 |
77 | def randomPeper(img):
78 | img = np.array(img)
79 | noiseNum = int(0.0015 * img.shape[0] * img.shape[1])
80 | for i in range(noiseNum):
81 |
82 | randX = random.randint(0, img.shape[0] - 1)
83 |
84 | randY = random.randint(0, img.shape[1] - 1)
85 |
86 | if random.randint(0, 1) == 0:
87 |
88 | img[randX, randY] = 0
89 |
90 | else:
91 |
92 | img[randX, randY] = 255
93 | return Image.fromarray(img)
94 |
95 |
96 | # dataset for training
97 | # The current loader is not using the normalized depth maps for training and test. If you use the normalized depth maps
98 | # (e.g., 0 represents background and 1 represents foreground.), the performance will be further improved.
99 | class SalObjDataset(data.Dataset):
100 | def __init__(self, image_root, gt_root, depth_root, edge_root, trainsize):
101 | self.trainsize = trainsize
102 | self.images = [image_root + f for f in os.listdir(image_root) if f.endswith('.jpg')]
103 |
104 | self.gts = [gt_root + f for f in os.listdir(gt_root) if f.endswith('.jpg')
105 | or f.endswith('.png')]
106 |
107 | self.depths = [depth_root + f for f in os.listdir(depth_root) if f.endswith('.bmp')
108 | or f.endswith('.png') or f.endswith('.jpg')]
109 | self.edges = [edge_root + f for f in os.listdir(edge_root) if f.endswith('.bmp')
110 | or f.endswith('.png') or f.endswith('.jpg')]
111 | self.images = sorted(self.images)
112 | self.gts = sorted(self.gts)
113 | self.depths = sorted(self.depths)
114 | self.edges = sorted(self.edges)
115 | self.filter_files()
116 | self.size = len(self.images)
117 | self.img_transform = transforms.Compose([
118 | transforms.Resize((self.trainsize, self.trainsize)),
119 | transforms.ToTensor(),
120 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
121 | self.gt_transform = transforms.Compose([
122 | transforms.Resize((self.trainsize, self.trainsize)),
123 | transforms.ToTensor()])
124 | self.depths_transform = transforms.Compose(
125 | [transforms.Resize((self.trainsize, self.trainsize)), transforms.ToTensor()])
126 | self.edges_transform = transforms.Compose(
127 | [transforms.Resize((self.trainsize, self.trainsize)), transforms.ToTensor()]
128 | )
129 |
130 | def __getitem__(self, index):
131 | image = self.rgb_loader(self.images[index])
132 | gt = self.binary_loader(self.gts[index])
133 | depth = self.binary_loader(self.depths[index]) # RGBD
134 | # depth = self.rgb_loader(self.depths[index]) # RGBT
135 | edge = self.binary_loader(self.edges[index])
136 |
137 | image, gt, depth, edge = cv_random_flip(image, gt, depth, edge)
138 | image, gt, depth, edge = randomCrop(image, gt, depth, edge)
139 | image, gt, depth, edge = randomRotation(image, gt, depth, edge)
140 | image = colorEnhance(image)
141 | # gt=randomGaussian(gt)
142 | gt = randomPeper(gt)
143 | image = self.img_transform(image)
144 | gt = self.gt_transform(gt)
145 | depth = self.depths_transform(depth)
146 | edge = self.edges_transform(edge)
147 |
148 | return image, gt, depth, edge
149 |
150 | def filter_files(self):
151 | assert len(self.images) == len(self.gts) and len(self.gts) == len(self.images)
152 | images = []
153 | gts = []
154 | depths = []
155 | edges = []
156 | for img_path, gt_path, depth_path, edge_path in zip(self.images, self.gts, self.depths, self.edges):
157 | img = Image.open(img_path)
158 | gt = Image.open(gt_path)
159 | depth = Image.open(depth_path)
160 | edge = Image.open(edge_path)
161 | if img.size == gt.size and gt.size == depth.size and edge.size == img.size:
162 | images.append(img_path)
163 | gts.append(gt_path)
164 | depths.append(depth_path)
165 | edges.append(edge_path)
166 | self.images = images
167 | self.gts = gts
168 | self.depths = depths
169 | self.edges = edges
170 |
171 | def rgb_loader(self, path):
172 | with open(path, 'rb') as f:
173 | img = Image.open(f)
174 | return img.convert('RGB')
175 |
176 | def binary_loader(self, path):
177 | with open(path, 'rb') as f:
178 | img = Image.open(f)
179 | return img.convert('L')
180 |
181 | def resize(self, img, gt, depth, edge):
182 | assert img.size == gt.size and gt.size == depth.size and edge.size == img.size
183 | w, h = img.size
184 | if h < self.trainsize or w < self.trainsize:
185 | h = max(h, self.trainsize)
186 | w = max(w, self.trainsize)
187 | return img.resize((w, h), Image.BILINEAR), gt.resize((w, h), Image.NEAREST), \
188 | depth.resize((w, h),Image.NEAREST), edge.resize((w,h), Image.NEAREST)
189 | else:
190 | return img, gt, depth, edge
191 |
192 | def __len__(self):
193 | return self.size
194 |
195 |
196 | # dataloader for training
197 | def get_loader(image_root, gt_root, depth_root, edge_root, batchsize, trainsize, shuffle=True, num_workers=0, pin_memory=True):
198 | dataset = SalObjDataset(image_root, gt_root, depth_root, edge_root, trainsize)
199 | # print(image_root)
200 | # print(gt_root)
201 | # print(depth_root)
202 | data_loader = data.DataLoader(dataset=dataset,
203 | batch_size=batchsize,
204 | shuffle=shuffle,
205 | num_workers=num_workers,
206 | pin_memory=pin_memory)
207 | return data_loader
208 |
209 |
210 | # test dataset and loader
211 | class test_dataset:
212 | def __init__(self, image_root, gt_root, depth_root, testsize):
213 | self.testsize = testsize
214 | self.images = [image_root + f for f in os.listdir(image_root) if f.endswith('.jpg')]
215 | self.gts = [gt_root + f for f in os.listdir(gt_root) if f.endswith('.jpg')
216 | or f.endswith('.png')]
217 | self.depths = [depth_root + f for f in os.listdir(depth_root) if f.endswith('.bmp')
218 | or f.endswith('.png')or f.endswith('.jpg')]
219 |
220 | self.images = sorted(self.images)
221 | self.gts = sorted(self.gts)
222 | self.depths = sorted(self.depths)
223 | self.transform = transforms.Compose([
224 | transforms.Resize((self.testsize, self.testsize)),
225 | transforms.ToTensor(),
226 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
227 | self.gt_transform = transforms.ToTensor()
228 | self.depths_transform = transforms.Compose(
229 | [transforms.Resize((self.testsize, self.testsize)), transforms.ToTensor()])
230 | self.size = len(self.images)
231 | self.index = 0
232 |
233 | def load_data(self):
234 | image = self.rgb_loader(self.images[self.index])
235 | image = self.transform(image).unsqueeze(0)
236 | gt = self.binary_loader(self.gts[self.index])
237 | depth = self.binary_loader(self.depths[self.index]) # RGBD
238 | # depth = self.rgb_loader(self.depths[self.index]) # RGBT
239 | depth = self.transform(depth).unsqueeze(0)
240 | name = self.images[self.index].split('/')[-1]
241 | image_for_post = self.rgb_loader(self.images[self.index])
242 | image_for_post = image_for_post.resize(gt.size)
243 | if name.endswith('.jpg'):
244 | name = name.split('.jpg')[0] + '.png'
245 | self.index += 1
246 | self.index = self.index % self.size
247 | return image, gt, depth, name, np.array(image_for_post)
248 |
249 | def rgb_loader(self, path):
250 | with open(path, 'rb') as f:
251 | img = Image.open(f)
252 | return img.convert('RGB')
253 |
254 | def binary_loader(self, path):
255 | with open(path, 'rb') as f:
256 | img = Image.open(f)
257 | return img.convert('L')
258 |
259 | def __len__(self):
260 | return self.size
261 |
262 |
--------------------------------------------------------------------------------
/networks/SwinNet/imgs/RGBD.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nowander/WaveNet/7e1343bcc5b583f1aaa8d581d0ea3dbf101d9011/networks/SwinNet/imgs/RGBD.png
--------------------------------------------------------------------------------
/networks/SwinNet/imgs/RGBT.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nowander/WaveNet/7e1343bcc5b583f1aaa8d581d0ea3dbf101d9011/networks/SwinNet/imgs/RGBT.png
--------------------------------------------------------------------------------
/networks/SwinNet/imgs/main.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nowander/WaveNet/7e1343bcc5b583f1aaa8d581d0ea3dbf101d9011/networks/SwinNet/imgs/main.png
--------------------------------------------------------------------------------
/networks/SwinNet/logger.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | @author: caigentan@AnHui University
4 | @software: PyCharm
5 | @file: logger.py
6 | @time: 2021/5/6 16:00
7 | """
8 | import os
9 | import sys
10 | import logging
11 | import functools
12 | from termcolor import colored
13 |
14 |
15 | @functools.lru_cache()
16 | def create_logger(output_dir, dist_rank=0, name=''):
17 | # create logger
18 | logger = logging.getLogger(name)
19 | logger.setLevel(logging.DEBUG)
20 | logger.propagate = False
21 |
22 | # create formatter
23 | fmt = '[%(asctime)s %(name)s] (%(filename)s %(lineno)d): %(levelname)s %(message)s'
24 | color_fmt = colored('[%(asctime)s %(name)s]', 'green') + \
25 | colored('(%(filename)s %(lineno)d)', 'yellow') + ': %(levelname)s %(message)s'
26 |
27 | # create console handlers for master process
28 | if dist_rank == 0:
29 | console_handler = logging.StreamHandler(sys.stdout)
30 | console_handler.setLevel(logging.DEBUG)
31 | console_handler.setFormatter(
32 | logging.Formatter(fmt=color_fmt, datefmt='%Y-%m-%d %H:%M:%S'))
33 | logger.addHandler(console_handler)
34 |
35 | # create file handlers
36 | file_handler = logging.FileHandler(os.path.join(output_dir, f'log_rank{dist_rank}.txt'), mode='a')
37 | file_handler.setLevel(logging.DEBUG)
38 | file_handler.setFormatter(logging.Formatter(fmt=fmt, datefmt='%Y-%m-%d %H:%M:%S'))
39 | logger.addHandler(file_handler)
40 |
41 | return logger
--------------------------------------------------------------------------------
/networks/SwinNet/lr_scheduler.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | @author: caigentan@AnHui University
4 | @software: PyCharm
5 | @file: lr_scheduler.py
6 | @time: 2021/5/6 16:04
7 | """
8 |
9 | import torch
10 | from timm.scheduler.cosine_lr import CosineLRScheduler
11 | from timm.scheduler.step_lr import StepLRScheduler
12 | from timm.scheduler.scheduler import Scheduler
13 |
14 |
15 | def build_scheduler(config, optimizer, n_iter_per_epoch):
16 | num_steps = int(config.TRAIN.EPOCHS * n_iter_per_epoch)
17 | warmup_steps = int(config.TRAIN.WARMUP_EPOCHS * n_iter_per_epoch)
18 | decay_steps = int(config.TRAIN.LR_SCHEDULER.DECAY_EPOCHS * n_iter_per_epoch)
19 |
20 | lr_scheduler = None
21 | if config.TRAIN.LR_SCHEDULER.NAME == 'cosine':
22 | lr_scheduler = CosineLRScheduler(
23 | optimizer,
24 | t_initial=num_steps,
25 | t_mul=1.,
26 | lr_min=config.TRAIN.MIN_LR,
27 | warmup_lr_init=config.TRAIN.WARMUP_LR,
28 | warmup_t=warmup_steps,
29 | cycle_limit=1,
30 | t_in_epochs=False,
31 | )
32 | elif config.TRAIN.LR_SCHEDULER.NAME == 'linear':
33 | lr_scheduler = LinearLRScheduler(
34 | optimizer,
35 | t_initial=num_steps,
36 | lr_min_rate=0.01,
37 | warmup_lr_init=config.TRAIN.WARMUP_LR,
38 | warmup_t=warmup_steps,
39 | t_in_epochs=False,
40 | )
41 | elif config.TRAIN.LR_SCHEDULER.NAME == 'step':
42 | lr_scheduler = StepLRScheduler(
43 | optimizer,
44 | decay_t=decay_steps,
45 | decay_rate=config.TRAIN.LR_SCHEDULER.DECAY_RATE,
46 | warmup_lr_init=config.TRAIN.WARMUP_LR,
47 | warmup_t=warmup_steps,
48 | t_in_epochs=False,
49 | )
50 |
51 | return lr_scheduler
52 |
53 |
54 | class LinearLRScheduler(Scheduler):
55 | def __init__(self,
56 | optimizer: torch.optim.Optimizer,
57 | t_initial: int,
58 | lr_min_rate: float,
59 | warmup_t=0,
60 | warmup_lr_init=0.,
61 | t_in_epochs=True,
62 | noise_range_t=None,
63 | noise_pct=0.67,
64 | noise_std=1.0,
65 | noise_seed=42,
66 | initialize=True,
67 | ) -> None:
68 | super().__init__(
69 | optimizer, param_group_field="lr",
70 | noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed,
71 | initialize=initialize)
72 |
73 | self.t_initial = t_initial
74 | self.lr_min_rate = lr_min_rate
75 | self.warmup_t = warmup_t
76 | self.warmup_lr_init = warmup_lr_init
77 | self.t_in_epochs = t_in_epochs
78 | if self.warmup_t:
79 | self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values]
80 | super().update_groups(self.warmup_lr_init)
81 | else:
82 | self.warmup_steps = [1 for _ in self.base_values]
83 |
84 | def _get_lr(self, t):
85 | if t < self.warmup_t:
86 | lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps]
87 | else:
88 | t = t - self.warmup_t
89 | total_t = self.t_initial - self.warmup_t
90 | lrs = [v - ((v - v * self.lr_min_rate) * (t / total_t)) for v in self.base_values]
91 | return lrs
92 |
93 | def get_epoch_values(self, epoch: int):
94 | if self.t_in_epochs:
95 | return self._get_lr(epoch)
96 | else:
97 | return None
98 |
99 | def get_update_values(self, num_updates: int):
100 | if not self.t_in_epochs:
101 | return self._get_lr(num_updates)
102 | else:
103 | return None
--------------------------------------------------------------------------------
/networks/SwinNet/main.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | @author: caigentan@AnHui University
4 | @software: PyCharm
5 | @file: main.py
6 | @time: 2021/4/6 16:10
7 | """
8 |
--------------------------------------------------------------------------------
/networks/SwinNet/models/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | @author: caigentan@AnHui University
4 | @software: PyCharm
5 | @file: __init__.py
6 | @time: 2021/5/6 15:58
7 | """
8 | from .build import build_model
--------------------------------------------------------------------------------
/networks/SwinNet/models/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nowander/WaveNet/7e1343bcc5b583f1aaa8d581d0ea3dbf101d9011/networks/SwinNet/models/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/networks/SwinNet/models/__pycache__/build.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nowander/WaveNet/7e1343bcc5b583f1aaa8d581d0ea3dbf101d9011/networks/SwinNet/models/__pycache__/build.cpython-38.pyc
--------------------------------------------------------------------------------
/networks/SwinNet/models/__pycache__/swinNet.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nowander/WaveNet/7e1343bcc5b583f1aaa8d581d0ea3dbf101d9011/networks/SwinNet/models/__pycache__/swinNet.cpython-38.pyc
--------------------------------------------------------------------------------
/networks/SwinNet/models/back/Swin_Transformer.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | @author: caigentan@AnHui University
4 | @software: PyCharm
5 | @file: build.py
6 | @time: 2021-05-03 11:31
7 | """
8 |
9 | import torch
10 | import torch.nn as nn
11 | import torch.utils.checkpoint as checkpoint
12 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_
13 | import numpy as np
14 | import math
15 |
16 | class Mlp(nn.Module):
17 | def __init__(self, in_features, hidden_features=None,out_features=None, act_layer=nn.GELU, drop=0.):
18 | super().__init__()
19 | out_features = out_features or in_features
20 | hidden_features = hidden_features or in_features
21 | self.fc1 = nn.Linear(in_features, hidden_features)
22 | self.act = act_layer()
23 | self.fc2 = nn.Linear(hidden_features, out_features)
24 | self.drop = nn.Dropout(drop)
25 |
26 | def forward(self,x):
27 | x = self.fc1(x)
28 | x = self.act(x)
29 | x = self.drop(x)
30 | x = self.fc2(x)
31 | x = self.drop(x)
32 | return x
33 |
34 | def window_partition(x, window_size):
35 | B, H, W, C = x.shape
36 | x = x.view(B, H//window_size, window_size, W//window_size, window_size, C)
37 | windows = x.permute(0,1,3,2,4,5).contiguous().view(-1, window_size, window_size, C)
38 | return windows
39 |
40 | def window_reverse(windows, window_size, H, W):
41 | B = int(windows.shape[0] / (H * W / window_size / window_size))
42 | x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
43 | x = x.permute(0,1,3,2,4,5).contiguous().view*(B, H, W, -1)
44 | return x
45 |
46 | class WindowAttention(nn.Module):
47 | def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
48 | super().__init__()
49 | self.dim = dim
50 | self.window_size = window_size
51 | self.num_heads = num_heads
52 | head_dim = dim // self.num_heads
53 | self.scale = qk_scale or head_dim ** -0.5
54 |
55 | self.relative_position_bias_table = nn.Parameter(
56 | torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1, num_heads))
57 | )
58 |
59 | coords_h = torch.arrange(self.window_size[0])
60 | coords_w = torch.arrange(self.window_size[1])
61 |
62 | coords = torch.stack(torch.meshgrid([coords_h, coords_w]))
63 | coords_flatten = torch.flatten(coords,1)
64 | relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
65 | relative_coords = relative_coords.permute(1,2,0).contiguous()
66 | relative_coords[:, :, 0] += self.window_size[0] - 1
67 | relative_coords[:, :, 1] += self.window_size[1] - 1
68 | relative_coords[:, :, 0] *= 2 * self.window_size[1] -1
69 | self.relative_position_index = relative_coords.sum(-1)
70 | self.register_buffer("relative_position_index", self.relative_position_index)
71 |
72 | self.qkv = nn.Linear(dim, dim*3, bias=qkv_bias)
73 | self.attn_drop = nn.Dropout(attn_drop)
74 | self.proj = nn.Linear(dim, dim)
75 | self.proj_drop = nn.Dropout(proj_drop)
76 |
77 | trunc_normal_(self.relative_position_bias_table, std=.02)
78 | self.softmax = nn.Softmax(dim=-1)
79 |
80 | def forward(self, x, mask=None):
81 | B_, N, C = x.shape
82 | qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
83 | q, k, v = qkv[0], qkv[1], qkv[2]
84 |
85 | q = q * self.scale
86 | attn = (q @ k.transpose(-2, -1))
87 |
88 | relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
89 | self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)
90 | attn = attn + relative_position_bias.unsqueeze(0)
91 |
92 | if mask is not None:
93 | nW = mask.shape[0]
94 | attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
95 | attn = attn.view(-1, self.num_heads, N, N)
96 | attn = self.softmax(attn)
97 |
98 | else:
99 | attn = self.softmax(attn)
100 |
101 | attn = self.attn_drop(attn)
102 |
103 | x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
104 | x = self.proj(x)
105 | x = self.proj_drop(x)
106 | return x
107 |
108 | def extra_repr(self) -> str:
109 | return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'
110 |
111 | def flops(self, N):
112 | flops = 0
113 | flops += N*self.dim*3*self.dim
114 | flops += self.num_heads * N * (self.dim // self.num_heads) * N
115 | flops += self.num_heads * N * N * (self.dim // self.num_heads)
116 | flops += N * self.dim + self.dim
117 | return flops
118 |
119 | class SwinTransformerBlock(nn.Module):
120 | def __init__(self, dim, input_resolution,num_heads,window_size=7, shift_size=0,
121 | mlp_ratio=4, qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
122 | act_layer=nn.GELU, norm_layer = nn.LayerNorm):
123 | super().__init__()
124 | self.dim = dim
125 | self.input_resolution = input_resolution
126 | self.num_heads = num_heads
127 | self.window_size = window_size
128 | self.shift_size = shift_size
129 | self.mlp_ratio = mlp_ratio
130 | if min(self.input_resolution) < self.window_size:
131 | self.shift_size = 0
132 | self.window_size = min(self.input_resolution)
133 |
134 | assert 0 <= self.shift_size < self.window_size,"shift_size must in 0-window_size"
135 |
136 | self.norm1 = norm_layer(dim)
137 | self.attn = WindowAttention(
138 | dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
139 | qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop
140 | )
141 | self.drop_path= DropPath(drop_path) if drop_path > 0. else nn.Identity()
142 | self.norm2 = norm_layer(dim)
143 | mlp_hidden_dim = int(dim * mlp_ratio)
144 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim,act_layer=act_layer, drop=drop)
145 |
146 | if self.shift_size > 0:
147 | H, W = self.input_resolution
148 | img_mask = torch.zeros((1, H, W, -1))
149 | h_slices = (slice(0, -self.window_size),
150 | slice(-self.window_size, -self.shift_size),
151 | slice(-self.shift_size,None))
152 | w_slices = (slice(0, -self.window_size),
153 | slice(-self.window_size, -self.shift_size),
154 | slice(-self.shift_size, None))
155 |
156 | cnt = 0
157 | for h in h_slices:
158 | for w in w_slices:
159 | img_mask[:, h, w, :] = cnt
160 | cnt += 1
161 |
162 | mask_windows = window_partition(img_mask, self.window_size)
163 | mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
164 | attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
165 | self.attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
166 | else:
167 | self.attn_mask = None
168 |
169 | self.register_buffer("attn_mask", self.attn_mask)
170 | def forward(self, x):
171 | H, W = self.input_resolution
172 | B, L, C = x.shape
173 | assert L == H * W, "input feature has wrong size"
174 |
175 | shortcut = x
176 | x = self.norm1(x)
177 | x = x.view(B, H, W, C)
178 |
179 | if self.shift_size > 0:
180 | shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
181 | else:
182 | shifted_x = x
183 |
184 | # position window
185 | x_windows = window_partition(shifted_x, self.window_size)
186 | x_windows = x_windows.view(-1, self.window_size * self.window_size, C)
187 |
188 | attn_windows = self.attn(x_windows, mask=self.attn_mask)
189 | attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
190 | shifted_x = window_reverse(attn_windows, self.window_size, H, W)
191 |
192 | if self.shift_size > 0:
193 | x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
194 | else:
195 | x = shifted_x
196 | x = x.view(B, H*W, C)
197 |
198 | #FFN
199 | x = shortcut + self.drop_path(x)
200 | x = x = self.drop_path(self.mlp(self.norm2(x)))
201 | return x
202 |
203 | def extra_repr(self) -> str:
204 | return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
205 | f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
206 |
207 | def flops(self):
208 | flops = 0
209 | H, W = self.input_resolution
210 |
211 | # norm1
212 | flops += self.dim * H * W
213 | nW = H * W / self.window_size / self.window_size
214 | flops += nW * self.attn.flops(self.window_size * self.window_size)
215 | flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
216 | flops += self.dim * H * W
217 | return flops
218 |
219 | class PatchMerging(nn.Module):
220 | def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
221 | super(self).__init__()
222 | self.input_resolution = input_resolution
223 | self.dim = dim
224 | self.reduction = nn.Linear(4 * dim, 2* dim, bias = False)
225 | self.norm =norm_layer(4 * dim)
226 |
227 | def forward(self, x):
228 | H, W = self.input_resolution
229 | B, L, C = x.shape
230 | assert L == H * W, "input feature has wrong size"
231 | assert H % 2 == 0 and W % 2 == 0,f"x size ({H}*{W}) are nor even(偶数)"
232 |
233 | x = x.view(B, H, W, C)
234 |
235 | x0 = x[:, 0::2, 0::2, :]
236 | x1 = x[:, 1::2, 0::2, :]
237 | x2 = x[:, 0::2, 1::2, :]
238 | x3 = x[:, 1::2, 1::2, :]
239 | x = torch.cat([x0,x1,x2,x3], -1)
240 | x = x.view(B, -1, 4 * C)
241 |
242 | x = self.norm(x)
243 | x =self.reduction(x)
244 |
245 | return x
246 |
247 | def extra_repr(self) -> str:
248 | return f"input_resolution={self.input_resolution}, dim={self.dim}"
249 |
250 | def flops(self):
251 | H,W = self.input_resolution
252 | flops = H * W * self.dim
253 | flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
254 | return flops
255 |
256 | class BasicLayer(nn.Module):
257 | def __init__(self, dim, input_resolution, depth, num_heads, window_size,
258 | mlp_ratio=4, qkv_bias = True, qk_scale=None, drop=0., attn_drop=0.,
259 | drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False):
260 | super().__init__()
261 | self.dim = dim
262 | self.input_resolution = input_resolution
263 | self.depth = depth
264 | self.use_checkpoint=use_checkpoint
265 |
266 | self.blocks = nn.ModuleList([
267 | SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
268 | num_heads=num_heads, window_size=window_size,
269 | shift_size=0 if (i % 2 == 0) else window_size // 2,
270 | mlp_ratio=mlp_ratio,
271 | qkv_bias=qkv_bias,qk_scale=qk_scale,
272 | drop=drop, attn_drop=attn_drop,
273 | drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
274 | norm_layer=norm_layer)
275 | for i in range(depth)])
276 |
277 | if downsample is not None:
278 | self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
279 | else:
280 | self.downsample = None
281 |
282 | def forward(self, x):
283 | for blk in self.blocks:
284 | if self.use_checkpoint:
285 | x = checkpoint.checkpoint(blk, x)
286 | else:
287 | x = blk(x)
288 | if self.downsample is not None:
289 | x = self.downsample(x)
290 | return x
291 |
292 | def extra_repr(self) -> str:
293 | return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
294 |
295 | def flops(self):
296 | flops = 0
297 | for blk in self.blocks:
298 | flops += blk.flops()
299 | if self.downsample is not None:
300 | flops += self.downsample.flops()
301 | return flops
302 |
303 | class PatchEmbed(nn.Module):
304 | def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
305 | super(PatchEmbed, self).__init__()
306 | img_size = to_2tuple(img_size)
307 | patch_size = to_2tuple(patch_size)
308 | patchs_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
309 | self.img_size = img_size
310 | self.patch_size = patch_size
311 | self.patchs_resolution = patchs_resolution
312 | self.num_patches = patchs_resolution[0] * patchs_resolution[1]
313 |
314 | self.in_chans = in_chans
315 | self.embed_dim = embed_dim
316 |
317 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
318 | if norm_layer is not None:
319 | self.norm = norm_layer(embed_dim)
320 | else:
321 | self.norm = None
322 |
323 | def forward(self, x):
324 | B, C, H, W = x.shape
325 | assert H == self.img_size[0] and W == self.img_size[1], \
326 | f"Imput image size ({H}*{W}) doesn't match model({self.img_size[0]}*{self.img_size[1]})"
327 | x = self.proj(x).flatten(2).transpose(1, 2)
328 | if self.norm is not None:
329 | x = self.norm(x)
330 | return x
331 |
332 | def flops(self):
333 | H0, W0 = self.patchs_resolution
334 | flops = H0 * W0 * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
335 | if self.norm is not None:
336 | flops += H0 * W0 * self.embed_dim
337 | return flops
338 |
339 | class SwinTransformer(nn.Module):
340 | def __init__(self, img_size=224, patch_size=4, in_chans=3, num_class=1000,
341 | embed_dim = 96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24],
342 | window_size=7, mlp_ratio = 4, qkv_bias = True, qk_scale=None,
343 | drop_rate = 0., attn_drop_rate=0., drop_path_rate=0.1,
344 | norm_layer=nn.LayerNorm, ape = False, patch_norm=True,
345 | use_checkpoint=False, **kwargs):
346 | super(SwinTransformer, self).__init__()
347 | self.num_class = num_class
348 | self.num_layers = len(depths)
349 | self.embed_dim = embed_dim
350 | self.ape = ape
351 | self.patch_norm = patch_norm
352 | self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))
353 | self.mlp_ratio = mlp_ratio
354 |
355 | #split image into non-overlapping patches
356 | self.patch_embed = PatchEmbed(
357 | img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
358 | norm_layer=norm_layer if self.patch_norm else None)
359 | num_patches = self.patch_embed.num_patches
360 | patches_resolution = self.patch_embed.patchs_resolution
361 | self.patches_resloution = patches_resolution
362 |
363 | # absolute position embedding
364 | if self.ape:
365 | self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
366 | trunc_normal_(self.absolute_pos_embed, std=.02)
367 | self.pos_drop = nn.Dropout(p=drop_rate)
368 |
369 | # stochastic depth 随机深度
370 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
371 |
372 | # build layer
373 | self.layers = nn.ModuleList()
374 | for i_layer in range(self.num_layers):
375 | layer = BasicLayer(dim = int(embed_dim * 2 ** i_layer),
376 | input_resolution = (patches_resolution[0] // (2 ** i_layer),
377 | patches_resolution[1] // (2 ** i_layer)),
378 | depth = depths[i_layer],
379 | num_heads = num_heads[i_layer],
380 | window_size = window_size,
381 | mlp_ratio = mlp_ratio,
382 | qkv_bias = qkv_bias,qk_scale=qk_scale,
383 | drop_path = dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
384 | norm_layer = norm_layer,
385 | downsample = PatchMerging if (i_layer < self.num_layers -1) else None,
386 | use_checkpoint = use_checkpoint)
387 | self.layers.appens(layer)
388 | self.norm = norm_layer(self.num_features)
389 | self.avgpool = nn.AdaptiveAvgPool1d(1)
390 | self.apply(self._init_weights)
391 |
392 | def _int_weights(self,m):
393 | if isinstance(m, nn.Linear):
394 | trunc_normal_(m.weight, std=.02)
395 | if isinstance(m, nn.Linear) and m.bias is not None:
396 | nn.init.constant_(m.bias, 0)
397 | elif isinstance(m, nn.LayerNorm):
398 | nn.init.constant_(m.bias, 0)
399 | nn.init.constant_(m.weight, 1.0)
400 |
401 | @torch.jit.ignore
402 | def no_weight_decay(self):
403 | return {'absolute_pos_embed'}
404 |
405 | @torch.jit.ignore
406 | def no_weight_decay_keywords(self):
407 | return {'relative_position_bias_table'}
408 |
409 | def forward_features(self, x):
410 | x = self.patch_embed(x)
411 | if self.ape:
412 | x += self.absolute_pos_embed
413 | x = self.pos_drop(x)
414 |
415 | for layer in self.layers:
416 | x = layer(x)
417 | x = self.norm(x)
418 | x = self.avgpool(x.transpose(1, 2))
419 | x = torch.flatten(x, 1)
420 | return x
421 |
422 | def forward(self, x):
423 | x = self.forward_features(x)
424 | return x
425 |
426 | def flops(self):
427 | flops = 0
428 | flops += self.patch_embed.flops()
429 | for i, layer in enumerate(self.layers):
430 | flops += layer.flops()
431 | flops += self.num_features * self.patches_resloution[0] * self.patches_resloution[1] // (2 ** self.num_layers)
432 | flops += self.num_features * self.num_class
433 | return flops
434 |
435 | if __name__ == '__main__':
436 | a = np.random.random((1,3,224,224))
437 | b = torch.Tensor(a).cuda()
438 | swintransformer = SwinTransformer().cuda()
439 | x = swintransformer(b)
440 | print(x.shape)
441 |
442 |
443 |
--------------------------------------------------------------------------------
/networks/SwinNet/models/build.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | @author: caigentan@AnHui University
4 | @software: PyCharm
5 | @file: build.py
6 | @time: 2021/5/3 00:31
7 | """
8 |
9 | # from models.back.Swin_Transformer import SwinTransformer
10 |
11 | from .swinNet import SwinTransformer
12 |
13 |
14 | def build_model(config):
15 | model_type = config.MODEL.TYPE
16 | if model_type == 'swin':
17 | model = SwinTransformer(img_size=config.DATA.IMG_SIZE,
18 | patch_size=config.MODEL.SWIN.PATCH_SIZE,
19 | in_chans=config.MODEL.SWIN.IN_CHANS,
20 | num_classes=config.MODEL.NUM_CLASSES,
21 | embed_dim=config.MODEL.SWIN.EMBED_DIM,
22 | depths=config.MODEL.SWIN.DEPTHS,
23 | num_heads=config.MODEL.SWIN.NUM_HEADS,
24 | window_size=config.MODEL.SWIN.WINDOW_SIZE,
25 | mlp_ratio=config.MODEL.SWIN.MLP_RATIO,
26 | qkv_bias=config.MODEL.SWIN.QKV_BIAS,
27 | qk_scale=config.MODEL.SWIN.QK_SCALE,
28 | drop_rate=config.MODEL.DROP_RATE,
29 | drop_path_rate=config.MODEL.DROP_PATH_RATE,
30 | ape=config.MODEL.SWIN.APE,
31 | patch_norm=config.MODEL.SWIN.PATCH_NORM,
32 | use_checkpoint=config.TRAIN.USE_CHECKPOINT)
33 | else:
34 | raise NotImplementedError(f"Unkown model: {model_type}")
35 |
36 | return model
--------------------------------------------------------------------------------
/networks/SwinNet/models/swinNet.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | @author: caigentan@AnHui University
4 | @software: PyCharm
5 | @file: SwinTransformer.py
6 | @time: 2021/5/6 6:13
7 | """
8 |
9 | import torch
10 | import torch.nn as nn
11 | import torch.utils.checkpoint as checkpoint
12 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_
13 | import numpy as np
14 | import torch.nn.functional as F
15 | import os
16 |
17 | def conv3x3(in_planes, out_planes, stride=1, has_bias=False):
18 | "3x3 convolution with padding"
19 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
20 | padding=1, bias=has_bias)
21 |
22 | def conv3x3_bn_relu(in_planes, out_planes, stride=1):
23 | return nn.Sequential(
24 | conv3x3(in_planes, out_planes, stride),
25 | nn.BatchNorm2d(out_planes),
26 | nn.ReLU(inplace=True),
27 | )
28 |
29 |
30 | class Mlp(nn.Module):
31 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
32 | super().__init__()
33 | out_features = out_features or in_features
34 | hidden_features = hidden_features or in_features
35 | self.fc1 = nn.Linear(in_features, hidden_features)
36 | self.act = act_layer()
37 | self.fc2 = nn.Linear(hidden_features, out_features)
38 | self.drop = nn.Dropout(drop)
39 |
40 | def forward(self, x):
41 | x = self.fc1(x)
42 | x = self.act(x)
43 | x = self.drop(x)
44 | x = self.fc2(x)
45 | x = self.drop(x)
46 | return x
47 |
48 |
49 | def window_partition(x, window_size):
50 | """
51 | Args:
52 | x: (B, H, W, C)
53 | window_size (int): window size
54 |
55 | Returns:
56 | windows: (num_windows*B, window_size, window_size, C) 堆叠到一起形成一个长条
57 | """
58 | B, H, W, C = x.shape
59 | x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
60 | windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
61 | return windows
62 |
63 |
64 | def window_reverse(windows, window_size, H, W):
65 | """
66 | Args:
67 | windows: (num_windows*B, window_size, window_size, C)
68 | window_size (int): Window size
69 | H (int): Height of image
70 | W (int): Width of image
71 |
72 | Returns:
73 | x: (B, H, W, C)
74 | """
75 | B = int(windows.shape[0] / (H * W / window_size / window_size))
76 | x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
77 | x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
78 | return x
79 |
80 |
81 | class WindowAttention(nn.Module):
82 | r""" Window based multi-head self attention (W-MSA) module with relative position bias.
83 | It supports both of shifted and non-shifted window.
84 |
85 | Args:
86 | dim (int): Number of input channels.
87 | window_size (tuple[int]): The height and width of the window.
88 | num_heads (int): Number of attention heads.
89 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
90 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
91 | attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
92 | proj_drop (float, optional): Dropout ratio of output. Default: 0.0
93 | """
94 |
95 | def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
96 |
97 | super().__init__()
98 | self.dim = dim
99 | self.window_size = window_size # Wh, Ww
100 | self.num_heads = num_heads
101 | head_dim = dim // num_heads
102 | self.scale = qk_scale or head_dim ** -0.5
103 |
104 | # define a parameter table of relative position bias
105 | self.relative_position_bias_table = nn.Parameter(
106 | torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
107 |
108 | # get pair-wise relative position index for each token inside the window
109 | coords_h = torch.arange(self.window_size[0])
110 | coords_w = torch.arange(self.window_size[1])
111 | coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
112 | coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
113 | relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
114 | relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
115 | relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
116 | relative_coords[:, :, 1] += self.window_size[1] - 1
117 | relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
118 | relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
119 | self.register_buffer("relative_position_index", relative_position_index)
120 |
121 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
122 | self.attn_drop = nn.Dropout(attn_drop)
123 | self.proj = nn.Linear(dim, dim)
124 | self.proj_drop = nn.Dropout(proj_drop)
125 |
126 | trunc_normal_(self.relative_position_bias_table, std=.02)
127 | self.softmax = nn.Softmax(dim=-1)
128 |
129 | def forward(self, x, mask=None):
130 | """
131 | Args:
132 | x: input features with shape of (num_windows*B, N, C)
133 | mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
134 | """
135 | B_, N, C = x.shape
136 | qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
137 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
138 |
139 | q = q * self.scale
140 | attn = (q @ k.transpose(-2, -1))
141 |
142 | relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
143 | self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
144 | relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
145 | attn = attn + relative_position_bias.unsqueeze(0)
146 |
147 | if mask is not None:
148 | nW = mask.shape[0]
149 | attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
150 | attn = attn.view(-1, self.num_heads, N, N)
151 | attn = self.softmax(attn)
152 | else:
153 | attn = self.softmax(attn)
154 |
155 | attn = self.attn_drop(attn)
156 |
157 | x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
158 | x = self.proj(x)
159 | x = self.proj_drop(x)
160 | return x
161 |
162 | def extra_repr(self) -> str:
163 | return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'
164 |
165 | def flops(self, N):
166 | # calculate flops for 1 window with token length of N
167 | flops = 0
168 | # qkv = self.qkv(x)
169 | flops += N * self.dim * 3 * self.dim
170 | # attn = (q @ k.transpose(-2, -1))
171 | flops += self.num_heads * N * (self.dim // self.num_heads) * N
172 | # x = (attn @ v)
173 | flops += self.num_heads * N * N * (self.dim // self.num_heads)
174 | # x = self.proj(x)
175 | flops += N * self.dim * self.dim
176 | return flops
177 |
178 |
179 | class SwinTransformerBlock(nn.Module):
180 | r""" Swin Transformer Block.
181 |
182 | Args:
183 | dim (int): Number of input channels.
184 | input_resolution (tuple[int]): Input resulotion.
185 | num_heads (int): Number of attention heads.
186 | window_size (int): Window size.
187 | shift_size (int): Shift size for SW-MSA.
188 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
189 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
190 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
191 | drop (float, optional): Dropout rate. Default: 0.0
192 | attn_drop (float, optional): Attention dropout rate. Default: 0.0
193 | drop_path (float, optional): Stochastic depth rate. Default: 0.0
194 | act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
195 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
196 | """
197 |
198 | def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
199 | mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
200 | act_layer=nn.GELU, norm_layer=nn.LayerNorm):
201 | super().__init__()
202 | self.dim = dim
203 | self.input_resolution = input_resolution
204 | self.num_heads = num_heads
205 | self.window_size = window_size
206 | self.shift_size = shift_size
207 | self.mlp_ratio = mlp_ratio
208 | if min(self.input_resolution) <= self.window_size:
209 | # if window size is larger than input resolution, we don't partition windows
210 | self.shift_size = 0
211 | self.window_size = min(self.input_resolution)
212 | assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
213 |
214 | self.norm1 = norm_layer(dim)
215 |
216 | self.attn = WindowAttention(
217 | dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
218 | qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
219 |
220 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
221 | self.norm2 = norm_layer(dim)
222 | mlp_hidden_dim = int(dim * mlp_ratio)
223 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
224 |
225 | if self.shift_size > 0:
226 | # calculate attention mask for SW-MSA
227 | H, W = self.input_resolution
228 | img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1---Important!!!
229 | h_slices = (slice(0, -self.window_size),
230 | slice(-self.window_size, -self.shift_size),
231 | slice(-self.shift_size, None))
232 | w_slices = (slice(0, -self.window_size),
233 | slice(-self.window_size, -self.shift_size),
234 | slice(-self.shift_size, None))
235 | cnt = 0
236 | for h in h_slices:
237 | for w in w_slices:
238 | img_mask[:, h, w, :] = cnt
239 | cnt += 1
240 |
241 | mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
242 | mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
243 | attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
244 | attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
245 | else:
246 | attn_mask = None
247 |
248 | self.register_buffer("attn_mask", attn_mask)
249 |
250 | def forward(self, x):
251 | # x: B,C,H,W
252 | H, W = self.input_resolution
253 | B, L, C = x.shape
254 | assert L == H * W, "input feature has wrong size"
255 |
256 | shortcut = x
257 | x = self.norm1(x)
258 | x = x.view(B, H, W, C)
259 |
260 | # cyclic shift
261 | if self.shift_size > 0:
262 | shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
263 | else:
264 | shifted_x = x
265 |
266 | # partition windows
267 | x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
268 | x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
269 |
270 | # W-MSA/SW-MSA
271 | attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C
272 |
273 | # merge windows
274 | attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
275 | shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
276 |
277 | # reverse cyclic shift
278 | if self.shift_size > 0:
279 | x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
280 | else:
281 | x = shifted_x
282 | x = x.view(B, H * W, C)
283 |
284 | # FFN
285 | x = shortcut + self.drop_path(x)
286 | x = x + self.drop_path(self.mlp(self.norm2(x)))
287 | # print('FFN',x.shape)
288 | return x
289 |
290 | def extra_repr(self) -> str:
291 | return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
292 | f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
293 |
294 | def flops(self):
295 | flops = 0
296 | H, W = self.input_resolution
297 | # norm1
298 | flops += self.dim * H * W
299 | # W-MSA/SW-MSA
300 | nW = H * W / self.window_size / self.window_size
301 | flops += nW * self.attn.flops(self.window_size * self.window_size)
302 | # mlp
303 | flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
304 | # norm2
305 | flops += self.dim * H * W
306 | return flops
307 |
308 |
309 | class PatchMerging(nn.Module):
310 | r""" Patch Merging Layer.
311 |
312 | Args:
313 | input_resolution (tuple[int]): Resolution of input feature.
314 | dim (int): Number of input channels.
315 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
316 | """
317 |
318 | def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
319 | super().__init__()
320 | self.input_resolution = input_resolution
321 | self.dim = dim
322 | self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
323 | self.norm = norm_layer(4 * dim)
324 |
325 | def forward(self, x):
326 | """
327 | x: B, H*W, C
328 | """
329 | H, W = self.input_resolution
330 | B, L, C = x.shape
331 | assert L == H * W, "input feature has wrong size"
332 | assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
333 |
334 | x = x.view(B, H, W, C)
335 |
336 | x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
337 | x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
338 | x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
339 | x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
340 | x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
341 | x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
342 |
343 | x = self.norm(x)
344 | x = self.reduction(x)
345 |
346 | return x
347 |
348 | def extra_repr(self) -> str:
349 | return f"input_resolution={self.input_resolution}, dim={self.dim}"
350 |
351 | def flops(self):
352 | H, W = self.input_resolution
353 | flops = H * W * self.dim
354 | flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
355 | return flops
356 |
357 |
358 | class BasicLayer(nn.Module):
359 | """ A basic Swin Transformer layer for one stage.
360 |
361 | Args:
362 | dim (int): Number of input channels.
363 | input_resolution (tuple[int]): Input resolution.
364 | depth (int): Number of blocks.
365 | num_heads (int): Number of attention heads.
366 | window_size (int): Local window size.
367 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
368 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
369 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
370 | drop (float, optional): Dropout rate. Default: 0.0
371 | attn_drop (float, optional): Attention dropout rate. Default: 0.0
372 | drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
373 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
374 | downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
375 | use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
376 | """
377 |
378 | def __init__(self, dim, input_resolution, depth, num_heads, window_size,
379 | mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
380 | drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False):
381 |
382 | super().__init__()
383 | self.dim = dim
384 | self.input_resolution = input_resolution
385 | self.depth = depth
386 | self.use_checkpoint = use_checkpoint
387 |
388 | # build blocks
389 | self.blocks = nn.ModuleList([
390 | SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
391 | num_heads=num_heads, window_size=window_size,
392 | shift_size=0 if (i % 2 == 0) else window_size // 2,
393 | mlp_ratio=mlp_ratio,
394 | qkv_bias=qkv_bias, qk_scale=qk_scale,
395 | drop=drop, attn_drop=attn_drop,
396 | drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
397 | norm_layer=norm_layer)
398 | for i in range(depth)])
399 |
400 | # patch merging layer
401 | if downsample is not None:
402 | self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
403 | else:
404 | self.downsample = None
405 |
406 | def forward(self, x):
407 | for blk in self.blocks:
408 | if self.use_checkpoint:
409 | x = checkpoint.checkpoint(blk, x)
410 | else:
411 | x = blk(x)
412 | if self.downsample is not None:
413 | x = self.downsample(x)
414 | return x
415 |
416 | def extra_repr(self) -> str:
417 | return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
418 |
419 | def flops(self):
420 | flops = 0
421 | for blk in self.blocks:
422 | flops += blk.flops()
423 | if self.downsample is not None:
424 | flops += self.downsample.flops()
425 | return flops
426 |
427 |
428 | class PatchEmbed(nn.Module):
429 | r""" Image to Patch Embedding
430 | First step !!!!!!主要作用在于将要输入到SwinTransNet的特征图下采样4倍并将通道变成初始的embed_dim
431 | Args:
432 | img_size (int): Image size. Default: 224.
433 | patch_size (int): Patch token size. Default: 4.
434 | in_chans (int): Number of input image channels. Default: 3.
435 | embed_dim (int): Number of linear projection output channels. Default: 96.
436 | patches_resolution: 以每个patch大小为单位的分辨率
437 | num_patches: patch 的数量
438 | norm_layer (nn.Module, optional): Normalization layer. Default: None
439 | """
440 |
441 | def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
442 | super().__init__()
443 | img_size = to_2tuple(img_size)
444 | patch_size = to_2tuple(patch_size)
445 | patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
446 | self.img_size = img_size
447 | self.patch_size = patch_size
448 | self.patches_resolution = patches_resolution
449 | self.num_patches = patches_resolution[0] * patches_resolution[1]
450 |
451 | self.in_chans = in_chans # define in_chans == 3
452 | self.embed_dim = embed_dim # Swin-B.embed_dim ==128,(T is 96)
453 |
454 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) # dim 3->128
455 | if norm_layer is not None:
456 | self.norm = norm_layer(embed_dim)
457 | else:
458 | self.norm = None
459 |
460 | def forward(self, x):
461 | B, C, H, W = x.shape
462 | # FIXME look at relaxing size constraints,尺寸固定,下有断言
463 | assert H == self.img_size[0] and W == self.img_size[1], \
464 | f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
465 | x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C
466 | if self.norm is not None:
467 | x = self.norm(x)
468 | return x
469 |
470 | def flops(self):
471 | Ho, Wo = self.patches_resolution
472 | flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
473 | if self.norm is not None:
474 | flops += Ho * Wo * self.embed_dim
475 | return flops
476 |
477 |
478 | class SwinTransformer(nn.Module):
479 | r""" Swin Transformer
480 | A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
481 | https://arxiv.org/pdf/2103.14030
482 |
483 | Args:
484 | img_size (int | tuple(int)): Input image size. Default 224
485 | patch_size (int | tuple(int)): Patch size. Default: 4
486 | in_chans (int): Number of input image channels. Default: 3
487 | num_classes (int): Number of classes for classification head. Default: 1000
488 | embed_dim (int): Patch embedding dimension. Default: 96
489 | depths (tuple(int)): Depth of each Swin Transformer layer.
490 | num_heads (tuple(int)): Number of attention heads in different layers.
491 | window_size (int): Window size. Default: 7
492 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
493 | qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
494 | qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
495 | drop_rate (float): Dropout rate. Default: 0
496 | attn_drop_rate (float): Attention dropout rate. Default: 0
497 | drop_path_rate (float): Stochastic depth rate. Default: 0.1
498 | norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
499 | ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
500 | patch_norm (bool): If True, add normalization after patch embedding. Default: True
501 | use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
502 | """
503 |
504 | def __init__(self, img_size=384, patch_size=4, in_chans=3, num_classes=1000,
505 | embed_dim=128, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24],
506 | window_size=12, mlp_ratio=4., qkv_bias=True, qk_scale=None,
507 | drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
508 | norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
509 | use_checkpoint=False, **kwargs):
510 | super().__init__()
511 |
512 | self.num_classes = num_classes
513 | self.num_layers = len(depths)
514 | self.embed_dim = embed_dim
515 | self.ape = ape
516 | self.patch_norm = patch_norm
517 | self.num_features = int(embed_dim * 2 ** (self.num_layers - 1)) # TODO
518 | self.mlp_ratio = mlp_ratio
519 |
520 | # split image into non-overlapping patches
521 | self.patch_embed = PatchEmbed(
522 | img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
523 | norm_layer=norm_layer if self.patch_norm else None)
524 | num_patches = self.patch_embed.num_patches
525 | patches_resolution = self.patch_embed.patches_resolution
526 | self.patches_resolution = patches_resolution
527 |
528 | # absolute position embedding
529 | if self.ape:
530 | self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
531 | trunc_normal_(self.absolute_pos_embed, std=.02)
532 |
533 | self.pos_drop = nn.Dropout(p=drop_rate)
534 |
535 | # stochastic depth
536 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
537 |
538 | # build layers
539 | self.layers = nn.ModuleList()
540 | for i_layer in range(self.num_layers):
541 | layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer),
542 | input_resolution=(patches_resolution[0] // (2 ** i_layer),
543 | patches_resolution[1] // (2 ** i_layer)),
544 | depth=depths[i_layer],
545 | num_heads=num_heads[i_layer],
546 | window_size=window_size,
547 | mlp_ratio=self.mlp_ratio,
548 | qkv_bias=qkv_bias, qk_scale=qk_scale,
549 | drop=drop_rate, attn_drop=attn_drop_rate,
550 | drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
551 | norm_layer=norm_layer,
552 | downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
553 | use_checkpoint=use_checkpoint)
554 | self.layers.append(layer)
555 |
556 | self.norm = norm_layer(self.num_features)
557 | # self.avgpool = nn.AdaptiveAvgPool1d(1)
558 | # self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
559 | # self.swin_decoder = Swin_decoder()
560 | self.apply(self._init_weights)
561 |
562 | def _init_weights(self, m):
563 | if isinstance(m, nn.Linear):
564 | trunc_normal_(m.weight, std=.02)
565 | if isinstance(m, nn.Linear) and m.bias is not None:
566 | nn.init.constant_(m.bias, 0)
567 | elif isinstance(m, nn.LayerNorm):
568 | nn.init.constant_(m.bias, 0)
569 | nn.init.constant_(m.weight, 1.0)
570 |
571 | @torch.jit.ignore
572 | def no_weight_decay(self):
573 | return {'absolute_pos_embed'}
574 |
575 | @torch.jit.ignore
576 | def no_weight_decay_keywords(self):
577 | return {'relative_position_bias_table'}
578 |
579 | def forward_features(self, x):
580 | layer_features = []
581 | x = self.patch_embed(x)
582 | B,L,C = x.shape
583 | layer_features.append(x.view(B, int(np.sqrt(L)), int(np.sqrt(L)),-1).permute(0, 3, 1, 2).contiguous())
584 | # layer_features.append(x)
585 | if self.ape:
586 | x = x + self.absolute_pos_embed
587 | x = self.pos_drop(x)
588 |
589 | for layer in self.layers:
590 | x = layer(x)
591 | B, L, C = x.shape
592 | # print('x:', x.shape)
593 | xl = x.view(B, int(np.sqrt(L)), int(np.sqrt(L)),-1).permute(0, 3, 1, 2).contiguous()
594 | # print('xl',xl.shape)
595 | layer_features.append(xl)
596 | x = self.norm(x) # B L C
597 | B, L, C = x.shape
598 | # x = self.avgpool(x.ranspose(1, 2)) # B C 1
599 | x = x.view(B, int(np.sqrt(L)), int(np.sqrt(L)),-1).permute(0, 3, 1, 2).contiguous()
600 | # x = torch.flatten(x, 1)
601 | layer_features[-1] = x
602 | # self.swin_decoder(layer_features)
603 |
604 |
605 | return layer_features
606 |
607 | def forward(self, x):
608 | outs = self.forward_features(x)
609 | # print("len outs",len(outs))
610 | # sal, edge = self.swin_decoder(outs)
611 | return outs
612 |
613 | def flops(self):
614 | flops = 0
615 | flops += self.patch_embed.flops()
616 | for i, layer in enumerate(self.layers):
617 | flops += layer.flops()
618 | flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2 ** self.num_layers)
619 | flops += self.num_features * self.num_classes
620 | return flops
621 |
622 | class SwinNet(nn.Module):
623 | def __init__(self):
624 | super(SwinNet, self).__init__()
625 |
626 | self.rgb_swin = SwinTransformer(embed_dim=128, depths=[2,2,18,2], num_heads=[4,8,16,32])
627 | self.depth_swin = SwinTransformer(embed_dim=128, depths=[2,2,18,2], num_heads=[4,8,16,32])
628 | self.fuse_enhance1 = fuse_enhance(1024)
629 | self.fuse_enhance2 = fuse_enhance(512)
630 | self.fuse_enhance3 = fuse_enhance(256)
631 | self.fuse_enhance4 = fuse_enhance(128)
632 | self.up2 = nn.UpsamplingBilinear2d(scale_factor = 2)
633 | self.up4 = nn.UpsamplingBilinear2d(scale_factor = 4)
634 | self.conv2048_1024 = conv3x3_bn_relu(2048, 1024)
635 | self.conv1024_512 = conv3x3_bn_relu(1024, 512)
636 | self.conv512_256 = conv3x3_bn_relu(512, 256)
637 | self.conv256_32 = conv3x3_bn_relu(256, 32)
638 | self.conv64_1 = conv3x3(64, 1)
639 |
640 | self.edge_layer = Edge_Module()
641 | self.edge_feature = conv3x3_bn_relu(1, 32)
642 | self.fuse_edge_sal = conv3x3(32, 1)
643 | self.up_edge = nn.Sequential(
644 | nn.UpsamplingBilinear2d(scale_factor = 4),
645 | conv3x3(32, 1)
646 | )
647 |
648 | self.relu = nn.ReLU(True)
649 | def forward(self,x ,d):
650 | rgb_list = self.rgb_swin(x)
651 | depth_list = self.depth_swin(d)
652 |
653 | r4 = rgb_list[0]
654 | r3 = rgb_list[1]
655 | r2 = rgb_list[2]
656 | r1 = rgb_list[3]
657 | d4 = depth_list[0]
658 | d3 = depth_list[1]
659 | d2 = depth_list[2]
660 | d1 = depth_list[3]
661 |
662 | fr1, fd1 = self.fuse_enhance1(r1, d1)
663 | fr2, fd2 = self.fuse_enhance2(r2, d2)
664 | fr3, fd3 = self.fuse_enhance3(r3, d3)
665 | fr4, fd4 = self.fuse_enhance4(r4, d4)
666 |
667 | mul_fea1 = fr1 * fd1
668 | add_fea1 = fr1 + fd1
669 | fuse_fea1 = torch.cat((mul_fea1, add_fea1), dim=1)
670 | fuse_fea1 = self.up2(fuse_fea1)
671 | fuse_fea1 = self.conv2048_1024(fuse_fea1)
672 |
673 | mul_fea2 = fr2 * fd2
674 | add_fea2 = fr2 + fd2
675 | fuse_fea2 = torch.cat((mul_fea2, add_fea2), dim=1)
676 |
677 | fuse_fea2 = fuse_fea2 + fuse_fea1
678 | fuse_fea2 = self.up2(fuse_fea2)
679 | fuse_fea2 = self.conv1024_512(fuse_fea2)
680 |
681 | mul_fea3 = fr3 * fd3
682 | add_fea3 = fr3 + fd3
683 | fuse_fea3 = torch.cat((mul_fea3, add_fea3), dim=1)
684 | fuse_fea3 = fuse_fea3 + fuse_fea2
685 | fuse_fea3 = self.up2(fuse_fea3)
686 | fuse_fea3 = self.conv512_256(fuse_fea3)
687 |
688 | mul_fea4 = fr4 * fd4
689 | add_fea4 = fr4 + fd4
690 | fuse_fea4 = torch.cat((mul_fea4, add_fea4), dim=1)
691 | fuse_fea4 = fuse_fea4 + fuse_fea3
692 |
693 | edge_map = self.edge_layer(d4, d3, d2)
694 | edge_feature = self.edge_feature(edge_map)
695 | end_sal = self.conv256_32(fuse_fea4)
696 | up_edge = self.up_edge(end_sal)
697 | out = self.relu(torch.cat((end_sal, edge_feature), dim=1))
698 | out = self.up4(out)
699 | sal_out = self.conv64_1(out)
700 |
701 | return sal_out, up_edge
702 |
703 | def load_pre(self, pre_model):
704 | self.rgb_swin.load_state_dict(torch.load(pre_model)['model'],strict=False)
705 | print(f"RGB SwinTransformer loading pre_model ${pre_model}")
706 | self.depth_swin.load_state_dict(torch.load(pre_model)['model'], strict=False)
707 | print(f"Depth SwinTransformer loading pre_model ${pre_model}")
708 |
709 |
710 | class fuse_enhance(nn.Module):
711 | def __init__(self, infeature):
712 | super(fuse_enhance, self).__init__()
713 | self.depth_channel_attention = ChannelAttention(infeature)
714 | self.rgb_channel_attention = ChannelAttention(infeature)
715 | self.rd_spatial_attention = SpatialAttention()
716 | self.rgb_spatial_attention = SpatialAttention()
717 | self.depth_spatial_attention = SpatialAttention()
718 |
719 | def forward(self,r,d):
720 | assert r.shape == d.shape,"rgb and depth should have same size"
721 | mul_fuse = r * d
722 | sa = self.rd_spatial_attention(mul_fuse)
723 | r_f = r * sa
724 | d_f = d * sa
725 | r_ca = self.rgb_channel_attention(r_f)
726 | d_ca = self.depth_channel_attention(d_f)
727 |
728 | r_out = r * r_ca
729 | d_out = d * d_ca
730 | return r_out, d_out
731 |
732 | class CALayer(nn.Module):
733 | def __init__(self, channel, reduction=16):
734 | super(CALayer, self).__init__()
735 | # global average pooling: feature --> point
736 | self.avg_pool = nn.AdaptiveAvgPool2d(1)
737 | # feature channel downscale and upscale --> channel weight
738 | self.conv_du = nn.Sequential(
739 | nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=True),
740 | nn.ReLU(inplace=True),
741 | nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=True),
742 | nn.Sigmoid()
743 | )
744 |
745 | def forward(self, x):
746 | y = self.avg_pool(x)
747 | y = self.conv_du(y)
748 | return x * y
749 |
750 | ## Residual Channel Attention Block (RCAB)
751 | class RCAB(nn.Module):
752 | def __init__(
753 | self, n_feat, kernel_size=3, reduction=16,
754 | bias=True, bn=False, act=nn.ReLU(True), res_scale=1):
755 |
756 | super(RCAB, self).__init__()
757 | modules_body = []
758 | for i in range(2):
759 | modules_body.append(self.default_conv(n_feat, n_feat, kernel_size, bias=bias))
760 | if bn: modules_body.append(nn.BatchNorm2d(n_feat))
761 | if i == 0: modules_body.append(act)
762 | modules_body.append(CALayer(n_feat, reduction))
763 | self.body = nn.Sequential(*modules_body)
764 | self.res_scale = res_scale
765 |
766 | def default_conv(self, in_channels, out_channels, kernel_size, bias=True):
767 | return nn.Conv2d(in_channels, out_channels, kernel_size,padding=(kernel_size // 2), bias=bias)
768 |
769 | def forward(self, x):
770 | res = self.body(x)
771 | #res = self.body(x).mul(self.res_scale)
772 | res += x
773 | return res
774 |
775 | class Edge_Module(nn.Module):
776 | def __init__(self, in_fea=[128, 256, 512], mid_fea=32):
777 | super(Edge_Module, self).__init__()
778 | self.relu = nn.ReLU(inplace=True)
779 | self.conv2 = nn.Conv2d(in_fea[0], mid_fea, 1)
780 | self.conv4 = nn.Conv2d(in_fea[1], mid_fea, 1)
781 | self.conv5 = nn.Conv2d(in_fea[2], mid_fea, 1)
782 | self.conv5_2 = nn.Conv2d(mid_fea, mid_fea, 3, padding=1)
783 | self.conv5_4 = nn.Conv2d(mid_fea, mid_fea, 3, padding=1)
784 | self.conv5_5 = nn.Conv2d(mid_fea, mid_fea, 3, padding=1)
785 | self.up2 = nn.UpsamplingBilinear2d(scale_factor=2)
786 | self.classifer = nn.Conv2d(mid_fea * 3, 1, kernel_size=3, padding=1)
787 | self.rcab = RCAB(mid_fea * 3)
788 |
789 | def forward(self, x2, x4, x5):
790 | _, _, h, w = x2.size()
791 | edge2_fea = self.relu(self.conv2(x2))
792 | edge2 = self.relu(self.conv5_2(edge2_fea))
793 | edge4_fea = self.relu(self.conv4(x4))
794 | edge4 = self.relu(self.conv5_4(edge4_fea))
795 | edge5_fea = self.relu(self.conv5(x5))
796 | edge5 = self.relu(self.conv5_5(edge5_fea))
797 |
798 | edge4 = F.interpolate(edge4, size=(h, w), mode='bilinear', align_corners=True)
799 | edge5 = F.interpolate(edge5, size=(h, w), mode='bilinear', align_corners=True)
800 |
801 | edge = torch.cat([edge2, edge4, edge5], dim=1)
802 | edge = self.rcab(edge)
803 | edge = self.classifer(edge)
804 | return edge
805 |
806 | class ChannelAttention(nn.Module):
807 | def __init__(self, in_planes, ratio=16):
808 | super(ChannelAttention, self).__init__()
809 |
810 | self.max_pool = nn.AdaptiveMaxPool2d(1)
811 |
812 | self.fc1 = nn.Conv2d(in_planes, in_planes // 16, 1, bias=False)
813 | self.relu1 = nn.ReLU()
814 | self.fc2 = nn.Conv2d(in_planes // 16, in_planes, 1, bias=False)
815 |
816 | self.sigmoid = nn.Sigmoid()
817 |
818 | def forward(self, x):
819 | max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x))))
820 | out = max_out
821 | return self.sigmoid(out)
822 |
823 |
824 | class SpatialAttention(nn.Module):
825 | def __init__(self, kernel_size=7):
826 | super(SpatialAttention, self).__init__()
827 |
828 | assert kernel_size in (3, 7), 'kernel size must be 3 or 7'
829 | padding = 3 if kernel_size == 7 else 1
830 |
831 | self.conv1 = nn.Conv2d(1, 1, kernel_size, padding=padding, bias=False)
832 | self.sigmoid = nn.Sigmoid()
833 |
834 | def forward(self, x):
835 | max_out, _ = torch.max(x, dim=1, keepdim=True)
836 | x = max_out
837 | x = self.conv1(x)
838 | return self.sigmoid(x)
839 |
840 |
841 | if __name__ == '__main__':
842 | os.environ["CUDA_VISIBLE_DEVICES"] = "1"
843 | pre_path = '../Pre_train/swin_base_patch4_window7_224.pth'
844 | a = torch.randn(1,3,224,224)
845 | b = torch.randn(1,3,224,224)
846 | # net = SwinTransformer(embed_dim=128, depths=[2, 2, 18, 2], num_heads=[4, 8, 16, 32])
847 | net = SwinNet()
848 | out = net(a,b)
849 | # print(out.shape)
850 | for i in out:
851 | print(i.shape)
852 | # model = SwinNet()
853 | # out = model(a,b)
854 | # for i in out:
855 | # print(i.shape)
856 | # model.load_state_dict(torch.load(r"D:\tanyacheng\Experiments\SOD\Transformer_Saliency\Swin\Saliency\Swin-Transformer-Saliency_v19\SwinTransNet_RGBD_cpts\SwinTransNet_epoch_best.pth"), strict=True)
857 |
--------------------------------------------------------------------------------
/networks/SwinNet/optimizer.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | @author: caigentan@AnHui University
4 | @software: PyCharm
5 | @file: optimizer.py
6 | @time: 2021/5/6 15:58
7 | """
8 |
9 | from torch import optim as optim
10 |
11 |
12 | def build_optimizer(config, model):
13 | """
14 | Build optimizer, set weight decay of normalization to 0 by default.
15 | """
16 | skip = {}
17 | skip_keywords = {}
18 | if hasattr(model, 'no_weight_decay'):
19 | skip = model.no_weight_decay()
20 | if hasattr(model, 'no_weight_decay_keywords'):
21 | skip_keywords = model.no_weight_decay_keywords()
22 | parameters = set_weight_decay(model, skip, skip_keywords)
23 |
24 | opt_lower = config.TRAIN.OPTIMIZER.NAME.lower()
25 | optimizer = None
26 | if opt_lower == 'sgd':
27 | optimizer = optim.SGD(parameters, momentum=config.TRAIN.OPTIMIZER.MOMENTUM, nesterov=True,
28 | lr=config.TRAIN.BASE_LR, weight_decay=config.TRAIN.WEIGHT_DECAY)
29 | elif opt_lower == 'adamw':
30 | optimizer = optim.AdamW(parameters, eps=config.TRAIN.OPTIMIZER.EPS, betas=config.TRAIN.OPTIMIZER.BETAS,
31 | lr=config.TRAIN.BASE_LR, weight_decay=config.TRAIN.WEIGHT_DECAY)
32 |
33 | return optimizer
34 |
35 |
36 | def set_weight_decay(model, skip_list=(), skip_keywords=()):
37 | has_decay = []
38 | no_decay = []
39 |
40 | for name, param in model.named_parameters():
41 | if not param.requires_grad:
42 | continue # frozen weights
43 | if len(param.shape) == 1 or name.endswith(".bias") or (name in skip_list) or \
44 | check_keywords_in_name(name, skip_keywords):
45 | no_decay.append(param)
46 | # print(f"{name} has no weight decay")
47 | else:
48 | has_decay.append(param)
49 | return [{'params': has_decay},
50 | {'params': no_decay, 'weight_decay': 0.}]
51 |
52 |
53 | def check_keywords_in_name(name, keywords=()):
54 | isin = False
55 | for keyword in keywords:
56 | if keyword in name:
57 | isin = True
58 | return isin
--------------------------------------------------------------------------------
/networks/SwinNet/options.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | @author: caigentan@AnHui University
4 | @software: PyCharm
5 | @file: options.py
6 | @time: 2021/5/16 14:52
7 | """
8 |
9 | import argparse
10 |
11 | parser = argparse.ArgumentParser()
12 | parser.add_argument('--epoch', type=int, default=300, help='epoch number')
13 | parser.add_argument('--lr', type=float, default=5e-5, help='learning rate')
14 | parser.add_argument('--batchsize', type=int, default=3, help='training batch size')
15 | parser.add_argument('--trainsize', type=int, default=384, help='training dataset size')
16 | parser.add_argument('--clip', type=float, default=0.5, help='gradient clipping margin')
17 | parser.add_argument('--decay_rate', type=float, default=0.1, help='decay rate of learning rate')
18 | parser.add_argument('--decay_epoch', type=int, default=100, help='every n epochs decay learning rate')
19 | # parser.add_argument('--load', type=str, default='./pre_train/swin_base_patch4_window7_224.pth', help='train from checkpoints')
20 | parser.add_argument('--load', type=str, default='./Pre_train/swin_base_patch4_window12_384_22k.pth', help='train from checkpoints')
21 | parser.add_argument('--gpu_id', type=str, default='1', help='train use gpu')
22 |
23 | # RGB-D Datasets
24 | parser.add_argument('--rgb_root', type=str, default='./datasets/train/RGB/', help='the training RGB images root')
25 | parser.add_argument('--depth_root', type=str, default='./datasets/train/Depth/', help='the training Depth images root')
26 | parser.add_argument('--gt_root', type=str, default='./datasets/train/GT/', help='the training GT images root')
27 | parser.add_argument('--edge_root', type=str, default='./datasets/train/Edge/', help='the training Edge images root')
28 | parser.add_argument('--val_rgb_root', type=str, default='./datasets/RGB-D/validation/RGB/', help='the validation RGB images root')
29 | parser.add_argument('--val_depth_root', type=str, default='./datasets/RGB-D/validation/Depth/', help='the validation Depth images root')
30 | parser.add_argument('--val_gt_root', type=str, default='./datasets/RGB-D/validation/GT/', help='the test validation GT images root')
31 |
32 | # RGB-T Datasets
33 | """
34 | parser.add_argument('--rgb_root', type=str, default='./datasets/RGB-T/train/RGB/', help='the training RGB images root')
35 | parser.add_argument('--depth_root', type=str, default='./datasets/RGB-T/train/T/', help='the training Thermal images root')
36 | parser.add_argument('--gt_root', type=str, default='./datasets/RGB-T/train/GT/', help='the training GT images root')
37 | parser.add_argument('--edge_root', type=str, default='./datasets/RGB-T/train/Edge/', help='the training Edge images root')
38 |
39 | parser.add_argument('--val_rgb_root', type=str, default='./datasets/RGB-T/validation/RGB/', help='the validation RGB images root')
40 | parser.add_argument('--val_depth_root', type=str, default='./datasets/RGB-T/validation/T/', help='the validation Thermal images root')
41 | parser.add_argument('--val_gt_root', type=str, default='./datasets/RGB-T/validation/GT/', help='the validation GT images root')
42 | """
43 | parser.add_argument('--save_path', type=str, default='./cpts/', help='the path to save models and logs')
44 |
45 | opt = parser.parse_args()
--------------------------------------------------------------------------------
/networks/SwinNet/test_gray_to_rgb.py:
--------------------------------------------------------------------------------
1 | from PIL import Image
2 | import cv2
3 | # # def rgb_loader(path):
4 | # # with open(path, 'rb') as f:
5 | # # img = Image.open(f)
6 | # # return img.convert('RGB')
7 | root = '/home/sunfan/Desktop/colorization-master/colorization-master/imgs/ansel_adams3.jpg'
8 | save_path = '/home/sunfan/Desktop/save1.png'
9 | img_gray = cv2.imread(root, flags = 0)
10 | img2 = cv2.cvtColor(img_gray, cv2.COLOR_GRAY2BGR)
11 | cv2.imwrite(save_path, img2)
12 | cv2.imshow('test', img2)
13 | # #
14 | # #
15 | # # cv2.waitKey(0)
16 | # #
17 | # # cv2.destroyAllWindows()
18 | #
19 | # # img = Image.open(root)
20 | # # rgb = img.convert('RGB')
21 | #
22 | # # cv2.imwrite(save_path,img2)
23 | #
24 | #
25 | #
26 | # import numpy as np
27 | # import cv2
28 | # root = '/home/sunfan/Desktop/colorization-master/colorization-master/imgs/ansel_adams3.jpg'
29 | # root = '/home/sunfan/Downloads/newdata/train/DUT_NJUNLPR/depth/1_02-02-40.png'
30 | #
31 | # save_path = '/home/sunfan/Desktop/c.png'
32 | # src_gray = cv2.imread(root,flags = 0)
33 | # print(src_gray.shape)
34 | # src = cv2.cvtColor(src_gray, cv2.COLOR_GRAY2BGR)
35 | # print(src.shape)
36 | # # RGB在opencv中存储为BGR的顺序,数据结构为一个3D的numpy.array,索引的顺序是行,列,通道:
37 | # B = src[:,:,0]
38 | # G = src[:,:,1]
39 | # R = src[:,:,2]
40 | # # 灰度g=p*R+q*G+t*B(其中p=0.2989,q=0.5870,t=0.1140),于是B=(g-p*R-q*G)/t。于是我们只要保留R和G两个颜色分量,再加上灰度图g,就可以回复原来的RGB图像。
41 | # g = src_gray[:]
42 | #
43 | # p = 1; q = 1; t = 1
44 | # B_new = (g-p*R-q*G)/t
45 | # G_new1 = (g-p*R+q*G)/t
46 | # R_new = R
47 | #
48 | # # B_new = np.uint8(B_new)
49 | # src_new = np.zeros((src.shape)).astype("uint8")
50 | # src_new[:,:,0] = B_new
51 | # src_new[:,:,1] = G_new1
52 | # src_new[:,:,2] = R_new
53 | # # 显示图像
54 | # cv2.imshow("input", src_gray)
55 | # cv2.imshow("output", src)
56 | # cv2.imshow("result", src_new)
57 | # cv2.imwrite(save_path, src_new)
58 | # cv2.waitKey(0)
59 | # cv2.destroyAllWindows()
60 | #
61 |
62 | # Show images
63 |
--------------------------------------------------------------------------------
/networks/SwinNet/utils.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | @author: caigentan@AnHui University
4 | @software: PyCharm
5 | @file: utils.py
6 | @time: 2021/5/6 16:16
7 | """
8 | def clip_gradient(optimizer, grad_clip):
9 | for group in optimizer.param_groups:
10 | for param in group['params']:
11 | if param.grad is not None:
12 | param.grad.data.clamp_(-grad_clip, grad_clip)
13 |
14 |
15 | def adjust_lr(optimizer, init_lr, epoch, decay_rate=0.1, decay_epoch=30):
16 | decay = decay_rate ** (epoch // decay_epoch)
17 | for param_group in optimizer.param_groups:
18 | param_group['lr'] = decay*init_lr
19 | lr=param_group['lr']
20 | return lr
21 |
--------------------------------------------------------------------------------
/networks/Wavenet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from networks.wavemlp import WaveMLP_S
4 | from timm.models.layers import DropPath
5 | from pytorch_wavelets import DTCWTForward, DTCWTInverse
6 | from torch.nn.functional import kl_div
7 |
8 | class BasicConv2d(nn.Module):
9 | def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
10 | super(BasicConv2d, self).__init__()
11 | self.conv = nn.Conv2d(in_planes, out_planes,
12 | kernel_size=kernel_size, stride=stride,
13 | padding=padding, dilation=dilation, bias=False)
14 | self.bn = nn.BatchNorm2d(out_planes)
15 | self.relu = nn.ReLU(inplace=True)
16 |
17 | def forward(self, x):
18 | x = self.conv(x)
19 | x = self.bn(x)
20 | x = self.relu(x)
21 | return x
22 |
23 | class TransBasicConv2d(nn.Module):
24 | def __init__(self, in_planes, out_planes, kernel_size=2, stride=2, padding=0, dilation=1,output_padding=0, bias=False):
25 | super(TransBasicConv2d, self).__init__()
26 | self.Deconv = nn.ConvTranspose2d(in_planes, out_planes,
27 | kernel_size=kernel_size, stride=stride,
28 | padding=padding,output_padding= output_padding, dilation=dilation, bias=bias)
29 | self.bn = nn.BatchNorm2d(out_planes)
30 | self.relu = nn.ReLU(inplace=True)
31 | self.inch = in_planes
32 | def forward(self, x):
33 |
34 | x = self.Deconv(x)
35 | x = self.bn(x)
36 | x = self.relu(x)
37 | return x
38 |
39 |
40 | class Mlp(nn.Module):
41 | def __init__(self, in_features=64, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
42 | super().__init__()
43 | out_features = out_features or in_features
44 | hidden_features = hidden_features or in_features
45 | self.fc1 = nn.Linear(in_features, hidden_features)
46 | self.act = act_layer()
47 | self.fc2 = nn.Linear(hidden_features, out_features)
48 | self.drop = nn.Dropout(drop)
49 |
50 | def forward(self, x):
51 | # print('x',x.shape)
52 | x = self.fc1(x)
53 | # print('fc',x.shape)
54 | x = self.act(x)
55 | x = self.drop(x)
56 | x = self.fc2(x)
57 | x = self.drop(x)
58 | return x
59 |
60 | class Mlp_wave(nn.Module):
61 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
62 | super().__init__()
63 |
64 | out_features = out_features or in_features
65 | hidden_features = hidden_features or in_features
66 | self.act = act_layer()
67 | self.drop = nn.Dropout(drop)
68 | self.fc1 = nn.Conv2d(in_features, hidden_features, 1, 1)
69 | self.fc2 = nn.Conv2d(hidden_features, out_features, 1, 1)
70 |
71 | def forward(self, x):
72 | x = self.fc1(x)
73 | x = self.act(x)
74 | x = self.drop(x)
75 | x = self.fc2(x)
76 | x = self.drop(x)
77 | return x
78 |
79 |
80 | class BAB_Decoder(nn.Module):
81 | def __init__(self, channel_1=1024, channel_2=512, channel_3=256, dilation_1=3, dilation_2=2):
82 | super(BAB_Decoder, self).__init__()
83 |
84 | self.conv1 = BasicConv2d(channel_1, channel_2, 3, padding=1)
85 | self.conv1_Dila = BasicConv2d(channel_2, channel_2, 3, padding=dilation_1, dilation=dilation_1)
86 | self.conv2 = BasicConv2d(channel_2, channel_2, 3, padding=1)
87 | self.conv2_Dila = BasicConv2d(channel_2, channel_2, 3, padding=dilation_2, dilation=dilation_2)
88 | self.conv3 = BasicConv2d(channel_2, channel_2, 3, padding=1)
89 | self.conv_fuse = BasicConv2d(channel_2*3, channel_3, 3, padding=1)
90 |
91 | def forward(self, x):
92 | x1 = self.conv1(x)
93 | x1_dila = self.conv1_Dila(x1)
94 |
95 | x2 = self.conv2(x1)
96 | x2_dila = self.conv2_Dila(x2)
97 |
98 | x3 = self.conv3(x2)
99 |
100 | x_fuse = self.conv_fuse(torch.cat((x1_dila, x2_dila, x3), 1))
101 |
102 | return x_fuse
103 |
104 | class FFT(nn.Module):
105 | def __init__(self,inchannel,outchannel):
106 | super().__init__()
107 | self.DWT = DTCWTForward(J=3, biort='near_sym_b', qshift='qshift_b')
108 | # self.DWT =DTCWTForward(J=3, include_scale=True)
109 | self.IWT = DTCWTInverse(biort='near_sym_b', qshift='qshift_b')
110 | self.conv1 = BasicConv2d(outchannel, outchannel)
111 | self.conv2 = BasicConv2d(inchannel, outchannel)
112 | self.conv3 = BasicConv2d(outchannel, outchannel)
113 | self.change = TransBasicConv2d(outchannel, outchannel)
114 |
115 | def forward(self, x, y):
116 | y = self.conv2(y)
117 | Xl, Xh = self.DWT(x)
118 | Yl, Yh = self.DWT(y)
119 | x_y = self.conv1(Xl) + self.conv1(Yl)
120 |
121 | x_m = self.IWT((x_y, Xh))
122 | y_m = self.IWT((x_y, Yh))
123 |
124 | out = self.conv3(x_m + y_m)
125 | return out
126 |
127 | class PATM_BAB(nn.Module):
128 | def __init__(self, channel_1=1024, channel_2=512, channel_3=256, dilation_1=3, dilation_2=2):
129 | super().__init__()
130 | self.conv1 = BasicConv2d(channel_1, channel_2, 3, padding=1)
131 | self.conv1_Dila = BasicConv2d(channel_2, channel_2, 3, padding=dilation_1, dilation=dilation_1)
132 | self.conv2 = BasicConv2d(channel_2, channel_2, 3, padding=1)
133 | self.conv2_Dila = BasicConv2d(channel_2, channel_2, 3, padding=dilation_2, dilation=dilation_2)
134 | self.conv3 = BasicConv2d(channel_2, channel_2, 3, padding=1)
135 | self.conv_fuse = BasicConv2d(channel_2 *2, channel_3, 3, padding=1)
136 | self.drop = nn.Dropout(0.5)
137 | self.conv_last=TransBasicConv2d(channel_3, channel_3, kernel_size=2, stride=2,
138 | padding=0, dilation=1, bias=False)
139 | def forward(self, x):
140 | x1 = self.conv1(x)
141 | x1_dila = self.conv1_Dila(x1)
142 |
143 | x2 = self.conv2(x1)
144 | x2_dila = self.conv2_Dila(x2)
145 |
146 | x3 = self.conv3(x2)
147 | x1_dila = torch.cat([x1_dila * torch.cos(x1_dila), x1_dila * torch.sin(x1_dila)], dim=1)
148 | x2_dila = torch.cat([x2_dila * torch.cos(x2_dila), x2_dila * torch.sin(x2_dila)], dim=1)
149 | x3 = torch.cat([x3 * torch.cos(x3), x3 * torch.sin(x3)], dim=1)
150 | # print('x1_dila + x2_dila+x3',x1_dila.shape)
151 | x_fuse = self.conv_fuse(x1_dila + x2_dila +x3)
152 | # x_fuse = self.conv_fuse(torch.cat((x1_dila, x2_dila, x3), 1))
153 | # print('x_f',x_fuse.shape)
154 | x_fuse= self.drop(x_fuse)
155 | # print()
156 | x_fuse = self.conv_last(x_fuse)
157 | return x_fuse
158 |
159 | class DWT(nn.Module):
160 | def __init__(self, inchannel,outchannel):
161 | super(DWT, self).__init__()
162 | self.DWT = DTCWTForward(J=3, biort='near_sym_b', qshift='qshift_b')
163 | self.IWT = DTCWTInverse(biort='near_sym_b', qshift='qshift_b')
164 | self.conv1 = BasicConv2d(outchannel,outchannel)
165 | self.conv2 = BasicConv2d(inchannel, outchannel)
166 | self.conv3 = BasicConv2d(outchannel, outchannel)
167 | self.change = TransBasicConv2d(outchannel,outchannel)
168 | def forward(self, x, y):
169 | # print('x',x.shape)
170 | y = self.change(self.conv2(y))
171 | # print('y',y.shape)
172 | Xl, Xh = self.DWT(x)
173 | Yl, Yh = self.DWT(y)
174 | # print('Xl',Xl.shape)
175 | x_y = self.conv1(Xl)+self.conv1(Yl)
176 | # print('x_y',x_y.shape)
177 | # print('Xh',Xh.shape)
178 | # print('Yh',Yh.shape)
179 | x_m = self.IWT((x_y,Xh))
180 | y_m = self.IWT((x_y,Yh))
181 | # print('x_m',x_m.shape)
182 | # print('y_m',y_m.shape)
183 | out = self.conv3(x_m + y_m)
184 | return out
185 | class Edge_Aware(nn.Module):
186 | def __init__(self, ):
187 | super(Edge_Aware, self).__init__()
188 | self.conv1 = TransBasicConv2d(512, 64,kernel_size=4,stride=8,padding=0,dilation=2,output_padding=1)
189 | self.conv2 = TransBasicConv2d(320, 64,kernel_size=2,stride=4,padding=0,dilation=2,output_padding=1)
190 | self.conv3 = TransBasicConv2d(128, 64,kernel_size=2,stride=2,padding=1,dilation=2,output_padding=1)
191 | self.pos_embed = BasicConv2d(64, 64 )
192 | self.pos_embed3 = BasicConv2d(64, 64)
193 | self.conv31 = nn.Conv2d(64,1, kernel_size=1)
194 | self.conv512_64 = TransBasicConv2d(512,64)
195 | self.conv320_64 = TransBasicConv2d(320, 64)
196 | self.conv128_64 = TransBasicConv2d(128, 64)
197 | self.up = nn.Upsample(56)
198 | self.up2 = nn.Upsample(384)
199 | self.norm1 = nn.LayerNorm(64)
200 | self.norm2 = nn.BatchNorm2d(64)
201 | self.drop_path = DropPath(0.3)
202 | self.maxpool =nn.AdaptiveMaxPool2d(1)
203 | # self.qkv = nn.Linear(64, 64 * 3, bias=False)
204 | self.num_heads = 8
205 | self.mlp1 = Mlp(in_features=64, out_features=64)
206 | self.mlp2 = Mlp(in_features=64, out_features=64)
207 | self.mlp3 = Mlp(in_features=64, out_features=64)
208 | def forward(self, x, y, z, v):
209 |
210 |
211 | # v = self.conv1(v)
212 | # z = self.conv2(z)
213 | # y = self.conv3(y)
214 | # print('v',v)
215 | v = self.up(self.conv512_64(v))
216 | z = self.up(self.conv320_64(z))
217 | y = self.up(self.conv128_64(y))
218 | x = self.up(x)
219 |
220 | x_max = self.maxpool(x)
221 | # print('x_max',x_max.shape)
222 | b,_,_,_ = x_max.shape
223 | x_max = x_max.reshape(b, -1)
224 | x_y = self.mlp1(x_max)
225 | # print('s',x_y.shape)
226 | x_z = self.mlp2(x_max)
227 | x_v = self.mlp3(x_max)
228 |
229 | x_y = x_y.reshape(b,64,1,1)
230 | x_z = x_z.reshape(b, 64, 1, 1)
231 | x_v = x_v.reshape(b, 64, 1, 1)
232 | x_y = torch.mul(x_y, y)
233 | x_z = torch.mul(x_z, z)
234 | x_v = torch.mul(x_v, v)
235 |
236 |
237 | # x_mix_1 = torch.cat((x_y,x_z,x_v),dim=1)
238 | x_mix_1 = x_y+ x_z+ x_v
239 | # print('sd',x_mix_1.shape)
240 | x_mix_1 = self.norm2(x_mix_1)
241 | # print('x_mix_1',x_mix_1.shape)
242 | x_mix_1= self.pos_embed3(x_mix_1)
243 | x_mix = self.drop_path(x_mix_1)
244 | x_mix = x_mix_1 + self. pos_embed3(x_mix)
245 | x_mix = self.up2(self.conv31(x_mix))
246 | return x_mix
247 |
248 | class Mutual_info_reg(nn.Module):
249 | def __init__(self, input_channels=64, channels=64, latent_size=6):
250 | super(Mutual_info_reg, self).__init__()
251 | self.soft = torch.nn.Softmax(dim=1)
252 | def forward(self, rgb_feat, depth_feat):
253 |
254 | # print('rgb_feat',rgb_feat.shape)
255 | # print('depth_feat', depth_feat.shape)
256 | rgb_feat = self.soft(rgb_feat)
257 | depth_feat = self.soft(depth_feat)
258 | #
259 | # print('rgb_feat',rgb_feat.shape)
260 | # print('depth_feat', depth_feat.shape)
261 | return kl_div(rgb_feat.log(), depth_feat)
262 |
263 | class WaveNet(nn.Module):
264 | def __init__(self, channel=32):
265 | super(WaveNet, self).__init__()
266 |
267 |
268 | self.encoderR = WaveMLP_S()
269 | # Lateral layers
270 | self.lateral_conv0 = BasicConv2d(64, 64, 3, stride=1, padding=1)
271 | self.lateral_conv1 = BasicConv2d(128, 64, 3, stride=1, padding=1)
272 | self.lateral_conv2 = BasicConv2d(320, 128, 3, stride=1, padding=1)
273 | self.lateral_conv3 = BasicConv2d(512, 320, 3, stride=1, padding=1)
274 |
275 |
276 | self.FFT1 = FFT(64,64)
277 | self.FFT2 = FFT(128,128)
278 | self.FFT3 = FFT(320,320)
279 | self.FFT4 = FFT(512,512)
280 |
281 |
282 |
283 | self.conv512_64 = BasicConv2d(512, 64)
284 | self.conv320_64 = BasicConv2d(320, 64)
285 | self.conv128_64 = BasicConv2d(128, 64)
286 | self.sigmoid = nn.Sigmoid()
287 | self.S4 = nn.Conv2d(512, 1, 3, stride=1, padding=1)
288 | self.S3 = nn.Conv2d(320, 1, 3, stride=1, padding=1)
289 | self.S2 = nn.Conv2d(128, 1, 3, stride=1, padding=1)
290 | self.S1 = nn.Conv2d(64, 1, 3, stride=1, padding=1)
291 | self.up1 = nn.Upsample(384)
292 | self.up2 = nn.Upsample(384)
293 | self.up3 = nn.Upsample(384)
294 | self.up_loss = nn.Upsample(92)
295 | # Mutual_info_reg1
296 | self.mi_level1 = Mutual_info_reg(64, 64, 6)
297 | self.mi_level2 = Mutual_info_reg(64, 64, 6)
298 | self.mi_level3 = Mutual_info_reg(64, 64, 6)
299 | self.mi_level4 = Mutual_info_reg(64, 64, 6)
300 |
301 | self.edge = Edge_Aware()
302 | self.PATM4 = PATM_BAB(512, 512, 512, 3, 2)
303 | self.PATM3 = PATM_BAB(832, 512, 320, 3, 2)
304 | self.PATM2 = PATM_BAB(448, 256, 128, 5, 3)
305 | self.PATM1 = PATM_BAB(192, 128, 64, 5, 3)
306 |
307 | def forward(self, x_rgb,x_thermal):
308 | x0,x1,x2,x3 = self.encoderR(x_rgb)
309 | y0, y1, y2, y3 = self.encoderR(x_thermal)
310 |
311 | x2_ACCoM = self.FFT1(x0, y0)
312 | x3_ACCoM = self.FFT2(x1, y1)
313 | x4_ACCoM = self.FFT3(x2, y2)
314 | x5_ACCoM = self.FFT4(x3, y3)
315 |
316 | edge = self.edge(x2_ACCoM, x3_ACCoM, x4_ACCoM, x5_ACCoM)
317 |
318 | mer_cros4 = self.PATM4(x5_ACCoM)
319 | m4 = torch.cat((mer_cros4,x4_ACCoM),dim=1)
320 | mer_cros3 = self.PATM3(m4)
321 | m3 = torch.cat((mer_cros3, x3_ACCoM), dim=1)
322 | mer_cros2 = self.PATM2(m3)
323 | m2 = torch.cat((mer_cros2, x2_ACCoM), dim=1)
324 | mer_cros1 = self.PATM1(m2)
325 |
326 | s1 = self.up1(self.S1(mer_cros1))
327 | s2 = self.up2(self.S2(mer_cros2))
328 | s3 = self.up3(self.S3(mer_cros3))
329 | s4 = self.up3(self.S4(mer_cros4))
330 |
331 | x_loss0 = x0
332 | y_loss0 = y0
333 | x_loss1 = self.up_loss(self.conv128_64(x1))
334 | y_loss1 = self.up_loss(self.conv128_64(y1))
335 | x_loss2 = self.up_loss(self.conv320_64(x2))
336 | y_loss2 = self.up_loss(self.conv320_64(y2))
337 | x_loss3 = self.up_loss(self.conv512_64(x3))
338 | y_loss3 = self.up_loss(self.conv512_64(y3))
339 |
340 | lat_loss0 = self.mi_level1(x_loss0, y_loss0)
341 | lat_loss1 = self.mi_level2(x_loss1, y_loss1)
342 | lat_loss2 = self.mi_level3(x_loss2, y_loss2)
343 | lat_loss3 = self.mi_level4(x_loss3, y_loss3)
344 | lat_loss = lat_loss0 + lat_loss1 + lat_loss2 + lat_loss3
345 | return s1, s2, s3, s4, self.sigmoid(s1), self.sigmoid(s2), self.sigmoid(s3), self.sigmoid(s4),edge,lat_loss
346 |
347 | if __name__=='__main__':
348 | image = torch.randn(1, 3, 384, 384).cuda(0)
349 | ndsm = torch.randn(1, 64, 56, 56)
350 | ndsm1 = torch.randn(1, 128, 28, 28)
351 | ndsm2 = torch.randn(1, 320, 14, 14)
352 | ndsm3 = torch.randn(1, 512, 7, 7)
353 |
354 | net = WaveNet().cuda()
355 |
356 |
--------------------------------------------------------------------------------
/networks/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nowander/WaveNet/7e1343bcc5b583f1aaa8d581d0ea3dbf101d9011/networks/__init__.py
--------------------------------------------------------------------------------
/networks/__pycache__/Wavenet.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nowander/WaveNet/7e1343bcc5b583f1aaa8d581d0ea3dbf101d9011/networks/__pycache__/Wavenet.cpython-36.pyc
--------------------------------------------------------------------------------
/networks/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nowander/WaveNet/7e1343bcc5b583f1aaa8d581d0ea3dbf101d9011/networks/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/networks/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nowander/WaveNet/7e1343bcc5b583f1aaa8d581d0ea3dbf101d9011/networks/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/networks/__pycache__/wavemlp.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nowander/WaveNet/7e1343bcc5b583f1aaa8d581d0ea3dbf101d9011/networks/__pycache__/wavemlp.cpython-36.pyc
--------------------------------------------------------------------------------
/networks/wavemlp.py:
--------------------------------------------------------------------------------
1 | # 2022.02.14-Changed for main script for wavemlp model
2 | # Huawei Technologies Co., Ltd.
3 | """ Vision Transformer (ViT) in PyTorch
4 |
5 | A PyTorch implement of Vision Transformers as described in
6 | 'An Image Is Worth 16 x 16 Words: Transformers for Image Recognition at Scale' - https://arxiv.org/abs/2010.11929
7 |
8 | The official jax code is released and available at https://github.com/google-research/vision_transformer
9 |
10 | Status/TODO:
11 | * Models updated to be compatible with official impl. Args added to support backward compat for old PyTorch weights.
12 | * Weights ported from official jax impl for 384x384 base and small models, 16x16 and 32x32 patches.
13 | * Trained (supervised on ImageNet-1k) my custom 'small' patch model to 77.9, 'base' to 79.4 top-1 with this code.
14 | * Hopefully find time and GPUs for SSL or unsupervised pretraining on OpenImages w/ ImageNet fine-tune in future.
15 |
16 | Acknowledgments:
17 | * The paper authors for releasing code and weights, thanks!
18 | * I fixed my class token impl based on Phil Wang's https://github.com/lucidrains/vit-pytorch ... check it out
19 | for some einops/einsum fun
20 | * Simple transformer style inspired by Andrej Karpathy's https://github.com/karpathy/minGPT
21 | * Bert reference code checks against Huggingface Transformers and Tensorflow Bert
22 |
23 | Hacked together by / Copyright 2020 Ross Wightman
24 | """
25 |
26 | import os
27 | import torch
28 | import torch.nn as nn
29 |
30 | from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
31 | from timm.models.layers import DropPath, trunc_normal_
32 | from timm.models.registry import register_model
33 | from timm.models.layers.helpers import to_2tuple
34 |
35 | import math
36 | from torch import Tensor
37 | from torch.nn import init
38 | from torch.nn.modules.utils import _pair
39 | import torch.nn.functional as F
40 |
41 |
42 | def _cfg(url='', **kwargs):
43 | return {
44 | 'url': url,
45 | 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
46 | 'crop_pct': .96, 'interpolation': 'bicubic',
47 | 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 'classifier': 'head',
48 | **kwargs
49 | }
50 |
51 | default_cfgs = {
52 | 'wave_T': _cfg(crop_pct=0.9),
53 | 'wave_S': _cfg(crop_pct=0.9),
54 | 'wave_M': _cfg(crop_pct=0.9),
55 | 'wave_B': _cfg(crop_pct=0.875),
56 | }
57 |
58 | class Mlp(nn.Module):
59 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
60 | super().__init__()
61 |
62 | out_features = out_features or in_features
63 | hidden_features = hidden_features or in_features
64 | self.act = act_layer()
65 | self.drop = nn.Dropout(drop)
66 | self.fc1 = nn.Conv2d(in_features, hidden_features, 1, 1)
67 | self.fc2 = nn.Conv2d(hidden_features, out_features, 1, 1)
68 |
69 | def forward(self, x):
70 | x = self.fc1(x)
71 | x = self.act(x)
72 | x = self.drop(x)
73 | x = self.fc2(x)
74 | x = self.drop(x)
75 | return x
76 |
77 |
78 | class PATM(nn.Module):
79 | def __init__(self, dim, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.,mode='fc'):
80 | super().__init__()
81 |
82 |
83 | self.fc_h = nn.Conv2d(dim, dim, 1, 1,bias=qkv_bias)
84 | self.fc_w = nn.Conv2d(dim, dim, 1, 1,bias=qkv_bias)
85 | self.fc_c = nn.Conv2d(dim, dim, 1, 1,bias=qkv_bias)
86 |
87 | self.tfc_h = nn.Conv2d(2*dim, dim, (1,7), stride=1, padding=(0,7//2), groups=dim, bias=False)
88 | self.tfc_w = nn.Conv2d(2*dim, dim, (7,1), stride=1, padding=(7//2,0), groups=dim, bias=False)
89 | self.reweight = Mlp(dim, dim // 4, dim * 3)
90 | self.proj = nn.Conv2d(dim, dim, 1, 1,bias=True)
91 | self.proj_drop = nn.Dropout(proj_drop)
92 | self.mode=mode
93 |
94 | if mode=='fc':
95 | self.theta_h_conv=nn.Sequential(nn.Conv2d(dim, dim, 1, 1,bias=True),nn.BatchNorm2d(dim),nn.ReLU())
96 | self.theta_w_conv=nn.Sequential(nn.Conv2d(dim, dim, 1, 1,bias=True),nn.BatchNorm2d(dim),nn.ReLU())
97 | else:
98 | self.theta_h_conv=nn.Sequential(nn.Conv2d(dim, dim, 3, stride=1, padding=1, groups=dim, bias=False),nn.BatchNorm2d(dim),nn.ReLU())
99 | self.theta_w_conv=nn.Sequential(nn.Conv2d(dim, dim, 3, stride=1, padding=1, groups=dim, bias=False),nn.BatchNorm2d(dim),nn.ReLU())
100 |
101 |
102 |
103 | def forward(self, x):
104 |
105 | B, C, H, W = x.shape
106 | theta_h=self.theta_h_conv(x)
107 | theta_w=self.theta_w_conv(x)
108 |
109 | x_h=self.fc_h(x)
110 | x_w=self.fc_w(x)
111 | x_h=torch.cat([x_h*torch.cos(theta_h),x_h*torch.sin(theta_h)],dim=1)
112 | x_w=torch.cat([x_w*torch.cos(theta_w),x_w*torch.sin(theta_w)],dim=1)
113 |
114 | # x_1=self.fc_h(x)
115 | # x_2=self.fc_w(x)
116 | # x_h=torch.cat([x_1*torch.cos(theta_h),x_2*torch.sin(theta_h)],dim=1)
117 | # x_w=torch.cat([x_1*torch.cos(theta_w),x_2*torch.sin(theta_w)],dim=1)
118 |
119 | h = self.tfc_h(x_h)
120 | w = self.tfc_w(x_w)
121 | c = self.fc_c(x)
122 | a = F.adaptive_avg_pool2d(h + w + c,output_size=1)
123 | a = self.reweight(a).reshape(B, C, 3).permute(2, 0, 1).softmax(dim=0).unsqueeze(-1).unsqueeze(-1)
124 | x = h * a[0] + w * a[1] + c * a[2]
125 | x = self.proj(x)
126 | x = self.proj_drop(x)
127 | return x
128 |
129 | class WaveBlock(nn.Module):
130 |
131 | def __init__(self, dim, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
132 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.BatchNorm2d, mode='fc'):
133 | super().__init__()
134 | self.norm1 = norm_layer(dim)
135 | self.attn = PATM(dim, qkv_bias=qkv_bias, qk_scale=None, attn_drop=attn_drop,mode=mode)
136 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
137 | self.norm2 = norm_layer(dim)
138 | mlp_hidden_dim = int(dim * mlp_ratio)
139 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer)
140 |
141 | def forward(self, x):
142 | x = x + self.drop_path(self.attn(self.norm1(x)))
143 | x = x + self.drop_path(self.mlp(self.norm2(x)))
144 | return x
145 |
146 |
147 | class PatchEmbedOverlapping(nn.Module):
148 | def __init__(self, patch_size=16, stride=16, padding=0, in_chans=3, embed_dim=768, norm_layer=nn.BatchNorm2d, groups=1,use_norm=True):
149 | super().__init__()
150 | patch_size = to_2tuple(patch_size)
151 | stride = to_2tuple(stride)
152 | padding = to_2tuple(padding)
153 | self.patch_size = patch_size
154 |
155 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride, padding=padding, groups=groups)
156 | self.norm = norm_layer(embed_dim) if use_norm==True else nn.Identity()
157 |
158 | def forward(self, x):
159 | x = self.proj(x)
160 | x = self.norm(x)
161 | return x
162 |
163 |
164 | class Downsample(nn.Module):
165 | def __init__(self, in_embed_dim, out_embed_dim, patch_size,norm_layer=nn.BatchNorm2d,use_norm=True):
166 | super().__init__()
167 | assert patch_size == 2, patch_size
168 | self.proj = nn.Conv2d(in_embed_dim, out_embed_dim, kernel_size=(3, 3), stride=(2, 2), padding=1)
169 | self.norm = norm_layer(out_embed_dim) if use_norm==True else nn.Identity()
170 | def forward(self, x):
171 | x = self.proj(x)
172 | x = self.norm(x)
173 | return x
174 |
175 |
176 | def basic_blocks(dim, index, layers, mlp_ratio=3., qkv_bias=False, qk_scale=None, attn_drop=0.,
177 | drop_path_rate=0.,norm_layer=nn.BatchNorm2d,mode='fc', **kwargs):
178 | blocks = []
179 | for block_idx in range(layers[index]):
180 | block_dpr = drop_path_rate * (block_idx + sum(layers[:index])) / (sum(layers) - 1)
181 | blocks.append(WaveBlock(dim, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
182 | attn_drop=attn_drop, drop_path=block_dpr, norm_layer=norm_layer,mode=mode))
183 | blocks = nn.Sequential(*blocks)
184 | return blocks
185 |
186 | class WaveNet(nn.Module):
187 | def __init__(self, layers, img_size=224, patch_size=4, in_chans=3, num_classes=1000,
188 | embed_dims=None, transitions=None, mlp_ratios=None,
189 | qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.,
190 | norm_layer=nn.BatchNorm2d, fork_feat=False,mode='fc',ds_use_norm=True,args=None):
191 |
192 | super().__init__()
193 |
194 | if not fork_feat:
195 | self.num_classes = num_classes
196 | self.fork_feat = fork_feat
197 |
198 | self.patch_embed = PatchEmbedOverlapping(patch_size=7, stride=4, padding=2, in_chans=3, embed_dim=embed_dims[0],norm_layer=norm_layer,use_norm=ds_use_norm)
199 |
200 | network = []
201 | for i in range(len(layers)):
202 | stage = basic_blocks(embed_dims[i], i, layers, mlp_ratio=mlp_ratios[i], qkv_bias=qkv_bias,
203 | qk_scale=qk_scale, attn_drop=attn_drop_rate, drop_path_rate=drop_path_rate,
204 | norm_layer=norm_layer,mode=mode)
205 | network.append(stage)
206 | if i >= len(layers) - 1:
207 | break
208 | # transitions = [True, True, True, True]
209 | if transitions[i] or embed_dims[i] != embed_dims[i+1]:
210 | patch_size = 2 if transitions[i] else 1
211 | network.append(Downsample(embed_dims[i], embed_dims[i+1], patch_size,norm_layer=norm_layer,use_norm=ds_use_norm))
212 |
213 | self.network = nn.ModuleList(network)
214 |
215 | if self.fork_feat:
216 | # add a norm layer for each output
217 | self.out_indices = [0, 2, 4, 6]
218 | for i_emb, i_layer in enumerate(self.out_indices):
219 | if i_emb == 0 and os.environ.get('FORK_LAST3', None):
220 | layer = nn.Identity()
221 | else:
222 | layer = norm_layer(embed_dims[i_emb])
223 | layer_name = f'norm{i_layer}'
224 | self.add_module(layer_name, layer)
225 | else:
226 | self.norm = norm_layer(embed_dims[-1])
227 | self.head = nn.Linear(embed_dims[-1], num_classes) if num_classes > 0 else nn.Identity()
228 | self.apply(self.cls_init_weights)
229 |
230 | def cls_init_weights(self, m):
231 | if isinstance(m, nn.Linear):
232 | trunc_normal_(m.weight, std=.02)
233 | if isinstance(m, nn.Linear) and m.bias is not None:
234 | nn.init.constant_(m.bias, 0)
235 | elif isinstance(m, nn.LayerNorm) or isinstance(m, nn.BatchNorm2d):
236 | nn.init.constant_(m.bias, 0)
237 | nn.init.constant_(m.weight, 1.0)
238 |
239 | def init_weights(self, pretrained=None):
240 | """ mmseg or mmdet `init_weight` """
241 |
242 | if isinstance(pretrained, str):
243 | pass;
244 | # logger = get_root_logger()
245 | # load_checkpoint(self, pretrained, map_location='cpu', strict=False, logger=logger)
246 |
247 | def get_classifier(self):
248 | return self.head
249 |
250 | def reset_classifier(self, num_classes, global_pool=''):
251 | self.num_classes = num_classes
252 | self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
253 |
254 | def forward_embeddings(self, x):
255 | x = self.patch_embed(x)
256 | return x
257 |
258 | def forward_tokens(self, x):
259 | outs = []
260 | for idx, block in enumerate(self.network):
261 | x = block(x)
262 | if self.fork_feat and idx in self.out_indices:
263 | norm_layer = getattr(self, f'norm{idx}')
264 | x_out = norm_layer(x)
265 | outs.append(x_out)
266 | # print('x_out',x_out.shape)
267 | # return outs
268 | # print('self.fork_feat',self.fork_feat)
269 | if self.fork_feat:
270 | return outs
271 | return x
272 |
273 | def forward(self, x):
274 | # print('a',x.shape)
275 | x = self.forward_embeddings(x)
276 | # print('b', x.shape)
277 | x = self.forward_tokens(x)
278 | # print('c',len(x))
279 | # if self.fork_feat:
280 | # return x
281 | # x = self.norm(x)
282 | return x
283 | # cls_out = self.head(F.adaptive_avg_pool2d(x,output_size=1).flatten(1))#x.mean(1)
284 | # return cls_out
285 |
286 | def MyNorm(dim):
287 | return nn.GroupNorm(1, dim)
288 |
289 | @register_model
290 | def WaveMLP_T_dw(pretrained=False, **kwargs):
291 | transitions = [True, True, True, True]
292 | layers = [2, 2, 4, 2]
293 | mlp_ratios = [4, 4, 4, 4]
294 | embed_dims = [64, 128, 320, 512]
295 | model = WaveNet(layers, embed_dims=embed_dims, patch_size=7, transitions=transitions,
296 | fork_feat=True,mlp_ratios=mlp_ratios,mode='depthwise', **kwargs)
297 | # print('model',model)
298 | model.default_cfg = default_cfgs['wave_T']
299 | return model
300 |
301 | @register_model
302 | def WaveMLP_T(pretrained=False, **kwargs):
303 | transitions = [True, True, True, True]
304 | layers = [2, 2, 4, 2]
305 | mlp_ratios = [4, 4, 4, 4]
306 | embed_dims = [64, 128, 320, 512]
307 | # model = WaveNet(layers, embed_dims=embed_dims, patch_size=7, transitions=transitions,
308 | # mlp_ratios=mlp_ratios, **kwargs)
309 | model = WaveNet(layers, embed_dims=embed_dims, patch_size=7, transitions=transitions,
310 | fork_feat=True, mlp_ratios=mlp_ratios, **kwargs)
311 | model.default_cfg = default_cfgs['wave_T']
312 | return model
313 |
314 | @register_model
315 | def WaveMLP_S(pretrained=False, **kwargs):
316 | transitions = [True, True, True, True]
317 | layers = [2, 3, 10, 3]
318 | mlp_ratios = [4, 4, 4, 4]
319 | embed_dims = [64, 128, 320, 512]
320 | model = WaveNet(layers, embed_dims=embed_dims, patch_size=7, transitions=transitions,
321 | fork_feat=True, mlp_ratios=mlp_ratios,norm_layer=MyNorm, **kwargs)
322 | model.default_cfg = default_cfgs['wave_S']
323 | return model
324 |
325 | @register_model
326 | def WaveMLP_M(pretrained=False, **kwargs):
327 | transitions = [True, True, True, True]
328 | layers = [3, 4, 18, 3]
329 | mlp_ratios = [8, 8, 4, 4]
330 | embed_dims = [64, 128, 320, 512]
331 | model = WaveNet(layers, embed_dims=embed_dims, patch_size=7, transitions=transitions,
332 | fork_feat=True, mlp_ratios=mlp_ratios,norm_layer=MyNorm,ds_use_norm=False, **kwargs)
333 | model.default_cfg = default_cfgs['wave_M']
334 | return model
335 |
336 | @register_model
337 | def WaveMLP_B(pretrained=False, **kwargs):
338 | transitions = [True, True, True, True]
339 | layers = [2, 2, 18, 2]
340 | mlp_ratios = [4, 4, 4, 4]
341 | embed_dims = [96, 192, 384, 768]
342 | model = WaveNet(layers, embed_dims=embed_dims, patch_size=7, transitions=transitions,
343 | fork_feat=True, mlp_ratios=mlp_ratios,norm_layer=MyNorm,ds_use_norm=False, **kwargs)
344 | model.default_cfg = default_cfgs['wave_B']
345 | return model
346 | if __name__ == '__main__':
347 | image = torch.randn(1, 3, 224, 224)
348 | ndsm = torch.randn(1, 3, 224, 224)
349 | ndsm1 = torch.randn(1, 3, 224, 224)
350 | net = WaveMLP_B()
351 | out = net(image)
352 | # print(out.shape)
353 | for i in out:
354 | print(i.shape)
355 |
356 |
--------------------------------------------------------------------------------
/pytorch_iou/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
3 | import torch
4 | import torch.nn.functional as F
5 | from torch.autograd import Variable
6 | import numpy as np
7 |
8 | def _iou(pred, target, size_average = True):
9 |
10 | b = pred.shape[0]
11 | IoU = 0.0
12 | for i in range(0,b):
13 | #compute the IoU of the foreground
14 | Iand1 = torch.sum(target[i,:,:,:]*pred[i,:,:,:])
15 | Ior1 = torch.sum(target[i,:,:,:]) + torch.sum(pred[i,:,:,:])-Iand1
16 | IoU1 = Iand1/Ior1
17 |
18 | #IoU loss is (1-IoU1)
19 | IoU = IoU + (1-IoU1)
20 |
21 | return IoU/b
22 |
23 | class IOU(torch.nn.Module):
24 | def __init__(self, size_average = True):
25 | super(IOU, self).__init__()
26 | self.size_average = size_average
27 |
28 | def forward(self, pred, target):
29 |
30 | return _iou(pred, target, self.size_average)
31 |
--------------------------------------------------------------------------------
/rgbt_dataset_KD.py:
--------------------------------------------------------------------------------
1 | import os
2 | from PIL import Image
3 | import torch.utils.data as data
4 | import torchvision.transforms as transforms
5 | import random
6 | import numpy as np
7 | from PIL import ImageEnhance
8 |
9 |
10 | # several data augumentation strategies
11 | def cv_random_flip(img, label, depth,bound,gt2):
12 | flip_flag = random.randint(0, 1)
13 | # flip_flag2= random.randint(0,1)
14 | # left right flip
15 | if flip_flag == 1:
16 | img = img.transpose(Image.FLIP_LEFT_RIGHT)
17 | label = label.transpose(Image.FLIP_LEFT_RIGHT)
18 | depth = depth.transpose(Image.FLIP_LEFT_RIGHT)
19 | bound = bound.transpose(Image.FLIP_LEFT_RIGHT)
20 | gt2 = gt2.transpose(Image.FLIP_LEFT_RIGHT)
21 | # top bottom flip
22 | # if flip_flag2==1:
23 | # img = img.transpose(Image.FLIP_TOP_BOTTOM)
24 | # label = label.transpose(Image.FLIP_TOP_BOTTOM)
25 | # depth = depth.transpose(Image.FLIP_TOP_BOTTOM)
26 | return img, label, depth,bound,gt2
27 |
28 |
29 | def randomCrop(image, label, depth,bound,gt2):
30 | border = 30
31 | image_width = image.size[0]
32 | image_height = image.size[1]
33 | crop_win_width = np.random.randint(image_width - border, image_width)
34 | crop_win_height = np.random.randint(image_height - border, image_height)
35 | random_region = (
36 | (image_width - crop_win_width) >> 1, (image_height - crop_win_height) >> 1, (image_width + crop_win_width) >> 1,
37 | (image_height + crop_win_height) >> 1)
38 | return image.crop(random_region), label.crop(random_region), depth.crop(random_region),bound.crop(random_region),gt2.crop(random_region)
39 |
40 |
41 | def randomRotation(image, label, depth,bound,gt2):
42 | mode = Image.BICUBIC
43 | if random.random() > 0.8:
44 | random_angle = np.random.randint(-15, 15)
45 | image = image.rotate(random_angle, mode)
46 | label = label.rotate(random_angle, mode)
47 | depth = depth.rotate(random_angle, mode)
48 | bound = bound.rotate(random_angle, mode)
49 | gt2 = gt2.rotate(random_angle, mode)
50 | return image, label, depth,bound,gt2
51 |
52 |
53 | def colorEnhance(image):
54 | bright_intensity = random.randint(5, 15) / 10.0
55 | image = ImageEnhance.Brightness(image).enhance(bright_intensity)
56 | contrast_intensity = random.randint(5, 15) / 10.0
57 | image = ImageEnhance.Contrast(image).enhance(contrast_intensity)
58 | color_intensity = random.randint(0, 20) / 10.0
59 | image = ImageEnhance.Color(image).enhance(color_intensity)
60 | sharp_intensity = random.randint(0, 30) / 10.0
61 | image = ImageEnhance.Sharpness(image).enhance(sharp_intensity)
62 | return image
63 |
64 |
65 | def randomGaussian(image, mean=0.1, sigma=0.35):
66 | def gaussianNoisy(im, mean=mean, sigma=sigma):
67 | for _i in range(len(im)):
68 | im[_i] += random.gauss(mean, sigma)
69 | return im
70 |
71 | img = np.asarray(image)
72 | width, height = img.shape
73 | img = gaussianNoisy(img[:].flatten(), mean, sigma)
74 | img = img.reshape([width, height])
75 | return Image.fromarray(np.uint8(img))
76 |
77 |
78 | def randomPeper(img):
79 | img = np.array(img)
80 | noiseNum = int(0.0015 * img.shape[0] * img.shape[1])
81 | for i in range(noiseNum):
82 |
83 | randX = random.randint(0, img.shape[0] - 1)
84 |
85 | randY = random.randint(0, img.shape[1] - 1)
86 |
87 | if random.randint(0, 1) == 0:
88 |
89 | img[randX, randY] = 0
90 |
91 | else:
92 |
93 | img[randX, randY] = 255
94 | return Image.fromarray(img)
95 |
96 |
97 | # dataset for training
98 | # The current loader is not using the normalized depth maps for training and test. If you use the normalized depth maps
99 | # (e.g., 0 represents background and 1 represents foreground.), the performance will be further improved.
100 | class SalObjDataset(data.Dataset):
101 | def __init__(self, image_root, gt_root, depth_root, bound_root,gt2_root,trainsize):
102 | self.trainsize = trainsize
103 | self.images = [image_root + f for f in os.listdir(image_root) if f.endswith('.jpg')]
104 | self.gts = [gt_root + f for f in os.listdir(gt_root) if f.endswith('.jpg')
105 | or f.endswith('.png')]
106 | self.depths = [depth_root + f for f in os.listdir(depth_root) if f.endswith('.bmp')
107 | or f.endswith('.png') or f.endswith('.jpg')]
108 | self.bound = [bound_root + f for f in os.listdir(bound_root) if f.endswith('.jpg')
109 | or f.endswith('.png')]
110 | self.images = sorted(self.images)
111 | self.gts2 = [gt2_root + f for f in os.listdir(gt_root) if f.endswith('.jpg')
112 | or f.endswith('.png')]
113 | # print(len(self.images))
114 | self.gts = sorted(self.gts)
115 | self.gts2 = sorted(self.gts2)
116 | # print(len(self.gts))
117 | self.depths = sorted(self.depths)
118 | self.bound = sorted(self.bound)
119 | # print(len(self.depths))
120 | self.filter_files()
121 | self.size = len(self.images)
122 | self.img_transform = transforms.Compose([
123 | transforms.Resize((self.trainsize, self.trainsize)),
124 | transforms.ToTensor(),
125 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
126 | self.gt_transform = transforms.Compose([
127 | transforms.Resize((self.trainsize, self.trainsize)),
128 | transforms.ToTensor()])
129 | self.bound_transform = transforms.Compose([
130 | transforms.Resize((self.trainsize, self.trainsize)),
131 | transforms.ToTensor()])
132 | self.depths_transform = transforms.Compose([
133 | transforms.Resize((self.trainsize, self.trainsize)),
134 | transforms.ToTensor(),
135 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
136 | ])
137 | self.gt2_transform = transforms.Compose([
138 | transforms.Resize((self.trainsize, self.trainsize)),
139 | transforms.ToTensor()])
140 |
141 | def __getitem__(self, index):
142 | image = self.rgb_loader(self.images[index])
143 | gt = self.binary_loader(self.gts[index])
144 | gt2 = self.binary_loader(self.gts2[index])
145 | # depth = self.binary_loader(self.depths[index])
146 | depth = self.rgb_loader(self.depths[index])
147 | bound = self.binary_loader(self.bound[index])
148 | image, gt, depth,bound,gt2 = cv_random_flip(image, gt, depth,bound,gt2)
149 | image, gt, depth,bound,gt2 = randomCrop(image, gt, depth,bound,gt2)
150 | image, gt, depth,bound,gt2 = randomRotation(image, gt, depth,bound,gt2)
151 | image = colorEnhance(image)
152 | # gt=randomGaussian(gt)
153 | gt = randomPeper(gt)
154 | bound = randomPeper(bound)
155 | gt2 = randomPeper(gt2)
156 | # image, gt, depth = self.resize(image,gt, depth)
157 | image = self.img_transform(image)
158 | gt = self.gt_transform(gt)
159 | bound = self.bound_transform(bound)
160 | depth = self.depths_transform(depth)
161 | gt2 = self.gt2_transform(gt2)
162 | return image, gt, depth,bound,gt2
163 |
164 | def filter_files(self):
165 | # print('len(self.images)',len(self.images))
166 | # print('len(self.depths)', len(self.depths))
167 | # print('len(self.images)', len(self.images))
168 | assert len(self.images) == len(self.depths) and len(self.gts) == len(self.images)
169 | images = []
170 | gts = []
171 | depths = []
172 | bounds = []
173 | gts2 = []
174 | for img_path, gt_path, depth_path,bound_path,gt2_path in zip(self.images, self.gts, self.depths,self.bound,self.gts2):
175 | img = Image.open(img_path)
176 | gt = Image.open(gt_path)
177 | depth = Image.open(depth_path)
178 | bound = Image.open(bound_path)
179 | # if img.size == depth.size:
180 | # and gt.size == depth.size
181 | images.append(img_path)
182 | gts.append(gt_path)
183 | depths.append(depth_path)
184 | bounds.append(bound_path)
185 | gts2.append(gt2_path)
186 | self.images = images
187 | self.gts = gts
188 | self.depths = depths
189 | self.bound = bounds
190 | self.gts2 = gts2
191 |
192 | def rgb_loader(self, path):
193 | with open(path, 'rb') as f:
194 | img = Image.open(f)
195 | return img.convert('RGB')
196 |
197 | def binary_loader(self, path):
198 | with open(path, 'rb') as f:
199 | img = Image.open(f)
200 | return img.convert('L')
201 |
202 | def resize(self, img, gt, depth,gt2):
203 | assert img.size == gt.size and gt.size == depth.size
204 | h = self.trainsize
205 | w = self.trainsize
206 | return img.resize((w, h), Image.BILINEAR), gt.resize((w, h), Image.NEAREST), depth.resize((w, h),Image.NEAREST)#, gt2.resize((w, h), Image.NEAREST)
207 |
208 |
209 |
210 | def __len__(self):
211 | return self.size
212 |
213 |
214 | # dataloader for training
215 | def get_loader(image_root, gt_root, depth_root, bound_root,gt2_root, batchsize, trainsize, shuffle=True, num_workers=4, pin_memory=False):
216 | dataset = SalObjDataset(image_root, gt_root, depth_root, bound_root,gt2_root,trainsize)
217 | data_loader = data.DataLoader(dataset=dataset,
218 | batch_size=batchsize,
219 | shuffle=shuffle,
220 | num_workers=num_workers,
221 | pin_memory=pin_memory,
222 | drop_last=True
223 | )
224 | return data_loader
225 |
226 |
227 | # test dataset and loader
228 | class test_dataset:
229 | def __init__(self, image_root, gt_root, depth_root, testsize):
230 | self.testsize = testsize
231 | self.images = [image_root + f for f in os.listdir(image_root) if f.endswith('.jpg')]
232 | self.gts = [gt_root + f for f in os.listdir(gt_root) if f.endswith('.jpg')
233 | or f.endswith('.png')]
234 | self.depths = [depth_root + f for f in os.listdir(depth_root) if f.endswith('.bmp')
235 | or f.endswith('.png')or f.endswith('.jpg')]
236 | self.images = sorted(self.images)
237 | self.gts = sorted(self.gts)
238 | self.depths = sorted(self.depths)
239 | self.transform = transforms.Compose([
240 | transforms.Resize((self.testsize, self.testsize)),
241 | transforms.ToTensor(),
242 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
243 | # self.gt_transform = transforms.ToTensor()
244 | self.gt_transform = transforms.Compose([
245 | transforms.Resize((self.testsize, self.testsize)),
246 | transforms.ToTensor()])
247 | self.depths_transform = transforms.Compose([
248 | transforms.Resize((self.testsize, self.testsize)),
249 | transforms.ToTensor(),
250 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
251 | self.size = len(self.images)
252 | self.index = 0
253 |
254 | def load_data(self):
255 | image = self.rgb_loader(self.images[self.index])
256 | gt = self.binary_loader(self.gts[self.index])
257 | # depth = self.binary_loader(self.depths[self.index])
258 | depth = self.rgb_loader(self.depths[self.index])
259 | # image, gt, depth = self.resize(image, gt, depth)
260 | image = self.transform(image).unsqueeze(0)
261 | gt = self.gt_transform(gt).unsqueeze(0)
262 | depth = self.depths_transform(depth).unsqueeze(0)
263 | name = self.images[self.index].split('/')[-1]
264 | if name.endswith('.jpg'):
265 | name = name.split('.jpg')[0] + '.png'
266 | self.index += 1
267 | self.index = self.index % self.size
268 | return image, gt, depth, name
269 |
270 | def rgb_loader(self, path):
271 | with open(path, 'rb') as f:
272 | img = Image.open(f)
273 | return img.convert('RGB')
274 |
275 | def binary_loader(self, path):
276 | with open(path, 'rb') as f:
277 | img = Image.open(f)
278 | return img.convert('L')
279 |
280 | def resize(self, img, gt, depth):
281 | # assert img.size == gt.size and gt.size == depth.size
282 | h = self.testsize
283 | w = self.testsize
284 | return img.resize((w, h), Image.BILINEAR), gt.resize((w, h), Image.NEAREST), depth.resize((w, h),
285 | Image.NEAREST)
286 |
287 |
288 | def __len__(self):
289 | return self.size
290 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | import os
4 | import cv2
5 |
6 | from networks.Wavenet import WaveNet
7 |
8 | import matplotlib.pyplot as plt
9 | from config import opt
10 | from rgbt_dataset_KD import test_dataset
11 | from datetime import datetime
12 |
13 | dataset_path = opt.test_path
14 |
15 | #set device for test
16 | os.environ["CUDA_VISIBLE_DEVICES"] = opt.gpu_id
17 | print('USE GPU:', opt.gpu_id)
18 |
19 | #load the model
20 | model = WaveNet()
21 | print('NOW USING:WaveNet')
22 | model.load_state_dict(torch.load(''))
23 |
24 | model.cuda()
25 | model.eval()
26 |
27 | #test
28 |
29 | test_mae = []
30 | test_datasets = ['VT800','VT1000', 'VT5000']
31 |
32 | for dataset in test_datasets:
33 |
34 | mae_sum = 0
35 | save_path = 'show/' + dataset + '/'
36 | if not os.path.exists(save_path):
37 | os.makedirs(save_path)
38 | image_root = dataset_path + dataset + '/RGB/'
39 | gt_root = dataset_path + dataset + '/GT/'
40 | depth_root = dataset_path + dataset + '/T/'
41 |
42 | test_loader = test_dataset(image_root, gt_root,depth_root, opt.testsize)
43 | prec_time = datetime.now()
44 | for i in range(test_loader.size):
45 | image, gt, depth, name = test_loader.load_data()
46 | gt = gt.cuda()
47 | image = image.cuda()
48 | depth = depth.cuda()
49 |
50 | res = model(image, depth)
51 | res = torch.sigmoid(res[0])
52 | res = (res - res.min()) / (res.max() - res.min() + 1e-8)
53 | mae_train = torch.sum((torch.abs(res - gt)) * 1.0 / (torch.numel(gt)))
54 | mae_sum = mae_train.item() + mae_sum
55 | predict = res.data.cpu().numpy().squeeze()
56 | print('save img to: ', save_path + name, )
57 | predict = cv2.resize(predict,(224,224),interpolation=cv2.INTER_LINEAR)
58 | plt.imsave(save_path + name, arr=predict, cmap='gray')
59 |
60 | cur_time = datetime.now()
61 | test_mae.append(mae_sum / len(test_loader))
62 |
63 | print('Test_mae:', test_mae)
64 | print('Test Done!')
--------------------------------------------------------------------------------
/train_wave_KD.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torch.nn.functional as F
4 |
5 | from datetime import datetime
6 | from torchvision.utils import make_grid
7 |
8 | from networks.Wavenet import WaveNet
9 | from rgbt_dataset_KD import get_loader, test_dataset
10 | from utils import clip_gradient, adjust_lr
11 | from tensorboardX import SummaryWriter
12 | import logging
13 | import torch.backends.cudnn as cudnn
14 | from config import opt
15 | from torch.cuda import amp
16 | import pytorch_iou
17 |
18 | # set the device for training
19 | cudnn.benchmark = True
20 |
21 | os.environ["CUDA_VISIBLE_DEVICES"] = opt.gpu_id
22 | print('USE GPU:', opt.gpu_id)
23 |
24 | # build the model
25 |
26 |
27 | model = WaveNet()
28 | model.load_pre()
29 |
30 | model.cuda()
31 | params = model.parameters()
32 | optimizer = torch.optim.Adam(params, opt.lr)
33 |
34 | # set the path
35 | train_dataset_path = opt.lr_train_root
36 | image_root = train_dataset_path + '/RGB/'
37 | depth_root = train_dataset_path + '/T/'
38 | gt_root = train_dataset_path + '/GT/'
39 | bound_root = train_dataset_path + '/bound/'
40 | gt2_root = train_dataset_path + '/GT_SwinNet/'
41 | val_dataset_path = opt.lr_val_root
42 | val_image_root = val_dataset_path + '/RGB/'
43 | val_depth_root = val_dataset_path + '/T/'
44 | val_gt_root = val_dataset_path + '/GT/'
45 | save_path = opt.save_path
46 |
47 | if not os.path.exists(save_path):
48 | os.makedirs(save_path)
49 |
50 | # load data
51 | print('load data...')
52 |
53 | train_loader = get_loader(image_root, gt_root,depth_root,bound_root,gt2_root, batchsize=opt.batchsize, trainsize=opt.trainsize)
54 | # print(len(train_loader))
55 | test_loader = test_dataset(val_image_root, val_gt_root,val_depth_root, opt.trainsize)
56 | total_step = len(train_loader)
57 |
58 | logging.basicConfig(filename=save_path + 'log.log', format='[%(asctime)s-%(filename)s-%(levelname)s:%(message)s]',
59 | level=logging.INFO, filemode='a', datefmt='%Y-%m-%d %I:%M:%S %p')
60 | logging.info(save_path + "Train")
61 | logging.info("Config")
62 | logging.info(
63 | 'epoch:{};lr:{};batchsize:{};trainsize:{};clip:{};decay_rate:{};load:{};save_path:{};decay_epoch:{}'.format(
64 | opt.epoch, opt.lr, opt.batchsize, opt.trainsize, opt.clip, opt.decay_rate, opt.load, save_path,
65 | opt.decay_epoch))
66 |
67 | # set loss function
68 | def linear_annealing(init, fin, step, annealing_steps):
69 | """Linear annealing of a parameter."""
70 | if annealing_steps == 0:
71 | return fin
72 | assert fin > init
73 | delta = fin - init
74 | annealed = min(init + delta * step / annealing_steps, fin)
75 | return annealed
76 |
77 |
78 |
79 | def joint_loss(pred, mask):
80 | weit = 1 + 5*torch.abs(F.avg_pool2d(mask, kernel_size=31, stride=1, padding=15) - mask)
81 | wbce = F.binary_cross_entropy_with_logits(pred, mask, reduce='none')
82 | wbce = (weit*wbce).sum(dim=(2, 3)) / weit.sum(dim=(2, 3))
83 |
84 | pred = torch.sigmoid(pred)
85 | inter = ((pred * mask)*weit).sum(dim=(2, 3))
86 | union = ((pred + mask)*weit).sum(dim=(2, 3))
87 | wiou = 1 - (inter + 1)/(union - inter+1)
88 | return (wbce + wiou).mean()
89 |
90 |
91 | CE = torch.nn.BCEWithLogitsLoss()
92 | IOU = pytorch_iou.IOU(size_average = True)
93 |
94 |
95 |
96 | # 超参数
97 | step = 0
98 | writer = SummaryWriter(save_path + 'summary')
99 | best_mae = 1
100 | best_epoch = 0
101 | Sacler = amp.GradScaler()
102 |
103 | # train function
104 | length = 821
105 |
106 | def train(train_loader, model, optimizer, epoch, save_path):
107 | global step
108 | model.train()
109 | # model2.train()
110 | loss_all = 0
111 | epoch_step = 0
112 | try:
113 | for i, (images, gts,depths,bound,gts2) in enumerate(train_loader, start=1):
114 | optimizer.zero_grad()
115 | # ima = images
116 | # dep = depths
117 | images = images.cuda()
118 | # print(images.shape)
119 | depths = depths.cuda()
120 | gts = gts.cuda()
121 | bound = bound.cuda()
122 | gts2 = gts2.cuda()
123 | _,_,w_t,h_t = gts.size()
124 |
125 | s1, s2, s3, s4, s1_sig, s2_sig, s3_sig, s4_sig, edge, latent_loss =model(images, depths)
126 | '''
127 | We directly use the predicts from the teacher model, but you can also use this type to load teacher model.
128 |
129 | t1,t2, t3, t4, t1_sig, t2_sig, t3_sig, t4_sig, tedge, tlatent_loss = model_T(images, depths)
130 | t1, t2 = model_T(images, depths)
131 | s1, s2, s3,s4,s1_sig, s2_sig, s3_sig,s4_sig= model(images, depths) # , self.sigmoid(s5)
132 | target,_ = model2(ima, dep)
133 | t1 = t1.cuda()
134 | m = nn.Sigmoid()
135 | t1 = m(t1)
136 | t1 = t1.detach()
137 |
138 | t2 = t2.cuda()
139 | m = nn.Sigmoid()
140 | t2 = m(t2)
141 | t2 = t2.detach()
142 | print('latent_loss',latent_loss)
143 | '''
144 | loss1 = CE(s1, gts) + IOU(s1_sig, gts)
145 | loss2 = CE(s2, gts) + IOU(s2_sig, gts)
146 | loss3 = CE(s3, gts) + IOU(s3_sig, gts)
147 | loss4 = CE(s4, gts) + IOU(s4_sig, gts)
148 | loss5 = CE(edge, bound)
149 | anneal_reg = linear_annealing(0, 1, epoch, opt.epoch)
150 | loss6 = 0.1 * anneal_reg * latent_loss
151 | loss7 = CE(s1_sig, gts2) + IOU(s1_sig, gts2)
152 |
153 | loss = loss1 + loss2 + loss3 + loss4 + loss5 + loss6 + loss7
154 | loss.backward()
155 |
156 |
157 | clip_gradient(optimizer, opt.clip)
158 |
159 | optimizer.step()
160 | step += 1
161 | epoch_step = epoch_step +1
162 | loss_all =loss_all + loss.item()
163 | if i % 100 == 0 or i == total_step or i == 1:
164 | print('{} Epoch [{:03d}/{:03d}],W*H [{:03d}*{:03d}], Step [{:04d}/{:04d}], Loss: {:.4f}'.
165 | format(datetime.now(), epoch+1, opt.epoch, w_t, h_t, i, total_step, loss.item()))
166 | logging.info('#TRAIN#:Epoch [{:03d}/{:03d}], Step [{:04d}/{:04d}], Loss: {:.4f}'.
167 | format(epoch+1, opt.epoch, i, total_step, loss.item()))
168 | writer.add_scalar('Loss/total_loss', loss, global_step=step)
169 | grid_image = make_grid(images[0].clone().cpu().data,1,normalize=True)
170 | writer.add_image('train/RGB',grid_image, step)
171 | grid_image = make_grid(depths[0].clone().cpu().data, 1, normalize=True)
172 | writer.add_image('train/Ti', grid_image, step)
173 |
174 | loss_all /= epoch_step
175 | logging.info('#TRAIN#:Epoch [{:03d}/{:03d}], Loss_AVG: {:.4f}'.format(epoch+1, opt.epoch, loss_all))
176 | writer.add_scalar('Loss-epoch', loss_all, global_step=epoch)
177 | if (epoch+1) % 10 == 0 or (epoch+1) == opt.epoch:
178 | torch.save(model.state_dict(), save_path + 'Epoch_{}_test.pth'.format(epoch+1))
179 | except KeyboardInterrupt:
180 | print('Keyboard Interrupt: save model and exit.')
181 | if not os.path.exists(save_path):
182 | os.makedirs(save_path)
183 | torch.save(model.state_dict(), save_path + 'Epoch_{}_test.pth'.format(epoch + 1))
184 | print('save checkpoints successfully!')
185 | raise
186 |
187 | #
188 |
189 | # test function
190 | def test(test_loader, model, epoch, save_path):
191 | global best_mae, best_epoch
192 | model.eval()
193 | with torch.no_grad():
194 | mae_sum = 0
195 | for i in range(test_loader.size):
196 | image, gt, depth, name = test_loader.load_data()
197 | gt = gt.cuda()
198 | image = image.cuda()
199 | depth = depth.cuda()
200 |
201 | res = model(image, depth)
202 | res = torch.sigmoid(res[0])
203 |
204 | res = (res-res.min())/(res.max()-res.min()+1e-8)
205 | mae_train =torch.sum(torch.abs(res-gt))*1.0/(torch.numel(gt))
206 | mae_sum = mae_train.item()+mae_sum
207 |
208 | mae = mae_sum / test_loader.size # mae / length:.8f
209 |
210 | print('Epoch: {} MAE: {} #### bestMAE: {} bestEpoch: {}'.format(epoch, mae, best_mae, best_epoch))
211 | if epoch == 1:
212 | best_mae = mae
213 | else:
214 | if mae < best_mae:
215 | best_mae = mae
216 | best_epoch = epoch
217 | torch.save(model.state_dict(), save_path + 'Best_mae_test.pth')
218 | print('best epoch:{}'.format(epoch))
219 | logging.info('#TEST#:Epoch:{} MAE:{} bestEpoch:{} bestMAE:{}'.format(epoch, mae, best_epoch, best_mae))
220 |
221 |
222 |
223 | if __name__ == '__main__':
224 | print("Start train...")
225 | start_time = datetime.now()
226 | for epoch in range(opt.epoch):
227 | cur_lr = adjust_lr(optimizer, opt.lr, epoch, opt.decay_rate, opt.decay_epoch)
228 | writer.add_scalar('learning_rate', cur_lr, global_step=epoch)
229 | train(train_loader, model, optimizer, epoch, save_path)
230 | test(test_loader, model, epoch, save_path)
231 | finish_time = datetime.now()
232 | h, remainder = divmod((finish_time - start_time).seconds, 3600)
233 | m, s = divmod(remainder, 60)
234 | time = 'Time: {:.0f}:{:.0f}:{:.0f}'.format(h, m, s)
235 | print(time)
236 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | def clip_gradient(optimizer, grad_clip):
2 | for group in optimizer.param_groups:
3 | for param in group['params']:
4 | if param.grad is not None:
5 | param.grad.data.clamp_(-grad_clip, grad_clip)
6 |
7 |
8 | def adjust_lr(optimizer1, init_lr, epoch, decay_rate=0.1, decay_epoch=30):
9 | decay = decay_rate ** (epoch // decay_epoch)
10 | for param_group in optimizer1.param_groups:
11 | param_group['lr'] = decay*init_lr
12 | lr=param_group['lr']
13 | return lr
14 |
15 |
--------------------------------------------------------------------------------