├── test_fps.py ├── utils ├── config.py ├── misc.py ├── test_data.py ├── dataset_strategy_fpn.py └── saliency_metric.py ├── LICENSE ├── requirements.txt ├── predict.py ├── test.py ├── README.md ├── train.py └── model ├── MVANet.py └── SwinTransformer.py /test_fps.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import torch 4 | 5 | from model.MVANet import MVANet 6 | 7 | data = torch.randn(1, 3, 1024, 1024, device='cuda', dtype=torch.float32) 8 | 9 | model = MVANet().eval().cuda() 10 | for _ in range(10): 11 | model(data) 12 | 13 | torch.cuda.synchronize() 14 | start_time = time.perf_counter() 15 | for _ in range(100): 16 | model(data) 17 | torch.cuda.synchronize() 18 | print(f"FPS: {100 / (time.perf_counter() - start_time):.03f}") 19 | -------------------------------------------------------------------------------- /utils/config.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | diste1 = '/home/vanessa/code/DIS-main/DIS5K/DIS-TE1/' 4 | diste2 = '/home/vanessa/code/DIS-main/DIS5K/DIS-TE2/' 5 | diste3 = '/home/vanessa/code/DIS-main/DIS5K/DIS-TE3/' 6 | diste4 = '/home/vanessa/code/DIS-main/DIS5K/DIS-TE4/' 7 | disvd = '/home/vanessa/code/DIS-main/DIS5K/DIS-VD/' 8 | 9 | diste1 = os.path.join(diste1) 10 | diste2 = os.path.join(diste2) 11 | diste3 = os.path.join(diste3) 12 | diste4 = os.path.join(diste4) 13 | disvd = os.path.join(disvd) 14 | 15 | 16 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 qianyu-dlut 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==2.1.0 2 | accelerate==0.20.3 3 | addict==2.4.0 4 | aiohttp==3.8.6 5 | aiosignal==1.3.1 6 | aliyun-python-sdk-core==2.16.0 7 | aliyun-python-sdk-kms==2.16.5 8 | async-timeout==4.0.3 9 | asynctest==0.13.0 10 | attrs==24.2.0 11 | cachetools==5.5.0 12 | certifi==2022.12.7 13 | cffi==1.15.1 14 | charset-normalizer==3.4.0 15 | click==8.1.7 16 | colorama==0.4.6 17 | crcmod==1.7 18 | cryptography==43.0.1 19 | cycler==0.11.0 20 | datasets==2.13.2 21 | dill==0.3.6 22 | einops==0.6.1 23 | filelock==3.12.2 24 | fonttools==4.38.0 25 | frozenlist==1.3.3 26 | fsspec==2023.1.0 27 | google-auth==2.35.0 28 | google-auth-oauthlib==0.4.6 29 | grpcio==1.62.3 30 | huggingface-hub==0.16.4 31 | idna==3.10 32 | importlib-metadata==6.7.0 33 | jmespath==0.10.0 34 | kiwisolver==1.4.5 35 | Markdown==3.4.4 36 | markdown-it-py==2.2.0 37 | MarkupSafe==2.1.5 38 | matplotlib==3.5.3 39 | mdurl==0.1.2 40 | mmdet==2.17.0 41 | mmengine==0.8.1 42 | mmsegmentation==0.19.0 43 | model-index==0.1.11 44 | multidict==6.0.5 45 | multiprocess==0.70.14 46 | numpy==1.21.6 47 | oauthlib==3.2.2 48 | opencv-python==4.10.0.84 49 | opendatalab==0.0.10 50 | openmim==0.3.9 51 | openxlab==0.0.10 52 | ordered-set==4.1.0 53 | oss2==2.17.0 54 | packaging==24.0 55 | pandas==1.1.5 56 | Pillow==9.5.0 57 | pip==22.3.1 58 | platformdirs==4.0.0 59 | prettytable==3.7.0 60 | protobuf==3.20.3 61 | psutil==6.1.0 62 | pyarrow==12.0.1 63 | pyasn1==0.5.1 64 | pyasn1-modules==0.3.0 65 | pycocotools==2.0.7 66 | pycparser==2.21 67 | pycryptodome==3.21.0 68 | Pygments==2.17.2 69 | pyparsing==3.1.4 70 | python-dateutil==2.9.0.post0 71 | pytz==2023.4 72 | PyYAML==6.0.1 73 | regex==2024.4.16 74 | requests==2.28.2 75 | requests-oauthlib==2.0.0 76 | rich==13.8.1 77 | rsa==4.9 78 | safetensors==0.4.5 79 | scipy==1.7.3 80 | setuptools==60.2.0 81 | six==1.16.0 82 | tabulate==0.9.0 83 | tensorboard==2.11.2 84 | tensorboard-data-server==0.6.1 85 | tensorboard-plugin-wit==1.8.1 86 | termcolor==2.3.0 87 | terminaltables==3.1.10 88 | timm==0.9.12 89 | tokenizers==0.13.3 90 | tomli==2.0.1 91 | tqdm==4.65.2 92 | transformers==4.30.2 93 | ttach==0.0.3 94 | typing_extensions==4.7.1 95 | urllib3==1.26.20 96 | wcwidth==0.2.13 97 | Werkzeug==2.2.3 98 | wheel==0.38.4 99 | xxhash==3.5.0 100 | yapf==0.40.2 101 | yarl==1.9.4 102 | zipp==3.15.0 103 | -------------------------------------------------------------------------------- /utils/misc.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import numpy as np 4 | import os 5 | 6 | def clip_gradient(optimizer, grad_clip): 7 | for group in optimizer.param_groups: 8 | for param in group['params']: 9 | if param.grad is not None: 10 | param.grad.data.clamp_(-grad_clip, grad_clip) 11 | 12 | 13 | def adjust_lr(optimizer, init_lr, epoch, decay_rate=0.1, decay_epoch=5): 14 | decay = decay_rate ** (epoch // decay_epoch) 15 | for param_group in optimizer.param_groups: 16 | param_group['lr'] *= decay 17 | 18 | 19 | def truncated_normal_(tensor, mean=0, std=1): 20 | size = tensor.shape 21 | tmp = tensor.new_empty(size + (4,)).normal_() 22 | valid = (tmp < 2) & (tmp > -2) 23 | ind = valid.max(-1, keepdim=True)[1] 24 | tensor.data.copy_(tmp.gather(-1, ind).squeeze(-1)) 25 | tensor.data.mul_(std).add_(mean) 26 | 27 | def init_weights(m): 28 | if type(m) == nn.Conv2d or type(m) == nn.ConvTranspose2d: 29 | nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu') 30 | #nn.init.normal_(m.weight, std=0.001) 31 | #nn.init.normal_(m.bias, std=0.001) 32 | truncated_normal_(m.bias, mean=0, std=0.001) 33 | 34 | def init_weights_orthogonal_normal(m): 35 | if type(m) == nn.Conv2d or type(m) == nn.ConvTranspose2d: 36 | nn.init.orthogonal_(m.weight) 37 | truncated_normal_(m.bias, mean=0, std=0.001) 38 | #nn.init.normal_(m.bias, std=0.001) 39 | 40 | def l2_regularisation(m): 41 | l2_reg = None 42 | 43 | for W in m.parameters(): 44 | if l2_reg is None: 45 | l2_reg = W.norm(2) 46 | else: 47 | l2_reg = l2_reg + W.norm(2) 48 | return l2_reg 49 | 50 | def check_mkdir(dir_name): 51 | if not os.path.isdir(dir_name): 52 | os.makedirs(dir_name) 53 | 54 | class AvgMeter(object): 55 | def __init__(self, num=40): 56 | self.num = num 57 | self.reset() 58 | 59 | def reset(self): 60 | self.val = 0 61 | self.avg = 0 62 | self.sum = 0 63 | self.count = 0 64 | self.losses = [] 65 | 66 | def update(self, val, n=1): 67 | self.val = val 68 | self.sum += val * n 69 | self.count += n 70 | self.avg = self.sum / self.count 71 | self.losses.append(val) 72 | 73 | def show(self): 74 | a = len(self.losses) 75 | b = np.maximum(a-self.num, 0) 76 | c = self.losses[b:] 77 | #print(c) 78 | #d = torch.mean(torch.stack(c)) 79 | #print(d) 80 | return torch.mean(torch.stack(c)) 81 | 82 | # def save_mask_prediction_example(mask, pred, iter): 83 | # plt.imshow(pred[0,:,:],cmap='Greys') 84 | # plt.savefig('images/'+str(iter)+"_prediction.png") 85 | # plt.imshow(mask[0,:,:],cmap='Greys') 86 | # plt.savefig('images/'+str(iter)+"_mask.png") -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import time 4 | import torch 5 | from PIL import Image 6 | from torch.autograd import Variable 7 | from torchvision import transforms 8 | from utils.config import diste1,diste2,diste3,diste4,disvd 9 | from utils.misc import check_mkdir 10 | from model.MVANet import inf_MVANet 11 | import ttach as tta 12 | 13 | torch.cuda.set_device(0) 14 | ckpt_path = './saved_model/' 15 | args = { 16 | 'save_results': True 17 | } 18 | 19 | 20 | img_transform = transforms.Compose([ 21 | transforms.ToTensor(), 22 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 23 | 24 | ]) 25 | 26 | depth_transform = transforms.ToTensor() 27 | target_transform = transforms.ToTensor() 28 | to_pil = transforms.ToPILImage() 29 | 30 | to_test ={ 31 | 'DIS-TE1':diste1, 32 | 'DIS-TE2':diste2, 33 | 'DIS-TE3':diste3, 34 | 'DIS-TE4':diste4, 35 | 'DIS-VD':disvd, 36 | } 37 | 38 | transforms = tta.Compose( 39 | [ 40 | tta.HorizontalFlip(), 41 | tta.Scale(scales=[0.75, 1, 1.125], interpolation='bilinear', align_corners=False), 42 | ] 43 | ) 44 | 45 | def main(item): 46 | net = inf_MVANet().cuda() 47 | pretrained_dict = torch.load(os.path.join(ckpt_path, item + '.pth'), map_location='cuda') 48 | model_dict = net.state_dict() 49 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} 50 | model_dict.update(pretrained_dict) 51 | net.load_state_dict(model_dict) 52 | net.eval() 53 | with torch.no_grad(): 54 | for name, root in to_test.items(): 55 | root1 = os.path.join(root, 'images') 56 | img_list = [os.path.splitext(f) for f in os.listdir(root1)] 57 | for idx, img_name in enumerate(img_list): 58 | print ('predicting for %s: %d / %d' % (name, idx + 1, len(img_list))) 59 | rgb_png_path = os.path.join(root, 'images', img_name[0] + '.png') 60 | rgb_jpg_path = os.path.join(root, 'images', img_name[0] + '.jpg') 61 | if os.path.exists(rgb_png_path): 62 | img = Image.open(rgb_png_path).convert('RGB') 63 | else: 64 | img = Image.open(rgb_jpg_path).convert('RGB') 65 | w_,h_ = img.size 66 | img_resize = img.resize([1024,1024],Image.BILINEAR) 67 | img_var = Variable(img_transform(img_resize).unsqueeze(0), volatile=True).cuda() 68 | mask = [] 69 | for transformer in transforms: 70 | rgb_trans = transformer.augment_image(img_var) 71 | model_output = net(rgb_trans) 72 | deaug_mask = transformer.deaugment_mask(model_output) 73 | mask.append(deaug_mask) 74 | 75 | prediction = torch.mean(torch.stack(mask, dim=0), dim=0) 76 | prediction = prediction.sigmoid() 77 | prediction = to_pil(prediction.data.squeeze(0).cpu()) 78 | prediction = prediction.resize((w_, h_), Image.BILINEAR) 79 | if args['save_results']: 80 | check_mkdir(os.path.join(ckpt_path, item, name)) 81 | prediction.save(os.path.join(ckpt_path, item, name, img_name[0] + '.png')) 82 | 83 | 84 | 85 | if __name__ == '__main__': 86 | files = os.listdir(ckpt_path) 87 | files.sort() 88 | for items in files: 89 | if '80.pth' in items: 90 | item = items.split('.')[0] 91 | main(item) 92 | 93 | 94 | 95 | 96 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | from utils.test_data import test_dataset 4 | from utils.saliency_metric import cal_mae,cal_fm,cal_sm,cal_em,cal_wfm, cal_dice, cal_iou,cal_ber,cal_acc, HCEMeasure 5 | from utils.config import diste1,diste2,diste3,diste4,disvd 6 | from tqdm import tqdm 7 | import cv2 8 | from skimage.morphology import skeletonize 9 | 10 | test_datasets = { 11 | 'DIS-TE1':diste1, 12 | 'DIS-TE2':diste2, 13 | 'DIS-TE3':diste3, 14 | 'DIS-TE4':diste4, 15 | 'DIS-VD':disvd, 16 | } 17 | 18 | 19 | dir = [ 20 | './saved_model/MVANet/Model_80/' 21 | ] 22 | for d in dir: 23 | for name, root in test_datasets.items(): 24 | print(name) 25 | sal_root = os.path.join(d, name) 26 | print(sal_root) 27 | gt_root = root + 'masks' 28 | print(gt_root) 29 | if os.path.exists(sal_root): 30 | test_loader = test_dataset(sal_root, gt_root) 31 | mae, fm, sm, em, wfm, m_dice, m_iou, ber, acc, hce = cal_mae(), cal_fm( 32 | test_loader.size), cal_sm(), cal_em(), cal_wfm(), cal_dice(), cal_iou(), cal_ber(), cal_acc(), HCEMeasure() 33 | for i in tqdm(range(test_loader.size)): 34 | # print ('predicting for %d / %d' % ( i + 1, test_loader.size)) 35 | sal, gt, gt_path = test_loader.load_data() 36 | 37 | if sal.size != gt.size: 38 | x, y = gt.size 39 | sal = sal.resize((x, y)) 40 | gt = np.asarray(gt, np.float64) 41 | gt /= (gt.max() + 1e-8) 42 | gt[gt > 0.5] = 1 43 | gt[gt != 1] = 0 44 | res = sal 45 | res = np.array(res, np.float64) 46 | if res.max() == res.min(): 47 | res = res / 255 48 | else: 49 | res = (res - res.min()) / (res.max() - res.min()) 50 | 51 | ske_path = gt_path.replace("/masks/", "/ske/") 52 | if os.path.exists(ske_path): 53 | ske_ary = cv2.imread(ske_path, cv2.IMREAD_GRAYSCALE) 54 | ske_ary = ske_ary > 128 55 | else: 56 | ske_ary = skeletonize(gt > 0.5) 57 | ske_save_dir = os.path.join(*ske_path.split(os.sep)[:-1]) 58 | if ske_path[0] == os.sep: 59 | ske_save_dir = os.sep + ske_save_dir 60 | os.makedirs(ske_save_dir, exist_ok=True) 61 | cv2.imwrite(ske_path, ske_ary.astype(np.uint8) * 255) 62 | 63 | mae.update(res, gt) 64 | sm.update(res, gt) 65 | fm.update(res, gt) 66 | em.update(res, gt) 67 | wfm.update(res, gt) 68 | m_dice.update(res, gt) 69 | m_iou.update(res, gt) 70 | ber.update(res, gt) 71 | acc.update(res, gt) 72 | hce.step(pred=res, gt=gt, gt_ske=ske_ary) 73 | 74 | 75 | MAE = mae.show() 76 | maxf, meanf, _, _ = fm.show() 77 | sm = sm.show() 78 | em = em.show() 79 | wfm = wfm.show() 80 | m_dice = m_dice.show() 81 | m_iou = m_iou.show() 82 | ber = ber.show() 83 | acc = acc.show() 84 | hce = hce.get_results()["hce"] 85 | 86 | print( 87 | 'dataset: {} MAE: {:.4f} Ber: {:.4f} maxF: {:.4f} avgF: {:.4f} wfm: {:.4f} Sm: {:.4f} adpEm: {:.4f} M_dice: {:.4f} M_iou: {:.4f} Acc: {:.4f} HCE:{}'.format( 88 | name, MAE, ber, maxf, meanf, wfm, sm, em, m_dice, m_iou, acc, int(hce))) 89 | -------------------------------------------------------------------------------- /utils/test_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image 3 | import torch.utils.data as data 4 | import torchvision.transforms as transforms 5 | 6 | class test_dataset: 7 | def __init__(self, image_root, gt_root): 8 | self.img_list_1 = [os.path.splitext(f)[0] for f in os.listdir(image_root) if f.endswith('.png') or f.endswith('.jpg') or f.endswith('.bmp')] 9 | self.img_list_2 = [os.path.splitext(f)[0] for f in os.listdir(gt_root) if f.endswith('.png') or f.endswith('.jpg') or f.endswith('.bmp')] 10 | self.img_list = list(set(self.img_list_1).intersection(set(self.img_list_2))) 11 | 12 | self.image_root = image_root 13 | self.gt_root = gt_root 14 | self.transform = transforms.Compose([ 15 | transforms.ToTensor(), 16 | ]) 17 | self.gt_transform = transforms.ToTensor() 18 | self.size = len(self.img_list) 19 | self.index = 0 20 | 21 | def load_data(self): 22 | #image = self.rgb_loader(self.images[self.index]) 23 | rgb_png_path = os.path.join(self.image_root,self.img_list[self.index]+ '.png') 24 | rgb_jpg_path = os.path.join(self.image_root,self.img_list[self.index]+ '.jpg') 25 | rgb_bmp_path = os.path.join(self.image_root,self.img_list[self.index]+ '.bmp') 26 | if os.path.exists(rgb_png_path): 27 | image = self.binary_loader(rgb_png_path) 28 | elif os.path.exists(rgb_jpg_path): 29 | image = self.binary_loader(rgb_jpg_path) 30 | else: 31 | image = self.binary_loader(rgb_bmp_path) 32 | if os.path.exists(os.path.join(self.gt_root,self.img_list[self.index] + '.png')): 33 | gt = self.binary_loader(os.path.join(self.gt_root,self.img_list[self.index] + '.png')) 34 | elif os.path.exists(os.path.join(self.gt_root,self.img_list[self.index] + '.jpg')): 35 | gt = self.binary_loader(os.path.join(self.gt_root,self.img_list[self.index] + '.jpg')) 36 | else: 37 | gt = self.binary_loader(os.path.join(self.gt_root, self.img_list[self.index] + '.bmp')) 38 | 39 | self.index += 1 40 | return image, gt, os.path.join(self.gt_root, self.img_list[self.index - 1] + ".jpg") 41 | 42 | def rgb_loader(self, path): 43 | with open(path, 'rb') as f: 44 | img = Image.open(f) 45 | return img.convert('RGB') 46 | 47 | def binary_loader(self, path): 48 | with open(path, 'rb') as f: 49 | img = Image.open(f) 50 | return img.convert('L') 51 | 52 | class val_dataset: 53 | def __init__(self, image_root, gt_root): 54 | self.img_list_1 = [os.path.splitext(f)[0] for f in os.listdir(image_root) if f.endswith('.png') or f.endswith('.jpg') or f.endswith('.bmp')] 55 | self.img_list_2 = [os.path.splitext(f)[0] for f in os.listdir(gt_root) if f.endswith('.png') or f.endswith('.jpg') or f.endswith('.bmp')] 56 | self.img_list = list(set(self.img_list_1).intersection(set(self.img_list_2))) 57 | 58 | self.image_root = image_root 59 | self.gt_root = gt_root 60 | self.transform = transforms.Compose([ 61 | transforms.ToTensor(), 62 | ]) 63 | self.gt_transform = transforms.ToTensor() 64 | self.size = len(self.img_list) 65 | self.index = 0 66 | 67 | def load_data(self): 68 | #image = self.rgb_loader(self.images[self.index]) 69 | rgb_png_path = os.path.join(self.image_root,self.img_list[self.index]+ '.png') 70 | rgb_jpg_path = os.path.join(self.image_root,self.img_list[self.index]+ '.jpg') 71 | rgb_bmp_path = os.path.join(self.image_root,self.img_list[self.index]+ '.bmp') 72 | if os.path.exists(rgb_png_path): 73 | image = self.rgb_loader(rgb_png_path) 74 | elif os.path.exists(rgb_jpg_path): 75 | image = self.rgb_loader(rgb_jpg_path) 76 | else: 77 | image = self.rgb_loader(rgb_bmp_path) 78 | if os.path.exists(os.path.join(self.gt_root,self.img_list[self.index] + '.png')): 79 | gt = self.binary_loader(os.path.join(self.gt_root,self.img_list[self.index] + '.png')) 80 | elif os.path.exists(os.path.join(self.gt_root,self.img_list[self.index] + '.jpg')): 81 | gt = self.binary_loader(os.path.join(self.gt_root,self.img_list[self.index] + '.jpg')) 82 | else: 83 | gt = self.binary_loader(os.path.join(self.gt_root, self.img_list[self.index] + '.bmp')) 84 | 85 | self.index += 1 86 | return image, gt 87 | 88 | def rgb_loader(self, path): 89 | with open(path, 'rb') as f: 90 | img = Image.open(f) 91 | return img.convert('RGB') 92 | 93 | def binary_loader(self, path): 94 | with open(path, 'rb') as f: 95 | img = Image.open(f) 96 | return img.convert('L') -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MVANet 2 | The official repo of the CVPR 2024 paper (Highlight), [Multi-view Aggregation Network for Dichotomous Image Segmentation](https://arxiv.org/abs/2404.07445) 3 | 4 | 5 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/multi-view-aggregation-network-for/dichotomous-image-segmentation-on-dis-te1)](https://paperswithcode.com/sota/dichotomous-image-segmentation-on-dis-te1?p=multi-view-aggregation-network-for) 6 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/multi-view-aggregation-network-for/dichotomous-image-segmentation-on-dis-te2)](https://paperswithcode.com/sota/dichotomous-image-segmentation-on-dis-te2?p=multi-view-aggregation-network-for) 7 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/multi-view-aggregation-network-for/dichotomous-image-segmentation-on-dis-te3)](https://paperswithcode.com/sota/dichotomous-image-segmentation-on-dis-te3?p=multi-view-aggregation-network-for) 8 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/multi-view-aggregation-network-for/dichotomous-image-segmentation-on-dis-te4)](https://paperswithcode.com/sota/dichotomous-image-segmentation-on-dis-te4?p=multi-view-aggregation-network-for) 9 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/multi-view-aggregation-network-for/dichotomous-image-segmentation-on-dis-vd)](https://paperswithcode.com/sota/dichotomous-image-segmentation-on-dis-vd?p=multi-view-aggregation-network-for) 10 | ## Introduction 11 | Dichotomous Image Segmentation (DIS) has recently emerged towards high-precision object segmentation from high-resolution natural images. When designing an effective DIS model, the main challenge is how to balance the semantic dispersion of high-resolution targets in the small receptive field and the loss of high-precision details in the large receptive field. Existing methods rely on tedious multiple encoder-decoder streams and stages to gradually complete the global localization and local refinement. 12 | 13 | Human visual system captures regions of interest by observing them from multiple views. Inspired by it, we model DIS as a multi-view object perception problem and provide a parsimonious multi-view aggregation network (MVANet), which unifies the feature fusion of the distant view and close-up view into a single stream with one encoder-decoder structure. Specifically, we split the high-resolution input images from the original view into the distant view images with global information and close-up view images with local details. Thus, they can constitute a set of complementary multi-view low-resolution input patches. 14 |

15 | image 16 |

17 | 18 | Moreover, two efficient transformer-based multi-view complementary localization and refinement modules (MCLM & MCRM) are proposed to jointly capturing the localization and restoring the boundary details of the targets. 19 |

20 | image 21 |

22 | 23 | 24 | NOTE:Initially, we calculated Fm by averaging precision and recall, then using these averages to compute Fm. Thanks to feedback, we identified this bug and revised the approach to compute Fm for each image individually before averaging. We have updated the `./utils/saliency_metric.py` file to fix this issue. Additionally, we have updated the results on the DIS-VD dataset and included the HCE metric. The updated results are shown below: 25 |

26 | image 27 |

28 | 29 | Here are some of our visual results: 30 |

31 | image 32 |

33 | 34 | 35 | ## I. Requiremets 36 | 37 | 1. Clone this repository 38 | ``` 39 | git clone git@github.com:qianyu-dlut/MVANet.git 40 | cd MVANet 41 | ``` 42 | 43 | 2. Install packages 44 | 45 | ``` 46 | conda create -n mvanet python==3.7 47 | conda activate mvanet 48 | pip install torch==1.10.1+cu102 torchvision==0.11.2+cu102 torchaudio==0.10.1 -f https://download.pytorch.org/whl/cu102/torch_stable.html 49 | pip install -U openmim 50 | mim install mmcv-full==1.3.17 51 | pip install -r requirements.txt 52 | ``` 53 | 54 | ## II. Training 55 | 1. Download the dataset [DIS5K](https://drive.google.com/file/d/1O1eIuXX1hlGsV7qx4eSkjH231q7G1by1/view?usp=sharing) and update `image_root` `gt_root` in `./train.py` (line 39-40). 56 | 1. Download the pretrained model at [Google Drive](https://drive.google.com/file/d/1-Zi_DtCT8oC2UAZpB3_XoFOIxIweIAyk/view?usp=sharing) and update the pretrained model path in `./model/SwinTransformer.py` (line 643) 57 | 2. Then, you can start training by simply running: 58 | ``` 59 | python train.py 60 | ``` 61 | 62 | ## III. Testing 63 | 1. Update the data path in config file `./utils/config.py` (line 3~7) 64 | 2. Replace the existing path with the path to your saved model in `./predict.py` (line 14) 65 | 66 | You can also download our trained model at [Google Drive](https://drive.google.com/file/d/1_gabQXOF03MfXnf3EWDK1d_8wKiOemOv/view?usp=sharing). 67 | 3. Start predicting by: 68 | ``` 69 | python predict.py 70 | ``` 71 | 4. Change the predicted map path in `./test.py` (line 19) and start testing: 72 | ``` 73 | python test.py 74 | ``` 75 | 76 | You can get our prediction maps at [Google Drive](https://drive.google.com/file/d/1qN9mVNK9hfS_a1radFQ9QNYsAQo9FpYS/view?usp=sharing). 77 | 78 | 5. You can get the FPS performance by running: 79 | ``` 80 | python test_fps.py 81 | ``` 82 | 83 | ## Contact 84 | If you have any questions, please feel free to contact me(ms.yuqian AT mail DOT dlut DOT edu DOT cn). 85 | 86 | ## Citations 87 | ``` 88 | @inproceedings{MVANet, 89 | title={Multi-view Aggregation Network for Dichotomous Image Segmentation}, 90 | author={Yu, Qian and Zhao, Xiaoqi and Pang, Youwei and Zhang, Lihe and Lu, Huchuan}, 91 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, 92 | pages={3921--3930}, 93 | year={2024} 94 | } 95 | 96 | ``` 97 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os, argparse 3 | os.environ["CUDA_VISIBLE_DEVICES"] ='0' 4 | from datetime import datetime 5 | from model.MVANet import MVANet 6 | from utils.dataset_strategy_fpn import get_loader 7 | from utils.misc import adjust_lr, AvgMeter 8 | import torch.nn.functional as F 9 | from torch.autograd import Variable 10 | from torch.backends import cudnn 11 | from torchvision import transforms 12 | import torch.nn as nn 13 | from torch.cuda import amp 14 | from torch.utils.tensorboard import SummaryWriter 15 | 16 | writer = SummaryWriter() 17 | 18 | cudnn.benchmark = True 19 | 20 | parser = argparse.ArgumentParser() 21 | parser.add_argument('--epoch', type=int, default=80, help='epoch number') 22 | parser.add_argument('--lr_gen', type=float, default=1e-5, help='learning rate') 23 | parser.add_argument('--batchsize', type=int, default=1, help='training batch size') 24 | parser.add_argument('--trainsize', type=int, default=1024, help='training dataset size') 25 | parser.add_argument('--decay_rate', type=float, default=0.9, help='decay rate of learning rate') 26 | parser.add_argument('--decay_epoch', type=int, default=60, help='every n epochs decay learning rate') 27 | 28 | opt = parser.parse_args() 29 | print('Generator Learning Rate: {}'.format(opt.lr_gen)) 30 | # build models 31 | if hasattr(torch.cuda, 'empty_cache'): 32 | torch.cuda.empty_cache() 33 | generator = MVANet() 34 | generator.cuda() 35 | 36 | generator_params = generator.parameters() 37 | generator_optimizer = torch.optim.Adam(generator_params, opt.lr_gen) 38 | 39 | image_root = './data/DIS5K/DIS-TR/images/' 40 | gt_root = './data/DIS5K/DIS-TR/masks/' 41 | 42 | train_loader = get_loader(image_root, gt_root, batchsize=opt.batchsize, trainsize=opt.trainsize) 43 | total_step = len(train_loader) 44 | to_pil = transforms.ToPILImage() 45 | ## define loss 46 | 47 | CE = torch.nn.BCELoss() 48 | mse_loss = torch.nn.MSELoss(size_average=True, reduce=True) 49 | size_rates = [1] 50 | criterion = nn.BCEWithLogitsLoss().cuda() 51 | criterion_mae = nn.L1Loss().cuda() 52 | criterion_mse = nn.MSELoss().cuda() 53 | use_fp16 = True 54 | scaler = amp.GradScaler(enabled=use_fp16) 55 | 56 | def structure_loss(pred, mask): 57 | weit = 1+5*torch.abs(F.avg_pool2d(mask, kernel_size=31, stride=1, padding=15)-mask) 58 | wbce = F.binary_cross_entropy_with_logits(pred, mask, reduction='none') 59 | wbce = (weit*wbce).sum(dim=(2,3))/weit.sum(dim=(2,3)) 60 | 61 | 62 | pred = torch.sigmoid(pred) 63 | inter = ((pred * mask) * weit).sum(dim=(2, 3)) 64 | 65 | union = ((pred + mask) * weit).sum(dim=(2, 3)) 66 | wiou = 1-(inter+1)/(union-inter+1) 67 | 68 | return (wbce+wiou).mean() 69 | 70 | 71 | 72 | for epoch in range(1, opt.epoch+1): 73 | torch.cuda.empty_cache() 74 | generator.train() 75 | loss_record = AvgMeter() 76 | print('Generator Learning Rate: {}'.format(generator_optimizer.param_groups[0]['lr'])) 77 | for i, pack in enumerate(train_loader, start=1): 78 | torch.cuda.empty_cache() 79 | for rate in size_rates: 80 | torch.cuda.empty_cache() 81 | generator_optimizer.zero_grad() 82 | images, gts = pack 83 | images = Variable(images) 84 | gts = Variable(gts) 85 | images = images.cuda() 86 | gts = gts.cuda() 87 | trainsize = int(round(opt.trainsize * rate / 32) * 32) 88 | if rate != 1: 89 | images = F.upsample(images, size=(trainsize, trainsize), mode='bilinear', 90 | align_corners=True) 91 | gts = F.upsample(gts, size=(trainsize, trainsize), mode='bilinear', align_corners=True) 92 | 93 | b, c, h, w = gts.size() 94 | target_1 = F.upsample(gts, size=h // 4, mode='nearest') 95 | target_2 = F.upsample(gts, size=h // 8, mode='nearest').cuda() 96 | target_3 = F.upsample(gts, size=h // 16, mode='nearest').cuda() 97 | target_4 = F.upsample(gts, size=h // 32, mode='nearest').cuda() 98 | target_5 = F.upsample(gts, size=h // 64, mode='nearest').cuda() 99 | 100 | with amp.autocast(enabled=use_fp16): 101 | sideout5, sideout4, sideout3, sideout2, sideout1, final, glb5, glb4, glb3, glb2, glb1, tokenattmap4, tokenattmap3,tokenattmap2,tokenattmap1= generator.forward(images) 102 | loss1 = structure_loss(sideout5, target_4) 103 | loss2 = structure_loss(sideout4, target_3) 104 | loss3 = structure_loss(sideout3, target_2) 105 | loss4 = structure_loss(sideout2, target_1) 106 | loss5 = structure_loss(sideout1, target_1) 107 | loss6 = structure_loss(final, gts) 108 | loss7 = structure_loss(glb5, target_5) 109 | loss8 = structure_loss(glb4, target_4) 110 | loss9 = structure_loss(glb3, target_3) 111 | loss10 = structure_loss(glb2, target_2) 112 | loss11 = structure_loss(glb1, target_2) 113 | loss12 = structure_loss(tokenattmap4, target_3) 114 | loss13 = structure_loss(tokenattmap3, target_2) 115 | loss14 = structure_loss(tokenattmap2, target_1) 116 | loss15 = structure_loss(tokenattmap1, target_1) 117 | loss = loss1 + loss2 + loss3 + loss4 + loss5 + loss6 + 0.3*(loss7 + loss8 + loss9 + loss10 + loss11)+ 0.3*(loss12 + loss13 + loss14 + loss15) 118 | Loss_loc = loss1 + loss2 + loss3 + loss4 + loss5 + loss6 119 | Loss_glb = loss7 + loss8 + loss9 + loss10 + loss11 120 | Loss_map = loss12 + loss13 + loss14 + loss15 121 | writer.add_scalar('loss', loss.item(), epoch * len(train_loader) + i) 122 | 123 | generator_optimizer.zero_grad() 124 | scaler.scale(loss).backward() 125 | scaler.step(generator_optimizer) 126 | scaler.update() 127 | 128 | if rate == 1: 129 | loss_record.update(loss.data, opt.batchsize) 130 | 131 | 132 | if i % 10 == 0 or i == total_step: 133 | print('{} Epoch [{:03d}/{:03d}], Step [{:04d}/{:04d}], gen Loss: {:.4f}'. 134 | format(datetime.now(), epoch, opt.epoch, i, total_step, loss_record.show())) 135 | 136 | adjust_lr(generator_optimizer, opt.lr_gen, epoch, opt.decay_rate, opt.decay_epoch) 137 | # save checkpoints every 20 epochs 138 | if epoch % 20== 0 : 139 | save_path = './saved_model/MVANet/' 140 | if not os.path.exists(save_path): 141 | os.mkdir(save_path) 142 | torch.save(generator.state_dict(), save_path + 'Model' + '_%d' % epoch + '.pth') 143 | 144 | 145 | -------------------------------------------------------------------------------- /utils/dataset_strategy_fpn.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image 3 | import torch.utils.data as data 4 | import torchvision.transforms as transforms 5 | import random 6 | import numpy as np 7 | from PIL import ImageEnhance 8 | 9 | 10 | # several data augumentation strategies 11 | def cv_random_flip(img, label): 12 | flip_flag = random.randint(0, 1) 13 | # flip_flag2= random.randint(0,1) 14 | # left right flip 15 | if flip_flag == 1: 16 | img = img.transpose(Image.FLIP_LEFT_RIGHT) 17 | label = label.transpose(Image.FLIP_LEFT_RIGHT) 18 | # top bottom flip 19 | # if flip_flag2==1: 20 | # img = img.transpose(Image.FLIP_TOP_BOTTOM) 21 | # label = label.transpose(Image.FLIP_TOP_BOTTOM) 22 | # depth = depth.transpose(Image.FLIP_TOP_BOTTOM) 23 | return img, label 24 | 25 | 26 | def randomCrop(image, label): 27 | border = 30 28 | image_width = image.size[0] 29 | image_height = image.size[1] 30 | crop_win_width = np.random.randint(image_width - border, image_width) 31 | crop_win_height = np.random.randint(image_height - border, image_height) 32 | random_region = ( 33 | (image_width - crop_win_width) >> 1, (image_height - crop_win_height) >> 1, (image_width + crop_win_width) >> 1, 34 | (image_height + crop_win_height) >> 1) 35 | return image.crop(random_region), label.crop(random_region) 36 | 37 | 38 | def randomRotation(image, label): 39 | mode = Image.BICUBIC 40 | if random.random() > 0.8: 41 | random_angle = np.random.randint(-15, 15) 42 | image = image.rotate(random_angle, mode) 43 | label = label.rotate(random_angle, mode) 44 | return image, label 45 | 46 | 47 | def colorEnhance(image): 48 | bright_intensity = random.randint(5, 15) / 10.0 49 | image = ImageEnhance.Brightness(image).enhance(bright_intensity) 50 | contrast_intensity = random.randint(5, 15) / 10.0 51 | image = ImageEnhance.Contrast(image).enhance(contrast_intensity) 52 | color_intensity = random.randint(0, 20) / 10.0 53 | image = ImageEnhance.Color(image).enhance(color_intensity) 54 | sharp_intensity = random.randint(0, 30) / 10.0 55 | image = ImageEnhance.Sharpness(image).enhance(sharp_intensity) 56 | return image 57 | 58 | 59 | def randomGaussian(image, mean=0.1, sigma=0.35): 60 | def gaussianNoisy(im, mean=mean, sigma=sigma): 61 | for _i in range(len(im)): 62 | im[_i] += random.gauss(mean, sigma) 63 | return im 64 | 65 | img = np.asarray(image) 66 | width, height = img.shape 67 | img = gaussianNoisy(img[:].flatten(), mean, sigma) 68 | img = img.reshape([width, height]) 69 | return Image.fromarray(np.uint8(img)) 70 | 71 | 72 | def randomPeper(img): 73 | img = np.array(img) 74 | noiseNum = int(0.0015 * img.shape[0] * img.shape[1]) 75 | for i in range(noiseNum): 76 | 77 | randX = random.randint(0, img.shape[0] - 1) 78 | 79 | randY = random.randint(0, img.shape[1] - 1) 80 | 81 | if random.randint(0, 1) == 0: 82 | 83 | img[randX, randY] = 0 84 | 85 | else: 86 | 87 | img[randX, randY] = 255 88 | return Image.fromarray(img) 89 | 90 | 91 | # dataset for training 92 | # The current loader is not using the normalized depth maps for training and test. If you use the normalized depth maps 93 | # (e.g., 0 represents background and 1 represents foreground.), the performance will be further improved. 94 | class DISDataset(data.Dataset): 95 | def __init__(self, image_root, gt_root, trainsize): 96 | self.trainsize = trainsize 97 | self.images = [image_root + f for f in os.listdir(image_root) if f.endswith('.jpg') or f.endswith('tif')] 98 | self.gts = [gt_root + f for f in os.listdir(gt_root) if f.endswith('.jpg') 99 | or f.endswith('.png') or f.endswith('tif')] 100 | self.images = sorted(self.images) 101 | self.gts = sorted(self.gts) 102 | self.filter_files() 103 | self.size = len(self.images) 104 | self.img_transform = transforms.Compose([ 105 | transforms.Resize((self.trainsize, self.trainsize)), 106 | transforms.ToTensor(), 107 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) 108 | self.gt_transform = transforms.Compose([ 109 | transforms.Resize((self.trainsize, self.trainsize)), 110 | transforms.ToTensor()]) 111 | 112 | def __getitem__(self, index): 113 | image = self.rgb_loader(self.images[index]) 114 | gt = self.binary_loader(self.gts[index]) 115 | image, gt = cv_random_flip(image, gt) 116 | image, gt = randomCrop(image, gt) 117 | image, gt = randomRotation(image, gt) 118 | image = colorEnhance(image) 119 | image = self.img_transform(image) 120 | gt = self.gt_transform(gt) 121 | 122 | 123 | return image, gt 124 | 125 | def filter_files(self): 126 | assert len(self.images) == len(self.gts) and len(self.gts) == len(self.images) 127 | images = [] 128 | gts = [] 129 | for img_path, gt_path in zip(self.images, self.gts): 130 | img = Image.open(img_path) 131 | gt = Image.open(gt_path) 132 | if img.size == gt.size : 133 | images.append(img_path) 134 | gts.append(gt_path) 135 | self.images = images 136 | self.gts = gts 137 | 138 | def rgb_loader(self, path): 139 | with open(path, 'rb') as f: 140 | img = Image.open(f) 141 | return img.convert('RGB') 142 | 143 | def binary_loader(self, path): 144 | with open(path, 'rb') as f: 145 | img = Image.open(f) 146 | return img.convert('L') 147 | 148 | def resize(self, img, gt): 149 | assert img.size == gt.size 150 | w, h = img.size 151 | if h < self.trainsize or w < self.trainsize: 152 | h = max(h, self.trainsize) 153 | w = max(w, self.trainsize) 154 | return img.resize((w, h), Image.BILINEAR), gt.resize((w, h), Image.NEAREST) 155 | else: 156 | return img, gt 157 | 158 | def __len__(self): 159 | return self.size 160 | 161 | 162 | # dataloader for training 163 | def get_loader(image_root, gt_root, batchsize, trainsize, shuffle=True, num_workers=12, pin_memory=False): 164 | dataset = DISDataset(image_root, gt_root, trainsize) 165 | data_loader = data.DataLoader(dataset=dataset, 166 | batch_size=batchsize, 167 | shuffle=shuffle, 168 | num_workers=num_workers, 169 | pin_memory=pin_memory) 170 | return data_loader 171 | 172 | 173 | # test dataset and loader 174 | class test_dataset: 175 | def __init__(self, image_root, depth_root, testsize): 176 | self.testsize = testsize 177 | self.images = [image_root + f for f in os.listdir(image_root) if f.endswith('.jpg')] 178 | self.depths = [depth_root + f for f in os.listdir(depth_root) if f.endswith('.bmp') 179 | or f.endswith('.png')] 180 | self.images = sorted(self.images) 181 | self.depths = sorted(self.depths) 182 | self.transform = transforms.Compose([ 183 | transforms.Resize((self.testsize, self.testsize)), 184 | transforms.ToTensor(), 185 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) 186 | # self.gt_transform = transforms.Compose([ 187 | # transforms.Resize((self.trainsize, self.trainsize)), 188 | # transforms.ToTensor()]) 189 | self.depths_transform = transforms.Compose( 190 | [transforms.Resize((self.testsize, self.testsize)), transforms.ToTensor()]) 191 | self.size = len(self.images) 192 | self.index = 0 193 | 194 | def load_data(self): 195 | image = self.rgb_loader(self.images[self.index]) 196 | HH = image.size[0] 197 | WW = image.size[1] 198 | image = self.transform(image).unsqueeze(0) 199 | depth = self.rgb_loader(self.depths[self.index]) 200 | depth = self.depths_transform(depth).unsqueeze(0) 201 | 202 | name = self.images[self.index].split('/')[-1] 203 | # image_for_post=self.rgb_loader(self.images[self.index]) 204 | # image_for_post=image_for_post.resize(gt.size) 205 | if name.endswith('.jpg'): 206 | name = name.split('.jpg')[0] + '.png' 207 | self.index += 1 208 | self.index = self.index % self.size 209 | return image, depth, HH, WW, name 210 | 211 | def rgb_loader(self, path): 212 | with open(path, 'rb') as f: 213 | img = Image.open(f) 214 | return img.convert('RGB') 215 | 216 | def binary_loader(self, path): 217 | with open(path, 'rb') as f: 218 | img = Image.open(f) 219 | return img.convert('L') 220 | 221 | def __len__(self): 222 | return self.size 223 | 224 | -------------------------------------------------------------------------------- /utils/saliency_metric.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy import ndimage 3 | from scipy.ndimage import convolve, distance_transform_edt as bwdist 4 | import cv2 5 | from skimage.morphology import skeletonize 6 | from skimage.morphology import disk 7 | from skimage.measure import label 8 | 9 | 10 | class cal_fm(object): 11 | # Fmeasure(maxFm,meanFm)---Frequency-tuned salient region detection(CVPR 2009) 12 | def __init__(self, num, thds=255): 13 | self.num = num 14 | self.thds = thds 15 | self.precision = np.zeros((self.num, self.thds)) 16 | self.recall = np.zeros((self.num, self.thds)) 17 | self.meanF = np.zeros((self.num,1)) 18 | self.changeable_fms = [] 19 | self.idx = 0 20 | 21 | def update(self, pred, gt): 22 | if gt.max() != 0: 23 | # prediction, recall, Fmeasure_temp = self.cal(pred, gt) 24 | prediction, recall, Fmeasure_temp, changeable_fms = self.cal(pred, gt) 25 | self.precision[self.idx, :] = prediction 26 | self.recall[self.idx, :] = recall 27 | self.meanF[self.idx, :] = Fmeasure_temp 28 | self.changeable_fms.append(changeable_fms) 29 | self.idx += 1 30 | 31 | def cal(self, pred, gt): 32 | ########################meanF############################## 33 | th = 2 * pred.mean() 34 | if th > 1: 35 | th = 1 36 | 37 | binary = np.zeros_like(pred) 38 | binary[pred >= th] = 1 39 | 40 | hard_gt = np.zeros_like(gt) 41 | hard_gt[gt > 0.5] = 1 42 | tp = (binary * hard_gt).sum() 43 | if tp == 0: 44 | meanF = 0 45 | else: 46 | pre = tp / binary.sum() 47 | rec = tp / hard_gt.sum() 48 | meanF = 1.3 * pre * rec / (0.3 * pre + rec) 49 | ########################maxF############################## 50 | pred = np.uint8(pred * 255) 51 | target = pred[gt > 0.5] 52 | nontarget = pred[gt <= 0.5] 53 | targetHist, _ = np.histogram(target, bins=range(256)) 54 | nontargetHist, _ = np.histogram(nontarget, bins=range(256)) 55 | targetHist = np.cumsum(np.flip(targetHist), axis=0) 56 | nontargetHist = np.cumsum(np.flip(nontargetHist), axis=0) 57 | precision = targetHist / (targetHist + nontargetHist + 1e-8) 58 | recall = targetHist / np.sum(gt) 59 | numerator = 1.3 * precision * recall 60 | denominator = np.where(numerator == 0, 1, 0.3 * precision + recall) 61 | changeable_fms = numerator / denominator 62 | return precision, recall, meanF, changeable_fms 63 | 64 | 65 | def show(self): 66 | assert self.num == self.idx 67 | precision = self.precision.mean(axis=0) 68 | recall = self.recall.mean(axis=0) 69 | # fmeasure = 1.3 * precision * recall / (0.3 * precision + recall + 1e-8) 70 | changeable_fm = np.mean(np.array(self.changeable_fms), axis=0) 71 | fmeasure_avg = self.meanF.mean(axis=0) 72 | return changeable_fm.max(),fmeasure_avg[0],precision,recall 73 | 74 | 75 | 76 | class cal_mae(object): 77 | # mean absolute error 78 | def __init__(self): 79 | self.prediction = [] 80 | 81 | def update(self, pred, gt): 82 | score = self.cal(pred, gt) 83 | self.prediction.append(score) 84 | 85 | def cal(self, pred, gt): 86 | return np.mean(np.abs(pred - gt)) 87 | 88 | def show(self): 89 | return np.mean(self.prediction) 90 | 91 | class cal_dice(object): 92 | # mean absolute error 93 | def __init__(self): 94 | self.prediction = [] 95 | 96 | def update(self, pred, gt): 97 | score = self.cal(pred, gt) 98 | self.prediction.append(score) 99 | 100 | def cal(self, y_pred, y_true): 101 | # smooth = 1 102 | smooth = 1e-5 103 | y_true_f = y_true.flatten() 104 | y_pred_f = y_pred.flatten() 105 | intersection = np.sum(y_true_f * y_pred_f) 106 | return (2. * intersection + smooth) / (np.sum(y_true_f) + np.sum(y_pred_f) + smooth) 107 | 108 | def show(self): 109 | return np.mean(self.prediction) 110 | 111 | class cal_ber(object): 112 | # mean absolute error 113 | def __init__(self): 114 | self.prediction = [] 115 | 116 | def update(self, pred, gt): 117 | score = self.cal(pred, gt) 118 | self.prediction.append(score) 119 | 120 | def cal(self, y_pred, y_true): 121 | binary = np.zeros_like(y_pred) 122 | binary[y_pred >= 0.5] = 1 123 | hard_gt = np.zeros_like(y_true) 124 | hard_gt[y_true > 0.5] = 1 125 | tp = (binary * hard_gt).sum() 126 | tn = ((1-binary) * (1-hard_gt)).sum() 127 | Np = hard_gt.sum() 128 | Nn = (1-hard_gt).sum() 129 | ber = (1-(tp/(Np+1e-8)+tn/(Nn+1e-8))/2) 130 | return ber 131 | 132 | def show(self): 133 | return np.mean(self.prediction) 134 | 135 | class cal_acc(object): 136 | # mean absolute error 137 | def __init__(self): 138 | self.prediction = [] 139 | 140 | def update(self, pred, gt): 141 | score = self.cal(pred, gt) 142 | self.prediction.append(score) 143 | 144 | def cal(self, y_pred, y_true): 145 | binary = np.zeros_like(y_pred) 146 | binary[y_pred >= 0.5] = 1 147 | hard_gt = np.zeros_like(y_true) 148 | hard_gt[y_true > 0.5] = 1 149 | tp = (binary * hard_gt).sum() 150 | tn = ((1-binary) * (1-hard_gt)).sum() 151 | Np = hard_gt.sum() 152 | Nn = (1-hard_gt).sum() 153 | acc = ((tp+tn)/(Np+Nn)) 154 | return acc 155 | 156 | def show(self): 157 | return np.mean(self.prediction) 158 | 159 | class cal_iou(object): 160 | # mean absolute error 161 | def __init__(self): 162 | self.prediction = [] 163 | 164 | def update(self, pred, gt): 165 | score = self.cal(pred, gt) 166 | self.prediction.append(score) 167 | 168 | # def cal(self, input, target): 169 | # classes = 1 170 | # intersection = np.logical_and(target == classes, input == classes) 171 | # # print(intersection.any()) 172 | # union = np.logical_or(target == classes, input == classes) 173 | # return np.sum(intersection) / np.sum(union) 174 | 175 | def cal(self, input, target): 176 | smooth = 1e-5 177 | input = input > 0.5 178 | target_ = target > 0.5 179 | intersection = (input & target_).sum() 180 | union = (input | target_).sum() 181 | 182 | return (intersection + smooth) / (union + smooth) 183 | def show(self): 184 | return np.mean(self.prediction) 185 | 186 | # smooth = 1e-5 187 | # 188 | # if torch.is_tensor(output): 189 | # output = torch.sigmoid(output).data.cpu().numpy() 190 | # if torch.is_tensor(target): 191 | # target = target.data.cpu().numpy() 192 | # output_ = output > 0.5 193 | # target_ = target > 0.5 194 | # intersection = (output_ & target_).sum() 195 | # union = (output_ | target_).sum() 196 | 197 | # return (intersection + smooth) / (union + smooth) 198 | 199 | class cal_sm(object): 200 | # Structure-measure: A new way to evaluate foreground maps (ICCV 2017) 201 | def __init__(self, alpha=0.5): 202 | self.prediction = [] 203 | self.alpha = alpha 204 | 205 | def update(self, pred, gt): 206 | gt = gt > 0.5 207 | score = self.cal(pred, gt) 208 | self.prediction.append(score) 209 | 210 | def show(self): 211 | return np.mean(self.prediction) 212 | 213 | def cal(self, pred, gt): 214 | y = np.mean(gt) 215 | if y == 0: 216 | score = 1 - np.mean(pred) 217 | elif y == 1: 218 | score = np.mean(pred) 219 | else: 220 | score = self.alpha * self.object(pred, gt) + (1 - self.alpha) * self.region(pred, gt) 221 | return score 222 | 223 | def object(self, pred, gt): 224 | fg = pred * gt 225 | bg = (1 - pred) * (1 - gt) 226 | 227 | u = np.mean(gt) 228 | return u * self.s_object(fg, gt) + (1 - u) * self.s_object(bg, np.logical_not(gt)) 229 | 230 | def s_object(self, in1, in2): 231 | x = np.mean(in1[in2]) 232 | sigma_x = np.std(in1[in2]) 233 | return 2 * x / (pow(x, 2) + 1 + sigma_x + 1e-8) 234 | 235 | def region(self, pred, gt): 236 | [y, x] = ndimage.center_of_mass(gt) 237 | y = int(round(y)) + 1 238 | x = int(round(x)) + 1 239 | [gt1, gt2, gt3, gt4, w1, w2, w3, w4] = self.divideGT(gt, x, y) 240 | pred1, pred2, pred3, pred4 = self.dividePred(pred, x, y) 241 | 242 | score1 = self.ssim(pred1, gt1) 243 | score2 = self.ssim(pred2, gt2) 244 | score3 = self.ssim(pred3, gt3) 245 | score4 = self.ssim(pred4, gt4) 246 | 247 | return w1 * score1 + w2 * score2 + w3 * score3 + w4 * score4 248 | 249 | def divideGT(self, gt, x, y): 250 | h, w = gt.shape 251 | area = h * w 252 | LT = gt[0:y, 0:x] 253 | RT = gt[0:y, x:w] 254 | LB = gt[y:h, 0:x] 255 | RB = gt[y:h, x:w] 256 | 257 | w1 = x * y / area 258 | w2 = y * (w - x) / area 259 | w3 = (h - y) * x / area 260 | w4 = (h - y) * (w - x) / area 261 | 262 | return LT, RT, LB, RB, w1, w2, w3, w4 263 | 264 | def dividePred(self, pred, x, y): 265 | h, w = pred.shape 266 | LT = pred[0:y, 0:x] 267 | RT = pred[0:y, x:w] 268 | LB = pred[y:h, 0:x] 269 | RB = pred[y:h, x:w] 270 | 271 | return LT, RT, LB, RB 272 | 273 | def ssim(self, in1, in2): 274 | in2 = np.float32(in2) 275 | h, w = in1.shape 276 | N = h * w 277 | 278 | x = np.mean(in1) 279 | y = np.mean(in2) 280 | sigma_x = np.var(in1) 281 | sigma_y = np.var(in2) 282 | sigma_xy = np.sum((in1 - x) * (in2 - y)) / (N - 1) 283 | 284 | alpha = 4 * x * y * sigma_xy 285 | beta = (x * x + y * y) * (sigma_x + sigma_y) 286 | 287 | if alpha != 0: 288 | score = alpha / (beta + 1e-8) 289 | elif alpha == 0 and beta == 0: 290 | score = 1 291 | else: 292 | score = 0 293 | 294 | return score 295 | 296 | class cal_em(object): 297 | #Enhanced-alignment Measure for Binary Foreground Map Evaluation (IJCAI 2018) 298 | def __init__(self): 299 | self.prediction = [] 300 | 301 | def update(self, pred, gt): 302 | score = self.cal(pred, gt) 303 | self.prediction.append(score) 304 | 305 | def cal(self, pred, gt): 306 | th = 2 * pred.mean() 307 | if th > 1: 308 | th = 1 309 | FM = np.zeros(gt.shape) 310 | FM[pred >= th] = 1 311 | FM = np.array(FM,dtype=bool) 312 | GT = np.array(gt,dtype=bool) 313 | dFM = np.double(FM) 314 | if (sum(sum(np.double(GT)))==0): 315 | enhanced_matrix = 1.0-dFM 316 | elif (sum(sum(np.double(~GT)))==0): 317 | enhanced_matrix = dFM 318 | else: 319 | dGT = np.double(GT) 320 | align_matrix = self.AlignmentTerm(dFM, dGT) 321 | enhanced_matrix = self.EnhancedAlignmentTerm(align_matrix) 322 | [w, h] = np.shape(GT) 323 | score = sum(sum(enhanced_matrix))/ (w * h - 1 + 1e-8) 324 | return score 325 | def AlignmentTerm(self,dFM,dGT): 326 | mu_FM = np.mean(dFM) 327 | mu_GT = np.mean(dGT) 328 | align_FM = dFM - mu_FM 329 | align_GT = dGT - mu_GT 330 | align_Matrix = 2. * (align_GT * align_FM)/ (align_GT* align_GT + align_FM* align_FM + 1e-8) 331 | return align_Matrix 332 | def EnhancedAlignmentTerm(self,align_Matrix): 333 | enhanced = np.power(align_Matrix + 1,2) / 4 334 | return enhanced 335 | def show(self): 336 | return np.mean(self.prediction) 337 | class cal_wfm(object): 338 | def __init__(self, beta=1): 339 | self.beta = beta 340 | self.eps = 1e-6 341 | self.scores_list = [] 342 | 343 | def update(self, pred, gt): 344 | assert pred.ndim == gt.ndim and pred.shape == gt.shape 345 | assert pred.max() <= 1 and pred.min() >= 0 346 | assert gt.max() <= 1 and gt.min() >= 0 347 | 348 | gt = gt > 0.5 349 | if gt.max() == 0: 350 | score = 0 351 | else: 352 | score = self.cal(pred, gt) 353 | self.scores_list.append(score) 354 | 355 | def matlab_style_gauss2D(self, shape=(7, 7), sigma=5): 356 | """ 357 | 2D gaussian mask - should give the same result as MATLAB's 358 | fspecial('gaussian',[shape],[sigma]) 359 | """ 360 | m, n = [(ss - 1.) / 2. for ss in shape] 361 | y, x = np.ogrid[-m:m + 1, -n:n + 1] 362 | h = np.exp(-(x * x + y * y) / (2. * sigma * sigma)) 363 | h[h < np.finfo(h.dtype).eps * h.max()] = 0 364 | sumh = h.sum() 365 | if sumh != 0: 366 | h /= sumh 367 | return h 368 | 369 | def cal(self, pred, gt): 370 | # [Dst,IDXT] = bwdist(dGT); 371 | Dst, Idxt = bwdist(gt == 0, return_indices=True) 372 | 373 | # %Pixel dependency 374 | # E = abs(FG-dGT); 375 | E = np.abs(pred - gt) 376 | # Et = E; 377 | # Et(~GT)=Et(IDXT(~GT)); %To deal correctly with the edges of the foreground region 378 | Et = np.copy(E) 379 | Et[gt == 0] = Et[Idxt[0][gt == 0], Idxt[1][gt == 0]] 380 | 381 | # K = fspecial('gaussian',7,5); 382 | # EA = imfilter(Et,K); 383 | # MIN_E_EA(GT & EA dict: 425 | hce = np.mean(np.array(self.hces)) 426 | return dict(hce=hce) 427 | 428 | 429 | def cal_hce(self, pred: np.ndarray, gt: np.ndarray, gt_ske: np.ndarray, relax=5, epsilon=2.0) -> float: 430 | # Binarize gt 431 | if(len(gt.shape)>2): 432 | gt = gt[:, :, 0] 433 | 434 | epsilon_gt = 0.5#(np.amin(gt)+np.amax(gt))/2.0 435 | gt = (gt>epsilon_gt).astype(np.uint8) 436 | 437 | # Binarize pred 438 | if(len(pred.shape)>2): 439 | pred = pred[:, :, 0] 440 | epsilon_pred = 0.5#(np.amin(pred)+np.amax(pred))/2.0 441 | pred = (pred>epsilon_pred).astype(np.uint8) 442 | 443 | Union = np.logical_or(gt, pred) 444 | TP = np.logical_and(gt, pred) 445 | FP = pred - TP 446 | FN = gt - TP 447 | 448 | # relax the Union of gt and pred 449 | Union_erode = Union.copy() 450 | Union_erode = cv2.erode(Union_erode.astype(np.uint8), disk(1), iterations=relax) 451 | 452 | # --- get the relaxed False Positive regions for computing the human efforts in correcting them --- 453 | FP_ = np.logical_and(FP, Union_erode) # get the relaxed FP 454 | for i in range(0, relax): 455 | FP_ = cv2.dilate(FP_.astype(np.uint8), disk(1)) 456 | FP_ = np.logical_and(FP_, 1-np.logical_or(TP, FN)) 457 | FP_ = np.logical_and(FP, FP_) 458 | 459 | # --- get the relaxed False Negative regions for computing the human efforts in correcting them --- 460 | FN_ = np.logical_and(FN, Union_erode) # preserve the structural components of FN 461 | ## recover the FN, where pixels are not close to the TP borders 462 | for i in range(0, relax): 463 | FN_ = cv2.dilate(FN_.astype(np.uint8), disk(1)) 464 | FN_ = np.logical_and(FN_, 1-np.logical_or(TP, FP)) 465 | FN_ = np.logical_and(FN, FN_) 466 | FN_ = np.logical_or(FN_, np.logical_xor(gt_ske, np.logical_and(TP, gt_ske))) # preserve the structural components of FN 467 | 468 | ## 2. =============Find exact polygon control points and independent regions============== 469 | ## find contours from FP_ 470 | ctrs_FP, hier_FP = cv2.findContours(FP_.astype(np.uint8), cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE) 471 | ## find control points and independent regions for human correction 472 | bdies_FP, indep_cnt_FP = self.filter_bdy_cond(ctrs_FP, FP_, np.logical_or(TP,FN_)) 473 | ## find contours from FN_ 474 | ctrs_FN, hier_FN = cv2.findContours(FN_.astype(np.uint8), cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE) 475 | ## find control points and independent regions for human correction 476 | bdies_FN, indep_cnt_FN = self.filter_bdy_cond(ctrs_FN, FN_, 1-np.logical_or(np.logical_or(TP, FP_), FN_)) 477 | 478 | poly_FP, poly_FP_len, poly_FP_point_cnt = self.approximate_RDP(bdies_FP, epsilon=epsilon) 479 | poly_FN, poly_FN_len, poly_FN_point_cnt = self.approximate_RDP(bdies_FN, epsilon=epsilon) 480 | 481 | # FP_points+FP_indep+FN_points+FN_indep 482 | return poly_FP_point_cnt+indep_cnt_FP+poly_FN_point_cnt+indep_cnt_FN 483 | 484 | def filter_bdy_cond(self, bdy_, mask, cond): 485 | 486 | cond = cv2.dilate(cond.astype(np.uint8), disk(1)) 487 | labels = label(mask) # find the connected regions 488 | lbls = np.unique(labels) # the indices of the connected regions 489 | indep = np.ones(lbls.shape[0]) # the label of each connected regions 490 | indep[0] = 0 # 0 indicate the background region 491 | 492 | boundaries = [] 493 | h,w = cond.shape[0:2] 494 | ind_map = np.zeros((h, w)) 495 | indep_cnt = 0 496 | 497 | for i in range(0, len(bdy_)): 498 | tmp_bdies = [] 499 | tmp_bdy = [] 500 | for j in range(0, bdy_[i].shape[0]): 501 | r, c = bdy_[i][j,0,1],bdy_[i][j,0,0] 502 | 503 | if(np.sum(cond[r, c])==0 or ind_map[r, c]!=0): 504 | if(len(tmp_bdy)>0): 505 | tmp_bdies.append(tmp_bdy) 506 | tmp_bdy = [] 507 | continue 508 | tmp_bdy.append([c, r]) 509 | ind_map[r, c] = ind_map[r, c] + 1 510 | indep[labels[r, c]] = 0 # indicates part of the boundary of this region needs human correction 511 | if(len(tmp_bdy)>0): 512 | tmp_bdies.append(tmp_bdy) 513 | 514 | # check if the first and the last boundaries are connected 515 | # if yes, invert the first boundary and attach it after the last boundary 516 | if(len(tmp_bdies)>1): 517 | first_x, first_y = tmp_bdies[0][0] 518 | last_x, last_y = tmp_bdies[-1][-1] 519 | if((abs(first_x-last_x)==1 and first_y==last_y) or 520 | (first_x==last_x and abs(first_y-last_y)==1) or 521 | (abs(first_x-last_x)==1 and abs(first_y-last_y)==1) 522 | ): 523 | tmp_bdies[-1].extend(tmp_bdies[0][::-1]) 524 | del tmp_bdies[0] 525 | 526 | for k in range(0, len(tmp_bdies)): 527 | tmp_bdies[k] = np.array(tmp_bdies[k])[:, np.newaxis, :] 528 | if(len(tmp_bdies)>0): 529 | boundaries.extend(tmp_bdies) 530 | 531 | return boundaries, np.sum(indep) 532 | 533 | # this function approximate each boundary by DP algorithm 534 | # https://en.wikipedia.org/wiki/Ramer%E2%80%93Douglas%E2%80%93Peucker_algorithm 535 | def approximate_RDP(self, boundaries, epsilon=1.0): 536 | 537 | boundaries_ = [] 538 | boundaries_len_ = [] 539 | pixel_cnt_ = 0 540 | 541 | # polygon approximate of each boundary 542 | for i in range(0, len(boundaries)): 543 | boundaries_.append(cv2.approxPolyDP(boundaries[i], epsilon, False)) 544 | 545 | # count the control points number of each boundary and the total control points number of all the boundaries 546 | for i in range(0, len(boundaries_)): 547 | boundaries_len_.append(len(boundaries_[i])) 548 | pixel_cnt_ = pixel_cnt_ + len(boundaries_[i]) 549 | 550 | return boundaries_, boundaries_len_, pixel_cnt_ 551 | -------------------------------------------------------------------------------- /model/MVANet.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | from einops import rearrange 6 | from torch import nn 7 | from .SwinTransformer import SwinB 8 | 9 | 10 | def get_activation_fn(activation): 11 | """Return an activation function given a string""" 12 | if activation == "relu": 13 | return F.relu 14 | if activation == "gelu": 15 | return F.gelu 16 | if activation == "glu": 17 | return F.glu 18 | raise RuntimeError(F"activation should be relu/gelu, not {activation}.") 19 | 20 | 21 | def make_cbr(in_dim, out_dim): 22 | return nn.Sequential(nn.Conv2d(in_dim, out_dim, kernel_size=3, padding=1), nn.BatchNorm2d(out_dim), nn.PReLU()) 23 | 24 | 25 | def make_cbg(in_dim, out_dim): 26 | return nn.Sequential(nn.Conv2d(in_dim, out_dim, kernel_size=3, padding=1), nn.BatchNorm2d(out_dim), nn.GELU()) 27 | 28 | 29 | def rescale_to(x, scale_factor: float = 2, interpolation='nearest'): 30 | return F.interpolate(x, scale_factor=scale_factor, mode=interpolation) 31 | 32 | 33 | def resize_as(x, y, interpolation='bilinear'): 34 | return F.interpolate(x, size=y.shape[-2:], mode=interpolation) 35 | 36 | 37 | def image2patches(x): 38 | """b c (hg h) (wg w) -> (hg wg b) c h w""" 39 | x = rearrange(x, 'b c (hg h) (wg w) -> (hg wg b) c h w', hg=2, wg=2) 40 | return x 41 | 42 | 43 | def patches2image(x): 44 | """(hg wg b) c h w -> b c (hg h) (wg w)""" 45 | x = rearrange(x, '(hg wg b) c h w -> b c (hg h) (wg w)', hg=2, wg=2) 46 | return x 47 | 48 | 49 | class PositionEmbeddingSine: 50 | def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None): 51 | super().__init__() 52 | self.num_pos_feats = num_pos_feats 53 | self.temperature = temperature 54 | self.normalize = normalize 55 | if scale is not None and normalize is False: 56 | raise ValueError("normalize should be True if scale is passed") 57 | if scale is None: 58 | scale = 2 * math.pi 59 | self.scale = scale 60 | self.dim_t = torch.arange(0, self.num_pos_feats, dtype=torch.float32, device='cuda') 61 | 62 | def __call__(self, b, h, w): 63 | mask = torch.zeros([b, h, w], dtype=torch.bool, device='cuda') 64 | assert mask is not None 65 | not_mask = ~mask 66 | y_embed = not_mask.cumsum(dim=1, dtype=torch.float32) 67 | x_embed = not_mask.cumsum(dim=2, dtype=torch.float32) 68 | if self.normalize: 69 | eps = 1e-6 70 | y_embed = ((y_embed - 0.5) / (y_embed[:, -1:, :] + eps) * self.scale).cuda() 71 | x_embed = ((x_embed - 0.5) / (x_embed[:, :, -1:] + eps) * self.scale).cuda() 72 | 73 | dim_t = self.temperature ** (2 * (self.dim_t // 2) / self.num_pos_feats) 74 | 75 | pos_x = x_embed[:, :, :, None] / dim_t 76 | pos_y = y_embed[:, :, :, None] / dim_t 77 | pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten( 78 | 3) 79 | pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten( 80 | 3) 81 | return torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) 82 | 83 | 84 | class MCLM(nn.Module): 85 | def __init__(self, d_model, num_heads, pool_ratios=[1, 4, 8]): 86 | super(MCLM, self).__init__() 87 | self.attention = nn.ModuleList([ 88 | nn.MultiheadAttention(d_model, num_heads, dropout=0.1), 89 | nn.MultiheadAttention(d_model, num_heads, dropout=0.1), 90 | nn.MultiheadAttention(d_model, num_heads, dropout=0.1), 91 | nn.MultiheadAttention(d_model, num_heads, dropout=0.1), 92 | nn.MultiheadAttention(d_model, num_heads, dropout=0.1) 93 | ]) 94 | 95 | self.linear3 = nn.Linear(d_model, d_model * 2) 96 | self.linear4 = nn.Linear(d_model * 2, d_model) 97 | self.linear5 = nn.Linear(d_model, d_model * 2) 98 | self.linear6 = nn.Linear(d_model * 2, d_model) 99 | self.norm1 = nn.LayerNorm(d_model) 100 | self.norm2 = nn.LayerNorm(d_model) 101 | self.dropout = nn.Dropout(0.1) 102 | self.dropout1 = nn.Dropout(0.1) 103 | self.dropout2 = nn.Dropout(0.1) 104 | self.activation = get_activation_fn('relu') 105 | self.pool_ratios = pool_ratios 106 | self.p_poses = [] 107 | self.g_pos = None 108 | self.positional_encoding = PositionEmbeddingSine(num_pos_feats=d_model // 2, normalize=True) 109 | 110 | def forward(self, l, g): 111 | """ 112 | l: 4,c,h,w 113 | g: 1,c,h,w 114 | """ 115 | b, c, h, w = l.size() 116 | # 4,c,h,w -> 1,c,2h,2w 117 | concated_locs = rearrange(l, '(hg wg b) c h w -> b c (hg h) (wg w)', hg=2, wg=2) 118 | 119 | pools = [] 120 | for pool_ratio in self.pool_ratios: 121 | # b,c,h,w 122 | tgt_hw = (round(h / pool_ratio), round(w / pool_ratio)) 123 | pool = F.adaptive_avg_pool2d(concated_locs, tgt_hw) 124 | pools.append(rearrange(pool, 'b c h w -> (h w) b c')) 125 | if self.g_pos is None: 126 | pos_emb = self.positional_encoding(pool.shape[0], pool.shape[2], pool.shape[3]) 127 | pos_emb = rearrange(pos_emb, 'b c h w -> (h w) b c') 128 | self.p_poses.append(pos_emb) 129 | pools = torch.cat(pools, 0) 130 | if self.g_pos is None: 131 | self.p_poses = torch.cat(self.p_poses, dim=0) 132 | pos_emb = self.positional_encoding(g.shape[0], g.shape[2], g.shape[3]) 133 | self.g_pos = rearrange(pos_emb, 'b c h w -> (h w) b c') 134 | 135 | # attention between glb (q) & multisensory concated-locs (k,v) 136 | g_hw_b_c = rearrange(g, 'b c h w -> (h w) b c') 137 | g_hw_b_c = g_hw_b_c + self.dropout1(self.attention[0](g_hw_b_c + self.g_pos, pools + self.p_poses, pools)[0]) 138 | g_hw_b_c = self.norm1(g_hw_b_c) 139 | g_hw_b_c = g_hw_b_c + self.dropout2(self.linear6(self.dropout(self.activation(self.linear5(g_hw_b_c)).clone()))) 140 | g_hw_b_c = self.norm2(g_hw_b_c) 141 | 142 | # attention between origin locs (q) & freashed glb (k,v) 143 | l_hw_b_c = rearrange(l, "b c h w -> (h w) b c") 144 | _g_hw_b_c = rearrange(g_hw_b_c, '(h w) b c -> h w b c', h=h, w=w) 145 | _g_hw_b_c = rearrange(_g_hw_b_c, "(ng h) (nw w) b c -> (h w) (ng nw b) c", ng=2, nw=2) 146 | outputs_re = [] 147 | for i, (_l, _g) in enumerate(zip(l_hw_b_c.chunk(4, dim=1), _g_hw_b_c.chunk(4, dim=1))): 148 | outputs_re.append(self.attention[i + 1](_l, _g, _g)[0]) # (h w) 1 c 149 | outputs_re = torch.cat(outputs_re, 1) # (h w) 4 c 150 | 151 | l_hw_b_c = l_hw_b_c + self.dropout1(outputs_re) 152 | l_hw_b_c = self.norm1(l_hw_b_c) 153 | l_hw_b_c = l_hw_b_c + self.dropout2(self.linear4(self.dropout(self.activation(self.linear3(l_hw_b_c)).clone()))) 154 | l_hw_b_c = self.norm2(l_hw_b_c) 155 | 156 | l = torch.cat((l_hw_b_c, g_hw_b_c), 1) # hw,b(5),c 157 | return rearrange(l, "(h w) b c -> b c h w", h=h, w=w) ## (5,c,h*w) 158 | 159 | 160 | class inf_MCLM(nn.Module): 161 | def __init__(self, d_model, num_heads, pool_ratios=[1, 4, 8]): 162 | super(inf_MCLM, self).__init__() 163 | self.attention = nn.ModuleList([ 164 | nn.MultiheadAttention(d_model, num_heads, dropout=0.1), 165 | nn.MultiheadAttention(d_model, num_heads, dropout=0.1), 166 | nn.MultiheadAttention(d_model, num_heads, dropout=0.1), 167 | nn.MultiheadAttention(d_model, num_heads, dropout=0.1), 168 | nn.MultiheadAttention(d_model, num_heads, dropout=0.1) 169 | ]) 170 | self.linear3 = nn.Linear(d_model, d_model * 2) 171 | self.linear4 = nn.Linear(d_model * 2, d_model) 172 | self.linear5 = nn.Linear(d_model, d_model * 2) 173 | self.linear6 = nn.Linear(d_model * 2, d_model) 174 | self.norm1 = nn.LayerNorm(d_model) 175 | self.norm2 = nn.LayerNorm(d_model) 176 | self.dropout = nn.Dropout(0.1) 177 | self.dropout1 = nn.Dropout(0.1) 178 | self.dropout2 = nn.Dropout(0.1) 179 | self.activation = get_activation_fn('relu') 180 | self.pool_ratios = pool_ratios 181 | self.p_poses = [] 182 | self.g_pos = None 183 | self.positional_encoding = PositionEmbeddingSine(num_pos_feats=d_model // 2, normalize=True) 184 | 185 | def forward(self, l, g): 186 | """ 187 | l: 4,c,h,w 188 | g: 1,c,h,w 189 | """ 190 | b, c, h, w = l.size() 191 | # 4,c,h,w -> 1,c,2h,2w 192 | concated_locs = rearrange(l, '(hg wg b) c h w -> b c (hg h) (wg w)', hg=2, wg=2) 193 | self.p_poses = [] 194 | pools = [] 195 | for pool_ratio in self.pool_ratios: 196 | # b,c,h,w 197 | tgt_hw = (round(h / pool_ratio), round(w / pool_ratio)) 198 | pool = F.adaptive_avg_pool2d(concated_locs, tgt_hw) 199 | pools.append(rearrange(pool, 'b c h w -> (h w) b c')) 200 | # if self.g_pos is None: 201 | pos_emb = self.positional_encoding(pool.shape[0], pool.shape[2], pool.shape[3]) 202 | pos_emb = rearrange(pos_emb, 'b c h w -> (h w) b c') 203 | self.p_poses.append(pos_emb) 204 | pools = torch.cat(pools, 0) 205 | # if self.g_pos is None: 206 | self.p_poses = torch.cat(self.p_poses, dim=0) 207 | pos_emb = self.positional_encoding(g.shape[0], g.shape[2], g.shape[3]) 208 | self.g_pos = rearrange(pos_emb, 'b c h w -> (h w) b c') 209 | 210 | # attention between glb (q) & multisensory concated-locs (k,v) 211 | g_hw_b_c = rearrange(g, 'b c h w -> (h w) b c') 212 | g_hw_b_c = g_hw_b_c + self.dropout1(self.attention[0](g_hw_b_c + self.g_pos, pools + self.p_poses, pools)[0]) 213 | g_hw_b_c = self.norm1(g_hw_b_c) 214 | g_hw_b_c = g_hw_b_c + self.dropout2(self.linear6(self.dropout(self.activation(self.linear5(g_hw_b_c)).clone()))) 215 | g_hw_b_c = self.norm2(g_hw_b_c) 216 | 217 | # attention between origin locs (q) & freashed glb (k,v) 218 | l_hw_b_c = rearrange(l, "b c h w -> (h w) b c") 219 | _g_hw_b_c = rearrange(g_hw_b_c, '(h w) b c -> h w b c', h=h, w=w) 220 | _g_hw_b_c = rearrange(_g_hw_b_c, "(ng h) (nw w) b c -> (h w) (ng nw b) c", ng=2, nw=2) 221 | outputs_re = [] 222 | for i, (_l, _g) in enumerate(zip(l_hw_b_c.chunk(4, dim=1), _g_hw_b_c.chunk(4, dim=1))): 223 | outputs_re.append(self.attention[i + 1](_l, _g, _g)[0]) # (h w) 1 c 224 | outputs_re = torch.cat(outputs_re, 1) # (h w) 4 c 225 | 226 | l_hw_b_c = l_hw_b_c + self.dropout1(outputs_re) 227 | l_hw_b_c = self.norm1(l_hw_b_c) 228 | l_hw_b_c = l_hw_b_c + self.dropout2(self.linear4(self.dropout(self.activation(self.linear3(l_hw_b_c)).clone()))) 229 | l_hw_b_c = self.norm2(l_hw_b_c) 230 | 231 | l = torch.cat((l_hw_b_c, g_hw_b_c), 1) # hw,b(5),c 232 | return rearrange(l, "(h w) b c -> b c h w", h=h, w=w) ## (5,c,h*w) 233 | 234 | 235 | class MCRM(nn.Module): 236 | def __init__(self, d_model, num_heads, pool_ratios=[4, 8, 16], h=None): 237 | super(MCRM, self).__init__() 238 | self.attention = nn.ModuleList([ 239 | nn.MultiheadAttention(d_model, num_heads, dropout=0.1), 240 | nn.MultiheadAttention(d_model, num_heads, dropout=0.1), 241 | nn.MultiheadAttention(d_model, num_heads, dropout=0.1), 242 | nn.MultiheadAttention(d_model, num_heads, dropout=0.1) 243 | ]) 244 | 245 | self.linear3 = nn.Linear(d_model, d_model * 2) 246 | self.linear4 = nn.Linear(d_model * 2, d_model) 247 | self.norm1 = nn.LayerNorm(d_model) 248 | self.norm2 = nn.LayerNorm(d_model) 249 | self.dropout = nn.Dropout(0.1) 250 | self.dropout1 = nn.Dropout(0.1) 251 | self.dropout2 = nn.Dropout(0.1) 252 | self.sigmoid = nn.Sigmoid() 253 | self.activation = get_activation_fn('relu') 254 | self.sal_conv = nn.Conv2d(d_model, 1, 1) 255 | self.pool_ratios = pool_ratios 256 | self.positional_encoding = PositionEmbeddingSine(num_pos_feats=d_model // 2, normalize=True) 257 | def forward(self, x): 258 | b, c, h, w = x.size() 259 | loc, glb = x.split([4, 1], dim=0) # 4,c,h,w; 1,c,h,w 260 | # b(4),c,h,w 261 | patched_glb = rearrange(glb, 'b c (hg h) (wg w) -> (hg wg b) c h w', hg=2, wg=2) 262 | 263 | # generate token attention map 264 | token_attention_map = self.sigmoid(self.sal_conv(glb)) 265 | token_attention_map = F.interpolate(token_attention_map, size=patches2image(loc).shape[-2:], mode='nearest') 266 | loc = loc * rearrange(token_attention_map, 'b c (hg h) (wg w) -> (hg wg b) c h w', hg=2, wg=2) 267 | pools = [] 268 | for pool_ratio in self.pool_ratios: 269 | tgt_hw = (round(h / pool_ratio), round(w / pool_ratio)) 270 | pool = F.adaptive_avg_pool2d(patched_glb, tgt_hw) 271 | pools.append(rearrange(pool, 'nl c h w -> nl c (h w)')) # nl(4),c,hw 272 | # nl(4),c,nphw -> nl(4),nphw,1,c 273 | pools = rearrange(torch.cat(pools, 2), "nl c nphw -> nl nphw 1 c") 274 | loc_ = rearrange(loc, 'nl c h w -> nl (h w) 1 c') 275 | outputs = [] 276 | for i, q in enumerate(loc_.unbind(dim=0)): # traverse all local patches 277 | # np*hw,1,c 278 | v = pools[i] 279 | k = v 280 | outputs.append(self.attention[i](q, k, v)[0]) 281 | outputs = torch.cat(outputs, 1) 282 | src = loc.view(4, c, -1).permute(2, 0, 1) + self.dropout1(outputs) 283 | src = self.norm1(src) 284 | src = src + self.dropout2(self.linear4(self.dropout(self.activation(self.linear3(src)).clone()))) 285 | src = self.norm2(src) 286 | 287 | src = src.permute(1, 2, 0).reshape(4, c, h, w) # freshed loc 288 | glb = glb + F.interpolate(patches2image(src), size=glb.shape[-2:], mode='nearest') # freshed glb 289 | return torch.cat((src, glb), 0), token_attention_map 290 | 291 | 292 | class inf_MCRM(nn.Module): 293 | def __init__(self, d_model, num_heads, pool_ratios=[4, 8, 16], h=None): 294 | super(inf_MCRM, self).__init__() 295 | self.attention = nn.ModuleList([ 296 | nn.MultiheadAttention(d_model, num_heads, dropout=0.1), 297 | nn.MultiheadAttention(d_model, num_heads, dropout=0.1), 298 | nn.MultiheadAttention(d_model, num_heads, dropout=0.1), 299 | nn.MultiheadAttention(d_model, num_heads, dropout=0.1) 300 | ]) 301 | 302 | self.linear3 = nn.Linear(d_model, d_model * 2) 303 | self.linear4 = nn.Linear(d_model * 2, d_model) 304 | self.norm1 = nn.LayerNorm(d_model) 305 | self.norm2 = nn.LayerNorm(d_model) 306 | self.dropout = nn.Dropout(0.1) 307 | self.dropout1 = nn.Dropout(0.1) 308 | self.dropout2 = nn.Dropout(0.1) 309 | self.sigmoid = nn.Sigmoid() 310 | self.activation = get_activation_fn('relu') 311 | self.sal_conv = nn.Conv2d(d_model, 1, 1) 312 | self.pool_ratios = pool_ratios 313 | self.positional_encoding = PositionEmbeddingSine(num_pos_feats=d_model // 2, normalize=True) 314 | def forward(self, x): 315 | b, c, h, w = x.size() 316 | loc, glb = x.split([4, 1], dim=0) # 4,c,h,w; 1,c,h,w 317 | # b(4),c,h,w 318 | patched_glb = rearrange(glb, 'b c (hg h) (wg w) -> (hg wg b) c h w', hg=2, wg=2) 319 | 320 | # generate token attention map 321 | token_attention_map = self.sigmoid(self.sal_conv(glb)) 322 | token_attention_map = F.interpolate(token_attention_map, size=patches2image(loc).shape[-2:], mode='nearest') 323 | loc = loc * rearrange(token_attention_map, 'b c (hg h) (wg w) -> (hg wg b) c h w', hg=2, wg=2) 324 | pools = [] 325 | for pool_ratio in self.pool_ratios: 326 | tgt_hw = (round(h / pool_ratio), round(w / pool_ratio)) 327 | pool = F.adaptive_avg_pool2d(patched_glb, tgt_hw) 328 | pools.append(rearrange(pool, 'nl c h w -> nl c (h w)')) # nl(4),c,hw 329 | # nl(4),c,nphw -> nl(4),nphw,1,c 330 | pools = rearrange(torch.cat(pools, 2), "nl c nphw -> nl nphw 1 c") 331 | loc_ = rearrange(loc, 'nl c h w -> nl (h w) 1 c') 332 | outputs = [] 333 | for i, q in enumerate(loc_.unbind(dim=0)): # traverse all local patches 334 | # np*hw,1,c 335 | v = pools[i] 336 | k = v 337 | outputs.append(self.attention[i](q, k, v)[0]) 338 | outputs = torch.cat(outputs, 1) 339 | src = loc.view(4, c, -1).permute(2, 0, 1) + self.dropout1(outputs) 340 | src = self.norm1(src) 341 | src = src + self.dropout2(self.linear4(self.dropout(self.activation(self.linear3(src)).clone()))) 342 | src = self.norm2(src) 343 | 344 | src = src.permute(1, 2, 0).reshape(4, c, h, w) # freshed loc 345 | glb = glb + F.interpolate(patches2image(src), size=glb.shape[-2:], mode='nearest') # freshed glb 346 | return torch.cat((src, glb), 0) 347 | 348 | # model for single-scale training 349 | class MVANet(nn.Module): 350 | def __init__(self): 351 | super().__init__() 352 | self.backbone = SwinB(pretrained=True) 353 | emb_dim = 128 354 | self.sideout5 = nn.Sequential(nn.Conv2d(emb_dim, 1, kernel_size=3, padding=1)) 355 | self.sideout4 = nn.Sequential(nn.Conv2d(emb_dim, 1, kernel_size=3, padding=1)) 356 | self.sideout3 = nn.Sequential(nn.Conv2d(emb_dim, 1, kernel_size=3, padding=1)) 357 | self.sideout2 = nn.Sequential(nn.Conv2d(emb_dim, 1, kernel_size=3, padding=1)) 358 | self.sideout1 = nn.Sequential(nn.Conv2d(emb_dim, 1, kernel_size=3, padding=1)) 359 | 360 | self.output5 = make_cbr(1024, emb_dim) 361 | self.output4 = make_cbr(512, emb_dim) 362 | self.output3 = make_cbr(256, emb_dim) 363 | self.output2 = make_cbr(128, emb_dim) 364 | self.output1 = make_cbr(128, emb_dim) 365 | 366 | self.multifieldcrossatt = MCLM(emb_dim, 1, [1, 4, 8]) 367 | self.conv1 = make_cbr(emb_dim, emb_dim) 368 | self.conv2 = make_cbr(emb_dim, emb_dim) 369 | self.conv3 = make_cbr(emb_dim, emb_dim) 370 | self.conv4 = make_cbr(emb_dim, emb_dim) 371 | self.dec_blk1 = MCRM(emb_dim, 1, [2, 4, 8]) 372 | self.dec_blk2 = MCRM(emb_dim, 1, [2, 4, 8]) 373 | self.dec_blk3 = MCRM(emb_dim, 1, [2, 4, 8]) 374 | self.dec_blk4 = MCRM(emb_dim, 1, [2, 4, 8]) 375 | 376 | self.insmask_head = nn.Sequential( 377 | nn.Conv2d(emb_dim, 384, kernel_size=3, padding=1), 378 | nn.BatchNorm2d(384), 379 | nn.PReLU(), 380 | nn.Conv2d(384, 384, kernel_size=3, padding=1), 381 | nn.BatchNorm2d(384), 382 | nn.PReLU(), 383 | nn.Conv2d(384, emb_dim, kernel_size=3, padding=1) 384 | ) 385 | 386 | self.shallow = nn.Sequential(nn.Conv2d(3, emb_dim, kernel_size=3, padding=1)) 387 | self.upsample1 = make_cbg(emb_dim, emb_dim) 388 | self.upsample2 = make_cbg(emb_dim, emb_dim) 389 | self.output = nn.Sequential(nn.Conv2d(emb_dim, 1, kernel_size=3, padding=1)) 390 | 391 | for m in self.modules(): 392 | if isinstance(m, nn.ReLU) or isinstance(m, nn.Dropout): 393 | m.inplace = True 394 | 395 | def forward(self, x): 396 | shallow = self.shallow(x) 397 | glb = rescale_to(x, scale_factor=0.5, interpolation='bilinear') 398 | loc = image2patches(x) 399 | input = torch.cat((loc, glb), dim=0) 400 | feature = self.backbone(input) 401 | e5 = self.output5(feature[4]) # (5,128,16,16) 402 | e4 = self.output4(feature[3]) # (5,128,32,32) 403 | e3 = self.output3(feature[2]) # (5,128,64,64) 404 | e2 = self.output2(feature[1]) # (5,128,128,128) 405 | e1 = self.output1(feature[0]) # (5,128,128,128) 406 | loc_e5, glb_e5 = e5.split([4, 1], dim=0) 407 | e5 = self.multifieldcrossatt(loc_e5, glb_e5) # (4,128,16,16) 408 | 409 | e4, tokenattmap4 = self.dec_blk4(e4 + resize_as(e5, e4)) 410 | e4 = self.conv4(e4) 411 | e3, tokenattmap3 = self.dec_blk3(e3 + resize_as(e4, e3)) 412 | e3 = self.conv3(e3) 413 | e2, tokenattmap2 = self.dec_blk2(e2 + resize_as(e3, e2)) 414 | e2 = self.conv2(e2) 415 | e1, tokenattmap1 = self.dec_blk1(e1 + resize_as(e2, e1)) 416 | e1 = self.conv1(e1) 417 | loc_e1, glb_e1 = e1.split([4, 1], dim=0) 418 | output1_cat = patches2image(loc_e1) # (1,128,256,256) 419 | # add glb feat in 420 | output1_cat = output1_cat + resize_as(glb_e1, output1_cat) 421 | # merge 422 | final_output = self.insmask_head(output1_cat) # (1,128,256,256) 423 | # shallow feature merge 424 | final_output = final_output + resize_as(shallow, final_output) 425 | final_output = self.upsample1(rescale_to(final_output)) 426 | final_output = rescale_to(final_output + resize_as(shallow, final_output)) 427 | final_output = self.upsample2(final_output) 428 | final_output = self.output(final_output) 429 | #### 430 | sideout5 = self.sideout5(e5).cuda() 431 | sideout4 = self.sideout4(e4) 432 | sideout3 = self.sideout3(e3) 433 | sideout2 = self.sideout2(e2) 434 | sideout1 = self.sideout1(e1) 435 | #######glb_sideouts ###### 436 | glb5 = self.sideout5(glb_e5) 437 | glb4 = sideout4[-1,:,:,:].unsqueeze(0) 438 | glb3 = sideout3[-1,:,:,:].unsqueeze(0) 439 | glb2 = sideout2[-1,:,:,:].unsqueeze(0) 440 | glb1 = sideout1[-1,:,:,:].unsqueeze(0) 441 | ####### concat 4 to 1 ####### 442 | sideout1 = patches2image(sideout1[:-1]).cuda() 443 | sideout2 = patches2image(sideout2[:-1]).cuda()####(5,c,h,w) -> (1 c 2h,2w) 444 | sideout3 = patches2image(sideout3[:-1]).cuda() 445 | sideout4 = patches2image(sideout4[:-1]).cuda() 446 | sideout5 = patches2image(sideout5[:-1]).cuda() 447 | if self.training: 448 | return sideout5, sideout4,sideout3,sideout2,sideout1,final_output, glb5, glb4, glb3, glb2, glb1,tokenattmap4, tokenattmap3,tokenattmap2,tokenattmap1 449 | else: 450 | return final_output 451 | 452 | # model for multi-scale testing 453 | class inf_MVANet(nn.Module): 454 | def __init__(self): 455 | super().__init__() 456 | self.backbone = SwinB(pretrained=False) 457 | 458 | emb_dim = 128 459 | self.output5 = make_cbr(1024, emb_dim) 460 | self.output4 = make_cbr(512, emb_dim) 461 | self.output3 = make_cbr(256, emb_dim) 462 | self.output2 = make_cbr(128, emb_dim) 463 | self.output1 = make_cbr(128, emb_dim) 464 | 465 | self.multifieldcrossatt = inf_MCLM(emb_dim, 1, [1, 4, 8]) 466 | self.conv1 = make_cbr(emb_dim, emb_dim) 467 | self.conv2 = make_cbr(emb_dim, emb_dim) 468 | self.conv3 = make_cbr(emb_dim, emb_dim) 469 | self.conv4 = make_cbr(emb_dim, emb_dim) 470 | self.dec_blk1 = inf_MCRM(emb_dim, 1, [2, 4, 8]) 471 | self.dec_blk2 = inf_MCRM(emb_dim, 1, [2, 4, 8]) 472 | self.dec_blk3 = inf_MCRM(emb_dim, 1, [2, 4, 8]) 473 | self.dec_blk4 = inf_MCRM(emb_dim, 1, [2, 4, 8]) 474 | 475 | self.insmask_head = nn.Sequential( 476 | nn.Conv2d(emb_dim, 384, kernel_size=3, padding=1), 477 | nn.BatchNorm2d(384), 478 | nn.PReLU(), 479 | nn.Conv2d(384, 384, kernel_size=3, padding=1), 480 | nn.BatchNorm2d(384), 481 | nn.PReLU(), 482 | nn.Conv2d(384, emb_dim, kernel_size=3, padding=1) 483 | ) 484 | 485 | self.shallow = nn.Sequential(nn.Conv2d(3, emb_dim, kernel_size=3, padding=1)) 486 | self.upsample1 = make_cbg(emb_dim, emb_dim) 487 | self.upsample2 = make_cbg(emb_dim, emb_dim) 488 | self.output = nn.Sequential(nn.Conv2d(emb_dim, 1, kernel_size=3, padding=1)) 489 | 490 | for m in self.modules(): 491 | if isinstance(m, nn.ReLU) or isinstance(m, nn.Dropout): 492 | m.inplace = True 493 | 494 | def forward(self, x): 495 | shallow = self.shallow(x) 496 | glb = rescale_to(x, scale_factor=0.5, interpolation='bilinear') 497 | loc = image2patches(x) 498 | input = torch.cat((loc, glb), dim=0) 499 | feature = self.backbone(input) 500 | e5 = self.output5(feature[4]) 501 | e4 = self.output4(feature[3]) 502 | e3 = self.output3(feature[2]) 503 | e2 = self.output2(feature[1]) 504 | e1 = self.output1(feature[0]) 505 | loc_e5, glb_e5 = e5.split([4, 1], dim=0) 506 | e5_cat = self.multifieldcrossatt(loc_e5, glb_e5) 507 | 508 | e4 = self.conv4(self.dec_blk4(e4 + resize_as(e5_cat, e4))) 509 | e3 = self.conv3(self.dec_blk3(e3 + resize_as(e4, e3))) 510 | e2 = self.conv2(self.dec_blk2(e2 + resize_as(e3, e2))) 511 | e1 = self.conv1(self.dec_blk1(e1 + resize_as(e2, e1))) 512 | loc_e1, glb_e1 = e1.split([4, 1], dim=0) 513 | # after decoder, concat loc features to a whole one, and merge 514 | output1_cat = patches2image(loc_e1) 515 | # add glb feat in 516 | output1_cat = output1_cat + resize_as(glb_e1, output1_cat) 517 | # merge 518 | final_output = self.insmask_head(output1_cat) 519 | # shallow feature merge 520 | final_output = final_output + resize_as(shallow, final_output) 521 | final_output = self.upsample1(rescale_to(final_output)) 522 | final_output = rescale_to(final_output + resize_as(shallow, final_output)) 523 | final_output = self.upsample2(final_output) 524 | final_output = self.output(final_output) 525 | return final_output 526 | 527 | 528 | -------------------------------------------------------------------------------- /model/SwinTransformer.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Swin Transformer 3 | # Copyright (c) 2021 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ze Liu, Yutong Lin, Yixuan Wei 6 | # -------------------------------------------------------- 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | import torch.utils.checkpoint as checkpoint 12 | import numpy as np 13 | from mmdet.utils import get_root_logger 14 | from timm.models import load_checkpoint 15 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_ 16 | 17 | class Mlp(nn.Module): 18 | """ Multilayer perceptron.""" 19 | 20 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 21 | super().__init__() 22 | out_features = out_features or in_features 23 | hidden_features = hidden_features or in_features 24 | self.fc1 = nn.Linear(in_features, hidden_features) 25 | self.act = act_layer() 26 | self.fc2 = nn.Linear(hidden_features, out_features) 27 | self.drop = nn.Dropout(drop) 28 | 29 | def forward(self, x): 30 | x = self.fc1(x) 31 | x = self.act(x) 32 | x = self.drop(x) 33 | x = self.fc2(x) 34 | x = self.drop(x) 35 | return x 36 | 37 | 38 | def window_partition(x, window_size): 39 | """ 40 | Args: 41 | x: (B, H, W, C) 42 | window_size (int): window size 43 | 44 | Returns: 45 | windows: (num_windows*B, window_size, window_size, C) 46 | """ 47 | B, H, W, C = x.shape 48 | x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) 49 | windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) 50 | return windows 51 | 52 | 53 | def window_reverse(windows, window_size, H, W): 54 | """ 55 | Args: 56 | windows: (num_windows*B, window_size, window_size, C) 57 | window_size (int): Window size 58 | H (int): Height of image 59 | W (int): Width of image 60 | 61 | Returns: 62 | x: (B, H, W, C) 63 | """ 64 | B = int(windows.shape[0] / (H * W / window_size / window_size)) 65 | x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) 66 | x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) 67 | return x 68 | 69 | 70 | class WindowAttention(nn.Module): 71 | """ Window based multi-head self attention (W-MSA) module with relative position bias. 72 | It supports both of shifted and non-shifted window. 73 | 74 | Args: 75 | dim (int): Number of input channels. 76 | window_size (tuple[int]): The height and width of the window. 77 | num_heads (int): Number of attention heads. 78 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 79 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set 80 | attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 81 | proj_drop (float, optional): Dropout ratio of output. Default: 0.0 82 | """ 83 | 84 | def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): 85 | 86 | super().__init__() 87 | self.dim = dim 88 | self.window_size = window_size # Wh, Ww 89 | self.num_heads = num_heads 90 | head_dim = dim // num_heads 91 | self.scale = qk_scale or head_dim ** -0.5 92 | 93 | # define a parameter table of relative position bias 94 | self.relative_position_bias_table = nn.Parameter( 95 | torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH 96 | 97 | # get pair-wise relative position index for each token inside the window 98 | coords_h = torch.arange(self.window_size[0]) 99 | coords_w = torch.arange(self.window_size[1]) 100 | coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww 101 | coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww 102 | relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww 103 | relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 104 | relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 105 | relative_coords[:, :, 1] += self.window_size[1] - 1 106 | relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 107 | relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww 108 | self.register_buffer("relative_position_index", relative_position_index) 109 | 110 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 111 | self.attn_drop = nn.Dropout(attn_drop) 112 | self.proj = nn.Linear(dim, dim) 113 | self.proj_drop = nn.Dropout(proj_drop) 114 | 115 | trunc_normal_(self.relative_position_bias_table, std=.02) 116 | self.softmax = nn.Softmax(dim=-1) 117 | 118 | def forward(self, x, mask=None): 119 | """ Forward function. 120 | 121 | Args: 122 | x: input features with shape of (num_windows*B, N, C) 123 | mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None 124 | """ 125 | B_, N, C = x.shape 126 | qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 127 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) 128 | 129 | q = q * self.scale 130 | attn = (q @ k.transpose(-2, -1)) 131 | 132 | relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( 133 | self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH 134 | relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww 135 | attn = attn + relative_position_bias.unsqueeze(0) 136 | 137 | if mask is not None: 138 | nW = mask.shape[0] 139 | attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) 140 | attn = attn.view(-1, self.num_heads, N, N) 141 | attn = self.softmax(attn) 142 | else: 143 | attn = self.softmax(attn) 144 | 145 | attn = self.attn_drop(attn) 146 | 147 | x = (attn @ v).transpose(1, 2).reshape(B_, N, C) 148 | x = self.proj(x) 149 | x = self.proj_drop(x) 150 | return x 151 | 152 | 153 | class SwinTransformerBlock(nn.Module): 154 | """ Swin Transformer Block. 155 | 156 | Args: 157 | dim (int): Number of input channels. 158 | num_heads (int): Number of attention heads. 159 | window_size (int): Window size. 160 | shift_size (int): Shift size for SW-MSA. 161 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 162 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 163 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. 164 | drop (float, optional): Dropout rate. Default: 0.0 165 | attn_drop (float, optional): Attention dropout rate. Default: 0.0 166 | drop_path (float, optional): Stochastic depth rate. Default: 0.0 167 | act_layer (nn.Module, optional): Activation layer. Default: nn.GELU 168 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 169 | """ 170 | 171 | def __init__(self, dim, num_heads, window_size=7, shift_size=0, 172 | mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., 173 | act_layer=nn.GELU, norm_layer=nn.LayerNorm): 174 | super().__init__() 175 | self.dim = dim 176 | self.num_heads = num_heads 177 | self.window_size = window_size 178 | self.shift_size = shift_size 179 | self.mlp_ratio = mlp_ratio 180 | assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" 181 | 182 | self.norm1 = norm_layer(dim) 183 | self.attn = WindowAttention( 184 | dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, 185 | qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) 186 | 187 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 188 | self.norm2 = norm_layer(dim) 189 | mlp_hidden_dim = int(dim * mlp_ratio) 190 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 191 | 192 | self.H = None 193 | self.W = None 194 | 195 | def forward(self, x, mask_matrix): 196 | """ Forward function. 197 | 198 | Args: 199 | x: Input feature, tensor size (B, H*W, C). 200 | H, W: Spatial resolution of the input feature. 201 | mask_matrix: Attention mask for cyclic shift. 202 | """ 203 | B, L, C = x.shape 204 | H, W = self.H, self.W 205 | assert L == H * W, "input feature has wrong size" 206 | 207 | shortcut = x 208 | x = self.norm1(x) 209 | x = x.view(B, H, W, C) 210 | 211 | # pad feature maps to multiples of window size 212 | pad_l = pad_t = 0 213 | pad_r = (self.window_size - W % self.window_size) % self.window_size 214 | pad_b = (self.window_size - H % self.window_size) % self.window_size 215 | x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) 216 | _, Hp, Wp, _ = x.shape 217 | 218 | # cyclic shift 219 | if self.shift_size > 0: 220 | shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) 221 | attn_mask = mask_matrix 222 | else: 223 | shifted_x = x 224 | attn_mask = None 225 | 226 | # partition windows 227 | x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C 228 | x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C 229 | 230 | # W-MSA/SW-MSA 231 | attn_windows = self.attn(x_windows, mask=attn_mask) # nW*B, window_size*window_size, C 232 | 233 | # merge windows 234 | attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) 235 | shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # B H' W' C 236 | 237 | # reverse cyclic shift 238 | if self.shift_size > 0: 239 | x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) 240 | else: 241 | x = shifted_x 242 | 243 | if pad_r > 0 or pad_b > 0: 244 | x = x[:, :H, :W, :].contiguous() 245 | 246 | x = x.view(B, H * W, C) 247 | 248 | # FFN 249 | x = shortcut + self.drop_path(x) 250 | x = x + self.drop_path(self.mlp(self.norm2(x))) 251 | 252 | return x 253 | 254 | 255 | class PatchMerging(nn.Module): 256 | """ Patch Merging Layer 257 | 258 | Args: 259 | dim (int): Number of input channels. 260 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 261 | """ 262 | def __init__(self, dim, norm_layer=nn.LayerNorm): 263 | super().__init__() 264 | self.dim = dim 265 | self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) 266 | self.norm = norm_layer(4 * dim) 267 | 268 | def forward(self, x, H, W): 269 | """ Forward function. 270 | 271 | Args: 272 | x: Input feature, tensor size (B, H*W, C). 273 | H, W: Spatial resolution of the input feature. 274 | """ 275 | B, L, C = x.shape 276 | assert L == H * W, "input feature has wrong size" 277 | 278 | x = x.view(B, H, W, C) 279 | 280 | # padding 281 | pad_input = (H % 2 == 1) or (W % 2 == 1) 282 | if pad_input: 283 | x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2)) 284 | 285 | x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C 286 | x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C 287 | x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C 288 | x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C 289 | x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C 290 | x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C 291 | 292 | x = self.norm(x) 293 | x = self.reduction(x) 294 | 295 | return x 296 | 297 | 298 | class BasicLayer(nn.Module): 299 | """ A basic Swin Transformer layer for one stage. 300 | 301 | Args: 302 | dim (int): Number of feature channels 303 | depth (int): Depths of this stage. 304 | num_heads (int): Number of attention head. 305 | window_size (int): Local window size. Default: 7. 306 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4. 307 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 308 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. 309 | drop (float, optional): Dropout rate. Default: 0.0 310 | attn_drop (float, optional): Attention dropout rate. Default: 0.0 311 | drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 312 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 313 | downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None 314 | use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. 315 | """ 316 | 317 | def __init__(self, 318 | dim, 319 | depth, 320 | num_heads, 321 | window_size=7, 322 | mlp_ratio=4., 323 | qkv_bias=True, 324 | qk_scale=None, 325 | drop=0., 326 | attn_drop=0., 327 | drop_path=0., 328 | norm_layer=nn.LayerNorm, 329 | downsample=None, 330 | use_checkpoint=False): 331 | super().__init__() 332 | self.window_size = window_size 333 | self.shift_size = window_size // 2 334 | self.depth = depth 335 | self.use_checkpoint = use_checkpoint 336 | 337 | # build blocks 338 | self.blocks = nn.ModuleList([ 339 | SwinTransformerBlock( 340 | dim=dim, 341 | num_heads=num_heads, 342 | window_size=window_size, 343 | shift_size=0 if (i % 2 == 0) else window_size // 2, 344 | mlp_ratio=mlp_ratio, 345 | qkv_bias=qkv_bias, 346 | qk_scale=qk_scale, 347 | drop=drop, 348 | attn_drop=attn_drop, 349 | drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, 350 | norm_layer=norm_layer) 351 | for i in range(depth)]) 352 | 353 | # patch merging layer 354 | if downsample is not None: 355 | self.downsample = downsample(dim=dim, norm_layer=norm_layer) 356 | else: 357 | self.downsample = None 358 | 359 | def forward(self, x, H, W): 360 | """ Forward function. 361 | 362 | Args: 363 | x: Input feature, tensor size (B, H*W, C). 364 | H, W: Spatial resolution of the input feature. 365 | """ 366 | 367 | # calculate attention mask for SW-MSA 368 | Hp = int(np.ceil(H / self.window_size)) * self.window_size 369 | Wp = int(np.ceil(W / self.window_size)) * self.window_size 370 | img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1 371 | h_slices = (slice(0, -self.window_size), 372 | slice(-self.window_size, -self.shift_size), 373 | slice(-self.shift_size, None)) 374 | w_slices = (slice(0, -self.window_size), 375 | slice(-self.window_size, -self.shift_size), 376 | slice(-self.shift_size, None)) 377 | cnt = 0 378 | for h in h_slices: 379 | for w in w_slices: 380 | img_mask[:, h, w, :] = cnt 381 | cnt += 1 382 | 383 | mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 384 | mask_windows = mask_windows.view(-1, self.window_size * self.window_size) 385 | attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) 386 | attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) 387 | 388 | for blk in self.blocks: 389 | blk.H, blk.W = H, W 390 | if self.use_checkpoint: 391 | x = checkpoint.checkpoint(blk, x, attn_mask) 392 | else: 393 | x = blk(x, attn_mask) 394 | if self.downsample is not None: 395 | x_down = self.downsample(x, H, W) 396 | Wh, Ww = (H + 1) // 2, (W + 1) // 2 397 | return x, H, W, x_down, Wh, Ww 398 | else: 399 | return x, H, W, x, H, W 400 | 401 | 402 | class PatchEmbed(nn.Module): 403 | """ Image to Patch Embedding 404 | 405 | Args: 406 | patch_size (int): Patch token size. Default: 4. 407 | in_chans (int): Number of input image channels. Default: 3. 408 | embed_dim (int): Number of linear projection output channels. Default: 96. 409 | norm_layer (nn.Module, optional): Normalization layer. Default: None 410 | """ 411 | 412 | def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): 413 | super().__init__() 414 | patch_size = to_2tuple(patch_size) 415 | self.patch_size = patch_size 416 | 417 | self.in_chans = in_chans 418 | self.embed_dim = embed_dim 419 | 420 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) 421 | if norm_layer is not None: 422 | self.norm = norm_layer(embed_dim) 423 | else: 424 | self.norm = None 425 | 426 | def forward(self, x): 427 | """Forward function.""" 428 | # padding 429 | _, _, H, W = x.size() 430 | if W % self.patch_size[1] != 0: 431 | x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1])) 432 | if H % self.patch_size[0] != 0: 433 | x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0])) 434 | 435 | x = self.proj(x) # B C Wh Ww 436 | if self.norm is not None: 437 | Wh, Ww = x.size(2), x.size(3) 438 | x = x.flatten(2).transpose(1, 2) 439 | x = self.norm(x) 440 | x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww) 441 | 442 | return x 443 | 444 | 445 | class SwinTransformer(nn.Module): 446 | """ Swin Transformer backbone. 447 | A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - 448 | https://arxiv.org/pdf/2103.14030 449 | 450 | Args: 451 | pretrain_img_size (int): Input image size for training the pretrained model, 452 | used in absolute postion embedding. Default 224. 453 | patch_size (int | tuple(int)): Patch size. Default: 4. 454 | in_chans (int): Number of input image channels. Default: 3. 455 | embed_dim (int): Number of linear projection output channels. Default: 96. 456 | depths (tuple[int]): Depths of each Swin Transformer stage. 457 | num_heads (tuple[int]): Number of attention head of each stage. 458 | window_size (int): Window size. Default: 7. 459 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4. 460 | qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True 461 | qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. 462 | drop_rate (float): Dropout rate. 463 | attn_drop_rate (float): Attention dropout rate. Default: 0. 464 | drop_path_rate (float): Stochastic depth rate. Default: 0.2. 465 | norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. 466 | ape (bool): If True, add absolute position embedding to the patch embedding. Default: False. 467 | patch_norm (bool): If True, add normalization after patch embedding. Default: True. 468 | out_indices (Sequence[int]): Output from which stages. 469 | frozen_stages (int): Stages to be frozen (stop grad and set eval mode). 470 | -1 means not freezing any parameters. 471 | use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. 472 | """ 473 | 474 | def __init__(self, 475 | pretrain_img_size=224, 476 | patch_size=4, 477 | in_chans=3, 478 | embed_dim=96, 479 | depths=[2, 2, 6, 2], 480 | num_heads=[3, 6, 12, 24], 481 | window_size=7, 482 | mlp_ratio=4., 483 | qkv_bias=True, 484 | qk_scale=None, 485 | drop_rate=0., 486 | attn_drop_rate=0., 487 | drop_path_rate=0.2, 488 | norm_layer=nn.LayerNorm, 489 | ape=False, 490 | patch_norm=True, 491 | out_indices=(0, 1, 2, 3), 492 | frozen_stages=-1, 493 | use_checkpoint=False): 494 | super().__init__() 495 | 496 | self.pretrain_img_size = pretrain_img_size 497 | self.num_layers = len(depths) 498 | self.embed_dim = embed_dim 499 | self.ape = ape 500 | self.patch_norm = patch_norm 501 | self.out_indices = out_indices 502 | self.frozen_stages = frozen_stages 503 | 504 | # split image into non-overlapping patches 505 | self.patch_embed = PatchEmbed( 506 | patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, 507 | norm_layer=norm_layer if self.patch_norm else None) 508 | 509 | # absolute position embedding 510 | if self.ape: 511 | pretrain_img_size = to_2tuple(pretrain_img_size) 512 | patch_size = to_2tuple(patch_size) 513 | patches_resolution = [pretrain_img_size[0] // patch_size[0], pretrain_img_size[1] // patch_size[1]] 514 | 515 | self.absolute_pos_embed = nn.Parameter(torch.zeros(1, embed_dim, patches_resolution[0], patches_resolution[1])) 516 | trunc_normal_(self.absolute_pos_embed, std=.02) 517 | 518 | self.pos_drop = nn.Dropout(p=drop_rate) 519 | 520 | # stochastic depth 521 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule 522 | 523 | # build layers 524 | self.layers = nn.ModuleList() 525 | for i_layer in range(self.num_layers): 526 | layer = BasicLayer( 527 | dim=int(embed_dim * 2 ** i_layer), 528 | depth=depths[i_layer], 529 | num_heads=num_heads[i_layer], 530 | window_size=window_size, 531 | mlp_ratio=mlp_ratio, 532 | qkv_bias=qkv_bias, 533 | qk_scale=qk_scale, 534 | drop=drop_rate, 535 | attn_drop=attn_drop_rate, 536 | drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], 537 | norm_layer=norm_layer, 538 | downsample=PatchMerging if (i_layer < self.num_layers - 1) else None, 539 | use_checkpoint=use_checkpoint) 540 | self.layers.append(layer) 541 | 542 | num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)] 543 | self.num_features = num_features 544 | 545 | # add a norm layer for each output 546 | for i_layer in out_indices: 547 | layer = norm_layer(num_features[i_layer]) 548 | layer_name = f'norm{i_layer}' 549 | self.add_module(layer_name, layer) 550 | 551 | self._freeze_stages() 552 | 553 | def _freeze_stages(self): 554 | if self.frozen_stages >= 0: 555 | self.patch_embed.eval() 556 | for param in self.patch_embed.parameters(): 557 | param.requires_grad = False 558 | 559 | if self.frozen_stages >= 1 and self.ape: 560 | self.absolute_pos_embed.requires_grad = False 561 | 562 | if self.frozen_stages >= 2: 563 | self.pos_drop.eval() 564 | for i in range(0, self.frozen_stages - 1): 565 | m = self.layers[i] 566 | m.eval() 567 | for param in m.parameters(): 568 | param.requires_grad = False 569 | 570 | def init_weights(self, pretrained=None): 571 | """Initialize the weights in backbone. 572 | 573 | Args: 574 | pretrained (str, optional): Path to pre-trained weights. 575 | Defaults to None. 576 | """ 577 | 578 | def _init_weights(m): 579 | if isinstance(m, nn.Linear): 580 | trunc_normal_(m.weight, std=.02) 581 | if isinstance(m, nn.Linear) and m.bias is not None: 582 | nn.init.constant_(m.bias, 0) 583 | elif isinstance(m, nn.LayerNorm): 584 | nn.init.constant_(m.bias, 0) 585 | nn.init.constant_(m.weight, 1.0) 586 | 587 | if isinstance(pretrained, str): 588 | self.apply(_init_weights) 589 | logger = get_root_logger() 590 | load_checkpoint(self, pretrained, strict=False, logger=logger) 591 | elif pretrained is None: 592 | self.apply(_init_weights) 593 | else: 594 | raise TypeError('pretrained must be a str or None') 595 | 596 | def forward(self, x): 597 | x = self.patch_embed(x) 598 | 599 | Wh, Ww = x.size(2), x.size(3) 600 | if self.ape: 601 | # interpolate the position embedding to the corresponding size 602 | absolute_pos_embed = F.interpolate(self.absolute_pos_embed, size=(Wh, Ww), mode='bicubic') 603 | x = (x + absolute_pos_embed) # B Wh*Ww C 604 | 605 | outs = [x.contiguous()] 606 | x = x.flatten(2).transpose(1, 2) 607 | x = self.pos_drop(x) 608 | for i in range(self.num_layers): 609 | layer = self.layers[i] 610 | x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww) 611 | 612 | if i in self.out_indices: 613 | norm_layer = getattr(self, f'norm{i}') 614 | x_out = norm_layer(x_out) 615 | 616 | out = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous() 617 | outs.append(out) 618 | 619 | return tuple(outs) 620 | 621 | def train(self, mode=True): 622 | """Convert the model into training mode while keep layers freezed.""" 623 | super(SwinTransformer, self).train(mode) 624 | self._freeze_stages() 625 | 626 | def SwinT(pretrained=True): 627 | model = SwinTransformer(embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], window_size=7) 628 | if pretrained is True: 629 | model.load_state_dict(torch.load('data/backbone_ckpt/swin_tiny_patch4_window7_224.pth', map_location='cpu')['model'], strict=False) 630 | 631 | return model 632 | 633 | def SwinS(pretrained=True): 634 | model = SwinTransformer(embed_dim=96, depths=[2, 2, 18, 2], num_heads=[3, 6, 12, 24], window_size=7) 635 | if pretrained is True: 636 | model.load_state_dict(torch.load('data/backbone_ckpt/swin_small_patch4_window7_224.pth', map_location='cpu')['model'], strict=False) 637 | 638 | return model 639 | 640 | def SwinB(pretrained=True): 641 | model = SwinTransformer(embed_dim=128, depths=[2, 2, 18, 2], num_heads=[4, 8, 16, 32], window_size=12) 642 | if pretrained is True: 643 | model.load_state_dict(torch.load('./swin_base_patch4_window12_384_22kto1k.pth', map_location='cpu')['model'], strict=False) 644 | 645 | return model 646 | 647 | def SwinL(pretrained=True): 648 | model = SwinTransformer(embed_dim=192, depths=[2, 2, 18, 2], num_heads=[6, 12, 24, 48], window_size=12) 649 | if pretrained is True: 650 | model.load_state_dict(torch.load('data/backbone_ckpt/swin_large_patch4_window12_384_22kto1k.pth', map_location='cpu')['model'], strict=False) 651 | 652 | return model 653 | 654 | --------------------------------------------------------------------------------