├── FSEL_ECCV_2024 ├── Instructions for lib files and lib_initial files.txt ├── Test.py ├── Test_flops.py ├── Train.py ├── environment.txt ├── lib │ ├── FSEL_modules.py │ ├── Network_PVT.py │ ├── Network_Res2Net.py │ ├── Network_ResNet.py │ ├── Res2Net_v1b.py │ ├── __init__.py │ ├── __pycache__ │ │ ├── GatedConv.cpython-38.pyc │ │ ├── Modules.cpython-38.pyc │ │ ├── Network.cpython-38.pyc │ │ └── __init__.cpython-38.pyc │ └── pvt_v2.py ├── lib_initial │ ├── Network_PVT_initial.py │ ├── Network_Res2Net_initial.py │ ├── Network_ResNet_initial.py │ └── module_FSEL.py ├── test │ ├── saliency_metric.py │ ├── sod_metrics.py │ ├── test_data.py │ ├── test_data1.py │ └── test_metric_score.py └── utils │ ├── FeatureViz.py │ ├── MyFeatureVisulization.py │ ├── cod10k_subclass_split.py │ ├── data_val.py │ ├── dataloader.py │ ├── fps.py │ ├── generate_LaTeX.py │ ├── heatmap.py │ ├── pytorch_jittor_convert.py │ ├── tif2png.py │ └── utils.py ├── LICENSE └── README.md /FSEL_ECCV_2024/Instructions for lib files and lib_initial files.txt: -------------------------------------------------------------------------------- 1 | The main code in the lib file and the lib_initial file is identical. The primary difference between the two lies in the module naming conventions. In the lib file, the code is named with the most precise labels, while lib_initial is a roughly named version primarily used to test the provided .pth file. -------------------------------------------------------------------------------- /FSEL_ECCV_2024/Test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import numpy as np 4 | import os, argparse 5 | import cv2 6 | from lib.Network_ResNet import Network 7 | from utils.data_val import test_dataset 8 | 9 | os.environ["CUDA_VISIBLE_DEVICES"] = '0' 10 | 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument('--testsize', type=int, default=416, help='testing size') # 13 | parser.add_argument('--pth_path', type=str, default='') 14 | parser.add_argument('--test_dataset_path', type=str, default='') 15 | opt = parser.parse_args() 16 | 17 | for _data_name in ['CHAMELEON','CAMO','NC4K','COD10K']: 18 | data_path = opt.test_dataset_path+'/{}/'.format(_data_name) 19 | save_path = './Our_ResNet/{}_3/{}/'.format(opt.pth_path.split('/')[-2], _data_name) 20 | os.makedirs(save_path, exist_ok=True) 21 | 22 | model = Network(channels=128) 23 | model.load_state_dict({k.replace('module.',''):v for k,v in torch.load(opt.pth_path).items()}) 24 | model.cuda() 25 | model.eval() 26 | 27 | image_root = '{}/Imgs/'.format(data_path) 28 | gt_root = '{}/GT/'.format(data_path) 29 | test_loader = test_dataset(image_root, gt_root, opt.testsize) 30 | 31 | for i in range(test_loader.size): 32 | image, gt, name, _ = test_loader.load_data() 33 | print('> {} - {}'.format(_data_name, name)) 34 | 35 | gt = np.asarray(gt, np.float32) 36 | gt /= (gt.max() + 1e-8) 37 | image = image.cuda() 38 | 39 | result = model(image) 40 | 41 | res = F.interpolate(result[4], size=gt.shape, mode='bilinear', align_corners=False) 42 | res = res.sigmoid().data.cpu().numpy().squeeze() 43 | res = (res - res.min()) / (res.max() - res.min() + 1e-8) 44 | cv2.imwrite(save_path+name,res*255) 45 | -------------------------------------------------------------------------------- /FSEL_ECCV_2024/Test_flops.py: -------------------------------------------------------------------------------- 1 | import os, argparse 2 | import cv2 3 | from lib.Network_PVT import Network 4 | import torch 5 | from thop import profile 6 | 7 | 8 | os.environ["CUDA_VISIBLE_DEVICES"] = '0' 9 | 10 | print('==> Building model..') 11 | input_features = torch.randn(1, 3, 416, 416) 12 | model = Network(128) 13 | 14 | flops, params = profile(model, (input_features,)) 15 | print('flops: ', flops, 'params: ', params) 16 | print('flops: %.2f G, params: %.2f M' % (flops / 1000000000.0, params / 1000000.0)) 17 | 18 | 19 | -------------------------------------------------------------------------------- /FSEL_ECCV_2024/Train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn.functional as F 4 | import numpy as np 5 | from datetime import datetime 6 | from torchvision.utils import make_grid 7 | from lib.Network_ResNet import Network 8 | 9 | from utils.data_val import get_loader, test_dataset 10 | from utils.utils import clip_gradient, adjust_lr, get_coef,cal_ual 11 | from tensorboardX import SummaryWriter 12 | import logging 13 | import torch.backends.cudnn as cudnn 14 | from torch import optim 15 | 16 | 17 | def structure_loss(pred, mask): 18 | weit = 1 + 5 * torch.abs(F.avg_pool2d(mask, kernel_size=31, stride=1, padding=15) - mask) 19 | wbce = F.binary_cross_entropy_with_logits(pred, mask, reduction='none') 20 | wbce = (weit * wbce).sum(dim=(2, 3)) / weit.sum(dim=(2, 3)) 21 | 22 | pred = torch.sigmoid(pred) 23 | inter = ((pred * mask) * weit).sum(dim=(2, 3)) 24 | union = ((pred + mask) * weit).sum(dim=(2, 3)) 25 | wiou = 1 - (inter + 1) / (union - inter + 1) 26 | return (wbce + wiou).mean() 27 | 28 | def dice_loss(predict, target): 29 | smooth = 1 30 | p = 2 31 | valid_mask = torch.ones_like(target) 32 | predict = predict.contiguous().view(predict.shape[0], -1) 33 | target = target.contiguous().view(target.shape[0], -1) 34 | valid_mask = valid_mask.contiguous().view(valid_mask.shape[0], -1) 35 | num = torch.sum(torch.mul(predict, target) * valid_mask, dim=1) * 2 + smooth 36 | den = torch.sum((predict.pow(p) + target.pow(p)) * valid_mask, dim=1) + smooth 37 | loss = 1 - num / den 38 | return loss.mean() 39 | 40 | def train(train_loader, model, optimizer, epoch, save_path, writer): 41 | global step 42 | model.train() 43 | loss_all = 0 44 | epoch_step = 0 45 | try: 46 | for i, (images, gts, edges) in enumerate(train_loader, start=1): 47 | optimizer.zero_grad() 48 | images = images.cuda(device=device_ids[0]) 49 | gts = gts.cuda(device=device_ids[0]) 50 | #edges = edges.cuda(device=device_ids[0]) 51 | 52 | preds = model(images) 53 | 54 | ual_coef = get_coef(iter_percentage=i/total_step, method='cos') 55 | ual_loss = cal_ual(seg_logits=preds[4], seg_gts=gts) 56 | ual_loss *= ual_coef 57 | 58 | loss_init = structure_loss(preds[0], gts)*0.0625 + structure_loss(preds[1], gts)*0.125 + structure_loss(preds[2], gts)*0.25 + \ 59 | structure_loss(preds[3], gts)*0.5 60 | loss_final = structure_loss(preds[4], gts) 61 | loss = loss_init + loss_final + 2 * ual_loss 62 | loss.backward() 63 | clip_gradient(optimizer, opt.clip) 64 | optimizer.step() 65 | 66 | 67 | 68 | 69 | step += 1 70 | epoch_step += 1 71 | loss_all += loss.data 72 | 73 | if i % 20 == 0 or i == total_step or i == 1: 74 | print('{} Epoch [{:03d}/{:03d}], Step [{:04d}/{:04d}], Total_loss: {:.4f} Loss1: {:.4f} Loss2: {:0.4f}'. 75 | format(datetime.now(), epoch, opt.epoch, i, total_step, loss.data, loss_init.data, loss_final.data)) # loss_edge.data 76 | logging.info( 77 | '[Train Info]:Epoch [{:03d}/{:03d}], Step [{:04d}/{:04d}], Total_loss: {:.4f} Loss1: {:.4f} Loss2: {:0.4f}'. 78 | format(epoch, opt.epoch, i, total_step, loss.data, loss_init.data, loss_final.data)) 79 | # TensorboardX-Loss 80 | writer.add_scalars('Loss_Statistics', 81 | {'Loss_init': loss_init.data, 'Loss_final': loss_final.data, 'Loss_total': loss.data}, 82 | global_step=step) 83 | # TensorboardX-Training Data 84 | grid_image = make_grid(images[0].clone().cpu().data, 1, normalize=True) 85 | writer.add_image('RGB', grid_image, step) 86 | grid_image = make_grid(gts[0].clone().cpu().data, 1, normalize=True) 87 | writer.add_image('GT', grid_image, step) 88 | 89 | # TensorboardX-Outputs 90 | res = preds[0][0].clone() 91 | res = res.sigmoid().data.cpu().numpy().squeeze() 92 | res = (res - res.min()) / (res.max() - res.min() + 1e-8) 93 | writer.add_image('Pred_init', torch.tensor(res), step, dataformats='HW') 94 | 95 | res = preds[4][0].clone() 96 | res = res.sigmoid().data.cpu().numpy().squeeze() 97 | res = (res - res.min()) / (res.max() - res.min() + 1e-8) 98 | writer.add_image('Pred_final', torch.tensor(res), step, dataformats='HW') 99 | 100 | loss_all /= epoch_step 101 | logging.info('[Train Info]: Epoch [{:03d}/{:03d}], Loss_AVG: {:.4f}'.format(epoch, opt.epoch, loss_all)) 102 | writer.add_scalar('Loss-epoch', loss_all, global_step=epoch) 103 | if epoch % 80 == 0: 104 | torch.save(model.state_dict(), save_path + 'Net_epoch_{}.pth'.format(epoch)) 105 | except KeyboardInterrupt: 106 | print('Keyboard Interrupt: save model and exit.') 107 | if not os.path.exists(save_path): 108 | os.makedirs(save_path) 109 | torch.save(model.state_dict(), save_path + 'Net_epoch_{}.pth'.format(epoch + 1)) 110 | print('Save checkpoints successfully!') 111 | raise 112 | 113 | 114 | def val(test_loader, model, epoch, save_path, writer): 115 | """ 116 | validation function 117 | """ 118 | global best_mae, best_epoch 119 | model.eval() 120 | with torch.no_grad(): 121 | mae_sum = 0 122 | mae_sum_edge = 0 123 | for i in range(test_loader.size): 124 | image, gt, name, img_for_post = test_loader.load_data() 125 | gt = np.asarray(gt, np.float32) 126 | gt /= (gt.max() + 1e-8) 127 | image = image.cuda(device=device_ids[0]) 128 | 129 | result = model(image) 130 | 131 | res = F.interpolate(result[4], size=gt.shape, mode='bilinear', align_corners=False) 132 | res = res.sigmoid().data.cpu().numpy().squeeze() 133 | res = (res - res.min()) / (res.max() - res.min() + 1e-8) 134 | mae_sum += np.sum(np.abs(res - gt)) * 1.0 / (gt.shape[0] * gt.shape[1]) 135 | 136 | mae = mae_sum / test_loader.size 137 | writer.add_scalar('MAE', torch.tensor(mae), global_step=epoch) 138 | print('Epoch: {}, MAE: {}, bestMAE: {}, bestEpoch: {}.'.format(epoch, mae, best_mae, best_epoch)) 139 | if epoch == 1: 140 | best_mae = mae 141 | best_epoch = 1 142 | else: 143 | if mae < best_mae: 144 | best_mae = mae 145 | best_epoch = epoch 146 | torch.save(model.state_dict(), save_path + 'Net_epoch_best.pth') 147 | print('Save state_dict successfully! Best epoch:{}.'.format(epoch)) 148 | logging.info( 149 | '[Val Info]:Epoch:{} MAE:{} bestEpoch:{} bestMAE:{}'.format(epoch, mae, best_epoch, best_mae)) 150 | 151 | if __name__ == '__main__': 152 | import argparse 153 | 154 | parser = argparse.ArgumentParser() 155 | parser.add_argument('--epoch', type=int, default=180, help='epoch number') 156 | parser.add_argument('--lr', type=float, default=1e-4, help='learning rate') 157 | parser.add_argument('--batchsize', type=int, default=40, help='training batch size') 158 | parser.add_argument('--trainsize', type=int, default=416, help='training dataset size') 159 | parser.add_argument('--clip', type=float, default=0.5, help='gradient clipping margin') 160 | parser.add_argument('--decay_rate', type=float, default=0.1, help='decay rate of learning rate') 161 | parser.add_argument('--decay_epoch', type=int, default=60, help='every n epochs decay learning rate') 162 | parser.add_argument('--load', type=str, default=None, help='train from checkpoints') 163 | parser.add_argument('--gpu_id', type=str, default='0,1,2,3', help='train use gpu') 164 | parser.add_argument('--train_root', type=str, default='', 165 | help='the training rgb images root') 166 | parser.add_argument('--val_root', type=str, default='', 167 | help='the test rgb images root') 168 | parser.add_argument('--save_path', type=str,default='', help='the path to save model and log') 169 | opt = parser.parse_args() 170 | 171 | 172 | os.environ["CUDA_VISIBLE_DEVICES"] = opt.gpu_id 173 | print('USE GPU 0,1,2,3') 174 | cudnn.benchmark = True 175 | 176 | # build the model 177 | device_ids = [0,1,2,3] # if you want to use more gpus than 2, you shoule change it just like when use opt.gpu_id='1,2,6,8' , device_ids = [0,1,2,3] 178 | model = torch.nn.DataParallel(Network(channels=128), device_ids=device_ids) 179 | model = model.cuda(device=device_ids[0]) 180 | 181 | if opt.load is not None: 182 | model.load_state_dict(torch.load(opt.load)) 183 | print('load model from ', opt.load) 184 | 185 | optimizer = torch.optim.Adam(model.parameters(), opt.lr) 186 | save_path = opt.save_path 187 | if not os.path.exists(save_path): 188 | os.makedirs(save_path) 189 | 190 | # load data 191 | print('load data...') 192 | train_loader = get_loader(image_root=opt.train_root + 'Imgs/', 193 | gt_root=opt.train_root + 'GT/', 194 | edge_root=opt.train_root + 'Edge/', 195 | batchsize=opt.batchsize, 196 | trainsize=opt.trainsize, 197 | num_workers=8) 198 | val_loader = test_dataset(image_root=opt.val_root + 'Imgs/', 199 | gt_root=opt.val_root + 'GT/', 200 | testsize=opt.trainsize) 201 | total_step = len(train_loader) 202 | 203 | 204 | # logging 205 | logging.basicConfig(filename=save_path + 'log.log', 206 | format='[%(asctime)s-%(filename)s-%(levelname)s:%(message)s]', 207 | level=logging.INFO, filemode='a', datefmt='%Y-%m-%d %I:%M:%S %p') 208 | logging.info("Network-Train") 209 | logging.info('Config: epoch: {}; lr: {}; batchsize: {}; trainsize: {}; clip: {}; decay_rate: {}; load: {}; ' 210 | 'save_path: {}; decay_epoch: {}'.format(opt.epoch, opt.lr, opt.batchsize, opt.trainsize, opt.clip, 211 | opt.decay_rate, opt.load, save_path, opt.decay_epoch)) 212 | 213 | step = 0 214 | writer = SummaryWriter(save_path + 'summary') 215 | best_mae = 1 216 | best_epoch = 0 217 | 218 | # learning rate schedule 219 | cosine_schedule = optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer, T_max=30, eta_min=1e-6) 220 | print("Start train...") 221 | for epoch in range(1, opt.epoch): 222 | 223 | cur_lr = adjust_lr(optimizer, opt.lr, epoch, opt.decay_rate, opt.decay_epoch) 224 | writer.add_scalar('learning_rate', cur_lr, global_step=epoch) 225 | 226 | cosine_schedule.step() 227 | writer.add_scalar('learning_rate', cosine_schedule.get_last_lr()[0], global_step=epoch) 228 | logging.info('>>> current lr: {}'.format(cosine_schedule.get_last_lr()[0])) 229 | 230 | train(train_loader, model, optimizer, epoch, save_path, writer) 231 | val(val_loader, model, epoch, save_path, writer) 232 | 233 | -------------------------------------------------------------------------------- /FSEL_ECCV_2024/environment.txt: -------------------------------------------------------------------------------- 1 | appdirs 1.4.4 2 | argon2-cffi 20.1.0 3 | async-generator 1.10 4 | attrs 20.3.0 5 | Babel 2.9.0 6 | backcall 0.2.0 7 | bce-python-sdk 0.8.49 8 | bleach 3.2.1 9 | bokeh 2.2.3 10 | certifi 2020.11.8 11 | cffi 1.14.3 12 | cfgv 3.2.0 13 | chardet 3.0.4 14 | click 7.1.2 15 | cloudpickle 1.6.0 16 | cycler 0.10.0 17 | Cython 0.29.21 18 | dataclasses 0.8 19 | decorator 4.4.2 20 | defusedxml 0.6.0 21 | distlib 0.3.1 22 | einops 0.4.1 23 | entrypoints 0.3 24 | filelock 3.0.12 25 | flake8 3.8.4 26 | Flask 1.1.2 27 | Flask-Babel 2.0.0 28 | future 0.18.2 29 | fvcore 0.1.5.post20221221 30 | graphviz 0.15 31 | horovod 0.20.0 32 | huggingface-hub 0.4.0 33 | identify 1.5.9 34 | idna 2.10 35 | importlib-metadata 2.0.0 36 | importlib-resources 3.3.0 37 | iopath 0.1.10 38 | ipykernel 5.3.4 39 | ipython 7.16.1 40 | ipython-genutils 0.2.0 41 | ipywidgets 7.5.1 42 | itsdangerous 1.1.0 43 | jedi 0.17.2 44 | Jinja2 2.11.2 45 | joblib 0.17.0 46 | json5 0.9.5 47 | jsonpatch 1.26 48 | jsonpointer 2.0 49 | jsonschema 3.2.0 50 | jupyter 1.0.0 51 | jupyter-client 6.1.7 52 | jupyter-console 6.2.0 53 | jupyter-core 4.6.3 54 | jupyterlab 2.2.9 55 | jupyterlab-pygments 0.1.2 56 | jupyterlab-server 1.2.0 57 | kiwisolver 1.3.1 58 | lib 4.0.0 59 | MarkupSafe 1.1.1 60 | matplotlib 3.3.3 61 | mccabe 0.6.1 62 | mistune 0.8.4 63 | nbclient 0.5.1 64 | nbconvert 6.0.7 65 | nbformat 5.0.8 66 | nest-asyncio 1.4.3 67 | networkx 2.5 68 | nodeenv 1.5.0 69 | notebook 6.1.5 70 | numpy 1.19.4 71 | opencv-python 4.4.0.46 72 | packaging 21.3 73 | pandas 1.1.4 74 | pandocfilters 1.4.3 75 | parso 0.7.1 76 | patsy 0.5.1 77 | pexpect 4.8.0 78 | pickleshare 0.7.5 79 | Pillow 8.0.1 80 | pip 21.3.1 81 | plotly 4.12.0 82 | portalocker 2.7.0 83 | pre-commit 2.8.2 84 | prometheus-client 0.8.0 85 | prompt-toolkit 3.0.8 86 | protobuf 3.14.0 87 | psutil 5.7.3 88 | ptyprocess 0.6.0 89 | pycodestyle 2.6.0 90 | pycparser 2.20 91 | pycryptodome 3.9.9 92 | pyflakes 2.2.0 93 | Pygments 2.7.2 94 | PyGObject 3.26.1 95 | pyparsing 2.4.7 96 | pyrsistent 0.17.3 97 | python-apt 1.6.5+ubuntu0.3 98 | python-dateutil 2.8.1 99 | pytz 2020.4 100 | PyYAML 5.3.1 101 | pyzmq 20.0.0 102 | qtconsole 4.7.7 103 | QtPy 1.9.0 104 | requests 2.25.0 105 | retrying 1.3.3 106 | scikit-learn 0.23.2 107 | scipy 1.5.4 108 | seaborn 0.11.0 109 | Send2Trash 1.5.0 110 | setuptools 41.0.0 111 | six 1.15.0 112 | statsmodels 0.12.1 113 | tabulate 0.8.10 114 | tensorboardX 2.1 115 | termcolor 1.1.0 116 | terminado 0.9.1 117 | testpath 0.4.4 118 | thop 0.1.1.post2209072238 119 | threadpoolctl 2.1.0 120 | timm 0.6.12 121 | toml 0.10.2 122 | torch 1.8.0+cu111 123 | torch-dct 0.1.6 124 | torchaudio 0.8.0 125 | torchfile 0.1.0 126 | torchvision 0.9.0+cu111 127 | tornado 6.1 128 | tqdm 4.64.1 129 | traitlets 4.3.3 130 | typing-extensions 3.7.4.3 131 | urllib3 1.26.2 132 | virtualenv 20.1.0 133 | visdom 0.1.8.9 134 | visualdl 2.0.4 135 | wcwidth 0.2.5 136 | webencodings 0.5.1 137 | websocket-client 0.57.0 138 | Werkzeug 1.0.1 139 | widgetsnbextension 3.5.1 140 | yacs 0.1.8 141 | zipp 3.4.0 142 | -------------------------------------------------------------------------------- /FSEL_ECCV_2024/lib/FSEL_modules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import fvcore.nn.weight_init as weight_init 5 | from einops import rearrange 6 | import numbers 7 | 8 | 9 | def to_3d(x): 10 | return rearrange(x, 'b c h w -> b (h w) c') 11 | 12 | def to_4d(x,h,w): 13 | return rearrange(x, 'b (h w) c -> b c h w',h=h,w=w) 14 | 15 | class BiasFree_LayerNorm(nn.Module): 16 | def __init__(self, normalized_shape): 17 | super(BiasFree_LayerNorm, self).__init__() 18 | if isinstance(normalized_shape, numbers.Integral): 19 | normalized_shape = (normalized_shape,) 20 | normalized_shape = torch.Size(normalized_shape) 21 | 22 | assert len(normalized_shape) == 1 23 | 24 | self.weight = nn.Parameter(torch.ones(normalized_shape)) 25 | self.normalized_shape = normalized_shape 26 | 27 | def forward(self, x): 28 | sigma = x.var(-1, keepdim=True, unbiased=False) 29 | return x / torch.sqrt(sigma + 1e-5) * self.weight 30 | 31 | 32 | class WithBias_LayerNorm(nn.Module): 33 | def __init__(self, normalized_shape): 34 | super(WithBias_LayerNorm, self).__init__() 35 | if isinstance(normalized_shape, numbers.Integral): 36 | normalized_shape = (normalized_shape,) 37 | normalized_shape = torch.Size(normalized_shape) 38 | 39 | assert len(normalized_shape) == 1 40 | 41 | self.weight = nn.Parameter(torch.ones(normalized_shape)) 42 | self.bias = nn.Parameter(torch.zeros(normalized_shape)) 43 | self.normalized_shape = normalized_shape 44 | 45 | def forward(self, x): 46 | mu = x.mean(-1, keepdim=True) 47 | sigma = x.var(-1, keepdim=True, unbiased=False) 48 | return (x - mu) / torch.sqrt(sigma + 1e-5) * self.weight + self.bias 49 | 50 | def initialize(self): 51 | weight_init(self) 52 | 53 | 54 | class LayerNorm(nn.Module): 55 | def __init__(self, dim, LayerNorm_type): 56 | super(LayerNorm, self).__init__() 57 | if LayerNorm_type == 'BiasFree': 58 | self.body = BiasFree_LayerNorm(dim) 59 | else: 60 | self.body = WithBias_LayerNorm(dim) 61 | 62 | def forward(self, x): 63 | h, w = x.shape[-2:] 64 | return to_4d(self.body(to_3d(x)), h, w) 65 | 66 | def initialize(self): 67 | weight_init(self) 68 | 69 | 70 | class FeedForward(nn.Module): 71 | def __init__(self, dim, ffn_expansion_factor, bias): 72 | super(FeedForward, self).__init__() 73 | 74 | self.dwconv1 = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1,groups=dim, bias=bias) 75 | self.dwconv2 = nn.Conv2d(dim*2, dim*2, kernel_size=3, stride=1, padding=1, groups=dim, bias=bias) 76 | self.project_out = nn.Conv2d(dim*4, dim, kernel_size=1, bias=bias) 77 | self.weight = nn.Sequential( 78 | nn.Conv2d(dim, dim // 16, 1, bias=True), 79 | nn.BatchNorm2d(dim // 16), 80 | nn.ReLU(True), 81 | nn.Conv2d(dim // 16, dim, 1, bias=True), 82 | nn.Sigmoid()) 83 | self.weight1 = nn.Sequential( 84 | nn.Conv2d(dim*2, dim // 16, 1, bias=True), 85 | nn.BatchNorm2d(dim // 16), 86 | nn.ReLU(True), 87 | nn.Conv2d(dim // 16, dim*2, 1, bias=True), 88 | nn.Sigmoid()) 89 | def forward(self, x): 90 | 91 | x_f = torch.abs(self.weight(torch.fft.fft2(x.float()).real)*torch.fft.fft2(x.float())) 92 | x_f_gelu = F.gelu(x_f) * x_f 93 | 94 | x_s = self.dwconv1(x) 95 | x_s_gelu = F.gelu(x_s) * x_s 96 | 97 | x_f = torch.fft.fft2(torch.cat((x_f_gelu,x_s_gelu),1)) 98 | x_f = torch.abs(torch.fft.ifft2(self.weight1(x_f.real) * x_f)) 99 | 100 | x_s = self.dwconv2(torch.cat((x_f_gelu,x_s_gelu),1)) 101 | out = self.project_out(torch.cat((x_f,x_s),1)) 102 | 103 | return out 104 | 105 | def initialize(self): 106 | weight_init(self) 107 | 108 | def custom_complex_normalization(input_tensor, dim=-1): 109 | real_part = input_tensor.real 110 | imag_part = input_tensor.imag 111 | norm_real = F.softmax(real_part, dim=dim) 112 | norm_imag = F.softmax(imag_part, dim=dim) 113 | 114 | normalized_tensor = torch.complex(norm_real, norm_imag) 115 | 116 | return normalized_tensor 117 | 118 | class Attention_F(nn.Module): 119 | def __init__(self, dim, num_heads, bias,): 120 | super(Attention_F, self).__init__() 121 | self.num_heads = num_heads 122 | self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1)) 123 | self.project_out = nn.Conv2d(dim*2, dim, kernel_size=1, bias=bias) 124 | self.weight = nn.Sequential( 125 | nn.Conv2d(dim, dim // 16, 1, bias=True), 126 | nn.BatchNorm2d(dim // 16), 127 | nn.ReLU(True), 128 | nn.Conv2d(dim // 16, dim, 1, bias=True), 129 | nn.Sigmoid()) 130 | def forward(self, x): 131 | b, c, h, w = x.shape 132 | 133 | q_f = torch.fft.fft2(x.float()) 134 | k_f = torch.fft.fft2(x.float()) 135 | v_f = torch.fft.fft2(x.float()) 136 | 137 | q_f = rearrange(q_f, 'b (head c) h w -> b head c (h w)', head=self.num_heads) 138 | k_f = rearrange(k_f, 'b (head c) h w -> b head c (h w)', head=self.num_heads) 139 | v_f = rearrange(v_f, 'b (head c) h w -> b head c (h w)', head=self.num_heads) 140 | 141 | q_f = torch.nn.functional.normalize(q_f, dim=-1) 142 | k_f = torch.nn.functional.normalize(k_f, dim=-1) 143 | attn_f = (q_f @ k_f.transpose(-2, -1)) * self.temperature 144 | attn_f = custom_complex_normalization(attn_f, dim=-1) 145 | out_f = torch.abs(torch.fft.ifft2(attn_f @ v_f)) 146 | out_f = rearrange(out_f, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w) 147 | out_f_l = torch.abs(torch.fft.ifft2(self.weight(torch.fft.fft2(x.float()).real)*torch.fft.fft2(x.float()))) 148 | out = self.project_out(torch.cat((out_f,out_f_l),1)) 149 | return out 150 | 151 | class Attention_S(nn.Module): 152 | def __init__(self, dim, num_heads, bias,): 153 | super(Attention_S, self).__init__() 154 | self.num_heads = num_heads 155 | self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1)) 156 | 157 | self.qkv1conv_1 = nn.Conv2d(dim,dim,kernel_size=1) 158 | self.qkv2conv_1 = nn.Conv2d(dim, dim, kernel_size=1) 159 | self.qkv3conv_1 = nn.Conv2d(dim, dim, kernel_size=1) 160 | 161 | 162 | self.qkv1conv_3 = nn.Conv2d(dim, dim//2, kernel_size=3, stride=1, padding=1, groups=dim//2, bias=bias) 163 | self.qkv2conv_3 = nn.Conv2d(dim, dim//2, kernel_size=3, stride=1, padding=1, groups=dim//2, bias=bias) 164 | self.qkv3conv_3 = nn.Conv2d(dim, dim//2, kernel_size=3, stride=1, padding=1, groups=dim//2, bias=bias) 165 | 166 | self.qkv1conv_5 = nn.Conv2d(dim, dim // 2, kernel_size=5, stride=1, padding=2, groups=dim//2, bias=bias) 167 | self.qkv2conv_5 = nn.Conv2d(dim, dim // 2, kernel_size=5, stride=1, padding=2, groups=dim//2, bias=bias) 168 | self.qkv3conv_5 = nn.Conv2d(dim, dim // 2, kernel_size=5, stride=1, padding=2, groups=dim//2, bias=bias) 169 | 170 | 171 | self.conv_3 = nn.Conv2d(dim, dim//2, kernel_size=3, stride=1, padding=1, groups=dim//2, bias=bias) 172 | self.conv_5 = nn.Conv2d(dim, dim // 2, kernel_size=5, stride=1, padding=2, groups=dim//2, bias=bias) 173 | self.project_out = nn.Conv2d(dim*2, dim, kernel_size=1, bias=bias) 174 | 175 | def forward(self, x): 176 | b, c, h, w = x.shape 177 | q_s = torch.cat((self.qkv1conv_3(self.qkv1conv_1(x)),self.qkv1conv_5(self.qkv1conv_1(x))),1) 178 | k_s = torch.cat((self.qkv2conv_3(self.qkv2conv_1(x)),self.qkv2conv_5(self.qkv2conv_1(x))),1) 179 | v_s = torch.cat((self.qkv3conv_3(self.qkv3conv_1(x)),self.qkv3conv_5(self.qkv3conv_1(x))),1) 180 | 181 | q_s = rearrange(q_s, 'b (head c) h w -> b head c (h w)', head=self.num_heads) 182 | k_s = rearrange(k_s, 'b (head c) h w -> b head c (h w)', head=self.num_heads) 183 | v_s = rearrange(v_s, 'b (head c) h w -> b head c (h w)', head=self.num_heads) 184 | 185 | q_s = torch.nn.functional.normalize(q_s, dim=-1) 186 | k_s = torch.nn.functional.normalize(k_s, dim=-1) 187 | attn_s = (q_s @ k_s.transpose(-2, -1)) * self.temperature 188 | attn_s = attn_s.softmax(dim=-1) 189 | out_s = (attn_s @ v_s) 190 | out_s = rearrange(out_s, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w) 191 | out_s_l = torch.cat((self.conv_3(x),self.conv_5(x)),1) 192 | out = self.project_out(torch.cat((out_s,out_s_l),1)) 193 | 194 | return out 195 | 196 | def initialize(self): 197 | weight_init(self) 198 | 199 | class Module1(nn.Module): 200 | def __init__(self, mode='dilation', dim=128, num_heads=8, ffn_expansion_factor=4, bias=False, 201 | LayerNorm_type='WithBias'): 202 | super(Module1, self).__init__() 203 | self.project_out = nn.Conv2d(dim * 2, dim, kernel_size=1, bias=bias) 204 | self.norm1 = LayerNorm(dim, LayerNorm_type) 205 | self.attn_S = Attention_S(dim, num_heads, bias) 206 | self.attn_F = Attention_F(dim, num_heads, bias) 207 | self.norm2 = LayerNorm(dim, LayerNorm_type) 208 | self.ffn = FeedForward(dim, ffn_expansion_factor, bias) 209 | 210 | def forward(self, x): 211 | x = x + torch.add(self.attn_F(self.norm1(x)),self.attn_S(self.norm1(x))) 212 | x = x + self.ffn(self.norm2(x)) 213 | return x 214 | 215 | 216 | class ETB(nn.Module): # ETB (Entanglement Transformer Block) 217 | def __init__(self, in_channel, out_channel): 218 | super(ETB, self).__init__() 219 | self.conv1 = nn.Sequential( 220 | nn.Conv2d(in_channel, out_channel, 1), nn.BatchNorm2d(out_channel),nn.ReLU(True) 221 | ) 222 | self.reduce = nn.Sequential( 223 | nn.Conv2d(out_channel*2, out_channel, 1),nn.BatchNorm2d(out_channel),nn.ReLU(True) 224 | ) 225 | self.relu = nn.ReLU(True) 226 | self.Module1 = Module1(dim=out_channel) 227 | 228 | def forward(self, x): 229 | x0 = self.conv1(x) 230 | x_FT = self.Module1(x0) 231 | x = self.reduce(torch.cat((x0,x_FT),1))+x0 232 | return x 233 | 234 | class JDPM(nn.Module): # JDPM (Joint Domain Perception Module) 235 | def __init__(self, channels, in_channels): 236 | super(JDPM, self).__init__() 237 | 238 | self.conv1 = nn.Sequential( 239 | nn.Conv2d(channels, in_channels, 1), nn.BatchNorm2d(in_channels), nn.ReLU(True) 240 | ) 241 | 242 | self.Dconv3 = nn.Sequential( 243 | nn.Conv2d(in_channels, in_channels, 1), nn.BatchNorm2d(in_channels), 244 | nn.Conv2d(in_channels, in_channels, 3, padding=3,dilation=3), nn.BatchNorm2d(in_channels), nn.ReLU(True) 245 | ) 246 | 247 | self.Dconv5 = nn.Sequential( 248 | nn.Conv2d(in_channels, in_channels, 1), nn.BatchNorm2d(in_channels), 249 | nn.Conv2d(in_channels, in_channels, 3, padding=5,dilation=5), nn.BatchNorm2d(in_channels), nn.ReLU(True) 250 | ) 251 | self.Dconv7 = nn.Sequential( 252 | nn.Conv2d(in_channels, in_channels, 1), nn.BatchNorm2d(in_channels), 253 | nn.Conv2d(in_channels, in_channels, 3, padding=7,dilation=7), nn.BatchNorm2d(in_channels), nn.ReLU(True) 254 | ) 255 | self.Dconv9 = nn.Sequential( 256 | nn.Conv2d(in_channels, in_channels, 1), nn.BatchNorm2d(in_channels), 257 | nn.Conv2d(in_channels, in_channels, 3, padding=9,dilation=9), nn.BatchNorm2d(in_channels),nn.ReLU(True) 258 | ) 259 | 260 | self.reduce = nn.Sequential( 261 | nn.Conv2d(in_channels * 5, in_channels, 1), nn.BatchNorm2d(in_channels),nn.ReLU(True) 262 | ) 263 | 264 | self.out = nn.Sequential( 265 | nn.Conv2d(in_channels, in_channels//2, kernel_size=3, padding=1), nn.BatchNorm2d(in_channels//2), nn.ReLU(True), 266 | nn.Conv2d(in_channels//2, 1, kernel_size=1) 267 | ) 268 | 269 | self.weight = nn.Sequential( 270 | nn.Conv2d(in_channels, in_channels // 16, 1, bias=True), 271 | nn.BatchNorm2d(in_channels // 16), 272 | nn.ReLU(True), 273 | nn.Conv2d(in_channels // 16, in_channels, 1, bias=True), 274 | nn.Sigmoid()) 275 | 276 | self.norm = nn.BatchNorm2d(in_channels) 277 | self.relu = nn.ReLU(True) 278 | 279 | def forward(self, F1): 280 | 281 | F1_input = self.conv1(F1) 282 | 283 | F1_3_s = self.Dconv3(F1_input) 284 | F1_3_f = self.relu(self.norm(torch.abs(torch.fft.ifft2(self.weight(torch.fft.fft2(F1_3_s.float()).real)*torch.fft.fft2(F1_3_s.float()))))) 285 | F1_3 = torch.add(F1_3_s,F1_3_f) 286 | 287 | F1_5_s = self.Dconv5(F1_input + F1_3) 288 | F1_5_f = self.relu(self.norm(torch.abs(torch.fft.ifft2(self.weight(torch.fft.fft2(F1_5_s.float()).real)*torch.fft.fft2(F1_5_s.float()))))) 289 | F1_5 = torch.add(F1_5_s, F1_5_f) 290 | 291 | F1_7_s = self.Dconv7(F1_input + F1_5) 292 | F1_7_f = self.relu(self.norm(torch.abs(torch.fft.ifft2(self.weight(torch.fft.fft2(F1_7_s.float()).real)*torch.fft.fft2(F1_7_s.float()))))) 293 | F1_7 = torch.add(F1_7_s, F1_7_f) 294 | 295 | F1_9_s = self.Dconv9(F1_input + F1_7) 296 | F1_9_f = self.relu(self.norm(torch.abs(torch.fft.ifft2(self.weight(torch.fft.fft2(F1_9_s.float()).real)*torch.fft.fft2(F1_9_s.float()))))) 297 | F1_9 = torch.add(F1_9_s, F1_9_f) 298 | 299 | F_out = self.out(self.reduce(torch.cat((F1_3,F1_5,F1_7,F1_9,F1_input),1)) + F1_input ) 300 | 301 | return F_out 302 | 303 | class DRP_1(nn.Module): # DRP (Dual-domain Reverse Parser) 304 | def __init__(self, in_channels, mid_channels): 305 | super(DRP_1, self).__init__() 306 | self.conv = nn.Sequential( 307 | nn.Conv2d(in_channels * 2, in_channels, kernel_size=1), nn.BatchNorm2d(in_channels), 308 | nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1, stride=1), nn.BatchNorm2d(in_channels), 309 | nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1, stride=1), nn.BatchNorm2d(in_channels), nn.ReLU(True) 310 | ) 311 | 312 | self.out = nn.Sequential( 313 | nn.Conv2d(in_channels * 2, mid_channels, kernel_size=3, padding=1), nn.BatchNorm2d(mid_channels),nn.ReLU(True), 314 | nn.Conv2d(mid_channels, 1, kernel_size=1) 315 | ) 316 | 317 | self.conv3 = nn.Sequential( 318 | nn.Conv2d(in_channels, in_channels, kernel_size=1), nn.BatchNorm2d(in_channels), 319 | nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1, stride=1), nn.BatchNorm2d(in_channels), 320 | nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1, stride=1), nn.BatchNorm2d(in_channels),nn.ReLU(True), 321 | ) 322 | 323 | self.weight = nn.Sequential( 324 | nn.Conv2d(in_channels, in_channels // 16, 1, bias=True), 325 | nn.BatchNorm2d(in_channels // 16), 326 | nn.ReLU(True), 327 | nn.Conv2d(in_channels // 16, in_channels, 1, bias=True), 328 | nn.Sigmoid()) 329 | 330 | self.norm = nn.BatchNorm2d(in_channels) 331 | self.relu = nn.ReLU(in_channels) 332 | 333 | def forward(self, X, prior_cam): 334 | prior_cam = F.interpolate(prior_cam, size=X.size()[2:], mode='bilinear',align_corners=True) 335 | 336 | FI = X 337 | 338 | yt = self.conv(torch.cat([FI, prior_cam.expand(-1, X.size()[1], -1, -1)], dim=1)) 339 | 340 | yt_s = self.conv3(yt) 341 | yt_f = self.relu(self.norm(torch.abs(torch.fft.ifft2(self.weight(torch.fft.fft2(yt.float()).real)*torch.fft.fft2(yt.float()))))) 342 | yt_out = torch.add(yt_s,yt_f) 343 | 344 | r_prior_cam_f = torch.abs(torch.fft.fft2(prior_cam)) 345 | r_prior_cam_f = -1 * (torch.sigmoid(r_prior_cam_f)) + 1 346 | r_prior_cam_s = -1 * (torch.sigmoid(prior_cam)) + 1 347 | r_prior_cam = r_prior_cam_s + r_prior_cam_f 348 | 349 | y_ra = r_prior_cam.expand(-1, X.size()[1], -1, -1).mul(FI) 350 | 351 | out = torch.cat([y_ra, yt_out], dim=1) # 2,128,48,48 352 | 353 | y = self.out(out) 354 | y = y + prior_cam 355 | return y 356 | 357 | class DRP_2(nn.Module): # DRP (Dual-domain Reverse Parser) 358 | def __init__(self, in_channels, mid_channels): 359 | super(DRP_2, self).__init__() 360 | self.conv = nn.Sequential( 361 | nn.Conv2d(in_channels * 3, in_channels, kernel_size=1), nn.BatchNorm2d(in_channels), 362 | nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1, stride=1), nn.BatchNorm2d(in_channels), 363 | nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1, stride=1), nn.BatchNorm2d(in_channels),nn.ReLU(True), 364 | ) 365 | 366 | self.conv3 = nn.Sequential( 367 | nn.Conv2d(in_channels, in_channels, kernel_size=1), nn.BatchNorm2d(in_channels), 368 | nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1, stride=1), nn.BatchNorm2d(in_channels), 369 | nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1, stride=1), nn.BatchNorm2d(in_channels),nn.ReLU(True), 370 | ) 371 | 372 | self.out = nn.Sequential( 373 | nn.Conv2d(in_channels * 2, mid_channels, kernel_size=3, padding=1), nn.BatchNorm2d(mid_channels),nn.ReLU(True), 374 | nn.Conv2d(mid_channels, 1, kernel_size=1) 375 | ) 376 | 377 | self.weight = nn.Sequential( 378 | nn.Conv2d(in_channels, in_channels // 16, 1, bias=True), 379 | nn.BatchNorm2d(in_channels // 16), 380 | nn.ReLU(True), 381 | nn.Conv2d(in_channels // 16, in_channels, 1, bias=True), 382 | nn.Sigmoid()) 383 | 384 | self.norm = nn.BatchNorm2d(in_channels) 385 | self.relu = nn.ReLU(True) 386 | 387 | def forward(self, X, x1, prior_cam): 388 | prior_cam = F.interpolate(prior_cam, size=X.size()[2:], mode='bilinear',align_corners=True) 389 | x1_prior_cam = F.interpolate(x1, size=X.size()[2:], mode='bilinear', align_corners=True) 390 | FI = X 391 | 392 | yt = self.conv(torch.cat([FI, prior_cam.expand(-1, X.size()[1], -1, -1), x1_prior_cam.expand(-1, X.size()[1], -1, -1)],dim=1)) 393 | 394 | yt_s = self.conv3(yt) 395 | yt_f = self.relu(self.norm(torch.abs(torch.fft.ifft2(self.weight(torch.fft.fft2(yt.float()).real) * torch.fft.fft2(yt.float()))))) 396 | yt_out = torch.add(yt_s, yt_f) 397 | 398 | r_prior_cam_f = torch.abs(torch.fft.fft2(prior_cam)) 399 | r_prior_cam_f = -1 * (torch.sigmoid(r_prior_cam_f)) + 1 400 | r_prior_cam_s = -1 * (torch.sigmoid(prior_cam)) + 1 401 | r_prior_cam = r_prior_cam_s + r_prior_cam_f 402 | 403 | r1_prior_cam_f = torch.abs(torch.fft.fft2(x1_prior_cam)) 404 | r1_prior_cam_f = -1 * (torch.sigmoid(r1_prior_cam_f)) + 1 405 | r1_prior_cam_s = -1 * (torch.sigmoid(x1_prior_cam)) + 1 406 | r1_prior_cam = r1_prior_cam_s + r1_prior_cam_f 407 | 408 | r_prior_cam = r_prior_cam + r1_prior_cam 409 | 410 | y_ra = r_prior_cam.expand(-1, X.size()[1], -1, -1).mul(FI) 411 | 412 | out = torch.cat([y_ra, yt_out], dim=1) 413 | 414 | y = self.out(out) 415 | y = y + prior_cam + x1_prior_cam 416 | return y 417 | 418 | class DRP_3(nn.Module): # DRP (Dual-domain Reverse Parser) 419 | def __init__(self, in_channels, mid_channels): 420 | super(DRP_3, self).__init__() 421 | self.conv = nn.Sequential( 422 | nn.Conv2d(in_channels * 4, in_channels, kernel_size=1), nn.BatchNorm2d(in_channels), 423 | nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1, stride=1), nn.BatchNorm2d(in_channels), 424 | nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1, stride=1), nn.BatchNorm2d(in_channels),nn.ReLU(True), 425 | ) 426 | 427 | self.conv3 = nn.Sequential( 428 | nn.Conv2d(in_channels, in_channels, kernel_size=1), nn.BatchNorm2d(in_channels), 429 | nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1, stride=1), nn.BatchNorm2d(in_channels), 430 | nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1, stride=1), nn.BatchNorm2d(in_channels),nn.ReLU(True), 431 | ) 432 | 433 | self.out = nn.Sequential( 434 | nn.Conv2d(in_channels * 2, mid_channels, kernel_size=3, padding=1), nn.BatchNorm2d(mid_channels),nn.ReLU(True), 435 | nn.Conv2d(mid_channels, 1, kernel_size=1) 436 | ) 437 | 438 | self.weight = nn.Sequential( 439 | nn.Conv2d(in_channels, in_channels // 16, 1, bias=True), 440 | nn.BatchNorm2d(in_channels // 16), 441 | nn.ReLU(True), 442 | nn.Conv2d(in_channels // 16, in_channels, 1, bias=True), 443 | nn.Sigmoid()) 444 | 445 | self.norm = nn.BatchNorm2d(in_channels) 446 | self.relu = nn.ReLU(True) 447 | 448 | def forward(self, X, x1,x2, prior_cam): 449 | prior_cam = F.interpolate(prior_cam, size=X.size()[2:], mode='bilinear',align_corners=True) # 450 | x1_prior_cam = F.interpolate(x1, size=X.size()[2:], mode='bilinear', align_corners=True) 451 | x2_prior_cam = F.interpolate(x2, size=X.size()[2:], mode='bilinear', align_corners=True) 452 | FI = X 453 | 454 | yt = self.conv(torch.cat([FI, prior_cam.expand(-1, X.size()[1], -1, -1), x1_prior_cam.expand(-1, X.size()[1], -1, -1),x2_prior_cam.expand(-1, X.size()[1], -1, -1)],dim=1)) 455 | 456 | yt_s = self.conv3(yt) 457 | yt_f = self.relu(self.norm(torch.abs(torch.fft.ifft2(self.weight(torch.fft.fft2(yt.float()).real) * torch.fft.fft2(yt.float()))))) 458 | yt_out = torch.add(yt_s, yt_f) 459 | 460 | r_prior_cam_f = torch.abs(torch.fft.fft2(prior_cam)) 461 | r_prior_cam_f = -1 * (torch.sigmoid(r_prior_cam_f)) + 1 462 | r_prior_cam_s = -1 * (torch.sigmoid(prior_cam)) + 1 463 | r_prior_cam = r_prior_cam_s + r_prior_cam_f 464 | 465 | r1_prior_cam_f = torch.abs(torch.fft.fft2(x1_prior_cam)) 466 | r1_prior_cam_f = -1 * (torch.sigmoid(r1_prior_cam_f)) + 1 467 | r1_prior_cam_s = -1 * (torch.sigmoid(x1_prior_cam)) + 1 468 | r1_prior_cam1 = r1_prior_cam_s + r1_prior_cam_f 469 | 470 | r2_prior_cam_f = torch.abs(torch.fft.fft2(x2_prior_cam)) 471 | r2_prior_cam_f = -1 * (torch.sigmoid(r2_prior_cam_f)) + 1 472 | r2_prior_cam_s = -1 * (torch.sigmoid(x2_prior_cam)) + 1 473 | r1_prior_cam2 = r2_prior_cam_s + r2_prior_cam_f 474 | 475 | r_prior_cam = r_prior_cam + r1_prior_cam1 + r1_prior_cam2 476 | 477 | y_ra = r_prior_cam.expand(-1, X.size()[1], -1, -1).mul(FI) 478 | 479 | out = torch.cat([y_ra, yt_out], dim=1) 480 | 481 | y = self.out(out) 482 | 483 | y = y + prior_cam + x1_prior_cam + x2_prior_cam 484 | 485 | return y 486 | 487 | 488 | 489 | 490 | 491 | 492 | 493 | 494 | 495 | 496 | 497 | 498 | 499 | 500 | 501 | 502 | 503 | 504 | 505 | 506 | 507 | 508 | 509 | 510 | -------------------------------------------------------------------------------- /FSEL_ECCV_2024/lib/Network_PVT.py: -------------------------------------------------------------------------------- 1 | import timm 2 | from lib.pvt_v2 import pvt_v2_b4 3 | import torch.nn as nn 4 | import torch 5 | import torch.nn.functional as F 6 | from lib.FSEL_modules import DRP_1, DRP_2, DRP_3, JDPM, ETB 7 | 8 | 9 | 10 | ''' 11 | backbone: PVT_v2_b4 12 | ''' 13 | 14 | 15 | class Network(nn.Module): 16 | def __init__(self, channels=128): 17 | super(Network, self).__init__() 18 | self.shared_encoder = pvt_v2_b4() 19 | pretrained_dict = torch.load('/opt/data/private/Syg/COD/pre_train_pth/pvt_v2_b4.pth') 20 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in self.shared_encoder.state_dict()} 21 | self.shared_encoder.load_state_dict(pretrained_dict) 22 | self.dePixelShuffle = torch.nn.PixelShuffle(2) 23 | self.up = nn.Sequential( 24 | nn.Conv2d(channels//4, channels, kernel_size=1),nn.BatchNorm2d(channels), 25 | nn.Conv2d(channels, channels, kernel_size=3, padding=1),nn.BatchNorm2d(channels),nn.ReLU(True) 26 | ) 27 | 28 | self.ETB_5 = ETB(512+channels, channels) 29 | self.ETB_4 = ETB(320+channels, channels) 30 | self.ETB_3 = ETB(128+channels, channels) 31 | self.ETB_2 = ETB(64+channels, channels) 32 | 33 | self.JDPM = JDPM(512, channels) 34 | 35 | self.DRP_1 = DRP_1(channels, channels) 36 | self.DRP_2 = DRP_2(channels, channels) 37 | self.DRP_3 = DRP_3(channels,channels) 38 | 39 | def forward(self, x): 40 | image = x 41 | 42 | en_feats = self.shared_encoder(x) 43 | x4, x3, x2, x1 = en_feats 44 | 45 | 46 | p1 = self.JDPM(x4) 47 | x5_4 = p1 48 | x5_4_1 = x5_4.expand(-1, 128, -1, -1) 49 | 50 | x4 = self.ETB_5(torch.cat((x4,x5_4_1),1)) 51 | x4_up = self.up(self.dePixelShuffle(x4)) 52 | 53 | x3 = self.ETB_4(torch.cat((x3,x4_up),1)) 54 | x3_up = self.up(self.dePixelShuffle(x3)) 55 | 56 | x2 = self.ETB_3(torch.cat((x2,x3_up),1)) 57 | x2_up = self.up(self.dePixelShuffle(x2)) 58 | 59 | 60 | x1 = self.ETB_2(torch.cat((x1,x2_up),1)) 61 | 62 | 63 | x4 = self.DRP_1(x4,x5_4) 64 | x3 = self.DRP_1(x3,x4) 65 | x2 = self.DRP_2(x2,x3,x4) 66 | x1 = self.DRP_3(x1,x2,x3,x4) 67 | 68 | 69 | p0 = F.interpolate(p1, size=image.size()[2:], mode='bilinear', align_corners=True) 70 | f4 = F.interpolate(x4, size=image.size()[2:], mode='bilinear', align_corners=True) 71 | f3 = F.interpolate(x3, size=image.size()[2:], mode='bilinear', align_corners=True) 72 | f2 = F.interpolate(x2, size=image.size()[2:], mode='bilinear', align_corners=True) 73 | f1 = F.interpolate(x1, size=image.size()[2:], mode='bilinear', align_corners=True) 74 | 75 | 76 | return p0, f4, f3, f2, f1 77 | 78 | 79 | -------------------------------------------------------------------------------- /FSEL_ECCV_2024/lib/Network_Res2Net.py: -------------------------------------------------------------------------------- 1 | 2 | import timm 3 | import torch.nn as nn 4 | import torch 5 | import torch.nn.functional as F 6 | from lib.FSEL_modules import DRP_1, DRP_2, DRP_3, JDPM, ETB 7 | from lib.Res2Net_v1b import res2net50_v1b_26w_4s 8 | 9 | 10 | ''' 11 | backbone: Res2Net50 12 | ''' 13 | 14 | 15 | class Network(nn.Module): 16 | def __init__(self, channels=128): 17 | super(Network, self).__init__() 18 | self.shared_encoder = res2net50_v1b_26w_4s() 19 | 20 | pretrained_dict = torch.load('/opt/data/private/Syg/COD/pre_train_pth/res2net50_v1b_26w_4s-3cf99910.pth') 21 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in self.shared_encoder.state_dict()} 22 | self.shared_encoder.load_state_dict(pretrained_dict) 23 | 24 | self.dePixelShuffle = torch.nn.PixelShuffle(2) 25 | 26 | self.up = nn.Sequential( 27 | nn.Conv2d(channels//4, channels, kernel_size=1),nn.BatchNorm2d(channels), 28 | nn.Conv2d(channels, channels, kernel_size=3, padding=1),nn.BatchNorm2d(channels),nn.ReLU(True) 29 | ) 30 | 31 | self.ETB_5 = ETB(2048+channels, channels) 32 | self.ETB_4 = ETB(1024+channels, channels) 33 | self.ETB_3 = ETB(512+channels, channels) 34 | self.ETB_2 = ETB(256+channels, channels) 35 | 36 | self.JDPM = JDPM(2048, channels) 37 | 38 | self.DRP_1 = DRP_1(channels, channels) 39 | self.DRP_2 = DRP_2(channels, channels) 40 | self.DRP_3 = DRP_3(channels,channels) 41 | 42 | def forward(self, x): 43 | image = x 44 | 45 | x = self.shared_encoder.conv1(x) 46 | x = self.shared_encoder.bn1(x) 47 | x = self.shared_encoder.relu(x) 48 | x = self.shared_encoder.maxpool(x) # bs, 64, 88, 88 49 | x1 = self.shared_encoder.layer1(x) # bs, 256, 88, 88 50 | x2 = self.shared_encoder.layer2(x1) # bs, 512, 44, 44 51 | x3 = self.shared_encoder.layer3(x2) # bs, 1024, 22, 22 52 | x4 = self.shared_encoder.layer4(x3) 53 | 54 | 55 | p1 = self.JDPM(x4) 56 | x5_4 = p1 57 | x5_4_1 = x5_4.expand(-1, 128, -1, -1) 58 | 59 | x4 = self.ETB_5(torch.cat((x4,x5_4_1),1)) 60 | x4_up = self.up(self.dePixelShuffle(x4)) 61 | 62 | 63 | x3 = self.ETB_4(torch.cat((x3,x4_up),1)) 64 | x3_up = self.up(self.dePixelShuffle(x3)) 65 | 66 | 67 | x2 = self.ETB_3(torch.cat((x2,x3_up),1)) 68 | x2_up = self.up(self.dePixelShuffle(x2)) 69 | 70 | 71 | x1 = self.ETB_2(torch.cat((x1,x2_up),1)) 72 | 73 | x4 = self.DRP_1(x4,x5_4) 74 | x3 = self.DRP_1(x3,x4) 75 | x2 = self.DRP_2(x2,x3,x4) 76 | x1 = self.DRP_3(x1,x2,x3,x4) 77 | 78 | p0 = F.interpolate(p1, size=image.size()[2:], mode='bilinear', align_corners=True) 79 | f4 = F.interpolate(x4, size=image.size()[2:], mode='bilinear', align_corners=True) 80 | f3 = F.interpolate(x3, size=image.size()[2:], mode='bilinear', align_corners=True) 81 | f2 = F.interpolate(x2, size=image.size()[2:], mode='bilinear', align_corners=True) 82 | f1 = F.interpolate(x1, size=image.size()[2:], mode='bilinear', align_corners=True) 83 | 84 | return p0, f4, f3, f2, f1 85 | 86 | 87 | -------------------------------------------------------------------------------- /FSEL_ECCV_2024/lib/Network_ResNet.py: -------------------------------------------------------------------------------- 1 | 2 | import timm 3 | import torch.nn as nn 4 | import torch 5 | import torch.nn.functional as F 6 | from lib.FSEL_modules import DRP_1, DRP_2, DRP_3, JDPM, ETB 7 | 8 | 9 | 10 | ''' 11 | backbone: resnet50 12 | ''' 13 | 14 | 15 | class Network(nn.Module): 16 | # resnet based encoder decoder 17 | def __init__(self, channels=128): 18 | super(Network, self).__init__() 19 | self.shared_encoder = timm.create_model(model_name="resnet50", pretrained=False, in_chans=3, features_only=True) 20 | 21 | self.dePixelShuffle = torch.nn.PixelShuffle(2) 22 | 23 | self.up = nn.Sequential( 24 | nn.Conv2d(channels//4, channels, kernel_size=1),nn.BatchNorm2d(channels), 25 | nn.Conv2d(channels, channels, kernel_size=3, padding=1),nn.BatchNorm2d(channels),nn.ReLU(True) 26 | ) 27 | self.channel = channels 28 | self.ETB_5 = ETB(2048+channels, channels) 29 | self.ETB_4 = ETB(1024+channels, channels) 30 | self.ETB_3 = ETB(512+channels, channels) 31 | self.ETB_2 = ETB(256+channels, channels) 32 | 33 | self.JDPM = JDPM(2048, channels) 34 | 35 | self.DRP_1 = DRP_1(channels, channels) 36 | self.DRP_2 = DRP_2(channels, channels) 37 | self.DRP_3 = DRP_3(channels,channels) 38 | 39 | def forward(self, x): 40 | image = x 41 | # Feature Extraction 42 | en_feats = self.shared_encoder(x) 43 | x0, x1, x2, x3, x4 = en_feats 44 | 45 | p1 = self.JDPM(x4) 46 | x5_4 = p1 47 | x5_4_1 = x5_4.expand(-1, self.channel, -1, -1) 48 | 49 | x4 = self.ETB_5(torch.cat((x4,x5_4_1),1)) 50 | x4_up = self.up(self.dePixelShuffle(x4)) 51 | 52 | x3 = self.ETB_4(torch.cat((x3,x4_up),1)) 53 | x3_up = self.up(self.dePixelShuffle(x3)) 54 | 55 | x2 = self.ETB_3(torch.cat((x2,x3_up),1)) 56 | x2_up = self.up(self.dePixelShuffle(x2)) 57 | 58 | x1 = self.ETB_2(torch.cat((x1,x2_up),1)) 59 | 60 | x4 = self.DRP_1(x4,x5_4) 61 | x3 = self.DRP_1(x3,x4) 62 | x2 = self.DRP_2(x2,x3,x4) 63 | x1 = self.DRP_3(x1,x2,x3,x4) 64 | 65 | p0 = F.interpolate(p1, size=image.size()[2:], mode='bilinear', align_corners=True) 66 | f4 = F.interpolate(x4, size=image.size()[2:], mode='bilinear', align_corners=True) 67 | f3 = F.interpolate(x3, size=image.size()[2:], mode='bilinear', align_corners=True) 68 | f2 = F.interpolate(x2, size=image.size()[2:], mode='bilinear', align_corners=True) 69 | f1 = F.interpolate(x1, size=image.size()[2:], mode='bilinear', align_corners=True) 70 | 71 | return p0, f4, f3, f2, f1 72 | -------------------------------------------------------------------------------- /FSEL_ECCV_2024/lib/Res2Net_v1b.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import math 3 | import torch.utils.model_zoo as model_zoo 4 | import torch 5 | import torch.nn.functional as F 6 | 7 | __all__ = ['Res2Net', 'res2net50_v1b', 'res2net101_v1b', 'res2net50_v1b_26w_4s'] 8 | 9 | model_urls = { 10 | 'res2net50_v1b_26w_4s': 'https://shanghuagao.oss-cn-beijing.aliyuncs.com/res2net/res2net50_v1b_26w_4s-3cf99910.pth', 11 | 'res2net101_v1b_26w_4s': 'https://shanghuagao.oss-cn-beijing.aliyuncs.com/res2net/res2net101_v1b_26w_4s-0812c246.pth', 12 | } 13 | 14 | 15 | class Bottle2neck(nn.Module): 16 | expansion = 4 17 | 18 | def __init__(self, inplanes, planes, stride=1, downsample=None, baseWidth=26, scale=4, stype='normal'): 19 | """ Constructor 20 | Args: 21 | inplanes: input channel dimensionality 22 | planes: output channel dimensionality 23 | stride: conv stride. Replaces pooling layer. 24 | downsample: None when stride = 1 25 | baseWidth: basic width of conv3x3 26 | scale: number of scale. 27 | type: 'normal': normal set. 'stage': first block of a new stage. 28 | """ 29 | super(Bottle2neck, self).__init__() 30 | 31 | width = int(math.floor(planes * (baseWidth / 64.0))) 32 | self.conv1 = nn.Conv2d(inplanes, width * scale, kernel_size=1, bias=False) 33 | self.bn1 = nn.BatchNorm2d(width * scale) 34 | 35 | if scale == 1: 36 | self.nums = 1 37 | else: 38 | self.nums = scale - 1 39 | if stype == 'stage': 40 | self.pool = nn.AvgPool2d(kernel_size=3, stride=stride, padding=1) 41 | convs = [] 42 | bns = [] 43 | for i in range(self.nums): 44 | convs.append(nn.Conv2d(width, width, kernel_size=3, stride=stride, padding=1, bias=False)) 45 | bns.append(nn.BatchNorm2d(width)) 46 | self.convs = nn.ModuleList(convs) 47 | self.bns = nn.ModuleList(bns) 48 | 49 | self.conv3 = nn.Conv2d(width * scale, planes * self.expansion, kernel_size=1, bias=False) 50 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 51 | 52 | self.relu = nn.ReLU(inplace=True) 53 | self.downsample = downsample 54 | self.stype = stype 55 | self.scale = scale 56 | self.width = width 57 | 58 | def forward(self, x): 59 | residual = x 60 | 61 | out = self.conv1(x) 62 | out = self.bn1(out) 63 | out = self.relu(out) 64 | 65 | spx = torch.split(out, self.width, 1) 66 | for i in range(self.nums): 67 | if i == 0 or self.stype == 'stage': 68 | sp = spx[i] 69 | else: 70 | sp = sp + spx[i] 71 | sp = self.convs[i](sp) 72 | sp = self.relu(self.bns[i](sp)) 73 | if i == 0: 74 | out = sp 75 | else: 76 | out = torch.cat((out, sp), 1) 77 | if self.scale != 1 and self.stype == 'normal': 78 | out = torch.cat((out, spx[self.nums]), 1) 79 | elif self.scale != 1 and self.stype == 'stage': 80 | out = torch.cat((out, self.pool(spx[self.nums])), 1) 81 | 82 | out = self.conv3(out) 83 | out = self.bn3(out) 84 | 85 | if self.downsample is not None: 86 | residual = self.downsample(x) 87 | 88 | out += residual 89 | out = self.relu(out) 90 | 91 | return out 92 | 93 | 94 | class Res2Net(nn.Module): 95 | 96 | def __init__(self, block, layers, baseWidth=26, scale=4, num_classes=1000): 97 | self.inplanes = 64 98 | super(Res2Net, self).__init__() 99 | self.baseWidth = baseWidth 100 | self.scale = scale 101 | self.conv1 = nn.Sequential( 102 | nn.Conv2d(3, 32, 3, 2, 1, bias=False), 103 | nn.BatchNorm2d(32), 104 | nn.ReLU(inplace=True), 105 | nn.Conv2d(32, 32, 3, 1, 1, bias=False), 106 | nn.BatchNorm2d(32), 107 | nn.ReLU(inplace=True), 108 | nn.Conv2d(32, 64, 3, 1, 1, bias=False) 109 | ) 110 | self.bn1 = nn.BatchNorm2d(64) 111 | self.relu = nn.ReLU() 112 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 113 | self.layer1 = self._make_layer(block, 64, layers[0]) 114 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 115 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 116 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 117 | self.avgpool = nn.AdaptiveAvgPool2d(1) 118 | self.fc = nn.Linear(512 * block.expansion, num_classes) 119 | 120 | for m in self.modules(): 121 | if isinstance(m, nn.Conv2d): 122 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 123 | elif isinstance(m, nn.BatchNorm2d): 124 | nn.init.constant_(m.weight, 1) 125 | nn.init.constant_(m.bias, 0) 126 | 127 | def _make_layer(self, block, planes, blocks, stride=1): 128 | downsample = None 129 | if stride != 1 or self.inplanes != planes * block.expansion: 130 | downsample = nn.Sequential( 131 | nn.AvgPool2d(kernel_size=stride, stride=stride, 132 | ceil_mode=True, count_include_pad=False), 133 | nn.Conv2d(self.inplanes, planes * block.expansion, 134 | kernel_size=1, stride=1, bias=False), 135 | nn.BatchNorm2d(planes * block.expansion), 136 | ) 137 | 138 | layers = [] 139 | layers.append(block(self.inplanes, planes, stride, downsample=downsample, 140 | stype='stage', baseWidth=self.baseWidth, scale=self.scale)) 141 | self.inplanes = planes * block.expansion 142 | for i in range(1, blocks): 143 | layers.append(block(self.inplanes, planes, baseWidth=self.baseWidth, scale=self.scale)) 144 | 145 | return nn.Sequential(*layers) 146 | 147 | def forward(self, x): 148 | x = self.conv1(x) 149 | x = self.bn1(x) 150 | x = self.relu(x) 151 | x = self.maxpool(x) 152 | 153 | x = self.layer1(x) 154 | x = self.layer2(x) 155 | x = self.layer3(x) 156 | x = self.layer4(x) 157 | 158 | x = self.avgpool(x) 159 | x = x.view(x.size(0), -1) 160 | x = self.fc(x) 161 | 162 | return x 163 | 164 | 165 | def res2net50_v1b(pretrained=False, **kwargs): 166 | """Constructs a Res2Net-50_v1b lib. 167 | Res2Net-50 refers to the Res2Net-50_v1b_26w_4s. 168 | Args: 169 | pretrained (bool): If True, returns a lib pre-trained on ImageNet 170 | """ 171 | model = Res2Net(Bottle2neck, [3, 4, 6, 3], baseWidth=26, scale=4, **kwargs) 172 | if pretrained: 173 | model.load_state_dict(model_zoo.load_url(model_urls['res2net50_v1b_26w_4s'])) 174 | return model 175 | 176 | 177 | def res2net101_v1b(pretrained=False, **kwargs): 178 | """Constructs a Res2Net-50_v1b_26w_4s lib. 179 | Args: 180 | pretrained (bool): If True, returns a lib pre-trained on ImageNet 181 | """ 182 | model = Res2Net(Bottle2neck, [3, 4, 23, 3], baseWidth=26, scale=4, **kwargs) 183 | if pretrained: 184 | model.load_state_dict(model_zoo.load_url(model_urls['res2net101_v1b_26w_4s'])) 185 | return model 186 | 187 | 188 | def res2net50_v1b_26w_4s(pretrained=False, **kwargs): 189 | """Constructs a Res2Net-50_v1b_26w_4s lib. 190 | Args: 191 | pretrained (bool): If True, returns a lib pre-trained on ImageNet 192 | """ 193 | model = Res2Net(Bottle2neck, [3, 4, 6, 3], baseWidth=26, scale=4, **kwargs) 194 | if pretrained: 195 | model_state = torch.load('/media/nercms/NERCMS/GepengJi/Medical_Seqmentation/CRANet/models/res2net50_v1b_26w_4s-3cf99910.pth') 196 | model.load_state_dict(model_state) 197 | # lib.load_state_dict(model_zoo.load_url(model_urls['res2net50_v1b_26w_4s'])) 198 | return model 199 | 200 | 201 | def res2net101_v1b_26w_4s(pretrained=False, **kwargs): 202 | """Constructs a Res2Net-50_v1b_26w_4s lib. 203 | Args: 204 | pretrained (bool): If True, returns a lib pre-trained on ImageNet 205 | """ 206 | model = Res2Net(Bottle2neck, [3, 4, 23, 3], baseWidth=26, scale=4, **kwargs) 207 | if pretrained: 208 | model.load_state_dict(model_zoo.load_url(model_urls['res2net101_v1b_26w_4s'])) 209 | return model 210 | 211 | 212 | def res2net152_v1b_26w_4s(pretrained=False, **kwargs): 213 | """Constructs a Res2Net-50_v1b_26w_4s lib. 214 | Args: 215 | pretrained (bool): If True, returns a lib pre-trained on ImageNet 216 | """ 217 | model = Res2Net(Bottle2neck, [3, 8, 36, 3], baseWidth=26, scale=4, **kwargs) 218 | if pretrained: 219 | model.load_state_dict(model_zoo.load_url(model_urls['res2net152_v1b_26w_4s'])) 220 | return model 221 | 222 | 223 | if __name__ == '__main__': 224 | images = torch.rand(1, 3, 352, 352).cuda(0) 225 | model = res2net50_v1b_26w_4s(pretrained=False) 226 | model = model.cuda(0) 227 | print(model(images).size()) 228 | -------------------------------------------------------------------------------- /FSEL_ECCV_2024/lib/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSYSI/FSEL/f3be46404486f70404c79f1c1b93025e77a7233e/FSEL_ECCV_2024/lib/__init__.py -------------------------------------------------------------------------------- /FSEL_ECCV_2024/lib/__pycache__/GatedConv.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSYSI/FSEL/f3be46404486f70404c79f1c1b93025e77a7233e/FSEL_ECCV_2024/lib/__pycache__/GatedConv.cpython-38.pyc -------------------------------------------------------------------------------- /FSEL_ECCV_2024/lib/__pycache__/Modules.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSYSI/FSEL/f3be46404486f70404c79f1c1b93025e77a7233e/FSEL_ECCV_2024/lib/__pycache__/Modules.cpython-38.pyc -------------------------------------------------------------------------------- /FSEL_ECCV_2024/lib/__pycache__/Network.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSYSI/FSEL/f3be46404486f70404c79f1c1b93025e77a7233e/FSEL_ECCV_2024/lib/__pycache__/Network.cpython-38.pyc -------------------------------------------------------------------------------- /FSEL_ECCV_2024/lib/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSYSI/FSEL/f3be46404486f70404c79f1c1b93025e77a7233e/FSEL_ECCV_2024/lib/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /FSEL_ECCV_2024/lib/pvt_v2.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from functools import partial 5 | 6 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_ 7 | from timm.models.registry import register_model 8 | from timm.models.vision_transformer import _cfg 9 | import math 10 | 11 | 12 | class Mlp(nn.Module): 13 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0., linear=False): 14 | super().__init__() 15 | out_features = out_features or in_features 16 | hidden_features = hidden_features or in_features 17 | self.fc1 = nn.Linear(in_features, hidden_features) 18 | self.dwconv = DWConv(hidden_features) 19 | self.act = act_layer() 20 | self.fc2 = nn.Linear(hidden_features, out_features) 21 | self.drop = nn.Dropout(drop) 22 | self.linear = linear 23 | if self.linear: 24 | self.relu = nn.ReLU(inplace=True) 25 | self.apply(self._init_weights) 26 | 27 | def _init_weights(self, m): 28 | if isinstance(m, nn.Linear): 29 | trunc_normal_(m.weight, std=.02) 30 | if isinstance(m, nn.Linear) and m.bias is not None: 31 | nn.init.constant_(m.bias, 0) 32 | elif isinstance(m, nn.LayerNorm): 33 | nn.init.constant_(m.bias, 0) 34 | nn.init.constant_(m.weight, 1.0) 35 | elif isinstance(m, nn.Conv2d): 36 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 37 | fan_out //= m.groups 38 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 39 | if m.bias is not None: 40 | m.bias.data.zero_() 41 | 42 | def forward(self, x, H, W): 43 | x = self.fc1(x) 44 | if self.linear: 45 | x = self.relu(x) 46 | x = self.dwconv(x, H, W) 47 | x = self.act(x) 48 | x = self.drop(x) 49 | x = self.fc2(x) 50 | x = self.drop(x) 51 | return x 52 | 53 | 54 | class Attention(nn.Module): 55 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1, linear=False): 56 | super().__init__() 57 | assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}." 58 | 59 | self.dim = dim 60 | self.num_heads = num_heads 61 | head_dim = dim // num_heads 62 | self.scale = qk_scale or head_dim ** -0.5 63 | 64 | self.q = nn.Linear(dim, dim, bias=qkv_bias) 65 | self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias) 66 | self.attn_drop = nn.Dropout(attn_drop) 67 | self.proj = nn.Linear(dim, dim) 68 | self.proj_drop = nn.Dropout(proj_drop) 69 | 70 | self.linear = linear 71 | self.sr_ratio = sr_ratio 72 | if not linear: 73 | if sr_ratio > 1: 74 | self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio) 75 | self.norm = nn.LayerNorm(dim) 76 | else: 77 | self.pool = nn.AdaptiveAvgPool2d(7) 78 | self.sr = nn.Conv2d(dim, dim, kernel_size=1, stride=1) 79 | self.norm = nn.LayerNorm(dim) 80 | self.act = nn.GELU() 81 | self.apply(self._init_weights) 82 | 83 | def _init_weights(self, m): 84 | if isinstance(m, nn.Linear): 85 | trunc_normal_(m.weight, std=.02) 86 | if isinstance(m, nn.Linear) and m.bias is not None: 87 | nn.init.constant_(m.bias, 0) 88 | elif isinstance(m, nn.LayerNorm): 89 | nn.init.constant_(m.bias, 0) 90 | nn.init.constant_(m.weight, 1.0) 91 | elif isinstance(m, nn.Conv2d): 92 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 93 | fan_out //= m.groups 94 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 95 | if m.bias is not None: 96 | m.bias.data.zero_() 97 | 98 | def forward(self, x, H, W): 99 | B, N, C = x.shape 100 | q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) 101 | 102 | if not self.linear: 103 | if self.sr_ratio > 1: 104 | x_ = x.permute(0, 2, 1).reshape(B, C, H, W) 105 | x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1) 106 | x_ = self.norm(x_) 107 | kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 108 | else: 109 | kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 110 | else: 111 | x_ = x.permute(0, 2, 1).reshape(B, C, H, W) 112 | x_ = self.sr(self.pool(x_)).reshape(B, C, -1).permute(0, 2, 1) 113 | x_ = self.norm(x_) 114 | x_ = self.act(x_) 115 | kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 116 | k, v = kv[0], kv[1] 117 | 118 | attn = (q @ k.transpose(-2, -1)) * self.scale 119 | attn = attn.softmax(dim=-1) 120 | attn = self.attn_drop(attn) 121 | 122 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 123 | x = self.proj(x) 124 | x = self.proj_drop(x) 125 | 126 | return x 127 | 128 | 129 | class Block(nn.Module): 130 | 131 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 132 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, sr_ratio=1, linear=False): 133 | super().__init__() 134 | self.norm1 = norm_layer(dim) 135 | self.attn = Attention( 136 | dim, 137 | num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, 138 | attn_drop=attn_drop, proj_drop=drop, sr_ratio=sr_ratio, linear=linear) 139 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here 140 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 141 | self.norm2 = norm_layer(dim) 142 | mlp_hidden_dim = int(dim * mlp_ratio) 143 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop, linear=linear) 144 | 145 | self.apply(self._init_weights) 146 | 147 | def _init_weights(self, m): 148 | if isinstance(m, nn.Linear): 149 | trunc_normal_(m.weight, std=.02) 150 | if isinstance(m, nn.Linear) and m.bias is not None: 151 | nn.init.constant_(m.bias, 0) 152 | elif isinstance(m, nn.LayerNorm): 153 | nn.init.constant_(m.bias, 0) 154 | nn.init.constant_(m.weight, 1.0) 155 | elif isinstance(m, nn.Conv2d): 156 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 157 | fan_out //= m.groups 158 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 159 | if m.bias is not None: 160 | m.bias.data.zero_() 161 | 162 | def forward(self, x, H, W): 163 | x = x + self.drop_path(self.attn(self.norm1(x), H, W)) 164 | x = x + self.drop_path(self.mlp(self.norm2(x), H, W)) 165 | 166 | return x 167 | 168 | 169 | class OverlapPatchEmbed(nn.Module): 170 | """ Image to Patch Embedding 171 | """ 172 | 173 | def __init__(self, img_size=224, patch_size=7, stride=4, in_chans=3, embed_dim=768): 174 | super().__init__() 175 | 176 | img_size = to_2tuple(img_size) 177 | patch_size = to_2tuple(patch_size) 178 | 179 | assert max(patch_size) > stride, "Set larger patch_size than stride" 180 | 181 | self.img_size = img_size 182 | self.patch_size = patch_size 183 | self.H, self.W = img_size[0] // stride, img_size[1] // stride 184 | self.num_patches = self.H * self.W 185 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride, 186 | padding=(patch_size[0] // 2, patch_size[1] // 2)) 187 | self.norm = nn.LayerNorm(embed_dim) 188 | 189 | self.apply(self._init_weights) 190 | 191 | def _init_weights(self, m): 192 | if isinstance(m, nn.Linear): 193 | trunc_normal_(m.weight, std=.02) 194 | if isinstance(m, nn.Linear) and m.bias is not None: 195 | nn.init.constant_(m.bias, 0) 196 | elif isinstance(m, nn.LayerNorm): 197 | nn.init.constant_(m.bias, 0) 198 | nn.init.constant_(m.weight, 1.0) 199 | elif isinstance(m, nn.Conv2d): 200 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 201 | fan_out //= m.groups 202 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 203 | if m.bias is not None: 204 | m.bias.data.zero_() 205 | 206 | def forward(self, x): 207 | x = self.proj(x) 208 | _, _, H, W = x.shape 209 | x = x.flatten(2).transpose(1, 2) 210 | x = self.norm(x) 211 | 212 | return x, H, W 213 | 214 | 215 | class PyramidVisionTransformerV2(nn.Module): 216 | def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dims=[64, 128, 256, 512], 217 | num_heads=[1, 2, 4, 8], mlp_ratios=[4, 4, 4, 4], qkv_bias=False, qk_scale=None, drop_rate=0., 218 | attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm, 219 | depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1], num_stages=4, linear=False): 220 | super().__init__() 221 | self.num_classes = num_classes 222 | self.depths = depths 223 | self.num_stages = num_stages 224 | 225 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule 226 | cur = 0 227 | 228 | for i in range(num_stages): 229 | patch_embed = OverlapPatchEmbed(img_size=img_size if i == 0 else img_size // (2 ** (i + 1)), 230 | patch_size=7 if i == 0 else 3, 231 | stride=4 if i == 0 else 2, 232 | in_chans=in_chans if i == 0 else embed_dims[i - 1], 233 | embed_dim=embed_dims[i]) 234 | 235 | block = nn.ModuleList([Block( 236 | dim=embed_dims[i], num_heads=num_heads[i], mlp_ratio=mlp_ratios[i], qkv_bias=qkv_bias, qk_scale=qk_scale, 237 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + j], norm_layer=norm_layer, 238 | sr_ratio=sr_ratios[i], linear=linear) 239 | for j in range(depths[i])]) 240 | norm = norm_layer(embed_dims[i]) 241 | cur += depths[i] 242 | 243 | setattr(self, f"patch_embed{i + 1}", patch_embed) 244 | setattr(self, f"block{i + 1}", block) 245 | setattr(self, f"norm{i + 1}", norm) 246 | 247 | # classification head 248 | #self.head = nn.Linear(embed_dims[3], num_classes) if num_classes > 0 else nn.Identity() 249 | 250 | self.apply(self._init_weights) 251 | self.initialize() 252 | 253 | def _init_weights(self, m): 254 | if isinstance(m, nn.Linear): 255 | trunc_normal_(m.weight, std=.02) 256 | if isinstance(m, nn.Linear) and m.bias is not None: 257 | nn.init.constant_(m.bias, 0) 258 | elif isinstance(m, nn.LayerNorm): 259 | nn.init.constant_(m.bias, 0) 260 | nn.init.constant_(m.weight, 1.0) 261 | elif isinstance(m, nn.Conv2d): 262 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 263 | fan_out //= m.groups 264 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 265 | if m.bias is not None: 266 | m.bias.data.zero_() 267 | 268 | def init_weights(self, pretrained=None): 269 | if isinstance(pretrained, str): 270 | logger = 1 271 | #load_checkpoint(self, pretrained, map_location='cpu', strict=False, logger=logger) 272 | 273 | def reset_drop_path(self, drop_path_rate): 274 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(self.depths))] 275 | cur = 0 276 | for i in range(self.depths[0]): 277 | self.block1[i].drop_path.drop_prob = dpr[cur + i] 278 | 279 | cur += self.depths[0] 280 | for i in range(self.depths[1]): 281 | self.block2[i].drop_path.drop_prob = dpr[cur + i] 282 | 283 | cur += self.depths[1] 284 | for i in range(self.depths[2]): 285 | self.block3[i].drop_path.drop_prob = dpr[cur + i] 286 | 287 | cur += self.depths[2] 288 | for i in range(self.depths[3]): 289 | self.block4[i].drop_path.drop_prob = dpr[cur + i] 290 | 291 | def freeze_patch_emb(self): 292 | self.patch_embed1.requires_grad = False 293 | 294 | @torch.jit.ignore 295 | def no_weight_decay(self): 296 | return {'pos_embed1', 'pos_embed2', 'pos_embed3', 'pos_embed4', 'cls_token'} # has pos_embed may be better 297 | 298 | def get_classifier(self): 299 | return self.head 300 | 301 | def reset_classifier(self, num_classes, global_pool=''): 302 | self.num_classes = num_classes 303 | self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() 304 | 305 | def forward_features(self, x): 306 | B = x.shape[0] 307 | outs = [] 308 | 309 | for i in range(self.num_stages): 310 | patch_embed = getattr(self, f"patch_embed{i + 1}") 311 | block = getattr(self, f"block{i + 1}") 312 | norm = getattr(self, f"norm{i + 1}") 313 | x, H, W = patch_embed(x) 314 | for blk in block: 315 | x = blk(x, H, W) 316 | x = norm(x) 317 | if i != self.num_stages: 318 | x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() 319 | outs.append(x) 320 | return outs[::-1] 321 | 322 | def forward(self, x): 323 | x = self.forward_features(x) 324 | #x = self.head(x) 325 | 326 | return x 327 | 328 | def initialize(self): 329 | pass 330 | class DWConv(nn.Module): 331 | def __init__(self, dim=768): 332 | super(DWConv, self).__init__() 333 | self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim) 334 | 335 | def forward(self, x, H, W): 336 | B, N, C = x.shape 337 | x = x.transpose(1, 2).view(B, C, H, W) 338 | x = self.dwconv(x) 339 | x = x.flatten(2).transpose(1, 2) 340 | 341 | return x 342 | 343 | 344 | def _conv_filter(state_dict, patch_size=16): 345 | """ convert patch embedding weight from manual patchify + linear proj to conv""" 346 | out_dict = {} 347 | for k, v in state_dict.items(): 348 | if 'patch_embed.proj.weight' in k: 349 | v = v.reshape((v.shape[0], 3, patch_size, patch_size)) 350 | out_dict[k] = v 351 | 352 | return out_dict 353 | 354 | 355 | @register_model 356 | def pvt_v2_b0(pretrained=False, **kwargs): 357 | model = PyramidVisionTransformerV2( 358 | patch_size=4, embed_dims=[32, 64, 160, 256], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], qkv_bias=True, 359 | norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1], 360 | **kwargs) 361 | model.default_cfg = _cfg() 362 | 363 | return model 364 | 365 | 366 | @register_model 367 | def pvt_v2_b1(pretrained=False, **kwargs): 368 | model = PyramidVisionTransformerV2( 369 | patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], qkv_bias=True, 370 | norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1], 371 | **kwargs) 372 | model.default_cfg = _cfg() 373 | 374 | return model 375 | 376 | 377 | @register_model 378 | def pvt_v2_b2(pretrained=False, **kwargs): 379 | model = PyramidVisionTransformerV2( 380 | patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], qkv_bias=True, 381 | norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1], **kwargs) 382 | model.default_cfg = _cfg() 383 | 384 | return model 385 | 386 | 387 | @register_model 388 | def pvt_v2_b3(pretrained=False, **kwargs): 389 | model = PyramidVisionTransformerV2( 390 | patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], qkv_bias=True, 391 | norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 18, 3], sr_ratios=[8, 4, 2, 1], 392 | **kwargs) 393 | model.default_cfg = _cfg() 394 | 395 | return model 396 | 397 | 398 | @register_model 399 | def pvt_v2_b4(pretrained=False, **kwargs): 400 | model = PyramidVisionTransformerV2( 401 | patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], qkv_bias=True, 402 | norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 8, 27, 3], sr_ratios=[8, 4, 2, 1], 403 | **kwargs) 404 | model.default_cfg = _cfg() 405 | 406 | return model 407 | 408 | 409 | @register_model 410 | def pvt_v2_b5(pretrained=False, **kwargs): 411 | model = PyramidVisionTransformerV2( 412 | patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], qkv_bias=True, 413 | norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 6, 40, 3], sr_ratios=[8, 4, 2, 1], 414 | **kwargs) 415 | model.default_cfg = _cfg() 416 | 417 | return model 418 | 419 | 420 | @register_model 421 | def pvt_v2_b2_li(pretrained=False, **kwargs): 422 | model = PyramidVisionTransformerV2( 423 | patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], qkv_bias=True, 424 | norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1], linear=True, **kwargs) 425 | model.default_cfg = _cfg() 426 | 427 | return model 428 | -------------------------------------------------------------------------------- /FSEL_ECCV_2024/lib_initial/Network_PVT_initial.py: -------------------------------------------------------------------------------- 1 | import timm 2 | from lib.pvt_v2 import pvt_v2_b4 3 | import torch.nn as nn 4 | import torch 5 | import torch.nn.functional as F 6 | from lib_initial.module_FSEL import Module3_1, Module3_2, Module3_3, Module2,Module1_res 7 | 8 | 9 | ''' 10 | backbone: PVT 11 | ''' 12 | 13 | 14 | class Network(nn.Module): 15 | # PVT based encoder decoder 16 | def __init__(self, channels=128): 17 | super(Network, self).__init__() 18 | 19 | 20 | self.shared_encoder = pvt_v2_b4() 21 | 22 | pretrained_dict = torch.load('/COD/pre_train_pth/pvt_v2_b4.pth') 23 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in self.shared_encoder.state_dict()} 24 | self.shared_encoder.load_state_dict(pretrained_dict) 25 | 26 | self.dePixelShuffle = torch.nn.PixelShuffle(2) 27 | 28 | self.up = nn.Sequential( 29 | nn.Conv2d(channels//4, channels, kernel_size=1),nn.BatchNorm2d(channels), 30 | nn.Conv2d(channels, channels, kernel_size=3, padding=1),nn.BatchNorm2d(channels),nn.ReLU(True) 31 | ) 32 | 33 | self.Module1_5 = Module1_res(512+channels, channels) 34 | self.Module1_4 = Module1_res(320+channels, channels) 35 | self.Module1_3 = Module1_res(128+channels, channels) 36 | self.Module1_2 = Module1_res(64+channels, channels) 37 | 38 | self.Module2 = Module2(512, channels) 39 | 40 | self.Module3_1 = Module3_1(channels, channels) 41 | self.Module3_2 = Module3_2(channels, channels) 42 | self.Module3_3 = Module3_3(channels,channels) 43 | 44 | def forward(self, x): 45 | image = x 46 | 47 | en_feats = self.shared_encoder(x) 48 | x4, x3, x2, x1 = en_feats 49 | 50 | 51 | p1 = self.Module2(x4) 52 | x5_4 = p1 53 | x5_4_1 = x5_4.expand(-1, 128, -1, -1) 54 | 55 | x4 = self.Module1_5(torch.cat((x4,x5_4_1),1)) 56 | x4_up = self.up(self.dePixelShuffle(x4)) 57 | 58 | x3 = self.Module1_4(torch.cat((x3,x4_up),1)) 59 | x3_up = self.up(self.dePixelShuffle(x3)) 60 | 61 | x2 = self.Module1_3(torch.cat((x2,x3_up),1)) 62 | x2_up = self.up(self.dePixelShuffle(x2)) 63 | 64 | 65 | x1 = self.Module1_2(torch.cat((x1,x2_up),1)) 66 | 67 | 68 | x4 = self.Module3_1(x4,x5_4) 69 | x3 = self.Module3_1(x3,x4) 70 | x2 = self.Module3_2(x2,x3,x4) 71 | x1 = self.Module3_3(x1,x2,x3,x4) 72 | 73 | 74 | p0 = F.interpolate(p1, size=image.size()[2:], mode='bilinear', align_corners=True) 75 | f4 = F.interpolate(x4, size=image.size()[2:], mode='bilinear', align_corners=True) 76 | f3 = F.interpolate(x3, size=image.size()[2:], mode='bilinear', align_corners=True) 77 | f2 = F.interpolate(x2, size=image.size()[2:], mode='bilinear', align_corners=True) 78 | f1 = F.interpolate(x1, size=image.size()[2:], mode='bilinear', align_corners=True) 79 | 80 | 81 | return p0, f4, f3, f2, f1 82 | 83 | -------------------------------------------------------------------------------- /FSEL_ECCV_2024/lib_initial/Network_Res2Net_initial.py: -------------------------------------------------------------------------------- 1 | 2 | import timm 3 | import torch.nn as nn 4 | import torch 5 | import torch.nn.functional as F 6 | from lib_initial.module_FSEL import Module3_1, Module3_2, Module3_3, Module2, Module1_res 7 | from lib.Res2Net_v1b import res2net50_v1b_26w_4s 8 | 9 | 10 | ''' 11 | backbone: res2net 12 | ''' 13 | 14 | 15 | class Network(nn.Module): 16 | # res2net based encoder decoder 17 | def __init__(self, channels=128): 18 | super(Network, self).__init__() 19 | self.shared_encoder = res2net50_v1b_26w_4s() 20 | 21 | pretrained_dict = torch.load('/COD/pre_train_pth/res2net50_v1b_26w_4s-3cf99910.pth') 22 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in self.shared_encoder.state_dict()} 23 | self.shared_encoder.load_state_dict(pretrained_dict) 24 | 25 | self.dePixelShuffle = torch.nn.PixelShuffle(2) 26 | 27 | self.up = nn.Sequential( 28 | nn.Conv2d(channels//4, channels, kernel_size=1),nn.BatchNorm2d(channels), 29 | nn.Conv2d(channels, channels, kernel_size=3, padding=1),nn.BatchNorm2d(channels),nn.ReLU(True) 30 | ) 31 | 32 | self.Module1_5 = Module1_res(2048+channels, channels) 33 | self.Module1_4 = Module1_res(1024+channels, channels) 34 | self.Module1_3 = Module1_res(512+channels, channels) 35 | self.Module1_2 = Module1_res(256+channels, channels) 36 | 37 | self.Module2 = Module2(2048, channels) 38 | 39 | self.Module3_1 = Module3_1(channels, channels) 40 | self.Module3_2 = Module3_2(channels, channels) 41 | self.Module3_3 = Module3_3(channels,channels) 42 | 43 | def forward(self, x): 44 | image = x 45 | # Feature Extraction 46 | x = self.shared_encoder.conv1(x) 47 | x = self.shared_encoder.bn1(x) 48 | x = self.shared_encoder.relu(x) 49 | x = self.shared_encoder.maxpool(x) 50 | x1 = self.shared_encoder.layer1(x) 51 | x2 = self.shared_encoder.layer2(x1) 52 | x3 = self.shared_encoder.layer3(x2) 53 | x4 = self.shared_encoder.layer4(x3) 54 | 55 | 56 | 57 | p1 = self.Module2(x4) 58 | x5_4 = p1 59 | x5_4_1 = x5_4.expand(-1, 128, -1, -1) 60 | 61 | x4 = self.Module1_5(torch.cat((x4,x5_4_1),1)) 62 | x4_up = self.up(self.dePixelShuffle(x4)) 63 | 64 | 65 | x3 = self.Module1_4(torch.cat((x3,x4_up),1)) 66 | x3_up = self.up(self.dePixelShuffle(x3)) 67 | 68 | 69 | x2 = self.Module1_3(torch.cat((x2,x3_up),1)) 70 | x2_up = self.up(self.dePixelShuffle(x2)) 71 | 72 | 73 | x1 = self.Module1_2(torch.cat((x1,x2_up),1)) 74 | 75 | x4 = self.Module3_1(x4,x5_4) 76 | x3 = self.Module3_1(x3,x4) 77 | x2 = self.Module3_2(x2,x3,x4) 78 | x1 = self.Module3_3(x1,x2,x3,x4) 79 | 80 | p0 = F.interpolate(p1, size=image.size()[2:], mode='bilinear', align_corners=True) 81 | f4 = F.interpolate(x4, size=image.size()[2:], mode='bilinear', align_corners=True) 82 | f3 = F.interpolate(x3, size=image.size()[2:], mode='bilinear', align_corners=True) 83 | f2 = F.interpolate(x2, size=image.size()[2:], mode='bilinear', align_corners=True) 84 | f1 = F.interpolate(x1, size=image.size()[2:], mode='bilinear', align_corners=True) 85 | 86 | return p0, f4, f3, f2, f1 87 | 88 | 89 | -------------------------------------------------------------------------------- /FSEL_ECCV_2024/lib_initial/Network_ResNet_initial.py: -------------------------------------------------------------------------------- 1 | 2 | import timm 3 | import torch.nn as nn 4 | import torch 5 | import torch.nn.functional as F 6 | from lib_initial.module_FSEL import Module3_1, Module3_2, Module3_3, Module2, Module1_res 7 | 8 | 9 | ''' 10 | backbone: resnet50 11 | ''' 12 | 13 | 14 | class Network(nn.Module): 15 | # resnet based encoder decoder 16 | def __init__(self, channels=128): 17 | super(Network, self).__init__() 18 | self.shared_encoder = timm.create_model(model_name="resnet50", pretrained=True, in_chans=3, features_only=True) 19 | 20 | self.dePixelShuffle = torch.nn.PixelShuffle(2) 21 | 22 | self.up = nn.Sequential( 23 | nn.Conv2d(channels//4, channels, kernel_size=1),nn.BatchNorm2d(channels), 24 | nn.Conv2d(channels, channels, kernel_size=3, padding=1),nn.BatchNorm2d(channels),nn.ReLU(True) 25 | ) 26 | self.channel = channels 27 | self.Module1_5 = Module1_res(2048+channels, channels) 28 | self.Module1_4 = Module1_res(1024+channels, channels) 29 | self.Module1_3 = Module1_res(512+channels, channels) 30 | self.Module1_2 = Module1_res(256+channels, channels) 31 | 32 | self.Module2 = Module2(2048, channels) 33 | 34 | self.Module3_1 = Module3_1(channels, channels) 35 | self.Module3_2 = Module3_2(channels, channels) 36 | self.Module3_3 = Module3_3(channels,channels) 37 | 38 | def forward(self, x): 39 | image = x 40 | # Feature Extraction 41 | en_feats = self.shared_encoder(x) 42 | x0, x1, x2, x3, x4 = en_feats 43 | 44 | p1 = self.Module2(x4) 45 | x5_4 = p1 46 | x5_4_1 = x5_4.expand(-1, self.channel, -1, -1) 47 | 48 | x4 = self.Module1_5(torch.cat((x4,x5_4_1),1)) 49 | x4_up = self.up(self.dePixelShuffle(x4)) 50 | 51 | x3 = self.Module1_4(torch.cat((x3,x4_up),1)) 52 | x3_up = self.up(self.dePixelShuffle(x3)) 53 | 54 | x2 = self.Module1_3(torch.cat((x2,x3_up),1)) 55 | x2_up = self.up(self.dePixelShuffle(x2)) 56 | 57 | x1 = self.Module1_2(torch.cat((x1,x2_up),1)) 58 | 59 | x4 = self.Module3_1(x4,x5_4) 60 | x3 = self.Module3_1(x3,x4) 61 | x2 = self.Module3_2(x2,x3,x4) 62 | x1 = self.Module3_3(x1,x2,x3,x4) 63 | 64 | p0 = F.interpolate(p1, size=image.size()[2:], mode='bilinear', align_corners=True) 65 | f4 = F.interpolate(x4, size=image.size()[2:], mode='bilinear', align_corners=True) 66 | f3 = F.interpolate(x3, size=image.size()[2:], mode='bilinear', align_corners=True) 67 | f2 = F.interpolate(x2, size=image.size()[2:], mode='bilinear', align_corners=True) 68 | f1 = F.interpolate(x1, size=image.size()[2:], mode='bilinear', align_corners=True) 69 | 70 | return p0, f4, f3, f2, f1 71 | 72 | -------------------------------------------------------------------------------- /FSEL_ECCV_2024/lib_initial/module_FSEL.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import fvcore.nn.weight_init as weight_init 5 | from einops import rearrange 6 | import numbers 7 | 8 | 9 | def to_3d(x): 10 | return rearrange(x, 'b c h w -> b (h w) c') 11 | 12 | def to_4d(x,h,w): 13 | return rearrange(x, 'b (h w) c -> b c h w',h=h,w=w) 14 | 15 | class BiasFree_LayerNorm(nn.Module): 16 | def __init__(self, normalized_shape): 17 | super(BiasFree_LayerNorm, self).__init__() 18 | if isinstance(normalized_shape, numbers.Integral): 19 | normalized_shape = (normalized_shape,) 20 | normalized_shape = torch.Size(normalized_shape) 21 | 22 | assert len(normalized_shape) == 1 23 | 24 | self.weight = nn.Parameter(torch.ones(normalized_shape)) 25 | self.normalized_shape = normalized_shape 26 | 27 | def forward(self, x): 28 | sigma = x.var(-1, keepdim=True, unbiased=False) 29 | return x / torch.sqrt(sigma + 1e-5) * self.weight 30 | 31 | 32 | class WithBias_LayerNorm(nn.Module): 33 | def __init__(self, normalized_shape): 34 | super(WithBias_LayerNorm, self).__init__() 35 | if isinstance(normalized_shape, numbers.Integral): 36 | normalized_shape = (normalized_shape,) 37 | normalized_shape = torch.Size(normalized_shape) 38 | 39 | assert len(normalized_shape) == 1 40 | 41 | self.weight = nn.Parameter(torch.ones(normalized_shape)) 42 | self.bias = nn.Parameter(torch.zeros(normalized_shape)) 43 | self.normalized_shape = normalized_shape 44 | 45 | def forward(self, x): 46 | mu = x.mean(-1, keepdim=True) 47 | sigma = x.var(-1, keepdim=True, unbiased=False) 48 | return (x - mu) / torch.sqrt(sigma + 1e-5) * self.weight + self.bias 49 | 50 | def initialize(self): 51 | weight_init(self) 52 | 53 | 54 | class LayerNorm(nn.Module): 55 | def __init__(self, dim, LayerNorm_type): 56 | super(LayerNorm, self).__init__() 57 | if LayerNorm_type == 'BiasFree': 58 | self.body = BiasFree_LayerNorm(dim) 59 | else: 60 | self.body = WithBias_LayerNorm(dim) 61 | 62 | def forward(self, x): 63 | h, w = x.shape[-2:] 64 | return to_4d(self.body(to_3d(x)), h, w) 65 | 66 | def initialize(self): 67 | weight_init(self) 68 | 69 | 70 | class FeedForward(nn.Module): 71 | def __init__(self, dim, ffn_expansion_factor, bias): 72 | super(FeedForward, self).__init__() 73 | 74 | self.dwconv1 = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1,groups=dim, bias=bias) 75 | self.dwconv2 = nn.Conv2d(dim*2, dim*2, kernel_size=3, stride=1, padding=1, groups=dim, bias=bias) 76 | self.project_out = nn.Conv2d(dim*4, dim, kernel_size=1, bias=bias) 77 | self.weight = nn.Sequential( 78 | nn.Conv2d(dim, dim // 16, 1, bias=True), 79 | nn.BatchNorm2d(dim // 16), 80 | nn.ReLU(True), 81 | nn.Conv2d(dim // 16, dim, 1, bias=True), 82 | nn.Sigmoid()) 83 | self.weight1 = nn.Sequential( 84 | nn.Conv2d(dim*2, dim // 16, 1, bias=True), 85 | nn.BatchNorm2d(dim // 16), 86 | nn.ReLU(True), 87 | nn.Conv2d(dim // 16, dim*2, 1, bias=True), 88 | nn.Sigmoid()) 89 | def forward(self, x): 90 | 91 | x_p = torch.abs(self.weight(torch.fft.fft2(x.float()).real)*torch.fft.fft2(x.float())) 92 | x_p_gelu = F.gelu(x_p)*x_p 93 | 94 | x_t = self.dwconv1(x) 95 | x_t_gelu = F.gelu(x_t)*x_t 96 | 97 | x_p = torch.fft.fft2(torch.cat((x_t_gelu,x_p_gelu),1)) 98 | x_p = torch.abs(torch.fft.ifft2(self.weight1(x_p.real)*x_p)) 99 | 100 | x_t = self.dwconv2(torch.cat((x_t_gelu,x_p_gelu),1)) 101 | out = self.project_out(torch.cat((x_p,x_t),1)) 102 | return out 103 | 104 | def initialize(self): 105 | weight_init(self) 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | def custom_complex_normalization(input_tensor, dim=-1): 114 | real_part = input_tensor.real 115 | imag_part = input_tensor.imag 116 | norm_real = F.softmax(real_part, dim=dim) 117 | norm_imag = F.softmax(imag_part, dim=dim) 118 | 119 | normalized_tensor = torch.complex(norm_real, norm_imag) 120 | 121 | return normalized_tensor 122 | 123 | class Attention_F(nn.Module): 124 | def __init__(self, dim, num_heads, bias,): 125 | super(Attention_F, self).__init__() 126 | self.num_heads = num_heads 127 | self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1)) 128 | self.project_out = nn.Conv2d(dim*2, dim, kernel_size=1, bias=bias) 129 | self.weight = nn.Sequential( 130 | nn.Conv2d(dim, dim // 16, 1, bias=True), 131 | nn.BatchNorm2d(dim // 16), 132 | nn.ReLU(True), 133 | nn.Conv2d(dim // 16, dim, 1, bias=True), 134 | nn.Sigmoid()) 135 | def forward(self, x): 136 | b, c, h, w = x.shape 137 | 138 | q_f = torch.fft.fft2(x.float()) 139 | k_f = torch.fft.fft2(x.float()) 140 | v_f = torch.fft.fft2(x.float()) 141 | 142 | q_f = rearrange(q_f, 'b (head c) h w -> b head c (h w)', head=self.num_heads) 143 | k_f = rearrange(k_f, 'b (head c) h w -> b head c (h w)', head=self.num_heads) 144 | v_f = rearrange(v_f, 'b (head c) h w -> b head c (h w)', head=self.num_heads) 145 | 146 | q_f = torch.nn.functional.normalize(q_f, dim=-1) 147 | k_f = torch.nn.functional.normalize(k_f, dim=-1) 148 | attn_f = (q_f @ k_f.transpose(-2, -1)) * self.temperature 149 | attn_f = custom_complex_normalization(attn_f, dim=-1) 150 | out_f = torch.abs(torch.fft.ifft2(attn_f @ v_f)) 151 | out_f = rearrange(out_f, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w) 152 | out_lf = torch.abs(torch.fft.ifft2(self.weight(torch.fft.fft2(x.float()).real)*torch.fft.fft2(x.float()))) 153 | out = self.project_out(torch.cat((out_f,out_lf),1)) 154 | 155 | return out 156 | 157 | class Attention_S(nn.Module): 158 | def __init__(self, dim, num_heads, bias,): 159 | super(Attention_S, self).__init__() 160 | self.num_heads = num_heads 161 | self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1)) 162 | 163 | self.qkv1conv_1 = nn.Conv2d(dim,dim,kernel_size=1) 164 | self.qkv2conv_1 = nn.Conv2d(dim, dim, kernel_size=1) 165 | self.qkv3conv_1 = nn.Conv2d(dim, dim, kernel_size=1) 166 | 167 | 168 | self.qkv1conv_3 = nn.Conv2d(dim, dim//2, kernel_size=3, stride=1, padding=1, groups=dim//2, bias=bias) 169 | self.qkv2conv_3 = nn.Conv2d(dim, dim//2, kernel_size=3, stride=1, padding=1, groups=dim//2, bias=bias) 170 | self.qkv3conv_3 = nn.Conv2d(dim, dim//2, kernel_size=3, stride=1, padding=1, groups=dim//2, bias=bias) 171 | 172 | self.qkv1conv_5 = nn.Conv2d(dim, dim // 2, kernel_size=5, stride=1, padding=2, groups=dim//2, bias=bias) 173 | self.qkv2conv_5 = nn.Conv2d(dim, dim // 2, kernel_size=5, stride=1, padding=2, groups=dim//2, bias=bias) 174 | self.qkv3conv_5 = nn.Conv2d(dim, dim // 2, kernel_size=5, stride=1, padding=2, groups=dim//2, bias=bias) 175 | 176 | 177 | self.conv_3 = nn.Conv2d(dim, dim//2, kernel_size=3, stride=1, padding=1, groups=dim//2, bias=bias) 178 | self.conv_5 = nn.Conv2d(dim, dim // 2, kernel_size=5, stride=1, padding=2, groups=dim//2, bias=bias) 179 | self.project_out = nn.Conv2d(dim*2, dim, kernel_size=1, bias=bias) 180 | 181 | def forward(self, x): 182 | b, c, h, w = x.shape 183 | q_t = torch.cat((self.qkv1conv_3(self.qkv1conv_1(x)),self.qkv1conv_5(self.qkv1conv_1(x))),1) 184 | k_t = torch.cat((self.qkv2conv_3(self.qkv2conv_1(x)),self.qkv2conv_5(self.qkv2conv_1(x))),1) 185 | v_t = torch.cat((self.qkv3conv_3(self.qkv3conv_1(x)),self.qkv3conv_5(self.qkv3conv_1(x))),1) 186 | 187 | q_t = rearrange(q_t, 'b (head c) h w -> b head c (h w)', head=self.num_heads) 188 | k_t = rearrange(k_t, 'b (head c) h w -> b head c (h w)', head=self.num_heads) 189 | v_t = rearrange(v_t, 'b (head c) h w -> b head c (h w)', head=self.num_heads) 190 | 191 | q_t = torch.nn.functional.normalize(q_t, dim=-1) 192 | k_t = torch.nn.functional.normalize(k_t, dim=-1) 193 | attn_t = (q_t @ k_t.transpose(-2, -1)) * self.temperature 194 | attn_t = attn_t.softmax(dim=-1) 195 | out_t = (attn_t @ v_t) 196 | out_t = rearrange(out_t, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w) 197 | out_p = torch.cat((self.conv_3(x),self.conv_5(x)),1) 198 | out = self.project_out(torch.cat((out_t,out_p),1)) 199 | 200 | return out 201 | 202 | 203 | def initialize(self): 204 | weight_init(self) 205 | 206 | 207 | class Module1(nn.Module): 208 | def __init__(self, mode='dilation', dim=128, num_heads=8, ffn_expansion_factor=4, bias=False, 209 | LayerNorm_type='WithBias'): 210 | super(Module1, self).__init__() 211 | self.project_out = nn.Conv2d(dim * 2, dim, kernel_size=1, bias=bias) 212 | self.norm1 = LayerNorm(dim, LayerNorm_type) 213 | self.attn_S = Attention_S(dim, num_heads, bias) 214 | self.attn_F = Attention_F(dim, num_heads, bias) 215 | self.norm2 = LayerNorm(dim, LayerNorm_type) 216 | self.ffn = FeedForward(dim, ffn_expansion_factor, bias) 217 | 218 | def forward(self, x): 219 | x = x + torch.add(self.attn_F(self.norm1(x)),self.attn_S(self.norm1(x))) 220 | x = x + self.ffn(self.norm2(x)) 221 | return x 222 | 223 | 224 | class Module1_res(nn.Module): 225 | def __init__(self, in_channel, out_channel): 226 | super(Module1_res, self).__init__() 227 | self.conv1 = nn.Sequential( 228 | nn.Conv2d(in_channel, out_channel, 1), nn.BatchNorm2d(out_channel),nn.ReLU(True) 229 | ) 230 | self.reduce = nn.Sequential( 231 | nn.Conv2d(out_channel*2, out_channel, 1),nn.BatchNorm2d(out_channel),nn.ReLU(True) 232 | ) 233 | self.relu = nn.ReLU(True) 234 | self.Module1 = Module1(dim=out_channel) 235 | 236 | def forward(self, x): 237 | x0 = self.conv1(x) 238 | x_FT = self.Module1(x0) 239 | x = self.reduce(torch.cat((x0,x_FT),1))+x0 240 | return x 241 | 242 | class Module2(nn.Module): 243 | def __init__(self, channels, in_channels): 244 | super(Module2, self).__init__() 245 | 246 | self.conv1 = nn.Sequential( 247 | nn.Conv2d(channels, in_channels, 1), nn.BatchNorm2d(in_channels), nn.ReLU(True) 248 | ) 249 | 250 | self.Dconv3 = nn.Sequential( 251 | nn.Conv2d(in_channels, in_channels, 1), nn.BatchNorm2d(in_channels), 252 | nn.Conv2d(in_channels, in_channels, 3, padding=3,dilation=3), nn.BatchNorm2d(in_channels), nn.ReLU(True) 253 | ) 254 | 255 | self.Dconv5 = nn.Sequential( 256 | nn.Conv2d(in_channels, in_channels, 1), nn.BatchNorm2d(in_channels), 257 | nn.Conv2d(in_channels, in_channels, 3, padding=5,dilation=5), nn.BatchNorm2d(in_channels), nn.ReLU(True) 258 | ) 259 | self.Dconv7 = nn.Sequential( 260 | nn.Conv2d(in_channels, in_channels, 1), nn.BatchNorm2d(in_channels), 261 | nn.Conv2d(in_channels, in_channels, 3, padding=7,dilation=7), nn.BatchNorm2d(in_channels), nn.ReLU(True) 262 | ) 263 | self.Dconv9 = nn.Sequential( 264 | nn.Conv2d(in_channels, in_channels, 1), nn.BatchNorm2d(in_channels), 265 | nn.Conv2d(in_channels, in_channels, 3, padding=9,dilation=9), nn.BatchNorm2d(in_channels),nn.ReLU(True) 266 | ) 267 | 268 | self.reduce = nn.Sequential( 269 | nn.Conv2d(in_channels * 5, in_channels, 1), nn.BatchNorm2d(in_channels),nn.ReLU(True) 270 | ) 271 | 272 | self.out = nn.Sequential( 273 | nn.Conv2d(in_channels, in_channels//2, kernel_size=3, padding=1), nn.BatchNorm2d(in_channels//2), nn.ReLU(True), 274 | nn.Conv2d(in_channels//2, 1, kernel_size=1) 275 | ) 276 | 277 | self.weight = nn.Sequential( 278 | nn.Conv2d(in_channels, in_channels // 16, 1, bias=True), 279 | nn.BatchNorm2d(in_channels // 16), 280 | nn.ReLU(True), 281 | nn.Conv2d(in_channels // 16, in_channels, 1, bias=True), 282 | nn.Sigmoid()) 283 | 284 | self.norm = nn.BatchNorm2d(in_channels) 285 | self.relu = nn.ReLU(True) 286 | 287 | def forward(self, F1): 288 | 289 | F1_input = self.conv1(F1) 290 | F1_3_t = self.Dconv3(F1_input) 291 | F1_3_f = self.relu(self.norm(torch.abs(torch.fft.ifft2(self.weight(torch.fft.fft2(F1_3_t.float()).real)*torch.fft.fft2(F1_3_t.float()))))) 292 | F1_3 = torch.add(F1_3_t,F1_3_f) 293 | 294 | F1_5_t = self.Dconv5(F1_input + F1_3) 295 | F1_5_f = self.relu(self.norm(torch.abs(torch.fft.ifft2(self.weight(torch.fft.fft2(F1_5_t.float()).real)*torch.fft.fft2(F1_5_t.float()))))) 296 | F1_5 = torch.add(F1_5_t, F1_5_f) 297 | 298 | F1_7_t = self.Dconv7(F1_input + F1_5) 299 | F1_7_f = self.relu(self.norm(torch.abs(torch.fft.ifft2(self.weight(torch.fft.fft2(F1_7_t.float()).real)*torch.fft.fft2(F1_7_t.float()))))) 300 | F1_7 = torch.add(F1_7_t, F1_7_f) 301 | 302 | F1_9_t = self.Dconv9(F1_input + F1_7) 303 | F1_9_f = self.relu(self.norm(torch.abs(torch.fft.ifft2(self.weight(torch.fft.fft2(F1_9_t.float()).real)*torch.fft.fft2(F1_9_t.float()))))) 304 | F1_9 = torch.add(F1_9_t, F1_9_f) 305 | 306 | F_out = self.out(self.reduce(torch.cat((F1_3,F1_5,F1_7,F1_9,F1_input),1)) + F1_input ) 307 | 308 | return F_out 309 | 310 | 311 | class Module3_1(nn.Module): 312 | def __init__(self, in_channels, mid_channels): 313 | super(Module3_1, self).__init__() 314 | self.conv = nn.Sequential( 315 | nn.Conv2d(in_channels * 2, in_channels, kernel_size=1), nn.BatchNorm2d(in_channels), 316 | nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1, stride=1), nn.BatchNorm2d(in_channels), 317 | nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1, stride=1), nn.BatchNorm2d(in_channels), nn.ReLU(True) 318 | 319 | ) 320 | 321 | self.out = nn.Sequential( 322 | nn.Conv2d(in_channels * 2, mid_channels, kernel_size=3, padding=1), nn.BatchNorm2d(mid_channels),nn.ReLU(True), 323 | nn.Conv2d(mid_channels, 1, kernel_size=1) 324 | ) 325 | 326 | self.conv3 = nn.Sequential( 327 | nn.Conv2d(in_channels, in_channels, kernel_size=1), nn.BatchNorm2d(in_channels), 328 | nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1, stride=1), nn.BatchNorm2d(in_channels), 329 | nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1, stride=1), nn.BatchNorm2d(in_channels),nn.ReLU(True), 330 | ) 331 | 332 | self.weight = nn.Sequential( 333 | nn.Conv2d(in_channels, in_channels // 16, 1, bias=True), 334 | nn.BatchNorm2d(in_channels // 16), 335 | nn.ReLU(True), 336 | nn.Conv2d(in_channels // 16, in_channels, 1, bias=True), 337 | nn.Sigmoid()) 338 | 339 | self.norm = nn.BatchNorm2d(in_channels) 340 | self.relu = nn.ReLU(in_channels) 341 | 342 | def forward(self, X, prior_cam): 343 | prior_cam = F.interpolate(prior_cam, size=X.size()[2:], mode='bilinear',align_corners=True) # 2,1,12,12->2,1,48,48 344 | 345 | FI = X 346 | 347 | yt = self.conv(torch.cat([FI, prior_cam.expand(-1, X.size()[1], -1, -1)], dim=1)) 348 | 349 | yt_t = self.conv3(yt) 350 | yt_f = self.relu(self.norm(torch.abs(torch.fft.ifft2(self.weight(torch.fft.fft2(yt.float()).real)*torch.fft.fft2(yt.float()))))) 351 | yt_out = torch.add(yt_t,yt_f) 352 | 353 | r_prior_cam_f = torch.abs(torch.fft.fft2(prior_cam)) 354 | r_prior_cam_f = -1 * (torch.sigmoid(r_prior_cam_f)) + 1 355 | r_prior_cam_t = -1 * (torch.sigmoid(prior_cam)) + 1 356 | r_prior_cam = r_prior_cam_t+r_prior_cam_f 357 | 358 | y_1 = r_prior_cam.expand(-1, X.size()[1], -1, -1).mul(FI) 359 | 360 | cat2 = torch.cat([y_1, yt_out], dim=1) # 2,128,48,48 361 | 362 | y = self.out(cat2) 363 | y = y + prior_cam 364 | return y 365 | 366 | class Module3_2(nn.Module): 367 | def __init__(self, in_channels, mid_channels): 368 | super(Module3_2, self).__init__() 369 | self.conv = nn.Sequential( 370 | nn.Conv2d(in_channels * 3, in_channels, kernel_size=1), nn.BatchNorm2d(in_channels), 371 | nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1, stride=1), nn.BatchNorm2d(in_channels), 372 | nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1, stride=1), nn.BatchNorm2d(in_channels),nn.ReLU(True), 373 | ) 374 | 375 | self.conv3 = nn.Sequential( 376 | nn.Conv2d(in_channels, in_channels, kernel_size=1), nn.BatchNorm2d(in_channels), 377 | nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1, stride=1), nn.BatchNorm2d(in_channels), 378 | nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1, stride=1), nn.BatchNorm2d(in_channels),nn.ReLU(True), 379 | ) 380 | 381 | self.out = nn.Sequential( 382 | nn.Conv2d(in_channels * 2, mid_channels, kernel_size=3, padding=1), nn.BatchNorm2d(mid_channels),nn.ReLU(True), 383 | nn.Conv2d(mid_channels, 1, kernel_size=1) 384 | ) 385 | 386 | self.weight = nn.Sequential( 387 | nn.Conv2d(in_channels, in_channels // 16, 1, bias=True), 388 | nn.BatchNorm2d(in_channels // 16), 389 | nn.ReLU(True), 390 | nn.Conv2d(in_channels // 16, in_channels, 1, bias=True), 391 | nn.Sigmoid()) 392 | 393 | self.norm = nn.BatchNorm2d(in_channels) 394 | self.relu = nn.ReLU(True) 395 | 396 | def forward(self, X, x1, prior_cam): 397 | prior_cam = F.interpolate(prior_cam, size=X.size()[2:], mode='bilinear', 398 | align_corners=True) # 399 | x1_prior_cam = F.interpolate(x1, size=X.size()[2:], mode='bilinear', align_corners=True) 400 | FI = X 401 | 402 | yt = self.conv(torch.cat([FI, prior_cam.expand(-1, X.size()[1], -1, -1), x1_prior_cam.expand(-1, X.size()[1], -1, -1)],dim=1)) 403 | 404 | yt_t = self.conv3(yt) 405 | yt_f = self.relu(self.norm(torch.abs(torch.fft.ifft2(self.weight(torch.fft.fft2(yt.float()).real) * torch.fft.fft2(yt.float()))))) 406 | yt_out = torch.add(yt_t, yt_f) 407 | 408 | r_prior_cam_f = torch.abs(torch.fft.fft2(prior_cam)) 409 | r_prior_cam_f = -1 * (torch.sigmoid(r_prior_cam_f)) + 1 410 | r_prior_cam_t = -1 * (torch.sigmoid(prior_cam)) + 1 411 | r_prior_cam = r_prior_cam_t+r_prior_cam_f 412 | 413 | r1_prior_cam_f = torch.abs(torch.fft.fft2(x1_prior_cam)) 414 | r1_prior_cam_f = -1 * (torch.sigmoid(r1_prior_cam_f)) + 1 415 | r1_prior_cam_t = -1 * (torch.sigmoid(x1_prior_cam)) + 1 416 | r1_prior_cam = r1_prior_cam_t+r1_prior_cam_f 417 | 418 | r_prior_cam = r_prior_cam + r1_prior_cam 419 | 420 | y_1 = r_prior_cam.expand(-1, X.size()[1], -1, -1).mul(FI) 421 | 422 | cat2 = torch.cat([y_1, yt_out], dim=1) # 423 | 424 | y = self.out(cat2) 425 | y = y + prior_cam + x1_prior_cam 426 | return y 427 | 428 | class Module3_3(nn.Module): 429 | def __init__(self, in_channels, mid_channels): 430 | super(Module3_3, self).__init__() 431 | self.conv = nn.Sequential( 432 | nn.Conv2d(in_channels * 4, in_channels, kernel_size=1), nn.BatchNorm2d(in_channels), 433 | nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1, stride=1), nn.BatchNorm2d(in_channels), 434 | nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1, stride=1), nn.BatchNorm2d(in_channels),nn.ReLU(True), 435 | ) 436 | 437 | self.conv3 = nn.Sequential( 438 | nn.Conv2d(in_channels, in_channels, kernel_size=1), nn.BatchNorm2d(in_channels), 439 | nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1, stride=1), nn.BatchNorm2d(in_channels), 440 | nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1, stride=1), nn.BatchNorm2d(in_channels),nn.ReLU(True), 441 | ) 442 | 443 | self.out = nn.Sequential( 444 | nn.Conv2d(in_channels * 2, mid_channels, kernel_size=3, padding=1), nn.BatchNorm2d(mid_channels),nn.ReLU(True), 445 | nn.Conv2d(mid_channels, 1, kernel_size=1) 446 | ) 447 | 448 | self.weight = nn.Sequential( 449 | nn.Conv2d(in_channels, in_channels // 16, 1, bias=True), 450 | nn.BatchNorm2d(in_channels // 16), 451 | nn.ReLU(True), 452 | nn.Conv2d(in_channels // 16, in_channels, 1, bias=True), 453 | nn.Sigmoid()) 454 | 455 | self.norm = nn.BatchNorm2d(in_channels) 456 | self.relu = nn.ReLU(True) 457 | 458 | def forward(self, X, x1,x2, prior_cam): 459 | prior_cam = F.interpolate(prior_cam, size=X.size()[2:], mode='bilinear',align_corners=True) # 460 | x1_prior_cam = F.interpolate(x1, size=X.size()[2:], mode='bilinear', align_corners=True) 461 | x2_prior_cam = F.interpolate(x2, size=X.size()[2:], mode='bilinear', align_corners=True) 462 | 463 | FI = X 464 | 465 | yt = self.conv(torch.cat([FI, prior_cam.expand(-1, X.size()[1], -1, -1), x1_prior_cam.expand(-1, X.size()[1], -1, -1),x2_prior_cam.expand(-1, X.size()[1], -1, -1)],dim=1)) 466 | 467 | yt_t = self.conv3(yt) 468 | yt_f = self.relu(self.norm(torch.abs(torch.fft.ifft2(self.weight(torch.fft.fft2(yt.float()).real) * torch.fft.fft2(yt.float()))))) 469 | yt_out = torch.add(yt_t, yt_f) 470 | 471 | r_prior_cam_f = torch.abs(torch.fft.fft2(prior_cam)) 472 | r_prior_cam_f = -1 * (torch.sigmoid(r_prior_cam_f)) + 1 473 | r_prior_cam_t = -1 * (torch.sigmoid(prior_cam)) + 1 474 | r_prior_cam = r_prior_cam_t+r_prior_cam_f 475 | 476 | r1_prior_cam_f = torch.abs(torch.fft.fft2(x1_prior_cam)) 477 | r1_prior_cam_f = -1 * (torch.sigmoid(r1_prior_cam_f)) + 1 478 | r1_prior_cam_t = -1 * (torch.sigmoid(x1_prior_cam)) + 1 479 | r1_prior_cam1 = r1_prior_cam_t+r1_prior_cam_f 480 | 481 | r2_prior_cam_f = torch.abs(torch.fft.fft2(x2_prior_cam)) 482 | r2_prior_cam_f = -1 * (torch.sigmoid(r2_prior_cam_f)) + 1 483 | r2_prior_cam_t = -1 * (torch.sigmoid(x2_prior_cam)) + 1 484 | r1_prior_cam2 = r2_prior_cam_t + r2_prior_cam_f 485 | 486 | r_prior_cam = r_prior_cam + r1_prior_cam1+r1_prior_cam2 487 | 488 | y_1 = r_prior_cam.expand(-1, X.size()[1], -1, -1).mul(FI) 489 | 490 | cat2 = torch.cat([y_1, yt_out], dim=1) # 491 | 492 | y = self.out(cat2) 493 | 494 | y = y + prior_cam + x1_prior_cam + x2_prior_cam 495 | 496 | return y 497 | 498 | 499 | 500 | 501 | 502 | 503 | 504 | 505 | 506 | 507 | 508 | 509 | 510 | 511 | 512 | 513 | 514 | 515 | 516 | 517 | 518 | 519 | -------------------------------------------------------------------------------- /FSEL_ECCV_2024/test/saliency_metric.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy import ndimage 3 | from scipy.ndimage import convolve, distance_transform_edt as bwdist 4 | 5 | 6 | class cal_fm(object): 7 | # Fmeasure(maxFm,meanFm)---Frequency-tuned salient region detection(CVPR 2009) 8 | def __init__(self, num, thds=255): 9 | self.num = num 10 | self.thds = thds 11 | self.precision = np.zeros((self.num, self.thds)) 12 | self.recall = np.zeros((self.num, self.thds)) 13 | self.meanF = np.zeros((self.num,1)) 14 | self.idx = 0 15 | 16 | def update(self, pred, gt): 17 | if gt.max() != 0: 18 | prediction, recall, Fmeasure_temp = self.cal(pred, gt) 19 | self.precision[self.idx, :] = prediction 20 | self.recall[self.idx, :] = recall 21 | self.meanF[self.idx, :] = Fmeasure_temp 22 | self.idx += 1 23 | 24 | def cal(self, pred, gt): 25 | ########################meanF############################## 26 | th = 2 * pred.mean() 27 | if th > 1: 28 | th = 1 29 | binary = np.zeros_like(pred) 30 | binary[pred >= th] = 1 31 | hard_gt = np.zeros_like(gt) 32 | hard_gt[gt > 0.5] = 1 33 | tp = (binary * hard_gt).sum() 34 | if tp == 0: 35 | meanF = 0 36 | else: 37 | pre = tp / binary.sum() 38 | rec = tp / hard_gt.sum() 39 | meanF = 1.3 * pre * rec / (0.3 * pre + rec) 40 | ########################maxF############################## 41 | pred = np.uint8(pred * 255) 42 | target = pred[gt > 0.5] 43 | nontarget = pred[gt <= 0.5] 44 | targetHist, _ = np.histogram(target, bins=range(256)) 45 | nontargetHist, _ = np.histogram(nontarget, bins=range(256)) 46 | targetHist = np.cumsum(np.flip(targetHist), axis=0) 47 | nontargetHist = np.cumsum(np.flip(nontargetHist), axis=0) 48 | precision = targetHist / (targetHist + nontargetHist + 1e-8) 49 | recall = targetHist / np.sum(gt) 50 | return precision, recall, meanF 51 | 52 | def show(self): 53 | assert self.num == self.idx 54 | precision = self.precision.mean(axis=0) 55 | recall = self.recall.mean(axis=0) 56 | fmeasure = 1.3 * precision * recall / (0.3 * precision + recall + 1e-8) 57 | fmeasure_avg = self.meanF.mean(axis=0) 58 | return fmeasure.max(),fmeasure_avg[0],precision,recall 59 | 60 | 61 | class cal_mae(object): 62 | # mean absolute error 63 | def __init__(self): 64 | self.prediction = [] 65 | 66 | def update(self, pred, gt): 67 | score = self.cal(pred, gt) 68 | self.prediction.append(score) 69 | 70 | def cal(self, pred, gt): 71 | return np.mean(np.abs(pred - gt)) 72 | 73 | def show(self): 74 | return np.mean(self.prediction) 75 | 76 | 77 | class cal_sm(object): 78 | # Structure-measure: A new way to evaluate foreground maps (ICCV 2017) 79 | def __init__(self, alpha=0.5): 80 | self.prediction = [] 81 | self.alpha = alpha 82 | 83 | def update(self, pred, gt): 84 | gt = gt > 0.5 85 | score = self.cal(pred, gt) 86 | self.prediction.append(score) 87 | 88 | def show(self): 89 | return np.mean(self.prediction) 90 | 91 | def cal(self, pred, gt): 92 | y = np.mean(gt) 93 | if y == 0: 94 | score = 1 - np.mean(pred) 95 | elif y == 1: 96 | score = np.mean(pred) 97 | else: 98 | score = self.alpha * self.object(pred, gt) + (1 - self.alpha) * self.region(pred, gt) 99 | return score 100 | 101 | def object(self, pred, gt): 102 | fg = pred * gt 103 | bg = (1 - pred) * (1 - gt) 104 | 105 | u = np.mean(gt) 106 | return u * self.s_object(fg, gt) + (1 - u) * self.s_object(bg, np.logical_not(gt)) 107 | 108 | def s_object(self, in1, in2): 109 | x = np.mean(in1[in2]) 110 | sigma_x = np.std(in1[in2]) 111 | return 2 * x / (pow(x, 2) + 1 + sigma_x + 1e-8) 112 | 113 | def region(self, pred, gt): 114 | [y, x] = ndimage.center_of_mass(gt) 115 | y = int(round(y)) + 1 116 | x = int(round(x)) + 1 117 | [gt1, gt2, gt3, gt4, w1, w2, w3, w4] = self.divideGT(gt, x, y) 118 | pred1, pred2, pred3, pred4 = self.dividePred(pred, x, y) 119 | 120 | score1 = self.ssim(pred1, gt1) 121 | score2 = self.ssim(pred2, gt2) 122 | score3 = self.ssim(pred3, gt3) 123 | score4 = self.ssim(pred4, gt4) 124 | 125 | return w1 * score1 + w2 * score2 + w3 * score3 + w4 * score4 126 | 127 | def divideGT(self, gt, x, y): 128 | h, w = gt.shape 129 | area = h * w 130 | LT = gt[0:y, 0:x] 131 | RT = gt[0:y, x:w] 132 | LB = gt[y:h, 0:x] 133 | RB = gt[y:h, x:w] 134 | 135 | w1 = x * y / area 136 | w2 = y * (w - x) / area 137 | w3 = (h - y) * x / area 138 | w4 = (h - y) * (w - x) / area 139 | 140 | return LT, RT, LB, RB, w1, w2, w3, w4 141 | 142 | def dividePred(self, pred, x, y): 143 | h, w = pred.shape 144 | LT = pred[0:y, 0:x] 145 | RT = pred[0:y, x:w] 146 | LB = pred[y:h, 0:x] 147 | RB = pred[y:h, x:w] 148 | 149 | return LT, RT, LB, RB 150 | 151 | def ssim(self, in1, in2): 152 | in2 = np.float32(in2) 153 | h, w = in1.shape 154 | N = h * w 155 | 156 | x = np.mean(in1) 157 | y = np.mean(in2) 158 | sigma_x = np.var(in1) 159 | sigma_y = np.var(in2) 160 | sigma_xy = np.sum((in1 - x) * (in2 - y)) / (N - 1) 161 | 162 | alpha = 4 * x * y * sigma_xy 163 | beta = (x * x + y * y) * (sigma_x + sigma_y) 164 | 165 | if alpha != 0: 166 | score = alpha / (beta + 1e-8) 167 | elif alpha == 0 and beta == 0: 168 | score = 1 169 | else: 170 | score = 0 171 | 172 | return score 173 | 174 | class cal_em(object): 175 | #Enhanced-alignment Measure for Binary Foreground Map Evaluation (IJCAI 2018) 176 | def __init__(self): 177 | self.prediction = [] 178 | 179 | def update(self, pred, gt): 180 | score = self.cal(pred, gt) 181 | self.prediction.append(score) 182 | 183 | def cal(self, pred, gt): 184 | th = 2 * pred.mean() 185 | if th > 1: 186 | th = 1 187 | FM = np.zeros(gt.shape) 188 | FM[pred >= th] = 1 189 | FM = np.array(FM,dtype=bool) 190 | GT = np.array(gt,dtype=bool) 191 | dFM = np.double(FM) 192 | if (sum(sum(np.double(GT)))==0): 193 | enhanced_matrix = 1.0-dFM 194 | elif (sum(sum(np.double(~GT)))==0): 195 | enhanced_matrix = dFM 196 | else: 197 | dGT = np.double(GT) 198 | align_matrix = self.AlignmentTerm(dFM, dGT) 199 | enhanced_matrix = self.EnhancedAlignmentTerm(align_matrix) 200 | [w, h] = np.shape(GT) 201 | score = sum(sum(enhanced_matrix))/ (w * h - 1 + 1e-8) 202 | return score 203 | def AlignmentTerm(self,dFM,dGT): 204 | mu_FM = np.mean(dFM) 205 | mu_GT = np.mean(dGT) 206 | align_FM = dFM - mu_FM 207 | align_GT = dGT - mu_GT 208 | align_Matrix = 2. * (align_GT * align_FM)/ (align_GT* align_GT + align_FM* align_FM + 1e-8) 209 | return align_Matrix 210 | def EnhancedAlignmentTerm(self,align_Matrix): 211 | enhanced = np.power(align_Matrix + 1,2) / 4 212 | return enhanced 213 | def show(self): 214 | return np.mean(self.prediction) 215 | class cal_wfm(object): 216 | def __init__(self, beta=1): 217 | self.beta = beta 218 | self.eps = 1e-6 219 | self.scores_list = [] 220 | 221 | def update(self, pred, gt): 222 | assert pred.ndim == gt.ndim and pred.shape == gt.shape 223 | assert pred.max() <= 1 and pred.min() >= 0 224 | assert gt.max() <= 1 and gt.min() >= 0 225 | 226 | gt = gt > 0.5 227 | if gt.max() == 0: 228 | score = 0 229 | else: 230 | score = self.cal(pred, gt) 231 | self.scores_list.append(score) 232 | 233 | def matlab_style_gauss2D(self, shape=(7, 7), sigma=5): 234 | """ 235 | 2D gaussian mask - should give the same result as MATLAB's 236 | fspecial('gaussian',[shape],[sigma]) 237 | """ 238 | m, n = [(ss - 1.) / 2. for ss in shape] 239 | y, x = np.ogrid[-m:m + 1, -n:n + 1] 240 | h = np.exp(-(x * x + y * y) / (2. * sigma * sigma)) 241 | h[h < np.finfo(h.dtype).eps * h.max()] = 0 242 | sumh = h.sum() 243 | if sumh != 0: 244 | h /= sumh 245 | return h 246 | 247 | def cal(self, pred, gt): 248 | # [Dst,IDXT] = bwdist(dGT); 249 | Dst, Idxt = bwdist(gt == 0, return_indices=True) 250 | 251 | # %Pixel dependency 252 | # E = abs(FG-dGT); 253 | E = np.abs(pred - gt) 254 | # Et = E; 255 | # Et(~GT)=Et(IDXT(~GT)); %To deal correctly with the edges of the foreground region 256 | Et = np.copy(E) 257 | Et[gt == 0] = Et[Idxt[0][gt == 0], Idxt[1][gt == 0]] 258 | 259 | # K = fspecial('gaussian',7,5); 260 | # EA = imfilter(Et,K); 261 | # MIN_E_EA(GT & EA tuple: 11 | """ 12 | A numpy-based function for preparing ``pred`` and ``gt``. 13 | 14 | - for ``pred``, it looks like ``mapminmax(im2double(...))`` of matlab; 15 | - ``gt`` will be binarized by 128. 16 | 17 | :param pred: prediction 18 | :param gt: mask 19 | :return: pred, gt 20 | """ 21 | 22 | """ 23 | gt = gt > 128 24 | # im2double, mapminmax 25 | pred = np.array(pred) 26 | pred = pred / 255 27 | 28 | if pred.max() != pred.min(): 29 | pred = (pred - pred.min()) / (pred.max() - pred.min()) 30 | """ 31 | return pred, gt 32 | 33 | 34 | def _get_adaptive_threshold(matrix: np.ndarray, max_value: float = 1) -> float: 35 | """ 36 | Return an adaptive threshold, which is equal to twice the mean of ``matrix``. 37 | 38 | :param matrix: a data array 39 | :param max_value: the upper limit of the threshold 40 | :return: min(2 * matrix.mean(), max_value) 41 | """ 42 | return min(2 * matrix.mean(), max_value) 43 | 44 | 45 | class Fmeasure(object): 46 | def __init__(self, beta: float = 0.3): 47 | """ 48 | F-measure for SOD. 49 | 50 | :: 51 | 52 | @inproceedings{Fmeasure, 53 | title={Frequency-tuned salient region detection}, 54 | author={Achanta, Radhakrishna and Hemami, Sheila and Estrada, Francisco and S{\"u}sstrunk, Sabine}, 55 | booktitle=CVPR, 56 | number={CONF}, 57 | pages={1597--1604}, 58 | year={2009} 59 | } 60 | 61 | :param beta: the weight of the precision 62 | """ 63 | self.beta = beta 64 | self.precisions = [] 65 | self.recalls = [] 66 | self.adaptive_fms = [] 67 | self.changeable_fms = [] 68 | 69 | def step(self, pred: np.ndarray, gt: np.ndarray): 70 | pred, gt = _prepare_data(pred, gt) 71 | 72 | adaptive_fm = self.cal_adaptive_fm(pred=pred, gt=gt) 73 | self.adaptive_fms.append(adaptive_fm) 74 | 75 | precisions, recalls, changeable_fms = self.cal_pr(pred=pred, gt=gt) 76 | self.precisions.append(precisions) 77 | self.recalls.append(recalls) 78 | self.changeable_fms.append(changeable_fms) 79 | 80 | def cal_adaptive_fm(self, pred: np.ndarray, gt: np.ndarray) -> float: 81 | """ 82 | Calculate the adaptive F-measure. 83 | 84 | :return: adaptive_fm 85 | """ 86 | # ``np.count_nonzero`` is faster and better 87 | adaptive_threshold = _get_adaptive_threshold(pred, max_value=1) 88 | binary_predcition = pred >= adaptive_threshold 89 | area_intersection = binary_predcition[gt].sum() 90 | if area_intersection == 0: 91 | adaptive_fm = 0 92 | else: 93 | pre = area_intersection / np.count_nonzero(binary_predcition) 94 | rec = area_intersection / np.count_nonzero(gt) 95 | adaptive_fm = (1 + self.beta) * pre * rec / (self.beta * pre + rec) 96 | return adaptive_fm 97 | 98 | def cal_pr(self, pred: np.ndarray, gt: np.ndarray) -> tuple: 99 | """ 100 | Calculate the corresponding precision and recall when the threshold changes from 0 to 255. 101 | 102 | These precisions and recalls can be used to obtain the mean F-measure, maximum F-measure, 103 | precision-recall curve and F-measure-threshold curve. 104 | 105 | For convenience, ``changeable_fms`` is provided here, which can be used directly to obtain 106 | the mean F-measure, maximum F-measure and F-measure-threshold curve. 107 | 108 | :return: precisions, recalls, changeable_fms 109 | """ 110 | # 1. 获取预测结果在真值前背景区域中的直方图 111 | pred = (pred * 255).astype(np.uint8) 112 | bins = np.linspace(0, 256, 257) 113 | fg_hist, _ = np.histogram(pred[gt], bins=bins) # 最后一个bin为[255, 256] 114 | bg_hist, _ = np.histogram(pred[~gt], bins=bins) 115 | # 2. 使用累积直方图(Cumulative Histogram)获得对应真值前背景中大于不同阈值的像素数量 116 | # 这里使用累加(cumsum)就是为了一次性得出 >=不同阈值 的像素数量, 这里仅计算了前景区域 117 | fg_w_thrs = np.cumsum(np.flip(fg_hist), axis=0) 118 | bg_w_thrs = np.cumsum(np.flip(bg_hist), axis=0) 119 | # 3. 使用不同阈值的结果计算对应的precision和recall 120 | # p和r的计算的真值是pred==1>==1,二者仅有分母不同,分母前者是pred==1,后者是gt==1 121 | # 为了同时计算不同阈值的结果,这里使用hsitogram&flip&cumsum 获得了不同各自的前景像素数量 122 | TPs = fg_w_thrs 123 | Ps = fg_w_thrs + bg_w_thrs 124 | # 为防止除0,这里针对除0的情况分析后直接对于0分母设为1,因为此时分子必为0 125 | Ps[Ps == 0] = 1 126 | T = max(np.count_nonzero(gt), 1) 127 | # TODO: T=0 或者 特定阈值下fg_w_thrs=0或者bg_w_thrs=0,这些都会包含在TPs[i]=0的情况中, 128 | # 但是这里使用TPs不便于处理列表 129 | precisions = TPs / Ps 130 | recalls = TPs / T 131 | 132 | numerator = (1 + self.beta) * precisions * recalls 133 | denominator = np.where(numerator == 0, 1, self.beta * precisions + recalls) 134 | changeable_fms = numerator / denominator 135 | return precisions, recalls, changeable_fms 136 | 137 | def get_results(self) -> dict: 138 | """ 139 | Return the results about F-measure. 140 | 141 | :return: dict(fm=dict(adp=adaptive_fm, curve=changeable_fm), pr=dict(p=precision, r=recall)) 142 | """ 143 | adaptive_fm = np.mean(np.array(self.adaptive_fms, _TYPE)) 144 | changeable_fm = np.mean(np.array(self.changeable_fms, dtype=_TYPE), axis=0) 145 | precision = np.mean(np.array(self.precisions, dtype=_TYPE), axis=0) # N, 256 146 | recall = np.mean(np.array(self.recalls, dtype=_TYPE), axis=0) # N, 256 147 | return dict(fm=dict(adp=adaptive_fm, curve=changeable_fm), pr=dict(p=precision, r=recall)) 148 | 149 | 150 | class MAE(object): 151 | def __init__(self): 152 | """ 153 | MAE(mean absolute error) for SOD. 154 | 155 | :: 156 | 157 | @inproceedings{MAE, 158 | title={Saliency filters: Contrast based filtering for salient region detection}, 159 | author={Perazzi, Federico and Kr{\"a}henb{\"u}hl, Philipp and Pritch, Yael and Hornung, Alexander}, 160 | booktitle=CVPR, 161 | pages={733--740}, 162 | year={2012} 163 | } 164 | """ 165 | self.maes = [] 166 | 167 | def step(self, pred: np.ndarray, gt: np.ndarray): 168 | pred, gt = _prepare_data(pred, gt) 169 | 170 | mae = self.cal_mae(pred, gt) 171 | self.maes.append(mae) 172 | 173 | def cal_mae(self, pred: np.ndarray, gt: np.ndarray) -> np.ndarray: 174 | """ 175 | Calculate the mean absolute error. 176 | 177 | :return: mae 178 | """ 179 | mae = np.mean(np.abs(pred - gt)) 180 | return mae 181 | 182 | def get_results(self) -> dict: 183 | """ 184 | Return the results about MAE. 185 | 186 | :return: dict(mae=mae) 187 | """ 188 | mae = np.mean(np.array(self.maes, _TYPE)) 189 | return dict(mae=mae) 190 | 191 | 192 | class Smeasure(object): 193 | def __init__(self, alpha: float = 0.5): 194 | """ 195 | S-measure(Structure-measure) of SOD. 196 | 197 | :: 198 | 199 | @inproceedings{Smeasure, 200 | title={Structure-measure: A new way to eval foreground maps}, 201 | author={Fan, Deng-Ping and Cheng, Ming-Ming and Liu, Yun and Li, Tao and Borji, Ali}, 202 | booktitle=ICCV, 203 | pages={4548--4557}, 204 | year={2017} 205 | } 206 | 207 | :param alpha: the weight for balancing the object score and the region score 208 | """ 209 | self.sms = [] 210 | self.alpha = alpha 211 | 212 | def step(self, pred: np.ndarray, gt: np.ndarray): 213 | pred, gt = _prepare_data(pred=pred, gt=gt) 214 | 215 | sm = self.cal_sm(pred, gt) 216 | self.sms.append(sm) 217 | 218 | def cal_sm(self, pred: np.ndarray, gt: np.ndarray) -> float: 219 | """ 220 | Calculate the S-measure. 221 | 222 | :return: s-measure 223 | """ 224 | y = np.mean(gt) 225 | if y == 0: 226 | sm = 1 - np.mean(pred) 227 | elif y == 1: 228 | sm = np.mean(pred) 229 | else: 230 | sm = self.alpha * self.object(pred, gt) + (1 - self.alpha) * self.region(pred, gt) 231 | sm = max(0, sm) 232 | return sm 233 | 234 | def object(self, pred: np.ndarray, gt: np.ndarray) -> float: 235 | """ 236 | Calculate the object score. 237 | """ 238 | fg = pred * gt 239 | bg = (1 - pred) * (1 - gt) 240 | u = np.mean(gt) 241 | object_score = u * self.s_object(fg, gt) + (1 - u) * self.s_object(bg, 1 - gt) 242 | return object_score 243 | 244 | def s_object(self, pred: np.ndarray, gt: np.ndarray) -> float: 245 | x = np.mean(pred[gt == 1]) 246 | sigma_x = np.std(pred[gt == 1]) 247 | score = 2 * x / (np.power(x, 2) + 1 + sigma_x + _EPS) 248 | return score 249 | 250 | def region(self, pred: np.ndarray, gt: np.ndarray) -> float: 251 | """ 252 | Calculate the region score. 253 | """ 254 | x, y = self.centroid(gt) 255 | part_info = self.divide_with_xy(pred, gt, x, y) 256 | w1, w2, w3, w4 = part_info["weight"] 257 | # assert np.isclose(w1 + w2 + w3 + w4, 1), (w1 + w2 + w3 + w4, pred.mean(), gt.mean()) 258 | 259 | pred1, pred2, pred3, pred4 = part_info["pred"] 260 | gt1, gt2, gt3, gt4 = part_info["gt"] 261 | score1 = self.ssim(pred1, gt1) 262 | score2 = self.ssim(pred2, gt2) 263 | score3 = self.ssim(pred3, gt3) 264 | score4 = self.ssim(pred4, gt4) 265 | 266 | return w1 * score1 + w2 * score2 + w3 * score3 + w4 * score4 267 | 268 | def centroid(self, matrix: np.ndarray) -> tuple: 269 | """ 270 | To ensure consistency with the matlab code, one is added to the centroid coordinate, 271 | so there is no need to use the redundant addition operation when dividing the region later, 272 | because the sequence generated by ``1:X`` in matlab will contain ``X``. 273 | 274 | :param matrix: a data array 275 | :return: the centroid coordinate 276 | """ 277 | h, w = matrix.shape 278 | if matrix.sum() == 0: 279 | x = np.round(w / 2) 280 | y = np.round(h / 2) 281 | else: 282 | area_object = np.sum(matrix) 283 | row_ids = np.arange(h) 284 | col_ids = np.arange(w) 285 | x = np.round(np.sum(np.sum(matrix, axis=0) * col_ids) / area_object) 286 | y = np.round(np.sum(np.sum(matrix, axis=1) * row_ids) / area_object) 287 | return int(x) + 1, int(y) + 1 288 | 289 | def divide_with_xy(self, pred: np.ndarray, gt: np.ndarray, x: int, y: int) -> dict: 290 | """ 291 | Use (x,y) to divide the ``pred`` and the ``gt`` into four submatrices, respectively. 292 | """ 293 | h, w = gt.shape 294 | area = h * w 295 | 296 | gt_LT = gt[0:y, 0:x] 297 | gt_RT = gt[0:y, x:w] 298 | gt_LB = gt[y:h, 0:x] 299 | gt_RB = gt[y:h, x:w] 300 | 301 | pred_LT = pred[0:y, 0:x] 302 | pred_RT = pred[0:y, x:w] 303 | pred_LB = pred[y:h, 0:x] 304 | pred_RB = pred[y:h, x:w] 305 | 306 | w1 = x * y / area 307 | w2 = y * (w - x) / area 308 | w3 = (h - y) * x / area 309 | w4 = 1 - w1 - w2 - w3 310 | 311 | return dict( 312 | gt=(gt_LT, gt_RT, gt_LB, gt_RB), 313 | pred=(pred_LT, pred_RT, pred_LB, pred_RB), 314 | weight=(w1, w2, w3, w4), 315 | ) 316 | 317 | def ssim(self, pred: np.ndarray, gt: np.ndarray) -> float: 318 | """ 319 | Calculate the ssim score. 320 | """ 321 | h, w = pred.shape 322 | N = h * w 323 | 324 | x = np.mean(pred) 325 | y = np.mean(gt) 326 | 327 | sigma_x = np.sum((pred - x) ** 2) / (N - 1) 328 | sigma_y = np.sum((gt - y) ** 2) / (N - 1) 329 | sigma_xy = np.sum((pred - x) * (gt - y)) / (N - 1) 330 | 331 | alpha = 4 * x * y * sigma_xy 332 | beta = (x ** 2 + y ** 2) * (sigma_x + sigma_y) 333 | 334 | if alpha != 0: 335 | score = alpha / (beta + _EPS) 336 | elif alpha == 0 and beta == 0: 337 | score = 1 338 | else: 339 | score = 0 340 | return score 341 | 342 | def get_results(self) -> dict: 343 | """ 344 | Return the results about S-measure. 345 | 346 | :return: dict(sm=sm) 347 | """ 348 | sm = np.mean(np.array(self.sms, dtype=_TYPE)) 349 | return dict(sm=sm) 350 | 351 | 352 | class Emeasure(object): 353 | def __init__(self): 354 | """ 355 | E-measure(Enhanced-alignment Measure) for SOD. 356 | 357 | More details about the implementation can be found in https://www.yuque.com/lart/blog/lwgt38 358 | 359 | :: 360 | 361 | @inproceedings{Emeasure, 362 | title="Enhanced-alignment Measure for Binary Foreground Map Evaluation", 363 | author="Deng-Ping {Fan} and Cheng {Gong} and Yang {Cao} and Bo {Ren} and Ming-Ming {Cheng} and Ali {Borji}", 364 | booktitle=IJCAI, 365 | pages="698--704", 366 | year={2018} 367 | } 368 | """ 369 | self.adaptive_ems = [] 370 | self.changeable_ems = [] 371 | 372 | def step(self, pred: np.ndarray, gt: np.ndarray): 373 | pred, gt = _prepare_data(pred=pred, gt=gt) 374 | self.gt_fg_numel = np.count_nonzero(gt) 375 | self.gt_size = gt.shape[0] * gt.shape[1] 376 | 377 | changeable_ems = self.cal_changeable_em(pred, gt) 378 | self.changeable_ems.append(changeable_ems) 379 | adaptive_em = self.cal_adaptive_em(pred, gt) 380 | self.adaptive_ems.append(adaptive_em) 381 | 382 | def cal_adaptive_em(self, pred: np.ndarray, gt: np.ndarray) -> float: 383 | """ 384 | Calculate the adaptive E-measure. 385 | 386 | :return: adaptive_em 387 | """ 388 | adaptive_threshold = _get_adaptive_threshold(pred, max_value=1) 389 | adaptive_em = self.cal_em_with_threshold(pred, gt, threshold=adaptive_threshold) 390 | return adaptive_em 391 | 392 | def cal_changeable_em(self, pred: np.ndarray, gt: np.ndarray) -> np.ndarray: 393 | """ 394 | Calculate the changeable E-measure, which can be used to obtain the mean E-measure, 395 | the maximum E-measure and the E-measure-threshold curve. 396 | 397 | :return: changeable_ems 398 | """ 399 | changeable_ems = self.cal_em_with_cumsumhistogram(pred, gt) 400 | return changeable_ems 401 | 402 | def cal_em_with_threshold(self, pred: np.ndarray, gt: np.ndarray, threshold: float) -> float: 403 | """ 404 | Calculate the E-measure corresponding to the specific threshold. 405 | 406 | Variable naming rules within the function: 407 | ``[pred attribute(foreground fg, background bg)]_[gt attribute(foreground fg, background bg)]_[meaning]`` 408 | 409 | If only ``pred`` or ``gt`` is considered, another corresponding attribute location is replaced with '``_``'. 410 | """ 411 | binarized_pred = pred >= threshold 412 | fg_fg_numel = np.count_nonzero(binarized_pred & gt) 413 | fg_bg_numel = np.count_nonzero(binarized_pred & ~gt) 414 | 415 | fg___numel = fg_fg_numel + fg_bg_numel 416 | bg___numel = self.gt_size - fg___numel 417 | 418 | if self.gt_fg_numel == 0: 419 | enhanced_matrix_sum = bg___numel 420 | elif self.gt_fg_numel == self.gt_size: 421 | enhanced_matrix_sum = fg___numel 422 | else: 423 | parts_numel, combinations = self.generate_parts_numel_combinations( 424 | fg_fg_numel=fg_fg_numel, 425 | fg_bg_numel=fg_bg_numel, 426 | pred_fg_numel=fg___numel, 427 | pred_bg_numel=bg___numel, 428 | ) 429 | 430 | results_parts = [] 431 | for i, (part_numel, combination) in enumerate(zip(parts_numel, combinations)): 432 | align_matrix_value = ( 433 | 2 434 | * (combination[0] * combination[1]) 435 | / (combination[0] ** 2 + combination[1] ** 2 + _EPS) 436 | ) 437 | enhanced_matrix_value = (align_matrix_value + 1) ** 2 / 4 438 | results_parts.append(enhanced_matrix_value * part_numel) 439 | enhanced_matrix_sum = sum(results_parts) 440 | 441 | em = enhanced_matrix_sum / (self.gt_size - 1 + _EPS) 442 | return em 443 | 444 | def cal_em_with_cumsumhistogram(self, pred: np.ndarray, gt: np.ndarray) -> np.ndarray: 445 | """ 446 | Calculate the E-measure corresponding to the threshold that varies from 0 to 255.. 447 | 448 | Variable naming rules within the function: 449 | ``[pred attribute(foreground fg, background bg)]_[gt attribute(foreground fg, background bg)]_[meaning]`` 450 | 451 | If only ``pred`` or ``gt`` is considered, another corresponding attribute location is replaced with '``_``'. 452 | """ 453 | pred = (pred * 255).astype(np.uint8) 454 | bins = np.linspace(0, 256, 257) 455 | fg_fg_hist, _ = np.histogram(pred[gt], bins=bins) 456 | fg_bg_hist, _ = np.histogram(pred[~gt], bins=bins) 457 | fg_fg_numel_w_thrs = np.cumsum(np.flip(fg_fg_hist), axis=0) 458 | fg_bg_numel_w_thrs = np.cumsum(np.flip(fg_bg_hist), axis=0) 459 | 460 | fg___numel_w_thrs = fg_fg_numel_w_thrs + fg_bg_numel_w_thrs 461 | bg___numel_w_thrs = self.gt_size - fg___numel_w_thrs 462 | 463 | if self.gt_fg_numel == 0: 464 | enhanced_matrix_sum = bg___numel_w_thrs 465 | elif self.gt_fg_numel == self.gt_size: 466 | enhanced_matrix_sum = fg___numel_w_thrs 467 | else: 468 | parts_numel_w_thrs, combinations = self.generate_parts_numel_combinations( 469 | fg_fg_numel=fg_fg_numel_w_thrs, 470 | fg_bg_numel=fg_bg_numel_w_thrs, 471 | pred_fg_numel=fg___numel_w_thrs, 472 | pred_bg_numel=bg___numel_w_thrs, 473 | ) 474 | 475 | results_parts = np.empty(shape=(4, 256), dtype=np.float64) 476 | for i, (part_numel, combination) in enumerate(zip(parts_numel_w_thrs, combinations)): 477 | align_matrix_value = ( 478 | 2 479 | * (combination[0] * combination[1]) 480 | / (combination[0] ** 2 + combination[1] ** 2 + _EPS) 481 | ) 482 | enhanced_matrix_value = (align_matrix_value + 1) ** 2 / 4 483 | results_parts[i] = enhanced_matrix_value * part_numel 484 | enhanced_matrix_sum = results_parts.sum(axis=0) 485 | 486 | em = enhanced_matrix_sum / (self.gt_size - 1 + _EPS) 487 | return em 488 | 489 | def generate_parts_numel_combinations( 490 | self, fg_fg_numel, fg_bg_numel, pred_fg_numel, pred_bg_numel 491 | ): 492 | bg_fg_numel = self.gt_fg_numel - fg_fg_numel 493 | bg_bg_numel = pred_bg_numel - bg_fg_numel 494 | 495 | parts_numel = [fg_fg_numel, fg_bg_numel, bg_fg_numel, bg_bg_numel] 496 | 497 | mean_pred_value = pred_fg_numel / self.gt_size 498 | mean_gt_value = self.gt_fg_numel / self.gt_size 499 | 500 | demeaned_pred_fg_value = 1 - mean_pred_value 501 | demeaned_pred_bg_value = 0 - mean_pred_value 502 | demeaned_gt_fg_value = 1 - mean_gt_value 503 | demeaned_gt_bg_value = 0 - mean_gt_value 504 | 505 | combinations = [ 506 | (demeaned_pred_fg_value, demeaned_gt_fg_value), 507 | (demeaned_pred_fg_value, demeaned_gt_bg_value), 508 | (demeaned_pred_bg_value, demeaned_gt_fg_value), 509 | (demeaned_pred_bg_value, demeaned_gt_bg_value), 510 | ] 511 | return parts_numel, combinations 512 | 513 | def get_results(self) -> dict: 514 | """ 515 | Return the results about E-measure. 516 | 517 | :return: dict(em=dict(adp=adaptive_em, curve=changeable_em)) 518 | """ 519 | adaptive_em = np.mean(np.array(self.adaptive_ems, dtype=_TYPE)) 520 | changeable_em = np.mean(np.array(self.changeable_ems, dtype=_TYPE), axis=0) 521 | return dict(em=dict(adp=adaptive_em, curve=changeable_em)) 522 | 523 | 524 | class WeightedFmeasure(object): 525 | def __init__(self, beta: float = 1): 526 | """ 527 | Weighted F-measure for SOD. 528 | 529 | :: 530 | 531 | @inproceedings{wFmeasure, 532 | title={How to eval foreground maps?}, 533 | author={Margolin, Ran and Zelnik-Manor, Lihi and Tal, Ayellet}, 534 | booktitle=CVPR, 535 | pages={248--255}, 536 | year={2014} 537 | } 538 | 539 | :param beta: the weight of the precision 540 | """ 541 | self.beta = beta 542 | self.weighted_fms = [] 543 | 544 | def step(self, pred: np.ndarray, gt: np.ndarray): 545 | pred, gt = _prepare_data(pred=pred, gt=gt) 546 | 547 | if np.all(~gt): 548 | wfm = 0 549 | else: 550 | wfm = self.cal_wfm(pred, gt) 551 | self.weighted_fms.append(wfm) 552 | 553 | def cal_wfm(self, pred: np.ndarray, gt: np.ndarray) -> float: 554 | """ 555 | Calculate the weighted F-measure. 556 | """ 557 | # [Dst,IDXT] = bwdist(dGT); 558 | Dst, Idxt = bwdist(gt == 0, return_indices=True) 559 | 560 | # %Pixel dependency 561 | # E = abs(FG-dGT); 562 | E = np.abs(pred - gt) 563 | # Et = E; 564 | # Et(~GT)=Et(IDXT(~GT)); %To deal correctly with the edges of the foreground region 565 | Et = np.copy(E) 566 | Et[gt == 0] = Et[Idxt[0][gt == 0], Idxt[1][gt == 0]] 567 | 568 | # K = fspecial('gaussian',7,5); 569 | # EA = imfilter(Et,K); 570 | K = self.matlab_style_gauss2D((7, 7), sigma=5) 571 | EA = convolve(Et, weights=K, mode="constant", cval=0) 572 | # MIN_E_EA = E; 573 | # MIN_E_EA(GT & EA np.ndarray: 600 | """ 601 | 2D gaussian mask - should give the same result as MATLAB's 602 | fspecial('gaussian',[shape],[sigma]) 603 | """ 604 | m, n = [(ss - 1) / 2 for ss in shape] 605 | y, x = np.ogrid[-m : m + 1, -n : n + 1] 606 | h = np.exp(-(x * x + y * y) / (2 * sigma * sigma)) 607 | h[h < np.finfo(h.dtype).eps * h.max()] = 0 608 | sumh = h.sum() 609 | if sumh != 0: 610 | h /= sumh 611 | return h 612 | 613 | def get_results(self) -> dict: 614 | """ 615 | Return the results about weighted F-measure. 616 | 617 | :return: dict(wfm=weighted_fm) 618 | """ 619 | weighted_fm = np.mean(np.array(self.weighted_fms, dtype=_TYPE)) 620 | return dict(wfm=weighted_fm) 621 | -------------------------------------------------------------------------------- /FSEL_ECCV_2024/test/test_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image 3 | import torchvision.transforms as transforms 4 | 5 | class test_dataset: 6 | def __init__(self, image_root, gt_root): 7 | self.img_list = [os.path.splitext(f)[0] for f in os.listdir(gt_root) if f.endswith('.png')] 8 | self.image_root = image_root 9 | self.gt_root = gt_root 10 | self.transform = transforms.Compose([ 11 | transforms.ToTensor(), 12 | ]) 13 | self.gt_transform = transforms.ToTensor() 14 | self.size = len(self.img_list) 15 | self.index = 0 16 | 17 | def load_data(self): 18 | #image = self.rgb_loader(self.images[self.index]) 19 | image = self.binary_loader(os.path.join(self.image_root,self.img_list[self.index]+ '.png')) 20 | gt = self.binary_loader(os.path.join(self.gt_root,self.img_list[self.index] + '.png')) 21 | self.index += 1 22 | return image, gt 23 | 24 | def rgb_loader(self, path): 25 | with open(path, 'rb') as f: 26 | img = Image.open(f) 27 | return img.convert('RGB') 28 | 29 | def binary_loader(self, path): 30 | with open(path, 'rb') as f: 31 | img = Image.open(f) 32 | return img.convert('L') 33 | 34 | -------------------------------------------------------------------------------- /FSEL_ECCV_2024/test/test_data1.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image 3 | import torchvision.transforms as transforms 4 | 5 | class test_dataset: 6 | def __init__(self, image_root, gt_root): 7 | self.img_list = [os.path.splitext(f)[0] for f in os.listdir(gt_root) if f.endswith('.png') or f.endswith('.jpg')] 8 | self.image_root = image_root 9 | self.gt_root = gt_root 10 | self.transform = transforms.Compose([ 11 | transforms.ToTensor(), 12 | ]) 13 | self.gt_transform = transforms.ToTensor() 14 | self.size = len(self.img_list) 15 | self.index = 0 16 | 17 | def load_data(self): 18 | #image = self.rgb_loader(self.images[self.index]) 19 | image = self.binary_loader(os.path.join(self.image_root,self.img_list[self.index]+ '.jpg')) 20 | gt = self.binary_loader(os.path.join(self.gt_root,self.img_list[self.index] + '.jpg')) 21 | self.index += 1 22 | return image, gt 23 | 24 | def rgb_loader(self, path): 25 | with open(path, 'rb') as f: 26 | img = Image.open(f) 27 | return img.convert('RGB') 28 | 29 | def binary_loader(self, path): 30 | with open(path, 'rb') as f: 31 | img = Image.open(f) 32 | return img.convert('L') 33 | 34 | -------------------------------------------------------------------------------- /FSEL_ECCV_2024/test/test_metric_score.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from test_data import test_dataset 3 | from saliency_metric import cal_mae,cal_fm,cal_sm,cal_em,cal_wfm 4 | from sod_metrics import MAE, Emeasure, Fmeasure, Smeasure, WeightedFmeasure 5 | dataset_path = '' ##gt_path 6 | dataset_path_pre = '' ##pre_salmap_path 7 | 8 | test_datasets = [''] ##test_datasets_name 9 | 10 | for dataset in test_datasets: 11 | sal_root = dataset_path_pre +dataset+'/' 12 | gt_root = dataset_path +dataset+'/GT/' 13 | test_loader = test_dataset(sal_root, gt_root) 14 | mae,fm,sm,em,wfm= cal_mae(),cal_fm(test_loader.size),cal_sm(),cal_em(),cal_wfm() 15 | for i in range(test_loader.size): 16 | print ('predicting for %d / %d' % ( i + 1, test_loader.size)) 17 | sal, gt = test_loader.load_data() 18 | if sal.size != gt.size: 19 | x, y = gt.size 20 | sal = sal.resize((x, y)) 21 | gt = np.asarray(gt, np.float32) 22 | gt /= (gt.max() + 1e-8) 23 | gt[gt > 0.5] = 1 24 | gt[gt != 1] = 0 25 | res = sal 26 | res = np.array(res) 27 | if res.max() == res.min(): 28 | res = res/255 29 | else: 30 | res = (res - res.min()) / (res.max() - res.min()) 31 | mae.update(res, gt) 32 | sm.update(res,gt) 33 | fm.update(res, gt) 34 | em.update(res,gt) 35 | wfm.update(res,gt) 36 | 37 | MAE = mae.show() 38 | maxf,meanf,_,_ = fm.show() 39 | sm = sm.show() 40 | em = em.show() 41 | wfm = wfm.show() 42 | print('dataset: {} MAE: {:.4f} maxF: {:.4f} avgF: {:.4f} wfm: {:.4f} Sm: {:.4f} Em: {:.4f}'.format(dataset, MAE, maxf,meanf,wfm,sm,em)) -------------------------------------------------------------------------------- /FSEL_ECCV_2024/utils/FeatureViz.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import os, argparse, cv2 4 | from lib.Network_Res2Net_GRA_NCD_FeatureViz import Network 5 | from utils.dataloader import test_dataset 6 | 7 | 8 | def heatmap(feat_viz, ori_img, save_path=None): 9 | feat_viz = torch.mean(feat_viz, dim=1, keepdim=True).data.cpu().numpy().squeeze() 10 | feat_viz = (feat_viz - feat_viz.min()) / (feat_viz.max() - feat_viz.min() + 1e-8) 11 | 12 | ori_img = ori_img.data.cpu().numpy().squeeze() 13 | ori_img = ori_img.transpose((1, 2, 0)) 14 | ori_img = ori_img * np.array((0.229, 0.224, 0.225)) + np.array((0.485, 0.456, 0.406)) 15 | ori_img = ori_img[:, :, ::-1] 16 | ori_img = np.uint8(255 * ori_img) 17 | feat_viz = np.uint8(255 * feat_viz) 18 | feat_viz = cv2.applyColorMap(feat_viz, cv2.COLORMAP_JET) 19 | feat_viz = cv2.resize(feat_viz, (320, 320)) 20 | ori_img = cv2.resize(ori_img, (320, 320)) 21 | feat_viz = cv2.addWeighted(ori_img, 0.5, feat_viz, 0.5, 0) 22 | 23 | cv2.imwrite(save_path, feat_viz) 24 | 25 | 26 | if __name__ == '__main__': 27 | parser = argparse.ArgumentParser() 28 | parser.add_argument('--testsize', type=int, default=352, help='testing size') 29 | parser.add_argument('--pth_path', type=str, default='./snapshot/20201214-Network_Res2Net_GRA_NCD/Net_epoch_best.pth') 30 | opt = parser.parse_args() 31 | 32 | for _data_name in ['CAMO', 'COD10K', 'CHAMELEON']: 33 | data_path = '/media/nercms/NERCMS/GepengJi/2020ACMMM/Dataset/COD_New_data/TestDataset/{}/'.format(_data_name) 34 | save_path = './res/{}/Feature_Viz/{}/'.format(opt.pth_path.split('/')[-2], _data_name) 35 | model = Network() 36 | model.load_state_dict(torch.load(opt.pth_path)) 37 | model.cuda() 38 | model.eval() 39 | 40 | os.makedirs(save_path, exist_ok=True) 41 | image_root = '{}/Imgs/'.format(data_path) 42 | gt_root = '{}/GT/'.format(data_path) 43 | test_loader = test_dataset(image_root, gt_root, opt.testsize) 44 | 45 | for i in range(test_loader.size): 46 | image, gt, name = test_loader.load_data() 47 | gt = np.asarray(gt, np.float32) 48 | gt /= (gt.max() + 1e-8) 49 | image = image.cuda() 50 | 51 | res5, res4, res3, res2, feat_viz = model(image) 52 | for i in range(0, 3): 53 | for j in range(0, 4): 54 | for k in range(0, 2): 55 | cur_feat_viz = feat_viz[i][j][k] 56 | label = 'feat' if k == 0 else 'guid' 57 | img_name = name.split('.')[0] + '_level{}_GRA{}_'.format(i+3, j+1) + label + '.png' 58 | heatmap(feat_viz=cur_feat_viz, ori_img=image, save_path=save_path+img_name) 59 | print('> Dataset: {}, Image: {}'.format(_data_name, save_path+img_name)) -------------------------------------------------------------------------------- /FSEL_ECCV_2024/utils/MyFeatureVisulization.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import numpy as np 4 | import os, argparse 5 | from scipy import misc 6 | from lib.Network_Res2Net_GRA_NCD import Network 7 | from utils.dataloader import test_dataset 8 | 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument('--testsize', type=int, default=352, help='testing size') 11 | parser.add_argument('--pth_path', type=str, default='./snapshot/20201214-Network_Res2Net_GRA_NCD/Net_epoch_best.pth') 12 | opt = parser.parse_args() 13 | 14 | for _data_name in ['vis']: 15 | data_path = '/{}/'.format(_data_name) 16 | save_path = './res/{}/middle_vis/{}/'.format(opt.pth_path.split('/')[-2], _data_name) 17 | model = Network() 18 | model.load_state_dict(torch.load(opt.pth_path)) 19 | model.cuda() 20 | model.eval() 21 | 22 | os.makedirs(save_path, exist_ok=True) 23 | image_root = '{}/Imgs/'.format(data_path) 24 | gt_root = '{}/GT/'.format(data_path) 25 | test_loader = test_dataset(image_root, gt_root, opt.testsize) 26 | 27 | for i in range(test_loader.size): 28 | image, gt, name = test_loader.load_data() 29 | gt = np.asarray(gt, np.float32) 30 | gt /= (gt.max() + 1e-8) 31 | image = image.cuda() 32 | 33 | res_list = model(image) 34 | for i in range(4): 35 | res = res_list[i] 36 | res = -1 * (torch.sigmoid(res)) + 1 37 | res = F.upsample(res, size=gt.shape, mode='bilinear', align_corners=False) 38 | res = res.sigmoid().data.cpu().numpy().squeeze() 39 | res = (res - res.min()) / (res.max() - res.min() + 1e-8) 40 | misc.imsave(save_path+name.replace('.png', '_{}.png'.format(i)), res) 41 | -------------------------------------------------------------------------------- /FSEL_ECCV_2024/utils/cod10k_subclass_split.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSYSI/FSEL/f3be46404486f70404c79f1c1b93025e77a7233e/FSEL_ECCV_2024/utils/cod10k_subclass_split.py -------------------------------------------------------------------------------- /FSEL_ECCV_2024/utils/data_val.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image 3 | import torch.utils.data as data 4 | import torchvision.transforms as transforms 5 | import random 6 | import numpy as np 7 | from PIL import ImageEnhance 8 | import torch 9 | import cv2 10 | 11 | 12 | 13 | # several data augumentation strategies 14 | def cv_random_flip(img, label, edge): 15 | # left right flip 16 | flip_flag = random.randint(0, 1) 17 | if flip_flag == 1: 18 | img = img.transpose(Image.FLIP_LEFT_RIGHT) 19 | label = label.transpose(Image.FLIP_LEFT_RIGHT) 20 | edge = edge.transpose(Image.FLIP_LEFT_RIGHT) 21 | return img, label, edge 22 | 23 | 24 | def randomCrop(image, label, edge): 25 | border = 30 26 | image_width = image.size[0] 27 | image_height = image.size[1] 28 | crop_win_width = np.random.randint(image_width - border, image_width) 29 | crop_win_height = np.random.randint(image_height - border, image_height) 30 | random_region = ( 31 | (image_width - crop_win_width) >> 1, (image_height - crop_win_height) >> 1, (image_width + crop_win_width) >> 1, 32 | (image_height + crop_win_height) >> 1) 33 | return image.crop(random_region), label.crop(random_region), edge.crop(random_region) 34 | 35 | 36 | def randomRotation(image, label, edge): 37 | mode = Image.BICUBIC 38 | if random.random() > 0.8: 39 | random_angle = np.random.randint(-15, 15) 40 | image = image.rotate(random_angle, mode) 41 | label = label.rotate(random_angle, mode) 42 | edge = edge.rotate(random_angle, mode) 43 | return image, label, edge 44 | 45 | 46 | def colorEnhance(image): 47 | bright_intensity = random.randint(5, 15) / 10.0 48 | image = ImageEnhance.Brightness(image).enhance(bright_intensity) 49 | contrast_intensity = random.randint(5, 15) / 10.0 50 | image = ImageEnhance.Contrast(image).enhance(contrast_intensity) 51 | color_intensity = random.randint(0, 20) / 10.0 52 | image = ImageEnhance.Color(image).enhance(color_intensity) 53 | sharp_intensity = random.randint(0, 30) / 10.0 54 | image = ImageEnhance.Sharpness(image).enhance(sharp_intensity) 55 | return image 56 | 57 | 58 | def randomGaussian(image, mean=0.1, sigma=0.35): 59 | def gaussianNoisy(im, mean=mean, sigma=sigma): 60 | for _i in range(len(im)): 61 | im[_i] += random.gauss(mean, sigma) 62 | return im 63 | 64 | img = np.asarray(image) 65 | width, height = img.shape 66 | img = gaussianNoisy(img[:].flatten(), mean, sigma) 67 | img = img.reshape([width, height]) 68 | return Image.fromarray(np.uint8(img)) 69 | 70 | 71 | def randomPeper(img): 72 | img = np.array(img) 73 | noiseNum = int(0.0015 * img.shape[0] * img.shape[1]) 74 | for i in range(noiseNum): 75 | 76 | randX = random.randint(0, img.shape[0] - 1) 77 | 78 | randY = random.randint(0, img.shape[1] - 1) 79 | 80 | if random.randint(0, 1) == 0: 81 | 82 | img[randX, randY] = 0 83 | 84 | else: 85 | 86 | img[randX, randY] = 255 87 | return Image.fromarray(img) 88 | 89 | 90 | # dataset for training 91 | class PolypObjDataset(data.Dataset): 92 | def __init__(self, image_root, gt_root, edge_root, trainsize): 93 | self.trainsize = trainsize 94 | self.images = [image_root + f for f in os.listdir(image_root) if f.endswith('.jpg')or f.endswith('.png')] 95 | self.gts = [gt_root + f for f in os.listdir(gt_root) if f.endswith('.jpg') or f.endswith('.png')] 96 | self.edges = [edge_root + f for f in os.listdir(edge_root) if f.endswith('.jpg') or f.endswith('.png')] 97 | self.images = sorted(self.images) 98 | self.gts = sorted(self.gts) 99 | self.edges = sorted(self.edges) 100 | self.filter_files() 101 | self.img_transform = transforms.Compose([ 102 | transforms.Resize((self.trainsize, self.trainsize)), 103 | transforms.ToTensor(), 104 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) 105 | self.gt_transform = transforms.Compose([ 106 | transforms.Resize((self.trainsize, self.trainsize)), 107 | transforms.ToTensor()]) 108 | 109 | self.edge_transform = transforms.Compose([ 110 | transforms.Resize((self.trainsize, self.trainsize)), 111 | transforms.ToTensor()]) 112 | 113 | self.kernel = np.ones((3, 3), np.uint8) 114 | self.size = len(self.images) 115 | 116 | def __getitem__(self, index): 117 | image = self.rgb_loader(self.images[index]) 118 | gt = self.binary_loader(self.gts[index]) 119 | edge = cv2.imread(self.edges[index], cv2.IMREAD_GRAYSCALE) 120 | edge = cv2.dilate(edge, self.kernel, iterations=1) 121 | edge = Image.fromarray(edge) 122 | 123 | image, gt, edge = cv_random_flip(image, gt, edge) 124 | image, gt, edge = randomCrop(image, gt, edge) 125 | image, gt, edge = randomRotation(image, gt, edge) 126 | 127 | image = colorEnhance(image) 128 | gt = randomPeper(gt) 129 | edge = randomPeper(edge) 130 | 131 | image = self.img_transform(image) 132 | gt = self.gt_transform(gt) 133 | edge = self.edge_transform(edge) 134 | 135 | edge_small = self.Threshold_process(edge) 136 | 137 | 138 | return image, gt, edge_small 139 | 140 | def filter_files(self): 141 | assert len(self.images) == len(self.gts) and len(self.edges) == len(self.images) \ 142 | and len(self.edges) == len(self.gts) 143 | images = [] 144 | gts = [] 145 | edges = [] 146 | for img_path, gt_path, edge_path in zip(self.images, self.gts, self.edges): 147 | img = Image.open(img_path) 148 | gt = Image.open(gt_path) 149 | edge = Image.open(edge_path) 150 | if img.size == gt.size and img.size == edge.size: 151 | images.append(img_path) 152 | gts.append(gt_path) 153 | edges.append(edge_path) 154 | self.images = images 155 | self.gts = gts 156 | self.edges = edges 157 | 158 | def rgb_loader(self, path): 159 | with open(path, 'rb') as f: 160 | img = Image.open(f) 161 | return img.convert('RGB') 162 | 163 | def binary_loader(self, path): 164 | with open(path, 'rb') as f: 165 | img = Image.open(f) 166 | return img.convert('L') 167 | 168 | 169 | def Threshold_process(self, a): 170 | one = torch.ones_like(a) 171 | return torch.where(a > 0, one, a) 172 | 173 | def __len__(self): 174 | return self.size 175 | 176 | # solve dataloader random bug 177 | def seed_worker(worker_id): 178 | worker_seed = torch.initial_seed() % 2**32 179 | np.random.seed(worker_seed) 180 | 181 | # dataloader for training 182 | def get_loader(image_root, gt_root, edge_root, batchsize, trainsize, shuffle=True, num_workers=12, pin_memory=True): 183 | dataset = PolypObjDataset(image_root, gt_root, edge_root, trainsize) 184 | data_loader = data.DataLoader(dataset=dataset, 185 | batch_size=batchsize, 186 | shuffle=shuffle, 187 | num_workers=num_workers, 188 | pin_memory=pin_memory, 189 | worker_init_fn=seed_worker) 190 | return data_loader 191 | 192 | 193 | # test dataset and loader 194 | class test_dataset: 195 | def __init__(self, image_root, gt_root, testsize): 196 | self.testsize = testsize 197 | 198 | self.images = [image_root + f for f in os.listdir(image_root) if f.endswith('.jpg') or f.endswith('.png')] 199 | self.gts = [gt_root + f for f in os.listdir(gt_root) if f.endswith('.tif') or f.endswith('.png')] 200 | self.images = sorted(self.images) 201 | self.gts = sorted(self.gts) 202 | 203 | self.transform = transforms.Compose([ 204 | transforms.Resize((self.testsize, self.testsize)), 205 | transforms.ToTensor(), 206 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) 207 | self.gt_transform = transforms.ToTensor() 208 | 209 | self.size = len(self.images) 210 | self.index = 0 211 | 212 | def load_data(self): 213 | image = self.rgb_loader(self.images[self.index]) 214 | image = self.transform(image).unsqueeze(0) 215 | 216 | gt = self.binary_loader(self.gts[self.index]) 217 | 218 | name = self.images[self.index].split('/')[-1] 219 | 220 | image_for_post = self.rgb_loader(self.images[self.index]) 221 | image_for_post = image_for_post.resize(gt.size) 222 | 223 | if name.endswith('.jpg'): 224 | name = name.split('.jpg')[0] + '.png' 225 | 226 | self.index += 1 227 | self.index = self.index % self.size 228 | 229 | return image, gt, name, np.array(image_for_post) 230 | 231 | def rgb_loader(self, path): 232 | with open(path, 'rb') as f: 233 | img = Image.open(f) 234 | return img.convert('RGB') 235 | 236 | def binary_loader(self, path): 237 | with open(path, 'rb') as f: 238 | img = Image.open(f) 239 | return img.convert('L') 240 | 241 | def __len__(self): 242 | return self.size 243 | 244 | 245 | if __name__ =='__main__': 246 | train_root = '/dataset/COD/TrainDataset/' 247 | batchsize = 36 248 | trainsize = 512 249 | train_loader = get_loader(image_root=train_root + 'Imgs/', 250 | gt_root=train_root + 'GT/', 251 | edge_root=train_root + 'Edge/', 252 | batchsize=batchsize, 253 | trainsize=trainsize, 254 | num_workers=8) 255 | for i, (images, gts, edges) in enumerate(train_loader, start=1): 256 | gt = gts[0].sigmoid().data.cpu().numpy().squeeze() 257 | edge =edges[0].sigmoid().data.cpu().numpy().squeeze() 258 | print(edge.shape) 259 | res_gt = (gt - gt.min()) / (gt.max() - gt.min() + 1e-8) 260 | cv2.imwrite('ceshi_gt.png',res_gt*255) 261 | res = (edge - edge.min()) / (edge.max() - edge.min() + 1e-8) 262 | cv2.imwrite('ceshi_edge.png',res*255) 263 | break -------------------------------------------------------------------------------- /FSEL_ECCV_2024/utils/dataloader.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image 3 | import torch.utils.data as data 4 | import torchvision.transforms as transforms 5 | 6 | 7 | class PolypDataset(data.Dataset): 8 | """ 9 | dataloader for polyp segmentation tasks 10 | """ 11 | def __init__(self, image_root, gt_root, trainsize): 12 | self.trainsize = trainsize 13 | self.images = [image_root + f for f in os.listdir(image_root) if f.endswith('.jpg') or f.endswith('.png')] 14 | self.gts = [gt_root + f for f in os.listdir(gt_root) if f.endswith('.png')] 15 | self.images = sorted(self.images) 16 | self.gts = sorted(self.gts) 17 | self.filter_files() 18 | self.size = len(self.images) 19 | self.img_transform = transforms.Compose([ 20 | transforms.Resize((self.trainsize, self.trainsize)), 21 | transforms.ToTensor(), 22 | transforms.Normalize([0.485, 0.456, 0.406], 23 | [0.229, 0.224, 0.225])]) 24 | self.gt_transform = transforms.Compose([ 25 | transforms.Resize((self.trainsize, self.trainsize)), 26 | transforms.ToTensor()]) 27 | 28 | def __getitem__(self, index): 29 | image = self.rgb_loader(self.images[index]) 30 | gt = self.binary_loader(self.gts[index]) 31 | image = self.img_transform(image) 32 | gt = self.gt_transform(gt) 33 | return image, gt 34 | 35 | def filter_files(self): 36 | assert len(self.images) == len(self.gts) 37 | images = [] 38 | gts = [] 39 | for img_path, gt_path in zip(self.images, self.gts): 40 | img = Image.open(img_path) 41 | gt = Image.open(gt_path) 42 | if img.size == gt.size: 43 | images.append(img_path) 44 | gts.append(gt_path) 45 | self.images = images 46 | self.gts = gts 47 | 48 | def rgb_loader(self, path): 49 | with open(path, 'rb') as f: 50 | img = Image.open(f) 51 | return img.convert('RGB') 52 | 53 | def binary_loader(self, path): 54 | with open(path, 'rb') as f: 55 | img = Image.open(f) 56 | # return img.convert('1') 57 | return img.convert('L') 58 | 59 | def resize(self, img, gt): 60 | assert img.size == gt.size 61 | w, h = img.size 62 | if h < self.trainsize or w < self.trainsize: 63 | h = max(h, self.trainsize) 64 | w = max(w, self.trainsize) 65 | return img.resize((w, h), Image.BILINEAR), gt.resize((w, h), Image.NEAREST) 66 | else: 67 | return img, gt 68 | 69 | def __len__(self): 70 | return self.size 71 | 72 | 73 | def get_loader(image_root, gt_root, batchsize, trainsize, shuffle=True, num_workers=4, pin_memory=True): 74 | 75 | dataset = PolypDataset(image_root, gt_root, trainsize) 76 | data_loader = data.DataLoader(dataset=dataset, 77 | batch_size=batchsize, 78 | shuffle=shuffle, 79 | num_workers=num_workers, 80 | pin_memory=pin_memory) 81 | return data_loader 82 | 83 | 84 | class test_dataset: 85 | def __init__(self, image_root, gt_root, testsize): 86 | self.testsize = testsize 87 | self.images = [image_root + f for f in os.listdir(image_root) if f.endswith('.jpg') or f.endswith('.png')] 88 | self.gts = [gt_root + f for f in os.listdir(gt_root) if f.endswith('.tif') or f.endswith('.png')] 89 | self.images = sorted(self.images) 90 | self.gts = sorted(self.gts) 91 | self.transform = transforms.Compose([ 92 | transforms.Resize((self.testsize, self.testsize)), 93 | transforms.ToTensor(), 94 | transforms.Normalize([0.485, 0.456, 0.406], 95 | [0.229, 0.224, 0.225])]) 96 | self.gt_transform = transforms.ToTensor() 97 | self.size = len(self.images) 98 | self.index = 0 99 | 100 | def load_data(self): 101 | image = self.rgb_loader(self.images[self.index]) 102 | image = self.transform(image).unsqueeze(0) 103 | gt = self.binary_loader(self.gts[self.index]) 104 | name = self.images[self.index].split('/')[-1] 105 | if name.endswith('.jpg'): 106 | name = name.split('.jpg')[0] + '.png' 107 | self.index += 1 108 | return image, gt, name 109 | 110 | def rgb_loader(self, path): 111 | with open(path, 'rb') as f: 112 | img = Image.open(f) 113 | return img.convert('RGB') 114 | 115 | def binary_loader(self, path): 116 | with open(path, 'rb') as f: 117 | img = Image.open(f) 118 | return img.convert('L') 119 | -------------------------------------------------------------------------------- /FSEL_ECCV_2024/utils/fps.py: -------------------------------------------------------------------------------- 1 | import torch, time 2 | import numpy as np 3 | import jittor as jt 4 | 5 | from lib.Network_Res2Net_GRA_NCD import Network as py_Network 6 | from jittor.lib.Network_Res2Net_GRA_NCD import Network as jt_Network 7 | 8 | jt.flags.use_cuda = 1 9 | 10 | # 定义numpy输入矩阵 11 | bs = 32 12 | test_img = np.random.random((bs,3,224,224)).astype('float32') 13 | 14 | # 定义 pytorch & jittor 输入矩阵 15 | pytorch_test_img = torch.Tensor(test_img).cuda() 16 | jittor_test_img = jt.array(test_img) 17 | 18 | # 跑turns次前向求平均值 19 | turns = 100 20 | 21 | # 定义 pytorch & jittor 的xxx模型,如vgg 22 | pytorch_model = py_Network().cuda() 23 | jittor_model = jt_Network() 24 | 25 | # 把模型都设置为eval来防止dropout层对输出结果的随机影响 26 | pytorch_model.eval() 27 | jittor_model.eval() 28 | 29 | # jittor加载pytorch的初始化参数来保证参数完全相同 30 | jittor_model.load_parameters(pytorch_model.state_dict()) 31 | 32 | # 测试Pytorch一次前向传播的平均用时 33 | for i in range(10): 34 | pytorch_result = pytorch_model(pytorch_test_img) # Pytorch热身 35 | torch.cuda.synchronize() 36 | sta = time.time() 37 | for i in range(turns): 38 | pytorch_result = pytorch_model(pytorch_test_img) 39 | torch.cuda.synchronize() # 只有运行了torch.cuda.synchronize()才会真正地运行,时间才是有效的,因此执行forward前后都要执行这句话 40 | end = time.time() 41 | tc_time = round((end - sta) / turns, 5) # 执行turns次的平均时间,输出时保留5位小数 42 | tc_fps = round(bs * turns / (end - sta),0) # 计算FPS 43 | print(f"- Pytorch forward average time cost: {tc_time}, Batch Size: {bs}, FPS: {tc_fps}") 44 | 45 | 46 | # 测试Jittor一次前向传播的平均用时 47 | for i in range(10): 48 | jittor_result = jittor_model(jittor_test_img) # Jittor热身 49 | jittor_result[0].sync() 50 | jt.sync_all(True) 51 | # sync_all(true)是把计算图发射到计算设备上,并且同步。只有运行了jt.sync_all(True)才会真正地运行,时间才是有效的,因此执行forward前后都要执行这句话 52 | sta = time.time() 53 | for i in range(turns): 54 | jittor_result = jittor_model(jittor_test_img) 55 | jittor_result[0].sync() # sync是把计算图发送到计算设备上 56 | jt.sync_all(True) 57 | end = time.time() 58 | jt_time = round((time.time() - sta) / turns, 5) # 执行turns次的平均时间,输出时保留5位小数 59 | jt_fps = round(bs * turns / (end - sta),0) # 计算FPS 60 | print(f"- Jittor forward average time cost: {jt_time}, Batch Size: {bs}, FPS: {jt_fps}") -------------------------------------------------------------------------------- /FSEL_ECCV_2024/utils/generate_LaTeX.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSYSI/FSEL/f3be46404486f70404c79f1c1b93025e77a7233e/FSEL_ECCV_2024/utils/generate_LaTeX.py -------------------------------------------------------------------------------- /FSEL_ECCV_2024/utils/heatmap.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSYSI/FSEL/f3be46404486f70404c79f1c1b93025e77a7233e/FSEL_ECCV_2024/utils/heatmap.py -------------------------------------------------------------------------------- /FSEL_ECCV_2024/utils/pytorch_jittor_convert.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSYSI/FSEL/f3be46404486f70404c79f1c1b93025e77a7233e/FSEL_ECCV_2024/utils/pytorch_jittor_convert.py -------------------------------------------------------------------------------- /FSEL_ECCV_2024/utils/tif2png.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSYSI/FSEL/f3be46404486f70404c79f1c1b93025e77a7233e/FSEL_ECCV_2024/utils/tif2png.py -------------------------------------------------------------------------------- /FSEL_ECCV_2024/utils/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from thop import profile 4 | from thop import clever_format 5 | 6 | def get_coef(iter_percentage, method): 7 | if method == "linear": 8 | milestones = (0.3, 0.7) 9 | coef_range = (0, 1) 10 | min_point, max_point = min(milestones), max(milestones) 11 | min_coef, max_coef = min(coef_range), max(coef_range) 12 | if iter_percentage < min_point: 13 | ual_coef = min_coef 14 | elif iter_percentage > max_point: 15 | ual_coef = max_coef 16 | else: 17 | ratio = (max_coef - min_coef) / (max_point - min_point) 18 | ual_coef = ratio * (iter_percentage - min_point) 19 | elif method == "cos": 20 | coef_range = (0, 1) 21 | min_coef, max_coef = min(coef_range), max(coef_range) 22 | normalized_coef = (1 - np.cos(iter_percentage * np.pi)) / 2 23 | ual_coef = normalized_coef * (max_coef - min_coef) + min_coef 24 | else: 25 | ual_coef = 1.0 26 | return ual_coef 27 | 28 | 29 | def cal_ual(seg_logits, seg_gts): 30 | assert seg_logits.shape == seg_gts.shape, (seg_logits.shape, seg_gts.shape) 31 | sigmoid_x = seg_logits.sigmoid() 32 | loss_map = 1 - (2 * sigmoid_x - 1).abs().pow(2) 33 | return loss_map.mean() 34 | 35 | def clip_gradient(optimizer, grad_clip): 36 | """ 37 | For calibrating misalignment gradient via cliping gradient technique 38 | :param optimizer: 39 | :param grad_clip: 40 | :return: 41 | """ 42 | for group in optimizer.param_groups: 43 | for param in group['params']: 44 | if param.grad is not None: 45 | param.grad.data.clamp_(-grad_clip, grad_clip) 46 | 47 | 48 | def adjust_lr(optimizer, init_lr, epoch, decay_rate=0.1, decay_epoch=30): 49 | decay = decay_rate ** (epoch // decay_epoch) 50 | for param_group in optimizer.param_groups: 51 | param_group['lr'] = decay*init_lr 52 | lr=param_group['lr'] 53 | return lr 54 | 55 | class AvgMeter(object): 56 | def __init__(self, num=40): 57 | self.num = num 58 | self.reset() 59 | 60 | def reset(self): 61 | self.val = 0 62 | self.avg = 0 63 | self.sum = 0 64 | self.count = 0 65 | self.losses = [] 66 | 67 | def update(self, val, n=1): 68 | self.val = val 69 | self.sum += val * n 70 | self.count += n 71 | self.avg = self.sum / self.count 72 | self.losses.append(val) 73 | 74 | def show(self): 75 | return torch.mean(torch.stack(self.losses[np.maximum(len(self.losses)-self.num, 0):])) 76 | 77 | 78 | def CalParams(model, input_tensor): 79 | """ 80 | Usage: 81 | Calculate Params and FLOPs via [THOP](https://github.com/Lyken17/pytorch-OpCounter) 82 | Necessarity: 83 | from thop import profile 84 | from thop import clever_format 85 | :param model: 86 | :param input_tensor: 87 | :return: 88 | """ 89 | flops, params = profile(model, inputs=(input_tensor,)) 90 | flops, params = clever_format([flops, params], "%.3f") 91 | print('[Statistics Information]\nFLOPs: {}\nParams: {}'.format(flops, params)) 92 | 93 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Yanguang Sun 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # [ECCV 2024] Frequency-Spatial Entanglement Learning for Camouflaged Object Detection 2 | Yanguang Sun, Chunyan Xu, Jian Yang, Hanyu Xuan, Lei Luo
3 | 4 | Our work has been accepted for ECCV 2024. The code has already been open sourced. 5 | 6 | If you are interested in our work, please do not hesitate to contact us at Sunyg@njust.edu.cn via email. 7 | 8 | 9 | ![image](https://github.com/user-attachments/assets/15fb5bc4-1662-4b4b-aeec-e41d20be4828) 10 | 11 | ![image](https://github.com/CSYSI/FSEL/assets/171759588/88a36f96-6e5e-42eb-9e50-a4b464a0f63a) 12 | 13 | ![image](https://github.com/CSYSI/FSEL/assets/171759588/a296b40d-2b15-49f1-8c05-bfa7de5e20ff) 14 | 15 | 16 | 17 | 18 | ## Prediction maps 19 | 20 | We provide the prediction maps of our FSEL model in camouflaged object detection, salient object detection, and polyp segmentation tasks. 21 | 22 | FSEL-camouflaged object detection (COD) (PVT/ResNet/Res2Net) [[baidu](https://pan.baidu.com/s/1ogYw7NNCJLahYzBurhvnKw),PIN:u5sb] 23 | 24 | FSEL-salient object detection (SOD) (PVT/ResNet) [[baidu](https://pan.baidu.com/s/1oVgPSDeibQ2HN9LNnzzPbw),PIN:pelf] 25 | 26 | FSEL-polyp segmentation (PS) (PVT/ResNet) [[baidu](https://pan.baidu.com/s/1x-eeELRpKH1XZwvQGaPvAg),PIN:48bi] 27 | 28 | 29 | ## Training weights 30 | 31 | We give the training weights of our FSEL model in COD tasks. 32 | 33 | Note that you should use the relevant network in the lib_initial file to test these .pth files 34 | 35 | FSEL-COD-weights (PVT/ResNet/Res2Net) [[baidu](https://pan.baidu.com/s/1D7nxuXxcF0RRCcVIZPp9Xg),PIN:u0mq] 36 | 37 | 38 | 39 | 40 | ## Citation 41 | 42 | If you use FSEL method in your research or wish to refer to the baseline results published in the Model, please use the following BibTeX entry. 43 | ``` 44 | @article{FSEL, 45 | title={Frequency-Spatial Entanglement Learning for Camouflaged Object Detection}, 46 | author={Sun, Yanguang and Xu, Chunyan and Yang, Jian and Xuan, Hanyu and Luo, Lei}, 47 | journal={arXiv preprint arXiv:2409.01686}, 48 | year={2024} 49 | } 50 | ``` 51 | 52 | ``` 53 | @article{FSEL, 54 | title={Frequency-Spatial Entanglement Learning for Camouflaged Object Detection}, 55 | author={Sun, Yanguang and Xu, Chunyan and Yang, Jian and Xuan, Hanyu and Luo, Lei}, 56 | booktitle={European Conference on Computer Vision}, 57 | year={2024}, 58 | pages={343--360}, 59 | } 60 | ``` 61 | 62 | 63 | --------------------------------------------------------------------------------