├── frame.jpg ├── isic2016.png ├── .gitignore ├── .ipynb_checkpoints └── infer-checkpoint.ipynb ├── .vscode └── settings.json ├── utils ├── loss.py ├── format_conversion.py ├── utils.py ├── polar_generate.py ├── polar_transformations.py ├── resize.py ├── isic2016_dataloader.py ├── dataloader.py ├── point_generate.py ├── isbi2016_new.py ├── isic2018_dataloader.py ├── isbi2018_new.py └── isic2018_polar.py ├── lib ├── Position_embedding.py ├── vision_transformers.py ├── TransFuse │ ├── DeiT.py │ ├── vision_transformer.py │ └── TransFuse.py ├── modules.py ├── xboundformer_v0.py ├── baseline.py ├── pvt.py ├── xboundformer.py ├── deeplabv3.py ├── replknet.py └── transformer.py ├── README.md └── src ├── test.py └── train.py /frame.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jcwang123/xboundformer/HEAD/frame.jpg -------------------------------------------------------------------------------- /isic2016.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jcwang123/xboundformer/HEAD/isic2016.png -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | logs/ 2 | results/ 3 | exps/ 4 | *.pth 5 | scripts/ 6 | *.pyc 7 | figures/ 8 | -------------------------------------------------------------------------------- /.ipynb_checkpoints/infer-checkpoint.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [], 3 | "metadata": {}, 4 | "nbformat": 4, 5 | "nbformat_minor": 5 6 | } 7 | -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "git.ignoreLimitWarning": true, 3 | "python.pythonPath": "/home/wjc/.conda/envs/pytorch/bin/python" 4 | } 5 | -------------------------------------------------------------------------------- /utils/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | def structure_loss(pred, mask): 6 | weit = 1 + 5 * torch.abs( 7 | F.avg_pool2d(mask, kernel_size=31, stride=1, padding=15) - mask) 8 | wbce = F.binary_cross_entropy_with_logits(pred, mask, reduce='none') 9 | wbce = (weit * wbce).sum(dim=(2, 3)) / weit.sum(dim=(2, 3)) 10 | 11 | pred = torch.sigmoid(pred) 12 | inter = ((pred * mask) * weit).sum(dim=(2, 3)) 13 | union = ((pred + mask) * weit).sum(dim=(2, 3)) 14 | wiou = 1 - (inter + 1) / (union - inter + 1) 15 | 16 | return (wbce + wiou).mean() 17 | 18 | 19 | def dice_loss(pred, mask, act=False): 20 | if not act: 21 | pred = torch.sigmoid(pred) 22 | inter = (pred * mask).sum(dim=(2, 3)) 23 | union = (pred + mask).sum(dim=(2, 3)) 24 | wiou = 1 - (inter + 1) / (union - inter + 1) 25 | return wiou.mean() -------------------------------------------------------------------------------- /utils/format_conversion.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | from libtiff import TIFF # pip install libtiff 4 | from scipy import misc 5 | import random 6 | 7 | 8 | def tif2png(_src_path, _dst_path): 9 | """ 10 | Usage: 11 | formatting `tif/tiff` files to `jpg/png` files 12 | :param _src_path: 13 | :param _dst_path: 14 | :return: 15 | """ 16 | tif = TIFF.open(_src_path, mode='r') 17 | image = tif.read_image() 18 | misc.imsave(_dst_path, image) 19 | 20 | 21 | def data_split(src_list): 22 | """ 23 | Usage: 24 | randomly spliting dataset 25 | :param src_list: 26 | :return: 27 | """ 28 | counter_list = random.sample(range(0, len(src_list)), 550) 29 | 30 | return counter_list 31 | 32 | 33 | if __name__ == '__main__': 34 | src_dir = '../Dataset/train_dataset/CVC-EndoSceneStill/CVC-612/test_split/masks_tif' 35 | dst_dir = '../Dataset/train_dataset/CVC-EndoSceneStill/CVC-612/test_split/masks' 36 | 37 | os.makedirs(dst_dir, exist_ok=True) 38 | for img_name in os.listdir(src_dir): 39 | tif2png(os.path.join(src_dir, img_name), 40 | os.path.join(dst_dir, img_name.replace('.tif', '.png'))) -------------------------------------------------------------------------------- /lib/Position_embedding.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import math 3 | import torch 4 | 5 | from torch import nn, Tensor 6 | import torch.nn.functional as F 7 | 8 | 9 | class PositionEmbeddingLearned(nn.Module): 10 | """ 11 | Absolute pos embedding, learned. 12 | """ 13 | def __init__(self, num_pos_feats=128): 14 | super().__init__() 15 | self.row_embed = nn.Embedding(50, num_pos_feats) 16 | self.col_embed = nn.Embedding(50, num_pos_feats) 17 | self.reset_parameters() 18 | 19 | def reset_parameters(self): 20 | nn.init.uniform_(self.row_embed.weight) 21 | nn.init.uniform_(self.col_embed.weight) 22 | 23 | def forward(self, x): 24 | h, w = x.shape[-2:] 25 | # print(h, w) 26 | i = torch.arange(w, device=x.device) 27 | j = torch.arange(h, device=x.device) 28 | x_emb = self.col_embed(i) 29 | y_emb = self.row_embed(j) 30 | pos = torch.cat([ 31 | x_emb.unsqueeze(0).repeat(h, 1, 1), 32 | y_emb.unsqueeze(1).repeat(1, w, 1), 33 | ], 34 | dim=-1).permute(2, 0, 1).unsqueeze(0).repeat( 35 | x.shape[0], 1, 1, 1) 36 | return pos -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from thop import profile 4 | from thop import clever_format 5 | 6 | 7 | def clip_gradient(optimizer, grad_clip): 8 | """ 9 | For calibrating misalignment gradient via cliping gradient technique 10 | :param optimizer: 11 | :param grad_clip: 12 | :return: 13 | """ 14 | for group in optimizer.param_groups: 15 | for param in group['params']: 16 | if param.grad is not None: 17 | param.grad.data.clamp_(-grad_clip, grad_clip) 18 | 19 | 20 | def adjust_lr(optimizer, init_lr, epoch, decay_rate=0.1, decay_epoch=30): 21 | decay = decay_rate**(epoch // decay_epoch) 22 | for param_group in optimizer.param_groups: 23 | param_group['lr'] *= decay 24 | 25 | 26 | class AvgMeter(object): 27 | def __init__(self, num=40): 28 | self.num = num 29 | self.reset() 30 | 31 | def reset(self): 32 | self.val = 0 33 | self.avg = 0 34 | self.sum = 0 35 | self.count = 0 36 | self.losses = [] 37 | 38 | def update(self, val, n=1): 39 | self.val = val 40 | self.sum += val * n 41 | self.count += n 42 | self.avg = self.sum / self.count 43 | self.losses.append(val) 44 | 45 | def show(self): 46 | return torch.mean( 47 | torch.stack(self.losses[np.maximum(len(self.losses) - 48 | self.num, 0):])) 49 | 50 | 51 | def CalParams(model, input_tensor): 52 | """ 53 | Usage: 54 | Calculate Params and FLOPs via [THOP](https://github.com/Lyken17/pytorch-OpCounter) 55 | Necessarity: 56 | from thop import profile 57 | from thop import clever_format 58 | :param model: 59 | :param input_tensor: 60 | :return: 61 | """ 62 | flops, params = profile(model, inputs=(input_tensor, )) 63 | flops, params = clever_format([flops, params], "%.3f") 64 | print('[Statistics Information]\nFLOPs: {}\nParams: {}'.format( 65 | flops, params)) 66 | -------------------------------------------------------------------------------- /utils/polar_generate.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import os 3 | import random 4 | import torch 5 | import numpy as np 6 | import skimage.draw 7 | from tqdm import tqdm 8 | import matplotlib.pyplot as plt 9 | import torch.nn.functional as F 10 | from polar_transformations import to_polar, centroid 11 | 12 | 13 | def polar_gen_isic2018(): 14 | data_dir = '/raid/wjc/data/skin_lesion/isic2018_jpg_smooth/' 15 | 16 | os.makedirs(data_dir + '/PolarImage', exist_ok=True) 17 | os.makedirs(data_dir + '/PolarLabel', exist_ok=True) 18 | 19 | path_list = os.listdir(data_dir + '/Label/') 20 | path_list.sort() 21 | num = 0 22 | for path in tqdm(path_list): 23 | image_data = cv2.imread(os.path.join(data_dir, 'Image', path)) 24 | 25 | label_data = cv2.imread(os.path.join(data_dir, 'Label', path), 26 | cv2.IMREAD_GRAYSCALE) 27 | center = centroid(image_data) 28 | image_data = to_polar(image_data, center) 29 | label_data = to_polar(label_data, center) 30 | # print(image_data.max(), label_data.max()) 31 | 32 | cv2.imwrite(data_dir + '/PolarImage/' + path, image_data) 33 | cv2.imwrite(data_dir + '/PolarLabel/' + path, label_data) 34 | # break 35 | 36 | 37 | def point_gen_isic2016(): 38 | R = 10 39 | N = 25 40 | for split in ['Train', 'Test', 'Validation']: 41 | data_dir = '/raid/wjc/data/skin_lesion/isic2016/{}/Label'.format(split) 42 | 43 | save_dir = data_dir.replace('Label', 'Point') 44 | os.makedirs(save_dir, exist_ok=True) 45 | 46 | path_list = os.listdir(data_dir) 47 | path_list.sort() 48 | num = 0 49 | for path in tqdm(path_list): 50 | name = path[:-4] 51 | label_path = os.path.join(data_dir, path) 52 | print(label_path) 53 | label_ori, point_heatmap = kpm_gen(label_path, R, N) 54 | save_path = os.path.join(save_dir, name + '.npy') 55 | np.save(save_path, point_heatmap) 56 | num += 1 57 | 58 | 59 | if __name__ == '__main__': 60 | polar_gen_isic2018() -------------------------------------------------------------------------------- /utils/polar_transformations.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import matplotlib.pyplot as plt 4 | import numpy as np 5 | import cv2 as cv 6 | 7 | 8 | def centroid(img, lcc=False): 9 | if lcc: 10 | img = img.astype(np.uint8) 11 | nb_components, output, stats, centroids = cv.connectedComponentsWithStats( 12 | img, connectivity=4) 13 | sizes = stats[:, -1] 14 | if len(sizes) > 2: 15 | max_label = 1 16 | max_size = sizes[1] 17 | 18 | for i in range(2, nb_components): 19 | if sizes[i] > max_size: 20 | max_label = i 21 | max_size = sizes[i] 22 | 23 | img2 = np.zeros(output.shape) 24 | img2[output == max_label] = 255 25 | img = img2 26 | 27 | if len(img.shape) > 2: 28 | M = cv.moments(img[:, :, 1]) 29 | else: 30 | M = cv.moments(img) 31 | 32 | if M["m00"] == 0: 33 | return (img.shape[0] // 2, img.shape[1] // 2) 34 | 35 | cX = int(M["m10"] / M["m00"]) 36 | cY = int(M["m01"] / M["m00"]) 37 | return (cX, cY) 38 | 39 | 40 | def to_polar(input_img, center): 41 | input_img = input_img.astype(np.float32) 42 | value = np.sqrt(((input_img.shape[0] / 2.0)**2.0) + 43 | ((input_img.shape[1] / 2.0)**2.0)) 44 | polar_image = cv.linearPolar(input_img, center, value, 45 | cv.WARP_FILL_OUTLIERS) 46 | polar_image = cv.rotate(polar_image, cv.ROTATE_90_COUNTERCLOCKWISE) 47 | return polar_image 48 | 49 | 50 | def to_cart(input_img, center): 51 | input_img = input_img.astype(np.float32) 52 | input_img = cv.rotate(input_img, cv.ROTATE_90_CLOCKWISE) 53 | value = np.sqrt(((input_img.shape[1] / 2.0)**2.0) + 54 | ((input_img.shape[0] / 2.0)**2.0)) 55 | polar_image = cv.linearPolar(input_img, center, value, 56 | cv.WARP_FILL_OUTLIERS + cv.WARP_INVERSE_MAP) 57 | polar_image = polar_image.astype(np.uint8) 58 | return polar_image 59 | 60 | 61 | if __name__ == "__main__": 62 | image = cv.imread('test_images/30.tif') 63 | plt.imshow(image) 64 | 65 | center = centroid(image) 66 | plt.scatter(center[0], center[1]) 67 | plt.show() 68 | 69 | polar = to_polar(image, center) 70 | plt.imshow(polar) 71 | plt.show() 72 | 73 | cart = to_cart(polar, center) 74 | plt.imshow(cart) 75 | plt.show() 76 | -------------------------------------------------------------------------------- /lib/vision_transformers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from lib.transformer import BoundaryAwareTransformer 5 | from lib.Position_embedding import PositionEmbeddingLearned 6 | 7 | 8 | class in_scale_transformer(nn.Module): 9 | def __init__(self, 10 | point_pred_layers=1, 11 | num_queries=1, 12 | d_model=512, 13 | nhead=8, 14 | num_encoder_layers=6, 15 | num_decoder_layers=6, 16 | dim_feedforward=2048, 17 | dropout=0.1, 18 | activation=nn.LeakyReLU, 19 | normalize_before=False, 20 | return_intermediate_dec=False, 21 | BAG_type='2D', 22 | Atrous=False): 23 | 24 | super().__init__() 25 | 26 | self.query_embed = nn.Embedding(num_queries, d_model) 27 | self.pos_embed = PositionEmbeddingLearned(d_model // 2) 28 | self.num_queries = num_queries 29 | 30 | self.transformer = BoundaryAwareTransformer( 31 | point_pred_layers=point_pred_layers, 32 | d_model=d_model, 33 | nhead=nhead, 34 | num_encoder_layers=num_encoder_layers, 35 | num_decoder_layers=num_decoder_layers, 36 | dim_feedforward=dim_feedforward, 37 | dropout=dropout, 38 | activation=activation, 39 | normalize_before=normalize_before, 40 | return_intermediate_dec=return_intermediate_dec, 41 | BAG_type=BAG_type, 42 | Atrous=Atrous) 43 | 44 | def forward(self, x): 45 | 46 | pos_embed = self.pos_embed(x).to(x.dtype) 47 | 48 | latent_tensor, features_encoded, point_maps = self.transformer( 49 | x, None, self.query_embed.weight, pos_embed) 50 | 51 | return latent_tensor, features_encoded, point_maps 52 | 53 | 54 | # def detr_Transformer(pretrained=False, **kwargs): 55 | 56 | # transformer = DETR_Transformer(num_encoder_layers=6, 57 | # num_decoder_layers=6, 58 | # d_model=256, 59 | # nhead=8) 60 | 61 | # if pretrained: 62 | # print("Loaded DETR Pretrained Parameters From ImageNet...") 63 | # ckpt = torch.load( 64 | # '/home/chenfei/my_codes/TransformerCode-master/Ours/pretrained/detr-r50-e632da11.pth' 65 | # ) 66 | # state_dict = ckpt['model'] 67 | 68 | # transformer.query_embed = nn.Embedding(100, 256) 69 | # unParalled_state_dict = {} 70 | # for key in state_dict.keys(): 71 | # if key.startswith("transformer"): 72 | # unParalled_state_dict[key] = state_dict[key] 73 | # #elif key.startswith("query_embed"): 74 | # # unParalled_state_dict[key] = state_dict[key] 75 | # ### Without positional embedding for detr and mismatch for query_embed 76 | # transformer.load_state_dict(unParalled_state_dict, strict=False) 77 | # return transformer 78 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # XBound-Former: Toward Cross-scale Boundary Modeling in Transformers 2 | 3 | ## Introduction 4 | 5 | This is an official release of the paper **XBound-Former: Toward Cross-scale Boundary Modeling in Transformers**, including the network implementation and the training scripts. 6 | 7 | > [**XBound-Former: Toward Cross-scale Boundary Modeling in Transformers**]
8 | > **Jiacheng Wang**, Fei Chen, Yuxi Ma, Liansheng Wang, Zhaodong Fei, Jianwei Shuai, Xiangdong Tang, Qichao Zhou, Jing Qin
9 | > In: Transactions on Medical Imaging (TMI), 2023
10 | > [[arXiv](https://arxiv.org/abs/2206.00806)][[Bibetex](https://github.com/jcwang123、xboundformer#Citation)] 11 | 12 |
13 | 14 | ## News 15 | - **[1/11 2022] This paper has been accepted to TMI.** 16 | - **[5/27 2022] We have released the training scripts.** 17 | - **[5/19 2022] We have created this repo.** 18 | 19 | ## Code List 20 | 21 | - [x] Network 22 | - [x] Pre-processing 23 | - [x] Training Codes 24 | - [ ] Pretrained Weights 25 | 26 | For more details or any questions, please feel easy to contact us by email (jiachengw@stu.xmu.edu.cn). 27 | 28 | ## Usage 29 | 30 | ### Dataset 31 | 32 | Please download the dataset from [ISIC](https://www.isic-archive.com/) challenge and [PH2](https://www.fc.up.pt/addi/ph2%20database.html) website. 33 | 34 | ### Pre-processing 35 | 36 | Please run: 37 | 38 | ```bash 39 | $ python utils/resize.py 40 | ``` 41 | 42 | You need to change the **File Path** to your own and select the correct function. 43 | 44 | ### Training 45 | 46 | Please run: 47 | 48 | ```bash 49 | $ python src/train.py 50 | ``` 51 | You need to change the **File Path** to your own and select the correct function. 52 | 53 | ### Testing 54 | 55 | Download the pretrained weight for ISCI-2016&$ph^2$ dataset from [Google Drive](https://drive.google.com/file/d/1-eMHYX1fr-QvI3n50S0xqWcxc3FGsMgE/view?usp=sharing) and move to the logger dir. 56 | 57 | Then, please run: 58 | 59 | ```bash 60 | $ python src/test.py 61 | ``` 62 | 63 | ### Result 64 | The ISIC-2016&$ph^2$ dataset: 65 |
66 | 67 | ## Citation 68 | 69 | If you find XBound-Former useful in your research, please consider citing: 70 | ``` 71 | @article{wang2023xbound, 72 | title={XBound-Former: Toward Cross-scale Boundary Modeling in Transformers}, 73 | author={Wang, Jiacheng and Chen, Fei and Ma, Yuxi and Wang, Liansheng and Fei, Zhaodong and Shuai, Jianwei and Tang, Xiangdong and Zhou, Qichao and Qin, Jing}, 74 | journal={IEEE Transactions on Medical Imaging}, 75 | year={2023}, 76 | publisher={IEEE} 77 | } 78 | ``` 79 | and the prior work, BAT, as: 80 | ``` 81 | @inproceedings{wang2021boundary, 82 | title={Boundary-Aware Transformers for Skin Lesion Segmentation}, 83 | author={Wang, Jiacheng and Wei, Lan and Wang, Liansheng and Zhou, Qichao and Zhu, Lei and Qin, Jing}, 84 | booktitle={International Conference on Medical Image Computing and Computer-Assisted Intervention}, 85 | pages={206--216}, 86 | year={2021}, 87 | organization={Springer} 88 | } 89 | ``` 90 | -------------------------------------------------------------------------------- /lib/TransFuse/DeiT.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from functools import partial 4 | 5 | from .vision_transformer import VisionTransformer, _cfg 6 | from timm.models.registry import register_model 7 | from timm.models.layers import trunc_normal_ 8 | import torch.nn.functional as F 9 | import numpy as np 10 | 11 | __all__ = [ 12 | 'deit_tiny_patch16_224', 13 | 'deit_small_patch16_224', 14 | 'deit_base_patch16_224', 15 | 'deit_tiny_distilled_patch16_224', 16 | 'deit_small_distilled_patch16_224', 17 | 'deit_base_distilled_patch16_224', 18 | 'deit_base_patch16_384', 19 | 'deit_base_distilled_patch16_384', 20 | ] 21 | 22 | 23 | class DeiT(VisionTransformer): 24 | def __init__(self, *args, **kwargs): 25 | super().__init__(*args, **kwargs) 26 | num_patches = self.patch_embed.num_patches 27 | self.pos_embed = nn.Parameter( 28 | torch.zeros(1, num_patches + 1, self.embed_dim)) 29 | 30 | def forward(self, x): 31 | # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py 32 | # with slight modifications to add the dist_token 33 | B = x.shape[0] 34 | x = self.patch_embed(x) 35 | pe = self.pos_embed 36 | 37 | x = x + pe 38 | x = self.pos_drop(x) 39 | 40 | for blk in self.blocks: 41 | x = blk(x) 42 | 43 | x = self.norm(x) 44 | return x 45 | 46 | 47 | @register_model 48 | def deit_small_patch16_224(pretrained=False, **kwargs): 49 | model = DeiT(patch_size=16, 50 | embed_dim=384, 51 | depth=8, 52 | num_heads=6, 53 | mlp_ratio=4, 54 | qkv_bias=True, 55 | norm_layer=partial(nn.LayerNorm, eps=1e-6), 56 | **kwargs) 57 | model.default_cfg = _cfg() 58 | if pretrained: 59 | ckpt = torch.load( 60 | '/raid/wjc/code/xbound_former/lib/TransFuse/pretrained/deit_small_patch16_224-cd65a155.pth' 61 | ) 62 | model.load_state_dict(ckpt['model'], strict=False) 63 | 64 | pe = model.pos_embed[:, 1:, :].detach() 65 | pe = pe.transpose(-1, -2) 66 | pe = pe.view(pe.shape[0], pe.shape[1], int(np.sqrt(pe.shape[2])), 67 | int(np.sqrt(pe.shape[2]))) 68 | pe = F.interpolate(pe, 69 | size=(512 // 16, 512 // 16), 70 | mode='bilinear', 71 | align_corners=True) # (12, 16) 72 | pe = pe.flatten(2) 73 | pe = pe.transpose(-1, -2) 74 | model.pos_embed = nn.Parameter(pe) 75 | model.head = nn.Identity() 76 | return model 77 | 78 | 79 | @register_model 80 | def deit_base_patch16_224(pretrained=False, **kwargs): 81 | model = DeiT(patch_size=16, 82 | embed_dim=768, 83 | depth=10, 84 | num_heads=12, 85 | mlp_ratio=4, 86 | qkv_bias=True, 87 | norm_layer=partial(nn.LayerNorm, eps=1e-6), 88 | **kwargs) 89 | model.default_cfg = _cfg() 90 | if pretrained: 91 | ckpt = torch.load( 92 | '/home/chenfei/my_codes/TransformerCode-master/transformer_family/TransFuse/pretrained/deit_base_patch16_224-b5f2ef4d.pth' 93 | ) 94 | model.load_state_dict(ckpt['model'], strict=False) 95 | 96 | pe = model.pos_embed[:, 1:, :].detach() 97 | pe = pe.transpose(-1, -2) 98 | pe = pe.view(pe.shape[0], pe.shape[1], int(np.sqrt(pe.shape[2])), 99 | int(np.sqrt(pe.shape[2]))) 100 | pe = F.interpolate(pe, size=(32, 32), mode='bilinear', 101 | align_corners=True) # (12, 16) 102 | pe = pe.flatten(2) 103 | pe = pe.transpose(-1, -2) 104 | model.pos_embed = nn.Parameter(pe) 105 | model.head = nn.Identity() 106 | return model -------------------------------------------------------------------------------- /utils/resize.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import os 3 | import random 4 | import torch 5 | import numpy as np 6 | from tqdm import tqdm 7 | import matplotlib.pyplot as plt 8 | 9 | 10 | def process_isic2018( 11 | dim=(352, 352), save_dir='/raid/wjc/data/skin_lesion/isic2018/'): 12 | image_dir_path = '/raid/wl/2018_raw_data/ISIC2018_Task1-2_Training_Input/' 13 | mask_dir_path = '/raid/wl/2018_raw_data/ISIC2018_Task1_Training_GroundTruth/' 14 | 15 | image_path_list = os.listdir(image_dir_path) 16 | mask_path_list = os.listdir(mask_dir_path) 17 | 18 | image_path_list = list(filter(lambda x: x[-3:] == 'jpg', image_path_list)) 19 | mask_path_list = list(filter(lambda x: x[-3:] == 'png', mask_path_list)) 20 | 21 | image_path_list.sort() 22 | mask_path_list.sort() 23 | 24 | print(len(image_path_list), len(mask_path_list)) 25 | 26 | # ISBI Dataset 27 | for image_path, mask_path in zip(image_path_list, mask_path_list): 28 | if image_path[-3:] == 'jpg': 29 | print(image_path) 30 | assert os.path.basename(image_path)[:-4].split( 31 | '_')[1] == os.path.basename(mask_path)[:-4].split('_')[1] 32 | _id = os.path.basename(image_path)[:-4].split('_')[1] 33 | image_path = os.path.join(image_dir_path, image_path) 34 | mask_path = os.path.join(mask_dir_path, mask_path) 35 | image = cv2.imread(image_path) 36 | mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE) 37 | 38 | image_new = cv2.resize(image, dim, interpolation=cv2.INTER_CUBIC) 39 | image_new = np.array(image_new, dtype=np.uint8) 40 | mask_new = cv2.resize(mask, dim, interpolation=cv2.INTER_NEAREST) 41 | mask_new = cv2.erode(mask_new, (5, 5)) 42 | mask_new = cv2.dilate(mask_new, (5, 5)) 43 | mask_new = np.array(mask_new, dtype=np.uint8) 44 | # print(np.unique(mask_new)) 45 | 46 | save_dir_path = save_dir + '/Image' 47 | os.makedirs(save_dir_path, exist_ok=True) 48 | # np.save(os.path.join(save_dir_path, _id + '.npy'), image_new) 49 | print(image_new.shape) 50 | cv2.imwrite(os.path.join(save_dir_path, 'ISIC_' + _id + '.jpg'), 51 | image_new) 52 | 53 | save_dir_path = save_dir + '/Label' 54 | os.makedirs(save_dir_path, exist_ok=True) 55 | # np.save(os.path.join(save_dir_path, _id + '.npy'), mask_new) 56 | cv2.imwrite(os.path.join(save_dir_path, 'ISIC_' + _id + '.jpg'), 57 | mask_new) 58 | 59 | 60 | def process_ph2(): 61 | PH2_images_path = '/data2/cf_data/skinlesion_segment/PH2_rawdata/PH2_Dataset_images' 62 | 63 | path_list = os.listdir(PH2_images_path) 64 | path_list.sort() 65 | 66 | for path in path_list: 67 | image_path = os.path.join(PH2_images_path, path, 68 | path + '_Dermoscopic_Image', path + '.bmp') 69 | label_path = os.path.join(PH2_images_path, path, path + '_lesion', 70 | path + '_lesion.bmp') 71 | image = plt.imread(image_path) 72 | label = plt.imread(label_path) 73 | label = label[:, :, 0] 74 | 75 | dim = (352, 352) 76 | image_new = cv2.resize(image, dim, interpolation=cv2.INTER_AREA) 77 | label_new = cv2.resize(label, dim, interpolation=cv2.INTER_AREA) 78 | 79 | image_save_path = os.path.join( 80 | '/data2/cf_data/skinlesion_segment/PH2_rawdata/PH2/Image', 81 | path + '.npy') 82 | label_save_path = os.path.join( 83 | '/data2/cf_data/skinlesion_segment/PH2_rawdata/PH2/Label', 84 | path + '.npy') 85 | 86 | np.save(image_save_path, image_new) 87 | np.save(label_save_path, label_new) 88 | 89 | 90 | if __name__ == '__main__': 91 | process_isic2018( 92 | dim=(352, 352), 93 | save_dir='/raid/wjc/data/skin_lesion/isic2018_jpg_352_smooth/') 94 | -------------------------------------------------------------------------------- /utils/isic2016_dataloader.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image 3 | import torch.utils.data as data 4 | import torchvision.transforms as transforms 5 | import numpy as np 6 | import random 7 | import torch 8 | import cv2 9 | from sklearn.model_selection import KFold 10 | import albumentations as A 11 | from albumentations.pytorch import ToTensorV2 12 | import json 13 | 14 | 15 | class isic2016Dataset(data.Dataset): 16 | """ 17 | dataloader for isic2016 segmentation tasks 18 | """ 19 | def __init__(self, image_root, gt_root, image_index, trainsize, 20 | augmentations): 21 | self.trainsize = trainsize 22 | self.augmentations = augmentations 23 | print(self.augmentations) 24 | self.image_root = image_root 25 | self.gt_root = gt_root 26 | self.images = image_index 27 | self.size = len(self.images) 28 | 29 | if self.augmentations: 30 | print('Using RandomRotation, RandomFlip') 31 | 32 | self.transform = A.Compose([ 33 | A.Rotate(90), 34 | A.VerticalFlip(p=0.5), 35 | A.HorizontalFlip(p=0.5), 36 | A.Resize(self.trainsize, self.trainsize), 37 | ToTensorV2() 38 | ]) 39 | else: 40 | print('no augmentation') 41 | self.transform = A.Compose( 42 | [A.Resize(self.trainsize, self.trainsize), 43 | ToTensorV2()]) 44 | 45 | def __getitem__(self, idx): 46 | file_name = self.images[idx] 47 | # gt_name = file_name[:-4] + '_segmentation.png' 48 | img_root = os.path.join(self.image_root, file_name) 49 | gt_root = os.path.join(self.gt_root, file_name[:-4] + '_label.npy') 50 | # image = cv2.imread(img_root) 51 | # image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 52 | # gt = (cv2.imread(gt_root, cv2.IMREAD_GRAYSCALE)) 53 | image = np.load(img_root) 54 | gt = (np.load(gt_root) * 255).astype(np.uint8) 55 | 56 | point_heatmap = cv2.Canny(gt, 0, 255) / 255.0 57 | gt = gt // 255.0 58 | gt = np.concatenate( 59 | [gt[..., np.newaxis], point_heatmap[..., np.newaxis]], axis=-1) 60 | pair = self.transform(image=image, mask=gt) 61 | gt = pair['mask'][:, :, 0] 62 | point_heatmap = pair['mask'][:, :, 1] 63 | gt = torch.unsqueeze(gt, 0) 64 | point_heatmap = torch.unsqueeze(point_heatmap, 0) 65 | image = pair['image'] / 255.0 66 | 67 | return image, gt, point_heatmap 68 | 69 | def resize(self, img, gt): 70 | assert img.size == gt.size 71 | w, h = img.size 72 | if h < self.trainsize or w < self.trainsize: 73 | h = max(h, self.trainsize) 74 | w = max(w, self.trainsize) 75 | return img.resize((w, h), Image.BILINEAR), gt.resize((w, h), 76 | Image.NEAREST) 77 | else: 78 | return img, gt 79 | 80 | def __len__(self): 81 | return self.size 82 | 83 | 84 | def get_loader(root_path, 85 | batchsize, 86 | trainsize, 87 | shuffle=True, 88 | num_workers=8, 89 | pin_memory=True, 90 | augmentation=False): 91 | 92 | dataset = isic2016Dataset(root_path + 'Train/Image', 93 | root_path + 'Train/Label', 94 | os.listdir(root_path + 'Train/Image'), trainsize, 95 | augmentation) 96 | data_loader = data.DataLoader(dataset=dataset, 97 | batch_size=batchsize, 98 | shuffle=shuffle, 99 | num_workers=num_workers, 100 | pin_memory=pin_memory) 101 | 102 | validset = isic2016Dataset(root_path + 'Validation/Image', 103 | root_path + 'Validation/Label', 104 | os.listdir(root_path + 'Validation/Image'), 105 | trainsize, False) 106 | valid_loader = data.DataLoader(dataset=validset, 107 | batch_size=1, 108 | shuffle=shuffle, 109 | num_workers=num_workers, 110 | pin_memory=pin_memory) 111 | 112 | testset = isic2016Dataset(root_path + 'Test/Image', 113 | root_path + 'Test/Label', 114 | os.listdir(root_path + 'Test/Image'), trainsize, 115 | False) 116 | test_loader = data.DataLoader(dataset=testset, 117 | batch_size=1, 118 | shuffle=shuffle, 119 | num_workers=num_workers, 120 | pin_memory=pin_memory) 121 | return data_loader, valid_loader, test_loader -------------------------------------------------------------------------------- /src/test.py: -------------------------------------------------------------------------------- 1 | from medpy.metric.binary import hd, hd95, dc, jc, assd 2 | import numpy as np 3 | from tqdm import tqdm 4 | import torch 5 | import torch.nn.functional as F 6 | import os 7 | import sys 8 | import cv2 9 | import matplotlib.pyplot as plt 10 | 11 | sys.path.append(os.path.join(os.path.dirname(__file__), '../')) 12 | from lib.xboundformer import _segm_pvtv2 13 | 14 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' 15 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 16 | 17 | save_point_pred = True 18 | 19 | 20 | def isbi2016(): 21 | target_size = (512, 512) 22 | 23 | for model_name in ['xboundformer']: 24 | if model_name == 'xboundformer': 25 | model = _segm_pvtv2(1, 2, 2, 1, 352).to(device) 26 | else: 27 | # TODO 28 | raise NotImplementedError 29 | model.load_state_dict( 30 | torch.load( 31 | f'logs/isbi2016/test_loss_1_aug_1/{model_name}/fold_None/model/best.pkl' 32 | )) 33 | for fold in ['PH2', 'Test']: 34 | save_dir = f'results/ISIC-2016-pictures/{model_name}/{fold}' 35 | os.makedirs(save_dir, exist_ok=True) 36 | from utils.isbi2016_new import norm01, myDataset 37 | if fold == 'PH2': 38 | dataset = myDataset(split='test', aug=False) 39 | else: 40 | dataset = myDataset(split='valid', aug=False) 41 | test_loader = torch.utils.data.DataLoader(dataset, batch_size=1) 42 | 43 | model.eval() 44 | for batch_idx, batch_data in tqdm(enumerate(test_loader)): 45 | data = batch_data['image'].to(device).float() 46 | label = batch_data['label'].to(device).float() 47 | path = batch_data['image_path'][0] 48 | with torch.no_grad(): 49 | output, point_pred1, point_pred2, point_pred3 = model(data) 50 | if save_point_pred: 51 | os.makedirs(save_dir.replace('pictures', 'point_maps'), 52 | exist_ok=True) 53 | point_pred1 = F.interpolate(point_pred1[-1], target_size) 54 | point_pred1 = point_pred1.cpu().numpy()[0, 0] 55 | plt.imsave( 56 | save_dir.replace('pictures', 'point_maps') + '/' + 57 | os.path.basename(path)[:-4] + '.png', point_pred1) 58 | output = torch.sigmoid(output)[0][0] 59 | output = (output.cpu().numpy() > 0.5).astype('uint8') 60 | output = (cv2.resize(output, target_size, cv2.INTER_NEAREST) > 61 | 0.5) * 1 62 | plt.imsave( 63 | save_dir + '/' + os.path.basename(path)[:-4] + '.png', 64 | output) 65 | 66 | 67 | def isbi2018(): 68 | model = _segm_pvtv2(1, 1, 1, 1, 352).to(device) 69 | target_size = (512, 512) 70 | for fold in range(5): 71 | model.load_state_dict( 72 | torch.load( 73 | f'logs/isbi2018/test_loss_1_aug_1/xboundformer/fold_{fold}/model/best.pkl' 74 | )) 75 | save_dir = f'results/ISIC-2018-pictures/xboundformer/fold-{int(fold)+1}' 76 | os.makedirs(save_dir, exist_ok=True) 77 | from utils.isbi2018_new import norm01, myDataset 78 | dataset = myDataset(fold=str(fold), split='valid', aug=False) 79 | test_loader = torch.utils.data.DataLoader(dataset, batch_size=1) 80 | 81 | model.eval() 82 | for batch_idx, batch_data in tqdm(enumerate(test_loader)): 83 | data = batch_data['image'].to(device).float() 84 | label = batch_data['label'].to(device).float() 85 | path = batch_data['image_path'][0] 86 | with torch.no_grad(): 87 | output, _, _, _ = model(data) 88 | output = torch.sigmoid(output)[0][0] 89 | output = (output.cpu().numpy() > 0.5).astype('uint8') 90 | output = (cv2.resize(output, target_size, cv2.INTER_NEAREST) > 91 | 0.5) * 1 92 | plt.imsave( 93 | save_dir + '/' + os.path.basename(path).split('_')[1][:-4] + 94 | '.png', output) 95 | 96 | 97 | def isbi2018_ablation(folder_name): 98 | vs = list(map(int, folder_name.split('_')[1:])) 99 | model = _segm_pvtv2(1, vs[0], vs[1], vs[2], 352).to(device) 100 | target_size = (512, 512) 101 | for fold in range(5): 102 | model.load_state_dict( 103 | torch.load( 104 | f'logs/isbi2018/test_loss_1_aug_1/{folder_name}/fold_{fold}/model/best.pkl' 105 | )) 106 | save_dir = f'results/ISIC-2018-pictures/{folder_name}/fold-{int(fold)+1}' 107 | os.makedirs(save_dir, exist_ok=True) 108 | from utils.isbi2018_new import norm01, myDataset 109 | dataset = myDataset(fold=str(fold), split='valid', aug=False) 110 | test_loader = torch.utils.data.DataLoader(dataset, batch_size=1) 111 | 112 | model.eval() 113 | for batch_idx, batch_data in tqdm(enumerate(test_loader)): 114 | data = batch_data['image'].to(device).float() 115 | label = batch_data['label'].to(device).float() 116 | path = batch_data['image_path'][0] 117 | with torch.no_grad(): 118 | output, _, _, _ = model(data) 119 | output = torch.sigmoid(output)[0][0] 120 | output = (output.cpu().numpy() > 0.5).astype('uint8') 121 | output = (cv2.resize(output, target_size, cv2.INTER_NEAREST) > 122 | 0.5) * 1 123 | plt.imsave( 124 | save_dir + '/' + os.path.basename(path).split('_')[1][:-4] + 125 | '.png', output) 126 | 127 | 128 | if __name__ == '__main__': 129 | # isbi2016() 130 | isbi2018_ablation('bl_0_0_0') 131 | isbi2018_ablation('bl_1_0_0') 132 | isbi2018_ablation('bl_1_1_0') 133 | isbi2018_ablation('bl_1_1_1') 134 | -------------------------------------------------------------------------------- /utils/dataloader.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image 3 | import torch.utils.data as data 4 | import torchvision.transforms as transforms 5 | import numpy as np 6 | import random 7 | import torch 8 | import cv2 9 | import albumentations as A 10 | from albumentations.pytorch import ToTensorV2 11 | 12 | 13 | class PolypDataset(data.Dataset): 14 | """ 15 | dataloader for polyp segmentation tasks 16 | """ 17 | def __init__(self, image_root, gt_root, trainsize, augmentations): 18 | self.trainsize = trainsize 19 | self.augmentations = augmentations 20 | print(self.augmentations) 21 | self.images = [ 22 | image_root + f for f in os.listdir(image_root) 23 | if f.endswith('.jpg') or f.endswith('.png') 24 | ] 25 | self.gts = [ 26 | gt_root + f for f in os.listdir(gt_root) if f.endswith('.png') 27 | ] 28 | self.images = sorted(self.images) 29 | self.gts = sorted(self.gts) 30 | self.filter_files() 31 | self.size = len(self.images) 32 | self.color1, self.color2 = [], [] 33 | for name in self.images: 34 | if os.path.basename(name)[:-4].isdigit(): 35 | self.color1.append(name) 36 | else: 37 | self.color2.append(name) 38 | if self.augmentations: 39 | self.transform = A.Compose([ 40 | A.Rotate(90), 41 | A.VerticalFlip(p=0.5), 42 | A.HorizontalFlip(p=0.5), 43 | A.Resize(self.trainsize, self.trainsize), 44 | ToTensorV2() 45 | ]) 46 | else: 47 | print('no augmentation') 48 | self.transform = A.Compose( 49 | [A.Resize(self.trainsize, self.trainsize), 50 | ToTensorV2()]) 51 | 52 | def __getitem__(self, idx): 53 | image = cv2.imread(self.images[idx]) 54 | image = cv2.cvtColor(image, cv2.COLOR_BGR2LAB) 55 | 56 | name2 = self.color1[idx % len(self.color1)] if np.random.rand( 57 | ) < 0.7 else self.color2[idx % len(self.color2)] 58 | image2 = cv2.imread(name2) 59 | image2 = cv2.cvtColor(image2, cv2.COLOR_BGR2LAB) 60 | 61 | mean, std = image.mean(axis=(0, 1), 62 | keepdims=True), image.std(axis=(0, 1), 63 | keepdims=True) 64 | mean2, std2 = image2.mean(axis=(0, 1), 65 | keepdims=True), image2.std(axis=(0, 1), 66 | keepdims=True) 67 | image = np.uint8((image - mean) / std * std2 + mean2) 68 | image = cv2.cvtColor(image, cv2.COLOR_LAB2RGB) 69 | gt = (cv2.imread(self.gts[idx], cv2.IMREAD_GRAYSCALE)) 70 | point_heatmap = cv2.Canny(gt, 0, 255) / 255.0 71 | gt = gt // 255.0 72 | gt = np.concatenate( 73 | [gt[..., np.newaxis], point_heatmap[..., np.newaxis]], axis=-1) 74 | pair = self.transform(image=image, mask=gt) 75 | gt = pair['mask'][:, :, 0] 76 | point_heatmap = pair['mask'][:, :, 1] 77 | gt = torch.unsqueeze(gt, 0) 78 | point_heatmap = torch.unsqueeze(point_heatmap, 0) 79 | image = pair['image'] / 255.0 80 | 81 | return image, gt, point_heatmap 82 | 83 | def filter_files(self): 84 | assert len(self.images) == len(self.gts) 85 | images = [] 86 | gts = [] 87 | for img_path, gt_path in zip(self.images, self.gts): 88 | img = Image.open(img_path) 89 | gt = Image.open(gt_path) 90 | if img.size == gt.size: 91 | images.append(img_path) 92 | gts.append(gt_path) 93 | self.images = images 94 | self.gts = gts 95 | 96 | def resize(self, img, gt): 97 | assert img.size == gt.size 98 | w, h = img.size 99 | if h < self.trainsize or w < self.trainsize: 100 | h = max(h, self.trainsize) 101 | w = max(w, self.trainsize) 102 | return img.resize((w, h), Image.BILINEAR), gt.resize((w, h), 103 | Image.NEAREST) 104 | else: 105 | return img, gt 106 | 107 | def __len__(self): 108 | return self.size 109 | 110 | 111 | def get_loader(image_root, 112 | gt_root, 113 | batchsize, 114 | trainsize, 115 | shuffle=True, 116 | num_workers=4, 117 | pin_memory=True, 118 | augmentation=False): 119 | 120 | dataset = PolypDataset(image_root, gt_root, trainsize, augmentation) 121 | data_loader = data.DataLoader(dataset=dataset, 122 | batch_size=batchsize, 123 | shuffle=shuffle, 124 | num_workers=num_workers, 125 | pin_memory=pin_memory) 126 | return data_loader 127 | 128 | 129 | class test_dataset: 130 | def __init__(self, image_root, gt_root, testsize): 131 | self.testsize = testsize 132 | self.images = [ 133 | image_root + f for f in os.listdir(image_root) 134 | if f.endswith('.jpg') or f.endswith('.png') 135 | ] 136 | self.gts = [ 137 | gt_root + f for f in os.listdir(gt_root) 138 | if f.endswith('.tif') or f.endswith('.png') 139 | ] 140 | self.images = sorted(self.images) 141 | self.gts = sorted(self.gts) 142 | self.transform = A.Compose( 143 | [A.Resize(self.testsize, self.testsize), 144 | ToTensorV2()]) 145 | self.size = len(self.images) 146 | self.index = 0 147 | 148 | def load_data(self): 149 | image = cv2.imread(self.images[self.index]) 150 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 151 | gt = cv2.imread(self.gts[self.index], cv2.IMREAD_GRAYSCALE) 152 | pair = self.transform(image=image, mask=gt) 153 | image = pair['image'].unsqueeze(0) / 255 154 | gt = pair['mask'] / 255 155 | name = self.images[self.index].split('/')[-1] 156 | if name.endswith('.jpg'): 157 | name = name.split('.jpg')[0] + '.png' 158 | self.index += 1 159 | return image, gt, name 160 | 161 | def rgb_loader(self, path): 162 | with open(path, 'rb') as f: 163 | img = Image.open(f) 164 | return img.convert('RGB') 165 | 166 | def binary_loader(self, path): 167 | with open(path, 'rb') as f: 168 | img = Image.open(f) 169 | return img.convert('L') -------------------------------------------------------------------------------- /lib/modules.py: -------------------------------------------------------------------------------- 1 | import torch.nn.functional as F 2 | import torch.nn as nn 3 | import torch 4 | 5 | 6 | class _simple_learner(nn.Module): 7 | def __init__(self, d_model): 8 | super().__init__() 9 | self.mlp = nn.Conv2d(d_model * 2, d_model, 1, 1) 10 | 11 | def forward(self, f_low, f_high): 12 | low_size = f_low.shape[2:] 13 | f2_high = F.interpolate(f_high, size=low_size) 14 | 15 | f2_low = torch.cat([f_low, f2_high], dim=1) 16 | f2_low = self.mlp(f2_low) 17 | return f2_low 18 | 19 | 20 | class xboundlearnerv2(nn.Module): 21 | def __init__(self, d_model, nhead, dim_feedforward=512, dropout=0.0): 22 | super().__init__() 23 | 24 | self.xbl = xboundlearner(d_model, 25 | nhead, 26 | dim_feedforward=dim_feedforward, 27 | dropout=dropout) 28 | self.xbl1 = xboundlearner(d_model, 29 | nhead, 30 | dim_feedforward=dim_feedforward, 31 | dropout=dropout) 32 | self.mlp = nn.Conv2d(d_model * 2, d_model, 1, 1) 33 | 34 | def forward(self, f_low, f_high, xi_low, xi_high): 35 | f2_low = self.xbl(f_low, xi_high) 36 | f2_high = self.xbl1(f_high, xi_low) 37 | 38 | low_size = f2_low.shape[2:] 39 | f2_high = F.interpolate(f2_high, size=low_size) 40 | 41 | f2_low = torch.cat([f2_low, f2_high], dim=1) 42 | f2_low = self.mlp(f2_low) 43 | return f2_low + f2_low 44 | 45 | 46 | class xboundlearner(nn.Module): 47 | def __init__(self, d_model, nhead, dim_feedforward=512, dropout=0.0): 48 | super().__init__() 49 | 50 | self.cross_attn = nn.MultiheadAttention(d_model, 51 | nhead, 52 | dropout=dropout) 53 | 54 | self.linear1 = nn.Linear(d_model, dim_feedforward) 55 | self.dropout = nn.Dropout(dropout) 56 | self.linear2 = nn.Linear(dim_feedforward, d_model) 57 | 58 | self.norm1 = nn.LayerNorm(d_model) 59 | self.norm2 = nn.LayerNorm(d_model) 60 | 61 | self.dropout1 = nn.Dropout(dropout) 62 | self.dropout2 = nn.Dropout(dropout) 63 | 64 | self.activation = nn.LeakyReLU() 65 | 66 | def forward(self, tgt, src): 67 | "tgt shape: Batch_size, C, H, W " 68 | "src shape: Batch_size, 1, C " 69 | 70 | B, C, h, w = tgt.shape 71 | tgt = tgt.view(B, C, h * w).permute(2, 0, 1) # shape: L, B, C 72 | 73 | src = src.permute(1, 0, 2) # shape: Q:1, B, C 74 | 75 | fusion_feature = self.cross_attn(query=tgt, key=src, value=src)[0] 76 | tgt = tgt + self.dropout1(fusion_feature) 77 | tgt = self.norm1(tgt) 78 | 79 | tgt1 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) 80 | tgt = tgt + self.dropout2(tgt1) 81 | tgt = self.norm2(tgt) 82 | return tgt.permute(1, 2, 0).view(B, C, h, w) 83 | 84 | 85 | class BoundaryWiseAttentionGateAtrous2D(nn.Module): 86 | def __init__(self, in_channels, hidden_channels=None): 87 | 88 | super(BoundaryWiseAttentionGateAtrous2D, self).__init__() 89 | 90 | modules = [] 91 | 92 | if hidden_channels == None: 93 | hidden_channels = in_channels // 2 94 | 95 | modules.append( 96 | nn.Sequential( 97 | nn.Conv2d(in_channels, hidden_channels, 1, bias=False), 98 | nn.BatchNorm2d(hidden_channels), nn.ReLU(inplace=True))) 99 | modules.append( 100 | nn.Sequential( 101 | nn.Conv2d(in_channels, 102 | hidden_channels, 103 | 3, 104 | padding=1, 105 | dilation=1, 106 | bias=False), nn.BatchNorm2d(hidden_channels), 107 | nn.ReLU(inplace=True))) 108 | modules.append( 109 | nn.Sequential( 110 | nn.Conv2d(in_channels, 111 | hidden_channels, 112 | 3, 113 | padding=2, 114 | dilation=2, 115 | bias=False), nn.BatchNorm2d(hidden_channels), 116 | nn.ReLU(inplace=True))) 117 | modules.append( 118 | nn.Sequential( 119 | nn.Conv2d(in_channels, 120 | hidden_channels, 121 | 3, 122 | padding=4, 123 | dilation=4, 124 | bias=False), nn.BatchNorm2d(hidden_channels), 125 | nn.ReLU(inplace=True))) 126 | modules.append( 127 | nn.Sequential( 128 | nn.Conv2d(in_channels, 129 | hidden_channels, 130 | 3, 131 | padding=6, 132 | dilation=6, 133 | bias=False), nn.BatchNorm2d(hidden_channels), 134 | nn.ReLU(inplace=True))) 135 | 136 | self.convs = nn.ModuleList(modules) 137 | 138 | self.conv_out = nn.Conv2d(5 * hidden_channels, 1, 1, bias=False) 139 | 140 | def forward(self, x): 141 | " x.shape: B, C, H, W " 142 | " return: feature, weight (B,C,H,W) " 143 | res = [] 144 | for conv in self.convs: 145 | res.append(conv(x)) 146 | res = torch.cat(res, dim=1) 147 | weight = torch.sigmoid(self.conv_out(res)) 148 | x = x * weight + x 149 | return x, weight 150 | 151 | 152 | class BoundaryWiseAttentionGate2D(nn.Sequential): 153 | def __init__(self, in_channels, hidden_channels=None): 154 | super(BoundaryWiseAttentionGate2D, self).__init__( 155 | nn.Conv2d(in_channels, 156 | in_channels, 157 | kernel_size=3, 158 | padding=1, 159 | bias=False), nn.BatchNorm2d(in_channels), 160 | nn.ReLU(inplace=False), 161 | nn.Conv2d(in_channels, 162 | in_channels, 163 | kernel_size=3, 164 | padding=1, 165 | bias=False), nn.BatchNorm2d(in_channels), 166 | nn.ReLU(inplace=False), nn.Conv2d(in_channels, 1, kernel_size=1)) 167 | 168 | def forward(self, x): 169 | " x.shape: B, C, H, W " 170 | " return: feature, weight (B,C,H,W) " 171 | weight = torch.sigmoid( 172 | super(BoundaryWiseAttentionGate2D, self).forward(x)) 173 | x = x * weight + x 174 | return x, weight 175 | -------------------------------------------------------------------------------- /utils/point_generate.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import os 3 | import random 4 | import torch 5 | import numpy as np 6 | import skimage.draw 7 | from tqdm import tqdm 8 | import matplotlib.pyplot as plt 9 | import torch.nn.functional as F 10 | 11 | 12 | def create_circular_mask(h, w, center, radius): 13 | Y, X = np.ogrid[:h, :w] 14 | dist_from_center = np.sqrt((X - center[0])**2 + (Y - center[1])**2) 15 | mask = dist_from_center <= radius 16 | return mask 17 | 18 | 19 | def NMS(heatmap, kernel=13): 20 | hmax = F.max_pool2d(heatmap, kernel, stride=1, padding=(kernel - 1) // 2) 21 | keep = (hmax == heatmap).float() 22 | return heatmap * keep, hmax, keep 23 | 24 | 25 | def draw_msra_gaussian(heatmap, center, sigma): 26 | tmp_size = sigma * 3 27 | mu_x = int(center[0] + 0.5) 28 | mu_y = int(center[1] + 0.5) 29 | w, h = heatmap.shape[0], heatmap.shape[1] 30 | ul = [int(mu_x - tmp_size), int(mu_y - tmp_size)] 31 | br = [int(mu_x + tmp_size + 1), int(mu_y + tmp_size + 1)] 32 | if ul[0] >= h or ul[1] >= w or br[0] < 0 or br[1] < 0: 33 | return heatmap 34 | size = 2 * tmp_size + 1 35 | x = np.arange(0, size, 1, np.float32) 36 | y = x[:, np.newaxis] 37 | x0 = y0 = size // 2 38 | g = np.exp(-((x - x0)**2 + (y - y0)**2) / (2 * sigma**2)) 39 | g_x = max(0, -ul[0]), min(br[0], h) - ul[0] 40 | g_y = max(0, -ul[1]), min(br[1], w) - ul[1] 41 | img_x = max(0, ul[0]), min(br[0], h) 42 | img_y = max(0, ul[1]), min(br[1], w) 43 | heatmap[img_y[0]:img_y[1], img_x[0]:img_x[1]] = np.maximum( 44 | heatmap[img_y[0]:img_y[1], img_x[0]:img_x[1]], g[g_y[0]:g_y[1], 45 | g_x[0]:g_x[1]]) 46 | return heatmap 47 | 48 | 49 | def kpm_gen(label_path, R, N): 50 | label = np.load(label_path) 51 | # label = label[0] 52 | label_ori = label.copy() 53 | label = label[::4, ::4] 54 | label = np.uint8(label * 255) 55 | contours, hierarchy = cv2.findContours(label, cv2.RETR_LIST, 56 | cv2.CHAIN_APPROX_NONE) 57 | contour_len = len(contours) 58 | 59 | label = np.repeat(label[..., np.newaxis], 3, axis=-1) 60 | draw_label = cv2.drawContours(label.copy(), contours, -1, (0, 0, 255), 1) 61 | 62 | point_file = [] 63 | if contour_len == 0: 64 | point_heatmap = np.zeros((512, 512)) 65 | else: 66 | point_heatmap = np.zeros((512, 512)) 67 | for contour in contours: 68 | stds = [] 69 | points = contour[:, 0] # (N,2) 70 | points = points * 4 71 | points_number = contour.shape[0] 72 | if points_number < 30: 73 | continue 74 | 75 | if points_number < 100: 76 | radius = 6 77 | neighbor_points_n_oneside = 3 78 | elif points_number < 200: 79 | radius = 10 80 | neighbor_points_n_oneside = 15 81 | elif points_number < 300: 82 | radius = 10 83 | neighbor_points_n_oneside = 20 84 | elif points_number < 350: 85 | radius = 10 86 | neighbor_points_n_oneside = 20 87 | else: 88 | radius = 10 89 | neighbor_points_n_oneside = 20 90 | 91 | radius 92 | for i in range(points_number): 93 | current_point = points[i] 94 | mask = create_circular_mask(512, 512, points[i], radius) 95 | overlap_area = np.sum( 96 | mask * label_ori) / (np.pi * radius * radius) 97 | stds.append(overlap_area) 98 | print("stds len: ", len(stds)) 99 | 100 | # show 101 | selected_points = [] 102 | stds = np.array(stds) 103 | neighbor_points = [] 104 | for i in range(len(points)): 105 | current_point = points[i] 106 | neighbor_points_index = np.concatenate([ 107 | np.arange(-neighbor_points_n_oneside, 0), 108 | np.arange(1, neighbor_points_n_oneside + 1) 109 | ]) + i 110 | neighbor_points_index[np.where( 111 | neighbor_points_index < 0)[0]] += len(points) 112 | neighbor_points_index[np.where( 113 | neighbor_points_index > len(points) - 1)[0]] -= len(points) 114 | if stds[i] < np.min( 115 | stds[neighbor_points_index]) or stds[i] > np.max( 116 | stds[neighbor_points_index]): 117 | # print(points[i]) 118 | point_heatmap = draw_msra_gaussian( 119 | point_heatmap, (points[i, 0], points[i, 1]), 5) 120 | selected_points.append(points[i]) 121 | 122 | print("selected_points num: ", len(selected_points)) 123 | # print(selected_points) 124 | maskk = np.zeros((512, 512)) 125 | rr, cc = skimage.draw.polygon( 126 | np.array(selected_points)[:, 1], 127 | np.array(selected_points)[:, 0]) 128 | maskk[rr, cc] = 1 129 | intersection = np.logical_and(label_ori, maskk) 130 | union = np.logical_or(label_ori, maskk) 131 | iou_score = np.sum(intersection) / np.sum(union) 132 | print(iou_score) 133 | return label_ori, point_heatmap 134 | 135 | 136 | def point_gen_isic2018(): 137 | R = 10 138 | N = 25 139 | data_dir = '/raid/wjc/data/skin_lesion/isic2018/Label' 140 | 141 | save_dir = data_dir.replace('Label', 'Point') 142 | os.makedirs(save_dir, exist_ok=True) 143 | 144 | path_list = os.listdir(data_dir) 145 | path_list.sort() 146 | num = 0 147 | for path in tqdm(path_list): 148 | name = path[:-4] 149 | label_path = os.path.join(data_dir, path) 150 | print(label_path) 151 | label_ori, point_heatmap = kpm_gen(label_path, R, N) 152 | 153 | save_path = os.path.join(save_dir, name + '.npy') 154 | np.save(save_path, point_heatmap) 155 | num += 1 156 | 157 | 158 | def point_gen_isic2016(): 159 | R = 10 160 | N = 25 161 | for split in ['Train', 'Test', 'Validation']: 162 | data_dir = '/raid/wjc/data/skin_lesion/isic2016/{}/Label'.format(split) 163 | 164 | save_dir = data_dir.replace('Label', 'Point') 165 | os.makedirs(save_dir, exist_ok=True) 166 | 167 | path_list = os.listdir(data_dir) 168 | path_list.sort() 169 | num = 0 170 | for path in tqdm(path_list): 171 | name = path[:-4] 172 | label_path = os.path.join(data_dir, path) 173 | print(label_path) 174 | label_ori, point_heatmap = kpm_gen(label_path, R, N) 175 | save_path = os.path.join(save_dir, name + '.npy') 176 | np.save(save_path, point_heatmap) 177 | num += 1 178 | 179 | 180 | if __name__ == '__main__': 181 | # point_gen_isic2018() 182 | point_gen_isic2016() -------------------------------------------------------------------------------- /utils/isbi2016_new.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import json 4 | import torch 5 | import random 6 | import torch.nn as nn 7 | import numpy as np 8 | import torch.utils.data 9 | from torchvision import transforms 10 | import torch.utils.data as data 11 | import torch.nn.functional as F 12 | import cv2 13 | 14 | import albumentations as A 15 | from sklearn.model_selection import KFold 16 | from .polar_transformations import centroid, to_polar 17 | 18 | 19 | def norm01(x): 20 | return np.clip(x, 0, 255) / 255 21 | 22 | 23 | seperable_indexes = json.load(open('utils/data_split.json', 'r')) 24 | 25 | 26 | # cross validation 27 | class myDataset(data.Dataset): 28 | def __init__(self, split, size=352, aug=False, polar=False): 29 | super(myDataset, self).__init__() 30 | self.polar = polar 31 | self.split = split 32 | 33 | # load images, label, point 34 | self.image_paths = [] 35 | self.label_paths = [] 36 | self.point_paths = [] 37 | self.dist_paths = [] 38 | 39 | root_dir = '/raid/wjc/data/skin_lesion/isic2016' 40 | 41 | if split == 'train': 42 | indexes = os.listdir(root_dir + '/Train/Image/') 43 | self.image_paths = [ 44 | f'{root_dir}/Train/Image/{_id}' for _id in indexes 45 | ] 46 | self.label_paths = [ 47 | f'{root_dir}/Train/Label/{_id[:-4]}_label.npy' 48 | for _id in indexes 49 | ] 50 | self.point_paths = [ 51 | f'{root_dir}/Train/Point/{_id[:-4]}_label.npy' 52 | for _id in indexes 53 | ] 54 | 55 | elif split == 'valid': 56 | indexes = os.listdir(root_dir + '/Validation/Image/') 57 | self.image_paths = [ 58 | f'{root_dir}/Validation/Image/{_id}' for _id in indexes 59 | ] 60 | self.label_paths = [ 61 | f'{root_dir}/Validation/Label/{_id[:-4]}_label.npy' 62 | for _id in indexes 63 | ] 64 | else: 65 | indexes = os.listdir(root_dir + '/Test/Image/') 66 | self.image_paths = [ 67 | f'{root_dir}/Test/Image/{_id}' for _id in indexes 68 | ] 69 | self.label_paths = [ 70 | f'{root_dir}/Test/Label/{_id[:-4]}_label.npy' 71 | for _id in indexes 72 | ] 73 | 74 | print('Loaded {} frames'.format(len(self.image_paths))) 75 | self.num_samples = len(self.image_paths) 76 | self.aug = aug 77 | self.size = size 78 | 79 | p = 0.5 80 | self.transf = A.Compose([ 81 | A.GaussNoise(p=p), 82 | A.HorizontalFlip(p=p), 83 | A.VerticalFlip(p=p), 84 | A.ShiftScaleRotate(p=p), 85 | # A.RandomBrightnessContrast(p=p), 86 | ]) 87 | 88 | def __getitem__(self, index): 89 | # print(self.image_paths[index]) 90 | # image = cv2.imread(self.image_paths[index]) 91 | # image_data = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 92 | # label_data = cv2.imread(self.label_paths[index], cv2.IMREAD_GRAYSCALE) 93 | image_data = np.load(self.image_paths[index]) 94 | label_data = (np.load(self.label_paths[index]) * 255).astype('uint8') 95 | 96 | label_data = np.array( 97 | cv2.resize(label_data, (self.size, self.size), cv2.INTER_NEAREST)) 98 | point_data = cv2.Canny(label_data, 0, 255) / 255.0 > 0.5 99 | label_data = label_data / 255. > 0.5 100 | image_data = np.array( 101 | cv2.resize(image_data, (self.size, self.size), cv2.INTER_LINEAR)) 102 | if self.split == 'train': 103 | filter_point_data = (np.load(self.point_paths[index]) > 104 | 0.7).astype('uint8') 105 | filter_point_data = np.array( 106 | cv2.resize(filter_point_data, (self.size, self.size), 107 | cv2.INTER_NEAREST)) 108 | else: 109 | filter_point_data = point_data.copy() 110 | 111 | # image_data = np.load(self.image_paths[index]) 112 | # label_data = np.load(self.label_paths[index]) > 0.5 113 | # point_data = np.load(self.point_paths[index]) > 0.5 114 | # point_All_data = np.load(self.point_All_paths[index]) > 0.5 # 115 | 116 | # label_data = np.expand_dims(label_data,-1) 117 | # point_data = np.expand_dims(point_data,-1) 118 | if self.aug and self.split == 'train': 119 | mask = np.concatenate([ 120 | label_data[..., np.newaxis].astype('uint8'), 121 | point_data[..., np.newaxis], filter_point_data[..., np.newaxis] 122 | ], 123 | axis=-1) 124 | # print(mask.shape) 125 | tsf = self.transf(image=image_data.astype('uint8'), mask=mask) 126 | image_data, mask_aug = tsf['image'], tsf['mask'] 127 | label_data = mask_aug[:, :, 0] 128 | point_data = mask_aug[:, :, 1] 129 | filter_point_data = mask_aug[:, :, 2] 130 | 131 | image_data = norm01(image_data) 132 | 133 | if self.polar: 134 | center = centroid(image_data) 135 | image_data = to_polar(image_data, center) 136 | label_data = to_polar(label_data, center) > 0.5 137 | 138 | label_data = np.expand_dims(label_data, 0) 139 | point_data = np.expand_dims(point_data, 0) 140 | filter_point_data = np.expand_dims(filter_point_data, 0) # 141 | 142 | image_data = torch.from_numpy(image_data).float() 143 | label_data = torch.from_numpy(label_data).float() 144 | point_data = torch.from_numpy(point_data).float() 145 | filter_point_data = torch.from_numpy(filter_point_data).float() # 146 | 147 | image_data = image_data.permute(2, 0, 1) 148 | return { 149 | 'image_path': self.image_paths[index], 150 | 'label_path': self.label_paths[index], 151 | # 'point_path': self.point_paths[index], 152 | 'image': image_data, 153 | 'label': label_data, 154 | 'point': point_data, 155 | 'filter_point_data': filter_point_data 156 | } 157 | 158 | def __len__(self): 159 | return self.num_samples 160 | 161 | 162 | if __name__ == '__main__': 163 | from tqdm import tqdm 164 | dataset = myDataset(split='train', aug=True) 165 | 166 | train_loader = torch.utils.data.DataLoader(dataset, 167 | batch_size=8, 168 | shuffle=False, 169 | num_workers=2, 170 | pin_memory=True, 171 | drop_last=True) 172 | import matplotlib.pyplot as plt 173 | for d in dataset: 174 | print(d['image'].shape, d['image'].max()) 175 | print(d['point'].shape, d['point'].max()) 176 | print(d['filter_point_data'].shape, d['filter_point_data'].max()) 177 | image = d['image'].permute(1, 2, 0).cpu() 178 | point = d['point'][0].cpu() 179 | filter_point_data = d['filter_point_data'][0].cpu() 180 | plt.imshow(filter_point_data) 181 | plt.show() -------------------------------------------------------------------------------- /lib/xboundformer_v0.py: -------------------------------------------------------------------------------- 1 | from turtle import forward 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from lib.modules import xboundlearner 7 | from lib.vision_transformers import in_scale_transformer 8 | 9 | from lib.pvtv2 import pvt_v2_b2 # 10 | 11 | 12 | def _segm_pvtv2(num_classes, im_num, ex_num, xbound): 13 | backbone = pvt_v2_b2() 14 | 15 | if 1: 16 | path = 'pvt_v2_b2.pth' 17 | save_model = torch.load(path) 18 | model_dict = backbone.state_dict() 19 | state_dict = { 20 | k: v 21 | for k, v in save_model.items() if k in model_dict.keys() 22 | } 23 | model_dict.update(state_dict) 24 | backbone.load_state_dict(model_dict) 25 | classifier = _simple_classifier(num_classes) 26 | model = _SimpleSegmentationModel(backbone, classifier, im_num, ex_num) 27 | return model 28 | 29 | 30 | class _simple_classifier(nn.Module): 31 | def __init__(self, num_classes): 32 | super(_simple_classifier, self).__init__() 33 | self.classifier = nn.Sequential( 34 | nn.Conv2d(576, 256, 3, padding=1, bias=False), #560 35 | nn.BatchNorm2d(256), 36 | nn.ReLU(inplace=True), 37 | nn.Conv2d(256, num_classes, 1)) 38 | 39 | def forward(self, feature): 40 | low_level_feature = feature[0] 41 | output_feature = feature[3] 42 | output_feature = F.interpolate(output_feature, 43 | size=low_level_feature.shape[2:], 44 | mode='bilinear', 45 | align_corners=False) 46 | return self.classifier( 47 | torch.cat([low_level_feature, output_feature], dim=1)) 48 | 49 | 50 | class _SimpleSegmentationModel(nn.Module): 51 | # general segmentation model 52 | def __init__(self, backbone, classifier, im_num, ex_num): 53 | super(_SimpleSegmentationModel, self).__init__() 54 | self.backbone = backbone 55 | self.classifier = classifier 56 | self.bat_low = _bound_learner(hidden_features=128, 57 | im_num=im_num, 58 | ex_num=ex_num) 59 | 60 | def forward(self, x): 61 | input_shape = x.shape[-2:] 62 | features = self.backbone( 63 | x 64 | ) # ([8, 64, 64, 64]) ([8, 128, 32, 32]) ([8, 320, 16, 16]) ([8, 512, 8, 8]) 65 | features, point_pre1, point_pre2, point_pre3 = self.bat_low(features) 66 | x = self.classifier(features) 67 | x = F.interpolate(x, 68 | size=input_shape, 69 | mode='bilinear', 70 | align_corners=False) 71 | return x, point_pre1, point_pre2, point_pre3 72 | 73 | 74 | class _bound_learner(nn.Module): 75 | def __init__(self, point_pred=1, hidden_features=128, im_num=2, ex_num=2): 76 | 77 | super().__init__() 78 | 79 | self.point_pred = point_pred 80 | 81 | self.convolution_mapping_1 = nn.Conv2d(in_channels=128, 82 | out_channels=hidden_features, 83 | kernel_size=(1, 1), 84 | stride=(1, 1), 85 | padding=(0, 0), 86 | bias=True) 87 | self.convolution_mapping_2 = nn.Conv2d(in_channels=320, 88 | out_channels=hidden_features, 89 | kernel_size=(1, 1), 90 | stride=(1, 1), 91 | padding=(0, 0), 92 | bias=True) 93 | self.convolution_mapping_3 = nn.Conv2d(in_channels=512, 94 | out_channels=hidden_features, 95 | kernel_size=(1, 1), 96 | stride=(1, 1), 97 | padding=(0, 0), 98 | bias=True) 99 | normalize_before = True 100 | self.im_ex_boud1 = in_scale_transformer( 101 | point_pred_layers=1, 102 | num_encoder_layers=im_num, 103 | num_decoder_layers=ex_num, 104 | d_model=hidden_features, 105 | nhead=8, 106 | normalize_before=normalize_before) 107 | self.im_ex_boud2 = in_scale_transformer( 108 | point_pred_layers=1, 109 | num_encoder_layers=im_num, 110 | num_decoder_layers=ex_num, 111 | d_model=hidden_features, 112 | nhead=8, 113 | normalize_before=normalize_before) 114 | self.im_ex_boud3 = in_scale_transformer( 115 | point_pred_layers=1, 116 | num_encoder_layers=im_num, 117 | num_decoder_layers=ex_num, 118 | d_model=hidden_features, 119 | nhead=8, 120 | normalize_before=normalize_before) 121 | self.cross_attention_3_1 = xboundlearner(hidden_features, 8) 122 | self.cross_attention_3_2 = xboundlearner(hidden_features, 8) 123 | self.trans_out_conv = nn.Conv2d(hidden_features * 2, 512, 1, 1) # 124 | 125 | def forward(self, x): 126 | features_1 = x[1] 127 | features_2 = x[2] 128 | features_3 = x[3] 129 | features_1 = self.convolution_mapping_1(features_1) 130 | features_2 = self.convolution_mapping_2(features_2) 131 | features_3 = self.convolution_mapping_3(features_3) 132 | 133 | # in-scale attention 134 | latent_tensor_1, features_encoded_1, point_maps_1 = self.im_ex_boud1( 135 | features_1) 136 | 137 | latent_tensor_2, features_encoded_2, point_maps_2 = self.im_ex_boud2( 138 | features_2) 139 | 140 | latent_tensor_3, features_encoded_3, point_maps_3 = self.im_ex_boud3( 141 | features_3) 142 | 143 | # cross-scale attention6 144 | latent_tensor_1 = latent_tensor_1.permute(2, 0, 1) 145 | latent_tensor_2 = latent_tensor_2.permute(2, 0, 1) 146 | latent_tensor_3 = latent_tensor_3.permute(2, 0, 1) 147 | 148 | # ''' point map Upsample ''' 149 | features_encoded_3_1 = self.cross_attention_3_1( 150 | features_encoded_3, latent_tensor_1) 151 | features_encoded_3_2 = self.cross_attention_3_2( 152 | features_encoded_3, latent_tensor_2) 153 | 154 | trans_feature_maps = self.trans_out_conv( 155 | torch.cat([features_encoded_3_1, features_encoded_3_2], dim=1)) 156 | 157 | x[3] = trans_feature_maps 158 | x[2] = torch.cat([x[2], features_encoded_2], dim=1) 159 | x[1] = torch.cat([x[1], features_encoded_1], dim=1) 160 | 161 | if self.point_pred == 1: 162 | return x, point_maps_1, point_maps_2, point_maps_3 # 163 | else: 164 | return trans_feature_maps 165 | 166 | 167 | if __name__ == '__main__': 168 | import os 169 | os.environ['CUDA_VISIBLE_DEVICES'] = '4' 170 | model = _segm_pvtv2(1).cuda() 171 | input_tensor = torch.randn(1, 3, 352, 352).cuda() 172 | 173 | prediction1 = model(input_tensor) 174 | -------------------------------------------------------------------------------- /utils/isic2018_dataloader.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image 3 | import torch.utils.data as data 4 | import torchvision.transforms as transforms 5 | import numpy as np 6 | import random 7 | import torch 8 | import cv2 9 | from sklearn.model_selection import KFold 10 | import albumentations as A 11 | from albumentations.pytorch import ToTensorV2 12 | import json 13 | 14 | 15 | class isic2018Dataset(data.Dataset): 16 | """ 17 | dataloader for isic2018 segmentation tasks 18 | """ 19 | def __init__(self, image_root, gt_root, image_index, trainsize, 20 | augmentations): 21 | self.trainsize = trainsize 22 | self.augmentations = augmentations 23 | print(self.augmentations) 24 | self.image_root = image_root 25 | self.gt_root = gt_root 26 | self.images = image_index 27 | self.size = len(self.images) 28 | 29 | if self.augmentations: 30 | print('Using RandomRotation, RandomFlip') 31 | 32 | self.transform = A.Compose([ 33 | A.Rotate(90), 34 | # A.GaussNoise(p=0.5), 35 | A.VerticalFlip(p=0.5), 36 | A.HorizontalFlip(p=0.5), 37 | # A.ShiftScaleRotate(p=0.5), 38 | # A.Resize(self.trainsize, self.trainsize), 39 | ToTensorV2() 40 | ]) 41 | else: 42 | print('no augmentation') 43 | self.transform = A.Compose([ 44 | # A.Resize(self.trainsize, self.trainsize), 45 | ToTensorV2() 46 | ]) 47 | 48 | def __getitem__(self, idx): 49 | file_name = self.images[idx] 50 | # gt_name = file_name[:-4] + '_segmentation.png' 51 | img_root = os.path.join(self.image_root, file_name) 52 | gt_root = os.path.join(self.gt_root, file_name) 53 | image = cv2.imread(img_root) 54 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 55 | gt = (cv2.imread(gt_root, cv2.IMREAD_GRAYSCALE)) 56 | point_heatmap = cv2.Canny(gt, 0, 255) / 255.0 57 | gt = gt // 255.0 58 | gt = np.concatenate( 59 | [gt[..., np.newaxis], point_heatmap[..., np.newaxis]], axis=-1) 60 | pair = self.transform(image=image, mask=gt) 61 | gt = pair['mask'][:, :, 0] 62 | point_heatmap = pair['mask'][:, :, 1] 63 | gt = torch.unsqueeze(gt, 0) 64 | point_heatmap = torch.unsqueeze(point_heatmap, 0) 65 | image = pair['image'] / 255.0 66 | 67 | return image, gt, point_heatmap 68 | 69 | def resize(self, img, gt): 70 | assert img.size == gt.size 71 | w, h = img.size 72 | if h < self.trainsize or w < self.trainsize: 73 | h = max(h, self.trainsize) 74 | w = max(w, self.trainsize) 75 | return img.resize((w, h), Image.BILINEAR), gt.resize((w, h), 76 | Image.NEAREST) 77 | else: 78 | return img, gt 79 | 80 | def __len__(self): 81 | return self.size 82 | 83 | 84 | # def get_loader(image_root,batchsize, trainsize, floder,shuffle=True, num_workers=8, pin_memory=True, augmentation=False): 85 | # train_image_index_all,test_image_index_all=create_k_fold_division(image_root) 86 | # image_index=train_image_index_all[floder] 87 | # test_index=test_image_index_all[floder] 88 | # dataset = isic2018Dataset(image_root, image_index, trainsize, augmentation) 89 | # data_loader = data.DataLoader(dataset=dataset, 90 | # batch_size=batchsize, 91 | # shuffle=shuffle, 92 | # num_workers=num_workers, 93 | # pin_memory=pin_memory) 94 | 95 | # testset=test_dataset(image_root,test_index,trainsize) 96 | # test_loader = data.DataLoader(dataset=testset, 97 | # batch_size=1, 98 | # shuffle=shuffle, 99 | # num_workers=num_workers, 100 | # pin_memory=pin_memory) 101 | # return data_loader,testset 102 | 103 | 104 | def get_loader(image_root, 105 | gt_root, 106 | batchsize, 107 | trainsize, 108 | floder, 109 | shuffle=True, 110 | num_workers=8, 111 | pin_memory=True, 112 | augmentation=False): 113 | js = json.load(open('utils/data_split.json')) 114 | # print(js) 115 | # train / test 116 | all_index = [f for f in os.listdir(image_root) if f.endswith('.jpg')] 117 | 118 | test_index = ['ISIC_' + i + '.jpg' for i in js[str(floder)]] 119 | image_index = list(filter(lambda x: x not in test_index, all_index)) 120 | print(len(all_index), len(image_index), len(test_index)) 121 | 122 | dataset = isic2018Dataset(image_root, gt_root, image_index, trainsize, 123 | augmentation) 124 | data_loader = data.DataLoader(dataset=dataset, 125 | batch_size=batchsize, 126 | shuffle=shuffle, 127 | num_workers=num_workers, 128 | pin_memory=pin_memory) 129 | 130 | testset = isic2018Dataset(image_root, gt_root, test_index, trainsize, 131 | False) 132 | test_loader = data.DataLoader(dataset=testset, 133 | batch_size=1, 134 | shuffle=shuffle, 135 | num_workers=num_workers, 136 | pin_memory=pin_memory) 137 | return data_loader, test_loader 138 | 139 | 140 | class test_dataset: 141 | def __init__(self, image_root, gt_root, test_index, testsize): 142 | self.testsize = testsize 143 | self.image_root = image_root 144 | self.gt_root = gt_root 145 | self.images = test_index 146 | self.transform = A.Compose( 147 | [A.Resize(self.testsize, self.testsize), 148 | ToTensorV2()]) 149 | self.size = len(self.images) 150 | self.index = 0 151 | 152 | def __getitem__(self, idx): 153 | image = cv2.imread(os.path.join(self.image_root, self.images[idx])) 154 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 155 | # gt_ind=self.images[idx].split('_') 156 | # gt_name=gt_ind[0]+'_Task1_'+gt_ind[2]+'_GroundTruth/ISIC_'+gt_ind[4][:-4]+'_segmentation.png' 157 | gt_name = self.images[idx][:-4] + '_segmentation.png' 158 | gt_root = os.path.join(self.gt_root, gt_name) 159 | gt = cv2.imread(gt_root, cv2.IMREAD_GRAYSCALE) 160 | pair = self.transform(image=image, mask=gt) 161 | name = self.images[idx].split('/')[-1] 162 | if name.endswith('.jpg'): 163 | name = name.split('.jpg')[0] + '.png' 164 | image = pair['image'].unsqueeze(0) / 255 165 | gt = pair['mask'] / 255 166 | return image, gt, name 167 | 168 | def rgb_loader(self, path): 169 | with open(path, 'rb') as f: 170 | img = Image.open(f) 171 | return img.convert('RGB') 172 | 173 | def binary_loader(self, path): 174 | with open(path, 'rb') as f: 175 | img = Image.open(f) 176 | return img.convert('L') 177 | 178 | def __len__(self): 179 | return self.size -------------------------------------------------------------------------------- /utils/isbi2018_new.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import json 4 | import torch 5 | import random 6 | import torch.nn as nn 7 | import numpy as np 8 | import torch.utils.data 9 | from torchvision import transforms 10 | import torch.utils.data as data 11 | import torch.nn.functional as F 12 | import cv2 13 | import sys 14 | 15 | sys.path.insert(0, os.path.dirname(__file__) + '/../') 16 | # from utils.polar_transformations import centroid, to_polar 17 | import albumentations as A 18 | from sklearn.model_selection import KFold 19 | 20 | 21 | def norm01(x): 22 | return np.clip(x, 0, 255) / 255 23 | 24 | 25 | seperable_indexes = json.load(open('utils/data_split.json', 'r')) 26 | 27 | 28 | # cross validation 29 | class myDataset(data.Dataset): 30 | def __init__(self, fold, split, size=352, aug=False, polar=False): 31 | super(myDataset, self).__init__() 32 | self.split = split 33 | self.polar = polar 34 | 35 | # load images, label, point 36 | self.image_paths = [] 37 | self.label_paths = [] 38 | self.point_paths = [] 39 | self.dist_paths = [] 40 | 41 | indexes = os.listdir( 42 | '/raid/wjc/data/skin_lesion/isic2018_jpg_smooth/Image') 43 | 44 | valid_indexes = [ 45 | 'ISIC_' + i + '.jpg' for i in seperable_indexes[str(fold)] 46 | ] 47 | train_indexes = list(filter(lambda x: x not in valid_indexes, indexes)) 48 | print(len(indexes), len(train_indexes), len(valid_indexes)) 49 | 50 | #valid_indexes = indexes[:260] 51 | #train_indexes = indexes[260:] 52 | print('Fold {}: train: {} valid: {}'.format(fold, len(train_indexes), 53 | len(valid_indexes))) 54 | 55 | root_dir = '/raid/wjc/data/skin_lesion/isic2018_jpg_smooth' 56 | if self.polar: 57 | if split == 'train': 58 | self.image_paths = [ 59 | f'{root_dir}/PolarImage/{_id}' for _id in train_indexes 60 | ] 61 | self.label_paths = [ 62 | f'{root_dir}/PolarLabel/{_id}' for _id in train_indexes 63 | ] 64 | elif split == 'valid': 65 | self.image_paths = [ 66 | f'{root_dir}/PolarImage/{_id}' for _id in valid_indexes 67 | ] 68 | self.label_paths = [ 69 | f'{root_dir}/PolarLabel/{_id}' for _id in valid_indexes 70 | ] 71 | else: 72 | if split == 'train': 73 | self.image_paths = [ 74 | f'{root_dir}/Image/{_id}' for _id in train_indexes 75 | ] 76 | self.label_paths = [ 77 | f'{root_dir}/Label/{_id}' for _id in train_indexes 78 | ] 79 | elif split == 'valid': 80 | self.image_paths = [ 81 | f'{root_dir}/Image/{_id}' for _id in valid_indexes 82 | ] 83 | self.label_paths = [ 84 | f'{root_dir}/Label/{_id}' for _id in valid_indexes 85 | ] 86 | 87 | print('Loaded {} frames'.format(len(self.image_paths))) 88 | self.num_samples = len(self.image_paths) 89 | self.aug = aug 90 | self.size = size 91 | 92 | p = 0.5 93 | self.transf = A.Compose([ 94 | A.GaussNoise(p=p), 95 | A.HorizontalFlip(p=p), 96 | A.VerticalFlip(p=p), 97 | A.ShiftScaleRotate(p=p), 98 | # A.RandomBrightnessContrast(p=p), 99 | ]) 100 | 101 | def __getitem__(self, index): 102 | # print(self.image_paths[index]) 103 | image = cv2.imread(self.image_paths[index]) 104 | image_data = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 105 | 106 | label_data = cv2.imread(self.label_paths[index], cv2.IMREAD_GRAYSCALE) 107 | label_data = np.array( 108 | cv2.resize(label_data, (self.size, self.size), cv2.INTER_NEAREST)) 109 | point_data = cv2.Canny(label_data, 0, 255) / 255.0 > 0.5 110 | label_data = label_data / 255. > 0.5 111 | image_data = np.array( 112 | cv2.resize(image_data, (self.size, self.size), cv2.INTER_LINEAR)) 113 | 114 | # image_data = np.load(self.image_paths[index]) 115 | # label_data = np.load(self.label_paths[index]) > 0.5 116 | # point_data = np.load(self.point_paths[index]) > 0.5 117 | # point_All_data = np.load(self.point_All_paths[index]) > 0.5 # 118 | 119 | # label_data = np.expand_dims(label_data,-1) 120 | # point_data = np.expand_dims(point_data,-1) 121 | if self.aug and self.split == 'train': 122 | mask = np.concatenate([ 123 | label_data[..., np.newaxis].astype('uint8'), 124 | point_data[..., np.newaxis] 125 | ], 126 | axis=-1) 127 | # print(mask.shape) 128 | tsf = self.transf(image=image_data.astype('uint8'), mask=mask) 129 | image_data, mask_aug = tsf['image'], tsf['mask'] 130 | label_data = mask_aug[:, :, 0] 131 | point_data = mask_aug[:, :, 1] 132 | 133 | image_data = norm01(image_data) 134 | 135 | label_data = np.expand_dims(label_data, 0) 136 | point_data = np.expand_dims(point_data, 0) 137 | # point_All_data = np.expand_dims(point_All_data, 0) # 138 | 139 | image_data = torch.from_numpy(image_data).float() 140 | label_data = torch.from_numpy(label_data).float() 141 | point_data = torch.from_numpy(point_data).float() 142 | # point_All_data = torch.from_numpy(point_All_data).float() # 143 | 144 | image_data = image_data.permute(2, 0, 1) 145 | return { 146 | 'image_path': self.image_paths[index], 147 | 'label_path': self.label_paths[index], 148 | # 'point_path': self.point_paths[index], 149 | 'image': image_data, 150 | 'label': label_data, 151 | 'point': point_data, 152 | 'point_All': label_data 153 | } 154 | 155 | def __len__(self): 156 | return self.num_samples 157 | 158 | 159 | if __name__ == '__main__': 160 | from tqdm import tqdm 161 | import sys 162 | dataset = myDataset(fold='0', split='valid', aug=False, polar=False) 163 | print(dataset.image_paths[:5]) 164 | print(seperable_indexes['0'][:5]) 165 | # for d in dataset: 166 | # print(d) 167 | # train_loader = torch.utils.data.DataLoader(dataset, 168 | # batch_size=8, 169 | # shuffle=False, 170 | # num_workers=2, 171 | # pin_memory=True, 172 | # drop_last=True) 173 | # import matplotlib.pyplot as plt 174 | # for d in dataset: 175 | # print(d['image'].shape, d['image'].max()) 176 | # print(d['point'].shape, d['point'].max()) 177 | # image = d['image'].permute(1, 2, 0).cpu() 178 | # label = d['label'].permute(1, 2, 0).cpu() 179 | # point = d['point'][0].cpu() 180 | # plt.figure() 181 | # plt.imshow(image) 182 | # plt.show() 183 | # plt.figure() 184 | # plt.imshow(label) 185 | # plt.show() 186 | # break -------------------------------------------------------------------------------- /utils/isic2018_polar.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image 3 | import torch.utils.data as data 4 | import torchvision.transforms as transforms 5 | import numpy as np 6 | import random 7 | import torch 8 | import cv2 9 | from sklearn.model_selection import KFold 10 | import albumentations as A 11 | from albumentations.pytorch import ToTensorV2 12 | import json 13 | from utils.polar_transformations import centroid, to_polar 14 | 15 | 16 | class isic2018Dataset(data.Dataset): 17 | """ 18 | dataloader for isic2018 segmentation tasks 19 | """ 20 | def __init__(self, image_root, gt_root, image_index, trainsize, 21 | augmentations): 22 | self.trainsize = trainsize 23 | self.augmentations = augmentations 24 | print(self.augmentations) 25 | self.image_root = image_root 26 | self.gt_root = gt_root 27 | self.images = image_index 28 | self.size = len(self.images) 29 | 30 | if self.augmentations: 31 | print('Using RandomRotation, RandomFlip') 32 | 33 | self.transform = A.Compose([ToTensorV2()]) 34 | else: 35 | print('no augmentation') 36 | self.transform = A.Compose([ 37 | # A.Resize(self.trainsize, self.trainsize), 38 | ToTensorV2() 39 | ]) 40 | 41 | def __getitem__(self, idx): 42 | file_name = self.images[idx] 43 | # gt_name = file_name[:-4] + '_segmentation.png' 44 | img_root = os.path.join(self.image_root, file_name) 45 | gt_root = os.path.join(self.gt_root, file_name) 46 | image = cv2.imread(img_root) 47 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 48 | gt = (cv2.imread(gt_root, cv2.IMREAD_GRAYSCALE)) 49 | gt = gt // 255.0 50 | 51 | # if self.polar: 52 | # if self.manual_centers is not None: 53 | # center = self.manual_centers[idx] 54 | # else: 55 | # center = polar_transformations.centroid(label) 56 | center = centroid(gt) 57 | 58 | image = to_polar(image, center) 59 | gt = to_polar(gt, center) 60 | 61 | gt = np.concatenate([gt[..., np.newaxis], gt[..., np.newaxis]], 62 | axis=-1) 63 | pair = self.transform(image=image, mask=gt) 64 | gt = pair['mask'][:, :, 0] 65 | point_heatmap = pair['mask'][:, :, 1] 66 | gt = torch.unsqueeze(gt, 0) 67 | point_heatmap = torch.unsqueeze(point_heatmap, 0) 68 | image = pair['image'] / 255.0 69 | 70 | return image, gt, point_heatmap 71 | 72 | def resize(self, img, gt): 73 | assert img.size == gt.size 74 | w, h = img.size 75 | if h < self.trainsize or w < self.trainsize: 76 | h = max(h, self.trainsize) 77 | w = max(w, self.trainsize) 78 | return img.resize((w, h), Image.BILINEAR), gt.resize((w, h), 79 | Image.NEAREST) 80 | else: 81 | return img, gt 82 | 83 | def __len__(self): 84 | return self.size 85 | 86 | 87 | # def get_loader(image_root,batchsize, trainsize, floder,shuffle=True, num_workers=8, pin_memory=True, augmentation=False): 88 | # train_image_index_all,test_image_index_all=create_k_fold_division(image_root) 89 | # image_index=train_image_index_all[floder] 90 | # test_index=test_image_index_all[floder] 91 | # dataset = isic2018Dataset(image_root, image_index, trainsize, augmentation) 92 | # data_loader = data.DataLoader(dataset=dataset, 93 | # batch_size=batchsize, 94 | # shuffle=shuffle, 95 | # num_workers=num_workers, 96 | # pin_memory=pin_memory) 97 | 98 | # testset=test_dataset(image_root,test_index,trainsize) 99 | # test_loader = data.DataLoader(dataset=testset, 100 | # batch_size=1, 101 | # shuffle=shuffle, 102 | # num_workers=num_workers, 103 | # pin_memory=pin_memory) 104 | # return data_loader,testset 105 | 106 | 107 | def get_loader(image_root, 108 | gt_root, 109 | batchsize, 110 | trainsize, 111 | floder, 112 | shuffle=True, 113 | num_workers=8, 114 | pin_memory=True, 115 | augmentation=False): 116 | js = json.load(open('utils/data_split.json')) 117 | # print(js) 118 | # train / test 119 | all_index = [f for f in os.listdir(image_root) if f.endswith('.jpg')] 120 | 121 | test_index = ['ISIC_' + i + '.jpg' for i in js[str(floder)]] 122 | image_index = list(filter(lambda x: x not in test_index, all_index)) 123 | print(len(all_index), len(image_index), len(test_index)) 124 | 125 | dataset = isic2018Dataset(image_root, gt_root, image_index, trainsize, 126 | augmentation) 127 | data_loader = data.DataLoader(dataset=dataset, 128 | batch_size=batchsize, 129 | shuffle=shuffle, 130 | num_workers=num_workers, 131 | pin_memory=pin_memory) 132 | 133 | testset = isic2018Dataset(image_root, gt_root, test_index, trainsize, 134 | False) 135 | test_loader = data.DataLoader(dataset=testset, 136 | batch_size=1, 137 | shuffle=shuffle, 138 | num_workers=num_workers, 139 | pin_memory=pin_memory) 140 | return data_loader, test_loader 141 | 142 | 143 | class test_dataset: 144 | def __init__(self, image_root, gt_root, test_index, testsize): 145 | self.testsize = testsize 146 | self.image_root = image_root 147 | self.gt_root = gt_root 148 | self.images = test_index 149 | self.transform = A.Compose( 150 | [A.Resize(self.testsize, self.testsize), 151 | ToTensorV2()]) 152 | self.size = len(self.images) 153 | self.index = 0 154 | 155 | def __getitem__(self, idx): 156 | image = cv2.imread(os.path.join(self.image_root, self.images[idx])) 157 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 158 | # gt_ind=self.images[idx].split('_') 159 | # gt_name=gt_ind[0]+'_Task1_'+gt_ind[2]+'_GroundTruth/ISIC_'+gt_ind[4][:-4]+'_segmentation.png' 160 | gt_name = self.images[idx][:-4] + '_segmentation.png' 161 | gt_root = os.path.join(self.gt_root, gt_name) 162 | gt = cv2.imread(gt_root, cv2.IMREAD_GRAYSCALE) 163 | pair = self.transform(image=image, mask=gt) 164 | name = self.images[idx].split('/')[-1] 165 | if name.endswith('.jpg'): 166 | name = name.split('.jpg')[0] + '.png' 167 | image = pair['image'].unsqueeze(0) / 255 168 | gt = pair['mask'] / 255 169 | return image, gt, name 170 | 171 | def rgb_loader(self, path): 172 | with open(path, 'rb') as f: 173 | img = Image.open(f) 174 | return img.convert('RGB') 175 | 176 | def binary_loader(self, path): 177 | with open(path, 'rb') as f: 178 | img = Image.open(f) 179 | return img.convert('L') 180 | 181 | def __len__(self): 182 | return self.size 183 | 184 | 185 | if __name__ == '__main__': 186 | isic2018Dataset() -------------------------------------------------------------------------------- /lib/baseline.py: -------------------------------------------------------------------------------- 1 | from turtle import forward 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from lib.modules import xboundlearner, xboundlearnerv2 7 | from lib.vision_transformers import in_scale_transformer 8 | 9 | from lib.pvtv2 import pvt_v2_b2 # 10 | 11 | 12 | def _segm_pvtv2(num_classes, im_num, ex_num, xbound, trainsize): 13 | backbone = pvt_v2_b2(img_size=trainsize) 14 | 15 | if 1: 16 | path = 'pvt_v2_b2.pth' 17 | save_model = torch.load(path) 18 | model_dict = backbone.state_dict() 19 | state_dict = { 20 | k: v 21 | for k, v in save_model.items() if k in model_dict.keys() 22 | } 23 | model_dict.update(state_dict) 24 | backbone.load_state_dict(model_dict) 25 | classifier = _simple_classifier(num_classes) 26 | model = _SimpleSegmentationModel(backbone, classifier, im_num, ex_num) 27 | return model 28 | 29 | 30 | class _simple_classifier(nn.Module): 31 | def __init__(self, num_classes): 32 | super(_simple_classifier, self).__init__() 33 | self.classifier = nn.Sequential( 34 | nn.Conv2d(192, 64, 1, padding=1, bias=False), #560 35 | nn.BatchNorm2d(64), 36 | nn.ReLU(inplace=True), 37 | nn.Conv2d(64, num_classes, 1)) 38 | self.classifier1 = nn.Sequential(nn.Conv2d(128, num_classes, 1)) 39 | self.classifier2 = nn.Sequential(nn.Conv2d(128, num_classes, 1)) 40 | self.classifier3 = nn.Sequential(nn.Conv2d(128, num_classes, 1)) 41 | 42 | def forward(self, feature): 43 | low_level_feature = feature[0] 44 | output_feature = feature[1] 45 | output_feature = F.interpolate(output_feature, 46 | size=low_level_feature.shape[2:], 47 | mode='bilinear', 48 | align_corners=False) 49 | if self.training: 50 | return [ 51 | self.classifier( 52 | torch.cat([low_level_feature, output_feature], dim=1)), 53 | self.classifier1(feature[1]), 54 | self.classifier2(feature[2]), 55 | self.classifier3(feature[3]) 56 | ] 57 | else: 58 | return self.classifier( 59 | torch.cat([low_level_feature, output_feature], dim=1)) 60 | 61 | 62 | class _SimpleSegmentationModel(nn.Module): 63 | # general segmentation model 64 | def __init__(self, backbone, classifier, im_num, ex_num): 65 | super(_SimpleSegmentationModel, self).__init__() 66 | self.backbone = backbone 67 | self.classifier = classifier 68 | self.bat_low = _bound_learner(hidden_features=128, 69 | im_num=im_num, 70 | ex_num=ex_num) 71 | 72 | def forward(self, x): 73 | input_shape = x.shape[-2:] 74 | features = self.backbone( 75 | x 76 | ) # ([8, 64, 64, 64]) ([8, 128, 32, 32]) ([8, 320, 16, 16]) ([8, 512, 8, 8]) 77 | features, point_pre1, point_pre2, point_pre3 = self.bat_low(features) 78 | outputs = self.classifier(features) 79 | if self.training: 80 | outputs = [ 81 | F.interpolate(o, 82 | size=input_shape, 83 | mode='bilinear', 84 | align_corners=False) for o in outputs 85 | ] 86 | else: 87 | outputs = F.interpolate(outputs, 88 | size=input_shape, 89 | mode='bilinear', 90 | align_corners=False) 91 | return outputs, point_pre1, point_pre2, point_pre3 92 | 93 | 94 | class _bound_learner(nn.Module): 95 | def __init__(self, point_pred=1, hidden_features=128, im_num=2, ex_num=2): 96 | 97 | super().__init__() 98 | 99 | self.point_pred = point_pred 100 | 101 | self.convolution_mapping_1 = nn.Conv2d(in_channels=128, 102 | out_channels=hidden_features, 103 | kernel_size=(1, 1), 104 | stride=(1, 1), 105 | padding=(0, 0), 106 | bias=True) 107 | self.convolution_mapping_2 = nn.Conv2d(in_channels=320, 108 | out_channels=hidden_features, 109 | kernel_size=(1, 1), 110 | stride=(1, 1), 111 | padding=(0, 0), 112 | bias=True) 113 | self.convolution_mapping_3 = nn.Conv2d(in_channels=512, 114 | out_channels=hidden_features, 115 | kernel_size=(1, 1), 116 | stride=(1, 1), 117 | padding=(0, 0), 118 | bias=True) 119 | normalize_before = True 120 | self.im_ex_boud1 = in_scale_transformer( 121 | point_pred_layers=1, 122 | num_encoder_layers=im_num, 123 | num_decoder_layers=ex_num, 124 | d_model=hidden_features, 125 | nhead=8, 126 | normalize_before=normalize_before) 127 | self.im_ex_boud2 = in_scale_transformer( 128 | point_pred_layers=1, 129 | num_encoder_layers=im_num, 130 | num_decoder_layers=ex_num, 131 | d_model=hidden_features, 132 | nhead=8, 133 | normalize_before=normalize_before) 134 | self.im_ex_boud3 = in_scale_transformer( 135 | point_pred_layers=1, 136 | num_encoder_layers=im_num, 137 | num_decoder_layers=ex_num, 138 | d_model=hidden_features, 139 | nhead=8, 140 | normalize_before=normalize_before) 141 | # self.cross_attention_3_1 = xboundlearner(hidden_features, 8) 142 | # self.cross_attention_3_2 = xboundlearner(hidden_features, 8) 143 | 144 | self.cross_attention_3_1 = xboundlearnerv2(hidden_features, 8) 145 | self.cross_attention_3_2 = xboundlearnerv2(hidden_features, 8) 146 | 147 | self.trans_out_conv = nn.Conv2d(hidden_features * 2, 512, 1, 1) # 148 | 149 | def forward(self, x): 150 | # for tmp in x: 151 | # print(tmp.size()) 152 | features_1 = x[1] 153 | features_2 = x[2] 154 | features_3 = x[3] 155 | features_1 = self.convolution_mapping_1(features_1) 156 | features_2 = self.convolution_mapping_2(features_2) 157 | features_3 = self.convolution_mapping_3(features_3) 158 | 159 | # in-scale attention 160 | latent_tensor_1, features_encoded_1, point_maps_1 = self.im_ex_boud1( 161 | features_1) 162 | 163 | latent_tensor_2, features_encoded_2, point_maps_2 = self.im_ex_boud2( 164 | features_2) 165 | 166 | latent_tensor_3, features_encoded_3, point_maps_3 = self.im_ex_boud3( 167 | features_3) 168 | 169 | # cross-scale attention6 170 | latent_tensor_1 = latent_tensor_1.permute(2, 0, 1) 171 | latent_tensor_2 = latent_tensor_2.permute(2, 0, 1) 172 | latent_tensor_3 = latent_tensor_3.permute(2, 0, 1) 173 | 174 | # ''' point map Upsample ''' 175 | features_encoded_2_2 = self.cross_attention_3_2( 176 | features_encoded_2, features_encoded_3, latent_tensor_2, 177 | latent_tensor_3) 178 | features_encoded_1_2 = self.cross_attention_3_1( 179 | features_encoded_1, features_encoded_2_2, latent_tensor_1, 180 | latent_tensor_2) 181 | 182 | # trans_feature_maps = self.trans_out_conv( 183 | # torch.cat([features_encoded_3_1, features_encoded_3_2], dim=1)) 184 | 185 | # x[3] = trans_feature_maps 186 | # x[2] = torch.cat([x[2], features_encoded_2], dim=1) 187 | # x[1] = torch.cat([x[1], features_encoded_1], dim=1) 188 | 189 | features_stage2 = [ 190 | x[0], features_encoded_1_2, features_encoded_2_2, 191 | features_encoded_3 192 | ] 193 | 194 | return features_stage2, point_maps_1, point_maps_2, point_maps_3 # 195 | 196 | 197 | if __name__ == '__main__': 198 | import os 199 | os.environ['CUDA_VISIBLE_DEVICES'] = '4' 200 | model = _segm_pvtv2(1).cuda() 201 | input_tensor = torch.randn(1, 3, 352, 352).cuda() 202 | 203 | prediction1 = model(input_tensor) 204 | -------------------------------------------------------------------------------- /lib/pvt.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from lib.pvtv2 import pvt_v2_b2 # 5 | import os 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | 11 | class BasicConv2d(nn.Module): 12 | def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1): 13 | super(BasicConv2d, self).__init__() 14 | 15 | self.conv = nn.Conv2d(in_planes, out_planes, 16 | kernel_size=kernel_size, stride=stride, 17 | padding=padding, dilation=dilation, bias=False) 18 | self.bn = nn.BatchNorm2d(out_planes) 19 | self.relu = nn.ReLU(inplace=True) 20 | 21 | def forward(self, x): 22 | x = self.conv(x) 23 | x = self.bn(x) 24 | return x 25 | 26 | 27 | class CFM(nn.Module): 28 | def __init__(self, channel): 29 | super(CFM, self).__init__() 30 | self.relu = nn.ReLU(True) 31 | 32 | self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 33 | self.conv_upsample1 = BasicConv2d(channel, channel, 3, padding=1) 34 | self.conv_upsample2 = BasicConv2d(channel, channel, 3, padding=1) 35 | self.conv_upsample3 = BasicConv2d(channel, channel, 3, padding=1) 36 | self.conv_upsample4 = BasicConv2d(channel, channel, 3, padding=1) 37 | self.conv_upsample5 = BasicConv2d(2 * channel, 2 * channel, 3, padding=1) 38 | 39 | self.conv_concat2 = BasicConv2d(2 * channel, 2 * channel, 3, padding=1) 40 | self.conv_concat3 = BasicConv2d(3 * channel, 3 * channel, 3, padding=1) 41 | self.conv4 = BasicConv2d(3 * channel, channel, 3, padding=1) 42 | 43 | def forward(self, x1, x2, x3): 44 | x1_1 = x1 45 | x2_1 = self.conv_upsample1(self.upsample(x1)) * x2 46 | x3_1 = self.conv_upsample2(self.upsample(self.upsample(x1))) \ 47 | * self.conv_upsample3(self.upsample(x2)) * x3 48 | 49 | x2_2 = torch.cat((x2_1, self.conv_upsample4(self.upsample(x1_1))), 1) 50 | x2_2 = self.conv_concat2(x2_2) 51 | 52 | x3_2 = torch.cat((x3_1, self.conv_upsample5(self.upsample(x2_2))), 1) 53 | x3_2 = self.conv_concat3(x3_2) 54 | 55 | x1 = self.conv4(x3_2) 56 | 57 | return x1 58 | 59 | 60 | 61 | 62 | class GCN(nn.Module): 63 | def __init__(self, num_state, num_node, bias=False): 64 | super(GCN, self).__init__() 65 | self.conv1 = nn.Conv1d(num_node, num_node, kernel_size=1) 66 | self.relu = nn.ReLU(inplace=True) 67 | self.conv2 = nn.Conv1d(num_state, num_state, kernel_size=1, bias=bias) 68 | 69 | def forward(self, x): 70 | h = self.conv1(x.permute(0, 2, 1)).permute(0, 2, 1) 71 | h = h - x 72 | h = self.relu(self.conv2(h)) 73 | return h 74 | 75 | 76 | class SAM(nn.Module): 77 | def __init__(self, num_in=32, plane_mid=16, mids=4, normalize=False): 78 | super(SAM, self).__init__() 79 | 80 | self.normalize = normalize 81 | self.num_s = int(plane_mid) 82 | self.num_n = (mids) * (mids) 83 | self.priors = nn.AdaptiveAvgPool2d(output_size=(mids + 2, mids + 2)) 84 | 85 | self.conv_state = nn.Conv2d(num_in, self.num_s, kernel_size=1) 86 | self.conv_proj = nn.Conv2d(num_in, self.num_s, kernel_size=1) 87 | self.gcn = GCN(num_state=self.num_s, num_node=self.num_n) 88 | self.conv_extend = nn.Conv2d(self.num_s, num_in, kernel_size=1, bias=False) 89 | 90 | def forward(self, x, edge): 91 | edge = F.interpolate(edge, (x.size()[-2], x.size()[-1])) 92 | 93 | n, c, h, w = x.size() 94 | edge = torch.nn.functional.softmax(edge, dim=1)[:, 1, :, :].unsqueeze(1) 95 | 96 | x_state_reshaped = self.conv_state(x).view(n, self.num_s, -1) 97 | x_proj = self.conv_proj(x) 98 | x_mask = x_proj * edge 99 | 100 | x_anchor1 = self.priors(x_mask) 101 | x_anchor2 = self.priors(x_mask)[:, :, 1:-1, 1:-1].reshape(n, self.num_s, -1) 102 | x_anchor = self.priors(x_mask)[:, :, 1:-1, 1:-1].reshape(n, self.num_s, -1) 103 | 104 | x_proj_reshaped = torch.matmul(x_anchor.permute(0, 2, 1), x_proj.reshape(n, self.num_s, -1)) 105 | x_proj_reshaped = torch.nn.functional.softmax(x_proj_reshaped, dim=1) 106 | 107 | x_rproj_reshaped = x_proj_reshaped 108 | 109 | x_n_state = torch.matmul(x_state_reshaped, x_proj_reshaped.permute(0, 2, 1)) 110 | if self.normalize: 111 | x_n_state = x_n_state * (1. / x_state_reshaped.size(2)) 112 | x_n_rel = self.gcn(x_n_state) 113 | 114 | x_state_reshaped = torch.matmul(x_n_rel, x_rproj_reshaped) 115 | x_state = x_state_reshaped.view(n, self.num_s, *x.size()[2:]) 116 | out = x + (self.conv_extend(x_state)) 117 | 118 | return out 119 | 120 | 121 | class ChannelAttention(nn.Module): 122 | def __init__(self, in_planes, ratio=16): 123 | super(ChannelAttention, self).__init__() 124 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 125 | self.max_pool = nn.AdaptiveMaxPool2d(1) 126 | 127 | self.fc1 = nn.Conv2d(in_planes, in_planes // 16, 1, bias=False) 128 | self.relu1 = nn.ReLU() 129 | self.fc2 = nn.Conv2d(in_planes // 16, in_planes, 1, bias=False) 130 | 131 | self.sigmoid = nn.Sigmoid() 132 | 133 | def forward(self, x): 134 | avg_out = self.fc2(self.relu1(self.fc1(self.avg_pool(x)))) 135 | max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x)))) 136 | out = avg_out + max_out 137 | return self.sigmoid(out) 138 | 139 | 140 | class SpatialAttention(nn.Module): 141 | def __init__(self, kernel_size=7): 142 | super(SpatialAttention, self).__init__() 143 | 144 | assert kernel_size in (3, 7), 'kernel size must be 3 or 7' 145 | padding = 3 if kernel_size == 7 else 1 146 | 147 | self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False) 148 | self.sigmoid = nn.Sigmoid() 149 | 150 | def forward(self, x): 151 | avg_out = torch.mean(x, dim=1, keepdim=True) 152 | max_out, _ = torch.max(x, dim=1, keepdim=True) 153 | x = torch.cat([avg_out, max_out], dim=1) 154 | x = self.conv1(x) 155 | return self.sigmoid(x) 156 | 157 | 158 | class PolypPVT(nn.Module): 159 | def __init__(self, channel=32): 160 | super(PolypPVT, self).__init__() 161 | 162 | self.backbone = pvt_v2_b2() # [64, 128, 320, 512] 163 | path = './pretrained_pth/pvt_v2_b2.pth' 164 | save_model = torch.load(path) 165 | model_dict = self.backbone.state_dict() 166 | state_dict = {k: v for k, v in save_model.items() if k in model_dict.keys()} 167 | model_dict.update(state_dict) 168 | self.backbone.load_state_dict(model_dict) 169 | 170 | self.Translayer2_0 = BasicConv2d(64, channel, 1) 171 | self.Translayer2_1 = BasicConv2d(128, channel, 1) 172 | self.Translayer3_1 = BasicConv2d(320, channel, 1) 173 | self.Translayer4_1 = BasicConv2d(512, channel, 1) 174 | 175 | self.CFM = CFM(channel) 176 | self.ca = ChannelAttention(64) 177 | self.sa = SpatialAttention() 178 | self.SAM = SAM() 179 | 180 | self.down05 = nn.Upsample(scale_factor=0.5, mode='bilinear', align_corners=True) 181 | self.out_SAM = nn.Conv2d(channel, 1, 1) 182 | self.out_CFM = nn.Conv2d(channel, 1, 1) 183 | 184 | 185 | def forward(self, x): 186 | 187 | # backbone 188 | pvt = self.backbone(x) 189 | x1 = pvt[0] 190 | x2 = pvt[1] 191 | x3 = pvt[2] 192 | x4 = pvt[3] 193 | 194 | # CIM 195 | x1 = self.ca(x1) * x1 # channel attention 196 | cim_feature = self.sa(x1) * x1 # spatial attention 197 | 198 | 199 | # CFM 200 | x2_t = self.Translayer2_1(x2) 201 | x3_t = self.Translayer3_1(x3) 202 | x4_t = self.Translayer4_1(x4) 203 | cfm_feature = self.CFM(x4_t, x3_t, x2_t) 204 | 205 | # SAM 206 | T2 = self.Translayer2_0(cim_feature) 207 | T2 = self.down05(T2) 208 | sam_feature = self.SAM(cfm_feature, T2) 209 | 210 | prediction1 = self.out_CFM(cfm_feature) 211 | prediction2 = self.out_SAM(sam_feature) 212 | 213 | prediction1_8 = F.interpolate(prediction1, scale_factor=8, mode='bilinear') 214 | prediction2_8 = F.interpolate(prediction2, scale_factor=8, mode='bilinear') 215 | return prediction1_8, prediction2_8 216 | 217 | 218 | if __name__ == '__main__': 219 | import os 220 | os.environ['CUDA_VISIBLE_DEVICES']='4' 221 | model = PolypPVT().cuda() 222 | input_tensor = torch.randn(1, 3, 352, 352).cuda() 223 | 224 | prediction1, prediction2 = model(input_tensor) 225 | print(prediction1.size(), prediction2.size()) -------------------------------------------------------------------------------- /lib/xboundformer.py: -------------------------------------------------------------------------------- 1 | from turtle import forward 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from lib.modules import xboundlearner, xboundlearnerv2, _simple_learner 7 | from lib.vision_transformers import in_scale_transformer 8 | 9 | from lib.pvtv2 import pvt_v2_b2 # 10 | 11 | 12 | def _segm_pvtv2(num_classes, im_num, ex_num, xbound, trainsize): 13 | backbone = pvt_v2_b2(img_size=trainsize) 14 | 15 | if 1: 16 | path = 'pvt_v2_b2.pth' 17 | save_model = torch.load(path) 18 | model_dict = backbone.state_dict() 19 | state_dict = { 20 | k: v 21 | for k, v in save_model.items() if k in model_dict.keys() 22 | } 23 | model_dict.update(state_dict) 24 | backbone.load_state_dict(model_dict) 25 | classifier = _simple_classifier(num_classes) 26 | model = _SimpleSegmentationModel(backbone, classifier, im_num, ex_num, 27 | xbound) 28 | return model 29 | 30 | 31 | class _simple_classifier(nn.Module): 32 | def __init__(self, num_classes): 33 | super(_simple_classifier, self).__init__() 34 | self.classifier = nn.Sequential( 35 | nn.Conv2d(192, 64, 1, padding=1, bias=False), #560 36 | nn.BatchNorm2d(64), 37 | nn.ReLU(inplace=True), 38 | nn.Conv2d(64, num_classes, 1)) 39 | self.classifier1 = nn.Sequential(nn.Conv2d(128, num_classes, 1)) 40 | self.classifier2 = nn.Sequential(nn.Conv2d(128, num_classes, 1)) 41 | self.classifier3 = nn.Sequential(nn.Conv2d(128, num_classes, 1)) 42 | 43 | def forward(self, feature): 44 | low_level_feature = feature[0] 45 | output_feature = feature[1] 46 | output_feature = F.interpolate(output_feature, 47 | size=low_level_feature.shape[2:], 48 | mode='bilinear', 49 | align_corners=False) 50 | if self.training: 51 | return [ 52 | self.classifier( 53 | torch.cat([low_level_feature, output_feature], dim=1)), 54 | self.classifier1(feature[1]), 55 | self.classifier2(feature[2]), 56 | self.classifier3(feature[3]) 57 | ] 58 | else: 59 | return self.classifier( 60 | torch.cat([low_level_feature, output_feature], dim=1)) 61 | 62 | 63 | class _SimpleSegmentationModel(nn.Module): 64 | # general segmentation model 65 | def __init__(self, backbone, classifier, im_num, ex_num, xbound): 66 | super(_SimpleSegmentationModel, self).__init__() 67 | self.backbone = backbone 68 | self.classifier = classifier 69 | self.bat_low = _bound_learner(hidden_features=128, 70 | im_num=im_num, 71 | ex_num=ex_num, 72 | xbound=xbound) 73 | 74 | def forward(self, x): 75 | input_shape = x.shape[-2:] 76 | features = self.backbone( 77 | x 78 | ) # ([8, 64, 64, 64]) ([8, 128, 32, 32]) ([8, 320, 16, 16]) ([8, 512, 8, 8]) 79 | features, point_pre1, point_pre2, point_pre3 = self.bat_low(features) 80 | outputs = self.classifier(features) 81 | if self.training: 82 | outputs = [ 83 | F.interpolate(o, 84 | size=input_shape, 85 | mode='bilinear', 86 | align_corners=False) for o in outputs 87 | ] 88 | else: 89 | outputs = F.interpolate(outputs, 90 | size=input_shape, 91 | mode='bilinear', 92 | align_corners=False) 93 | return outputs, point_pre1, point_pre2, point_pre3 94 | 95 | 96 | class _bound_learner(nn.Module): 97 | def __init__(self, 98 | point_pred=1, 99 | hidden_features=128, 100 | im_num=2, 101 | ex_num=2, 102 | xbound=True): 103 | 104 | super().__init__() 105 | self.im_num = im_num 106 | self.ex_num = ex_num 107 | 108 | self.point_pred = point_pred 109 | 110 | self.convolution_mapping_1 = nn.Conv2d(in_channels=128, 111 | out_channels=hidden_features, 112 | kernel_size=(1, 1), 113 | stride=(1, 1), 114 | padding=(0, 0), 115 | bias=True) 116 | self.convolution_mapping_2 = nn.Conv2d(in_channels=320, 117 | out_channels=hidden_features, 118 | kernel_size=(1, 1), 119 | stride=(1, 1), 120 | padding=(0, 0), 121 | bias=True) 122 | self.convolution_mapping_3 = nn.Conv2d(in_channels=512, 123 | out_channels=hidden_features, 124 | kernel_size=(1, 1), 125 | stride=(1, 1), 126 | padding=(0, 0), 127 | bias=True) 128 | normalize_before = True 129 | 130 | if im_num + ex_num > 0: 131 | self.im_ex_boud1 = in_scale_transformer( 132 | point_pred_layers=1, 133 | num_encoder_layers=im_num, 134 | num_decoder_layers=ex_num, 135 | d_model=hidden_features, 136 | nhead=8, 137 | normalize_before=normalize_before) 138 | self.im_ex_boud2 = in_scale_transformer( 139 | point_pred_layers=1, 140 | num_encoder_layers=im_num, 141 | num_decoder_layers=ex_num, 142 | d_model=hidden_features, 143 | nhead=8, 144 | normalize_before=normalize_before) 145 | self.im_ex_boud3 = in_scale_transformer( 146 | point_pred_layers=1, 147 | num_encoder_layers=im_num, 148 | num_decoder_layers=ex_num, 149 | d_model=hidden_features, 150 | nhead=8, 151 | normalize_before=normalize_before) 152 | # self.cross_attention_3_1 = xboundlearner(hidden_features, 8) 153 | # self.cross_attention_3_2 = xboundlearner(hidden_features, 8) 154 | 155 | self.xbound = xbound 156 | if xbound: 157 | self.cross_attention_3_1 = xboundlearnerv2(hidden_features, 8) 158 | self.cross_attention_3_2 = xboundlearnerv2(hidden_features, 8) 159 | else: 160 | self.cross_attention_3_1 = _simple_learner(hidden_features) 161 | self.cross_attention_3_2 = _simple_learner(hidden_features) 162 | 163 | self.trans_out_conv = nn.Conv2d(hidden_features * 2, 512, 1, 1) # 164 | 165 | def forward(self, x): 166 | # for tmp in x: 167 | # print(tmp.size()) 168 | features_1 = x[1] 169 | features_2 = x[2] 170 | features_3 = x[3] 171 | features_1 = self.convolution_mapping_1(features_1) 172 | features_2 = self.convolution_mapping_2(features_2) 173 | features_3 = self.convolution_mapping_3(features_3) 174 | 175 | # in-scale attention 176 | if self.im_num + self.ex_num > 0: 177 | latent_tensor_1, features_encoded_1, point_maps_1 = self.im_ex_boud1( 178 | features_1) 179 | 180 | latent_tensor_2, features_encoded_2, point_maps_2 = self.im_ex_boud2( 181 | features_2) 182 | 183 | latent_tensor_3, features_encoded_3, point_maps_3 = self.im_ex_boud3( 184 | features_3) 185 | 186 | # cross-scale attention6 187 | if self.ex_num > 0: 188 | latent_tensor_1 = latent_tensor_1.permute(2, 0, 1) 189 | latent_tensor_2 = latent_tensor_2.permute(2, 0, 1) 190 | latent_tensor_3 = latent_tensor_3.permute(2, 0, 1) 191 | 192 | else: 193 | features_encoded_1 = features_1 194 | features_encoded_2 = features_2 195 | features_encoded_3 = features_3 196 | 197 | # ''' point map Upsample ''' 198 | if self.xbound: 199 | features_encoded_2_2 = self.cross_attention_3_2( 200 | features_encoded_2, features_encoded_3, latent_tensor_2, 201 | latent_tensor_3) 202 | features_encoded_1_2 = self.cross_attention_3_1( 203 | features_encoded_1, features_encoded_2_2, latent_tensor_1, 204 | latent_tensor_2) 205 | else: 206 | features_encoded_2_2 = self.cross_attention_3_2( 207 | features_encoded_2, features_encoded_3) 208 | features_encoded_1_2 = self.cross_attention_3_1( 209 | features_encoded_1, features_encoded_2_2) 210 | 211 | # trans_feature_maps = self.trans_out_conv( 212 | # torch.cat([features_encoded_3_1, features_encoded_3_2], dim=1)) 213 | 214 | # x[3] = trans_feature_maps 215 | # x[2] = torch.cat([x[2], features_encoded_2], dim=1) 216 | # x[1] = torch.cat([x[1], features_encoded_1], dim=1) 217 | 218 | features_stage2 = [ 219 | x[0], features_encoded_1_2, features_encoded_2_2, 220 | features_encoded_3 221 | ] 222 | 223 | if self.im_num + self.ex_num > 0: 224 | return features_stage2, point_maps_1, point_maps_2, point_maps_3 # 225 | else: 226 | return features_stage2, None, None, None # 227 | 228 | 229 | if __name__ == '__main__': 230 | import os 231 | os.environ['CUDA_VISIBLE_DEVICES'] = '4' 232 | model = _segm_pvtv2(1).cuda() 233 | input_tensor = torch.randn(1, 3, 352, 352).cuda() 234 | 235 | prediction1 = model(input_tensor) 236 | -------------------------------------------------------------------------------- /src/train.py: -------------------------------------------------------------------------------- 1 | import os, argparse, math 2 | import numpy as np 3 | from glob import glob 4 | from tqdm import tqdm 5 | import sys 6 | 7 | sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..')) 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | import torch.utils.data 11 | import torch.optim as optim 12 | 13 | from medpy.metric.binary import hd, dc, assd, jc 14 | 15 | from skimage import segmentation as skimage_seg 16 | from scipy.ndimage import distance_transform_edt as distance 17 | from torch.utils.tensorboard import SummaryWriter 18 | 19 | from torch.optim.lr_scheduler import ReduceLROnPlateau, StepLR, CosineAnnealingLR 20 | import time 21 | 22 | 23 | def get_cfg(): 24 | parser = argparse.ArgumentParser() 25 | parser.add_argument('--arch', type=str, default='xboundformer') 26 | parser.add_argument('--gpu', type=str, default='4') 27 | parser.add_argument('--net_layer', type=int, default=50) 28 | parser.add_argument('--dataset', type=str, default='isic2016') 29 | parser.add_argument('--exp_name', type=str, default='test') 30 | parser.add_argument('--fold', type=str) 31 | parser.add_argument('--lr_seg', type=float, default=1e-4) #0.0003 32 | parser.add_argument('--n_epochs', type=int, default=150) #100 33 | parser.add_argument('--bt_size', type=int, default=12) #36 34 | parser.add_argument('--seg_loss', type=int, default=1, choices=[0, 1]) 35 | parser.add_argument('--aug', type=int, default=1) 36 | parser.add_argument('--patience', type=int, default=500) #50 37 | 38 | # transformer 39 | parser.add_argument('--filter', type=int, default=0) 40 | parser.add_argument('--im_num', type=int, default=1) 41 | parser.add_argument('--ex_num', type=int, default=1) 42 | parser.add_argument('--xbound', type=int, default=1) 43 | parser.add_argument('--point_w', type=float, default=1) 44 | 45 | #log_dir name 46 | parser.add_argument('--folder_name', type=str, default='Default_folder') 47 | 48 | parse_config = parser.parse_args() 49 | print(parse_config) 50 | return parse_config 51 | 52 | 53 | def ce_loss(pred, gt): 54 | pred = torch.clamp(pred, 1e-6, 1 - 1e-6) 55 | return (-gt * torch.log(pred) - (1 - gt) * torch.log(1 - pred)).mean() 56 | 57 | 58 | def structure_loss(pred, mask): 59 | """ TransFuse train loss """ 60 | """ Without sigmoid """ 61 | weit = 1 + 5 * torch.abs( 62 | F.avg_pool2d(mask, kernel_size=31, stride=1, padding=15) - mask) 63 | wbce = F.binary_cross_entropy_with_logits(pred, mask, reduction='none') 64 | wbce = (weit * wbce).sum(dim=(2, 3)) / weit.sum(dim=(2, 3)) 65 | 66 | pred = torch.sigmoid(pred) 67 | inter = ((pred * mask) * weit).sum(dim=(2, 3)) 68 | union = ((pred + mask) * weit).sum(dim=(2, 3)) 69 | wiou = 1 - (inter + 1) / (union - inter + 1) 70 | return (wbce + wiou).mean() 71 | 72 | 73 | #-------------------------- train func --------------------------# 74 | def train(epoch): 75 | model.train() 76 | iteration = 0 77 | for batch_idx, batch_data in enumerate(train_loader): 78 | # print(epoch, batch_idx) 79 | data = batch_data['image'].cuda().float() 80 | label = batch_data['label'].cuda().float() 81 | if parse_config.filter: 82 | point = (batch_data['filter_point_data'] > 0).cuda().float() 83 | else: 84 | point = (batch_data['point'] > 0).cuda().float() 85 | #point_All = (batch_data['point_All'] > 0).cuda().float() 86 | 87 | if parse_config.arch == 'transfuse': 88 | lateral_map_4, lateral_map_3, lateral_map_2 = model(data) 89 | 90 | loss4 = structure_loss(lateral_map_4, label) 91 | loss3 = structure_loss(lateral_map_3, label) 92 | loss2 = structure_loss(lateral_map_2, label) 93 | 94 | loss = 0.5 * loss2 + 0.3 * loss3 + 0.2 * loss4 95 | 96 | optimizer.zero_grad() 97 | loss.backward() 98 | optimizer.step() 99 | if (batch_idx + 1) % 10 == 0: 100 | print( 101 | 'Train Epoch: {} [{}/{} ({:.0f}%)]\t[lateral-2: {:.4f}, lateral-3: {:0.4f}, lateral-4: {:0.4f}]' 102 | .format(epoch, batch_idx * len(data), 103 | len(train_loader.dataset), 104 | 100. * batch_idx / len(train_loader), loss2.item(), 105 | loss3.item(), loss4.item())) 106 | 107 | else: 108 | P2, point_maps_pre, point_maps_pre1, point_maps_pre2 = model(data) 109 | if parse_config.im_num + parse_config.ex_num > 0: 110 | point_loss = 0.0 111 | point3 = F.max_pool2d(point, (32, 32), (32, 32)) 112 | point2 = F.max_pool2d(point, (16, 16), (16, 16)) 113 | point1 = F.max_pool2d(point, (8, 8), (8, 8)) 114 | 115 | for point_pre, point_pre1, point_pre2 in zip( 116 | point_maps_pre, point_maps_pre1, point_maps_pre2): 117 | point_loss = point_loss + criteon( 118 | point_pre, point1) + criteon( 119 | point_pre1, point2) + criteon(point_pre2, point3) 120 | point_loss = point_loss / (3 * len(point_maps_pre1)) 121 | seg_loss = 0.0 122 | for p in P2: 123 | seg_loss = seg_loss + structure_loss(p, label) 124 | seg_loss = seg_loss / len(P2) 125 | loss = seg_loss + parse_config.point_w * point_loss 126 | 127 | if batch_idx % 50 == 0: 128 | show_image = [label[0], F.sigmoid(P2[0][0])] 129 | for point_map in [ 130 | point_maps_pre, point_maps_pre1, point_maps_pre2 131 | ]: 132 | tmp = F.interpolate(point_map[-1], size=(352, 352))[0] 133 | show_image.append(tmp) 134 | show_image = torch.cat(show_image, dim=2) 135 | show_image = show_image.repeat(3, 1, 1) 136 | show_image = torch.cat([data[0], show_image], dim=2) 137 | 138 | writer.add_image('pred/all', 139 | show_image, 140 | epoch * len(train_loader) + batch_idx, 141 | dataformats='CHW') 142 | 143 | else: 144 | point_loss = 0.0 145 | seg_loss = 0.0 146 | for p in P2: 147 | seg_loss = seg_loss + structure_loss(p, label) 148 | seg_loss = seg_loss / len(P2) 149 | loss = seg_loss + 2 * point_loss 150 | 151 | optimizer.zero_grad() 152 | loss.backward() 153 | optimizer.step() 154 | if (batch_idx + 1) % 10 == 0: 155 | print( 156 | 'Train Epoch: {} [{}/{} ({:.0f}%)]\t[lateral-2: {:.4f}, lateral-3: {:0.4f}, lateral-4: {:0.4f}]' 157 | .format(epoch, batch_idx * len(data), 158 | len(train_loader.dataset), 159 | 100. * batch_idx / len(train_loader), loss, 160 | seg_loss, point_loss)) 161 | 162 | print("Iteration numbers: ", iteration) 163 | 164 | 165 | #-------------------------- eval func --------------------------# 166 | def evaluation(epoch, loader): 167 | model.eval() 168 | dice_value = 0 169 | iou_value = 0 170 | dice_average = 0 171 | iou_average = 0 172 | numm = 0 173 | for batch_idx, batch_data in enumerate(loader): 174 | data = batch_data['image'].cuda().float() 175 | label = batch_data['label'].cuda().float() 176 | point = (batch_data['point'] > 0).cuda().float() 177 | # point_All = (batch_data['point_data'] > 0).cuda().float() 178 | #point_All = nn.functional.max_pool2d(point_All, 179 | # kernel_size=(16, 16), 180 | # stride=(16, 16)) 181 | 182 | with torch.no_grad(): 183 | if parse_config.arch == 'transfuse': 184 | _, _, output = model(data) 185 | loss_fuse = structure_loss(output, label) 186 | elif parse_config.arch == 'xboundformer': 187 | output, point_maps_pre, point_maps_pre1, point_maps_pre2 = model( 188 | data) 189 | loss = 0 190 | 191 | if parse_config.arch == 'transfuse': 192 | loss = loss_fuse 193 | 194 | output = output.cpu().numpy() > 0.5 195 | 196 | label = label.cpu().numpy() 197 | assert (output.shape == label.shape) 198 | dice_ave = dc(output, label) 199 | iou_ave = jc(output, label) 200 | dice_value += dice_ave 201 | iou_value += iou_ave 202 | numm += 1 203 | 204 | dice_average = dice_value / numm 205 | iou_average = iou_value / numm 206 | writer.add_scalar('val_metrics/val_dice', dice_average, epoch) 207 | writer.add_scalar('val_metrics/val_iou', iou_average, epoch) 208 | print("Average dice value of evaluation dataset = ", dice_average) 209 | print("Average iou value of evaluation dataset = ", iou_average) 210 | return dice_average, iou_average, loss 211 | 212 | 213 | if __name__ == '__main__': 214 | #-------------------------- get args --------------------------# 215 | parse_config = get_cfg() 216 | 217 | #-------------------------- build loggers and savers --------------------------# 218 | exp_name = parse_config.dataset + '/' + parse_config.exp_name + '_loss_' + str( 219 | parse_config.seg_loss) + '_aug_' + str( 220 | parse_config.aug 221 | ) + '/' + parse_config.folder_name + '/fold_' + str(parse_config.fold) 222 | 223 | os.makedirs('logs/{}'.format(exp_name), exist_ok=True) 224 | os.makedirs('logs/{}/model'.format(exp_name), exist_ok=True) 225 | writer = SummaryWriter('logs/{}/log'.format(exp_name)) 226 | save_path = 'logs/{}/model/best.pkl'.format(exp_name) 227 | latest_path = 'logs/{}/model/latest.pkl'.format(exp_name) 228 | 229 | EPOCHS = parse_config.n_epochs 230 | os.environ['CUDA_VISIBLE_DEVICES'] = parse_config.gpu 231 | device_ids = range(torch.cuda.device_count()) 232 | 233 | #-------------------------- build dataloaders --------------------------# 234 | if parse_config.dataset == 'isic2018': 235 | from utils.isbi2018_new import norm01, myDataset 236 | 237 | dataset = myDataset(fold=parse_config.fold, 238 | split='train', 239 | aug=parse_config.aug) 240 | dataset2 = myDataset(fold=parse_config.fold, split='valid', aug=False) 241 | elif parse_config.dataset == 'isic2016': 242 | from utils.isbi2016_new import norm01, myDataset 243 | 244 | dataset = myDataset(split='train', aug=parse_config.aug) 245 | dataset2 = myDataset(split='valid', aug=False) 246 | else: 247 | raise NotImplementedError 248 | 249 | train_loader = torch.utils.data.DataLoader(dataset, 250 | batch_size=parse_config.bt_size, 251 | shuffle=True, 252 | num_workers=2, 253 | pin_memory=True, 254 | drop_last=True) 255 | val_loader = torch.utils.data.DataLoader( 256 | dataset2, 257 | batch_size=1, #parse_config.bt_size 258 | shuffle=False, #True 259 | num_workers=2, 260 | pin_memory=True, 261 | drop_last=False) #True 262 | 263 | #-------------------------- build models --------------------------# 264 | if parse_config.arch is 'xboundformer': 265 | from lib.xboundformer import _segm_pvtv2 266 | model = _segm_pvtv2(1, parse_config.im_num, parse_config.ex_num, 267 | parse_config.xbound, 352).cuda() 268 | elif parse_config.arch == 'transfuse': 269 | from lib.TransFuse.TransFuse import TransFuse_S 270 | model = TransFuse_S(pretrained=True).cuda() 271 | 272 | if len(device_ids) > 1: # 多卡训练 273 | model = torch.nn.DataParallel(model).cuda() 274 | 275 | optimizer = torch.optim.Adam(model.parameters(), lr=parse_config.lr_seg) 276 | 277 | #scheduler = ReduceLROnPlateau(optimizer, 'min', factor=0.5, patience=10) 278 | scheduler = CosineAnnealingLR(optimizer, T_max=20) 279 | 280 | criteon = [None, ce_loss][parse_config.seg_loss] 281 | 282 | #-------------------------- start training --------------------------# 283 | 284 | max_dice = 0 285 | max_iou = 0 286 | best_ep = 0 287 | 288 | min_loss = 10 289 | min_epoch = 0 290 | 291 | for epoch in range(1, EPOCHS + 1): 292 | print(optimizer.state_dict()['param_groups'][0]['lr']) 293 | start = time.time() 294 | train(epoch) 295 | dice, iou, loss = evaluation(epoch, val_loader) 296 | scheduler.step() 297 | 298 | if loss < min_loss: 299 | min_epoch = epoch 300 | min_loss = loss 301 | else: 302 | if epoch - min_epoch >= parse_config.patience: 303 | print('Early stopping!') 304 | break 305 | if iou > max_iou: 306 | max_iou = iou 307 | best_ep = epoch 308 | torch.save(model.state_dict(), save_path) 309 | else: 310 | if epoch - best_ep >= parse_config.patience: 311 | print('Early stopping!') 312 | break 313 | torch.save(model.state_dict(), latest_path) 314 | time_elapsed = time.time() - start 315 | print( 316 | 'Training and evaluating on epoch:{} complete in {:.0f}m {:.0f}s'. 317 | format(epoch, time_elapsed // 60, time_elapsed % 60)) 318 | -------------------------------------------------------------------------------- /lib/deeplabv3.py: -------------------------------------------------------------------------------- 1 | from lib.pvtv2 import pvt_v2_b2 # 2 | import torch 3 | from torch import nn 4 | import torch.nn.functional as F 5 | # from lib.Vision_Transformer import detr_Transformer, detr_BA_Transformer 6 | from lib.modules import BoundaryCrossAttention 7 | from lib.transformer import BoundaryAwareTransformer 8 | from lib.replknet import RepLKBlock 9 | 10 | 11 | def _segm_pvtv2(name, backbone_name, num_classes, output_stride, 12 | pretrained_backbone): 13 | 14 | if output_stride == 8: 15 | aspp_dilate = [12, 24, 36] 16 | else: 17 | aspp_dilate = [6, 12, 18] 18 | 19 | backbone = pvt_v2_b2() 20 | if pretrained_backbone: 21 | path = './pretrained_pth/pvt_v2_b2.pth' 22 | save_model = torch.load(path) 23 | model_dict = backbone.state_dict() 24 | state_dict = { 25 | k: v 26 | for k, v in save_model.items() if k in model_dict.keys() 27 | } 28 | model_dict.update(state_dict) 29 | backbone.load_state_dict(model_dict) 30 | 31 | inplanes = 512 32 | low_level_planes = 64 33 | 34 | if name == 'deeplabv3plus': 35 | classifier = DeepLabHeadV3Plus(inplanes, low_level_planes, num_classes, 36 | aspp_dilate) 37 | 38 | model = DeepLabV3(backbone, classifier) 39 | return model 40 | 41 | 42 | class DeepLabHeadV3Plus(nn.Module): 43 | def __init__(self, 44 | in_channels, 45 | low_level_channels, 46 | num_classes, 47 | aspp_dilate=[12, 24, 36]): 48 | super(DeepLabHeadV3Plus, self).__init__() 49 | self.project = nn.Sequential( 50 | nn.Conv2d(low_level_channels, 48, 1, bias=False), 51 | nn.BatchNorm2d(48), 52 | nn.ReLU(inplace=True), 53 | ) 54 | 55 | # self.aspp = ASPP(in_channels, aspp_dilate) 56 | 57 | self.classifier = nn.Sequential( 58 | nn.Conv2d(560, 256, 3, padding=1, bias=False), # 59 | nn.BatchNorm2d(256), 60 | nn.ReLU(inplace=True), 61 | nn.Conv2d(256, num_classes, 1)) 62 | self._init_weight() 63 | 64 | def forward(self, feature): 65 | # low_level_feature = self.project( feature['low_level'] ) 66 | # output_feature = self.aspp(feature['out']) 67 | low_level_feature = self.project(feature[0]) 68 | output_feature = feature[3] 69 | # output_feature = self.aspp(feature[3]) 70 | output_feature = F.interpolate(output_feature, 71 | size=low_level_feature.shape[2:], 72 | mode='bilinear', 73 | align_corners=False) 74 | return self.classifier( 75 | torch.cat([low_level_feature, output_feature], dim=1)) 76 | 77 | def _init_weight(self): 78 | for m in self.modules(): 79 | if isinstance(m, nn.Conv2d): 80 | nn.init.kaiming_normal_(m.weight) 81 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 82 | nn.init.constant_(m.weight, 1) 83 | nn.init.constant_(m.bias, 0) 84 | 85 | 86 | class _SimpleSegmentationModel(nn.Module): 87 | def __init__(self, backbone, classifier): 88 | super(_SimpleSegmentationModel, self).__init__() 89 | self.backbone = backbone 90 | self.classifier = classifier 91 | # self.bat=BAT(num_classes=1,point_pred=1, in_channels=512,decoder=True, transformer_type_index=0) 92 | # self.bat_low=BAT(num_classes=1,point_pred=1,in_channels=320,hidden_features=320, decoder=False, transformer_type_index=0) 93 | # self.bat=BAT(in_channel=512) 94 | self.bat_low = BAT(in_channel=320) 95 | # self.bat0=BAT(in_channel=64) 96 | def forward(self, x): 97 | input_shape = x.shape[-2:] 98 | features = self.backbone( 99 | x 100 | ) # ([8, 64, 64, 64]) ([8, 128, 32, 32]) ([8, 320, 16, 16]) ([8, 512, 8, 8]) 101 | # features[3],point_pre=self.bat(features[3]) 102 | features[2], point_pre = self.bat_low(features[2]) 103 | # features[0],point_pre=self.bat0(features[0]) 104 | x = self.classifier(features) 105 | x = F.interpolate(x, 106 | size=input_shape, 107 | mode='bilinear', 108 | align_corners=False) 109 | # p=F.interpolate(point_pre, size=input_shape, mode='bilinear', align_corners=False) 110 | return x, point_pre 111 | 112 | 113 | class DeepLabV3(_SimpleSegmentationModel): 114 | """ 115 | Implements DeepLabV3 model from 116 | `"Rethinking Atrous Convolution for Semantic Image Segmentation" 117 | `_. 118 | Arguments: 119 | backbone (nn.Module): the network used to compute the features for the model. 120 | The backbone should return an OrderedDict[Tensor], with the key being 121 | "out" for the last feature map used, and "aux" if an auxiliary classifier 122 | is used. 123 | classifier (nn.Module): module that takes the "out" element returned from 124 | the backbone and returns a dense prediction. 125 | aux_classifier (nn.Module, optional): auxiliary classifier used during training 126 | """ 127 | pass 128 | 129 | 130 | class ASPP(nn.Module): 131 | def __init__(self, in_channels, atrous_rates): 132 | super(ASPP, self).__init__() 133 | out_channels = 256 134 | modules = [] 135 | modules.append( 136 | nn.Sequential(nn.Conv2d(in_channels, out_channels, 1, bias=False), 137 | nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True))) 138 | 139 | rate1, rate2, rate3 = tuple(atrous_rates) 140 | modules.append(ASPPConv(in_channels, out_channels, rate1)) 141 | modules.append(ASPPConv(in_channels, out_channels, rate2)) 142 | modules.append(ASPPConv(in_channels, out_channels, rate3)) 143 | modules.append(ASPPPooling(in_channels, out_channels)) 144 | 145 | self.convs = nn.ModuleList(modules) 146 | 147 | self.project = nn.Sequential( 148 | nn.Conv2d(5 * out_channels, out_channels, 1, bias=False), 149 | nn.BatchNorm2d(out_channels), 150 | nn.ReLU(inplace=True), 151 | nn.Dropout(0.1), 152 | ) 153 | 154 | def forward(self, x): 155 | res = [] 156 | for conv in self.convs: 157 | res.append(conv(x)) 158 | res = torch.cat(res, dim=1) 159 | return self.project(res) 160 | 161 | 162 | class ASPPConv(nn.Sequential): 163 | def __init__(self, in_channels, out_channels, dilation): 164 | modules = [ 165 | nn.Conv2d(in_channels, 166 | out_channels, 167 | 3, 168 | padding=dilation, 169 | dilation=dilation, 170 | bias=False), 171 | nn.BatchNorm2d(out_channels), 172 | nn.ReLU(inplace=True) 173 | ] 174 | super(ASPPConv, self).__init__(*modules) 175 | 176 | 177 | class ASPPPooling(nn.Sequential): 178 | def __init__(self, in_channels, out_channels): 179 | super(ASPPPooling, self).__init__( 180 | nn.AdaptiveAvgPool2d(11), ### 181 | nn.Conv2d(in_channels, out_channels, 1, bias=False), 182 | nn.BatchNorm2d(out_channels), 183 | nn.ReLU(inplace=True)) 184 | 185 | def forward(self, x): 186 | size = x.shape[-2:] 187 | x = super(ASPPPooling, self).forward(x) 188 | return F.interpolate(x, 189 | size=size, 190 | mode='bilinear', 191 | align_corners=False) 192 | 193 | 194 | def deeplabv3plus_pvtv2(num_classes=1, 195 | output_stride=8, 196 | pretrained_backbone=True): 197 | """Constructs a DeepLabV3 model with a pvtv2 backbone. 198 | Args: 199 | num_classes (int): number of classes. 200 | output_stride (int): output stride for deeplab. 201 | pretrained_backbone (bool): If True, use the pretrained backbone. 202 | """ 203 | return _segm_pvtv2('deeplabv3plus', 204 | 'pvtv2', 205 | num_classes, 206 | output_stride=output_stride, 207 | pretrained_backbone=pretrained_backbone) 208 | 209 | 210 | class BAT(nn.Module): 211 | def __init__(self, in_channel=512): 212 | super(BAT, self).__init__() 213 | self.in_channel = in_channel 214 | self.conv1 = nn.Sequential( 215 | nn.Conv2d(in_channels=self.in_channel, 216 | out_channels=512, 217 | kernel_size=(3, 3), 218 | padding=1), nn.BatchNorm2d(512), nn.ReLU()) 219 | self.conv2 = nn.Sequential( 220 | nn.Conv2d(in_channels=512, 221 | out_channels=256, 222 | kernel_size=(3, 3), 223 | padding=1), nn.BatchNorm2d(256), nn.ReLU()) 224 | self.conv3 = nn.Sequential( 225 | nn.Conv2d(in_channels=256, 226 | out_channels=self.in_channel, 227 | kernel_size=(3, 3), 228 | padding=1), nn.BatchNorm2d(self.in_channel), nn.ReLU()) 229 | # self.conv1=RepLKBlock(self.in_channel,128,31,5,0.) 230 | # self.conv2=RepLKBlock(self.in_channel,256,29,5,0.1) 231 | # self.conv3=RepLKBlock(self.in_channel,512,27,5,0.2) 232 | # self.conv5=RepLKBlock(self.in_channel,1024,13,5,0.3) 233 | self.conv4 = nn.Sequential( 234 | nn.Conv2d(in_channels=self.in_channel, 235 | out_channels=1, 236 | kernel_size=(1, 1)), nn.BatchNorm2d(1), nn.ReLU()) 237 | # self.sigmoid=nn.Sigmoid() 238 | def forward(self, x): 239 | point1 = self.conv1(x) 240 | # # point1=point1+x 241 | point1 = self.conv2(point1) 242 | point1 = self.conv3(point1) 243 | # # point1=point1+point2 244 | # point2=self.conv5(point1) 245 | point1 = point1 + x 246 | point1 = self.conv4(point1) 247 | # point1=self.sigmoid(point1) 248 | return x, point1 249 | 250 | 251 | # class BAT(nn.Module): 252 | # def __init__( 253 | # self, 254 | # num_classes, 255 | # point_pred, 256 | # in_channels=512, 257 | # decoder=False, 258 | # transformer_type_index=0, 259 | # hidden_features=256, # 256 260 | # number_of_query_positions=1, 261 | # segmentation_attention_heads=8): 262 | 263 | # super(BAT, self).__init__() 264 | 265 | # self.num_classes = num_classes 266 | # self.point_pred = point_pred 267 | # self.transformer_type = "BoundaryAwareTransformer" if transformer_type_index == 0 else "Transformer" 268 | # self.use_decoder = decoder 269 | 270 | # self.in_channels = in_channels 271 | 272 | # # self.convolution_mapping = nn.Conv2d(in_channels=in_channels, 273 | # # out_channels=hidden_features, 274 | # # kernel_size=(1, 1), 275 | # # stride=(1, 1), 276 | # # padding=(0, 0), 277 | # # bias=True) 278 | 279 | # self.query_positions = nn.Parameter(data=torch.randn( 280 | # number_of_query_positions, hidden_features, dtype=torch.float), 281 | # requires_grad=True) 282 | 283 | # self.row_embedding = nn.Parameter(data=torch.randn(100, 284 | # hidden_features // 285 | # 2, 286 | # dtype=torch.float), 287 | # requires_grad=True) 288 | # self.column_embedding = nn.Parameter(data=torch.randn( 289 | # 100, hidden_features // 2, dtype=torch.float), 290 | # requires_grad=True) 291 | 292 | # # self.transformer =BoundaryAwareTransformer(d_model=hidden_features,normalize_before=False,num_encoder_layers=6,num_decoder_layers=2,Atrous=False) 293 | # self.transformer =BoundaryAwareTransformer(d_model=hidden_features,normalize_before=False,num_decoder_layers=0,point_pred_layers=1,Atrous=False) 294 | 295 | # if self.use_decoder: 296 | # self.BCA = BoundaryCrossAttention(hidden_features, 8) 297 | 298 | # # self.trans_out_conv = nn.Conv2d(in_channels=hidden_features, 299 | # # out_channels=in_channels, 300 | # # kernel_size=(1, 1), 301 | # # stride=(1, 1), 302 | # # padding=(0, 0), 303 | # # bias=True) 304 | 305 | # def forward(self, x): 306 | # h = x.size()[2] 307 | # w = x.size()[3] 308 | # feature_map = x 309 | # features=x 310 | # # features = self.convolution_mapping(feature_map) 311 | # height, width = features.shape[2:] 312 | # batch_size = features.shape[0] 313 | # positional_embeddings = torch.cat([ 314 | # self.column_embedding[:height].unsqueeze(dim=0).repeat( 315 | # height, 1, 1), 316 | # self.row_embedding[:width].unsqueeze(dim=1).repeat(1, width, 1) 317 | # ],dim=-1).permute(2, 0, 1).unsqueeze(0).repeat(batch_size, 1, 1, 1) 318 | 319 | # if self.transformer_type == 'BoundaryAwareTransformer': 320 | # latent_tensor, features_encoded, point_maps = self.transformer( 321 | # features, None, self.query_positions, positional_embeddings) 322 | # else: 323 | # latent_tensor, features_encoded = self.transformer( 324 | # features, None, self.query_positions, positional_embeddings) 325 | # point_maps = [] 326 | 327 | # latent_tensor = latent_tensor.permute(2, 0, 1) 328 | # # shape:(bs, 1 , 128) 329 | 330 | # # if self.use_decoder: 331 | # # features_encoded, point_dec = self.BCA(features_encoded, 332 | # # latent_tensor) 333 | # # point_maps.append(point_dec) 334 | 335 | # # trans_feature_maps = self.trans_out_conv( 336 | # # features_encoded.contiguous()) #.contiguous() 337 | # trans_feature_maps = features_encoded.contiguous() 338 | # trans_feature_maps = trans_feature_maps + feature_map 339 | 340 | # output = F.interpolate( 341 | # trans_feature_maps, size=(h, w), 342 | # mode="bilinear") # (shape: (batch_size, num_classes, h, w)) 343 | 344 | # if self.point_pred == 1: 345 | # return output, point_maps[0] 346 | 347 | # return output 348 | 349 | if __name__ == '__main__': 350 | import os 351 | os.environ['CUDA_VISIBLE_DEVICES'] = '4' 352 | model = deeplabv3plus_pvtv2().cuda() 353 | input_tensor = torch.randn(1, 3, 352, 352).cuda() 354 | 355 | prediction1 = model(input_tensor) 356 | print(prediction1.size()) -------------------------------------------------------------------------------- /lib/TransFuse/vision_transformer.py: -------------------------------------------------------------------------------- 1 | """ Vision Transformer (ViT) in PyTorch 2 | A PyTorch implement of Vision Transformers as described in 3 | 'An Image Is Worth 16 x 16 Words: Transformers for Image Recognition at Scale' - https://arxiv.org/abs/2010.11929 4 | The official jax code is released and available at https://github.com/google-research/vision_transformer 5 | Status/TODO: 6 | * Models updated to be compatible with official impl. Args added to support backward compat for old PyTorch weights. 7 | * Weights ported from official jax impl for 384x384 base and small models, 16x16 and 32x32 patches. 8 | * Trained (supervised on ImageNet-1k) my custom 'small' patch model to 77.9, 'base' to 79.4 top-1 with this code. 9 | * Hopefully find time and GPUs for SSL or unsupervised pretraining on OpenImages w/ ImageNet fine-tune in future. 10 | Acknowledgments: 11 | * The paper authors for releasing code and weights, thanks! 12 | * I fixed my class token impl based on Phil Wang's https://github.com/lucidrains/vit-pytorch ... check it out 13 | for some einops/einsum fun 14 | * Simple transformer style inspired by Andrej Karpathy's https://github.com/karpathy/minGPT 15 | * Bert reference code checks against Huggingface Transformers and Tensorflow Bert 16 | Hacked together by / Copyright 2020 Ross Wightman 17 | """ 18 | import torch 19 | import torch.nn as nn 20 | from functools import partial 21 | 22 | from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 23 | from timm.models.helpers import load_pretrained 24 | 25 | from timm.models.registry import register_model 26 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_ 27 | 28 | 29 | 30 | def _cfg(url='', **kwargs): 31 | return { 32 | 'url': url, 33 | 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, 34 | 'crop_pct': .9, 'interpolation': 'bicubic', 35 | 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 36 | 'first_conv': 'patch_embed.proj', 'classifier': 'head', 37 | **kwargs 38 | } 39 | 40 | 41 | default_cfgs = { 42 | # patch models 43 | 'vit_small_patch16_224': _cfg( 44 | url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/vit_small_p16_224-15ec54c9.pth', 45 | ), 46 | 'vit_base_patch16_224': _cfg( 47 | url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_224-80ecf9dd.pth', 48 | mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), 49 | ), 50 | 'vit_base_patch16_384': _cfg( 51 | url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_384-83fb41ba.pth', 52 | input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0), 53 | 'vit_base_patch32_384': _cfg( 54 | url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p32_384-830016f5.pth', 55 | input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0), 56 | 'vit_large_patch16_224': _cfg( 57 | url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p16_224-4ee7a4dc.pth', 58 | mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), 59 | 'vit_large_patch16_384': _cfg( 60 | url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p16_384-b3be5167.pth', 61 | input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0), 62 | 'vit_large_patch32_384': _cfg( 63 | url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p32_384-9b920ba8.pth', 64 | input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0), 65 | 'vit_huge_patch16_224': _cfg(), 66 | 'vit_huge_patch32_384': _cfg(input_size=(3, 384, 384)), 67 | # hybrid models 68 | 'vit_small_resnet26d_224': _cfg(), 69 | 'vit_small_resnet50d_s3_224': _cfg(), 70 | 'vit_base_resnet26d_224': _cfg(), 71 | 'vit_base_resnet50d_224': _cfg(), 72 | } 73 | 74 | 75 | class Mlp(nn.Module): 76 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 77 | super().__init__() 78 | out_features = out_features or in_features 79 | hidden_features = hidden_features or in_features 80 | self.fc1 = nn.Linear(in_features, hidden_features) 81 | self.act = act_layer() 82 | self.fc2 = nn.Linear(hidden_features, out_features) 83 | self.drop = nn.Dropout(drop) 84 | 85 | def forward(self, x): 86 | x = self.fc1(x) 87 | x = self.act(x) 88 | x = self.drop(x) 89 | x = self.fc2(x) 90 | x = self.drop(x) 91 | return x 92 | 93 | 94 | class Attention(nn.Module): 95 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): 96 | super().__init__() 97 | self.num_heads = num_heads 98 | head_dim = dim // num_heads 99 | # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights 100 | self.scale = qk_scale or head_dim ** -0.5 101 | 102 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 103 | self.attn_drop = nn.Dropout(attn_drop) 104 | self.proj = nn.Linear(dim, dim) 105 | self.proj_drop = nn.Dropout(proj_drop) 106 | 107 | def forward(self, x): 108 | B, N, C = x.shape 109 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 110 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) 111 | 112 | attn = (q @ k.transpose(-2, -1)) * self.scale 113 | attn = attn.softmax(dim=-1) 114 | attn = self.attn_drop(attn) 115 | 116 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 117 | x = self.proj(x) 118 | x = self.proj_drop(x) 119 | return x 120 | 121 | 122 | class Block(nn.Module): 123 | 124 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 125 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): 126 | super().__init__() 127 | self.norm1 = norm_layer(dim) 128 | self.attn = Attention( 129 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) 130 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here 131 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 132 | self.norm2 = norm_layer(dim) 133 | mlp_hidden_dim = int(dim * mlp_ratio) 134 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 135 | 136 | def forward(self, x): 137 | x = x + self.drop_path(self.attn(self.norm1(x))) 138 | x = x + self.drop_path(self.mlp(self.norm2(x))) 139 | return x 140 | 141 | 142 | class PatchEmbed(nn.Module): 143 | """ Image to Patch Embedding 144 | """ 145 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): 146 | super().__init__() 147 | img_size = to_2tuple(img_size) 148 | patch_size = to_2tuple(patch_size) 149 | num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) 150 | self.img_size = img_size 151 | self.patch_size = patch_size 152 | self.num_patches = num_patches 153 | 154 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) 155 | 156 | def forward(self, x): 157 | B, C, H, W = x.shape 158 | 159 | # FIXME look at relaxing size constraints 160 | #assert H == self.img_size[0] and W == self.img_size[1], \ 161 | # f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." 162 | x = self.proj(x).flatten(2).transpose(1, 2) 163 | return x 164 | 165 | 166 | 167 | class VisionTransformer(nn.Module): 168 | """ Vision Transformer with support for patch or hybrid CNN input stage 169 | """ 170 | def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, 171 | num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., 172 | drop_path_rate=0., hybrid_backbone=None, norm_layer=nn.LayerNorm): 173 | super().__init__() 174 | self.num_classes = num_classes 175 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models 176 | 177 | if hybrid_backbone is not None: 178 | self.patch_embed = HybridEmbed( 179 | hybrid_backbone, img_size=img_size, in_chans=in_chans, embed_dim=embed_dim) 180 | else: 181 | self.patch_embed = PatchEmbed( 182 | img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) 183 | num_patches = self.patch_embed.num_patches 184 | 185 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 186 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) 187 | self.pos_drop = nn.Dropout(p=drop_rate) 188 | 189 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule 190 | self.blocks = nn.ModuleList([ 191 | Block( 192 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 193 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer) 194 | for i in range(depth)]) 195 | self.norm = norm_layer(embed_dim) 196 | 197 | # NOTE as per official impl, we could have a pre-logits representation dense layer + tanh here 198 | #self.repr = nn.Linear(embed_dim, representation_size) 199 | #self.repr_act = nn.Tanh() 200 | 201 | # Classifier head 202 | self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity() 203 | 204 | trunc_normal_(self.pos_embed, std=.02) 205 | trunc_normal_(self.cls_token, std=.02) 206 | self.apply(self._init_weights) 207 | 208 | def _init_weights(self, m): 209 | if isinstance(m, nn.Linear): 210 | trunc_normal_(m.weight, std=.02) 211 | if isinstance(m, nn.Linear) and m.bias is not None: 212 | nn.init.constant_(m.bias, 0) 213 | elif isinstance(m, nn.LayerNorm): 214 | nn.init.constant_(m.bias, 0) 215 | nn.init.constant_(m.weight, 1.0) 216 | 217 | @torch.jit.ignore 218 | def no_weight_decay(self): 219 | return {'pos_embed', 'cls_token'} 220 | 221 | def get_classifier(self): 222 | return self.head 223 | 224 | def reset_classifier(self, num_classes, global_pool=''): 225 | self.num_classes = num_classes 226 | self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() 227 | 228 | def forward_features(self, x): 229 | B = x.shape[0] 230 | x = self.patch_embed(x) 231 | 232 | cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks 233 | x = torch.cat((cls_tokens, x), dim=1) 234 | x = x + self.pos_embed 235 | x = self.pos_drop(x) 236 | 237 | for blk in self.blocks: 238 | x = blk(x) 239 | 240 | x = self.norm(x) 241 | return x[:, 0] 242 | 243 | def forward(self, x): 244 | x = self.forward_features(x) 245 | x = self.head(x) 246 | return x 247 | 248 | 249 | def _conv_filter(state_dict, patch_size=16): 250 | """ convert patch embedding weight from manual patchify + linear proj to conv""" 251 | out_dict = {} 252 | for k, v in state_dict.items(): 253 | if 'patch_embed.proj.weight' in k: 254 | v = v.reshape((v.shape[0], 3, patch_size, patch_size)) 255 | out_dict[k] = v 256 | return out_dict 257 | 258 | 259 | @register_model 260 | def vit_small_patch16_224(pretrained=False, **kwargs): 261 | if pretrained: 262 | # NOTE my scale was wrong for original weights, leaving this here until I have better ones for this model 263 | kwargs.setdefault('qk_scale', 768 ** -0.5) 264 | model = VisionTransformer(patch_size=16, embed_dim=768, depth=8, num_heads=8, mlp_ratio=3., **kwargs) 265 | model.default_cfg = default_cfgs['vit_small_patch16_224'] 266 | if pretrained: 267 | load_pretrained( 268 | model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3), filter_fn=_conv_filter) 269 | return model 270 | 271 | 272 | @register_model 273 | def vit_base_patch16_224(pretrained=False, **kwargs): 274 | model = VisionTransformer( 275 | patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, 276 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 277 | model.default_cfg = default_cfgs['vit_base_patch16_224'] 278 | if pretrained: 279 | load_pretrained( 280 | model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3), filter_fn=_conv_filter) 281 | return model 282 | 283 | 284 | @register_model 285 | def vit_base_patch16_384(pretrained=False, **kwargs): 286 | model = VisionTransformer( 287 | img_size=384, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, 288 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 289 | model.default_cfg = default_cfgs['vit_base_patch16_384'] 290 | if pretrained: 291 | load_pretrained(model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3)) 292 | return model 293 | 294 | 295 | @register_model 296 | def vit_base_patch32_384(pretrained=False, **kwargs): 297 | model = VisionTransformer( 298 | img_size=384, patch_size=32, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, 299 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 300 | model.default_cfg = default_cfgs['vit_base_patch32_384'] 301 | if pretrained: 302 | load_pretrained(model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3)) 303 | return model 304 | 305 | 306 | @register_model 307 | def vit_large_patch16_224(pretrained=False, **kwargs): 308 | model = VisionTransformer( 309 | patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True, 310 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 311 | model.default_cfg = default_cfgs['vit_large_patch16_224'] 312 | if pretrained: 313 | load_pretrained(model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3)) 314 | return model 315 | 316 | 317 | @register_model 318 | def vit_large_patch16_384(pretrained=False, **kwargs): 319 | model = VisionTransformer( 320 | img_size=384, patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True, 321 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 322 | model.default_cfg = default_cfgs['vit_large_patch16_384'] 323 | if pretrained: 324 | load_pretrained(model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3)) 325 | return model 326 | 327 | 328 | @register_model 329 | def vit_large_patch32_384(pretrained=False, **kwargs): 330 | model = VisionTransformer( 331 | img_size=384, patch_size=32, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True, 332 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 333 | model.default_cfg = default_cfgs['vit_large_patch32_384'] 334 | if pretrained: 335 | load_pretrained(model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3)) 336 | return model 337 | 338 | 339 | @register_model 340 | def vit_huge_patch16_224(pretrained=False, **kwargs): 341 | model = VisionTransformer(patch_size=16, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4, **kwargs) 342 | model.default_cfg = default_cfgs['vit_huge_patch16_224'] 343 | return model 344 | 345 | 346 | @register_model 347 | def vit_huge_patch32_384(pretrained=False, **kwargs): 348 | model = VisionTransformer( 349 | img_size=384, patch_size=32, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4, **kwargs) 350 | model.default_cfg = default_cfgs['vit_huge_patch32_384'] 351 | return model -------------------------------------------------------------------------------- /lib/TransFuse/TransFuse.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torchvision.models import resnet34 as resnet 4 | from torchvision.models import resnet50 5 | from .DeiT import deit_small_patch16_224 as deit 6 | from .DeiT import deit_base_patch16_224 as deit_b 7 | from torch.nn import CrossEntropyLoss, Dropout, Softmax, Linear, Conv2d, LayerNorm 8 | import torch.nn.functional as F 9 | import numpy as np 10 | import math 11 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_ 12 | 13 | 14 | class ChannelPool(nn.Module): 15 | def forward(self, x): 16 | return torch.cat( 17 | (torch.max(x, 1)[0].unsqueeze(1), torch.mean(x, 1).unsqueeze(1)), 18 | dim=1) 19 | 20 | 21 | class BiFusion_block(nn.Module): 22 | def __init__(self, ch_1, ch_2, r_2, ch_int, ch_out, drop_rate=0.): 23 | super(BiFusion_block, self).__init__() 24 | 25 | # channel attention for F_g, use SE Block 26 | self.fc1 = nn.Conv2d(ch_2, ch_2 // r_2, kernel_size=1) 27 | self.relu = nn.ReLU(inplace=True) 28 | self.fc2 = nn.Conv2d(ch_2 // r_2, ch_2, kernel_size=1) 29 | self.sigmoid = nn.Sigmoid() 30 | 31 | # spatial attention for F_l 32 | self.compress = ChannelPool() 33 | self.spatial = Conv(2, 1, 7, bn=True, relu=False, bias=False) 34 | 35 | # bi-linear modelling for both 36 | self.W_g = Conv(ch_1, ch_int, 1, bn=True, relu=False) 37 | self.W_x = Conv(ch_2, ch_int, 1, bn=True, relu=False) 38 | self.W = Conv(ch_int, ch_int, 3, bn=True, relu=True) 39 | 40 | self.relu = nn.ReLU(inplace=True) 41 | 42 | self.residual = Residual(ch_1 + ch_2 + ch_int, ch_out) 43 | 44 | self.dropout = nn.Dropout2d(drop_rate) 45 | self.drop_rate = drop_rate 46 | 47 | def forward(self, g, x): 48 | # bilinear pooling 49 | W_g = self.W_g(g) 50 | W_x = self.W_x(x) 51 | bp = self.W(W_g * W_x) 52 | 53 | # spatial attention for cnn branch 54 | g_in = g 55 | g = self.compress(g) 56 | g = self.spatial(g) 57 | g = self.sigmoid(g) * g_in 58 | 59 | # channel attetion for transformer branch 60 | x_in = x 61 | x = x.mean((2, 3), keepdim=True) 62 | x = self.fc1(x) 63 | x = self.relu(x) 64 | x = self.fc2(x) 65 | x = self.sigmoid(x) * x_in 66 | fuse = self.residual(torch.cat([g, x, bp], 1)) 67 | 68 | if self.drop_rate > 0: 69 | return self.dropout(fuse) 70 | else: 71 | return fuse 72 | 73 | 74 | class TransFuse_S(nn.Module): 75 | def __init__(self, 76 | num_classes=1, 77 | drop_rate=0.2, 78 | normal_init=True, 79 | pretrained=False): 80 | super(TransFuse_S, self).__init__() 81 | 82 | self.resnet = resnet() 83 | if pretrained: 84 | self.resnet.load_state_dict( 85 | torch.load( 86 | '/raid/wjc/code/xbound_former/lib/TransFuse/pretrained/resnet34-43635321.pth' 87 | )) 88 | self.resnet.fc = nn.Identity() 89 | self.resnet.layer4 = nn.Identity() 90 | 91 | self.transformer = deit(pretrained=pretrained) 92 | 93 | self.up1 = Up(in_ch1=384, out_ch=128) 94 | self.up2 = Up(128, 64) 95 | 96 | self.final_x = nn.Sequential( 97 | Conv(256, 64, 1, bn=True, relu=True), 98 | Conv(64, 64, 3, bn=True, relu=True), 99 | Conv(64, num_classes, 3, bn=False, relu=False)) 100 | 101 | self.final_1 = nn.Sequential( 102 | Conv(64, 64, 3, bn=True, relu=True), 103 | Conv(64, num_classes, 3, bn=False, relu=False)) 104 | 105 | self.final_2 = nn.Sequential( 106 | Conv(64, 64, 3, bn=True, relu=True), 107 | Conv(64, num_classes, 3, bn=False, relu=False)) 108 | 109 | self.up_c = BiFusion_block(ch_1=256, 110 | ch_2=384, 111 | r_2=4, 112 | ch_int=256, 113 | ch_out=256, 114 | drop_rate=drop_rate / 2) 115 | 116 | self.up_c_1_1 = BiFusion_block(ch_1=128, 117 | ch_2=128, 118 | r_2=2, 119 | ch_int=128, 120 | ch_out=128, 121 | drop_rate=drop_rate / 2) 122 | self.up_c_1_2 = Up(in_ch1=256, out_ch=128, in_ch2=128, attn=True) 123 | 124 | self.up_c_2_1 = BiFusion_block(ch_1=64, 125 | ch_2=64, 126 | r_2=1, 127 | ch_int=64, 128 | ch_out=64, 129 | drop_rate=drop_rate / 2) 130 | self.up_c_2_2 = Up(128, 64, 64, attn=True) 131 | 132 | self.drop = nn.Dropout2d(drop_rate) 133 | 134 | if normal_init: 135 | self.init_weights() 136 | 137 | def forward(self, imgs, labels=None): 138 | # bottom-up path 139 | x_b = self.transformer(imgs) 140 | x_b = torch.transpose(x_b, 1, 2) 141 | x_b = x_b.view(x_b.shape[0], -1, 142 | imgs.size(2) // 16, 143 | imgs.size(3) // 16) # (x_b.shape[0], -1, 12, 16) 144 | x_b = self.drop(x_b) 145 | 146 | x_b_1 = self.up1(x_b) 147 | x_b_1 = self.drop(x_b_1) 148 | 149 | x_b_2 = self.up2(x_b_1) # transformer pred supervise here 150 | x_b_2 = self.drop(x_b_2) 151 | 152 | # top-down path 153 | x_u = self.resnet.conv1(imgs) 154 | x_u = self.resnet.bn1(x_u) 155 | x_u = self.resnet.relu(x_u) 156 | x_u = self.resnet.maxpool(x_u) 157 | 158 | x_u_2 = self.resnet.layer1(x_u) 159 | x_u_2 = self.drop(x_u_2) 160 | 161 | x_u_1 = self.resnet.layer2(x_u_2) 162 | x_u_1 = self.drop(x_u_1) 163 | 164 | x_u = self.resnet.layer3(x_u_1) 165 | x_u = self.drop(x_u) 166 | 167 | # joint path 168 | x_c = self.up_c(x_u, x_b) 169 | 170 | x_c_1_1 = self.up_c_1_1(x_u_1, x_b_1) 171 | x_c_1 = self.up_c_1_2(x_c, x_c_1_1) 172 | 173 | x_c_2_1 = self.up_c_2_1(x_u_2, x_b_2) 174 | x_c_2 = self.up_c_2_2(x_c_1, 175 | x_c_2_1) # joint predict low supervise here 176 | 177 | # decoder part 178 | map_x = F.interpolate(self.final_x(x_c), 179 | scale_factor=16, 180 | mode='bilinear') 181 | map_1 = F.interpolate(self.final_1(x_b_2), 182 | scale_factor=4, 183 | mode='bilinear') 184 | map_2 = F.interpolate(self.final_2(x_c_2), 185 | scale_factor=4, 186 | mode='bilinear') 187 | return map_x, map_1, map_2 188 | 189 | def init_weights(self): 190 | self.up1.apply(init_weights) 191 | self.up2.apply(init_weights) 192 | self.final_x.apply(init_weights) 193 | self.final_1.apply(init_weights) 194 | self.final_2.apply(init_weights) 195 | self.up_c.apply(init_weights) 196 | self.up_c_1_1.apply(init_weights) 197 | self.up_c_1_2.apply(init_weights) 198 | self.up_c_2_1.apply(init_weights) 199 | self.up_c_2_2.apply(init_weights) 200 | 201 | 202 | class TransFuse_L(nn.Module): 203 | def __init__(self, 204 | num_classes=1, 205 | drop_rate=0.2, 206 | normal_init=True, 207 | pretrained=False): 208 | super(TransFuse_L, self).__init__() 209 | 210 | self.resnet = resnet50() 211 | if pretrained: 212 | self.resnet.load_state_dict( 213 | torch.load( 214 | '/home/chenfei/.cache/torch/hub/checkpoints/resnet50-19c8e357.pth' 215 | )) 216 | self.resnet.fc = nn.Identity() 217 | self.resnet.layer4 = nn.Identity() 218 | 219 | self.transformer = deit_b(pretrained=pretrained) 220 | 221 | self.up1 = Up(in_ch1=768, out_ch=512) 222 | self.up2 = Up(512, 256) 223 | 224 | self.final_x = nn.Sequential( 225 | Conv(1024, 256, 1, bn=True, relu=True), 226 | Conv(256, 256, 3, bn=True, relu=True), 227 | Conv(256, num_classes, 3, bn=False, relu=False)) 228 | 229 | self.final_1 = nn.Sequential( 230 | Conv(256, 256, 3, bn=True, relu=True), 231 | Conv(256, num_classes, 3, bn=False, relu=False)) 232 | 233 | self.final_2 = nn.Sequential( 234 | Conv(256, 256, 3, bn=True, relu=True), 235 | Conv(256, num_classes, 3, bn=False, relu=False)) 236 | 237 | self.up_c = BiFusion_block(ch_1=1024, 238 | ch_2=768, 239 | r_2=4, 240 | ch_int=1024, 241 | ch_out=1024, 242 | drop_rate=drop_rate / 2) 243 | 244 | self.up_c_1_1 = BiFusion_block(ch_1=512, 245 | ch_2=512, 246 | r_2=2, 247 | ch_int=512, 248 | ch_out=512, 249 | drop_rate=drop_rate / 2) 250 | self.up_c_1_2 = Up(in_ch1=1024, out_ch=512, in_ch2=512, attn=True) 251 | 252 | self.up_c_2_1 = BiFusion_block(ch_1=256, 253 | ch_2=256, 254 | r_2=1, 255 | ch_int=256, 256 | ch_out=256, 257 | drop_rate=drop_rate / 2) 258 | self.up_c_2_2 = Up(512, 256, 256, attn=True) 259 | 260 | self.drop = nn.Dropout2d(drop_rate) 261 | 262 | if normal_init: 263 | self.init_weights() 264 | 265 | def forward(self, imgs, labels=None): 266 | # bottom-up path 267 | x_b = self.transformer(imgs) 268 | x_b = torch.transpose(x_b, 1, 2) 269 | x_b = x_b.view(x_b.shape[0], -1, 32, 32) # (x_b.shape[0], -1, 12, 16) 270 | x_b = self.drop(x_b) 271 | 272 | x_b_1 = self.up1(x_b) 273 | x_b_1 = self.drop(x_b_1) 274 | 275 | x_b_2 = self.up2(x_b_1) # transformer pred supervise here 276 | x_b_2 = self.drop(x_b_2) 277 | 278 | # top-down path 279 | x_u = self.resnet.conv1(imgs) 280 | x_u = self.resnet.bn1(x_u) 281 | x_u = self.resnet.relu(x_u) 282 | x_u = self.resnet.maxpool(x_u) 283 | 284 | x_u_2 = self.resnet.layer1(x_u) 285 | x_u_2 = self.drop(x_u_2) 286 | 287 | x_u_1 = self.resnet.layer2(x_u_2) 288 | x_u_1 = self.drop(x_u_1) 289 | 290 | x_u = self.resnet.layer3(x_u_1) 291 | x_u = self.drop(x_u) 292 | 293 | # joint path 294 | x_c = self.up_c(x_u, x_b) 295 | 296 | x_c_1_1 = self.up_c_1_1(x_u_1, x_b_1) 297 | x_c_1 = self.up_c_1_2(x_c, x_c_1_1) 298 | 299 | x_c_2_1 = self.up_c_2_1(x_u_2, x_b_2) 300 | x_c_2 = self.up_c_2_2(x_c_1, 301 | x_c_2_1) # joint predict low supervise here 302 | 303 | # decoder part 304 | map_x = F.interpolate(self.final_x(x_c), 305 | scale_factor=16, 306 | mode='bilinear') 307 | map_1 = F.interpolate(self.final_1(x_b_2), 308 | scale_factor=4, 309 | mode='bilinear') 310 | map_2 = F.interpolate(self.final_2(x_c_2), 311 | scale_factor=4, 312 | mode='bilinear') 313 | return map_x, map_1, map_2 314 | 315 | def init_weights(self): 316 | self.up1.apply(init_weights) 317 | self.up2.apply(init_weights) 318 | self.final_x.apply(init_weights) 319 | self.final_1.apply(init_weights) 320 | self.final_2.apply(init_weights) 321 | self.up_c.apply(init_weights) 322 | self.up_c_1_1.apply(init_weights) 323 | self.up_c_1_2.apply(init_weights) 324 | self.up_c_2_1.apply(init_weights) 325 | self.up_c_2_2.apply(init_weights) 326 | 327 | 328 | def init_weights(m): 329 | """ 330 | Initialize weights of layers using Kaiming Normal (He et al.) as argument of "Apply" function of 331 | "nn.Module" 332 | :param m: Layer to initialize 333 | :return: None 334 | """ 335 | if isinstance(m, nn.Conv2d): 336 | ''' 337 | fan_in, _ = nn.init._calculate_fan_in_and_fan_out(m.weight) 338 | trunc_normal_(m.weight, std=math.sqrt(1.0/fan_in)/.87962566103423978) 339 | if m.bias is not None: 340 | nn.init.zeros_(m.bias) 341 | ''' 342 | nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu') 343 | if m.bias is not None: 344 | fan_in, _ = nn.init._calculate_fan_in_and_fan_out(m.weight) 345 | bound = 1 / math.sqrt(fan_in) 346 | nn.init.uniform_(m.bias, -bound, bound) 347 | 348 | elif isinstance(m, nn.BatchNorm2d): 349 | nn.init.constant_(m.weight, 1) 350 | nn.init.constant_(m.bias, 0) 351 | 352 | 353 | class Up(nn.Module): 354 | """Upscaling then double conv""" 355 | def __init__(self, in_ch1, out_ch, in_ch2=0, attn=False): 356 | super().__init__() 357 | 358 | self.up = nn.Upsample(scale_factor=2, 359 | mode='bilinear', 360 | align_corners=True) 361 | self.conv = DoubleConv(in_ch1 + in_ch2, out_ch) 362 | 363 | if attn: 364 | self.attn_block = Attention_block(in_ch1, in_ch2, out_ch) 365 | else: 366 | self.attn_block = None 367 | 368 | def forward(self, x1, x2=None): 369 | 370 | x1 = self.up(x1) 371 | # input is CHW 372 | if x2 is not None: 373 | diffY = torch.tensor([x2.size()[2] - x1.size()[2]]) 374 | diffX = torch.tensor([x2.size()[3] - x1.size()[3]]) 375 | 376 | x1 = F.pad(x1, [ 377 | diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2 378 | ]) 379 | 380 | if self.attn_block is not None: 381 | x2 = self.attn_block(x1, x2) 382 | x1 = torch.cat([x2, x1], dim=1) 383 | x = x1 384 | return self.conv(x) 385 | 386 | 387 | class Attention_block(nn.Module): 388 | def __init__(self, F_g, F_l, F_int): 389 | super(Attention_block, self).__init__() 390 | self.W_g = nn.Sequential( 391 | nn.Conv2d(F_g, 392 | F_int, 393 | kernel_size=1, 394 | stride=1, 395 | padding=0, 396 | bias=True), nn.BatchNorm2d(F_int)) 397 | self.W_x = nn.Sequential( 398 | nn.Conv2d(F_l, 399 | F_int, 400 | kernel_size=1, 401 | stride=1, 402 | padding=0, 403 | bias=True), nn.BatchNorm2d(F_int)) 404 | self.psi = nn.Sequential( 405 | nn.Conv2d(F_int, 1, kernel_size=1, stride=1, padding=0, bias=True), 406 | nn.BatchNorm2d(1), nn.Sigmoid()) 407 | self.relu = nn.ReLU(inplace=True) 408 | 409 | def forward(self, g, x): 410 | g1 = self.W_g(g) 411 | x1 = self.W_x(x) 412 | psi = self.relu(g1 + x1) 413 | psi = self.psi(psi) 414 | return x * psi 415 | 416 | 417 | class DoubleConv(nn.Module): 418 | def __init__(self, in_channels, out_channels): 419 | super().__init__() 420 | self.double_conv = nn.Sequential( 421 | nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), 422 | nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True), 423 | nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), 424 | nn.BatchNorm2d(out_channels)) 425 | self.identity = nn.Sequential( 426 | nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0), 427 | nn.BatchNorm2d(out_channels)) 428 | self.relu = nn.ReLU(inplace=True) 429 | 430 | def forward(self, x): 431 | return self.relu(self.double_conv(x) + self.identity(x)) 432 | 433 | 434 | class Residual(nn.Module): 435 | def __init__(self, inp_dim, out_dim): 436 | super(Residual, self).__init__() 437 | self.relu = nn.ReLU(inplace=True) 438 | self.bn1 = nn.BatchNorm2d(inp_dim) 439 | self.conv1 = Conv(inp_dim, int(out_dim / 2), 1, relu=False) 440 | self.bn2 = nn.BatchNorm2d(int(out_dim / 2)) 441 | self.conv2 = Conv(int(out_dim / 2), int(out_dim / 2), 3, relu=False) 442 | self.bn3 = nn.BatchNorm2d(int(out_dim / 2)) 443 | self.conv3 = Conv(int(out_dim / 2), out_dim, 1, relu=False) 444 | self.skip_layer = Conv(inp_dim, out_dim, 1, relu=False) 445 | if inp_dim == out_dim: 446 | self.need_skip = False 447 | else: 448 | self.need_skip = True 449 | 450 | def forward(self, x): 451 | if self.need_skip: 452 | residual = self.skip_layer(x) 453 | else: 454 | residual = x 455 | out = self.bn1(x) 456 | out = self.relu(out) 457 | out = self.conv1(out) 458 | out = self.bn2(out) 459 | out = self.relu(out) 460 | out = self.conv2(out) 461 | out = self.bn3(out) 462 | out = self.relu(out) 463 | out = self.conv3(out) 464 | out += residual 465 | return out 466 | 467 | 468 | class Conv(nn.Module): 469 | def __init__(self, 470 | inp_dim, 471 | out_dim, 472 | kernel_size=3, 473 | stride=1, 474 | bn=False, 475 | relu=True, 476 | bias=True): 477 | super(Conv, self).__init__() 478 | self.inp_dim = inp_dim 479 | self.conv = nn.Conv2d(inp_dim, 480 | out_dim, 481 | kernel_size, 482 | stride, 483 | padding=(kernel_size - 1) // 2, 484 | bias=bias) 485 | self.relu = None 486 | self.bn = None 487 | if relu: 488 | self.relu = nn.ReLU(inplace=True) 489 | if bn: 490 | self.bn = nn.BatchNorm2d(out_dim) 491 | 492 | def forward(self, x): 493 | assert x.size()[1] == self.inp_dim, "{} {}".format( 494 | x.size()[1], self.inp_dim) 495 | x = self.conv(x) 496 | if self.bn is not None: 497 | x = self.bn(x) 498 | if self.relu is not None: 499 | x = self.relu(x) 500 | return x -------------------------------------------------------------------------------- /lib/replknet.py: -------------------------------------------------------------------------------- 1 | # Scaling Up Your Kernels to 31x31: Revisiting Large Kernel Design in CNNs (https://arxiv.org/abs/2203.06717) 2 | # Github source: https://github.com/DingXiaoH/RepLKNet-pytorch 3 | # Licensed under The MIT License [see LICENSE for details] 4 | # Based on ConvNeXt, timm, DINO and DeiT code bases 5 | # https://github.com/facebookresearch/ConvNeXt 6 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm 7 | # https://github.com/facebookresearch/deit/ 8 | # https://github.com/facebookresearch/dino 9 | # --------------------------------------------------------' 10 | import torch 11 | import torch.nn as nn 12 | import torch.utils.checkpoint as checkpoint 13 | from timm.models.layers import DropPath 14 | import sys 15 | import os 16 | 17 | def get_conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias): 18 | if type(kernel_size) is int: 19 | use_large_impl = kernel_size > 5 20 | else: 21 | assert len(kernel_size) == 2 and kernel_size[0] == kernel_size[1] 22 | use_large_impl = kernel_size[0] > 5 23 | has_large_impl = 'LARGE_KERNEL_CONV_IMPL' in os.environ 24 | if has_large_impl and in_channels == out_channels and out_channels == groups and use_large_impl and stride == 1 and padding == kernel_size // 2 and dilation == 1: 25 | sys.path.append(os.environ['LARGE_KERNEL_CONV_IMPL']) 26 | # Please follow the instructions https://github.com/DingXiaoH/RepLKNet-pytorch/blob/main/README.md 27 | # export LARGE_KERNEL_CONV_IMPL=absolute_path_to_where_you_cloned_the_example (i.e., depthwise_conv2d_implicit_gemm.py) 28 | # TODO more efficient PyTorch implementations of large-kernel convolutions. Pull requests are welcomed. 29 | # Or you may try MegEngine. We have integrated an efficient implementation into MegEngine and it will automatically use it. 30 | from depthwise_conv2d_implicit_gemm import DepthWiseConv2dImplicitGEMM 31 | return DepthWiseConv2dImplicitGEMM(in_channels, kernel_size, bias=bias) 32 | else: 33 | return nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, 34 | padding=padding, dilation=dilation, groups=groups, bias=bias) 35 | 36 | use_sync_bn = False 37 | 38 | def enable_sync_bn(): 39 | global use_sync_bn 40 | use_sync_bn = True 41 | 42 | def get_bn(channels): 43 | if use_sync_bn: 44 | return nn.SyncBatchNorm(channels) 45 | else: 46 | return nn.BatchNorm2d(channels) 47 | 48 | def conv_bn(in_channels, out_channels, kernel_size, stride, padding, groups, dilation=1): 49 | if padding is None: 50 | padding = kernel_size // 2 51 | result = nn.Sequential() 52 | result.add_module('conv', get_conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, 53 | stride=stride, padding=padding, dilation=dilation, groups=groups, bias=False)) 54 | result.add_module('bn', get_bn(out_channels)) 55 | return result 56 | 57 | def conv_bn_relu(in_channels, out_channels, kernel_size, stride, padding, groups, dilation=1): 58 | if padding is None: 59 | padding = kernel_size // 2 60 | result = conv_bn(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, 61 | stride=stride, padding=padding, groups=groups, dilation=dilation) 62 | result.add_module('nonlinear', nn.ReLU()) 63 | return result 64 | 65 | def fuse_bn(conv, bn): 66 | kernel = conv.weight 67 | running_mean = bn.running_mean 68 | running_var = bn.running_var 69 | gamma = bn.weight 70 | beta = bn.bias 71 | eps = bn.eps 72 | std = (running_var + eps).sqrt() 73 | t = (gamma / std).reshape(-1, 1, 1, 1) 74 | return kernel * t, beta - running_mean * gamma / std 75 | 76 | class ReparamLargeKernelConv(nn.Module): 77 | 78 | def __init__(self, in_channels, out_channels, kernel_size, 79 | stride, groups, 80 | small_kernel, 81 | small_kernel_merged=False): 82 | super(ReparamLargeKernelConv, self).__init__() 83 | self.kernel_size = kernel_size 84 | self.small_kernel = small_kernel 85 | # We assume the conv does not change the feature map size, so padding = k//2. Otherwise, you may configure padding as you wish, and change the padding of small_conv accordingly. 86 | padding = kernel_size // 2 87 | if small_kernel_merged: 88 | self.lkb_reparam = get_conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, 89 | stride=stride, padding=padding, dilation=1, groups=groups, bias=True) 90 | else: 91 | self.lkb_origin = conv_bn(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, 92 | stride=stride, padding=padding, dilation=1, groups=groups) 93 | if small_kernel is not None: 94 | assert small_kernel <= kernel_size, 'The kernel size for re-param cannot be larger than the large kernel!' 95 | self.small_conv = conv_bn(in_channels=in_channels, out_channels=out_channels, kernel_size=small_kernel, 96 | stride=stride, padding=small_kernel//2, groups=groups, dilation=1) 97 | 98 | def forward(self, inputs): 99 | if hasattr(self, 'lkb_reparam'): 100 | out = self.lkb_reparam(inputs) 101 | else: 102 | out = self.lkb_origin(inputs) 103 | if hasattr(self, 'small_conv'): 104 | out += self.small_conv(inputs) 105 | return out 106 | 107 | def get_equivalent_kernel_bias(self): 108 | eq_k, eq_b = fuse_bn(self.lkb_origin.conv, self.lkb_origin.bn) 109 | if hasattr(self, 'small_conv'): 110 | small_k, small_b = fuse_bn(self.small_conv.conv, self.small_conv.bn) 111 | eq_b += small_b 112 | # add to the central part 113 | eq_k += nn.functional.pad(small_k, [(self.kernel_size - self.small_kernel) // 2] * 4) 114 | return eq_k, eq_b 115 | 116 | def merge_kernel(self): 117 | eq_k, eq_b = self.get_equivalent_kernel_bias() 118 | self.lkb_reparam = get_conv2d(in_channels=self.lkb_origin.conv.in_channels, 119 | out_channels=self.lkb_origin.conv.out_channels, 120 | kernel_size=self.lkb_origin.conv.kernel_size, stride=self.lkb_origin.conv.stride, 121 | padding=self.lkb_origin.conv.padding, dilation=self.lkb_origin.conv.dilation, 122 | groups=self.lkb_origin.conv.groups, bias=True) 123 | self.lkb_reparam.weight.data = eq_k 124 | self.lkb_reparam.bias.data = eq_b 125 | self.__delattr__('lkb_origin') 126 | if hasattr(self, 'small_conv'): 127 | self.__delattr__('small_conv') 128 | 129 | 130 | class ConvFFN(nn.Module): 131 | 132 | def __init__(self, in_channels, internal_channels, out_channels, drop_path): 133 | super().__init__() 134 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 135 | self.preffn_bn = get_bn(in_channels) 136 | self.pw1 = conv_bn(in_channels=in_channels, out_channels=internal_channels, kernel_size=1, stride=1, padding=0, groups=1) 137 | self.pw2 = conv_bn(in_channels=internal_channels, out_channels=out_channels, kernel_size=1, stride=1, padding=0, groups=1) 138 | self.nonlinear = nn.GELU() 139 | 140 | def forward(self, x): 141 | out = self.preffn_bn(x) 142 | out = self.pw1(out) 143 | out = self.nonlinear(out) 144 | out = self.pw2(out) 145 | return x + self.drop_path(out) 146 | 147 | 148 | class RepLKBlock(nn.Module): 149 | 150 | def __init__(self, in_channels, dw_channels, block_lk_size, small_kernel, drop_path, small_kernel_merged=False): 151 | super().__init__() 152 | self.pw1 = conv_bn_relu(in_channels, dw_channels, 1, 1, 0, groups=1) 153 | self.pw2 = conv_bn(dw_channels, in_channels, 1, 1, 0, groups=1) 154 | self.large_kernel = ReparamLargeKernelConv(in_channels=dw_channels, out_channels=dw_channels, kernel_size=block_lk_size, 155 | stride=1, groups=dw_channels, small_kernel=small_kernel, small_kernel_merged=small_kernel_merged) 156 | self.lk_nonlinear = nn.ReLU() 157 | self.prelkb_bn = get_bn(in_channels) 158 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 159 | print('drop path:', self.drop_path) 160 | 161 | def forward(self, x): 162 | out = self.prelkb_bn(x) 163 | out = self.pw1(out) 164 | out = self.large_kernel(out) 165 | out = self.lk_nonlinear(out) 166 | out = self.pw2(out) 167 | return x + self.drop_path(out) 168 | 169 | 170 | class RepLKNetStage(nn.Module): 171 | 172 | def __init__(self, channels, num_blocks, stage_lk_size, drop_path, 173 | small_kernel, dw_ratio=1, ffn_ratio=4, 174 | use_checkpoint=False, # train with torch.utils.checkpoint to save memory 175 | small_kernel_merged=False, 176 | norm_intermediate_features=False): 177 | super().__init__() 178 | self.use_checkpoint = use_checkpoint 179 | blks = [] 180 | for i in range(num_blocks): 181 | block_drop_path = drop_path[i] if isinstance(drop_path, list) else drop_path 182 | # Assume all RepLK Blocks within a stage share the same lk_size. You may tune it on your own model. 183 | replk_block = RepLKBlock(in_channels=channels, dw_channels=int(channels * dw_ratio), block_lk_size=stage_lk_size, 184 | small_kernel=small_kernel, drop_path=block_drop_path, small_kernel_merged=small_kernel_merged) 185 | convffn_block = ConvFFN(in_channels=channels, internal_channels=int(channels * ffn_ratio), out_channels=channels, 186 | drop_path=block_drop_path) 187 | blks.append(replk_block) 188 | blks.append(convffn_block) 189 | self.blocks = nn.ModuleList(blks) 190 | if norm_intermediate_features: 191 | self.norm = get_bn(channels) # Only use this with RepLKNet-XL on downstream tasks 192 | else: 193 | self.norm = nn.Identity() 194 | 195 | def forward(self, x): 196 | for blk in self.blocks: 197 | if self.use_checkpoint: 198 | x = checkpoint.checkpoint(blk, x) # Save training memory 199 | else: 200 | x = blk(x) 201 | return x 202 | 203 | class RepLKNet(nn.Module): 204 | 205 | def __init__(self, large_kernel_sizes, layers, channels, drop_path_rate, small_kernel, 206 | dw_ratio=1, ffn_ratio=4, in_channels=3, num_classes=1000, out_indices=None, 207 | use_checkpoint=False, 208 | small_kernel_merged=False, 209 | use_sync_bn=True, 210 | norm_intermediate_features=False # for RepLKNet-XL on COCO and ADE20K, use an extra BN to normalize the intermediate feature maps then feed them into the heads 211 | ): 212 | super().__init__() 213 | 214 | if num_classes is None and out_indices is None: 215 | raise ValueError('must specify one of num_classes (for pretraining) and out_indices (for downstream tasks)') 216 | elif num_classes is not None and out_indices is not None: 217 | raise ValueError('cannot specify both num_classes (for pretraining) and out_indices (for downstream tasks)') 218 | elif num_classes is not None and norm_intermediate_features: 219 | raise ValueError('for pretraining, no need to normalize the intermediate feature maps') 220 | self.out_indices = out_indices 221 | if use_sync_bn: 222 | enable_sync_bn() 223 | 224 | base_width = channels[0] 225 | self.use_checkpoint = use_checkpoint 226 | self.norm_intermediate_features = norm_intermediate_features 227 | self.num_stages = len(layers) 228 | self.stem = nn.ModuleList([ 229 | conv_bn_relu(in_channels=in_channels, out_channels=base_width, kernel_size=3, stride=2, padding=1, groups=1), 230 | conv_bn_relu(in_channels=base_width, out_channels=base_width, kernel_size=3, stride=1, padding=1, groups=base_width), 231 | conv_bn_relu(in_channels=base_width, out_channels=base_width, kernel_size=1, stride=1, padding=0, groups=1), 232 | conv_bn_relu(in_channels=base_width, out_channels=base_width, kernel_size=3, stride=2, padding=1, groups=base_width)]) 233 | # stochastic depth. We set block-wise drop-path rate. The higher level blocks are more likely to be dropped. This implementation follows Swin. 234 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(layers))] 235 | self.stages = nn.ModuleList() 236 | self.transitions = nn.ModuleList() 237 | for stage_idx in range(self.num_stages): 238 | layer = RepLKNetStage(channels=channels[stage_idx], num_blocks=layers[stage_idx], 239 | stage_lk_size=large_kernel_sizes[stage_idx], 240 | drop_path=dpr[sum(layers[:stage_idx]):sum(layers[:stage_idx + 1])], 241 | small_kernel=small_kernel, dw_ratio=dw_ratio, ffn_ratio=ffn_ratio, 242 | use_checkpoint=use_checkpoint, small_kernel_merged=small_kernel_merged, 243 | norm_intermediate_features=norm_intermediate_features) 244 | self.stages.append(layer) 245 | if stage_idx < len(layers) - 1: 246 | transition = nn.Sequential( 247 | conv_bn_relu(channels[stage_idx], channels[stage_idx + 1], 1, 1, 0, groups=1), 248 | conv_bn_relu(channels[stage_idx + 1], channels[stage_idx + 1], 3, stride=2, padding=1, groups=channels[stage_idx + 1])) 249 | self.transitions.append(transition) 250 | 251 | if num_classes is not None: 252 | self.norm = get_bn(channels[-1]) 253 | self.avgpool = nn.AdaptiveAvgPool2d(1) 254 | self.head = nn.Linear(channels[-1], num_classes) 255 | 256 | 257 | 258 | def forward_features(self, x): 259 | x = self.stem[0](x) 260 | for stem_layer in self.stem[1:]: 261 | if self.use_checkpoint: 262 | x = checkpoint.checkpoint(stem_layer, x) # save memory 263 | else: 264 | x = stem_layer(x) 265 | 266 | if self.out_indices is None: 267 | # Just need the final output 268 | for stage_idx in range(self.num_stages): 269 | x = self.stages[stage_idx](x) 270 | if stage_idx < self.num_stages - 1: 271 | x = self.transitions[stage_idx](x) 272 | return x 273 | else: 274 | # Need the intermediate feature maps 275 | outs = [] 276 | for stage_idx in range(self.num_stages): 277 | x = self.stages[stage_idx](x) 278 | if stage_idx in self.out_indices: 279 | outs.append(self.stages[stage_idx].norm(x)) # For RepLKNet-XL normalize the features before feeding them into the heads 280 | if stage_idx < self.num_stages - 1: 281 | x = self.transitions[stage_idx](x) 282 | return outs 283 | 284 | def forward(self, x): 285 | x = self.forward_features(x) 286 | if self.out_indices: 287 | return x 288 | else: 289 | x = self.norm(x) 290 | x = self.avgpool(x) 291 | x = torch.flatten(x, 1) 292 | x = self.head(x) 293 | return x 294 | 295 | def structural_reparam(self): 296 | for m in self.modules(): 297 | if hasattr(m, 'merge_kernel'): 298 | m.merge_kernel() 299 | 300 | # If your framework cannot automatically fuse BN for inference, you may do it manually. 301 | # The BNs after and before conv layers can be removed. 302 | # No need to call this if your framework support automatic BN fusion. 303 | def deep_fuse_BN(self): 304 | for m in self.modules(): 305 | if not isinstance(m, nn.Sequential): 306 | continue 307 | if not len(m) in [2, 3]: # Only handle conv-BN or conv-BN-relu 308 | continue 309 | # If you use a custom Conv2d impl, assume it also has 'kernel_size' and 'weight' 310 | if hasattr(m[0], 'kernel_size') and hasattr(m[0], 'weight') and isinstance(m[1], nn.BatchNorm2d): 311 | conv = m[0] 312 | bn = m[1] 313 | fused_kernel, fused_bias = fuse_bn(conv, bn) 314 | fused_conv = get_conv2d(conv.in_channels, conv.out_channels, kernel_size=conv.kernel_size, 315 | stride=conv.stride, 316 | padding=conv.padding, dilation=conv.dilation, groups=conv.groups, bias=True) 317 | fused_conv.weight.data = fused_kernel 318 | fused_conv.bias.data = fused_bias 319 | m[0] = fused_conv 320 | m[1] = nn.Identity() 321 | 322 | 323 | def create_RepLKNet31B(drop_path_rate=0.3, num_classes=1000, use_checkpoint=True, small_kernel_merged=False): 324 | return RepLKNet(large_kernel_sizes=[31,29,27,13], layers=[2,2,18,2], channels=[128,256,512,1024], 325 | drop_path_rate=drop_path_rate, small_kernel=5, num_classes=num_classes, use_checkpoint=use_checkpoint, 326 | small_kernel_merged=small_kernel_merged) 327 | 328 | def create_RepLKNet31L(drop_path_rate=0.3, num_classes=1000, use_checkpoint=True, small_kernel_merged=False): 329 | return RepLKNet(large_kernel_sizes=[31,29,27,13], layers=[2,2,18,2], channels=[192,384,768,1536], 330 | drop_path_rate=drop_path_rate, small_kernel=5, num_classes=num_classes, use_checkpoint=use_checkpoint, 331 | small_kernel_merged=small_kernel_merged) 332 | 333 | def create_RepLKNetXL(drop_path_rate=0.3, num_classes=1000, use_checkpoint=True, small_kernel_merged=False): 334 | return RepLKNet(large_kernel_sizes=[27,27,27,13], layers=[2,2,18,2], channels=[256,512,1024,2048], 335 | drop_path_rate=drop_path_rate, small_kernel=None, dw_ratio=1.5, 336 | num_classes=num_classes, use_checkpoint=use_checkpoint, 337 | small_kernel_merged=small_kernel_merged) 338 | 339 | if __name__ == '__main__': 340 | os.environ['CUDA_VISIBLE_DEVICES']='3' 341 | model = create_RepLKNet31B(small_kernel_merged=False) 342 | model.eval() 343 | # print('------------------- training-time model -------------') 344 | # print(model) 345 | # x = torch.randn(2, 3, 224, 224).cuda() 346 | # origin_y = model(x) 347 | # model.structural_reparam() 348 | # print('------------------- after re-param -------------') 349 | # print(model) 350 | # reparam_y = model(x) 351 | # print('------------------- the difference is ------------------------') 352 | # print((origin_y - reparam_y).abs().sum()) 353 | 354 | x = torch.randn(1, 320, 16, 16).cuda() 355 | # dpr = [x.item() for x in torch.linspace(0, 0.3, 1)] 356 | # dpr[sum(layers[:stage_idx]):sum(layers[:stage_idx + 1])] 357 | # model=RepLKBlock(320,512,31,5,0.3).cuda() 358 | origin_y = model(x) 359 | print(origin_y.shape) -------------------------------------------------------------------------------- /lib/transformer.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from typing import Optional, List 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | from torch import nn, Tensor 7 | 8 | from .modules import BoundaryWiseAttentionGate2D, BoundaryWiseAttentionGateAtrous2D 9 | 10 | 11 | class Transformer(nn.Module): 12 | def __init__(self, 13 | d_model=512, 14 | nhead=8, 15 | num_encoder_layers=6, 16 | num_decoder_layers=2, 17 | dim_feedforward=2048, 18 | dropout=0.1, 19 | activation=nn.LeakyReLU, 20 | normalize_before=False, 21 | return_intermediate_dec=False): 22 | super().__init__() 23 | 24 | encoder_layer = TransformerEncoderLayer(d_model, nhead, 25 | dim_feedforward, dropout, 26 | activation, normalize_before) 27 | encoder_norm = nn.LayerNorm(d_model) if normalize_before else None 28 | self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, 29 | encoder_norm) 30 | decoder_layer = TransformerDecoderLayer(d_model, nhead, 31 | dim_feedforward, dropout, 32 | activation, normalize_before) 33 | decoder_norm = nn.LayerNorm(d_model) 34 | self.decoder = TransformerDecoder( 35 | decoder_layer, 36 | num_decoder_layers, 37 | decoder_norm, 38 | return_intermediate=return_intermediate_dec) 39 | self._reset_parameters() 40 | 41 | self.d_model = d_model 42 | self.nhead = nhead 43 | 44 | def _reset_parameters(self): 45 | for p in self.parameters(): 46 | if p.dim() > 1: 47 | nn.init.xavier_uniform_(p) 48 | 49 | def forward(self, src, mask, query_embed, pos_embed): 50 | bs, c, h, w = src.shape 51 | src = src.flatten(2).permute(2, 0, 1) 52 | pos_embed = pos_embed.flatten(2).permute(2, 0, 1) 53 | query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1) 54 | if mask is not None: 55 | mask = mask.flatten(1) 56 | 57 | tgt = torch.zeros_like(query_embed) 58 | memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed) 59 | # print("Trans Encoder",memory.shape) 60 | hs = self.decoder(tgt, 61 | memory, 62 | memory_key_padding_mask=mask, 63 | pos=pos_embed, 64 | query_pos=query_embed) 65 | return hs.transpose(1, 2), memory.permute(1, 2, 0).view(bs, c, h, w) 66 | 67 | 68 | class BoundaryAwareTransformer(nn.Module): 69 | def __init__(self, 70 | point_pred_layers=2, 71 | d_model=512, 72 | nhead=8, 73 | num_encoder_layers=2, 74 | num_decoder_layers=2, 75 | dim_feedforward=2048, 76 | dropout=0.1, 77 | activation=nn.LeakyReLU, 78 | normalize_before=False, 79 | return_intermediate_dec=False, 80 | BAG_type='2D', 81 | Atrous=False): 82 | super().__init__() 83 | self.num_decoder_layers = num_decoder_layers 84 | 85 | encoder_layer = BoundaryAwareTransformerEncoderLayer( 86 | d_model, nhead, BAG_type, Atrous, dim_feedforward, dropout, 87 | activation, normalize_before) 88 | encoder_norm = nn.LayerNorm(d_model) if normalize_before else None 89 | self.encoder = BoundaryAwareTransformerEncoder(point_pred_layers, 90 | encoder_layer, 91 | num_encoder_layers, 92 | encoder_norm) 93 | if num_decoder_layers > 0: 94 | decoder_layer = TransformerDecoderLayer(d_model, nhead, 95 | dim_feedforward, dropout, 96 | activation, 97 | normalize_before) 98 | decoder_norm = nn.LayerNorm(d_model) 99 | self.decoder = TransformerDecoder( 100 | decoder_layer, 101 | num_decoder_layers, 102 | decoder_norm, 103 | return_intermediate=return_intermediate_dec) 104 | self._reset_parameters() 105 | 106 | self.d_model = d_model 107 | self.nhead = nhead 108 | 109 | def _reset_parameters(self): 110 | for p in self.parameters(): 111 | if p.dim() > 1: 112 | nn.init.xavier_uniform_(p) 113 | 114 | def forward(self, src, mask, query_embed, pos_embed): 115 | bs, c, h, w = src.shape 116 | src = src.flatten(2).permute(2, 0, 1) 117 | pos_embed = pos_embed.flatten(2).permute(2, 0, 1) 118 | query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1) 119 | if mask is not None: 120 | mask = mask.flatten(1) 121 | 122 | tgt = torch.zeros_like(query_embed) 123 | memory, weights = self.encoder(src, 124 | src_key_padding_mask=mask, 125 | pos=pos_embed, 126 | height=h, 127 | width=w) 128 | if self.num_decoder_layers > 0: 129 | hs = self.decoder(tgt, 130 | memory, 131 | memory_key_padding_mask=mask, 132 | pos=pos_embed, 133 | query_pos=query_embed) 134 | return hs.transpose(1, 2), memory.permute(1, 2, 135 | 0).view(bs, c, h, 136 | w), weights 137 | else: 138 | return tgt.transpose(1, 2), memory.permute(1, 2, 139 | 0).view(bs, c, h, 140 | w), weights 141 | 142 | 143 | class TransformerEncoder(nn.Module): 144 | def __init__(self, encoder_layer, num_layers, norm=None): 145 | super().__init__() 146 | self.layers = _get_clones(encoder_layer, num_layers) 147 | self.num_layers = num_layers 148 | self.norm = norm 149 | 150 | def forward(self, 151 | src, 152 | mask: Optional[Tensor] = None, 153 | src_key_padding_mask: Optional[Tensor] = None, 154 | pos: Optional[Tensor] = None): 155 | output = src 156 | 157 | for layer in self.layers: 158 | output = layer(output, 159 | src_mask=mask, 160 | src_key_padding_mask=src_key_padding_mask, 161 | pos=pos) 162 | 163 | if self.norm is not None: 164 | output = self.norm(output) 165 | 166 | return output 167 | 168 | 169 | class BoundaryAwareTransformerEncoder(nn.Module): 170 | def __init__(self, 171 | point_pred_layers, 172 | encoder_layer, 173 | num_layers, 174 | norm=None): 175 | super().__init__() 176 | self.point_pred_layers = point_pred_layers 177 | self.layers = _get_clones(encoder_layer, num_layers) 178 | self.num_layers = num_layers 179 | self.norm = norm 180 | 181 | def forward(self, 182 | src, 183 | mask: Optional[Tensor] = None, 184 | src_key_padding_mask: Optional[Tensor] = None, 185 | pos: Optional[Tensor] = None, 186 | height: int = 32, 187 | width: int = 32): 188 | output = src 189 | weights = [] 190 | 191 | for layer_i, layer in enumerate(self.layers): 192 | output, weight = layer(True, 193 | output, 194 | src_mask=mask, 195 | src_key_padding_mask=src_key_padding_mask, 196 | pos=pos, 197 | height=height, 198 | width=width) 199 | weights.append(weight) 200 | 201 | if self.norm is not None: 202 | output = self.norm(output) 203 | 204 | return output, weights 205 | 206 | 207 | class TransformerEncoderLayer(nn.Module): 208 | def __init__(self, 209 | d_model, 210 | nhead, 211 | dim_feedforward=2048, 212 | dropout=0.1, 213 | activation=nn.LeakyReLU, 214 | normalize_before=False): 215 | super().__init__() 216 | self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) 217 | # Implementation of Feedforward model 218 | self.linear1 = nn.Linear(d_model, dim_feedforward) 219 | self.dropout = nn.Dropout(dropout) 220 | self.linear2 = nn.Linear(dim_feedforward, d_model) 221 | 222 | self.norm1 = nn.LayerNorm(d_model) 223 | self.norm2 = nn.LayerNorm(d_model) 224 | self.dropout1 = nn.Dropout(dropout) 225 | self.dropout2 = nn.Dropout(dropout) 226 | 227 | self.activation = activation() 228 | self.normalize_before = normalize_before 229 | 230 | def with_pos_embed(self, tensor, pos: Optional[Tensor]): 231 | return tensor if pos is None else tensor + pos 232 | 233 | def forward_post(self, 234 | src, 235 | src_mask: Optional[Tensor] = None, 236 | src_key_padding_mask: Optional[Tensor] = None, 237 | pos: Optional[Tensor] = None): 238 | q = k = self.with_pos_embed(src, pos) 239 | src2 = self.self_attn(q, 240 | k, 241 | value=src, 242 | attn_mask=src_mask, 243 | key_padding_mask=src_key_padding_mask)[0] 244 | src = src + self.dropout1(src2) 245 | src = self.norm1(src) 246 | src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) 247 | src = src + self.dropout2(src2) 248 | src = self.norm2(src) 249 | return src 250 | 251 | def forward_pre(self, 252 | src, 253 | src_mask: Optional[Tensor] = None, 254 | src_key_padding_mask: Optional[Tensor] = None, 255 | pos: Optional[Tensor] = None): 256 | src2 = self.norm1(src) 257 | q = k = self.with_pos_embed(src2, pos) 258 | src2 = self.self_attn(q, 259 | k, 260 | value=src2, 261 | attn_mask=src_mask, 262 | key_padding_mask=src_key_padding_mask)[0] 263 | src = src + self.dropout1(src2) 264 | src2 = self.norm2(src) 265 | src2 = self.linear2(self.dropout(self.activation(self.linear1(src2)))) 266 | src = src + self.dropout2(src2) 267 | return src 268 | 269 | def forward(self, 270 | src, 271 | src_mask: Optional[Tensor] = None, 272 | src_key_padding_mask: Optional[Tensor] = None, 273 | pos: Optional[Tensor] = None): 274 | if self.normalize_before: 275 | return self.forward_pre(src, src_mask, src_key_padding_mask, pos) 276 | return self.forward_post(src, src_mask, src_key_padding_mask, pos) 277 | 278 | 279 | class BoundaryAwareTransformerEncoderLayer(TransformerEncoderLayer): 280 | " Add Boundary-wise Attention Gate to Transformer's Encoder" 281 | 282 | def __init__(self, 283 | d_model, 284 | nhead, 285 | BAG_type='2D', 286 | Atrous=True, 287 | dim_feedforward=2048, 288 | dropout=0.1, 289 | activation=nn.LeakyReLU, 290 | normalize_before=False): 291 | super().__init__(d_model, nhead, dim_feedforward, dropout, activation, 292 | normalize_before) 293 | if BAG_type == '2D': 294 | if Atrous: 295 | self.BAG = BoundaryWiseAttentionGateAtrous2D(d_model) 296 | else: 297 | self.BAG = BoundaryWiseAttentionGate2D(d_model) 298 | self.BAG_type = BAG_type 299 | 300 | def forward(self, 301 | use_bag, 302 | src, 303 | src_mask: Optional[Tensor] = None, 304 | src_key_padding_mask: Optional[Tensor] = None, 305 | pos: Optional[Tensor] = None, 306 | height: int = 32, 307 | width: int = 32): 308 | if self.normalize_before: 309 | features = self.forward_pre(src, src_mask, src_key_padding_mask, 310 | pos) 311 | if use_bag: 312 | b, c = features.shape[1:] 313 | if self.BAG_type == '1D': 314 | features = features.permute(1, 2, 0) 315 | features, weights = self.BAG(features) 316 | features = features.permute(2, 0, 1).contiguous() 317 | weights = weights.view(b, 1, height, width) 318 | elif self.BAG_type == '2D': 319 | features = features.permute(1, 2, 320 | 0).view(b, c, height, width) 321 | features, weights = self.BAG(features) 322 | features = features.flatten(2).permute(2, 0, 323 | 1).contiguous() 324 | return features, weights 325 | else: 326 | return features 327 | features = self.forward_post(src, src_mask, src_key_padding_mask, pos) 328 | if use_bag: 329 | b, c = features.shape[1:] 330 | if self.BAG_type == '1D': 331 | features = features.permute(1, 2, 0) 332 | features, weights = self.BAG(features) 333 | features = features.permute(2, 0, 1).contiguous() 334 | weights = weights.view(b, 1, height, width) 335 | elif self.BAG_type == '2D': 336 | features = features.permute(1, 2, 0).view(b, c, height, width) 337 | features, weights = self.BAG(features) 338 | features = features.flatten(2).permute(2, 0, 1).contiguous() 339 | return features, weights 340 | else: 341 | return features 342 | 343 | 344 | class TransformerDecoder(nn.Module): 345 | def __init__(self, 346 | decoder_layer, 347 | num_layers, 348 | norm=None, 349 | return_intermediate=False): 350 | super().__init__() 351 | self.layers = _get_clones(decoder_layer, num_layers) 352 | self.num_layers = num_layers 353 | self.norm = norm 354 | self.return_intermediate = return_intermediate 355 | 356 | def forward(self, 357 | tgt, 358 | memory, 359 | tgt_mask: Optional[Tensor] = None, 360 | memory_mask: Optional[Tensor] = None, 361 | tgt_key_padding_mask: Optional[Tensor] = None, 362 | memory_key_padding_mask: Optional[Tensor] = None, 363 | pos: Optional[Tensor] = None, 364 | query_pos: Optional[Tensor] = None): 365 | output = tgt 366 | 367 | intermediate = [] 368 | 369 | for layer in self.layers: 370 | output = layer(output, 371 | memory, 372 | tgt_mask=tgt_mask, 373 | memory_mask=memory_mask, 374 | tgt_key_padding_mask=tgt_key_padding_mask, 375 | memory_key_padding_mask=memory_key_padding_mask, 376 | pos=pos, 377 | query_pos=query_pos) 378 | if self.return_intermediate: 379 | intermediate.append(self.norm(output)) 380 | 381 | if self.norm is not None: 382 | output = self.norm(output) 383 | if self.return_intermediate: 384 | intermediate.pop() 385 | intermediate.append(output) 386 | 387 | if self.return_intermediate: 388 | return torch.stack(intermediate) 389 | 390 | return output 391 | 392 | 393 | class TransformerDecoderLayer(nn.Module): 394 | def __init__(self, 395 | d_model, 396 | nhead, 397 | dim_feedforward=2048, 398 | dropout=0.1, 399 | activation=nn.LeakyReLU, 400 | normalize_before=False): 401 | super().__init__() 402 | self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) 403 | self.multihead_attn = nn.MultiheadAttention(d_model, 404 | nhead, 405 | dropout=dropout) 406 | # Implementation of Feedforward model 407 | self.linear1 = nn.Linear(d_model, dim_feedforward) 408 | self.dropout = nn.Dropout(dropout) 409 | self.linear2 = nn.Linear(dim_feedforward, d_model) 410 | 411 | self.norm1 = nn.LayerNorm(d_model) 412 | self.norm2 = nn.LayerNorm(d_model) 413 | self.norm3 = nn.LayerNorm(d_model) 414 | self.dropout1 = nn.Dropout(dropout) 415 | self.dropout2 = nn.Dropout(dropout) 416 | self.dropout3 = nn.Dropout(dropout) 417 | 418 | self.activation = activation() 419 | self.normalize_before = normalize_before 420 | 421 | def with_pos_embed(self, tensor, pos: Optional[Tensor]): 422 | return tensor if pos is None else tensor + pos 423 | 424 | def forward_post(self, 425 | tgt, 426 | memory, 427 | tgt_mask: Optional[Tensor] = None, 428 | memory_mask: Optional[Tensor] = None, 429 | tgt_key_padding_mask: Optional[Tensor] = None, 430 | memory_key_padding_mask: Optional[Tensor] = None, 431 | pos: Optional[Tensor] = None, 432 | query_pos: Optional[Tensor] = None): 433 | q = k = self.with_pos_embed(tgt, query_pos) 434 | tgt2 = self.self_attn(q, 435 | k, 436 | value=tgt, 437 | attn_mask=tgt_mask, 438 | key_padding_mask=tgt_key_padding_mask)[0] 439 | tgt = tgt + self.dropout1(tgt2) 440 | tgt = self.norm1(tgt) 441 | tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos), 442 | key=self.with_pos_embed(memory, pos), 443 | value=memory, 444 | attn_mask=memory_mask, 445 | key_padding_mask=memory_key_padding_mask)[0] 446 | tgt = tgt + self.dropout2(tgt2) 447 | tgt = self.norm2(tgt) 448 | tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) 449 | tgt = tgt + self.dropout3(tgt2) 450 | tgt = self.norm3(tgt) 451 | return tgt 452 | 453 | def forward_pre(self, 454 | tgt, 455 | memory, 456 | tgt_mask: Optional[Tensor] = None, 457 | memory_mask: Optional[Tensor] = None, 458 | tgt_key_padding_mask: Optional[Tensor] = None, 459 | memory_key_padding_mask: Optional[Tensor] = None, 460 | pos: Optional[Tensor] = None, 461 | query_pos: Optional[Tensor] = None): 462 | tgt2 = self.norm1(tgt) 463 | q = k = self.with_pos_embed(tgt2, query_pos) 464 | tgt2 = self.self_attn(q, 465 | k, 466 | value=tgt2, 467 | attn_mask=tgt_mask, 468 | key_padding_mask=tgt_key_padding_mask)[0] 469 | tgt = tgt + self.dropout1(tgt2) 470 | tgt2 = self.norm2(tgt) 471 | tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos), 472 | key=self.with_pos_embed(memory, pos), 473 | value=memory, 474 | attn_mask=memory_mask, 475 | key_padding_mask=memory_key_padding_mask)[0] 476 | tgt = tgt + self.dropout2(tgt2) 477 | tgt2 = self.norm3(tgt) 478 | tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) 479 | tgt = tgt + self.dropout3(tgt2) 480 | return tgt 481 | 482 | def forward(self, 483 | tgt, 484 | memory, 485 | tgt_mask: Optional[Tensor] = None, 486 | memory_mask: Optional[Tensor] = None, 487 | tgt_key_padding_mask: Optional[Tensor] = None, 488 | memory_key_padding_mask: Optional[Tensor] = None, 489 | pos: Optional[Tensor] = None, 490 | query_pos: Optional[Tensor] = None): 491 | if self.normalize_before: 492 | return self.forward_pre(tgt, memory, tgt_mask, memory_mask, 493 | tgt_key_padding_mask, 494 | memory_key_padding_mask, pos, query_pos) 495 | return self.forward_post(tgt, memory, tgt_mask, memory_mask, 496 | tgt_key_padding_mask, memory_key_padding_mask, 497 | pos, query_pos) 498 | 499 | 500 | def _get_clones(module, N): 501 | return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) 502 | 503 | 504 | def build_transformer(args): 505 | return Transformer( 506 | d_model=args.hidden_dim, 507 | dropout=args.dropout, 508 | nhead=args.nheads, 509 | dim_feedforward=args.dim_feedforward, 510 | num_encoder_layers=args.enc_layers, 511 | num_decoder_layers=args.dec_layers, 512 | normalize_before=args.pre_norm, 513 | return_intermediate_dec=True, 514 | ) 515 | 516 | 517 | def _get_activation_fn(activation): 518 | """Return an activation function given a string""" 519 | if activation == "leaky relu": 520 | return F.leaky_relu 521 | if activation == "selu": 522 | return F.selu 523 | if activation == "relu": 524 | return F.relu 525 | if activation == "gelu": 526 | return F.gelu 527 | if activation == "glu": 528 | return F.glu 529 | raise RuntimeError( 530 | F"activation should be relu, gelu, glu, leaky relu or selu, not {activation}." 531 | ) --------------------------------------------------------------------------------