├── model ├── __init__.py ├── MobileNetV2.py ├── MAGNet.py └── smt.py ├── ckps └── smt │ └── pretrained SMT here ├── test_maps └── pred maps here ├── pytorch_iou ├── __pycache__ │ └── __init__.cpython-39.pyc └── __init__.py ├── LICENSE ├── test_Net.py ├── utils.py ├── README.md ├── train_Net.py └── data.py /model/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /ckps/smt/pretrained SMT here: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /test_maps/pred maps here: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /pytorch_iou/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mingyu6346/MAGNet/HEAD/pytorch_iou/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /pytorch_iou/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | from torch.autograd import Variable 6 | import numpy as np 7 | 8 | def _iou(pred, target, size_average = True): 9 | 10 | b = pred.shape[0] 11 | IoU = 0.0 12 | for i in range(0,b): 13 | #compute the IoU of the foreground 14 | Iand1 = torch.sum(target[i,:,:,:]*pred[i,:,:,:]) 15 | Ior1 = torch.sum(target[i,:,:,:]) + torch.sum(pred[i,:,:,:])-Iand1 16 | IoU1 = Iand1/Ior1 17 | 18 | #IoU loss is (1-IoU1) 19 | IoU = IoU + (1-IoU1) 20 | 21 | return IoU/b 22 | 23 | class IOU(torch.nn.Module): 24 | def __init__(self, size_average = True): 25 | super(IOU, self).__init__() 26 | self.size_average = size_average 27 | 28 | def forward(self, pred, target): 29 | 30 | return _iou(pred, target, self.size_average) 31 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 mingyu6346 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /test_Net.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import sys 4 | import warnings 5 | from tqdm import tqdm 6 | import numpy as np 7 | import os, argparse 8 | import cv2 9 | from model.MAGNet import MAGNet 10 | from data import test_dataset 11 | 12 | warnings.filterwarnings("ignore") 13 | 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument('--trainsize', type=int, default=384, help='testing size') 16 | parser.add_argument('--gpu_id', type=str, default='0', help='select gpu id') 17 | parser.add_argument('--test_path', type=str, default='', help='test dataset path') 18 | parser.add_argument('--save_path', type=str, default='./test_maps/MAGNet/', help='save path') 19 | parser.add_argument('--pth_path', type=str, default='', help='checkpoint path') 20 | opt = parser.parse_args() 21 | 22 | dataset_path = opt.test_path 23 | 24 | # set device for test 25 | if opt.gpu_id == '0': 26 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 27 | print('USE GPU 0') 28 | elif opt.gpu_id == '1': 29 | os.environ["CUDA_VISIB_LEDEVICES"] = "1" 30 | print('USE GPU 1') 31 | # load the model 32 | model = MAGNet().eval().cuda() 33 | model.load_state_dict(torch.load(opt.pth_path)) 34 | 35 | test_datasets = ['DUT', 'LFSD', 'NJU2K', 'NLPR', 'SIP', 'STERE'] 36 | 37 | 38 | for dataset in test_datasets: 39 | save_path = opt.save_path + dataset + '/' 40 | if not os.path.exists(save_path): 41 | os.makedirs(save_path) 42 | image_root = dataset_path + dataset + '/RGB/' 43 | gt_root = dataset_path + dataset + '/GT/' 44 | depth_root = dataset_path + dataset + '/depth/' 45 | # depth_root = dataset_path + dataset + '/T/' 46 | test_loader = test_dataset(image_root, gt_root, depth_root, opt.trainsize) 47 | for i in tqdm(range(test_loader.size), desc=dataset, file=sys.stdout): 48 | image, gt, depth, name, image_for_post = test_loader.load_data() 49 | gt = np.asarray(gt, np.float32) 50 | gt /= (gt.max() + 1e-8) 51 | image = image.cuda() 52 | depth = depth.repeat(1, 3, 1, 1).cuda() 53 | res, pred_2, pred_3, pred_4 = model(image, depth) 54 | res = F.upsample(res, size=gt.shape, mode='bilinear', align_corners=False) 55 | res = res.sigmoid().data.cpu().numpy().squeeze() 56 | res = (res - res.min()) / (res.max() - res.min() + 1e-8) 57 | cv2.imwrite(save_path + name, res * 255) 58 | print('Test Done!') 59 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import requests 5 | import torch 6 | 7 | 8 | def clip_gradient(optimizer, grad_clip): 9 | for group in optimizer.param_groups: 10 | for param in group['params']: 11 | if param.grad is not None: 12 | param.grad.data.clamp_(-grad_clip, grad_clip) 13 | 14 | 15 | def adjust_lr(optimizer, init_lr, epoch, decay_rate=0.1, decay_epoch=30): 16 | decay = decay_rate ** (epoch // decay_epoch) 17 | for param_group in optimizer.param_groups: 18 | param_group['lr'] = decay * init_lr 19 | lr = param_group['lr'] 20 | return lr 21 | 22 | 23 | def opt_save(opt): 24 | log_path = opt.save_path + "train_settings.log" 25 | if not os.path.exists(opt.save_path): 26 | os.makedirs(opt.save_path) 27 | att = [i for i in opt.__dir__() if not i.startswith("_")] 28 | with open(log_path, "w") as f: 29 | for i in att: 30 | print("{}:{}".format(i, eval(f"opt.{i}")), file=f) 31 | 32 | 33 | def iou_loss(pred, mask): 34 | pred = torch.sigmoid(pred) 35 | inter = (pred * mask).sum(dim=(2, 3)) 36 | union = (pred + mask).sum(dim=(2, 3)) 37 | iou = 1 - (inter + 1) / (union - inter + 1) 38 | return iou.mean() 39 | 40 | 41 | def fps(model, epoch_num, size): 42 | ls = [] # 每次计算得到的fps 43 | iterations = 300 # 重复计算的轮次 44 | device = torch.device("cuda:0") 45 | # device = torch.device("cpu") 46 | model.to(device) 47 | 48 | random_input = torch.randn(1, 3, size, size).to(device) 49 | starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True) 50 | 51 | # GPU预热 52 | for _ in range(50): 53 | _ = model(random_input, random_input) 54 | 55 | for i in range(epoch_num): 56 | # 测速 57 | times = torch.zeros(iterations) # 存储每轮iteration的时间 58 | with torch.no_grad(): 59 | for iter in range(iterations): 60 | starter.record() 61 | _ = model(random_input, random_input) 62 | ender.record() 63 | # 同步GPU时间 64 | torch.cuda.synchronize() 65 | curr_time = starter.elapsed_time(ender) # 计算时间 66 | times[iter] = curr_time 67 | # print(curr_time) 68 | 69 | mean_time = times.mean().item() 70 | ls.append(1000 / mean_time) 71 | print("{}/{} Inference time: {:.6f}, FPS: {} ".format(i + 1, epoch_num, mean_time, 1000 / mean_time)) 72 | print(f"平均fps为 {np.mean(ls):.2f}") -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # README.md 2 | 3 | This project provides the code and results for 'MAGNet: Multi-scale Awareness and Global Fusion Network for RGB-D Salient Object Detection'
4 | 5 | # Environments 6 | 7 | ```bash 8 | conda create -n magnet python=3.9.18 9 | conda activate magnet 10 | conda install pytorch==1.10.1 torchvision==0.11.2 torchaudio==0.10.1 cudatoolkit=11.3 -c pytorch -c conda-forge 11 | conda install -c conda-forge opencv-python==4.7.0 12 | pip install timm==0.6.5 13 | conda install -c conda-forge tqdm 14 | conda install yacs 15 | ``` 16 | 17 | # Data Preparation 18 | 19 | - Download the RGB-D raw data from [baidu](https://pan.baidu.com/s/10Y90OXUFoW8yAeRmr5LFnA?pwd=exwj) / [Google drive](https://drive.google.com/file/d/19HXwGJCtz0QdEDsEbH7cJqTBfD-CEXxX/view?usp=sharing)
20 | - Download the RGB-T raw data from [baidu](https://pan.baidu.com/s/1eexJSI4a2EGoaYcDkt1B9Q?pwd=i7a2) / [Google drive](https://drive.google.com/file/d/1hLhn5WV6xh-Q41upXF-bzyVpbszF9hUc/view?usp=sharing)
21 | 22 | Note that the depth maps of the raw data above are foreground is white. 23 | 24 | # Training & Testing 25 | 26 | - Train the MAGNet: 27 | 1. download the pretrained SMT pth from [baidu](https://pan.baidu.com/s/11bNtCS7HyjnB7Lf3RIbpFg?pwd=bxiw) / [Google drive](https://drive.google.com/file/d/1eNhQwUHmfjR-vVGY38D_dFYUOqD_pw-H/view?usp=sharing), and put it under `ckps/smt/`. 28 | 2. modify the `rgb_root` `depth_root` `gt_root` in `train_Net.py` according to your own data path. 29 | 3. run `python train_Net.py` 30 | - Test the MAGNet: 31 | 1. modify the `test_path` `pth_path` in `test_Net.py` according to your own data path. 32 | 2. run `python test_Net.py` 33 | 34 | # Evaluate tools 35 | 36 | - You can select one of toolboxes to get the metrics 37 | [CODToolbox](https://github.com/DengPingFan/CODToolbox) / [SOD_Evaluation_Metrics](https://github.com/zyjwuyan/SOD_Evaluation_Metrics) 38 | 39 | # Saliency Maps 40 | 41 | We provide the saliency maps of DUT, LFSD, NJU2K, NLPR, SIP, SSD, STERE datasets. 42 | 43 | - RGB-D [baidu](https://pan.baidu.com/s/1FK8jcDb61QdFvZF1qKMV6g?pwd=c3a6) / [Google drive](https://drive.google.com/file/d/1uoeNZPzsj2RAr0ofM8fPD6N0JJ8HCyn9/view?usp=sharing)
44 | 45 | We provide the saliency maps of VT821, VT1000, VT5000 datasets. 46 | 47 | - RGB-T [baidu](https://pan.baidu.com/s/1IQIkZS9GzmBT0PHflHqMNw?pwd=ebuw) / [Google drive](https://drive.google.com/file/d/198k-3R-yDF_y0Br7MoeSBP5XQOPuXPnL/view?usp=sharing)
48 | 49 | # Trained Models 50 | 51 | - RGB-D [baidu](https://pan.baidu.com/s/1RPMA5Z3liMoUlG0AWuGeRA?pwd=5aqf) / [Google drive](https://drive.google.com/file/d/1vb2Vcbz9bCjvaSwoRZjIi39ae5Ei1GVs/view?usp=sharing)
52 | 53 | # Acknowledgement 54 | 55 | The implement of this project is based on the codebases bellow.
56 | 57 | - [SeaNet](https://github.com/MathLee/SeaNet)
58 | - [LSNet](https://github.com/zyrant/LSNet)
59 | - Fps/speed test [MobileSal](https://github.com/yuhuan-wu/MobileSal/blob/master/speed_test.py) 60 | - Evaluate tools [CODToolbox](https://github.com/DengPingFan/CODToolbox) / [SOD_Evaluation_Metrics](https://github.com/zyjwuyan/SOD_Evaluation_Metrics)
61 | 62 | # Contact 63 | 64 | Feel free to contact me if you have any questions: (mingyu6346 at 163 dot com) 65 | -------------------------------------------------------------------------------- /model/MobileNetV2.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torch 3 | 4 | try: 5 | from torchvision.models.utils import load_state_dict_from_url # torchvision 0.4+ 6 | except ModuleNotFoundError: 7 | try: 8 | from torch.hub import load_state_dict_from_url # torch 1.x 9 | except ModuleNotFoundError: 10 | from torch.utils.model_zoo import load_url as load_state_dict_from_url # torch 0.4.1 11 | 12 | model_urls = { 13 | 'mobilenet_v2': 'https://download.pytorch.org/models/mobilenet_v2-b0353104.pth', 14 | } 15 | 16 | 17 | class ConvBNReLU(nn.Sequential): 18 | def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1, dilation=1): 19 | padding = (kernel_size - 1) // 2 20 | if dilation != 1: 21 | padding = dilation 22 | super(ConvBNReLU, self).__init__( 23 | nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, dilation=dilation, 24 | bias=False), 25 | nn.BatchNorm2d(out_planes), 26 | nn.ReLU6(inplace=True) 27 | ) 28 | 29 | 30 | class InvertedResidual(nn.Module): 31 | def __init__(self, inp, oup, stride, expand_ratio, dilation=1): 32 | super(InvertedResidual, self).__init__() 33 | self.stride = stride 34 | assert stride in [1, 2] 35 | 36 | hidden_dim = int(round(inp * expand_ratio)) 37 | self.use_res_connect = self.stride == 1 and inp == oup 38 | 39 | layers = [] 40 | if expand_ratio != 1: 41 | # pw 42 | layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1)) 43 | layers.extend([ 44 | # dw 45 | ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim, dilation=dilation), 46 | # pw-linear 47 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), 48 | nn.BatchNorm2d(oup), 49 | ]) 50 | self.conv = nn.Sequential(*layers) 51 | 52 | def forward(self, x): 53 | if self.use_res_connect: 54 | return x + self.conv(x) 55 | else: 56 | return self.conv(x) 57 | 58 | 59 | class MobileNetV2(nn.Module): 60 | def __init__(self, pretrained=None, num_classes=1000, width_mult=1.0): 61 | super(MobileNetV2, self).__init__() 62 | block = InvertedResidual 63 | input_channel = 32 64 | last_channel = 1280 65 | inverted_residual_setting = [ 66 | # t, c, n, s, d 67 | [1, 16, 1, 1, 1], # conv1 112*112*16 68 | [6, 24, 2, 2, 1], # conv2 56*56*24 69 | [6, 32, 3, 2, 1], # conv3 28*28*32 70 | [6, 64, 4, 2, 1], 71 | [6, 96, 3, 1, 1], # conv4 14*14*96 72 | [6, 160, 3, 2, 1], 73 | [6, 320, 1, 1, 1], # conv5 7*7*320 74 | ] 75 | 76 | # building first layer 77 | input_channel = int(input_channel * width_mult) 78 | self.last_channel = int(last_channel * max(1.0, width_mult)) 79 | features = [ConvBNReLU(3, input_channel, stride=2)] 80 | # building inverted residual blocks 81 | for t, c, n, s, d in inverted_residual_setting: 82 | output_channel = int(c * width_mult) 83 | for i in range(n): 84 | stride = s if i == 0 else 1 85 | dilation = d if i == 0 else 1 86 | features.append(block(input_channel, output_channel, stride, expand_ratio=t, dilation=d)) 87 | input_channel = output_channel 88 | # building last several layers 89 | features.append(ConvBNReLU(input_channel, self.last_channel, kernel_size=1)) 90 | # make it nn.Sequential 91 | self.features = nn.Sequential(*features) 92 | 93 | # weight initialization 94 | for m in self.modules(): 95 | if isinstance(m, nn.Conv2d): 96 | nn.init.kaiming_normal_(m.weight, mode='fan_out') 97 | if m.bias is not None: 98 | nn.init.zeros_(m.bias) 99 | elif isinstance(m, nn.BatchNorm2d): 100 | nn.init.ones_(m.weight) 101 | nn.init.zeros_(m.bias) 102 | elif isinstance(m, nn.Linear): 103 | nn.init.normal_(m.weight, 0, 0.01) 104 | nn.init.zeros_(m.bias) 105 | 106 | def forward(self, x): 107 | res = [] 108 | for idx, m in enumerate(self.features): 109 | x = m(x) 110 | if idx in [1, 3, 6, 13, 17]: 111 | res.append(x) 112 | return res 113 | 114 | 115 | def mobilenet_v2(pretrained=False, progress=True, **kwargs): 116 | model = MobileNetV2(**kwargs) 117 | if pretrained: 118 | state_dict = load_state_dict_from_url(model_urls['mobilenet_v2'], 119 | progress=progress) 120 | print("loading imagenet pretrained mobilenetv2") 121 | model.load_state_dict(state_dict, strict=False) 122 | print("loaded imagenet pretrained mobilenetv2") 123 | return model -------------------------------------------------------------------------------- /train_Net.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import sys 3 | import torch.nn.functional as F 4 | import torch 5 | import numpy as np 6 | import os, argparse 7 | from datetime import datetime 8 | from tqdm import tqdm 9 | from model.MAGNet import MAGNet 10 | from data import get_loader, test_dataset 11 | from utils import clip_gradient, adjust_lr, opt_save, iou_loss 12 | 13 | import pytorch_iou 14 | 15 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' 16 | parser = argparse.ArgumentParser() 17 | parser.add_argument('--epoch', type=int, default=200, help='epoch number') 18 | parser.add_argument('--lr', type=float, default=5e-5, help='learning rate') 19 | parser.add_argument('--batchsize', type=int, default=8, help='training batch size') 20 | parser.add_argument('--trainsize', type=int, default=384, help='training image size') 21 | parser.add_argument('--continue_train', type=bool, default=False, help='continue training') 22 | parser.add_argument('--continue_train_path', type=str, default='', help='continue training path') 23 | 24 | parser.add_argument('--rgb_root', type=str, default='D:/DataSet/SOD/RGB-D SOD/train_dut/RGB/', 25 | help='the training rgb images root') # train_dut 26 | parser.add_argument('--depth_root', type=str, default='D:/DataSet/SOD/RGB-D SOD/train_dut/depth/', 27 | help='the training depth images root') 28 | parser.add_argument('--gt_root', type=str, default='D:/DataSet/SOD/RGB-D SOD/train_dut/GT/', 29 | help='the training gt images root') 30 | 31 | parser.add_argument('--val_rgb', type=str, default="D:/DataSet/SOD/RGB-D SOD/test_data/NLPR/RGB/", 32 | help='validate rgb path') 33 | parser.add_argument('--val_depth', type=str, default="D:/DataSet/SOD/RGB-D SOD/test_data/NLPR/depth/", 34 | help='validate depth path') 35 | parser.add_argument('--val_gt', type=str, default="D:/DataSet/SOD/RGB-D SOD/test_data/NLPR/GT/", 36 | help='validate gt path') 37 | 38 | parser.add_argument('--clip', type=float, default=0.5, help='gradient clipping margin') 39 | parser.add_argument('--decay_rate', type=float, default=0.1, help='decay rate of learning rate') 40 | parser.add_argument('--decay_epoch', type=int, default=80, help='every n epochs decay learning rate') 41 | parser.add_argument('--save_path', type=str, default="ckps/MAGNet/", help='checkpoint path') 42 | 43 | opt = parser.parse_args() 44 | 45 | opt_save(opt) 46 | logging.basicConfig(filename=opt.save_path + 'log.log', 47 | format='[%(asctime)s-%(filename)s-%(levelname)s:%(message)s]', level=logging.INFO, filemode='a', 48 | datefmt='%Y-%m-%d %H:%M:%S %p') 49 | logging.info("Net-Train") 50 | # model 51 | model = MAGNet() 52 | if os.path.exists("ckps/smt/smt_tiny.pth"): 53 | model.rgb_backbone.load_state_dict(torch.load("ckps/smt/smt_tiny.pth")['model']) 54 | print(f"loaded imagenet pretrained SMT from ckps/smt/smt_tiny.pth") 55 | else: 56 | raise "please put smt_tiny.pth under ckps/smt/ folder" 57 | if opt.continue_train: 58 | model.load_state_dict(torch.load(opt.continue_train_path)) 59 | print(f"continue training from {opt.continue_train_path}") 60 | 61 | model.cuda() 62 | params = model.parameters() 63 | optimizer = torch.optim.Adam(params, opt.lr) 64 | 65 | print('load data...') 66 | train_loader = get_loader(opt.rgb_root, opt.gt_root, opt.depth_root, batchsize=opt.batchsize, trainsize=opt.trainsize) 67 | total_step = len(train_loader) 68 | 69 | CE = torch.nn.BCEWithLogitsLoss() 70 | IOU = pytorch_iou.IOU(size_average=True) 71 | best_loss = 1.0 72 | 73 | 74 | def train(train_loader, model, optimizer, epoch): 75 | model.train() 76 | loss_list = [] 77 | for i, pack in enumerate(train_loader, start=1): 78 | optimizer.zero_grad() 79 | (images, gts, depth) = pack 80 | images = images.cuda() 81 | gts = gts.cuda() 82 | depth = depth.cuda().repeat(1, 3, 1, 1) 83 | 84 | pred_1, pred_2, pred_3, pred_4 = model(images, depth) 85 | 86 | loss1 = CE(pred_1, gts) + iou_loss(pred_1, gts) 87 | loss2 = CE(pred_2, gts) + iou_loss(pred_2, gts) 88 | loss3 = CE(pred_3, gts) + iou_loss(pred_3, gts) 89 | loss4 = CE(pred_4, gts) + iou_loss(pred_4, gts) 90 | 91 | loss = loss1 + loss2 + loss3 + loss4 92 | 93 | loss.backward() 94 | 95 | clip_gradient(optimizer, opt.clip) 96 | optimizer.step() 97 | 98 | loss_list.append(float(loss)) 99 | if i % 20 == 0 or i == total_step: 100 | msg = '{} Epoch [{:03d}/{:03d}], Step [{:04d}/{:04d}], Learning Rate: {}, Loss: {:.4f}, Loss1: {:.4f}, Loss2: {:.4f}, Loss3: {:.4f}, Loss4: {:.4f}'.format( 101 | datetime.now(), epoch, opt.epoch, i, total_step, 102 | opt.lr * opt.decay_rate ** (epoch // opt.decay_epoch), loss.data, loss1.data, 103 | loss2.data, loss3.data, loss4.data) 104 | print(msg) 105 | logging.info(msg) 106 | epoch_loss = np.mean(loss_list) 107 | if not os.path.exists(opt.save_path): 108 | os.makedirs(opt.save_path) 109 | 110 | global best_loss 111 | if epoch_loss < best_loss: 112 | best_loss = epoch_loss 113 | torch.save(model.state_dict(), opt.save_path + 'Net_rgb_d.pth' + f'.{epoch}_{epoch_loss:.3f}', 114 | _use_new_zipfile_serialization=False) 115 | with open(opt.save_path + "loss.log", "a") as f: 116 | print(f"{datetime.now()} epoch {epoch} loss {np.mean(loss_list):.3f}", file=f) 117 | 118 | 119 | best_mae = 1.0 120 | best_epoch = 0 121 | 122 | 123 | def validate(test_dataset, model, epoch, opt): 124 | global best_mae, best_epoch 125 | model.eval().cuda() 126 | mae_sum = 0 127 | test_loader = test_dataset(opt.val_rgb, opt.val_gt, opt.val_depth, opt.trainsize) 128 | with torch.no_grad(): 129 | for i in tqdm(range(test_loader.size), desc="Validating", file=sys.stdout): 130 | image, gt, depth, name, image_for_post = test_loader.load_data() 131 | gt = np.asarray(gt, np.float32) 132 | gt /= (gt.max() + 1e-8) 133 | image = image.cuda() 134 | depth = depth.repeat(1, 3, 1, 1).cuda() 135 | 136 | res, _, _, _ = model(image, depth) 137 | res = F.interpolate(res, size=gt.shape, mode='bilinear', align_corners=False) 138 | res = res.sigmoid().data.cpu().numpy().squeeze() 139 | 140 | res = (res - res.min()) / (res.max() - res.min() + 1e-8) 141 | # mae_train = torch.sum(torch.abs(res - gt)) * 1.0 / (torch.numel(gt)) 142 | mae_train = np.mean(np.abs(res - gt)) 143 | mae_sum = mae_train + mae_sum 144 | mae = mae_sum / test_loader.size 145 | 146 | if epoch == 0: 147 | best_mae = mae 148 | else: 149 | if mae < best_mae: 150 | best_mae = round(mae, 5) 151 | best_epoch = epoch 152 | torch.save(model.state_dict(), opt.save_path + 'MAGNet_mae_best.pth', _use_new_zipfile_serialization=False) 153 | print('best epoch:{}'.format(epoch)) 154 | msg = 'Epoch: {} MAE: {:.5f} #### bestMAE: {} bestEpoch: {}'.format(epoch, mae, best_mae, best_epoch) 155 | print(msg) 156 | logging.info(msg) 157 | with open(f"{opt.save_path}mae.log", "a", encoding='utf-8') as f: 158 | f.write('Epoch: {:03d} MAE: {:.5f} #### bestMAE: {:.5f} bestEpoch: {:03d}\n'.format(epoch, mae, best_mae, 159 | best_epoch)) 160 | return mae 161 | 162 | 163 | print("Let's go!") 164 | for epoch in range(opt.epoch): 165 | adjust_lr(optimizer, opt.lr, epoch, opt.decay_rate, opt.decay_epoch) 166 | train(train_loader, model, optimizer, epoch) 167 | validate(test_dataset, model, epoch, opt) 168 | -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image 3 | import torch.utils.data as data 4 | import torchvision.transforms as transforms 5 | import random 6 | import numpy as np 7 | from PIL import ImageEnhance 8 | from tqdm import tqdm 9 | 10 | 11 | # several data augumentation strategies 12 | def cv_random_flip(img, label, depth): 13 | flip_flag = random.randint(0, 1) 14 | # flip_flag2= random.randint(0,1) 15 | # left right flip 16 | if flip_flag == 1: 17 | img = img.transpose(Image.FLIP_LEFT_RIGHT) 18 | label = label.transpose(Image.FLIP_LEFT_RIGHT) 19 | depth = depth.transpose(Image.FLIP_LEFT_RIGHT) 20 | # top bottom flip 21 | # if flip_flag2==1: 22 | # img = img.transpose(Image.FLIP_TOP_BOTTOM) 23 | # label = label.transpose(Image.FLIP_TOP_BOTTOM) 24 | # depth = depth.transpose(Image.FLIP_TOP_BOTTOM) 25 | return img, label, depth 26 | 27 | 28 | def randomCrop(image, label, depth): 29 | border = 30 30 | image_width = image.size[0] 31 | image_height = image.size[1] 32 | crop_win_width = np.random.randint(image_width - border, image_width) 33 | crop_win_height = np.random.randint(image_height - border, image_height) 34 | random_region = ( 35 | (image_width - crop_win_width) >> 1, (image_height - crop_win_height) >> 1, (image_width + crop_win_width) >> 1, 36 | (image_height + crop_win_height) >> 1) 37 | return image.crop(random_region), label.crop(random_region), depth.crop(random_region) 38 | 39 | 40 | def randomRotation(image, label, depth): 41 | mode = Image.BICUBIC 42 | if random.random() > 0.8: 43 | random_angle = np.random.randint(-15, 15) 44 | image = image.rotate(random_angle, mode) 45 | label = label.rotate(random_angle, mode) 46 | depth = depth.rotate(random_angle, mode) 47 | return image, label, depth 48 | 49 | 50 | def colorEnhance(image): 51 | bright_intensity = random.randint(5, 15) / 10.0 52 | image = ImageEnhance.Brightness(image).enhance(bright_intensity) 53 | contrast_intensity = random.randint(5, 15) / 10.0 54 | image = ImageEnhance.Contrast(image).enhance(contrast_intensity) 55 | color_intensity = random.randint(0, 20) / 10.0 56 | image = ImageEnhance.Color(image).enhance(color_intensity) 57 | sharp_intensity = random.randint(0, 30) / 10.0 58 | image = ImageEnhance.Sharpness(image).enhance(sharp_intensity) 59 | return image 60 | 61 | 62 | def randomGaussian(image, mean=0.1, sigma=0.35): 63 | def gaussianNoisy(im, mean=mean, sigma=sigma): 64 | for _i in range(len(im)): 65 | im[_i] += random.gauss(mean, sigma) 66 | return im 67 | 68 | img = np.asarray(image) 69 | width, height = img.shape 70 | img = gaussianNoisy(img[:].flatten(), mean, sigma) 71 | img = img.reshape([width, height]) 72 | return Image.fromarray(np.uint8(img)) 73 | 74 | 75 | def randomPepper(img): 76 | img = np.array(img) 77 | noiseNum = int(0.0015 * img.shape[0] * img.shape[1]) 78 | for i in range(noiseNum): 79 | 80 | randX = random.randint(0, img.shape[0] - 1) 81 | 82 | randY = random.randint(0, img.shape[1] - 1) 83 | 84 | if random.randint(0, 1) == 0: 85 | 86 | img[randX, randY] = 0 87 | 88 | else: 89 | 90 | img[randX, randY] = 255 91 | return Image.fromarray(img) 92 | 93 | 94 | # dataset for training 95 | # The current loader is not using the normalized depth maps for training and test. If you use the normalized depth maps 96 | # (e.g., 0 represents background and 1 represents foreground.), the performance will be further improved. 97 | class SalObjDataset(data.Dataset): 98 | def __init__(self, image_root, gt_root, depth_root, trainsize): 99 | self.trainsize = trainsize 100 | self.images = [image_root + f for f in os.listdir(image_root) if f.endswith('.jpg')] 101 | 102 | self.gts = [gt_root + f for f in os.listdir(gt_root) if f.endswith('.jpg') 103 | or f.endswith('.png')] 104 | 105 | self.depths = [depth_root + f for f in os.listdir(depth_root) if f.endswith('.bmp') 106 | or f.endswith('.png') or f.endswith('.jpg')] 107 | self.images = sorted(self.images) 108 | self.gts = sorted(self.gts) 109 | self.depths = sorted(self.depths) 110 | self.filter_files() 111 | self.size = len(self.images) 112 | self.img_transform = transforms.Compose([ 113 | transforms.Resize((self.trainsize, self.trainsize)), 114 | transforms.ToTensor(), 115 | transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))]) 116 | self.gt_transform = transforms.Compose([ 117 | transforms.Resize((self.trainsize, self.trainsize)), 118 | transforms.ToTensor()]) 119 | self.depths_transform = transforms.Compose( 120 | [transforms.Resize((self.trainsize, self.trainsize)), transforms.ToTensor()]) 121 | 122 | def __getitem__(self, index): 123 | image = self.rgb_loader(self.images[index]) 124 | gt = self.binary_loader(self.gts[index]) 125 | depth = self.binary_loader(self.depths[index]) 126 | 127 | image, gt, depth = cv_random_flip(image, gt, depth) 128 | image, gt, depth = randomCrop(image, gt, depth) 129 | image, gt, depth = randomRotation(image, gt, depth) 130 | image = colorEnhance(image) 131 | # gt=randomGaussian(gt) 132 | gt = randomPepper(gt) 133 | image = self.img_transform(image) 134 | gt = self.gt_transform(gt) 135 | depth = self.depths_transform(depth) 136 | 137 | return image, gt, depth 138 | 139 | def filter_files(self): 140 | assert len(self.images) == len(self.gts) and len(self.gts) == len(self.images) 141 | images = [] 142 | gts = [] 143 | depths = [] 144 | for img_path, gt_path, depth_path in zip(self.images, self.gts, self.depths): 145 | img = Image.open(img_path) 146 | gt = Image.open(gt_path) 147 | depth = Image.open(depth_path) 148 | if img.size == gt.size and gt.size == depth.size: 149 | images.append(img_path) 150 | gts.append(gt_path) 151 | depths.append(depth_path) 152 | self.images = images 153 | self.gts = gts 154 | self.depths = depths 155 | 156 | def rgb_loader(self, path): 157 | with open(path, 'rb') as f: 158 | img = Image.open(f) 159 | return img.convert('RGB') 160 | 161 | def binary_loader(self, path): 162 | with open(path, 'rb') as f: 163 | img = Image.open(f) 164 | return img.convert('L') 165 | 166 | def resize(self, img, gt, depth): 167 | assert img.size == gt.size and gt.size == depth.size 168 | w, h = img.size 169 | if h < self.trainsize or w < self.trainsize: 170 | h = max(h, self.trainsize) 171 | w = max(w, self.trainsize) 172 | return img.resize((w, h), Image.BILINEAR), gt.resize((w, h), Image.NEAREST), \ 173 | depth.resize((w, h), Image.NEAREST) 174 | else: 175 | return img, gt, depth 176 | 177 | def __len__(self): 178 | return self.size 179 | 180 | 181 | # dataloader for training 182 | def get_loader(image_root, gt_root, depth_root, batchsize, trainsize, shuffle=True, num_workers=0, pin_memory=True): 183 | dataset = SalObjDataset(image_root, gt_root, depth_root, trainsize) 184 | # print(image_root) 185 | # print(gt_root) 186 | # print(depth_root) 187 | data_loader = data.DataLoader(dataset=dataset, 188 | batch_size=batchsize, 189 | shuffle=shuffle, 190 | num_workers=num_workers, 191 | pin_memory=pin_memory) 192 | return data_loader 193 | 194 | 195 | # test dataset and loader 196 | class test_dataset: 197 | def __init__(self, image_root, gt_root, depth_root, testsize): 198 | self.testsize = testsize 199 | self.images = [image_root + f for f in os.listdir(image_root) if f.endswith('.jpg')] 200 | self.gts = [gt_root + f for f in os.listdir(gt_root) if f.endswith('.jpg') 201 | or f.endswith('.png')] 202 | self.depths = [depth_root + f for f in os.listdir(depth_root) if f.endswith('.bmp') 203 | or f.endswith('.png') or f.endswith('.jpg')] 204 | 205 | self.images = sorted(self.images) 206 | self.gts = sorted(self.gts) 207 | self.depths = sorted(self.depths) 208 | self.transform = transforms.Compose([ 209 | transforms.Resize((self.testsize, self.testsize)), 210 | transforms.ToTensor(), 211 | transforms.Normalize(mean=(0.485, 0.458, 0.407), std=(0.229, 0.224, 0.225))]) 212 | # transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) 213 | self.gt_transform = transforms.ToTensor() 214 | # self.gt_transform = transforms.Compose([ 215 | # transforms.Resize((self.trainsize, self.trainsize)), 216 | # transforms.ToTensor()]) 217 | self.depths_transform = transforms.Compose( 218 | [transforms.Resize((self.testsize, self.testsize)), transforms.ToTensor()]) 219 | 220 | self.size = len(self.images) 221 | self.index = 0 222 | 223 | def load_data(self): 224 | image = self.rgb_loader(self.images[self.index]) 225 | image = self.transform(image).unsqueeze(0) 226 | gt = self.binary_loader(self.gts[self.index]) 227 | depth = self.binary_loader(self.depths[self.index]) 228 | depth = self.depths_transform(depth).unsqueeze(0) 229 | name = self.gts[self.index].split('/')[-1] 230 | image_for_post = self.rgb_loader(self.images[self.index]) 231 | image_for_post = image_for_post.resize(gt.size) 232 | if name.endswith('.jpg'): 233 | name = name.split('.jpg')[0] + '.jpg' 234 | self.index += 1 235 | self.index = self.index % self.size 236 | return image, gt, depth, name, np.array(image_for_post) 237 | 238 | def rgb_loader(self, path): 239 | with open(path, 'rb') as f: 240 | img = Image.open(f) 241 | return img.convert('RGB') 242 | 243 | def binary_loader(self, path): 244 | with open(path, 'rb') as f: 245 | img = Image.open(f) 246 | return img.convert('L') 247 | 248 | def __len__(self): 249 | return self.size 250 | 251 | 252 | if __name__ == "__main__": 253 | ir = "E:/lmt/train_dut/train_dut/train_images/" 254 | dr = "E:/lmt/train_dut/train_dut/train_depth/" 255 | gt = "E:/lmt/train_dut/train_dut/train_masks/" 256 | b = 8 257 | ts = 384 258 | train_loader = get_loader(ir, gt, dr, b, ts) 259 | train_loader = tqdm(train_loader) 260 | for i, (images, gts, depth) in enumerate(train_loader, start=1): 261 | pass 262 | -------------------------------------------------------------------------------- /model/MAGNet.py: -------------------------------------------------------------------------------- 1 | import math 2 | from model.smt import smt_t 3 | from model.MobileNetV2 import mobilenet_v2 4 | import torch.nn as nn 5 | import torch 6 | import torch.nn.functional as F 7 | from timm.models.layers import trunc_normal_ 8 | 9 | TRAIN_SIZE = 384 10 | 11 | 12 | class MAGNet(nn.Module): 13 | def __init__(self, pretrained=True): 14 | super().__init__() 15 | self.rgb_backbone = smt_t(pretrained) 16 | self.d_backbone = mobilenet_v2(pretrained) 17 | 18 | # Fuse 19 | self.gfm2 = GFM(inc=512, expend_ratio=2) 20 | self.gfm1 = GFM(inc=256, expend_ratio=3) 21 | self.mafm2 = MAFM(inc=128) 22 | self.mafm1 = MAFM(inc=64) 23 | 24 | # Pred 25 | self.mcm3 = MCM(inc=512, outc=256) 26 | self.mcm2 = MCM(inc=256, outc=128) 27 | self.mcm1 = MCM(inc=128, outc=64) 28 | 29 | self.d_trans_4 = Trans(320, 512) 30 | self.d_trans_3 = Trans(96, 256) 31 | self.d_trans_2 = Trans(32, 128) 32 | self.d_trans_1 = Trans(24, 64) 33 | 34 | self.predtrans = nn.Sequential( 35 | nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1, groups=512), 36 | nn.BatchNorm2d(512), 37 | nn.GELU(), 38 | nn.Conv2d(in_channels=512, out_channels=1, kernel_size=1, stride=1) 39 | ) 40 | 41 | def forward(self, x_rgb, x_d): 42 | # rgb 43 | _, (rgb_1, rgb_2, rgb_3, rgb_4) = self.rgb_backbone(x_rgb) 44 | 45 | # d 46 | _, d_1, d_2, d_3, d_4 = self.d_backbone(x_d) 47 | 48 | d_4 = self.d_trans_4(d_4) 49 | d_3 = self.d_trans_3(d_3) 50 | d_2 = self.d_trans_2(d_2) 51 | d_1 = self.d_trans_1(d_1) 52 | 53 | # Fuse 54 | fuse_4 = self.gfm2(rgb_4, d_4) # [B, 512, 12, 12] 55 | fuse_3 = self.gfm1(rgb_3, d_3) # [B, 256, 24, 24] 56 | fuse_2 = self.mafm2(rgb_2, d_2) # [B, 128, 48, 48] 57 | fuse_1 = self.mafm1(rgb_1, d_1) # [B, 64, 96, 96] 58 | 59 | # Pred 60 | pred_4 = F.interpolate(self.predtrans(fuse_4), TRAIN_SIZE, mode="bilinear", align_corners=True) 61 | pred_3, xf_3 = self.mcm3(fuse_3, fuse_4) 62 | pred_2, xf_2 = self.mcm2(fuse_2, xf_3) 63 | pred_1, xf_1 = self.mcm1(fuse_1, xf_2) 64 | 65 | return pred_1, pred_2, pred_3, pred_4 66 | 67 | 68 | class Trans(nn.Module): 69 | def __init__(self, inc, outc): 70 | super().__init__() 71 | self.trans = nn.Sequential( 72 | nn.Conv2d(in_channels=inc, out_channels=outc, kernel_size=1), 73 | nn.BatchNorm2d(outc), 74 | nn.GELU() 75 | ) 76 | self.apply(self._init_weights) 77 | 78 | def forward(self, d): 79 | return self.trans(d) 80 | 81 | def _init_weights(self, m): 82 | if isinstance(m, nn.Linear): 83 | trunc_normal_(m.weight, std=.02) 84 | if isinstance(m, nn.Linear) and m.bias is not None: 85 | nn.init.constant_(m.bias, 0) 86 | elif isinstance(m, nn.LayerNorm): 87 | nn.init.constant_(m.bias, 0) 88 | nn.init.constant_(m.weight, 1.0) 89 | elif isinstance(m, nn.Conv2d): 90 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 91 | fan_out //= m.groups 92 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 93 | if m.bias is not None: 94 | m.bias.data.zero_() 95 | 96 | 97 | # Conv_One_Identity 98 | class COI(nn.Module): 99 | def __init__(self, inc, k=3, p=1): 100 | super().__init__() 101 | self.outc = inc 102 | self.dw = nn.Conv2d(inc, self.outc, kernel_size=k, padding=p, groups=inc) 103 | self.conv1_1 = nn.Conv2d(inc, self.outc, kernel_size=1, stride=1) 104 | self.bn1 = nn.BatchNorm2d(self.outc) 105 | self.bn2 = nn.BatchNorm2d(self.outc) 106 | self.bn3 = nn.BatchNorm2d(self.outc) 107 | self.act = nn.GELU() 108 | self.apply(self._init_weights) 109 | 110 | def forward(self, x): 111 | shortcut = self.bn1(x) 112 | 113 | x_dw = self.bn2(self.dw(x)) 114 | x_conv1_1 = self.bn3(self.conv1_1(x)) 115 | return self.act(shortcut + x_dw + x_conv1_1) 116 | 117 | def _init_weights(self, m): 118 | if isinstance(m, nn.Linear): 119 | trunc_normal_(m.weight, std=.02) 120 | if isinstance(m, nn.Linear) and m.bias is not None: 121 | nn.init.constant_(m.bias, 0) 122 | elif isinstance(m, nn.LayerNorm): 123 | nn.init.constant_(m.bias, 0) 124 | nn.init.constant_(m.weight, 1.0) 125 | elif isinstance(m, nn.Conv2d): 126 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 127 | fan_out //= m.groups 128 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 129 | if m.bias is not None: 130 | m.bias.data.zero_() 131 | 132 | 133 | class MHMC(nn.Module): 134 | def __init__(self, dim, ca_num_heads=4, qkv_bias=True, proj_drop=0., ca_attention=1, expand_ratio=2): 135 | super().__init__() 136 | 137 | self.ca_attention = ca_attention 138 | self.dim = dim 139 | self.ca_num_heads = ca_num_heads 140 | 141 | assert dim % ca_num_heads == 0, f"dim {dim} should be divided by num_heads {ca_num_heads}." 142 | 143 | self.act = nn.GELU() 144 | self.proj = nn.Linear(dim, dim) 145 | self.proj_drop = nn.Dropout(proj_drop) 146 | 147 | self.split_groups = self.dim // ca_num_heads 148 | 149 | self.v = nn.Linear(dim, dim, bias=qkv_bias) 150 | self.s = nn.Linear(dim, dim, bias=qkv_bias) 151 | for i in range(self.ca_num_heads): 152 | local_conv = nn.Conv2d(dim // self.ca_num_heads, dim // self.ca_num_heads, kernel_size=(3 + i * 2), 153 | padding=(1 + i), stride=1, 154 | groups=dim // self.ca_num_heads) # kernel_size 3,5,7,9 大核dw卷积,padding 1,2,3,4 155 | setattr(self, f"local_conv_{i + 1}", local_conv) 156 | self.proj0 = nn.Conv2d(dim, dim * expand_ratio, kernel_size=1, padding=0, stride=1, 157 | groups=self.split_groups) 158 | self.bn = nn.BatchNorm2d(dim * expand_ratio) 159 | self.proj1 = nn.Conv2d(dim * expand_ratio, dim, kernel_size=1, padding=0, stride=1) 160 | 161 | self.apply(self._init_weights) 162 | 163 | def _init_weights(self, m): 164 | if isinstance(m, nn.Linear): 165 | trunc_normal_(m.weight, std=.02) 166 | if isinstance(m, nn.Linear) and m.bias is not None: 167 | nn.init.constant_(m.bias, 0) 168 | elif isinstance(m, nn.LayerNorm): 169 | nn.init.constant_(m.bias, 0) 170 | nn.init.constant_(m.weight, 1.0) 171 | elif isinstance(m, nn.Conv2d): 172 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 173 | fan_out //= m.groups 174 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 175 | if m.bias is not None: 176 | m.bias.data.zero_() 177 | 178 | def forward(self, x, H, W): 179 | B, N, C = x.shape 180 | v = self.v(x) 181 | s = self.s(x).reshape(B, H, W, self.ca_num_heads, C // self.ca_num_heads).permute(3, 0, 4, 1, 182 | 2) # num_heads,B,C,H,W 183 | for i in range(self.ca_num_heads): 184 | local_conv = getattr(self, f"local_conv_{i + 1}") 185 | s_i = s[i] # B,C,H,W 186 | s_i = local_conv(s_i).reshape(B, self.split_groups, -1, H, W) 187 | if i == 0: 188 | s_out = s_i 189 | else: 190 | s_out = torch.cat([s_out, s_i], 2) 191 | s_out = s_out.reshape(B, C, H, W) 192 | s_out = self.proj1(self.act(self.bn(self.proj0(s_out)))) 193 | self.modulator = s_out 194 | s_out = s_out.reshape(B, C, N).permute(0, 2, 1) 195 | x = s_out * v 196 | 197 | x = self.proj(x) 198 | x = self.proj_drop(x) 199 | return x 200 | 201 | 202 | class SAttention(nn.Module): 203 | def __init__(self, dim, sa_num_heads=8, qkv_bias=True, qk_scale=None, 204 | attn_drop=0., proj_drop=0.): 205 | super().__init__() 206 | 207 | self.dim = dim 208 | self.sa_num_heads = sa_num_heads 209 | 210 | assert dim % sa_num_heads == 0, f"dim {dim} should be divided by num_heads {sa_num_heads}." 211 | 212 | self.act = nn.GELU() 213 | self.proj = nn.Linear(dim, dim) 214 | self.proj_drop = nn.Dropout(proj_drop) 215 | 216 | head_dim = dim // sa_num_heads 217 | self.scale = qk_scale or head_dim ** -0.5 218 | self.q = nn.Linear(dim, dim, bias=qkv_bias) 219 | self.attn_drop = nn.Dropout(attn_drop) 220 | self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias) 221 | self.local_conv = nn.Conv2d(dim, dim, kernel_size=3, padding=1, stride=1, groups=dim) 222 | 223 | self.apply(self._init_weights) 224 | 225 | def _init_weights(self, m): 226 | if isinstance(m, nn.Linear): 227 | trunc_normal_(m.weight, std=.02) 228 | if isinstance(m, nn.Linear) and m.bias is not None: 229 | nn.init.constant_(m.bias, 0) 230 | elif isinstance(m, nn.LayerNorm): 231 | nn.init.constant_(m.bias, 0) 232 | nn.init.constant_(m.weight, 1.0) 233 | elif isinstance(m, nn.Conv2d): 234 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 235 | fan_out //= m.groups 236 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 237 | if m.bias is not None: 238 | m.bias.data.zero_() 239 | 240 | def forward(self, x, H, W): 241 | B, N, C = x.shape 242 | 243 | q = self.q(x).reshape(B, N, self.sa_num_heads, C // self.sa_num_heads).permute(0, 2, 1, 3) 244 | kv = self.kv(x).reshape(B, -1, 2, self.sa_num_heads, C // self.sa_num_heads).permute(2, 0, 3, 1, 4) 245 | k, v = kv[0], kv[1] 246 | attn = (q @ k.transpose(-2, -1)) * self.scale 247 | attn = attn.softmax(dim=-1) 248 | attn = self.attn_drop(attn) 249 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) + \ 250 | self.local_conv(v.transpose(1, 2).reshape(B, N, C).transpose(1, 2).view(B, C, H, W)).view(B, C, 251 | N).transpose(1, 2) 252 | 253 | x = self.proj(x) 254 | x = self.proj_drop(x) 255 | 256 | return x.permute(0, 2, 1).reshape(B, C, H, W) 257 | 258 | 259 | # Multi-scale Awareness Fusion Module 260 | class MAFM(nn.Module): 261 | def __init__(self, inc): 262 | super().__init__() 263 | self.outc = inc 264 | self.attention = MHMC(dim=inc) 265 | self.coi = COI(inc) 266 | self.pw = nn.Sequential( 267 | nn.Conv2d(in_channels=inc, out_channels=inc, kernel_size=1, stride=1), 268 | nn.BatchNorm2d(inc), 269 | nn.GELU() 270 | ) 271 | self.pre_att = nn.Sequential( 272 | nn.Conv2d(inc * 2, inc * 2, kernel_size=3, padding=1, groups=inc * 2), 273 | nn.BatchNorm2d(inc * 2), 274 | nn.GELU(), 275 | nn.Conv2d(inc * 2, inc, kernel_size=1), 276 | nn.BatchNorm2d(inc), 277 | nn.GELU() 278 | ) 279 | 280 | self.apply(self._init_weights) 281 | 282 | def forward(self, x, d): 283 | # multi = x * d 284 | # B, C, H, W = x.shape 285 | # x_cat = torch.cat((x, d, multi), dim=1) 286 | 287 | B, C, H, W = x.shape 288 | x_cat = torch.cat((x, d), dim=1) 289 | x_pre = self.pre_att(x_cat) 290 | # Attention 291 | x_reshape = x_pre.flatten(2).permute(0, 2, 1) # B,C,H,W to B,N,C 292 | attention = self.attention(x_reshape, H, W) # attention 293 | attention = attention.permute(0, 2, 1).reshape(B, C, H, W) # B,N,C to B,C,H,W 294 | 295 | # COI 296 | x_conv = self.coi(attention) # dw3*3,1*1,identity 297 | x_conv = self.pw(x_conv) # pw 298 | 299 | return x_conv 300 | 301 | def _init_weights(self, m): 302 | if isinstance(m, nn.Linear): 303 | trunc_normal_(m.weight, std=.02) 304 | if isinstance(m, nn.Linear) and m.bias is not None: 305 | nn.init.constant_(m.bias, 0) 306 | elif isinstance(m, nn.LayerNorm): 307 | nn.init.constant_(m.bias, 0) 308 | nn.init.constant_(m.weight, 1.0) 309 | elif isinstance(m, nn.Conv2d): 310 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 311 | fan_out //= m.groups 312 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 313 | if m.bias is not None: 314 | m.bias.data.zero_() 315 | 316 | 317 | # Decoder 318 | class MCM(nn.Module): 319 | def __init__(self, inc, outc): 320 | super().__init__() 321 | self.upsample2 = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True) 322 | self.rc = nn.Sequential( 323 | nn.Conv2d(in_channels=inc, out_channels=inc, kernel_size=3, padding=1, stride=1, groups=inc), 324 | nn.BatchNorm2d(inc), 325 | nn.GELU(), 326 | nn.Conv2d(in_channels=inc, out_channels=outc, kernel_size=1, stride=1), 327 | nn.BatchNorm2d(outc), 328 | nn.GELU() 329 | ) 330 | self.predtrans = nn.Sequential( 331 | nn.Conv2d(in_channels=outc, out_channels=outc, kernel_size=3, padding=1, groups=outc), 332 | nn.BatchNorm2d(outc), 333 | nn.GELU(), 334 | nn.Conv2d(in_channels=outc, out_channels=1, kernel_size=1) 335 | ) 336 | 337 | self.rc2 = nn.Sequential( 338 | nn.Conv2d(in_channels=outc * 2, out_channels=outc * 2, kernel_size=3, padding=1, groups=outc * 2), 339 | nn.BatchNorm2d(outc * 2), 340 | nn.GELU(), 341 | nn.Conv2d(in_channels=outc * 2, out_channels=outc, kernel_size=1, stride=1), 342 | nn.BatchNorm2d(outc), 343 | nn.GELU() 344 | ) 345 | 346 | self.apply(self._init_weights) 347 | 348 | def forward(self, x1, x2): 349 | x2_upsample = self.upsample2(x2) # 上采样 350 | x2_rc = self.rc(x2_upsample) # 减少通道数 351 | shortcut = x2_rc 352 | 353 | x_cat = torch.cat((x1, x2_rc), dim=1) # 拼接 354 | x_forward = self.rc2(x_cat) # 减少通道数2 355 | x_forward = x_forward + shortcut 356 | pred = F.interpolate(self.predtrans(x_forward), TRAIN_SIZE, mode="bilinear", align_corners=True) # 预测图 357 | 358 | return pred, x_forward 359 | 360 | def _init_weights(self, m): 361 | if isinstance(m, nn.Linear): 362 | trunc_normal_(m.weight, std=.02) 363 | if isinstance(m, nn.Linear) and m.bias is not None: 364 | nn.init.constant_(m.bias, 0) 365 | elif isinstance(m, nn.LayerNorm): 366 | nn.init.constant_(m.bias, 0) 367 | nn.init.constant_(m.weight, 1.0) 368 | elif isinstance(m, nn.Conv2d): 369 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 370 | fan_out //= m.groups 371 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 372 | if m.bias is not None: 373 | m.bias.data.zero_() 374 | 375 | 376 | # Global Fusion Module 377 | class GFM(nn.Module): 378 | def __init__(self, inc, expend_ratio=2): 379 | super().__init__() 380 | self.expend_ratio = expend_ratio 381 | assert expend_ratio in [2, 3], f"expend_ratio {expend_ratio} mismatch" 382 | 383 | self.sa = SAttention(dim=inc) 384 | self.dw_pw = DWPWConv(inc * expend_ratio, inc) 385 | self.act = nn.GELU() 386 | self.apply(self._init_weights) 387 | 388 | def forward(self, x, d): 389 | B, C, H, W = x.shape 390 | if self.expend_ratio == 2: 391 | cat = torch.cat((x, d), dim=1) 392 | else: 393 | multi = x * d 394 | cat = torch.cat((x, d, multi), dim=1) 395 | x_rc = self.dw_pw(cat).flatten(2).permute(0, 2, 1) 396 | x_ = self.sa(x_rc, H, W) 397 | x_ = x_ + x 398 | return self.act(x_) 399 | 400 | def _init_weights(self, m): 401 | if isinstance(m, nn.Linear): 402 | trunc_normal_(m.weight, std=.02) 403 | if isinstance(m, nn.Linear) and m.bias is not None: 404 | nn.init.constant_(m.bias, 0) 405 | elif isinstance(m, nn.LayerNorm): 406 | nn.init.constant_(m.bias, 0) 407 | nn.init.constant_(m.weight, 1.0) 408 | elif isinstance(m, nn.Conv2d): 409 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 410 | fan_out //= m.groups 411 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 412 | if m.bias is not None: 413 | m.bias.data.zero_() 414 | 415 | 416 | class DWPWConv(nn.Module): 417 | def __init__(self, inc, outc): 418 | super().__init__() 419 | self.conv = nn.Sequential( 420 | nn.Conv2d(in_channels=inc, out_channels=inc, kernel_size=3, padding=1, stride=1, groups=inc), 421 | nn.BatchNorm2d(inc), 422 | nn.GELU(), 423 | nn.Conv2d(in_channels=inc, out_channels=outc, kernel_size=1, stride=1), 424 | nn.BatchNorm2d(outc), 425 | nn.GELU() 426 | ) 427 | 428 | def forward(self, x): 429 | return self.conv(x) 430 | -------------------------------------------------------------------------------- /model/smt.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from functools import partial 5 | 6 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_ 7 | from timm.models.registry import register_model 8 | from timm.models.vision_transformer import _cfg 9 | import math 10 | 11 | from torchvision import transforms 12 | from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 13 | from timm.data import create_transform 14 | from timm.data.transforms import str_to_pil_interp 15 | # from ptflops import get_model_complexity_info 16 | # from thop import profile 17 | 18 | 19 | class Mlp(nn.Module): 20 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 21 | super().__init__() 22 | out_features = out_features or in_features 23 | hidden_features = hidden_features or in_features 24 | self.fc1 = nn.Linear(in_features, hidden_features) 25 | self.dwconv = DWConv(hidden_features) 26 | self.act = act_layer() 27 | self.fc2 = nn.Linear(hidden_features, out_features) 28 | self.drop = nn.Dropout(drop) 29 | self.apply(self._init_weights) 30 | 31 | def _init_weights(self, m): 32 | if isinstance(m, nn.Linear): 33 | trunc_normal_(m.weight, std=.02) 34 | if isinstance(m, nn.Linear) and m.bias is not None: 35 | nn.init.constant_(m.bias, 0) 36 | elif isinstance(m, nn.LayerNorm): 37 | nn.init.constant_(m.bias, 0) 38 | nn.init.constant_(m.weight, 1.0) 39 | elif isinstance(m, nn.Conv2d): 40 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 41 | fan_out //= m.groups 42 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 43 | if m.bias is not None: 44 | m.bias.data.zero_() 45 | 46 | def forward(self, x, H, W): 47 | x = self.fc1(x) 48 | x = self.act(x + self.dwconv(x, H, W)) 49 | x = self.drop(x) 50 | x = self.fc2(x) 51 | x = self.drop(x) 52 | return x 53 | 54 | 55 | class Attention(nn.Module): 56 | def __init__(self, dim, ca_num_heads=4, sa_num_heads=8, qkv_bias=False, qk_scale=None, 57 | attn_drop=0., proj_drop=0., ca_attention=1, expand_ratio=2): 58 | super().__init__() 59 | 60 | self.ca_attention = ca_attention 61 | self.dim = dim 62 | self.ca_num_heads = ca_num_heads 63 | self.sa_num_heads = sa_num_heads 64 | 65 | assert dim % ca_num_heads == 0, f"dim {dim} should be divided by num_heads {ca_num_heads}." 66 | assert dim % sa_num_heads == 0, f"dim {dim} should be divided by num_heads {sa_num_heads}." 67 | 68 | self.act = nn.GELU() 69 | self.proj = nn.Linear(dim, dim) 70 | self.proj_drop = nn.Dropout(proj_drop) 71 | 72 | self.split_groups = self.dim // ca_num_heads 73 | 74 | if ca_attention == 1: 75 | self.v = nn.Linear(dim, dim, bias=qkv_bias) 76 | self.s = nn.Linear(dim, dim, bias=qkv_bias) 77 | for i in range(self.ca_num_heads): 78 | local_conv = nn.Conv2d(dim // self.ca_num_heads, dim // self.ca_num_heads, kernel_size=(3 + i * 2), 79 | padding=(1 + i), stride=1, groups=dim // self.ca_num_heads) # kernel_size 3,5,7,9 大核dw卷积,padding 1,2,3,4 80 | setattr(self, f"local_conv_{i + 1}", local_conv) 81 | self.proj0 = nn.Conv2d(dim, dim * expand_ratio, kernel_size=1, padding=0, stride=1, 82 | groups=self.split_groups) 83 | self.bn = nn.BatchNorm2d(dim * expand_ratio) 84 | self.proj1 = nn.Conv2d(dim * expand_ratio, dim, kernel_size=1, padding=0, stride=1) 85 | 86 | else: 87 | head_dim = dim // sa_num_heads 88 | self.scale = qk_scale or head_dim ** -0.5 89 | self.q = nn.Linear(dim, dim, bias=qkv_bias) 90 | self.attn_drop = nn.Dropout(attn_drop) 91 | self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias) 92 | self.local_conv = nn.Conv2d(dim, dim, kernel_size=3, padding=1, stride=1, groups=dim) 93 | 94 | self.apply(self._init_weights) 95 | 96 | def _init_weights(self, m): 97 | if isinstance(m, nn.Linear): 98 | trunc_normal_(m.weight, std=.02) 99 | if isinstance(m, nn.Linear) and m.bias is not None: 100 | nn.init.constant_(m.bias, 0) 101 | elif isinstance(m, nn.LayerNorm): 102 | nn.init.constant_(m.bias, 0) 103 | nn.init.constant_(m.weight, 1.0) 104 | elif isinstance(m, nn.Conv2d): 105 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 106 | fan_out //= m.groups 107 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 108 | if m.bias is not None: 109 | m.bias.data.zero_() 110 | 111 | def forward(self, x, H, W): 112 | B, N, C = x.shape 113 | if self.ca_attention == 1: 114 | v = self.v(x) 115 | s = self.s(x).reshape(B, H, W, self.ca_num_heads, C // self.ca_num_heads).permute(3, 0, 4, 1, 2) # num_heads,B,C,H,W 116 | for i in range(self.ca_num_heads): 117 | local_conv = getattr(self, f"local_conv_{i + 1}") 118 | s_i = s[i] # B,C,H,W 119 | s_i = local_conv(s_i).reshape(B, self.split_groups, -1, H, W) 120 | if i == 0: 121 | s_out = s_i 122 | else: 123 | s_out = torch.cat([s_out, s_i], 2) 124 | s_out = s_out.reshape(B, C, H, W) 125 | s_out = self.proj1(self.act(self.bn(self.proj0(s_out)))) 126 | self.modulator = s_out 127 | s_out = s_out.reshape(B, C, N).permute(0, 2, 1) 128 | x = s_out * v 129 | 130 | else: 131 | q = self.q(x).reshape(B, N, self.sa_num_heads, C // self.sa_num_heads).permute(0, 2, 1, 3) 132 | kv = self.kv(x).reshape(B, -1, 2, self.sa_num_heads, C // self.sa_num_heads).permute(2, 0, 3, 1, 4) 133 | k, v = kv[0], kv[1] 134 | attn = (q @ k.transpose(-2, -1)) * self.scale 135 | attn = attn.softmax(dim=-1) 136 | attn = self.attn_drop(attn) 137 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) + \ 138 | self.local_conv(v.transpose(1, 2).reshape(B, N, C).transpose(1, 2).view(B, C, H, W)).view(B, C, N).transpose(1, 2) 139 | 140 | x = self.proj(x) 141 | x = self.proj_drop(x) 142 | 143 | return x 144 | 145 | 146 | class Block(nn.Module): 147 | 148 | def __init__(self, dim, ca_num_heads, sa_num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, 149 | use_layerscale=False, layerscale_value=1e-4, drop=0., attn_drop=0., 150 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, ca_attention=1, expand_ratio=2): 151 | super().__init__() 152 | self.norm1 = norm_layer(dim) 153 | self.attn = Attention( 154 | dim, 155 | ca_num_heads=ca_num_heads, sa_num_heads=sa_num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, 156 | attn_drop=attn_drop, proj_drop=drop, ca_attention=ca_attention, 157 | expand_ratio=expand_ratio) 158 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here 159 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 160 | self.norm2 = norm_layer(dim) 161 | mlp_hidden_dim = int(dim * mlp_ratio) 162 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 163 | 164 | self.gamma_1 = 1.0 165 | self.gamma_2 = 1.0 166 | if use_layerscale: 167 | self.gamma_1 = nn.Parameter(layerscale_value * torch.ones((dim)), requires_grad=True) 168 | self.gamma_2 = nn.Parameter(layerscale_value * torch.ones((dim)), requires_grad=True) 169 | 170 | self.apply(self._init_weights) 171 | 172 | def _init_weights(self, m): 173 | if isinstance(m, nn.Linear): 174 | trunc_normal_(m.weight, std=.02) 175 | if isinstance(m, nn.Linear) and m.bias is not None: 176 | nn.init.constant_(m.bias, 0) 177 | elif isinstance(m, nn.LayerNorm): 178 | nn.init.constant_(m.bias, 0) 179 | nn.init.constant_(m.weight, 1.0) 180 | elif isinstance(m, nn.Conv2d): 181 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 182 | fan_out //= m.groups 183 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 184 | if m.bias is not None: 185 | m.bias.data.zero_() 186 | 187 | def forward(self, x, H, W): 188 | 189 | x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x), H, W)) 190 | x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x), H, W)) 191 | 192 | return x 193 | 194 | 195 | class OverlapPatchEmbed(nn.Module): 196 | """ Image to Patch Embedding 197 | """ 198 | 199 | def __init__(self, img_size=224, patch_size=3, stride=2, in_chans=3, embed_dim=768): 200 | super().__init__() 201 | patch_size = to_2tuple(patch_size) 202 | img_size = to_2tuple(img_size) 203 | 204 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride, 205 | padding=(patch_size[0] // 2, patch_size[1] // 2)) 206 | self.norm = nn.LayerNorm(embed_dim) 207 | 208 | self.apply(self._init_weights) 209 | 210 | def _init_weights(self, m): 211 | if isinstance(m, nn.Linear): 212 | trunc_normal_(m.weight, std=.02) 213 | if isinstance(m, nn.Linear) and m.bias is not None: 214 | nn.init.constant_(m.bias, 0) 215 | elif isinstance(m, nn.LayerNorm): 216 | nn.init.constant_(m.bias, 0) 217 | nn.init.constant_(m.weight, 1.0) 218 | elif isinstance(m, nn.Conv2d): 219 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 220 | fan_out //= m.groups 221 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 222 | if m.bias is not None: 223 | m.bias.data.zero_() 224 | 225 | def forward(self, x): 226 | x = self.proj(x) 227 | _, _, H, W = x.shape 228 | x = x.flatten(2).transpose(1, 2) 229 | x = self.norm(x) 230 | 231 | return x, H, W 232 | 233 | 234 | class Head(nn.Module): 235 | def __init__(self, head_conv, dim): 236 | super(Head, self).__init__() 237 | stem = [nn.Conv2d(3, dim, head_conv, 2, padding=3 if head_conv == 7 else 1, bias=False), nn.BatchNorm2d(dim), 238 | nn.ReLU(True)] 239 | stem.append(nn.Conv2d(dim, dim, kernel_size=2, stride=2)) 240 | self.conv = nn.Sequential(*stem) 241 | self.norm = nn.LayerNorm(dim) 242 | self.apply(self._init_weights) 243 | 244 | def _init_weights(self, m): 245 | if isinstance(m, nn.Linear): 246 | trunc_normal_(m.weight, std=.02) 247 | if isinstance(m, nn.Linear) and m.bias is not None: 248 | nn.init.constant_(m.bias, 0) 249 | elif isinstance(m, nn.LayerNorm): 250 | nn.init.constant_(m.bias, 0) 251 | nn.init.constant_(m.weight, 1.0) 252 | elif isinstance(m, nn.Conv2d): 253 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 254 | fan_out //= m.groups 255 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 256 | if m.bias is not None: 257 | m.bias.data.zero_() 258 | 259 | def forward(self, x): 260 | x = self.conv(x) 261 | _, _, H, W = x.shape 262 | x = x.flatten(2).transpose(1, 2) 263 | x = self.norm(x) 264 | return x, H, W 265 | 266 | 267 | class SMT(nn.Module): 268 | def __init__(self, img_size=224, in_chans=3, num_classes=1000, embed_dims=[64, 128, 256, 512], 269 | ca_num_heads=[4, 4, 4, -1], sa_num_heads=[-1, -1, 8, 16], mlp_ratios=[8, 6, 4, 2], 270 | qkv_bias=False, qk_scale=None, use_layerscale=False, layerscale_value=1e-4, drop_rate=0., 271 | attn_drop_rate=0., drop_path_rate=0., norm_layer=partial(nn.LayerNorm, eps=1e-6), 272 | depths=[2, 2, 8, 1], ca_attentions=[1, 1, 1, 0], num_stages=4, head_conv=3, expand_ratio=2, **kwargs): 273 | super().__init__() 274 | self.num_classes = num_classes 275 | self.depths = depths 276 | self.num_stages = num_stages 277 | 278 | 279 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule 280 | cur = 0 281 | 282 | for i in range(num_stages): 283 | if i == 0: 284 | patch_embed = Head(head_conv, embed_dims[i]) # 285 | else: 286 | patch_embed = OverlapPatchEmbed(img_size=img_size if i == 0 else img_size // (2 ** (i + 1)), 287 | patch_size=3, 288 | stride=2, 289 | in_chans=embed_dims[i - 1], 290 | embed_dim=embed_dims[i]) 291 | 292 | block = nn.ModuleList([Block( 293 | dim=embed_dims[i], ca_num_heads=ca_num_heads[i], sa_num_heads=sa_num_heads[i], mlp_ratio=mlp_ratios[i], 294 | qkv_bias=qkv_bias, qk_scale=qk_scale, 295 | use_layerscale=use_layerscale, 296 | layerscale_value=layerscale_value, 297 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + j], norm_layer=norm_layer, 298 | ca_attention=0 if i == 2 and j % 2 != 0 else ca_attentions[i], expand_ratio=expand_ratio) 299 | for j in range(depths[i])]) 300 | norm = norm_layer(embed_dims[i]) 301 | cur += depths[i] 302 | 303 | setattr(self, f"patch_embed{i + 1}", patch_embed) 304 | setattr(self, f"block{i + 1}", block) 305 | setattr(self, f"norm{i + 1}", norm) 306 | 307 | # classification head 308 | self.head = nn.Linear(embed_dims[3], num_classes) if num_classes > 0 else nn.Identity() 309 | 310 | self.apply(self._init_weights) 311 | 312 | def _init_weights(self, m): 313 | if isinstance(m, nn.Linear): 314 | trunc_normal_(m.weight, std=.02) 315 | if isinstance(m, nn.Linear) and m.bias is not None: 316 | nn.init.constant_(m.bias, 0) 317 | elif isinstance(m, nn.LayerNorm): 318 | nn.init.constant_(m.bias, 0) 319 | nn.init.constant_(m.weight, 1.0) 320 | elif isinstance(m, nn.Conv2d): 321 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 322 | fan_out //= m.groups 323 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 324 | if m.bias is not None: 325 | m.bias.data.zero_() 326 | 327 | def freeze_patch_emb(self): 328 | self.patch_embed1.requires_grad = False 329 | 330 | @torch.jit.ignore 331 | def no_weight_decay(self): 332 | return {'pos_embed1', 'pos_embed2', 'pos_embed3', 'pos_embed4', 'cls_token'} # has pos_embed may be better 333 | 334 | def get_classifier(self): 335 | return self.head 336 | 337 | def reset_classifier(self, num_classes, global_pool=''): 338 | self.num_classes = num_classes 339 | self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() 340 | 341 | def forward_features(self, x): 342 | B = x.shape[0] 343 | f_list = [] 344 | for i in range(self.num_stages): 345 | patch_embed = getattr(self, f"patch_embed{i + 1}") 346 | block = getattr(self, f"block{i + 1}") 347 | norm = getattr(self, f"norm{i + 1}") 348 | x, H, W = patch_embed(x) 349 | for blk in block: 350 | x = blk(x, H, W) 351 | x = norm(x) 352 | # if i != self.num_stages - 1: 353 | # x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() 354 | x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() 355 | f_list.append(x) 356 | 357 | return x.mean(dim=1), f_list 358 | 359 | def forward(self, x): 360 | x, f_list = self.forward_features(x) 361 | # x = self.head(x) 362 | 363 | return x, f_list 364 | 365 | 366 | class DWConv(nn.Module): 367 | def __init__(self, dim=768): 368 | super(DWConv, self).__init__() 369 | self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim) 370 | 371 | def forward(self, x, H, W): 372 | B, N, C = x.shape 373 | x = x.transpose(1, 2).view(B, C, H, W) 374 | x = self.dwconv(x) 375 | x = x.flatten(2).transpose(1, 2) 376 | 377 | return x 378 | 379 | 380 | def build_transforms(img_size, center_crop=False): 381 | t = [] 382 | if center_crop: 383 | size = int((256 / 224) * img_size) 384 | t.append( 385 | transforms.Resize(size, interpolation=str_to_pil_interp('bicubic')) 386 | ) 387 | t.append( 388 | transforms.CenterCrop(img_size) 389 | ) 390 | else: 391 | t.append( 392 | transforms.Resize(img_size, interpolation=str_to_pil_interp('bicubic')) 393 | ) 394 | t.append(transforms.ToTensor()) 395 | t.append(transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)) 396 | return transforms.Compose(t) 397 | 398 | 399 | def build_transforms4display(img_size, center_crop=False): 400 | t = [] 401 | if center_crop: 402 | size = int((256 / 224) * img_size) 403 | t.append( 404 | transforms.Resize(size, interpolation=str_to_pil_interp('bicubic')) 405 | ) 406 | t.append( 407 | transforms.CenterCrop(img_size) 408 | ) 409 | else: 410 | t.append( 411 | transforms.Resize(img_size, interpolation=str_to_pil_interp('bicubic')) 412 | ) 413 | t.append(transforms.ToTensor()) 414 | return transforms.Compose(t) 415 | 416 | 417 | def smt_t(pretrained=False, **kwargs): 418 | model = SMT( 419 | embed_dims=[64, 128, 256, 512], ca_num_heads=[4, 4, 4, -1], sa_num_heads=[-1, -1, 8, 16], 420 | mlp_ratios=[4, 4, 4, 2], 421 | qkv_bias=True, depths=[2, 2, 8, 1], ca_attentions=[1, 1, 1, 0], head_conv=3, expand_ratio=2, **kwargs) 422 | model.default_cfg = _cfg() 423 | 424 | return model 425 | 426 | 427 | def smt_s(pretrained=False, **kwargs): 428 | model = SMT( 429 | embed_dims=[64, 128, 256, 512], ca_num_heads=[4, 4, 4, -1], sa_num_heads=[-1, -1, 8, 16], 430 | mlp_ratios=[4, 4, 4, 2], 431 | qkv_bias=True, depths=[3, 4, 18, 2], ca_attentions=[1, 1, 1, 0], head_conv=3, expand_ratio=2, **kwargs) 432 | model.default_cfg = _cfg() 433 | return model 434 | 435 | 436 | def smt_b(pretrained=False, **kwargs): 437 | model = SMT( 438 | embed_dims=[64, 128, 256, 512], ca_num_heads=[4, 4, 4, -1], sa_num_heads=[-1, -1, 8, 16], 439 | mlp_ratios=[8, 6, 4, 2], 440 | qkv_bias=True, depths=[4, 6, 28, 2], ca_attentions=[1, 1, 1, 0], head_conv=7, expand_ratio=2, **kwargs) 441 | model.default_cfg = _cfg() 442 | 443 | return model 444 | 445 | 446 | def smt_l(pretrained=False, **kwargs): 447 | model = SMT( 448 | embed_dims=[96, 192, 384, 768], ca_num_heads=[4, 4, 4, -1], sa_num_heads=[-1, -1, 8, 16], 449 | mlp_ratios=[8, 6, 4, 2], 450 | qkv_bias=True, depths=[4, 6, 28, 4], ca_attentions=[1, 1, 1, 0], head_conv=7, expand_ratio=2, **kwargs) 451 | model.default_cfg = _cfg() 452 | 453 | return model 454 | 455 | 456 | if __name__ == '__main__': 457 | import torch 458 | from collections import OrderedDict 459 | 460 | model = smt_t() 461 | # solution 2 462 | model_path = '../ckps/smt/smt_tiny.pth' 463 | device = 'cuda:0' 464 | model.load_state_dict(torch.load(model_path)['model']) 465 | torch.save(model.state_dict(), "../ckps/smt.pth") 466 | # model2 = torch.load(model_path, map_location=device)["model"] 467 | # new_pth = dict() 468 | # for i in model2: 469 | # for k,v in i: 470 | # new_pth[k] = v 471 | # model.load_state_dict(new_pth) 472 | 473 | # 将模型移动到指定的设备上 474 | # model = model.cuda() 475 | # input = torch.rand(4, 3, 224, 224).cuda() 476 | # output = model(input) 477 | # print(model) 478 | 479 | ### thop cal ### 480 | # input_shape = (1, 3, 384, 384) # 输入的形状 481 | # input_data = torch.randn(*input_shape) 482 | # macs, params = profile(model, inputs=(input_data,)) 483 | # print(f"FLOPS: {macs / 1e9:.2f}G") 484 | # print(f"params: {params / 1e6:.2f}M") 485 | 486 | ### ptflops cal ### 487 | # flops_count, params_count = get_model_complexity_info(model, (3, 224, 224), as_strings=True, 488 | # print_per_layer_stat=False) 489 | # 490 | # print('flops: ', flops_count) 491 | # print('params: ', params_count) 492 | --------------------------------------------------------------------------------