├── test_fps.py
├── utils
├── config.py
├── misc.py
├── test_data.py
├── dataset_strategy_fpn.py
└── saliency_metric.py
├── LICENSE
├── requirements.txt
├── predict.py
├── test.py
├── README.md
├── train.py
└── model
├── MVANet.py
└── SwinTransformer.py
/test_fps.py:
--------------------------------------------------------------------------------
1 | import time
2 |
3 | import torch
4 |
5 | from model.MVANet import MVANet
6 |
7 | data = torch.randn(1, 3, 1024, 1024, device='cuda', dtype=torch.float32)
8 |
9 | model = MVANet().eval().cuda()
10 | for _ in range(10):
11 | model(data)
12 |
13 | torch.cuda.synchronize()
14 | start_time = time.perf_counter()
15 | for _ in range(100):
16 | model(data)
17 | torch.cuda.synchronize()
18 | print(f"FPS: {100 / (time.perf_counter() - start_time):.03f}")
19 |
--------------------------------------------------------------------------------
/utils/config.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | diste1 = '/home/vanessa/code/DIS-main/DIS5K/DIS-TE1/'
4 | diste2 = '/home/vanessa/code/DIS-main/DIS5K/DIS-TE2/'
5 | diste3 = '/home/vanessa/code/DIS-main/DIS5K/DIS-TE3/'
6 | diste4 = '/home/vanessa/code/DIS-main/DIS5K/DIS-TE4/'
7 | disvd = '/home/vanessa/code/DIS-main/DIS5K/DIS-VD/'
8 |
9 | diste1 = os.path.join(diste1)
10 | diste2 = os.path.join(diste2)
11 | diste3 = os.path.join(diste3)
12 | diste4 = os.path.join(diste4)
13 | disvd = os.path.join(disvd)
14 |
15 |
16 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 qianyu-dlut
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | absl-py==2.1.0
2 | accelerate==0.20.3
3 | addict==2.4.0
4 | aiohttp==3.8.6
5 | aiosignal==1.3.1
6 | aliyun-python-sdk-core==2.16.0
7 | aliyun-python-sdk-kms==2.16.5
8 | async-timeout==4.0.3
9 | asynctest==0.13.0
10 | attrs==24.2.0
11 | cachetools==5.5.0
12 | certifi==2022.12.7
13 | cffi==1.15.1
14 | charset-normalizer==3.4.0
15 | click==8.1.7
16 | colorama==0.4.6
17 | crcmod==1.7
18 | cryptography==43.0.1
19 | cycler==0.11.0
20 | datasets==2.13.2
21 | dill==0.3.6
22 | einops==0.6.1
23 | filelock==3.12.2
24 | fonttools==4.38.0
25 | frozenlist==1.3.3
26 | fsspec==2023.1.0
27 | google-auth==2.35.0
28 | google-auth-oauthlib==0.4.6
29 | grpcio==1.62.3
30 | huggingface-hub==0.16.4
31 | idna==3.10
32 | importlib-metadata==6.7.0
33 | jmespath==0.10.0
34 | kiwisolver==1.4.5
35 | Markdown==3.4.4
36 | markdown-it-py==2.2.0
37 | MarkupSafe==2.1.5
38 | matplotlib==3.5.3
39 | mdurl==0.1.2
40 | mmdet==2.17.0
41 | mmengine==0.8.1
42 | mmsegmentation==0.19.0
43 | model-index==0.1.11
44 | multidict==6.0.5
45 | multiprocess==0.70.14
46 | numpy==1.21.6
47 | oauthlib==3.2.2
48 | opencv-python==4.10.0.84
49 | opendatalab==0.0.10
50 | openmim==0.3.9
51 | openxlab==0.0.10
52 | ordered-set==4.1.0
53 | oss2==2.17.0
54 | packaging==24.0
55 | pandas==1.1.5
56 | Pillow==9.5.0
57 | pip==22.3.1
58 | platformdirs==4.0.0
59 | prettytable==3.7.0
60 | protobuf==3.20.3
61 | psutil==6.1.0
62 | pyarrow==12.0.1
63 | pyasn1==0.5.1
64 | pyasn1-modules==0.3.0
65 | pycocotools==2.0.7
66 | pycparser==2.21
67 | pycryptodome==3.21.0
68 | Pygments==2.17.2
69 | pyparsing==3.1.4
70 | python-dateutil==2.9.0.post0
71 | pytz==2023.4
72 | PyYAML==6.0.1
73 | regex==2024.4.16
74 | requests==2.28.2
75 | requests-oauthlib==2.0.0
76 | rich==13.8.1
77 | rsa==4.9
78 | safetensors==0.4.5
79 | scipy==1.7.3
80 | setuptools==60.2.0
81 | six==1.16.0
82 | tabulate==0.9.0
83 | tensorboard==2.11.2
84 | tensorboard-data-server==0.6.1
85 | tensorboard-plugin-wit==1.8.1
86 | termcolor==2.3.0
87 | terminaltables==3.1.10
88 | timm==0.9.12
89 | tokenizers==0.13.3
90 | tomli==2.0.1
91 | tqdm==4.65.2
92 | transformers==4.30.2
93 | ttach==0.0.3
94 | typing_extensions==4.7.1
95 | urllib3==1.26.20
96 | wcwidth==0.2.13
97 | Werkzeug==2.2.3
98 | wheel==0.38.4
99 | xxhash==3.5.0
100 | yapf==0.40.2
101 | yarl==1.9.4
102 | zipp==3.15.0
103 |
--------------------------------------------------------------------------------
/utils/misc.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch
3 | import numpy as np
4 | import os
5 |
6 | def clip_gradient(optimizer, grad_clip):
7 | for group in optimizer.param_groups:
8 | for param in group['params']:
9 | if param.grad is not None:
10 | param.grad.data.clamp_(-grad_clip, grad_clip)
11 |
12 |
13 | def adjust_lr(optimizer, init_lr, epoch, decay_rate=0.1, decay_epoch=5):
14 | decay = decay_rate ** (epoch // decay_epoch)
15 | for param_group in optimizer.param_groups:
16 | param_group['lr'] *= decay
17 |
18 |
19 | def truncated_normal_(tensor, mean=0, std=1):
20 | size = tensor.shape
21 | tmp = tensor.new_empty(size + (4,)).normal_()
22 | valid = (tmp < 2) & (tmp > -2)
23 | ind = valid.max(-1, keepdim=True)[1]
24 | tensor.data.copy_(tmp.gather(-1, ind).squeeze(-1))
25 | tensor.data.mul_(std).add_(mean)
26 |
27 | def init_weights(m):
28 | if type(m) == nn.Conv2d or type(m) == nn.ConvTranspose2d:
29 | nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu')
30 | #nn.init.normal_(m.weight, std=0.001)
31 | #nn.init.normal_(m.bias, std=0.001)
32 | truncated_normal_(m.bias, mean=0, std=0.001)
33 |
34 | def init_weights_orthogonal_normal(m):
35 | if type(m) == nn.Conv2d or type(m) == nn.ConvTranspose2d:
36 | nn.init.orthogonal_(m.weight)
37 | truncated_normal_(m.bias, mean=0, std=0.001)
38 | #nn.init.normal_(m.bias, std=0.001)
39 |
40 | def l2_regularisation(m):
41 | l2_reg = None
42 |
43 | for W in m.parameters():
44 | if l2_reg is None:
45 | l2_reg = W.norm(2)
46 | else:
47 | l2_reg = l2_reg + W.norm(2)
48 | return l2_reg
49 |
50 | def check_mkdir(dir_name):
51 | if not os.path.isdir(dir_name):
52 | os.makedirs(dir_name)
53 |
54 | class AvgMeter(object):
55 | def __init__(self, num=40):
56 | self.num = num
57 | self.reset()
58 |
59 | def reset(self):
60 | self.val = 0
61 | self.avg = 0
62 | self.sum = 0
63 | self.count = 0
64 | self.losses = []
65 |
66 | def update(self, val, n=1):
67 | self.val = val
68 | self.sum += val * n
69 | self.count += n
70 | self.avg = self.sum / self.count
71 | self.losses.append(val)
72 |
73 | def show(self):
74 | a = len(self.losses)
75 | b = np.maximum(a-self.num, 0)
76 | c = self.losses[b:]
77 | #print(c)
78 | #d = torch.mean(torch.stack(c))
79 | #print(d)
80 | return torch.mean(torch.stack(c))
81 |
82 | # def save_mask_prediction_example(mask, pred, iter):
83 | # plt.imshow(pred[0,:,:],cmap='Greys')
84 | # plt.savefig('images/'+str(iter)+"_prediction.png")
85 | # plt.imshow(mask[0,:,:],cmap='Greys')
86 | # plt.savefig('images/'+str(iter)+"_mask.png")
--------------------------------------------------------------------------------
/predict.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 | import time
4 | import torch
5 | from PIL import Image
6 | from torch.autograd import Variable
7 | from torchvision import transforms
8 | from utils.config import diste1,diste2,diste3,diste4,disvd
9 | from utils.misc import check_mkdir
10 | from model.MVANet import inf_MVANet
11 | import ttach as tta
12 |
13 | torch.cuda.set_device(0)
14 | ckpt_path = './saved_model/'
15 | args = {
16 | 'save_results': True
17 | }
18 |
19 |
20 | img_transform = transforms.Compose([
21 | transforms.ToTensor(),
22 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
23 |
24 | ])
25 |
26 | depth_transform = transforms.ToTensor()
27 | target_transform = transforms.ToTensor()
28 | to_pil = transforms.ToPILImage()
29 |
30 | to_test ={
31 | 'DIS-TE1':diste1,
32 | 'DIS-TE2':diste2,
33 | 'DIS-TE3':diste3,
34 | 'DIS-TE4':diste4,
35 | 'DIS-VD':disvd,
36 | }
37 |
38 | transforms = tta.Compose(
39 | [
40 | tta.HorizontalFlip(),
41 | tta.Scale(scales=[0.75, 1, 1.125], interpolation='bilinear', align_corners=False),
42 | ]
43 | )
44 |
45 | def main(item):
46 | net = inf_MVANet().cuda()
47 | pretrained_dict = torch.load(os.path.join(ckpt_path, item + '.pth'), map_location='cuda')
48 | model_dict = net.state_dict()
49 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
50 | model_dict.update(pretrained_dict)
51 | net.load_state_dict(model_dict)
52 | net.eval()
53 | with torch.no_grad():
54 | for name, root in to_test.items():
55 | root1 = os.path.join(root, 'images')
56 | img_list = [os.path.splitext(f) for f in os.listdir(root1)]
57 | for idx, img_name in enumerate(img_list):
58 | print ('predicting for %s: %d / %d' % (name, idx + 1, len(img_list)))
59 | rgb_png_path = os.path.join(root, 'images', img_name[0] + '.png')
60 | rgb_jpg_path = os.path.join(root, 'images', img_name[0] + '.jpg')
61 | if os.path.exists(rgb_png_path):
62 | img = Image.open(rgb_png_path).convert('RGB')
63 | else:
64 | img = Image.open(rgb_jpg_path).convert('RGB')
65 | w_,h_ = img.size
66 | img_resize = img.resize([1024,1024],Image.BILINEAR)
67 | img_var = Variable(img_transform(img_resize).unsqueeze(0), volatile=True).cuda()
68 | mask = []
69 | for transformer in transforms:
70 | rgb_trans = transformer.augment_image(img_var)
71 | model_output = net(rgb_trans)
72 | deaug_mask = transformer.deaugment_mask(model_output)
73 | mask.append(deaug_mask)
74 |
75 | prediction = torch.mean(torch.stack(mask, dim=0), dim=0)
76 | prediction = prediction.sigmoid()
77 | prediction = to_pil(prediction.data.squeeze(0).cpu())
78 | prediction = prediction.resize((w_, h_), Image.BILINEAR)
79 | if args['save_results']:
80 | check_mkdir(os.path.join(ckpt_path, item, name))
81 | prediction.save(os.path.join(ckpt_path, item, name, img_name[0] + '.png'))
82 |
83 |
84 |
85 | if __name__ == '__main__':
86 | files = os.listdir(ckpt_path)
87 | files.sort()
88 | for items in files:
89 | if '80.pth' in items:
90 | item = items.split('.')[0]
91 | main(item)
92 |
93 |
94 |
95 |
96 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 | from utils.test_data import test_dataset
4 | from utils.saliency_metric import cal_mae,cal_fm,cal_sm,cal_em,cal_wfm, cal_dice, cal_iou,cal_ber,cal_acc, HCEMeasure
5 | from utils.config import diste1,diste2,diste3,diste4,disvd
6 | from tqdm import tqdm
7 | import cv2
8 | from skimage.morphology import skeletonize
9 |
10 | test_datasets = {
11 | 'DIS-TE1':diste1,
12 | 'DIS-TE2':diste2,
13 | 'DIS-TE3':diste3,
14 | 'DIS-TE4':diste4,
15 | 'DIS-VD':disvd,
16 | }
17 |
18 |
19 | dir = [
20 | './saved_model/MVANet/Model_80/'
21 | ]
22 | for d in dir:
23 | for name, root in test_datasets.items():
24 | print(name)
25 | sal_root = os.path.join(d, name)
26 | print(sal_root)
27 | gt_root = root + 'masks'
28 | print(gt_root)
29 | if os.path.exists(sal_root):
30 | test_loader = test_dataset(sal_root, gt_root)
31 | mae, fm, sm, em, wfm, m_dice, m_iou, ber, acc, hce = cal_mae(), cal_fm(
32 | test_loader.size), cal_sm(), cal_em(), cal_wfm(), cal_dice(), cal_iou(), cal_ber(), cal_acc(), HCEMeasure()
33 | for i in tqdm(range(test_loader.size)):
34 | # print ('predicting for %d / %d' % ( i + 1, test_loader.size))
35 | sal, gt, gt_path = test_loader.load_data()
36 |
37 | if sal.size != gt.size:
38 | x, y = gt.size
39 | sal = sal.resize((x, y))
40 | gt = np.asarray(gt, np.float64)
41 | gt /= (gt.max() + 1e-8)
42 | gt[gt > 0.5] = 1
43 | gt[gt != 1] = 0
44 | res = sal
45 | res = np.array(res, np.float64)
46 | if res.max() == res.min():
47 | res = res / 255
48 | else:
49 | res = (res - res.min()) / (res.max() - res.min())
50 |
51 | ske_path = gt_path.replace("/masks/", "/ske/")
52 | if os.path.exists(ske_path):
53 | ske_ary = cv2.imread(ske_path, cv2.IMREAD_GRAYSCALE)
54 | ske_ary = ske_ary > 128
55 | else:
56 | ske_ary = skeletonize(gt > 0.5)
57 | ske_save_dir = os.path.join(*ske_path.split(os.sep)[:-1])
58 | if ske_path[0] == os.sep:
59 | ske_save_dir = os.sep + ske_save_dir
60 | os.makedirs(ske_save_dir, exist_ok=True)
61 | cv2.imwrite(ske_path, ske_ary.astype(np.uint8) * 255)
62 |
63 | mae.update(res, gt)
64 | sm.update(res, gt)
65 | fm.update(res, gt)
66 | em.update(res, gt)
67 | wfm.update(res, gt)
68 | m_dice.update(res, gt)
69 | m_iou.update(res, gt)
70 | ber.update(res, gt)
71 | acc.update(res, gt)
72 | hce.step(pred=res, gt=gt, gt_ske=ske_ary)
73 |
74 |
75 | MAE = mae.show()
76 | maxf, meanf, _, _ = fm.show()
77 | sm = sm.show()
78 | em = em.show()
79 | wfm = wfm.show()
80 | m_dice = m_dice.show()
81 | m_iou = m_iou.show()
82 | ber = ber.show()
83 | acc = acc.show()
84 | hce = hce.get_results()["hce"]
85 |
86 | print(
87 | 'dataset: {} MAE: {:.4f} Ber: {:.4f} maxF: {:.4f} avgF: {:.4f} wfm: {:.4f} Sm: {:.4f} adpEm: {:.4f} M_dice: {:.4f} M_iou: {:.4f} Acc: {:.4f} HCE:{}'.format(
88 | name, MAE, ber, maxf, meanf, wfm, sm, em, m_dice, m_iou, acc, int(hce)))
89 |
--------------------------------------------------------------------------------
/utils/test_data.py:
--------------------------------------------------------------------------------
1 | import os
2 | from PIL import Image
3 | import torch.utils.data as data
4 | import torchvision.transforms as transforms
5 |
6 | class test_dataset:
7 | def __init__(self, image_root, gt_root):
8 | self.img_list_1 = [os.path.splitext(f)[0] for f in os.listdir(image_root) if f.endswith('.png') or f.endswith('.jpg') or f.endswith('.bmp')]
9 | self.img_list_2 = [os.path.splitext(f)[0] for f in os.listdir(gt_root) if f.endswith('.png') or f.endswith('.jpg') or f.endswith('.bmp')]
10 | self.img_list = list(set(self.img_list_1).intersection(set(self.img_list_2)))
11 |
12 | self.image_root = image_root
13 | self.gt_root = gt_root
14 | self.transform = transforms.Compose([
15 | transforms.ToTensor(),
16 | ])
17 | self.gt_transform = transforms.ToTensor()
18 | self.size = len(self.img_list)
19 | self.index = 0
20 |
21 | def load_data(self):
22 | #image = self.rgb_loader(self.images[self.index])
23 | rgb_png_path = os.path.join(self.image_root,self.img_list[self.index]+ '.png')
24 | rgb_jpg_path = os.path.join(self.image_root,self.img_list[self.index]+ '.jpg')
25 | rgb_bmp_path = os.path.join(self.image_root,self.img_list[self.index]+ '.bmp')
26 | if os.path.exists(rgb_png_path):
27 | image = self.binary_loader(rgb_png_path)
28 | elif os.path.exists(rgb_jpg_path):
29 | image = self.binary_loader(rgb_jpg_path)
30 | else:
31 | image = self.binary_loader(rgb_bmp_path)
32 | if os.path.exists(os.path.join(self.gt_root,self.img_list[self.index] + '.png')):
33 | gt = self.binary_loader(os.path.join(self.gt_root,self.img_list[self.index] + '.png'))
34 | elif os.path.exists(os.path.join(self.gt_root,self.img_list[self.index] + '.jpg')):
35 | gt = self.binary_loader(os.path.join(self.gt_root,self.img_list[self.index] + '.jpg'))
36 | else:
37 | gt = self.binary_loader(os.path.join(self.gt_root, self.img_list[self.index] + '.bmp'))
38 |
39 | self.index += 1
40 | return image, gt, os.path.join(self.gt_root, self.img_list[self.index - 1] + ".jpg")
41 |
42 | def rgb_loader(self, path):
43 | with open(path, 'rb') as f:
44 | img = Image.open(f)
45 | return img.convert('RGB')
46 |
47 | def binary_loader(self, path):
48 | with open(path, 'rb') as f:
49 | img = Image.open(f)
50 | return img.convert('L')
51 |
52 | class val_dataset:
53 | def __init__(self, image_root, gt_root):
54 | self.img_list_1 = [os.path.splitext(f)[0] for f in os.listdir(image_root) if f.endswith('.png') or f.endswith('.jpg') or f.endswith('.bmp')]
55 | self.img_list_2 = [os.path.splitext(f)[0] for f in os.listdir(gt_root) if f.endswith('.png') or f.endswith('.jpg') or f.endswith('.bmp')]
56 | self.img_list = list(set(self.img_list_1).intersection(set(self.img_list_2)))
57 |
58 | self.image_root = image_root
59 | self.gt_root = gt_root
60 | self.transform = transforms.Compose([
61 | transforms.ToTensor(),
62 | ])
63 | self.gt_transform = transforms.ToTensor()
64 | self.size = len(self.img_list)
65 | self.index = 0
66 |
67 | def load_data(self):
68 | #image = self.rgb_loader(self.images[self.index])
69 | rgb_png_path = os.path.join(self.image_root,self.img_list[self.index]+ '.png')
70 | rgb_jpg_path = os.path.join(self.image_root,self.img_list[self.index]+ '.jpg')
71 | rgb_bmp_path = os.path.join(self.image_root,self.img_list[self.index]+ '.bmp')
72 | if os.path.exists(rgb_png_path):
73 | image = self.rgb_loader(rgb_png_path)
74 | elif os.path.exists(rgb_jpg_path):
75 | image = self.rgb_loader(rgb_jpg_path)
76 | else:
77 | image = self.rgb_loader(rgb_bmp_path)
78 | if os.path.exists(os.path.join(self.gt_root,self.img_list[self.index] + '.png')):
79 | gt = self.binary_loader(os.path.join(self.gt_root,self.img_list[self.index] + '.png'))
80 | elif os.path.exists(os.path.join(self.gt_root,self.img_list[self.index] + '.jpg')):
81 | gt = self.binary_loader(os.path.join(self.gt_root,self.img_list[self.index] + '.jpg'))
82 | else:
83 | gt = self.binary_loader(os.path.join(self.gt_root, self.img_list[self.index] + '.bmp'))
84 |
85 | self.index += 1
86 | return image, gt
87 |
88 | def rgb_loader(self, path):
89 | with open(path, 'rb') as f:
90 | img = Image.open(f)
91 | return img.convert('RGB')
92 |
93 | def binary_loader(self, path):
94 | with open(path, 'rb') as f:
95 | img = Image.open(f)
96 | return img.convert('L')
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # MVANet
2 | The official repo of the CVPR 2024 paper (Highlight), [Multi-view Aggregation Network for Dichotomous Image Segmentation](https://arxiv.org/abs/2404.07445)
3 |
4 |
5 | [](https://paperswithcode.com/sota/dichotomous-image-segmentation-on-dis-te1?p=multi-view-aggregation-network-for)
6 | [](https://paperswithcode.com/sota/dichotomous-image-segmentation-on-dis-te2?p=multi-view-aggregation-network-for)
7 | [](https://paperswithcode.com/sota/dichotomous-image-segmentation-on-dis-te3?p=multi-view-aggregation-network-for)
8 | [](https://paperswithcode.com/sota/dichotomous-image-segmentation-on-dis-te4?p=multi-view-aggregation-network-for)
9 | [](https://paperswithcode.com/sota/dichotomous-image-segmentation-on-dis-vd?p=multi-view-aggregation-network-for)
10 | ## Introduction
11 | Dichotomous Image Segmentation (DIS) has recently emerged towards high-precision object segmentation from high-resolution natural images. When designing an effective DIS model, the main challenge is how to balance the semantic dispersion of high-resolution targets in the small receptive field and the loss of high-precision details in the large receptive field. Existing methods rely on tedious multiple encoder-decoder streams and stages to gradually complete the global localization and local refinement.
12 |
13 | Human visual system captures regions of interest by observing them from multiple views. Inspired by it, we model DIS as a multi-view object perception problem and provide a parsimonious multi-view aggregation network (MVANet), which unifies the feature fusion of the distant view and close-up view into a single stream with one encoder-decoder structure. Specifically, we split the high-resolution input images from the original view into the distant view images with global information and close-up view images with local details. Thus, they can constitute a set of complementary multi-view low-resolution input patches.
14 |
15 |
16 |
17 |
18 | Moreover, two efficient transformer-based multi-view complementary localization and refinement modules (MCLM & MCRM) are proposed to jointly capturing the localization and restoring the boundary details of the targets.
19 |
20 |
21 |
22 |
23 |
24 | NOTE:Initially, we calculated Fm by averaging precision and recall, then using these averages to compute Fm. Thanks to feedback, we identified this bug and revised the approach to compute Fm for each image individually before averaging. We have updated the `./utils/saliency_metric.py` file to fix this issue. Additionally, we have updated the results on the DIS-VD dataset and included the HCE metric. The updated results are shown below:
25 |
26 |
27 |
28 |
29 | Here are some of our visual results:
30 |
31 |
32 |
33 |
34 |
35 | ## I. Requiremets
36 |
37 | 1. Clone this repository
38 | ```
39 | git clone git@github.com:qianyu-dlut/MVANet.git
40 | cd MVANet
41 | ```
42 |
43 | 2. Install packages
44 |
45 | ```
46 | conda create -n mvanet python==3.7
47 | conda activate mvanet
48 | pip install torch==1.10.1+cu102 torchvision==0.11.2+cu102 torchaudio==0.10.1 -f https://download.pytorch.org/whl/cu102/torch_stable.html
49 | pip install -U openmim
50 | mim install mmcv-full==1.3.17
51 | pip install -r requirements.txt
52 | ```
53 |
54 | ## II. Training
55 | 1. Download the dataset [DIS5K](https://drive.google.com/file/d/1O1eIuXX1hlGsV7qx4eSkjH231q7G1by1/view?usp=sharing) and update `image_root` `gt_root` in `./train.py` (line 39-40).
56 | 1. Download the pretrained model at [Google Drive](https://drive.google.com/file/d/1-Zi_DtCT8oC2UAZpB3_XoFOIxIweIAyk/view?usp=sharing) and update the pretrained model path in `./model/SwinTransformer.py` (line 643)
57 | 2. Then, you can start training by simply running:
58 | ```
59 | python train.py
60 | ```
61 |
62 | ## III. Testing
63 | 1. Update the data path in config file `./utils/config.py` (line 3~7)
64 | 2. Replace the existing path with the path to your saved model in `./predict.py` (line 14)
65 |
66 | You can also download our trained model at [Google Drive](https://drive.google.com/file/d/1_gabQXOF03MfXnf3EWDK1d_8wKiOemOv/view?usp=sharing).
67 | 3. Start predicting by:
68 | ```
69 | python predict.py
70 | ```
71 | 4. Change the predicted map path in `./test.py` (line 19) and start testing:
72 | ```
73 | python test.py
74 | ```
75 |
76 | You can get our prediction maps at [Google Drive](https://drive.google.com/file/d/1qN9mVNK9hfS_a1radFQ9QNYsAQo9FpYS/view?usp=sharing).
77 |
78 | 5. You can get the FPS performance by running:
79 | ```
80 | python test_fps.py
81 | ```
82 |
83 | ## Contact
84 | If you have any questions, please feel free to contact me(ms.yuqian AT mail DOT dlut DOT edu DOT cn).
85 |
86 | ## Citations
87 | ```
88 | @inproceedings{MVANet,
89 | title={Multi-view Aggregation Network for Dichotomous Image Segmentation},
90 | author={Yu, Qian and Zhao, Xiaoqi and Pang, Youwei and Zhang, Lihe and Lu, Huchuan},
91 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
92 | pages={3921--3930},
93 | year={2024}
94 | }
95 |
96 | ```
97 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import os, argparse
3 | os.environ["CUDA_VISIBLE_DEVICES"] ='0'
4 | from datetime import datetime
5 | from model.MVANet import MVANet
6 | from utils.dataset_strategy_fpn import get_loader
7 | from utils.misc import adjust_lr, AvgMeter
8 | import torch.nn.functional as F
9 | from torch.autograd import Variable
10 | from torch.backends import cudnn
11 | from torchvision import transforms
12 | import torch.nn as nn
13 | from torch.cuda import amp
14 | from torch.utils.tensorboard import SummaryWriter
15 |
16 | writer = SummaryWriter()
17 |
18 | cudnn.benchmark = True
19 |
20 | parser = argparse.ArgumentParser()
21 | parser.add_argument('--epoch', type=int, default=80, help='epoch number')
22 | parser.add_argument('--lr_gen', type=float, default=1e-5, help='learning rate')
23 | parser.add_argument('--batchsize', type=int, default=1, help='training batch size')
24 | parser.add_argument('--trainsize', type=int, default=1024, help='training dataset size')
25 | parser.add_argument('--decay_rate', type=float, default=0.9, help='decay rate of learning rate')
26 | parser.add_argument('--decay_epoch', type=int, default=60, help='every n epochs decay learning rate')
27 |
28 | opt = parser.parse_args()
29 | print('Generator Learning Rate: {}'.format(opt.lr_gen))
30 | # build models
31 | if hasattr(torch.cuda, 'empty_cache'):
32 | torch.cuda.empty_cache()
33 | generator = MVANet()
34 | generator.cuda()
35 |
36 | generator_params = generator.parameters()
37 | generator_optimizer = torch.optim.Adam(generator_params, opt.lr_gen)
38 |
39 | image_root = './data/DIS5K/DIS-TR/images/'
40 | gt_root = './data/DIS5K/DIS-TR/masks/'
41 |
42 | train_loader = get_loader(image_root, gt_root, batchsize=opt.batchsize, trainsize=opt.trainsize)
43 | total_step = len(train_loader)
44 | to_pil = transforms.ToPILImage()
45 | ## define loss
46 |
47 | CE = torch.nn.BCELoss()
48 | mse_loss = torch.nn.MSELoss(size_average=True, reduce=True)
49 | size_rates = [1]
50 | criterion = nn.BCEWithLogitsLoss().cuda()
51 | criterion_mae = nn.L1Loss().cuda()
52 | criterion_mse = nn.MSELoss().cuda()
53 | use_fp16 = True
54 | scaler = amp.GradScaler(enabled=use_fp16)
55 |
56 | def structure_loss(pred, mask):
57 | weit = 1+5*torch.abs(F.avg_pool2d(mask, kernel_size=31, stride=1, padding=15)-mask)
58 | wbce = F.binary_cross_entropy_with_logits(pred, mask, reduction='none')
59 | wbce = (weit*wbce).sum(dim=(2,3))/weit.sum(dim=(2,3))
60 |
61 |
62 | pred = torch.sigmoid(pred)
63 | inter = ((pred * mask) * weit).sum(dim=(2, 3))
64 |
65 | union = ((pred + mask) * weit).sum(dim=(2, 3))
66 | wiou = 1-(inter+1)/(union-inter+1)
67 |
68 | return (wbce+wiou).mean()
69 |
70 |
71 |
72 | for epoch in range(1, opt.epoch+1):
73 | torch.cuda.empty_cache()
74 | generator.train()
75 | loss_record = AvgMeter()
76 | print('Generator Learning Rate: {}'.format(generator_optimizer.param_groups[0]['lr']))
77 | for i, pack in enumerate(train_loader, start=1):
78 | torch.cuda.empty_cache()
79 | for rate in size_rates:
80 | torch.cuda.empty_cache()
81 | generator_optimizer.zero_grad()
82 | images, gts = pack
83 | images = Variable(images)
84 | gts = Variable(gts)
85 | images = images.cuda()
86 | gts = gts.cuda()
87 | trainsize = int(round(opt.trainsize * rate / 32) * 32)
88 | if rate != 1:
89 | images = F.upsample(images, size=(trainsize, trainsize), mode='bilinear',
90 | align_corners=True)
91 | gts = F.upsample(gts, size=(trainsize, trainsize), mode='bilinear', align_corners=True)
92 |
93 | b, c, h, w = gts.size()
94 | target_1 = F.upsample(gts, size=h // 4, mode='nearest')
95 | target_2 = F.upsample(gts, size=h // 8, mode='nearest').cuda()
96 | target_3 = F.upsample(gts, size=h // 16, mode='nearest').cuda()
97 | target_4 = F.upsample(gts, size=h // 32, mode='nearest').cuda()
98 | target_5 = F.upsample(gts, size=h // 64, mode='nearest').cuda()
99 |
100 | with amp.autocast(enabled=use_fp16):
101 | sideout5, sideout4, sideout3, sideout2, sideout1, final, glb5, glb4, glb3, glb2, glb1, tokenattmap4, tokenattmap3,tokenattmap2,tokenattmap1= generator.forward(images)
102 | loss1 = structure_loss(sideout5, target_4)
103 | loss2 = structure_loss(sideout4, target_3)
104 | loss3 = structure_loss(sideout3, target_2)
105 | loss4 = structure_loss(sideout2, target_1)
106 | loss5 = structure_loss(sideout1, target_1)
107 | loss6 = structure_loss(final, gts)
108 | loss7 = structure_loss(glb5, target_5)
109 | loss8 = structure_loss(glb4, target_4)
110 | loss9 = structure_loss(glb3, target_3)
111 | loss10 = structure_loss(glb2, target_2)
112 | loss11 = structure_loss(glb1, target_2)
113 | loss12 = structure_loss(tokenattmap4, target_3)
114 | loss13 = structure_loss(tokenattmap3, target_2)
115 | loss14 = structure_loss(tokenattmap2, target_1)
116 | loss15 = structure_loss(tokenattmap1, target_1)
117 | loss = loss1 + loss2 + loss3 + loss4 + loss5 + loss6 + 0.3*(loss7 + loss8 + loss9 + loss10 + loss11)+ 0.3*(loss12 + loss13 + loss14 + loss15)
118 | Loss_loc = loss1 + loss2 + loss3 + loss4 + loss5 + loss6
119 | Loss_glb = loss7 + loss8 + loss9 + loss10 + loss11
120 | Loss_map = loss12 + loss13 + loss14 + loss15
121 | writer.add_scalar('loss', loss.item(), epoch * len(train_loader) + i)
122 |
123 | generator_optimizer.zero_grad()
124 | scaler.scale(loss).backward()
125 | scaler.step(generator_optimizer)
126 | scaler.update()
127 |
128 | if rate == 1:
129 | loss_record.update(loss.data, opt.batchsize)
130 |
131 |
132 | if i % 10 == 0 or i == total_step:
133 | print('{} Epoch [{:03d}/{:03d}], Step [{:04d}/{:04d}], gen Loss: {:.4f}'.
134 | format(datetime.now(), epoch, opt.epoch, i, total_step, loss_record.show()))
135 |
136 | adjust_lr(generator_optimizer, opt.lr_gen, epoch, opt.decay_rate, opt.decay_epoch)
137 | # save checkpoints every 20 epochs
138 | if epoch % 20== 0 :
139 | save_path = './saved_model/MVANet/'
140 | if not os.path.exists(save_path):
141 | os.mkdir(save_path)
142 | torch.save(generator.state_dict(), save_path + 'Model' + '_%d' % epoch + '.pth')
143 |
144 |
145 |
--------------------------------------------------------------------------------
/utils/dataset_strategy_fpn.py:
--------------------------------------------------------------------------------
1 | import os
2 | from PIL import Image
3 | import torch.utils.data as data
4 | import torchvision.transforms as transforms
5 | import random
6 | import numpy as np
7 | from PIL import ImageEnhance
8 |
9 |
10 | # several data augumentation strategies
11 | def cv_random_flip(img, label):
12 | flip_flag = random.randint(0, 1)
13 | # flip_flag2= random.randint(0,1)
14 | # left right flip
15 | if flip_flag == 1:
16 | img = img.transpose(Image.FLIP_LEFT_RIGHT)
17 | label = label.transpose(Image.FLIP_LEFT_RIGHT)
18 | # top bottom flip
19 | # if flip_flag2==1:
20 | # img = img.transpose(Image.FLIP_TOP_BOTTOM)
21 | # label = label.transpose(Image.FLIP_TOP_BOTTOM)
22 | # depth = depth.transpose(Image.FLIP_TOP_BOTTOM)
23 | return img, label
24 |
25 |
26 | def randomCrop(image, label):
27 | border = 30
28 | image_width = image.size[0]
29 | image_height = image.size[1]
30 | crop_win_width = np.random.randint(image_width - border, image_width)
31 | crop_win_height = np.random.randint(image_height - border, image_height)
32 | random_region = (
33 | (image_width - crop_win_width) >> 1, (image_height - crop_win_height) >> 1, (image_width + crop_win_width) >> 1,
34 | (image_height + crop_win_height) >> 1)
35 | return image.crop(random_region), label.crop(random_region)
36 |
37 |
38 | def randomRotation(image, label):
39 | mode = Image.BICUBIC
40 | if random.random() > 0.8:
41 | random_angle = np.random.randint(-15, 15)
42 | image = image.rotate(random_angle, mode)
43 | label = label.rotate(random_angle, mode)
44 | return image, label
45 |
46 |
47 | def colorEnhance(image):
48 | bright_intensity = random.randint(5, 15) / 10.0
49 | image = ImageEnhance.Brightness(image).enhance(bright_intensity)
50 | contrast_intensity = random.randint(5, 15) / 10.0
51 | image = ImageEnhance.Contrast(image).enhance(contrast_intensity)
52 | color_intensity = random.randint(0, 20) / 10.0
53 | image = ImageEnhance.Color(image).enhance(color_intensity)
54 | sharp_intensity = random.randint(0, 30) / 10.0
55 | image = ImageEnhance.Sharpness(image).enhance(sharp_intensity)
56 | return image
57 |
58 |
59 | def randomGaussian(image, mean=0.1, sigma=0.35):
60 | def gaussianNoisy(im, mean=mean, sigma=sigma):
61 | for _i in range(len(im)):
62 | im[_i] += random.gauss(mean, sigma)
63 | return im
64 |
65 | img = np.asarray(image)
66 | width, height = img.shape
67 | img = gaussianNoisy(img[:].flatten(), mean, sigma)
68 | img = img.reshape([width, height])
69 | return Image.fromarray(np.uint8(img))
70 |
71 |
72 | def randomPeper(img):
73 | img = np.array(img)
74 | noiseNum = int(0.0015 * img.shape[0] * img.shape[1])
75 | for i in range(noiseNum):
76 |
77 | randX = random.randint(0, img.shape[0] - 1)
78 |
79 | randY = random.randint(0, img.shape[1] - 1)
80 |
81 | if random.randint(0, 1) == 0:
82 |
83 | img[randX, randY] = 0
84 |
85 | else:
86 |
87 | img[randX, randY] = 255
88 | return Image.fromarray(img)
89 |
90 |
91 | # dataset for training
92 | # The current loader is not using the normalized depth maps for training and test. If you use the normalized depth maps
93 | # (e.g., 0 represents background and 1 represents foreground.), the performance will be further improved.
94 | class DISDataset(data.Dataset):
95 | def __init__(self, image_root, gt_root, trainsize):
96 | self.trainsize = trainsize
97 | self.images = [image_root + f for f in os.listdir(image_root) if f.endswith('.jpg') or f.endswith('tif')]
98 | self.gts = [gt_root + f for f in os.listdir(gt_root) if f.endswith('.jpg')
99 | or f.endswith('.png') or f.endswith('tif')]
100 | self.images = sorted(self.images)
101 | self.gts = sorted(self.gts)
102 | self.filter_files()
103 | self.size = len(self.images)
104 | self.img_transform = transforms.Compose([
105 | transforms.Resize((self.trainsize, self.trainsize)),
106 | transforms.ToTensor(),
107 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
108 | self.gt_transform = transforms.Compose([
109 | transforms.Resize((self.trainsize, self.trainsize)),
110 | transforms.ToTensor()])
111 |
112 | def __getitem__(self, index):
113 | image = self.rgb_loader(self.images[index])
114 | gt = self.binary_loader(self.gts[index])
115 | image, gt = cv_random_flip(image, gt)
116 | image, gt = randomCrop(image, gt)
117 | image, gt = randomRotation(image, gt)
118 | image = colorEnhance(image)
119 | image = self.img_transform(image)
120 | gt = self.gt_transform(gt)
121 |
122 |
123 | return image, gt
124 |
125 | def filter_files(self):
126 | assert len(self.images) == len(self.gts) and len(self.gts) == len(self.images)
127 | images = []
128 | gts = []
129 | for img_path, gt_path in zip(self.images, self.gts):
130 | img = Image.open(img_path)
131 | gt = Image.open(gt_path)
132 | if img.size == gt.size :
133 | images.append(img_path)
134 | gts.append(gt_path)
135 | self.images = images
136 | self.gts = gts
137 |
138 | def rgb_loader(self, path):
139 | with open(path, 'rb') as f:
140 | img = Image.open(f)
141 | return img.convert('RGB')
142 |
143 | def binary_loader(self, path):
144 | with open(path, 'rb') as f:
145 | img = Image.open(f)
146 | return img.convert('L')
147 |
148 | def resize(self, img, gt):
149 | assert img.size == gt.size
150 | w, h = img.size
151 | if h < self.trainsize or w < self.trainsize:
152 | h = max(h, self.trainsize)
153 | w = max(w, self.trainsize)
154 | return img.resize((w, h), Image.BILINEAR), gt.resize((w, h), Image.NEAREST)
155 | else:
156 | return img, gt
157 |
158 | def __len__(self):
159 | return self.size
160 |
161 |
162 | # dataloader for training
163 | def get_loader(image_root, gt_root, batchsize, trainsize, shuffle=True, num_workers=12, pin_memory=False):
164 | dataset = DISDataset(image_root, gt_root, trainsize)
165 | data_loader = data.DataLoader(dataset=dataset,
166 | batch_size=batchsize,
167 | shuffle=shuffle,
168 | num_workers=num_workers,
169 | pin_memory=pin_memory)
170 | return data_loader
171 |
172 |
173 | # test dataset and loader
174 | class test_dataset:
175 | def __init__(self, image_root, depth_root, testsize):
176 | self.testsize = testsize
177 | self.images = [image_root + f for f in os.listdir(image_root) if f.endswith('.jpg')]
178 | self.depths = [depth_root + f for f in os.listdir(depth_root) if f.endswith('.bmp')
179 | or f.endswith('.png')]
180 | self.images = sorted(self.images)
181 | self.depths = sorted(self.depths)
182 | self.transform = transforms.Compose([
183 | transforms.Resize((self.testsize, self.testsize)),
184 | transforms.ToTensor(),
185 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
186 | # self.gt_transform = transforms.Compose([
187 | # transforms.Resize((self.trainsize, self.trainsize)),
188 | # transforms.ToTensor()])
189 | self.depths_transform = transforms.Compose(
190 | [transforms.Resize((self.testsize, self.testsize)), transforms.ToTensor()])
191 | self.size = len(self.images)
192 | self.index = 0
193 |
194 | def load_data(self):
195 | image = self.rgb_loader(self.images[self.index])
196 | HH = image.size[0]
197 | WW = image.size[1]
198 | image = self.transform(image).unsqueeze(0)
199 | depth = self.rgb_loader(self.depths[self.index])
200 | depth = self.depths_transform(depth).unsqueeze(0)
201 |
202 | name = self.images[self.index].split('/')[-1]
203 | # image_for_post=self.rgb_loader(self.images[self.index])
204 | # image_for_post=image_for_post.resize(gt.size)
205 | if name.endswith('.jpg'):
206 | name = name.split('.jpg')[0] + '.png'
207 | self.index += 1
208 | self.index = self.index % self.size
209 | return image, depth, HH, WW, name
210 |
211 | def rgb_loader(self, path):
212 | with open(path, 'rb') as f:
213 | img = Image.open(f)
214 | return img.convert('RGB')
215 |
216 | def binary_loader(self, path):
217 | with open(path, 'rb') as f:
218 | img = Image.open(f)
219 | return img.convert('L')
220 |
221 | def __len__(self):
222 | return self.size
223 |
224 |
--------------------------------------------------------------------------------
/utils/saliency_metric.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from scipy import ndimage
3 | from scipy.ndimage import convolve, distance_transform_edt as bwdist
4 | import cv2
5 | from skimage.morphology import skeletonize
6 | from skimage.morphology import disk
7 | from skimage.measure import label
8 |
9 |
10 | class cal_fm(object):
11 | # Fmeasure(maxFm,meanFm)---Frequency-tuned salient region detection(CVPR 2009)
12 | def __init__(self, num, thds=255):
13 | self.num = num
14 | self.thds = thds
15 | self.precision = np.zeros((self.num, self.thds))
16 | self.recall = np.zeros((self.num, self.thds))
17 | self.meanF = np.zeros((self.num,1))
18 | self.changeable_fms = []
19 | self.idx = 0
20 |
21 | def update(self, pred, gt):
22 | if gt.max() != 0:
23 | # prediction, recall, Fmeasure_temp = self.cal(pred, gt)
24 | prediction, recall, Fmeasure_temp, changeable_fms = self.cal(pred, gt)
25 | self.precision[self.idx, :] = prediction
26 | self.recall[self.idx, :] = recall
27 | self.meanF[self.idx, :] = Fmeasure_temp
28 | self.changeable_fms.append(changeable_fms)
29 | self.idx += 1
30 |
31 | def cal(self, pred, gt):
32 | ########################meanF##############################
33 | th = 2 * pred.mean()
34 | if th > 1:
35 | th = 1
36 |
37 | binary = np.zeros_like(pred)
38 | binary[pred >= th] = 1
39 |
40 | hard_gt = np.zeros_like(gt)
41 | hard_gt[gt > 0.5] = 1
42 | tp = (binary * hard_gt).sum()
43 | if tp == 0:
44 | meanF = 0
45 | else:
46 | pre = tp / binary.sum()
47 | rec = tp / hard_gt.sum()
48 | meanF = 1.3 * pre * rec / (0.3 * pre + rec)
49 | ########################maxF##############################
50 | pred = np.uint8(pred * 255)
51 | target = pred[gt > 0.5]
52 | nontarget = pred[gt <= 0.5]
53 | targetHist, _ = np.histogram(target, bins=range(256))
54 | nontargetHist, _ = np.histogram(nontarget, bins=range(256))
55 | targetHist = np.cumsum(np.flip(targetHist), axis=0)
56 | nontargetHist = np.cumsum(np.flip(nontargetHist), axis=0)
57 | precision = targetHist / (targetHist + nontargetHist + 1e-8)
58 | recall = targetHist / np.sum(gt)
59 | numerator = 1.3 * precision * recall
60 | denominator = np.where(numerator == 0, 1, 0.3 * precision + recall)
61 | changeable_fms = numerator / denominator
62 | return precision, recall, meanF, changeable_fms
63 |
64 |
65 | def show(self):
66 | assert self.num == self.idx
67 | precision = self.precision.mean(axis=0)
68 | recall = self.recall.mean(axis=0)
69 | # fmeasure = 1.3 * precision * recall / (0.3 * precision + recall + 1e-8)
70 | changeable_fm = np.mean(np.array(self.changeable_fms), axis=0)
71 | fmeasure_avg = self.meanF.mean(axis=0)
72 | return changeable_fm.max(),fmeasure_avg[0],precision,recall
73 |
74 |
75 |
76 | class cal_mae(object):
77 | # mean absolute error
78 | def __init__(self):
79 | self.prediction = []
80 |
81 | def update(self, pred, gt):
82 | score = self.cal(pred, gt)
83 | self.prediction.append(score)
84 |
85 | def cal(self, pred, gt):
86 | return np.mean(np.abs(pred - gt))
87 |
88 | def show(self):
89 | return np.mean(self.prediction)
90 |
91 | class cal_dice(object):
92 | # mean absolute error
93 | def __init__(self):
94 | self.prediction = []
95 |
96 | def update(self, pred, gt):
97 | score = self.cal(pred, gt)
98 | self.prediction.append(score)
99 |
100 | def cal(self, y_pred, y_true):
101 | # smooth = 1
102 | smooth = 1e-5
103 | y_true_f = y_true.flatten()
104 | y_pred_f = y_pred.flatten()
105 | intersection = np.sum(y_true_f * y_pred_f)
106 | return (2. * intersection + smooth) / (np.sum(y_true_f) + np.sum(y_pred_f) + smooth)
107 |
108 | def show(self):
109 | return np.mean(self.prediction)
110 |
111 | class cal_ber(object):
112 | # mean absolute error
113 | def __init__(self):
114 | self.prediction = []
115 |
116 | def update(self, pred, gt):
117 | score = self.cal(pred, gt)
118 | self.prediction.append(score)
119 |
120 | def cal(self, y_pred, y_true):
121 | binary = np.zeros_like(y_pred)
122 | binary[y_pred >= 0.5] = 1
123 | hard_gt = np.zeros_like(y_true)
124 | hard_gt[y_true > 0.5] = 1
125 | tp = (binary * hard_gt).sum()
126 | tn = ((1-binary) * (1-hard_gt)).sum()
127 | Np = hard_gt.sum()
128 | Nn = (1-hard_gt).sum()
129 | ber = (1-(tp/(Np+1e-8)+tn/(Nn+1e-8))/2)
130 | return ber
131 |
132 | def show(self):
133 | return np.mean(self.prediction)
134 |
135 | class cal_acc(object):
136 | # mean absolute error
137 | def __init__(self):
138 | self.prediction = []
139 |
140 | def update(self, pred, gt):
141 | score = self.cal(pred, gt)
142 | self.prediction.append(score)
143 |
144 | def cal(self, y_pred, y_true):
145 | binary = np.zeros_like(y_pred)
146 | binary[y_pred >= 0.5] = 1
147 | hard_gt = np.zeros_like(y_true)
148 | hard_gt[y_true > 0.5] = 1
149 | tp = (binary * hard_gt).sum()
150 | tn = ((1-binary) * (1-hard_gt)).sum()
151 | Np = hard_gt.sum()
152 | Nn = (1-hard_gt).sum()
153 | acc = ((tp+tn)/(Np+Nn))
154 | return acc
155 |
156 | def show(self):
157 | return np.mean(self.prediction)
158 |
159 | class cal_iou(object):
160 | # mean absolute error
161 | def __init__(self):
162 | self.prediction = []
163 |
164 | def update(self, pred, gt):
165 | score = self.cal(pred, gt)
166 | self.prediction.append(score)
167 |
168 | # def cal(self, input, target):
169 | # classes = 1
170 | # intersection = np.logical_and(target == classes, input == classes)
171 | # # print(intersection.any())
172 | # union = np.logical_or(target == classes, input == classes)
173 | # return np.sum(intersection) / np.sum(union)
174 |
175 | def cal(self, input, target):
176 | smooth = 1e-5
177 | input = input > 0.5
178 | target_ = target > 0.5
179 | intersection = (input & target_).sum()
180 | union = (input | target_).sum()
181 |
182 | return (intersection + smooth) / (union + smooth)
183 | def show(self):
184 | return np.mean(self.prediction)
185 |
186 | # smooth = 1e-5
187 | #
188 | # if torch.is_tensor(output):
189 | # output = torch.sigmoid(output).data.cpu().numpy()
190 | # if torch.is_tensor(target):
191 | # target = target.data.cpu().numpy()
192 | # output_ = output > 0.5
193 | # target_ = target > 0.5
194 | # intersection = (output_ & target_).sum()
195 | # union = (output_ | target_).sum()
196 |
197 | # return (intersection + smooth) / (union + smooth)
198 |
199 | class cal_sm(object):
200 | # Structure-measure: A new way to evaluate foreground maps (ICCV 2017)
201 | def __init__(self, alpha=0.5):
202 | self.prediction = []
203 | self.alpha = alpha
204 |
205 | def update(self, pred, gt):
206 | gt = gt > 0.5
207 | score = self.cal(pred, gt)
208 | self.prediction.append(score)
209 |
210 | def show(self):
211 | return np.mean(self.prediction)
212 |
213 | def cal(self, pred, gt):
214 | y = np.mean(gt)
215 | if y == 0:
216 | score = 1 - np.mean(pred)
217 | elif y == 1:
218 | score = np.mean(pred)
219 | else:
220 | score = self.alpha * self.object(pred, gt) + (1 - self.alpha) * self.region(pred, gt)
221 | return score
222 |
223 | def object(self, pred, gt):
224 | fg = pred * gt
225 | bg = (1 - pred) * (1 - gt)
226 |
227 | u = np.mean(gt)
228 | return u * self.s_object(fg, gt) + (1 - u) * self.s_object(bg, np.logical_not(gt))
229 |
230 | def s_object(self, in1, in2):
231 | x = np.mean(in1[in2])
232 | sigma_x = np.std(in1[in2])
233 | return 2 * x / (pow(x, 2) + 1 + sigma_x + 1e-8)
234 |
235 | def region(self, pred, gt):
236 | [y, x] = ndimage.center_of_mass(gt)
237 | y = int(round(y)) + 1
238 | x = int(round(x)) + 1
239 | [gt1, gt2, gt3, gt4, w1, w2, w3, w4] = self.divideGT(gt, x, y)
240 | pred1, pred2, pred3, pred4 = self.dividePred(pred, x, y)
241 |
242 | score1 = self.ssim(pred1, gt1)
243 | score2 = self.ssim(pred2, gt2)
244 | score3 = self.ssim(pred3, gt3)
245 | score4 = self.ssim(pred4, gt4)
246 |
247 | return w1 * score1 + w2 * score2 + w3 * score3 + w4 * score4
248 |
249 | def divideGT(self, gt, x, y):
250 | h, w = gt.shape
251 | area = h * w
252 | LT = gt[0:y, 0:x]
253 | RT = gt[0:y, x:w]
254 | LB = gt[y:h, 0:x]
255 | RB = gt[y:h, x:w]
256 |
257 | w1 = x * y / area
258 | w2 = y * (w - x) / area
259 | w3 = (h - y) * x / area
260 | w4 = (h - y) * (w - x) / area
261 |
262 | return LT, RT, LB, RB, w1, w2, w3, w4
263 |
264 | def dividePred(self, pred, x, y):
265 | h, w = pred.shape
266 | LT = pred[0:y, 0:x]
267 | RT = pred[0:y, x:w]
268 | LB = pred[y:h, 0:x]
269 | RB = pred[y:h, x:w]
270 |
271 | return LT, RT, LB, RB
272 |
273 | def ssim(self, in1, in2):
274 | in2 = np.float32(in2)
275 | h, w = in1.shape
276 | N = h * w
277 |
278 | x = np.mean(in1)
279 | y = np.mean(in2)
280 | sigma_x = np.var(in1)
281 | sigma_y = np.var(in2)
282 | sigma_xy = np.sum((in1 - x) * (in2 - y)) / (N - 1)
283 |
284 | alpha = 4 * x * y * sigma_xy
285 | beta = (x * x + y * y) * (sigma_x + sigma_y)
286 |
287 | if alpha != 0:
288 | score = alpha / (beta + 1e-8)
289 | elif alpha == 0 and beta == 0:
290 | score = 1
291 | else:
292 | score = 0
293 |
294 | return score
295 |
296 | class cal_em(object):
297 | #Enhanced-alignment Measure for Binary Foreground Map Evaluation (IJCAI 2018)
298 | def __init__(self):
299 | self.prediction = []
300 |
301 | def update(self, pred, gt):
302 | score = self.cal(pred, gt)
303 | self.prediction.append(score)
304 |
305 | def cal(self, pred, gt):
306 | th = 2 * pred.mean()
307 | if th > 1:
308 | th = 1
309 | FM = np.zeros(gt.shape)
310 | FM[pred >= th] = 1
311 | FM = np.array(FM,dtype=bool)
312 | GT = np.array(gt,dtype=bool)
313 | dFM = np.double(FM)
314 | if (sum(sum(np.double(GT)))==0):
315 | enhanced_matrix = 1.0-dFM
316 | elif (sum(sum(np.double(~GT)))==0):
317 | enhanced_matrix = dFM
318 | else:
319 | dGT = np.double(GT)
320 | align_matrix = self.AlignmentTerm(dFM, dGT)
321 | enhanced_matrix = self.EnhancedAlignmentTerm(align_matrix)
322 | [w, h] = np.shape(GT)
323 | score = sum(sum(enhanced_matrix))/ (w * h - 1 + 1e-8)
324 | return score
325 | def AlignmentTerm(self,dFM,dGT):
326 | mu_FM = np.mean(dFM)
327 | mu_GT = np.mean(dGT)
328 | align_FM = dFM - mu_FM
329 | align_GT = dGT - mu_GT
330 | align_Matrix = 2. * (align_GT * align_FM)/ (align_GT* align_GT + align_FM* align_FM + 1e-8)
331 | return align_Matrix
332 | def EnhancedAlignmentTerm(self,align_Matrix):
333 | enhanced = np.power(align_Matrix + 1,2) / 4
334 | return enhanced
335 | def show(self):
336 | return np.mean(self.prediction)
337 | class cal_wfm(object):
338 | def __init__(self, beta=1):
339 | self.beta = beta
340 | self.eps = 1e-6
341 | self.scores_list = []
342 |
343 | def update(self, pred, gt):
344 | assert pred.ndim == gt.ndim and pred.shape == gt.shape
345 | assert pred.max() <= 1 and pred.min() >= 0
346 | assert gt.max() <= 1 and gt.min() >= 0
347 |
348 | gt = gt > 0.5
349 | if gt.max() == 0:
350 | score = 0
351 | else:
352 | score = self.cal(pred, gt)
353 | self.scores_list.append(score)
354 |
355 | def matlab_style_gauss2D(self, shape=(7, 7), sigma=5):
356 | """
357 | 2D gaussian mask - should give the same result as MATLAB's
358 | fspecial('gaussian',[shape],[sigma])
359 | """
360 | m, n = [(ss - 1.) / 2. for ss in shape]
361 | y, x = np.ogrid[-m:m + 1, -n:n + 1]
362 | h = np.exp(-(x * x + y * y) / (2. * sigma * sigma))
363 | h[h < np.finfo(h.dtype).eps * h.max()] = 0
364 | sumh = h.sum()
365 | if sumh != 0:
366 | h /= sumh
367 | return h
368 |
369 | def cal(self, pred, gt):
370 | # [Dst,IDXT] = bwdist(dGT);
371 | Dst, Idxt = bwdist(gt == 0, return_indices=True)
372 |
373 | # %Pixel dependency
374 | # E = abs(FG-dGT);
375 | E = np.abs(pred - gt)
376 | # Et = E;
377 | # Et(~GT)=Et(IDXT(~GT)); %To deal correctly with the edges of the foreground region
378 | Et = np.copy(E)
379 | Et[gt == 0] = Et[Idxt[0][gt == 0], Idxt[1][gt == 0]]
380 |
381 | # K = fspecial('gaussian',7,5);
382 | # EA = imfilter(Et,K);
383 | # MIN_E_EA(GT & EA dict:
425 | hce = np.mean(np.array(self.hces))
426 | return dict(hce=hce)
427 |
428 |
429 | def cal_hce(self, pred: np.ndarray, gt: np.ndarray, gt_ske: np.ndarray, relax=5, epsilon=2.0) -> float:
430 | # Binarize gt
431 | if(len(gt.shape)>2):
432 | gt = gt[:, :, 0]
433 |
434 | epsilon_gt = 0.5#(np.amin(gt)+np.amax(gt))/2.0
435 | gt = (gt>epsilon_gt).astype(np.uint8)
436 |
437 | # Binarize pred
438 | if(len(pred.shape)>2):
439 | pred = pred[:, :, 0]
440 | epsilon_pred = 0.5#(np.amin(pred)+np.amax(pred))/2.0
441 | pred = (pred>epsilon_pred).astype(np.uint8)
442 |
443 | Union = np.logical_or(gt, pred)
444 | TP = np.logical_and(gt, pred)
445 | FP = pred - TP
446 | FN = gt - TP
447 |
448 | # relax the Union of gt and pred
449 | Union_erode = Union.copy()
450 | Union_erode = cv2.erode(Union_erode.astype(np.uint8), disk(1), iterations=relax)
451 |
452 | # --- get the relaxed False Positive regions for computing the human efforts in correcting them ---
453 | FP_ = np.logical_and(FP, Union_erode) # get the relaxed FP
454 | for i in range(0, relax):
455 | FP_ = cv2.dilate(FP_.astype(np.uint8), disk(1))
456 | FP_ = np.logical_and(FP_, 1-np.logical_or(TP, FN))
457 | FP_ = np.logical_and(FP, FP_)
458 |
459 | # --- get the relaxed False Negative regions for computing the human efforts in correcting them ---
460 | FN_ = np.logical_and(FN, Union_erode) # preserve the structural components of FN
461 | ## recover the FN, where pixels are not close to the TP borders
462 | for i in range(0, relax):
463 | FN_ = cv2.dilate(FN_.astype(np.uint8), disk(1))
464 | FN_ = np.logical_and(FN_, 1-np.logical_or(TP, FP))
465 | FN_ = np.logical_and(FN, FN_)
466 | FN_ = np.logical_or(FN_, np.logical_xor(gt_ske, np.logical_and(TP, gt_ske))) # preserve the structural components of FN
467 |
468 | ## 2. =============Find exact polygon control points and independent regions==============
469 | ## find contours from FP_
470 | ctrs_FP, hier_FP = cv2.findContours(FP_.astype(np.uint8), cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE)
471 | ## find control points and independent regions for human correction
472 | bdies_FP, indep_cnt_FP = self.filter_bdy_cond(ctrs_FP, FP_, np.logical_or(TP,FN_))
473 | ## find contours from FN_
474 | ctrs_FN, hier_FN = cv2.findContours(FN_.astype(np.uint8), cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE)
475 | ## find control points and independent regions for human correction
476 | bdies_FN, indep_cnt_FN = self.filter_bdy_cond(ctrs_FN, FN_, 1-np.logical_or(np.logical_or(TP, FP_), FN_))
477 |
478 | poly_FP, poly_FP_len, poly_FP_point_cnt = self.approximate_RDP(bdies_FP, epsilon=epsilon)
479 | poly_FN, poly_FN_len, poly_FN_point_cnt = self.approximate_RDP(bdies_FN, epsilon=epsilon)
480 |
481 | # FP_points+FP_indep+FN_points+FN_indep
482 | return poly_FP_point_cnt+indep_cnt_FP+poly_FN_point_cnt+indep_cnt_FN
483 |
484 | def filter_bdy_cond(self, bdy_, mask, cond):
485 |
486 | cond = cv2.dilate(cond.astype(np.uint8), disk(1))
487 | labels = label(mask) # find the connected regions
488 | lbls = np.unique(labels) # the indices of the connected regions
489 | indep = np.ones(lbls.shape[0]) # the label of each connected regions
490 | indep[0] = 0 # 0 indicate the background region
491 |
492 | boundaries = []
493 | h,w = cond.shape[0:2]
494 | ind_map = np.zeros((h, w))
495 | indep_cnt = 0
496 |
497 | for i in range(0, len(bdy_)):
498 | tmp_bdies = []
499 | tmp_bdy = []
500 | for j in range(0, bdy_[i].shape[0]):
501 | r, c = bdy_[i][j,0,1],bdy_[i][j,0,0]
502 |
503 | if(np.sum(cond[r, c])==0 or ind_map[r, c]!=0):
504 | if(len(tmp_bdy)>0):
505 | tmp_bdies.append(tmp_bdy)
506 | tmp_bdy = []
507 | continue
508 | tmp_bdy.append([c, r])
509 | ind_map[r, c] = ind_map[r, c] + 1
510 | indep[labels[r, c]] = 0 # indicates part of the boundary of this region needs human correction
511 | if(len(tmp_bdy)>0):
512 | tmp_bdies.append(tmp_bdy)
513 |
514 | # check if the first and the last boundaries are connected
515 | # if yes, invert the first boundary and attach it after the last boundary
516 | if(len(tmp_bdies)>1):
517 | first_x, first_y = tmp_bdies[0][0]
518 | last_x, last_y = tmp_bdies[-1][-1]
519 | if((abs(first_x-last_x)==1 and first_y==last_y) or
520 | (first_x==last_x and abs(first_y-last_y)==1) or
521 | (abs(first_x-last_x)==1 and abs(first_y-last_y)==1)
522 | ):
523 | tmp_bdies[-1].extend(tmp_bdies[0][::-1])
524 | del tmp_bdies[0]
525 |
526 | for k in range(0, len(tmp_bdies)):
527 | tmp_bdies[k] = np.array(tmp_bdies[k])[:, np.newaxis, :]
528 | if(len(tmp_bdies)>0):
529 | boundaries.extend(tmp_bdies)
530 |
531 | return boundaries, np.sum(indep)
532 |
533 | # this function approximate each boundary by DP algorithm
534 | # https://en.wikipedia.org/wiki/Ramer%E2%80%93Douglas%E2%80%93Peucker_algorithm
535 | def approximate_RDP(self, boundaries, epsilon=1.0):
536 |
537 | boundaries_ = []
538 | boundaries_len_ = []
539 | pixel_cnt_ = 0
540 |
541 | # polygon approximate of each boundary
542 | for i in range(0, len(boundaries)):
543 | boundaries_.append(cv2.approxPolyDP(boundaries[i], epsilon, False))
544 |
545 | # count the control points number of each boundary and the total control points number of all the boundaries
546 | for i in range(0, len(boundaries_)):
547 | boundaries_len_.append(len(boundaries_[i]))
548 | pixel_cnt_ = pixel_cnt_ + len(boundaries_[i])
549 |
550 | return boundaries_, boundaries_len_, pixel_cnt_
551 |
--------------------------------------------------------------------------------
/model/MVANet.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import torch
4 | import torch.nn.functional as F
5 | from einops import rearrange
6 | from torch import nn
7 | from .SwinTransformer import SwinB
8 |
9 |
10 | def get_activation_fn(activation):
11 | """Return an activation function given a string"""
12 | if activation == "relu":
13 | return F.relu
14 | if activation == "gelu":
15 | return F.gelu
16 | if activation == "glu":
17 | return F.glu
18 | raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
19 |
20 |
21 | def make_cbr(in_dim, out_dim):
22 | return nn.Sequential(nn.Conv2d(in_dim, out_dim, kernel_size=3, padding=1), nn.BatchNorm2d(out_dim), nn.PReLU())
23 |
24 |
25 | def make_cbg(in_dim, out_dim):
26 | return nn.Sequential(nn.Conv2d(in_dim, out_dim, kernel_size=3, padding=1), nn.BatchNorm2d(out_dim), nn.GELU())
27 |
28 |
29 | def rescale_to(x, scale_factor: float = 2, interpolation='nearest'):
30 | return F.interpolate(x, scale_factor=scale_factor, mode=interpolation)
31 |
32 |
33 | def resize_as(x, y, interpolation='bilinear'):
34 | return F.interpolate(x, size=y.shape[-2:], mode=interpolation)
35 |
36 |
37 | def image2patches(x):
38 | """b c (hg h) (wg w) -> (hg wg b) c h w"""
39 | x = rearrange(x, 'b c (hg h) (wg w) -> (hg wg b) c h w', hg=2, wg=2)
40 | return x
41 |
42 |
43 | def patches2image(x):
44 | """(hg wg b) c h w -> b c (hg h) (wg w)"""
45 | x = rearrange(x, '(hg wg b) c h w -> b c (hg h) (wg w)', hg=2, wg=2)
46 | return x
47 |
48 |
49 | class PositionEmbeddingSine:
50 | def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
51 | super().__init__()
52 | self.num_pos_feats = num_pos_feats
53 | self.temperature = temperature
54 | self.normalize = normalize
55 | if scale is not None and normalize is False:
56 | raise ValueError("normalize should be True if scale is passed")
57 | if scale is None:
58 | scale = 2 * math.pi
59 | self.scale = scale
60 | self.dim_t = torch.arange(0, self.num_pos_feats, dtype=torch.float32, device='cuda')
61 |
62 | def __call__(self, b, h, w):
63 | mask = torch.zeros([b, h, w], dtype=torch.bool, device='cuda')
64 | assert mask is not None
65 | not_mask = ~mask
66 | y_embed = not_mask.cumsum(dim=1, dtype=torch.float32)
67 | x_embed = not_mask.cumsum(dim=2, dtype=torch.float32)
68 | if self.normalize:
69 | eps = 1e-6
70 | y_embed = ((y_embed - 0.5) / (y_embed[:, -1:, :] + eps) * self.scale).cuda()
71 | x_embed = ((x_embed - 0.5) / (x_embed[:, :, -1:] + eps) * self.scale).cuda()
72 |
73 | dim_t = self.temperature ** (2 * (self.dim_t // 2) / self.num_pos_feats)
74 |
75 | pos_x = x_embed[:, :, :, None] / dim_t
76 | pos_y = y_embed[:, :, :, None] / dim_t
77 | pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(
78 | 3)
79 | pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(
80 | 3)
81 | return torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
82 |
83 |
84 | class MCLM(nn.Module):
85 | def __init__(self, d_model, num_heads, pool_ratios=[1, 4, 8]):
86 | super(MCLM, self).__init__()
87 | self.attention = nn.ModuleList([
88 | nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
89 | nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
90 | nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
91 | nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
92 | nn.MultiheadAttention(d_model, num_heads, dropout=0.1)
93 | ])
94 |
95 | self.linear3 = nn.Linear(d_model, d_model * 2)
96 | self.linear4 = nn.Linear(d_model * 2, d_model)
97 | self.linear5 = nn.Linear(d_model, d_model * 2)
98 | self.linear6 = nn.Linear(d_model * 2, d_model)
99 | self.norm1 = nn.LayerNorm(d_model)
100 | self.norm2 = nn.LayerNorm(d_model)
101 | self.dropout = nn.Dropout(0.1)
102 | self.dropout1 = nn.Dropout(0.1)
103 | self.dropout2 = nn.Dropout(0.1)
104 | self.activation = get_activation_fn('relu')
105 | self.pool_ratios = pool_ratios
106 | self.p_poses = []
107 | self.g_pos = None
108 | self.positional_encoding = PositionEmbeddingSine(num_pos_feats=d_model // 2, normalize=True)
109 |
110 | def forward(self, l, g):
111 | """
112 | l: 4,c,h,w
113 | g: 1,c,h,w
114 | """
115 | b, c, h, w = l.size()
116 | # 4,c,h,w -> 1,c,2h,2w
117 | concated_locs = rearrange(l, '(hg wg b) c h w -> b c (hg h) (wg w)', hg=2, wg=2)
118 |
119 | pools = []
120 | for pool_ratio in self.pool_ratios:
121 | # b,c,h,w
122 | tgt_hw = (round(h / pool_ratio), round(w / pool_ratio))
123 | pool = F.adaptive_avg_pool2d(concated_locs, tgt_hw)
124 | pools.append(rearrange(pool, 'b c h w -> (h w) b c'))
125 | if self.g_pos is None:
126 | pos_emb = self.positional_encoding(pool.shape[0], pool.shape[2], pool.shape[3])
127 | pos_emb = rearrange(pos_emb, 'b c h w -> (h w) b c')
128 | self.p_poses.append(pos_emb)
129 | pools = torch.cat(pools, 0)
130 | if self.g_pos is None:
131 | self.p_poses = torch.cat(self.p_poses, dim=0)
132 | pos_emb = self.positional_encoding(g.shape[0], g.shape[2], g.shape[3])
133 | self.g_pos = rearrange(pos_emb, 'b c h w -> (h w) b c')
134 |
135 | # attention between glb (q) & multisensory concated-locs (k,v)
136 | g_hw_b_c = rearrange(g, 'b c h w -> (h w) b c')
137 | g_hw_b_c = g_hw_b_c + self.dropout1(self.attention[0](g_hw_b_c + self.g_pos, pools + self.p_poses, pools)[0])
138 | g_hw_b_c = self.norm1(g_hw_b_c)
139 | g_hw_b_c = g_hw_b_c + self.dropout2(self.linear6(self.dropout(self.activation(self.linear5(g_hw_b_c)).clone())))
140 | g_hw_b_c = self.norm2(g_hw_b_c)
141 |
142 | # attention between origin locs (q) & freashed glb (k,v)
143 | l_hw_b_c = rearrange(l, "b c h w -> (h w) b c")
144 | _g_hw_b_c = rearrange(g_hw_b_c, '(h w) b c -> h w b c', h=h, w=w)
145 | _g_hw_b_c = rearrange(_g_hw_b_c, "(ng h) (nw w) b c -> (h w) (ng nw b) c", ng=2, nw=2)
146 | outputs_re = []
147 | for i, (_l, _g) in enumerate(zip(l_hw_b_c.chunk(4, dim=1), _g_hw_b_c.chunk(4, dim=1))):
148 | outputs_re.append(self.attention[i + 1](_l, _g, _g)[0]) # (h w) 1 c
149 | outputs_re = torch.cat(outputs_re, 1) # (h w) 4 c
150 |
151 | l_hw_b_c = l_hw_b_c + self.dropout1(outputs_re)
152 | l_hw_b_c = self.norm1(l_hw_b_c)
153 | l_hw_b_c = l_hw_b_c + self.dropout2(self.linear4(self.dropout(self.activation(self.linear3(l_hw_b_c)).clone())))
154 | l_hw_b_c = self.norm2(l_hw_b_c)
155 |
156 | l = torch.cat((l_hw_b_c, g_hw_b_c), 1) # hw,b(5),c
157 | return rearrange(l, "(h w) b c -> b c h w", h=h, w=w) ## (5,c,h*w)
158 |
159 |
160 | class inf_MCLM(nn.Module):
161 | def __init__(self, d_model, num_heads, pool_ratios=[1, 4, 8]):
162 | super(inf_MCLM, self).__init__()
163 | self.attention = nn.ModuleList([
164 | nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
165 | nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
166 | nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
167 | nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
168 | nn.MultiheadAttention(d_model, num_heads, dropout=0.1)
169 | ])
170 | self.linear3 = nn.Linear(d_model, d_model * 2)
171 | self.linear4 = nn.Linear(d_model * 2, d_model)
172 | self.linear5 = nn.Linear(d_model, d_model * 2)
173 | self.linear6 = nn.Linear(d_model * 2, d_model)
174 | self.norm1 = nn.LayerNorm(d_model)
175 | self.norm2 = nn.LayerNorm(d_model)
176 | self.dropout = nn.Dropout(0.1)
177 | self.dropout1 = nn.Dropout(0.1)
178 | self.dropout2 = nn.Dropout(0.1)
179 | self.activation = get_activation_fn('relu')
180 | self.pool_ratios = pool_ratios
181 | self.p_poses = []
182 | self.g_pos = None
183 | self.positional_encoding = PositionEmbeddingSine(num_pos_feats=d_model // 2, normalize=True)
184 |
185 | def forward(self, l, g):
186 | """
187 | l: 4,c,h,w
188 | g: 1,c,h,w
189 | """
190 | b, c, h, w = l.size()
191 | # 4,c,h,w -> 1,c,2h,2w
192 | concated_locs = rearrange(l, '(hg wg b) c h w -> b c (hg h) (wg w)', hg=2, wg=2)
193 | self.p_poses = []
194 | pools = []
195 | for pool_ratio in self.pool_ratios:
196 | # b,c,h,w
197 | tgt_hw = (round(h / pool_ratio), round(w / pool_ratio))
198 | pool = F.adaptive_avg_pool2d(concated_locs, tgt_hw)
199 | pools.append(rearrange(pool, 'b c h w -> (h w) b c'))
200 | # if self.g_pos is None:
201 | pos_emb = self.positional_encoding(pool.shape[0], pool.shape[2], pool.shape[3])
202 | pos_emb = rearrange(pos_emb, 'b c h w -> (h w) b c')
203 | self.p_poses.append(pos_emb)
204 | pools = torch.cat(pools, 0)
205 | # if self.g_pos is None:
206 | self.p_poses = torch.cat(self.p_poses, dim=0)
207 | pos_emb = self.positional_encoding(g.shape[0], g.shape[2], g.shape[3])
208 | self.g_pos = rearrange(pos_emb, 'b c h w -> (h w) b c')
209 |
210 | # attention between glb (q) & multisensory concated-locs (k,v)
211 | g_hw_b_c = rearrange(g, 'b c h w -> (h w) b c')
212 | g_hw_b_c = g_hw_b_c + self.dropout1(self.attention[0](g_hw_b_c + self.g_pos, pools + self.p_poses, pools)[0])
213 | g_hw_b_c = self.norm1(g_hw_b_c)
214 | g_hw_b_c = g_hw_b_c + self.dropout2(self.linear6(self.dropout(self.activation(self.linear5(g_hw_b_c)).clone())))
215 | g_hw_b_c = self.norm2(g_hw_b_c)
216 |
217 | # attention between origin locs (q) & freashed glb (k,v)
218 | l_hw_b_c = rearrange(l, "b c h w -> (h w) b c")
219 | _g_hw_b_c = rearrange(g_hw_b_c, '(h w) b c -> h w b c', h=h, w=w)
220 | _g_hw_b_c = rearrange(_g_hw_b_c, "(ng h) (nw w) b c -> (h w) (ng nw b) c", ng=2, nw=2)
221 | outputs_re = []
222 | for i, (_l, _g) in enumerate(zip(l_hw_b_c.chunk(4, dim=1), _g_hw_b_c.chunk(4, dim=1))):
223 | outputs_re.append(self.attention[i + 1](_l, _g, _g)[0]) # (h w) 1 c
224 | outputs_re = torch.cat(outputs_re, 1) # (h w) 4 c
225 |
226 | l_hw_b_c = l_hw_b_c + self.dropout1(outputs_re)
227 | l_hw_b_c = self.norm1(l_hw_b_c)
228 | l_hw_b_c = l_hw_b_c + self.dropout2(self.linear4(self.dropout(self.activation(self.linear3(l_hw_b_c)).clone())))
229 | l_hw_b_c = self.norm2(l_hw_b_c)
230 |
231 | l = torch.cat((l_hw_b_c, g_hw_b_c), 1) # hw,b(5),c
232 | return rearrange(l, "(h w) b c -> b c h w", h=h, w=w) ## (5,c,h*w)
233 |
234 |
235 | class MCRM(nn.Module):
236 | def __init__(self, d_model, num_heads, pool_ratios=[4, 8, 16], h=None):
237 | super(MCRM, self).__init__()
238 | self.attention = nn.ModuleList([
239 | nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
240 | nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
241 | nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
242 | nn.MultiheadAttention(d_model, num_heads, dropout=0.1)
243 | ])
244 |
245 | self.linear3 = nn.Linear(d_model, d_model * 2)
246 | self.linear4 = nn.Linear(d_model * 2, d_model)
247 | self.norm1 = nn.LayerNorm(d_model)
248 | self.norm2 = nn.LayerNorm(d_model)
249 | self.dropout = nn.Dropout(0.1)
250 | self.dropout1 = nn.Dropout(0.1)
251 | self.dropout2 = nn.Dropout(0.1)
252 | self.sigmoid = nn.Sigmoid()
253 | self.activation = get_activation_fn('relu')
254 | self.sal_conv = nn.Conv2d(d_model, 1, 1)
255 | self.pool_ratios = pool_ratios
256 | self.positional_encoding = PositionEmbeddingSine(num_pos_feats=d_model // 2, normalize=True)
257 | def forward(self, x):
258 | b, c, h, w = x.size()
259 | loc, glb = x.split([4, 1], dim=0) # 4,c,h,w; 1,c,h,w
260 | # b(4),c,h,w
261 | patched_glb = rearrange(glb, 'b c (hg h) (wg w) -> (hg wg b) c h w', hg=2, wg=2)
262 |
263 | # generate token attention map
264 | token_attention_map = self.sigmoid(self.sal_conv(glb))
265 | token_attention_map = F.interpolate(token_attention_map, size=patches2image(loc).shape[-2:], mode='nearest')
266 | loc = loc * rearrange(token_attention_map, 'b c (hg h) (wg w) -> (hg wg b) c h w', hg=2, wg=2)
267 | pools = []
268 | for pool_ratio in self.pool_ratios:
269 | tgt_hw = (round(h / pool_ratio), round(w / pool_ratio))
270 | pool = F.adaptive_avg_pool2d(patched_glb, tgt_hw)
271 | pools.append(rearrange(pool, 'nl c h w -> nl c (h w)')) # nl(4),c,hw
272 | # nl(4),c,nphw -> nl(4),nphw,1,c
273 | pools = rearrange(torch.cat(pools, 2), "nl c nphw -> nl nphw 1 c")
274 | loc_ = rearrange(loc, 'nl c h w -> nl (h w) 1 c')
275 | outputs = []
276 | for i, q in enumerate(loc_.unbind(dim=0)): # traverse all local patches
277 | # np*hw,1,c
278 | v = pools[i]
279 | k = v
280 | outputs.append(self.attention[i](q, k, v)[0])
281 | outputs = torch.cat(outputs, 1)
282 | src = loc.view(4, c, -1).permute(2, 0, 1) + self.dropout1(outputs)
283 | src = self.norm1(src)
284 | src = src + self.dropout2(self.linear4(self.dropout(self.activation(self.linear3(src)).clone())))
285 | src = self.norm2(src)
286 |
287 | src = src.permute(1, 2, 0).reshape(4, c, h, w) # freshed loc
288 | glb = glb + F.interpolate(patches2image(src), size=glb.shape[-2:], mode='nearest') # freshed glb
289 | return torch.cat((src, glb), 0), token_attention_map
290 |
291 |
292 | class inf_MCRM(nn.Module):
293 | def __init__(self, d_model, num_heads, pool_ratios=[4, 8, 16], h=None):
294 | super(inf_MCRM, self).__init__()
295 | self.attention = nn.ModuleList([
296 | nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
297 | nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
298 | nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
299 | nn.MultiheadAttention(d_model, num_heads, dropout=0.1)
300 | ])
301 |
302 | self.linear3 = nn.Linear(d_model, d_model * 2)
303 | self.linear4 = nn.Linear(d_model * 2, d_model)
304 | self.norm1 = nn.LayerNorm(d_model)
305 | self.norm2 = nn.LayerNorm(d_model)
306 | self.dropout = nn.Dropout(0.1)
307 | self.dropout1 = nn.Dropout(0.1)
308 | self.dropout2 = nn.Dropout(0.1)
309 | self.sigmoid = nn.Sigmoid()
310 | self.activation = get_activation_fn('relu')
311 | self.sal_conv = nn.Conv2d(d_model, 1, 1)
312 | self.pool_ratios = pool_ratios
313 | self.positional_encoding = PositionEmbeddingSine(num_pos_feats=d_model // 2, normalize=True)
314 | def forward(self, x):
315 | b, c, h, w = x.size()
316 | loc, glb = x.split([4, 1], dim=0) # 4,c,h,w; 1,c,h,w
317 | # b(4),c,h,w
318 | patched_glb = rearrange(glb, 'b c (hg h) (wg w) -> (hg wg b) c h w', hg=2, wg=2)
319 |
320 | # generate token attention map
321 | token_attention_map = self.sigmoid(self.sal_conv(glb))
322 | token_attention_map = F.interpolate(token_attention_map, size=patches2image(loc).shape[-2:], mode='nearest')
323 | loc = loc * rearrange(token_attention_map, 'b c (hg h) (wg w) -> (hg wg b) c h w', hg=2, wg=2)
324 | pools = []
325 | for pool_ratio in self.pool_ratios:
326 | tgt_hw = (round(h / pool_ratio), round(w / pool_ratio))
327 | pool = F.adaptive_avg_pool2d(patched_glb, tgt_hw)
328 | pools.append(rearrange(pool, 'nl c h w -> nl c (h w)')) # nl(4),c,hw
329 | # nl(4),c,nphw -> nl(4),nphw,1,c
330 | pools = rearrange(torch.cat(pools, 2), "nl c nphw -> nl nphw 1 c")
331 | loc_ = rearrange(loc, 'nl c h w -> nl (h w) 1 c')
332 | outputs = []
333 | for i, q in enumerate(loc_.unbind(dim=0)): # traverse all local patches
334 | # np*hw,1,c
335 | v = pools[i]
336 | k = v
337 | outputs.append(self.attention[i](q, k, v)[0])
338 | outputs = torch.cat(outputs, 1)
339 | src = loc.view(4, c, -1).permute(2, 0, 1) + self.dropout1(outputs)
340 | src = self.norm1(src)
341 | src = src + self.dropout2(self.linear4(self.dropout(self.activation(self.linear3(src)).clone())))
342 | src = self.norm2(src)
343 |
344 | src = src.permute(1, 2, 0).reshape(4, c, h, w) # freshed loc
345 | glb = glb + F.interpolate(patches2image(src), size=glb.shape[-2:], mode='nearest') # freshed glb
346 | return torch.cat((src, glb), 0)
347 |
348 | # model for single-scale training
349 | class MVANet(nn.Module):
350 | def __init__(self):
351 | super().__init__()
352 | self.backbone = SwinB(pretrained=True)
353 | emb_dim = 128
354 | self.sideout5 = nn.Sequential(nn.Conv2d(emb_dim, 1, kernel_size=3, padding=1))
355 | self.sideout4 = nn.Sequential(nn.Conv2d(emb_dim, 1, kernel_size=3, padding=1))
356 | self.sideout3 = nn.Sequential(nn.Conv2d(emb_dim, 1, kernel_size=3, padding=1))
357 | self.sideout2 = nn.Sequential(nn.Conv2d(emb_dim, 1, kernel_size=3, padding=1))
358 | self.sideout1 = nn.Sequential(nn.Conv2d(emb_dim, 1, kernel_size=3, padding=1))
359 |
360 | self.output5 = make_cbr(1024, emb_dim)
361 | self.output4 = make_cbr(512, emb_dim)
362 | self.output3 = make_cbr(256, emb_dim)
363 | self.output2 = make_cbr(128, emb_dim)
364 | self.output1 = make_cbr(128, emb_dim)
365 |
366 | self.multifieldcrossatt = MCLM(emb_dim, 1, [1, 4, 8])
367 | self.conv1 = make_cbr(emb_dim, emb_dim)
368 | self.conv2 = make_cbr(emb_dim, emb_dim)
369 | self.conv3 = make_cbr(emb_dim, emb_dim)
370 | self.conv4 = make_cbr(emb_dim, emb_dim)
371 | self.dec_blk1 = MCRM(emb_dim, 1, [2, 4, 8])
372 | self.dec_blk2 = MCRM(emb_dim, 1, [2, 4, 8])
373 | self.dec_blk3 = MCRM(emb_dim, 1, [2, 4, 8])
374 | self.dec_blk4 = MCRM(emb_dim, 1, [2, 4, 8])
375 |
376 | self.insmask_head = nn.Sequential(
377 | nn.Conv2d(emb_dim, 384, kernel_size=3, padding=1),
378 | nn.BatchNorm2d(384),
379 | nn.PReLU(),
380 | nn.Conv2d(384, 384, kernel_size=3, padding=1),
381 | nn.BatchNorm2d(384),
382 | nn.PReLU(),
383 | nn.Conv2d(384, emb_dim, kernel_size=3, padding=1)
384 | )
385 |
386 | self.shallow = nn.Sequential(nn.Conv2d(3, emb_dim, kernel_size=3, padding=1))
387 | self.upsample1 = make_cbg(emb_dim, emb_dim)
388 | self.upsample2 = make_cbg(emb_dim, emb_dim)
389 | self.output = nn.Sequential(nn.Conv2d(emb_dim, 1, kernel_size=3, padding=1))
390 |
391 | for m in self.modules():
392 | if isinstance(m, nn.ReLU) or isinstance(m, nn.Dropout):
393 | m.inplace = True
394 |
395 | def forward(self, x):
396 | shallow = self.shallow(x)
397 | glb = rescale_to(x, scale_factor=0.5, interpolation='bilinear')
398 | loc = image2patches(x)
399 | input = torch.cat((loc, glb), dim=0)
400 | feature = self.backbone(input)
401 | e5 = self.output5(feature[4]) # (5,128,16,16)
402 | e4 = self.output4(feature[3]) # (5,128,32,32)
403 | e3 = self.output3(feature[2]) # (5,128,64,64)
404 | e2 = self.output2(feature[1]) # (5,128,128,128)
405 | e1 = self.output1(feature[0]) # (5,128,128,128)
406 | loc_e5, glb_e5 = e5.split([4, 1], dim=0)
407 | e5 = self.multifieldcrossatt(loc_e5, glb_e5) # (4,128,16,16)
408 |
409 | e4, tokenattmap4 = self.dec_blk4(e4 + resize_as(e5, e4))
410 | e4 = self.conv4(e4)
411 | e3, tokenattmap3 = self.dec_blk3(e3 + resize_as(e4, e3))
412 | e3 = self.conv3(e3)
413 | e2, tokenattmap2 = self.dec_blk2(e2 + resize_as(e3, e2))
414 | e2 = self.conv2(e2)
415 | e1, tokenattmap1 = self.dec_blk1(e1 + resize_as(e2, e1))
416 | e1 = self.conv1(e1)
417 | loc_e1, glb_e1 = e1.split([4, 1], dim=0)
418 | output1_cat = patches2image(loc_e1) # (1,128,256,256)
419 | # add glb feat in
420 | output1_cat = output1_cat + resize_as(glb_e1, output1_cat)
421 | # merge
422 | final_output = self.insmask_head(output1_cat) # (1,128,256,256)
423 | # shallow feature merge
424 | final_output = final_output + resize_as(shallow, final_output)
425 | final_output = self.upsample1(rescale_to(final_output))
426 | final_output = rescale_to(final_output + resize_as(shallow, final_output))
427 | final_output = self.upsample2(final_output)
428 | final_output = self.output(final_output)
429 | ####
430 | sideout5 = self.sideout5(e5).cuda()
431 | sideout4 = self.sideout4(e4)
432 | sideout3 = self.sideout3(e3)
433 | sideout2 = self.sideout2(e2)
434 | sideout1 = self.sideout1(e1)
435 | #######glb_sideouts ######
436 | glb5 = self.sideout5(glb_e5)
437 | glb4 = sideout4[-1,:,:,:].unsqueeze(0)
438 | glb3 = sideout3[-1,:,:,:].unsqueeze(0)
439 | glb2 = sideout2[-1,:,:,:].unsqueeze(0)
440 | glb1 = sideout1[-1,:,:,:].unsqueeze(0)
441 | ####### concat 4 to 1 #######
442 | sideout1 = patches2image(sideout1[:-1]).cuda()
443 | sideout2 = patches2image(sideout2[:-1]).cuda()####(5,c,h,w) -> (1 c 2h,2w)
444 | sideout3 = patches2image(sideout3[:-1]).cuda()
445 | sideout4 = patches2image(sideout4[:-1]).cuda()
446 | sideout5 = patches2image(sideout5[:-1]).cuda()
447 | if self.training:
448 | return sideout5, sideout4,sideout3,sideout2,sideout1,final_output, glb5, glb4, glb3, glb2, glb1,tokenattmap4, tokenattmap3,tokenattmap2,tokenattmap1
449 | else:
450 | return final_output
451 |
452 | # model for multi-scale testing
453 | class inf_MVANet(nn.Module):
454 | def __init__(self):
455 | super().__init__()
456 | self.backbone = SwinB(pretrained=False)
457 |
458 | emb_dim = 128
459 | self.output5 = make_cbr(1024, emb_dim)
460 | self.output4 = make_cbr(512, emb_dim)
461 | self.output3 = make_cbr(256, emb_dim)
462 | self.output2 = make_cbr(128, emb_dim)
463 | self.output1 = make_cbr(128, emb_dim)
464 |
465 | self.multifieldcrossatt = inf_MCLM(emb_dim, 1, [1, 4, 8])
466 | self.conv1 = make_cbr(emb_dim, emb_dim)
467 | self.conv2 = make_cbr(emb_dim, emb_dim)
468 | self.conv3 = make_cbr(emb_dim, emb_dim)
469 | self.conv4 = make_cbr(emb_dim, emb_dim)
470 | self.dec_blk1 = inf_MCRM(emb_dim, 1, [2, 4, 8])
471 | self.dec_blk2 = inf_MCRM(emb_dim, 1, [2, 4, 8])
472 | self.dec_blk3 = inf_MCRM(emb_dim, 1, [2, 4, 8])
473 | self.dec_blk4 = inf_MCRM(emb_dim, 1, [2, 4, 8])
474 |
475 | self.insmask_head = nn.Sequential(
476 | nn.Conv2d(emb_dim, 384, kernel_size=3, padding=1),
477 | nn.BatchNorm2d(384),
478 | nn.PReLU(),
479 | nn.Conv2d(384, 384, kernel_size=3, padding=1),
480 | nn.BatchNorm2d(384),
481 | nn.PReLU(),
482 | nn.Conv2d(384, emb_dim, kernel_size=3, padding=1)
483 | )
484 |
485 | self.shallow = nn.Sequential(nn.Conv2d(3, emb_dim, kernel_size=3, padding=1))
486 | self.upsample1 = make_cbg(emb_dim, emb_dim)
487 | self.upsample2 = make_cbg(emb_dim, emb_dim)
488 | self.output = nn.Sequential(nn.Conv2d(emb_dim, 1, kernel_size=3, padding=1))
489 |
490 | for m in self.modules():
491 | if isinstance(m, nn.ReLU) or isinstance(m, nn.Dropout):
492 | m.inplace = True
493 |
494 | def forward(self, x):
495 | shallow = self.shallow(x)
496 | glb = rescale_to(x, scale_factor=0.5, interpolation='bilinear')
497 | loc = image2patches(x)
498 | input = torch.cat((loc, glb), dim=0)
499 | feature = self.backbone(input)
500 | e5 = self.output5(feature[4])
501 | e4 = self.output4(feature[3])
502 | e3 = self.output3(feature[2])
503 | e2 = self.output2(feature[1])
504 | e1 = self.output1(feature[0])
505 | loc_e5, glb_e5 = e5.split([4, 1], dim=0)
506 | e5_cat = self.multifieldcrossatt(loc_e5, glb_e5)
507 |
508 | e4 = self.conv4(self.dec_blk4(e4 + resize_as(e5_cat, e4)))
509 | e3 = self.conv3(self.dec_blk3(e3 + resize_as(e4, e3)))
510 | e2 = self.conv2(self.dec_blk2(e2 + resize_as(e3, e2)))
511 | e1 = self.conv1(self.dec_blk1(e1 + resize_as(e2, e1)))
512 | loc_e1, glb_e1 = e1.split([4, 1], dim=0)
513 | # after decoder, concat loc features to a whole one, and merge
514 | output1_cat = patches2image(loc_e1)
515 | # add glb feat in
516 | output1_cat = output1_cat + resize_as(glb_e1, output1_cat)
517 | # merge
518 | final_output = self.insmask_head(output1_cat)
519 | # shallow feature merge
520 | final_output = final_output + resize_as(shallow, final_output)
521 | final_output = self.upsample1(rescale_to(final_output))
522 | final_output = rescale_to(final_output + resize_as(shallow, final_output))
523 | final_output = self.upsample2(final_output)
524 | final_output = self.output(final_output)
525 | return final_output
526 |
527 |
528 |
--------------------------------------------------------------------------------
/model/SwinTransformer.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Swin Transformer
3 | # Copyright (c) 2021 Microsoft
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # Written by Ze Liu, Yutong Lin, Yixuan Wei
6 | # --------------------------------------------------------
7 |
8 | import torch
9 | import torch.nn as nn
10 | import torch.nn.functional as F
11 | import torch.utils.checkpoint as checkpoint
12 | import numpy as np
13 | from mmdet.utils import get_root_logger
14 | from timm.models import load_checkpoint
15 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_
16 |
17 | class Mlp(nn.Module):
18 | """ Multilayer perceptron."""
19 |
20 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
21 | super().__init__()
22 | out_features = out_features or in_features
23 | hidden_features = hidden_features or in_features
24 | self.fc1 = nn.Linear(in_features, hidden_features)
25 | self.act = act_layer()
26 | self.fc2 = nn.Linear(hidden_features, out_features)
27 | self.drop = nn.Dropout(drop)
28 |
29 | def forward(self, x):
30 | x = self.fc1(x)
31 | x = self.act(x)
32 | x = self.drop(x)
33 | x = self.fc2(x)
34 | x = self.drop(x)
35 | return x
36 |
37 |
38 | def window_partition(x, window_size):
39 | """
40 | Args:
41 | x: (B, H, W, C)
42 | window_size (int): window size
43 |
44 | Returns:
45 | windows: (num_windows*B, window_size, window_size, C)
46 | """
47 | B, H, W, C = x.shape
48 | x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
49 | windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
50 | return windows
51 |
52 |
53 | def window_reverse(windows, window_size, H, W):
54 | """
55 | Args:
56 | windows: (num_windows*B, window_size, window_size, C)
57 | window_size (int): Window size
58 | H (int): Height of image
59 | W (int): Width of image
60 |
61 | Returns:
62 | x: (B, H, W, C)
63 | """
64 | B = int(windows.shape[0] / (H * W / window_size / window_size))
65 | x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
66 | x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
67 | return x
68 |
69 |
70 | class WindowAttention(nn.Module):
71 | """ Window based multi-head self attention (W-MSA) module with relative position bias.
72 | It supports both of shifted and non-shifted window.
73 |
74 | Args:
75 | dim (int): Number of input channels.
76 | window_size (tuple[int]): The height and width of the window.
77 | num_heads (int): Number of attention heads.
78 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
79 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
80 | attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
81 | proj_drop (float, optional): Dropout ratio of output. Default: 0.0
82 | """
83 |
84 | def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
85 |
86 | super().__init__()
87 | self.dim = dim
88 | self.window_size = window_size # Wh, Ww
89 | self.num_heads = num_heads
90 | head_dim = dim // num_heads
91 | self.scale = qk_scale or head_dim ** -0.5
92 |
93 | # define a parameter table of relative position bias
94 | self.relative_position_bias_table = nn.Parameter(
95 | torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
96 |
97 | # get pair-wise relative position index for each token inside the window
98 | coords_h = torch.arange(self.window_size[0])
99 | coords_w = torch.arange(self.window_size[1])
100 | coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
101 | coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
102 | relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
103 | relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
104 | relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
105 | relative_coords[:, :, 1] += self.window_size[1] - 1
106 | relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
107 | relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
108 | self.register_buffer("relative_position_index", relative_position_index)
109 |
110 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
111 | self.attn_drop = nn.Dropout(attn_drop)
112 | self.proj = nn.Linear(dim, dim)
113 | self.proj_drop = nn.Dropout(proj_drop)
114 |
115 | trunc_normal_(self.relative_position_bias_table, std=.02)
116 | self.softmax = nn.Softmax(dim=-1)
117 |
118 | def forward(self, x, mask=None):
119 | """ Forward function.
120 |
121 | Args:
122 | x: input features with shape of (num_windows*B, N, C)
123 | mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
124 | """
125 | B_, N, C = x.shape
126 | qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
127 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
128 |
129 | q = q * self.scale
130 | attn = (q @ k.transpose(-2, -1))
131 |
132 | relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
133 | self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
134 | relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
135 | attn = attn + relative_position_bias.unsqueeze(0)
136 |
137 | if mask is not None:
138 | nW = mask.shape[0]
139 | attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
140 | attn = attn.view(-1, self.num_heads, N, N)
141 | attn = self.softmax(attn)
142 | else:
143 | attn = self.softmax(attn)
144 |
145 | attn = self.attn_drop(attn)
146 |
147 | x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
148 | x = self.proj(x)
149 | x = self.proj_drop(x)
150 | return x
151 |
152 |
153 | class SwinTransformerBlock(nn.Module):
154 | """ Swin Transformer Block.
155 |
156 | Args:
157 | dim (int): Number of input channels.
158 | num_heads (int): Number of attention heads.
159 | window_size (int): Window size.
160 | shift_size (int): Shift size for SW-MSA.
161 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
162 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
163 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
164 | drop (float, optional): Dropout rate. Default: 0.0
165 | attn_drop (float, optional): Attention dropout rate. Default: 0.0
166 | drop_path (float, optional): Stochastic depth rate. Default: 0.0
167 | act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
168 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
169 | """
170 |
171 | def __init__(self, dim, num_heads, window_size=7, shift_size=0,
172 | mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
173 | act_layer=nn.GELU, norm_layer=nn.LayerNorm):
174 | super().__init__()
175 | self.dim = dim
176 | self.num_heads = num_heads
177 | self.window_size = window_size
178 | self.shift_size = shift_size
179 | self.mlp_ratio = mlp_ratio
180 | assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
181 |
182 | self.norm1 = norm_layer(dim)
183 | self.attn = WindowAttention(
184 | dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
185 | qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
186 |
187 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
188 | self.norm2 = norm_layer(dim)
189 | mlp_hidden_dim = int(dim * mlp_ratio)
190 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
191 |
192 | self.H = None
193 | self.W = None
194 |
195 | def forward(self, x, mask_matrix):
196 | """ Forward function.
197 |
198 | Args:
199 | x: Input feature, tensor size (B, H*W, C).
200 | H, W: Spatial resolution of the input feature.
201 | mask_matrix: Attention mask for cyclic shift.
202 | """
203 | B, L, C = x.shape
204 | H, W = self.H, self.W
205 | assert L == H * W, "input feature has wrong size"
206 |
207 | shortcut = x
208 | x = self.norm1(x)
209 | x = x.view(B, H, W, C)
210 |
211 | # pad feature maps to multiples of window size
212 | pad_l = pad_t = 0
213 | pad_r = (self.window_size - W % self.window_size) % self.window_size
214 | pad_b = (self.window_size - H % self.window_size) % self.window_size
215 | x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
216 | _, Hp, Wp, _ = x.shape
217 |
218 | # cyclic shift
219 | if self.shift_size > 0:
220 | shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
221 | attn_mask = mask_matrix
222 | else:
223 | shifted_x = x
224 | attn_mask = None
225 |
226 | # partition windows
227 | x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
228 | x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
229 |
230 | # W-MSA/SW-MSA
231 | attn_windows = self.attn(x_windows, mask=attn_mask) # nW*B, window_size*window_size, C
232 |
233 | # merge windows
234 | attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
235 | shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # B H' W' C
236 |
237 | # reverse cyclic shift
238 | if self.shift_size > 0:
239 | x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
240 | else:
241 | x = shifted_x
242 |
243 | if pad_r > 0 or pad_b > 0:
244 | x = x[:, :H, :W, :].contiguous()
245 |
246 | x = x.view(B, H * W, C)
247 |
248 | # FFN
249 | x = shortcut + self.drop_path(x)
250 | x = x + self.drop_path(self.mlp(self.norm2(x)))
251 |
252 | return x
253 |
254 |
255 | class PatchMerging(nn.Module):
256 | """ Patch Merging Layer
257 |
258 | Args:
259 | dim (int): Number of input channels.
260 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
261 | """
262 | def __init__(self, dim, norm_layer=nn.LayerNorm):
263 | super().__init__()
264 | self.dim = dim
265 | self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
266 | self.norm = norm_layer(4 * dim)
267 |
268 | def forward(self, x, H, W):
269 | """ Forward function.
270 |
271 | Args:
272 | x: Input feature, tensor size (B, H*W, C).
273 | H, W: Spatial resolution of the input feature.
274 | """
275 | B, L, C = x.shape
276 | assert L == H * W, "input feature has wrong size"
277 |
278 | x = x.view(B, H, W, C)
279 |
280 | # padding
281 | pad_input = (H % 2 == 1) or (W % 2 == 1)
282 | if pad_input:
283 | x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
284 |
285 | x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
286 | x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
287 | x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
288 | x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
289 | x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
290 | x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
291 |
292 | x = self.norm(x)
293 | x = self.reduction(x)
294 |
295 | return x
296 |
297 |
298 | class BasicLayer(nn.Module):
299 | """ A basic Swin Transformer layer for one stage.
300 |
301 | Args:
302 | dim (int): Number of feature channels
303 | depth (int): Depths of this stage.
304 | num_heads (int): Number of attention head.
305 | window_size (int): Local window size. Default: 7.
306 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
307 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
308 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
309 | drop (float, optional): Dropout rate. Default: 0.0
310 | attn_drop (float, optional): Attention dropout rate. Default: 0.0
311 | drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
312 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
313 | downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
314 | use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
315 | """
316 |
317 | def __init__(self,
318 | dim,
319 | depth,
320 | num_heads,
321 | window_size=7,
322 | mlp_ratio=4.,
323 | qkv_bias=True,
324 | qk_scale=None,
325 | drop=0.,
326 | attn_drop=0.,
327 | drop_path=0.,
328 | norm_layer=nn.LayerNorm,
329 | downsample=None,
330 | use_checkpoint=False):
331 | super().__init__()
332 | self.window_size = window_size
333 | self.shift_size = window_size // 2
334 | self.depth = depth
335 | self.use_checkpoint = use_checkpoint
336 |
337 | # build blocks
338 | self.blocks = nn.ModuleList([
339 | SwinTransformerBlock(
340 | dim=dim,
341 | num_heads=num_heads,
342 | window_size=window_size,
343 | shift_size=0 if (i % 2 == 0) else window_size // 2,
344 | mlp_ratio=mlp_ratio,
345 | qkv_bias=qkv_bias,
346 | qk_scale=qk_scale,
347 | drop=drop,
348 | attn_drop=attn_drop,
349 | drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
350 | norm_layer=norm_layer)
351 | for i in range(depth)])
352 |
353 | # patch merging layer
354 | if downsample is not None:
355 | self.downsample = downsample(dim=dim, norm_layer=norm_layer)
356 | else:
357 | self.downsample = None
358 |
359 | def forward(self, x, H, W):
360 | """ Forward function.
361 |
362 | Args:
363 | x: Input feature, tensor size (B, H*W, C).
364 | H, W: Spatial resolution of the input feature.
365 | """
366 |
367 | # calculate attention mask for SW-MSA
368 | Hp = int(np.ceil(H / self.window_size)) * self.window_size
369 | Wp = int(np.ceil(W / self.window_size)) * self.window_size
370 | img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1
371 | h_slices = (slice(0, -self.window_size),
372 | slice(-self.window_size, -self.shift_size),
373 | slice(-self.shift_size, None))
374 | w_slices = (slice(0, -self.window_size),
375 | slice(-self.window_size, -self.shift_size),
376 | slice(-self.shift_size, None))
377 | cnt = 0
378 | for h in h_slices:
379 | for w in w_slices:
380 | img_mask[:, h, w, :] = cnt
381 | cnt += 1
382 |
383 | mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
384 | mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
385 | attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
386 | attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
387 |
388 | for blk in self.blocks:
389 | blk.H, blk.W = H, W
390 | if self.use_checkpoint:
391 | x = checkpoint.checkpoint(blk, x, attn_mask)
392 | else:
393 | x = blk(x, attn_mask)
394 | if self.downsample is not None:
395 | x_down = self.downsample(x, H, W)
396 | Wh, Ww = (H + 1) // 2, (W + 1) // 2
397 | return x, H, W, x_down, Wh, Ww
398 | else:
399 | return x, H, W, x, H, W
400 |
401 |
402 | class PatchEmbed(nn.Module):
403 | """ Image to Patch Embedding
404 |
405 | Args:
406 | patch_size (int): Patch token size. Default: 4.
407 | in_chans (int): Number of input image channels. Default: 3.
408 | embed_dim (int): Number of linear projection output channels. Default: 96.
409 | norm_layer (nn.Module, optional): Normalization layer. Default: None
410 | """
411 |
412 | def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
413 | super().__init__()
414 | patch_size = to_2tuple(patch_size)
415 | self.patch_size = patch_size
416 |
417 | self.in_chans = in_chans
418 | self.embed_dim = embed_dim
419 |
420 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
421 | if norm_layer is not None:
422 | self.norm = norm_layer(embed_dim)
423 | else:
424 | self.norm = None
425 |
426 | def forward(self, x):
427 | """Forward function."""
428 | # padding
429 | _, _, H, W = x.size()
430 | if W % self.patch_size[1] != 0:
431 | x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1]))
432 | if H % self.patch_size[0] != 0:
433 | x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))
434 |
435 | x = self.proj(x) # B C Wh Ww
436 | if self.norm is not None:
437 | Wh, Ww = x.size(2), x.size(3)
438 | x = x.flatten(2).transpose(1, 2)
439 | x = self.norm(x)
440 | x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww)
441 |
442 | return x
443 |
444 |
445 | class SwinTransformer(nn.Module):
446 | """ Swin Transformer backbone.
447 | A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
448 | https://arxiv.org/pdf/2103.14030
449 |
450 | Args:
451 | pretrain_img_size (int): Input image size for training the pretrained model,
452 | used in absolute postion embedding. Default 224.
453 | patch_size (int | tuple(int)): Patch size. Default: 4.
454 | in_chans (int): Number of input image channels. Default: 3.
455 | embed_dim (int): Number of linear projection output channels. Default: 96.
456 | depths (tuple[int]): Depths of each Swin Transformer stage.
457 | num_heads (tuple[int]): Number of attention head of each stage.
458 | window_size (int): Window size. Default: 7.
459 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
460 | qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
461 | qk_scale (float): Override default qk scale of head_dim ** -0.5 if set.
462 | drop_rate (float): Dropout rate.
463 | attn_drop_rate (float): Attention dropout rate. Default: 0.
464 | drop_path_rate (float): Stochastic depth rate. Default: 0.2.
465 | norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
466 | ape (bool): If True, add absolute position embedding to the patch embedding. Default: False.
467 | patch_norm (bool): If True, add normalization after patch embedding. Default: True.
468 | out_indices (Sequence[int]): Output from which stages.
469 | frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
470 | -1 means not freezing any parameters.
471 | use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
472 | """
473 |
474 | def __init__(self,
475 | pretrain_img_size=224,
476 | patch_size=4,
477 | in_chans=3,
478 | embed_dim=96,
479 | depths=[2, 2, 6, 2],
480 | num_heads=[3, 6, 12, 24],
481 | window_size=7,
482 | mlp_ratio=4.,
483 | qkv_bias=True,
484 | qk_scale=None,
485 | drop_rate=0.,
486 | attn_drop_rate=0.,
487 | drop_path_rate=0.2,
488 | norm_layer=nn.LayerNorm,
489 | ape=False,
490 | patch_norm=True,
491 | out_indices=(0, 1, 2, 3),
492 | frozen_stages=-1,
493 | use_checkpoint=False):
494 | super().__init__()
495 |
496 | self.pretrain_img_size = pretrain_img_size
497 | self.num_layers = len(depths)
498 | self.embed_dim = embed_dim
499 | self.ape = ape
500 | self.patch_norm = patch_norm
501 | self.out_indices = out_indices
502 | self.frozen_stages = frozen_stages
503 |
504 | # split image into non-overlapping patches
505 | self.patch_embed = PatchEmbed(
506 | patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
507 | norm_layer=norm_layer if self.patch_norm else None)
508 |
509 | # absolute position embedding
510 | if self.ape:
511 | pretrain_img_size = to_2tuple(pretrain_img_size)
512 | patch_size = to_2tuple(patch_size)
513 | patches_resolution = [pretrain_img_size[0] // patch_size[0], pretrain_img_size[1] // patch_size[1]]
514 |
515 | self.absolute_pos_embed = nn.Parameter(torch.zeros(1, embed_dim, patches_resolution[0], patches_resolution[1]))
516 | trunc_normal_(self.absolute_pos_embed, std=.02)
517 |
518 | self.pos_drop = nn.Dropout(p=drop_rate)
519 |
520 | # stochastic depth
521 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
522 |
523 | # build layers
524 | self.layers = nn.ModuleList()
525 | for i_layer in range(self.num_layers):
526 | layer = BasicLayer(
527 | dim=int(embed_dim * 2 ** i_layer),
528 | depth=depths[i_layer],
529 | num_heads=num_heads[i_layer],
530 | window_size=window_size,
531 | mlp_ratio=mlp_ratio,
532 | qkv_bias=qkv_bias,
533 | qk_scale=qk_scale,
534 | drop=drop_rate,
535 | attn_drop=attn_drop_rate,
536 | drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
537 | norm_layer=norm_layer,
538 | downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
539 | use_checkpoint=use_checkpoint)
540 | self.layers.append(layer)
541 |
542 | num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)]
543 | self.num_features = num_features
544 |
545 | # add a norm layer for each output
546 | for i_layer in out_indices:
547 | layer = norm_layer(num_features[i_layer])
548 | layer_name = f'norm{i_layer}'
549 | self.add_module(layer_name, layer)
550 |
551 | self._freeze_stages()
552 |
553 | def _freeze_stages(self):
554 | if self.frozen_stages >= 0:
555 | self.patch_embed.eval()
556 | for param in self.patch_embed.parameters():
557 | param.requires_grad = False
558 |
559 | if self.frozen_stages >= 1 and self.ape:
560 | self.absolute_pos_embed.requires_grad = False
561 |
562 | if self.frozen_stages >= 2:
563 | self.pos_drop.eval()
564 | for i in range(0, self.frozen_stages - 1):
565 | m = self.layers[i]
566 | m.eval()
567 | for param in m.parameters():
568 | param.requires_grad = False
569 |
570 | def init_weights(self, pretrained=None):
571 | """Initialize the weights in backbone.
572 |
573 | Args:
574 | pretrained (str, optional): Path to pre-trained weights.
575 | Defaults to None.
576 | """
577 |
578 | def _init_weights(m):
579 | if isinstance(m, nn.Linear):
580 | trunc_normal_(m.weight, std=.02)
581 | if isinstance(m, nn.Linear) and m.bias is not None:
582 | nn.init.constant_(m.bias, 0)
583 | elif isinstance(m, nn.LayerNorm):
584 | nn.init.constant_(m.bias, 0)
585 | nn.init.constant_(m.weight, 1.0)
586 |
587 | if isinstance(pretrained, str):
588 | self.apply(_init_weights)
589 | logger = get_root_logger()
590 | load_checkpoint(self, pretrained, strict=False, logger=logger)
591 | elif pretrained is None:
592 | self.apply(_init_weights)
593 | else:
594 | raise TypeError('pretrained must be a str or None')
595 |
596 | def forward(self, x):
597 | x = self.patch_embed(x)
598 |
599 | Wh, Ww = x.size(2), x.size(3)
600 | if self.ape:
601 | # interpolate the position embedding to the corresponding size
602 | absolute_pos_embed = F.interpolate(self.absolute_pos_embed, size=(Wh, Ww), mode='bicubic')
603 | x = (x + absolute_pos_embed) # B Wh*Ww C
604 |
605 | outs = [x.contiguous()]
606 | x = x.flatten(2).transpose(1, 2)
607 | x = self.pos_drop(x)
608 | for i in range(self.num_layers):
609 | layer = self.layers[i]
610 | x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww)
611 |
612 | if i in self.out_indices:
613 | norm_layer = getattr(self, f'norm{i}')
614 | x_out = norm_layer(x_out)
615 |
616 | out = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous()
617 | outs.append(out)
618 |
619 | return tuple(outs)
620 |
621 | def train(self, mode=True):
622 | """Convert the model into training mode while keep layers freezed."""
623 | super(SwinTransformer, self).train(mode)
624 | self._freeze_stages()
625 |
626 | def SwinT(pretrained=True):
627 | model = SwinTransformer(embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], window_size=7)
628 | if pretrained is True:
629 | model.load_state_dict(torch.load('data/backbone_ckpt/swin_tiny_patch4_window7_224.pth', map_location='cpu')['model'], strict=False)
630 |
631 | return model
632 |
633 | def SwinS(pretrained=True):
634 | model = SwinTransformer(embed_dim=96, depths=[2, 2, 18, 2], num_heads=[3, 6, 12, 24], window_size=7)
635 | if pretrained is True:
636 | model.load_state_dict(torch.load('data/backbone_ckpt/swin_small_patch4_window7_224.pth', map_location='cpu')['model'], strict=False)
637 |
638 | return model
639 |
640 | def SwinB(pretrained=True):
641 | model = SwinTransformer(embed_dim=128, depths=[2, 2, 18, 2], num_heads=[4, 8, 16, 32], window_size=12)
642 | if pretrained is True:
643 | model.load_state_dict(torch.load('./swin_base_patch4_window12_384_22kto1k.pth', map_location='cpu')['model'], strict=False)
644 |
645 | return model
646 |
647 | def SwinL(pretrained=True):
648 | model = SwinTransformer(embed_dim=192, depths=[2, 2, 18, 2], num_heads=[6, 12, 24, 48], window_size=12)
649 | if pretrained is True:
650 | model.load_state_dict(torch.load('data/backbone_ckpt/swin_large_patch4_window12_384_22kto1k.pth', map_location='cpu')['model'], strict=False)
651 |
652 | return model
653 |
654 |
--------------------------------------------------------------------------------