├── requirements.txt ├── README.md ├── configs └── DenUnet_configs.py ├── models ├── Decoder.py ├── DenUnet.py ├── convert_pidinet.py ├── ops.py ├── config.py ├── pidinet.py └── Encoder.py ├── datasets └── dataset_tooth.py ├── train.py ├── test.py ├── trainer.py └── utils.py /requirements.txt: -------------------------------------------------------------------------------- 1 | timm 2 | einops 3 | ml_collections 4 | wget 5 | tensorboardX 6 | SimpleITK 7 | medpy 8 | torch>=2.0.0 -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | DenUnet: Enhancing Dental Image Segmentation through Edge and Body Fusion. 2 | 3 | ## CheckList 4 | 5 | - [x] Network 6 | - [ ] Paper link 7 | - [ ] dataset 8 | - [ ] Pretrained Weights 9 | -------------------------------------------------------------------------------- /configs/DenUnet_configs.py: -------------------------------------------------------------------------------- 1 | import ml_collections 2 | import os 3 | import wget 4 | 5 | os.makedirs('./weights', exist_ok=True) 6 | 7 | 8 | # DenUnet Configs 9 | def get_DenUnet_configs(): 10 | cfg = ml_collections.ConfigDict() 11 | 12 | # Swin Transformer Configs 13 | cfg.swin_pyramid_fm = [96, 192, 384, 768] 14 | cfg.image_size = 224 15 | cfg.patch_size = 4 16 | cfg.num_classes = 9 17 | if not os.path.isfile('./weights/swin_tiny_patch4_window7_224.pth'): 18 | print('Downloading Swin-transformer model ...') 19 | wget.download( 20 | "https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_tiny_patch4_window7_224.pth", 21 | "./weights/swin_tiny_patch4_window7_224.pth") 22 | cfg.swin_pretrained_path = './weights/swin_tiny_patch4_window7_224.pth' 23 | 24 | # CNN Configs 25 | cfg.cnn_backbone = "pidinet_small_converted" 26 | cfg.pdcs = 'carv4' 27 | cfg.cnn_pyramid_fm = [30, 60, 120, 240] 28 | # cfg.cnn_pyramid_fm = [256,512,1024] 29 | cfg.pidinet_pretrained = False 30 | 31 | # DLF Configs 32 | cfg.depth = [[1, 2, 0]] 33 | cfg.num_heads = (6, 12) 34 | cfg.mlp_ratio = (2., 2., 1.) 35 | cfg.drop_rate = 0. 36 | cfg.attn_drop_rate = 0. 37 | cfg.drop_path_rate = 0. 38 | cfg.qkv_bias = True 39 | cfg.qk_scale = None 40 | cfg.cross_pos_embed = True 41 | 42 | return cfg 43 | 44 | -------------------------------------------------------------------------------- /models/Decoder.py: -------------------------------------------------------------------------------- 1 | """ 2 | Author: Omid Nejati Manzari 3 | Date: Jul 2023 4 | """ 5 | import torch.nn as nn 6 | 7 | class ConvUpsample(nn.Module): 8 | def __init__(self, in_chans=384, out_chans=[128], upsample=True): 9 | super().__init__() 10 | self.in_chans = in_chans 11 | self.out_chans = out_chans 12 | 13 | self.conv_tower = nn.ModuleList() 14 | for i, out_ch in enumerate(self.out_chans): 15 | if i > 0: self.in_chans = out_ch 16 | self.conv_tower.append(nn.Conv2d( 17 | self.in_chans, out_ch, 18 | kernel_size=3, stride=1, 19 | padding=1, bias=False 20 | )) 21 | self.conv_tower.append(nn.GroupNorm(32, out_ch)) 22 | self.conv_tower.append(nn.ReLU(inplace=False)) 23 | if upsample: 24 | self.conv_tower.append(nn.Upsample( 25 | scale_factor=2, mode='bilinear', align_corners=False)) 26 | 27 | self.convs_level = nn.Sequential(*self.conv_tower) 28 | 29 | def forward(self, x): 30 | return self.convs_level(x) 31 | 32 | 33 | class SegmentationHead(nn.Sequential): 34 | def __init__(self, in_channels, out_channels, kernel_size=3): 35 | conv2d = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=kernel_size // 2) 36 | super().__init__(conv2d) 37 | -------------------------------------------------------------------------------- /models/DenUnet.py: -------------------------------------------------------------------------------- 1 | """ 2 | Author: Omid Nejati Manzari 3 | Date: Jul 2023 4 | """ 5 | 6 | import torch.nn as nn 7 | from einops.layers.torch import Rearrange 8 | 9 | from models.Encoder import All2Cross 10 | from models.Decoder import ConvUpsample, SegmentationHead 11 | 12 | 13 | class DenUnet(nn.Module): 14 | def __init__(self, config, img_size=224, in_chans=3, n_classes=9): 15 | super().__init__() 16 | self.img_size = img_size 17 | self.patch_size = [4, 32] 18 | self.n_classes = n_classes 19 | self.All2Cross = All2Cross(config=config, img_size=img_size, in_chans=in_chans) 20 | 21 | self.ConvUp_s = ConvUpsample(in_chans=768, out_chans=[128, 128, 128], upsample=True) # 1 22 | self.ConvUp_l = ConvUpsample(in_chans=96, upsample=False) # 0 23 | 24 | self.segmentation_head = SegmentationHead( 25 | in_channels=16, 26 | out_channels=n_classes, 27 | kernel_size=3, 28 | ) 29 | 30 | self.conv_pred = nn.Sequential( 31 | nn.Conv2d( 32 | 128, 16, 33 | kernel_size=1, stride=1, 34 | padding=0, bias=True), 35 | # nn.GroupNorm(8, 16), 36 | nn.ReLU(inplace=True), 37 | nn.Upsample(scale_factor=4, mode='bilinear', align_corners=False) 38 | ) 39 | 40 | def forward(self, x): 41 | xs = self.All2Cross(x) 42 | embeddings = [x[:, 1:] for x in xs] 43 | reshaped_embed = [] 44 | for i, embed in enumerate(embeddings): 45 | embed = Rearrange('b (h w) d -> b d h w', h=(self.img_size // self.patch_size[i]), 46 | w=(self.img_size // self.patch_size[i]))(embed) 47 | embed = self.ConvUp_l(embed) if i == 0 else self.ConvUp_s(embed) 48 | 49 | reshaped_embed.append(embed) 50 | 51 | C = reshaped_embed[0] + reshaped_embed[1] 52 | C = self.conv_pred(C) 53 | 54 | out = self.segmentation_head(C) 55 | 56 | return out -------------------------------------------------------------------------------- /datasets/dataset_tooth.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import h5py 4 | import numpy as np 5 | import torch 6 | from scipy import ndimage 7 | from scipy.ndimage.interpolation import zoom 8 | from torch.utils.data import Dataset 9 | from PIL import Image 10 | 11 | def random_rot_flip(image, label): 12 | k = np.random.randint(0, 4) 13 | image = np.rot90(image, k) 14 | label = np.rot90(label, k) 15 | axis = np.random.randint(0, 2) 16 | image = np.flip(image, axis=axis).copy() 17 | label = np.flip(label, axis=axis).copy() 18 | return image, label 19 | 20 | 21 | def random_rotate(image, label): 22 | angle = np.random.randint(-20, 20) 23 | image = ndimage.rotate(image, angle, order=0, reshape=False) 24 | label = ndimage.rotate(label, angle, order=0, reshape=False) 25 | return image, label 26 | 27 | 28 | class RandomGenerator(object): 29 | def __init__(self, output_size): 30 | self.output_size = output_size 31 | 32 | def __call__(self, sample): 33 | image, label = sample['image'], sample['label'] 34 | 35 | if random.random() > 0.5: 36 | image, label = random_rot_flip(image, label) 37 | elif random.random() > 0.5: 38 | image, label = random_rotate(image, label) 39 | x, y = image.shape 40 | if x != self.output_size[0] or y != self.output_size[1]: 41 | image = zoom(image, (self.output_size[0] / x, self.output_size[1] / y), order=3) # why not 3? 42 | label = zoom(label, (self.output_size[0] / x, self.output_size[1] / y), order=0) 43 | image = torch.from_numpy(image.astype(np.float32)).unsqueeze(0) 44 | label = torch.from_numpy(label.astype(np.float32)) 45 | sample = {'image': image, 'label': label.long()} 46 | return sample 47 | 48 | 49 | class TeethDataset(Dataset): 50 | def __init__(self, root, split, transform=None): 51 | self.transform = transform # using transform in torch! 52 | self.split = split 53 | self.root = root 54 | self.imgs = list(sorted(os.listdir(os.path.join(root, "img")))) 55 | self.masks = list(sorted(os.listdir(os.path.join(root, "masks")))) 56 | 57 | def __len__(self): 58 | return len(self.imgs) 59 | 60 | def __getitem__(self, idx): 61 | if self.split == "train": 62 | img_path = os.path.join(self.root, "img", self.imgs[idx]) 63 | mask_path = os.path.join(self.root, "masks", self.masks[idx]) 64 | 65 | #image, label = data['image'], data['label'] 66 | else: 67 | img_path = os.path.join(self.root, "val\img", self.imgs[idx]) 68 | mask_path = os.path.join(self.root, "val\masks", self.masks[idx]) 69 | 70 | image = np.array( Image.open(img_path)) 71 | label = np.array( Image.open(mask_path)) 72 | sample = {'image': image, 'label': label} 73 | if self.transform: 74 | sample = self.transform(sample) 75 | 76 | return sample -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import random 4 | import numpy as np 5 | import torch 6 | import torch.backends.cudnn as cudnn 7 | 8 | from models.DenUnet import DenUnet 9 | import configs.DenUnet_configs as configs 10 | from trainer import trainer 11 | 12 | 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument('--root_path', type=str, 15 | default='./data/Toothdataset', help='root dir for data') 16 | parser.add_argument('--test_path', type=str, 17 | default='./data/Toothdataset/test', help='root dir for data') 18 | parser.add_argument('--dataset', type=str, 19 | default='Toothdataset', help='experiment_name') 20 | parser.add_argument('--num_classes', type=int, 21 | default=33, help='output channel of network') 22 | parser.add_argument('--max_iterations', type=int, 23 | default=30000, help='maximum epoch number to train') 24 | parser.add_argument('--max_epochs', type=int, 25 | default=501, help='maximum epoch number to train') 26 | parser.add_argument('--batch_size', type=int, 27 | default=10, help='batch_size per gpu') 28 | parser.add_argument('--n_gpu', type=int, default=1, help='total gpu') 29 | parser.add_argument('--deterministic', type=int, default=1, 30 | help='whether use deterministic training') 31 | parser.add_argument('--base_lr', type=float, default=0.01, 32 | help='segmentation network learning rate') 33 | parser.add_argument('--num_workers', type=int, default=2, 34 | help='number of workers') 35 | parser.add_argument('--img_size', type=int, 36 | default=224, help='input patch size of network input') 37 | parser.add_argument('--seed', type=int, 38 | default=1234, help='random seed') 39 | parser.add_argument('--output_dir', type=str, 40 | default='./results', help='root dir for output log') 41 | parser.add_argument('--model_name', type=str, 42 | default='DenUnet') 43 | parser.add_argument('--eval_interval', type=int, 44 | default=20, help='evaluation epoch') 45 | parser.add_argument('--z_spacing', type=int, 46 | default=1, help='z_spacing') 47 | 48 | args = parser.parse_args() 49 | 50 | args.output_dir = args.output_dir + f'/{args.model_name}' 51 | os.makedirs(args.output_dir, exist_ok=True) 52 | 53 | 54 | if __name__ == "__main__": 55 | if not args.deterministic: 56 | cudnn.benchmark = True 57 | cudnn.deterministic = False 58 | else: 59 | cudnn.benchmark = False 60 | cudnn.deterministic = True 61 | 62 | random.seed(args.seed) 63 | np.random.seed(args.seed) 64 | torch.manual_seed(args.seed) 65 | torch.cuda.manual_seed(args.seed) 66 | 67 | 68 | CONFIGS = { 69 | 'DenUnet': configs.get_DenUnet_configs(), 70 | } 71 | 72 | if args.batch_size != 24 and args.batch_size % 6 == 0: 73 | args.base_lr *= args.batch_size / 24 74 | 75 | 76 | model = DenUnet(config=CONFIGS[args.model_name], img_size=args.img_size, n_classes=args.num_classes).cuda() 77 | trainer(args, model, args.output_dir) 78 | -------------------------------------------------------------------------------- /models/convert_pidinet.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from .config import config_model_converted 7 | 8 | def convert_pdc(op, weight): 9 | if op == 'cv': 10 | return weight 11 | elif op == 'cd': 12 | shape = weight.shape 13 | weight_c = weight.sum(dim=[2, 3]) 14 | weight = weight.view(shape[0], shape[1], -1) 15 | weight[:, :, 4] = weight[:, :, 4] - weight_c 16 | weight = weight.view(shape) 17 | return weight 18 | elif op == 'ad': 19 | shape = weight.shape 20 | weight = weight.view(shape[0], shape[1], -1) 21 | weight_conv = (weight - weight[:, :, [3, 0, 1, 6, 4, 2, 7, 8, 5]]).view(shape) 22 | return weight_conv 23 | elif op == 'rd': 24 | shape = weight.shape 25 | buffer = torch.zeros(shape[0], shape[1], 5 * 5, device=weight.device) 26 | weight = weight.view(shape[0], shape[1], -1) 27 | buffer[:, :, [0, 2, 4, 10, 14, 20, 22, 24]] = weight[:, :, 1:] 28 | buffer[:, :, [6, 7, 8, 11, 13, 16, 17, 18]] = -weight[:, :, 1:] 29 | buffer = buffer.view(shape[0], shape[1], 5, 5) 30 | return buffer 31 | raise ValueError("wrong op {}".format(str(op))) 32 | 33 | def convert_pidinet(state_dict, config): 34 | pdcs = config_model_converted(config) 35 | new_dict = {} 36 | for pname, p in state_dict.items(): 37 | if 'init_block.weight' in pname: 38 | new_dict[pname] = convert_pdc(pdcs[0], p) 39 | elif 'block1_1.conv1.weight' in pname: 40 | new_dict[pname] = convert_pdc(pdcs[1], p) 41 | elif 'block1_2.conv1.weight' in pname: 42 | new_dict[pname] = convert_pdc(pdcs[2], p) 43 | elif 'block1_3.conv1.weight' in pname: 44 | new_dict[pname] = convert_pdc(pdcs[3], p) 45 | elif 'block2_1.conv1.weight' in pname: 46 | new_dict[pname] = convert_pdc(pdcs[4], p) 47 | elif 'block2_2.conv1.weight' in pname: 48 | new_dict[pname] = convert_pdc(pdcs[5], p) 49 | elif 'block2_3.conv1.weight' in pname: 50 | new_dict[pname] = convert_pdc(pdcs[6], p) 51 | elif 'block2_4.conv1.weight' in pname: 52 | new_dict[pname] = convert_pdc(pdcs[7], p) 53 | elif 'block3_1.conv1.weight' in pname: 54 | new_dict[pname] = convert_pdc(pdcs[8], p) 55 | elif 'block3_2.conv1.weight' in pname: 56 | new_dict[pname] = convert_pdc(pdcs[9], p) 57 | elif 'block3_3.conv1.weight' in pname: 58 | new_dict[pname] = convert_pdc(pdcs[10], p) 59 | elif 'block3_4.conv1.weight' in pname: 60 | new_dict[pname] = convert_pdc(pdcs[11], p) 61 | elif 'block4_1.conv1.weight' in pname: 62 | new_dict[pname] = convert_pdc(pdcs[12], p) 63 | elif 'block4_2.conv1.weight' in pname: 64 | new_dict[pname] = convert_pdc(pdcs[13], p) 65 | elif 'block4_3.conv1.weight' in pname: 66 | new_dict[pname] = convert_pdc(pdcs[14], p) 67 | elif 'block4_4.conv1.weight' in pname: 68 | new_dict[pname] = convert_pdc(pdcs[15], p) 69 | else: 70 | new_dict[pname] = p 71 | 72 | return new_dict 73 | 74 | -------------------------------------------------------------------------------- /models/ops.py: -------------------------------------------------------------------------------- 1 | """ 2 | Function factory for pixel difference convolutional operations. 3 | 4 | """ 5 | import math 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | class Conv2d(nn.Module): 12 | def __init__(self, pdc, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=False): 13 | super(Conv2d, self).__init__() 14 | if in_channels % groups != 0: 15 | raise ValueError('in_channels must be divisible by groups') 16 | if out_channels % groups != 0: 17 | raise ValueError('out_channels must be divisible by groups') 18 | self.in_channels = in_channels 19 | self.out_channels = out_channels 20 | self.kernel_size = kernel_size 21 | self.stride = stride 22 | self.padding = padding 23 | self.dilation = dilation 24 | self.groups = groups 25 | self.weight = nn.Parameter(torch.Tensor(out_channels, in_channels // groups, kernel_size, kernel_size)) 26 | if bias: 27 | self.bias = nn.Parameter(torch.Tensor(out_channels)) 28 | else: 29 | self.register_parameter('bias', None) 30 | self.reset_parameters() 31 | self.pdc = pdc 32 | 33 | def reset_parameters(self): 34 | nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) 35 | if self.bias is not None: 36 | fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight) 37 | bound = 1 / math.sqrt(fan_in) 38 | nn.init.uniform_(self.bias, -bound, bound) 39 | 40 | def forward(self, input): 41 | 42 | return self.pdc(input, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) 43 | 44 | 45 | ## cd, ad, rd convolutions 46 | def createConvFunc(op_type): 47 | assert op_type in ['cv', 'cd', 'ad', 'rd'], 'unknown op type: %s' % str(op_type) 48 | if op_type == 'cv': 49 | return F.conv2d 50 | 51 | if op_type == 'cd': 52 | def func(x, weights, bias=None, stride=1, padding=0, dilation=1, groups=1): 53 | assert dilation in [1, 2], 'dilation for cd_conv should be in 1 or 2' 54 | assert weights.size(2) == 3 and weights.size(3) == 3, 'kernel size for cd_conv should be 3x3' 55 | assert padding == dilation, 'padding for cd_conv set wrong' 56 | 57 | weights_c = weights.sum(dim=[2, 3], keepdim=True) 58 | yc = F.conv2d(x, weights_c, stride=stride, padding=0, groups=groups) 59 | y = F.conv2d(x, weights, bias, stride=stride, padding=padding, dilation=dilation, groups=groups) 60 | return y - yc 61 | return func 62 | elif op_type == 'ad': 63 | def func(x, weights, bias=None, stride=1, padding=0, dilation=1, groups=1): 64 | assert dilation in [1, 2], 'dilation for ad_conv should be in 1 or 2' 65 | assert weights.size(2) == 3 and weights.size(3) == 3, 'kernel size for ad_conv should be 3x3' 66 | assert padding == dilation, 'padding for ad_conv set wrong' 67 | 68 | shape = weights.shape 69 | weights = weights.view(shape[0], shape[1], -1) 70 | weights_conv = (weights - weights[:, :, [3, 0, 1, 6, 4, 2, 7, 8, 5]]).view(shape) # clock-wise 71 | y = F.conv2d(x, weights_conv, bias, stride=stride, padding=padding, dilation=dilation, groups=groups) 72 | return y 73 | return func 74 | elif op_type == 'rd': 75 | def func(x, weights, bias=None, stride=1, padding=0, dilation=1, groups=1): 76 | assert dilation in [1, 2], 'dilation for rd_conv should be in 1 or 2' 77 | assert weights.size(2) == 3 and weights.size(3) == 3, 'kernel size for rd_conv should be 3x3' 78 | padding = 2 * dilation 79 | 80 | shape = weights.shape 81 | if weights.is_cuda: 82 | buffer = torch.cuda.FloatTensor(shape[0], shape[1], 5 * 5).fill_(0) 83 | else: 84 | buffer = torch.zeros(shape[0], shape[1], 5 * 5) 85 | weights = weights.view(shape[0], shape[1], -1) 86 | buffer[:, :, [0, 2, 4, 10, 14, 20, 22, 24]] = weights[:, :, 1:] 87 | buffer[:, :, [6, 7, 8, 11, 13, 16, 17, 18]] = -weights[:, :, 1:] 88 | buffer[:, :, 12] = 0 89 | buffer = buffer.view(shape[0], shape[1], 5, 5) 90 | y = F.conv2d(x, buffer, bias, stride=stride, padding=padding, dilation=dilation, groups=groups) 91 | return y 92 | return func 93 | else: 94 | print('impossible to be here unless you force that') 95 | return None 96 | 97 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import os 4 | import random 5 | import sys 6 | import numpy as np 7 | import torch 8 | import torch.backends.cudnn as cudnn 9 | from torch.utils.data import DataLoader 10 | from tqdm import tqdm 11 | from datasets.dataset_tooth import TeethDataset 12 | from utils import test_single_volume 13 | 14 | from models.DenUnet import DenUnet 15 | import configs.DenUnet_configs as configs 16 | 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument('--test_path', type=str, 19 | default='./data/Toothdataset/test', help='root dir for data') 20 | parser.add_argument('--dataset', type=str, 21 | default='ToothDataset', help='experiment_name') 22 | parser.add_argument('--model_weight', type=str, 23 | default='9', help='epoch number for prediction') 24 | parser.add_argument('--num_classes', type=int, 25 | default=33, help='output channel of network') 26 | parser.add_argument('--max_epochs', type=int, 27 | default=501, help='maximum epoch number to train') 28 | parser.add_argument('--deterministic', type=int, default=1, 29 | help='whether use deterministic training') 30 | parser.add_argument('--base_lr', type=float, default=0.01, 31 | help='segmentation network learning rate') 32 | parser.add_argument('--img_size', type=int, 33 | default=224, help='input patch size of network input') 34 | parser.add_argument('--seed', type=int, 35 | default=1234, help='random seed') 36 | parser.add_argument('--output_dir', type=str, 37 | default='./predictions', help='root dir for output log') 38 | parser.add_argument('--model_name', type=str, 39 | default='DenUnet') 40 | parser.add_argument('--z_spacing', type=int, 41 | default=1, help='z_spacing') 42 | parser.add_argument('--is_savenii', 43 | action="store_true", help='whether to save results during inference') 44 | parser.add_argument('--test_save_dir', type=str, 45 | default='./predictions', help='saving prediction as nii!') 46 | 47 | args = parser.parse_args() 48 | 49 | 50 | def inference(args, testloader, model, test_save_path=None): 51 | logging.info("{} test iterations per epoch".format(len(testloader))) 52 | model.eval() 53 | metric_list = 0.0 54 | 55 | for i_batch, sampled_batch in tqdm(enumerate(testloader)): 56 | h, w = sampled_batch["image"].size()[2:] 57 | image, label, case_name = sampled_batch["image"], sampled_batch["label"], sampled_batch['case_name'][0] 58 | metric_i = test_single_volume(image, label, model, classes=args.num_classes, patch_size=[args.img_size, args.img_size], 59 | test_save_path=test_save_path, case=case_name, z_spacing=args.z_spacing) 60 | metric_list += np.array(metric_i) 61 | logging.info(' idx %d case %s mean_dice %f mean_hd95 %f' % (i_batch, case_name, np.mean(metric_i, axis=0)[0], np.mean(metric_i, axis=0)[1])) 62 | 63 | metric_list = metric_list / len(db_test) 64 | 65 | for i in range(1, args.num_classes): 66 | logging.info('Mean class %d mean_dice %f mean_hd95 %f' % (i, metric_list[i-1][0], metric_list[i-1][1])) 67 | 68 | performance = np.mean(metric_list, axis=0)[0] 69 | mean_hd95 = np.mean(metric_list, axis=0)[1] 70 | 71 | logging.info('Testing performance in best val model: mean_dice : %f mean_hd95 : %f' % (performance, mean_hd95)) 72 | 73 | return "Testing Finished!" 74 | 75 | 76 | 77 | if __name__ == "__main__": 78 | if not args.deterministic: 79 | cudnn.benchmark = True 80 | cudnn.deterministic = False 81 | else: 82 | cudnn.benchmark = False 83 | cudnn.deterministic = True 84 | random.seed(args.seed) 85 | np.random.seed(args.seed) 86 | torch.manual_seed(args.seed) 87 | torch.cuda.manual_seed(args.seed) 88 | 89 | 90 | CONFIGS = { 91 | 'DenUnet': configs.get_DenUnet_configs(), 92 | } 93 | 94 | 95 | args.is_pretrain = True 96 | 97 | model = DenUnet(config=CONFIGS[args.model_name], img_size=args.img_size, n_classes=args.num_classes).cuda() 98 | msg = model.load_state_dict(torch.load(args.model_weight)) 99 | print("BEFUnet Model: ", msg) 100 | 101 | log_folder = './test_log/test_log_' 102 | os.makedirs(log_folder, exist_ok=True) 103 | 104 | logging.basicConfig(filename=log_folder + '/' + args.model_name + ".txt", level=logging.INFO, format='[%(asctime)s.%(msecs)03d] %(message)s', datefmt='%H:%M:%S') 105 | logging.getLogger().addHandler(logging.StreamHandler(sys.stdout)) 106 | logging.info(str(args)) 107 | 108 | if args.is_savenii: 109 | args.test_save_dir = os.path.join(args.output_dir, args.model_name) 110 | test_save_path = args.test_save_dir 111 | os.makedirs(test_save_path, exist_ok=True) 112 | else: 113 | test_save_path = None 114 | 115 | db_test = Synapse_dataset(base_dir=args.test_path, split="test_vol", list_dir=args.list_dir) 116 | testloader = DataLoader(db_test, batch_size=1, shuffle=False, num_workers=1) 117 | 118 | inference(args, testloader, model, test_save_path) -------------------------------------------------------------------------------- /models/config.py: -------------------------------------------------------------------------------- 1 | 2 | from .ops import createConvFunc 3 | 4 | nets = { 5 | 'baseline': { 6 | 'layer0': 'cv', 7 | 'layer1': 'cv', 8 | 'layer2': 'cv', 9 | 'layer3': 'cv', 10 | 'layer4': 'cv', 11 | 'layer5': 'cv', 12 | 'layer6': 'cv', 13 | 'layer7': 'cv', 14 | 'layer8': 'cv', 15 | 'layer9': 'cv', 16 | 'layer10': 'cv', 17 | 'layer11': 'cv', 18 | 'layer12': 'cv', 19 | 'layer13': 'cv', 20 | 'layer14': 'cv', 21 | 'layer15': 'cv', 22 | }, 23 | 'c-v15': { 24 | 'layer0': 'cd', 25 | 'layer1': 'cv', 26 | 'layer2': 'cv', 27 | 'layer3': 'cv', 28 | 'layer4': 'cv', 29 | 'layer5': 'cv', 30 | 'layer6': 'cv', 31 | 'layer7': 'cv', 32 | 'layer8': 'cv', 33 | 'layer9': 'cv', 34 | 'layer10': 'cv', 35 | 'layer11': 'cv', 36 | 'layer12': 'cv', 37 | 'layer13': 'cv', 38 | 'layer14': 'cv', 39 | 'layer15': 'cv', 40 | }, 41 | 'a-v15': { 42 | 'layer0': 'ad', 43 | 'layer1': 'cv', 44 | 'layer2': 'cv', 45 | 'layer3': 'cv', 46 | 'layer4': 'cv', 47 | 'layer5': 'cv', 48 | 'layer6': 'cv', 49 | 'layer7': 'cv', 50 | 'layer8': 'cv', 51 | 'layer9': 'cv', 52 | 'layer10': 'cv', 53 | 'layer11': 'cv', 54 | 'layer12': 'cv', 55 | 'layer13': 'cv', 56 | 'layer14': 'cv', 57 | 'layer15': 'cv', 58 | }, 59 | 'r-v15': { 60 | 'layer0': 'rd', 61 | 'layer1': 'cv', 62 | 'layer2': 'cv', 63 | 'layer3': 'cv', 64 | 'layer4': 'cv', 65 | 'layer5': 'cv', 66 | 'layer6': 'cv', 67 | 'layer7': 'cv', 68 | 'layer8': 'cv', 69 | 'layer9': 'cv', 70 | 'layer10': 'cv', 71 | 'layer11': 'cv', 72 | 'layer12': 'cv', 73 | 'layer13': 'cv', 74 | 'layer14': 'cv', 75 | 'layer15': 'cv', 76 | }, 77 | 'cvvv4': { 78 | 'layer0': 'cd', 79 | 'layer1': 'cv', 80 | 'layer2': 'cv', 81 | 'layer3': 'cv', 82 | 'layer4': 'cd', 83 | 'layer5': 'cv', 84 | 'layer6': 'cv', 85 | 'layer7': 'cv', 86 | 'layer8': 'cd', 87 | 'layer9': 'cv', 88 | 'layer10': 'cv', 89 | 'layer11': 'cv', 90 | 'layer12': 'cd', 91 | 'layer13': 'cv', 92 | 'layer14': 'cv', 93 | 'layer15': 'cv', 94 | }, 95 | 'avvv4': { 96 | 'layer0': 'ad', 97 | 'layer1': 'cv', 98 | 'layer2': 'cv', 99 | 'layer3': 'cv', 100 | 'layer4': 'ad', 101 | 'layer5': 'cv', 102 | 'layer6': 'cv', 103 | 'layer7': 'cv', 104 | 'layer8': 'ad', 105 | 'layer9': 'cv', 106 | 'layer10': 'cv', 107 | 'layer11': 'cv', 108 | 'layer12': 'ad', 109 | 'layer13': 'cv', 110 | 'layer14': 'cv', 111 | 'layer15': 'cv', 112 | }, 113 | 'rvvv4': { 114 | 'layer0': 'rd', 115 | 'layer1': 'cv', 116 | 'layer2': 'cv', 117 | 'layer3': 'cv', 118 | 'layer4': 'rd', 119 | 'layer5': 'cv', 120 | 'layer6': 'cv', 121 | 'layer7': 'cv', 122 | 'layer8': 'rd', 123 | 'layer9': 'cv', 124 | 'layer10': 'cv', 125 | 'layer11': 'cv', 126 | 'layer12': 'rd', 127 | 'layer13': 'cv', 128 | 'layer14': 'cv', 129 | 'layer15': 'cv', 130 | }, 131 | 'cccv4': { 132 | 'layer0': 'cd', 133 | 'layer1': 'cd', 134 | 'layer2': 'cd', 135 | 'layer3': 'cv', 136 | 'layer4': 'cd', 137 | 'layer5': 'cd', 138 | 'layer6': 'cd', 139 | 'layer7': 'cv', 140 | 'layer8': 'cd', 141 | 'layer9': 'cd', 142 | 'layer10': 'cd', 143 | 'layer11': 'cv', 144 | 'layer12': 'cd', 145 | 'layer13': 'cd', 146 | 'layer14': 'cd', 147 | 'layer15': 'cv', 148 | }, 149 | 'aaav4': { 150 | 'layer0': 'ad', 151 | 'layer1': 'ad', 152 | 'layer2': 'ad', 153 | 'layer3': 'cv', 154 | 'layer4': 'ad', 155 | 'layer5': 'ad', 156 | 'layer6': 'ad', 157 | 'layer7': 'cv', 158 | 'layer8': 'ad', 159 | 'layer9': 'ad', 160 | 'layer10': 'ad', 161 | 'layer11': 'cv', 162 | 'layer12': 'ad', 163 | 'layer13': 'ad', 164 | 'layer14': 'ad', 165 | 'layer15': 'cv', 166 | }, 167 | 'rrrv4': { 168 | 'layer0': 'rd', 169 | 'layer1': 'rd', 170 | 'layer2': 'rd', 171 | 'layer3': 'cv', 172 | 'layer4': 'rd', 173 | 'layer5': 'rd', 174 | 'layer6': 'rd', 175 | 'layer7': 'cv', 176 | 'layer8': 'rd', 177 | 'layer9': 'rd', 178 | 'layer10': 'rd', 179 | 'layer11': 'cv', 180 | 'layer12': 'rd', 181 | 'layer13': 'rd', 182 | 'layer14': 'rd', 183 | 'layer15': 'cv', 184 | }, 185 | 'c16': { 186 | 'layer0': 'cd', 187 | 'layer1': 'cd', 188 | 'layer2': 'cd', 189 | 'layer3': 'cd', 190 | 'layer4': 'cd', 191 | 'layer5': 'cd', 192 | 'layer6': 'cd', 193 | 'layer7': 'cd', 194 | 'layer8': 'cd', 195 | 'layer9': 'cd', 196 | 'layer10': 'cd', 197 | 'layer11': 'cd', 198 | 'layer12': 'cd', 199 | 'layer13': 'cd', 200 | 'layer14': 'cd', 201 | 'layer15': 'cd', 202 | }, 203 | 'a16': { 204 | 'layer0': 'ad', 205 | 'layer1': 'ad', 206 | 'layer2': 'ad', 207 | 'layer3': 'ad', 208 | 'layer4': 'ad', 209 | 'layer5': 'ad', 210 | 'layer6': 'ad', 211 | 'layer7': 'ad', 212 | 'layer8': 'ad', 213 | 'layer9': 'ad', 214 | 'layer10': 'ad', 215 | 'layer11': 'ad', 216 | 'layer12': 'ad', 217 | 'layer13': 'ad', 218 | 'layer14': 'ad', 219 | 'layer15': 'ad', 220 | }, 221 | 'r16': { 222 | 'layer0': 'rd', 223 | 'layer1': 'rd', 224 | 'layer2': 'rd', 225 | 'layer3': 'rd', 226 | 'layer4': 'rd', 227 | 'layer5': 'rd', 228 | 'layer6': 'rd', 229 | 'layer7': 'rd', 230 | 'layer8': 'rd', 231 | 'layer9': 'rd', 232 | 'layer10': 'rd', 233 | 'layer11': 'rd', 234 | 'layer12': 'rd', 235 | 'layer13': 'rd', 236 | 'layer14': 'rd', 237 | 'layer15': 'rd', 238 | }, 239 | 'carv4': { 240 | 'layer0': 'cd', 241 | 'layer1': 'ad', 242 | 'layer2': 'rd', 243 | 'layer3': 'cv', 244 | 'layer4': 'cd', 245 | 'layer5': 'ad', 246 | 'layer6': 'rd', 247 | 'layer7': 'cv', 248 | 'layer8': 'cd', 249 | 'layer9': 'ad', 250 | 'layer10': 'rd', 251 | 'layer11': 'cv', 252 | 'layer12': 'cd', 253 | 'layer13': 'ad', 254 | 'layer14': 'rd', 255 | 'layer15': 'cv', 256 | }, 257 | } 258 | 259 | 260 | def config_model(model): 261 | model_options = list(nets.keys()) 262 | assert model in model_options, \ 263 | 'unrecognized model, please choose from %s' % str(model_options) 264 | 265 | print(str(nets[model])) 266 | 267 | pdcs = [] 268 | for i in range(16): 269 | layer_name = 'layer%d' % i 270 | op = nets[model][layer_name] 271 | pdcs.append(createConvFunc(op)) 272 | 273 | return pdcs 274 | 275 | def config_model_converted(model): 276 | model_options = list(nets.keys()) 277 | assert model in model_options, \ 278 | 'unrecognized model, please choose from %s' % str(model_options) 279 | 280 | print(str(nets[model])) 281 | 282 | pdcs = [] 283 | for i in range(16): 284 | layer_name = 'layer%d' % i 285 | op = nets[model][layer_name] 286 | pdcs.append(op) 287 | 288 | return pdcs 289 | 290 | -------------------------------------------------------------------------------- /trainer.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import random 4 | import sys 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | import torch.optim as optim 9 | from tensorboardX import SummaryWriter 10 | from torch.nn.modules.loss import CrossEntropyLoss 11 | from torch.utils.data import DataLoader 12 | from tqdm import tqdm 13 | from utils import DiceLoss, test_single_volume 14 | from torchvision import transforms 15 | import matplotlib.pyplot as plt 16 | import pandas as pd 17 | import datetime 18 | 19 | from datasets.dataset_tooth import TeethDataset, RandomGenerator 20 | 21 | 22 | def inference(model, testloader, args, test_save_path=None): 23 | model.eval() 24 | metric_list = 0.0 25 | 26 | for i_batch, sampled_batch in tqdm(enumerate(testloader)): 27 | h, w = sampled_batch["image"].size()[2:] 28 | image, label, case_name = sampled_batch["image"], sampled_batch["label"], sampled_batch['case_name'][0] 29 | metric_i = test_single_volume(image, label, model, classes=args.num_classes, patch_size=[args.img_size, args.img_size], 30 | test_save_path=test_save_path, case=case_name, z_spacing=args.z_spacing) 31 | metric_list += np.array(metric_i) 32 | logging.info(' idx %d case %s mean_dice %f mean_hd95 %f' % (i_batch, case_name, np.mean(metric_i, axis=0)[0], np.mean(metric_i, axis=0)[1])) 33 | 34 | metric_list = metric_list / len(testloader.dataset) 35 | 36 | for i in range(1, args.num_classes): 37 | logging.info('Mean class %d mean_dice %f mean_hd95 %f' % (i, metric_list[i-1][0], metric_list[i-1][1])) 38 | 39 | performance = np.mean(metric_list, axis=0)[0] 40 | mean_hd95 = np.mean(metric_list, axis=0)[1] 41 | 42 | logging.info('Testing performance in best val model: mean_dice : %f mean_hd95 : %f' % (performance, mean_hd95)) 43 | 44 | return performance, mean_hd95 45 | 46 | 47 | def plot_result(dice, h, snapshot_path,args): 48 | dict = {'mean_dice': dice, 'mean_hd95': h} 49 | df = pd.DataFrame(dict) 50 | plt.figure(0) 51 | df['mean_dice'].plot() 52 | resolution_value = 1200 53 | plt.title('Mean Dice') 54 | date_and_time = datetime.datetime.now() 55 | filename = f'{args.model_name}_' + str(date_and_time)+'dice'+'.png' 56 | save_mode_path = os.path.join(snapshot_path, filename) 57 | plt.savefig(save_mode_path, format="png", dpi=resolution_value) 58 | plt.figure(1) 59 | df['mean_hd95'].plot() 60 | plt.title('Mean hd95') 61 | filename = f'{args.model_name}_' + str(date_and_time)+'hd95'+'.png' 62 | save_mode_path = os.path.join(snapshot_path, filename) 63 | #save csv 64 | filename = f'{args.model_name}_' + str(date_and_time)+'results'+'.csv' 65 | save_mode_path = os.path.join(snapshot_path, filename) 66 | df.to_csv(save_mode_path, sep='\t') 67 | 68 | 69 | def trainer(args, model, snapshot_path): 70 | date_and_time = datetime.datetime.now() 71 | 72 | os.makedirs(os.path.join(snapshot_path, 'test'), exist_ok=True) 73 | test_save_path = os.path.join(snapshot_path, 'test') 74 | 75 | # Save logs 76 | logging.basicConfig(filename=snapshot_path + f"/{args.model_name}" + str(date_and_time) + "_log.txt", level=logging.INFO, 77 | format='[%(asctime)s.%(msecs)03d] %(message)s', datefmt='%H:%M:%S') 78 | logging.getLogger().addHandler(logging.StreamHandler(sys.stdout)) 79 | logging.info(str(args)) 80 | base_lr = args.base_lr 81 | num_classes = args.num_classes 82 | batch_size = args.batch_size * args.n_gpu 83 | 84 | db_train = TeethDataset(root=args.root_path, split="train", 85 | transform=transforms.Compose( 86 | [RandomGenerator(output_size=[args.img_size, args.img_size])])) 87 | 88 | 89 | db_test = TeethDataset(root=args.test_path, split="test_vol", list_dir=args.list_dir) 90 | testloader = DataLoader(db_test, batch_size=1, shuffle=False, num_workers=1) 91 | 92 | print("The length of train set is: {}".format(len(db_train))) 93 | 94 | def worker_init_fn(worker_id): 95 | random.seed(args.seed + worker_id) 96 | 97 | trainloader = DataLoader(db_train, batch_size=batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True, 98 | worker_init_fn=worker_init_fn) 99 | if args.n_gpu > 1: 100 | model = nn.DataParallel(model) 101 | model.train() 102 | 103 | ce_loss = CrossEntropyLoss() 104 | dice_loss = DiceLoss(num_classes) 105 | optimizer = optim.SGD(model.parameters(), lr=base_lr, momentum=0.9, weight_decay=0.0001) 106 | 107 | writer = SummaryWriter(snapshot_path + '/log') 108 | 109 | iter_num = 0 110 | max_epoch = args.max_epochs 111 | max_iterations = args.max_epochs * len(trainloader) 112 | logging.info("{} iterations per epoch. {} max iterations ".format(len(trainloader), max_iterations)) 113 | 114 | 115 | best_performance = 0.0 116 | iterator = tqdm(range(max_epoch), ncols=70) 117 | dice_=[] 118 | hd95_= [] 119 | 120 | for epoch_num in iterator: 121 | for i_batch, sampled_batch in enumerate(trainloader): 122 | image_batch, label_batch = sampled_batch['image'], sampled_batch['label'] 123 | image_batch, label_batch = image_batch.cuda(), label_batch.cuda() 124 | 125 | B, C, H, W = image_batch.shape 126 | image_batch = image_batch.expand(B, 3, H, W) 127 | 128 | outputs = model(image_batch) 129 | loss_ce = ce_loss(outputs, label_batch[:].long()) 130 | loss_dice = dice_loss(outputs, label_batch, softmax=True) 131 | loss = 0.4 * loss_ce + 0.6 * loss_dice 132 | optimizer.zero_grad() 133 | loss.backward() 134 | optimizer.step() 135 | 136 | lr_ = base_lr * (1.0 - iter_num / max_iterations) ** 0.9 137 | for param_group in optimizer.param_groups: 138 | param_group['lr'] = lr_ 139 | 140 | iter_num = iter_num + 1 141 | writer.add_scalar('info/lr', lr_, iter_num) 142 | writer.add_scalar('info/total_loss', loss, iter_num) 143 | writer.add_scalar('info/loss_ce', loss_ce, iter_num) 144 | writer.add_scalar('info/loss_dice', loss_dice, iter_num) 145 | 146 | logging.info('iteration %d : loss : %f, loss_ce: %f loss_dice: %f' % (iter_num, loss.item(), loss_ce.item(), loss_dice.item())) 147 | 148 | try: 149 | if iter_num % 10 == 0: 150 | image = image_batch[1, 0:1, :, :] 151 | image = (image - image.min()) / (image.max() - image.min()) 152 | writer.add_image('train/Image', image, iter_num) 153 | outputs = torch.argmax(torch.softmax(outputs, dim=1), dim=1, keepdim=True) 154 | writer.add_image('train/Prediction', outputs[1, ...] * 50, iter_num) 155 | labs = label_batch[1, ...].unsqueeze(0) * 50 156 | writer.add_image('train/GroundTruth', labs, iter_num) 157 | except: pass 158 | 159 | # Test 160 | if (epoch_num + 1) % args.eval_interval == 0: 161 | filename = f'{args.model_name}_epoch_{epoch_num}.pth' 162 | save_mode_path = os.path.join(snapshot_path, filename) 163 | torch.save(model.state_dict(), save_mode_path) 164 | logging.info("save model to {}".format(save_mode_path)) 165 | 166 | logging.info("*" * 20) 167 | logging.info(f"Running Inference after epoch {epoch_num}") 168 | print(f"Epoch {epoch_num}") 169 | mean_dice, mean_hd95 = inference(model, testloader, args, test_save_path=test_save_path) 170 | dice_.append(mean_dice) 171 | hd95_.append(mean_hd95) 172 | model.train() 173 | 174 | if epoch_num >= max_epoch - 1: 175 | filename = f'{args.model_name}_epoch_{epoch_num}.pth' 176 | save_mode_path = os.path.join(snapshot_path, filename) 177 | torch.save(model.state_dict(), save_mode_path) 178 | logging.info("save model to {}".format(save_mode_path)) 179 | 180 | if not (epoch_num + 1) % args.eval_interval == 0: 181 | logging.info("*" * 20) 182 | logging.info(f"Running Inference after epoch {epoch_num} (Last Epoch)") 183 | print(f"Epoch {epoch_num}, Last Epcoh") 184 | mean_dice, mean_hd95 = inference(model, testloader, args, test_save_path=test_save_path) 185 | dice_.append(mean_dice) 186 | hd95_.append(mean_hd95) 187 | model.train() 188 | 189 | iterator.close() 190 | break 191 | 192 | plot_result(dice_,hd95_,snapshot_path,args) 193 | writer.close() 194 | return "Training Finished!" -------------------------------------------------------------------------------- /models/pidinet.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | import math 4 | import numpy as np 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | from .ops import Conv2d 11 | from .config import config_model, config_model_converted 12 | 13 | class CSAM(nn.Module): 14 | """ 15 | Compact Spatial Attention Module 16 | """ 17 | def __init__(self, channels): 18 | super(CSAM, self).__init__() 19 | 20 | mid_channels = 4 21 | self.relu1 = nn.ReLU() 22 | self.conv1 = nn.Conv2d(channels, mid_channels, kernel_size=1, padding=0) 23 | self.conv2 = nn.Conv2d(mid_channels, 1, kernel_size=3, padding=1, bias=False) 24 | self.sigmoid = nn.Sigmoid() 25 | nn.init.constant_(self.conv1.bias, 0) 26 | 27 | def forward(self, x): 28 | y = self.relu1(x) 29 | y = self.conv1(y) 30 | y = self.conv2(y) 31 | y = self.sigmoid(y) 32 | 33 | return x * y 34 | 35 | class CDCM(nn.Module): 36 | """ 37 | Compact Dilation Convolution based Module 38 | """ 39 | def __init__(self, in_channels, out_channels): 40 | super(CDCM, self).__init__() 41 | 42 | self.relu1 = nn.ReLU() 43 | self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0) 44 | self.conv2_1 = nn.Conv2d(out_channels, out_channels, kernel_size=3, dilation=5, padding=5, bias=False) 45 | self.conv2_2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, dilation=7, padding=7, bias=False) 46 | self.conv2_3 = nn.Conv2d(out_channels, out_channels, kernel_size=3, dilation=9, padding=9, bias=False) 47 | self.conv2_4 = nn.Conv2d(out_channels, out_channels, kernel_size=3, dilation=11, padding=11, bias=False) 48 | nn.init.constant_(self.conv1.bias, 0) 49 | 50 | def forward(self, x): 51 | x = self.relu1(x) 52 | x = self.conv1(x) 53 | x1 = self.conv2_1(x) 54 | x2 = self.conv2_2(x) 55 | x3 = self.conv2_3(x) 56 | x4 = self.conv2_4(x) 57 | return x1 + x2 + x3 + x4 58 | 59 | 60 | class MapReduce(nn.Module): 61 | """ 62 | Reduce feature maps into a single edge map 63 | """ 64 | def __init__(self, channels): 65 | super(MapReduce, self).__init__() 66 | self.conv = nn.Conv2d(channels, 1, kernel_size=1, padding=0) 67 | nn.init.constant_(self.conv.bias, 0) 68 | 69 | def forward(self, x): 70 | return self.conv(x) 71 | 72 | 73 | class PDCBlock(nn.Module): 74 | def __init__(self, pdc, inplane, ouplane, stride=1): 75 | super(PDCBlock, self).__init__() 76 | self.stride=stride 77 | 78 | self.stride=stride 79 | if self.stride > 1: 80 | self.pool = nn.MaxPool2d(kernel_size=2, stride=2) 81 | self.shortcut = nn.Conv2d(inplane, ouplane, kernel_size=1, padding=0) 82 | self.conv1 = Conv2d(pdc, inplane, inplane, kernel_size=3, padding=1, groups=inplane, bias=False) 83 | self.relu2 = nn.ReLU() 84 | self.conv2 = nn.Conv2d(inplane, ouplane, kernel_size=1, padding=0, bias=False) 85 | 86 | def forward(self, x): 87 | if self.stride > 1: 88 | x = self.pool(x) 89 | y = self.conv1(x) 90 | y = self.relu2(y) 91 | y = self.conv2(y) 92 | if self.stride > 1: 93 | x = self.shortcut(x) 94 | y = y + x 95 | return y 96 | 97 | class PDCBlock_converted(nn.Module): 98 | """ 99 | CPDC, APDC can be converted to vanilla 3x3 convolution 100 | RPDC can be converted to vanilla 5x5 convolution 101 | """ 102 | def __init__(self, pdc, inplane, ouplane, stride=1): 103 | super(PDCBlock_converted, self).__init__() 104 | self.stride=stride 105 | 106 | if self.stride > 1: 107 | self.pool = nn.MaxPool2d(kernel_size=2, stride=2) 108 | self.shortcut = nn.Conv2d(inplane, ouplane, kernel_size=1, padding=0) 109 | if pdc == 'rd': 110 | self.conv1 = nn.Conv2d(inplane, inplane, kernel_size=5, padding=2, groups=inplane, bias=False) 111 | else: 112 | self.conv1 = nn.Conv2d(inplane, inplane, kernel_size=3, padding=1, groups=inplane, bias=False) 113 | self.relu2 = nn.ReLU() 114 | self.conv2 = nn.Conv2d(inplane, ouplane, kernel_size=1, padding=0, bias=False) 115 | 116 | def forward(self, x): 117 | if self.stride > 1: 118 | x = self.pool(x) 119 | y = self.conv1(x) 120 | y = self.relu2(y) 121 | y = self.conv2(y) 122 | if self.stride > 1: 123 | x = self.shortcut(x) 124 | y = y + x 125 | return y 126 | 127 | class PiDiNet(nn.Module): 128 | def __init__(self, inplane, pdcs, dil=None, sa=False, convert=False): 129 | super(PiDiNet, self).__init__() 130 | self.sa = sa 131 | if dil is not None: 132 | assert isinstance(dil, int), 'dil should be an int' 133 | self.dil = dil 134 | 135 | self.fuseplanes = [] 136 | 137 | self.inplane = inplane 138 | if convert: 139 | if pdcs[0] == 'rd': 140 | init_kernel_size = 5 141 | init_padding = 2 142 | else: 143 | init_kernel_size = 3 144 | init_padding = 1 145 | self.init_block = nn.Conv2d(3, self.inplane, 146 | kernel_size=init_kernel_size, padding=init_padding, bias=False) 147 | block_class = PDCBlock_converted 148 | else: 149 | self.init_block = Conv2d(pdcs[0], 3, self.inplane, kernel_size=3, padding=1) 150 | block_class = PDCBlock 151 | 152 | self.block1_1 = block_class(pdcs[1], self.inplane, self.inplane, stride=2) 153 | self.block1_2 = block_class(pdcs[2], self.inplane, self.inplane, stride=2) 154 | self.block1_3 = block_class(pdcs[3], self.inplane, self.inplane) 155 | self.fuseplanes.append(self.inplane) # C 156 | 157 | inplane = self.inplane 158 | self.inplane = self.inplane * 2 159 | self.block2_1 = block_class(pdcs[4], inplane, self.inplane, stride=2) 160 | self.block2_2 = block_class(pdcs[5], self.inplane, self.inplane) 161 | self.block2_3 = block_class(pdcs[6], self.inplane, self.inplane) 162 | self.block2_4 = block_class(pdcs[7], self.inplane, self.inplane) 163 | self.fuseplanes.append(self.inplane) # 2C 164 | 165 | inplane = self.inplane 166 | self.inplane = self.inplane * 2 167 | self.block3_1 = block_class(pdcs[8], inplane, self.inplane, stride=2) 168 | self.block3_2 = block_class(pdcs[9], self.inplane, self.inplane) 169 | self.block3_3 = block_class(pdcs[10], self.inplane, self.inplane) 170 | self.block3_4 = block_class(pdcs[11], self.inplane, self.inplane) 171 | self.fuseplanes.append(self.inplane) # 4C 172 | 173 | inplane = self.inplane 174 | self.inplane = self.inplane * 2 175 | self.block4_1 = block_class(pdcs[12], inplane, self.inplane, stride=2) 176 | self.block4_2 = block_class(pdcs[13], self.inplane, self.inplane) 177 | self.block4_3 = block_class(pdcs[14], self.inplane, self.inplane) 178 | self.block4_4 = block_class(pdcs[15], self.inplane, self.inplane) 179 | self.fuseplanes.append(self.inplane) # 8C 180 | 181 | self.conv_reduces = nn.ModuleList() 182 | if self.sa and self.dil is not None: 183 | self.attentions = nn.ModuleList() 184 | self.dilations = nn.ModuleList() 185 | for i in range(4): 186 | self.dilations.append(CDCM(self.fuseplanes[i], self.dil)) 187 | self.attentions.append(CSAM(self.dil)) 188 | self.conv_reduces.append(MapReduce(self.dil)) 189 | elif self.sa: 190 | self.attentions = nn.ModuleList() 191 | for i in range(4): 192 | self.attentions.append(CSAM(self.fuseplanes[i])) 193 | self.conv_reduces.append(MapReduce(self.fuseplanes[i])) 194 | elif self.dil is not None: 195 | self.dilations = nn.ModuleList() 196 | for i in range(4): 197 | self.dilations.append(CDCM(self.fuseplanes[i], self.dil)) 198 | self.conv_reduces.append(MapReduce(self.dil)) 199 | else: 200 | for i in range(4): 201 | self.conv_reduces.append(MapReduce(self.fuseplanes[i])) 202 | 203 | self.classifier = nn.Conv2d(4, 1, kernel_size=1) # has bias 204 | nn.init.constant_(self.classifier.weight, 0.25) 205 | nn.init.constant_(self.classifier.bias, 0) 206 | 207 | print('initialization done') 208 | 209 | def get_weights(self): 210 | conv_weights = [] 211 | bn_weights = [] 212 | relu_weights = [] 213 | for pname, p in self.named_parameters(): 214 | if 'bn' in pname: 215 | bn_weights.append(p) 216 | elif 'relu' in pname: 217 | relu_weights.append(p) 218 | else: 219 | conv_weights.append(p) 220 | 221 | return conv_weights, bn_weights, relu_weights 222 | 223 | def forward(self, x): 224 | H, W = x.size()[2:] 225 | 226 | x = self.init_block(x) 227 | 228 | x1 = self.block1_1(x) 229 | x1 = self.block1_2(x1) 230 | x1 = self.block1_3(x1) 231 | 232 | x2 = self.block2_1(x1) 233 | x2 = self.block2_2(x2) 234 | x2 = self.block2_3(x2) 235 | x2 = self.block2_4(x2) 236 | 237 | x3 = self.block3_1(x2) 238 | x3 = self.block3_2(x3) 239 | x3 = self.block3_3(x3) 240 | x3 = self.block3_4(x3) 241 | 242 | x4 = self.block4_1(x3) 243 | x4 = self.block4_2(x4) 244 | x4 = self.block4_3(x4) 245 | x4 = self.block4_4(x4) 246 | 247 | x_fuses = [] 248 | if self.sa and self.dil is not None: 249 | for i, xi in enumerate([x1, x2, x3, x4]): 250 | x_fuses.append(self.attentions[i](self.dilations[i](xi))) 251 | elif self.sa: 252 | for i, xi in enumerate([x1, x2, x3, x4]): 253 | x_fuses.append(self.attentions[i](xi)) 254 | elif self.dil is not None: 255 | for i, xi in enumerate([x1, x2, x3, x4]): 256 | x_fuses.append(self.dilations[i](xi)) 257 | else: 258 | x_fuses = [x1, x2, x3, x4] 259 | 260 | e1 = self.conv_reduces[0](x_fuses[0]) 261 | e1 = F.interpolate(e1, (H, W), mode="bilinear", align_corners=False) 262 | 263 | e2 = self.conv_reduces[1](x_fuses[1]) 264 | e2 = F.interpolate(e2, (H, W), mode="bilinear", align_corners=False) 265 | 266 | e3 = self.conv_reduces[2](x_fuses[2]) 267 | e3 = F.interpolate(e3, (H, W), mode="bilinear", align_corners=False) 268 | 269 | e4 = self.conv_reduces[3](x_fuses[3]) 270 | e4 = F.interpolate(e4, (H, W), mode="bilinear", align_corners=False) 271 | 272 | outputs = [e1, e2, e3, e4] 273 | 274 | output = self.classifier(torch.cat(outputs, dim=1)) 275 | #if not self.training: 276 | # return torch.sigmoid(output) 277 | 278 | outputs.append(output) 279 | outputs = [torch.sigmoid(r) for r in outputs] 280 | return outputs 281 | 282 | 283 | def pidinet_tiny(args): 284 | pdcs = config_model(args.config) 285 | dil = 8 if args.dil else None 286 | return PiDiNet(20, pdcs, dil=dil, sa=args.sa) 287 | 288 | def pidinet_small(args): 289 | pdcs = config_model(args.config) 290 | dil = 12 if args.dil else None 291 | return PiDiNet(30, pdcs, dil=dil, sa=args.sa) 292 | 293 | def pidinet(args): 294 | pdcs = config_model(args.config) 295 | dil = 24 if args.dil else None 296 | return PiDiNet(60, pdcs, dil=dil, sa=args.sa) 297 | 298 | 299 | 300 | ## convert pidinet to vanilla cnn 301 | 302 | def pidinet_tiny_converted(args): 303 | pdcs = config_model_converted(args.config) 304 | dil = 8 if args.dil else None 305 | return PiDiNet(20, pdcs, dil=dil, sa=args.sa, convert=True) 306 | 307 | def pidinet_small_converted(args): 308 | pdcs = config_model_converted(args.config) 309 | dil = 12 if args.dil else None 310 | return PiDiNet(30, pdcs, dil=dil, sa=args.sa, convert=True) 311 | 312 | def pidinet_converted(args): 313 | pdcs = config_model_converted(args.config) 314 | dil = 24 if args.dil else None 315 | return PiDiNet(60, pdcs, dil=dil, sa=args.sa, convert=True) 316 | -------------------------------------------------------------------------------- /models/Encoder.py: -------------------------------------------------------------------------------- 1 | """ 2 | Author: Omid Nejati Manzari 3 | Date: Jun 2023 4 | """ 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torchvision 9 | from timm.models.layers import trunc_normal_ 10 | from utils import * 11 | from einops import rearrange 12 | from einops.layers.torch import Rearrange 13 | from .pidinet import PiDiNet 14 | from .config import config_model, config_model_converted 15 | 16 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 17 | 18 | 19 | class Attention(nn.Module): 20 | def __init__(self, dim, factor, heads = 8, dim_head = 64, dropout = 0.): 21 | super().__init__() 22 | inner_dim = dim_head * heads 23 | project_out = not (heads == 1 and dim_head == dim) 24 | 25 | self.heads = heads 26 | self.scale = dim_head ** -0.5 27 | 28 | self.attend = nn.Softmax(dim = -1) 29 | self.dropout = nn.Dropout(dropout) 30 | 31 | self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) 32 | 33 | self.to_out = nn.Sequential( 34 | nn.Linear(inner_dim, dim * factor), 35 | nn.Dropout(dropout) 36 | ) if project_out else nn.Identity() 37 | 38 | def forward(self, x): 39 | qkv = self.to_qkv(x).chunk(3, dim = -1) 40 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv) 41 | 42 | dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale 43 | 44 | attn = self.attend(dots) 45 | attn = self.dropout(attn) 46 | 47 | out = torch.matmul(attn, v) 48 | out = rearrange(out, 'b h n d -> b n (h d)') 49 | return self.to_out(out) 50 | 51 | 52 | class SwinTransformer(nn.Module): 53 | def __init__(self, img_size=224, patch_size=4, embed_dim=96, 54 | depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], 55 | window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None, 56 | drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, 57 | norm_layer=nn.LayerNorm, ape=False, patch_norm=True, **kwargs): 58 | 59 | super().__init__() 60 | 61 | patches_resolution = [img_size // patch_size, img_size // patch_size] 62 | num_patches = patches_resolution[0] * patches_resolution[1] 63 | 64 | self.num_layers = len(depths) 65 | self.embed_dim = embed_dim 66 | self.ape = ape 67 | self.patch_norm = patch_norm 68 | self.num_features = int(embed_dim * 2 ** (self.num_layers - 1)) 69 | self.mlp_ratio = mlp_ratio 70 | 71 | 72 | if self.ape: 73 | self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) 74 | trunc_normal_(self.absolute_pos_embed, std=.02) 75 | 76 | self.pos_drop = nn.Dropout(p=drop_rate) 77 | 78 | # stochastic depth 79 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule 80 | 81 | # build layers 82 | self.layers = nn.ModuleList() 83 | for i_layer in range(self.num_layers): 84 | layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer), 85 | input_resolution=(patches_resolution[0] // (2 ** i_layer), 86 | patches_resolution[1] // (2 ** i_layer)), 87 | depth=depths[i_layer], 88 | num_heads=num_heads[i_layer], 89 | window_size=window_size, 90 | mlp_ratio=self.mlp_ratio, 91 | qkv_bias=qkv_bias, qk_scale=qk_scale, 92 | drop=drop_rate, attn_drop=attn_drop_rate, 93 | drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], 94 | norm_layer=norm_layer, 95 | downsample=None) 96 | self.layers.append(layer) 97 | 98 | self.apply(self._init_weights) 99 | 100 | def _init_weights(self, m): 101 | if isinstance(m, nn.Linear): 102 | trunc_normal_(m.weight, std=.02) 103 | if isinstance(m, nn.Linear) and m.bias is not None: 104 | nn.init.constant_(m.bias, 0) 105 | elif isinstance(m, nn.LayerNorm): 106 | nn.init.constant_(m.bias, 0) 107 | nn.init.constant_(m.weight, 1.0) 108 | 109 | @torch.jit.ignore 110 | def no_weight_decay(self): 111 | return {'absolute_pos_embed'} 112 | 113 | @torch.jit.ignore 114 | def no_weight_decay_keywords(self): 115 | return {'relative_position_bias_table'} 116 | 117 | 118 | class PyramidFeatures(nn.Module): 119 | def __init__(self, config, img_size = 224, in_channels=3): 120 | super().__init__() 121 | 122 | model_path = config.swin_pretrained_path 123 | self.swin_transformer = SwinTransformer(img_size,in_chans = 3) 124 | checkpoint = torch.load(model_path, map_location=torch.device(device))['model'] 125 | unexpected = ["patch_embed.proj.weight", "patch_embed.proj.bias", "patch_embed.norm.weight", "patch_embed.norm.bias", 126 | "head.weight", "head.bias", "layers.0.downsample.norm.weight", "layers.0.downsample.norm.bias", 127 | "layers.0.downsample.reduction.weight", "layers.1.downsample.norm.weight", "layers.1.downsample.norm.bias", 128 | "layers.1.downsample.reduction.weight", "layers.2.downsample.norm.weight", "layers.2.downsample.norm.bias", 129 | "layers.2.downsample.reduction.weight", "layers.3.downsample.norm.weight", "layers.3.downsample.norm.bias", 130 | "layers.3.downsample.reduction.weight","norm.weight", "norm.bias"] 131 | 132 | 133 | pidinet = PiDiNet(30, config_model_converted(config.pdcs), dil=None, sa=False, convert=True) 134 | self.pidinet_layers = nn.ModuleList(pidinet.children())[:17] 135 | 136 | 137 | self.p1_ch = nn.Conv2d(config.cnn_pyramid_fm[0], config.swin_pyramid_fm[0] , kernel_size = 1) 138 | self.p1_pm = PatchMerging((config.image_size // config.patch_size, config.image_size // config.patch_size), config.swin_pyramid_fm[0]) 139 | self.p1_pm.state_dict()['reduction.weight'][:]= checkpoint["layers.0.downsample.reduction.weight"] 140 | self.p1_pm.state_dict()['norm.weight'][:]= checkpoint["layers.0.downsample.norm.weight"] 141 | self.p1_pm.state_dict()['norm.bias'][:]= checkpoint["layers.0.downsample.norm.bias"] 142 | self.norm_1 = nn.LayerNorm(config.swin_pyramid_fm[0]) 143 | self.avgpool_1 = nn.AdaptiveAvgPool1d(1) 144 | 145 | 146 | self.p2_ch = nn.Conv2d(config.cnn_pyramid_fm[1], config.swin_pyramid_fm[1] , kernel_size = 1) 147 | self.p2_pm = PatchMerging((config.image_size // config.patch_size // 2, config.image_size // config.patch_size // 2), config.swin_pyramid_fm[1]) 148 | self.p2_pm.state_dict()['reduction.weight'][:]= checkpoint["layers.1.downsample.reduction.weight"] 149 | self.p2_pm.state_dict()['norm.weight'][:]= checkpoint["layers.1.downsample.norm.weight"] 150 | self.p2_pm.state_dict()['norm.bias'][:]= checkpoint["layers.1.downsample.norm.bias"] 151 | 152 | 153 | self.p3_ch = nn.Conv2d(config.cnn_pyramid_fm[2] , config.swin_pyramid_fm[2] , kernel_size = 1) 154 | self.p3_pm = PatchMerging((config.image_size // config.patch_size // 4, config.image_size // config.patch_size // 4), config.swin_pyramid_fm[2]) 155 | self.p3_pm.state_dict()['reduction.weight'][:] = checkpoint["layers.2.downsample.reduction.weight"] 156 | self.p3_pm.state_dict()['norm.weight'][:] = checkpoint["layers.2.downsample.norm.weight"] 157 | self.p3_pm.state_dict()['norm.bias'][:] = checkpoint["layers.2.downsample.norm.bias"] 158 | 159 | self.p4_ch = nn.Conv2d(config.cnn_pyramid_fm[3], config.swin_pyramid_fm[3], kernel_size=1) 160 | self.norm_2 = nn.LayerNorm(config.swin_pyramid_fm[3]) 161 | self.avgpool_2 = nn.AdaptiveAvgPool1d(1) 162 | 163 | 164 | for key in list(checkpoint.keys()): 165 | if key in unexpected : 166 | del checkpoint[key] 167 | self.swin_transformer.load_state_dict(checkpoint) 168 | 169 | 170 | def forward(self, x): 171 | 172 | 173 | 174 | for i in range(4): 175 | x = self.pidinet_layers[i](x) 176 | 177 | 178 | # Level 1 179 | fm1 = x 180 | fm1_ch = self.p1_ch(x) 181 | fm1_reshaped = Rearrange('b c h w -> b (h w) c')(fm1_ch) 182 | sw1 = self.swin_transformer.layers[0](fm1_reshaped) 183 | sw1_skipped = fm1_reshaped + sw1 184 | norm1 = self.norm_1(sw1_skipped) 185 | sw1_CLS = self.avgpool_1(norm1.transpose(1, 2)) 186 | sw1_CLS_reshaped = Rearrange('b c 1 -> b 1 c')(sw1_CLS) 187 | fm1_sw1 = self.p1_pm(sw1_skipped) 188 | 189 | # Level 2 190 | fm1_sw2 = self.swin_transformer.layers[1](fm1_sw1) 191 | for i in range(4, 8): 192 | fm1 = self.pidinet_layers[i](fm1) 193 | 194 | fm2 = fm1 195 | fm2_ch = self.p2_ch(fm2) 196 | fm2_reshaped = Rearrange('b c h w -> b (h w) c')(fm2_ch) 197 | fm2_sw2_skipped = fm2_reshaped + fm1_sw2 198 | fm2_sw2 = self.p2_pm(fm2_sw2_skipped) 199 | 200 | # Level 3 201 | fm2_sw3 = self.swin_transformer.layers[2](fm2_sw2) 202 | for i in range(8, 12): 203 | fm2 = self.pidinet_layers[i](fm2) 204 | 205 | fm3 = fm2 206 | fm3_ch = self.p3_ch(fm3) 207 | fm3_reshaped = Rearrange('b c h w -> b (h w) c')(fm3_ch) 208 | fm3_sw3_skipped = fm3_reshaped + fm2_sw3 209 | fm3_sw3 = self.p3_pm(fm3_sw3_skipped) 210 | 211 | # Level 4 212 | fm3_sw4 = self.swin_transformer.layers[3](fm3_sw3) 213 | for i in range(12, 16): 214 | fm3 = self.pidinet_layers[i](fm3) 215 | 216 | fm4 = fm3 217 | fm4_ch = self.p4_ch(fm4) 218 | fm4_reshaped = Rearrange('b c h w -> b (h w) c')(fm4_ch) 219 | fm4_sw4_skipped = fm4_reshaped + fm3_sw4 220 | norm2 = self.norm_2(fm4_sw4_skipped) 221 | sw4_CLS = self.avgpool_2(norm2.transpose(1, 2)) 222 | sw4_CLS_reshaped = Rearrange('b c 1 -> b 1 c')(sw4_CLS) 223 | 224 | return [torch.cat((sw1_CLS_reshaped, sw1_skipped), dim=1), torch.cat((sw4_CLS_reshaped, fm4_sw4_skipped), dim=1)] 225 | 226 | # DLF Module 227 | class All2Cross(nn.Module): 228 | def __init__(self, config, img_size = 224 , in_chans=3, embed_dim=(96, 768), norm_layer=nn.LayerNorm): 229 | super().__init__() 230 | self.cross_pos_embed = config.cross_pos_embed 231 | self.pyramid = PyramidFeatures(config=config, img_size= img_size, in_channels=in_chans) 232 | 233 | n_p1 = (config.image_size // config.patch_size ) ** 2 # default: 3136 234 | n_p2 = (config.image_size // config.patch_size // 8) ** 2 # default: 49 235 | num_patches = (n_p1, n_p2) 236 | self.num_branches = 2 237 | 238 | self.pos_embed = nn.ParameterList([nn.Parameter(torch.zeros(1, 1 + num_patches[i], embed_dim[i])) for i in range(self.num_branches)]) 239 | 240 | total_depth = sum([sum(x[-2:]) for x in config.depth]) 241 | dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, total_depth)] # stochastic depth decay rule 242 | dpr_ptr = 0 243 | self.blocks = nn.ModuleList() 244 | for idx, block_config in enumerate(config.depth): 245 | curr_depth = max(block_config[:-1]) + block_config[-1] 246 | dpr_ = dpr[dpr_ptr:dpr_ptr + curr_depth] 247 | blk = MultiScaleBlock(embed_dim, num_patches, block_config, num_heads=config.num_heads, mlp_ratio=config.mlp_ratio, 248 | qkv_bias=config.qkv_bias, qk_scale=config.qk_scale, drop=config.drop_rate, 249 | attn_drop=config.attn_drop_rate, drop_path=dpr_, norm_layer=norm_layer) 250 | dpr_ptr += curr_depth 251 | self.blocks.append(blk) 252 | 253 | self.norm = nn.ModuleList([norm_layer(embed_dim[i]) for i in range(self.num_branches)]) 254 | 255 | for i in range(self.num_branches): 256 | if self.pos_embed[i].requires_grad: 257 | trunc_normal_(self.pos_embed[i], std=.02) 258 | 259 | self.apply(self._init_weights) 260 | 261 | def _init_weights(self, m): 262 | if isinstance(m, nn.Linear): 263 | trunc_normal_(m.weight, std=.02) 264 | if isinstance(m, nn.Linear) and m.bias is not None: 265 | nn.init.constant_(m.bias, 0) 266 | elif isinstance(m, nn.LayerNorm): 267 | nn.init.constant_(m.bias, 0) 268 | nn.init.constant_(m.weight, 1.0) 269 | 270 | @torch.jit.ignore 271 | def no_weight_decay(self): 272 | out = {'cls_token'} 273 | if self.pos_embed[0].requires_grad: 274 | out.add('pos_embed') 275 | return out 276 | 277 | def forward(self, x): 278 | xs = self.pyramid(x) 279 | 280 | if self.cross_pos_embed: 281 | for i in range(self.num_branches): 282 | xs[i] += self.pos_embed[i] 283 | 284 | for blk in self.blocks: 285 | xs = blk(xs) 286 | xs = [self.norm[i](x) for i, x in enumerate(xs)] 287 | 288 | return xs -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import torch.utils.checkpoint as checkpoint 7 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_ 8 | from timm.models.vision_transformer import _cfg, Mlp, Block 9 | 10 | def get_n_params(model): 11 | pp=0 12 | for p in list(model.parameters()): 13 | nn=1 14 | for s in list(p.size()): 15 | nn = nn*s 16 | pp += nn 17 | return pp 18 | 19 | 20 | ############ Swin Transformer ############ 21 | class Mlp(nn.Module): 22 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 23 | super().__init__() 24 | out_features = out_features or in_features 25 | hidden_features = hidden_features or in_features 26 | self.fc1 = nn.Linear(in_features, hidden_features) 27 | self.act = act_layer() 28 | self.fc2 = nn.Linear(hidden_features, out_features) 29 | self.drop = nn.Dropout(drop) 30 | 31 | def forward(self, x): 32 | x = self.fc1(x) 33 | x = self.act(x) 34 | x = self.drop(x) 35 | x = self.fc2(x) 36 | x = self.drop(x) 37 | return x 38 | 39 | 40 | def window_partition(x, window_size): 41 | """ 42 | Args: 43 | x: (B, H, W, C) 44 | window_size (int): window size 45 | Returns: 46 | windows: (num_windows*B, window_size, window_size, C) 47 | """ 48 | B, H, W, C = x.shape 49 | x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) 50 | windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) 51 | return windows 52 | 53 | 54 | def window_reverse(windows, window_size, H, W): 55 | """ 56 | Args: 57 | windows: (num_windows*B, window_size, window_size, C) 58 | window_size (int): Window size 59 | H (int): Height of image 60 | W (int): Width of image 61 | Returns: 62 | x: (B, H, W, C) 63 | """ 64 | B = int(windows.shape[0] / (H * W / window_size / window_size)) 65 | x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) 66 | x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) 67 | return x 68 | 69 | 70 | class WindowAttention(nn.Module): # W-MSA in the paper 71 | r""" Window based multi-head self attention (W-MSA) module with relative position bias. 72 | It supports both of shifted and non-shifted window. 73 | Args: 74 | dim (int): Number of input channels. 75 | window_size (tuple[int]): The height and width of the window. 76 | num_heads (int): Number of attention heads. 77 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 78 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set 79 | attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 80 | proj_drop (float, optional): Dropout ratio of output. Default: 0.0 81 | """ 82 | 83 | def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): 84 | 85 | super().__init__() 86 | self.dim = dim 87 | self.window_size = window_size # Wh, Ww 88 | self.num_heads = num_heads 89 | head_dim = dim // num_heads 90 | self.scale = qk_scale or head_dim ** -0.5 91 | 92 | # define a parameter table of relative position bias 93 | self.relative_position_bias_table = nn.Parameter( 94 | torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH 95 | 96 | # get pair-wise relative position index for each token inside the window 97 | coords_h = torch.arange(self.window_size[0]) 98 | coords_w = torch.arange(self.window_size[1]) 99 | coords = torch.stack(torch.meshgrid([coords_h, coords_w], indexing='ij')) # 2, Wh, Ww 100 | coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww 101 | relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww 102 | relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 103 | relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 104 | relative_coords[:, :, 1] += self.window_size[1] - 1 105 | relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 106 | relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww 107 | self.register_buffer("relative_position_index", relative_position_index) 108 | 109 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 110 | self.attn_drop = nn.Dropout(attn_drop) 111 | self.proj = nn.Linear(dim, dim) 112 | self.proj_drop = nn.Dropout(proj_drop) 113 | 114 | trunc_normal_(self.relative_position_bias_table, std=.02) 115 | self.softmax = nn.Softmax(dim=-1) 116 | 117 | def forward(self, x, mask=None): 118 | """ 119 | Args: 120 | x: input features with shape of (num_windows*B, N, C) >>> (B * 32*32, 4*4, 192) 121 | mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None 122 | """ 123 | B_, N, C = x.shape 124 | qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) #AMBIGUOUS X) 125 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) 126 | 127 | q = q * self.scale 128 | attn = (q @ k.transpose(-2, -1)) 129 | 130 | relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( 131 | self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH 132 | relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww 133 | attn = attn + relative_position_bias.unsqueeze(0) 134 | 135 | if mask is not None: 136 | nW = mask.shape[0] 137 | attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) 138 | attn = attn.view(-1, self.num_heads, N, N) 139 | attn = self.softmax(attn) 140 | else: 141 | attn = self.softmax(attn) 142 | 143 | attn = self.attn_drop(attn) 144 | 145 | x = (attn @ v).transpose(1, 2).reshape(B_, N, C) 146 | x = self.proj(x) 147 | x = self.proj_drop(x) 148 | return x 149 | 150 | def extra_repr(self) -> str: 151 | return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' 152 | 153 | def flops(self, N): 154 | # calculate flops for 1 window with token length of N 155 | flops = 0 156 | # qkv = self.qkv(x) 157 | flops += N * self.dim * 3 * self.dim 158 | # attn = (q @ k.transpose(-2, -1)) 159 | flops += self.num_heads * N * (self.dim // self.num_heads) * N 160 | # x = (attn @ v) 161 | flops += self.num_heads * N * N * (self.dim // self.num_heads) 162 | # x = self.proj(x) 163 | flops += N * self.dim * self.dim 164 | return flops 165 | 166 | 167 | class SwinTransformerBlock(nn.Module): 168 | r""" Swin Transformer Block. 169 | Args: 170 | dim (int): Number of input channels. 171 | input_resolution (tuple[int]): Input resulotion. 172 | num_heads (int): Number of attention heads. 173 | window_size (int): Window size. 174 | shift_size (int): Shift size for SW-MSA. 175 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 176 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 177 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. 178 | drop (float, optional): Dropout rate. Default: 0.0 179 | attn_drop (float, optional): Attention dropout rate. Default: 0.0 180 | drop_path (float, optional): Stochastic depth rate. Default: 0.0 181 | act_layer (nn.Module, optional): Activation layer. Default: nn.GELU 182 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 183 | """ 184 | 185 | def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, 186 | mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., 187 | act_layer=nn.GELU, norm_layer=nn.LayerNorm): 188 | super().__init__() 189 | self.dim = dim 190 | self.input_resolution = input_resolution 191 | self.num_heads = num_heads 192 | self.window_size = window_size 193 | self.shift_size = shift_size 194 | self.mlp_ratio = mlp_ratio 195 | if min(self.input_resolution) <= self.window_size: 196 | # if window size is larger than input resolution, we don't partition windows 197 | self.shift_size = 0 198 | self.window_size = min(self.input_resolution) 199 | assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" 200 | 201 | self.norm1 = norm_layer(dim) 202 | self.attn = WindowAttention( 203 | dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, 204 | qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) 205 | 206 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 207 | self.norm2 = norm_layer(dim) 208 | mlp_hidden_dim = int(dim * mlp_ratio) 209 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 210 | 211 | if self.shift_size > 0: 212 | # calculate attention mask for SW-MSA 213 | H, W = self.input_resolution 214 | img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 215 | h_slices = (slice(0, -self.window_size), 216 | slice(-self.window_size, -self.shift_size), 217 | slice(-self.shift_size, None)) 218 | w_slices = (slice(0, -self.window_size), 219 | slice(-self.window_size, -self.shift_size), 220 | slice(-self.shift_size, None)) 221 | cnt = 0 222 | for h in h_slices: 223 | for w in w_slices: 224 | img_mask[:, h, w, :] = cnt 225 | cnt += 1 226 | 227 | mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 228 | mask_windows = mask_windows.view(-1, self.window_size * self.window_size) 229 | attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) 230 | attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) 231 | else: 232 | attn_mask = None 233 | 234 | self.register_buffer("attn_mask", attn_mask) 235 | 236 | def forward(self, x): 237 | H, W = self.input_resolution 238 | B, L, C = x.shape 239 | assert L == H * W, "input feature has wrong size" 240 | 241 | shortcut = x 242 | x = self.norm1(x) 243 | x = x.view(B, H, W, C) 244 | 245 | # cyclic shift 246 | if self.shift_size > 0: 247 | shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) 248 | else: 249 | shifted_x = x 250 | 251 | # partition windows 252 | x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C 253 | x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C 254 | 255 | # W-MSA/SW-MSA 256 | attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C 257 | 258 | # merge windows 259 | attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) 260 | shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C 261 | 262 | # reverse cyclic shift 263 | if self.shift_size > 0: 264 | x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) 265 | else: 266 | x = shifted_x 267 | x = x.view(B, H * W, C) 268 | 269 | # FFN 270 | x = shortcut + self.drop_path(x) 271 | x = x + self.drop_path(self.mlp(self.norm2(x))) 272 | 273 | return x 274 | 275 | def extra_repr(self) -> str: 276 | return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ 277 | f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" 278 | 279 | def flops(self): 280 | flops = 0 281 | H, W = self.input_resolution 282 | # norm1 283 | flops += self.dim * H * W 284 | # W-MSA/SW-MSA 285 | nW = H * W / self.window_size / self.window_size 286 | flops += nW * self.attn.flops(self.window_size * self.window_size) 287 | # mlp 288 | flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio 289 | # norm2 290 | flops += self.dim * H * W 291 | return flops 292 | 293 | 294 | class PatchMerging(nn.Module): 295 | r""" Patch Merging Layer. 296 | Args: 297 | input_resolution (tuple[int]): Resolution of input feature. 298 | dim (int): Number of input channels. 299 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 300 | """ 301 | 302 | def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): 303 | super().__init__() 304 | self.input_resolution = input_resolution 305 | self.dim = dim 306 | self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) 307 | self.norm = norm_layer(4 * dim) 308 | 309 | def forward(self, x): 310 | """ 311 | x: B, H*W, C 312 | """ 313 | H, W = self.input_resolution 314 | B, L, C = x.shape 315 | assert L == H * W, "input feature has wrong size" 316 | assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." 317 | 318 | x = x.view(B, H, W, C) 319 | 320 | x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C 321 | x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C 322 | x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C 323 | x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C 324 | x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C 325 | x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C 326 | 327 | x = self.norm(x) 328 | x = self.reduction(x) 329 | 330 | return x 331 | 332 | def extra_repr(self) -> str: 333 | return f"input_resolution={self.input_resolution}, dim={self.dim}" 334 | 335 | def flops(self): 336 | H, W = self.input_resolution 337 | flops = H * W * self.dim 338 | flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim 339 | return flops 340 | 341 | 342 | class BasicLayer(nn.Module): 343 | """ A basic Swin Transformer layer for one stage. 344 | Args: 345 | dim (int): Number of input channels. 346 | input_resolution (tuple[int]): Input resolution. 347 | depth (int): Number of blocks. 348 | num_heads (int): Number of attention heads. 349 | window_size (int): Local window size. 350 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 351 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 352 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. 353 | drop (float, optional): Dropout rate. Default: 0.0 354 | attn_drop (float, optional): Attention dropout rate. Default: 0.0 355 | drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 356 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 357 | downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None 358 | use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. 359 | """ 360 | 361 | def __init__(self, dim, input_resolution, depth, num_heads, window_size, 362 | mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., 363 | drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False): 364 | 365 | super().__init__() 366 | self.dim = dim 367 | self.input_resolution = input_resolution 368 | self.depth = depth 369 | self.use_checkpoint = use_checkpoint 370 | 371 | # build blocks 372 | self.blocks = nn.ModuleList([ 373 | SwinTransformerBlock(dim=dim, input_resolution=input_resolution, 374 | num_heads=num_heads, window_size=window_size, 375 | shift_size=0 if (i % 2 == 0) else window_size // 2, 376 | mlp_ratio=mlp_ratio, 377 | qkv_bias=qkv_bias, qk_scale=qk_scale, 378 | drop=drop, attn_drop=attn_drop, 379 | drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, 380 | norm_layer=norm_layer) 381 | for i in range(depth)]) 382 | 383 | # patch merging layer 384 | if downsample is not None: 385 | self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) 386 | else: 387 | self.downsample = None 388 | 389 | def forward(self, x): 390 | # print(x.shape, end = " | ") 391 | for blk in self.blocks: 392 | if self.use_checkpoint: 393 | x = checkpoint.checkpoint(blk, x) 394 | else: 395 | x = blk(x) 396 | if self.downsample is not None: 397 | x = self.downsample(x) 398 | # print(x.shape) 399 | return x 400 | 401 | def extra_repr(self) -> str: 402 | return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" 403 | 404 | def flops(self): 405 | flops = 0 406 | for blk in self.blocks: 407 | flops += blk.flops() 408 | if self.downsample is not None: 409 | flops += self.downsample.flops() 410 | return flops 411 | 412 | 413 | ############ DLF ############ 414 | class CrossAttention(nn.Module): 415 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): 416 | super().__init__() 417 | self.num_heads = num_heads 418 | head_dim = dim // num_heads 419 | self.scale = qk_scale or head_dim ** -0.5 420 | 421 | self.wq = nn.Linear(dim, dim, bias=qkv_bias) 422 | self.wk = nn.Linear(dim, dim, bias=qkv_bias) 423 | self.wv = nn.Linear(dim, dim, bias=qkv_bias) 424 | self.attn_drop = nn.Dropout(attn_drop) 425 | self.proj = nn.Linear(dim, dim) 426 | self.proj_drop = nn.Dropout(proj_drop) 427 | 428 | def forward(self, x): 429 | 430 | B, N, C = x.shape 431 | q = self.wq(x[:, 0:1, ...]).reshape(B, 1, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) # B1C -> B1H(C/H) -> BH1(C/H) 432 | k = self.wk(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) # BNC -> BNH(C/H) -> BHN(C/H) 433 | v = self.wv(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) # BNC -> BNH(C/H) -> BHN(C/H) 434 | 435 | attn = (q @ k.transpose(-2, -1)) * self.scale # BH1(C/H) @ BH(C/H)N -> BH1N 436 | attn = attn.softmax(dim=-1) 437 | attn = self.attn_drop(attn) 438 | 439 | x = (attn @ v).transpose(1, 2).reshape(B, 1, C) # (BH1N @ BHN(C/H)) -> BH1(C/H) -> B1H(C/H) -> B1C 440 | x = self.proj(x) 441 | x = self.proj_drop(x) 442 | return x 443 | 444 | 445 | class CrossAttentionBlock(nn.Module): 446 | 447 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 448 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, has_mlp=True): 449 | super().__init__() 450 | self.norm1 = norm_layer(dim) 451 | self.attn = CrossAttention( 452 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) 453 | 454 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 455 | self.has_mlp = has_mlp 456 | if has_mlp: 457 | self.norm2 = norm_layer(dim) 458 | mlp_hidden_dim = int(dim * mlp_ratio) 459 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 460 | 461 | def forward(self, x): 462 | x = x[:, 0:1, ...] + self.drop_path(self.attn(self.norm1(x))) 463 | if self.has_mlp: 464 | x = x + self.drop_path(self.mlp(self.norm2(x))) 465 | 466 | return x 467 | 468 | 469 | class MultiScaleBlock(nn.Module): 470 | 471 | def __init__(self, dim, patches, depth, num_heads, mlp_ratio, qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 472 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): 473 | super().__init__() 474 | 475 | num_branches = len(dim) 476 | self.num_branches = num_branches 477 | 478 | self.blocks = nn.ModuleList() 479 | for d in range(num_branches): 480 | tmp = [] 481 | for i in range(depth[d]): 482 | tmp.append( 483 | Block(dim=dim[d], num_heads=num_heads[d], mlp_ratio=mlp_ratio[d], qkv_bias=qkv_bias, 484 | attn_drop=attn_drop, drop_path=drop_path[i], norm_layer=norm_layer)) 485 | if len(tmp) != 0: 486 | self.blocks.append(nn.Sequential(*tmp)) 487 | 488 | if len(self.blocks) == 0: 489 | self.blocks = None 490 | 491 | self.projs = nn.ModuleList() 492 | for d in range(num_branches): 493 | if dim[d] == dim[(d+1) % num_branches] and False: 494 | tmp = [nn.Identity()] 495 | else: 496 | tmp = [norm_layer(dim[d]), act_layer(), nn.Linear(dim[d], dim[(d+1) % num_branches])] 497 | self.projs.append(nn.Sequential(*tmp)) 498 | 499 | self.fusion = nn.ModuleList() 500 | for d in range(num_branches): 501 | d_ = (d+1) % num_branches 502 | nh = num_heads[d_] 503 | if depth[-1] == 0: # backward capability: 504 | self.fusion.append(CrossAttentionBlock(dim=dim[d_], num_heads=nh, mlp_ratio=mlp_ratio[d], qkv_bias=qkv_bias, qk_scale=qk_scale, 505 | drop=drop, attn_drop=attn_drop, drop_path=drop_path[-1], norm_layer=norm_layer, 506 | has_mlp=False)) 507 | else: 508 | tmp = [] 509 | for _ in range(depth[-1]): 510 | tmp.append(CrossAttentionBlock(dim=dim[d_], num_heads=nh, mlp_ratio=mlp_ratio[d], qkv_bias=qkv_bias, qk_scale=qk_scale, 511 | drop=drop, attn_drop=attn_drop, drop_path=drop_path[-1], norm_layer=norm_layer, 512 | has_mlp=False)) 513 | self.fusion.append(nn.Sequential(*tmp)) 514 | 515 | self.revert_projs = nn.ModuleList() 516 | for d in range(num_branches): 517 | if dim[(d+1) % num_branches] == dim[d] and False: 518 | tmp = [nn.Identity()] 519 | else: 520 | tmp = [norm_layer(dim[(d+1) % num_branches]), act_layer(), nn.Linear(dim[(d+1) % num_branches], dim[d])] 521 | self.revert_projs.append(nn.Sequential(*tmp)) 522 | 523 | def forward(self, x): 524 | inp = x 525 | 526 | # only take the cls token out 527 | proj_cls_token = [proj(x[:, 0:1]) for x, proj in zip(inp, self.projs)] 528 | 529 | # cross attention 530 | outs = [] 531 | for i in range(self.num_branches): 532 | tmp = torch.cat((proj_cls_token[i], inp[(i + 1) % self.num_branches][:, 1:, ...]), dim=1) 533 | tmp = self.fusion[i](tmp) 534 | reverted_proj_cls_token = self.revert_projs[i](tmp[:, 0:1, ...]) 535 | tmp = torch.cat((reverted_proj_cls_token, inp[i][:, 1:, ...]), dim=1) 536 | outs.append(tmp) 537 | 538 | outs_b = [block(x_) for x_, block in zip(outs, self.blocks)] 539 | return outs 540 | 541 | 542 | ############ Test ############ 543 | import numpy as np 544 | import torch 545 | from medpy import metric 546 | from scipy.ndimage import zoom 547 | import torch.nn as nn 548 | import SimpleITK as sitk 549 | 550 | 551 | class DiceLoss(nn.Module): 552 | def __init__(self, n_classes): 553 | super(DiceLoss, self).__init__() 554 | self.n_classes = n_classes 555 | 556 | def _one_hot_encoder(self, input_tensor): 557 | tensor_list = [] 558 | for i in range(self.n_classes): 559 | temp_prob = input_tensor == i # * torch.ones_like(input_tensor) 560 | tensor_list.append(temp_prob.unsqueeze(1)) 561 | output_tensor = torch.cat(tensor_list, dim=1) 562 | return output_tensor.float() 563 | 564 | def _dice_loss(self, score, target): 565 | target = target.float() 566 | smooth = 1e-5 567 | intersect = torch.sum(score * target) 568 | y_sum = torch.sum(target * target) 569 | z_sum = torch.sum(score * score) 570 | loss = (2 * intersect + smooth) / (z_sum + y_sum + smooth) 571 | loss = 1 - loss 572 | return loss 573 | 574 | def forward(self, inputs, target, weight=None, softmax=False): 575 | if softmax: 576 | inputs = torch.softmax(inputs, dim=1) 577 | target = self._one_hot_encoder(target) 578 | if weight is None: 579 | weight = [1] * self.n_classes 580 | assert inputs.size() == target.size(), 'predict {} & target {} shape do not match'.format(inputs.size(), target.size()) 581 | class_wise_dice = [] 582 | loss = 0.0 583 | for i in range(0, self.n_classes): 584 | dice = self._dice_loss(inputs[:, i], target[:, i]) 585 | class_wise_dice.append(1.0 - dice.item()) 586 | loss += dice * weight[i] 587 | return loss / self.n_classes 588 | 589 | 590 | def calculate_metric_percase(pred, gt): 591 | pred[pred > 0] = 1 592 | gt[gt > 0] = 1 593 | if pred.sum() > 0 and gt.sum()>0: 594 | dice = metric.binary.dc(pred, gt) 595 | hd95 = metric.binary.hd95(pred, gt) 596 | return dice, hd95 597 | elif pred.sum() > 0 and gt.sum()==0: 598 | return 1, 0 599 | else: 600 | return 0, 0 601 | 602 | 603 | def test_single_volume(image, label, net, classes, patch_size=[256, 256], test_save_path=None, case=None, z_spacing=1): 604 | image, label = image.squeeze(0).cpu().detach().numpy(), label.squeeze(0).cpu().detach().numpy() 605 | if len(image.shape) == 3: 606 | prediction = np.zeros_like(label) 607 | for ind in range(image.shape[0]): 608 | slice = image[ind, :, :] 609 | x, y = slice.shape[0], slice.shape[1] 610 | if x != patch_size[0] or y != patch_size[1]: 611 | slice = zoom(slice, (patch_size[0] / x, patch_size[1] / y), order=3) # previous using 0 612 | input = torch.from_numpy(slice).unsqueeze(0).unsqueeze(0).float().cuda() 613 | 614 | B, C, H, W = input.shape 615 | input = input.expand(B, 3, H, W) 616 | 617 | net.eval() 618 | with torch.no_grad(): 619 | outputs = net(input) 620 | out = torch.argmax(torch.softmax(outputs, dim=1), dim=1).squeeze(0) 621 | out = out.cpu().detach().numpy() 622 | if x != patch_size[0] or y != patch_size[1]: 623 | pred = zoom(out, (x / patch_size[0], y / patch_size[1]), order=0) 624 | else: 625 | pred = out 626 | prediction[ind] = pred 627 | else: 628 | input = torch.from_numpy(image).unsqueeze( 629 | 0).unsqueeze(0).float().cuda() 630 | net.eval() 631 | with torch.no_grad(): 632 | out = torch.argmax(torch.softmax(net(input), dim=1), dim=1).squeeze(0) 633 | prediction = out.cpu().detach().numpy() 634 | metric_list = [] 635 | for i in range(1, classes): 636 | metric_list.append(calculate_metric_percase(prediction == i, label == i)) 637 | 638 | if test_save_path is not None: 639 | img_itk = sitk.GetImageFromArray(image.astype(np.float32)) 640 | prd_itk = sitk.GetImageFromArray(prediction.astype(np.float32)) 641 | lab_itk = sitk.GetImageFromArray(label.astype(np.float32)) 642 | img_itk.SetSpacing((1, 1, z_spacing)) 643 | prd_itk.SetSpacing((1, 1, z_spacing)) 644 | lab_itk.SetSpacing((1, 1, z_spacing)) 645 | sitk.WriteImage(prd_itk, test_save_path + '/'+case + "_pred.nii.gz") 646 | sitk.WriteImage(img_itk, test_save_path + '/'+ case + "_img.nii.gz") 647 | sitk.WriteImage(lab_itk, test_save_path + '/'+ case + "_gt.nii.gz") 648 | return metric_list 649 | --------------------------------------------------------------------------------