├── frame.jpg
├── isic2016.png
├── .gitignore
├── .ipynb_checkpoints
└── infer-checkpoint.ipynb
├── .vscode
└── settings.json
├── utils
├── loss.py
├── format_conversion.py
├── utils.py
├── polar_generate.py
├── polar_transformations.py
├── resize.py
├── isic2016_dataloader.py
├── dataloader.py
├── point_generate.py
├── isbi2016_new.py
├── isic2018_dataloader.py
├── isbi2018_new.py
└── isic2018_polar.py
├── lib
├── Position_embedding.py
├── vision_transformers.py
├── TransFuse
│ ├── DeiT.py
│ ├── vision_transformer.py
│ └── TransFuse.py
├── modules.py
├── xboundformer_v0.py
├── baseline.py
├── pvt.py
├── xboundformer.py
├── deeplabv3.py
├── replknet.py
└── transformer.py
├── README.md
└── src
├── test.py
└── train.py
/frame.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jcwang123/xboundformer/HEAD/frame.jpg
--------------------------------------------------------------------------------
/isic2016.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jcwang123/xboundformer/HEAD/isic2016.png
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | logs/
2 | results/
3 | exps/
4 | *.pth
5 | scripts/
6 | *.pyc
7 | figures/
8 |
--------------------------------------------------------------------------------
/.ipynb_checkpoints/infer-checkpoint.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [],
3 | "metadata": {},
4 | "nbformat": 4,
5 | "nbformat_minor": 5
6 | }
7 |
--------------------------------------------------------------------------------
/.vscode/settings.json:
--------------------------------------------------------------------------------
1 | {
2 | "git.ignoreLimitWarning": true,
3 | "python.pythonPath": "/home/wjc/.conda/envs/pytorch/bin/python"
4 | }
5 |
--------------------------------------------------------------------------------
/utils/loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 |
4 |
5 | def structure_loss(pred, mask):
6 | weit = 1 + 5 * torch.abs(
7 | F.avg_pool2d(mask, kernel_size=31, stride=1, padding=15) - mask)
8 | wbce = F.binary_cross_entropy_with_logits(pred, mask, reduce='none')
9 | wbce = (weit * wbce).sum(dim=(2, 3)) / weit.sum(dim=(2, 3))
10 |
11 | pred = torch.sigmoid(pred)
12 | inter = ((pred * mask) * weit).sum(dim=(2, 3))
13 | union = ((pred + mask) * weit).sum(dim=(2, 3))
14 | wiou = 1 - (inter + 1) / (union - inter + 1)
15 |
16 | return (wbce + wiou).mean()
17 |
18 |
19 | def dice_loss(pred, mask, act=False):
20 | if not act:
21 | pred = torch.sigmoid(pred)
22 | inter = (pred * mask).sum(dim=(2, 3))
23 | union = (pred + mask).sum(dim=(2, 3))
24 | wiou = 1 - (inter + 1) / (union - inter + 1)
25 | return wiou.mean()
--------------------------------------------------------------------------------
/utils/format_conversion.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 | from libtiff import TIFF # pip install libtiff
4 | from scipy import misc
5 | import random
6 |
7 |
8 | def tif2png(_src_path, _dst_path):
9 | """
10 | Usage:
11 | formatting `tif/tiff` files to `jpg/png` files
12 | :param _src_path:
13 | :param _dst_path:
14 | :return:
15 | """
16 | tif = TIFF.open(_src_path, mode='r')
17 | image = tif.read_image()
18 | misc.imsave(_dst_path, image)
19 |
20 |
21 | def data_split(src_list):
22 | """
23 | Usage:
24 | randomly spliting dataset
25 | :param src_list:
26 | :return:
27 | """
28 | counter_list = random.sample(range(0, len(src_list)), 550)
29 |
30 | return counter_list
31 |
32 |
33 | if __name__ == '__main__':
34 | src_dir = '../Dataset/train_dataset/CVC-EndoSceneStill/CVC-612/test_split/masks_tif'
35 | dst_dir = '../Dataset/train_dataset/CVC-EndoSceneStill/CVC-612/test_split/masks'
36 |
37 | os.makedirs(dst_dir, exist_ok=True)
38 | for img_name in os.listdir(src_dir):
39 | tif2png(os.path.join(src_dir, img_name),
40 | os.path.join(dst_dir, img_name.replace('.tif', '.png')))
--------------------------------------------------------------------------------
/lib/Position_embedding.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import math
3 | import torch
4 |
5 | from torch import nn, Tensor
6 | import torch.nn.functional as F
7 |
8 |
9 | class PositionEmbeddingLearned(nn.Module):
10 | """
11 | Absolute pos embedding, learned.
12 | """
13 | def __init__(self, num_pos_feats=128):
14 | super().__init__()
15 | self.row_embed = nn.Embedding(50, num_pos_feats)
16 | self.col_embed = nn.Embedding(50, num_pos_feats)
17 | self.reset_parameters()
18 |
19 | def reset_parameters(self):
20 | nn.init.uniform_(self.row_embed.weight)
21 | nn.init.uniform_(self.col_embed.weight)
22 |
23 | def forward(self, x):
24 | h, w = x.shape[-2:]
25 | # print(h, w)
26 | i = torch.arange(w, device=x.device)
27 | j = torch.arange(h, device=x.device)
28 | x_emb = self.col_embed(i)
29 | y_emb = self.row_embed(j)
30 | pos = torch.cat([
31 | x_emb.unsqueeze(0).repeat(h, 1, 1),
32 | y_emb.unsqueeze(1).repeat(1, w, 1),
33 | ],
34 | dim=-1).permute(2, 0, 1).unsqueeze(0).repeat(
35 | x.shape[0], 1, 1, 1)
36 | return pos
--------------------------------------------------------------------------------
/utils/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from thop import profile
4 | from thop import clever_format
5 |
6 |
7 | def clip_gradient(optimizer, grad_clip):
8 | """
9 | For calibrating misalignment gradient via cliping gradient technique
10 | :param optimizer:
11 | :param grad_clip:
12 | :return:
13 | """
14 | for group in optimizer.param_groups:
15 | for param in group['params']:
16 | if param.grad is not None:
17 | param.grad.data.clamp_(-grad_clip, grad_clip)
18 |
19 |
20 | def adjust_lr(optimizer, init_lr, epoch, decay_rate=0.1, decay_epoch=30):
21 | decay = decay_rate**(epoch // decay_epoch)
22 | for param_group in optimizer.param_groups:
23 | param_group['lr'] *= decay
24 |
25 |
26 | class AvgMeter(object):
27 | def __init__(self, num=40):
28 | self.num = num
29 | self.reset()
30 |
31 | def reset(self):
32 | self.val = 0
33 | self.avg = 0
34 | self.sum = 0
35 | self.count = 0
36 | self.losses = []
37 |
38 | def update(self, val, n=1):
39 | self.val = val
40 | self.sum += val * n
41 | self.count += n
42 | self.avg = self.sum / self.count
43 | self.losses.append(val)
44 |
45 | def show(self):
46 | return torch.mean(
47 | torch.stack(self.losses[np.maximum(len(self.losses) -
48 | self.num, 0):]))
49 |
50 |
51 | def CalParams(model, input_tensor):
52 | """
53 | Usage:
54 | Calculate Params and FLOPs via [THOP](https://github.com/Lyken17/pytorch-OpCounter)
55 | Necessarity:
56 | from thop import profile
57 | from thop import clever_format
58 | :param model:
59 | :param input_tensor:
60 | :return:
61 | """
62 | flops, params = profile(model, inputs=(input_tensor, ))
63 | flops, params = clever_format([flops, params], "%.3f")
64 | print('[Statistics Information]\nFLOPs: {}\nParams: {}'.format(
65 | flops, params))
66 |
--------------------------------------------------------------------------------
/utils/polar_generate.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import os
3 | import random
4 | import torch
5 | import numpy as np
6 | import skimage.draw
7 | from tqdm import tqdm
8 | import matplotlib.pyplot as plt
9 | import torch.nn.functional as F
10 | from polar_transformations import to_polar, centroid
11 |
12 |
13 | def polar_gen_isic2018():
14 | data_dir = '/raid/wjc/data/skin_lesion/isic2018_jpg_smooth/'
15 |
16 | os.makedirs(data_dir + '/PolarImage', exist_ok=True)
17 | os.makedirs(data_dir + '/PolarLabel', exist_ok=True)
18 |
19 | path_list = os.listdir(data_dir + '/Label/')
20 | path_list.sort()
21 | num = 0
22 | for path in tqdm(path_list):
23 | image_data = cv2.imread(os.path.join(data_dir, 'Image', path))
24 |
25 | label_data = cv2.imread(os.path.join(data_dir, 'Label', path),
26 | cv2.IMREAD_GRAYSCALE)
27 | center = centroid(image_data)
28 | image_data = to_polar(image_data, center)
29 | label_data = to_polar(label_data, center)
30 | # print(image_data.max(), label_data.max())
31 |
32 | cv2.imwrite(data_dir + '/PolarImage/' + path, image_data)
33 | cv2.imwrite(data_dir + '/PolarLabel/' + path, label_data)
34 | # break
35 |
36 |
37 | def point_gen_isic2016():
38 | R = 10
39 | N = 25
40 | for split in ['Train', 'Test', 'Validation']:
41 | data_dir = '/raid/wjc/data/skin_lesion/isic2016/{}/Label'.format(split)
42 |
43 | save_dir = data_dir.replace('Label', 'Point')
44 | os.makedirs(save_dir, exist_ok=True)
45 |
46 | path_list = os.listdir(data_dir)
47 | path_list.sort()
48 | num = 0
49 | for path in tqdm(path_list):
50 | name = path[:-4]
51 | label_path = os.path.join(data_dir, path)
52 | print(label_path)
53 | label_ori, point_heatmap = kpm_gen(label_path, R, N)
54 | save_path = os.path.join(save_dir, name + '.npy')
55 | np.save(save_path, point_heatmap)
56 | num += 1
57 |
58 |
59 | if __name__ == '__main__':
60 | polar_gen_isic2018()
--------------------------------------------------------------------------------
/utils/polar_transformations.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import matplotlib.pyplot as plt
4 | import numpy as np
5 | import cv2 as cv
6 |
7 |
8 | def centroid(img, lcc=False):
9 | if lcc:
10 | img = img.astype(np.uint8)
11 | nb_components, output, stats, centroids = cv.connectedComponentsWithStats(
12 | img, connectivity=4)
13 | sizes = stats[:, -1]
14 | if len(sizes) > 2:
15 | max_label = 1
16 | max_size = sizes[1]
17 |
18 | for i in range(2, nb_components):
19 | if sizes[i] > max_size:
20 | max_label = i
21 | max_size = sizes[i]
22 |
23 | img2 = np.zeros(output.shape)
24 | img2[output == max_label] = 255
25 | img = img2
26 |
27 | if len(img.shape) > 2:
28 | M = cv.moments(img[:, :, 1])
29 | else:
30 | M = cv.moments(img)
31 |
32 | if M["m00"] == 0:
33 | return (img.shape[0] // 2, img.shape[1] // 2)
34 |
35 | cX = int(M["m10"] / M["m00"])
36 | cY = int(M["m01"] / M["m00"])
37 | return (cX, cY)
38 |
39 |
40 | def to_polar(input_img, center):
41 | input_img = input_img.astype(np.float32)
42 | value = np.sqrt(((input_img.shape[0] / 2.0)**2.0) +
43 | ((input_img.shape[1] / 2.0)**2.0))
44 | polar_image = cv.linearPolar(input_img, center, value,
45 | cv.WARP_FILL_OUTLIERS)
46 | polar_image = cv.rotate(polar_image, cv.ROTATE_90_COUNTERCLOCKWISE)
47 | return polar_image
48 |
49 |
50 | def to_cart(input_img, center):
51 | input_img = input_img.astype(np.float32)
52 | input_img = cv.rotate(input_img, cv.ROTATE_90_CLOCKWISE)
53 | value = np.sqrt(((input_img.shape[1] / 2.0)**2.0) +
54 | ((input_img.shape[0] / 2.0)**2.0))
55 | polar_image = cv.linearPolar(input_img, center, value,
56 | cv.WARP_FILL_OUTLIERS + cv.WARP_INVERSE_MAP)
57 | polar_image = polar_image.astype(np.uint8)
58 | return polar_image
59 |
60 |
61 | if __name__ == "__main__":
62 | image = cv.imread('test_images/30.tif')
63 | plt.imshow(image)
64 |
65 | center = centroid(image)
66 | plt.scatter(center[0], center[1])
67 | plt.show()
68 |
69 | polar = to_polar(image, center)
70 | plt.imshow(polar)
71 | plt.show()
72 |
73 | cart = to_cart(polar, center)
74 | plt.imshow(cart)
75 | plt.show()
76 |
--------------------------------------------------------------------------------
/lib/vision_transformers.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 | from lib.transformer import BoundaryAwareTransformer
5 | from lib.Position_embedding import PositionEmbeddingLearned
6 |
7 |
8 | class in_scale_transformer(nn.Module):
9 | def __init__(self,
10 | point_pred_layers=1,
11 | num_queries=1,
12 | d_model=512,
13 | nhead=8,
14 | num_encoder_layers=6,
15 | num_decoder_layers=6,
16 | dim_feedforward=2048,
17 | dropout=0.1,
18 | activation=nn.LeakyReLU,
19 | normalize_before=False,
20 | return_intermediate_dec=False,
21 | BAG_type='2D',
22 | Atrous=False):
23 |
24 | super().__init__()
25 |
26 | self.query_embed = nn.Embedding(num_queries, d_model)
27 | self.pos_embed = PositionEmbeddingLearned(d_model // 2)
28 | self.num_queries = num_queries
29 |
30 | self.transformer = BoundaryAwareTransformer(
31 | point_pred_layers=point_pred_layers,
32 | d_model=d_model,
33 | nhead=nhead,
34 | num_encoder_layers=num_encoder_layers,
35 | num_decoder_layers=num_decoder_layers,
36 | dim_feedforward=dim_feedforward,
37 | dropout=dropout,
38 | activation=activation,
39 | normalize_before=normalize_before,
40 | return_intermediate_dec=return_intermediate_dec,
41 | BAG_type=BAG_type,
42 | Atrous=Atrous)
43 |
44 | def forward(self, x):
45 |
46 | pos_embed = self.pos_embed(x).to(x.dtype)
47 |
48 | latent_tensor, features_encoded, point_maps = self.transformer(
49 | x, None, self.query_embed.weight, pos_embed)
50 |
51 | return latent_tensor, features_encoded, point_maps
52 |
53 |
54 | # def detr_Transformer(pretrained=False, **kwargs):
55 |
56 | # transformer = DETR_Transformer(num_encoder_layers=6,
57 | # num_decoder_layers=6,
58 | # d_model=256,
59 | # nhead=8)
60 |
61 | # if pretrained:
62 | # print("Loaded DETR Pretrained Parameters From ImageNet...")
63 | # ckpt = torch.load(
64 | # '/home/chenfei/my_codes/TransformerCode-master/Ours/pretrained/detr-r50-e632da11.pth'
65 | # )
66 | # state_dict = ckpt['model']
67 |
68 | # transformer.query_embed = nn.Embedding(100, 256)
69 | # unParalled_state_dict = {}
70 | # for key in state_dict.keys():
71 | # if key.startswith("transformer"):
72 | # unParalled_state_dict[key] = state_dict[key]
73 | # #elif key.startswith("query_embed"):
74 | # # unParalled_state_dict[key] = state_dict[key]
75 | # ### Without positional embedding for detr and mismatch for query_embed
76 | # transformer.load_state_dict(unParalled_state_dict, strict=False)
77 | # return transformer
78 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # XBound-Former: Toward Cross-scale Boundary Modeling in Transformers
2 |
3 | ## Introduction
4 |
5 | This is an official release of the paper **XBound-Former: Toward Cross-scale Boundary Modeling in Transformers**, including the network implementation and the training scripts.
6 |
7 | > [**XBound-Former: Toward Cross-scale Boundary Modeling in Transformers**]
8 | > **Jiacheng Wang**, Fei Chen, Yuxi Ma, Liansheng Wang, Zhaodong Fei, Jianwei Shuai, Xiangdong Tang, Qichao Zhou, Jing Qin
9 | > In: Transactions on Medical Imaging (TMI), 2023
10 | > [[arXiv](https://arxiv.org/abs/2206.00806)][[Bibetex](https://github.com/jcwang123、xboundformer#Citation)]
11 |
12 |
13 |
14 | ## News
15 | - **[1/11 2022] This paper has been accepted to TMI.**
16 | - **[5/27 2022] We have released the training scripts.**
17 | - **[5/19 2022] We have created this repo.**
18 |
19 | ## Code List
20 |
21 | - [x] Network
22 | - [x] Pre-processing
23 | - [x] Training Codes
24 | - [ ] Pretrained Weights
25 |
26 | For more details or any questions, please feel easy to contact us by email (jiachengw@stu.xmu.edu.cn).
27 |
28 | ## Usage
29 |
30 | ### Dataset
31 |
32 | Please download the dataset from [ISIC](https://www.isic-archive.com/) challenge and [PH2](https://www.fc.up.pt/addi/ph2%20database.html) website.
33 |
34 | ### Pre-processing
35 |
36 | Please run:
37 |
38 | ```bash
39 | $ python utils/resize.py
40 | ```
41 |
42 | You need to change the **File Path** to your own and select the correct function.
43 |
44 | ### Training
45 |
46 | Please run:
47 |
48 | ```bash
49 | $ python src/train.py
50 | ```
51 | You need to change the **File Path** to your own and select the correct function.
52 |
53 | ### Testing
54 |
55 | Download the pretrained weight for ISCI-2016&$ph^2$ dataset from [Google Drive](https://drive.google.com/file/d/1-eMHYX1fr-QvI3n50S0xqWcxc3FGsMgE/view?usp=sharing) and move to the logger dir.
56 |
57 | Then, please run:
58 |
59 | ```bash
60 | $ python src/test.py
61 | ```
62 |
63 | ### Result
64 | The ISIC-2016&$ph^2$ dataset:
65 |
66 |
67 | ## Citation
68 |
69 | If you find XBound-Former useful in your research, please consider citing:
70 | ```
71 | @article{wang2023xbound,
72 | title={XBound-Former: Toward Cross-scale Boundary Modeling in Transformers},
73 | author={Wang, Jiacheng and Chen, Fei and Ma, Yuxi and Wang, Liansheng and Fei, Zhaodong and Shuai, Jianwei and Tang, Xiangdong and Zhou, Qichao and Qin, Jing},
74 | journal={IEEE Transactions on Medical Imaging},
75 | year={2023},
76 | publisher={IEEE}
77 | }
78 | ```
79 | and the prior work, BAT, as:
80 | ```
81 | @inproceedings{wang2021boundary,
82 | title={Boundary-Aware Transformers for Skin Lesion Segmentation},
83 | author={Wang, Jiacheng and Wei, Lan and Wang, Liansheng and Zhou, Qichao and Zhu, Lei and Qin, Jing},
84 | booktitle={International Conference on Medical Image Computing and Computer-Assisted Intervention},
85 | pages={206--216},
86 | year={2021},
87 | organization={Springer}
88 | }
89 | ```
90 |
--------------------------------------------------------------------------------
/lib/TransFuse/DeiT.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from functools import partial
4 |
5 | from .vision_transformer import VisionTransformer, _cfg
6 | from timm.models.registry import register_model
7 | from timm.models.layers import trunc_normal_
8 | import torch.nn.functional as F
9 | import numpy as np
10 |
11 | __all__ = [
12 | 'deit_tiny_patch16_224',
13 | 'deit_small_patch16_224',
14 | 'deit_base_patch16_224',
15 | 'deit_tiny_distilled_patch16_224',
16 | 'deit_small_distilled_patch16_224',
17 | 'deit_base_distilled_patch16_224',
18 | 'deit_base_patch16_384',
19 | 'deit_base_distilled_patch16_384',
20 | ]
21 |
22 |
23 | class DeiT(VisionTransformer):
24 | def __init__(self, *args, **kwargs):
25 | super().__init__(*args, **kwargs)
26 | num_patches = self.patch_embed.num_patches
27 | self.pos_embed = nn.Parameter(
28 | torch.zeros(1, num_patches + 1, self.embed_dim))
29 |
30 | def forward(self, x):
31 | # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
32 | # with slight modifications to add the dist_token
33 | B = x.shape[0]
34 | x = self.patch_embed(x)
35 | pe = self.pos_embed
36 |
37 | x = x + pe
38 | x = self.pos_drop(x)
39 |
40 | for blk in self.blocks:
41 | x = blk(x)
42 |
43 | x = self.norm(x)
44 | return x
45 |
46 |
47 | @register_model
48 | def deit_small_patch16_224(pretrained=False, **kwargs):
49 | model = DeiT(patch_size=16,
50 | embed_dim=384,
51 | depth=8,
52 | num_heads=6,
53 | mlp_ratio=4,
54 | qkv_bias=True,
55 | norm_layer=partial(nn.LayerNorm, eps=1e-6),
56 | **kwargs)
57 | model.default_cfg = _cfg()
58 | if pretrained:
59 | ckpt = torch.load(
60 | '/raid/wjc/code/xbound_former/lib/TransFuse/pretrained/deit_small_patch16_224-cd65a155.pth'
61 | )
62 | model.load_state_dict(ckpt['model'], strict=False)
63 |
64 | pe = model.pos_embed[:, 1:, :].detach()
65 | pe = pe.transpose(-1, -2)
66 | pe = pe.view(pe.shape[0], pe.shape[1], int(np.sqrt(pe.shape[2])),
67 | int(np.sqrt(pe.shape[2])))
68 | pe = F.interpolate(pe,
69 | size=(512 // 16, 512 // 16),
70 | mode='bilinear',
71 | align_corners=True) # (12, 16)
72 | pe = pe.flatten(2)
73 | pe = pe.transpose(-1, -2)
74 | model.pos_embed = nn.Parameter(pe)
75 | model.head = nn.Identity()
76 | return model
77 |
78 |
79 | @register_model
80 | def deit_base_patch16_224(pretrained=False, **kwargs):
81 | model = DeiT(patch_size=16,
82 | embed_dim=768,
83 | depth=10,
84 | num_heads=12,
85 | mlp_ratio=4,
86 | qkv_bias=True,
87 | norm_layer=partial(nn.LayerNorm, eps=1e-6),
88 | **kwargs)
89 | model.default_cfg = _cfg()
90 | if pretrained:
91 | ckpt = torch.load(
92 | '/home/chenfei/my_codes/TransformerCode-master/transformer_family/TransFuse/pretrained/deit_base_patch16_224-b5f2ef4d.pth'
93 | )
94 | model.load_state_dict(ckpt['model'], strict=False)
95 |
96 | pe = model.pos_embed[:, 1:, :].detach()
97 | pe = pe.transpose(-1, -2)
98 | pe = pe.view(pe.shape[0], pe.shape[1], int(np.sqrt(pe.shape[2])),
99 | int(np.sqrt(pe.shape[2])))
100 | pe = F.interpolate(pe, size=(32, 32), mode='bilinear',
101 | align_corners=True) # (12, 16)
102 | pe = pe.flatten(2)
103 | pe = pe.transpose(-1, -2)
104 | model.pos_embed = nn.Parameter(pe)
105 | model.head = nn.Identity()
106 | return model
--------------------------------------------------------------------------------
/utils/resize.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import os
3 | import random
4 | import torch
5 | import numpy as np
6 | from tqdm import tqdm
7 | import matplotlib.pyplot as plt
8 |
9 |
10 | def process_isic2018(
11 | dim=(352, 352), save_dir='/raid/wjc/data/skin_lesion/isic2018/'):
12 | image_dir_path = '/raid/wl/2018_raw_data/ISIC2018_Task1-2_Training_Input/'
13 | mask_dir_path = '/raid/wl/2018_raw_data/ISIC2018_Task1_Training_GroundTruth/'
14 |
15 | image_path_list = os.listdir(image_dir_path)
16 | mask_path_list = os.listdir(mask_dir_path)
17 |
18 | image_path_list = list(filter(lambda x: x[-3:] == 'jpg', image_path_list))
19 | mask_path_list = list(filter(lambda x: x[-3:] == 'png', mask_path_list))
20 |
21 | image_path_list.sort()
22 | mask_path_list.sort()
23 |
24 | print(len(image_path_list), len(mask_path_list))
25 |
26 | # ISBI Dataset
27 | for image_path, mask_path in zip(image_path_list, mask_path_list):
28 | if image_path[-3:] == 'jpg':
29 | print(image_path)
30 | assert os.path.basename(image_path)[:-4].split(
31 | '_')[1] == os.path.basename(mask_path)[:-4].split('_')[1]
32 | _id = os.path.basename(image_path)[:-4].split('_')[1]
33 | image_path = os.path.join(image_dir_path, image_path)
34 | mask_path = os.path.join(mask_dir_path, mask_path)
35 | image = cv2.imread(image_path)
36 | mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
37 |
38 | image_new = cv2.resize(image, dim, interpolation=cv2.INTER_CUBIC)
39 | image_new = np.array(image_new, dtype=np.uint8)
40 | mask_new = cv2.resize(mask, dim, interpolation=cv2.INTER_NEAREST)
41 | mask_new = cv2.erode(mask_new, (5, 5))
42 | mask_new = cv2.dilate(mask_new, (5, 5))
43 | mask_new = np.array(mask_new, dtype=np.uint8)
44 | # print(np.unique(mask_new))
45 |
46 | save_dir_path = save_dir + '/Image'
47 | os.makedirs(save_dir_path, exist_ok=True)
48 | # np.save(os.path.join(save_dir_path, _id + '.npy'), image_new)
49 | print(image_new.shape)
50 | cv2.imwrite(os.path.join(save_dir_path, 'ISIC_' + _id + '.jpg'),
51 | image_new)
52 |
53 | save_dir_path = save_dir + '/Label'
54 | os.makedirs(save_dir_path, exist_ok=True)
55 | # np.save(os.path.join(save_dir_path, _id + '.npy'), mask_new)
56 | cv2.imwrite(os.path.join(save_dir_path, 'ISIC_' + _id + '.jpg'),
57 | mask_new)
58 |
59 |
60 | def process_ph2():
61 | PH2_images_path = '/data2/cf_data/skinlesion_segment/PH2_rawdata/PH2_Dataset_images'
62 |
63 | path_list = os.listdir(PH2_images_path)
64 | path_list.sort()
65 |
66 | for path in path_list:
67 | image_path = os.path.join(PH2_images_path, path,
68 | path + '_Dermoscopic_Image', path + '.bmp')
69 | label_path = os.path.join(PH2_images_path, path, path + '_lesion',
70 | path + '_lesion.bmp')
71 | image = plt.imread(image_path)
72 | label = plt.imread(label_path)
73 | label = label[:, :, 0]
74 |
75 | dim = (352, 352)
76 | image_new = cv2.resize(image, dim, interpolation=cv2.INTER_AREA)
77 | label_new = cv2.resize(label, dim, interpolation=cv2.INTER_AREA)
78 |
79 | image_save_path = os.path.join(
80 | '/data2/cf_data/skinlesion_segment/PH2_rawdata/PH2/Image',
81 | path + '.npy')
82 | label_save_path = os.path.join(
83 | '/data2/cf_data/skinlesion_segment/PH2_rawdata/PH2/Label',
84 | path + '.npy')
85 |
86 | np.save(image_save_path, image_new)
87 | np.save(label_save_path, label_new)
88 |
89 |
90 | if __name__ == '__main__':
91 | process_isic2018(
92 | dim=(352, 352),
93 | save_dir='/raid/wjc/data/skin_lesion/isic2018_jpg_352_smooth/')
94 |
--------------------------------------------------------------------------------
/utils/isic2016_dataloader.py:
--------------------------------------------------------------------------------
1 | import os
2 | from PIL import Image
3 | import torch.utils.data as data
4 | import torchvision.transforms as transforms
5 | import numpy as np
6 | import random
7 | import torch
8 | import cv2
9 | from sklearn.model_selection import KFold
10 | import albumentations as A
11 | from albumentations.pytorch import ToTensorV2
12 | import json
13 |
14 |
15 | class isic2016Dataset(data.Dataset):
16 | """
17 | dataloader for isic2016 segmentation tasks
18 | """
19 | def __init__(self, image_root, gt_root, image_index, trainsize,
20 | augmentations):
21 | self.trainsize = trainsize
22 | self.augmentations = augmentations
23 | print(self.augmentations)
24 | self.image_root = image_root
25 | self.gt_root = gt_root
26 | self.images = image_index
27 | self.size = len(self.images)
28 |
29 | if self.augmentations:
30 | print('Using RandomRotation, RandomFlip')
31 |
32 | self.transform = A.Compose([
33 | A.Rotate(90),
34 | A.VerticalFlip(p=0.5),
35 | A.HorizontalFlip(p=0.5),
36 | A.Resize(self.trainsize, self.trainsize),
37 | ToTensorV2()
38 | ])
39 | else:
40 | print('no augmentation')
41 | self.transform = A.Compose(
42 | [A.Resize(self.trainsize, self.trainsize),
43 | ToTensorV2()])
44 |
45 | def __getitem__(self, idx):
46 | file_name = self.images[idx]
47 | # gt_name = file_name[:-4] + '_segmentation.png'
48 | img_root = os.path.join(self.image_root, file_name)
49 | gt_root = os.path.join(self.gt_root, file_name[:-4] + '_label.npy')
50 | # image = cv2.imread(img_root)
51 | # image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
52 | # gt = (cv2.imread(gt_root, cv2.IMREAD_GRAYSCALE))
53 | image = np.load(img_root)
54 | gt = (np.load(gt_root) * 255).astype(np.uint8)
55 |
56 | point_heatmap = cv2.Canny(gt, 0, 255) / 255.0
57 | gt = gt // 255.0
58 | gt = np.concatenate(
59 | [gt[..., np.newaxis], point_heatmap[..., np.newaxis]], axis=-1)
60 | pair = self.transform(image=image, mask=gt)
61 | gt = pair['mask'][:, :, 0]
62 | point_heatmap = pair['mask'][:, :, 1]
63 | gt = torch.unsqueeze(gt, 0)
64 | point_heatmap = torch.unsqueeze(point_heatmap, 0)
65 | image = pair['image'] / 255.0
66 |
67 | return image, gt, point_heatmap
68 |
69 | def resize(self, img, gt):
70 | assert img.size == gt.size
71 | w, h = img.size
72 | if h < self.trainsize or w < self.trainsize:
73 | h = max(h, self.trainsize)
74 | w = max(w, self.trainsize)
75 | return img.resize((w, h), Image.BILINEAR), gt.resize((w, h),
76 | Image.NEAREST)
77 | else:
78 | return img, gt
79 |
80 | def __len__(self):
81 | return self.size
82 |
83 |
84 | def get_loader(root_path,
85 | batchsize,
86 | trainsize,
87 | shuffle=True,
88 | num_workers=8,
89 | pin_memory=True,
90 | augmentation=False):
91 |
92 | dataset = isic2016Dataset(root_path + 'Train/Image',
93 | root_path + 'Train/Label',
94 | os.listdir(root_path + 'Train/Image'), trainsize,
95 | augmentation)
96 | data_loader = data.DataLoader(dataset=dataset,
97 | batch_size=batchsize,
98 | shuffle=shuffle,
99 | num_workers=num_workers,
100 | pin_memory=pin_memory)
101 |
102 | validset = isic2016Dataset(root_path + 'Validation/Image',
103 | root_path + 'Validation/Label',
104 | os.listdir(root_path + 'Validation/Image'),
105 | trainsize, False)
106 | valid_loader = data.DataLoader(dataset=validset,
107 | batch_size=1,
108 | shuffle=shuffle,
109 | num_workers=num_workers,
110 | pin_memory=pin_memory)
111 |
112 | testset = isic2016Dataset(root_path + 'Test/Image',
113 | root_path + 'Test/Label',
114 | os.listdir(root_path + 'Test/Image'), trainsize,
115 | False)
116 | test_loader = data.DataLoader(dataset=testset,
117 | batch_size=1,
118 | shuffle=shuffle,
119 | num_workers=num_workers,
120 | pin_memory=pin_memory)
121 | return data_loader, valid_loader, test_loader
--------------------------------------------------------------------------------
/src/test.py:
--------------------------------------------------------------------------------
1 | from medpy.metric.binary import hd, hd95, dc, jc, assd
2 | import numpy as np
3 | from tqdm import tqdm
4 | import torch
5 | import torch.nn.functional as F
6 | import os
7 | import sys
8 | import cv2
9 | import matplotlib.pyplot as plt
10 |
11 | sys.path.append(os.path.join(os.path.dirname(__file__), '../'))
12 | from lib.xboundformer import _segm_pvtv2
13 |
14 | os.environ['CUDA_VISIBLE_DEVICES'] = '0'
15 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
16 |
17 | save_point_pred = True
18 |
19 |
20 | def isbi2016():
21 | target_size = (512, 512)
22 |
23 | for model_name in ['xboundformer']:
24 | if model_name == 'xboundformer':
25 | model = _segm_pvtv2(1, 2, 2, 1, 352).to(device)
26 | else:
27 | # TODO
28 | raise NotImplementedError
29 | model.load_state_dict(
30 | torch.load(
31 | f'logs/isbi2016/test_loss_1_aug_1/{model_name}/fold_None/model/best.pkl'
32 | ))
33 | for fold in ['PH2', 'Test']:
34 | save_dir = f'results/ISIC-2016-pictures/{model_name}/{fold}'
35 | os.makedirs(save_dir, exist_ok=True)
36 | from utils.isbi2016_new import norm01, myDataset
37 | if fold == 'PH2':
38 | dataset = myDataset(split='test', aug=False)
39 | else:
40 | dataset = myDataset(split='valid', aug=False)
41 | test_loader = torch.utils.data.DataLoader(dataset, batch_size=1)
42 |
43 | model.eval()
44 | for batch_idx, batch_data in tqdm(enumerate(test_loader)):
45 | data = batch_data['image'].to(device).float()
46 | label = batch_data['label'].to(device).float()
47 | path = batch_data['image_path'][0]
48 | with torch.no_grad():
49 | output, point_pred1, point_pred2, point_pred3 = model(data)
50 | if save_point_pred:
51 | os.makedirs(save_dir.replace('pictures', 'point_maps'),
52 | exist_ok=True)
53 | point_pred1 = F.interpolate(point_pred1[-1], target_size)
54 | point_pred1 = point_pred1.cpu().numpy()[0, 0]
55 | plt.imsave(
56 | save_dir.replace('pictures', 'point_maps') + '/' +
57 | os.path.basename(path)[:-4] + '.png', point_pred1)
58 | output = torch.sigmoid(output)[0][0]
59 | output = (output.cpu().numpy() > 0.5).astype('uint8')
60 | output = (cv2.resize(output, target_size, cv2.INTER_NEAREST) >
61 | 0.5) * 1
62 | plt.imsave(
63 | save_dir + '/' + os.path.basename(path)[:-4] + '.png',
64 | output)
65 |
66 |
67 | def isbi2018():
68 | model = _segm_pvtv2(1, 1, 1, 1, 352).to(device)
69 | target_size = (512, 512)
70 | for fold in range(5):
71 | model.load_state_dict(
72 | torch.load(
73 | f'logs/isbi2018/test_loss_1_aug_1/xboundformer/fold_{fold}/model/best.pkl'
74 | ))
75 | save_dir = f'results/ISIC-2018-pictures/xboundformer/fold-{int(fold)+1}'
76 | os.makedirs(save_dir, exist_ok=True)
77 | from utils.isbi2018_new import norm01, myDataset
78 | dataset = myDataset(fold=str(fold), split='valid', aug=False)
79 | test_loader = torch.utils.data.DataLoader(dataset, batch_size=1)
80 |
81 | model.eval()
82 | for batch_idx, batch_data in tqdm(enumerate(test_loader)):
83 | data = batch_data['image'].to(device).float()
84 | label = batch_data['label'].to(device).float()
85 | path = batch_data['image_path'][0]
86 | with torch.no_grad():
87 | output, _, _, _ = model(data)
88 | output = torch.sigmoid(output)[0][0]
89 | output = (output.cpu().numpy() > 0.5).astype('uint8')
90 | output = (cv2.resize(output, target_size, cv2.INTER_NEAREST) >
91 | 0.5) * 1
92 | plt.imsave(
93 | save_dir + '/' + os.path.basename(path).split('_')[1][:-4] +
94 | '.png', output)
95 |
96 |
97 | def isbi2018_ablation(folder_name):
98 | vs = list(map(int, folder_name.split('_')[1:]))
99 | model = _segm_pvtv2(1, vs[0], vs[1], vs[2], 352).to(device)
100 | target_size = (512, 512)
101 | for fold in range(5):
102 | model.load_state_dict(
103 | torch.load(
104 | f'logs/isbi2018/test_loss_1_aug_1/{folder_name}/fold_{fold}/model/best.pkl'
105 | ))
106 | save_dir = f'results/ISIC-2018-pictures/{folder_name}/fold-{int(fold)+1}'
107 | os.makedirs(save_dir, exist_ok=True)
108 | from utils.isbi2018_new import norm01, myDataset
109 | dataset = myDataset(fold=str(fold), split='valid', aug=False)
110 | test_loader = torch.utils.data.DataLoader(dataset, batch_size=1)
111 |
112 | model.eval()
113 | for batch_idx, batch_data in tqdm(enumerate(test_loader)):
114 | data = batch_data['image'].to(device).float()
115 | label = batch_data['label'].to(device).float()
116 | path = batch_data['image_path'][0]
117 | with torch.no_grad():
118 | output, _, _, _ = model(data)
119 | output = torch.sigmoid(output)[0][0]
120 | output = (output.cpu().numpy() > 0.5).astype('uint8')
121 | output = (cv2.resize(output, target_size, cv2.INTER_NEAREST) >
122 | 0.5) * 1
123 | plt.imsave(
124 | save_dir + '/' + os.path.basename(path).split('_')[1][:-4] +
125 | '.png', output)
126 |
127 |
128 | if __name__ == '__main__':
129 | # isbi2016()
130 | isbi2018_ablation('bl_0_0_0')
131 | isbi2018_ablation('bl_1_0_0')
132 | isbi2018_ablation('bl_1_1_0')
133 | isbi2018_ablation('bl_1_1_1')
134 |
--------------------------------------------------------------------------------
/utils/dataloader.py:
--------------------------------------------------------------------------------
1 | import os
2 | from PIL import Image
3 | import torch.utils.data as data
4 | import torchvision.transforms as transforms
5 | import numpy as np
6 | import random
7 | import torch
8 | import cv2
9 | import albumentations as A
10 | from albumentations.pytorch import ToTensorV2
11 |
12 |
13 | class PolypDataset(data.Dataset):
14 | """
15 | dataloader for polyp segmentation tasks
16 | """
17 | def __init__(self, image_root, gt_root, trainsize, augmentations):
18 | self.trainsize = trainsize
19 | self.augmentations = augmentations
20 | print(self.augmentations)
21 | self.images = [
22 | image_root + f for f in os.listdir(image_root)
23 | if f.endswith('.jpg') or f.endswith('.png')
24 | ]
25 | self.gts = [
26 | gt_root + f for f in os.listdir(gt_root) if f.endswith('.png')
27 | ]
28 | self.images = sorted(self.images)
29 | self.gts = sorted(self.gts)
30 | self.filter_files()
31 | self.size = len(self.images)
32 | self.color1, self.color2 = [], []
33 | for name in self.images:
34 | if os.path.basename(name)[:-4].isdigit():
35 | self.color1.append(name)
36 | else:
37 | self.color2.append(name)
38 | if self.augmentations:
39 | self.transform = A.Compose([
40 | A.Rotate(90),
41 | A.VerticalFlip(p=0.5),
42 | A.HorizontalFlip(p=0.5),
43 | A.Resize(self.trainsize, self.trainsize),
44 | ToTensorV2()
45 | ])
46 | else:
47 | print('no augmentation')
48 | self.transform = A.Compose(
49 | [A.Resize(self.trainsize, self.trainsize),
50 | ToTensorV2()])
51 |
52 | def __getitem__(self, idx):
53 | image = cv2.imread(self.images[idx])
54 | image = cv2.cvtColor(image, cv2.COLOR_BGR2LAB)
55 |
56 | name2 = self.color1[idx % len(self.color1)] if np.random.rand(
57 | ) < 0.7 else self.color2[idx % len(self.color2)]
58 | image2 = cv2.imread(name2)
59 | image2 = cv2.cvtColor(image2, cv2.COLOR_BGR2LAB)
60 |
61 | mean, std = image.mean(axis=(0, 1),
62 | keepdims=True), image.std(axis=(0, 1),
63 | keepdims=True)
64 | mean2, std2 = image2.mean(axis=(0, 1),
65 | keepdims=True), image2.std(axis=(0, 1),
66 | keepdims=True)
67 | image = np.uint8((image - mean) / std * std2 + mean2)
68 | image = cv2.cvtColor(image, cv2.COLOR_LAB2RGB)
69 | gt = (cv2.imread(self.gts[idx], cv2.IMREAD_GRAYSCALE))
70 | point_heatmap = cv2.Canny(gt, 0, 255) / 255.0
71 | gt = gt // 255.0
72 | gt = np.concatenate(
73 | [gt[..., np.newaxis], point_heatmap[..., np.newaxis]], axis=-1)
74 | pair = self.transform(image=image, mask=gt)
75 | gt = pair['mask'][:, :, 0]
76 | point_heatmap = pair['mask'][:, :, 1]
77 | gt = torch.unsqueeze(gt, 0)
78 | point_heatmap = torch.unsqueeze(point_heatmap, 0)
79 | image = pair['image'] / 255.0
80 |
81 | return image, gt, point_heatmap
82 |
83 | def filter_files(self):
84 | assert len(self.images) == len(self.gts)
85 | images = []
86 | gts = []
87 | for img_path, gt_path in zip(self.images, self.gts):
88 | img = Image.open(img_path)
89 | gt = Image.open(gt_path)
90 | if img.size == gt.size:
91 | images.append(img_path)
92 | gts.append(gt_path)
93 | self.images = images
94 | self.gts = gts
95 |
96 | def resize(self, img, gt):
97 | assert img.size == gt.size
98 | w, h = img.size
99 | if h < self.trainsize or w < self.trainsize:
100 | h = max(h, self.trainsize)
101 | w = max(w, self.trainsize)
102 | return img.resize((w, h), Image.BILINEAR), gt.resize((w, h),
103 | Image.NEAREST)
104 | else:
105 | return img, gt
106 |
107 | def __len__(self):
108 | return self.size
109 |
110 |
111 | def get_loader(image_root,
112 | gt_root,
113 | batchsize,
114 | trainsize,
115 | shuffle=True,
116 | num_workers=4,
117 | pin_memory=True,
118 | augmentation=False):
119 |
120 | dataset = PolypDataset(image_root, gt_root, trainsize, augmentation)
121 | data_loader = data.DataLoader(dataset=dataset,
122 | batch_size=batchsize,
123 | shuffle=shuffle,
124 | num_workers=num_workers,
125 | pin_memory=pin_memory)
126 | return data_loader
127 |
128 |
129 | class test_dataset:
130 | def __init__(self, image_root, gt_root, testsize):
131 | self.testsize = testsize
132 | self.images = [
133 | image_root + f for f in os.listdir(image_root)
134 | if f.endswith('.jpg') or f.endswith('.png')
135 | ]
136 | self.gts = [
137 | gt_root + f for f in os.listdir(gt_root)
138 | if f.endswith('.tif') or f.endswith('.png')
139 | ]
140 | self.images = sorted(self.images)
141 | self.gts = sorted(self.gts)
142 | self.transform = A.Compose(
143 | [A.Resize(self.testsize, self.testsize),
144 | ToTensorV2()])
145 | self.size = len(self.images)
146 | self.index = 0
147 |
148 | def load_data(self):
149 | image = cv2.imread(self.images[self.index])
150 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
151 | gt = cv2.imread(self.gts[self.index], cv2.IMREAD_GRAYSCALE)
152 | pair = self.transform(image=image, mask=gt)
153 | image = pair['image'].unsqueeze(0) / 255
154 | gt = pair['mask'] / 255
155 | name = self.images[self.index].split('/')[-1]
156 | if name.endswith('.jpg'):
157 | name = name.split('.jpg')[0] + '.png'
158 | self.index += 1
159 | return image, gt, name
160 |
161 | def rgb_loader(self, path):
162 | with open(path, 'rb') as f:
163 | img = Image.open(f)
164 | return img.convert('RGB')
165 |
166 | def binary_loader(self, path):
167 | with open(path, 'rb') as f:
168 | img = Image.open(f)
169 | return img.convert('L')
--------------------------------------------------------------------------------
/lib/modules.py:
--------------------------------------------------------------------------------
1 | import torch.nn.functional as F
2 | import torch.nn as nn
3 | import torch
4 |
5 |
6 | class _simple_learner(nn.Module):
7 | def __init__(self, d_model):
8 | super().__init__()
9 | self.mlp = nn.Conv2d(d_model * 2, d_model, 1, 1)
10 |
11 | def forward(self, f_low, f_high):
12 | low_size = f_low.shape[2:]
13 | f2_high = F.interpolate(f_high, size=low_size)
14 |
15 | f2_low = torch.cat([f_low, f2_high], dim=1)
16 | f2_low = self.mlp(f2_low)
17 | return f2_low
18 |
19 |
20 | class xboundlearnerv2(nn.Module):
21 | def __init__(self, d_model, nhead, dim_feedforward=512, dropout=0.0):
22 | super().__init__()
23 |
24 | self.xbl = xboundlearner(d_model,
25 | nhead,
26 | dim_feedforward=dim_feedforward,
27 | dropout=dropout)
28 | self.xbl1 = xboundlearner(d_model,
29 | nhead,
30 | dim_feedforward=dim_feedforward,
31 | dropout=dropout)
32 | self.mlp = nn.Conv2d(d_model * 2, d_model, 1, 1)
33 |
34 | def forward(self, f_low, f_high, xi_low, xi_high):
35 | f2_low = self.xbl(f_low, xi_high)
36 | f2_high = self.xbl1(f_high, xi_low)
37 |
38 | low_size = f2_low.shape[2:]
39 | f2_high = F.interpolate(f2_high, size=low_size)
40 |
41 | f2_low = torch.cat([f2_low, f2_high], dim=1)
42 | f2_low = self.mlp(f2_low)
43 | return f2_low + f2_low
44 |
45 |
46 | class xboundlearner(nn.Module):
47 | def __init__(self, d_model, nhead, dim_feedforward=512, dropout=0.0):
48 | super().__init__()
49 |
50 | self.cross_attn = nn.MultiheadAttention(d_model,
51 | nhead,
52 | dropout=dropout)
53 |
54 | self.linear1 = nn.Linear(d_model, dim_feedforward)
55 | self.dropout = nn.Dropout(dropout)
56 | self.linear2 = nn.Linear(dim_feedforward, d_model)
57 |
58 | self.norm1 = nn.LayerNorm(d_model)
59 | self.norm2 = nn.LayerNorm(d_model)
60 |
61 | self.dropout1 = nn.Dropout(dropout)
62 | self.dropout2 = nn.Dropout(dropout)
63 |
64 | self.activation = nn.LeakyReLU()
65 |
66 | def forward(self, tgt, src):
67 | "tgt shape: Batch_size, C, H, W "
68 | "src shape: Batch_size, 1, C "
69 |
70 | B, C, h, w = tgt.shape
71 | tgt = tgt.view(B, C, h * w).permute(2, 0, 1) # shape: L, B, C
72 |
73 | src = src.permute(1, 0, 2) # shape: Q:1, B, C
74 |
75 | fusion_feature = self.cross_attn(query=tgt, key=src, value=src)[0]
76 | tgt = tgt + self.dropout1(fusion_feature)
77 | tgt = self.norm1(tgt)
78 |
79 | tgt1 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
80 | tgt = tgt + self.dropout2(tgt1)
81 | tgt = self.norm2(tgt)
82 | return tgt.permute(1, 2, 0).view(B, C, h, w)
83 |
84 |
85 | class BoundaryWiseAttentionGateAtrous2D(nn.Module):
86 | def __init__(self, in_channels, hidden_channels=None):
87 |
88 | super(BoundaryWiseAttentionGateAtrous2D, self).__init__()
89 |
90 | modules = []
91 |
92 | if hidden_channels == None:
93 | hidden_channels = in_channels // 2
94 |
95 | modules.append(
96 | nn.Sequential(
97 | nn.Conv2d(in_channels, hidden_channels, 1, bias=False),
98 | nn.BatchNorm2d(hidden_channels), nn.ReLU(inplace=True)))
99 | modules.append(
100 | nn.Sequential(
101 | nn.Conv2d(in_channels,
102 | hidden_channels,
103 | 3,
104 | padding=1,
105 | dilation=1,
106 | bias=False), nn.BatchNorm2d(hidden_channels),
107 | nn.ReLU(inplace=True)))
108 | modules.append(
109 | nn.Sequential(
110 | nn.Conv2d(in_channels,
111 | hidden_channels,
112 | 3,
113 | padding=2,
114 | dilation=2,
115 | bias=False), nn.BatchNorm2d(hidden_channels),
116 | nn.ReLU(inplace=True)))
117 | modules.append(
118 | nn.Sequential(
119 | nn.Conv2d(in_channels,
120 | hidden_channels,
121 | 3,
122 | padding=4,
123 | dilation=4,
124 | bias=False), nn.BatchNorm2d(hidden_channels),
125 | nn.ReLU(inplace=True)))
126 | modules.append(
127 | nn.Sequential(
128 | nn.Conv2d(in_channels,
129 | hidden_channels,
130 | 3,
131 | padding=6,
132 | dilation=6,
133 | bias=False), nn.BatchNorm2d(hidden_channels),
134 | nn.ReLU(inplace=True)))
135 |
136 | self.convs = nn.ModuleList(modules)
137 |
138 | self.conv_out = nn.Conv2d(5 * hidden_channels, 1, 1, bias=False)
139 |
140 | def forward(self, x):
141 | " x.shape: B, C, H, W "
142 | " return: feature, weight (B,C,H,W) "
143 | res = []
144 | for conv in self.convs:
145 | res.append(conv(x))
146 | res = torch.cat(res, dim=1)
147 | weight = torch.sigmoid(self.conv_out(res))
148 | x = x * weight + x
149 | return x, weight
150 |
151 |
152 | class BoundaryWiseAttentionGate2D(nn.Sequential):
153 | def __init__(self, in_channels, hidden_channels=None):
154 | super(BoundaryWiseAttentionGate2D, self).__init__(
155 | nn.Conv2d(in_channels,
156 | in_channels,
157 | kernel_size=3,
158 | padding=1,
159 | bias=False), nn.BatchNorm2d(in_channels),
160 | nn.ReLU(inplace=False),
161 | nn.Conv2d(in_channels,
162 | in_channels,
163 | kernel_size=3,
164 | padding=1,
165 | bias=False), nn.BatchNorm2d(in_channels),
166 | nn.ReLU(inplace=False), nn.Conv2d(in_channels, 1, kernel_size=1))
167 |
168 | def forward(self, x):
169 | " x.shape: B, C, H, W "
170 | " return: feature, weight (B,C,H,W) "
171 | weight = torch.sigmoid(
172 | super(BoundaryWiseAttentionGate2D, self).forward(x))
173 | x = x * weight + x
174 | return x, weight
175 |
--------------------------------------------------------------------------------
/utils/point_generate.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import os
3 | import random
4 | import torch
5 | import numpy as np
6 | import skimage.draw
7 | from tqdm import tqdm
8 | import matplotlib.pyplot as plt
9 | import torch.nn.functional as F
10 |
11 |
12 | def create_circular_mask(h, w, center, radius):
13 | Y, X = np.ogrid[:h, :w]
14 | dist_from_center = np.sqrt((X - center[0])**2 + (Y - center[1])**2)
15 | mask = dist_from_center <= radius
16 | return mask
17 |
18 |
19 | def NMS(heatmap, kernel=13):
20 | hmax = F.max_pool2d(heatmap, kernel, stride=1, padding=(kernel - 1) // 2)
21 | keep = (hmax == heatmap).float()
22 | return heatmap * keep, hmax, keep
23 |
24 |
25 | def draw_msra_gaussian(heatmap, center, sigma):
26 | tmp_size = sigma * 3
27 | mu_x = int(center[0] + 0.5)
28 | mu_y = int(center[1] + 0.5)
29 | w, h = heatmap.shape[0], heatmap.shape[1]
30 | ul = [int(mu_x - tmp_size), int(mu_y - tmp_size)]
31 | br = [int(mu_x + tmp_size + 1), int(mu_y + tmp_size + 1)]
32 | if ul[0] >= h or ul[1] >= w or br[0] < 0 or br[1] < 0:
33 | return heatmap
34 | size = 2 * tmp_size + 1
35 | x = np.arange(0, size, 1, np.float32)
36 | y = x[:, np.newaxis]
37 | x0 = y0 = size // 2
38 | g = np.exp(-((x - x0)**2 + (y - y0)**2) / (2 * sigma**2))
39 | g_x = max(0, -ul[0]), min(br[0], h) - ul[0]
40 | g_y = max(0, -ul[1]), min(br[1], w) - ul[1]
41 | img_x = max(0, ul[0]), min(br[0], h)
42 | img_y = max(0, ul[1]), min(br[1], w)
43 | heatmap[img_y[0]:img_y[1], img_x[0]:img_x[1]] = np.maximum(
44 | heatmap[img_y[0]:img_y[1], img_x[0]:img_x[1]], g[g_y[0]:g_y[1],
45 | g_x[0]:g_x[1]])
46 | return heatmap
47 |
48 |
49 | def kpm_gen(label_path, R, N):
50 | label = np.load(label_path)
51 | # label = label[0]
52 | label_ori = label.copy()
53 | label = label[::4, ::4]
54 | label = np.uint8(label * 255)
55 | contours, hierarchy = cv2.findContours(label, cv2.RETR_LIST,
56 | cv2.CHAIN_APPROX_NONE)
57 | contour_len = len(contours)
58 |
59 | label = np.repeat(label[..., np.newaxis], 3, axis=-1)
60 | draw_label = cv2.drawContours(label.copy(), contours, -1, (0, 0, 255), 1)
61 |
62 | point_file = []
63 | if contour_len == 0:
64 | point_heatmap = np.zeros((512, 512))
65 | else:
66 | point_heatmap = np.zeros((512, 512))
67 | for contour in contours:
68 | stds = []
69 | points = contour[:, 0] # (N,2)
70 | points = points * 4
71 | points_number = contour.shape[0]
72 | if points_number < 30:
73 | continue
74 |
75 | if points_number < 100:
76 | radius = 6
77 | neighbor_points_n_oneside = 3
78 | elif points_number < 200:
79 | radius = 10
80 | neighbor_points_n_oneside = 15
81 | elif points_number < 300:
82 | radius = 10
83 | neighbor_points_n_oneside = 20
84 | elif points_number < 350:
85 | radius = 10
86 | neighbor_points_n_oneside = 20
87 | else:
88 | radius = 10
89 | neighbor_points_n_oneside = 20
90 |
91 | radius
92 | for i in range(points_number):
93 | current_point = points[i]
94 | mask = create_circular_mask(512, 512, points[i], radius)
95 | overlap_area = np.sum(
96 | mask * label_ori) / (np.pi * radius * radius)
97 | stds.append(overlap_area)
98 | print("stds len: ", len(stds))
99 |
100 | # show
101 | selected_points = []
102 | stds = np.array(stds)
103 | neighbor_points = []
104 | for i in range(len(points)):
105 | current_point = points[i]
106 | neighbor_points_index = np.concatenate([
107 | np.arange(-neighbor_points_n_oneside, 0),
108 | np.arange(1, neighbor_points_n_oneside + 1)
109 | ]) + i
110 | neighbor_points_index[np.where(
111 | neighbor_points_index < 0)[0]] += len(points)
112 | neighbor_points_index[np.where(
113 | neighbor_points_index > len(points) - 1)[0]] -= len(points)
114 | if stds[i] < np.min(
115 | stds[neighbor_points_index]) or stds[i] > np.max(
116 | stds[neighbor_points_index]):
117 | # print(points[i])
118 | point_heatmap = draw_msra_gaussian(
119 | point_heatmap, (points[i, 0], points[i, 1]), 5)
120 | selected_points.append(points[i])
121 |
122 | print("selected_points num: ", len(selected_points))
123 | # print(selected_points)
124 | maskk = np.zeros((512, 512))
125 | rr, cc = skimage.draw.polygon(
126 | np.array(selected_points)[:, 1],
127 | np.array(selected_points)[:, 0])
128 | maskk[rr, cc] = 1
129 | intersection = np.logical_and(label_ori, maskk)
130 | union = np.logical_or(label_ori, maskk)
131 | iou_score = np.sum(intersection) / np.sum(union)
132 | print(iou_score)
133 | return label_ori, point_heatmap
134 |
135 |
136 | def point_gen_isic2018():
137 | R = 10
138 | N = 25
139 | data_dir = '/raid/wjc/data/skin_lesion/isic2018/Label'
140 |
141 | save_dir = data_dir.replace('Label', 'Point')
142 | os.makedirs(save_dir, exist_ok=True)
143 |
144 | path_list = os.listdir(data_dir)
145 | path_list.sort()
146 | num = 0
147 | for path in tqdm(path_list):
148 | name = path[:-4]
149 | label_path = os.path.join(data_dir, path)
150 | print(label_path)
151 | label_ori, point_heatmap = kpm_gen(label_path, R, N)
152 |
153 | save_path = os.path.join(save_dir, name + '.npy')
154 | np.save(save_path, point_heatmap)
155 | num += 1
156 |
157 |
158 | def point_gen_isic2016():
159 | R = 10
160 | N = 25
161 | for split in ['Train', 'Test', 'Validation']:
162 | data_dir = '/raid/wjc/data/skin_lesion/isic2016/{}/Label'.format(split)
163 |
164 | save_dir = data_dir.replace('Label', 'Point')
165 | os.makedirs(save_dir, exist_ok=True)
166 |
167 | path_list = os.listdir(data_dir)
168 | path_list.sort()
169 | num = 0
170 | for path in tqdm(path_list):
171 | name = path[:-4]
172 | label_path = os.path.join(data_dir, path)
173 | print(label_path)
174 | label_ori, point_heatmap = kpm_gen(label_path, R, N)
175 | save_path = os.path.join(save_dir, name + '.npy')
176 | np.save(save_path, point_heatmap)
177 | num += 1
178 |
179 |
180 | if __name__ == '__main__':
181 | # point_gen_isic2018()
182 | point_gen_isic2016()
--------------------------------------------------------------------------------
/utils/isbi2016_new.py:
--------------------------------------------------------------------------------
1 | import os
2 | import glob
3 | import json
4 | import torch
5 | import random
6 | import torch.nn as nn
7 | import numpy as np
8 | import torch.utils.data
9 | from torchvision import transforms
10 | import torch.utils.data as data
11 | import torch.nn.functional as F
12 | import cv2
13 |
14 | import albumentations as A
15 | from sklearn.model_selection import KFold
16 | from .polar_transformations import centroid, to_polar
17 |
18 |
19 | def norm01(x):
20 | return np.clip(x, 0, 255) / 255
21 |
22 |
23 | seperable_indexes = json.load(open('utils/data_split.json', 'r'))
24 |
25 |
26 | # cross validation
27 | class myDataset(data.Dataset):
28 | def __init__(self, split, size=352, aug=False, polar=False):
29 | super(myDataset, self).__init__()
30 | self.polar = polar
31 | self.split = split
32 |
33 | # load images, label, point
34 | self.image_paths = []
35 | self.label_paths = []
36 | self.point_paths = []
37 | self.dist_paths = []
38 |
39 | root_dir = '/raid/wjc/data/skin_lesion/isic2016'
40 |
41 | if split == 'train':
42 | indexes = os.listdir(root_dir + '/Train/Image/')
43 | self.image_paths = [
44 | f'{root_dir}/Train/Image/{_id}' for _id in indexes
45 | ]
46 | self.label_paths = [
47 | f'{root_dir}/Train/Label/{_id[:-4]}_label.npy'
48 | for _id in indexes
49 | ]
50 | self.point_paths = [
51 | f'{root_dir}/Train/Point/{_id[:-4]}_label.npy'
52 | for _id in indexes
53 | ]
54 |
55 | elif split == 'valid':
56 | indexes = os.listdir(root_dir + '/Validation/Image/')
57 | self.image_paths = [
58 | f'{root_dir}/Validation/Image/{_id}' for _id in indexes
59 | ]
60 | self.label_paths = [
61 | f'{root_dir}/Validation/Label/{_id[:-4]}_label.npy'
62 | for _id in indexes
63 | ]
64 | else:
65 | indexes = os.listdir(root_dir + '/Test/Image/')
66 | self.image_paths = [
67 | f'{root_dir}/Test/Image/{_id}' for _id in indexes
68 | ]
69 | self.label_paths = [
70 | f'{root_dir}/Test/Label/{_id[:-4]}_label.npy'
71 | for _id in indexes
72 | ]
73 |
74 | print('Loaded {} frames'.format(len(self.image_paths)))
75 | self.num_samples = len(self.image_paths)
76 | self.aug = aug
77 | self.size = size
78 |
79 | p = 0.5
80 | self.transf = A.Compose([
81 | A.GaussNoise(p=p),
82 | A.HorizontalFlip(p=p),
83 | A.VerticalFlip(p=p),
84 | A.ShiftScaleRotate(p=p),
85 | # A.RandomBrightnessContrast(p=p),
86 | ])
87 |
88 | def __getitem__(self, index):
89 | # print(self.image_paths[index])
90 | # image = cv2.imread(self.image_paths[index])
91 | # image_data = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
92 | # label_data = cv2.imread(self.label_paths[index], cv2.IMREAD_GRAYSCALE)
93 | image_data = np.load(self.image_paths[index])
94 | label_data = (np.load(self.label_paths[index]) * 255).astype('uint8')
95 |
96 | label_data = np.array(
97 | cv2.resize(label_data, (self.size, self.size), cv2.INTER_NEAREST))
98 | point_data = cv2.Canny(label_data, 0, 255) / 255.0 > 0.5
99 | label_data = label_data / 255. > 0.5
100 | image_data = np.array(
101 | cv2.resize(image_data, (self.size, self.size), cv2.INTER_LINEAR))
102 | if self.split == 'train':
103 | filter_point_data = (np.load(self.point_paths[index]) >
104 | 0.7).astype('uint8')
105 | filter_point_data = np.array(
106 | cv2.resize(filter_point_data, (self.size, self.size),
107 | cv2.INTER_NEAREST))
108 | else:
109 | filter_point_data = point_data.copy()
110 |
111 | # image_data = np.load(self.image_paths[index])
112 | # label_data = np.load(self.label_paths[index]) > 0.5
113 | # point_data = np.load(self.point_paths[index]) > 0.5
114 | # point_All_data = np.load(self.point_All_paths[index]) > 0.5 #
115 |
116 | # label_data = np.expand_dims(label_data,-1)
117 | # point_data = np.expand_dims(point_data,-1)
118 | if self.aug and self.split == 'train':
119 | mask = np.concatenate([
120 | label_data[..., np.newaxis].astype('uint8'),
121 | point_data[..., np.newaxis], filter_point_data[..., np.newaxis]
122 | ],
123 | axis=-1)
124 | # print(mask.shape)
125 | tsf = self.transf(image=image_data.astype('uint8'), mask=mask)
126 | image_data, mask_aug = tsf['image'], tsf['mask']
127 | label_data = mask_aug[:, :, 0]
128 | point_data = mask_aug[:, :, 1]
129 | filter_point_data = mask_aug[:, :, 2]
130 |
131 | image_data = norm01(image_data)
132 |
133 | if self.polar:
134 | center = centroid(image_data)
135 | image_data = to_polar(image_data, center)
136 | label_data = to_polar(label_data, center) > 0.5
137 |
138 | label_data = np.expand_dims(label_data, 0)
139 | point_data = np.expand_dims(point_data, 0)
140 | filter_point_data = np.expand_dims(filter_point_data, 0) #
141 |
142 | image_data = torch.from_numpy(image_data).float()
143 | label_data = torch.from_numpy(label_data).float()
144 | point_data = torch.from_numpy(point_data).float()
145 | filter_point_data = torch.from_numpy(filter_point_data).float() #
146 |
147 | image_data = image_data.permute(2, 0, 1)
148 | return {
149 | 'image_path': self.image_paths[index],
150 | 'label_path': self.label_paths[index],
151 | # 'point_path': self.point_paths[index],
152 | 'image': image_data,
153 | 'label': label_data,
154 | 'point': point_data,
155 | 'filter_point_data': filter_point_data
156 | }
157 |
158 | def __len__(self):
159 | return self.num_samples
160 |
161 |
162 | if __name__ == '__main__':
163 | from tqdm import tqdm
164 | dataset = myDataset(split='train', aug=True)
165 |
166 | train_loader = torch.utils.data.DataLoader(dataset,
167 | batch_size=8,
168 | shuffle=False,
169 | num_workers=2,
170 | pin_memory=True,
171 | drop_last=True)
172 | import matplotlib.pyplot as plt
173 | for d in dataset:
174 | print(d['image'].shape, d['image'].max())
175 | print(d['point'].shape, d['point'].max())
176 | print(d['filter_point_data'].shape, d['filter_point_data'].max())
177 | image = d['image'].permute(1, 2, 0).cpu()
178 | point = d['point'][0].cpu()
179 | filter_point_data = d['filter_point_data'][0].cpu()
180 | plt.imshow(filter_point_data)
181 | plt.show()
--------------------------------------------------------------------------------
/lib/xboundformer_v0.py:
--------------------------------------------------------------------------------
1 | from turtle import forward
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 | from lib.modules import xboundlearner
7 | from lib.vision_transformers import in_scale_transformer
8 |
9 | from lib.pvtv2 import pvt_v2_b2 #
10 |
11 |
12 | def _segm_pvtv2(num_classes, im_num, ex_num, xbound):
13 | backbone = pvt_v2_b2()
14 |
15 | if 1:
16 | path = 'pvt_v2_b2.pth'
17 | save_model = torch.load(path)
18 | model_dict = backbone.state_dict()
19 | state_dict = {
20 | k: v
21 | for k, v in save_model.items() if k in model_dict.keys()
22 | }
23 | model_dict.update(state_dict)
24 | backbone.load_state_dict(model_dict)
25 | classifier = _simple_classifier(num_classes)
26 | model = _SimpleSegmentationModel(backbone, classifier, im_num, ex_num)
27 | return model
28 |
29 |
30 | class _simple_classifier(nn.Module):
31 | def __init__(self, num_classes):
32 | super(_simple_classifier, self).__init__()
33 | self.classifier = nn.Sequential(
34 | nn.Conv2d(576, 256, 3, padding=1, bias=False), #560
35 | nn.BatchNorm2d(256),
36 | nn.ReLU(inplace=True),
37 | nn.Conv2d(256, num_classes, 1))
38 |
39 | def forward(self, feature):
40 | low_level_feature = feature[0]
41 | output_feature = feature[3]
42 | output_feature = F.interpolate(output_feature,
43 | size=low_level_feature.shape[2:],
44 | mode='bilinear',
45 | align_corners=False)
46 | return self.classifier(
47 | torch.cat([low_level_feature, output_feature], dim=1))
48 |
49 |
50 | class _SimpleSegmentationModel(nn.Module):
51 | # general segmentation model
52 | def __init__(self, backbone, classifier, im_num, ex_num):
53 | super(_SimpleSegmentationModel, self).__init__()
54 | self.backbone = backbone
55 | self.classifier = classifier
56 | self.bat_low = _bound_learner(hidden_features=128,
57 | im_num=im_num,
58 | ex_num=ex_num)
59 |
60 | def forward(self, x):
61 | input_shape = x.shape[-2:]
62 | features = self.backbone(
63 | x
64 | ) # ([8, 64, 64, 64]) ([8, 128, 32, 32]) ([8, 320, 16, 16]) ([8, 512, 8, 8])
65 | features, point_pre1, point_pre2, point_pre3 = self.bat_low(features)
66 | x = self.classifier(features)
67 | x = F.interpolate(x,
68 | size=input_shape,
69 | mode='bilinear',
70 | align_corners=False)
71 | return x, point_pre1, point_pre2, point_pre3
72 |
73 |
74 | class _bound_learner(nn.Module):
75 | def __init__(self, point_pred=1, hidden_features=128, im_num=2, ex_num=2):
76 |
77 | super().__init__()
78 |
79 | self.point_pred = point_pred
80 |
81 | self.convolution_mapping_1 = nn.Conv2d(in_channels=128,
82 | out_channels=hidden_features,
83 | kernel_size=(1, 1),
84 | stride=(1, 1),
85 | padding=(0, 0),
86 | bias=True)
87 | self.convolution_mapping_2 = nn.Conv2d(in_channels=320,
88 | out_channels=hidden_features,
89 | kernel_size=(1, 1),
90 | stride=(1, 1),
91 | padding=(0, 0),
92 | bias=True)
93 | self.convolution_mapping_3 = nn.Conv2d(in_channels=512,
94 | out_channels=hidden_features,
95 | kernel_size=(1, 1),
96 | stride=(1, 1),
97 | padding=(0, 0),
98 | bias=True)
99 | normalize_before = True
100 | self.im_ex_boud1 = in_scale_transformer(
101 | point_pred_layers=1,
102 | num_encoder_layers=im_num,
103 | num_decoder_layers=ex_num,
104 | d_model=hidden_features,
105 | nhead=8,
106 | normalize_before=normalize_before)
107 | self.im_ex_boud2 = in_scale_transformer(
108 | point_pred_layers=1,
109 | num_encoder_layers=im_num,
110 | num_decoder_layers=ex_num,
111 | d_model=hidden_features,
112 | nhead=8,
113 | normalize_before=normalize_before)
114 | self.im_ex_boud3 = in_scale_transformer(
115 | point_pred_layers=1,
116 | num_encoder_layers=im_num,
117 | num_decoder_layers=ex_num,
118 | d_model=hidden_features,
119 | nhead=8,
120 | normalize_before=normalize_before)
121 | self.cross_attention_3_1 = xboundlearner(hidden_features, 8)
122 | self.cross_attention_3_2 = xboundlearner(hidden_features, 8)
123 | self.trans_out_conv = nn.Conv2d(hidden_features * 2, 512, 1, 1) #
124 |
125 | def forward(self, x):
126 | features_1 = x[1]
127 | features_2 = x[2]
128 | features_3 = x[3]
129 | features_1 = self.convolution_mapping_1(features_1)
130 | features_2 = self.convolution_mapping_2(features_2)
131 | features_3 = self.convolution_mapping_3(features_3)
132 |
133 | # in-scale attention
134 | latent_tensor_1, features_encoded_1, point_maps_1 = self.im_ex_boud1(
135 | features_1)
136 |
137 | latent_tensor_2, features_encoded_2, point_maps_2 = self.im_ex_boud2(
138 | features_2)
139 |
140 | latent_tensor_3, features_encoded_3, point_maps_3 = self.im_ex_boud3(
141 | features_3)
142 |
143 | # cross-scale attention6
144 | latent_tensor_1 = latent_tensor_1.permute(2, 0, 1)
145 | latent_tensor_2 = latent_tensor_2.permute(2, 0, 1)
146 | latent_tensor_3 = latent_tensor_3.permute(2, 0, 1)
147 |
148 | # ''' point map Upsample '''
149 | features_encoded_3_1 = self.cross_attention_3_1(
150 | features_encoded_3, latent_tensor_1)
151 | features_encoded_3_2 = self.cross_attention_3_2(
152 | features_encoded_3, latent_tensor_2)
153 |
154 | trans_feature_maps = self.trans_out_conv(
155 | torch.cat([features_encoded_3_1, features_encoded_3_2], dim=1))
156 |
157 | x[3] = trans_feature_maps
158 | x[2] = torch.cat([x[2], features_encoded_2], dim=1)
159 | x[1] = torch.cat([x[1], features_encoded_1], dim=1)
160 |
161 | if self.point_pred == 1:
162 | return x, point_maps_1, point_maps_2, point_maps_3 #
163 | else:
164 | return trans_feature_maps
165 |
166 |
167 | if __name__ == '__main__':
168 | import os
169 | os.environ['CUDA_VISIBLE_DEVICES'] = '4'
170 | model = _segm_pvtv2(1).cuda()
171 | input_tensor = torch.randn(1, 3, 352, 352).cuda()
172 |
173 | prediction1 = model(input_tensor)
174 |
--------------------------------------------------------------------------------
/utils/isic2018_dataloader.py:
--------------------------------------------------------------------------------
1 | import os
2 | from PIL import Image
3 | import torch.utils.data as data
4 | import torchvision.transforms as transforms
5 | import numpy as np
6 | import random
7 | import torch
8 | import cv2
9 | from sklearn.model_selection import KFold
10 | import albumentations as A
11 | from albumentations.pytorch import ToTensorV2
12 | import json
13 |
14 |
15 | class isic2018Dataset(data.Dataset):
16 | """
17 | dataloader for isic2018 segmentation tasks
18 | """
19 | def __init__(self, image_root, gt_root, image_index, trainsize,
20 | augmentations):
21 | self.trainsize = trainsize
22 | self.augmentations = augmentations
23 | print(self.augmentations)
24 | self.image_root = image_root
25 | self.gt_root = gt_root
26 | self.images = image_index
27 | self.size = len(self.images)
28 |
29 | if self.augmentations:
30 | print('Using RandomRotation, RandomFlip')
31 |
32 | self.transform = A.Compose([
33 | A.Rotate(90),
34 | # A.GaussNoise(p=0.5),
35 | A.VerticalFlip(p=0.5),
36 | A.HorizontalFlip(p=0.5),
37 | # A.ShiftScaleRotate(p=0.5),
38 | # A.Resize(self.trainsize, self.trainsize),
39 | ToTensorV2()
40 | ])
41 | else:
42 | print('no augmentation')
43 | self.transform = A.Compose([
44 | # A.Resize(self.trainsize, self.trainsize),
45 | ToTensorV2()
46 | ])
47 |
48 | def __getitem__(self, idx):
49 | file_name = self.images[idx]
50 | # gt_name = file_name[:-4] + '_segmentation.png'
51 | img_root = os.path.join(self.image_root, file_name)
52 | gt_root = os.path.join(self.gt_root, file_name)
53 | image = cv2.imread(img_root)
54 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
55 | gt = (cv2.imread(gt_root, cv2.IMREAD_GRAYSCALE))
56 | point_heatmap = cv2.Canny(gt, 0, 255) / 255.0
57 | gt = gt // 255.0
58 | gt = np.concatenate(
59 | [gt[..., np.newaxis], point_heatmap[..., np.newaxis]], axis=-1)
60 | pair = self.transform(image=image, mask=gt)
61 | gt = pair['mask'][:, :, 0]
62 | point_heatmap = pair['mask'][:, :, 1]
63 | gt = torch.unsqueeze(gt, 0)
64 | point_heatmap = torch.unsqueeze(point_heatmap, 0)
65 | image = pair['image'] / 255.0
66 |
67 | return image, gt, point_heatmap
68 |
69 | def resize(self, img, gt):
70 | assert img.size == gt.size
71 | w, h = img.size
72 | if h < self.trainsize or w < self.trainsize:
73 | h = max(h, self.trainsize)
74 | w = max(w, self.trainsize)
75 | return img.resize((w, h), Image.BILINEAR), gt.resize((w, h),
76 | Image.NEAREST)
77 | else:
78 | return img, gt
79 |
80 | def __len__(self):
81 | return self.size
82 |
83 |
84 | # def get_loader(image_root,batchsize, trainsize, floder,shuffle=True, num_workers=8, pin_memory=True, augmentation=False):
85 | # train_image_index_all,test_image_index_all=create_k_fold_division(image_root)
86 | # image_index=train_image_index_all[floder]
87 | # test_index=test_image_index_all[floder]
88 | # dataset = isic2018Dataset(image_root, image_index, trainsize, augmentation)
89 | # data_loader = data.DataLoader(dataset=dataset,
90 | # batch_size=batchsize,
91 | # shuffle=shuffle,
92 | # num_workers=num_workers,
93 | # pin_memory=pin_memory)
94 |
95 | # testset=test_dataset(image_root,test_index,trainsize)
96 | # test_loader = data.DataLoader(dataset=testset,
97 | # batch_size=1,
98 | # shuffle=shuffle,
99 | # num_workers=num_workers,
100 | # pin_memory=pin_memory)
101 | # return data_loader,testset
102 |
103 |
104 | def get_loader(image_root,
105 | gt_root,
106 | batchsize,
107 | trainsize,
108 | floder,
109 | shuffle=True,
110 | num_workers=8,
111 | pin_memory=True,
112 | augmentation=False):
113 | js = json.load(open('utils/data_split.json'))
114 | # print(js)
115 | # train / test
116 | all_index = [f for f in os.listdir(image_root) if f.endswith('.jpg')]
117 |
118 | test_index = ['ISIC_' + i + '.jpg' for i in js[str(floder)]]
119 | image_index = list(filter(lambda x: x not in test_index, all_index))
120 | print(len(all_index), len(image_index), len(test_index))
121 |
122 | dataset = isic2018Dataset(image_root, gt_root, image_index, trainsize,
123 | augmentation)
124 | data_loader = data.DataLoader(dataset=dataset,
125 | batch_size=batchsize,
126 | shuffle=shuffle,
127 | num_workers=num_workers,
128 | pin_memory=pin_memory)
129 |
130 | testset = isic2018Dataset(image_root, gt_root, test_index, trainsize,
131 | False)
132 | test_loader = data.DataLoader(dataset=testset,
133 | batch_size=1,
134 | shuffle=shuffle,
135 | num_workers=num_workers,
136 | pin_memory=pin_memory)
137 | return data_loader, test_loader
138 |
139 |
140 | class test_dataset:
141 | def __init__(self, image_root, gt_root, test_index, testsize):
142 | self.testsize = testsize
143 | self.image_root = image_root
144 | self.gt_root = gt_root
145 | self.images = test_index
146 | self.transform = A.Compose(
147 | [A.Resize(self.testsize, self.testsize),
148 | ToTensorV2()])
149 | self.size = len(self.images)
150 | self.index = 0
151 |
152 | def __getitem__(self, idx):
153 | image = cv2.imread(os.path.join(self.image_root, self.images[idx]))
154 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
155 | # gt_ind=self.images[idx].split('_')
156 | # gt_name=gt_ind[0]+'_Task1_'+gt_ind[2]+'_GroundTruth/ISIC_'+gt_ind[4][:-4]+'_segmentation.png'
157 | gt_name = self.images[idx][:-4] + '_segmentation.png'
158 | gt_root = os.path.join(self.gt_root, gt_name)
159 | gt = cv2.imread(gt_root, cv2.IMREAD_GRAYSCALE)
160 | pair = self.transform(image=image, mask=gt)
161 | name = self.images[idx].split('/')[-1]
162 | if name.endswith('.jpg'):
163 | name = name.split('.jpg')[0] + '.png'
164 | image = pair['image'].unsqueeze(0) / 255
165 | gt = pair['mask'] / 255
166 | return image, gt, name
167 |
168 | def rgb_loader(self, path):
169 | with open(path, 'rb') as f:
170 | img = Image.open(f)
171 | return img.convert('RGB')
172 |
173 | def binary_loader(self, path):
174 | with open(path, 'rb') as f:
175 | img = Image.open(f)
176 | return img.convert('L')
177 |
178 | def __len__(self):
179 | return self.size
--------------------------------------------------------------------------------
/utils/isbi2018_new.py:
--------------------------------------------------------------------------------
1 | import os
2 | import glob
3 | import json
4 | import torch
5 | import random
6 | import torch.nn as nn
7 | import numpy as np
8 | import torch.utils.data
9 | from torchvision import transforms
10 | import torch.utils.data as data
11 | import torch.nn.functional as F
12 | import cv2
13 | import sys
14 |
15 | sys.path.insert(0, os.path.dirname(__file__) + '/../')
16 | # from utils.polar_transformations import centroid, to_polar
17 | import albumentations as A
18 | from sklearn.model_selection import KFold
19 |
20 |
21 | def norm01(x):
22 | return np.clip(x, 0, 255) / 255
23 |
24 |
25 | seperable_indexes = json.load(open('utils/data_split.json', 'r'))
26 |
27 |
28 | # cross validation
29 | class myDataset(data.Dataset):
30 | def __init__(self, fold, split, size=352, aug=False, polar=False):
31 | super(myDataset, self).__init__()
32 | self.split = split
33 | self.polar = polar
34 |
35 | # load images, label, point
36 | self.image_paths = []
37 | self.label_paths = []
38 | self.point_paths = []
39 | self.dist_paths = []
40 |
41 | indexes = os.listdir(
42 | '/raid/wjc/data/skin_lesion/isic2018_jpg_smooth/Image')
43 |
44 | valid_indexes = [
45 | 'ISIC_' + i + '.jpg' for i in seperable_indexes[str(fold)]
46 | ]
47 | train_indexes = list(filter(lambda x: x not in valid_indexes, indexes))
48 | print(len(indexes), len(train_indexes), len(valid_indexes))
49 |
50 | #valid_indexes = indexes[:260]
51 | #train_indexes = indexes[260:]
52 | print('Fold {}: train: {} valid: {}'.format(fold, len(train_indexes),
53 | len(valid_indexes)))
54 |
55 | root_dir = '/raid/wjc/data/skin_lesion/isic2018_jpg_smooth'
56 | if self.polar:
57 | if split == 'train':
58 | self.image_paths = [
59 | f'{root_dir}/PolarImage/{_id}' for _id in train_indexes
60 | ]
61 | self.label_paths = [
62 | f'{root_dir}/PolarLabel/{_id}' for _id in train_indexes
63 | ]
64 | elif split == 'valid':
65 | self.image_paths = [
66 | f'{root_dir}/PolarImage/{_id}' for _id in valid_indexes
67 | ]
68 | self.label_paths = [
69 | f'{root_dir}/PolarLabel/{_id}' for _id in valid_indexes
70 | ]
71 | else:
72 | if split == 'train':
73 | self.image_paths = [
74 | f'{root_dir}/Image/{_id}' for _id in train_indexes
75 | ]
76 | self.label_paths = [
77 | f'{root_dir}/Label/{_id}' for _id in train_indexes
78 | ]
79 | elif split == 'valid':
80 | self.image_paths = [
81 | f'{root_dir}/Image/{_id}' for _id in valid_indexes
82 | ]
83 | self.label_paths = [
84 | f'{root_dir}/Label/{_id}' for _id in valid_indexes
85 | ]
86 |
87 | print('Loaded {} frames'.format(len(self.image_paths)))
88 | self.num_samples = len(self.image_paths)
89 | self.aug = aug
90 | self.size = size
91 |
92 | p = 0.5
93 | self.transf = A.Compose([
94 | A.GaussNoise(p=p),
95 | A.HorizontalFlip(p=p),
96 | A.VerticalFlip(p=p),
97 | A.ShiftScaleRotate(p=p),
98 | # A.RandomBrightnessContrast(p=p),
99 | ])
100 |
101 | def __getitem__(self, index):
102 | # print(self.image_paths[index])
103 | image = cv2.imread(self.image_paths[index])
104 | image_data = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
105 |
106 | label_data = cv2.imread(self.label_paths[index], cv2.IMREAD_GRAYSCALE)
107 | label_data = np.array(
108 | cv2.resize(label_data, (self.size, self.size), cv2.INTER_NEAREST))
109 | point_data = cv2.Canny(label_data, 0, 255) / 255.0 > 0.5
110 | label_data = label_data / 255. > 0.5
111 | image_data = np.array(
112 | cv2.resize(image_data, (self.size, self.size), cv2.INTER_LINEAR))
113 |
114 | # image_data = np.load(self.image_paths[index])
115 | # label_data = np.load(self.label_paths[index]) > 0.5
116 | # point_data = np.load(self.point_paths[index]) > 0.5
117 | # point_All_data = np.load(self.point_All_paths[index]) > 0.5 #
118 |
119 | # label_data = np.expand_dims(label_data,-1)
120 | # point_data = np.expand_dims(point_data,-1)
121 | if self.aug and self.split == 'train':
122 | mask = np.concatenate([
123 | label_data[..., np.newaxis].astype('uint8'),
124 | point_data[..., np.newaxis]
125 | ],
126 | axis=-1)
127 | # print(mask.shape)
128 | tsf = self.transf(image=image_data.astype('uint8'), mask=mask)
129 | image_data, mask_aug = tsf['image'], tsf['mask']
130 | label_data = mask_aug[:, :, 0]
131 | point_data = mask_aug[:, :, 1]
132 |
133 | image_data = norm01(image_data)
134 |
135 | label_data = np.expand_dims(label_data, 0)
136 | point_data = np.expand_dims(point_data, 0)
137 | # point_All_data = np.expand_dims(point_All_data, 0) #
138 |
139 | image_data = torch.from_numpy(image_data).float()
140 | label_data = torch.from_numpy(label_data).float()
141 | point_data = torch.from_numpy(point_data).float()
142 | # point_All_data = torch.from_numpy(point_All_data).float() #
143 |
144 | image_data = image_data.permute(2, 0, 1)
145 | return {
146 | 'image_path': self.image_paths[index],
147 | 'label_path': self.label_paths[index],
148 | # 'point_path': self.point_paths[index],
149 | 'image': image_data,
150 | 'label': label_data,
151 | 'point': point_data,
152 | 'point_All': label_data
153 | }
154 |
155 | def __len__(self):
156 | return self.num_samples
157 |
158 |
159 | if __name__ == '__main__':
160 | from tqdm import tqdm
161 | import sys
162 | dataset = myDataset(fold='0', split='valid', aug=False, polar=False)
163 | print(dataset.image_paths[:5])
164 | print(seperable_indexes['0'][:5])
165 | # for d in dataset:
166 | # print(d)
167 | # train_loader = torch.utils.data.DataLoader(dataset,
168 | # batch_size=8,
169 | # shuffle=False,
170 | # num_workers=2,
171 | # pin_memory=True,
172 | # drop_last=True)
173 | # import matplotlib.pyplot as plt
174 | # for d in dataset:
175 | # print(d['image'].shape, d['image'].max())
176 | # print(d['point'].shape, d['point'].max())
177 | # image = d['image'].permute(1, 2, 0).cpu()
178 | # label = d['label'].permute(1, 2, 0).cpu()
179 | # point = d['point'][0].cpu()
180 | # plt.figure()
181 | # plt.imshow(image)
182 | # plt.show()
183 | # plt.figure()
184 | # plt.imshow(label)
185 | # plt.show()
186 | # break
--------------------------------------------------------------------------------
/utils/isic2018_polar.py:
--------------------------------------------------------------------------------
1 | import os
2 | from PIL import Image
3 | import torch.utils.data as data
4 | import torchvision.transforms as transforms
5 | import numpy as np
6 | import random
7 | import torch
8 | import cv2
9 | from sklearn.model_selection import KFold
10 | import albumentations as A
11 | from albumentations.pytorch import ToTensorV2
12 | import json
13 | from utils.polar_transformations import centroid, to_polar
14 |
15 |
16 | class isic2018Dataset(data.Dataset):
17 | """
18 | dataloader for isic2018 segmentation tasks
19 | """
20 | def __init__(self, image_root, gt_root, image_index, trainsize,
21 | augmentations):
22 | self.trainsize = trainsize
23 | self.augmentations = augmentations
24 | print(self.augmentations)
25 | self.image_root = image_root
26 | self.gt_root = gt_root
27 | self.images = image_index
28 | self.size = len(self.images)
29 |
30 | if self.augmentations:
31 | print('Using RandomRotation, RandomFlip')
32 |
33 | self.transform = A.Compose([ToTensorV2()])
34 | else:
35 | print('no augmentation')
36 | self.transform = A.Compose([
37 | # A.Resize(self.trainsize, self.trainsize),
38 | ToTensorV2()
39 | ])
40 |
41 | def __getitem__(self, idx):
42 | file_name = self.images[idx]
43 | # gt_name = file_name[:-4] + '_segmentation.png'
44 | img_root = os.path.join(self.image_root, file_name)
45 | gt_root = os.path.join(self.gt_root, file_name)
46 | image = cv2.imread(img_root)
47 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
48 | gt = (cv2.imread(gt_root, cv2.IMREAD_GRAYSCALE))
49 | gt = gt // 255.0
50 |
51 | # if self.polar:
52 | # if self.manual_centers is not None:
53 | # center = self.manual_centers[idx]
54 | # else:
55 | # center = polar_transformations.centroid(label)
56 | center = centroid(gt)
57 |
58 | image = to_polar(image, center)
59 | gt = to_polar(gt, center)
60 |
61 | gt = np.concatenate([gt[..., np.newaxis], gt[..., np.newaxis]],
62 | axis=-1)
63 | pair = self.transform(image=image, mask=gt)
64 | gt = pair['mask'][:, :, 0]
65 | point_heatmap = pair['mask'][:, :, 1]
66 | gt = torch.unsqueeze(gt, 0)
67 | point_heatmap = torch.unsqueeze(point_heatmap, 0)
68 | image = pair['image'] / 255.0
69 |
70 | return image, gt, point_heatmap
71 |
72 | def resize(self, img, gt):
73 | assert img.size == gt.size
74 | w, h = img.size
75 | if h < self.trainsize or w < self.trainsize:
76 | h = max(h, self.trainsize)
77 | w = max(w, self.trainsize)
78 | return img.resize((w, h), Image.BILINEAR), gt.resize((w, h),
79 | Image.NEAREST)
80 | else:
81 | return img, gt
82 |
83 | def __len__(self):
84 | return self.size
85 |
86 |
87 | # def get_loader(image_root,batchsize, trainsize, floder,shuffle=True, num_workers=8, pin_memory=True, augmentation=False):
88 | # train_image_index_all,test_image_index_all=create_k_fold_division(image_root)
89 | # image_index=train_image_index_all[floder]
90 | # test_index=test_image_index_all[floder]
91 | # dataset = isic2018Dataset(image_root, image_index, trainsize, augmentation)
92 | # data_loader = data.DataLoader(dataset=dataset,
93 | # batch_size=batchsize,
94 | # shuffle=shuffle,
95 | # num_workers=num_workers,
96 | # pin_memory=pin_memory)
97 |
98 | # testset=test_dataset(image_root,test_index,trainsize)
99 | # test_loader = data.DataLoader(dataset=testset,
100 | # batch_size=1,
101 | # shuffle=shuffle,
102 | # num_workers=num_workers,
103 | # pin_memory=pin_memory)
104 | # return data_loader,testset
105 |
106 |
107 | def get_loader(image_root,
108 | gt_root,
109 | batchsize,
110 | trainsize,
111 | floder,
112 | shuffle=True,
113 | num_workers=8,
114 | pin_memory=True,
115 | augmentation=False):
116 | js = json.load(open('utils/data_split.json'))
117 | # print(js)
118 | # train / test
119 | all_index = [f for f in os.listdir(image_root) if f.endswith('.jpg')]
120 |
121 | test_index = ['ISIC_' + i + '.jpg' for i in js[str(floder)]]
122 | image_index = list(filter(lambda x: x not in test_index, all_index))
123 | print(len(all_index), len(image_index), len(test_index))
124 |
125 | dataset = isic2018Dataset(image_root, gt_root, image_index, trainsize,
126 | augmentation)
127 | data_loader = data.DataLoader(dataset=dataset,
128 | batch_size=batchsize,
129 | shuffle=shuffle,
130 | num_workers=num_workers,
131 | pin_memory=pin_memory)
132 |
133 | testset = isic2018Dataset(image_root, gt_root, test_index, trainsize,
134 | False)
135 | test_loader = data.DataLoader(dataset=testset,
136 | batch_size=1,
137 | shuffle=shuffle,
138 | num_workers=num_workers,
139 | pin_memory=pin_memory)
140 | return data_loader, test_loader
141 |
142 |
143 | class test_dataset:
144 | def __init__(self, image_root, gt_root, test_index, testsize):
145 | self.testsize = testsize
146 | self.image_root = image_root
147 | self.gt_root = gt_root
148 | self.images = test_index
149 | self.transform = A.Compose(
150 | [A.Resize(self.testsize, self.testsize),
151 | ToTensorV2()])
152 | self.size = len(self.images)
153 | self.index = 0
154 |
155 | def __getitem__(self, idx):
156 | image = cv2.imread(os.path.join(self.image_root, self.images[idx]))
157 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
158 | # gt_ind=self.images[idx].split('_')
159 | # gt_name=gt_ind[0]+'_Task1_'+gt_ind[2]+'_GroundTruth/ISIC_'+gt_ind[4][:-4]+'_segmentation.png'
160 | gt_name = self.images[idx][:-4] + '_segmentation.png'
161 | gt_root = os.path.join(self.gt_root, gt_name)
162 | gt = cv2.imread(gt_root, cv2.IMREAD_GRAYSCALE)
163 | pair = self.transform(image=image, mask=gt)
164 | name = self.images[idx].split('/')[-1]
165 | if name.endswith('.jpg'):
166 | name = name.split('.jpg')[0] + '.png'
167 | image = pair['image'].unsqueeze(0) / 255
168 | gt = pair['mask'] / 255
169 | return image, gt, name
170 |
171 | def rgb_loader(self, path):
172 | with open(path, 'rb') as f:
173 | img = Image.open(f)
174 | return img.convert('RGB')
175 |
176 | def binary_loader(self, path):
177 | with open(path, 'rb') as f:
178 | img = Image.open(f)
179 | return img.convert('L')
180 |
181 | def __len__(self):
182 | return self.size
183 |
184 |
185 | if __name__ == '__main__':
186 | isic2018Dataset()
--------------------------------------------------------------------------------
/lib/baseline.py:
--------------------------------------------------------------------------------
1 | from turtle import forward
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 | from lib.modules import xboundlearner, xboundlearnerv2
7 | from lib.vision_transformers import in_scale_transformer
8 |
9 | from lib.pvtv2 import pvt_v2_b2 #
10 |
11 |
12 | def _segm_pvtv2(num_classes, im_num, ex_num, xbound, trainsize):
13 | backbone = pvt_v2_b2(img_size=trainsize)
14 |
15 | if 1:
16 | path = 'pvt_v2_b2.pth'
17 | save_model = torch.load(path)
18 | model_dict = backbone.state_dict()
19 | state_dict = {
20 | k: v
21 | for k, v in save_model.items() if k in model_dict.keys()
22 | }
23 | model_dict.update(state_dict)
24 | backbone.load_state_dict(model_dict)
25 | classifier = _simple_classifier(num_classes)
26 | model = _SimpleSegmentationModel(backbone, classifier, im_num, ex_num)
27 | return model
28 |
29 |
30 | class _simple_classifier(nn.Module):
31 | def __init__(self, num_classes):
32 | super(_simple_classifier, self).__init__()
33 | self.classifier = nn.Sequential(
34 | nn.Conv2d(192, 64, 1, padding=1, bias=False), #560
35 | nn.BatchNorm2d(64),
36 | nn.ReLU(inplace=True),
37 | nn.Conv2d(64, num_classes, 1))
38 | self.classifier1 = nn.Sequential(nn.Conv2d(128, num_classes, 1))
39 | self.classifier2 = nn.Sequential(nn.Conv2d(128, num_classes, 1))
40 | self.classifier3 = nn.Sequential(nn.Conv2d(128, num_classes, 1))
41 |
42 | def forward(self, feature):
43 | low_level_feature = feature[0]
44 | output_feature = feature[1]
45 | output_feature = F.interpolate(output_feature,
46 | size=low_level_feature.shape[2:],
47 | mode='bilinear',
48 | align_corners=False)
49 | if self.training:
50 | return [
51 | self.classifier(
52 | torch.cat([low_level_feature, output_feature], dim=1)),
53 | self.classifier1(feature[1]),
54 | self.classifier2(feature[2]),
55 | self.classifier3(feature[3])
56 | ]
57 | else:
58 | return self.classifier(
59 | torch.cat([low_level_feature, output_feature], dim=1))
60 |
61 |
62 | class _SimpleSegmentationModel(nn.Module):
63 | # general segmentation model
64 | def __init__(self, backbone, classifier, im_num, ex_num):
65 | super(_SimpleSegmentationModel, self).__init__()
66 | self.backbone = backbone
67 | self.classifier = classifier
68 | self.bat_low = _bound_learner(hidden_features=128,
69 | im_num=im_num,
70 | ex_num=ex_num)
71 |
72 | def forward(self, x):
73 | input_shape = x.shape[-2:]
74 | features = self.backbone(
75 | x
76 | ) # ([8, 64, 64, 64]) ([8, 128, 32, 32]) ([8, 320, 16, 16]) ([8, 512, 8, 8])
77 | features, point_pre1, point_pre2, point_pre3 = self.bat_low(features)
78 | outputs = self.classifier(features)
79 | if self.training:
80 | outputs = [
81 | F.interpolate(o,
82 | size=input_shape,
83 | mode='bilinear',
84 | align_corners=False) for o in outputs
85 | ]
86 | else:
87 | outputs = F.interpolate(outputs,
88 | size=input_shape,
89 | mode='bilinear',
90 | align_corners=False)
91 | return outputs, point_pre1, point_pre2, point_pre3
92 |
93 |
94 | class _bound_learner(nn.Module):
95 | def __init__(self, point_pred=1, hidden_features=128, im_num=2, ex_num=2):
96 |
97 | super().__init__()
98 |
99 | self.point_pred = point_pred
100 |
101 | self.convolution_mapping_1 = nn.Conv2d(in_channels=128,
102 | out_channels=hidden_features,
103 | kernel_size=(1, 1),
104 | stride=(1, 1),
105 | padding=(0, 0),
106 | bias=True)
107 | self.convolution_mapping_2 = nn.Conv2d(in_channels=320,
108 | out_channels=hidden_features,
109 | kernel_size=(1, 1),
110 | stride=(1, 1),
111 | padding=(0, 0),
112 | bias=True)
113 | self.convolution_mapping_3 = nn.Conv2d(in_channels=512,
114 | out_channels=hidden_features,
115 | kernel_size=(1, 1),
116 | stride=(1, 1),
117 | padding=(0, 0),
118 | bias=True)
119 | normalize_before = True
120 | self.im_ex_boud1 = in_scale_transformer(
121 | point_pred_layers=1,
122 | num_encoder_layers=im_num,
123 | num_decoder_layers=ex_num,
124 | d_model=hidden_features,
125 | nhead=8,
126 | normalize_before=normalize_before)
127 | self.im_ex_boud2 = in_scale_transformer(
128 | point_pred_layers=1,
129 | num_encoder_layers=im_num,
130 | num_decoder_layers=ex_num,
131 | d_model=hidden_features,
132 | nhead=8,
133 | normalize_before=normalize_before)
134 | self.im_ex_boud3 = in_scale_transformer(
135 | point_pred_layers=1,
136 | num_encoder_layers=im_num,
137 | num_decoder_layers=ex_num,
138 | d_model=hidden_features,
139 | nhead=8,
140 | normalize_before=normalize_before)
141 | # self.cross_attention_3_1 = xboundlearner(hidden_features, 8)
142 | # self.cross_attention_3_2 = xboundlearner(hidden_features, 8)
143 |
144 | self.cross_attention_3_1 = xboundlearnerv2(hidden_features, 8)
145 | self.cross_attention_3_2 = xboundlearnerv2(hidden_features, 8)
146 |
147 | self.trans_out_conv = nn.Conv2d(hidden_features * 2, 512, 1, 1) #
148 |
149 | def forward(self, x):
150 | # for tmp in x:
151 | # print(tmp.size())
152 | features_1 = x[1]
153 | features_2 = x[2]
154 | features_3 = x[3]
155 | features_1 = self.convolution_mapping_1(features_1)
156 | features_2 = self.convolution_mapping_2(features_2)
157 | features_3 = self.convolution_mapping_3(features_3)
158 |
159 | # in-scale attention
160 | latent_tensor_1, features_encoded_1, point_maps_1 = self.im_ex_boud1(
161 | features_1)
162 |
163 | latent_tensor_2, features_encoded_2, point_maps_2 = self.im_ex_boud2(
164 | features_2)
165 |
166 | latent_tensor_3, features_encoded_3, point_maps_3 = self.im_ex_boud3(
167 | features_3)
168 |
169 | # cross-scale attention6
170 | latent_tensor_1 = latent_tensor_1.permute(2, 0, 1)
171 | latent_tensor_2 = latent_tensor_2.permute(2, 0, 1)
172 | latent_tensor_3 = latent_tensor_3.permute(2, 0, 1)
173 |
174 | # ''' point map Upsample '''
175 | features_encoded_2_2 = self.cross_attention_3_2(
176 | features_encoded_2, features_encoded_3, latent_tensor_2,
177 | latent_tensor_3)
178 | features_encoded_1_2 = self.cross_attention_3_1(
179 | features_encoded_1, features_encoded_2_2, latent_tensor_1,
180 | latent_tensor_2)
181 |
182 | # trans_feature_maps = self.trans_out_conv(
183 | # torch.cat([features_encoded_3_1, features_encoded_3_2], dim=1))
184 |
185 | # x[3] = trans_feature_maps
186 | # x[2] = torch.cat([x[2], features_encoded_2], dim=1)
187 | # x[1] = torch.cat([x[1], features_encoded_1], dim=1)
188 |
189 | features_stage2 = [
190 | x[0], features_encoded_1_2, features_encoded_2_2,
191 | features_encoded_3
192 | ]
193 |
194 | return features_stage2, point_maps_1, point_maps_2, point_maps_3 #
195 |
196 |
197 | if __name__ == '__main__':
198 | import os
199 | os.environ['CUDA_VISIBLE_DEVICES'] = '4'
200 | model = _segm_pvtv2(1).cuda()
201 | input_tensor = torch.randn(1, 3, 352, 352).cuda()
202 |
203 | prediction1 = model(input_tensor)
204 |
--------------------------------------------------------------------------------
/lib/pvt.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from lib.pvtv2 import pvt_v2_b2 #
5 | import os
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 |
10 |
11 | class BasicConv2d(nn.Module):
12 | def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1):
13 | super(BasicConv2d, self).__init__()
14 |
15 | self.conv = nn.Conv2d(in_planes, out_planes,
16 | kernel_size=kernel_size, stride=stride,
17 | padding=padding, dilation=dilation, bias=False)
18 | self.bn = nn.BatchNorm2d(out_planes)
19 | self.relu = nn.ReLU(inplace=True)
20 |
21 | def forward(self, x):
22 | x = self.conv(x)
23 | x = self.bn(x)
24 | return x
25 |
26 |
27 | class CFM(nn.Module):
28 | def __init__(self, channel):
29 | super(CFM, self).__init__()
30 | self.relu = nn.ReLU(True)
31 |
32 | self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
33 | self.conv_upsample1 = BasicConv2d(channel, channel, 3, padding=1)
34 | self.conv_upsample2 = BasicConv2d(channel, channel, 3, padding=1)
35 | self.conv_upsample3 = BasicConv2d(channel, channel, 3, padding=1)
36 | self.conv_upsample4 = BasicConv2d(channel, channel, 3, padding=1)
37 | self.conv_upsample5 = BasicConv2d(2 * channel, 2 * channel, 3, padding=1)
38 |
39 | self.conv_concat2 = BasicConv2d(2 * channel, 2 * channel, 3, padding=1)
40 | self.conv_concat3 = BasicConv2d(3 * channel, 3 * channel, 3, padding=1)
41 | self.conv4 = BasicConv2d(3 * channel, channel, 3, padding=1)
42 |
43 | def forward(self, x1, x2, x3):
44 | x1_1 = x1
45 | x2_1 = self.conv_upsample1(self.upsample(x1)) * x2
46 | x3_1 = self.conv_upsample2(self.upsample(self.upsample(x1))) \
47 | * self.conv_upsample3(self.upsample(x2)) * x3
48 |
49 | x2_2 = torch.cat((x2_1, self.conv_upsample4(self.upsample(x1_1))), 1)
50 | x2_2 = self.conv_concat2(x2_2)
51 |
52 | x3_2 = torch.cat((x3_1, self.conv_upsample5(self.upsample(x2_2))), 1)
53 | x3_2 = self.conv_concat3(x3_2)
54 |
55 | x1 = self.conv4(x3_2)
56 |
57 | return x1
58 |
59 |
60 |
61 |
62 | class GCN(nn.Module):
63 | def __init__(self, num_state, num_node, bias=False):
64 | super(GCN, self).__init__()
65 | self.conv1 = nn.Conv1d(num_node, num_node, kernel_size=1)
66 | self.relu = nn.ReLU(inplace=True)
67 | self.conv2 = nn.Conv1d(num_state, num_state, kernel_size=1, bias=bias)
68 |
69 | def forward(self, x):
70 | h = self.conv1(x.permute(0, 2, 1)).permute(0, 2, 1)
71 | h = h - x
72 | h = self.relu(self.conv2(h))
73 | return h
74 |
75 |
76 | class SAM(nn.Module):
77 | def __init__(self, num_in=32, plane_mid=16, mids=4, normalize=False):
78 | super(SAM, self).__init__()
79 |
80 | self.normalize = normalize
81 | self.num_s = int(plane_mid)
82 | self.num_n = (mids) * (mids)
83 | self.priors = nn.AdaptiveAvgPool2d(output_size=(mids + 2, mids + 2))
84 |
85 | self.conv_state = nn.Conv2d(num_in, self.num_s, kernel_size=1)
86 | self.conv_proj = nn.Conv2d(num_in, self.num_s, kernel_size=1)
87 | self.gcn = GCN(num_state=self.num_s, num_node=self.num_n)
88 | self.conv_extend = nn.Conv2d(self.num_s, num_in, kernel_size=1, bias=False)
89 |
90 | def forward(self, x, edge):
91 | edge = F.interpolate(edge, (x.size()[-2], x.size()[-1]))
92 |
93 | n, c, h, w = x.size()
94 | edge = torch.nn.functional.softmax(edge, dim=1)[:, 1, :, :].unsqueeze(1)
95 |
96 | x_state_reshaped = self.conv_state(x).view(n, self.num_s, -1)
97 | x_proj = self.conv_proj(x)
98 | x_mask = x_proj * edge
99 |
100 | x_anchor1 = self.priors(x_mask)
101 | x_anchor2 = self.priors(x_mask)[:, :, 1:-1, 1:-1].reshape(n, self.num_s, -1)
102 | x_anchor = self.priors(x_mask)[:, :, 1:-1, 1:-1].reshape(n, self.num_s, -1)
103 |
104 | x_proj_reshaped = torch.matmul(x_anchor.permute(0, 2, 1), x_proj.reshape(n, self.num_s, -1))
105 | x_proj_reshaped = torch.nn.functional.softmax(x_proj_reshaped, dim=1)
106 |
107 | x_rproj_reshaped = x_proj_reshaped
108 |
109 | x_n_state = torch.matmul(x_state_reshaped, x_proj_reshaped.permute(0, 2, 1))
110 | if self.normalize:
111 | x_n_state = x_n_state * (1. / x_state_reshaped.size(2))
112 | x_n_rel = self.gcn(x_n_state)
113 |
114 | x_state_reshaped = torch.matmul(x_n_rel, x_rproj_reshaped)
115 | x_state = x_state_reshaped.view(n, self.num_s, *x.size()[2:])
116 | out = x + (self.conv_extend(x_state))
117 |
118 | return out
119 |
120 |
121 | class ChannelAttention(nn.Module):
122 | def __init__(self, in_planes, ratio=16):
123 | super(ChannelAttention, self).__init__()
124 | self.avg_pool = nn.AdaptiveAvgPool2d(1)
125 | self.max_pool = nn.AdaptiveMaxPool2d(1)
126 |
127 | self.fc1 = nn.Conv2d(in_planes, in_planes // 16, 1, bias=False)
128 | self.relu1 = nn.ReLU()
129 | self.fc2 = nn.Conv2d(in_planes // 16, in_planes, 1, bias=False)
130 |
131 | self.sigmoid = nn.Sigmoid()
132 |
133 | def forward(self, x):
134 | avg_out = self.fc2(self.relu1(self.fc1(self.avg_pool(x))))
135 | max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x))))
136 | out = avg_out + max_out
137 | return self.sigmoid(out)
138 |
139 |
140 | class SpatialAttention(nn.Module):
141 | def __init__(self, kernel_size=7):
142 | super(SpatialAttention, self).__init__()
143 |
144 | assert kernel_size in (3, 7), 'kernel size must be 3 or 7'
145 | padding = 3 if kernel_size == 7 else 1
146 |
147 | self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
148 | self.sigmoid = nn.Sigmoid()
149 |
150 | def forward(self, x):
151 | avg_out = torch.mean(x, dim=1, keepdim=True)
152 | max_out, _ = torch.max(x, dim=1, keepdim=True)
153 | x = torch.cat([avg_out, max_out], dim=1)
154 | x = self.conv1(x)
155 | return self.sigmoid(x)
156 |
157 |
158 | class PolypPVT(nn.Module):
159 | def __init__(self, channel=32):
160 | super(PolypPVT, self).__init__()
161 |
162 | self.backbone = pvt_v2_b2() # [64, 128, 320, 512]
163 | path = './pretrained_pth/pvt_v2_b2.pth'
164 | save_model = torch.load(path)
165 | model_dict = self.backbone.state_dict()
166 | state_dict = {k: v for k, v in save_model.items() if k in model_dict.keys()}
167 | model_dict.update(state_dict)
168 | self.backbone.load_state_dict(model_dict)
169 |
170 | self.Translayer2_0 = BasicConv2d(64, channel, 1)
171 | self.Translayer2_1 = BasicConv2d(128, channel, 1)
172 | self.Translayer3_1 = BasicConv2d(320, channel, 1)
173 | self.Translayer4_1 = BasicConv2d(512, channel, 1)
174 |
175 | self.CFM = CFM(channel)
176 | self.ca = ChannelAttention(64)
177 | self.sa = SpatialAttention()
178 | self.SAM = SAM()
179 |
180 | self.down05 = nn.Upsample(scale_factor=0.5, mode='bilinear', align_corners=True)
181 | self.out_SAM = nn.Conv2d(channel, 1, 1)
182 | self.out_CFM = nn.Conv2d(channel, 1, 1)
183 |
184 |
185 | def forward(self, x):
186 |
187 | # backbone
188 | pvt = self.backbone(x)
189 | x1 = pvt[0]
190 | x2 = pvt[1]
191 | x3 = pvt[2]
192 | x4 = pvt[3]
193 |
194 | # CIM
195 | x1 = self.ca(x1) * x1 # channel attention
196 | cim_feature = self.sa(x1) * x1 # spatial attention
197 |
198 |
199 | # CFM
200 | x2_t = self.Translayer2_1(x2)
201 | x3_t = self.Translayer3_1(x3)
202 | x4_t = self.Translayer4_1(x4)
203 | cfm_feature = self.CFM(x4_t, x3_t, x2_t)
204 |
205 | # SAM
206 | T2 = self.Translayer2_0(cim_feature)
207 | T2 = self.down05(T2)
208 | sam_feature = self.SAM(cfm_feature, T2)
209 |
210 | prediction1 = self.out_CFM(cfm_feature)
211 | prediction2 = self.out_SAM(sam_feature)
212 |
213 | prediction1_8 = F.interpolate(prediction1, scale_factor=8, mode='bilinear')
214 | prediction2_8 = F.interpolate(prediction2, scale_factor=8, mode='bilinear')
215 | return prediction1_8, prediction2_8
216 |
217 |
218 | if __name__ == '__main__':
219 | import os
220 | os.environ['CUDA_VISIBLE_DEVICES']='4'
221 | model = PolypPVT().cuda()
222 | input_tensor = torch.randn(1, 3, 352, 352).cuda()
223 |
224 | prediction1, prediction2 = model(input_tensor)
225 | print(prediction1.size(), prediction2.size())
--------------------------------------------------------------------------------
/lib/xboundformer.py:
--------------------------------------------------------------------------------
1 | from turtle import forward
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 | from lib.modules import xboundlearner, xboundlearnerv2, _simple_learner
7 | from lib.vision_transformers import in_scale_transformer
8 |
9 | from lib.pvtv2 import pvt_v2_b2 #
10 |
11 |
12 | def _segm_pvtv2(num_classes, im_num, ex_num, xbound, trainsize):
13 | backbone = pvt_v2_b2(img_size=trainsize)
14 |
15 | if 1:
16 | path = 'pvt_v2_b2.pth'
17 | save_model = torch.load(path)
18 | model_dict = backbone.state_dict()
19 | state_dict = {
20 | k: v
21 | for k, v in save_model.items() if k in model_dict.keys()
22 | }
23 | model_dict.update(state_dict)
24 | backbone.load_state_dict(model_dict)
25 | classifier = _simple_classifier(num_classes)
26 | model = _SimpleSegmentationModel(backbone, classifier, im_num, ex_num,
27 | xbound)
28 | return model
29 |
30 |
31 | class _simple_classifier(nn.Module):
32 | def __init__(self, num_classes):
33 | super(_simple_classifier, self).__init__()
34 | self.classifier = nn.Sequential(
35 | nn.Conv2d(192, 64, 1, padding=1, bias=False), #560
36 | nn.BatchNorm2d(64),
37 | nn.ReLU(inplace=True),
38 | nn.Conv2d(64, num_classes, 1))
39 | self.classifier1 = nn.Sequential(nn.Conv2d(128, num_classes, 1))
40 | self.classifier2 = nn.Sequential(nn.Conv2d(128, num_classes, 1))
41 | self.classifier3 = nn.Sequential(nn.Conv2d(128, num_classes, 1))
42 |
43 | def forward(self, feature):
44 | low_level_feature = feature[0]
45 | output_feature = feature[1]
46 | output_feature = F.interpolate(output_feature,
47 | size=low_level_feature.shape[2:],
48 | mode='bilinear',
49 | align_corners=False)
50 | if self.training:
51 | return [
52 | self.classifier(
53 | torch.cat([low_level_feature, output_feature], dim=1)),
54 | self.classifier1(feature[1]),
55 | self.classifier2(feature[2]),
56 | self.classifier3(feature[3])
57 | ]
58 | else:
59 | return self.classifier(
60 | torch.cat([low_level_feature, output_feature], dim=1))
61 |
62 |
63 | class _SimpleSegmentationModel(nn.Module):
64 | # general segmentation model
65 | def __init__(self, backbone, classifier, im_num, ex_num, xbound):
66 | super(_SimpleSegmentationModel, self).__init__()
67 | self.backbone = backbone
68 | self.classifier = classifier
69 | self.bat_low = _bound_learner(hidden_features=128,
70 | im_num=im_num,
71 | ex_num=ex_num,
72 | xbound=xbound)
73 |
74 | def forward(self, x):
75 | input_shape = x.shape[-2:]
76 | features = self.backbone(
77 | x
78 | ) # ([8, 64, 64, 64]) ([8, 128, 32, 32]) ([8, 320, 16, 16]) ([8, 512, 8, 8])
79 | features, point_pre1, point_pre2, point_pre3 = self.bat_low(features)
80 | outputs = self.classifier(features)
81 | if self.training:
82 | outputs = [
83 | F.interpolate(o,
84 | size=input_shape,
85 | mode='bilinear',
86 | align_corners=False) for o in outputs
87 | ]
88 | else:
89 | outputs = F.interpolate(outputs,
90 | size=input_shape,
91 | mode='bilinear',
92 | align_corners=False)
93 | return outputs, point_pre1, point_pre2, point_pre3
94 |
95 |
96 | class _bound_learner(nn.Module):
97 | def __init__(self,
98 | point_pred=1,
99 | hidden_features=128,
100 | im_num=2,
101 | ex_num=2,
102 | xbound=True):
103 |
104 | super().__init__()
105 | self.im_num = im_num
106 | self.ex_num = ex_num
107 |
108 | self.point_pred = point_pred
109 |
110 | self.convolution_mapping_1 = nn.Conv2d(in_channels=128,
111 | out_channels=hidden_features,
112 | kernel_size=(1, 1),
113 | stride=(1, 1),
114 | padding=(0, 0),
115 | bias=True)
116 | self.convolution_mapping_2 = nn.Conv2d(in_channels=320,
117 | out_channels=hidden_features,
118 | kernel_size=(1, 1),
119 | stride=(1, 1),
120 | padding=(0, 0),
121 | bias=True)
122 | self.convolution_mapping_3 = nn.Conv2d(in_channels=512,
123 | out_channels=hidden_features,
124 | kernel_size=(1, 1),
125 | stride=(1, 1),
126 | padding=(0, 0),
127 | bias=True)
128 | normalize_before = True
129 |
130 | if im_num + ex_num > 0:
131 | self.im_ex_boud1 = in_scale_transformer(
132 | point_pred_layers=1,
133 | num_encoder_layers=im_num,
134 | num_decoder_layers=ex_num,
135 | d_model=hidden_features,
136 | nhead=8,
137 | normalize_before=normalize_before)
138 | self.im_ex_boud2 = in_scale_transformer(
139 | point_pred_layers=1,
140 | num_encoder_layers=im_num,
141 | num_decoder_layers=ex_num,
142 | d_model=hidden_features,
143 | nhead=8,
144 | normalize_before=normalize_before)
145 | self.im_ex_boud3 = in_scale_transformer(
146 | point_pred_layers=1,
147 | num_encoder_layers=im_num,
148 | num_decoder_layers=ex_num,
149 | d_model=hidden_features,
150 | nhead=8,
151 | normalize_before=normalize_before)
152 | # self.cross_attention_3_1 = xboundlearner(hidden_features, 8)
153 | # self.cross_attention_3_2 = xboundlearner(hidden_features, 8)
154 |
155 | self.xbound = xbound
156 | if xbound:
157 | self.cross_attention_3_1 = xboundlearnerv2(hidden_features, 8)
158 | self.cross_attention_3_2 = xboundlearnerv2(hidden_features, 8)
159 | else:
160 | self.cross_attention_3_1 = _simple_learner(hidden_features)
161 | self.cross_attention_3_2 = _simple_learner(hidden_features)
162 |
163 | self.trans_out_conv = nn.Conv2d(hidden_features * 2, 512, 1, 1) #
164 |
165 | def forward(self, x):
166 | # for tmp in x:
167 | # print(tmp.size())
168 | features_1 = x[1]
169 | features_2 = x[2]
170 | features_3 = x[3]
171 | features_1 = self.convolution_mapping_1(features_1)
172 | features_2 = self.convolution_mapping_2(features_2)
173 | features_3 = self.convolution_mapping_3(features_3)
174 |
175 | # in-scale attention
176 | if self.im_num + self.ex_num > 0:
177 | latent_tensor_1, features_encoded_1, point_maps_1 = self.im_ex_boud1(
178 | features_1)
179 |
180 | latent_tensor_2, features_encoded_2, point_maps_2 = self.im_ex_boud2(
181 | features_2)
182 |
183 | latent_tensor_3, features_encoded_3, point_maps_3 = self.im_ex_boud3(
184 | features_3)
185 |
186 | # cross-scale attention6
187 | if self.ex_num > 0:
188 | latent_tensor_1 = latent_tensor_1.permute(2, 0, 1)
189 | latent_tensor_2 = latent_tensor_2.permute(2, 0, 1)
190 | latent_tensor_3 = latent_tensor_3.permute(2, 0, 1)
191 |
192 | else:
193 | features_encoded_1 = features_1
194 | features_encoded_2 = features_2
195 | features_encoded_3 = features_3
196 |
197 | # ''' point map Upsample '''
198 | if self.xbound:
199 | features_encoded_2_2 = self.cross_attention_3_2(
200 | features_encoded_2, features_encoded_3, latent_tensor_2,
201 | latent_tensor_3)
202 | features_encoded_1_2 = self.cross_attention_3_1(
203 | features_encoded_1, features_encoded_2_2, latent_tensor_1,
204 | latent_tensor_2)
205 | else:
206 | features_encoded_2_2 = self.cross_attention_3_2(
207 | features_encoded_2, features_encoded_3)
208 | features_encoded_1_2 = self.cross_attention_3_1(
209 | features_encoded_1, features_encoded_2_2)
210 |
211 | # trans_feature_maps = self.trans_out_conv(
212 | # torch.cat([features_encoded_3_1, features_encoded_3_2], dim=1))
213 |
214 | # x[3] = trans_feature_maps
215 | # x[2] = torch.cat([x[2], features_encoded_2], dim=1)
216 | # x[1] = torch.cat([x[1], features_encoded_1], dim=1)
217 |
218 | features_stage2 = [
219 | x[0], features_encoded_1_2, features_encoded_2_2,
220 | features_encoded_3
221 | ]
222 |
223 | if self.im_num + self.ex_num > 0:
224 | return features_stage2, point_maps_1, point_maps_2, point_maps_3 #
225 | else:
226 | return features_stage2, None, None, None #
227 |
228 |
229 | if __name__ == '__main__':
230 | import os
231 | os.environ['CUDA_VISIBLE_DEVICES'] = '4'
232 | model = _segm_pvtv2(1).cuda()
233 | input_tensor = torch.randn(1, 3, 352, 352).cuda()
234 |
235 | prediction1 = model(input_tensor)
236 |
--------------------------------------------------------------------------------
/src/train.py:
--------------------------------------------------------------------------------
1 | import os, argparse, math
2 | import numpy as np
3 | from glob import glob
4 | from tqdm import tqdm
5 | import sys
6 |
7 | sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
8 | import torch.nn as nn
9 | import torch.nn.functional as F
10 | import torch.utils.data
11 | import torch.optim as optim
12 |
13 | from medpy.metric.binary import hd, dc, assd, jc
14 |
15 | from skimage import segmentation as skimage_seg
16 | from scipy.ndimage import distance_transform_edt as distance
17 | from torch.utils.tensorboard import SummaryWriter
18 |
19 | from torch.optim.lr_scheduler import ReduceLROnPlateau, StepLR, CosineAnnealingLR
20 | import time
21 |
22 |
23 | def get_cfg():
24 | parser = argparse.ArgumentParser()
25 | parser.add_argument('--arch', type=str, default='xboundformer')
26 | parser.add_argument('--gpu', type=str, default='4')
27 | parser.add_argument('--net_layer', type=int, default=50)
28 | parser.add_argument('--dataset', type=str, default='isic2016')
29 | parser.add_argument('--exp_name', type=str, default='test')
30 | parser.add_argument('--fold', type=str)
31 | parser.add_argument('--lr_seg', type=float, default=1e-4) #0.0003
32 | parser.add_argument('--n_epochs', type=int, default=150) #100
33 | parser.add_argument('--bt_size', type=int, default=12) #36
34 | parser.add_argument('--seg_loss', type=int, default=1, choices=[0, 1])
35 | parser.add_argument('--aug', type=int, default=1)
36 | parser.add_argument('--patience', type=int, default=500) #50
37 |
38 | # transformer
39 | parser.add_argument('--filter', type=int, default=0)
40 | parser.add_argument('--im_num', type=int, default=1)
41 | parser.add_argument('--ex_num', type=int, default=1)
42 | parser.add_argument('--xbound', type=int, default=1)
43 | parser.add_argument('--point_w', type=float, default=1)
44 |
45 | #log_dir name
46 | parser.add_argument('--folder_name', type=str, default='Default_folder')
47 |
48 | parse_config = parser.parse_args()
49 | print(parse_config)
50 | return parse_config
51 |
52 |
53 | def ce_loss(pred, gt):
54 | pred = torch.clamp(pred, 1e-6, 1 - 1e-6)
55 | return (-gt * torch.log(pred) - (1 - gt) * torch.log(1 - pred)).mean()
56 |
57 |
58 | def structure_loss(pred, mask):
59 | """ TransFuse train loss """
60 | """ Without sigmoid """
61 | weit = 1 + 5 * torch.abs(
62 | F.avg_pool2d(mask, kernel_size=31, stride=1, padding=15) - mask)
63 | wbce = F.binary_cross_entropy_with_logits(pred, mask, reduction='none')
64 | wbce = (weit * wbce).sum(dim=(2, 3)) / weit.sum(dim=(2, 3))
65 |
66 | pred = torch.sigmoid(pred)
67 | inter = ((pred * mask) * weit).sum(dim=(2, 3))
68 | union = ((pred + mask) * weit).sum(dim=(2, 3))
69 | wiou = 1 - (inter + 1) / (union - inter + 1)
70 | return (wbce + wiou).mean()
71 |
72 |
73 | #-------------------------- train func --------------------------#
74 | def train(epoch):
75 | model.train()
76 | iteration = 0
77 | for batch_idx, batch_data in enumerate(train_loader):
78 | # print(epoch, batch_idx)
79 | data = batch_data['image'].cuda().float()
80 | label = batch_data['label'].cuda().float()
81 | if parse_config.filter:
82 | point = (batch_data['filter_point_data'] > 0).cuda().float()
83 | else:
84 | point = (batch_data['point'] > 0).cuda().float()
85 | #point_All = (batch_data['point_All'] > 0).cuda().float()
86 |
87 | if parse_config.arch == 'transfuse':
88 | lateral_map_4, lateral_map_3, lateral_map_2 = model(data)
89 |
90 | loss4 = structure_loss(lateral_map_4, label)
91 | loss3 = structure_loss(lateral_map_3, label)
92 | loss2 = structure_loss(lateral_map_2, label)
93 |
94 | loss = 0.5 * loss2 + 0.3 * loss3 + 0.2 * loss4
95 |
96 | optimizer.zero_grad()
97 | loss.backward()
98 | optimizer.step()
99 | if (batch_idx + 1) % 10 == 0:
100 | print(
101 | 'Train Epoch: {} [{}/{} ({:.0f}%)]\t[lateral-2: {:.4f}, lateral-3: {:0.4f}, lateral-4: {:0.4f}]'
102 | .format(epoch, batch_idx * len(data),
103 | len(train_loader.dataset),
104 | 100. * batch_idx / len(train_loader), loss2.item(),
105 | loss3.item(), loss4.item()))
106 |
107 | else:
108 | P2, point_maps_pre, point_maps_pre1, point_maps_pre2 = model(data)
109 | if parse_config.im_num + parse_config.ex_num > 0:
110 | point_loss = 0.0
111 | point3 = F.max_pool2d(point, (32, 32), (32, 32))
112 | point2 = F.max_pool2d(point, (16, 16), (16, 16))
113 | point1 = F.max_pool2d(point, (8, 8), (8, 8))
114 |
115 | for point_pre, point_pre1, point_pre2 in zip(
116 | point_maps_pre, point_maps_pre1, point_maps_pre2):
117 | point_loss = point_loss + criteon(
118 | point_pre, point1) + criteon(
119 | point_pre1, point2) + criteon(point_pre2, point3)
120 | point_loss = point_loss / (3 * len(point_maps_pre1))
121 | seg_loss = 0.0
122 | for p in P2:
123 | seg_loss = seg_loss + structure_loss(p, label)
124 | seg_loss = seg_loss / len(P2)
125 | loss = seg_loss + parse_config.point_w * point_loss
126 |
127 | if batch_idx % 50 == 0:
128 | show_image = [label[0], F.sigmoid(P2[0][0])]
129 | for point_map in [
130 | point_maps_pre, point_maps_pre1, point_maps_pre2
131 | ]:
132 | tmp = F.interpolate(point_map[-1], size=(352, 352))[0]
133 | show_image.append(tmp)
134 | show_image = torch.cat(show_image, dim=2)
135 | show_image = show_image.repeat(3, 1, 1)
136 | show_image = torch.cat([data[0], show_image], dim=2)
137 |
138 | writer.add_image('pred/all',
139 | show_image,
140 | epoch * len(train_loader) + batch_idx,
141 | dataformats='CHW')
142 |
143 | else:
144 | point_loss = 0.0
145 | seg_loss = 0.0
146 | for p in P2:
147 | seg_loss = seg_loss + structure_loss(p, label)
148 | seg_loss = seg_loss / len(P2)
149 | loss = seg_loss + 2 * point_loss
150 |
151 | optimizer.zero_grad()
152 | loss.backward()
153 | optimizer.step()
154 | if (batch_idx + 1) % 10 == 0:
155 | print(
156 | 'Train Epoch: {} [{}/{} ({:.0f}%)]\t[lateral-2: {:.4f}, lateral-3: {:0.4f}, lateral-4: {:0.4f}]'
157 | .format(epoch, batch_idx * len(data),
158 | len(train_loader.dataset),
159 | 100. * batch_idx / len(train_loader), loss,
160 | seg_loss, point_loss))
161 |
162 | print("Iteration numbers: ", iteration)
163 |
164 |
165 | #-------------------------- eval func --------------------------#
166 | def evaluation(epoch, loader):
167 | model.eval()
168 | dice_value = 0
169 | iou_value = 0
170 | dice_average = 0
171 | iou_average = 0
172 | numm = 0
173 | for batch_idx, batch_data in enumerate(loader):
174 | data = batch_data['image'].cuda().float()
175 | label = batch_data['label'].cuda().float()
176 | point = (batch_data['point'] > 0).cuda().float()
177 | # point_All = (batch_data['point_data'] > 0).cuda().float()
178 | #point_All = nn.functional.max_pool2d(point_All,
179 | # kernel_size=(16, 16),
180 | # stride=(16, 16))
181 |
182 | with torch.no_grad():
183 | if parse_config.arch == 'transfuse':
184 | _, _, output = model(data)
185 | loss_fuse = structure_loss(output, label)
186 | elif parse_config.arch == 'xboundformer':
187 | output, point_maps_pre, point_maps_pre1, point_maps_pre2 = model(
188 | data)
189 | loss = 0
190 |
191 | if parse_config.arch == 'transfuse':
192 | loss = loss_fuse
193 |
194 | output = output.cpu().numpy() > 0.5
195 |
196 | label = label.cpu().numpy()
197 | assert (output.shape == label.shape)
198 | dice_ave = dc(output, label)
199 | iou_ave = jc(output, label)
200 | dice_value += dice_ave
201 | iou_value += iou_ave
202 | numm += 1
203 |
204 | dice_average = dice_value / numm
205 | iou_average = iou_value / numm
206 | writer.add_scalar('val_metrics/val_dice', dice_average, epoch)
207 | writer.add_scalar('val_metrics/val_iou', iou_average, epoch)
208 | print("Average dice value of evaluation dataset = ", dice_average)
209 | print("Average iou value of evaluation dataset = ", iou_average)
210 | return dice_average, iou_average, loss
211 |
212 |
213 | if __name__ == '__main__':
214 | #-------------------------- get args --------------------------#
215 | parse_config = get_cfg()
216 |
217 | #-------------------------- build loggers and savers --------------------------#
218 | exp_name = parse_config.dataset + '/' + parse_config.exp_name + '_loss_' + str(
219 | parse_config.seg_loss) + '_aug_' + str(
220 | parse_config.aug
221 | ) + '/' + parse_config.folder_name + '/fold_' + str(parse_config.fold)
222 |
223 | os.makedirs('logs/{}'.format(exp_name), exist_ok=True)
224 | os.makedirs('logs/{}/model'.format(exp_name), exist_ok=True)
225 | writer = SummaryWriter('logs/{}/log'.format(exp_name))
226 | save_path = 'logs/{}/model/best.pkl'.format(exp_name)
227 | latest_path = 'logs/{}/model/latest.pkl'.format(exp_name)
228 |
229 | EPOCHS = parse_config.n_epochs
230 | os.environ['CUDA_VISIBLE_DEVICES'] = parse_config.gpu
231 | device_ids = range(torch.cuda.device_count())
232 |
233 | #-------------------------- build dataloaders --------------------------#
234 | if parse_config.dataset == 'isic2018':
235 | from utils.isbi2018_new import norm01, myDataset
236 |
237 | dataset = myDataset(fold=parse_config.fold,
238 | split='train',
239 | aug=parse_config.aug)
240 | dataset2 = myDataset(fold=parse_config.fold, split='valid', aug=False)
241 | elif parse_config.dataset == 'isic2016':
242 | from utils.isbi2016_new import norm01, myDataset
243 |
244 | dataset = myDataset(split='train', aug=parse_config.aug)
245 | dataset2 = myDataset(split='valid', aug=False)
246 | else:
247 | raise NotImplementedError
248 |
249 | train_loader = torch.utils.data.DataLoader(dataset,
250 | batch_size=parse_config.bt_size,
251 | shuffle=True,
252 | num_workers=2,
253 | pin_memory=True,
254 | drop_last=True)
255 | val_loader = torch.utils.data.DataLoader(
256 | dataset2,
257 | batch_size=1, #parse_config.bt_size
258 | shuffle=False, #True
259 | num_workers=2,
260 | pin_memory=True,
261 | drop_last=False) #True
262 |
263 | #-------------------------- build models --------------------------#
264 | if parse_config.arch is 'xboundformer':
265 | from lib.xboundformer import _segm_pvtv2
266 | model = _segm_pvtv2(1, parse_config.im_num, parse_config.ex_num,
267 | parse_config.xbound, 352).cuda()
268 | elif parse_config.arch == 'transfuse':
269 | from lib.TransFuse.TransFuse import TransFuse_S
270 | model = TransFuse_S(pretrained=True).cuda()
271 |
272 | if len(device_ids) > 1: # 多卡训练
273 | model = torch.nn.DataParallel(model).cuda()
274 |
275 | optimizer = torch.optim.Adam(model.parameters(), lr=parse_config.lr_seg)
276 |
277 | #scheduler = ReduceLROnPlateau(optimizer, 'min', factor=0.5, patience=10)
278 | scheduler = CosineAnnealingLR(optimizer, T_max=20)
279 |
280 | criteon = [None, ce_loss][parse_config.seg_loss]
281 |
282 | #-------------------------- start training --------------------------#
283 |
284 | max_dice = 0
285 | max_iou = 0
286 | best_ep = 0
287 |
288 | min_loss = 10
289 | min_epoch = 0
290 |
291 | for epoch in range(1, EPOCHS + 1):
292 | print(optimizer.state_dict()['param_groups'][0]['lr'])
293 | start = time.time()
294 | train(epoch)
295 | dice, iou, loss = evaluation(epoch, val_loader)
296 | scheduler.step()
297 |
298 | if loss < min_loss:
299 | min_epoch = epoch
300 | min_loss = loss
301 | else:
302 | if epoch - min_epoch >= parse_config.patience:
303 | print('Early stopping!')
304 | break
305 | if iou > max_iou:
306 | max_iou = iou
307 | best_ep = epoch
308 | torch.save(model.state_dict(), save_path)
309 | else:
310 | if epoch - best_ep >= parse_config.patience:
311 | print('Early stopping!')
312 | break
313 | torch.save(model.state_dict(), latest_path)
314 | time_elapsed = time.time() - start
315 | print(
316 | 'Training and evaluating on epoch:{} complete in {:.0f}m {:.0f}s'.
317 | format(epoch, time_elapsed // 60, time_elapsed % 60))
318 |
--------------------------------------------------------------------------------
/lib/deeplabv3.py:
--------------------------------------------------------------------------------
1 | from lib.pvtv2 import pvt_v2_b2 #
2 | import torch
3 | from torch import nn
4 | import torch.nn.functional as F
5 | # from lib.Vision_Transformer import detr_Transformer, detr_BA_Transformer
6 | from lib.modules import BoundaryCrossAttention
7 | from lib.transformer import BoundaryAwareTransformer
8 | from lib.replknet import RepLKBlock
9 |
10 |
11 | def _segm_pvtv2(name, backbone_name, num_classes, output_stride,
12 | pretrained_backbone):
13 |
14 | if output_stride == 8:
15 | aspp_dilate = [12, 24, 36]
16 | else:
17 | aspp_dilate = [6, 12, 18]
18 |
19 | backbone = pvt_v2_b2()
20 | if pretrained_backbone:
21 | path = './pretrained_pth/pvt_v2_b2.pth'
22 | save_model = torch.load(path)
23 | model_dict = backbone.state_dict()
24 | state_dict = {
25 | k: v
26 | for k, v in save_model.items() if k in model_dict.keys()
27 | }
28 | model_dict.update(state_dict)
29 | backbone.load_state_dict(model_dict)
30 |
31 | inplanes = 512
32 | low_level_planes = 64
33 |
34 | if name == 'deeplabv3plus':
35 | classifier = DeepLabHeadV3Plus(inplanes, low_level_planes, num_classes,
36 | aspp_dilate)
37 |
38 | model = DeepLabV3(backbone, classifier)
39 | return model
40 |
41 |
42 | class DeepLabHeadV3Plus(nn.Module):
43 | def __init__(self,
44 | in_channels,
45 | low_level_channels,
46 | num_classes,
47 | aspp_dilate=[12, 24, 36]):
48 | super(DeepLabHeadV3Plus, self).__init__()
49 | self.project = nn.Sequential(
50 | nn.Conv2d(low_level_channels, 48, 1, bias=False),
51 | nn.BatchNorm2d(48),
52 | nn.ReLU(inplace=True),
53 | )
54 |
55 | # self.aspp = ASPP(in_channels, aspp_dilate)
56 |
57 | self.classifier = nn.Sequential(
58 | nn.Conv2d(560, 256, 3, padding=1, bias=False), #
59 | nn.BatchNorm2d(256),
60 | nn.ReLU(inplace=True),
61 | nn.Conv2d(256, num_classes, 1))
62 | self._init_weight()
63 |
64 | def forward(self, feature):
65 | # low_level_feature = self.project( feature['low_level'] )
66 | # output_feature = self.aspp(feature['out'])
67 | low_level_feature = self.project(feature[0])
68 | output_feature = feature[3]
69 | # output_feature = self.aspp(feature[3])
70 | output_feature = F.interpolate(output_feature,
71 | size=low_level_feature.shape[2:],
72 | mode='bilinear',
73 | align_corners=False)
74 | return self.classifier(
75 | torch.cat([low_level_feature, output_feature], dim=1))
76 |
77 | def _init_weight(self):
78 | for m in self.modules():
79 | if isinstance(m, nn.Conv2d):
80 | nn.init.kaiming_normal_(m.weight)
81 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
82 | nn.init.constant_(m.weight, 1)
83 | nn.init.constant_(m.bias, 0)
84 |
85 |
86 | class _SimpleSegmentationModel(nn.Module):
87 | def __init__(self, backbone, classifier):
88 | super(_SimpleSegmentationModel, self).__init__()
89 | self.backbone = backbone
90 | self.classifier = classifier
91 | # self.bat=BAT(num_classes=1,point_pred=1, in_channels=512,decoder=True, transformer_type_index=0)
92 | # self.bat_low=BAT(num_classes=1,point_pred=1,in_channels=320,hidden_features=320, decoder=False, transformer_type_index=0)
93 | # self.bat=BAT(in_channel=512)
94 | self.bat_low = BAT(in_channel=320)
95 | # self.bat0=BAT(in_channel=64)
96 | def forward(self, x):
97 | input_shape = x.shape[-2:]
98 | features = self.backbone(
99 | x
100 | ) # ([8, 64, 64, 64]) ([8, 128, 32, 32]) ([8, 320, 16, 16]) ([8, 512, 8, 8])
101 | # features[3],point_pre=self.bat(features[3])
102 | features[2], point_pre = self.bat_low(features[2])
103 | # features[0],point_pre=self.bat0(features[0])
104 | x = self.classifier(features)
105 | x = F.interpolate(x,
106 | size=input_shape,
107 | mode='bilinear',
108 | align_corners=False)
109 | # p=F.interpolate(point_pre, size=input_shape, mode='bilinear', align_corners=False)
110 | return x, point_pre
111 |
112 |
113 | class DeepLabV3(_SimpleSegmentationModel):
114 | """
115 | Implements DeepLabV3 model from
116 | `"Rethinking Atrous Convolution for Semantic Image Segmentation"
117 | `_.
118 | Arguments:
119 | backbone (nn.Module): the network used to compute the features for the model.
120 | The backbone should return an OrderedDict[Tensor], with the key being
121 | "out" for the last feature map used, and "aux" if an auxiliary classifier
122 | is used.
123 | classifier (nn.Module): module that takes the "out" element returned from
124 | the backbone and returns a dense prediction.
125 | aux_classifier (nn.Module, optional): auxiliary classifier used during training
126 | """
127 | pass
128 |
129 |
130 | class ASPP(nn.Module):
131 | def __init__(self, in_channels, atrous_rates):
132 | super(ASPP, self).__init__()
133 | out_channels = 256
134 | modules = []
135 | modules.append(
136 | nn.Sequential(nn.Conv2d(in_channels, out_channels, 1, bias=False),
137 | nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True)))
138 |
139 | rate1, rate2, rate3 = tuple(atrous_rates)
140 | modules.append(ASPPConv(in_channels, out_channels, rate1))
141 | modules.append(ASPPConv(in_channels, out_channels, rate2))
142 | modules.append(ASPPConv(in_channels, out_channels, rate3))
143 | modules.append(ASPPPooling(in_channels, out_channels))
144 |
145 | self.convs = nn.ModuleList(modules)
146 |
147 | self.project = nn.Sequential(
148 | nn.Conv2d(5 * out_channels, out_channels, 1, bias=False),
149 | nn.BatchNorm2d(out_channels),
150 | nn.ReLU(inplace=True),
151 | nn.Dropout(0.1),
152 | )
153 |
154 | def forward(self, x):
155 | res = []
156 | for conv in self.convs:
157 | res.append(conv(x))
158 | res = torch.cat(res, dim=1)
159 | return self.project(res)
160 |
161 |
162 | class ASPPConv(nn.Sequential):
163 | def __init__(self, in_channels, out_channels, dilation):
164 | modules = [
165 | nn.Conv2d(in_channels,
166 | out_channels,
167 | 3,
168 | padding=dilation,
169 | dilation=dilation,
170 | bias=False),
171 | nn.BatchNorm2d(out_channels),
172 | nn.ReLU(inplace=True)
173 | ]
174 | super(ASPPConv, self).__init__(*modules)
175 |
176 |
177 | class ASPPPooling(nn.Sequential):
178 | def __init__(self, in_channels, out_channels):
179 | super(ASPPPooling, self).__init__(
180 | nn.AdaptiveAvgPool2d(11), ###
181 | nn.Conv2d(in_channels, out_channels, 1, bias=False),
182 | nn.BatchNorm2d(out_channels),
183 | nn.ReLU(inplace=True))
184 |
185 | def forward(self, x):
186 | size = x.shape[-2:]
187 | x = super(ASPPPooling, self).forward(x)
188 | return F.interpolate(x,
189 | size=size,
190 | mode='bilinear',
191 | align_corners=False)
192 |
193 |
194 | def deeplabv3plus_pvtv2(num_classes=1,
195 | output_stride=8,
196 | pretrained_backbone=True):
197 | """Constructs a DeepLabV3 model with a pvtv2 backbone.
198 | Args:
199 | num_classes (int): number of classes.
200 | output_stride (int): output stride for deeplab.
201 | pretrained_backbone (bool): If True, use the pretrained backbone.
202 | """
203 | return _segm_pvtv2('deeplabv3plus',
204 | 'pvtv2',
205 | num_classes,
206 | output_stride=output_stride,
207 | pretrained_backbone=pretrained_backbone)
208 |
209 |
210 | class BAT(nn.Module):
211 | def __init__(self, in_channel=512):
212 | super(BAT, self).__init__()
213 | self.in_channel = in_channel
214 | self.conv1 = nn.Sequential(
215 | nn.Conv2d(in_channels=self.in_channel,
216 | out_channels=512,
217 | kernel_size=(3, 3),
218 | padding=1), nn.BatchNorm2d(512), nn.ReLU())
219 | self.conv2 = nn.Sequential(
220 | nn.Conv2d(in_channels=512,
221 | out_channels=256,
222 | kernel_size=(3, 3),
223 | padding=1), nn.BatchNorm2d(256), nn.ReLU())
224 | self.conv3 = nn.Sequential(
225 | nn.Conv2d(in_channels=256,
226 | out_channels=self.in_channel,
227 | kernel_size=(3, 3),
228 | padding=1), nn.BatchNorm2d(self.in_channel), nn.ReLU())
229 | # self.conv1=RepLKBlock(self.in_channel,128,31,5,0.)
230 | # self.conv2=RepLKBlock(self.in_channel,256,29,5,0.1)
231 | # self.conv3=RepLKBlock(self.in_channel,512,27,5,0.2)
232 | # self.conv5=RepLKBlock(self.in_channel,1024,13,5,0.3)
233 | self.conv4 = nn.Sequential(
234 | nn.Conv2d(in_channels=self.in_channel,
235 | out_channels=1,
236 | kernel_size=(1, 1)), nn.BatchNorm2d(1), nn.ReLU())
237 | # self.sigmoid=nn.Sigmoid()
238 | def forward(self, x):
239 | point1 = self.conv1(x)
240 | # # point1=point1+x
241 | point1 = self.conv2(point1)
242 | point1 = self.conv3(point1)
243 | # # point1=point1+point2
244 | # point2=self.conv5(point1)
245 | point1 = point1 + x
246 | point1 = self.conv4(point1)
247 | # point1=self.sigmoid(point1)
248 | return x, point1
249 |
250 |
251 | # class BAT(nn.Module):
252 | # def __init__(
253 | # self,
254 | # num_classes,
255 | # point_pred,
256 | # in_channels=512,
257 | # decoder=False,
258 | # transformer_type_index=0,
259 | # hidden_features=256, # 256
260 | # number_of_query_positions=1,
261 | # segmentation_attention_heads=8):
262 |
263 | # super(BAT, self).__init__()
264 |
265 | # self.num_classes = num_classes
266 | # self.point_pred = point_pred
267 | # self.transformer_type = "BoundaryAwareTransformer" if transformer_type_index == 0 else "Transformer"
268 | # self.use_decoder = decoder
269 |
270 | # self.in_channels = in_channels
271 |
272 | # # self.convolution_mapping = nn.Conv2d(in_channels=in_channels,
273 | # # out_channels=hidden_features,
274 | # # kernel_size=(1, 1),
275 | # # stride=(1, 1),
276 | # # padding=(0, 0),
277 | # # bias=True)
278 |
279 | # self.query_positions = nn.Parameter(data=torch.randn(
280 | # number_of_query_positions, hidden_features, dtype=torch.float),
281 | # requires_grad=True)
282 |
283 | # self.row_embedding = nn.Parameter(data=torch.randn(100,
284 | # hidden_features //
285 | # 2,
286 | # dtype=torch.float),
287 | # requires_grad=True)
288 | # self.column_embedding = nn.Parameter(data=torch.randn(
289 | # 100, hidden_features // 2, dtype=torch.float),
290 | # requires_grad=True)
291 |
292 | # # self.transformer =BoundaryAwareTransformer(d_model=hidden_features,normalize_before=False,num_encoder_layers=6,num_decoder_layers=2,Atrous=False)
293 | # self.transformer =BoundaryAwareTransformer(d_model=hidden_features,normalize_before=False,num_decoder_layers=0,point_pred_layers=1,Atrous=False)
294 |
295 | # if self.use_decoder:
296 | # self.BCA = BoundaryCrossAttention(hidden_features, 8)
297 |
298 | # # self.trans_out_conv = nn.Conv2d(in_channels=hidden_features,
299 | # # out_channels=in_channels,
300 | # # kernel_size=(1, 1),
301 | # # stride=(1, 1),
302 | # # padding=(0, 0),
303 | # # bias=True)
304 |
305 | # def forward(self, x):
306 | # h = x.size()[2]
307 | # w = x.size()[3]
308 | # feature_map = x
309 | # features=x
310 | # # features = self.convolution_mapping(feature_map)
311 | # height, width = features.shape[2:]
312 | # batch_size = features.shape[0]
313 | # positional_embeddings = torch.cat([
314 | # self.column_embedding[:height].unsqueeze(dim=0).repeat(
315 | # height, 1, 1),
316 | # self.row_embedding[:width].unsqueeze(dim=1).repeat(1, width, 1)
317 | # ],dim=-1).permute(2, 0, 1).unsqueeze(0).repeat(batch_size, 1, 1, 1)
318 |
319 | # if self.transformer_type == 'BoundaryAwareTransformer':
320 | # latent_tensor, features_encoded, point_maps = self.transformer(
321 | # features, None, self.query_positions, positional_embeddings)
322 | # else:
323 | # latent_tensor, features_encoded = self.transformer(
324 | # features, None, self.query_positions, positional_embeddings)
325 | # point_maps = []
326 |
327 | # latent_tensor = latent_tensor.permute(2, 0, 1)
328 | # # shape:(bs, 1 , 128)
329 |
330 | # # if self.use_decoder:
331 | # # features_encoded, point_dec = self.BCA(features_encoded,
332 | # # latent_tensor)
333 | # # point_maps.append(point_dec)
334 |
335 | # # trans_feature_maps = self.trans_out_conv(
336 | # # features_encoded.contiguous()) #.contiguous()
337 | # trans_feature_maps = features_encoded.contiguous()
338 | # trans_feature_maps = trans_feature_maps + feature_map
339 |
340 | # output = F.interpolate(
341 | # trans_feature_maps, size=(h, w),
342 | # mode="bilinear") # (shape: (batch_size, num_classes, h, w))
343 |
344 | # if self.point_pred == 1:
345 | # return output, point_maps[0]
346 |
347 | # return output
348 |
349 | if __name__ == '__main__':
350 | import os
351 | os.environ['CUDA_VISIBLE_DEVICES'] = '4'
352 | model = deeplabv3plus_pvtv2().cuda()
353 | input_tensor = torch.randn(1, 3, 352, 352).cuda()
354 |
355 | prediction1 = model(input_tensor)
356 | print(prediction1.size())
--------------------------------------------------------------------------------
/lib/TransFuse/vision_transformer.py:
--------------------------------------------------------------------------------
1 | """ Vision Transformer (ViT) in PyTorch
2 | A PyTorch implement of Vision Transformers as described in
3 | 'An Image Is Worth 16 x 16 Words: Transformers for Image Recognition at Scale' - https://arxiv.org/abs/2010.11929
4 | The official jax code is released and available at https://github.com/google-research/vision_transformer
5 | Status/TODO:
6 | * Models updated to be compatible with official impl. Args added to support backward compat for old PyTorch weights.
7 | * Weights ported from official jax impl for 384x384 base and small models, 16x16 and 32x32 patches.
8 | * Trained (supervised on ImageNet-1k) my custom 'small' patch model to 77.9, 'base' to 79.4 top-1 with this code.
9 | * Hopefully find time and GPUs for SSL or unsupervised pretraining on OpenImages w/ ImageNet fine-tune in future.
10 | Acknowledgments:
11 | * The paper authors for releasing code and weights, thanks!
12 | * I fixed my class token impl based on Phil Wang's https://github.com/lucidrains/vit-pytorch ... check it out
13 | for some einops/einsum fun
14 | * Simple transformer style inspired by Andrej Karpathy's https://github.com/karpathy/minGPT
15 | * Bert reference code checks against Huggingface Transformers and Tensorflow Bert
16 | Hacked together by / Copyright 2020 Ross Wightman
17 | """
18 | import torch
19 | import torch.nn as nn
20 | from functools import partial
21 |
22 | from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
23 | from timm.models.helpers import load_pretrained
24 |
25 | from timm.models.registry import register_model
26 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_
27 |
28 |
29 |
30 | def _cfg(url='', **kwargs):
31 | return {
32 | 'url': url,
33 | 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
34 | 'crop_pct': .9, 'interpolation': 'bicubic',
35 | 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
36 | 'first_conv': 'patch_embed.proj', 'classifier': 'head',
37 | **kwargs
38 | }
39 |
40 |
41 | default_cfgs = {
42 | # patch models
43 | 'vit_small_patch16_224': _cfg(
44 | url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/vit_small_p16_224-15ec54c9.pth',
45 | ),
46 | 'vit_base_patch16_224': _cfg(
47 | url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_224-80ecf9dd.pth',
48 | mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
49 | ),
50 | 'vit_base_patch16_384': _cfg(
51 | url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_384-83fb41ba.pth',
52 | input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0),
53 | 'vit_base_patch32_384': _cfg(
54 | url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p32_384-830016f5.pth',
55 | input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0),
56 | 'vit_large_patch16_224': _cfg(
57 | url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p16_224-4ee7a4dc.pth',
58 | mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
59 | 'vit_large_patch16_384': _cfg(
60 | url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p16_384-b3be5167.pth',
61 | input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0),
62 | 'vit_large_patch32_384': _cfg(
63 | url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p32_384-9b920ba8.pth',
64 | input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0),
65 | 'vit_huge_patch16_224': _cfg(),
66 | 'vit_huge_patch32_384': _cfg(input_size=(3, 384, 384)),
67 | # hybrid models
68 | 'vit_small_resnet26d_224': _cfg(),
69 | 'vit_small_resnet50d_s3_224': _cfg(),
70 | 'vit_base_resnet26d_224': _cfg(),
71 | 'vit_base_resnet50d_224': _cfg(),
72 | }
73 |
74 |
75 | class Mlp(nn.Module):
76 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
77 | super().__init__()
78 | out_features = out_features or in_features
79 | hidden_features = hidden_features or in_features
80 | self.fc1 = nn.Linear(in_features, hidden_features)
81 | self.act = act_layer()
82 | self.fc2 = nn.Linear(hidden_features, out_features)
83 | self.drop = nn.Dropout(drop)
84 |
85 | def forward(self, x):
86 | x = self.fc1(x)
87 | x = self.act(x)
88 | x = self.drop(x)
89 | x = self.fc2(x)
90 | x = self.drop(x)
91 | return x
92 |
93 |
94 | class Attention(nn.Module):
95 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
96 | super().__init__()
97 | self.num_heads = num_heads
98 | head_dim = dim // num_heads
99 | # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
100 | self.scale = qk_scale or head_dim ** -0.5
101 |
102 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
103 | self.attn_drop = nn.Dropout(attn_drop)
104 | self.proj = nn.Linear(dim, dim)
105 | self.proj_drop = nn.Dropout(proj_drop)
106 |
107 | def forward(self, x):
108 | B, N, C = x.shape
109 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
110 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
111 |
112 | attn = (q @ k.transpose(-2, -1)) * self.scale
113 | attn = attn.softmax(dim=-1)
114 | attn = self.attn_drop(attn)
115 |
116 | x = (attn @ v).transpose(1, 2).reshape(B, N, C)
117 | x = self.proj(x)
118 | x = self.proj_drop(x)
119 | return x
120 |
121 |
122 | class Block(nn.Module):
123 |
124 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
125 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
126 | super().__init__()
127 | self.norm1 = norm_layer(dim)
128 | self.attn = Attention(
129 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
130 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
131 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
132 | self.norm2 = norm_layer(dim)
133 | mlp_hidden_dim = int(dim * mlp_ratio)
134 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
135 |
136 | def forward(self, x):
137 | x = x + self.drop_path(self.attn(self.norm1(x)))
138 | x = x + self.drop_path(self.mlp(self.norm2(x)))
139 | return x
140 |
141 |
142 | class PatchEmbed(nn.Module):
143 | """ Image to Patch Embedding
144 | """
145 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
146 | super().__init__()
147 | img_size = to_2tuple(img_size)
148 | patch_size = to_2tuple(patch_size)
149 | num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
150 | self.img_size = img_size
151 | self.patch_size = patch_size
152 | self.num_patches = num_patches
153 |
154 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
155 |
156 | def forward(self, x):
157 | B, C, H, W = x.shape
158 |
159 | # FIXME look at relaxing size constraints
160 | #assert H == self.img_size[0] and W == self.img_size[1], \
161 | # f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
162 | x = self.proj(x).flatten(2).transpose(1, 2)
163 | return x
164 |
165 |
166 |
167 | class VisionTransformer(nn.Module):
168 | """ Vision Transformer with support for patch or hybrid CNN input stage
169 | """
170 | def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
171 | num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
172 | drop_path_rate=0., hybrid_backbone=None, norm_layer=nn.LayerNorm):
173 | super().__init__()
174 | self.num_classes = num_classes
175 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
176 |
177 | if hybrid_backbone is not None:
178 | self.patch_embed = HybridEmbed(
179 | hybrid_backbone, img_size=img_size, in_chans=in_chans, embed_dim=embed_dim)
180 | else:
181 | self.patch_embed = PatchEmbed(
182 | img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
183 | num_patches = self.patch_embed.num_patches
184 |
185 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
186 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
187 | self.pos_drop = nn.Dropout(p=drop_rate)
188 |
189 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
190 | self.blocks = nn.ModuleList([
191 | Block(
192 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
193 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer)
194 | for i in range(depth)])
195 | self.norm = norm_layer(embed_dim)
196 |
197 | # NOTE as per official impl, we could have a pre-logits representation dense layer + tanh here
198 | #self.repr = nn.Linear(embed_dim, representation_size)
199 | #self.repr_act = nn.Tanh()
200 |
201 | # Classifier head
202 | self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
203 |
204 | trunc_normal_(self.pos_embed, std=.02)
205 | trunc_normal_(self.cls_token, std=.02)
206 | self.apply(self._init_weights)
207 |
208 | def _init_weights(self, m):
209 | if isinstance(m, nn.Linear):
210 | trunc_normal_(m.weight, std=.02)
211 | if isinstance(m, nn.Linear) and m.bias is not None:
212 | nn.init.constant_(m.bias, 0)
213 | elif isinstance(m, nn.LayerNorm):
214 | nn.init.constant_(m.bias, 0)
215 | nn.init.constant_(m.weight, 1.0)
216 |
217 | @torch.jit.ignore
218 | def no_weight_decay(self):
219 | return {'pos_embed', 'cls_token'}
220 |
221 | def get_classifier(self):
222 | return self.head
223 |
224 | def reset_classifier(self, num_classes, global_pool=''):
225 | self.num_classes = num_classes
226 | self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
227 |
228 | def forward_features(self, x):
229 | B = x.shape[0]
230 | x = self.patch_embed(x)
231 |
232 | cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
233 | x = torch.cat((cls_tokens, x), dim=1)
234 | x = x + self.pos_embed
235 | x = self.pos_drop(x)
236 |
237 | for blk in self.blocks:
238 | x = blk(x)
239 |
240 | x = self.norm(x)
241 | return x[:, 0]
242 |
243 | def forward(self, x):
244 | x = self.forward_features(x)
245 | x = self.head(x)
246 | return x
247 |
248 |
249 | def _conv_filter(state_dict, patch_size=16):
250 | """ convert patch embedding weight from manual patchify + linear proj to conv"""
251 | out_dict = {}
252 | for k, v in state_dict.items():
253 | if 'patch_embed.proj.weight' in k:
254 | v = v.reshape((v.shape[0], 3, patch_size, patch_size))
255 | out_dict[k] = v
256 | return out_dict
257 |
258 |
259 | @register_model
260 | def vit_small_patch16_224(pretrained=False, **kwargs):
261 | if pretrained:
262 | # NOTE my scale was wrong for original weights, leaving this here until I have better ones for this model
263 | kwargs.setdefault('qk_scale', 768 ** -0.5)
264 | model = VisionTransformer(patch_size=16, embed_dim=768, depth=8, num_heads=8, mlp_ratio=3., **kwargs)
265 | model.default_cfg = default_cfgs['vit_small_patch16_224']
266 | if pretrained:
267 | load_pretrained(
268 | model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3), filter_fn=_conv_filter)
269 | return model
270 |
271 |
272 | @register_model
273 | def vit_base_patch16_224(pretrained=False, **kwargs):
274 | model = VisionTransformer(
275 | patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
276 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
277 | model.default_cfg = default_cfgs['vit_base_patch16_224']
278 | if pretrained:
279 | load_pretrained(
280 | model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3), filter_fn=_conv_filter)
281 | return model
282 |
283 |
284 | @register_model
285 | def vit_base_patch16_384(pretrained=False, **kwargs):
286 | model = VisionTransformer(
287 | img_size=384, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
288 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
289 | model.default_cfg = default_cfgs['vit_base_patch16_384']
290 | if pretrained:
291 | load_pretrained(model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3))
292 | return model
293 |
294 |
295 | @register_model
296 | def vit_base_patch32_384(pretrained=False, **kwargs):
297 | model = VisionTransformer(
298 | img_size=384, patch_size=32, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
299 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
300 | model.default_cfg = default_cfgs['vit_base_patch32_384']
301 | if pretrained:
302 | load_pretrained(model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3))
303 | return model
304 |
305 |
306 | @register_model
307 | def vit_large_patch16_224(pretrained=False, **kwargs):
308 | model = VisionTransformer(
309 | patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True,
310 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
311 | model.default_cfg = default_cfgs['vit_large_patch16_224']
312 | if pretrained:
313 | load_pretrained(model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3))
314 | return model
315 |
316 |
317 | @register_model
318 | def vit_large_patch16_384(pretrained=False, **kwargs):
319 | model = VisionTransformer(
320 | img_size=384, patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True,
321 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
322 | model.default_cfg = default_cfgs['vit_large_patch16_384']
323 | if pretrained:
324 | load_pretrained(model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3))
325 | return model
326 |
327 |
328 | @register_model
329 | def vit_large_patch32_384(pretrained=False, **kwargs):
330 | model = VisionTransformer(
331 | img_size=384, patch_size=32, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True,
332 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
333 | model.default_cfg = default_cfgs['vit_large_patch32_384']
334 | if pretrained:
335 | load_pretrained(model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3))
336 | return model
337 |
338 |
339 | @register_model
340 | def vit_huge_patch16_224(pretrained=False, **kwargs):
341 | model = VisionTransformer(patch_size=16, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4, **kwargs)
342 | model.default_cfg = default_cfgs['vit_huge_patch16_224']
343 | return model
344 |
345 |
346 | @register_model
347 | def vit_huge_patch32_384(pretrained=False, **kwargs):
348 | model = VisionTransformer(
349 | img_size=384, patch_size=32, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4, **kwargs)
350 | model.default_cfg = default_cfgs['vit_huge_patch32_384']
351 | return model
--------------------------------------------------------------------------------
/lib/TransFuse/TransFuse.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torchvision.models import resnet34 as resnet
4 | from torchvision.models import resnet50
5 | from .DeiT import deit_small_patch16_224 as deit
6 | from .DeiT import deit_base_patch16_224 as deit_b
7 | from torch.nn import CrossEntropyLoss, Dropout, Softmax, Linear, Conv2d, LayerNorm
8 | import torch.nn.functional as F
9 | import numpy as np
10 | import math
11 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_
12 |
13 |
14 | class ChannelPool(nn.Module):
15 | def forward(self, x):
16 | return torch.cat(
17 | (torch.max(x, 1)[0].unsqueeze(1), torch.mean(x, 1).unsqueeze(1)),
18 | dim=1)
19 |
20 |
21 | class BiFusion_block(nn.Module):
22 | def __init__(self, ch_1, ch_2, r_2, ch_int, ch_out, drop_rate=0.):
23 | super(BiFusion_block, self).__init__()
24 |
25 | # channel attention for F_g, use SE Block
26 | self.fc1 = nn.Conv2d(ch_2, ch_2 // r_2, kernel_size=1)
27 | self.relu = nn.ReLU(inplace=True)
28 | self.fc2 = nn.Conv2d(ch_2 // r_2, ch_2, kernel_size=1)
29 | self.sigmoid = nn.Sigmoid()
30 |
31 | # spatial attention for F_l
32 | self.compress = ChannelPool()
33 | self.spatial = Conv(2, 1, 7, bn=True, relu=False, bias=False)
34 |
35 | # bi-linear modelling for both
36 | self.W_g = Conv(ch_1, ch_int, 1, bn=True, relu=False)
37 | self.W_x = Conv(ch_2, ch_int, 1, bn=True, relu=False)
38 | self.W = Conv(ch_int, ch_int, 3, bn=True, relu=True)
39 |
40 | self.relu = nn.ReLU(inplace=True)
41 |
42 | self.residual = Residual(ch_1 + ch_2 + ch_int, ch_out)
43 |
44 | self.dropout = nn.Dropout2d(drop_rate)
45 | self.drop_rate = drop_rate
46 |
47 | def forward(self, g, x):
48 | # bilinear pooling
49 | W_g = self.W_g(g)
50 | W_x = self.W_x(x)
51 | bp = self.W(W_g * W_x)
52 |
53 | # spatial attention for cnn branch
54 | g_in = g
55 | g = self.compress(g)
56 | g = self.spatial(g)
57 | g = self.sigmoid(g) * g_in
58 |
59 | # channel attetion for transformer branch
60 | x_in = x
61 | x = x.mean((2, 3), keepdim=True)
62 | x = self.fc1(x)
63 | x = self.relu(x)
64 | x = self.fc2(x)
65 | x = self.sigmoid(x) * x_in
66 | fuse = self.residual(torch.cat([g, x, bp], 1))
67 |
68 | if self.drop_rate > 0:
69 | return self.dropout(fuse)
70 | else:
71 | return fuse
72 |
73 |
74 | class TransFuse_S(nn.Module):
75 | def __init__(self,
76 | num_classes=1,
77 | drop_rate=0.2,
78 | normal_init=True,
79 | pretrained=False):
80 | super(TransFuse_S, self).__init__()
81 |
82 | self.resnet = resnet()
83 | if pretrained:
84 | self.resnet.load_state_dict(
85 | torch.load(
86 | '/raid/wjc/code/xbound_former/lib/TransFuse/pretrained/resnet34-43635321.pth'
87 | ))
88 | self.resnet.fc = nn.Identity()
89 | self.resnet.layer4 = nn.Identity()
90 |
91 | self.transformer = deit(pretrained=pretrained)
92 |
93 | self.up1 = Up(in_ch1=384, out_ch=128)
94 | self.up2 = Up(128, 64)
95 |
96 | self.final_x = nn.Sequential(
97 | Conv(256, 64, 1, bn=True, relu=True),
98 | Conv(64, 64, 3, bn=True, relu=True),
99 | Conv(64, num_classes, 3, bn=False, relu=False))
100 |
101 | self.final_1 = nn.Sequential(
102 | Conv(64, 64, 3, bn=True, relu=True),
103 | Conv(64, num_classes, 3, bn=False, relu=False))
104 |
105 | self.final_2 = nn.Sequential(
106 | Conv(64, 64, 3, bn=True, relu=True),
107 | Conv(64, num_classes, 3, bn=False, relu=False))
108 |
109 | self.up_c = BiFusion_block(ch_1=256,
110 | ch_2=384,
111 | r_2=4,
112 | ch_int=256,
113 | ch_out=256,
114 | drop_rate=drop_rate / 2)
115 |
116 | self.up_c_1_1 = BiFusion_block(ch_1=128,
117 | ch_2=128,
118 | r_2=2,
119 | ch_int=128,
120 | ch_out=128,
121 | drop_rate=drop_rate / 2)
122 | self.up_c_1_2 = Up(in_ch1=256, out_ch=128, in_ch2=128, attn=True)
123 |
124 | self.up_c_2_1 = BiFusion_block(ch_1=64,
125 | ch_2=64,
126 | r_2=1,
127 | ch_int=64,
128 | ch_out=64,
129 | drop_rate=drop_rate / 2)
130 | self.up_c_2_2 = Up(128, 64, 64, attn=True)
131 |
132 | self.drop = nn.Dropout2d(drop_rate)
133 |
134 | if normal_init:
135 | self.init_weights()
136 |
137 | def forward(self, imgs, labels=None):
138 | # bottom-up path
139 | x_b = self.transformer(imgs)
140 | x_b = torch.transpose(x_b, 1, 2)
141 | x_b = x_b.view(x_b.shape[0], -1,
142 | imgs.size(2) // 16,
143 | imgs.size(3) // 16) # (x_b.shape[0], -1, 12, 16)
144 | x_b = self.drop(x_b)
145 |
146 | x_b_1 = self.up1(x_b)
147 | x_b_1 = self.drop(x_b_1)
148 |
149 | x_b_2 = self.up2(x_b_1) # transformer pred supervise here
150 | x_b_2 = self.drop(x_b_2)
151 |
152 | # top-down path
153 | x_u = self.resnet.conv1(imgs)
154 | x_u = self.resnet.bn1(x_u)
155 | x_u = self.resnet.relu(x_u)
156 | x_u = self.resnet.maxpool(x_u)
157 |
158 | x_u_2 = self.resnet.layer1(x_u)
159 | x_u_2 = self.drop(x_u_2)
160 |
161 | x_u_1 = self.resnet.layer2(x_u_2)
162 | x_u_1 = self.drop(x_u_1)
163 |
164 | x_u = self.resnet.layer3(x_u_1)
165 | x_u = self.drop(x_u)
166 |
167 | # joint path
168 | x_c = self.up_c(x_u, x_b)
169 |
170 | x_c_1_1 = self.up_c_1_1(x_u_1, x_b_1)
171 | x_c_1 = self.up_c_1_2(x_c, x_c_1_1)
172 |
173 | x_c_2_1 = self.up_c_2_1(x_u_2, x_b_2)
174 | x_c_2 = self.up_c_2_2(x_c_1,
175 | x_c_2_1) # joint predict low supervise here
176 |
177 | # decoder part
178 | map_x = F.interpolate(self.final_x(x_c),
179 | scale_factor=16,
180 | mode='bilinear')
181 | map_1 = F.interpolate(self.final_1(x_b_2),
182 | scale_factor=4,
183 | mode='bilinear')
184 | map_2 = F.interpolate(self.final_2(x_c_2),
185 | scale_factor=4,
186 | mode='bilinear')
187 | return map_x, map_1, map_2
188 |
189 | def init_weights(self):
190 | self.up1.apply(init_weights)
191 | self.up2.apply(init_weights)
192 | self.final_x.apply(init_weights)
193 | self.final_1.apply(init_weights)
194 | self.final_2.apply(init_weights)
195 | self.up_c.apply(init_weights)
196 | self.up_c_1_1.apply(init_weights)
197 | self.up_c_1_2.apply(init_weights)
198 | self.up_c_2_1.apply(init_weights)
199 | self.up_c_2_2.apply(init_weights)
200 |
201 |
202 | class TransFuse_L(nn.Module):
203 | def __init__(self,
204 | num_classes=1,
205 | drop_rate=0.2,
206 | normal_init=True,
207 | pretrained=False):
208 | super(TransFuse_L, self).__init__()
209 |
210 | self.resnet = resnet50()
211 | if pretrained:
212 | self.resnet.load_state_dict(
213 | torch.load(
214 | '/home/chenfei/.cache/torch/hub/checkpoints/resnet50-19c8e357.pth'
215 | ))
216 | self.resnet.fc = nn.Identity()
217 | self.resnet.layer4 = nn.Identity()
218 |
219 | self.transformer = deit_b(pretrained=pretrained)
220 |
221 | self.up1 = Up(in_ch1=768, out_ch=512)
222 | self.up2 = Up(512, 256)
223 |
224 | self.final_x = nn.Sequential(
225 | Conv(1024, 256, 1, bn=True, relu=True),
226 | Conv(256, 256, 3, bn=True, relu=True),
227 | Conv(256, num_classes, 3, bn=False, relu=False))
228 |
229 | self.final_1 = nn.Sequential(
230 | Conv(256, 256, 3, bn=True, relu=True),
231 | Conv(256, num_classes, 3, bn=False, relu=False))
232 |
233 | self.final_2 = nn.Sequential(
234 | Conv(256, 256, 3, bn=True, relu=True),
235 | Conv(256, num_classes, 3, bn=False, relu=False))
236 |
237 | self.up_c = BiFusion_block(ch_1=1024,
238 | ch_2=768,
239 | r_2=4,
240 | ch_int=1024,
241 | ch_out=1024,
242 | drop_rate=drop_rate / 2)
243 |
244 | self.up_c_1_1 = BiFusion_block(ch_1=512,
245 | ch_2=512,
246 | r_2=2,
247 | ch_int=512,
248 | ch_out=512,
249 | drop_rate=drop_rate / 2)
250 | self.up_c_1_2 = Up(in_ch1=1024, out_ch=512, in_ch2=512, attn=True)
251 |
252 | self.up_c_2_1 = BiFusion_block(ch_1=256,
253 | ch_2=256,
254 | r_2=1,
255 | ch_int=256,
256 | ch_out=256,
257 | drop_rate=drop_rate / 2)
258 | self.up_c_2_2 = Up(512, 256, 256, attn=True)
259 |
260 | self.drop = nn.Dropout2d(drop_rate)
261 |
262 | if normal_init:
263 | self.init_weights()
264 |
265 | def forward(self, imgs, labels=None):
266 | # bottom-up path
267 | x_b = self.transformer(imgs)
268 | x_b = torch.transpose(x_b, 1, 2)
269 | x_b = x_b.view(x_b.shape[0], -1, 32, 32) # (x_b.shape[0], -1, 12, 16)
270 | x_b = self.drop(x_b)
271 |
272 | x_b_1 = self.up1(x_b)
273 | x_b_1 = self.drop(x_b_1)
274 |
275 | x_b_2 = self.up2(x_b_1) # transformer pred supervise here
276 | x_b_2 = self.drop(x_b_2)
277 |
278 | # top-down path
279 | x_u = self.resnet.conv1(imgs)
280 | x_u = self.resnet.bn1(x_u)
281 | x_u = self.resnet.relu(x_u)
282 | x_u = self.resnet.maxpool(x_u)
283 |
284 | x_u_2 = self.resnet.layer1(x_u)
285 | x_u_2 = self.drop(x_u_2)
286 |
287 | x_u_1 = self.resnet.layer2(x_u_2)
288 | x_u_1 = self.drop(x_u_1)
289 |
290 | x_u = self.resnet.layer3(x_u_1)
291 | x_u = self.drop(x_u)
292 |
293 | # joint path
294 | x_c = self.up_c(x_u, x_b)
295 |
296 | x_c_1_1 = self.up_c_1_1(x_u_1, x_b_1)
297 | x_c_1 = self.up_c_1_2(x_c, x_c_1_1)
298 |
299 | x_c_2_1 = self.up_c_2_1(x_u_2, x_b_2)
300 | x_c_2 = self.up_c_2_2(x_c_1,
301 | x_c_2_1) # joint predict low supervise here
302 |
303 | # decoder part
304 | map_x = F.interpolate(self.final_x(x_c),
305 | scale_factor=16,
306 | mode='bilinear')
307 | map_1 = F.interpolate(self.final_1(x_b_2),
308 | scale_factor=4,
309 | mode='bilinear')
310 | map_2 = F.interpolate(self.final_2(x_c_2),
311 | scale_factor=4,
312 | mode='bilinear')
313 | return map_x, map_1, map_2
314 |
315 | def init_weights(self):
316 | self.up1.apply(init_weights)
317 | self.up2.apply(init_weights)
318 | self.final_x.apply(init_weights)
319 | self.final_1.apply(init_weights)
320 | self.final_2.apply(init_weights)
321 | self.up_c.apply(init_weights)
322 | self.up_c_1_1.apply(init_weights)
323 | self.up_c_1_2.apply(init_weights)
324 | self.up_c_2_1.apply(init_weights)
325 | self.up_c_2_2.apply(init_weights)
326 |
327 |
328 | def init_weights(m):
329 | """
330 | Initialize weights of layers using Kaiming Normal (He et al.) as argument of "Apply" function of
331 | "nn.Module"
332 | :param m: Layer to initialize
333 | :return: None
334 | """
335 | if isinstance(m, nn.Conv2d):
336 | '''
337 | fan_in, _ = nn.init._calculate_fan_in_and_fan_out(m.weight)
338 | trunc_normal_(m.weight, std=math.sqrt(1.0/fan_in)/.87962566103423978)
339 | if m.bias is not None:
340 | nn.init.zeros_(m.bias)
341 | '''
342 | nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu')
343 | if m.bias is not None:
344 | fan_in, _ = nn.init._calculate_fan_in_and_fan_out(m.weight)
345 | bound = 1 / math.sqrt(fan_in)
346 | nn.init.uniform_(m.bias, -bound, bound)
347 |
348 | elif isinstance(m, nn.BatchNorm2d):
349 | nn.init.constant_(m.weight, 1)
350 | nn.init.constant_(m.bias, 0)
351 |
352 |
353 | class Up(nn.Module):
354 | """Upscaling then double conv"""
355 | def __init__(self, in_ch1, out_ch, in_ch2=0, attn=False):
356 | super().__init__()
357 |
358 | self.up = nn.Upsample(scale_factor=2,
359 | mode='bilinear',
360 | align_corners=True)
361 | self.conv = DoubleConv(in_ch1 + in_ch2, out_ch)
362 |
363 | if attn:
364 | self.attn_block = Attention_block(in_ch1, in_ch2, out_ch)
365 | else:
366 | self.attn_block = None
367 |
368 | def forward(self, x1, x2=None):
369 |
370 | x1 = self.up(x1)
371 | # input is CHW
372 | if x2 is not None:
373 | diffY = torch.tensor([x2.size()[2] - x1.size()[2]])
374 | diffX = torch.tensor([x2.size()[3] - x1.size()[3]])
375 |
376 | x1 = F.pad(x1, [
377 | diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2
378 | ])
379 |
380 | if self.attn_block is not None:
381 | x2 = self.attn_block(x1, x2)
382 | x1 = torch.cat([x2, x1], dim=1)
383 | x = x1
384 | return self.conv(x)
385 |
386 |
387 | class Attention_block(nn.Module):
388 | def __init__(self, F_g, F_l, F_int):
389 | super(Attention_block, self).__init__()
390 | self.W_g = nn.Sequential(
391 | nn.Conv2d(F_g,
392 | F_int,
393 | kernel_size=1,
394 | stride=1,
395 | padding=0,
396 | bias=True), nn.BatchNorm2d(F_int))
397 | self.W_x = nn.Sequential(
398 | nn.Conv2d(F_l,
399 | F_int,
400 | kernel_size=1,
401 | stride=1,
402 | padding=0,
403 | bias=True), nn.BatchNorm2d(F_int))
404 | self.psi = nn.Sequential(
405 | nn.Conv2d(F_int, 1, kernel_size=1, stride=1, padding=0, bias=True),
406 | nn.BatchNorm2d(1), nn.Sigmoid())
407 | self.relu = nn.ReLU(inplace=True)
408 |
409 | def forward(self, g, x):
410 | g1 = self.W_g(g)
411 | x1 = self.W_x(x)
412 | psi = self.relu(g1 + x1)
413 | psi = self.psi(psi)
414 | return x * psi
415 |
416 |
417 | class DoubleConv(nn.Module):
418 | def __init__(self, in_channels, out_channels):
419 | super().__init__()
420 | self.double_conv = nn.Sequential(
421 | nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
422 | nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True),
423 | nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
424 | nn.BatchNorm2d(out_channels))
425 | self.identity = nn.Sequential(
426 | nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0),
427 | nn.BatchNorm2d(out_channels))
428 | self.relu = nn.ReLU(inplace=True)
429 |
430 | def forward(self, x):
431 | return self.relu(self.double_conv(x) + self.identity(x))
432 |
433 |
434 | class Residual(nn.Module):
435 | def __init__(self, inp_dim, out_dim):
436 | super(Residual, self).__init__()
437 | self.relu = nn.ReLU(inplace=True)
438 | self.bn1 = nn.BatchNorm2d(inp_dim)
439 | self.conv1 = Conv(inp_dim, int(out_dim / 2), 1, relu=False)
440 | self.bn2 = nn.BatchNorm2d(int(out_dim / 2))
441 | self.conv2 = Conv(int(out_dim / 2), int(out_dim / 2), 3, relu=False)
442 | self.bn3 = nn.BatchNorm2d(int(out_dim / 2))
443 | self.conv3 = Conv(int(out_dim / 2), out_dim, 1, relu=False)
444 | self.skip_layer = Conv(inp_dim, out_dim, 1, relu=False)
445 | if inp_dim == out_dim:
446 | self.need_skip = False
447 | else:
448 | self.need_skip = True
449 |
450 | def forward(self, x):
451 | if self.need_skip:
452 | residual = self.skip_layer(x)
453 | else:
454 | residual = x
455 | out = self.bn1(x)
456 | out = self.relu(out)
457 | out = self.conv1(out)
458 | out = self.bn2(out)
459 | out = self.relu(out)
460 | out = self.conv2(out)
461 | out = self.bn3(out)
462 | out = self.relu(out)
463 | out = self.conv3(out)
464 | out += residual
465 | return out
466 |
467 |
468 | class Conv(nn.Module):
469 | def __init__(self,
470 | inp_dim,
471 | out_dim,
472 | kernel_size=3,
473 | stride=1,
474 | bn=False,
475 | relu=True,
476 | bias=True):
477 | super(Conv, self).__init__()
478 | self.inp_dim = inp_dim
479 | self.conv = nn.Conv2d(inp_dim,
480 | out_dim,
481 | kernel_size,
482 | stride,
483 | padding=(kernel_size - 1) // 2,
484 | bias=bias)
485 | self.relu = None
486 | self.bn = None
487 | if relu:
488 | self.relu = nn.ReLU(inplace=True)
489 | if bn:
490 | self.bn = nn.BatchNorm2d(out_dim)
491 |
492 | def forward(self, x):
493 | assert x.size()[1] == self.inp_dim, "{} {}".format(
494 | x.size()[1], self.inp_dim)
495 | x = self.conv(x)
496 | if self.bn is not None:
497 | x = self.bn(x)
498 | if self.relu is not None:
499 | x = self.relu(x)
500 | return x
--------------------------------------------------------------------------------
/lib/replknet.py:
--------------------------------------------------------------------------------
1 | # Scaling Up Your Kernels to 31x31: Revisiting Large Kernel Design in CNNs (https://arxiv.org/abs/2203.06717)
2 | # Github source: https://github.com/DingXiaoH/RepLKNet-pytorch
3 | # Licensed under The MIT License [see LICENSE for details]
4 | # Based on ConvNeXt, timm, DINO and DeiT code bases
5 | # https://github.com/facebookresearch/ConvNeXt
6 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm
7 | # https://github.com/facebookresearch/deit/
8 | # https://github.com/facebookresearch/dino
9 | # --------------------------------------------------------'
10 | import torch
11 | import torch.nn as nn
12 | import torch.utils.checkpoint as checkpoint
13 | from timm.models.layers import DropPath
14 | import sys
15 | import os
16 |
17 | def get_conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias):
18 | if type(kernel_size) is int:
19 | use_large_impl = kernel_size > 5
20 | else:
21 | assert len(kernel_size) == 2 and kernel_size[0] == kernel_size[1]
22 | use_large_impl = kernel_size[0] > 5
23 | has_large_impl = 'LARGE_KERNEL_CONV_IMPL' in os.environ
24 | if has_large_impl and in_channels == out_channels and out_channels == groups and use_large_impl and stride == 1 and padding == kernel_size // 2 and dilation == 1:
25 | sys.path.append(os.environ['LARGE_KERNEL_CONV_IMPL'])
26 | # Please follow the instructions https://github.com/DingXiaoH/RepLKNet-pytorch/blob/main/README.md
27 | # export LARGE_KERNEL_CONV_IMPL=absolute_path_to_where_you_cloned_the_example (i.e., depthwise_conv2d_implicit_gemm.py)
28 | # TODO more efficient PyTorch implementations of large-kernel convolutions. Pull requests are welcomed.
29 | # Or you may try MegEngine. We have integrated an efficient implementation into MegEngine and it will automatically use it.
30 | from depthwise_conv2d_implicit_gemm import DepthWiseConv2dImplicitGEMM
31 | return DepthWiseConv2dImplicitGEMM(in_channels, kernel_size, bias=bias)
32 | else:
33 | return nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride,
34 | padding=padding, dilation=dilation, groups=groups, bias=bias)
35 |
36 | use_sync_bn = False
37 |
38 | def enable_sync_bn():
39 | global use_sync_bn
40 | use_sync_bn = True
41 |
42 | def get_bn(channels):
43 | if use_sync_bn:
44 | return nn.SyncBatchNorm(channels)
45 | else:
46 | return nn.BatchNorm2d(channels)
47 |
48 | def conv_bn(in_channels, out_channels, kernel_size, stride, padding, groups, dilation=1):
49 | if padding is None:
50 | padding = kernel_size // 2
51 | result = nn.Sequential()
52 | result.add_module('conv', get_conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,
53 | stride=stride, padding=padding, dilation=dilation, groups=groups, bias=False))
54 | result.add_module('bn', get_bn(out_channels))
55 | return result
56 |
57 | def conv_bn_relu(in_channels, out_channels, kernel_size, stride, padding, groups, dilation=1):
58 | if padding is None:
59 | padding = kernel_size // 2
60 | result = conv_bn(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,
61 | stride=stride, padding=padding, groups=groups, dilation=dilation)
62 | result.add_module('nonlinear', nn.ReLU())
63 | return result
64 |
65 | def fuse_bn(conv, bn):
66 | kernel = conv.weight
67 | running_mean = bn.running_mean
68 | running_var = bn.running_var
69 | gamma = bn.weight
70 | beta = bn.bias
71 | eps = bn.eps
72 | std = (running_var + eps).sqrt()
73 | t = (gamma / std).reshape(-1, 1, 1, 1)
74 | return kernel * t, beta - running_mean * gamma / std
75 |
76 | class ReparamLargeKernelConv(nn.Module):
77 |
78 | def __init__(self, in_channels, out_channels, kernel_size,
79 | stride, groups,
80 | small_kernel,
81 | small_kernel_merged=False):
82 | super(ReparamLargeKernelConv, self).__init__()
83 | self.kernel_size = kernel_size
84 | self.small_kernel = small_kernel
85 | # We assume the conv does not change the feature map size, so padding = k//2. Otherwise, you may configure padding as you wish, and change the padding of small_conv accordingly.
86 | padding = kernel_size // 2
87 | if small_kernel_merged:
88 | self.lkb_reparam = get_conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,
89 | stride=stride, padding=padding, dilation=1, groups=groups, bias=True)
90 | else:
91 | self.lkb_origin = conv_bn(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,
92 | stride=stride, padding=padding, dilation=1, groups=groups)
93 | if small_kernel is not None:
94 | assert small_kernel <= kernel_size, 'The kernel size for re-param cannot be larger than the large kernel!'
95 | self.small_conv = conv_bn(in_channels=in_channels, out_channels=out_channels, kernel_size=small_kernel,
96 | stride=stride, padding=small_kernel//2, groups=groups, dilation=1)
97 |
98 | def forward(self, inputs):
99 | if hasattr(self, 'lkb_reparam'):
100 | out = self.lkb_reparam(inputs)
101 | else:
102 | out = self.lkb_origin(inputs)
103 | if hasattr(self, 'small_conv'):
104 | out += self.small_conv(inputs)
105 | return out
106 |
107 | def get_equivalent_kernel_bias(self):
108 | eq_k, eq_b = fuse_bn(self.lkb_origin.conv, self.lkb_origin.bn)
109 | if hasattr(self, 'small_conv'):
110 | small_k, small_b = fuse_bn(self.small_conv.conv, self.small_conv.bn)
111 | eq_b += small_b
112 | # add to the central part
113 | eq_k += nn.functional.pad(small_k, [(self.kernel_size - self.small_kernel) // 2] * 4)
114 | return eq_k, eq_b
115 |
116 | def merge_kernel(self):
117 | eq_k, eq_b = self.get_equivalent_kernel_bias()
118 | self.lkb_reparam = get_conv2d(in_channels=self.lkb_origin.conv.in_channels,
119 | out_channels=self.lkb_origin.conv.out_channels,
120 | kernel_size=self.lkb_origin.conv.kernel_size, stride=self.lkb_origin.conv.stride,
121 | padding=self.lkb_origin.conv.padding, dilation=self.lkb_origin.conv.dilation,
122 | groups=self.lkb_origin.conv.groups, bias=True)
123 | self.lkb_reparam.weight.data = eq_k
124 | self.lkb_reparam.bias.data = eq_b
125 | self.__delattr__('lkb_origin')
126 | if hasattr(self, 'small_conv'):
127 | self.__delattr__('small_conv')
128 |
129 |
130 | class ConvFFN(nn.Module):
131 |
132 | def __init__(self, in_channels, internal_channels, out_channels, drop_path):
133 | super().__init__()
134 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
135 | self.preffn_bn = get_bn(in_channels)
136 | self.pw1 = conv_bn(in_channels=in_channels, out_channels=internal_channels, kernel_size=1, stride=1, padding=0, groups=1)
137 | self.pw2 = conv_bn(in_channels=internal_channels, out_channels=out_channels, kernel_size=1, stride=1, padding=0, groups=1)
138 | self.nonlinear = nn.GELU()
139 |
140 | def forward(self, x):
141 | out = self.preffn_bn(x)
142 | out = self.pw1(out)
143 | out = self.nonlinear(out)
144 | out = self.pw2(out)
145 | return x + self.drop_path(out)
146 |
147 |
148 | class RepLKBlock(nn.Module):
149 |
150 | def __init__(self, in_channels, dw_channels, block_lk_size, small_kernel, drop_path, small_kernel_merged=False):
151 | super().__init__()
152 | self.pw1 = conv_bn_relu(in_channels, dw_channels, 1, 1, 0, groups=1)
153 | self.pw2 = conv_bn(dw_channels, in_channels, 1, 1, 0, groups=1)
154 | self.large_kernel = ReparamLargeKernelConv(in_channels=dw_channels, out_channels=dw_channels, kernel_size=block_lk_size,
155 | stride=1, groups=dw_channels, small_kernel=small_kernel, small_kernel_merged=small_kernel_merged)
156 | self.lk_nonlinear = nn.ReLU()
157 | self.prelkb_bn = get_bn(in_channels)
158 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
159 | print('drop path:', self.drop_path)
160 |
161 | def forward(self, x):
162 | out = self.prelkb_bn(x)
163 | out = self.pw1(out)
164 | out = self.large_kernel(out)
165 | out = self.lk_nonlinear(out)
166 | out = self.pw2(out)
167 | return x + self.drop_path(out)
168 |
169 |
170 | class RepLKNetStage(nn.Module):
171 |
172 | def __init__(self, channels, num_blocks, stage_lk_size, drop_path,
173 | small_kernel, dw_ratio=1, ffn_ratio=4,
174 | use_checkpoint=False, # train with torch.utils.checkpoint to save memory
175 | small_kernel_merged=False,
176 | norm_intermediate_features=False):
177 | super().__init__()
178 | self.use_checkpoint = use_checkpoint
179 | blks = []
180 | for i in range(num_blocks):
181 | block_drop_path = drop_path[i] if isinstance(drop_path, list) else drop_path
182 | # Assume all RepLK Blocks within a stage share the same lk_size. You may tune it on your own model.
183 | replk_block = RepLKBlock(in_channels=channels, dw_channels=int(channels * dw_ratio), block_lk_size=stage_lk_size,
184 | small_kernel=small_kernel, drop_path=block_drop_path, small_kernel_merged=small_kernel_merged)
185 | convffn_block = ConvFFN(in_channels=channels, internal_channels=int(channels * ffn_ratio), out_channels=channels,
186 | drop_path=block_drop_path)
187 | blks.append(replk_block)
188 | blks.append(convffn_block)
189 | self.blocks = nn.ModuleList(blks)
190 | if norm_intermediate_features:
191 | self.norm = get_bn(channels) # Only use this with RepLKNet-XL on downstream tasks
192 | else:
193 | self.norm = nn.Identity()
194 |
195 | def forward(self, x):
196 | for blk in self.blocks:
197 | if self.use_checkpoint:
198 | x = checkpoint.checkpoint(blk, x) # Save training memory
199 | else:
200 | x = blk(x)
201 | return x
202 |
203 | class RepLKNet(nn.Module):
204 |
205 | def __init__(self, large_kernel_sizes, layers, channels, drop_path_rate, small_kernel,
206 | dw_ratio=1, ffn_ratio=4, in_channels=3, num_classes=1000, out_indices=None,
207 | use_checkpoint=False,
208 | small_kernel_merged=False,
209 | use_sync_bn=True,
210 | norm_intermediate_features=False # for RepLKNet-XL on COCO and ADE20K, use an extra BN to normalize the intermediate feature maps then feed them into the heads
211 | ):
212 | super().__init__()
213 |
214 | if num_classes is None and out_indices is None:
215 | raise ValueError('must specify one of num_classes (for pretraining) and out_indices (for downstream tasks)')
216 | elif num_classes is not None and out_indices is not None:
217 | raise ValueError('cannot specify both num_classes (for pretraining) and out_indices (for downstream tasks)')
218 | elif num_classes is not None and norm_intermediate_features:
219 | raise ValueError('for pretraining, no need to normalize the intermediate feature maps')
220 | self.out_indices = out_indices
221 | if use_sync_bn:
222 | enable_sync_bn()
223 |
224 | base_width = channels[0]
225 | self.use_checkpoint = use_checkpoint
226 | self.norm_intermediate_features = norm_intermediate_features
227 | self.num_stages = len(layers)
228 | self.stem = nn.ModuleList([
229 | conv_bn_relu(in_channels=in_channels, out_channels=base_width, kernel_size=3, stride=2, padding=1, groups=1),
230 | conv_bn_relu(in_channels=base_width, out_channels=base_width, kernel_size=3, stride=1, padding=1, groups=base_width),
231 | conv_bn_relu(in_channels=base_width, out_channels=base_width, kernel_size=1, stride=1, padding=0, groups=1),
232 | conv_bn_relu(in_channels=base_width, out_channels=base_width, kernel_size=3, stride=2, padding=1, groups=base_width)])
233 | # stochastic depth. We set block-wise drop-path rate. The higher level blocks are more likely to be dropped. This implementation follows Swin.
234 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(layers))]
235 | self.stages = nn.ModuleList()
236 | self.transitions = nn.ModuleList()
237 | for stage_idx in range(self.num_stages):
238 | layer = RepLKNetStage(channels=channels[stage_idx], num_blocks=layers[stage_idx],
239 | stage_lk_size=large_kernel_sizes[stage_idx],
240 | drop_path=dpr[sum(layers[:stage_idx]):sum(layers[:stage_idx + 1])],
241 | small_kernel=small_kernel, dw_ratio=dw_ratio, ffn_ratio=ffn_ratio,
242 | use_checkpoint=use_checkpoint, small_kernel_merged=small_kernel_merged,
243 | norm_intermediate_features=norm_intermediate_features)
244 | self.stages.append(layer)
245 | if stage_idx < len(layers) - 1:
246 | transition = nn.Sequential(
247 | conv_bn_relu(channels[stage_idx], channels[stage_idx + 1], 1, 1, 0, groups=1),
248 | conv_bn_relu(channels[stage_idx + 1], channels[stage_idx + 1], 3, stride=2, padding=1, groups=channels[stage_idx + 1]))
249 | self.transitions.append(transition)
250 |
251 | if num_classes is not None:
252 | self.norm = get_bn(channels[-1])
253 | self.avgpool = nn.AdaptiveAvgPool2d(1)
254 | self.head = nn.Linear(channels[-1], num_classes)
255 |
256 |
257 |
258 | def forward_features(self, x):
259 | x = self.stem[0](x)
260 | for stem_layer in self.stem[1:]:
261 | if self.use_checkpoint:
262 | x = checkpoint.checkpoint(stem_layer, x) # save memory
263 | else:
264 | x = stem_layer(x)
265 |
266 | if self.out_indices is None:
267 | # Just need the final output
268 | for stage_idx in range(self.num_stages):
269 | x = self.stages[stage_idx](x)
270 | if stage_idx < self.num_stages - 1:
271 | x = self.transitions[stage_idx](x)
272 | return x
273 | else:
274 | # Need the intermediate feature maps
275 | outs = []
276 | for stage_idx in range(self.num_stages):
277 | x = self.stages[stage_idx](x)
278 | if stage_idx in self.out_indices:
279 | outs.append(self.stages[stage_idx].norm(x)) # For RepLKNet-XL normalize the features before feeding them into the heads
280 | if stage_idx < self.num_stages - 1:
281 | x = self.transitions[stage_idx](x)
282 | return outs
283 |
284 | def forward(self, x):
285 | x = self.forward_features(x)
286 | if self.out_indices:
287 | return x
288 | else:
289 | x = self.norm(x)
290 | x = self.avgpool(x)
291 | x = torch.flatten(x, 1)
292 | x = self.head(x)
293 | return x
294 |
295 | def structural_reparam(self):
296 | for m in self.modules():
297 | if hasattr(m, 'merge_kernel'):
298 | m.merge_kernel()
299 |
300 | # If your framework cannot automatically fuse BN for inference, you may do it manually.
301 | # The BNs after and before conv layers can be removed.
302 | # No need to call this if your framework support automatic BN fusion.
303 | def deep_fuse_BN(self):
304 | for m in self.modules():
305 | if not isinstance(m, nn.Sequential):
306 | continue
307 | if not len(m) in [2, 3]: # Only handle conv-BN or conv-BN-relu
308 | continue
309 | # If you use a custom Conv2d impl, assume it also has 'kernel_size' and 'weight'
310 | if hasattr(m[0], 'kernel_size') and hasattr(m[0], 'weight') and isinstance(m[1], nn.BatchNorm2d):
311 | conv = m[0]
312 | bn = m[1]
313 | fused_kernel, fused_bias = fuse_bn(conv, bn)
314 | fused_conv = get_conv2d(conv.in_channels, conv.out_channels, kernel_size=conv.kernel_size,
315 | stride=conv.stride,
316 | padding=conv.padding, dilation=conv.dilation, groups=conv.groups, bias=True)
317 | fused_conv.weight.data = fused_kernel
318 | fused_conv.bias.data = fused_bias
319 | m[0] = fused_conv
320 | m[1] = nn.Identity()
321 |
322 |
323 | def create_RepLKNet31B(drop_path_rate=0.3, num_classes=1000, use_checkpoint=True, small_kernel_merged=False):
324 | return RepLKNet(large_kernel_sizes=[31,29,27,13], layers=[2,2,18,2], channels=[128,256,512,1024],
325 | drop_path_rate=drop_path_rate, small_kernel=5, num_classes=num_classes, use_checkpoint=use_checkpoint,
326 | small_kernel_merged=small_kernel_merged)
327 |
328 | def create_RepLKNet31L(drop_path_rate=0.3, num_classes=1000, use_checkpoint=True, small_kernel_merged=False):
329 | return RepLKNet(large_kernel_sizes=[31,29,27,13], layers=[2,2,18,2], channels=[192,384,768,1536],
330 | drop_path_rate=drop_path_rate, small_kernel=5, num_classes=num_classes, use_checkpoint=use_checkpoint,
331 | small_kernel_merged=small_kernel_merged)
332 |
333 | def create_RepLKNetXL(drop_path_rate=0.3, num_classes=1000, use_checkpoint=True, small_kernel_merged=False):
334 | return RepLKNet(large_kernel_sizes=[27,27,27,13], layers=[2,2,18,2], channels=[256,512,1024,2048],
335 | drop_path_rate=drop_path_rate, small_kernel=None, dw_ratio=1.5,
336 | num_classes=num_classes, use_checkpoint=use_checkpoint,
337 | small_kernel_merged=small_kernel_merged)
338 |
339 | if __name__ == '__main__':
340 | os.environ['CUDA_VISIBLE_DEVICES']='3'
341 | model = create_RepLKNet31B(small_kernel_merged=False)
342 | model.eval()
343 | # print('------------------- training-time model -------------')
344 | # print(model)
345 | # x = torch.randn(2, 3, 224, 224).cuda()
346 | # origin_y = model(x)
347 | # model.structural_reparam()
348 | # print('------------------- after re-param -------------')
349 | # print(model)
350 | # reparam_y = model(x)
351 | # print('------------------- the difference is ------------------------')
352 | # print((origin_y - reparam_y).abs().sum())
353 |
354 | x = torch.randn(1, 320, 16, 16).cuda()
355 | # dpr = [x.item() for x in torch.linspace(0, 0.3, 1)]
356 | # dpr[sum(layers[:stage_idx]):sum(layers[:stage_idx + 1])]
357 | # model=RepLKBlock(320,512,31,5,0.3).cuda()
358 | origin_y = model(x)
359 | print(origin_y.shape)
--------------------------------------------------------------------------------
/lib/transformer.py:
--------------------------------------------------------------------------------
1 | import copy
2 | from typing import Optional, List
3 |
4 | import torch
5 | import torch.nn.functional as F
6 | from torch import nn, Tensor
7 |
8 | from .modules import BoundaryWiseAttentionGate2D, BoundaryWiseAttentionGateAtrous2D
9 |
10 |
11 | class Transformer(nn.Module):
12 | def __init__(self,
13 | d_model=512,
14 | nhead=8,
15 | num_encoder_layers=6,
16 | num_decoder_layers=2,
17 | dim_feedforward=2048,
18 | dropout=0.1,
19 | activation=nn.LeakyReLU,
20 | normalize_before=False,
21 | return_intermediate_dec=False):
22 | super().__init__()
23 |
24 | encoder_layer = TransformerEncoderLayer(d_model, nhead,
25 | dim_feedforward, dropout,
26 | activation, normalize_before)
27 | encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
28 | self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers,
29 | encoder_norm)
30 | decoder_layer = TransformerDecoderLayer(d_model, nhead,
31 | dim_feedforward, dropout,
32 | activation, normalize_before)
33 | decoder_norm = nn.LayerNorm(d_model)
34 | self.decoder = TransformerDecoder(
35 | decoder_layer,
36 | num_decoder_layers,
37 | decoder_norm,
38 | return_intermediate=return_intermediate_dec)
39 | self._reset_parameters()
40 |
41 | self.d_model = d_model
42 | self.nhead = nhead
43 |
44 | def _reset_parameters(self):
45 | for p in self.parameters():
46 | if p.dim() > 1:
47 | nn.init.xavier_uniform_(p)
48 |
49 | def forward(self, src, mask, query_embed, pos_embed):
50 | bs, c, h, w = src.shape
51 | src = src.flatten(2).permute(2, 0, 1)
52 | pos_embed = pos_embed.flatten(2).permute(2, 0, 1)
53 | query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1)
54 | if mask is not None:
55 | mask = mask.flatten(1)
56 |
57 | tgt = torch.zeros_like(query_embed)
58 | memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed)
59 | # print("Trans Encoder",memory.shape)
60 | hs = self.decoder(tgt,
61 | memory,
62 | memory_key_padding_mask=mask,
63 | pos=pos_embed,
64 | query_pos=query_embed)
65 | return hs.transpose(1, 2), memory.permute(1, 2, 0).view(bs, c, h, w)
66 |
67 |
68 | class BoundaryAwareTransformer(nn.Module):
69 | def __init__(self,
70 | point_pred_layers=2,
71 | d_model=512,
72 | nhead=8,
73 | num_encoder_layers=2,
74 | num_decoder_layers=2,
75 | dim_feedforward=2048,
76 | dropout=0.1,
77 | activation=nn.LeakyReLU,
78 | normalize_before=False,
79 | return_intermediate_dec=False,
80 | BAG_type='2D',
81 | Atrous=False):
82 | super().__init__()
83 | self.num_decoder_layers = num_decoder_layers
84 |
85 | encoder_layer = BoundaryAwareTransformerEncoderLayer(
86 | d_model, nhead, BAG_type, Atrous, dim_feedforward, dropout,
87 | activation, normalize_before)
88 | encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
89 | self.encoder = BoundaryAwareTransformerEncoder(point_pred_layers,
90 | encoder_layer,
91 | num_encoder_layers,
92 | encoder_norm)
93 | if num_decoder_layers > 0:
94 | decoder_layer = TransformerDecoderLayer(d_model, nhead,
95 | dim_feedforward, dropout,
96 | activation,
97 | normalize_before)
98 | decoder_norm = nn.LayerNorm(d_model)
99 | self.decoder = TransformerDecoder(
100 | decoder_layer,
101 | num_decoder_layers,
102 | decoder_norm,
103 | return_intermediate=return_intermediate_dec)
104 | self._reset_parameters()
105 |
106 | self.d_model = d_model
107 | self.nhead = nhead
108 |
109 | def _reset_parameters(self):
110 | for p in self.parameters():
111 | if p.dim() > 1:
112 | nn.init.xavier_uniform_(p)
113 |
114 | def forward(self, src, mask, query_embed, pos_embed):
115 | bs, c, h, w = src.shape
116 | src = src.flatten(2).permute(2, 0, 1)
117 | pos_embed = pos_embed.flatten(2).permute(2, 0, 1)
118 | query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1)
119 | if mask is not None:
120 | mask = mask.flatten(1)
121 |
122 | tgt = torch.zeros_like(query_embed)
123 | memory, weights = self.encoder(src,
124 | src_key_padding_mask=mask,
125 | pos=pos_embed,
126 | height=h,
127 | width=w)
128 | if self.num_decoder_layers > 0:
129 | hs = self.decoder(tgt,
130 | memory,
131 | memory_key_padding_mask=mask,
132 | pos=pos_embed,
133 | query_pos=query_embed)
134 | return hs.transpose(1, 2), memory.permute(1, 2,
135 | 0).view(bs, c, h,
136 | w), weights
137 | else:
138 | return tgt.transpose(1, 2), memory.permute(1, 2,
139 | 0).view(bs, c, h,
140 | w), weights
141 |
142 |
143 | class TransformerEncoder(nn.Module):
144 | def __init__(self, encoder_layer, num_layers, norm=None):
145 | super().__init__()
146 | self.layers = _get_clones(encoder_layer, num_layers)
147 | self.num_layers = num_layers
148 | self.norm = norm
149 |
150 | def forward(self,
151 | src,
152 | mask: Optional[Tensor] = None,
153 | src_key_padding_mask: Optional[Tensor] = None,
154 | pos: Optional[Tensor] = None):
155 | output = src
156 |
157 | for layer in self.layers:
158 | output = layer(output,
159 | src_mask=mask,
160 | src_key_padding_mask=src_key_padding_mask,
161 | pos=pos)
162 |
163 | if self.norm is not None:
164 | output = self.norm(output)
165 |
166 | return output
167 |
168 |
169 | class BoundaryAwareTransformerEncoder(nn.Module):
170 | def __init__(self,
171 | point_pred_layers,
172 | encoder_layer,
173 | num_layers,
174 | norm=None):
175 | super().__init__()
176 | self.point_pred_layers = point_pred_layers
177 | self.layers = _get_clones(encoder_layer, num_layers)
178 | self.num_layers = num_layers
179 | self.norm = norm
180 |
181 | def forward(self,
182 | src,
183 | mask: Optional[Tensor] = None,
184 | src_key_padding_mask: Optional[Tensor] = None,
185 | pos: Optional[Tensor] = None,
186 | height: int = 32,
187 | width: int = 32):
188 | output = src
189 | weights = []
190 |
191 | for layer_i, layer in enumerate(self.layers):
192 | output, weight = layer(True,
193 | output,
194 | src_mask=mask,
195 | src_key_padding_mask=src_key_padding_mask,
196 | pos=pos,
197 | height=height,
198 | width=width)
199 | weights.append(weight)
200 |
201 | if self.norm is not None:
202 | output = self.norm(output)
203 |
204 | return output, weights
205 |
206 |
207 | class TransformerEncoderLayer(nn.Module):
208 | def __init__(self,
209 | d_model,
210 | nhead,
211 | dim_feedforward=2048,
212 | dropout=0.1,
213 | activation=nn.LeakyReLU,
214 | normalize_before=False):
215 | super().__init__()
216 | self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
217 | # Implementation of Feedforward model
218 | self.linear1 = nn.Linear(d_model, dim_feedforward)
219 | self.dropout = nn.Dropout(dropout)
220 | self.linear2 = nn.Linear(dim_feedforward, d_model)
221 |
222 | self.norm1 = nn.LayerNorm(d_model)
223 | self.norm2 = nn.LayerNorm(d_model)
224 | self.dropout1 = nn.Dropout(dropout)
225 | self.dropout2 = nn.Dropout(dropout)
226 |
227 | self.activation = activation()
228 | self.normalize_before = normalize_before
229 |
230 | def with_pos_embed(self, tensor, pos: Optional[Tensor]):
231 | return tensor if pos is None else tensor + pos
232 |
233 | def forward_post(self,
234 | src,
235 | src_mask: Optional[Tensor] = None,
236 | src_key_padding_mask: Optional[Tensor] = None,
237 | pos: Optional[Tensor] = None):
238 | q = k = self.with_pos_embed(src, pos)
239 | src2 = self.self_attn(q,
240 | k,
241 | value=src,
242 | attn_mask=src_mask,
243 | key_padding_mask=src_key_padding_mask)[0]
244 | src = src + self.dropout1(src2)
245 | src = self.norm1(src)
246 | src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
247 | src = src + self.dropout2(src2)
248 | src = self.norm2(src)
249 | return src
250 |
251 | def forward_pre(self,
252 | src,
253 | src_mask: Optional[Tensor] = None,
254 | src_key_padding_mask: Optional[Tensor] = None,
255 | pos: Optional[Tensor] = None):
256 | src2 = self.norm1(src)
257 | q = k = self.with_pos_embed(src2, pos)
258 | src2 = self.self_attn(q,
259 | k,
260 | value=src2,
261 | attn_mask=src_mask,
262 | key_padding_mask=src_key_padding_mask)[0]
263 | src = src + self.dropout1(src2)
264 | src2 = self.norm2(src)
265 | src2 = self.linear2(self.dropout(self.activation(self.linear1(src2))))
266 | src = src + self.dropout2(src2)
267 | return src
268 |
269 | def forward(self,
270 | src,
271 | src_mask: Optional[Tensor] = None,
272 | src_key_padding_mask: Optional[Tensor] = None,
273 | pos: Optional[Tensor] = None):
274 | if self.normalize_before:
275 | return self.forward_pre(src, src_mask, src_key_padding_mask, pos)
276 | return self.forward_post(src, src_mask, src_key_padding_mask, pos)
277 |
278 |
279 | class BoundaryAwareTransformerEncoderLayer(TransformerEncoderLayer):
280 | " Add Boundary-wise Attention Gate to Transformer's Encoder"
281 |
282 | def __init__(self,
283 | d_model,
284 | nhead,
285 | BAG_type='2D',
286 | Atrous=True,
287 | dim_feedforward=2048,
288 | dropout=0.1,
289 | activation=nn.LeakyReLU,
290 | normalize_before=False):
291 | super().__init__(d_model, nhead, dim_feedforward, dropout, activation,
292 | normalize_before)
293 | if BAG_type == '2D':
294 | if Atrous:
295 | self.BAG = BoundaryWiseAttentionGateAtrous2D(d_model)
296 | else:
297 | self.BAG = BoundaryWiseAttentionGate2D(d_model)
298 | self.BAG_type = BAG_type
299 |
300 | def forward(self,
301 | use_bag,
302 | src,
303 | src_mask: Optional[Tensor] = None,
304 | src_key_padding_mask: Optional[Tensor] = None,
305 | pos: Optional[Tensor] = None,
306 | height: int = 32,
307 | width: int = 32):
308 | if self.normalize_before:
309 | features = self.forward_pre(src, src_mask, src_key_padding_mask,
310 | pos)
311 | if use_bag:
312 | b, c = features.shape[1:]
313 | if self.BAG_type == '1D':
314 | features = features.permute(1, 2, 0)
315 | features, weights = self.BAG(features)
316 | features = features.permute(2, 0, 1).contiguous()
317 | weights = weights.view(b, 1, height, width)
318 | elif self.BAG_type == '2D':
319 | features = features.permute(1, 2,
320 | 0).view(b, c, height, width)
321 | features, weights = self.BAG(features)
322 | features = features.flatten(2).permute(2, 0,
323 | 1).contiguous()
324 | return features, weights
325 | else:
326 | return features
327 | features = self.forward_post(src, src_mask, src_key_padding_mask, pos)
328 | if use_bag:
329 | b, c = features.shape[1:]
330 | if self.BAG_type == '1D':
331 | features = features.permute(1, 2, 0)
332 | features, weights = self.BAG(features)
333 | features = features.permute(2, 0, 1).contiguous()
334 | weights = weights.view(b, 1, height, width)
335 | elif self.BAG_type == '2D':
336 | features = features.permute(1, 2, 0).view(b, c, height, width)
337 | features, weights = self.BAG(features)
338 | features = features.flatten(2).permute(2, 0, 1).contiguous()
339 | return features, weights
340 | else:
341 | return features
342 |
343 |
344 | class TransformerDecoder(nn.Module):
345 | def __init__(self,
346 | decoder_layer,
347 | num_layers,
348 | norm=None,
349 | return_intermediate=False):
350 | super().__init__()
351 | self.layers = _get_clones(decoder_layer, num_layers)
352 | self.num_layers = num_layers
353 | self.norm = norm
354 | self.return_intermediate = return_intermediate
355 |
356 | def forward(self,
357 | tgt,
358 | memory,
359 | tgt_mask: Optional[Tensor] = None,
360 | memory_mask: Optional[Tensor] = None,
361 | tgt_key_padding_mask: Optional[Tensor] = None,
362 | memory_key_padding_mask: Optional[Tensor] = None,
363 | pos: Optional[Tensor] = None,
364 | query_pos: Optional[Tensor] = None):
365 | output = tgt
366 |
367 | intermediate = []
368 |
369 | for layer in self.layers:
370 | output = layer(output,
371 | memory,
372 | tgt_mask=tgt_mask,
373 | memory_mask=memory_mask,
374 | tgt_key_padding_mask=tgt_key_padding_mask,
375 | memory_key_padding_mask=memory_key_padding_mask,
376 | pos=pos,
377 | query_pos=query_pos)
378 | if self.return_intermediate:
379 | intermediate.append(self.norm(output))
380 |
381 | if self.norm is not None:
382 | output = self.norm(output)
383 | if self.return_intermediate:
384 | intermediate.pop()
385 | intermediate.append(output)
386 |
387 | if self.return_intermediate:
388 | return torch.stack(intermediate)
389 |
390 | return output
391 |
392 |
393 | class TransformerDecoderLayer(nn.Module):
394 | def __init__(self,
395 | d_model,
396 | nhead,
397 | dim_feedforward=2048,
398 | dropout=0.1,
399 | activation=nn.LeakyReLU,
400 | normalize_before=False):
401 | super().__init__()
402 | self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
403 | self.multihead_attn = nn.MultiheadAttention(d_model,
404 | nhead,
405 | dropout=dropout)
406 | # Implementation of Feedforward model
407 | self.linear1 = nn.Linear(d_model, dim_feedforward)
408 | self.dropout = nn.Dropout(dropout)
409 | self.linear2 = nn.Linear(dim_feedforward, d_model)
410 |
411 | self.norm1 = nn.LayerNorm(d_model)
412 | self.norm2 = nn.LayerNorm(d_model)
413 | self.norm3 = nn.LayerNorm(d_model)
414 | self.dropout1 = nn.Dropout(dropout)
415 | self.dropout2 = nn.Dropout(dropout)
416 | self.dropout3 = nn.Dropout(dropout)
417 |
418 | self.activation = activation()
419 | self.normalize_before = normalize_before
420 |
421 | def with_pos_embed(self, tensor, pos: Optional[Tensor]):
422 | return tensor if pos is None else tensor + pos
423 |
424 | def forward_post(self,
425 | tgt,
426 | memory,
427 | tgt_mask: Optional[Tensor] = None,
428 | memory_mask: Optional[Tensor] = None,
429 | tgt_key_padding_mask: Optional[Tensor] = None,
430 | memory_key_padding_mask: Optional[Tensor] = None,
431 | pos: Optional[Tensor] = None,
432 | query_pos: Optional[Tensor] = None):
433 | q = k = self.with_pos_embed(tgt, query_pos)
434 | tgt2 = self.self_attn(q,
435 | k,
436 | value=tgt,
437 | attn_mask=tgt_mask,
438 | key_padding_mask=tgt_key_padding_mask)[0]
439 | tgt = tgt + self.dropout1(tgt2)
440 | tgt = self.norm1(tgt)
441 | tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos),
442 | key=self.with_pos_embed(memory, pos),
443 | value=memory,
444 | attn_mask=memory_mask,
445 | key_padding_mask=memory_key_padding_mask)[0]
446 | tgt = tgt + self.dropout2(tgt2)
447 | tgt = self.norm2(tgt)
448 | tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
449 | tgt = tgt + self.dropout3(tgt2)
450 | tgt = self.norm3(tgt)
451 | return tgt
452 |
453 | def forward_pre(self,
454 | tgt,
455 | memory,
456 | tgt_mask: Optional[Tensor] = None,
457 | memory_mask: Optional[Tensor] = None,
458 | tgt_key_padding_mask: Optional[Tensor] = None,
459 | memory_key_padding_mask: Optional[Tensor] = None,
460 | pos: Optional[Tensor] = None,
461 | query_pos: Optional[Tensor] = None):
462 | tgt2 = self.norm1(tgt)
463 | q = k = self.with_pos_embed(tgt2, query_pos)
464 | tgt2 = self.self_attn(q,
465 | k,
466 | value=tgt2,
467 | attn_mask=tgt_mask,
468 | key_padding_mask=tgt_key_padding_mask)[0]
469 | tgt = tgt + self.dropout1(tgt2)
470 | tgt2 = self.norm2(tgt)
471 | tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos),
472 | key=self.with_pos_embed(memory, pos),
473 | value=memory,
474 | attn_mask=memory_mask,
475 | key_padding_mask=memory_key_padding_mask)[0]
476 | tgt = tgt + self.dropout2(tgt2)
477 | tgt2 = self.norm3(tgt)
478 | tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
479 | tgt = tgt + self.dropout3(tgt2)
480 | return tgt
481 |
482 | def forward(self,
483 | tgt,
484 | memory,
485 | tgt_mask: Optional[Tensor] = None,
486 | memory_mask: Optional[Tensor] = None,
487 | tgt_key_padding_mask: Optional[Tensor] = None,
488 | memory_key_padding_mask: Optional[Tensor] = None,
489 | pos: Optional[Tensor] = None,
490 | query_pos: Optional[Tensor] = None):
491 | if self.normalize_before:
492 | return self.forward_pre(tgt, memory, tgt_mask, memory_mask,
493 | tgt_key_padding_mask,
494 | memory_key_padding_mask, pos, query_pos)
495 | return self.forward_post(tgt, memory, tgt_mask, memory_mask,
496 | tgt_key_padding_mask, memory_key_padding_mask,
497 | pos, query_pos)
498 |
499 |
500 | def _get_clones(module, N):
501 | return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
502 |
503 |
504 | def build_transformer(args):
505 | return Transformer(
506 | d_model=args.hidden_dim,
507 | dropout=args.dropout,
508 | nhead=args.nheads,
509 | dim_feedforward=args.dim_feedforward,
510 | num_encoder_layers=args.enc_layers,
511 | num_decoder_layers=args.dec_layers,
512 | normalize_before=args.pre_norm,
513 | return_intermediate_dec=True,
514 | )
515 |
516 |
517 | def _get_activation_fn(activation):
518 | """Return an activation function given a string"""
519 | if activation == "leaky relu":
520 | return F.leaky_relu
521 | if activation == "selu":
522 | return F.selu
523 | if activation == "relu":
524 | return F.relu
525 | if activation == "gelu":
526 | return F.gelu
527 | if activation == "glu":
528 | return F.glu
529 | raise RuntimeError(
530 | F"activation should be relu, gelu, glu, leaky relu or selu, not {activation}."
531 | )
--------------------------------------------------------------------------------