├── model
├── __init__.py
├── MobileNetV2.py
├── MAGNet.py
└── smt.py
├── ckps
└── smt
│ └── pretrained SMT here
├── test_maps
└── pred maps here
├── pytorch_iou
├── __pycache__
│ └── __init__.cpython-39.pyc
└── __init__.py
├── LICENSE
├── test_Net.py
├── utils.py
├── README.md
├── train_Net.py
└── data.py
/model/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/ckps/smt/pretrained SMT here:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/test_maps/pred maps here:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/pytorch_iou/__pycache__/__init__.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mingyu6346/MAGNet/HEAD/pytorch_iou/__pycache__/__init__.cpython-39.pyc
--------------------------------------------------------------------------------
/pytorch_iou/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
3 | import torch
4 | import torch.nn.functional as F
5 | from torch.autograd import Variable
6 | import numpy as np
7 |
8 | def _iou(pred, target, size_average = True):
9 |
10 | b = pred.shape[0]
11 | IoU = 0.0
12 | for i in range(0,b):
13 | #compute the IoU of the foreground
14 | Iand1 = torch.sum(target[i,:,:,:]*pred[i,:,:,:])
15 | Ior1 = torch.sum(target[i,:,:,:]) + torch.sum(pred[i,:,:,:])-Iand1
16 | IoU1 = Iand1/Ior1
17 |
18 | #IoU loss is (1-IoU1)
19 | IoU = IoU + (1-IoU1)
20 |
21 | return IoU/b
22 |
23 | class IOU(torch.nn.Module):
24 | def __init__(self, size_average = True):
25 | super(IOU, self).__init__()
26 | self.size_average = size_average
27 |
28 | def forward(self, pred, target):
29 |
30 | return _iou(pred, target, self.size_average)
31 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 mingyu6346
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/test_Net.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | import sys
4 | import warnings
5 | from tqdm import tqdm
6 | import numpy as np
7 | import os, argparse
8 | import cv2
9 | from model.MAGNet import MAGNet
10 | from data import test_dataset
11 |
12 | warnings.filterwarnings("ignore")
13 |
14 | parser = argparse.ArgumentParser()
15 | parser.add_argument('--trainsize', type=int, default=384, help='testing size')
16 | parser.add_argument('--gpu_id', type=str, default='0', help='select gpu id')
17 | parser.add_argument('--test_path', type=str, default='', help='test dataset path')
18 | parser.add_argument('--save_path', type=str, default='./test_maps/MAGNet/', help='save path')
19 | parser.add_argument('--pth_path', type=str, default='', help='checkpoint path')
20 | opt = parser.parse_args()
21 |
22 | dataset_path = opt.test_path
23 |
24 | # set device for test
25 | if opt.gpu_id == '0':
26 | os.environ["CUDA_VISIBLE_DEVICES"] = "0"
27 | print('USE GPU 0')
28 | elif opt.gpu_id == '1':
29 | os.environ["CUDA_VISIB_LEDEVICES"] = "1"
30 | print('USE GPU 1')
31 | # load the model
32 | model = MAGNet().eval().cuda()
33 | model.load_state_dict(torch.load(opt.pth_path))
34 |
35 | test_datasets = ['DUT', 'LFSD', 'NJU2K', 'NLPR', 'SIP', 'STERE']
36 |
37 |
38 | for dataset in test_datasets:
39 | save_path = opt.save_path + dataset + '/'
40 | if not os.path.exists(save_path):
41 | os.makedirs(save_path)
42 | image_root = dataset_path + dataset + '/RGB/'
43 | gt_root = dataset_path + dataset + '/GT/'
44 | depth_root = dataset_path + dataset + '/depth/'
45 | # depth_root = dataset_path + dataset + '/T/'
46 | test_loader = test_dataset(image_root, gt_root, depth_root, opt.trainsize)
47 | for i in tqdm(range(test_loader.size), desc=dataset, file=sys.stdout):
48 | image, gt, depth, name, image_for_post = test_loader.load_data()
49 | gt = np.asarray(gt, np.float32)
50 | gt /= (gt.max() + 1e-8)
51 | image = image.cuda()
52 | depth = depth.repeat(1, 3, 1, 1).cuda()
53 | res, pred_2, pred_3, pred_4 = model(image, depth)
54 | res = F.upsample(res, size=gt.shape, mode='bilinear', align_corners=False)
55 | res = res.sigmoid().data.cpu().numpy().squeeze()
56 | res = (res - res.min()) / (res.max() - res.min() + 1e-8)
57 | cv2.imwrite(save_path + name, res * 255)
58 | print('Test Done!')
59 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import numpy as np
4 | import requests
5 | import torch
6 |
7 |
8 | def clip_gradient(optimizer, grad_clip):
9 | for group in optimizer.param_groups:
10 | for param in group['params']:
11 | if param.grad is not None:
12 | param.grad.data.clamp_(-grad_clip, grad_clip)
13 |
14 |
15 | def adjust_lr(optimizer, init_lr, epoch, decay_rate=0.1, decay_epoch=30):
16 | decay = decay_rate ** (epoch // decay_epoch)
17 | for param_group in optimizer.param_groups:
18 | param_group['lr'] = decay * init_lr
19 | lr = param_group['lr']
20 | return lr
21 |
22 |
23 | def opt_save(opt):
24 | log_path = opt.save_path + "train_settings.log"
25 | if not os.path.exists(opt.save_path):
26 | os.makedirs(opt.save_path)
27 | att = [i for i in opt.__dir__() if not i.startswith("_")]
28 | with open(log_path, "w") as f:
29 | for i in att:
30 | print("{}:{}".format(i, eval(f"opt.{i}")), file=f)
31 |
32 |
33 | def iou_loss(pred, mask):
34 | pred = torch.sigmoid(pred)
35 | inter = (pred * mask).sum(dim=(2, 3))
36 | union = (pred + mask).sum(dim=(2, 3))
37 | iou = 1 - (inter + 1) / (union - inter + 1)
38 | return iou.mean()
39 |
40 |
41 | def fps(model, epoch_num, size):
42 | ls = [] # 每次计算得到的fps
43 | iterations = 300 # 重复计算的轮次
44 | device = torch.device("cuda:0")
45 | # device = torch.device("cpu")
46 | model.to(device)
47 |
48 | random_input = torch.randn(1, 3, size, size).to(device)
49 | starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
50 |
51 | # GPU预热
52 | for _ in range(50):
53 | _ = model(random_input, random_input)
54 |
55 | for i in range(epoch_num):
56 | # 测速
57 | times = torch.zeros(iterations) # 存储每轮iteration的时间
58 | with torch.no_grad():
59 | for iter in range(iterations):
60 | starter.record()
61 | _ = model(random_input, random_input)
62 | ender.record()
63 | # 同步GPU时间
64 | torch.cuda.synchronize()
65 | curr_time = starter.elapsed_time(ender) # 计算时间
66 | times[iter] = curr_time
67 | # print(curr_time)
68 |
69 | mean_time = times.mean().item()
70 | ls.append(1000 / mean_time)
71 | print("{}/{} Inference time: {:.6f}, FPS: {} ".format(i + 1, epoch_num, mean_time, 1000 / mean_time))
72 | print(f"平均fps为 {np.mean(ls):.2f}")
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # README.md
2 |
3 | This project provides the code and results for 'MAGNet: Multi-scale Awareness and Global Fusion Network for RGB-D Salient Object Detection'
4 |
5 | # Environments
6 |
7 | ```bash
8 | conda create -n magnet python=3.9.18
9 | conda activate magnet
10 | conda install pytorch==1.10.1 torchvision==0.11.2 torchaudio==0.10.1 cudatoolkit=11.3 -c pytorch -c conda-forge
11 | conda install -c conda-forge opencv-python==4.7.0
12 | pip install timm==0.6.5
13 | conda install -c conda-forge tqdm
14 | conda install yacs
15 | ```
16 |
17 | # Data Preparation
18 |
19 | - Download the RGB-D raw data from [baidu](https://pan.baidu.com/s/10Y90OXUFoW8yAeRmr5LFnA?pwd=exwj) / [Google drive](https://drive.google.com/file/d/19HXwGJCtz0QdEDsEbH7cJqTBfD-CEXxX/view?usp=sharing)
20 | - Download the RGB-T raw data from [baidu](https://pan.baidu.com/s/1eexJSI4a2EGoaYcDkt1B9Q?pwd=i7a2) / [Google drive](https://drive.google.com/file/d/1hLhn5WV6xh-Q41upXF-bzyVpbszF9hUc/view?usp=sharing)
21 |
22 | Note that the depth maps of the raw data above are foreground is white.
23 |
24 | # Training & Testing
25 |
26 | - Train the MAGNet:
27 | 1. download the pretrained SMT pth from [baidu](https://pan.baidu.com/s/11bNtCS7HyjnB7Lf3RIbpFg?pwd=bxiw) / [Google drive](https://drive.google.com/file/d/1eNhQwUHmfjR-vVGY38D_dFYUOqD_pw-H/view?usp=sharing), and put it under `ckps/smt/`.
28 | 2. modify the `rgb_root` `depth_root` `gt_root` in `train_Net.py` according to your own data path.
29 | 3. run `python train_Net.py`
30 | - Test the MAGNet:
31 | 1. modify the `test_path` `pth_path` in `test_Net.py` according to your own data path.
32 | 2. run `python test_Net.py`
33 |
34 | # Evaluate tools
35 |
36 | - You can select one of toolboxes to get the metrics
37 | [CODToolbox](https://github.com/DengPingFan/CODToolbox) / [SOD_Evaluation_Metrics](https://github.com/zyjwuyan/SOD_Evaluation_Metrics)
38 |
39 | # Saliency Maps
40 |
41 | We provide the saliency maps of DUT, LFSD, NJU2K, NLPR, SIP, SSD, STERE datasets.
42 |
43 | - RGB-D [baidu](https://pan.baidu.com/s/1FK8jcDb61QdFvZF1qKMV6g?pwd=c3a6) / [Google drive](https://drive.google.com/file/d/1uoeNZPzsj2RAr0ofM8fPD6N0JJ8HCyn9/view?usp=sharing)
44 |
45 | We provide the saliency maps of VT821, VT1000, VT5000 datasets.
46 |
47 | - RGB-T [baidu](https://pan.baidu.com/s/1IQIkZS9GzmBT0PHflHqMNw?pwd=ebuw) / [Google drive](https://drive.google.com/file/d/198k-3R-yDF_y0Br7MoeSBP5XQOPuXPnL/view?usp=sharing)
48 |
49 | # Trained Models
50 |
51 | - RGB-D [baidu](https://pan.baidu.com/s/1RPMA5Z3liMoUlG0AWuGeRA?pwd=5aqf) / [Google drive](https://drive.google.com/file/d/1vb2Vcbz9bCjvaSwoRZjIi39ae5Ei1GVs/view?usp=sharing)
52 |
53 | # Acknowledgement
54 |
55 | The implement of this project is based on the codebases bellow.
56 |
57 | - [SeaNet](https://github.com/MathLee/SeaNet)
58 | - [LSNet](https://github.com/zyrant/LSNet)
59 | - Fps/speed test [MobileSal](https://github.com/yuhuan-wu/MobileSal/blob/master/speed_test.py)
60 | - Evaluate tools [CODToolbox](https://github.com/DengPingFan/CODToolbox) / [SOD_Evaluation_Metrics](https://github.com/zyjwuyan/SOD_Evaluation_Metrics)
61 |
62 | # Contact
63 |
64 | Feel free to contact me if you have any questions: (mingyu6346 at 163 dot com)
65 |
--------------------------------------------------------------------------------
/model/MobileNetV2.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 | import torch
3 |
4 | try:
5 | from torchvision.models.utils import load_state_dict_from_url # torchvision 0.4+
6 | except ModuleNotFoundError:
7 | try:
8 | from torch.hub import load_state_dict_from_url # torch 1.x
9 | except ModuleNotFoundError:
10 | from torch.utils.model_zoo import load_url as load_state_dict_from_url # torch 0.4.1
11 |
12 | model_urls = {
13 | 'mobilenet_v2': 'https://download.pytorch.org/models/mobilenet_v2-b0353104.pth',
14 | }
15 |
16 |
17 | class ConvBNReLU(nn.Sequential):
18 | def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1, dilation=1):
19 | padding = (kernel_size - 1) // 2
20 | if dilation != 1:
21 | padding = dilation
22 | super(ConvBNReLU, self).__init__(
23 | nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, dilation=dilation,
24 | bias=False),
25 | nn.BatchNorm2d(out_planes),
26 | nn.ReLU6(inplace=True)
27 | )
28 |
29 |
30 | class InvertedResidual(nn.Module):
31 | def __init__(self, inp, oup, stride, expand_ratio, dilation=1):
32 | super(InvertedResidual, self).__init__()
33 | self.stride = stride
34 | assert stride in [1, 2]
35 |
36 | hidden_dim = int(round(inp * expand_ratio))
37 | self.use_res_connect = self.stride == 1 and inp == oup
38 |
39 | layers = []
40 | if expand_ratio != 1:
41 | # pw
42 | layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1))
43 | layers.extend([
44 | # dw
45 | ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim, dilation=dilation),
46 | # pw-linear
47 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
48 | nn.BatchNorm2d(oup),
49 | ])
50 | self.conv = nn.Sequential(*layers)
51 |
52 | def forward(self, x):
53 | if self.use_res_connect:
54 | return x + self.conv(x)
55 | else:
56 | return self.conv(x)
57 |
58 |
59 | class MobileNetV2(nn.Module):
60 | def __init__(self, pretrained=None, num_classes=1000, width_mult=1.0):
61 | super(MobileNetV2, self).__init__()
62 | block = InvertedResidual
63 | input_channel = 32
64 | last_channel = 1280
65 | inverted_residual_setting = [
66 | # t, c, n, s, d
67 | [1, 16, 1, 1, 1], # conv1 112*112*16
68 | [6, 24, 2, 2, 1], # conv2 56*56*24
69 | [6, 32, 3, 2, 1], # conv3 28*28*32
70 | [6, 64, 4, 2, 1],
71 | [6, 96, 3, 1, 1], # conv4 14*14*96
72 | [6, 160, 3, 2, 1],
73 | [6, 320, 1, 1, 1], # conv5 7*7*320
74 | ]
75 |
76 | # building first layer
77 | input_channel = int(input_channel * width_mult)
78 | self.last_channel = int(last_channel * max(1.0, width_mult))
79 | features = [ConvBNReLU(3, input_channel, stride=2)]
80 | # building inverted residual blocks
81 | for t, c, n, s, d in inverted_residual_setting:
82 | output_channel = int(c * width_mult)
83 | for i in range(n):
84 | stride = s if i == 0 else 1
85 | dilation = d if i == 0 else 1
86 | features.append(block(input_channel, output_channel, stride, expand_ratio=t, dilation=d))
87 | input_channel = output_channel
88 | # building last several layers
89 | features.append(ConvBNReLU(input_channel, self.last_channel, kernel_size=1))
90 | # make it nn.Sequential
91 | self.features = nn.Sequential(*features)
92 |
93 | # weight initialization
94 | for m in self.modules():
95 | if isinstance(m, nn.Conv2d):
96 | nn.init.kaiming_normal_(m.weight, mode='fan_out')
97 | if m.bias is not None:
98 | nn.init.zeros_(m.bias)
99 | elif isinstance(m, nn.BatchNorm2d):
100 | nn.init.ones_(m.weight)
101 | nn.init.zeros_(m.bias)
102 | elif isinstance(m, nn.Linear):
103 | nn.init.normal_(m.weight, 0, 0.01)
104 | nn.init.zeros_(m.bias)
105 |
106 | def forward(self, x):
107 | res = []
108 | for idx, m in enumerate(self.features):
109 | x = m(x)
110 | if idx in [1, 3, 6, 13, 17]:
111 | res.append(x)
112 | return res
113 |
114 |
115 | def mobilenet_v2(pretrained=False, progress=True, **kwargs):
116 | model = MobileNetV2(**kwargs)
117 | if pretrained:
118 | state_dict = load_state_dict_from_url(model_urls['mobilenet_v2'],
119 | progress=progress)
120 | print("loading imagenet pretrained mobilenetv2")
121 | model.load_state_dict(state_dict, strict=False)
122 | print("loaded imagenet pretrained mobilenetv2")
123 | return model
--------------------------------------------------------------------------------
/train_Net.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import sys
3 | import torch.nn.functional as F
4 | import torch
5 | import numpy as np
6 | import os, argparse
7 | from datetime import datetime
8 | from tqdm import tqdm
9 | from model.MAGNet import MAGNet
10 | from data import get_loader, test_dataset
11 | from utils import clip_gradient, adjust_lr, opt_save, iou_loss
12 |
13 | import pytorch_iou
14 |
15 | os.environ['CUDA_VISIBLE_DEVICES'] = '0'
16 | parser = argparse.ArgumentParser()
17 | parser.add_argument('--epoch', type=int, default=200, help='epoch number')
18 | parser.add_argument('--lr', type=float, default=5e-5, help='learning rate')
19 | parser.add_argument('--batchsize', type=int, default=8, help='training batch size')
20 | parser.add_argument('--trainsize', type=int, default=384, help='training image size')
21 | parser.add_argument('--continue_train', type=bool, default=False, help='continue training')
22 | parser.add_argument('--continue_train_path', type=str, default='', help='continue training path')
23 |
24 | parser.add_argument('--rgb_root', type=str, default='D:/DataSet/SOD/RGB-D SOD/train_dut/RGB/',
25 | help='the training rgb images root') # train_dut
26 | parser.add_argument('--depth_root', type=str, default='D:/DataSet/SOD/RGB-D SOD/train_dut/depth/',
27 | help='the training depth images root')
28 | parser.add_argument('--gt_root', type=str, default='D:/DataSet/SOD/RGB-D SOD/train_dut/GT/',
29 | help='the training gt images root')
30 |
31 | parser.add_argument('--val_rgb', type=str, default="D:/DataSet/SOD/RGB-D SOD/test_data/NLPR/RGB/",
32 | help='validate rgb path')
33 | parser.add_argument('--val_depth', type=str, default="D:/DataSet/SOD/RGB-D SOD/test_data/NLPR/depth/",
34 | help='validate depth path')
35 | parser.add_argument('--val_gt', type=str, default="D:/DataSet/SOD/RGB-D SOD/test_data/NLPR/GT/",
36 | help='validate gt path')
37 |
38 | parser.add_argument('--clip', type=float, default=0.5, help='gradient clipping margin')
39 | parser.add_argument('--decay_rate', type=float, default=0.1, help='decay rate of learning rate')
40 | parser.add_argument('--decay_epoch', type=int, default=80, help='every n epochs decay learning rate')
41 | parser.add_argument('--save_path', type=str, default="ckps/MAGNet/", help='checkpoint path')
42 |
43 | opt = parser.parse_args()
44 |
45 | opt_save(opt)
46 | logging.basicConfig(filename=opt.save_path + 'log.log',
47 | format='[%(asctime)s-%(filename)s-%(levelname)s:%(message)s]', level=logging.INFO, filemode='a',
48 | datefmt='%Y-%m-%d %H:%M:%S %p')
49 | logging.info("Net-Train")
50 | # model
51 | model = MAGNet()
52 | if os.path.exists("ckps/smt/smt_tiny.pth"):
53 | model.rgb_backbone.load_state_dict(torch.load("ckps/smt/smt_tiny.pth")['model'])
54 | print(f"loaded imagenet pretrained SMT from ckps/smt/smt_tiny.pth")
55 | else:
56 | raise "please put smt_tiny.pth under ckps/smt/ folder"
57 | if opt.continue_train:
58 | model.load_state_dict(torch.load(opt.continue_train_path))
59 | print(f"continue training from {opt.continue_train_path}")
60 |
61 | model.cuda()
62 | params = model.parameters()
63 | optimizer = torch.optim.Adam(params, opt.lr)
64 |
65 | print('load data...')
66 | train_loader = get_loader(opt.rgb_root, opt.gt_root, opt.depth_root, batchsize=opt.batchsize, trainsize=opt.trainsize)
67 | total_step = len(train_loader)
68 |
69 | CE = torch.nn.BCEWithLogitsLoss()
70 | IOU = pytorch_iou.IOU(size_average=True)
71 | best_loss = 1.0
72 |
73 |
74 | def train(train_loader, model, optimizer, epoch):
75 | model.train()
76 | loss_list = []
77 | for i, pack in enumerate(train_loader, start=1):
78 | optimizer.zero_grad()
79 | (images, gts, depth) = pack
80 | images = images.cuda()
81 | gts = gts.cuda()
82 | depth = depth.cuda().repeat(1, 3, 1, 1)
83 |
84 | pred_1, pred_2, pred_3, pred_4 = model(images, depth)
85 |
86 | loss1 = CE(pred_1, gts) + iou_loss(pred_1, gts)
87 | loss2 = CE(pred_2, gts) + iou_loss(pred_2, gts)
88 | loss3 = CE(pred_3, gts) + iou_loss(pred_3, gts)
89 | loss4 = CE(pred_4, gts) + iou_loss(pred_4, gts)
90 |
91 | loss = loss1 + loss2 + loss3 + loss4
92 |
93 | loss.backward()
94 |
95 | clip_gradient(optimizer, opt.clip)
96 | optimizer.step()
97 |
98 | loss_list.append(float(loss))
99 | if i % 20 == 0 or i == total_step:
100 | msg = '{} Epoch [{:03d}/{:03d}], Step [{:04d}/{:04d}], Learning Rate: {}, Loss: {:.4f}, Loss1: {:.4f}, Loss2: {:.4f}, Loss3: {:.4f}, Loss4: {:.4f}'.format(
101 | datetime.now(), epoch, opt.epoch, i, total_step,
102 | opt.lr * opt.decay_rate ** (epoch // opt.decay_epoch), loss.data, loss1.data,
103 | loss2.data, loss3.data, loss4.data)
104 | print(msg)
105 | logging.info(msg)
106 | epoch_loss = np.mean(loss_list)
107 | if not os.path.exists(opt.save_path):
108 | os.makedirs(opt.save_path)
109 |
110 | global best_loss
111 | if epoch_loss < best_loss:
112 | best_loss = epoch_loss
113 | torch.save(model.state_dict(), opt.save_path + 'Net_rgb_d.pth' + f'.{epoch}_{epoch_loss:.3f}',
114 | _use_new_zipfile_serialization=False)
115 | with open(opt.save_path + "loss.log", "a") as f:
116 | print(f"{datetime.now()} epoch {epoch} loss {np.mean(loss_list):.3f}", file=f)
117 |
118 |
119 | best_mae = 1.0
120 | best_epoch = 0
121 |
122 |
123 | def validate(test_dataset, model, epoch, opt):
124 | global best_mae, best_epoch
125 | model.eval().cuda()
126 | mae_sum = 0
127 | test_loader = test_dataset(opt.val_rgb, opt.val_gt, opt.val_depth, opt.trainsize)
128 | with torch.no_grad():
129 | for i in tqdm(range(test_loader.size), desc="Validating", file=sys.stdout):
130 | image, gt, depth, name, image_for_post = test_loader.load_data()
131 | gt = np.asarray(gt, np.float32)
132 | gt /= (gt.max() + 1e-8)
133 | image = image.cuda()
134 | depth = depth.repeat(1, 3, 1, 1).cuda()
135 |
136 | res, _, _, _ = model(image, depth)
137 | res = F.interpolate(res, size=gt.shape, mode='bilinear', align_corners=False)
138 | res = res.sigmoid().data.cpu().numpy().squeeze()
139 |
140 | res = (res - res.min()) / (res.max() - res.min() + 1e-8)
141 | # mae_train = torch.sum(torch.abs(res - gt)) * 1.0 / (torch.numel(gt))
142 | mae_train = np.mean(np.abs(res - gt))
143 | mae_sum = mae_train + mae_sum
144 | mae = mae_sum / test_loader.size
145 |
146 | if epoch == 0:
147 | best_mae = mae
148 | else:
149 | if mae < best_mae:
150 | best_mae = round(mae, 5)
151 | best_epoch = epoch
152 | torch.save(model.state_dict(), opt.save_path + 'MAGNet_mae_best.pth', _use_new_zipfile_serialization=False)
153 | print('best epoch:{}'.format(epoch))
154 | msg = 'Epoch: {} MAE: {:.5f} #### bestMAE: {} bestEpoch: {}'.format(epoch, mae, best_mae, best_epoch)
155 | print(msg)
156 | logging.info(msg)
157 | with open(f"{opt.save_path}mae.log", "a", encoding='utf-8') as f:
158 | f.write('Epoch: {:03d} MAE: {:.5f} #### bestMAE: {:.5f} bestEpoch: {:03d}\n'.format(epoch, mae, best_mae,
159 | best_epoch))
160 | return mae
161 |
162 |
163 | print("Let's go!")
164 | for epoch in range(opt.epoch):
165 | adjust_lr(optimizer, opt.lr, epoch, opt.decay_rate, opt.decay_epoch)
166 | train(train_loader, model, optimizer, epoch)
167 | validate(test_dataset, model, epoch, opt)
168 |
--------------------------------------------------------------------------------
/data.py:
--------------------------------------------------------------------------------
1 | import os
2 | from PIL import Image
3 | import torch.utils.data as data
4 | import torchvision.transforms as transforms
5 | import random
6 | import numpy as np
7 | from PIL import ImageEnhance
8 | from tqdm import tqdm
9 |
10 |
11 | # several data augumentation strategies
12 | def cv_random_flip(img, label, depth):
13 | flip_flag = random.randint(0, 1)
14 | # flip_flag2= random.randint(0,1)
15 | # left right flip
16 | if flip_flag == 1:
17 | img = img.transpose(Image.FLIP_LEFT_RIGHT)
18 | label = label.transpose(Image.FLIP_LEFT_RIGHT)
19 | depth = depth.transpose(Image.FLIP_LEFT_RIGHT)
20 | # top bottom flip
21 | # if flip_flag2==1:
22 | # img = img.transpose(Image.FLIP_TOP_BOTTOM)
23 | # label = label.transpose(Image.FLIP_TOP_BOTTOM)
24 | # depth = depth.transpose(Image.FLIP_TOP_BOTTOM)
25 | return img, label, depth
26 |
27 |
28 | def randomCrop(image, label, depth):
29 | border = 30
30 | image_width = image.size[0]
31 | image_height = image.size[1]
32 | crop_win_width = np.random.randint(image_width - border, image_width)
33 | crop_win_height = np.random.randint(image_height - border, image_height)
34 | random_region = (
35 | (image_width - crop_win_width) >> 1, (image_height - crop_win_height) >> 1, (image_width + crop_win_width) >> 1,
36 | (image_height + crop_win_height) >> 1)
37 | return image.crop(random_region), label.crop(random_region), depth.crop(random_region)
38 |
39 |
40 | def randomRotation(image, label, depth):
41 | mode = Image.BICUBIC
42 | if random.random() > 0.8:
43 | random_angle = np.random.randint(-15, 15)
44 | image = image.rotate(random_angle, mode)
45 | label = label.rotate(random_angle, mode)
46 | depth = depth.rotate(random_angle, mode)
47 | return image, label, depth
48 |
49 |
50 | def colorEnhance(image):
51 | bright_intensity = random.randint(5, 15) / 10.0
52 | image = ImageEnhance.Brightness(image).enhance(bright_intensity)
53 | contrast_intensity = random.randint(5, 15) / 10.0
54 | image = ImageEnhance.Contrast(image).enhance(contrast_intensity)
55 | color_intensity = random.randint(0, 20) / 10.0
56 | image = ImageEnhance.Color(image).enhance(color_intensity)
57 | sharp_intensity = random.randint(0, 30) / 10.0
58 | image = ImageEnhance.Sharpness(image).enhance(sharp_intensity)
59 | return image
60 |
61 |
62 | def randomGaussian(image, mean=0.1, sigma=0.35):
63 | def gaussianNoisy(im, mean=mean, sigma=sigma):
64 | for _i in range(len(im)):
65 | im[_i] += random.gauss(mean, sigma)
66 | return im
67 |
68 | img = np.asarray(image)
69 | width, height = img.shape
70 | img = gaussianNoisy(img[:].flatten(), mean, sigma)
71 | img = img.reshape([width, height])
72 | return Image.fromarray(np.uint8(img))
73 |
74 |
75 | def randomPepper(img):
76 | img = np.array(img)
77 | noiseNum = int(0.0015 * img.shape[0] * img.shape[1])
78 | for i in range(noiseNum):
79 |
80 | randX = random.randint(0, img.shape[0] - 1)
81 |
82 | randY = random.randint(0, img.shape[1] - 1)
83 |
84 | if random.randint(0, 1) == 0:
85 |
86 | img[randX, randY] = 0
87 |
88 | else:
89 |
90 | img[randX, randY] = 255
91 | return Image.fromarray(img)
92 |
93 |
94 | # dataset for training
95 | # The current loader is not using the normalized depth maps for training and test. If you use the normalized depth maps
96 | # (e.g., 0 represents background and 1 represents foreground.), the performance will be further improved.
97 | class SalObjDataset(data.Dataset):
98 | def __init__(self, image_root, gt_root, depth_root, trainsize):
99 | self.trainsize = trainsize
100 | self.images = [image_root + f for f in os.listdir(image_root) if f.endswith('.jpg')]
101 |
102 | self.gts = [gt_root + f for f in os.listdir(gt_root) if f.endswith('.jpg')
103 | or f.endswith('.png')]
104 |
105 | self.depths = [depth_root + f for f in os.listdir(depth_root) if f.endswith('.bmp')
106 | or f.endswith('.png') or f.endswith('.jpg')]
107 | self.images = sorted(self.images)
108 | self.gts = sorted(self.gts)
109 | self.depths = sorted(self.depths)
110 | self.filter_files()
111 | self.size = len(self.images)
112 | self.img_transform = transforms.Compose([
113 | transforms.Resize((self.trainsize, self.trainsize)),
114 | transforms.ToTensor(),
115 | transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))])
116 | self.gt_transform = transforms.Compose([
117 | transforms.Resize((self.trainsize, self.trainsize)),
118 | transforms.ToTensor()])
119 | self.depths_transform = transforms.Compose(
120 | [transforms.Resize((self.trainsize, self.trainsize)), transforms.ToTensor()])
121 |
122 | def __getitem__(self, index):
123 | image = self.rgb_loader(self.images[index])
124 | gt = self.binary_loader(self.gts[index])
125 | depth = self.binary_loader(self.depths[index])
126 |
127 | image, gt, depth = cv_random_flip(image, gt, depth)
128 | image, gt, depth = randomCrop(image, gt, depth)
129 | image, gt, depth = randomRotation(image, gt, depth)
130 | image = colorEnhance(image)
131 | # gt=randomGaussian(gt)
132 | gt = randomPepper(gt)
133 | image = self.img_transform(image)
134 | gt = self.gt_transform(gt)
135 | depth = self.depths_transform(depth)
136 |
137 | return image, gt, depth
138 |
139 | def filter_files(self):
140 | assert len(self.images) == len(self.gts) and len(self.gts) == len(self.images)
141 | images = []
142 | gts = []
143 | depths = []
144 | for img_path, gt_path, depth_path in zip(self.images, self.gts, self.depths):
145 | img = Image.open(img_path)
146 | gt = Image.open(gt_path)
147 | depth = Image.open(depth_path)
148 | if img.size == gt.size and gt.size == depth.size:
149 | images.append(img_path)
150 | gts.append(gt_path)
151 | depths.append(depth_path)
152 | self.images = images
153 | self.gts = gts
154 | self.depths = depths
155 |
156 | def rgb_loader(self, path):
157 | with open(path, 'rb') as f:
158 | img = Image.open(f)
159 | return img.convert('RGB')
160 |
161 | def binary_loader(self, path):
162 | with open(path, 'rb') as f:
163 | img = Image.open(f)
164 | return img.convert('L')
165 |
166 | def resize(self, img, gt, depth):
167 | assert img.size == gt.size and gt.size == depth.size
168 | w, h = img.size
169 | if h < self.trainsize or w < self.trainsize:
170 | h = max(h, self.trainsize)
171 | w = max(w, self.trainsize)
172 | return img.resize((w, h), Image.BILINEAR), gt.resize((w, h), Image.NEAREST), \
173 | depth.resize((w, h), Image.NEAREST)
174 | else:
175 | return img, gt, depth
176 |
177 | def __len__(self):
178 | return self.size
179 |
180 |
181 | # dataloader for training
182 | def get_loader(image_root, gt_root, depth_root, batchsize, trainsize, shuffle=True, num_workers=0, pin_memory=True):
183 | dataset = SalObjDataset(image_root, gt_root, depth_root, trainsize)
184 | # print(image_root)
185 | # print(gt_root)
186 | # print(depth_root)
187 | data_loader = data.DataLoader(dataset=dataset,
188 | batch_size=batchsize,
189 | shuffle=shuffle,
190 | num_workers=num_workers,
191 | pin_memory=pin_memory)
192 | return data_loader
193 |
194 |
195 | # test dataset and loader
196 | class test_dataset:
197 | def __init__(self, image_root, gt_root, depth_root, testsize):
198 | self.testsize = testsize
199 | self.images = [image_root + f for f in os.listdir(image_root) if f.endswith('.jpg')]
200 | self.gts = [gt_root + f for f in os.listdir(gt_root) if f.endswith('.jpg')
201 | or f.endswith('.png')]
202 | self.depths = [depth_root + f for f in os.listdir(depth_root) if f.endswith('.bmp')
203 | or f.endswith('.png') or f.endswith('.jpg')]
204 |
205 | self.images = sorted(self.images)
206 | self.gts = sorted(self.gts)
207 | self.depths = sorted(self.depths)
208 | self.transform = transforms.Compose([
209 | transforms.Resize((self.testsize, self.testsize)),
210 | transforms.ToTensor(),
211 | transforms.Normalize(mean=(0.485, 0.458, 0.407), std=(0.229, 0.224, 0.225))])
212 | # transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
213 | self.gt_transform = transforms.ToTensor()
214 | # self.gt_transform = transforms.Compose([
215 | # transforms.Resize((self.trainsize, self.trainsize)),
216 | # transforms.ToTensor()])
217 | self.depths_transform = transforms.Compose(
218 | [transforms.Resize((self.testsize, self.testsize)), transforms.ToTensor()])
219 |
220 | self.size = len(self.images)
221 | self.index = 0
222 |
223 | def load_data(self):
224 | image = self.rgb_loader(self.images[self.index])
225 | image = self.transform(image).unsqueeze(0)
226 | gt = self.binary_loader(self.gts[self.index])
227 | depth = self.binary_loader(self.depths[self.index])
228 | depth = self.depths_transform(depth).unsqueeze(0)
229 | name = self.gts[self.index].split('/')[-1]
230 | image_for_post = self.rgb_loader(self.images[self.index])
231 | image_for_post = image_for_post.resize(gt.size)
232 | if name.endswith('.jpg'):
233 | name = name.split('.jpg')[0] + '.jpg'
234 | self.index += 1
235 | self.index = self.index % self.size
236 | return image, gt, depth, name, np.array(image_for_post)
237 |
238 | def rgb_loader(self, path):
239 | with open(path, 'rb') as f:
240 | img = Image.open(f)
241 | return img.convert('RGB')
242 |
243 | def binary_loader(self, path):
244 | with open(path, 'rb') as f:
245 | img = Image.open(f)
246 | return img.convert('L')
247 |
248 | def __len__(self):
249 | return self.size
250 |
251 |
252 | if __name__ == "__main__":
253 | ir = "E:/lmt/train_dut/train_dut/train_images/"
254 | dr = "E:/lmt/train_dut/train_dut/train_depth/"
255 | gt = "E:/lmt/train_dut/train_dut/train_masks/"
256 | b = 8
257 | ts = 384
258 | train_loader = get_loader(ir, gt, dr, b, ts)
259 | train_loader = tqdm(train_loader)
260 | for i, (images, gts, depth) in enumerate(train_loader, start=1):
261 | pass
262 |
--------------------------------------------------------------------------------
/model/MAGNet.py:
--------------------------------------------------------------------------------
1 | import math
2 | from model.smt import smt_t
3 | from model.MobileNetV2 import mobilenet_v2
4 | import torch.nn as nn
5 | import torch
6 | import torch.nn.functional as F
7 | from timm.models.layers import trunc_normal_
8 |
9 | TRAIN_SIZE = 384
10 |
11 |
12 | class MAGNet(nn.Module):
13 | def __init__(self, pretrained=True):
14 | super().__init__()
15 | self.rgb_backbone = smt_t(pretrained)
16 | self.d_backbone = mobilenet_v2(pretrained)
17 |
18 | # Fuse
19 | self.gfm2 = GFM(inc=512, expend_ratio=2)
20 | self.gfm1 = GFM(inc=256, expend_ratio=3)
21 | self.mafm2 = MAFM(inc=128)
22 | self.mafm1 = MAFM(inc=64)
23 |
24 | # Pred
25 | self.mcm3 = MCM(inc=512, outc=256)
26 | self.mcm2 = MCM(inc=256, outc=128)
27 | self.mcm1 = MCM(inc=128, outc=64)
28 |
29 | self.d_trans_4 = Trans(320, 512)
30 | self.d_trans_3 = Trans(96, 256)
31 | self.d_trans_2 = Trans(32, 128)
32 | self.d_trans_1 = Trans(24, 64)
33 |
34 | self.predtrans = nn.Sequential(
35 | nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1, groups=512),
36 | nn.BatchNorm2d(512),
37 | nn.GELU(),
38 | nn.Conv2d(in_channels=512, out_channels=1, kernel_size=1, stride=1)
39 | )
40 |
41 | def forward(self, x_rgb, x_d):
42 | # rgb
43 | _, (rgb_1, rgb_2, rgb_3, rgb_4) = self.rgb_backbone(x_rgb)
44 |
45 | # d
46 | _, d_1, d_2, d_3, d_4 = self.d_backbone(x_d)
47 |
48 | d_4 = self.d_trans_4(d_4)
49 | d_3 = self.d_trans_3(d_3)
50 | d_2 = self.d_trans_2(d_2)
51 | d_1 = self.d_trans_1(d_1)
52 |
53 | # Fuse
54 | fuse_4 = self.gfm2(rgb_4, d_4) # [B, 512, 12, 12]
55 | fuse_3 = self.gfm1(rgb_3, d_3) # [B, 256, 24, 24]
56 | fuse_2 = self.mafm2(rgb_2, d_2) # [B, 128, 48, 48]
57 | fuse_1 = self.mafm1(rgb_1, d_1) # [B, 64, 96, 96]
58 |
59 | # Pred
60 | pred_4 = F.interpolate(self.predtrans(fuse_4), TRAIN_SIZE, mode="bilinear", align_corners=True)
61 | pred_3, xf_3 = self.mcm3(fuse_3, fuse_4)
62 | pred_2, xf_2 = self.mcm2(fuse_2, xf_3)
63 | pred_1, xf_1 = self.mcm1(fuse_1, xf_2)
64 |
65 | return pred_1, pred_2, pred_3, pred_4
66 |
67 |
68 | class Trans(nn.Module):
69 | def __init__(self, inc, outc):
70 | super().__init__()
71 | self.trans = nn.Sequential(
72 | nn.Conv2d(in_channels=inc, out_channels=outc, kernel_size=1),
73 | nn.BatchNorm2d(outc),
74 | nn.GELU()
75 | )
76 | self.apply(self._init_weights)
77 |
78 | def forward(self, d):
79 | return self.trans(d)
80 |
81 | def _init_weights(self, m):
82 | if isinstance(m, nn.Linear):
83 | trunc_normal_(m.weight, std=.02)
84 | if isinstance(m, nn.Linear) and m.bias is not None:
85 | nn.init.constant_(m.bias, 0)
86 | elif isinstance(m, nn.LayerNorm):
87 | nn.init.constant_(m.bias, 0)
88 | nn.init.constant_(m.weight, 1.0)
89 | elif isinstance(m, nn.Conv2d):
90 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
91 | fan_out //= m.groups
92 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
93 | if m.bias is not None:
94 | m.bias.data.zero_()
95 |
96 |
97 | # Conv_One_Identity
98 | class COI(nn.Module):
99 | def __init__(self, inc, k=3, p=1):
100 | super().__init__()
101 | self.outc = inc
102 | self.dw = nn.Conv2d(inc, self.outc, kernel_size=k, padding=p, groups=inc)
103 | self.conv1_1 = nn.Conv2d(inc, self.outc, kernel_size=1, stride=1)
104 | self.bn1 = nn.BatchNorm2d(self.outc)
105 | self.bn2 = nn.BatchNorm2d(self.outc)
106 | self.bn3 = nn.BatchNorm2d(self.outc)
107 | self.act = nn.GELU()
108 | self.apply(self._init_weights)
109 |
110 | def forward(self, x):
111 | shortcut = self.bn1(x)
112 |
113 | x_dw = self.bn2(self.dw(x))
114 | x_conv1_1 = self.bn3(self.conv1_1(x))
115 | return self.act(shortcut + x_dw + x_conv1_1)
116 |
117 | def _init_weights(self, m):
118 | if isinstance(m, nn.Linear):
119 | trunc_normal_(m.weight, std=.02)
120 | if isinstance(m, nn.Linear) and m.bias is not None:
121 | nn.init.constant_(m.bias, 0)
122 | elif isinstance(m, nn.LayerNorm):
123 | nn.init.constant_(m.bias, 0)
124 | nn.init.constant_(m.weight, 1.0)
125 | elif isinstance(m, nn.Conv2d):
126 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
127 | fan_out //= m.groups
128 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
129 | if m.bias is not None:
130 | m.bias.data.zero_()
131 |
132 |
133 | class MHMC(nn.Module):
134 | def __init__(self, dim, ca_num_heads=4, qkv_bias=True, proj_drop=0., ca_attention=1, expand_ratio=2):
135 | super().__init__()
136 |
137 | self.ca_attention = ca_attention
138 | self.dim = dim
139 | self.ca_num_heads = ca_num_heads
140 |
141 | assert dim % ca_num_heads == 0, f"dim {dim} should be divided by num_heads {ca_num_heads}."
142 |
143 | self.act = nn.GELU()
144 | self.proj = nn.Linear(dim, dim)
145 | self.proj_drop = nn.Dropout(proj_drop)
146 |
147 | self.split_groups = self.dim // ca_num_heads
148 |
149 | self.v = nn.Linear(dim, dim, bias=qkv_bias)
150 | self.s = nn.Linear(dim, dim, bias=qkv_bias)
151 | for i in range(self.ca_num_heads):
152 | local_conv = nn.Conv2d(dim // self.ca_num_heads, dim // self.ca_num_heads, kernel_size=(3 + i * 2),
153 | padding=(1 + i), stride=1,
154 | groups=dim // self.ca_num_heads) # kernel_size 3,5,7,9 大核dw卷积,padding 1,2,3,4
155 | setattr(self, f"local_conv_{i + 1}", local_conv)
156 | self.proj0 = nn.Conv2d(dim, dim * expand_ratio, kernel_size=1, padding=0, stride=1,
157 | groups=self.split_groups)
158 | self.bn = nn.BatchNorm2d(dim * expand_ratio)
159 | self.proj1 = nn.Conv2d(dim * expand_ratio, dim, kernel_size=1, padding=0, stride=1)
160 |
161 | self.apply(self._init_weights)
162 |
163 | def _init_weights(self, m):
164 | if isinstance(m, nn.Linear):
165 | trunc_normal_(m.weight, std=.02)
166 | if isinstance(m, nn.Linear) and m.bias is not None:
167 | nn.init.constant_(m.bias, 0)
168 | elif isinstance(m, nn.LayerNorm):
169 | nn.init.constant_(m.bias, 0)
170 | nn.init.constant_(m.weight, 1.0)
171 | elif isinstance(m, nn.Conv2d):
172 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
173 | fan_out //= m.groups
174 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
175 | if m.bias is not None:
176 | m.bias.data.zero_()
177 |
178 | def forward(self, x, H, W):
179 | B, N, C = x.shape
180 | v = self.v(x)
181 | s = self.s(x).reshape(B, H, W, self.ca_num_heads, C // self.ca_num_heads).permute(3, 0, 4, 1,
182 | 2) # num_heads,B,C,H,W
183 | for i in range(self.ca_num_heads):
184 | local_conv = getattr(self, f"local_conv_{i + 1}")
185 | s_i = s[i] # B,C,H,W
186 | s_i = local_conv(s_i).reshape(B, self.split_groups, -1, H, W)
187 | if i == 0:
188 | s_out = s_i
189 | else:
190 | s_out = torch.cat([s_out, s_i], 2)
191 | s_out = s_out.reshape(B, C, H, W)
192 | s_out = self.proj1(self.act(self.bn(self.proj0(s_out))))
193 | self.modulator = s_out
194 | s_out = s_out.reshape(B, C, N).permute(0, 2, 1)
195 | x = s_out * v
196 |
197 | x = self.proj(x)
198 | x = self.proj_drop(x)
199 | return x
200 |
201 |
202 | class SAttention(nn.Module):
203 | def __init__(self, dim, sa_num_heads=8, qkv_bias=True, qk_scale=None,
204 | attn_drop=0., proj_drop=0.):
205 | super().__init__()
206 |
207 | self.dim = dim
208 | self.sa_num_heads = sa_num_heads
209 |
210 | assert dim % sa_num_heads == 0, f"dim {dim} should be divided by num_heads {sa_num_heads}."
211 |
212 | self.act = nn.GELU()
213 | self.proj = nn.Linear(dim, dim)
214 | self.proj_drop = nn.Dropout(proj_drop)
215 |
216 | head_dim = dim // sa_num_heads
217 | self.scale = qk_scale or head_dim ** -0.5
218 | self.q = nn.Linear(dim, dim, bias=qkv_bias)
219 | self.attn_drop = nn.Dropout(attn_drop)
220 | self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias)
221 | self.local_conv = nn.Conv2d(dim, dim, kernel_size=3, padding=1, stride=1, groups=dim)
222 |
223 | self.apply(self._init_weights)
224 |
225 | def _init_weights(self, m):
226 | if isinstance(m, nn.Linear):
227 | trunc_normal_(m.weight, std=.02)
228 | if isinstance(m, nn.Linear) and m.bias is not None:
229 | nn.init.constant_(m.bias, 0)
230 | elif isinstance(m, nn.LayerNorm):
231 | nn.init.constant_(m.bias, 0)
232 | nn.init.constant_(m.weight, 1.0)
233 | elif isinstance(m, nn.Conv2d):
234 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
235 | fan_out //= m.groups
236 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
237 | if m.bias is not None:
238 | m.bias.data.zero_()
239 |
240 | def forward(self, x, H, W):
241 | B, N, C = x.shape
242 |
243 | q = self.q(x).reshape(B, N, self.sa_num_heads, C // self.sa_num_heads).permute(0, 2, 1, 3)
244 | kv = self.kv(x).reshape(B, -1, 2, self.sa_num_heads, C // self.sa_num_heads).permute(2, 0, 3, 1, 4)
245 | k, v = kv[0], kv[1]
246 | attn = (q @ k.transpose(-2, -1)) * self.scale
247 | attn = attn.softmax(dim=-1)
248 | attn = self.attn_drop(attn)
249 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) + \
250 | self.local_conv(v.transpose(1, 2).reshape(B, N, C).transpose(1, 2).view(B, C, H, W)).view(B, C,
251 | N).transpose(1, 2)
252 |
253 | x = self.proj(x)
254 | x = self.proj_drop(x)
255 |
256 | return x.permute(0, 2, 1).reshape(B, C, H, W)
257 |
258 |
259 | # Multi-scale Awareness Fusion Module
260 | class MAFM(nn.Module):
261 | def __init__(self, inc):
262 | super().__init__()
263 | self.outc = inc
264 | self.attention = MHMC(dim=inc)
265 | self.coi = COI(inc)
266 | self.pw = nn.Sequential(
267 | nn.Conv2d(in_channels=inc, out_channels=inc, kernel_size=1, stride=1),
268 | nn.BatchNorm2d(inc),
269 | nn.GELU()
270 | )
271 | self.pre_att = nn.Sequential(
272 | nn.Conv2d(inc * 2, inc * 2, kernel_size=3, padding=1, groups=inc * 2),
273 | nn.BatchNorm2d(inc * 2),
274 | nn.GELU(),
275 | nn.Conv2d(inc * 2, inc, kernel_size=1),
276 | nn.BatchNorm2d(inc),
277 | nn.GELU()
278 | )
279 |
280 | self.apply(self._init_weights)
281 |
282 | def forward(self, x, d):
283 | # multi = x * d
284 | # B, C, H, W = x.shape
285 | # x_cat = torch.cat((x, d, multi), dim=1)
286 |
287 | B, C, H, W = x.shape
288 | x_cat = torch.cat((x, d), dim=1)
289 | x_pre = self.pre_att(x_cat)
290 | # Attention
291 | x_reshape = x_pre.flatten(2).permute(0, 2, 1) # B,C,H,W to B,N,C
292 | attention = self.attention(x_reshape, H, W) # attention
293 | attention = attention.permute(0, 2, 1).reshape(B, C, H, W) # B,N,C to B,C,H,W
294 |
295 | # COI
296 | x_conv = self.coi(attention) # dw3*3,1*1,identity
297 | x_conv = self.pw(x_conv) # pw
298 |
299 | return x_conv
300 |
301 | def _init_weights(self, m):
302 | if isinstance(m, nn.Linear):
303 | trunc_normal_(m.weight, std=.02)
304 | if isinstance(m, nn.Linear) and m.bias is not None:
305 | nn.init.constant_(m.bias, 0)
306 | elif isinstance(m, nn.LayerNorm):
307 | nn.init.constant_(m.bias, 0)
308 | nn.init.constant_(m.weight, 1.0)
309 | elif isinstance(m, nn.Conv2d):
310 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
311 | fan_out //= m.groups
312 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
313 | if m.bias is not None:
314 | m.bias.data.zero_()
315 |
316 |
317 | # Decoder
318 | class MCM(nn.Module):
319 | def __init__(self, inc, outc):
320 | super().__init__()
321 | self.upsample2 = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
322 | self.rc = nn.Sequential(
323 | nn.Conv2d(in_channels=inc, out_channels=inc, kernel_size=3, padding=1, stride=1, groups=inc),
324 | nn.BatchNorm2d(inc),
325 | nn.GELU(),
326 | nn.Conv2d(in_channels=inc, out_channels=outc, kernel_size=1, stride=1),
327 | nn.BatchNorm2d(outc),
328 | nn.GELU()
329 | )
330 | self.predtrans = nn.Sequential(
331 | nn.Conv2d(in_channels=outc, out_channels=outc, kernel_size=3, padding=1, groups=outc),
332 | nn.BatchNorm2d(outc),
333 | nn.GELU(),
334 | nn.Conv2d(in_channels=outc, out_channels=1, kernel_size=1)
335 | )
336 |
337 | self.rc2 = nn.Sequential(
338 | nn.Conv2d(in_channels=outc * 2, out_channels=outc * 2, kernel_size=3, padding=1, groups=outc * 2),
339 | nn.BatchNorm2d(outc * 2),
340 | nn.GELU(),
341 | nn.Conv2d(in_channels=outc * 2, out_channels=outc, kernel_size=1, stride=1),
342 | nn.BatchNorm2d(outc),
343 | nn.GELU()
344 | )
345 |
346 | self.apply(self._init_weights)
347 |
348 | def forward(self, x1, x2):
349 | x2_upsample = self.upsample2(x2) # 上采样
350 | x2_rc = self.rc(x2_upsample) # 减少通道数
351 | shortcut = x2_rc
352 |
353 | x_cat = torch.cat((x1, x2_rc), dim=1) # 拼接
354 | x_forward = self.rc2(x_cat) # 减少通道数2
355 | x_forward = x_forward + shortcut
356 | pred = F.interpolate(self.predtrans(x_forward), TRAIN_SIZE, mode="bilinear", align_corners=True) # 预测图
357 |
358 | return pred, x_forward
359 |
360 | def _init_weights(self, m):
361 | if isinstance(m, nn.Linear):
362 | trunc_normal_(m.weight, std=.02)
363 | if isinstance(m, nn.Linear) and m.bias is not None:
364 | nn.init.constant_(m.bias, 0)
365 | elif isinstance(m, nn.LayerNorm):
366 | nn.init.constant_(m.bias, 0)
367 | nn.init.constant_(m.weight, 1.0)
368 | elif isinstance(m, nn.Conv2d):
369 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
370 | fan_out //= m.groups
371 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
372 | if m.bias is not None:
373 | m.bias.data.zero_()
374 |
375 |
376 | # Global Fusion Module
377 | class GFM(nn.Module):
378 | def __init__(self, inc, expend_ratio=2):
379 | super().__init__()
380 | self.expend_ratio = expend_ratio
381 | assert expend_ratio in [2, 3], f"expend_ratio {expend_ratio} mismatch"
382 |
383 | self.sa = SAttention(dim=inc)
384 | self.dw_pw = DWPWConv(inc * expend_ratio, inc)
385 | self.act = nn.GELU()
386 | self.apply(self._init_weights)
387 |
388 | def forward(self, x, d):
389 | B, C, H, W = x.shape
390 | if self.expend_ratio == 2:
391 | cat = torch.cat((x, d), dim=1)
392 | else:
393 | multi = x * d
394 | cat = torch.cat((x, d, multi), dim=1)
395 | x_rc = self.dw_pw(cat).flatten(2).permute(0, 2, 1)
396 | x_ = self.sa(x_rc, H, W)
397 | x_ = x_ + x
398 | return self.act(x_)
399 |
400 | def _init_weights(self, m):
401 | if isinstance(m, nn.Linear):
402 | trunc_normal_(m.weight, std=.02)
403 | if isinstance(m, nn.Linear) and m.bias is not None:
404 | nn.init.constant_(m.bias, 0)
405 | elif isinstance(m, nn.LayerNorm):
406 | nn.init.constant_(m.bias, 0)
407 | nn.init.constant_(m.weight, 1.0)
408 | elif isinstance(m, nn.Conv2d):
409 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
410 | fan_out //= m.groups
411 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
412 | if m.bias is not None:
413 | m.bias.data.zero_()
414 |
415 |
416 | class DWPWConv(nn.Module):
417 | def __init__(self, inc, outc):
418 | super().__init__()
419 | self.conv = nn.Sequential(
420 | nn.Conv2d(in_channels=inc, out_channels=inc, kernel_size=3, padding=1, stride=1, groups=inc),
421 | nn.BatchNorm2d(inc),
422 | nn.GELU(),
423 | nn.Conv2d(in_channels=inc, out_channels=outc, kernel_size=1, stride=1),
424 | nn.BatchNorm2d(outc),
425 | nn.GELU()
426 | )
427 |
428 | def forward(self, x):
429 | return self.conv(x)
430 |
--------------------------------------------------------------------------------
/model/smt.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from functools import partial
5 |
6 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_
7 | from timm.models.registry import register_model
8 | from timm.models.vision_transformer import _cfg
9 | import math
10 |
11 | from torchvision import transforms
12 | from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
13 | from timm.data import create_transform
14 | from timm.data.transforms import str_to_pil_interp
15 | # from ptflops import get_model_complexity_info
16 | # from thop import profile
17 |
18 |
19 | class Mlp(nn.Module):
20 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
21 | super().__init__()
22 | out_features = out_features or in_features
23 | hidden_features = hidden_features or in_features
24 | self.fc1 = nn.Linear(in_features, hidden_features)
25 | self.dwconv = DWConv(hidden_features)
26 | self.act = act_layer()
27 | self.fc2 = nn.Linear(hidden_features, out_features)
28 | self.drop = nn.Dropout(drop)
29 | self.apply(self._init_weights)
30 |
31 | def _init_weights(self, m):
32 | if isinstance(m, nn.Linear):
33 | trunc_normal_(m.weight, std=.02)
34 | if isinstance(m, nn.Linear) and m.bias is not None:
35 | nn.init.constant_(m.bias, 0)
36 | elif isinstance(m, nn.LayerNorm):
37 | nn.init.constant_(m.bias, 0)
38 | nn.init.constant_(m.weight, 1.0)
39 | elif isinstance(m, nn.Conv2d):
40 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
41 | fan_out //= m.groups
42 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
43 | if m.bias is not None:
44 | m.bias.data.zero_()
45 |
46 | def forward(self, x, H, W):
47 | x = self.fc1(x)
48 | x = self.act(x + self.dwconv(x, H, W))
49 | x = self.drop(x)
50 | x = self.fc2(x)
51 | x = self.drop(x)
52 | return x
53 |
54 |
55 | class Attention(nn.Module):
56 | def __init__(self, dim, ca_num_heads=4, sa_num_heads=8, qkv_bias=False, qk_scale=None,
57 | attn_drop=0., proj_drop=0., ca_attention=1, expand_ratio=2):
58 | super().__init__()
59 |
60 | self.ca_attention = ca_attention
61 | self.dim = dim
62 | self.ca_num_heads = ca_num_heads
63 | self.sa_num_heads = sa_num_heads
64 |
65 | assert dim % ca_num_heads == 0, f"dim {dim} should be divided by num_heads {ca_num_heads}."
66 | assert dim % sa_num_heads == 0, f"dim {dim} should be divided by num_heads {sa_num_heads}."
67 |
68 | self.act = nn.GELU()
69 | self.proj = nn.Linear(dim, dim)
70 | self.proj_drop = nn.Dropout(proj_drop)
71 |
72 | self.split_groups = self.dim // ca_num_heads
73 |
74 | if ca_attention == 1:
75 | self.v = nn.Linear(dim, dim, bias=qkv_bias)
76 | self.s = nn.Linear(dim, dim, bias=qkv_bias)
77 | for i in range(self.ca_num_heads):
78 | local_conv = nn.Conv2d(dim // self.ca_num_heads, dim // self.ca_num_heads, kernel_size=(3 + i * 2),
79 | padding=(1 + i), stride=1, groups=dim // self.ca_num_heads) # kernel_size 3,5,7,9 大核dw卷积,padding 1,2,3,4
80 | setattr(self, f"local_conv_{i + 1}", local_conv)
81 | self.proj0 = nn.Conv2d(dim, dim * expand_ratio, kernel_size=1, padding=0, stride=1,
82 | groups=self.split_groups)
83 | self.bn = nn.BatchNorm2d(dim * expand_ratio)
84 | self.proj1 = nn.Conv2d(dim * expand_ratio, dim, kernel_size=1, padding=0, stride=1)
85 |
86 | else:
87 | head_dim = dim // sa_num_heads
88 | self.scale = qk_scale or head_dim ** -0.5
89 | self.q = nn.Linear(dim, dim, bias=qkv_bias)
90 | self.attn_drop = nn.Dropout(attn_drop)
91 | self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias)
92 | self.local_conv = nn.Conv2d(dim, dim, kernel_size=3, padding=1, stride=1, groups=dim)
93 |
94 | self.apply(self._init_weights)
95 |
96 | def _init_weights(self, m):
97 | if isinstance(m, nn.Linear):
98 | trunc_normal_(m.weight, std=.02)
99 | if isinstance(m, nn.Linear) and m.bias is not None:
100 | nn.init.constant_(m.bias, 0)
101 | elif isinstance(m, nn.LayerNorm):
102 | nn.init.constant_(m.bias, 0)
103 | nn.init.constant_(m.weight, 1.0)
104 | elif isinstance(m, nn.Conv2d):
105 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
106 | fan_out //= m.groups
107 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
108 | if m.bias is not None:
109 | m.bias.data.zero_()
110 |
111 | def forward(self, x, H, W):
112 | B, N, C = x.shape
113 | if self.ca_attention == 1:
114 | v = self.v(x)
115 | s = self.s(x).reshape(B, H, W, self.ca_num_heads, C // self.ca_num_heads).permute(3, 0, 4, 1, 2) # num_heads,B,C,H,W
116 | for i in range(self.ca_num_heads):
117 | local_conv = getattr(self, f"local_conv_{i + 1}")
118 | s_i = s[i] # B,C,H,W
119 | s_i = local_conv(s_i).reshape(B, self.split_groups, -1, H, W)
120 | if i == 0:
121 | s_out = s_i
122 | else:
123 | s_out = torch.cat([s_out, s_i], 2)
124 | s_out = s_out.reshape(B, C, H, W)
125 | s_out = self.proj1(self.act(self.bn(self.proj0(s_out))))
126 | self.modulator = s_out
127 | s_out = s_out.reshape(B, C, N).permute(0, 2, 1)
128 | x = s_out * v
129 |
130 | else:
131 | q = self.q(x).reshape(B, N, self.sa_num_heads, C // self.sa_num_heads).permute(0, 2, 1, 3)
132 | kv = self.kv(x).reshape(B, -1, 2, self.sa_num_heads, C // self.sa_num_heads).permute(2, 0, 3, 1, 4)
133 | k, v = kv[0], kv[1]
134 | attn = (q @ k.transpose(-2, -1)) * self.scale
135 | attn = attn.softmax(dim=-1)
136 | attn = self.attn_drop(attn)
137 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) + \
138 | self.local_conv(v.transpose(1, 2).reshape(B, N, C).transpose(1, 2).view(B, C, H, W)).view(B, C, N).transpose(1, 2)
139 |
140 | x = self.proj(x)
141 | x = self.proj_drop(x)
142 |
143 | return x
144 |
145 |
146 | class Block(nn.Module):
147 |
148 | def __init__(self, dim, ca_num_heads, sa_num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None,
149 | use_layerscale=False, layerscale_value=1e-4, drop=0., attn_drop=0.,
150 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, ca_attention=1, expand_ratio=2):
151 | super().__init__()
152 | self.norm1 = norm_layer(dim)
153 | self.attn = Attention(
154 | dim,
155 | ca_num_heads=ca_num_heads, sa_num_heads=sa_num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
156 | attn_drop=attn_drop, proj_drop=drop, ca_attention=ca_attention,
157 | expand_ratio=expand_ratio)
158 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
159 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
160 | self.norm2 = norm_layer(dim)
161 | mlp_hidden_dim = int(dim * mlp_ratio)
162 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
163 |
164 | self.gamma_1 = 1.0
165 | self.gamma_2 = 1.0
166 | if use_layerscale:
167 | self.gamma_1 = nn.Parameter(layerscale_value * torch.ones((dim)), requires_grad=True)
168 | self.gamma_2 = nn.Parameter(layerscale_value * torch.ones((dim)), requires_grad=True)
169 |
170 | self.apply(self._init_weights)
171 |
172 | def _init_weights(self, m):
173 | if isinstance(m, nn.Linear):
174 | trunc_normal_(m.weight, std=.02)
175 | if isinstance(m, nn.Linear) and m.bias is not None:
176 | nn.init.constant_(m.bias, 0)
177 | elif isinstance(m, nn.LayerNorm):
178 | nn.init.constant_(m.bias, 0)
179 | nn.init.constant_(m.weight, 1.0)
180 | elif isinstance(m, nn.Conv2d):
181 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
182 | fan_out //= m.groups
183 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
184 | if m.bias is not None:
185 | m.bias.data.zero_()
186 |
187 | def forward(self, x, H, W):
188 |
189 | x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x), H, W))
190 | x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x), H, W))
191 |
192 | return x
193 |
194 |
195 | class OverlapPatchEmbed(nn.Module):
196 | """ Image to Patch Embedding
197 | """
198 |
199 | def __init__(self, img_size=224, patch_size=3, stride=2, in_chans=3, embed_dim=768):
200 | super().__init__()
201 | patch_size = to_2tuple(patch_size)
202 | img_size = to_2tuple(img_size)
203 |
204 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride,
205 | padding=(patch_size[0] // 2, patch_size[1] // 2))
206 | self.norm = nn.LayerNorm(embed_dim)
207 |
208 | self.apply(self._init_weights)
209 |
210 | def _init_weights(self, m):
211 | if isinstance(m, nn.Linear):
212 | trunc_normal_(m.weight, std=.02)
213 | if isinstance(m, nn.Linear) and m.bias is not None:
214 | nn.init.constant_(m.bias, 0)
215 | elif isinstance(m, nn.LayerNorm):
216 | nn.init.constant_(m.bias, 0)
217 | nn.init.constant_(m.weight, 1.0)
218 | elif isinstance(m, nn.Conv2d):
219 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
220 | fan_out //= m.groups
221 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
222 | if m.bias is not None:
223 | m.bias.data.zero_()
224 |
225 | def forward(self, x):
226 | x = self.proj(x)
227 | _, _, H, W = x.shape
228 | x = x.flatten(2).transpose(1, 2)
229 | x = self.norm(x)
230 |
231 | return x, H, W
232 |
233 |
234 | class Head(nn.Module):
235 | def __init__(self, head_conv, dim):
236 | super(Head, self).__init__()
237 | stem = [nn.Conv2d(3, dim, head_conv, 2, padding=3 if head_conv == 7 else 1, bias=False), nn.BatchNorm2d(dim),
238 | nn.ReLU(True)]
239 | stem.append(nn.Conv2d(dim, dim, kernel_size=2, stride=2))
240 | self.conv = nn.Sequential(*stem)
241 | self.norm = nn.LayerNorm(dim)
242 | self.apply(self._init_weights)
243 |
244 | def _init_weights(self, m):
245 | if isinstance(m, nn.Linear):
246 | trunc_normal_(m.weight, std=.02)
247 | if isinstance(m, nn.Linear) and m.bias is not None:
248 | nn.init.constant_(m.bias, 0)
249 | elif isinstance(m, nn.LayerNorm):
250 | nn.init.constant_(m.bias, 0)
251 | nn.init.constant_(m.weight, 1.0)
252 | elif isinstance(m, nn.Conv2d):
253 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
254 | fan_out //= m.groups
255 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
256 | if m.bias is not None:
257 | m.bias.data.zero_()
258 |
259 | def forward(self, x):
260 | x = self.conv(x)
261 | _, _, H, W = x.shape
262 | x = x.flatten(2).transpose(1, 2)
263 | x = self.norm(x)
264 | return x, H, W
265 |
266 |
267 | class SMT(nn.Module):
268 | def __init__(self, img_size=224, in_chans=3, num_classes=1000, embed_dims=[64, 128, 256, 512],
269 | ca_num_heads=[4, 4, 4, -1], sa_num_heads=[-1, -1, 8, 16], mlp_ratios=[8, 6, 4, 2],
270 | qkv_bias=False, qk_scale=None, use_layerscale=False, layerscale_value=1e-4, drop_rate=0.,
271 | attn_drop_rate=0., drop_path_rate=0., norm_layer=partial(nn.LayerNorm, eps=1e-6),
272 | depths=[2, 2, 8, 1], ca_attentions=[1, 1, 1, 0], num_stages=4, head_conv=3, expand_ratio=2, **kwargs):
273 | super().__init__()
274 | self.num_classes = num_classes
275 | self.depths = depths
276 | self.num_stages = num_stages
277 |
278 |
279 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
280 | cur = 0
281 |
282 | for i in range(num_stages):
283 | if i == 0:
284 | patch_embed = Head(head_conv, embed_dims[i]) #
285 | else:
286 | patch_embed = OverlapPatchEmbed(img_size=img_size if i == 0 else img_size // (2 ** (i + 1)),
287 | patch_size=3,
288 | stride=2,
289 | in_chans=embed_dims[i - 1],
290 | embed_dim=embed_dims[i])
291 |
292 | block = nn.ModuleList([Block(
293 | dim=embed_dims[i], ca_num_heads=ca_num_heads[i], sa_num_heads=sa_num_heads[i], mlp_ratio=mlp_ratios[i],
294 | qkv_bias=qkv_bias, qk_scale=qk_scale,
295 | use_layerscale=use_layerscale,
296 | layerscale_value=layerscale_value,
297 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + j], norm_layer=norm_layer,
298 | ca_attention=0 if i == 2 and j % 2 != 0 else ca_attentions[i], expand_ratio=expand_ratio)
299 | for j in range(depths[i])])
300 | norm = norm_layer(embed_dims[i])
301 | cur += depths[i]
302 |
303 | setattr(self, f"patch_embed{i + 1}", patch_embed)
304 | setattr(self, f"block{i + 1}", block)
305 | setattr(self, f"norm{i + 1}", norm)
306 |
307 | # classification head
308 | self.head = nn.Linear(embed_dims[3], num_classes) if num_classes > 0 else nn.Identity()
309 |
310 | self.apply(self._init_weights)
311 |
312 | def _init_weights(self, m):
313 | if isinstance(m, nn.Linear):
314 | trunc_normal_(m.weight, std=.02)
315 | if isinstance(m, nn.Linear) and m.bias is not None:
316 | nn.init.constant_(m.bias, 0)
317 | elif isinstance(m, nn.LayerNorm):
318 | nn.init.constant_(m.bias, 0)
319 | nn.init.constant_(m.weight, 1.0)
320 | elif isinstance(m, nn.Conv2d):
321 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
322 | fan_out //= m.groups
323 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
324 | if m.bias is not None:
325 | m.bias.data.zero_()
326 |
327 | def freeze_patch_emb(self):
328 | self.patch_embed1.requires_grad = False
329 |
330 | @torch.jit.ignore
331 | def no_weight_decay(self):
332 | return {'pos_embed1', 'pos_embed2', 'pos_embed3', 'pos_embed4', 'cls_token'} # has pos_embed may be better
333 |
334 | def get_classifier(self):
335 | return self.head
336 |
337 | def reset_classifier(self, num_classes, global_pool=''):
338 | self.num_classes = num_classes
339 | self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
340 |
341 | def forward_features(self, x):
342 | B = x.shape[0]
343 | f_list = []
344 | for i in range(self.num_stages):
345 | patch_embed = getattr(self, f"patch_embed{i + 1}")
346 | block = getattr(self, f"block{i + 1}")
347 | norm = getattr(self, f"norm{i + 1}")
348 | x, H, W = patch_embed(x)
349 | for blk in block:
350 | x = blk(x, H, W)
351 | x = norm(x)
352 | # if i != self.num_stages - 1:
353 | # x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
354 | x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
355 | f_list.append(x)
356 |
357 | return x.mean(dim=1), f_list
358 |
359 | def forward(self, x):
360 | x, f_list = self.forward_features(x)
361 | # x = self.head(x)
362 |
363 | return x, f_list
364 |
365 |
366 | class DWConv(nn.Module):
367 | def __init__(self, dim=768):
368 | super(DWConv, self).__init__()
369 | self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim)
370 |
371 | def forward(self, x, H, W):
372 | B, N, C = x.shape
373 | x = x.transpose(1, 2).view(B, C, H, W)
374 | x = self.dwconv(x)
375 | x = x.flatten(2).transpose(1, 2)
376 |
377 | return x
378 |
379 |
380 | def build_transforms(img_size, center_crop=False):
381 | t = []
382 | if center_crop:
383 | size = int((256 / 224) * img_size)
384 | t.append(
385 | transforms.Resize(size, interpolation=str_to_pil_interp('bicubic'))
386 | )
387 | t.append(
388 | transforms.CenterCrop(img_size)
389 | )
390 | else:
391 | t.append(
392 | transforms.Resize(img_size, interpolation=str_to_pil_interp('bicubic'))
393 | )
394 | t.append(transforms.ToTensor())
395 | t.append(transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD))
396 | return transforms.Compose(t)
397 |
398 |
399 | def build_transforms4display(img_size, center_crop=False):
400 | t = []
401 | if center_crop:
402 | size = int((256 / 224) * img_size)
403 | t.append(
404 | transforms.Resize(size, interpolation=str_to_pil_interp('bicubic'))
405 | )
406 | t.append(
407 | transforms.CenterCrop(img_size)
408 | )
409 | else:
410 | t.append(
411 | transforms.Resize(img_size, interpolation=str_to_pil_interp('bicubic'))
412 | )
413 | t.append(transforms.ToTensor())
414 | return transforms.Compose(t)
415 |
416 |
417 | def smt_t(pretrained=False, **kwargs):
418 | model = SMT(
419 | embed_dims=[64, 128, 256, 512], ca_num_heads=[4, 4, 4, -1], sa_num_heads=[-1, -1, 8, 16],
420 | mlp_ratios=[4, 4, 4, 2],
421 | qkv_bias=True, depths=[2, 2, 8, 1], ca_attentions=[1, 1, 1, 0], head_conv=3, expand_ratio=2, **kwargs)
422 | model.default_cfg = _cfg()
423 |
424 | return model
425 |
426 |
427 | def smt_s(pretrained=False, **kwargs):
428 | model = SMT(
429 | embed_dims=[64, 128, 256, 512], ca_num_heads=[4, 4, 4, -1], sa_num_heads=[-1, -1, 8, 16],
430 | mlp_ratios=[4, 4, 4, 2],
431 | qkv_bias=True, depths=[3, 4, 18, 2], ca_attentions=[1, 1, 1, 0], head_conv=3, expand_ratio=2, **kwargs)
432 | model.default_cfg = _cfg()
433 | return model
434 |
435 |
436 | def smt_b(pretrained=False, **kwargs):
437 | model = SMT(
438 | embed_dims=[64, 128, 256, 512], ca_num_heads=[4, 4, 4, -1], sa_num_heads=[-1, -1, 8, 16],
439 | mlp_ratios=[8, 6, 4, 2],
440 | qkv_bias=True, depths=[4, 6, 28, 2], ca_attentions=[1, 1, 1, 0], head_conv=7, expand_ratio=2, **kwargs)
441 | model.default_cfg = _cfg()
442 |
443 | return model
444 |
445 |
446 | def smt_l(pretrained=False, **kwargs):
447 | model = SMT(
448 | embed_dims=[96, 192, 384, 768], ca_num_heads=[4, 4, 4, -1], sa_num_heads=[-1, -1, 8, 16],
449 | mlp_ratios=[8, 6, 4, 2],
450 | qkv_bias=True, depths=[4, 6, 28, 4], ca_attentions=[1, 1, 1, 0], head_conv=7, expand_ratio=2, **kwargs)
451 | model.default_cfg = _cfg()
452 |
453 | return model
454 |
455 |
456 | if __name__ == '__main__':
457 | import torch
458 | from collections import OrderedDict
459 |
460 | model = smt_t()
461 | # solution 2
462 | model_path = '../ckps/smt/smt_tiny.pth'
463 | device = 'cuda:0'
464 | model.load_state_dict(torch.load(model_path)['model'])
465 | torch.save(model.state_dict(), "../ckps/smt.pth")
466 | # model2 = torch.load(model_path, map_location=device)["model"]
467 | # new_pth = dict()
468 | # for i in model2:
469 | # for k,v in i:
470 | # new_pth[k] = v
471 | # model.load_state_dict(new_pth)
472 |
473 | # 将模型移动到指定的设备上
474 | # model = model.cuda()
475 | # input = torch.rand(4, 3, 224, 224).cuda()
476 | # output = model(input)
477 | # print(model)
478 |
479 | ### thop cal ###
480 | # input_shape = (1, 3, 384, 384) # 输入的形状
481 | # input_data = torch.randn(*input_shape)
482 | # macs, params = profile(model, inputs=(input_data,))
483 | # print(f"FLOPS: {macs / 1e9:.2f}G")
484 | # print(f"params: {params / 1e6:.2f}M")
485 |
486 | ### ptflops cal ###
487 | # flops_count, params_count = get_model_complexity_info(model, (3, 224, 224), as_strings=True,
488 | # print_per_layer_stat=False)
489 | #
490 | # print('flops: ', flops_count)
491 | # print('params: ', params_count)
492 |
--------------------------------------------------------------------------------