├── C2DFNet ├── metric │ ├── __pycache__ │ │ ├── metric.cpython-36.pyc │ │ └── metric.cpython-38.pyc │ └── metric.py ├── models │ ├── BaseBlocks.py │ └── resnet_dilation.py ├── test.py ├── dataset │ └── data_RGB.py ├── DualFastnet_res.py ├── SDFM.py └── MDEM.py └── README.md /C2DFNet/metric/__pycache__/metric.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DUT-IIAU-OIP-Lab/C2DFNet/HEAD/C2DFNet/metric/__pycache__/metric.cpython-36.pyc -------------------------------------------------------------------------------- /C2DFNet/metric/__pycache__/metric.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DUT-IIAU-OIP-Lab/C2DFNet/HEAD/C2DFNet/metric/__pycache__/metric.cpython-38.pyc -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # C2DFNet 2 | This is the official implementaion of TMM 2022 paper "C2DFNet: Criss-Cross Dynamic Filter Network for RGB-D Salient Object Detection". 3 | 4 | Miao Zhang, Shunyu Yao, Beiqi Hu, [Yongri Piao](http://ice.dlut.edu.cn/yrpiao/), Wei Ji. 5 | 6 | ## Prerequisites 7 | + Ubuntu 16 8 | + PyTorch 1.10.0 9 | + CUDA 11.3 10 | + Python 3.8 11 | 12 | ## Training and Testing Datasets 13 | Training dataset 14 | * [Download Link](https://pan.baidu.com/s/14cGEwcCRulWDOuKNIjuGCg). Code: 0fj8 15 | 16 | Testing dataset 17 | * [Download Link](https://pan.baidu.com/s/1Yp5YtVIBB3-9PMFruYhxSw). Code: f7vk 18 | 19 | ## Testing 20 | Download pretrained model from [here](https://pan.baidu.com/s/1_3rA5Y_jtUXzIJO8imZz2g). Code: qcra 21 | * Modify your path of testing dataset in test.py 22 | * Run test.py to inference saliency maps 23 | * Saliency maps generated from the model can be downnloaded from [here](https://pan.baidu.com/s/10UQOmUbDWDvw87gGAjeM-A). Code: hp32 24 | 25 | ```shell 26 | python test.py 27 | ``` 28 | 29 | ## Contact and Questions 30 | Contact: Shunyu Yao. Email: yao_shunyu@foxmail.com or ysyfeverfew@mail.dlut.edu.cn 31 | -------------------------------------------------------------------------------- /C2DFNet/models/BaseBlocks.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import torch.nn as nn 4 | import torch 5 | # PRelu 6 | class BasicConv_PRelu(nn.Module): 7 | 8 | def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1, relu=True, bn=True, bias=False): 9 | super(BasicConv_PRelu, self).__init__() 10 | self.out_channels = out_planes 11 | self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias) 12 | self.bn = nn.BatchNorm2d(out_planes,eps=1e-5, momentum=0.01, affine=True) if bn else None 13 | self.relu = nn.PReLU() if relu else None 14 | #self.relu = h_sigmoid() if relu else None 15 | 16 | def forward(self, x): 17 | x = self.conv(x) 18 | if self.bn is not None: 19 | x = self.bn(x) 20 | if self.relu is not None: 21 | x = self.relu(x) 22 | return x 23 | 24 | class BasicConv2d(nn.Module): 25 | def __init__( 26 | self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=False, 27 | ): 28 | super(BasicConv2d, self).__init__() 29 | 30 | self.basicconv = nn.Sequential( 31 | nn.Conv2d( 32 | in_planes, 33 | out_planes, 34 | kernel_size=kernel_size, 35 | stride=stride, 36 | padding=padding, 37 | dilation=dilation, 38 | groups=groups, 39 | bias=bias, 40 | ), 41 | nn.BatchNorm2d(out_planes), 42 | nn.ReLU(inplace=True), 43 | ) 44 | 45 | def forward(self, x): 46 | return self.basicconv(x) 47 | 48 | 49 | BN_MOMENTUM = 0.1 50 | def conv3x3(in_planes, out_planes, stride=1): 51 | """3x3 convolution with padding""" 52 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) 53 | 54 | 55 | 56 | class BasicBlock(nn.Module): 57 | 58 | def __init__(self, inplanes, planes, stride=1, downsample=None): 59 | super(BasicBlock, self).__init__() 60 | self.conv1 = conv3x3(inplanes, planes, stride) 61 | self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) 62 | self.relu = nn.ReLU(inplace=True) 63 | self.conv2 = conv3x3(planes, planes) 64 | self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) 65 | self.downsample = downsample 66 | self.stride = stride 67 | 68 | def forward(self, x): 69 | residual = x 70 | 71 | out = self.conv1(x) 72 | out = self.bn1(out) 73 | out = self.relu(out) 74 | 75 | out = self.conv2(out) 76 | out = self.bn2(out) 77 | 78 | out += residual 79 | out = self.relu(out) 80 | 81 | return out -------------------------------------------------------------------------------- /C2DFNet/test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pdb, os, argparse 3 | import imageio 4 | from DualFastnet_res import DualFastnet 5 | from dataset.data_RGB import get_loader 6 | from skimage import img_as_ubyte 7 | from metric.metric import CalFM,CalMAE,CalSM 8 | # from torchvision import transforms 9 | 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument('--testsize', type=int, default=256, help='testing size') 12 | parser.add_argument('--numworkers', type=int, default=8, help='the number of workers') 13 | opt = parser.parse_args() 14 | 15 | dataset_path = 'xxxxx' 16 | model = DualFastnet() 17 | 18 | #ck path 19 | load_name = 'xxxxx/best_epoch.pth' 20 | 21 | split_name = load_name.split('/')[6] 22 | print(split_name) 23 | model.cuda() 24 | a = torch.load(load_name) 25 | # model = torch.nn.DataParallel(model) 26 | model.load_state_dict(a) 27 | model.eval() 28 | test_datasets=['/DUT-RGBD/test_data','/LFSD','/NLPR/test_data','/RGBD135','/SSD','/STEREO','/SIP','/NJU2K','/STEREO1000'] 29 | 30 | 31 | conter = 0 32 | F_dic ={} 33 | Fmax_dic ={} 34 | MAE_dic ={} 35 | S_dic ={} 36 | for dataset in test_datasets: 37 | # 38 | print(test_datasets[conter]) 39 | 40 | save_path = './results-final/'+split_name+'/'+test_datasets[conter].split('/')[1]+'/' 41 | conter +=1 42 | if not os.path.exists(save_path): 43 | os.makedirs(save_path) 44 | image_root = dataset_path + dataset + '/images/' 45 | gt_root = dataset_path + dataset + '/gts/' 46 | depth_root = dataset_path + dataset + '/depths/' 47 | test_loader, test_samples = get_loader(image_root, depth_root, gt_root, batchsize=1, 48 | numworkers=opt.numworkers, trainsize=opt.testsize) 49 | cal_fm = CalFM(num=test_samples)# cal是一个对象 50 | cal_mae = CalMAE(num=test_samples) 51 | cal_sm = CalSM(num=test_samples) 52 | for step, packs in enumerate(test_loader): 53 | input,depth, target,name = packs 54 | input = input.cuda(non_blocking=True) 55 | target = target.cuda(non_blocking=True) 56 | target =torch.squeeze(target) 57 | depth = depth.cuda(non_blocking=True) 58 | n, c, h, w = depth.size() 59 | depth = depth.view(n, h, w, 1).repeat(1, 1, 1, 3) 60 | depth = depth.transpose(3, 1) 61 | depth = depth.transpose(3, 2) 62 | with torch.no_grad(): 63 | out1u = model(input.cuda(),depth.cuda()) 64 | output_rgb = torch.squeeze(out1u) 65 | predict_rgb = output_rgb.sigmoid().cpu().detach().numpy() 66 | max_pred_array = predict_rgb.max() 67 | min_pred_array = predict_rgb.min() 68 | if max_pred_array == min_pred_array: 69 | predict_rgb = predict_rgb / 255 70 | else: 71 | predict_rgb = (predict_rgb - min_pred_array) / (max_pred_array - min_pred_array) 72 | 73 | cal_fm.update(predict_rgb,target.data.cpu().detach().numpy()) 74 | cal_mae.update(predict_rgb,target.data.cpu().detach().numpy()) 75 | cal_sm.update(predict_rgb,target.data.cpu().detach().numpy()) 76 | 77 | # 这个负责写图 78 | imageio.imwrite(save_path + name[0], img_as_ubyte(predict_rgb)) 79 | 80 | _,maxf,mmf,_,_=cal_fm.show() 81 | mae = cal_mae.show() 82 | sm = cal_sm.show() 83 | F_dic[test_datasets[conter-1].split('/')[1]] = mmf 84 | Fmax_dic[test_datasets[conter-1].split('/')[1]] = maxf 85 | MAE_dic[test_datasets[conter-1].split('/')[1]] = mae 86 | S_dic[test_datasets[conter-1].split('/')[1]] =sm 87 | print(split_name) 88 | print("maxF-measure") 89 | print(Fmax_dic) 90 | print("F-measure") 91 | print(F_dic) 92 | print("MAE") 93 | print(MAE_dic) 94 | print("S-measure") 95 | print(S_dic) -------------------------------------------------------------------------------- /C2DFNet/dataset/data_RGB.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image 3 | from PIL import ImageEnhance 4 | import torch.utils.data as data 5 | import torchvision.transforms as transforms 6 | import torch 7 | import numpy as np 8 | import random 9 | 10 | class SalObjDataset(data.Dataset): 11 | def __init__(self, image_root, depth_root, gt_root, trainsize): 12 | self.trainsize = trainsize 13 | 14 | self.image = [image_root + f for f in os.listdir(image_root) if f.endswith('.jpg') 15 | or f.endswith('.png')] 16 | self.depth = [depth_root + f for f in os.listdir(depth_root) if f.endswith('.jpg') 17 | or f.endswith('.png')] 18 | self.gts = [gt_root + f for f in os.listdir(gt_root) if f.endswith('.jpg') 19 | or f.endswith('.png')] 20 | 21 | self.image = sorted(self.image) 22 | self.depth = sorted(self.depth) 23 | self.gts = sorted(self.gts) 24 | self.filter_files() 25 | self.size = len(self.image) 26 | self.img_transform = transforms.Compose([ 27 | transforms.Resize((self.trainsize, self.trainsize)), 28 | transforms.ToTensor(), 29 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) 30 | self.depth_transform = transforms.Compose([ 31 | transforms.Resize((self.trainsize, self.trainsize)), 32 | transforms.ToTensor()]) 33 | self.gt_transform = transforms.Compose([ 34 | transforms.Resize((self.trainsize, self.trainsize)), ]) 35 | 36 | def __getitem__(self, index): 37 | image = self.rgb_loader(self.image[index]) 38 | depth = self.binary_loader(self.depth[index]) 39 | gt = self.binary_loader(self.gts[index]) 40 | 41 | # 不加数据增强 42 | 43 | image = self.img_transform(image) 44 | depth = self.depth_transform(depth) 45 | depth = torch.div(depth.float(),255.0) 46 | gt = self.gt_transform(gt) 47 | gt = np.array(gt, dtype=np.int32) 48 | gt[gt <= 255/2] = 0 49 | gt[gt > 255/2] = 1 50 | gt = torch.from_numpy(gt).float() 51 | gt = gt.unsqueeze(0) 52 | gt = gt.reshape(1,self.trainsize,self.trainsize) 53 | 54 | 55 | # name 56 | name = self.image[index].split('/')[-1] 57 | if name.endswith('.jpg'): 58 | name = name.split('.jpg')[0] + '.png' 59 | return image, depth, gt,name 60 | 61 | def filter_files(self): 62 | assert len(self.image) == len(self.gts) 63 | depth = [] 64 | image = [] 65 | gts = [] 66 | for image_path, depth_path, gt_path in zip(self.image, self.depth, self.gts): 67 | img = Image.open(image_path) 68 | dep = Image.open(depth_path) 69 | gt = Image.open(gt_path) 70 | if img.size == gt.size == dep.size: 71 | image.append(image_path) 72 | depth.append(depth_path) 73 | gts.append(gt_path) 74 | # print(len(depth)) 75 | print("Read done") 76 | self.image = image 77 | self.depth = depth 78 | self.gts = gts 79 | 80 | def rgb_loader(self, path): 81 | with open(path, 'rb') as f: 82 | img = Image.open(f) 83 | return img.convert('RGB') 84 | 85 | def binary_loader(self, path): 86 | with open(path, 'rb') as f: 87 | img = Image.open(f) 88 | # return img.convert('1') 89 | return img.convert('L') 90 | 91 | def resize(self, img, gt): 92 | assert img.size == gt.size 93 | w, h = img.size 94 | if h < self.trainsize or w < self.trainsize: 95 | h = max(h, self.trainsize) 96 | w = max(w, self.trainsize) 97 | return img.resize((w, h), Image.BILINEAR), gt.resize((w, h), Image.NEAREST) 98 | else: 99 | return img, gt 100 | 101 | def __len__(self): 102 | return self.size 103 | 104 | 105 | def get_loader(image_root, depth_root, gt_root, batchsize,numworkers, trainsize, shuffle=True, pin_memory=True,iftrain = True): 106 | 107 | dataset = SalObjDataset(image_root, depth_root, gt_root, trainsize) 108 | # 获取数量 109 | numbers = len(dataset) 110 | # 多卡 111 | # train_sampler = torch.utils.data.distributed.DistributedSampler(dataset) 112 | data_loader = data.DataLoader(dataset=dataset, 113 | batch_size=batchsize, 114 | num_workers=numworkers, 115 | shuffle=shuffle, 116 | pin_memory=pin_memory) # 多卡,sampler=train_sampler 117 | return data_loader,numbers 118 | 119 | 120 | class test_dataset: 121 | def __init__(self, image_root, gt_root, testsize): 122 | self.testsize = testsize 123 | self.images = [image_root + f for f in os.listdir(image_root) if f.endswith('.jpg') or f.endswith('.png')] 124 | self.gts = [gt_root + f for f in os.listdir(gt_root) if f.endswith('.jpg') 125 | or f.endswith('.png')] 126 | self.images = sorted(self.images) 127 | # self.depth = sorted(self.depth) 128 | self.gts = sorted(self.gts) 129 | self.transform = transforms.Compose([ 130 | transforms.Resize((self.testsize, self.testsize)), 131 | transforms.ToTensor(), 132 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) 133 | self.gt_transform = transforms.ToTensor() 134 | self.size = len(self.images) 135 | self.index = 0 136 | 137 | def load_data(self): 138 | image = self.rgb_loader(self.images[self.index]) 139 | image = self.transform(image).unsqueeze(0) 140 | gt = self.binary_loader(self.gts[self.index]) 141 | gt = np.array(gt, dtype=np.int32) 142 | gt[gt <= 255/2] = 0 143 | gt[gt > 255/2] = 1 144 | gt = torch.from_numpy(gt).float() 145 | gt = gt.unsqueeze(0) 146 | gt = gt.reshape(1,self.trainsize,self.trainsize) 147 | 148 | name = self.images[self.index].split('\\')[-1] 149 | if name.endswith('.jpg'): 150 | name = name.split('.jpg')[0] + '.png' 151 | self.index += 1 152 | return image, gt, name 153 | 154 | def rgb_loader(self, path): 155 | with open(path, 'rb') as f: 156 | img = Image.open(f) 157 | return img.convert('RGB') 158 | 159 | def binary_loader(self, path): 160 | with open(path, 'rb') as f: 161 | img = Image.open(f) 162 | return img.convert('L') 163 | 164 | 165 | -------------------------------------------------------------------------------- /C2DFNet/models/resnet_dilation.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | # 4 | # This code is based on torchvison resnet 5 | # URL: https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py 6 | 7 | import torch.nn as nn 8 | import torch.utils.model_zoo as model_zoo 9 | 10 | 11 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 12 | 'resnet152'] 13 | 14 | 15 | model_urls = { 16 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 17 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 18 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 19 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 20 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 21 | } 22 | 23 | 24 | def conv3x3(in_planes, out_planes, stride=1, padding=1, dilation=1): 25 | """3x3 convolution with padding""" 26 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 27 | padding=padding, dilation=dilation, bias=False) 28 | 29 | 30 | def conv1x1(in_planes, out_planes, stride=1): 31 | """1x1 convolution""" 32 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 33 | 34 | 35 | class BasicBlock(nn.Module): 36 | expansion = 1 37 | 38 | def __init__(self, inplanes, planes, stride, dilation, downsample=None): 39 | super(BasicBlock, self).__init__() 40 | self.conv1 = conv3x3(inplanes, planes, stride) 41 | self.bn1 = nn.BatchNorm2d(planes) 42 | self.relu = nn.ReLU(inplace=True) 43 | self.conv2 = conv3x3(planes, planes, 1, dilation, dilation) 44 | self.bn2 = nn.BatchNorm2d(planes) 45 | self.downsample = downsample 46 | self.stride = stride 47 | 48 | def forward(self, x): 49 | residual = x 50 | 51 | out = self.conv1(x) 52 | out = self.bn1(out) 53 | out = self.relu(out) 54 | 55 | out = self.conv2(out) 56 | out = self.bn2(out) 57 | 58 | if self.downsample is not None: 59 | residual = self.downsample(x) 60 | 61 | out += residual 62 | out = self.relu(out) 63 | 64 | return out 65 | 66 | 67 | class Bottleneck(nn.Module): 68 | expansion = 4 69 | 70 | def __init__(self, inplanes, planes, stride, dilation, downsample=None, expansion=4): 71 | super(Bottleneck, self).__init__() 72 | self.expansion = expansion 73 | self.conv1 = conv1x1(inplanes, planes) 74 | self.bn1 = nn.BatchNorm2d(planes) 75 | self.conv2 = conv3x3(planes, planes, stride, dilation, dilation) 76 | self.bn2 = nn.BatchNorm2d(planes) 77 | self.conv3 = conv1x1(planes, planes * self.expansion) 78 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 79 | self.relu = nn.ReLU(inplace=True) 80 | self.downsample = downsample 81 | self.stride = stride 82 | 83 | def forward(self, x): 84 | residual = x 85 | 86 | out = self.conv1(x) 87 | out = self.bn1(out) 88 | out = self.relu(out) 89 | 90 | out = self.conv2(out) 91 | out = self.bn2(out) 92 | out = self.relu(out) 93 | 94 | out = self.conv3(out) 95 | out = self.bn3(out) 96 | 97 | if self.downsample is not None: 98 | residual = self.downsample(x) 99 | 100 | out += residual 101 | out = self.relu(out) 102 | 103 | return out 104 | 105 | 106 | class ResNet(nn.Module): 107 | 108 | def __init__(self, block, layers, output_stride, num_classes=1000, input_channels=3): 109 | super(ResNet, self).__init__() 110 | if output_stride == 8: 111 | stride = [1, 1, 2, 2] 112 | dilation = [1, 1, 2, 2] 113 | elif output_stride == 16: 114 | stride = [1, 2, 2, 2] 115 | dilation = [1, 1, 2, 2] 116 | 117 | self.inplanes = 64 118 | self.conv1 = nn.Conv2d(input_channels, 64, kernel_size=7, stride=2, padding=3, 119 | bias=False) 120 | self.bn1 = nn.BatchNorm2d(64) 121 | self.relu = nn.ReLU(inplace=True) 122 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 123 | self.layer1 = self._make_layer(block, 64, layers[0], stride=stride[0], dilation=dilation[0]) 124 | self.layer2 = self._make_layer(block, 128, layers[1], stride=stride[1], dilation=dilation[1]) 125 | self.layer3 = self._make_layer(block, 256, layers[2], stride=stride[2], dilation=dilation[2]) 126 | self.layer4 = self._make_layer(block, 512, layers[3], stride=stride[3], dilation=dilation[3]) 127 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 128 | self.fc = nn.Linear(512 * block.expansion, num_classes) 129 | 130 | for m in self.modules(): 131 | if isinstance(m, nn.Conv2d): 132 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 133 | elif isinstance(m, nn.BatchNorm2d): 134 | nn.init.constant_(m.weight, 1) 135 | nn.init.constant_(m.bias, 0) 136 | 137 | def _make_layer(self, block, planes, blocks, stride, dilation): 138 | downsample = None 139 | if stride != 1 or self.inplanes != planes * block.expansion: 140 | downsample = nn.Sequential( 141 | conv1x1(self.inplanes, planes * block.expansion, stride), 142 | nn.BatchNorm2d(planes * block.expansion), 143 | ) 144 | 145 | layers = [] 146 | layers.append(block(self.inplanes, planes, stride, dilation, downsample)) 147 | self.inplanes = planes * block.expansion 148 | for _ in range(1, blocks): 149 | layers.append(block(self.inplanes, planes, 1, dilation)) 150 | 151 | return nn.Sequential(*layers) 152 | 153 | def forward(self, x): 154 | x = self.conv1(x) 155 | x = self.bn1(x) 156 | x = self.relu(x) 157 | x = self.maxpool(x) 158 | 159 | x = self.layer1(x) 160 | x = self.layer2(x) 161 | x = self.layer3(x) 162 | x = self.layer4(x) 163 | 164 | x = self.avgpool(x) 165 | x = x.view(x.size(0), -1) 166 | x = self.fc(x) 167 | 168 | return x 169 | 170 | 171 | def resnet18(pretrained=False, **kwargs): 172 | """Constructs a ResNet-18 model. 173 | Args: 174 | pretrained (bool): If True, returns a model pre-trained on ImageNet 175 | """ 176 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 177 | if pretrained: 178 | model.load_state_dict(model_zoo.load_url(model_urls['resnet18'])) 179 | return model 180 | 181 | 182 | def resnet34(pretrained=False, **kwargs): 183 | """Constructs a ResNet-34 model. 184 | Args: 185 | pretrained (bool): If True, returns a model pre-trained on ImageNet 186 | """ 187 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 188 | if pretrained: 189 | model.load_state_dict(model_zoo.load_url(model_urls['resnet34'])) 190 | return model 191 | 192 | 193 | def resnet50(pretrained=False, **kwargs): 194 | """Constructs a ResNet-50 model. 195 | Args: 196 | pretrained (bool): If True, returns a model pre-trained on ImageNet 197 | """ 198 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 199 | if pretrained: 200 | model.load_state_dict(model_zoo.load_url(model_urls['resnet50'])) 201 | return model 202 | 203 | 204 | def resnet101(pretrained=False, **kwargs): 205 | """Constructs a ResNet-101 model. 206 | Args: 207 | pretrained (bool): If True, returns a model pre-trained on ImageNet 208 | """ 209 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 210 | if pretrained: 211 | model.load_state_dict(model_zoo.load_url(model_urls['resnet101'])) 212 | return model 213 | 214 | 215 | def resnet152(pretrained=False, **kwargs): 216 | """Constructs a ResNet-152 model. 217 | Args: 218 | pretrained (bool): If True, returns a model pre-trained on ImageNet 219 | """ 220 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) 221 | if pretrained: 222 | model.load_state_dict(model_zoo.load_url(model_urls['resnet152'])) 223 | return model -------------------------------------------------------------------------------- /C2DFNet/DualFastnet_res.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | import torch.nn.functional as F 5 | from models.resnet_dilation import resnet50, Bottleneck, conv1x1 6 | from SDFM import (SDFM, DenseTransLayer,) 7 | from MDEM import DFM 8 | from models.BaseBlocks import BasicConv_PRelu 9 | import torchvision 10 | class DenseLayer(nn.Module): 11 | def __init__(self, in_C, out_C, down_factor=4, k=4): 12 | """ 13 | 更像是DenseNet的Block,从而构造特征内的密集连接 14 | """ 15 | super(DenseLayer, self).__init__() 16 | self.k = k 17 | self.down_factor = down_factor 18 | mid_C = out_C // self.down_factor 19 | 20 | self.down = nn.Conv2d(in_C, mid_C, 1) 21 | 22 | self.denseblock = nn.ModuleList() 23 | for i in range(1, self.k + 1): 24 | self.denseblock.append(BasicConv2d(mid_C * i, mid_C, 3, 1, 1)) 25 | 26 | self.fuse = BasicConv2d(in_C + mid_C, out_C, kernel_size=3, stride=1, padding=1) 27 | 28 | def forward(self, in_feat): 29 | down_feats = self.down(in_feat) 30 | out_feats = [] 31 | for denseblock in self.denseblock: 32 | feats = denseblock(torch.cat((*out_feats, down_feats), dim=1)) 33 | out_feats.append(feats) 34 | feats = torch.cat((in_feat, feats), dim=1) 35 | return self.fuse(feats) 36 | 37 | def get_upsampling_weight(in_channels, out_channels, kernel_size): 38 | """Make a 2D bilinear kernel suitable for upsampling""" 39 | factor = (kernel_size + 1) // 2 40 | if kernel_size % 2 == 1: 41 | center = factor - 1 42 | else: 43 | center = factor - 0.5 44 | og = np.ogrid[:kernel_size, :kernel_size] 45 | filt = (1 - abs(og[0] - center) / factor) * \ 46 | (1 - abs(og[1] - center) / factor) 47 | weight = np.zeros((in_channels, out_channels, kernel_size, kernel_size), 48 | dtype=np.float64) 49 | weight[range(in_channels), range(out_channels), :, :] = filt 50 | return torch.from_numpy(weight).float() 51 | 52 | 53 | class BasicConv2d(nn.Module): 54 | def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1): 55 | super(BasicConv2d, self).__init__() 56 | self.conv = nn.Conv2d(in_planes, out_planes, 57 | kernel_size=kernel_size, stride=stride, 58 | padding=padding, dilation=dilation, bias=False) 59 | self.bn = nn.BatchNorm2d(out_planes) 60 | self.relu = nn.ReLU(inplace=True) 61 | 62 | def forward(self, x): 63 | x = self.conv(x) 64 | x = self.bn(x) 65 | return x 66 | 67 | class conbine_feature(nn.Module): 68 | def __init__(self): 69 | super(conbine_feature, self).__init__() 70 | self.up2_high = DilatedParallelConvBlockD2(32, 16) # 32 16 71 | self.up2_low = nn.Conv2d(256, 16, 1, stride=1, padding=0,bias=False) 72 | self.up2_bn2 = nn.BatchNorm2d(16) 73 | self.up2_act = nn.PReLU(16) 74 | self.refine=nn.Sequential(nn.Conv2d(16,16,3,padding=1,bias=False),nn.BatchNorm2d(16),nn.PReLU()) 75 | 76 | def forward(self, low_fea,high_fea): 77 | high_fea = self.up2_high(high_fea) # c 16 78 | low_fea = self.up2_bn2(self.up2_low(low_fea)) # c 16 79 | refine_feature = self.refine(self.up2_act(high_fea+low_fea)) # 卷积层 80 | return refine_feature 81 | 82 | class DilatedParallelConvBlockD2(nn.Module): # 表面像是降通道的 83 | def __init__(self, nIn, nOut, add=False): 84 | super(DilatedParallelConvBlockD2, self).__init__() 85 | n = int(np.ceil(nOut / 2.)) # 向上取整数 86 | n2 = nOut - n # 这个不就是减去了一半 87 | #这里有个问题是既然是降低了,为什么还要按照通道分开,这里没有提到 88 | self.conv0 = nn.Conv2d(nIn, nOut, 1, stride=1, padding=0, dilation=1, bias=False) 89 | self.conv1 = nn.Conv2d(n, n, 3, stride=1, padding=1, dilation=1, bias=False) 90 | self.conv2 = nn.Conv2d(n2, n2, 3, stride=1, padding=2, dilation=2, bias=False) # 降低了维度 91 | 92 | self.bn = nn.BatchNorm2d(nOut) 93 | #self.act = nn.PReLU(nOut) 94 | self.add = add 95 | # 在通道上进行不同的空洞操作类似于八度卷积吗 96 | def forward(self, input): 97 | in0 = self.conv0(input) # 先改通道 98 | in1, in2 = torch.chunk(in0, 2, dim=1) # 按照通道数分块 99 | b1 = self.conv1(in1) # 空洞率1 100 | b2 = self.conv2(in2) # 空洞率2 101 | output = torch.cat([b1, b2], dim=1) 102 | 103 | if self.add: 104 | output = input + output 105 | output = self.bn(output) 106 | #output = self.act(output) # 为什么不加relu了 107 | 108 | return output 109 | 110 | class DualFastnet(nn.Module): 111 | def __init__(self, channel=32): # ,down_factor=4 112 | super(DualFastnet, self).__init__() 113 | # num_of_feat = 512 114 | # 这里是两个encoder 115 | self.Res50_depth = resnet50(pretrained=True, output_stride=16, input_channels=3) 116 | self.Res50_rgb = resnet50(pretrained=True, output_stride=16, input_channels=3) 117 | # 这是特征融合的层 118 | 119 | self.translayer = DenseTransLayer(32, 32) 120 | # 动态卷积融合 121 | 122 | self.selfdc = SDFM(32, 32, 32, 3, 4) 123 | self.decoder_plus_rgb = DFM() 124 | self.decoder_plus_depth = DFM() 125 | # transfor 126 | self.tranposelayer_rgb3 = BasicConv_PRelu(512,32,1) 127 | self.tranposelayer_rgb4 = BasicConv_PRelu(1024,32,1) 128 | self.tranposelayer_rgb5 = BasicConv_PRelu(2048,32,1) 129 | self.tranposelayer_depth3 = BasicConv_PRelu(512,32,1) 130 | self.tranposelayer_depth4 = BasicConv_PRelu(1024,32,1) 131 | self.tranposelayer_depth5 = BasicConv_PRelu(2048,32,1) 132 | 133 | self.combine=conbine_feature() 134 | # Drop这里有什么用 135 | self.SegNIN = nn.Sequential(nn.Dropout2d(0.1),nn.Conv2d(16, 1, kernel_size=1,bias=False)) 136 | 137 | # # 这里是上采样 138 | # self.upsample = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=False) 139 | 140 | 141 | def forward(self,rgb,depth): 142 | # 两个encoder分别获得 143 | 144 | # rgb net 145 | block0 = self.Res50_rgb.conv1(rgb) 146 | block0 = self.Res50_rgb.bn1(block0) 147 | block0 = self.Res50_rgb.relu(block0) # 256x256 148 | block0 = self.Res50_rgb.maxpool(block0) # 128x128 149 | frist_rgb = self.Res50_rgb.layer1(block0) # 64x64 150 | conv3_rgb = self.Res50_rgb.layer2(frist_rgb) # 32x32 151 | conv4_rgb = self.Res50_rgb.layer3(conv3_rgb) # 16x16 152 | conv5_rgb = self.Res50_rgb.layer4(conv4_rgb) # 8x8 153 | 154 | # depth net 155 | block0_im = self.Res50_depth.conv1(depth) 156 | block0_im = self.Res50_depth.bn1(block0_im) 157 | block0_im = self.Res50_depth.relu(block0_im) 158 | block0_im = self.Res50_depth.maxpool(block0_im) 159 | frist_depth = self.Res50_depth.layer1(block0_im) # 256 160 | conv3_depth = self.Res50_depth.layer2(frist_depth) 161 | conv4_depth = self.Res50_depth.layer3(conv3_depth) 162 | conv5_depth = self.Res50_depth.layer4(conv4_depth) 163 | 164 | 165 | # transpose 166 | conv3_rgb = self.tranposelayer_rgb3(conv3_rgb) 167 | conv4_rgb = self.tranposelayer_rgb4(conv4_rgb) 168 | conv5_rgb = self.tranposelayer_rgb5(conv5_rgb) 169 | conv3_depth = self.tranposelayer_depth3(conv3_depth) 170 | conv4_depth = self.tranposelayer_depth4(conv4_depth) 171 | conv5_depth = self.tranposelayer_depth5(conv5_depth) 172 | 173 | 174 | # scale fuse 175 | rgb_final = self.decoder_plus_rgb(conv3_rgb, conv4_rgb,conv5_rgb) 176 | depth_final = self.decoder_plus_depth(conv3_depth,conv4_depth, conv5_depth) #1/8 177 | 178 | 179 | # DDPM 180 | trans_rgb = self.translayer(rgb_final,depth_final) 181 | 182 | 183 | rgb_high_feature_dy = self.selfdc(rgb_final,trans_rgb)+rgb_final 184 | 185 | # decoder 186 | # rgb decoder 187 | rgb_final = F.interpolate(rgb_high_feature_dy, size=(frist_rgb.shape[-2], frist_rgb.shape[-1]), 188 | mode="bilinear", 189 | align_corners=False) 190 | rgb_final = self.combine(frist_rgb,rgb_final) # 1/8 191 | 192 | rgb_final = F.interpolate(self.SegNIN(rgb_final), size=(rgb.shape[-2], rgb.shape[-1]), mode="bilinear",align_corners=False) 193 | 194 | return rgb_final 195 | 196 | if __name__=="__main__": 197 | # from torchstat import stat 198 | a = torch.zeros(1, 3, 256, 256).cuda() 199 | b = torch.zeros(1, 3, 256, 256).cuda() 200 | 201 | mobile = DualFastnet().cuda() 202 | c = mobile(a, b) 203 | print(c.size()) 204 | total_paramters = sum([np.prod(p.size()) for p in mobile.parameters()]) 205 | print('Total network parameters: ' + str(total_paramters / 1e6) + "M") -------------------------------------------------------------------------------- /C2DFNet/SDFM.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from models.BaseBlocks import BasicConv2d,BasicConv_PRelu 4 | import torch.nn.functional as F 5 | def Split(x): 6 | c = int(x.size()[1]) # x的通道数量 7 | c1 = round(c * 0.5) # 大约为0.5 8 | x1 = x[:, :c1, :, :].contiguous() 9 | x2 = x[:, c1:, :, :].contiguous() 10 | return x1, x2 11 | class h_sigmoid(nn.Module): 12 | def __init__(self, inplace=True): 13 | super(h_sigmoid, self).__init__() 14 | self.relu = nn.ReLU6(inplace=inplace) 15 | 16 | def forward(self, x): 17 | return self.relu(x + 3) / 6 18 | 19 | 20 | class DenseLayer(nn.Module): 21 | def __init__(self, in_C, out_C, down_factor=4, k=4): 22 | """ 23 | 更像是DenseNet的Block,从而构造特征内的密集连接 24 | """ 25 | super(DenseLayer, self).__init__() 26 | self.k = k 27 | self.down_factor = down_factor 28 | mid_C = out_C // self.down_factor 29 | 30 | self.down = nn.Conv2d(in_C, mid_C, 1) 31 | 32 | self.denseblock = nn.ModuleList() 33 | for i in range(1, self.k + 1): 34 | self.denseblock.append(BasicConv_PRelu(mid_C * i, mid_C, 3, 1, 1)) 35 | 36 | self.fuse = BasicConv_PRelu(in_C + mid_C, out_C, kernel_size=3, stride=1, padding=1) 37 | 38 | def forward(self, in_feat): 39 | down_feats = self.down(in_feat) 40 | out_feats = [] 41 | for denseblock in self.denseblock: 42 | feats = denseblock(torch.cat((*out_feats, down_feats), dim=1)) 43 | out_feats.append(feats) 44 | feats = torch.cat((in_feat, feats), dim=1) 45 | return self.fuse(feats) 46 | 47 | 48 | class DenseTransLayer(nn.Module): 49 | def __init__(self, in_C, out_C): 50 | super(DenseTransLayer, self).__init__() 51 | down_factor = in_C // out_C 52 | self.fuse_down_mul = BasicConv_PRelu(in_C*2, in_C, 3, 1, 1) 53 | #去掉denselayer会提升速度 54 | self.res_main = DenseLayer(in_C, in_C, down_factor=down_factor) 55 | # self.res_main = _OSA_module(in_C,in_C,in_C,4,True) 56 | 57 | self.fuse_main = BasicConv_PRelu(in_C, out_C, kernel_size=3, stride=1, padding=1) 58 | 59 | def forward(self, rgb, depth): 60 | assert rgb.size() == depth.size() 61 | feat = self.fuse_down_mul(torch.cat([rgb,depth],dim=1)) 62 | # feat =Channel_shuffle(feat,4) # 不知道这个通道洗牌有用吗,但是速度没怎么降低 63 | return self.fuse_main(self.res_main(feat)+feat) #self.res_main(feat)+ 64 | # return self.fuse_main(feat) #self.res_main(feat)+ 65 | 66 | 67 | class SDFM(nn.Module): 68 | def __init__(self, in_xC, in_yC, out_C, kernel_size=3, down_factor=4): 69 | """ 70 | Args: 71 | in_xC (int): 第一个输入的通道数 72 | in_yC (int): 第二个输入的通道数 73 | out_C (int): 最终输出的通道数 74 | kernel_size (int): 指定的生成的卷积核的大小 75 | down_factor (int): 用来降低卷积核生成过程中的参数量的一个降低通道数的参数 76 | """ 77 | #(32, 32, 32, 3, 4) 78 | super(SDFM, self).__init__() 79 | self.kernel_size = kernel_size 80 | self.mid_c = out_C # 这里没有缩减通道 =8 81 | self.down_input = nn.Conv2d(in_xC, self.mid_c, 1) 82 | self.branch_1 = DepthDC3x3_1(self.mid_c, in_yC, self.mid_c, down_factor=down_factor) 83 | self.fuse = BasicConv_PRelu(2 * self.mid_c, out_C, 3, 1, 1) 84 | 85 | def forward(self, x, y): 86 | x = self.down_input(x) # channel 32 to 8 87 | result_1 = self.branch_1(x, y) 88 | # result_3 = self.branch_3(x, y) 89 | # result_5 = self.branch_5(x, y) 90 | # return self.fuse(torch.cat((x, result_1, result_3, result_5), dim=1)) 91 | return self.fuse(torch.cat((x, result_1), dim=1)) 92 | 93 | 94 | class DepthDC3x3_1(nn.Module): 95 | def __init__(self, in_xC, in_yC, out_C, down_factor=4): 96 | """DepthDC3x3_1,利用nn.Unfold实现的动态卷积模块 97 | 这里的x应该是被卷的,y是核 98 | Args: 99 | in_xC (int): 第一个输入的通道数 rgb 100 | in_yC (int): 第二个输入的通道数 kernel 101 | out_C (int): 最终输出的通道数 102 | down_factor (int): 用来降低卷积核生成过程中的参数量的一个降低通道数的参数 103 | 104 | 这个版本改为是通道的和空间的平行,并且采用分组的方式,最后cat在一起然后洗牌 105 | """ 106 | super(DepthDC3x3_1, self).__init__() 107 | 108 | self.kernel_size = 3 109 | mid_in_yC = in_yC//2 110 | mid_in_xC = in_xC//2 111 | self.fuse = nn.Conv2d(in_xC, out_C, 3, 1, 1) 112 | self.gernerate_kernel_spatial = nn.Sequential( 113 | nn.Conv2d(mid_in_yC, mid_in_yC, 3, 1, 1), 114 | # DenseLayer(in_yC, in_yC, k=down_factor), 115 | nn.Conv2d(mid_in_yC, self.kernel_size ** 2, 1),# in_xC 116 | #N C W H -> N k2 W H 117 | ) 118 | self.gernerate_kernel_channel = nn.Sequential( 119 | # nn.Conv2d(in_yC, in_yC, 3, 1, 1), 120 | # DenseLayer(in_yC, in_yC, k=down_factor), 121 | nn.AdaptiveAvgPool2d(self.kernel_size), 122 | nn.Conv2d(mid_in_yC, mid_in_xC, 1), 123 | ) 124 | self.unfold = nn.Unfold(kernel_size=3, dilation=1, padding=1, stride=1) 125 | self.padding = 1 126 | self.dilation = 1 127 | self.stride = 1 128 | self.dynamic_bias = None 129 | 130 | # channel attention part 131 | self.avg_pool = nn.AdaptiveAvgPool2d((self.kernel_size, self.kernel_size)) 132 | self.num_lat = int((self.kernel_size * self.kernel_size) / 2 + 1) 133 | self.ce = nn.Linear(self.kernel_size * self.kernel_size, self.num_lat, False) 134 | self.gd = nn.Linear(self.num_lat, self.kernel_size * self.kernel_size, False) 135 | self.ce_bn = nn.BatchNorm1d(mid_in_xC) 136 | # 激活层 137 | self.act = nn.ReLU(inplace=True) 138 | self.sig = nn.Sigmoid() 139 | 140 | # spatial attention part 141 | self.conv_sp_1 = nn.Conv2d(mid_in_xC,1,kernel_size=3,padding=1,bias=False) 142 | self.conv_sp = nn.Conv2d(2,1,kernel_size=3,padding=1,bias=False) 143 | self.unfold_sp = nn.Unfold(kernel_size=3, dilation=1, padding=1, stride=1) 144 | self.sig_sp = nn.Sigmoid() 145 | def forward(self, x, y): # x : rgb y :kernel 146 | N, xC, xH, xW = x.size() 147 | 148 | # split 149 | x1,x2 =Split(x) 150 | y1,y2 = Split(y) 151 | # channel filter 152 | # --------------channel attention------------------- 153 | # 这里是用混合特征生成核 154 | N, yC, yH, yW = y1.size() 155 | gl = self.avg_pool(x1).view(N,yC, -1) # N C k^2 156 | # # 实际实现过程就是一个se一样的 157 | out = self.ce(gl) # N C numlat 158 | out = self.ce_bn(out) # bn 159 | out = self.act(out) # act 160 | out = self.gd(out) 161 | out = self.sig(out.view(-1,1,self.kernel_size,self.kernel_size)) 162 | # 163 | kernel_channel = self.gernerate_kernel_channel(y1) 164 | # 165 | kernel_channel = kernel_channel.reshape(-1, 1, self.kernel_size, self.kernel_size) 166 | #kernel * filter 167 | kernel_channel_after = kernel_channel*out 168 | # ----------------------------------------------------- 169 | # 1 NC k k 170 | x_input = x1.view(1, -1, x1.size()[2], x1.size()[3]) 171 | channel_after = F.conv2d(x_input, weight=kernel_channel_after, bias=self.dynamic_bias, stride=self.stride, 172 | padding=self.padding, dilation=self.dilation, groups=N * xC//2) 173 | channel_after = channel_after.reshape(N, -1, xH, xW) 174 | 175 | # spatial filter 176 | kernel = self.gernerate_kernel_spatial(y2) 177 | kernel = kernel.reshape([N, self.kernel_size ** 2, xH, xW, 1]) 178 | # spatial attention 179 | # 这个是CBAM里的空间att 180 | kernel_sp = self.conv_sp_1(x2) # N 1 H W 181 | kernel_sp =self.unfold_sp(kernel_sp).reshape([N,-1,xH,xW,1]) # N k2 H W 1 182 | avg_out = torch.mean(kernel_sp, dim=1, keepdim=True) # N 1 H W 1 183 | max_out,_ = torch.max(kernel_sp, dim=1, keepdim=True) # N 1 H W 1 184 | x = torch.cat([avg_out, max_out], dim=1) # N 2 H W 1 185 | x = x.squeeze(4) # N 1 H W 1 186 | x = self.conv_sp(x) # N 1 W H 187 | sp_x = self.sig_sp(x) # N 1 W H 这里可以当作是一个att 188 | sp_x = sp_x.unsqueeze(4) 189 | sp_x =sp_x.permute(0,2,3,4,1).contiguous() # N W H 1 C 190 | kernel = kernel.permute(0,2,3,4,1).contiguous() # N W H 1 C 191 | 192 | kernel_after = kernel* sp_x 193 | 194 | kernel_after = kernel_after.permute(0,4,1,2,3) # N C W H 1 195 | # 这里就应该是kernel的部分 196 | kernel_after = kernel_after.permute(0, 2, 3, 1, 4).contiguous() # N H W k2 1 197 | unfold_x = self.unfold(x2).reshape([N, xH, xW, xC//2, -1]) # N H W C k2 198 | # 这里就是两个矩阵的低维度相乘 199 | spatial_after = torch.matmul(unfold_x,kernel_after) #N H W k2 1×N H W C k2 = N H W C 1 200 | spatial_after = spatial_after.squeeze(4).permute(0,3,1,2) 201 | 202 | result = torch.cat([channel_after,spatial_after],dim=1) 203 | 204 | 205 | return self.fuse(result) -------------------------------------------------------------------------------- /C2DFNet/metric/metric.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | from scipy.ndimage import center_of_mass, convolve, distance_transform_edt as bwdist 4 | 5 | 6 | class CalFM(object): 7 | # Fmeasure(maxFm, meanFm)---Frequency-tuned salient region detection(CVPR 2009) 8 | def __init__(self, num, thds=255): 9 | self.precision = np.zeros((num, thds)) 10 | self.recall = np.zeros((num, thds)) 11 | self.meanF = np.zeros(num) 12 | self.idx = 0 13 | self.num = num 14 | 15 | def update(self, pred, gt): 16 | if gt.max() != 0: 17 | prediction, recall, mfmeasure = self.cal(pred, gt) 18 | self.precision[self.idx, :] = prediction 19 | self.recall[self.idx, :] = recall 20 | self.meanF[self.idx] = mfmeasure 21 | self.idx += 1 22 | 23 | def cal(self, pred, gt): 24 | ########################meanF############################## 25 | th = 2 * pred.mean() 26 | if th > 1: 27 | th = 1 28 | binary = np.zeros_like(pred) 29 | binary[pred >= th] = 1 30 | hard_gt = np.zeros_like(gt) 31 | hard_gt[gt > 0.5] = 1 32 | tp = (binary * hard_gt).sum() 33 | if tp == 0: 34 | mfmeasure = 0 35 | else: 36 | pre = tp / binary.sum() 37 | rec = tp / hard_gt.sum() 38 | mfmeasure = 1.3 * pre * rec / (0.3 * pre + rec) 39 | 40 | ########################maxF############################## 41 | pred = np.uint8(pred * 255) 42 | target = pred[gt > 0.5] 43 | nontarget = pred[gt <= 0.5] 44 | targetHist, _ = np.histogram(target, bins=range(256)) 45 | nontargetHist, _ = np.histogram(nontarget, bins=range(256)) 46 | targetHist = np.cumsum(np.flip(targetHist), axis=0) 47 | nontargetHist = np.cumsum(np.flip(nontargetHist), axis=0) 48 | precision = targetHist / (targetHist + nontargetHist + 1e-8) 49 | recall = targetHist / np.sum(gt) 50 | return precision, recall, mfmeasure 51 | 52 | def show(self): 53 | assert self.num == self.idx, f"{self.num}, {self.idx}" 54 | precision = self.precision.mean(axis=0) 55 | recall = self.recall.mean(axis=0) 56 | fmeasure = 1.3 * precision * recall / (0.3 * precision + recall + 1e-8) 57 | mmfmeasure = np.around(self.meanF.mean(),4) 58 | return fmeasure, fmeasure.max(), mmfmeasure, precision, recall 59 | 60 | 61 | class CalMAE(object): 62 | # mean absolute error 63 | def __init__(self, num): 64 | # self.prediction = [] 65 | self.prediction = np.zeros(num) 66 | self.idx = 0 67 | self.num = num 68 | 69 | def update(self, pred, gt): 70 | self.prediction[self.idx] = self.cal(pred, gt) 71 | self.idx += 1 72 | 73 | def cal(self, pred, gt): 74 | return np.mean(np.abs(pred - gt)) 75 | 76 | def show(self): 77 | assert self.num == self.idx, f"{self.num}, {self.idx}" 78 | return np.around(self.prediction.mean(),4) 79 | 80 | 81 | class CalSM(object): 82 | # Structure-measure: A new way to evaluate foreground maps (ICCV 2017) 83 | def __init__(self, num, alpha=0.5): 84 | self.prediction = np.zeros(num) 85 | self.alpha = alpha 86 | self.idx = 0 87 | self.num = num 88 | 89 | def update(self, pred, gt): 90 | gt = gt > 0.5 91 | self.prediction[self.idx] = self.cal(pred, gt) 92 | self.idx += 1 93 | 94 | def show(self): 95 | assert self.num == self.idx, f"{self.num}, {self.idx}" 96 | return np.around(self.prediction.mean(),4) 97 | 98 | def cal(self, pred, gt): 99 | y = np.mean(gt) 100 | if y == 0: 101 | score = 1 - np.mean(pred) 102 | elif y == 1: 103 | score = np.mean(pred) 104 | else: 105 | score = self.alpha * self.object(pred, gt) + (1 - self.alpha) * self.region(pred, gt) 106 | return score 107 | 108 | def object(self, pred, gt): 109 | fg = pred * gt 110 | bg = (1 - pred) * (1 - gt) 111 | 112 | u = np.mean(gt) 113 | return u * self.s_object(fg, gt) + (1 - u) * self.s_object(bg, np.logical_not(gt)) 114 | 115 | def s_object(self, in1, in2): 116 | x = np.mean(in1[in2]) 117 | sigma_x = np.std(in1[in2]) 118 | return 2 * x / (pow(x, 2) + 1 + sigma_x + 1e-8) 119 | 120 | def region(self, pred, gt): 121 | [y, x] = center_of_mass(gt) 122 | y = int(round(y)) + 1 123 | x = int(round(x)) + 1 124 | [gt1, gt2, gt3, gt4, w1, w2, w3, w4] = self.divideGT(gt, x, y) 125 | pred1, pred2, pred3, pred4 = self.dividePred(pred, x, y) 126 | 127 | score1 = self.ssim(pred1, gt1) 128 | score2 = self.ssim(pred2, gt2) 129 | score3 = self.ssim(pred3, gt3) 130 | score4 = self.ssim(pred4, gt4) 131 | 132 | return w1 * score1 + w2 * score2 + w3 * score3 + w4 * score4 133 | 134 | def divideGT(self, gt, x, y): 135 | h, w = gt.shape 136 | area = h * w 137 | LT = gt[0:y, 0:x] 138 | RT = gt[0:y, x:w] 139 | LB = gt[y:h, 0:x] 140 | RB = gt[y:h, x:w] 141 | 142 | w1 = x * y / area 143 | w2 = y * (w - x) / area 144 | w3 = (h - y) * x / area 145 | w4 = (h - y) * (w - x) / area 146 | 147 | return LT, RT, LB, RB, w1, w2, w3, w4 148 | 149 | def dividePred(self, pred, x, y): 150 | h, w = pred.shape 151 | LT = pred[0:y, 0:x] 152 | RT = pred[0:y, x:w] 153 | LB = pred[y:h, 0:x] 154 | RB = pred[y:h, x:w] 155 | 156 | return LT, RT, LB, RB 157 | 158 | def ssim(self, in1, in2): 159 | in2 = np.float32(in2) 160 | h, w = in1.shape 161 | N = h * w 162 | 163 | x = np.mean(in1) 164 | y = np.mean(in2) 165 | sigma_x = np.var(in1) 166 | sigma_y = np.var(in2) 167 | sigma_xy = np.sum((in1 - x) * (in2 - y)) / (N - 1) 168 | 169 | alpha = 4 * x * y * sigma_xy 170 | beta = (x * x + y * y) * (sigma_x + sigma_y) 171 | 172 | if alpha != 0: 173 | score = alpha / (beta + 1e-8) 174 | elif alpha == 0 and beta == 0: 175 | score = 1 176 | else: 177 | score = 0 178 | 179 | return score 180 | 181 | 182 | class CalEM(object): 183 | # Enhanced-alignment Measure for Binary Foreground Map Evaluation (IJCAI 2018) 184 | def __init__(self, num): 185 | self.prediction = np.zeros(num) 186 | self.idx = 0 187 | self.num = num 188 | 189 | def update(self, pred, gt): 190 | self.prediction[self.idx] = self.cal(pred, gt) 191 | self.idx += 1 192 | 193 | def cal(self, pred, gt): 194 | th = 2 * pred.mean() 195 | if th > 1: 196 | th = 1 197 | FM = np.zeros(gt.shape) 198 | FM[pred >= th] = 1 199 | FM = np.array(FM, dtype=bool) 200 | GT = np.array(gt, dtype=bool) 201 | dFM = np.double(FM) 202 | if sum(sum(np.double(GT))) == 0: 203 | enhanced_matrix = 1.0 - dFM 204 | elif sum(sum(np.double(~GT))) == 0: 205 | enhanced_matrix = dFM 206 | else: 207 | dGT = np.double(GT) 208 | align_matrix = self.AlignmentTerm(dFM, dGT) 209 | enhanced_matrix = self.EnhancedAlignmentTerm(align_matrix) 210 | [w, h] = np.shape(GT) 211 | score = sum(sum(enhanced_matrix)) / (w * h - 1 + 1e-8) 212 | return score 213 | 214 | def AlignmentTerm(self, dFM, dGT): 215 | mu_FM = np.mean(dFM) 216 | mu_GT = np.mean(dGT) 217 | align_FM = dFM - mu_FM 218 | align_GT = dGT - mu_GT 219 | align_Matrix = 2.0 * (align_GT * align_FM) / (align_GT * align_GT + align_FM * align_FM + 1e-8) 220 | return align_Matrix 221 | 222 | def EnhancedAlignmentTerm(self, align_Matrix): 223 | enhanced = np.power(align_Matrix + 1, 2) / 4 224 | return enhanced 225 | 226 | def show(self): 227 | assert self.num == self.idx, f"{self.num}, {self.idx}" 228 | return np.around(self.prediction.mean()) 229 | 230 | 231 | class CalWFM(object): 232 | def __init__(self, num, beta=1): 233 | self.scores_list = np.zeros(num) 234 | self.beta = beta 235 | self.eps = 1e-6 236 | self.idx = 0 237 | self.num = num 238 | 239 | def update(self, pred, gt): 240 | gt = gt > 0.5 241 | self.scores_list[self.idx] = 0 if gt.max() == 0 else self.cal(pred, gt) 242 | self.idx += 1 243 | 244 | def matlab_style_gauss2D(self, shape=(7, 7), sigma=5): 245 | """ 246 | 2D gaussian mask - should give the same result as MATLAB's 247 | fspecial('gaussian',[shape],[sigma]) 248 | """ 249 | m, n = [(ss - 1.0) / 2.0 for ss in shape] 250 | y, x = np.ogrid[-m : m + 1, -n : n + 1] 251 | h = np.exp(-(x * x + y * y) / (2.0 * sigma * sigma)) 252 | h[h < np.finfo(h.dtype).eps * h.max()] = 0 253 | sumh = h.sum() 254 | if sumh != 0: 255 | h /= sumh 256 | return h 257 | 258 | def cal(self, pred, gt): 259 | # [Dst,IDXT] = bwdist(dGT); 260 | Dst, Idxt = bwdist(gt == 0, return_indices=True) 261 | 262 | # %Pixel dependency 263 | # E = abs(FG-dGT); 264 | E = np.abs(pred - gt) 265 | # Et = E; 266 | # Et(~GT)=Et(IDXT(~GT)); %To deal correctly with the edges of the foreground region 267 | Et = np.copy(E) 268 | Et[gt == 0] = Et[Idxt[0][gt == 0], Idxt[1][gt == 0]] 269 | 270 | # K = fspecial('gaussian',7,5); 271 | # EA = imfilter(Et,K); 272 | # MIN_E_EA(GT & EA= 0 315 | assert gt.max() <= 1 and gt.min() >= 0 316 | 317 | self.cal_mae.update(pred, gt) 318 | self.cal_fm.update(pred, gt) 319 | self.cal_sm.update(pred, gt) 320 | self.cal_em.update(pred, gt) 321 | self.cal_wfm.update(pred, gt) 322 | 323 | def show(self): 324 | MAE = self.cal_mae.show() 325 | _, Maxf, Meanf, _, _, = self.cal_fm.show() 326 | SM = self.cal_sm.show() 327 | EM = self.cal_em.show() 328 | WFM = self.cal_wfm.show() 329 | results = { 330 | "MaxF": Maxf, 331 | "MeanF": Meanf, 332 | "WFM": WFM, 333 | "MAE": MAE, 334 | "SM": SM, 335 | "EM": EM, 336 | } 337 | return results 338 | 339 | 340 | if __name__ == "__main__": 341 | pred = Image 342 | -------------------------------------------------------------------------------- /C2DFNet/MDEM.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from models.BaseBlocks import BasicConv_PRelu,BasicConv2d 4 | import torch.nn.functional as F 5 | 6 | 7 | class decoder_plus(nn.Module): 8 | def __init__(self,code=512): 9 | super(decoder_plus,self).__init__() 10 | self.conv_lowfuse = BasicConv_PRelu(96,64,1,1,bias=False) 11 | self.conv_highcode = BasicConv_PRelu(96,code,1,1,bias=False) 12 | self.conv_highfuse = BasicConv_PRelu(96,64,1,1,bias=False) 13 | self.conv_transfor = nn.Conv2d(64,code,1,1,padding=0,bias=False) 14 | self.conv_fuse = BasicConv_PRelu(96,64,1,1,bias=False) 15 | self.softmax_weight = nn.Softmax(dim=1) 16 | self.gap = nn.AdaptiveAvgPool2d(1) 17 | def forward(self, fea_down,fea_up): # up low \ down high 18 | # high 19 | n,c,hh,wh = fea_down.shape 20 | n,c,hl,wl = fea_up.shape 21 | fea_down_code = self.softmax_weight(self.conv_highcode(fea_down)) #n 512 h w 22 | 23 | fea_down_code =fea_down_code.view(n,-1,hh*wh) 24 | fea_down_code_t = fea_down_code.transpose(1, 2).contiguous() 25 | # base map 26 | fea_down_base = self.conv_highfuse(fea_down) # n c h w 27 | fea_down_base_change = fea_down_base.view(n,-1,hh*wh) 28 | 29 | # pool 30 | fea_down_base_pool = self.gap(fea_down_base) 31 | fea_down_base_pool = F.interpolate(fea_down_base_pool,size=(fea_up.shape[-2],fea_up.shape[-1]),mode="bilinear",align_corners=False) 32 | codebook = torch.matmul(fea_down_base_change,fea_down_code_t) # n c code 33 | 34 | # low 35 | fea_up_fuse_base = self.conv_lowfuse(fea_up) # N C H W 36 | fea_up_fuse_trans = fea_up_fuse_base + fea_down_base_pool 37 | fea_up_fuse_trans = self.conv_transfor(fea_up_fuse_trans) # N code H W 38 | fea_up_fuse_trans = fea_up_fuse_trans.view(n,fea_up_fuse_trans.shape[1],-1)# N code H*W 39 | 40 | 41 | # multiply 42 | fea_new = torch.matmul(codebook,fea_up_fuse_trans) # N C H*W 43 | fea_new = fea_new.view(n,-1,hl,wl) 44 | 45 | # fuse 46 | final = self.conv_fuse(torch.cat((fea_new,fea_up_fuse_base),dim=1)) 47 | return final 48 | class DenseLayer(nn.Module): 49 | def __init__(self, in_C, out_C, down_factor=4, k=4): 50 | """ 51 | 更像是DenseNet的Block,从而构造特征内的密集连接 52 | """ 53 | super(DenseLayer, self).__init__() 54 | self.k = k 55 | self.down_factor = down_factor 56 | mid_C = out_C // self.down_factor 57 | 58 | self.down = nn.Conv2d(in_C, mid_C, 1) 59 | 60 | self.denseblock = nn.ModuleList() 61 | for i in range(1, self.k + 1): 62 | self.denseblock.append(BasicConv2d(mid_C * i, mid_C, 3, 1, 1)) 63 | 64 | self.fuse = BasicConv2d(in_C + mid_C, out_C, kernel_size=3, stride=1, padding=1) 65 | 66 | def forward(self, in_feat): 67 | down_feats = self.down(in_feat) 68 | out_feats = [] 69 | for denseblock in self.denseblock: 70 | feats = denseblock(torch.cat((*out_feats, down_feats), dim=1)) 71 | out_feats.append(feats) 72 | feats = torch.cat((in_feat, feats), dim=1) 73 | return self.fuse(feats) 74 | 75 | class DFM(nn.Module): 76 | def __init__(self): 77 | super(DFM,self).__init__() 78 | #计算改的通道数量 79 | # large = 32 80 | # middle = 16 81 | # small = 8 82 | large = 128 83 | middle = 64 84 | small = 32 85 | self.d_ls = int(large/small) # 4 这个比例是扩张比例 86 | self.d_ms = int(middle/small) # 2 87 | self.kup = 3 # 这个是核的大小 88 | 89 | # 用于最后的特征融合 90 | # 这个融合的部分用3x3的还是1x1的呢 91 | self.conv_fuse = BasicConv_PRelu(192, 32, 1, 1, bias=False) 92 | # 用于改通道生成核 93 | # high 94 | self.conv_chchaneg_high = BasicConv_PRelu(32, self.kup ** 2, 1, 1) # 这里high的不用d的扩张比例 95 | # mid 96 | self.conv_chchaneg_mid = BasicConv_PRelu(32,self.d_ms**2*self.kup**2,1,1) # 这里该用3x3还是1x1 97 | # low 98 | self.conv_chchaneg_low = BasicConv_PRelu(32,self.d_ls**2*self.kup**2,1,1) 99 | 100 | #resize channel residual part 101 | self.conv_low_sp = BasicConv_PRelu(32,self.kup**2,1,1) 102 | self.conv_mid_sp = BasicConv_PRelu(32,self.kup**2,1,1) 103 | 104 | #-----------通道Dy的部分--------------------- 105 | # 空洞率1 106 | # low 107 | self.resize_low = BasicConv_PRelu(32,32*self.d_ls**2,1,1) #变成 N C*d2 H W 108 | self.gernerate_kernel_channel = nn.Sequential( 109 | nn.Conv2d(32, 32, 1, 1, 1), 110 | # DenseLayer(in_yC, in_yC, k=down_factor), 111 | nn.AdaptiveAvgPool2d(self.kup), 112 | #nn.Conv2d(32, 32, 1), 113 | BasicConv_PRelu(32,32,1), 114 | ) 115 | self.padding = 1 116 | self.dilation = 1 117 | self.stride = 1 118 | self.dynamic_bias = None 119 | #mid 120 | self.resize_mid = BasicConv_PRelu(32,32*self.d_ms**2,1,1) 121 | self.gernerate_kernel_channel_mid = nn.Sequential( 122 | nn.Conv2d(32, 32, 1, 1, 1), 123 | nn.AdaptiveAvgPool2d(self.kup), 124 | BasicConv_PRelu(32, 32, 1), 125 | ) 126 | self.padding_mid = 1 127 | self.dilation_mid = 1 128 | self.stride_mid = 1 129 | self.dynamic_bias_mid = None 130 | 131 | self.gernerate_kernel_channel_high = nn.Sequential( 132 | nn.Conv2d(32, 32, 1, 1, 1), 133 | nn.AdaptiveAvgPool2d(self.kup), 134 | BasicConv_PRelu(32, 32, 1), 135 | ) 136 | self.padding_high = 1 137 | self.dilation_high = 1 138 | self.stride_high = 1 139 | self.dynamic_bias_high = None 140 | 141 | # fuse conv 142 | self.conv_fuse_high = BasicConv_PRelu(96,32,1) 143 | self.gap = nn.AdaptiveAvgPool2d(1) 144 | 145 | 146 | 147 | 148 | def forward(self,fea_low,fea_mid,fea_high):# up low up的尺寸大\ down high的尺寸小 149 | # 这里的特征是两个concat在一起的 150 | 151 | # 尺寸小 152 | nh, ch, hh, wh = fea_high.shape 153 | #中等尺寸的 154 | nm,cm,hm,wm = fea_mid.shape 155 | #尺寸大的 156 | nl, cl, hl, wl = fea_low.shape 157 | 158 | # resize to high 159 | fea_low_new = F.interpolate(fea_low,size=(fea_high.shape[-2],fea_high.shape[-1]),mode="bilinear",align_corners=False) 160 | fea_mid_new = F.interpolate(fea_mid,size=(fea_high.shape[-2],fea_high.shape[-1]),mode="bilinear",align_corners=False) 161 | 162 | 163 | #fuse high 164 | fea_high_fused = self.conv_fuse_high(torch.cat([fea_high,fea_mid_new,fea_low_new],dim=1)) 165 | 166 | # pool 167 | fea_down_pool = self.gap(fea_high_fused) 168 | fea_down_pool_high = F.interpolate(fea_down_pool,size=(fea_high.shape[-2],fea_high.shape[-1]),mode="bilinear",align_corners=False) 169 | fea_down_pool_mid = F.interpolate(fea_down_pool,size=(fea_mid.shape[-2],fea_mid.shape[-1]),mode="bilinear",align_corners=False) 170 | fea_down_pool_low = F.interpolate(fea_down_pool,size=(fea_low.shape[-2],fea_low.shape[-1]),mode="bilinear",align_corners=False) 171 | 172 | 173 | 174 | #-----生成dy的核的部分-------- 175 | # low 176 | kernel_tensor_low = self.conv_chchaneg_low(fea_high_fused) # N d^2*k^2 w h 177 | kernel_tensor_low = F.pixel_shuffle(kernel_tensor_low,self.d_ls) # N d^2*k^2 w h -> N k^2 dh dw = N k^2 H W 178 | # 这个正规化有没有用再试试 179 | # 添加特征low和high 180 | fea_low_d = self.conv_low_sp(fea_low) # 这里相当于去添加原来的特征。 181 | kernel_tensor_low += fea_low_d 182 | kernel_tensor_low = F.softmax(kernel_tensor_low,dim=1) # N k2 H W 这个对 183 | # reshape成k2在最后 184 | kernel_tensor_low = kernel_tensor_low.permute(0,2,3,1).contiguous() # N H W k^2 185 | # mid 186 | kernel_tensor_mid = self.conv_chchaneg_mid(fea_high_fused) 187 | kernel_tensor_mid = F.pixel_shuffle(kernel_tensor_mid,self.d_ms) 188 | fea_mid_d =self.conv_mid_sp(fea_mid) 189 | kernel_tensor_mid +=fea_mid_d 190 | kernel_tensor_mid = F.softmax(kernel_tensor_mid,dim=1) 191 | kernel_tensor_mid = kernel_tensor_mid.permute(0,2,3,1).contiguous() 192 | 193 | # high 194 | kernel_tensor_high = self.conv_chchaneg_high(fea_high_fused) # 这里少一步 195 | kernel_tensor_high = F.softmax(kernel_tensor_high,dim=1) 196 | kernel_tensor_high = kernel_tensor_high.permute(0,2,3,1).contiguous() 197 | 198 | #------生成特征的d维度-------- 199 | # N C H+k W+k 200 | # low 201 | # New !!这里添加了各种pool 202 | fea_low_pad = F.pad(fea_low+fea_down_pool_low, pad=(self.kup // 2, self.kup // 2,self.kup // 2, self.kup // 2),mode='constant', value=0) 203 | fea_low_pad = fea_low_pad.unfold(dimension=2,size=self.kup,step=1) # N C H W+k k 204 | fea_low_pad = fea_low_pad.unfold(3,self.kup,step=1) # N C H W k k 205 | fea_low_pad = fea_low_pad.reshape(nl,cl,hl,wl,-1) # N C H W k^2 206 | fea_low_pad = fea_low_pad.permute(0,2,3,1,4).contiguous() # N H W C k^2 207 | # mid 208 | # New !!这里添加了各种pool 209 | fea_mid_pad = F.pad(fea_mid+fea_down_pool_mid, pad=(self.kup // 2, self.kup // 2, self.kup // 2, self.kup // 2), mode='constant', 210 | value=0) 211 | fea_mid_pad = fea_mid_pad.unfold(dimension=2, size=self.kup, step=1) # N C H W+k k 212 | fea_mid_pad = fea_mid_pad.unfold(3, self.kup, step=1) # N C H W k k 213 | fea_mid_pad = fea_mid_pad.reshape(nm, cm, hm, wm, -1) # N C H W k^2 214 | fea_mid_pad = fea_mid_pad.permute(0, 2, 3, 1, 4).contiguous() # N H W C k^2 215 | # high 216 | # New !!这里添加了各种pool 217 | fea_high_pad = F.pad(fea_high+fea_down_pool_high, pad=(self.kup // 2, self.kup // 2, self.kup // 2, self.kup // 2), mode='constant', 218 | value=0) 219 | fea_high_pad = fea_high_pad.unfold(dimension=2, size=self.kup, step=1) # N C H W+k k 220 | fea_high_pad = fea_high_pad.unfold(3, self.kup, step=1) # N C H W k k 221 | fea_high_pad = fea_high_pad.reshape(nh, ch, hh, wh, -1) # N C H W k^2 222 | fea_high_pad = fea_high_pad.permute(0, 2, 3, 1, 4).contiguous() # N H W C k^2 223 | 224 | #-------核的相乘---------- 225 | # low 226 | # 首先扩张维度,将核的扩张成5维 227 | kernel_tensor_low = kernel_tensor_low.unsqueeze(4) # N H W k^2 1 228 | # N H W C k^2 * N H W k^2 1 = N H W C 1 229 | fea_low_new = torch.matmul(fea_low_pad,kernel_tensor_low) 230 | # 压缩通道 231 | fea_low_new = fea_low_new.squeeze(dim=4) # N H W C 232 | # 改变通道 233 | fea_low_new = fea_low_new.permute(0,3,1,2) # N C H W 234 | # mid 235 | # 首先扩张维度,将核的扩张成5维 236 | kernel_tensor_mid = kernel_tensor_mid.unsqueeze(4) # N H W k^2 1 237 | # N H W C k^2 * N H W k^2 1 = N H W C 1 238 | fea_mid_new = torch.matmul(fea_mid_pad, kernel_tensor_mid) 239 | # 压缩通道 240 | fea_mid_new = fea_mid_new.squeeze(dim=4) # N H W C 241 | # 改变通道 242 | fea_mid_new = fea_mid_new.permute(0, 3, 1, 2) # N C H W 243 | # high 244 | # 首先扩张维度,将核的扩张成5维 245 | kernel_tensor_high = kernel_tensor_high.unsqueeze(4) # N H W k^2 1 246 | # N H W C k^2 * N H W k^2 1 = N H W C 1 247 | fea_high_new = torch.matmul(fea_high_pad, kernel_tensor_high) 248 | # 压缩通道 249 | fea_high_new = fea_high_new.squeeze(dim=4) # N H W C 250 | # 改变通道 251 | fea_high_new = fea_high_new.permute(0, 3, 1, 2) # N C H W 252 | 253 | 254 | 255 | #--------通道Dy的部分---------- 256 | #low 257 | #空洞率1 258 | # 生成核的部分 259 | fea_high_lowch = self.resize_low(fea_high_fused) # N C*d2 h w 260 | fea_high_lowch = F.pixel_shuffle(fea_high_lowch,self.d_ls) # N C h*d w*d 261 | kernel_tensor_ch_low =self.gernerate_kernel_channel(fea_high_lowch+fea_low).reshape(-1, 1, self.kup, self.kup) #ch 32,# NC 1 k k 262 | # 改变input的shape 263 | fea_low_re = fea_low +fea_down_pool_low 264 | fea_up_change_low = fea_low_re.reshape(1, -1, fea_low.size()[2], fea_low.size()[3]) 265 | #卷积 266 | channel_after = F.conv2d(fea_up_change_low, weight=kernel_tensor_ch_low, bias=self.dynamic_bias, stride=self.stride, 267 | padding=self.padding, dilation=self.dilation, groups=nl * cl) 268 | channel_after =channel_after.reshape(nl, -1, hl, wl) 269 | 270 | # mid 271 | # 空洞率1 272 | # 生成核的部分 273 | fea_high_midch = self.resize_mid(fea_high_fused) # N C*d2 h w 274 | fea_high_midch = F.pixel_shuffle(fea_high_midch,self.d_ms) 275 | kernel_tensor_ch_mid = self.gernerate_kernel_channel_mid(fea_high_midch+fea_mid).reshape(-1, 1, self.kup, 276 | self.kup) # ch 32,# NC 1 k k 277 | # 改变input的shape 278 | fea_mid_re = fea_mid +fea_down_pool_mid 279 | fea_up_change_mid = fea_mid_re.reshape(1, -1, fea_mid.size()[2], fea_mid.size()[3]) 280 | # 卷积 281 | channel_after_mid = F.conv2d(fea_up_change_mid, weight=kernel_tensor_ch_mid, bias=self.dynamic_bias_mid, 282 | stride=self.stride_mid, 283 | padding=self.padding_mid, dilation=self.dilation_mid, groups=nm * cm) 284 | channel_after_mid = channel_after_mid.reshape(nm, -1, hm, wm) 285 | 286 | # high 287 | # 空洞率1 288 | # 生成核的部分 289 | kernel_tensor_ch_high = self.gernerate_kernel_channel_high(fea_high_fused).reshape(-1, 1, self.kup, 290 | self.kup) # ch 32,# NC 1 k k 291 | # 改变input的shape 292 | fea_high_re = fea_high +fea_down_pool_high 293 | fea_up_change_high = fea_high_re.reshape(1, -1, fea_high.size()[2], fea_high.size()[3]) 294 | # 卷积 295 | channel_after_high = F.conv2d(fea_up_change_high, weight=kernel_tensor_ch_high, bias=self.dynamic_bias_high, 296 | stride=self.stride_high, 297 | padding=self.padding_high, dilation=self.dilation_high, groups=nh * ch) 298 | channel_after_high = channel_after_high.reshape(nh, -1, hh, wh) 299 | 300 | # 特征融合 301 | # 特征合并 concat在一起fea_up_af, 302 | 303 | # 融合还是要上采样 304 | fea_high_new = F.interpolate(fea_high_new, 305 | size=(fea_low_new.shape[-2], fea_low_new.shape[-1]), 306 | mode="bilinear", align_corners=False) 307 | fea_mid_new = F.interpolate(fea_mid_new, 308 | size=(fea_low_new.shape[-2], fea_low_new.shape[-1]), 309 | mode="bilinear", align_corners=False) 310 | 311 | # ch dy 上采样 312 | channel_after_mid = F.interpolate(channel_after_mid, 313 | size=(fea_low_new.shape[-2], fea_low_new.shape[-1]), 314 | mode="bilinear", align_corners=False) 315 | channel_after_high = F.interpolate(channel_after_high, 316 | size=(fea_low_new.shape[-2], fea_low_new.shape[-1]), 317 | mode="bilinear", align_corners=False) 318 | 319 | fea_final = torch.cat([fea_high_new,channel_after_high,fea_mid_new,channel_after_mid,fea_low_new,channel_after], dim=1) 320 | 321 | fea_final = self.conv_fuse(fea_final) 322 | 323 | # 最后返回特征 324 | return fea_final --------------------------------------------------------------------------------