├── .gitignore ├── Ours ├── ASPP.py ├── Base_transformer.py ├── __init__.py ├── __pycache__ │ ├── ASPP.cpython-37.pyc │ ├── Base_transformer.cpython-37.pyc │ ├── __init__.cpython-37.pyc │ ├── base.cpython-37.pyc │ ├── cell_DETR.cpython-37.pyc │ └── resnet.cpython-37.pyc ├── base.py ├── cell_DETR.py ├── non_local.py └── resnet.py ├── README.md ├── dataset ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-37.pyc │ ├── isbi2016.cpython-37.pyc │ └── isic2016.cpython-37.pyc ├── isic2016.py └── isic2018.py ├── framework.jpg ├── infer.py ├── lib ├── Cell_DETR_master │ ├── .gitignore │ ├── DETR_Test.py │ ├── DeepDetr.py │ ├── LICENSE │ ├── __pycache__ │ │ ├── backbone.cpython-37.pyc │ │ ├── bounding_box_head.cpython-37.pyc │ │ ├── detr_new.cpython-37.pyc │ │ ├── segmentation.cpython-37.pyc │ │ └── transformer.cpython-37.pyc │ ├── augmentation.py │ ├── backbone.py │ ├── botr.py │ ├── botr2.py │ ├── bounding_box_head.py │ ├── ccccode.py │ ├── dataset.py │ ├── dcn2 │ │ ├── .gitignore │ │ ├── LICENSE │ │ ├── README.md │ │ ├── build.py │ │ ├── build_modulated.py │ │ ├── functions │ │ │ ├── __init__.py │ │ │ ├── deform_conv.py │ │ │ └── modulated_dcn_func.py │ │ ├── modules │ │ │ ├── __init__.py │ │ │ ├── deform_conv.py │ │ │ └── modulated_dcn.py │ │ ├── src │ │ │ ├── cuda │ │ │ │ ├── deform_psroi_pooling_cuda.cu │ │ │ │ ├── deform_psroi_pooling_cuda.h │ │ │ │ ├── modulated_deform_im2col_cuda.cu │ │ │ │ └── modulated_deform_im2col_cuda.h │ │ │ ├── deform_conv.c │ │ │ ├── deform_conv.h │ │ │ ├── deform_conv_cuda.c │ │ │ ├── deform_conv_cuda.h │ │ │ ├── deform_conv_cuda_kernel.cu │ │ │ ├── deform_conv_cuda_kernel.h │ │ │ ├── modulated_dcn.c │ │ │ ├── modulated_dcn.h │ │ │ ├── modulated_dcn_cuda.c │ │ │ └── modulated_dcn_cuda.h │ │ ├── test.py │ │ └── test_modulated.py │ ├── detr.py │ ├── detr_new.py │ ├── images │ │ └── CELL_DETR.PNG │ ├── lossfunction.py │ ├── main.py │ ├── matcher.py │ ├── misc.py │ ├── model_wrapper.py │ ├── pade_activation_unit │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-37.pyc │ │ │ └── utils.cpython-37.pyc │ │ ├── cuda │ │ │ ├── pau_cuda.cpp │ │ │ ├── pau_cuda_kernels.cu │ │ │ ├── python_imp │ │ │ │ ├── Pade.py │ │ │ │ └── __init__.py │ │ │ └── setup.py │ │ ├── torchsummary.py │ │ └── utils.py │ ├── pixel_adaptive_convolution │ │ ├── README.md │ │ ├── __pycache__ │ │ │ └── pac.cpython-37.pyc │ │ ├── pac.py │ │ ├── paccrf.py │ │ ├── requirements.txt │ │ ├── test_pac.py │ │ └── tools │ │ │ ├── flowlib.py │ │ │ └── plot_log.py │ ├── requirements.txt │ ├── segmentation.py │ ├── sgtr.py │ ├── transformer.py │ └── validation_metric.py └── non_local │ ├── non_local_concatenation.py │ ├── non_local_dot_product.py │ ├── non_local_embedded_gaussian.py │ └── non_local_gaussian.py ├── src ├── BAT_Modules.py ├── __pycache__ │ ├── BAT_Modules.cpython-37.pyc │ ├── losses.cpython-37.pyc │ ├── transformer.cpython-37.pyc │ └── utils.cpython-37.pyc ├── losses.py ├── process_point.py ├── process_resize.py ├── transformer.py └── utils.py ├── test.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | # test module 个人习惯,测试目录去掉 2 | Cell_DETR 3 | weights 4 | logs/ 5 | **/__pycache__ 6 | **/*.pyc 7 | <<<<<<< HEAD 8 | <<<<<<< HEAD 9 | build/ 10 | ======= 11 | >>>>>>> b37c2743a0c135f1baaeddad6207f5e071d770d9 12 | ======= 13 | >>>>>>> b37c2743a0c135f1baaeddad6207f5e071d770d9 14 | -------------------------------------------------------------------------------- /Ours/ASPP.py: -------------------------------------------------------------------------------- 1 | # camera-ready 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | class ASPP(nn.Module): 8 | def __init__(self, num_classes, head = True): 9 | super(ASPP, self).__init__() 10 | 11 | self.conv_1x1_1 = nn.Conv2d(512, 256, kernel_size=1) 12 | self.bn_conv_1x1_1 = nn.BatchNorm2d(256) 13 | 14 | self.conv_3x3_1 = nn.Conv2d(512, 256, kernel_size=3, stride=1, padding=6, dilation=6) 15 | self.bn_conv_3x3_1 = nn.BatchNorm2d(256) 16 | 17 | self.conv_3x3_2 = nn.Conv2d(512, 256, kernel_size=3, stride=1, padding=12, dilation=12) 18 | self.bn_conv_3x3_2 = nn.BatchNorm2d(256) 19 | 20 | self.conv_3x3_3 = nn.Conv2d(512, 256, kernel_size=3, stride=1, padding=18, dilation=18) 21 | self.bn_conv_3x3_3 = nn.BatchNorm2d(256) 22 | 23 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 24 | 25 | self.conv_1x1_2 = nn.Conv2d(512, 256, kernel_size=1) 26 | self.bn_conv_1x1_2 = nn.BatchNorm2d(256) 27 | 28 | self.conv_1x1_3 = nn.Conv2d(1280, 256, kernel_size=1) # (1280 = 5*256) 29 | self.bn_conv_1x1_3 = nn.BatchNorm2d(256) 30 | 31 | if head: 32 | self.conv_1x1_4 = nn.Conv2d(256, num_classes, kernel_size=1) 33 | self.head = head 34 | 35 | def forward(self, feature_map): 36 | # (feature_map has shape (batch_size, 512, h/16, w/16)) (assuming self.resnet is ResNet18_OS16 or ResNet34_OS16. If self.resnet instead is ResNet18_OS8 or ResNet34_OS8, it will be (batch_size, 512, h/8, w/8)) 37 | 38 | feature_map_h = feature_map.size()[2] # (== h/16) 39 | feature_map_w = feature_map.size()[3] # (== w/16) 40 | 41 | out_1x1 = F.relu(self.bn_conv_1x1_1(self.conv_1x1_1(feature_map))) # (shape: (batch_size, 256, h/16, w/16)) 42 | out_3x3_1 = F.relu(self.bn_conv_3x3_1(self.conv_3x3_1(feature_map))) # (shape: (batch_size, 256, h/16, w/16)) 43 | out_3x3_2 = F.relu(self.bn_conv_3x3_2(self.conv_3x3_2(feature_map))) # (shape: (batch_size, 256, h/16, w/16)) 44 | out_3x3_3 = F.relu(self.bn_conv_3x3_3(self.conv_3x3_3(feature_map))) # (shape: (batch_size, 256, h/16, w/16)) 45 | 46 | out_img = self.avg_pool(feature_map) # (shape: (batch_size, 512, 1, 1)) 47 | out_img = F.relu(self.bn_conv_1x1_2(self.conv_1x1_2(out_img))) # (shape: (batch_size, 256, 1, 1)) 48 | out_img = F.interpolate(out_img, size=(feature_map_h, feature_map_w), mode="bilinear") # (shape: (batch_size, 256, h/16, w/16)) 49 | 50 | out = torch.cat([out_1x1, out_3x3_1, out_3x3_2, out_3x3_3, out_img], 1) # (shape: (batch_size, 1280, h/16, w/16)) 51 | out = F.relu(self.bn_conv_1x1_3(self.conv_1x1_3(out))) # (shape: (batch_size, 256, h/16, w/16)) 52 | if self.head: 53 | out = self.conv_1x1_4(out) # (shape: (batch_size, num_classes, h/16, w/16)) 54 | 55 | return out 56 | 57 | class ASPP_Bottleneck(nn.Module): 58 | def __init__(self, num_classes): 59 | super(ASPP_Bottleneck, self).__init__() 60 | 61 | self.conv_1x1_1 = nn.Conv2d(4*512, 256, kernel_size=1) 62 | self.bn_conv_1x1_1 = nn.BatchNorm2d(256) 63 | 64 | self.conv_3x3_1 = nn.Conv2d(4*512, 256, kernel_size=3, stride=1, padding=6, dilation=6) 65 | self.bn_conv_3x3_1 = nn.BatchNorm2d(256) 66 | 67 | self.conv_3x3_2 = nn.Conv2d(4*512, 256, kernel_size=3, stride=1, padding=12, dilation=12) 68 | self.bn_conv_3x3_2 = nn.BatchNorm2d(256) 69 | 70 | self.conv_3x3_3 = nn.Conv2d(4*512, 256, kernel_size=3, stride=1, padding=18, dilation=18) 71 | self.bn_conv_3x3_3 = nn.BatchNorm2d(256) 72 | 73 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 74 | 75 | self.conv_1x1_2 = nn.Conv2d(4*512, 256, kernel_size=1) 76 | self.bn_conv_1x1_2 = nn.BatchNorm2d(256) 77 | 78 | self.conv_1x1_3 = nn.Conv2d(1280, 256, kernel_size=1) # (1280 = 5*256) 79 | self.bn_conv_1x1_3 = nn.BatchNorm2d(256) 80 | 81 | self.conv_1x1_4 = nn.Conv2d(256, num_classes, kernel_size=1) 82 | 83 | def forward(self, feature_map): 84 | # (feature_map has shape (batch_size, 4*512, h/16, w/16)) 85 | 86 | feature_map_h = feature_map.size()[2] # (== h/16) 87 | feature_map_w = feature_map.size()[3] # (== w/16) 88 | 89 | out_1x1 = F.relu(self.bn_conv_1x1_1(self.conv_1x1_1(feature_map))) # (shape: (batch_size, 256, h/16, w/16)) 90 | out_3x3_1 = F.relu(self.bn_conv_3x3_1(self.conv_3x3_1(feature_map))) # (shape: (batch_size, 256, h/16, w/16)) 91 | out_3x3_2 = F.relu(self.bn_conv_3x3_2(self.conv_3x3_2(feature_map))) # (shape: (batch_size, 256, h/16, w/16)) 92 | out_3x3_3 = F.relu(self.bn_conv_3x3_3(self.conv_3x3_3(feature_map))) # (shape: (batch_size, 256, h/16, w/16)) 93 | 94 | out_img = self.avg_pool(feature_map) # (shape: (batch_size, 512, 1, 1)) 95 | out_img = F.relu(self.bn_conv_1x1_2(self.conv_1x1_2(out_img))) # (shape: (batch_size, 256, 1, 1)) 96 | out_img = F.interpolate(out_img, size=(feature_map_h, feature_map_w), mode="bilinear") # (shape: (batch_size, 256, h/16, w/16)) 97 | 98 | out = torch.cat([out_1x1, out_3x3_1, out_3x3_2, out_3x3_3, out_img], 1) # (shape: (batch_size, 1280, h/16, w/16)) 99 | out = F.relu(self.bn_conv_1x1_3(self.conv_1x1_3(out))) # (shape: (batch_size, 256, h/16, w/16)) 100 | out = self.conv_1x1_4(out) # (shape: (batch_size, num_classes, h/16, w/16)) 101 | 102 | return out 103 | -------------------------------------------------------------------------------- /Ours/Base_transformer.py: -------------------------------------------------------------------------------- 1 | import sys, os 2 | 3 | root_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), '..') 4 | sys.path.insert(0, os.path.join(root_path)) 5 | sys.path.insert(0, os.path.join(root_path, 'lib')) 6 | sys.path.insert(0, os.path.join(root_path, 'lib/Cell_DETR_master')) 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | 12 | from Ours.base import DeepLabV3 as base 13 | 14 | from src.BAT_Modules import BoundaryCrossAttention, CrossAttention 15 | from src.BAT_Modules import MultiHeadAttention as Attention_head 16 | from src.transformer import BoundaryAwareTransformer, Transformer 17 | 18 | 19 | class BAT(nn.Module): 20 | def __init__( 21 | self, 22 | num_classes, 23 | num_layers, 24 | point_pred, 25 | decoder=False, 26 | transformer_type_index=0, 27 | hidden_features=128, # 256 28 | number_of_query_positions=1, 29 | segmentation_attention_heads=8): 30 | 31 | super(BAT, self).__init__() 32 | 33 | self.num_classes = num_classes 34 | self.point_pred = point_pred 35 | self.transformer_type = "BoundaryAwareTransformer" if transformer_type_index == 0 else "Transformer" 36 | self.use_decoder = decoder 37 | 38 | self.deeplab = base(num_classes, num_layers) 39 | 40 | in_channels = 2048 if num_layers == 50 else 512 41 | 42 | self.convolution_mapping = nn.Conv2d(in_channels=in_channels, 43 | out_channels=hidden_features, 44 | kernel_size=(1, 1), 45 | stride=(1, 1), 46 | padding=(0, 0), 47 | bias=True) 48 | 49 | self.query_positions = nn.Parameter(data=torch.randn( 50 | number_of_query_positions, hidden_features, dtype=torch.float), 51 | requires_grad=True) 52 | 53 | self.row_embedding = nn.Parameter(data=torch.randn(100, 54 | hidden_features // 55 | 2, 56 | dtype=torch.float), 57 | requires_grad=True) 58 | self.column_embedding = nn.Parameter(data=torch.randn( 59 | 100, hidden_features // 2, dtype=torch.float), 60 | requires_grad=True) 61 | 62 | self.transformer = [ 63 | Transformer(d_model=hidden_features), 64 | BoundaryAwareTransformer(d_model=hidden_features) 65 | ][point_pred] 66 | 67 | if self.use_decoder: 68 | self.BCA = BoundaryCrossAttention(hidden_features, 8) 69 | 70 | self.trans_out_conv = nn.Conv2d(in_channels=hidden_features, 71 | out_channels=in_channels, 72 | kernel_size=(1, 1), 73 | stride=(1, 1), 74 | padding=(0, 0), 75 | bias=True) 76 | 77 | def forward(self, x): 78 | h = x.size()[2] 79 | w = x.size()[3] 80 | feature_map = self.deeplab.resnet(x) 81 | 82 | features = self.convolution_mapping(feature_map) 83 | height, width = features.shape[2:] 84 | batch_size = features.shape[0] 85 | positional_embeddings = torch.cat([ 86 | self.column_embedding[:height].unsqueeze(dim=0).repeat( 87 | height, 1, 1), 88 | self.row_embedding[:width].unsqueeze(dim=1).repeat(1, width, 1) 89 | ], 90 | dim=-1).permute( 91 | 2, 0, 1).unsqueeze(0).repeat( 92 | batch_size, 1, 1, 1) 93 | 94 | if self.transformer_type == 'BoundaryAwareTransformer': 95 | latent_tensor, features_encoded, point_maps = self.transformer( 96 | features, None, self.query_positions, positional_embeddings) 97 | else: 98 | latent_tensor, features_encoded = self.transformer( 99 | features, None, self.query_positions, positional_embeddings) 100 | point_maps = [] 101 | 102 | latent_tensor = latent_tensor.permute(2, 0, 1) 103 | # shape:(bs, 1 , 128) 104 | 105 | if self.use_decoder: 106 | features_encoded, point_dec = self.BCA(features_encoded, 107 | latent_tensor) 108 | point_maps.append(point_dec) 109 | 110 | trans_feature_maps = self.trans_out_conv( 111 | features_encoded.contiguous()) #.contiguous() 112 | 113 | trans_feature_maps = trans_feature_maps + feature_map 114 | 115 | output = self.deeplab.aspp( 116 | trans_feature_maps 117 | ) # (shape: (batch_size, num_classes, h/16, w/16)) 118 | output = F.interpolate( 119 | output, size=(h, w), 120 | mode="bilinear") # (shape: (batch_size, num_classes, h, w)) 121 | 122 | if self.point_pred == 1: 123 | return output, point_maps 124 | 125 | return output -------------------------------------------------------------------------------- /Ours/__init__.py: -------------------------------------------------------------------------------- 1 | # from .Cell_DETR_master.sgtr import CellDETR 2 | -------------------------------------------------------------------------------- /Ours/__pycache__/ASPP.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jcwang123/BA-Transformer/d1ccdff68beac82d18c2cd64bb800e51d7afc5e1/Ours/__pycache__/ASPP.cpython-37.pyc -------------------------------------------------------------------------------- /Ours/__pycache__/Base_transformer.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jcwang123/BA-Transformer/d1ccdff68beac82d18c2cd64bb800e51d7afc5e1/Ours/__pycache__/Base_transformer.cpython-37.pyc -------------------------------------------------------------------------------- /Ours/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jcwang123/BA-Transformer/d1ccdff68beac82d18c2cd64bb800e51d7afc5e1/Ours/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /Ours/__pycache__/base.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jcwang123/BA-Transformer/d1ccdff68beac82d18c2cd64bb800e51d7afc5e1/Ours/__pycache__/base.cpython-37.pyc -------------------------------------------------------------------------------- /Ours/__pycache__/cell_DETR.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jcwang123/BA-Transformer/d1ccdff68beac82d18c2cd64bb800e51d7afc5e1/Ours/__pycache__/cell_DETR.cpython-37.pyc -------------------------------------------------------------------------------- /Ours/__pycache__/resnet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jcwang123/BA-Transformer/d1ccdff68beac82d18c2cd64bb800e51d7afc5e1/Ours/__pycache__/resnet.cpython-37.pyc -------------------------------------------------------------------------------- /Ours/base.py: -------------------------------------------------------------------------------- 1 | # camera-ready 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import os 7 | import sys 8 | # sys.path.insert(0, '../') 9 | 10 | from Ours.resnet import ResNet18_OS16, ResNet34_OS16, ResNet50_OS16, ResNet101_OS16, ResNet152_OS16, ResNet18_OS8, ResNet34_OS8 11 | from Ours.ASPP import ASPP, ASPP_Bottleneck 12 | 13 | 14 | class DeepLabV3(nn.Module): 15 | def __init__(self, num_classes, num_layers): 16 | super(DeepLabV3, self).__init__() 17 | 18 | self.num_classes = num_classes 19 | layers = num_layers 20 | # NOTE! specify the type of ResNet here 21 | # NOTE! if you use ResNet50-152, set self.aspp = ASPP_Bottleneck(num_classes=self.num_classes) instead 22 | if layers == 18: 23 | self.resnet = ResNet18_OS16() 24 | self.aspp = ASPP(num_classes=self.num_classes) 25 | elif layers == 50: 26 | self.resnet = ResNet50_OS16() 27 | self.aspp = ASPP_Bottleneck(num_classes=self.num_classes) 28 | 29 | def forward(self, x): 30 | # (x has shape (batch_size, 3, h, w)) 31 | h = x.size()[2] 32 | w = x.size()[3] 33 | feature_map = self.resnet(x) 34 | 35 | # (shape: (batch_size, 512, h/16, w/16)) (assuming self.resnet is ResNet18_OS16 or ResNet34_OS16. 36 | # If self.resnet is ResNet18_OS8 or ResNet34_OS8, it will be (batch_size, 512, h/8, w/8). 37 | # If self.resnet is ResNet50-152, it will be (batch_size, 4*512, h/16, w/16)) 38 | output = self.aspp( 39 | feature_map) # (shape: (batch_size, num_classes, h/16, w/16)) 40 | output = F.upsample( 41 | output, size=(h, w), 42 | mode="bilinear") # (shape: (batch_size, num_classes, h, w)) 43 | return output 44 | 45 | 46 | if __name__ == '__main__': 47 | model = DeepLabV3(12).cuda() 48 | d = torch.rand((2, 3, 384, 384)).cuda() 49 | o = model(d) 50 | print(o.size()) -------------------------------------------------------------------------------- /Ours/cell_DETR.py: -------------------------------------------------------------------------------- 1 | import sys, os 2 | 3 | root_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), '..') 4 | sys.path.insert(0, root_path) 5 | sys.path.insert(0, os.path.join(root_path, 'lib/Cell_DETR_master')) 6 | 7 | import torch 8 | import torch.nn as nn 9 | 10 | from detr_new import CellDETR 11 | from segmentation import ResFeaturePyramidBlock, ResPACFeaturePyramidBlock 12 | #from transformer import BoundaryAwareTransformer, Transformer 13 | from src.transformer import BoundaryAwareTransformer, Transformer 14 | 15 | type_of_transformer = [BoundaryAwareTransformer, Transformer] 16 | 17 | 18 | def cell_detr_128(pretrained=False, type_index=0): 19 | 20 | detr = CellDETR(num_classes=2, 21 | transformer_type=type_of_transformer[type_index], 22 | segmentation_head_block=ResFeaturePyramidBlock, 23 | segmentation_head_final_activation=nn.Softmax, 24 | backbone_convolution=nn.Conv2d, 25 | segmentation_head_convolution=nn.Conv2d, 26 | transformer_activation=nn.LeakyReLU, 27 | backbone_activation=nn.LeakyReLU, 28 | bounding_box_head_activation=nn.LeakyReLU, 29 | classification_head_activation=nn.LeakyReLU, 30 | segmentation_head_activation=nn.LeakyReLU) 31 | if pretrained: 32 | ckpt = torch.load( 33 | "/home/chenfei/my_codes/TransformerCode-master/lib/Cell_DETR_master/trained_models/Cell_DETR_A/detr_99.pt" 34 | ) 35 | detr.load_state_dict(ckpt, strict=False) 36 | print("pretrained cell_detr's transformer!") 37 | return detr.transformer -------------------------------------------------------------------------------- /Ours/non_local.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | 5 | class _NonLocalBlockND(nn.Module): 6 | def __init__(self, in_channels, inter_channels=None, dimension=3, sub_sample=True, bn_layer=True): 7 | """ 8 | :param in_channels: 9 | :param inter_channels: 10 | :param dimension: 11 | :param sub_sample: 12 | :param bn_layer: 13 | """ 14 | 15 | super(_NonLocalBlockND, self).__init__() 16 | 17 | assert dimension in [1, 2, 3] 18 | 19 | self.dimension = dimension 20 | self.sub_sample = sub_sample 21 | 22 | self.in_channels = in_channels 23 | self.inter_channels = inter_channels 24 | 25 | if self.inter_channels is None: 26 | self.inter_channels = in_channels // 2 27 | if self.inter_channels == 0: 28 | self.inter_channels = 1 29 | 30 | if dimension == 3: 31 | conv_nd = nn.Conv3d 32 | max_pool_layer = nn.MaxPool3d(kernel_size=(1, 2, 2)) 33 | bn = nn.BatchNorm3d 34 | elif dimension == 2: 35 | conv_nd = nn.Conv2d 36 | max_pool_layer = nn.MaxPool2d(kernel_size=(2, 2)) 37 | bn = nn.BatchNorm2d 38 | else: 39 | conv_nd = nn.Conv1d 40 | max_pool_layer = nn.MaxPool1d(kernel_size=(2)) 41 | bn = nn.BatchNorm1d 42 | 43 | self.g = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, 44 | kernel_size=1, stride=1, padding=0) 45 | 46 | if bn_layer: 47 | self.W = nn.Sequential( 48 | conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels, 49 | kernel_size=1, stride=1, padding=0), 50 | bn(self.in_channels) 51 | ) 52 | nn.init.constant_(self.W[1].weight, 0) 53 | nn.init.constant_(self.W[1].bias, 0) 54 | else: 55 | self.W = conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels, 56 | kernel_size=1, stride=1, padding=0) 57 | nn.init.constant_(self.W.weight, 0) 58 | nn.init.constant_(self.W.bias, 0) 59 | 60 | self.theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, 61 | kernel_size=1, stride=1, padding=0) 62 | self.phi = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, 63 | kernel_size=1, stride=1, padding=0) 64 | 65 | if sub_sample: 66 | self.g = nn.Sequential(self.g, max_pool_layer) 67 | self.phi = nn.Sequential(self.phi, max_pool_layer) 68 | 69 | def forward(self, x, y, return_nl_map=False): 70 | """ 71 | :param x: (b, c, h, w) 72 | :param y: (b, c, 1) 73 | :param return_nl_map: if True return z, nl_map, else only return z. 74 | :return: 75 | """ 76 | 77 | batch_size = x.size(0) 78 | h, w = x.shape[2:] 79 | 80 | g_x = self.g(x).view(batch_size, self.inter_channels, -1) 81 | #g_x = g_x.permute(0, 2, 1) 82 | 83 | theta_x = self.theta(x).view(batch_size, self.inter_channels, -1) 84 | theta_x = theta_x.permute(0, 2, 1) 85 | 86 | #phi_x = self.phi(x).view(batch_size, self.inter_channels, -1) 87 | phi_x = self.phi(y.unsqueeze(-1)).view(batch_size, self.inter_channels, -1) 88 | f = torch.matmul(theta_x, phi_x) 89 | #f_div_C = F.softmax(f, dim=-1) 90 | f_div_C = torch.sigmoid(f) 91 | f_div_C = f_div_C.permute(0,2,1).contiguous() 92 | 93 | #y = torch.matmul(f_div_C, g_x) 94 | #y = y.permute(0, 2, 1).contiguous() 95 | y = g_x * f_div_C 96 | y = y.view(batch_size, self.inter_channels, *x.size()[2:]) 97 | W_y = self.W(y) 98 | z = W_y + x 99 | 100 | if return_nl_map: 101 | return z, f_div_C.view(batch_size, 1, h, w) 102 | return z 103 | 104 | class NONLocalBlock2D(_NonLocalBlockND): 105 | def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True): 106 | super(NONLocalBlock2D, self).__init__(in_channels, 107 | inter_channels=inter_channels, 108 | dimension=2, sub_sample=sub_sample, 109 | bn_layer=bn_layer,) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Boundary-aware Transformers for Skin Lesion Segmentation 2 | 3 | ## Introduction 4 | 5 | This is an official release of the paper **Boundary-aware Transformers for Skin Lesion Segmentation**. 6 | 7 | > [**Boundary-aware Transformers for Skin Lesion Segmentation**](https://arxiv.org/abs/2110.03864),
8 | > **Jiacheng Wang**, Lan Wei, Liansheng Wang, Qichao Zhou, Lei Zhu, Jing Qin
9 | > In: Medical Image Computing and Computer Assisted Intervention (MICCAI), 2021
10 | > [[arXiv](https://arxiv.org/abs/2110.03864)][[Bibetex](https://github.com/jcwang123/BA-Transformer#citation)] 11 | 12 |
13 | 14 | ## News 15 | - **[5/27 2022] We have released a more powerful [XBound-Former](https://github.com/jcwang123/xboundformer) with clearer concept and codes.** 16 | - **[11/15 2021] We have released the point map data.** 17 | - **[11/08 2021] We have released the training / testing codes.** 18 | 19 | ## Code List 20 | 21 | - [x] Network 22 | - [x] Pre-processing 23 | - [x] Training Codes 24 | - [ ] MS 25 | 26 | For more details or any questions, please feel easy to contact us by email (jiachengw@stu.xmu.edu.cn). 27 | 28 | 29 | ## Usage 30 | 31 | ### Dataset 32 | 33 | Please download the dataset from [ISIC](https://www.isic-archive.com/) challenge and [PH2](https://www.fc.up.pt/addi/ph2%20database.html) website. 34 | 35 | ### Pre-processing 36 | 37 | Please run: 38 | 39 | ```bash 40 | $ python src/process_resize.py 41 | $ python src/process_point.py 42 | ``` 43 | 44 | You need to change the **File Path** to your own. 45 | 46 | ### Point Maps 47 | 48 | For your convenience, we release the processed maps and the dataset division. 49 | 50 | Please download them from [Baidu Disk](https://pan.baidu.com/s/1pNbH5zUI8Dw_ZAC8Iq9f7w) (code:**kmqr**) or [Google Drive](https://drive.google.com/file/d/1mSLt-ipLM9CxrfvwgjJr5V9NKrpnQaQ5/view?usp=sharing) 51 | 52 | The file names are equal to the original image names. 53 | 54 | ### Training 55 | 56 | ### Testing 57 | 58 | Download the pretrained weight for PH2 dataset from [Google Drive](https://drive.google.com/file/d/1-eMHYX1fr-QvI3n50S0xqWcxc3FGsMgE/view?usp=sharing). 59 | 60 | ```bash 61 | $ python test.py --dataset isic2016 62 | ``` 63 | 64 | ### Result 65 | 66 | |Method | Dice | IoU | HD95 | ASSD| 67 | | ------ | ------ | ------ |------ |------ | 68 | Lee *et al.* | 0.918 | 0.843 | - | - | 69 | BAT (paper)| 0.921 | 0.858 | - | - | 70 | 71 | 72 | 73 | 74 | ## Citation 75 | 76 | If you find BAT useful in your research, please consider citing: 77 | 78 | ``` 79 | @inproceedings{wang2021boundary, 80 | title={Boundary-Aware Transformers for Skin Lesion Segmentation}, 81 | author={Wang, Jiacheng and Wei, Lan and Wang, Liansheng and Zhou, Qichao and Zhu, Lei and Qin, Jing}, 82 | booktitle={International Conference on Medical Image Computing and Computer-Assisted Intervention}, 83 | pages={206--216}, 84 | year={2021}, 85 | organization={Springer} 86 | } 87 | ``` 88 | -------------------------------------------------------------------------------- /dataset/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jcwang123/BA-Transformer/d1ccdff68beac82d18c2cd64bb800e51d7afc5e1/dataset/__init__.py -------------------------------------------------------------------------------- /dataset/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jcwang123/BA-Transformer/d1ccdff68beac82d18c2cd64bb800e51d7afc5e1/dataset/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /dataset/__pycache__/isbi2016.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jcwang123/BA-Transformer/d1ccdff68beac82d18c2cd64bb800e51d7afc5e1/dataset/__pycache__/isbi2016.cpython-37.pyc -------------------------------------------------------------------------------- /dataset/__pycache__/isic2016.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jcwang123/BA-Transformer/d1ccdff68beac82d18c2cd64bb800e51d7afc5e1/dataset/__pycache__/isic2016.cpython-37.pyc -------------------------------------------------------------------------------- /dataset/isic2016.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import random 4 | import torch 5 | import torch.nn as nn 6 | import numpy as np 7 | import torch.utils.data 8 | from torchvision import transforms 9 | import torch.utils.data as data 10 | import torch.nn.functional as F 11 | 12 | import albumentations as A 13 | 14 | 15 | def norm01(x): 16 | return np.clip(x, 0, 255) / 255 17 | 18 | 19 | def filter_image(p): 20 | label_data = np.load(p.replace('image', 'label')) 21 | return np.max(label_data) == 1 22 | 23 | 24 | class myDataset(data.Dataset): 25 | def __init__(self, split, aug=False): 26 | super(myDataset, self).__init__() 27 | 28 | self.image_paths = [] 29 | self.label_paths = [] 30 | self.point_paths = [] 31 | self.dist_paths = [] 32 | 33 | root_dir = '/raid/wjc/data/skin_lesion/isic2016/' 34 | if split == 'train': 35 | self.image_paths = glob.glob(root_dir + '/Train/Image/*.npy') 36 | self.label_paths = glob.glob(root_dir + '/Train/Label/*.npy') 37 | self.point_paths = glob.glob(root_dir + '/Train/Point/*.npy') 38 | elif split == 'valid': 39 | self.image_paths = glob.glob(root_dir + '/Validation/Image/*.npy') 40 | self.label_paths = glob.glob(root_dir + '/Validation/Label/*.npy') 41 | self.point_paths = glob.glob(root_dir + '/Validation/Point/*.npy') 42 | elif split == 'test': 43 | self.image_paths = glob.glob(root_dir + '/Test/Image/*.npy') 44 | self.label_paths = glob.glob(root_dir + '/Test/Label/*.npy') 45 | self.point_paths = glob.glob(root_dir + '/Test/Point/*.npy') 46 | self.image_paths.sort() 47 | self.label_paths.sort() 48 | self.point_paths.sort() 49 | 50 | print('Loaded {} frames'.format(len(self.image_paths))) 51 | self.num_samples = len(self.image_paths) 52 | self.aug = aug 53 | 54 | self.transf = A.Compose([ 55 | A.HorizontalFlip(p=0.5), 56 | A.VerticalFlip(p=0.5), 57 | A.RandomBrightnessContrast(p=0.2), 58 | A.Rotate() 59 | ]) 60 | 61 | def __getitem__(self, index): 62 | 63 | image_data = np.load(self.image_paths[index]) 64 | label_data = np.load(self.label_paths[index]) > 0.5 65 | point_data = np.load(self.point_paths[index]) 66 | 67 | if self.aug: 68 | mask = np.concatenate([ 69 | label_data[..., np.newaxis].astype('uint8'), 70 | point_data[..., np.newaxis] 71 | ], 72 | axis=-1) 73 | # print(mask.shape) 74 | tsf = self.transf(image=image_data.astype('uint8'), mask=mask) 75 | image_data, mask_aug = tsf['image'], tsf['mask'] 76 | label_data = mask_aug[:, :, 0] 77 | point_data = mask_aug[:, :, 1] 78 | 79 | image_data = norm01(image_data) 80 | label_data = np.expand_dims(label_data, 0) 81 | point_data = np.expand_dims(point_data, 0) 82 | image_data = torch.from_numpy(image_data).float() 83 | label_data = torch.from_numpy(label_data).float() 84 | point_data = torch.from_numpy(point_data).float() 85 | 86 | image_data = image_data.permute(2, 0, 1) 87 | return { 88 | 'image_path': self.image_paths[index], 89 | 'label_path': self.label_paths[index], 90 | 'point_path': self.point_paths[index], 91 | 'image': image_data, 92 | 'label': label_data, 93 | 'point': point_data 94 | } 95 | 96 | def __len__(self): 97 | return self.num_samples -------------------------------------------------------------------------------- /dataset/isic2018.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import json 4 | import torch 5 | import random 6 | import torch.nn as nn 7 | import numpy as np 8 | import torch.utils.data 9 | from torchvision import transforms 10 | import torch.utils.data as data 11 | import torch.nn.functional as F 12 | 13 | import albumentations as A 14 | from sklearn.model_selection import KFold 15 | 16 | 17 | def norm01(x): 18 | return np.clip(x, 0, 255) / 255 19 | 20 | 21 | seperable_indexes = json.load(open('dataset/data_split.json', 'r')) 22 | 23 | 24 | # cross validation 25 | class myDataset(data.Dataset): 26 | def __init__(self, fold, split, aug=False): 27 | super(myDataset, self).__init__() 28 | self.split = split 29 | root_data_dir = '/raid/wjc/data/skin_lesion/isic2018/' 30 | 31 | # load images, label, point 32 | self.image_paths = [] 33 | self.label_paths = [] 34 | self.point_paths = [] 35 | self.dist_paths = [] 36 | 37 | indexes = [l[:-4] for l in os.listdir(root_data_dir + 'Image/')] 38 | valid_indexes = seperable_indexes[fold] 39 | 40 | train_indexes = list(filter(lambda x: x not in valid_indexes, indexes)) 41 | print('Fold {}: train: {} valid: {}'.format(fold, len(train_indexes), 42 | len(valid_indexes))) 43 | 44 | indexes = train_indexes if split == 'train' else valid_indexes 45 | 46 | self.image_paths = [ 47 | root_data_dir + '/Image/{}.npy'.format(_id) for _id in indexes 48 | ] 49 | self.label_paths = [ 50 | root_data_dir + '/Label/{}.npy'.format(_id) for _id in indexes 51 | ] 52 | self.point_paths = [ 53 | root_data_dir + '/Point/{}.npy'.format(_id) for _id in indexes 54 | ] 55 | # self.point_All_paths = [ 56 | # '/data2/cf_data/skinlesion_segment/ISIC2018_rawdata/ISBI_2018/Train/Point_All/{}.npy' 57 | # .format(_id) for _id in indexes 58 | # ] 59 | 60 | print('Loaded {} frames'.format(len(self.image_paths))) 61 | self.num_samples = len(self.image_paths) 62 | self.aug = aug 63 | 64 | p = 0.5 65 | self.transf = A.Compose([ 66 | A.GaussNoise(p=p), 67 | A.HorizontalFlip(p=p), 68 | A.VerticalFlip(p=p), 69 | A.ShiftScaleRotate(p=p), 70 | # A.RandomBrightnessContrast(p=p), 71 | ]) 72 | 73 | def __getitem__(self, index): 74 | 75 | image_data = np.load(self.image_paths[index]) 76 | label_data = np.load(self.label_paths[index]) > 0.5 77 | point_data = np.load(self.point_paths[index]) > 0.5 78 | 79 | if self.aug and self.split == 'train': 80 | mask = np.concatenate([ 81 | label_data[..., np.newaxis].astype('uint8'), 82 | point_data[..., np.newaxis] 83 | ], 84 | axis=-1) 85 | # print(mask.shape) 86 | tsf = self.transf(image=image_data.astype('uint8'), mask=mask) 87 | image_data, mask_aug = tsf['image'], tsf['mask'] 88 | label_data = mask_aug[:, :, 0] 89 | point_data = mask_aug[:, :, 1] 90 | 91 | image_data = norm01(image_data) 92 | label_data = np.expand_dims(label_data, 0) 93 | point_data = np.expand_dims(point_data, 0) 94 | 95 | image_data = torch.from_numpy(image_data).float() 96 | label_data = torch.from_numpy(label_data).float() 97 | point_data = torch.from_numpy(point_data).float() 98 | 99 | image_data = image_data.permute(2, 0, 1) 100 | return { 101 | 'image_path': self.image_paths[index], 102 | 'label_path': self.label_paths[index], 103 | 'point_path': self.point_paths[index], 104 | 'image': image_data, 105 | 'label': label_data, 106 | 'point': point_data 107 | } 108 | 109 | def __len__(self): 110 | return self.num_samples 111 | 112 | 113 | def dataset_kfold(dataset_dir, save_path, k=5): 114 | indexes = [l[:-4] for l in os.listdir(dataset_dir)] 115 | 116 | kf = KFold(k, shuffle=True) #k折交叉验证 117 | val_index = dict() 118 | for i in range(k): 119 | val_index[str(i)] = [] 120 | 121 | for i, (tr, val) in enumerate(kf.split(indexes)): 122 | for item in val: 123 | val_index[str(i)].append(indexes[item]) 124 | print('fold:{},train_len:{},val_len:{}'.format(i, len(tr), len(val))) 125 | 126 | with open(save_path, 'w') as f: 127 | json.dump(val_index, f) 128 | 129 | 130 | def random_seperate_dataset(): 131 | # indexes = [l[:-4] for l in os.listdir('/raid/wl/ISBI_2018/Train/Image/')] 132 | # random.shuffle(indexes) 133 | # names = {'0':indexes[:500], '1':indexes[500:1000], '2':indexes[1000:1500], '3':indexes[1500:2000],'4':indexes[2000:]} 134 | # names = [indexes[:500], indexes[500:1000], indexes[1000:1500], indexes[1500:2000], indexes[2000:]] 135 | # with open('/raid/wl/ISBI_2018/data_split.json','w') as f: 136 | # json.dump(names, f) 137 | return 138 | 139 | 140 | if __name__ == '__main__': 141 | from tqdm import tqdm 142 | dataset = myDataset(fold='0', split='train', aug=True) 143 | 144 | train_loader = torch.utils.data.DataLoader(dataset, 145 | batch_size=8, 146 | shuffle=False, 147 | num_workers=2, 148 | pin_memory=True, 149 | drop_last=True) 150 | for d in train_loader: 151 | pass 152 | -------------------------------------------------------------------------------- /framework.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jcwang123/BA-Transformer/d1ccdff68beac82d18c2cd64bb800e51d7afc5e1/framework.jpg -------------------------------------------------------------------------------- /infer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def main(): 5 | print('hello') 6 | 7 | 8 | if __name__ == '__main__': 9 | main() -------------------------------------------------------------------------------- /lib/Cell_DETR_master/.gitignore: -------------------------------------------------------------------------------- 1 | # Source: https://github.com/github/gitignore/blob/master/Global/JetBrains.gitignore 2 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider 3 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 4 | 5 | # User-specific stuff 6 | .idea 7 | .idea/**/workspace.xml 8 | .idea/**/tasks.xml 9 | .idea/**/usage.statistics.xml 10 | .idea/**/dictionaries 11 | .idea/**/shelf 12 | 13 | # Generated files 14 | .idea/**/contentModel.xml 15 | 16 | # Sensitive or high-churn files 17 | .idea/**/dataSources/ 18 | .idea/**/dataSources.ids 19 | .idea/**/dataSources.local.xml 20 | .idea/**/sqlDataSources.xml 21 | .idea/**/dynamic.xml 22 | .idea/**/uiDesigner.xml 23 | .idea/**/dbnavigator.xml 24 | 25 | # Gradle 26 | .idea/**/gradle.xml 27 | .idea/**/libraries 28 | 29 | # Gradle and Maven with auto-import 30 | # When using Gradle or Maven with auto-import, you should exclude module files, 31 | # since they will be recreated, and may cause churn. Uncomment if using 32 | # auto-import. 33 | # .idea/artifacts 34 | # .idea/compiler.xml 35 | # .idea/jarRepositories.xml 36 | # .idea/modules.xml 37 | # .idea/*.iml 38 | # .idea/modules 39 | # *.iml 40 | # *.ipr 41 | 42 | # CMake 43 | cmake-build-*/ 44 | 45 | # Mongo Explorer plugin 46 | .idea/**/mongoSettings.xml 47 | 48 | # File-based project format 49 | *.iws 50 | 51 | # IntelliJ 52 | out/ 53 | 54 | # mpeltonen/sbt-idea plugin 55 | .idea_modules/ 56 | 57 | # JIRA plugin 58 | atlassian-ide-plugin.xml 59 | 60 | # Cursive Clojure plugin 61 | .idea/replstate.xml 62 | 63 | # Crashlytics plugin (for Android Studio and IntelliJ) 64 | com_crashlytics_export_strings.xml 65 | crashlytics.properties 66 | crashlytics-build.properties 67 | fabric.properties 68 | 69 | # Editor-based Rest Client 70 | .idea/httpRequests 71 | 72 | # Android studio 3.1+ serialized cache file 73 | .idea/caches/build_file_checksums.ser 74 | -------------------------------------------------------------------------------- /lib/Cell_DETR_master/DETR_Test.py: -------------------------------------------------------------------------------- 1 | import os,argparse, math 2 | import sys 3 | import torch 4 | import torch.nn 5 | import tqdm 6 | import logging 7 | import numpy as np 8 | from glob import glob 9 | from HNgtv_dataset import norm01,myDataset,crop_array 10 | from medpy.metric.binary import hd, hd95, dc, jc, assd 11 | sys.path.append('Ours/Cell_DETR_master/') 12 | 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument('--crop_size', type=int,default=128) 15 | # parser.add_argument('--with_BPB', type=int, default=0,choices = [0,1]) 16 | parse_config = parser.parse_args() 17 | print(parse_config) 18 | 19 | os.environ['CUDA_VISIBLE_DEVICES'] = '1' 20 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 21 | # model = torch.load('./logs/test5_aaai_sdm_loss/model/best.pkl') 22 | dir_path = "./logs/test_transformer_loss_0_ver_1/" 23 | model = torch.load(dir_path + 'model/best.pkl') 24 | # model = torch.load(dir_path + 'model/latest.pkl') 25 | txt_path = os.path.join(dir_path + 'parameter.txt') 26 | 27 | def test(): 28 | dice_value = 0 29 | hd_value = 0 30 | hd95_value = 0 31 | jc_value = 0 32 | assd_value = 0 33 | numm = 0 34 | 35 | logging.basicConfig(filename = txt_path, level=logging.INFO, format='[%(asctime)s.%(msecs)03d] %(message)s', datefmt='%H:%M:%S') 36 | logging.getLogger().addHandler(logging.StreamHandler(sys.stdout)) 37 | 38 | for num in range(41,51): 39 | prediction = [] 40 | labels = [] 41 | img_dir = os.path.join( '/home/wl/gtv_data/Patient' + str(num).zfill(3)+ '/') 42 | n = len(glob(img_dir + 'image_*.npy')) 43 | for i in range(n): 44 | img = np.load(img_dir + 'image_{:03d}.npy'.format(i)) 45 | label = np.load(img_dir + 'label_{:03d}.npy'.format(i)) 46 | point = np.load(img_dir + 'point_{:03d}.npy'.format(i)) 47 | point = np.expand_dims(point,0) 48 | label_sum = np.sum(label) 49 | img = norm01(img) 50 | img = crop_array(img, parse_config.crop_size) 51 | # img = np.repeat(img,3,0) 52 | label = crop_array(label, parse_config.crop_size) 53 | labels.append(label) 54 | img = torch.from_numpy(img).unsqueeze(0).float().cuda() 55 | 56 | with torch.no_grad(): 57 | # if parse_config.with_BPB == 0: 58 | # output = model(img) 59 | # if parse_config.with_BPB == 1: 60 | # output,maps = model(img) 61 | # output = torch.sigmoid(output)[0] 62 | 63 | output = model(img) 64 | output = torch.max(output, dim=0, keepdim=True).values 65 | output = output.cpu().numpy()>0.5 66 | 67 | # writer.add_image('val_label', label, val_num) 68 | # writer.add_image('val_point', point, val_num) 69 | # writer.add_image('val_output', output, val_num) 70 | 71 | # if parse_config.with_BPB == 1: 72 | # for j, m in enumerate(maps): 73 | # if m is not None: 74 | # writer.add_image('val_m{}_img'.format(j+1), maps[j][0,...], val_num) 75 | # val_num += 1 76 | prediction.append(output) 77 | 78 | prediction = np.array(prediction) 79 | # prediction = prediction.squeeze(1) 80 | labels = np.array(labels) 81 | 82 | assert(prediction.shape==labels.shape) 83 | # calculate metric 84 | dice_ave = dc(prediction, labels) 85 | hd_ave = hd(prediction, labels) 86 | hd95_ave = hd95(prediction, labels) 87 | jc_ave = jc(prediction, labels) 88 | assd_ave = assd(prediction, labels) 89 | 90 | logging.info('patient %d : dice_value : %f' % (num, dice_ave)) 91 | logging.info('patient %d : hd_value : %f' % (num, hd_ave)) 92 | logging.info('patient %d : hd95_value : %f' % (num, hd95_ave)) 93 | logging.info('patient %d : jc_value : %f' % (num, jc_ave)) 94 | logging.info('patient %d : assd_value : %f' % (num, assd_ave)) 95 | 96 | # print("Dice value for patient{} = ".format(num),dice_ave) 97 | # print("HD value for patient{} = ".format(num),hd_ave) 98 | dice_value += dice_ave 99 | hd_value += hd_ave 100 | hd95_value += hd95_ave 101 | jc_value += jc_ave 102 | assd_value += assd_ave 103 | numm += 1 104 | 105 | dice_average = dice_value / numm 106 | hd_average = hd_value / numm 107 | hd95_average = hd95_value / numm 108 | jc_average = jc_value / numm 109 | assd_average = assd_value / numm 110 | 111 | logging.info('Dice value of test dataset : %f' % (dice_average)) 112 | logging.info('HD value of test dataset : %f' % (hd_average)) 113 | logging.info('HD95 value of test dataset : %f' % (hd95_average)) 114 | logging.info('JC value of test dataset : %f' % (jc_average)) 115 | logging.info('ASSD value of test dataset : %f' % (assd_average)) 116 | # print("Average dice value of evaluation dataset = ",dice_average) 117 | return dice_average 118 | 119 | if __name__ == '__main__': 120 | test() -------------------------------------------------------------------------------- /lib/Cell_DETR_master/DeepDetr.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from typing import Tuple, Type, Iterable 3 | sys.path.insert(0, '/home/chenfei/my_codes/TransformerCode-master/lib/Cell_DETR_master/') 4 | import torch 5 | import numpy as np 6 | import torch.nn as nn 7 | from torchvision import models 8 | import torch.nn.functional as F 9 | # from modules.modulated_deform_conv import ModulatedDeformConvPack 10 | # from pade_activation_unit.utils import PAU 11 | from torch.nn.modules import Conv2d,LeakyReLU 12 | 13 | # A 14 | conv = Conv2d 15 | act = LeakyReLU 16 | 17 | from backbone import Backbone, DenseNetBlock, StandardBlock, ResNetBlock 18 | from segmentation import MultiHeadAttention, SegmentationHead, SingleSegmentationHead, ResFeaturePyramidBlock, ResPACFeaturePyramidBlock 19 | from transformer import Transformer,TransformerEncoder,TransformerEncoderLayer 20 | 21 | class CellDETR(nn.Module): 22 | def __init__(self, 23 | num_classes: int = 1, 24 | number_of_query_positions: int = 1, 25 | hidden_features=128, 26 | backbone_channels: Tuple[Tuple[int, int], ...] = ( 27 | (3, 64), (64, 128), (128, 256), (256, 256)), 28 | backbone_block: Type = ResNetBlock, backbone_convolution: Type = conv, 29 | backbone_normalization: Type = nn.BatchNorm2d, backbone_activation: Type = act, 30 | backbone_pooling: Type = nn.AvgPool2d, 31 | bounding_box_head_features: Tuple[Tuple[int, int], ...] = ((128, 64), (64, 16), (16, 4)), 32 | bounding_box_head_activation: Type = act, 33 | classification_head_activation: Type = act, 34 | num_encoder_layers: int = 3, 35 | num_decoder_layers: int = 2, 36 | dropout: float = 0.0, 37 | transformer_attention_heads: int = 8, 38 | transformer_activation: Type = act, 39 | segmentation_attention_heads: int = 8, 40 | segmentation_head_channels: Tuple[Tuple[int, int], ...] = ( 41 | (128+8, 64), (64, 32),(32, 16)), 42 | segmentation_head_feature_channels: Tuple[int, ...] = (256, 128, 64), 43 | segmentation_head_block: Type = ResPACFeaturePyramidBlock, 44 | segmentation_head_convolution: Type = conv, 45 | segmentation_head_normalization: Type = nn.InstanceNorm2d, 46 | segmentation_head_activation: Type = act, 47 | segmentation_head_final_activation: Type = nn.Sigmoid) -> None: 48 | 49 | # Call super constructor 50 | super(CellDETR, self).__init__() 51 | # Init backbone 52 | self.backbone = Backbone(channels=backbone_channels, block=backbone_block, convolution=backbone_convolution, 53 | normalization=backbone_normalization, activation=backbone_activation, 54 | pooling=backbone_pooling) 55 | # Init convolution mapping to match transformer dims 56 | self.convolution_mapping = nn.Conv2d(in_channels=backbone_channels[-1][-1], out_channels=hidden_features, 57 | kernel_size=(1, 1), stride=(1, 1), padding=(0, 0), bias=True) 58 | # Init query positions 59 | self.query_positions = nn.Parameter( 60 | data=torch.randn(number_of_query_positions, hidden_features, dtype=torch.float), 61 | requires_grad=True) 62 | # Init embeddings 63 | self.row_embedding = nn.Parameter(data=torch.randn(50, hidden_features // 2, dtype=torch.float), 64 | requires_grad=True) 65 | self.column_embedding = nn.Parameter(data=torch.randn(50, hidden_features // 2, dtype=torch.float), 66 | requires_grad=True) 67 | # Init transformer 68 | self.transformer = Transformer(d_model=hidden_features, nhead=transformer_attention_heads, 69 | num_encoder_layers=num_encoder_layers, num_decoder_layers=num_decoder_layers, 70 | dropout=dropout, dim_feedforward=4 * hidden_features, 71 | activation=transformer_activation) 72 | 73 | # Init segmentation attention head 74 | self.segmentation_attention_head = MultiHeadAttention(query_dimension=hidden_features, 75 | hidden_features=hidden_features, 76 | number_of_heads=segmentation_attention_heads, 77 | dropout=dropout) 78 | # Init segmentation head 79 | self.segmentation_head = SegmentationHead(channels=segmentation_head_channels, 80 | feature_channels=segmentation_head_feature_channels, 81 | convolution=segmentation_head_convolution, 82 | normalization=segmentation_head_normalization, 83 | activation=segmentation_head_activation, 84 | block=segmentation_head_block, 85 | number_of_query_positions=number_of_query_positions, 86 | softmax=isinstance(segmentation_head_final_activation(), nn.Softmax)) 87 | # Init final segmentation activation 88 | self.segmentation_final_activation = segmentation_head_final_activation(dim=1) if isinstance( 89 | segmentation_head_final_activation(), nn.Softmax) else segmentation_head_final_activation() 90 | 91 | self.point_pre_layer = nn.Conv2d(hidden_features, 1, kernel_size = 1) 92 | 93 | def get_parameters(self, lr_main: float = 1e-04, lr_backbone: float = 1e-05) -> Iterable: 94 | return [{'params': self.backbone.parameters(), 'lr': lr_backbone}, 95 | {'params': self.convolution_mapping.parameters(), 'lr': lr_main}, 96 | {'params': [self.row_embedding], 'lr': lr_main}, 97 | {'params': [self.column_embedding], 'lr': lr_main}, 98 | {'params': self.transformer.parameters(), 'lr': lr_main}, 99 | {'params': self.bounding_box_head.parameters(), 'lr': lr_main}, 100 | {'params': self.class_head.parameters(), 'lr': lr_main}, 101 | {'params': self.segmentation_attention_head.parameters(), 'lr': lr_main}, 102 | {'params': self.segmentation_head.parameters(), 'lr': lr_main}] 103 | 104 | def get_segmentation_head_parameters(self, lr: float = 1e-05) -> Iterable: 105 | return [{'params': self.segmentation_attention_head.parameters(), 'lr': lr}, 106 | {'params': self.segmentation_head.parameters(), 'lr': lr}] 107 | 108 | def forward(self, input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 109 | features, feature_list = self.backbone(input) 110 | features = self.convolution_mapping(features) 111 | height, width = features.shape[2:] 112 | batch_size = features.shape[0] 113 | positional_embeddings = torch.cat([self.column_embedding[:height].unsqueeze(dim=0).repeat(height, 1, 1), 114 | self.row_embedding[:width].unsqueeze(dim=1).repeat(1, width, 1)], 115 | dim=-1).permute(2, 0, 1).unsqueeze(0).repeat(batch_size, 1, 1, 1) 116 | latent_tensor, features_encoded = self.transformer(features, None, self.query_positions, positional_embeddings) 117 | latent_tensor = latent_tensor.permute(2, 0, 1) 118 | # point = self.point_pre_layer(features_encoded) 119 | # point = torch.sigmoid(point) 120 | # Feature = point*features_encoded + features_encoded 121 | bounding_box_attention_masks = self.segmentation_attention_head( 122 | latent_tensor, features_encoded.contiguous()) 123 | instance_segmentation_prediction = self.segmentation_head(features.contiguous(), 124 | bounding_box_attention_masks.contiguous(), 125 | feature_list[-2::-1]) 126 | return self.segmentation_final_activation(instance_segmentation_prediction).clone() 127 | 128 | if __name__ == '__main__': 129 | # Init model 130 | detr = CellDETR() 131 | # Print number of parameters 132 | print("DETR # parameters", sum([p.numel() for p in detr.parameters()])) 133 | # Model into eval mode 134 | # detr.eval() 135 | image = torch.randn(5,3,512,512) 136 | # point = torch.randn(5,1,128,128) 137 | # Predict 138 | segmentation_prediction = detr(image) 139 | 140 | # Print shapes 141 | # print(segmentation_prediction.shape) 142 | # print(segmentation_prediction.max(), segmentation_prediction.min()) 143 | # loss = segmentation_prediction.sum() 144 | # loss.backward() 145 | -------------------------------------------------------------------------------- /lib/Cell_DETR_master/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Christoph Reich & Tim Prangemeier (Bioinspired Communication Systems Lab, TU Darmstadt) 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /lib/Cell_DETR_master/__pycache__/backbone.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jcwang123/BA-Transformer/d1ccdff68beac82d18c2cd64bb800e51d7afc5e1/lib/Cell_DETR_master/__pycache__/backbone.cpython-37.pyc -------------------------------------------------------------------------------- /lib/Cell_DETR_master/__pycache__/bounding_box_head.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jcwang123/BA-Transformer/d1ccdff68beac82d18c2cd64bb800e51d7afc5e1/lib/Cell_DETR_master/__pycache__/bounding_box_head.cpython-37.pyc -------------------------------------------------------------------------------- /lib/Cell_DETR_master/__pycache__/detr_new.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jcwang123/BA-Transformer/d1ccdff68beac82d18c2cd64bb800e51d7afc5e1/lib/Cell_DETR_master/__pycache__/detr_new.cpython-37.pyc -------------------------------------------------------------------------------- /lib/Cell_DETR_master/__pycache__/segmentation.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jcwang123/BA-Transformer/d1ccdff68beac82d18c2cd64bb800e51d7afc5e1/lib/Cell_DETR_master/__pycache__/segmentation.cpython-37.pyc -------------------------------------------------------------------------------- /lib/Cell_DETR_master/__pycache__/transformer.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jcwang123/BA-Transformer/d1ccdff68beac82d18c2cd64bb800e51d7afc5e1/lib/Cell_DETR_master/__pycache__/transformer.cpython-37.pyc -------------------------------------------------------------------------------- /lib/Cell_DETR_master/augmentation.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import torch 4 | from scipy.ndimage.interpolation import map_coordinates 5 | from scipy.ndimage.filters import gaussian_filter 6 | import numpy as np 7 | 8 | 9 | class Augmentation(object): 10 | """ 11 | Super class for all augmentations. 12 | """ 13 | 14 | def __init__(self) -> None: 15 | """ 16 | Constructor method 17 | """ 18 | pass 19 | 20 | def need_labels(self) -> None: 21 | """ 22 | Method should return if the labels are needed for the augmentation 23 | """ 24 | raise NotImplementedError() 25 | 26 | def __call__(self, *args, **kwargs) -> None: 27 | """ 28 | Call method is used to apply the augmentation 29 | :param args: Will be ignored 30 | :param kwargs: Will be ignored 31 | """ 32 | raise NotImplementedError() 33 | 34 | 35 | class VerticalFlip(Augmentation): 36 | """ 37 | This class implements vertical flipping for instance segmentation. 38 | """ 39 | 40 | def __init__(self) -> None: 41 | """ 42 | Constructor method 43 | """ 44 | # Call super constructor 45 | super(VerticalFlip, self).__init__() 46 | 47 | def need_labels(self) -> bool: 48 | """ 49 | Method returns that the labels are needed for the augmentation 50 | :return: (Bool) True will be returned 51 | """ 52 | return True 53 | 54 | def __call__(self, input: torch.tensor, instances: torch.tensor, 55 | bounding_boxes: torch.tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 56 | """ 57 | Flipping augmentation (only horizontal) 58 | :param image: (torch.Tensor) Input image of shape [channels, height, width] 59 | :param instances: (torch.Tenor) Instances segmentation maps of shape [instances, height, width] 60 | :param bounding_boxes: (torch.Tensor) Bounding boxes of shape [instances, 4 (x1, y1, x2, y2)] 61 | :return: (Tuple[torch.Tensor, torch.Tensor, torch.Tensor]) Input flipped, instances flipped & BBs flipped 62 | """ 63 | # Flip input 64 | input_flipped = input.flip(dims=(2,)) 65 | # Flip instances 66 | instances_flipped = instances.flip(dims=(2,)) 67 | # Flip bounding boxes 68 | image_center = torch.tensor((input.shape[2] // 2, input.shape[1] // 2)) 69 | bounding_boxes[:, [0, 2]] += 2 * (image_center - bounding_boxes[:, [0, 2]]) 70 | bounding_boxes_w = torch.abs(bounding_boxes[:, 0] - bounding_boxes[:, 2]) 71 | bounding_boxes[:, 0] -= bounding_boxes_w 72 | bounding_boxes[:, 2] += bounding_boxes_w 73 | return input_flipped, instances_flipped, bounding_boxes 74 | 75 | 76 | class ElasticDeformation(Augmentation): 77 | """ 78 | This class implement random elastic deformation of a given input image 79 | """ 80 | 81 | def __init__(self, alpha: float = 125, sigma: float = 20) -> None: 82 | """ 83 | Constructor method 84 | :param alpha: (float) Alpha coefficient which represents the scaling 85 | :param sigma: (float) Sigma coefficient which represents the elastic factor 86 | """ 87 | # Call super constructor 88 | super(ElasticDeformation, self).__init__() 89 | # Save parameters 90 | self.alpha = alpha 91 | self.sigma = sigma 92 | 93 | def need_labels(self) -> bool: 94 | """ 95 | Method returns that the labels are needed for the augmentation 96 | :return: (Bool) True will be returned 97 | """ 98 | return False 99 | 100 | def __call__(self, image: torch.Tensor) -> torch.Tensor: 101 | """ 102 | Method applies the random elastic deformation 103 | :param image: (torch.Tensor) Input image 104 | :return: (torch.Tensor) Transformed input image 105 | """ 106 | # Convert torch tensor to numpy array for scipy 107 | image = image.numpy() 108 | # Save basic shape 109 | shape = image.shape[1:] 110 | # Sample offsets 111 | dx = gaussian_filter((np.random.rand(*shape) * 2 - 1), self.sigma, mode="constant", cval=0) * self.alpha 112 | dy = gaussian_filter((np.random.rand(*shape) * 2 - 1), self.sigma, mode="constant", cval=0) * self.alpha 113 | x, y = np.meshgrid(np.arange(shape[0]), np.arange(shape[1]), indexing='ij') 114 | indices = np.reshape(x + dx, (-1, 1)), np.reshape(y + dy, (-1, 1)) 115 | # Perform deformation 116 | for index in range(image.shape[0]): 117 | image[index] = map_coordinates(image[index], indices, order=1).reshape(shape) 118 | return torch.from_numpy(image) 119 | 120 | 121 | class NoiseInjection(Augmentation): 122 | """ 123 | This class implements vertical flipping for instance segmentation. 124 | """ 125 | 126 | def __init__(self, mean: float = 0.0, std: float = 0.25) -> None: 127 | """ 128 | Constructor method 129 | :param mean: (Optional[float]) Mean of the gaussian noise 130 | :param std: (Optional[float]) Standard deviation of the gaussian noise 131 | """ 132 | # Call super constructor 133 | super(NoiseInjection, self).__init__() 134 | # Save parameter 135 | self.mean = mean 136 | self.std = std 137 | 138 | def need_labels(self) -> bool: 139 | """ 140 | Method returns that the labels are needed for the augmentation 141 | :return: (Bool) False will be returned 142 | """ 143 | return False 144 | 145 | def __call__(self, input: torch.Tensor) -> torch.Tensor: 146 | """ 147 | Method injects gaussian noise to the given input image 148 | :param image: (torch.Tensor) Input image 149 | :return: (torch.Tensor) Transformed input image 150 | """ 151 | # Get noise 152 | noise = self.mean + torch.randn_like(input) * self.std 153 | # Apply nose to image 154 | input = input + noise 155 | return input 156 | -------------------------------------------------------------------------------- /lib/Cell_DETR_master/botr.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from typing import Tuple, Type, Iterable 3 | sys.path.insert(0, '/home/wl/File/SegGTV/Ours/Cell_DETR_master/') 4 | import torch 5 | import numpy as np 6 | import torch.nn as nn 7 | from torchvision import models 8 | import torch.nn.functional as F 9 | from modules.modulated_deform_conv import ModulatedDeformConvPack 10 | from pade_activation_unit.utils import PAU 11 | from torch.nn.modules import Conv2d,LeakyReLU 12 | 13 | # A 14 | conv = Conv2d 15 | act = LeakyReLU 16 | 17 | from backbone import Backbone, DenseNetBlock, StandardBlock, ResNetBlock 18 | from modules.modulated_deform_conv import ModulatedDeformConvPack 19 | from pade_activation_unit.utils import PAU 20 | from segmentation import MultiHeadAttention, SegmentationHead, SingleSegmentationHead, ResFeaturePyramidBlock, ResPACFeaturePyramidBlock 21 | from transformer import Transformer,TransformerEncoder,TransformerEncoderLayer 22 | 23 | class CellDETR(nn.Module): 24 | def __init__(self, 25 | num_classes: int = 1, 26 | number_of_query_positions: int = 1, 27 | hidden_features=128, 28 | backbone_channels: Tuple[Tuple[int, int], ...] = ( 29 | (1, 64), (64, 128), (128, 256), (256, 256)), 30 | backbone_block: Type = ResNetBlock, backbone_convolution: Type = conv, 31 | backbone_normalization: Type = nn.BatchNorm2d, backbone_activation: Type = act, 32 | backbone_pooling: Type = nn.AvgPool2d, 33 | bounding_box_head_features: Tuple[Tuple[int, int], ...] = ((128, 64), (64, 16), (16, 4)), 34 | bounding_box_head_activation: Type = act, 35 | classification_head_activation: Type = act, 36 | num_encoder_layers: int = 3, 37 | num_decoder_layers: int = 2, 38 | dropout: float = 0.0, 39 | transformer_attention_heads: int = 8, 40 | transformer_activation: Type = act, 41 | segmentation_attention_heads: int = 8, 42 | segmentation_head_channels: Tuple[Tuple[int, int], ...] = ( 43 | (128+8, 64), (64, 32),(32, 16)), 44 | segmentation_head_feature_channels: Tuple[int, ...] = (256, 128, 64), 45 | segmentation_head_block: Type = ResPACFeaturePyramidBlock, 46 | segmentation_head_convolution: Type = conv, 47 | segmentation_head_normalization: Type = nn.InstanceNorm2d, 48 | segmentation_head_activation: Type = act, 49 | segmentation_head_final_activation: Type = nn.Sigmoid) -> None: 50 | 51 | # Call super constructor 52 | super(CellDETR, self).__init__() 53 | # Init backbone 54 | self.backbone = Backbone(channels=backbone_channels, block=backbone_block, convolution=backbone_convolution, 55 | normalization=backbone_normalization, activation=backbone_activation, 56 | pooling=backbone_pooling) 57 | # Init convolution mapping to match transformer dims 58 | self.convolution_mapping = nn.Conv2d(in_channels=backbone_channels[-1][-1], out_channels=hidden_features, 59 | kernel_size=(1, 1), stride=(1, 1), padding=(0, 0), bias=True) 60 | # Init query positions 61 | self.query_positions = nn.Parameter( 62 | data=torch.randn(number_of_query_positions, hidden_features, dtype=torch.float), 63 | requires_grad=True) 64 | # Init embeddings 65 | self.row_embedding = nn.Parameter(data=torch.randn(50, hidden_features // 2, dtype=torch.float), 66 | requires_grad=True) 67 | self.column_embedding = nn.Parameter(data=torch.randn(50, hidden_features // 2, dtype=torch.float), 68 | requires_grad=True) 69 | # Init transformer 70 | self.transformer = Transformer(d_model=hidden_features, nhead=transformer_attention_heads, 71 | num_encoder_layers=num_encoder_layers, num_decoder_layers=num_decoder_layers, 72 | dropout=dropout, dim_feedforward=4 * hidden_features, 73 | activation=transformer_activation) 74 | 75 | # Init segmentation attention head 76 | self.segmentation_attention_head = MultiHeadAttention(query_dimension=hidden_features, 77 | hidden_features=hidden_features, 78 | number_of_heads=segmentation_attention_heads, 79 | dropout=dropout) 80 | # Init segmentation head 81 | self.segmentation_head = SegmentationHead(channels=segmentation_head_channels, 82 | feature_channels=segmentation_head_feature_channels, 83 | convolution=segmentation_head_convolution, 84 | normalization=segmentation_head_normalization, 85 | activation=segmentation_head_activation, 86 | block=segmentation_head_block, 87 | number_of_query_positions=number_of_query_positions, 88 | softmax=isinstance(segmentation_head_final_activation(), nn.Softmax)) 89 | # Init final segmentation activation 90 | self.segmentation_final_activation = segmentation_head_final_activation(dim=1) if isinstance( 91 | segmentation_head_final_activation(), nn.Softmax) else segmentation_head_final_activation() 92 | 93 | self.point_pre_layer = nn.Conv2d(hidden_features, 1, kernel_size = 1) 94 | 95 | def get_parameters(self, lr_main: float = 1e-04, lr_backbone: float = 1e-05) -> Iterable: 96 | return [{'params': self.backbone.parameters(), 'lr': lr_backbone}, 97 | {'params': self.convolution_mapping.parameters(), 'lr': lr_main}, 98 | {'params': [self.row_embedding], 'lr': lr_main}, 99 | {'params': [self.column_embedding], 'lr': lr_main}, 100 | {'params': self.transformer.parameters(), 'lr': lr_main}, 101 | {'params': self.bounding_box_head.parameters(), 'lr': lr_main}, 102 | {'params': self.class_head.parameters(), 'lr': lr_main}, 103 | {'params': self.segmentation_attention_head.parameters(), 'lr': lr_main}, 104 | {'params': self.segmentation_head.parameters(), 'lr': lr_main}] 105 | 106 | def get_segmentation_head_parameters(self, lr: float = 1e-05) -> Iterable: 107 | return [{'params': self.segmentation_attention_head.parameters(), 'lr': lr}, 108 | {'params': self.segmentation_head.parameters(), 'lr': lr}] 109 | 110 | def forward(self, input: torch.Tensor, point: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 111 | features, feature_list = self.backbone(input) 112 | features = self.convolution_mapping(features) 113 | height, width = features.shape[2:] 114 | batch_size = features.shape[0] 115 | positional_embeddings = torch.cat([self.column_embedding[:height].unsqueeze(dim=0).repeat(height, 1, 1), 116 | self.row_embedding[:width].unsqueeze(dim=1).repeat(1, width, 1)], 117 | dim=-1).permute(2, 0, 1).unsqueeze(0).repeat(batch_size, 1, 1, 1) 118 | latent_tensor, features_encoded = self.transformer(features, None, self.query_positions, positional_embeddings) 119 | latent_tensor = latent_tensor.permute(2, 0, 1) 120 | Feature = point*features_encoded + features_encoded 121 | bounding_box_attention_masks = self.segmentation_attention_head( 122 | latent_tensor, Feature.contiguous()) 123 | instance_segmentation_prediction = self.segmentation_head(features.contiguous(), 124 | bounding_box_attention_masks.contiguous(), 125 | feature_list[-2::-1]) 126 | return self.segmentation_final_activation(instance_segmentation_prediction).clone() 127 | 128 | 129 | if __name__ == '__main__': 130 | # Init model 131 | detr = CellDETR() 132 | # Print number of parameters 133 | print("DETR # parameters", sum([p.numel() for p in detr.parameters()])) 134 | # Model into eval mode 135 | # detr.eval() 136 | image = torch.randn(5,1,128,128) 137 | # point = torch.randn(5,1,128,128) 138 | # Predict 139 | segmentation_prediction = detr(image) -------------------------------------------------------------------------------- /lib/Cell_DETR_master/botr2.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from typing import Tuple, Type, Iterable 3 | sys.path.insert(0, '/home/wl/File/SegGTV/Ours/Cell_DETR_master/') 4 | import torch 5 | import numpy as np 6 | import torch.nn as nn 7 | from torchvision import models 8 | import torch.nn.functional as F 9 | from modules.modulated_deform_conv import ModulatedDeformConvPack 10 | from pade_activation_unit.utils import PAU 11 | from torch.nn.modules import Conv2d,LeakyReLU 12 | 13 | # A 14 | conv = Conv2d 15 | act = LeakyReLU 16 | 17 | from backbone import Backbone, DenseNetBlock, StandardBlock, ResNetBlock 18 | from modules.modulated_deform_conv import ModulatedDeformConvPack 19 | from pade_activation_unit.utils import PAU 20 | from segmentation import MultiHeadAttention, SegmentationHead, SingleSegmentationHead, ResFeaturePyramidBlock, ResPACFeaturePyramidBlock 21 | from transformer import Transformer,TransformerEncoder,TransformerEncoderLayer 22 | 23 | class CellDETR(nn.Module): 24 | def __init__(self, 25 | num_classes: int = 1, 26 | number_of_query_positions: int = 1, 27 | hidden_features=128, 28 | backbone_channels: Tuple[Tuple[int, int], ...] = ( 29 | (1, 64), (64, 128), (128, 256), (256, 256)), 30 | backbone_block: Type = ResNetBlock, backbone_convolution: Type = conv, 31 | backbone_normalization: Type = nn.BatchNorm2d, backbone_activation: Type = act, 32 | backbone_pooling: Type = nn.AvgPool2d, 33 | bounding_box_head_features: Tuple[Tuple[int, int], ...] = ((128, 64), (64, 16), (16, 4)), 34 | bounding_box_head_activation: Type = act, 35 | classification_head_activation: Type = act, 36 | num_encoder_layers: int = 3, 37 | num_decoder_layers: int = 2, 38 | dropout: float = 0.0, 39 | transformer_attention_heads: int = 8, 40 | transformer_activation: Type = act, 41 | segmentation_attention_heads: int = 8, 42 | segmentation_head_channels: Tuple[Tuple[int, int], ...] = ( 43 | (128+8, 64), (64, 32),(32, 16)), 44 | segmentation_head_feature_channels: Tuple[int, ...] = (256, 128, 64), 45 | segmentation_head_block: Type = ResPACFeaturePyramidBlock, 46 | segmentation_head_convolution: Type = conv, 47 | segmentation_head_normalization: Type = nn.InstanceNorm2d, 48 | segmentation_head_activation: Type = act, 49 | segmentation_head_final_activation: Type = nn.Sigmoid) -> None: 50 | 51 | # Call super constructor 52 | super(CellDETR, self).__init__() 53 | # Init backbone 54 | self.backbone = Backbone(channels=backbone_channels, block=backbone_block, convolution=backbone_convolution, 55 | normalization=backbone_normalization, activation=backbone_activation, 56 | pooling=backbone_pooling) 57 | # Init convolution mapping to match transformer dims 58 | self.convolution_mapping = nn.Conv2d(in_channels=backbone_channels[-1][-1], out_channels=hidden_features, 59 | kernel_size=(1, 1), stride=(1, 1), padding=(0, 0), bias=True) 60 | # Init query positions 61 | self.query_positions = nn.Parameter( 62 | data=torch.randn(number_of_query_positions, hidden_features, dtype=torch.float), 63 | requires_grad=True) 64 | # Init embeddings 65 | self.row_embedding = nn.Parameter(data=torch.randn(50, hidden_features // 2, dtype=torch.float), 66 | requires_grad=True) 67 | self.column_embedding = nn.Parameter(data=torch.randn(50, hidden_features // 2, dtype=torch.float), 68 | requires_grad=True) 69 | # Init transformer 70 | self.transformer = Transformer(d_model=hidden_features, nhead=transformer_attention_heads, 71 | num_encoder_layers=num_encoder_layers, num_decoder_layers=num_decoder_layers, 72 | dropout=dropout, dim_feedforward=4 * hidden_features, 73 | activation=transformer_activation) 74 | 75 | # Init segmentation attention head 76 | self.segmentation_attention_head = MultiHeadAttention(query_dimension=hidden_features, 77 | hidden_features=hidden_features, 78 | number_of_heads=segmentation_attention_heads, 79 | dropout=dropout) 80 | # Init segmentation head 81 | self.segmentation_head = SegmentationHead(channels=segmentation_head_channels, 82 | feature_channels=segmentation_head_feature_channels, 83 | convolution=segmentation_head_convolution, 84 | normalization=segmentation_head_normalization, 85 | activation=segmentation_head_activation, 86 | block=segmentation_head_block, 87 | number_of_query_positions=number_of_query_positions, 88 | softmax=isinstance(segmentation_head_final_activation(), nn.Softmax)) 89 | # Init final segmentation activation 90 | self.segmentation_final_activation = segmentation_head_final_activation(dim=1) if isinstance( 91 | segmentation_head_final_activation(), nn.Softmax) else segmentation_head_final_activation() 92 | 93 | self.point_pre_layer = nn.Conv2d(hidden_features, 1, kernel_size = 1) 94 | 95 | def get_parameters(self, lr_main: float = 1e-04, lr_backbone: float = 1e-05) -> Iterable: 96 | return [{'params': self.backbone.parameters(), 'lr': lr_backbone}, 97 | {'params': self.convolution_mapping.parameters(), 'lr': lr_main}, 98 | {'params': [self.row_embedding], 'lr': lr_main}, 99 | {'params': [self.column_embedding], 'lr': lr_main}, 100 | {'params': self.transformer.parameters(), 'lr': lr_main}, 101 | {'params': self.bounding_box_head.parameters(), 'lr': lr_main}, 102 | {'params': self.class_head.parameters(), 'lr': lr_main}, 103 | {'params': self.segmentation_attention_head.parameters(), 'lr': lr_main}, 104 | {'params': self.segmentation_head.parameters(), 'lr': lr_main}] 105 | 106 | def get_segmentation_head_parameters(self, lr: float = 1e-05) -> Iterable: 107 | return [{'params': self.segmentation_attention_head.parameters(), 'lr': lr}, 108 | {'params': self.segmentation_head.parameters(), 'lr': lr}] 109 | 110 | def forward(self, input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 111 | features, feature_list = self.backbone(input) 112 | features = self.convolution_mapping(features) 113 | height, width = features.shape[2:] 114 | batch_size = features.shape[0] 115 | positional_embeddings = torch.cat([self.column_embedding[:height].unsqueeze(dim=0).repeat(height, 1, 1), 116 | self.row_embedding[:width].unsqueeze(dim=1).repeat(1, width, 1)], 117 | dim=-1).permute(2, 0, 1).unsqueeze(0).repeat(batch_size, 1, 1, 1) 118 | latent_tensor, features_encoded = self.transformer(features, None, self.query_positions, positional_embeddings) 119 | latent_tensor = latent_tensor.permute(2, 0, 1) 120 | point = self.point_pre_layer(features_encoded) 121 | point = torch.sigmoid(point) 122 | Feature = point*features_encoded + features_encoded 123 | bounding_box_attention_masks = self.segmentation_attention_head( 124 | latent_tensor, Feature.contiguous()) 125 | instance_segmentation_prediction = self.segmentation_head(features.contiguous(), 126 | bounding_box_attention_masks.contiguous(), 127 | feature_list[-2::-1]) 128 | return self.segmentation_final_activation(instance_segmentation_prediction).clone(),point 129 | 130 | if __name__ == '__main__': 131 | # Init model 132 | detr = CellDETR() 133 | # Print number of parameters 134 | print("DETR # parameters", sum([p.numel() for p in detr.parameters()])) 135 | # Model into eval mode 136 | # detr.eval() 137 | image = torch.randn(5,1,128,128) 138 | # point = torch.randn(5,1,128,128) 139 | # Predict 140 | segmentation_prediction = detr(image) 141 | 142 | -------------------------------------------------------------------------------- /lib/Cell_DETR_master/bounding_box_head.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, Type 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | 7 | class BoundingBoxHead(nn.Module): 8 | """ 9 | This class implements the feed forward bounding box head as proposed in: 10 | https://arxiv.org/abs/2005.12872 11 | """ 12 | 13 | def __init__(self, features: Tuple[Tuple[int, int]] = ((256, 64), (64, 16), (16, 4)), 14 | activation: Type = nn.PReLU) -> None: 15 | """ 16 | Constructor method 17 | :param features: (Tuple[Tuple[int, int]]) Number of input and output features in each layer 18 | :param activation: (Type) Activation function to be utilized 19 | """ 20 | # Call super constructor 21 | super(BoundingBoxHead, self).__init__() 22 | # Init layers 23 | self.layers = [] 24 | for index, feature in enumerate(features): 25 | if index < len(features) - 1: 26 | self.layers.extend([nn.Linear(in_features=feature[0], out_features=feature[1]), activation()]) 27 | else: 28 | self.layers.append(nn.Linear(in_features=feature[0], out_features=feature[1])) 29 | self.layers = nn.Sequential(*self.layers) 30 | 31 | def forward(self, input: torch.Tensor) -> torch.Tensor: 32 | """ 33 | Forward pass 34 | :param input: (torch.Tensor) Input tensor of shape (batch size, instances, features) 35 | :return: (torch.Tensor) Output tensor of shape (batch size, instances, classes + 1 (no object)) 36 | """ 37 | return self.layers(input) 38 | -------------------------------------------------------------------------------- /lib/Cell_DETR_master/ccccode.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | from torchvision import models 5 | import torch.nn.functional as F 6 | from Ours.Cell_DETR_master.transformer import Transformer,TransformerEncoder,TransformerEncoderLayer 7 | os.environ['CUDA_VISIBLE_DEVICES'] = '1' 8 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 9 | 10 | x = torch.rand((5,1,128,128)).cpu() 11 | # resnet = models.resnet18().cuda() 12 | # cnn_encoder = nn.Sequential(*list(resnet.children())[:-4]) 13 | # features = cnn_encoder(x) 14 | 15 | from torch.nn.modules import Conv2d,LeakyReLU 16 | hidden_features=128 17 | query_dimension = 128 18 | dropout = 0.0 19 | num_classes = 3 20 | number_of_heads = 16 21 | num_encoder_layers = 3 22 | num_decoder_layers = 2 23 | number_of_query_positions=12 24 | transformer_attention_heads = 8 25 | segmentation_attention_heads = 8 26 | transformer_activation = nn.LeakyReLU 27 | classification_head_activation = nn.LeakyReLU 28 | normalize_before=False 29 | 30 | from backbone import Backbone, DenseNetBlock, StandardBlock, ResNetBlock 31 | from modules.modulated_deform_conv import ModulatedDeformConvPack 32 | from pade_activation_unit.utils import PAU 33 | backbone_channels = ((1, 64), (64, 128), (128, 256), (256, 256)) 34 | backbone_block = ResNetBlock 35 | backbone_convolution = nn.Conv2d 36 | backbone_normalization = nn.BatchNorm2d 37 | backbone_activation = nn.LeakyReLU 38 | backbone_pooling = nn.AvgPool2d 39 | backbone = Backbone(channels=backbone_channels, block=backbone_block, convolution=backbone_convolution, 40 | normalization=backbone_normalization, activation=backbone_activation, 41 | pooling=backbone_pooling) 42 | convolution_mapping = nn.Conv2d(in_channels=backbone_channels[-1][-1], out_channels=hidden_features, 43 | kernel_size=(1, 1), stride=(1, 1), padding=(0, 0), bias=True) 44 | 45 | 46 | features, feature_list = backbone(x) 47 | features = convolution_mapping(features) 48 | height, width = features.shape[2:] 49 | print("height width: ",height, width) 50 | # Get batch size 51 | batch_size = features.shape[0] 52 | # Make positional embeddings 53 | print("features size: ",features.shape,len(feature_list)) 54 | 55 | query_positions = nn.Parameter( 56 | data=torch.randn(number_of_query_positions, hidden_features, dtype=torch.float), 57 | requires_grad=True) 58 | 59 | row_embedding = nn.Parameter(data=torch.randn(50, hidden_features // 2, dtype=torch.float), 60 | requires_grad=True) 61 | column_embedding = nn.Parameter(data=torch.randn(50, hidden_features // 2, dtype=torch.float), 62 | requires_grad=True) 63 | positional_embeddings = torch.cat([column_embedding[:height].unsqueeze(dim=0).repeat(height, 1, 1), 64 | row_embedding[:width].unsqueeze(dim=1).repeat(1, width, 1)], 65 | dim=-1).permute(2, 0, 1).unsqueeze(0).repeat(batch_size, 1, 1, 1) 66 | print("query_positions size: ", query_positions.shape) 67 | print("positional_embeddings size: ",positional_embeddings.shape) 68 | 69 | positional_embeddings.max(), positional_embeddings.min() 70 | Input = features+positional_embeddings 71 | print(Input.shape) 72 | plt.figure() 73 | plt.subplot(1,3,1) 74 | plt.imshow(features.detach()[0,63,:]) 75 | plt.subplot(1,3,2) 76 | plt.imshow(positional_embeddings.detach()[0,63,:]) 77 | plt.subplot(1,3,3) 78 | plt.imshow(Input.detach()[0,63,:]) 79 | plt.figure() 80 | plt.subplot(1,2,1) 81 | plt.imshow(positional_embeddings.detach()[0,63,:]) 82 | plt.subplot(1,2,2) 83 | plt.imshow(positional_embeddings.detach()[0,64,:]) 84 | 85 | 86 | 87 | transformer = Transformer(d_model=hidden_features, nhead=transformer_attention_heads, 88 | num_encoder_layers=num_encoder_layers, num_decoder_layers=num_decoder_layers, 89 | dropout=dropout, dim_feedforward=4 * hidden_features, 90 | activation=transformer_activation) 91 | latent_tensor, features_encoded = transformer(features, None, query_positions, positional_embeddings) 92 | latent_tensor = latent_tensor.permute(2, 0, 1) 93 | print(latent_tensor.shape,features_encoded.shape) 94 | 95 | 96 | 97 | from Ours.Cell_DETR_master.segmentation import MultiHeadAttention, SegmentationHead, ResFeaturePyramidBlock, ResPACFeaturePyramidBlock 98 | segmentation_head_channels = ((128 + 8, 128), (128, 64), (64, 32)) 99 | segmentation_head_feature_channels = (256, 128, 64) 100 | segmentation_head_convolution = Conv2d 101 | segmentation_head_normalization = nn.InstanceNorm2d 102 | segmentation_head_activation = nn.LeakyReLU 103 | segmentation_head_block = ResPACFeaturePyramidBlock 104 | segmentation_head_final_activation = nn.Sigmoid 105 | segmentation_attention_head = MultiHeadAttention(query_dimension=hidden_features, 106 | hidden_features=hidden_features, 107 | number_of_heads=segmentation_attention_heads, 108 | dropout=dropout) 109 | segmentation_head = SegmentationHead(channels=segmentation_head_channels, 110 | feature_channels=segmentation_head_feature_channels, 111 | convolution=segmentation_head_convolution, 112 | normalization=segmentation_head_normalization, 113 | activation=segmentation_head_activation, 114 | block=segmentation_head_block, 115 | number_of_query_positions=number_of_query_positions, 116 | softmax=isinstance(segmentation_head_final_activation(), nn.Softmax)) 117 | 118 | 119 | 120 | bounding_box_attention_masks = segmentation_attention_head( 121 | latent_tensor, features_encoded.contiguous()) 122 | instance_segmentation_prediction = segmentation_head(features.contiguous(), 123 | bounding_box_attention_masks.contiguous(), 124 | feature_list[-2::-1]) -------------------------------------------------------------------------------- /lib/Cell_DETR_master/dataset.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Tuple, List 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | from torch.utils.data import Dataset 6 | import os 7 | import numpy as np 8 | 9 | import misc 10 | import augmentation 11 | 12 | 13 | class CellInstanceSegmentation(Dataset): 14 | """ 15 | This dataset implements the cell instance segmentation dataset for the DETR model. 16 | Dataset source: https://github.com/ChristophReich1996/BCS_Data/tree/master/Cell_Instance_Segmentation_Regular_Traps 17 | """ 18 | 19 | def __init__(self, path: str = "../../BCS_Data/Cell_Instance_Segmentation_Regular_Traps/train", 20 | normalize: bool = True, 21 | normalization_function: Callable[[torch.Tensor], torch.Tensor] = misc.normalize, 22 | augmentation: Tuple[augmentation.Augmentation, ...] = ( 23 | augmentation.VerticalFlip(), augmentation.NoiseInjection(), augmentation.ElasticDeformation()), 24 | augmentation_p: float = 0.5, return_absolute_bounding_box: bool = False, 25 | downscale: bool = True, downscale_shape: Tuple[int, int] = (128, 128), 26 | two_classes: bool = True) -> None: 27 | """ 28 | Constructor method 29 | :param path: (str) Path to dataset 30 | :param normalize: (bool) If true normalization_function is applied 31 | :param normalization_function: (Callable[[torch.Tensor], torch.Tensor]) Normalization function 32 | :param augmentation: (Tuple[Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]]) Tuple of 33 | augmentation functions to be applied 34 | :param augmentation_p: (float) Probability that an augmentation is utilized 35 | :param downscale: (bool) If true images and segmentation maps will be downscaled to a size of 256 X 256 36 | :param downscale_shape: (Tuple[int, int]) Target shape is downscale is utilized 37 | :param return_absolute_bounding_box: (Bool) If true the absolute bb is returned else the relative bb is returned 38 | :param two_classes: (bool) If true only two classes, trap and cell, will be utilized 39 | """ 40 | # Save parameters 41 | self.normalize = normalize 42 | self.normalization_function = normalization_function 43 | self.augmentation = augmentation 44 | self.augmentation_p = augmentation_p 45 | self.return_absolute_bounding_box = return_absolute_bounding_box 46 | self.downscale = downscale 47 | self.downscale_shape = downscale_shape 48 | self.two_class = two_classes 49 | # Get paths of input images 50 | self.inputs = [] 51 | for file in sorted(os.listdir(os.path.join(path, "inputs"))): 52 | self.inputs.append(os.path.join(path, "inputs", file)) 53 | # Get paths of instances 54 | self.instances = [] 55 | for file in sorted(os.listdir(os.path.join(path, "instances"))): 56 | self.instances.append(os.path.join(path, "instances", file)) 57 | # Get paths of class labels 58 | self.class_labels = [] 59 | for file in sorted(os.listdir(os.path.join(path, "classes"))): 60 | self.class_labels.append(os.path.join(path, "classes", file)) 61 | # Get paths of bounding boxes 62 | self.bounding_boxes = [] 63 | for file in sorted(os.listdir(os.path.join(path, "bounding_boxes"))): 64 | self.bounding_boxes.append(os.path.join(path, "bounding_boxes", file)) 65 | 66 | def __len__(self) -> int: 67 | """ 68 | Method returns the length of the dataset 69 | :return: (int) Length of the dataset 70 | """ 71 | return len(self.inputs) 72 | 73 | def __getitem__(self, item: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: 74 | """ 75 | Get item method 76 | :param item: (int) Item to be returned of the dataset 77 | :return: (Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]) Tuple including input image, 78 | bounding box, class label and instances. 79 | """ 80 | # Load data 81 | input = torch.load(self.inputs[item]).unsqueeze(dim=0) 82 | instances = torch.load(self.instances[item]) 83 | bounding_boxes = torch.load(self.bounding_boxes[item]) 84 | class_labels = torch.load(self.class_labels[item]) 85 | # Encode class labels as one-hot 86 | if self.two_class: 87 | class_labels = misc.to_one_hot(class_labels.clamp(max=2.0), num_classes=2 + 1) 88 | else: 89 | class_labels = misc.to_one_hot(class_labels, num_classes=3 + 1) 90 | # Normalize input if utilized 91 | if self.normalize: 92 | input = self.normalization_function(input) 93 | # Apply augmentation if needed 94 | if np.random.random() < self.augmentation_p and self.augmentation is not None: 95 | # Get augmentation 96 | augmentation_to_be_applied = np.random.choice(self.augmentation) 97 | # Apply augmentation 98 | if augmentation_to_be_applied.need_labels(): 99 | input, instances, bounding_boxes = augmentation_to_be_applied(input, instances, bounding_boxes) 100 | else: 101 | input = augmentation_to_be_applied(input) 102 | # Downscale data to 256 x 256 if utilized 103 | if self.downscale: 104 | # Apply height and width 105 | bounding_boxes[..., [0, 2]] = bounding_boxes[..., [0, 2]] * (self.downscale_shape[0] / input.shape[-1]) 106 | bounding_boxes[..., [1, 3]] = bounding_boxes[..., [1, 3]] * (self.downscale_shape[1] / input.shape[-2]) 107 | input = F.interpolate(input=input.unsqueeze(dim=0), 108 | size=self.downscale_shape, mode="bicubic", align_corners=False)[0] 109 | instances = (F.interpolate(input=instances.unsqueeze(dim=0), 110 | size=self.downscale_shape, mode="bilinear", align_corners=False)[ 111 | 0] > 0.75).float() 112 | # Convert absolute bounding box to relative bounding box of utilized 113 | if not self.return_absolute_bounding_box: 114 | bounding_boxes = misc.absolute_bounding_box_to_relative(bounding_boxes=bounding_boxes, 115 | height=input.shape[1], width=input.shape[2]) 116 | return input, instances, misc.bounding_box_x0y0x1y1_to_xcycwh(bounding_boxes), class_labels 117 | 118 | 119 | def collate_function_cell_instance_segmentation( 120 | batch: List[Tuple[torch.Tensor]]) -> \ 121 | Tuple[torch.Tensor, List[torch.Tensor], List[torch.Tensor], List[torch.Tensor]]: 122 | """ 123 | Collate function of instance segmentation dataset. 124 | :param batch: (Tuple[Iterable[torch.Tensor], Iterable[torch.Tensor], Iterable[torch.Tensor], Iterable[torch.Tensor]]) 125 | Batch of input data, instances maps, bounding boxes and class labels 126 | :return: (Tuple[torch.Tensor, Iterable[torch.Tensor], Iterable[torch.Tensor], Iterable[torch.Tensor]]) Batched input 127 | data, instances, bounding boxes and class labels are stored in a list due to the different instances. 128 | """ 129 | return torch.stack([input_samples[0] for input_samples in batch], dim=0), \ 130 | [input_samples[1] for input_samples in batch], \ 131 | [input_samples[2] for input_samples in batch], \ 132 | [input_samples[3] for input_samples in batch] 133 | 134 | 135 | -------------------------------------------------------------------------------- /lib/Cell_DETR_master/dcn2/.gitignore: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jcwang123/BA-Transformer/d1ccdff68beac82d18c2cd64bb800e51d7afc5e1/lib/Cell_DETR_master/dcn2/.gitignore -------------------------------------------------------------------------------- /lib/Cell_DETR_master/dcn2/LICENSE: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jcwang123/BA-Transformer/d1ccdff68beac82d18c2cd64bb800e51d7afc5e1/lib/Cell_DETR_master/dcn2/LICENSE -------------------------------------------------------------------------------- /lib/Cell_DETR_master/dcn2/README.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jcwang123/BA-Transformer/d1ccdff68beac82d18c2cd64bb800e51d7afc5e1/lib/Cell_DETR_master/dcn2/README.md -------------------------------------------------------------------------------- /lib/Cell_DETR_master/dcn2/build.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jcwang123/BA-Transformer/d1ccdff68beac82d18c2cd64bb800e51d7afc5e1/lib/Cell_DETR_master/dcn2/build.py -------------------------------------------------------------------------------- /lib/Cell_DETR_master/dcn2/build_modulated.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jcwang123/BA-Transformer/d1ccdff68beac82d18c2cd64bb800e51d7afc5e1/lib/Cell_DETR_master/dcn2/build_modulated.py -------------------------------------------------------------------------------- /lib/Cell_DETR_master/dcn2/functions/__init__.py: -------------------------------------------------------------------------------- 1 | from .deform_conv import DeformConvFunction, deform_conv_function 2 | from .modulated_dcn_func import DeformRoIPoolingFunction, ModulatedDeformConvFunction -------------------------------------------------------------------------------- /lib/Cell_DETR_master/dcn2/functions/deform_conv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Function 3 | from torch.nn.modules.utils import _pair 4 | 5 | from _ext import deform_conv 6 | 7 | 8 | def deform_conv_function(input, 9 | offset, 10 | weight, 11 | stride=1, 12 | padding=0, 13 | dilation=1, 14 | deform_groups=1, 15 | im2col_step=64): 16 | 17 | if input is not None and input.dim() != 4: 18 | raise ValueError( 19 | "Expected 4D tensor as input, got {}D tensor instead.".format( 20 | input.dim())) 21 | 22 | f = DeformConvFunction( 23 | _pair(stride), _pair(padding), _pair(dilation), deform_groups, im2col_step) 24 | return f(input, offset, weight) 25 | 26 | 27 | class DeformConvFunction(Function): 28 | def __init__(self, stride, padding, dilation, deformable_groups=1, im2col_step=64): 29 | super(DeformConvFunction, self).__init__() 30 | self.stride = stride 31 | self.padding = padding 32 | self.dilation = dilation 33 | self.deformable_groups = deformable_groups 34 | self.im2col_step = im2col_step 35 | 36 | def forward(self, input, offset, weight): 37 | self.save_for_backward(input, offset, weight) 38 | 39 | output = input.new(*self._output_size(input, weight)) 40 | 41 | self.bufs_ = [input.new(), input.new()] # columns, ones 42 | 43 | if not input.is_cuda: 44 | raise NotImplementedError 45 | else: 46 | if isinstance(input, torch.autograd.Variable): 47 | if not isinstance(input.data, torch.cuda.FloatTensor): 48 | raise NotImplementedError 49 | else: 50 | if not isinstance(input, torch.cuda.FloatTensor): 51 | raise NotImplementedError 52 | 53 | cur_im2col_step = min(self.im2col_step, input.shape[0]) 54 | assert (input.shape[0] % cur_im2col_step) == 0, 'im2col step must divide batchsize' 55 | deform_conv.deform_conv_forward_cuda( 56 | input, weight, offset, output, self.bufs_[0], self.bufs_[1], 57 | weight.size(3), weight.size(2), self.stride[1], self.stride[0], 58 | self.padding[1], self.padding[0], self.dilation[1], 59 | self.dilation[0], self.deformable_groups, cur_im2col_step) 60 | return output 61 | 62 | def backward(self, grad_output): 63 | input, offset, weight = self.saved_tensors 64 | 65 | grad_input = grad_offset = grad_weight = None 66 | 67 | if not grad_output.is_cuda: 68 | raise NotImplementedError 69 | else: 70 | if isinstance(grad_output, torch.autograd.Variable): 71 | if not isinstance(grad_output.data, torch.cuda.FloatTensor): 72 | raise NotImplementedError 73 | else: 74 | if not isinstance(grad_output, torch.cuda.FloatTensor): 75 | raise NotImplementedError 76 | 77 | cur_im2col_step = min(self.im2col_step, input.shape[0]) 78 | assert (input.shape[0] % cur_im2col_step) == 0, 'im2col step must divide batchsize' 79 | 80 | if self.needs_input_grad[0] or self.needs_input_grad[1]: 81 | grad_input = input.new(*input.size()).zero_() 82 | grad_offset = offset.new(*offset.size()).zero_() 83 | deform_conv.deform_conv_backward_input_cuda( 84 | input, offset, grad_output, grad_input, 85 | grad_offset, weight, self.bufs_[0], weight.size(3), 86 | weight.size(2), self.stride[1], self.stride[0], 87 | self.padding[1], self.padding[0], self.dilation[1], 88 | self.dilation[0], self.deformable_groups, cur_im2col_step) 89 | 90 | 91 | if self.needs_input_grad[2]: 92 | grad_weight = weight.new(*weight.size()).zero_() 93 | deform_conv.deform_conv_backward_parameters_cuda( 94 | input, offset, grad_output, 95 | grad_weight, self.bufs_[0], self.bufs_[1], weight.size(3), 96 | weight.size(2), self.stride[1], self.stride[0], 97 | self.padding[1], self.padding[0], self.dilation[1], 98 | self.dilation[0], self.deformable_groups, 1, cur_im2col_step) 99 | 100 | return grad_input, grad_offset, grad_weight 101 | 102 | def _output_size(self, input, weight): 103 | channels = weight.size(0) 104 | 105 | output_size = (input.size(0), channels) 106 | for d in range(input.dim() - 2): 107 | in_size = input.size(d + 2) 108 | pad = self.padding[d] 109 | kernel = self.dilation[d] * (weight.size(d + 2) - 1) + 1 110 | stride = self.stride[d] 111 | output_size += ((in_size + (2 * pad) - kernel) // stride + 1, ) 112 | if not all(map(lambda s: s > 0, output_size)): 113 | raise ValueError( 114 | "convolution input is too small (output would be {})".format( 115 | 'x'.join(map(str, output_size)))) 116 | return output_size 117 | -------------------------------------------------------------------------------- /lib/Cell_DETR_master/dcn2/functions/modulated_dcn_func.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | from __future__ import absolute_import 3 | from __future__ import print_function 4 | from __future__ import division 5 | 6 | import torch 7 | from torch.autograd import Function 8 | 9 | from _ext import modulated_dcn as _backend 10 | 11 | 12 | class ModulatedDeformConvFunction(Function): 13 | 14 | def __init__(self, stride, padding, dilation=1, deformable_groups=1): 15 | super(ModulatedDeformConvFunction, self).__init__() 16 | self.stride = stride 17 | self.padding = padding 18 | self.dilation = dilation 19 | self.deformable_groups = deformable_groups 20 | 21 | def forward(self, input, offset, mask, weight, bias): 22 | if not input.is_cuda: 23 | raise NotImplementedError 24 | if weight.requires_grad or mask.requires_grad or offset.requires_grad or input.requires_grad: 25 | self.save_for_backward(input, offset, mask, weight, bias) 26 | output = input.new(*self._infer_shape(input, weight)) 27 | self._bufs = [input.new(), input.new()] 28 | _backend.modulated_deform_conv_cuda_forward(input, weight, 29 | bias, self._bufs[0], 30 | offset, mask, 31 | output, self._bufs[1], 32 | weight.shape[2], weight.shape[3], 33 | self.stride, self.stride, 34 | self.padding, self.padding, 35 | self.dilation, self.dilation, 36 | self.deformable_groups) 37 | return output 38 | 39 | def backward(self, grad_output): 40 | if not grad_output.is_cuda: 41 | raise NotImplementedError 42 | input, offset, mask, weight, bias = self.saved_tensors 43 | grad_input = input.new(*input.size()).zero_() 44 | grad_offset = offset.new(*offset.size()).zero_() 45 | grad_mask = mask.new(*mask.size()).zero_() 46 | grad_weight = weight.new(*weight.size()).zero_() 47 | grad_bias = bias.new(*bias.size()).zero_() 48 | _backend.modulated_deform_conv_cuda_backward(input, weight, 49 | bias, self._bufs[0], 50 | offset, mask, 51 | self._bufs[1], 52 | grad_input, grad_weight, 53 | grad_bias, grad_offset, 54 | grad_mask, grad_output, 55 | weight.shape[2], weight.shape[3], 56 | self.stride, self.stride, 57 | self.padding, self.padding, 58 | self.dilation, self.dilation, 59 | self.deformable_groups) 60 | 61 | return grad_input, grad_offset, grad_mask, grad_weight, grad_bias 62 | 63 | def _infer_shape(self, input, weight): 64 | n = input.size(0) 65 | channels_out = weight.size(0) 66 | height, width = input.shape[2:4] 67 | kernel_h, kernel_w = weight.shape[2:4] 68 | height_out = (height + 2 * self.padding - 69 | (self.dilation * (kernel_h - 1) + 1)) // self.stride + 1 70 | width_out = (width + 2 * self.padding - (self.dilation * 71 | (kernel_w - 1) + 1)) // self.stride + 1 72 | return (n, channels_out, height_out, width_out) 73 | 74 | 75 | class DeformRoIPoolingFunction(Function): 76 | 77 | def __init__(self, 78 | spatial_scale, 79 | pooled_size, 80 | output_dim, 81 | no_trans, 82 | group_size=1, 83 | part_size=None, 84 | sample_per_part=4, 85 | trans_std=.0): 86 | super(DeformRoIPoolingFunction, self).__init__() 87 | self.spatial_scale = spatial_scale 88 | self.pooled_size = pooled_size 89 | self.output_dim = output_dim 90 | self.no_trans = no_trans 91 | self.group_size = group_size 92 | self.part_size = pooled_size if part_size is None else part_size 93 | self.sample_per_part = sample_per_part 94 | self.trans_std = trans_std 95 | 96 | assert self.trans_std >= 0.0 and self.trans_std <= 1.0 97 | 98 | def forward(self, data, rois, offset): 99 | if not data.is_cuda: 100 | raise NotImplementedError 101 | 102 | output = data.new(*self._infer_shape(data, rois)) 103 | output_count = data.new(*self._infer_shape(data, rois)) 104 | _backend.deform_psroi_pooling_cuda_forward(data, rois, offset, 105 | output, output_count, 106 | self.no_trans, self.spatial_scale, 107 | self.output_dim, self.group_size, 108 | self.pooled_size, self.part_size, 109 | self.sample_per_part, self.trans_std) 110 | 111 | # if data.requires_grad or rois.requires_grad or offset.requires_grad: 112 | # self.save_for_backward(data, rois, offset, output_count) 113 | self.data = data 114 | self.rois = rois 115 | self.offset = offset 116 | self.output_count = output_count 117 | 118 | return output 119 | 120 | def backward(self, grad_output): 121 | if not grad_output.is_cuda: 122 | raise NotImplementedError 123 | 124 | # data, rois, offset, output_count = self.saved_tensors 125 | data = self.data 126 | rois = self.rois 127 | offset = self.offset 128 | output_count = self.output_count 129 | grad_input = data.new(*data.size()).zero_() 130 | grad_offset = offset.new(*offset.size()).zero_() 131 | 132 | _backend.deform_psroi_pooling_cuda_backward(grad_output, 133 | data, 134 | rois, 135 | offset, 136 | output_count, 137 | grad_input, 138 | grad_offset, 139 | self.no_trans, 140 | self.spatial_scale, 141 | self.output_dim, 142 | self.group_size, 143 | self.pooled_size, 144 | self.part_size, 145 | self.sample_per_part, 146 | self.trans_std) 147 | return grad_input, torch.zeros(rois.shape).cuda(), grad_offset 148 | 149 | def _infer_shape(self, data, rois): 150 | # _, c, h, w = data.shape[:4] 151 | c = data.shape[1] 152 | n = rois.shape[0] 153 | return (n, self.output_dim, self.pooled_size, self.pooled_size) 154 | -------------------------------------------------------------------------------- /lib/Cell_DETR_master/dcn2/modules/__init__.py: -------------------------------------------------------------------------------- 1 | from .deform_conv import DeformConv 2 | from .modulated_dcn import DeformRoIPooling, ModulatedDeformConv, ModulatedDeformConvPack, ModulatedDeformRoIPoolingPack -------------------------------------------------------------------------------- /lib/Cell_DETR_master/dcn2/modules/deform_conv.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torch.nn.modules.module import Module 6 | from torch.nn.modules.utils import _pair 7 | from functions import deform_conv_function 8 | 9 | 10 | class DeformConv(Module): 11 | def __init__(self, 12 | in_channels, 13 | out_channels, 14 | kernel_size, 15 | stride=1, 16 | padding=0, 17 | dilation=1, 18 | num_deformable_groups=1): 19 | super(DeformConv, self).__init__() 20 | self.in_channels = in_channels 21 | self.out_channels = out_channels 22 | self.kernel_size = _pair(kernel_size) 23 | self.stride = _pair(stride) 24 | self.padding = _pair(padding) 25 | self.dilation = _pair(dilation) 26 | self.num_deformable_groups = num_deformable_groups 27 | 28 | self.weight = nn.Parameter( 29 | torch.Tensor(out_channels, in_channels, *self.kernel_size)) 30 | 31 | self.reset_parameters() 32 | 33 | def reset_parameters(self): 34 | n = self.in_channels 35 | for k in self.kernel_size: 36 | n *= k 37 | stdv = 1. / math.sqrt(n) 38 | self.weight.data.uniform_(-stdv, stdv) 39 | 40 | def forward(self, input, offset): 41 | return deform_conv_function(input, offset, self.weight, self.stride, 42 | self.padding, self.dilation, 43 | self.num_deformable_groups) 44 | -------------------------------------------------------------------------------- /lib/Cell_DETR_master/dcn2/modules/modulated_dcn.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | from __future__ import absolute_import 3 | from __future__ import print_function 4 | from __future__ import division 5 | 6 | import torch 7 | import math 8 | from torch import nn 9 | from torch.nn.modules.utils import _pair 10 | 11 | from functions.modulated_dcn_func import ModulatedDeformConvFunction 12 | from functions.modulated_dcn_func import DeformRoIPoolingFunction 13 | 14 | class ModulatedDeformConv(nn.Module): 15 | 16 | def __init__(self, in_channels, out_channels, 17 | kernel_size, stride, padding, dilation=1, deformable_groups=1, no_bias=True): 18 | super(ModulatedDeformConv, self).__init__() 19 | self.in_channels = in_channels 20 | self.out_channels = out_channels 21 | self.kernel_size = _pair(kernel_size) 22 | self.stride = stride 23 | self.padding = padding 24 | self.dilation = dilation 25 | self.deformable_groups = deformable_groups 26 | self.no_bias = no_bias 27 | 28 | self.weight = nn.Parameter(torch.Tensor(out_channels, in_channels, *self.kernel_size)) 29 | self.bias = nn.Parameter(torch.zeros(out_channels)) 30 | self.reset_parameters() 31 | if self.no_bias: 32 | self.bias.requires_grad = False 33 | 34 | def reset_parameters(self): 35 | n = self.in_channels 36 | for k in self.kernel_size: 37 | n *= k 38 | stdv = 1. / math.sqrt(n) 39 | self.weight.data.uniform_(-stdv, stdv) 40 | self.bias.data.zero_() 41 | 42 | def forward(self, input, offset, mask): 43 | func = ModulatedDeformConvFunction(self.stride, self.padding, self.dilation, self.deformable_groups) 44 | return func(input, offset, mask, self.weight, self.bias) 45 | 46 | 47 | class ModulatedDeformConvPack(ModulatedDeformConv): 48 | 49 | def __init__(self, in_channels, out_channels, 50 | kernel_size, stride, padding, 51 | dilation=1, deformable_groups=1, no_bias=False): 52 | super(ModulatedDeformConvPack, self).__init__(in_channels, out_channels, 53 | kernel_size, stride, padding, dilation, deformable_groups, no_bias) 54 | 55 | self.conv_offset_mask = nn.Conv2d(self.in_channels, 56 | self.deformable_groups * 3 * self.kernel_size[0] * self.kernel_size[1], 57 | kernel_size=self.kernel_size, 58 | stride=(self.stride, self.stride), 59 | padding=(self.padding, self.padding), 60 | bias=True) 61 | self.init_offset() 62 | 63 | def init_offset(self): 64 | self.conv_offset_mask.weight.data.zero_() 65 | self.conv_offset_mask.bias.data.zero_() 66 | 67 | def forward(self, input): 68 | out = self.conv_offset_mask(input) 69 | o1, o2, mask = torch.chunk(out, 3, dim=1) 70 | offset = torch.cat((o1, o2), dim=1) 71 | mask = torch.sigmoid(mask) 72 | func = ModulatedDeformConvFunction(self.stride, self.padding, self.dilation, self.deformable_groups) 73 | return func(input, offset, mask, self.weight, self.bias) 74 | 75 | 76 | class DeformRoIPooling(nn.Module): 77 | 78 | def __init__(self, 79 | spatial_scale, 80 | pooled_size, 81 | output_dim, 82 | no_trans, 83 | group_size=1, 84 | part_size=None, 85 | sample_per_part=4, 86 | trans_std=.0): 87 | super(DeformRoIPooling, self).__init__() 88 | self.spatial_scale = spatial_scale 89 | self.pooled_size = pooled_size 90 | self.output_dim = output_dim 91 | self.no_trans = no_trans 92 | self.group_size = group_size 93 | self.part_size = pooled_size if part_size is None else part_size 94 | self.sample_per_part = sample_per_part 95 | self.trans_std = trans_std 96 | self.func = DeformRoIPoolingFunction(self.spatial_scale, 97 | self.pooled_size, 98 | self.output_dim, 99 | self.no_trans, 100 | self.group_size, 101 | self.part_size, 102 | self.sample_per_part, 103 | self.trans_std) 104 | 105 | def forward(self, data, rois, offset): 106 | 107 | if self.no_trans: 108 | offset = data.new() 109 | return self.func(data, rois, offset) 110 | 111 | class ModulatedDeformRoIPoolingPack(DeformRoIPooling): 112 | 113 | def __init__(self, 114 | spatial_scale, 115 | pooled_size, 116 | output_dim, 117 | no_trans, 118 | group_size=1, 119 | part_size=None, 120 | sample_per_part=4, 121 | trans_std=.0, 122 | deform_fc_dim=1024): 123 | super(ModulatedDeformRoIPoolingPack, self).__init__(spatial_scale, 124 | pooled_size, 125 | output_dim, 126 | no_trans, 127 | group_size, 128 | part_size, 129 | sample_per_part, 130 | trans_std) 131 | 132 | self.deform_fc_dim = deform_fc_dim 133 | 134 | if not no_trans: 135 | self.func_offset = DeformRoIPoolingFunction(self.spatial_scale, 136 | self.pooled_size, 137 | self.output_dim, 138 | True, 139 | self.group_size, 140 | self.part_size, 141 | self.sample_per_part, 142 | self.trans_std) 143 | self.offset_fc = nn.Sequential( 144 | nn.Linear(self.pooled_size * self.pooled_size * self.output_dim, self.deform_fc_dim), 145 | nn.ReLU(inplace=True), 146 | nn.Linear(self.deform_fc_dim, self.deform_fc_dim), 147 | nn.ReLU(inplace=True), 148 | nn.Linear(self.deform_fc_dim, self.pooled_size * self.pooled_size * 2) 149 | ) 150 | self.offset_fc[4].weight.data.zero_() 151 | self.offset_fc[4].bias.data.zero_() 152 | self.mask_fc = nn.Sequential( 153 | nn.Linear(self.pooled_size * self.pooled_size * self.output_dim, self.deform_fc_dim), 154 | nn.ReLU(inplace=True), 155 | nn.Linear(self.deform_fc_dim, self.pooled_size * self.pooled_size * 1), 156 | nn.Sigmoid() 157 | ) 158 | self.mask_fc[2].weight.data.zero_() 159 | self.mask_fc[2].bias.data.zero_() 160 | 161 | def forward(self, data, rois): 162 | if self.no_trans: 163 | offset = data.new() 164 | else: 165 | n = rois.shape[0] 166 | offset = data.new() 167 | x = self.func_offset(data, rois, offset) 168 | offset = self.offset_fc(x.view(n, -1)) 169 | offset = offset.view(n, 2, self.pooled_size, self.pooled_size) 170 | mask = self.mask_fc(x.view(n, -1)) 171 | mask = mask.view(n, 1, self.pooled_size, self.pooled_size) 172 | feat = self.func(data, rois, offset) * mask 173 | return feat 174 | return self.func(data, rois, offset) 175 | -------------------------------------------------------------------------------- /lib/Cell_DETR_master/dcn2/src/cuda/deform_psroi_pooling_cuda.h: -------------------------------------------------------------------------------- 1 | /*! 2 | * Copyright (c) 2017 Microsoft 3 | * Licensed under The MIT License [see LICENSE for details] 4 | * \file deformable_psroi_pooling.cu 5 | * \brief 6 | * \author Yi Li, Guodong Zhang, Jifeng Dai 7 | */ 8 | /***************** Adapted by Charles Shang *********************/ 9 | 10 | #ifndef DCN_V2_PSROI_POOLING_CUDA 11 | #define DCN_V2_PSROI_POOLING_CUDA 12 | 13 | #ifdef __cplusplus 14 | extern "C" 15 | { 16 | #endif 17 | 18 | void DeformablePSROIPoolForward(cudaStream_t stream, 19 | const float *data, 20 | const float *bbox, 21 | const float *trans, 22 | float *out, 23 | float *top_count, 24 | const int batch, 25 | const int channels, 26 | const int height, 27 | const int width, 28 | const int num_bbox, 29 | const int channels_trans, 30 | const int no_trans, 31 | const float spatial_scale, 32 | const int output_dim, 33 | const int group_size, 34 | const int pooled_size, 35 | const int part_size, 36 | const int sample_per_part, 37 | const float trans_std); 38 | 39 | void DeformablePSROIPoolBackwardAcc(cudaStream_t stream, 40 | const float *out_grad, 41 | const float *data, 42 | const float *bbox, 43 | const float *trans, 44 | const float *top_count, 45 | float *in_grad, 46 | float *trans_grad, 47 | const int batch, 48 | const int channels, 49 | const int height, 50 | const int width, 51 | const int num_bbox, 52 | const int channels_trans, 53 | const int no_trans, 54 | const float spatial_scale, 55 | const int output_dim, 56 | const int group_size, 57 | const int pooled_size, 58 | const int part_size, 59 | const int sample_per_part, 60 | const float trans_std); 61 | 62 | #ifdef __cplusplus 63 | } 64 | #endif 65 | 66 | #endif -------------------------------------------------------------------------------- /lib/Cell_DETR_master/dcn2/src/cuda/modulated_deform_im2col_cuda.h: -------------------------------------------------------------------------------- 1 | /*! 2 | ******************* BEGIN Caffe Copyright Notice and Disclaimer **************** 3 | * 4 | * COPYRIGHT 5 | * 6 | * All contributions by the University of California: 7 | * Copyright (c) 2014-2017 The Regents of the University of California (Regents) 8 | * All rights reserved. 9 | * 10 | * All other contributions: 11 | * Copyright (c) 2014-2017, the respective contributors 12 | * All rights reserved. 13 | * 14 | * Caffe uses a shared copyright model: each contributor holds copyright over 15 | * their contributions to Caffe. The project versioning records all such 16 | * contribution and copyright details. If a contributor wants to further mark 17 | * their specific copyright on a particular contribution, they should indicate 18 | * their copyright solely in the commit message of the change when it is 19 | * committed. 20 | * 21 | * LICENSE 22 | * 23 | * Redistribution and use in source and binary forms, with or without 24 | * modification, are permitted provided that the following conditions are met: 25 | * 26 | * 1. Redistributions of source code must retain the above copyright notice, this 27 | * list of conditions and the following disclaimer. 28 | * 2. Redistributions in binary form must reproduce the above copyright notice, 29 | * this list of conditions and the following disclaimer in the documentation 30 | * and/or other materials provided with the distribution. 31 | * 32 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 33 | * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 34 | * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 35 | * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR 36 | * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 37 | * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 38 | * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 39 | * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 40 | * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 41 | * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 42 | * 43 | * CONTRIBUTION AGREEMENT 44 | * 45 | * By contributing to the BVLC/caffe repository through pull-request, comment, 46 | * or otherwise, the contributor releases their content to the 47 | * license and copyright terms herein. 48 | * 49 | ***************** END Caffe Copyright Notice and Disclaimer ******************** 50 | * 51 | * Copyright (c) 2018 Microsoft 52 | * Licensed under The MIT License [see LICENSE for details] 53 | * \file modulated_deformable_im2col.h 54 | * \brief Function definitions of converting an image to 55 | * column matrix based on kernel, padding, dilation, and offset. 56 | * These functions are mainly used in deformable convolution operators. 57 | * \ref: https://arxiv.org/abs/1811.11168 58 | * \author Yuwen Xiong, Haozhi Qi, Jifeng Dai, Xizhou Zhu, Han Hu 59 | */ 60 | 61 | /***************** Adapted by Charles Shang *********************/ 62 | 63 | #ifndef DCN_V2_IM2COL_CUDA 64 | #define DCN_V2_IM2COL_CUDA 65 | 66 | #ifdef __cplusplus 67 | extern "C" 68 | { 69 | #endif 70 | 71 | void modulated_deformable_im2col_cuda(cudaStream_t stream, 72 | const float *data_im, const float *data_offset, const float *data_mask, 73 | const int batch_size, const int channels, const int height_im, const int width_im, 74 | const int height_col, const int width_col, const int kernel_h, const int kenerl_w, 75 | const int pad_h, const int pad_w, const int stride_h, const int stride_w, 76 | const int dilation_h, const int dilation_w, 77 | const int deformable_group, float *data_col); 78 | 79 | void modulated_deformable_col2im_cuda(cudaStream_t stream, 80 | const float *data_col, const float *data_offset, const float *data_mask, 81 | const int batch_size, const int channels, const int height_im, const int width_im, 82 | const int height_col, const int width_col, const int kernel_h, const int kenerl_w, 83 | const int pad_h, const int pad_w, const int stride_h, const int stride_w, 84 | const int dilation_h, const int dilation_w, 85 | const int deformable_group, float *grad_im); 86 | 87 | void modulated_deformable_col2im_coord_cuda(cudaStream_t stream, 88 | const float *data_col, const float *data_im, const float *data_offset, const float *data_mask, 89 | const int batch_size, const int channels, const int height_im, const int width_im, 90 | const int height_col, const int width_col, const int kernel_h, const int kenerl_w, 91 | const int pad_h, const int pad_w, const int stride_h, const int stride_w, 92 | const int dilation_h, const int dilation_w, 93 | const int deformable_group, 94 | float *grad_offset, float *grad_mask); 95 | 96 | #ifdef __cplusplus 97 | } 98 | #endif 99 | 100 | #endif -------------------------------------------------------------------------------- /lib/Cell_DETR_master/dcn2/src/deform_conv.c: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | int deform_conv_forward(THFloatTensor *input, THFloatTensor *offset, 4 | THFloatTensor *output) 5 | { 6 | // if (!THFloatTensor_isSameSizeAs(input1, input2)) 7 | // return 0; 8 | // THFloatTensor_resizeAs(output, input); 9 | // THFloatTensor_cadd(output, input1, 1.0, input2); 10 | return 1; 11 | } 12 | 13 | int deform_conv_backward(THFloatTensor *grad_output, THFloatTensor *grad_input, 14 | THFloatTensor *grad_offset) 15 | { 16 | // THFloatTensor_resizeAs(grad_input, grad_output); 17 | // THFloatTensor_fill(grad_input, 1); 18 | return 1; 19 | } 20 | -------------------------------------------------------------------------------- /lib/Cell_DETR_master/dcn2/src/deform_conv.h: -------------------------------------------------------------------------------- 1 | int deform_conv_forward(THFloatTensor *input, THFloatTensor *offset, 2 | THFloatTensor *output); 3 | int deform_conv_backward(THFloatTensor *grad_output, THFloatTensor *grad_input, 4 | THFloatTensor *grad_offset); 5 | -------------------------------------------------------------------------------- /lib/Cell_DETR_master/dcn2/src/deform_conv_cuda.h: -------------------------------------------------------------------------------- 1 | int deform_conv_forward_cuda(THCudaTensor *input, 2 | THCudaTensor *weight, /*THCudaTensor * bias, */ 3 | THCudaTensor *offset, THCudaTensor *output, 4 | THCudaTensor *columns, THCudaTensor *ones, int kW, 5 | int kH, int dW, int dH, int padW, int padH, 6 | int dilationW, int dilationH, 7 | int deformable_group, int im2col_step); 8 | 9 | int deform_conv_backward_input_cuda( 10 | THCudaTensor *input, THCudaTensor *offset, THCudaTensor *gradOutput, 11 | THCudaTensor *gradInput, THCudaTensor *gradOffset, THCudaTensor *weight, 12 | THCudaTensor *columns, int kW, int kH, int dW, int dH, int padW, int padH, 13 | int dilationW, int dilationH, int deformable_group, int im2col_step); 14 | 15 | int deform_conv_backward_parameters_cuda( 16 | THCudaTensor *input, THCudaTensor *offset, THCudaTensor *gradOutput, 17 | THCudaTensor *gradWeight, /*THCudaTensor *gradBias, */ 18 | THCudaTensor *columns, THCudaTensor *ones, int kW, int kH, int dW, int dH, 19 | int padW, int padH, int dilationW, int dilationH, int deformable_group, 20 | float scale, int im2col_step); 21 | -------------------------------------------------------------------------------- /lib/Cell_DETR_master/dcn2/src/deform_conv_cuda_kernel.h: -------------------------------------------------------------------------------- 1 | template 2 | void deformable_im2col(cudaStream_t stream, const DType *data_im, 3 | const DType *data_offset, const int channels, 4 | const int height, const int width, const int ksize_h, 5 | const int ksize_w, const int pad_h, const int pad_w, 6 | const int stride_h, const int stride_w, 7 | const int dilation_h, const int dilation_w, 8 | const int parallel_imgs, 9 | const int deformable_group, DType *data_col); 10 | 11 | template 12 | void deformable_col2im(cudaStream_t stream, const DType *data_col, 13 | const DType *data_offset, const int channels, 14 | const int height, const int width, const int ksize_h, 15 | const int ksize_w, const int pad_h, const int pad_w, 16 | const int stride_h, const int stride_w, 17 | const int dilation_h, const int dilation_w, 18 | const int parallel_imgs, 19 | const int deformable_group, DType *grad_im); 20 | 21 | template 22 | void deformable_col2im_coord(cudaStream_t stream, const DType *data_col, 23 | const DType *data_im, const DType *data_offset, 24 | const int channels, const int height, 25 | const int width, const int ksize_h, 26 | const int ksize_w, const int pad_h, 27 | const int pad_w, const int stride_h, 28 | const int stride_w, const int dilation_h, 29 | const int dilation_w, const int parallel_imgs, 30 | const int deformable_group, DType *grad_offset); 31 | -------------------------------------------------------------------------------- /lib/Cell_DETR_master/dcn2/src/modulated_dcn.c: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | void modulated_deform_conv_forward(THFloatTensor *input, THFloatTensor *weight, 6 | THFloatTensor *bias, THFloatTensor *ones, 7 | THFloatTensor *offset, THFloatTensor *mask, 8 | THFloatTensor *output, THFloatTensor *columns, 9 | const int pad_h, const int pad_w, 10 | const int stride_h, const int stride_w, 11 | const int dilation_h, const int dilation_w, 12 | const int deformable_group) 13 | { 14 | printf("only implemented in GPU"); 15 | } 16 | void modulated_deform_conv_backward(THFloatTensor *input, THFloatTensor *weight, 17 | THFloatTensor *bias, THFloatTensor *ones, 18 | THFloatTensor *offset, THFloatTensor *mask, 19 | THFloatTensor *output, THFloatTensor *columns, 20 | THFloatTensor *grad_input, THFloatTensor *grad_weight, 21 | THFloatTensor *grad_bias, THFloatTensor *grad_offset, 22 | THFloatTensor *grad_mask, THFloatTensor *grad_output, 23 | int kernel_h, int kernel_w, 24 | int stride_h, int stride_w, 25 | int pad_h, int pad_w, 26 | int dilation_h, int dilation_w, 27 | int deformable_group) 28 | { 29 | printf("only implemented in GPU"); 30 | } -------------------------------------------------------------------------------- /lib/Cell_DETR_master/dcn2/src/modulated_dcn.h: -------------------------------------------------------------------------------- 1 | void modulated_deform_conv_forward(THFloatTensor *input, THFloatTensor *weight, 2 | THFloatTensor *bias, THFloatTensor *ones, 3 | THFloatTensor *offset, THFloatTensor *mask, 4 | THFloatTensor *output, THFloatTensor *columns, 5 | const int pad_h, const int pad_w, 6 | const int stride_h, const int stride_w, 7 | const int dilation_h, const int dilation_w, 8 | const int deformable_group); 9 | void modulated_deform_conv_backward(THFloatTensor *input, THFloatTensor *weight, 10 | THFloatTensor *bias, THFloatTensor *ones, 11 | THFloatTensor *offset, THFloatTensor *mask, 12 | THFloatTensor *output, THFloatTensor *columns, 13 | THFloatTensor *grad_input, THFloatTensor *grad_weight, 14 | THFloatTensor *grad_bias, THFloatTensor *grad_offset, 15 | THFloatTensor *grad_mask, THFloatTensor *grad_output, 16 | int kernel_h, int kernel_w, 17 | int stride_h, int stride_w, 18 | int pad_h, int pad_w, 19 | int dilation_h, int dilation_w, 20 | int deformable_group); -------------------------------------------------------------------------------- /lib/Cell_DETR_master/dcn2/src/modulated_dcn_cuda.h: -------------------------------------------------------------------------------- 1 | // #ifndef DCN_V2_CUDA 2 | // #define DCN_V2_CUDA 3 | 4 | // #ifdef __cplusplus 5 | // extern "C" 6 | // { 7 | // #endif 8 | 9 | void modulated_deform_conv_cuda_forward(THCudaTensor *input, THCudaTensor *weight, 10 | THCudaTensor *bias, THCudaTensor *ones, 11 | THCudaTensor *offset, THCudaTensor *mask, 12 | THCudaTensor *output, THCudaTensor *columns, 13 | int kernel_h, int kernel_w, 14 | const int stride_h, const int stride_w, 15 | const int pad_h, const int pad_w, 16 | const int dilation_h, const int dilation_w, 17 | const int deformable_group); 18 | void modulated_deform_conv_cuda_backward(THCudaTensor *input, THCudaTensor *weight, 19 | THCudaTensor *bias, THCudaTensor *ones, 20 | THCudaTensor *offset, THCudaTensor *mask, 21 | THCudaTensor *columns, 22 | THCudaTensor *grad_input, THCudaTensor *grad_weight, 23 | THCudaTensor *grad_bias, THCudaTensor *grad_offset, 24 | THCudaTensor *grad_mask, THCudaTensor *grad_output, 25 | int kernel_h, int kernel_w, 26 | int stride_h, int stride_w, 27 | int pad_h, int pad_w, 28 | int dilation_h, int dilation_w, 29 | int deformable_group); 30 | 31 | void deform_psroi_pooling_cuda_forward(THCudaTensor * input, THCudaTensor * bbox, 32 | THCudaTensor * trans, 33 | THCudaTensor * out, THCudaTensor * top_count, 34 | const int no_trans, 35 | const float spatial_scale, 36 | const int output_dim, 37 | const int group_size, 38 | const int pooled_size, 39 | const int part_size, 40 | const int sample_per_part, 41 | const float trans_std); 42 | 43 | void deform_psroi_pooling_cuda_backward(THCudaTensor * out_grad, 44 | THCudaTensor * input, THCudaTensor * bbox, 45 | THCudaTensor * trans, THCudaTensor * top_count, 46 | THCudaTensor * input_grad, THCudaTensor * trans_grad, 47 | const int no_trans, 48 | const float spatial_scale, 49 | const int output_dim, 50 | const int group_size, 51 | const int pooled_size, 52 | const int part_size, 53 | const int sample_per_part, 54 | const float trans_std); 55 | -------------------------------------------------------------------------------- /lib/Cell_DETR_master/dcn2/test.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jcwang123/BA-Transformer/d1ccdff68beac82d18c2cd64bb800e51d7afc5e1/lib/Cell_DETR_master/dcn2/test.py -------------------------------------------------------------------------------- /lib/Cell_DETR_master/dcn2/test_modulated.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jcwang123/BA-Transformer/d1ccdff68beac82d18c2cd64bb800e51d7afc5e1/lib/Cell_DETR_master/dcn2/test_modulated.py -------------------------------------------------------------------------------- /lib/Cell_DETR_master/images/CELL_DETR.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jcwang123/BA-Transformer/d1ccdff68beac82d18c2cd64bb800e51d7afc5e1/lib/Cell_DETR_master/images/CELL_DETR.PNG -------------------------------------------------------------------------------- /lib/Cell_DETR_master/main.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | import os 3 | import setproctitle 4 | 5 | # Manage command line arguments 6 | parser = ArgumentParser() 7 | parser.add_argument("--train", default=False, action="store_true", 8 | help="Binary flag. If set training will be performed.") 9 | parser.add_argument("--val", default=False, action="store_true", 10 | help="Binary flag. If set validation will be performed.") 11 | parser.add_argument("--test", default=False, action="store_true", 12 | help="Binary flag. If set testing will be performed.") 13 | parser.add_argument("--cuda_devices", default="0", type=str, 14 | help="String of cuda device indexes to be used. Indexes must be separated by a comma.") 15 | parser.add_argument("--data_parallel", default=False, action="store_true", 16 | help="Binary flag. If multi GPU training should be utilized set flag.") 17 | parser.add_argument("--cpu", default=False, action="store_true", 18 | help="Binary flag. If set all operations are performed on the CPU.") 19 | parser.add_argument("--epochs", default=200, type=int, 20 | help="Number of epochs to perform while training.") 21 | parser.add_argument("--lr_schedule", default=False, action="store_true", 22 | help="Binary flag. If set the learning rate will be reduced after epoch 50 and 100.") 23 | parser.add_argument("--ohem", default=False, action="store_true", 24 | help="Binary flag. If set online heard example mining is utilized.") 25 | parser.add_argument("--ohem_fraction", default=0.75, type=float, 26 | help="Ohem fraction to be applied when performing ohem.") 27 | parser.add_argument("--batch_size", default=4, type=int, 28 | help="Batch size to be utilized while training.") 29 | parser.add_argument("--path_to_data", default="../../BCS_Data/Cell_Instance_Segmentation_Regular_Traps", type=str, 30 | help="Path to dataset.") 31 | parser.add_argument("--augmentation_p", default=0.6, type=float, 32 | help="Probability that data augmentation is applied on training data sample.") 33 | parser.add_argument("--lr_main", default=1e-04, type=float, 34 | help="Learning rate of the detr model (excluding backbone).") 35 | parser.add_argument("--lr_backbone", default=1e-05, type=float, 36 | help="Learning rate of the backbone network.") 37 | parser.add_argument("--lr_segmentation_head", default=1e-06, type=float, 38 | help="Learning rate of the segmentation head, only applied when seg head is trained exclusively.") 39 | parser.add_argument("--no_pac", default=False, action="store_true", 40 | help="Binary flag. If set no pixel adaptive convolutions will be utilized in the segmentation head.") 41 | parser.add_argument("--load_model", default="", type=str, 42 | help="Path to model to be loaded.") 43 | parser.add_argument("--dropout", default=0.05, type=float, 44 | help="Dropout factor to be used in model.") 45 | parser.add_argument("--three_classes", default=False, action="store_true", 46 | help="Binary flag, If set three classes (trap, cell of interest and add. cells) will be utilized.") 47 | parser.add_argument("--softmax", default=False, action="store_true", 48 | help="Binary flag, If set a softmax will be applied to the segmentation prediction instead sigmoid.") 49 | parser.add_argument("--only_train_segmentation_head_after_epoch", default=150, type=int, 50 | help="Number of epoch where only the segmentation head is trained.") 51 | parser.add_argument("--no_deform_conv", default=False, action="store_true", 52 | help="Binary flag. If set no deformable convolutions will be utilized.") 53 | parser.add_argument("--no_pau", default=False, action="store_true", 54 | help="Binary flag. If set no pade activation unit is utilized, however, a leaky ReLU is utilized.") 55 | 56 | # Get arguments 57 | args = parser.parse_args() 58 | 59 | # Set device type 60 | device = "cpu" if args.cpu else "cuda" 61 | 62 | # Set cuda devices 63 | os.environ["CUDA_VISIBLE_DEVICES"] = args.cuda_devices 64 | 65 | setproctitle.setproctitle("Cell-DETR") 66 | 67 | import torch 68 | import torch.nn as nn 69 | from torch.utils.data import DataLoader 70 | from modules.modulated_deform_conv import ModulatedDeformConvPack 71 | from pade_activation_unit.utils import PAU 72 | 73 | # Avoid data loader bug 74 | import resource 75 | 76 | rlimit = resource.getrlimit(resource.RLIMIT_NOFILE) 77 | resource.setrlimit(resource.RLIMIT_NOFILE, (2 ** 12, rlimit[1])) 78 | 79 | from detr import CellDETR 80 | from dataset import CellInstanceSegmentation, collate_function_cell_instance_segmentation 81 | from lossfunction import InstanceSegmentationLoss, SegmentationLoss, MultiClassSegmentationLoss 82 | from model_wrapper import ModelWrapper 83 | from segmentation import ResFeaturePyramidBlock, ResPACFeaturePyramidBlock 84 | 85 | if __name__ == '__main__': 86 | # Init detr 87 | detr = CellDETR(num_classes=3 if args.three_classes else 2, 88 | segmentation_head_block=ResPACFeaturePyramidBlock if not args.no_pac else ResFeaturePyramidBlock, 89 | segmentation_head_final_activation=nn.Softmax if args.softmax else nn.Sigmoid, 90 | backbone_convolution=nn.Conv2d if args.no_deform_conv else ModulatedDeformConvPack, 91 | segmentation_head_convolution=nn.Conv2d if args.no_deform_conv else ModulatedDeformConvPack, 92 | transformer_activation=nn.LeakyReLU if args.no_pau else PAU, 93 | backbone_activation=nn.LeakyReLU if args.no_pau else PAU, 94 | bounding_box_head_activation=nn.LeakyReLU if args.no_pau else PAU, 95 | classification_head_activation=nn.LeakyReLU if args.no_pau else PAU, 96 | segmentation_head_activation=nn.LeakyReLU if args.no_pau else PAU) 97 | if args.load_model != "": 98 | detr.load_state_dict(torch.load(args.load_model)) 99 | # Print network 100 | print(detr) 101 | # Print number of parameters 102 | print("# DETR parameters", sum([p.numel() for p in detr.parameters()])) 103 | # Init optimizer 104 | detr_optimizer = torch.optim.AdamW(detr.get_parameters(lr_main=args.lr_main, lr_backbone=args.lr_backbone), 105 | weight_decay=1e-06) 106 | detr_segmentation_optimizer = torch.optim.AdamW(detr.get_segmentation_head_parameters(lr=args.lr_segmentation_head), 107 | weight_decay=1e-06) 108 | # Init data parallel if utilized 109 | if args.data_parallel: 110 | detr = torch.nn.DataParallel(detr) 111 | # Init learning rate schedule if utilized 112 | if args.lr_schedule: 113 | learning_rate_schedule = torch.optim.lr_scheduler.MultiStepLR(detr_optimizer, milestones=[50, 100], gamma=0.1) 114 | else: 115 | learning_rate_schedule = None 116 | # Init datasets 117 | training_dataset = DataLoader( 118 | CellInstanceSegmentation(path=os.path.join(args.path_to_data, "train"), 119 | augmentation_p=args.augmentation_p, two_classes=not args.three_classes), 120 | collate_fn=collate_function_cell_instance_segmentation, batch_size=args.batch_size, num_workers=20, 121 | shuffle=True) 122 | validation_dataset = DataLoader( 123 | CellInstanceSegmentation(path=os.path.join(args.path_to_data, "val"), 124 | augmentation_p=0.0, two_classes=not args.three_classes), 125 | collate_fn=collate_function_cell_instance_segmentation, batch_size=1, num_workers=1, shuffle=False) 126 | test_dataset = DataLoader( 127 | CellInstanceSegmentation(path=os.path.join(args.path_to_data, "test"), 128 | augmentation_p=0.0, two_classes=not args.three_classes), 129 | collate_fn=collate_function_cell_instance_segmentation, batch_size=1, num_workers=1, shuffle=False) 130 | # Model wrapper 131 | model_wrapper = ModelWrapper(detr=detr, 132 | detr_optimizer=detr_optimizer, 133 | detr_segmentation_optimizer=detr_segmentation_optimizer, 134 | training_dataset=training_dataset, 135 | validation_dataset=validation_dataset, 136 | test_dataset=test_dataset, 137 | loss_function=InstanceSegmentationLoss( 138 | segmentation_loss=SegmentationLoss(), 139 | ohem=args.ohem, 140 | ohem_faction=args.ohem_fraction), 141 | device=device) 142 | # Perform training 143 | if args.train: 144 | model_wrapper.train(epochs=args.epochs, 145 | optimize_only_segmentation_head_after_epoch=args.only_train_segmentation_head_after_epoch) 146 | # Perform validation 147 | if args.val: 148 | model_wrapper.validate(number_of_plots=30) 149 | # Perform testing 150 | if args.test: 151 | model_wrapper.test() 152 | -------------------------------------------------------------------------------- /lib/Cell_DETR_master/matcher.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, List 2 | 3 | import torch 4 | import torch.nn as nn 5 | from scipy.optimize import linear_sum_assignment 6 | 7 | import misc 8 | 9 | 10 | class HungarianMatcher(nn.Module): 11 | """ 12 | This class implements a hungarian algorithm based matcher for DETR. 13 | """ 14 | 15 | def __init__(self, weight_classification: float = 1.0, 16 | weight_bb_l1: float = 1.0, 17 | weight_bb_giou: float = 1.0) -> None: 18 | # Call super constructor 19 | super(HungarianMatcher, self).__init__() 20 | # Save parameters 21 | self.weight_classification = weight_classification 22 | self.weight_bb_l1 = weight_bb_l1 23 | self.weight_bb_giou = weight_bb_giou 24 | 25 | def __repr__(self): 26 | """ 27 | Get representation of the matcher module 28 | :return: (str) String including information 29 | """ 30 | return "{}, W classification:{}, W BB L1:{}, W BB gIoU".format(self.__class__.__name__, 31 | self.weight_classification, self.weight_bb_l1, 32 | self.weight_bb_giou) 33 | 34 | @torch.no_grad() 35 | def forward(self, prediction_classification: torch.Tensor, 36 | prediction_bounding_box: torch.Tensor, 37 | label_classification: Tuple[torch.Tensor], 38 | label_bounding_box: Tuple[torch.Tensor]) -> List[Tuple[torch.Tensor, torch.Tensor]]: 39 | """ 40 | Forward pass computes the permutation produced by the hungarian algorithm. 41 | :param prediction_classification: (torch.Tensor) Classification prediction (batch size, # queries, classes + 1) 42 | :param prediction_bounding_box: (torch.Tensor) BB predictions (batch size, # queries, 4) 43 | :param label_classification: (Tuple[torch.Tensor]) Classification label batched [(instances, classes + 1)] 44 | :param label_bounding_box: (Tuple[torch.Tensor]) BB label batched [(instances, 4)] 45 | :return: (torch.Tensor) Permutation of shape (batch size, instances) 46 | """ 47 | # Save shapes 48 | batch_size, number_of_queries = prediction_classification.shape[:2] 49 | # Get number of instances in each training sample 50 | number_of_instances = [label_bounding_box_instance.shape[0] for label_bounding_box_instance in 51 | label_bounding_box] 52 | # Flatten to shape [batch size * # queries, classes + 1] 53 | prediction_classification = prediction_classification.flatten(start_dim=0, end_dim=1) 54 | # Flatten to shape [batch size * # queries, 4] 55 | prediction_bounding_box = prediction_bounding_box.flatten(start_dim=0, end_dim=1) 56 | # Class label to index 57 | # Concat labels 58 | label_classification = torch.cat([instance.argmax(dim=-1) for instance in label_classification], dim=0) 59 | label_bounding_box = torch.cat([instance for instance in label_bounding_box], dim=0) 60 | # Compute classification cost 61 | cost_classification = -prediction_classification[:, label_classification.long()] 62 | # Compute the L1 cost of bounding boxes 63 | cost_bounding_boxes_l1 = torch.cdist(prediction_bounding_box, label_bounding_box, p=1) 64 | # Compute gIoU cost of bounding boxes 65 | cost_bounding_boxes_giou = -misc.giou_for_matching( 66 | misc.bounding_box_xcycwh_to_x0y0x1y1(prediction_bounding_box), 67 | misc.bounding_box_xcycwh_to_x0y0x1y1(label_bounding_box)) 68 | # Construct cost matrix 69 | cost_matrix = self.weight_classification * cost_classification \ 70 | + self.weight_bb_l1 * cost_bounding_boxes_l1 \ 71 | + self.weight_bb_giou * cost_bounding_boxes_giou 72 | cost_matrix = cost_matrix.view(batch_size, number_of_queries, -1).cpu().clamp(min=-1e20, max=1e20) 73 | # Get optimal indexes 74 | indexes = [linear_sum_assignment(cost_vector[index]) for index, cost_vector in 75 | enumerate(cost_matrix.split(number_of_instances, dim=-1))] 76 | # Convert indexes to list of prediction index and label index 77 | return [(torch.as_tensor(index_prediction, dtype=torch.int), torch.as_tensor(index_label, dtype=torch.int)) for 78 | index_prediction, index_label in indexes] 79 | -------------------------------------------------------------------------------- /lib/Cell_DETR_master/pade_activation_unit/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jcwang123/BA-Transformer/d1ccdff68beac82d18c2cd64bb800e51d7afc5e1/lib/Cell_DETR_master/pade_activation_unit/__init__.py -------------------------------------------------------------------------------- /lib/Cell_DETR_master/pade_activation_unit/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jcwang123/BA-Transformer/d1ccdff68beac82d18c2cd64bb800e51d7afc5e1/lib/Cell_DETR_master/pade_activation_unit/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /lib/Cell_DETR_master/pade_activation_unit/__pycache__/utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jcwang123/BA-Transformer/d1ccdff68beac82d18c2cd64bb800e51d7afc5e1/lib/Cell_DETR_master/pade_activation_unit/__pycache__/utils.cpython-37.pyc -------------------------------------------------------------------------------- /lib/Cell_DETR_master/pade_activation_unit/cuda/pau_cuda.cpp: -------------------------------------------------------------------------------- 1 | 2 | #include 3 | #include 4 | #include 5 | 6 | #define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor") 7 | #define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") 8 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 9 | 10 | 11 | at::Tensor pau_cuda_forward_3_3(torch::Tensor x, torch::Tensor n, torch::Tensor d); 12 | std::vector pau_cuda_backward_3_3(torch::Tensor grad_output, torch::Tensor x, torch::Tensor n, torch::Tensor d); 13 | 14 | at::Tensor pau_cuda_forward_4_4(torch::Tensor x, torch::Tensor n, torch::Tensor d); 15 | std::vector pau_cuda_backward_4_4(torch::Tensor grad_output, torch::Tensor x, torch::Tensor n, torch::Tensor d); 16 | 17 | at::Tensor pau_cuda_forward_5_5(torch::Tensor x, torch::Tensor n, torch::Tensor d); 18 | std::vector pau_cuda_backward_5_5(torch::Tensor grad_output, torch::Tensor x, torch::Tensor n, torch::Tensor d); 19 | 20 | at::Tensor pau_cuda_forward_6_6(torch::Tensor x, torch::Tensor n, torch::Tensor d); 21 | std::vector pau_cuda_backward_6_6(torch::Tensor grad_output, torch::Tensor x, torch::Tensor n, torch::Tensor d); 22 | 23 | at::Tensor pau_cuda_forward_7_7(torch::Tensor x, torch::Tensor n, torch::Tensor d); 24 | std::vector pau_cuda_backward_7_7(torch::Tensor grad_output, torch::Tensor x, torch::Tensor n, torch::Tensor d); 25 | 26 | at::Tensor pau_cuda_forward_8_8(torch::Tensor x, torch::Tensor n, torch::Tensor d); 27 | std::vector pau_cuda_backward_8_8(torch::Tensor grad_output, torch::Tensor x, torch::Tensor n, torch::Tensor d); 28 | 29 | at::Tensor pau_cuda_forward_5_4(torch::Tensor x, torch::Tensor n, torch::Tensor d); 30 | std::vector pau_cuda_backward_5_4(torch::Tensor grad_output, torch::Tensor x, torch::Tensor n, torch::Tensor d); 31 | 32 | 33 | at::Tensor pau_forward__3_3(torch::Tensor x, torch::Tensor n, torch::Tensor d) { 34 | CHECK_INPUT(x); 35 | CHECK_INPUT(n); 36 | CHECK_INPUT(d); 37 | 38 | return pau_cuda_forward_3_3(x, n, d); 39 | } 40 | std::vector pau_backward__3_3(torch::Tensor grad_output, torch::Tensor x, torch::Tensor n, torch::Tensor d) { 41 | CHECK_INPUT(grad_output); 42 | CHECK_INPUT(x); 43 | CHECK_INPUT(n); 44 | CHECK_INPUT(d); 45 | 46 | return pau_cuda_backward_3_3(grad_output, x, n, d); 47 | } 48 | 49 | at::Tensor pau_forward__4_4(torch::Tensor x, torch::Tensor n, torch::Tensor d) { 50 | CHECK_INPUT(x); 51 | CHECK_INPUT(n); 52 | CHECK_INPUT(d); 53 | 54 | return pau_cuda_forward_4_4(x, n, d); 55 | } 56 | std::vector pau_backward__4_4(torch::Tensor grad_output, torch::Tensor x, torch::Tensor n, torch::Tensor d) { 57 | CHECK_INPUT(grad_output); 58 | CHECK_INPUT(x); 59 | CHECK_INPUT(n); 60 | CHECK_INPUT(d); 61 | 62 | return pau_cuda_backward_4_4(grad_output, x, n, d); 63 | } 64 | 65 | at::Tensor pau_forward__5_5(torch::Tensor x, torch::Tensor n, torch::Tensor d) { 66 | CHECK_INPUT(x); 67 | CHECK_INPUT(n); 68 | CHECK_INPUT(d); 69 | 70 | return pau_cuda_forward_5_5(x, n, d); 71 | } 72 | std::vector pau_backward__5_5(torch::Tensor grad_output, torch::Tensor x, torch::Tensor n, torch::Tensor d) { 73 | CHECK_INPUT(grad_output); 74 | CHECK_INPUT(x); 75 | CHECK_INPUT(n); 76 | CHECK_INPUT(d); 77 | 78 | return pau_cuda_backward_5_5(grad_output, x, n, d); 79 | } 80 | 81 | at::Tensor pau_forward__6_6(torch::Tensor x, torch::Tensor n, torch::Tensor d) { 82 | CHECK_INPUT(x); 83 | CHECK_INPUT(n); 84 | CHECK_INPUT(d); 85 | 86 | return pau_cuda_forward_6_6(x, n, d); 87 | } 88 | std::vector pau_backward__6_6(torch::Tensor grad_output, torch::Tensor x, torch::Tensor n, torch::Tensor d) { 89 | CHECK_INPUT(grad_output); 90 | CHECK_INPUT(x); 91 | CHECK_INPUT(n); 92 | CHECK_INPUT(d); 93 | 94 | return pau_cuda_backward_6_6(grad_output, x, n, d); 95 | } 96 | 97 | at::Tensor pau_forward__7_7(torch::Tensor x, torch::Tensor n, torch::Tensor d) { 98 | CHECK_INPUT(x); 99 | CHECK_INPUT(n); 100 | CHECK_INPUT(d); 101 | 102 | return pau_cuda_forward_7_7(x, n, d); 103 | } 104 | std::vector pau_backward__7_7(torch::Tensor grad_output, torch::Tensor x, torch::Tensor n, torch::Tensor d) { 105 | CHECK_INPUT(grad_output); 106 | CHECK_INPUT(x); 107 | CHECK_INPUT(n); 108 | CHECK_INPUT(d); 109 | 110 | return pau_cuda_backward_7_7(grad_output, x, n, d); 111 | } 112 | 113 | at::Tensor pau_forward__8_8(torch::Tensor x, torch::Tensor n, torch::Tensor d) { 114 | CHECK_INPUT(x); 115 | CHECK_INPUT(n); 116 | CHECK_INPUT(d); 117 | 118 | return pau_cuda_forward_8_8(x, n, d); 119 | } 120 | std::vector pau_backward__8_8(torch::Tensor grad_output, torch::Tensor x, torch::Tensor n, torch::Tensor d) { 121 | CHECK_INPUT(grad_output); 122 | CHECK_INPUT(x); 123 | CHECK_INPUT(n); 124 | CHECK_INPUT(d); 125 | 126 | return pau_cuda_backward_8_8(grad_output, x, n, d); 127 | } 128 | 129 | at::Tensor pau_forward__5_4(torch::Tensor x, torch::Tensor n, torch::Tensor d) { 130 | CHECK_INPUT(x); 131 | CHECK_INPUT(n); 132 | CHECK_INPUT(d); 133 | 134 | return pau_cuda_forward_5_4(x, n, d); 135 | } 136 | std::vector pau_backward__5_4(torch::Tensor grad_output, torch::Tensor x, torch::Tensor n, torch::Tensor d) { 137 | CHECK_INPUT(grad_output); 138 | CHECK_INPUT(x); 139 | CHECK_INPUT(n); 140 | CHECK_INPUT(d); 141 | 142 | return pau_cuda_backward_5_4(grad_output, x, n, d); 143 | } 144 | 145 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 146 | 147 | m.def("forward_3_3", &pau_forward__3_3, "PAU forward _3_3"); 148 | m.def("backward_3_3", &pau_backward__3_3, "PAU backward _3_3"); 149 | 150 | m.def("forward_4_4", &pau_forward__4_4, "PAU forward _4_4"); 151 | m.def("backward_4_4", &pau_backward__4_4, "PAU backward _4_4"); 152 | 153 | m.def("forward_5_5", &pau_forward__5_5, "PAU forward _5_5"); 154 | m.def("backward_5_5", &pau_backward__5_5, "PAU backward _5_5"); 155 | 156 | m.def("forward_6_6", &pau_forward__6_6, "PAU forward _6_6"); 157 | m.def("backward_6_6", &pau_backward__6_6, "PAU backward _6_6"); 158 | 159 | m.def("forward_7_7", &pau_forward__7_7, "PAU forward _7_7"); 160 | m.def("backward_7_7", &pau_backward__7_7, "PAU backward _7_7"); 161 | 162 | m.def("forward_8_8", &pau_forward__8_8, "PAU forward _8_8"); 163 | m.def("backward_8_8", &pau_backward__8_8, "PAU backward _8_8"); 164 | 165 | m.def("forward_5_4", &pau_forward__5_4, "PAU forward _5_4"); 166 | m.def("backward_5_4", &pau_backward__5_4, "PAU backward _5_4"); 167 | } 168 | -------------------------------------------------------------------------------- /lib/Cell_DETR_master/pade_activation_unit/cuda/python_imp/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jcwang123/BA-Transformer/d1ccdff68beac82d18c2cd64bb800e51d7afc5e1/lib/Cell_DETR_master/pade_activation_unit/cuda/python_imp/__init__.py -------------------------------------------------------------------------------- /lib/Cell_DETR_master/pade_activation_unit/torchsummary.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jcwang123/BA-Transformer/d1ccdff68beac82d18c2cd64bb800e51d7afc5e1/lib/Cell_DETR_master/pade_activation_unit/torchsummary.py -------------------------------------------------------------------------------- /lib/Cell_DETR_master/pade_activation_unit/utils.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jcwang123/BA-Transformer/d1ccdff68beac82d18c2cd64bb800e51d7afc5e1/lib/Cell_DETR_master/pade_activation_unit/utils.py -------------------------------------------------------------------------------- /lib/Cell_DETR_master/pixel_adaptive_convolution/README.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jcwang123/BA-Transformer/d1ccdff68beac82d18c2cd64bb800e51d7afc5e1/lib/Cell_DETR_master/pixel_adaptive_convolution/README.md -------------------------------------------------------------------------------- /lib/Cell_DETR_master/pixel_adaptive_convolution/__pycache__/pac.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jcwang123/BA-Transformer/d1ccdff68beac82d18c2cd64bb800e51d7afc5e1/lib/Cell_DETR_master/pixel_adaptive_convolution/__pycache__/pac.cpython-37.pyc -------------------------------------------------------------------------------- /lib/Cell_DETR_master/pixel_adaptive_convolution/pac.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jcwang123/BA-Transformer/d1ccdff68beac82d18c2cd64bb800e51d7afc5e1/lib/Cell_DETR_master/pixel_adaptive_convolution/pac.py -------------------------------------------------------------------------------- /lib/Cell_DETR_master/pixel_adaptive_convolution/paccrf.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jcwang123/BA-Transformer/d1ccdff68beac82d18c2cd64bb800e51d7afc5e1/lib/Cell_DETR_master/pixel_adaptive_convolution/paccrf.py -------------------------------------------------------------------------------- /lib/Cell_DETR_master/pixel_adaptive_convolution/requirements.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jcwang123/BA-Transformer/d1ccdff68beac82d18c2cd64bb800e51d7afc5e1/lib/Cell_DETR_master/pixel_adaptive_convolution/requirements.txt -------------------------------------------------------------------------------- /lib/Cell_DETR_master/pixel_adaptive_convolution/test_pac.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jcwang123/BA-Transformer/d1ccdff68beac82d18c2cd64bb800e51d7afc5e1/lib/Cell_DETR_master/pixel_adaptive_convolution/test_pac.py -------------------------------------------------------------------------------- /lib/Cell_DETR_master/pixel_adaptive_convolution/tools/flowlib.py: -------------------------------------------------------------------------------- 1 | """ Copyright (C) 2019 NVIDIA Corporation. All rights reserved. Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 2 | This file incorporates work covered by the following copyright and permission notice: 3 | 4 | Copyright (c) 2019 LI RUOTENG 5 | 6 | Permission to use, copy, modify, and/or distribute this software 7 | for any purpose with or without fee is hereby granted, provided 8 | that the above copyright notice and this permission notice appear 9 | in all copies. 10 | 11 | THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL 12 | WARRANTIES WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED 13 | WARRANTIES OF MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE 14 | AUTHOR BE LIABLE FOR ANY SPECIAL, DIRECT, INDIRECT, OR 15 | CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS 16 | OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, 17 | NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN 18 | CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. 19 | """ 20 | from __future__ import division 21 | 22 | import numpy as np 23 | import matplotlib.pyplot as plt 24 | UNKNOWN_FLOW_THRESH = 1e7 25 | 26 | 27 | def evaluate_flow(gt, pred): 28 | """ 29 | evaluate the estimated optical flow end point error according to ground truth provided 30 | :param gt: ground truth file path 31 | :param pred: estimated optical flow file path 32 | :return: end point error, float32 33 | """ 34 | # Read flow files and calculate the errors 35 | gt_flow = read_flow(gt) # ground truth flow 36 | eva_flow = read_flow(pred) # predicted flow 37 | # Calculate errors 38 | average_pe = flow_error(gt_flow[:, :, 0], gt_flow[:, :, 1], eva_flow[:, :, 0], eva_flow[:, :, 1]) 39 | return average_pe 40 | 41 | 42 | def show_flow(filename): 43 | """ 44 | visualize optical flow map using matplotlib 45 | :param filename: optical flow file 46 | :return: None 47 | """ 48 | flow = read_flow(filename) 49 | img = flow_to_image(flow) 50 | plt.imshow(img) 51 | plt.show() 52 | 53 | 54 | # WARNING: this will work on little-endian architectures (eg Intel x86) only! 55 | def read_flow(filename): 56 | """ 57 | read optical flow in Middlebury .flo file format 58 | :param filename: 59 | :return: 60 | """ 61 | f = open(filename, 'rb') 62 | magic = np.fromfile(f, np.float32, count=1) 63 | data2d = None 64 | 65 | if 202021.25 != magic: 66 | print('Magic number incorrect. Invalid .flo file') 67 | else: 68 | w = np.fromfile(f, np.int32, count=1)[0] 69 | h = np.fromfile(f, np.int32, count=1)[0] 70 | data2d = np.fromfile(f, np.float32, count=2 * w * h) 71 | # reshape data into 3D array (columns, rows, channels) 72 | data2d = np.reshape(data2d, (h, w, 2)) 73 | f.close() 74 | return data2d 75 | 76 | 77 | # WARNING: this will work on little-endian architectures only! 78 | def write_flow(flow, filename): 79 | """ 80 | write optical flow in Middlebury .flo format 81 | :param flow: optical flow map 82 | :param filename: optical flow file path to be saved 83 | :return: None 84 | """ 85 | f = open(filename, 'wb') 86 | magic = np.array([202021.25], dtype=np.float32) 87 | (height, width) = flow.shape 88 | w = np.array([width], dtype=np.int32) 89 | h = np.array([height], dtype=np.int32) 90 | empty_map = np.zeros((height, width), dtype=np.float32) 91 | data = np.dstack((flow, empty_map)) 92 | magic.tofile(f) 93 | w.tofile(f) 94 | h.tofile(f) 95 | data.tofile(f) 96 | f.close() 97 | 98 | 99 | def flow_error(tu, tv, u, v): 100 | """ 101 | Calculate average end point error 102 | :param tu: ground-truth horizontal flow map 103 | :param tv: ground-truth vertical flow map 104 | :param u: estimated horizontal flow map 105 | :param v: estimated vertical flow map 106 | :return: End point error of the estimated flow 107 | """ 108 | smallflow = 0.0 109 | ''' 110 | stu = tu[bord+1:end-bord,bord+1:end-bord] 111 | stv = tv[bord+1:end-bord,bord+1:end-bord] 112 | su = u[bord+1:end-bord,bord+1:end-bord] 113 | sv = v[bord+1:end-bord,bord+1:end-bord] 114 | ''' 115 | stu = tu[:] 116 | stv = tv[:] 117 | su = u[:] 118 | sv = v[:] 119 | 120 | idxUnknow = (abs(stu) > UNKNOWN_FLOW_THRESH) | (abs(stv) > UNKNOWN_FLOW_THRESH) 121 | stu[idxUnknow] = 0 122 | stv[idxUnknow] = 0 123 | su[idxUnknow] = 0 124 | sv[idxUnknow] = 0 125 | 126 | ind2 = [(np.absolute(stu) > smallflow) | (np.absolute(stv) > smallflow)] 127 | index_su = su[ind2] 128 | index_sv = sv[ind2] 129 | an = 1.0 / np.sqrt(index_su ** 2 + index_sv ** 2 + 1) 130 | un = index_su * an 131 | vn = index_sv * an 132 | 133 | index_stu = stu[ind2] 134 | index_stv = stv[ind2] 135 | tn = 1.0 / np.sqrt(index_stu ** 2 + index_stv ** 2 + 1) 136 | tun = index_stu * tn 137 | tvn = index_stv * tn 138 | 139 | ''' 140 | angle = un * tun + vn * tvn + (an * tn) 141 | index = [angle == 1.0] 142 | angle[index] = 0.999 143 | ang = np.arccos(angle) 144 | mang = np.mean(ang) 145 | mang = mang * 180 / np.pi 146 | ''' 147 | 148 | epe = np.sqrt((stu - su) ** 2 + (stv - sv) ** 2) 149 | epe = epe[ind2] 150 | mepe = np.mean(epe) 151 | return mepe 152 | 153 | 154 | def flow_to_image(flow): 155 | """ 156 | Convert flow into middlebury color code image 157 | :param flow: optical flow map 158 | :return: optical flow image in middlebury color 159 | """ 160 | u = flow[:, :, 0] 161 | v = flow[:, :, 1] 162 | 163 | maxu = -999. 164 | maxv = -999. 165 | minu = 999. 166 | minv = 999. 167 | 168 | idxUnknow = (abs(u) > UNKNOWN_FLOW_THRESH) | (abs(v) > UNKNOWN_FLOW_THRESH) 169 | u[idxUnknow] = 0 170 | v[idxUnknow] = 0 171 | 172 | maxu = max(maxu, np.max(u)) 173 | minu = min(minu, np.min(u)) 174 | 175 | maxv = max(maxv, np.max(v)) 176 | minv = min(minv, np.min(v)) 177 | 178 | rad = np.sqrt(u ** 2 + v ** 2) 179 | maxrad = max(-1, np.max(rad)) 180 | 181 | print("max flow: %.4f\nflow range:\nu = %.3f .. %.3f\nv = %.3f .. %.3f" % (maxrad, minu,maxu, minv, maxv)) 182 | 183 | u = u/(maxrad + np.finfo(float).eps) 184 | v = v/(maxrad + np.finfo(float).eps) 185 | 186 | img = compute_color(u, v) 187 | 188 | idx = np.repeat(idxUnknow[:, :, np.newaxis], 3, axis=2) 189 | img[idx] = 0 190 | 191 | return np.uint8(img) 192 | 193 | 194 | def compute_color(u, v): 195 | """ 196 | compute optical flow color map 197 | :param u: optical flow horizontal map 198 | :param v: optical flow vertical map 199 | :return: optical flow in color code 200 | """ 201 | [h, w] = u.shape 202 | img = np.zeros([h, w, 3]) 203 | nanIdx = np.isnan(u) | np.isnan(v) 204 | u[nanIdx] = 0 205 | v[nanIdx] = 0 206 | 207 | colorwheel = make_color_wheel() 208 | ncols = np.size(colorwheel, 0) 209 | 210 | rad = np.sqrt(u**2+v**2) 211 | 212 | a = np.arctan2(-v, -u) / np.pi 213 | 214 | fk = (a+1) / 2 * (ncols - 1) + 1 215 | 216 | k0 = np.floor(fk).astype(int) 217 | 218 | k1 = k0 + 1 219 | k1[k1 == ncols+1] = 1 220 | f = fk - k0 221 | 222 | for i in range(0, np.size(colorwheel,1)): 223 | tmp = colorwheel[:, i] 224 | col0 = tmp[k0-1] / 255 225 | col1 = tmp[k1-1] / 255 226 | col = (1-f) * col0 + f * col1 227 | 228 | idx = rad <= 1 229 | col[idx] = 1-rad[idx]*(1-col[idx]) 230 | notidx = np.logical_not(idx) 231 | 232 | col[notidx] *= 0.75 233 | img[:, :, i] = np.uint8(np.floor(255 * col*(1-nanIdx))) 234 | 235 | return img 236 | 237 | 238 | def make_color_wheel(): 239 | """ 240 | Generate color wheel according Middlebury color code 241 | :return: Color wheel 242 | """ 243 | RY = 15 244 | YG = 6 245 | GC = 4 246 | CB = 11 247 | BM = 13 248 | MR = 6 249 | 250 | ncols = RY + YG + GC + CB + BM + MR 251 | 252 | colorwheel = np.zeros([ncols, 3]) 253 | 254 | col = 0 255 | 256 | # RY 257 | colorwheel[0:RY, 0] = 255 258 | colorwheel[0:RY, 1] = np.transpose(np.floor(255*np.arange(0, RY) / RY)) 259 | col += RY 260 | 261 | # YG 262 | colorwheel[col:col+YG, 0] = 255 - np.transpose(np.floor(255*np.arange(0, YG) / YG)) 263 | colorwheel[col:col+YG, 1] = 255 264 | col += YG 265 | 266 | # GC 267 | colorwheel[col:col+GC, 1] = 255 268 | colorwheel[col:col+GC, 2] = np.transpose(np.floor(255*np.arange(0, GC) / GC)) 269 | col += GC 270 | 271 | # CB 272 | colorwheel[col:col+CB, 1] = 255 - np.transpose(np.floor(255*np.arange(0, CB) / CB)) 273 | colorwheel[col:col+CB, 2] = 255 274 | col += CB 275 | 276 | # BM 277 | colorwheel[col:col+BM, 2] = 255 278 | colorwheel[col:col+BM, 0] = np.transpose(np.floor(255*np.arange(0, BM) / BM)) 279 | col += + BM 280 | 281 | # MR 282 | colorwheel[col:col+MR, 2] = 255 - np.transpose(np.floor(255 * np.arange(0, MR) / MR)) 283 | colorwheel[col:col+MR, 0] = 255 284 | 285 | return colorwheel 286 | 287 | 288 | def scale_image(image, new_range): 289 | """ 290 | Linearly scale the image into desired range 291 | :param image: input image 292 | :param new_range: the new range to be aligned 293 | :return: image normalized in new range 294 | """ 295 | min_val = np.min(image).astype(np.float32) 296 | max_val = np.max(image).astype(np.float32) 297 | min_val_new = np.array(min(new_range), dtype=np.float32) 298 | max_val_new = np.array(max(new_range), dtype=np.float32) 299 | scaled_image = (image - min_val) / (max_val - min_val) * (max_val_new - min_val_new) + min_val_new 300 | return scaled_image.astype(np.uint8) 301 | 302 | -------------------------------------------------------------------------------- /lib/Cell_DETR_master/pixel_adaptive_convolution/tools/plot_log.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | """ 5 | import argparse 6 | 7 | import numpy as np 8 | import matplotlib.pyplot as plt 9 | 10 | 11 | LINE_COLORS = plt.rcParams['axes.prop_cycle'].by_key()['color'] 12 | LINE_STYLES = ('solid', 'dashed', 'dashdot', 'dotted') 13 | 14 | 15 | def smooth_plot(xs, ys, smooth=5.0, axis=None, *args, **kwargs): 16 | min_window = 3 17 | if smooth > 0 and len(xs) > min_window * 100 / smooth: 18 | window = int(len(xs) * smooth // 100) 19 | f = np.repeat(1.0, window) / window 20 | ys = np.convolve(ys, f, 'valid') 21 | xs = xs[(window//2):(window//2)+len(ys)] 22 | # TODO: fix issue with empty strings and strings starting with '_' 23 | if 'label' in kwargs and (kwargs['label'].startswith('_') or kwargs['label'] == ''): 24 | kwargs['label'] = ' ' + kwargs['label'] 25 | if axis is None: 26 | plt.plot(xs, ys, *args, **kwargs) 27 | else: 28 | axis.plot(xs, ys, *args, **kwargs) 29 | 30 | 31 | def remove_common_prefix_suffix(array_of_str): 32 | if len(array_of_str) == 1: 33 | return array_of_str 34 | remove_head, remove_tail = 0, 0 35 | while True: 36 | v = array_of_str[0][remove_head] 37 | if all(s[remove_head] == v for s in array_of_str): 38 | remove_head += 1 39 | else: 40 | break 41 | while True: 42 | v = array_of_str[0][len(array_of_str[0]) - remove_tail - 1] 43 | if all(s[len(s) - remove_tail - 1] == v for s in array_of_str): 44 | remove_tail += 1 45 | else: 46 | break 47 | return [s[remove_head:len(s) - remove_tail] for s in array_of_str] 48 | 49 | 50 | def parse_and_plot(paths, output=None, plots=None, labels=None, reorder=None, trials=1, 51 | smooth=5.0, num_col=0, subplot_size=5, start_x=0.0, end_x=-1.0, legend='best'): 52 | with open(paths[0], 'r') as f: 53 | r = f.readline() 54 | xlabel = r.strip().split(',')[0] 55 | plots_all = r.strip().split(',')[1:] 56 | if plots: 57 | plot_dict = {p: i+1 for i, p in enumerate(plots_all)} 58 | plot_cols = tuple(-1 if p == '-' else plot_dict[p] for p in plots) 59 | else: 60 | plots = plots_all 61 | plot_cols = tuple(range(1, len(plots) + 1)) 62 | if not labels: 63 | labels = remove_common_prefix_suffix(paths)[::trials] 64 | assert(len(labels) == len(set(labels))) 65 | paths = [paths[i*trials:(i+1)*trials] for i in range(len(labels))] 66 | if reorder: 67 | if len(reorder) == 1 and reorder[0] == 'str': 68 | reorder = sorted(labels) 69 | paths = [paths[labels.index(l)] for l in reorder] 70 | labels = reorder 71 | 72 | runs_data = [] 73 | for ps in paths: 74 | tmp = [] 75 | for p in ps: 76 | data = np.genfromtxt(p, delimiter=',', usecols=(0,)+tuple(filter(lambda v: v != -1, plot_cols)), skip_header=1) 77 | data = data.reshape(-1, data.shape[-1]) 78 | start = np.where(data[:, 0] >= start_x)[0][0] 79 | end = None 80 | if end_x > 0: 81 | end_idxs = np.where(data[:, 0] > end_x)[0] 82 | if len(end_idxs) > 0: 83 | end = end_idxs[0] 84 | data = data[start:end] 85 | tmp.append(data) 86 | _len = min(len(d) for d in tmp) 87 | runs_data.append(np.mean([d[:_len] for d in tmp], axis=0)) 88 | 89 | if output: 90 | plt.switch_backend('agg') 91 | 92 | if num_col <= 0: 93 | num_col = int(np.ceil(np.sqrt(len(plots)))) 94 | num_row = (len(plots) - 1) // num_col + 1 95 | fig, axes = plt.subplots(num_row, num_col, squeeze=False, figsize=(num_col * subplot_size, num_row * subplot_size)) 96 | p_idx = 0 97 | for plabel, ax in zip(plots, axes.flat): 98 | if plabel == '-': 99 | continue 100 | p_idx += 1 101 | # plt.subplot(num_row, num_col, p + 1) 102 | for r, rdata in enumerate(runs_data): 103 | valid_mask = ~np.isnan(rdata[:, p_idx]) 104 | if any(valid_mask): 105 | smooth_plot(rdata[valid_mask, 0], rdata[valid_mask, p_idx], smooth, ax, 106 | color=LINE_COLORS[r % len(LINE_COLORS)], 107 | linestyle=LINE_STYLES[r // len(LINE_COLORS)], 108 | linewidth=1.5, label=labels[r]) 109 | ax.set_xlabel(xlabel) 110 | ax.set_title(plabel) 111 | ax.grid(True) 112 | if legend != 'off' and len(labels) > 1: 113 | ax.legend(loc=legend) 114 | for ax in axes.flat[len(plots):]: 115 | fig.delaxes(ax) 116 | fig.tight_layout() 117 | 118 | if output: 119 | plt.savefig(output) 120 | else: 121 | plt.show() 122 | 123 | 124 | if __name__ == '__main__': 125 | parser = argparse.ArgumentParser(description='Parse log files (csv format) and make plots', 126 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 127 | parser.add_argument('paths', nargs='+', help='one or more log files (each subplot will have #paths curves)') 128 | parser.add_argument('--output', type=str, default='', help='place to save figure (show figure if blank)') 129 | parser.add_argument('--plots', nargs='+', default=None, help='use select plots instead of one for each column') 130 | parser.add_argument('--labels', nargs='+', default=None, help='use labels instead of file names') 131 | parser.add_argument('--average', type=int, default=1, help='plot average curves') 132 | parser.add_argument('--reorder', nargs='+', default=None, help='order the label legends') 133 | parser.add_argument('--smooth', type=float, default=5.0, help='smoothing level (0-100)') 134 | parser.add_argument('--num-col', type=int, default=0, help='number of columns (0 - auto)') 135 | parser.add_argument('--subplot-size', type=int, default=5, help='width and height of each subplot') 136 | parser.add_argument('--legend', type=str, default='best', help='place to put legend') 137 | parser.add_argument('--xlim', nargs='+', default=None, help='xmin (xmax)') 138 | 139 | args = parser.parse_args() 140 | 141 | start_x = 0.0 if not args.xlim else float(args.xlim[0]) 142 | end_x = -1.0 if (not args.xlim or len(args.xlim) < 2) else float(args.xlim[1]) 143 | 144 | parse_and_plot(args.paths, output=args.output, plots=args.plots, labels=args.labels, reorder=args.reorder, 145 | trials=args.average, smooth=args.smooth, num_col=args.num_col, subplot_size=args.subplot_size, 146 | start_x=start_x, end_x=end_x, legend=args.legend.replace('_', ' ')) 147 | -------------------------------------------------------------------------------- /lib/Cell_DETR_master/requirements.txt: -------------------------------------------------------------------------------- 1 | pytorch==1.4.0 2 | torchvision==0.5.0 3 | numpy 4 | skimage 5 | scipy 6 | tqdm 7 | setproctitle 8 | tikzplotlib -------------------------------------------------------------------------------- /lib/non_local/non_local_concatenation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | 5 | 6 | class _NonLocalBlockND(nn.Module): 7 | def __init__(self, in_channels, inter_channels=None, dimension=3, sub_sample=True, bn_layer=True): 8 | super(_NonLocalBlockND, self).__init__() 9 | 10 | assert dimension in [1, 2, 3] 11 | 12 | self.dimension = dimension 13 | self.sub_sample = sub_sample 14 | 15 | self.in_channels = in_channels 16 | self.inter_channels = inter_channels 17 | 18 | if self.inter_channels is None: 19 | self.inter_channels = in_channels // 2 20 | if self.inter_channels == 0: 21 | self.inter_channels = 1 22 | 23 | if dimension == 3: 24 | conv_nd = nn.Conv3d 25 | max_pool_layer = nn.MaxPool3d(kernel_size=(1, 2, 2)) 26 | bn = nn.BatchNorm3d 27 | elif dimension == 2: 28 | conv_nd = nn.Conv2d 29 | max_pool_layer = nn.MaxPool2d(kernel_size=(2, 2)) 30 | bn = nn.BatchNorm2d 31 | else: 32 | conv_nd = nn.Conv1d 33 | max_pool_layer = nn.MaxPool1d(kernel_size=(2)) 34 | bn = nn.BatchNorm1d 35 | 36 | self.g = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, 37 | kernel_size=1, stride=1, padding=0) 38 | 39 | if bn_layer: 40 | self.W = nn.Sequential( 41 | conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels, 42 | kernel_size=1, stride=1, padding=0), 43 | bn(self.in_channels) 44 | ) 45 | nn.init.constant_(self.W[1].weight, 0) 46 | nn.init.constant_(self.W[1].bias, 0) 47 | else: 48 | self.W = conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels, 49 | kernel_size=1, stride=1, padding=0) 50 | nn.init.constant_(self.W.weight, 0) 51 | nn.init.constant_(self.W.bias, 0) 52 | 53 | self.theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, 54 | kernel_size=1, stride=1, padding=0) 55 | 56 | self.phi = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, 57 | kernel_size=1, stride=1, padding=0) 58 | 59 | self.concat_project = nn.Sequential( 60 | nn.Conv2d(self.inter_channels * 2, 1, 1, 1, 0, bias=False), 61 | nn.ReLU() 62 | ) 63 | 64 | if sub_sample: 65 | self.g = nn.Sequential(self.g, max_pool_layer) 66 | self.phi = nn.Sequential(self.phi, max_pool_layer) 67 | 68 | def forward(self, x, return_nl_map=False): 69 | ''' 70 | :param x: (b, c, t, h, w) 71 | :param return_nl_map: if True return z, nl_map, else only return z. 72 | :return: 73 | ''' 74 | 75 | batch_size = x.size(0) 76 | 77 | g_x = self.g(x).view(batch_size, self.inter_channels, -1) 78 | g_x = g_x.permute(0, 2, 1) 79 | 80 | # (b, c, N, 1) 81 | theta_x = self.theta(x).view(batch_size, self.inter_channels, -1, 1) 82 | # (b, c, 1, N) 83 | phi_x = self.phi(x).view(batch_size, self.inter_channels, 1, -1) 84 | 85 | h = theta_x.size(2) 86 | w = phi_x.size(3) 87 | theta_x = theta_x.repeat(1, 1, 1, w) 88 | phi_x = phi_x.repeat(1, 1, h, 1) 89 | 90 | concat_feature = torch.cat([theta_x, phi_x], dim=1) 91 | f = self.concat_project(concat_feature) 92 | b, _, h, w = f.size() 93 | f = f.view(b, h, w) 94 | 95 | N = f.size(-1) 96 | f_div_C = f / N 97 | 98 | y = torch.matmul(f_div_C, g_x) 99 | y = y.permute(0, 2, 1).contiguous() 100 | y = y.view(batch_size, self.inter_channels, *x.size()[2:]) 101 | W_y = self.W(y) 102 | z = W_y + x 103 | 104 | if return_nl_map: 105 | return z, f_div_C 106 | return z 107 | 108 | 109 | class NONLocalBlock1D(_NonLocalBlockND): 110 | def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True): 111 | super(NONLocalBlock1D, self).__init__(in_channels, 112 | inter_channels=inter_channels, 113 | dimension=1, sub_sample=sub_sample, 114 | bn_layer=bn_layer) 115 | 116 | 117 | class NONLocalBlock2D(_NonLocalBlockND): 118 | def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True): 119 | super(NONLocalBlock2D, self).__init__(in_channels, 120 | inter_channels=inter_channels, 121 | dimension=2, sub_sample=sub_sample, 122 | bn_layer=bn_layer) 123 | 124 | 125 | class NONLocalBlock3D(_NonLocalBlockND): 126 | def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True,): 127 | super(NONLocalBlock3D, self).__init__(in_channels, 128 | inter_channels=inter_channels, 129 | dimension=3, sub_sample=sub_sample, 130 | bn_layer=bn_layer) 131 | 132 | 133 | if __name__ == '__main__': 134 | import torch 135 | 136 | for (sub_sample_, bn_layer_) in [(True, True), (False, False), (True, False), (False, True)]: 137 | img = torch.zeros(2, 3, 20) 138 | net = NONLocalBlock1D(3, sub_sample=sub_sample_, bn_layer=bn_layer_) 139 | out = net(img) 140 | print(out.size()) 141 | 142 | img = torch.zeros(2, 3, 20, 20) 143 | net = NONLocalBlock2D(3, sub_sample=sub_sample_, bn_layer=bn_layer_) 144 | out = net(img) 145 | print(out.size()) 146 | 147 | img = torch.randn(2, 3, 8, 20, 20) 148 | net = NONLocalBlock3D(3, sub_sample=sub_sample_, bn_layer=bn_layer_) 149 | out = net(img) 150 | print(out.size()) 151 | -------------------------------------------------------------------------------- /lib/non_local/non_local_dot_product.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | 5 | 6 | class _NonLocalBlockND(nn.Module): 7 | def __init__(self, in_channels, inter_channels=None, dimension=3, sub_sample=True, bn_layer=True): 8 | super(_NonLocalBlockND, self).__init__() 9 | 10 | assert dimension in [1, 2, 3] 11 | 12 | self.dimension = dimension 13 | self.sub_sample = sub_sample 14 | 15 | self.in_channels = in_channels 16 | self.inter_channels = inter_channels 17 | 18 | if self.inter_channels is None: 19 | self.inter_channels = in_channels // 2 20 | if self.inter_channels == 0: 21 | self.inter_channels = 1 22 | 23 | if dimension == 3: 24 | conv_nd = nn.Conv3d 25 | max_pool_layer = nn.MaxPool3d(kernel_size=(1, 2, 2)) 26 | bn = nn.BatchNorm3d 27 | elif dimension == 2: 28 | conv_nd = nn.Conv2d 29 | max_pool_layer = nn.MaxPool2d(kernel_size=(2, 2)) 30 | bn = nn.BatchNorm2d 31 | else: 32 | conv_nd = nn.Conv1d 33 | max_pool_layer = nn.MaxPool1d(kernel_size=(2)) 34 | bn = nn.BatchNorm1d 35 | 36 | self.g = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, 37 | kernel_size=1, stride=1, padding=0) 38 | 39 | if bn_layer: 40 | self.W = nn.Sequential( 41 | conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels, 42 | kernel_size=1, stride=1, padding=0), 43 | bn(self.in_channels) 44 | ) 45 | nn.init.constant_(self.W[1].weight, 0) 46 | nn.init.constant_(self.W[1].bias, 0) 47 | else: 48 | self.W = conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels, 49 | kernel_size=1, stride=1, padding=0) 50 | nn.init.constant_(self.W.weight, 0) 51 | nn.init.constant_(self.W.bias, 0) 52 | 53 | self.theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, 54 | kernel_size=1, stride=1, padding=0) 55 | 56 | self.phi = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, 57 | kernel_size=1, stride=1, padding=0) 58 | 59 | if sub_sample: 60 | self.g = nn.Sequential(self.g, max_pool_layer) 61 | self.phi = nn.Sequential(self.phi, max_pool_layer) 62 | 63 | def forward(self, x, return_nl_map=False): 64 | """ 65 | :param x: (b, c, t, h, w) 66 | :param return_nl_map: if True return z, nl_map, else only return z. 67 | :return: 68 | """ 69 | 70 | batch_size = x.size(0) 71 | 72 | g_x = self.g(x).view(batch_size, self.inter_channels, -1) 73 | g_x = g_x.permute(0, 2, 1) 74 | 75 | theta_x = self.theta(x).view(batch_size, self.inter_channels, -1) 76 | theta_x = theta_x.permute(0, 2, 1) 77 | phi_x = self.phi(x).view(batch_size, self.inter_channels, -1) 78 | f = torch.matmul(theta_x, phi_x) 79 | N = f.size(-1) 80 | f_div_C = f / N 81 | 82 | y = torch.matmul(f_div_C, g_x) 83 | y = y.permute(0, 2, 1).contiguous() 84 | y = y.view(batch_size, self.inter_channels, *x.size()[2:]) 85 | W_y = self.W(y) 86 | z = W_y + x 87 | 88 | if return_nl_map: 89 | return z, f_div_C 90 | return z 91 | 92 | 93 | class NONLocalBlock1D(_NonLocalBlockND): 94 | def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True): 95 | super(NONLocalBlock1D, self).__init__(in_channels, 96 | inter_channels=inter_channels, 97 | dimension=1, sub_sample=sub_sample, 98 | bn_layer=bn_layer) 99 | 100 | 101 | class NONLocalBlock2D(_NonLocalBlockND): 102 | def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True): 103 | super(NONLocalBlock2D, self).__init__(in_channels, 104 | inter_channels=inter_channels, 105 | dimension=2, sub_sample=sub_sample, 106 | bn_layer=bn_layer) 107 | 108 | 109 | class NONLocalBlock3D(_NonLocalBlockND): 110 | def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True): 111 | super(NONLocalBlock3D, self).__init__(in_channels, 112 | inter_channels=inter_channels, 113 | dimension=3, sub_sample=sub_sample, 114 | bn_layer=bn_layer) 115 | 116 | 117 | if __name__ == '__main__': 118 | import torch 119 | 120 | for (sub_sample_, bn_layer_) in [(True, True), (False, False), (True, False), (False, True)]: 121 | img = torch.zeros(2, 3, 20) 122 | net = NONLocalBlock1D(3, sub_sample=sub_sample_, bn_layer=bn_layer_) 123 | out = net(img) 124 | print(out.size()) 125 | 126 | img = torch.zeros(2, 3, 20, 20) 127 | net = NONLocalBlock2D(3, sub_sample=sub_sample_, bn_layer=bn_layer_) 128 | out = net(img) 129 | print(out.size()) 130 | 131 | img = torch.randn(2, 3, 8, 20, 20) 132 | net = NONLocalBlock3D(3, sub_sample=sub_sample_, bn_layer=bn_layer_) 133 | out = net(img) 134 | print(out.size()) 135 | 136 | 137 | 138 | -------------------------------------------------------------------------------- /lib/non_local/non_local_embedded_gaussian.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | 5 | 6 | class _NonLocalBlockND(nn.Module): 7 | def __init__(self, in_channels, inter_channels=None, dimension=3, sub_sample=True, bn_layer=True): 8 | """ 9 | :param in_channels: 10 | :param inter_channels: 11 | :param dimension: 12 | :param sub_sample: 13 | :param bn_layer: 14 | """ 15 | 16 | super(_NonLocalBlockND, self).__init__() 17 | 18 | assert dimension in [1, 2, 3] 19 | 20 | self.dimension = dimension 21 | self.sub_sample = sub_sample 22 | 23 | self.in_channels = in_channels 24 | self.inter_channels = inter_channels 25 | 26 | if self.inter_channels is None: 27 | self.inter_channels = in_channels // 2 28 | if self.inter_channels == 0: 29 | self.inter_channels = 1 30 | 31 | if dimension == 3: 32 | conv_nd = nn.Conv3d 33 | max_pool_layer = nn.MaxPool3d(kernel_size=(1, 2, 2)) 34 | bn = nn.BatchNorm3d 35 | elif dimension == 2: 36 | conv_nd = nn.Conv2d 37 | max_pool_layer = nn.MaxPool2d(kernel_size=(2, 2)) 38 | bn = nn.BatchNorm2d 39 | else: 40 | conv_nd = nn.Conv1d 41 | max_pool_layer = nn.MaxPool1d(kernel_size=(2)) 42 | bn = nn.BatchNorm1d 43 | 44 | self.g = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, 45 | kernel_size=1, stride=1, padding=0) 46 | 47 | if bn_layer: 48 | self.W = nn.Sequential( 49 | conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels, 50 | kernel_size=1, stride=1, padding=0), 51 | bn(self.in_channels) 52 | ) 53 | nn.init.constant_(self.W[1].weight, 0) 54 | nn.init.constant_(self.W[1].bias, 0) 55 | else: 56 | self.W = conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels, 57 | kernel_size=1, stride=1, padding=0) 58 | nn.init.constant_(self.W.weight, 0) 59 | nn.init.constant_(self.W.bias, 0) 60 | 61 | self.theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, 62 | kernel_size=1, stride=1, padding=0) 63 | self.phi = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, 64 | kernel_size=1, stride=1, padding=0) 65 | 66 | if sub_sample: 67 | self.g = nn.Sequential(self.g, max_pool_layer) 68 | self.phi = nn.Sequential(self.phi, max_pool_layer) 69 | 70 | def forward(self, x, return_nl_map=False): 71 | """ 72 | :param x: (b, c, t, h, w) 73 | :param return_nl_map: if True return z, nl_map, else only return z. 74 | :return: 75 | """ 76 | 77 | batch_size = x.size(0) 78 | 79 | g_x = self.g(x).view(batch_size, self.inter_channels, -1) 80 | g_x = g_x.permute(0, 2, 1) 81 | 82 | theta_x = self.theta(x).view(batch_size, self.inter_channels, -1) 83 | theta_x = theta_x.permute(0, 2, 1) 84 | phi_x = self.phi(x).view(batch_size, self.inter_channels, -1) 85 | f = torch.matmul(theta_x, phi_x) 86 | f_div_C = F.softmax(f, dim=-1) 87 | 88 | y = torch.matmul(f_div_C, g_x) 89 | y = y.permute(0, 2, 1).contiguous() 90 | y = y.view(batch_size, self.inter_channels, *x.size()[2:]) 91 | W_y = self.W(y) 92 | z = W_y + x 93 | 94 | if return_nl_map: 95 | return z, f_div_C 96 | return z 97 | 98 | 99 | class NONLocalBlock1D(_NonLocalBlockND): 100 | def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True): 101 | super(NONLocalBlock1D, self).__init__(in_channels, 102 | inter_channels=inter_channels, 103 | dimension=1, sub_sample=sub_sample, 104 | bn_layer=bn_layer) 105 | 106 | 107 | class NONLocalBlock2D(_NonLocalBlockND): 108 | def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True): 109 | super(NONLocalBlock2D, self).__init__(in_channels, 110 | inter_channels=inter_channels, 111 | dimension=2, sub_sample=sub_sample, 112 | bn_layer=bn_layer,) 113 | 114 | 115 | class NONLocalBlock3D(_NonLocalBlockND): 116 | def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True): 117 | super(NONLocalBlock3D, self).__init__(in_channels, 118 | inter_channels=inter_channels, 119 | dimension=3, sub_sample=sub_sample, 120 | bn_layer=bn_layer,) 121 | 122 | 123 | if __name__ == '__main__': 124 | import torch 125 | 126 | for (sub_sample_, bn_layer_) in [(True, True), (False, False), (True, False), (False, True)]: 127 | img = torch.zeros(2, 3, 20) 128 | net = NONLocalBlock1D(3, sub_sample=sub_sample_, bn_layer=bn_layer_) 129 | out = net(img) 130 | print(out.size()) 131 | 132 | img = torch.zeros(2, 3, 20, 20) 133 | net = NONLocalBlock2D(3, sub_sample=sub_sample_, bn_layer=bn_layer_) 134 | out = net(img) 135 | print(out.size()) 136 | 137 | img = torch.randn(2, 3, 8, 20, 20) 138 | net = NONLocalBlock3D(3, sub_sample=sub_sample_, bn_layer=bn_layer_) 139 | out = net(img) 140 | print(out.size()) 141 | 142 | 143 | -------------------------------------------------------------------------------- /lib/non_local/non_local_gaussian.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | 5 | 6 | class _NonLocalBlockND(nn.Module): 7 | def __init__(self, in_channels, inter_channels=None, dimension=3, sub_sample=True, bn_layer=True): 8 | super(_NonLocalBlockND, self).__init__() 9 | 10 | assert dimension in [1, 2, 3] 11 | 12 | self.dimension = dimension 13 | self.sub_sample = sub_sample 14 | 15 | self.in_channels = in_channels 16 | self.inter_channels = inter_channels 17 | 18 | if self.inter_channels is None: 19 | self.inter_channels = in_channels // 2 20 | if self.inter_channels == 0: 21 | self.inter_channels = 1 22 | 23 | if dimension == 3: 24 | conv_nd = nn.Conv3d 25 | max_pool_layer = nn.MaxPool3d(kernel_size=(1, 2, 2)) 26 | bn = nn.BatchNorm3d 27 | elif dimension == 2: 28 | conv_nd = nn.Conv2d 29 | max_pool_layer = nn.MaxPool2d(kernel_size=(2, 2)) 30 | bn = nn.BatchNorm2d 31 | else: 32 | conv_nd = nn.Conv1d 33 | max_pool_layer = nn.MaxPool1d(kernel_size=(2)) 34 | bn = nn.BatchNorm1d 35 | 36 | self.g = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, 37 | kernel_size=1, stride=1, padding=0) 38 | 39 | if bn_layer: 40 | self.W = nn.Sequential( 41 | conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels, 42 | kernel_size=1, stride=1, padding=0), 43 | bn(self.in_channels) 44 | ) 45 | nn.init.constant_(self.W[1].weight, 0) 46 | nn.init.constant_(self.W[1].bias, 0) 47 | else: 48 | self.W = conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels, 49 | kernel_size=1, stride=1, padding=0) 50 | nn.init.constant_(self.W.weight, 0) 51 | nn.init.constant_(self.W.bias, 0) 52 | 53 | if sub_sample: 54 | self.g = nn.Sequential(self.g, max_pool_layer) 55 | self.phi = max_pool_layer 56 | 57 | def forward(self, x, return_nl_map=False): 58 | """ 59 | :param x: (b, c, t, h, w) 60 | :param return_nl_map: if True return z, nl_map, else only return z. 61 | :return: 62 | """ 63 | 64 | batch_size = x.size(0) 65 | 66 | g_x = self.g(x).view(batch_size, self.inter_channels, -1) 67 | 68 | g_x = g_x.permute(0, 2, 1) 69 | 70 | theta_x = x.view(batch_size, self.in_channels, -1) 71 | theta_x = theta_x.permute(0, 2, 1) 72 | 73 | if self.sub_sample: 74 | phi_x = self.phi(x).view(batch_size, self.in_channels, -1) 75 | else: 76 | phi_x = x.view(batch_size, self.in_channels, -1) 77 | 78 | f = torch.matmul(theta_x, phi_x) 79 | f_div_C = F.softmax(f, dim=-1) 80 | 81 | # if self.store_last_batch_nl_map: 82 | # self.nl_map = f_div_C 83 | 84 | y = torch.matmul(f_div_C, g_x) 85 | y = y.permute(0, 2, 1).contiguous() 86 | y = y.view(batch_size, self.inter_channels, *x.size()[2:]) 87 | W_y = self.W(y) 88 | z = W_y + x 89 | 90 | if return_nl_map: 91 | return z, f_div_C 92 | return z 93 | 94 | 95 | class NONLocalBlock1D(_NonLocalBlockND): 96 | def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True): 97 | super(NONLocalBlock1D, self).__init__(in_channels, 98 | inter_channels=inter_channels, 99 | dimension=1, sub_sample=sub_sample, 100 | bn_layer=bn_layer) 101 | 102 | 103 | class NONLocalBlock2D(_NonLocalBlockND): 104 | def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True): 105 | super(NONLocalBlock2D, self).__init__(in_channels, 106 | inter_channels=inter_channels, 107 | dimension=2, sub_sample=sub_sample, 108 | bn_layer=bn_layer) 109 | 110 | 111 | class NONLocalBlock3D(_NonLocalBlockND): 112 | def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True): 113 | super(NONLocalBlock3D, self).__init__(in_channels, 114 | inter_channels=inter_channels, 115 | dimension=3, sub_sample=sub_sample, 116 | bn_layer=bn_layer) 117 | 118 | 119 | if __name__ == '__main__': 120 | import torch 121 | 122 | for (sub_sample_, bn_layer_) in [(True, True), (False, False), (True, False), (False, True)]: 123 | img = torch.zeros(2, 3, 20) 124 | net = NONLocalBlock1D(3, sub_sample=sub_sample_, bn_layer=bn_layer_) 125 | out = net(img) 126 | print(out.size()) 127 | 128 | img = torch.zeros(2, 3, 20, 20) 129 | net = NONLocalBlock2D(3, sub_sample=sub_sample_, bn_layer=bn_layer_) 130 | out = net(img) 131 | print(out.size()) 132 | 133 | img = torch.randn(2, 3, 8, 20, 20) 134 | net = NONLocalBlock3D(3, sub_sample=sub_sample_, bn_layer=bn_layer_) 135 | out = net(img) 136 | print(out.size()) 137 | 138 | 139 | 140 | 141 | 142 | 143 | -------------------------------------------------------------------------------- /src/__pycache__/BAT_Modules.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jcwang123/BA-Transformer/d1ccdff68beac82d18c2cd64bb800e51d7afc5e1/src/__pycache__/BAT_Modules.cpython-37.pyc -------------------------------------------------------------------------------- /src/__pycache__/losses.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jcwang123/BA-Transformer/d1ccdff68beac82d18c2cd64bb800e51d7afc5e1/src/__pycache__/losses.cpython-37.pyc -------------------------------------------------------------------------------- /src/__pycache__/transformer.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jcwang123/BA-Transformer/d1ccdff68beac82d18c2cd64bb800e51d7afc5e1/src/__pycache__/transformer.cpython-37.pyc -------------------------------------------------------------------------------- /src/__pycache__/utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jcwang123/BA-Transformer/d1ccdff68beac82d18c2cd64bb800e51d7afc5e1/src/__pycache__/utils.cpython-37.pyc -------------------------------------------------------------------------------- /src/process_point.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import os 3 | import random 4 | import torch 5 | import numpy as np 6 | import skimage.draw 7 | from tqdm import tqdm 8 | import matplotlib.pyplot as plt 9 | import torch.nn.functional as F 10 | 11 | 12 | def create_circular_mask(h, w, center, radius): 13 | Y, X = np.ogrid[:h, :w] 14 | dist_from_center = np.sqrt((X - center[0])**2 + (Y - center[1])**2) 15 | mask = dist_from_center <= radius 16 | return mask 17 | 18 | 19 | def NMS(heatmap, kernel=13): 20 | hmax = F.max_pool2d(heatmap, kernel, stride=1, padding=(kernel - 1) // 2) 21 | keep = (hmax == heatmap).float() 22 | return heatmap * keep, hmax, keep 23 | 24 | 25 | def draw_msra_gaussian(heatmap, center, sigma): 26 | tmp_size = sigma * 3 27 | mu_x = int(center[0] + 0.5) 28 | mu_y = int(center[1] + 0.5) 29 | w, h = heatmap.shape[0], heatmap.shape[1] 30 | ul = [int(mu_x - tmp_size), int(mu_y - tmp_size)] 31 | br = [int(mu_x + tmp_size + 1), int(mu_y + tmp_size + 1)] 32 | if ul[0] >= h or ul[1] >= w or br[0] < 0 or br[1] < 0: 33 | return heatmap 34 | size = 2 * tmp_size + 1 35 | x = np.arange(0, size, 1, np.float32) 36 | y = x[:, np.newaxis] 37 | x0 = y0 = size // 2 38 | g = np.exp(-((x - x0)**2 + (y - y0)**2) / (2 * sigma**2)) 39 | g_x = max(0, -ul[0]), min(br[0], h) - ul[0] 40 | g_y = max(0, -ul[1]), min(br[1], w) - ul[1] 41 | img_x = max(0, ul[0]), min(br[0], h) 42 | img_y = max(0, ul[1]), min(br[1], w) 43 | heatmap[img_y[0]:img_y[1], img_x[0]:img_x[1]] = np.maximum( 44 | heatmap[img_y[0]:img_y[1], img_x[0]:img_x[1]], g[g_y[0]:g_y[1], 45 | g_x[0]:g_x[1]]) 46 | return heatmap 47 | 48 | 49 | def kpm_gen(label_path, R, N): 50 | label = np.load(label_path) 51 | # label = label[0] 52 | label_ori = label.copy() 53 | label = label[::4, ::4] 54 | label = np.uint8(label * 255) 55 | contours, hierarchy = cv2.findContours(label, cv2.RETR_LIST, 56 | cv2.CHAIN_APPROX_NONE) 57 | contour_len = len(contours) 58 | 59 | label = np.repeat(label[..., np.newaxis], 3, axis=-1) 60 | draw_label = cv2.drawContours(label.copy(), contours, -1, (0, 0, 255), 1) 61 | 62 | point_file = [] 63 | if contour_len == 0: 64 | point_heatmap = np.zeros((512, 512)) 65 | else: 66 | point_heatmap = np.zeros((512, 512)) 67 | for contour in contours: 68 | stds = [] 69 | points = contour[:, 0] # (N,2) 70 | points = points * 4 71 | points_number = contour.shape[0] 72 | if points_number < 30: 73 | continue 74 | 75 | if points_number < 100: 76 | radius = 6 77 | neighbor_points_n_oneside = 3 78 | elif points_number < 200: 79 | radius = 10 80 | neighbor_points_n_oneside = 15 81 | elif points_number < 300: 82 | radius = 10 83 | neighbor_points_n_oneside = 20 84 | elif points_number < 350: 85 | radius = 15 86 | neighbor_points_n_oneside = 30 87 | else: 88 | radius = 10 89 | neighbor_points_n_oneside = 40 90 | 91 | for i in range(points_number): 92 | current_point = points[i] 93 | mask = create_circular_mask(512, 512, points[i], radius) 94 | overlap_area = np.sum( 95 | mask * label_ori) / (np.pi * radius * radius) 96 | stds.append(overlap_area) 97 | print("stds len: ", len(stds)) 98 | 99 | # show 100 | selected_points = [] 101 | stds = np.array(stds) 102 | neighbor_points = [] 103 | for i in range(len(points)): 104 | current_point = points[i] 105 | neighbor_points_index = np.concatenate([ 106 | np.arange(-neighbor_points_n_oneside, 0), 107 | np.arange(1, neighbor_points_n_oneside + 1) 108 | ]) + i 109 | neighbor_points_index[np.where( 110 | neighbor_points_index < 0)[0]] += len(points) 111 | neighbor_points_index[np.where( 112 | neighbor_points_index > len(points) - 1)[0]] -= len(points) 113 | if stds[i] < np.min( 114 | stds[neighbor_points_index]) or stds[i] > np.max( 115 | stds[neighbor_points_index]): 116 | # print(points[i]) 117 | point_heatmap = draw_msra_gaussian( 118 | point_heatmap, (points[i, 0], points[i, 1]), 5) 119 | selected_points.append(points[i]) 120 | 121 | print("selected_points num: ", len(selected_points)) 122 | # print(selected_points) 123 | maskk = np.zeros((512, 512)) 124 | rr, cc = skimage.draw.polygon( 125 | np.array(selected_points)[:, 1], 126 | np.array(selected_points)[:, 0]) 127 | maskk[rr, cc] = 1 128 | intersection = np.logical_and(label_ori, maskk) 129 | union = np.logical_or(label_ori, maskk) 130 | iou_score = np.sum(intersection) / np.sum(union) 131 | print(iou_score) 132 | return label_ori, point_heatmap 133 | 134 | 135 | def point_gen_isic2018(): 136 | R = 10 137 | N = 25 138 | data_dir = '/raid/wjc/data/skin_lesion/isic2018/Label' 139 | 140 | save_dir = data_dir.replace('Label', 'Point') 141 | os.makedirs(save_dir, exist_ok=True) 142 | 143 | path_list = os.listdir(data_dir) 144 | path_list.sort() 145 | num = 0 146 | for path in tqdm(path_list): 147 | name = path[:-4] 148 | label_path = os.path.join(data_dir, path) 149 | print(label_path) 150 | label_ori, point_heatmap = kpm_gen(label_path, R, N) 151 | 152 | save_path = os.path.join(save_dir, name + '.npy') 153 | np.save(save_path, point_heatmap) 154 | num += 1 155 | 156 | 157 | def point_gen_isic2016(): 158 | R = 10 159 | N = 25 160 | for split in ['Train', 'Test', 'Validation']: 161 | data_dir = '/raid/wjc/data/skin_lesion/isic2016/{}/Label'.format(split) 162 | 163 | save_dir = data_dir.replace('Label', 'Point') 164 | os.makedirs(save_dir, exist_ok=True) 165 | 166 | path_list = os.listdir(data_dir) 167 | path_list.sort() 168 | num = 0 169 | for path in tqdm(path_list): 170 | name = path[:-4] 171 | label_path = os.path.join(data_dir, path) 172 | print(label_path) 173 | label_ori, point_heatmap = kpm_gen(label_path, R, N) 174 | save_path = os.path.join(save_dir, name + '.npy') 175 | np.save(save_path, point_heatmap) 176 | num += 1 177 | 178 | 179 | if __name__ == '__main__': 180 | # point_gen_isic2018() 181 | point_gen_isic2016() -------------------------------------------------------------------------------- /src/process_resize.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import os 3 | import random 4 | import torch 5 | import numpy as np 6 | from tqdm import tqdm 7 | import matplotlib.pyplot as plt 8 | 9 | 10 | def process_isic2018( 11 | dim=(352, 352), save_dir='/raid/wjc/data/skin_lesion/isic2018/'): 12 | image_dir_path = '/raid/wl/2018_raw_data/ISIC2018_Task1-2_Training_Input/' 13 | mask_dir_path = '/raid/wl/2018_raw_data/ISIC2018_Task1_Training_GroundTruth/' 14 | 15 | image_path_list = os.listdir(image_dir_path) 16 | mask_path_list = os.listdir(mask_dir_path) 17 | 18 | image_path_list = list(filter(lambda x: x[-3:] == 'jpg', image_path_list)) 19 | mask_path_list = list(filter(lambda x: x[-3:] == 'png', mask_path_list)) 20 | 21 | image_path_list.sort() 22 | mask_path_list.sort() 23 | 24 | print(len(image_path_list), len(mask_path_list)) 25 | 26 | # ISBI Dataset 27 | for image_path, mask_path in zip(image_path_list, mask_path_list): 28 | if image_path[-3:] == 'jpg': 29 | print(image_path) 30 | assert os.path.basename(image_path)[:-4].split( 31 | '_')[1] == os.path.basename(mask_path)[:-4].split('_')[1] 32 | _id = os.path.basename(image_path)[:-4].split('_')[1] 33 | image_path = os.path.join(image_dir_path, image_path) 34 | mask_path = os.path.join(mask_dir_path, mask_path) 35 | image = cv2.imread(image_path) 36 | mask = cv2.imread(mask_path) 37 | 38 | image_new = cv2.resize(image, dim, interpolation=cv2.INTER_CUBIC) 39 | image_new = np.array(image_new, dtype=np.uint8) 40 | mask_new = cv2.resize(mask, dim, interpolation=cv2.INTER_NEAREST) 41 | mask_new = cv2.blur(img,(3,3)) 42 | mask_new = np.array(mask_new, dtype=np.uint8) 43 | 44 | save_dir_path = save_dir + '/Image' 45 | os.makedirs(save_dir_path, exist_ok=True) 46 | # np.save(os.path.join(save_dir_path, _id + '.npy'), image_new) 47 | print(image_new.shape) 48 | cv2.imwrite(os.path.join(save_dir_path, 'ISIC_' + _id + '.jpg'), 49 | image_new) 50 | 51 | save_dir_path = save_dir + '/Label' 52 | os.makedirs(save_dir_path, exist_ok=True) 53 | # np.save(os.path.join(save_dir_path, _id + '.npy'), mask_new) 54 | cv2.imwrite(os.path.join(save_dir_path, 'ISIC_' + _id + '.jpg'), 55 | mask_new) 56 | 57 | 58 | def process_ph2(): 59 | PH2_images_path = '/data2/cf_data/skinlesion_segment/PH2_rawdata/PH2_Dataset_images' 60 | 61 | path_list = os.listdir(PH2_images_path) 62 | path_list.sort() 63 | 64 | for path in path_list: 65 | image_path = os.path.join(PH2_images_path, path, 66 | path + '_Dermoscopic_Image', path + '.bmp') 67 | label_path = os.path.join(PH2_images_path, path, path + '_lesion', 68 | path + '_lesion.bmp') 69 | image = plt.imread(image_path) 70 | label = plt.imread(label_path) 71 | label = label[:, :, 0] 72 | 73 | dim = (512, 512) 74 | image_new = cv2.resize(image, dim, interpolation=cv2.INTER_AREA) 75 | label_new = cv2.resize(label, dim, interpolation=cv2.INTER_AREA) 76 | 77 | image_save_path = os.path.join( 78 | '/data2/cf_data/skinlesion_segment/PH2_rawdata/PH2/Image', 79 | path + '.npy') 80 | label_save_path = os.path.join( 81 | '/data2/cf_data/skinlesion_segment/PH2_rawdata/PH2/Label', 82 | path + '.npy') 83 | 84 | np.save(image_save_path, image_new) 85 | np.save(label_save_path, label_new) 86 | 87 | 88 | if __name__ == '__main__': 89 | process_isic2018(dim=(352,352), save_dir='/raid/wjc/data/skin_lesion/isic2018_jpg_352_smooth/') 90 | -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from collections import OrderedDict 4 | 5 | 6 | def load_model(model, pretrain_dir, log=True): 7 | state_dict_ = torch.load(pretrain_dir, map_location='cuda:0') 8 | print('loaded pretrained weights form %s !' % pretrain_dir) 9 | state_dict = OrderedDict() 10 | 11 | # convert data_parallal to model 12 | for key in state_dict_: 13 | if key.startswith('module') and not key.startswith('module_list'): 14 | state_dict[key[7:]] = state_dict_[key] 15 | else: 16 | state_dict[key] = state_dict_[key] 17 | 18 | # check loaded parameters and created model parameters 19 | model_state_dict = model.state_dict() 20 | for key in state_dict: 21 | if key in model_state_dict: 22 | # print(key,state_dict[key].shape,model_state_dict[key].shape) 23 | if state_dict[key].shape != model_state_dict[key].shape: 24 | if log: 25 | print( 26 | 'Skip loading parameter {}, required shape{}, loaded shape{}.' 27 | .format(key, model_state_dict[key].shape, 28 | state_dict[key].shape)) 29 | state_dict[key] = model_state_dict[key] 30 | else: 31 | if log: 32 | print('Drop parameter {}.'.format(key)) 33 | for key in model_state_dict: 34 | if key not in state_dict: 35 | if log: 36 | print('No param {}.'.format(key)) 37 | state_dict[key] = model_state_dict[key] 38 | model.load_state_dict(state_dict, strict=False) 39 | 40 | return model -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import os, argparse, sys, tqdm, logging, cv2 2 | import torch 3 | import torch.nn as nn 4 | import numpy as np 5 | from glob import glob 6 | import torch.nn.functional as F 7 | from medpy.metric.binary import hd, hd95, dc, jc, assd 8 | 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument('--log_name', 11 | type=str, 12 | default='bat_1_1_0_e6_loss_0_aug_1') 13 | parser.add_argument('--gpu', type=str, default='1') 14 | parser.add_argument('--fold', type=str, default='0') 15 | parser.add_argument('--dataset', type=str, default='isic2016') 16 | 17 | parser.add_argument('--arch', type=str, default='BAT') 18 | parser.add_argument('--net_layer', type=int, default=50) 19 | # pre-train 20 | parser.add_argument('--pre', type=int, default=0) 21 | 22 | # transformer 23 | parser.add_argument('--trans', type=int, default=1) 24 | 25 | # point constrain 26 | parser.add_argument('--point_pred', type=int, default=1) 27 | parser.add_argument('--ppl', type=int, default=6) 28 | 29 | # cross-scale framework 30 | parser.add_argument('--cross', type=int, default=0) 31 | 32 | parse_config = parser.parse_args() 33 | print(parse_config) 34 | os.environ['CUDA_VISIBLE_DEVICES'] = parse_config.gpu 35 | 36 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 37 | 38 | if parse_config.dataset == 'isic2018': 39 | from dataset.isic2018 import norm01, myDataset 40 | dataset = myDataset(parse_config.fold, 'valid', aug=False) 41 | elif parse_config.dataset == 'isic2016': 42 | from dataset.isic2016 import norm01, myDataset 43 | dataset = myDataset('test', aug=False) 44 | 45 | if parse_config.arch == 'BAT': 46 | if parse_config.trans == 1: 47 | from Ours.Base_transformer import BAT 48 | model = BAT(1, parse_config.net_layer, parse_config.point_pred, 49 | parse_config.ppl).cuda() 50 | else: 51 | from Ours.base import DeepLabV3 52 | model = DeepLabV3(1, parse_config.net_layer).cuda() 53 | 54 | dir_path = os.path.dirname( 55 | os.path.abspath(__file__)) + "/logs/{}/{}/fold_{}/".format( 56 | parse_config.dataset, parse_config.log_name, parse_config.fold) 57 | 58 | from src.utils import load_model 59 | 60 | model = load_model(model, dir_path + 'model/best.pkl') 61 | 62 | # logging 63 | txt_path = os.path.join(dir_path + 'parameter.txt') 64 | logging.basicConfig(filename=txt_path, 65 | level=logging.INFO, 66 | format='[%(asctime)s.%(msecs)03d] %(message)s', 67 | datefmt='%H:%M:%S') 68 | logging.getLogger().addHandler(logging.StreamHandler(sys.stdout)) 69 | test_loader = torch.utils.data.DataLoader(dataset, 70 | batch_size=8, 71 | pin_memory=True, 72 | drop_last=False, 73 | shuffle=False) 74 | 75 | 76 | def test(): 77 | model.eval() 78 | num = 0 79 | 80 | dice_value = 0 81 | jc_value = 0 82 | hd95_value = 0 83 | assd_value = 0 84 | 85 | from tqdm import tqdm 86 | labels = [] 87 | pres = [] 88 | for batch_idx, batch_data in tqdm(enumerate(test_loader)): 89 | data = batch_data['image'].to(device).float() 90 | label = batch_data['label'].to(device).float() 91 | with torch.no_grad(): 92 | if parse_config.arch == 'transfuse': 93 | _, _, output = model(data) 94 | elif parse_config.point_pred == 0: 95 | output = model(data) 96 | elif parse_config.point_pred == 1: 97 | output, _ = model(data) 98 | output = torch.sigmoid(output) 99 | output = output.cpu().numpy() > 0.5 100 | label = label.cpu().numpy() 101 | assert (output.shape == label.shape) 102 | labels.append(label) 103 | pres.append(output) 104 | labels = np.concatenate(labels, axis=0) 105 | pres = np.concatenate(pres, axis=0) 106 | print(labels.shape, pres.shape) 107 | for _id in range(labels.shape[0]): 108 | dice_ave = dc(labels[_id], pres[_id]) 109 | jc_ave = jc(labels[_id], pres[_id]) 110 | try: 111 | hd95_ave = hd95(labels[_id], pres[_id]) 112 | assd_ave = assd(labels[_id], pres[_id]) 113 | except RuntimeError: 114 | num += 1 115 | hd95_ave = 0 116 | assd_ave = 0 117 | 118 | dice_value += dice_ave 119 | jc_value += jc_ave 120 | hd95_value += hd95_ave 121 | assd_value += assd_ave 122 | 123 | dice_average = dice_value / (labels.shape[0] - num) 124 | jc_average = jc_value / (labels.shape[0] - num) 125 | hd95_average = hd95_value / (labels.shape[0] - num) 126 | assd_average = assd_value / (labels.shape[0] - num) 127 | 128 | logging.info('Dice value of test dataset : %f' % (dice_average)) 129 | logging.info('Jc value of test dataset : %f' % (jc_average)) 130 | logging.info('Hd95 value of test dataset : %f' % (hd95_average)) 131 | logging.info('Assd value of test dataset : %f' % (assd_average)) 132 | 133 | print("Average dice value of evaluation dataset = ", dice_average) 134 | print("Average jc value of evaluation dataset = ", jc_average) 135 | print("Average hd95 value of evaluation dataset = ", hd95_average) 136 | print("Average assd value of evaluation dataset = ", assd_average) 137 | return dice_average 138 | 139 | 140 | if __name__ == '__main__': 141 | test() --------------------------------------------------------------------------------