├── .gitignore
├── Ours
├── ASPP.py
├── Base_transformer.py
├── __init__.py
├── __pycache__
│ ├── ASPP.cpython-37.pyc
│ ├── Base_transformer.cpython-37.pyc
│ ├── __init__.cpython-37.pyc
│ ├── base.cpython-37.pyc
│ ├── cell_DETR.cpython-37.pyc
│ └── resnet.cpython-37.pyc
├── base.py
├── cell_DETR.py
├── non_local.py
└── resnet.py
├── README.md
├── dataset
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-37.pyc
│ ├── isbi2016.cpython-37.pyc
│ └── isic2016.cpython-37.pyc
├── isic2016.py
└── isic2018.py
├── framework.jpg
├── infer.py
├── lib
├── Cell_DETR_master
│ ├── .gitignore
│ ├── DETR_Test.py
│ ├── DeepDetr.py
│ ├── LICENSE
│ ├── __pycache__
│ │ ├── backbone.cpython-37.pyc
│ │ ├── bounding_box_head.cpython-37.pyc
│ │ ├── detr_new.cpython-37.pyc
│ │ ├── segmentation.cpython-37.pyc
│ │ └── transformer.cpython-37.pyc
│ ├── augmentation.py
│ ├── backbone.py
│ ├── botr.py
│ ├── botr2.py
│ ├── bounding_box_head.py
│ ├── ccccode.py
│ ├── dataset.py
│ ├── dcn2
│ │ ├── .gitignore
│ │ ├── LICENSE
│ │ ├── README.md
│ │ ├── build.py
│ │ ├── build_modulated.py
│ │ ├── functions
│ │ │ ├── __init__.py
│ │ │ ├── deform_conv.py
│ │ │ └── modulated_dcn_func.py
│ │ ├── modules
│ │ │ ├── __init__.py
│ │ │ ├── deform_conv.py
│ │ │ └── modulated_dcn.py
│ │ ├── src
│ │ │ ├── cuda
│ │ │ │ ├── deform_psroi_pooling_cuda.cu
│ │ │ │ ├── deform_psroi_pooling_cuda.h
│ │ │ │ ├── modulated_deform_im2col_cuda.cu
│ │ │ │ └── modulated_deform_im2col_cuda.h
│ │ │ ├── deform_conv.c
│ │ │ ├── deform_conv.h
│ │ │ ├── deform_conv_cuda.c
│ │ │ ├── deform_conv_cuda.h
│ │ │ ├── deform_conv_cuda_kernel.cu
│ │ │ ├── deform_conv_cuda_kernel.h
│ │ │ ├── modulated_dcn.c
│ │ │ ├── modulated_dcn.h
│ │ │ ├── modulated_dcn_cuda.c
│ │ │ └── modulated_dcn_cuda.h
│ │ ├── test.py
│ │ └── test_modulated.py
│ ├── detr.py
│ ├── detr_new.py
│ ├── images
│ │ └── CELL_DETR.PNG
│ ├── lossfunction.py
│ ├── main.py
│ ├── matcher.py
│ ├── misc.py
│ ├── model_wrapper.py
│ ├── pade_activation_unit
│ │ ├── __init__.py
│ │ ├── __pycache__
│ │ │ ├── __init__.cpython-37.pyc
│ │ │ └── utils.cpython-37.pyc
│ │ ├── cuda
│ │ │ ├── pau_cuda.cpp
│ │ │ ├── pau_cuda_kernels.cu
│ │ │ ├── python_imp
│ │ │ │ ├── Pade.py
│ │ │ │ └── __init__.py
│ │ │ └── setup.py
│ │ ├── torchsummary.py
│ │ └── utils.py
│ ├── pixel_adaptive_convolution
│ │ ├── README.md
│ │ ├── __pycache__
│ │ │ └── pac.cpython-37.pyc
│ │ ├── pac.py
│ │ ├── paccrf.py
│ │ ├── requirements.txt
│ │ ├── test_pac.py
│ │ └── tools
│ │ │ ├── flowlib.py
│ │ │ └── plot_log.py
│ ├── requirements.txt
│ ├── segmentation.py
│ ├── sgtr.py
│ ├── transformer.py
│ └── validation_metric.py
└── non_local
│ ├── non_local_concatenation.py
│ ├── non_local_dot_product.py
│ ├── non_local_embedded_gaussian.py
│ └── non_local_gaussian.py
├── src
├── BAT_Modules.py
├── __pycache__
│ ├── BAT_Modules.cpython-37.pyc
│ ├── losses.cpython-37.pyc
│ ├── transformer.cpython-37.pyc
│ └── utils.cpython-37.pyc
├── losses.py
├── process_point.py
├── process_resize.py
├── transformer.py
└── utils.py
├── test.py
└── train.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # test module 个人习惯,测试目录去掉
2 | Cell_DETR
3 | weights
4 | logs/
5 | **/__pycache__
6 | **/*.pyc
7 | <<<<<<< HEAD
8 | <<<<<<< HEAD
9 | build/
10 | =======
11 | >>>>>>> b37c2743a0c135f1baaeddad6207f5e071d770d9
12 | =======
13 | >>>>>>> b37c2743a0c135f1baaeddad6207f5e071d770d9
14 |
--------------------------------------------------------------------------------
/Ours/ASPP.py:
--------------------------------------------------------------------------------
1 | # camera-ready
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 |
7 | class ASPP(nn.Module):
8 | def __init__(self, num_classes, head = True):
9 | super(ASPP, self).__init__()
10 |
11 | self.conv_1x1_1 = nn.Conv2d(512, 256, kernel_size=1)
12 | self.bn_conv_1x1_1 = nn.BatchNorm2d(256)
13 |
14 | self.conv_3x3_1 = nn.Conv2d(512, 256, kernel_size=3, stride=1, padding=6, dilation=6)
15 | self.bn_conv_3x3_1 = nn.BatchNorm2d(256)
16 |
17 | self.conv_3x3_2 = nn.Conv2d(512, 256, kernel_size=3, stride=1, padding=12, dilation=12)
18 | self.bn_conv_3x3_2 = nn.BatchNorm2d(256)
19 |
20 | self.conv_3x3_3 = nn.Conv2d(512, 256, kernel_size=3, stride=1, padding=18, dilation=18)
21 | self.bn_conv_3x3_3 = nn.BatchNorm2d(256)
22 |
23 | self.avg_pool = nn.AdaptiveAvgPool2d(1)
24 |
25 | self.conv_1x1_2 = nn.Conv2d(512, 256, kernel_size=1)
26 | self.bn_conv_1x1_2 = nn.BatchNorm2d(256)
27 |
28 | self.conv_1x1_3 = nn.Conv2d(1280, 256, kernel_size=1) # (1280 = 5*256)
29 | self.bn_conv_1x1_3 = nn.BatchNorm2d(256)
30 |
31 | if head:
32 | self.conv_1x1_4 = nn.Conv2d(256, num_classes, kernel_size=1)
33 | self.head = head
34 |
35 | def forward(self, feature_map):
36 | # (feature_map has shape (batch_size, 512, h/16, w/16)) (assuming self.resnet is ResNet18_OS16 or ResNet34_OS16. If self.resnet instead is ResNet18_OS8 or ResNet34_OS8, it will be (batch_size, 512, h/8, w/8))
37 |
38 | feature_map_h = feature_map.size()[2] # (== h/16)
39 | feature_map_w = feature_map.size()[3] # (== w/16)
40 |
41 | out_1x1 = F.relu(self.bn_conv_1x1_1(self.conv_1x1_1(feature_map))) # (shape: (batch_size, 256, h/16, w/16))
42 | out_3x3_1 = F.relu(self.bn_conv_3x3_1(self.conv_3x3_1(feature_map))) # (shape: (batch_size, 256, h/16, w/16))
43 | out_3x3_2 = F.relu(self.bn_conv_3x3_2(self.conv_3x3_2(feature_map))) # (shape: (batch_size, 256, h/16, w/16))
44 | out_3x3_3 = F.relu(self.bn_conv_3x3_3(self.conv_3x3_3(feature_map))) # (shape: (batch_size, 256, h/16, w/16))
45 |
46 | out_img = self.avg_pool(feature_map) # (shape: (batch_size, 512, 1, 1))
47 | out_img = F.relu(self.bn_conv_1x1_2(self.conv_1x1_2(out_img))) # (shape: (batch_size, 256, 1, 1))
48 | out_img = F.interpolate(out_img, size=(feature_map_h, feature_map_w), mode="bilinear") # (shape: (batch_size, 256, h/16, w/16))
49 |
50 | out = torch.cat([out_1x1, out_3x3_1, out_3x3_2, out_3x3_3, out_img], 1) # (shape: (batch_size, 1280, h/16, w/16))
51 | out = F.relu(self.bn_conv_1x1_3(self.conv_1x1_3(out))) # (shape: (batch_size, 256, h/16, w/16))
52 | if self.head:
53 | out = self.conv_1x1_4(out) # (shape: (batch_size, num_classes, h/16, w/16))
54 |
55 | return out
56 |
57 | class ASPP_Bottleneck(nn.Module):
58 | def __init__(self, num_classes):
59 | super(ASPP_Bottleneck, self).__init__()
60 |
61 | self.conv_1x1_1 = nn.Conv2d(4*512, 256, kernel_size=1)
62 | self.bn_conv_1x1_1 = nn.BatchNorm2d(256)
63 |
64 | self.conv_3x3_1 = nn.Conv2d(4*512, 256, kernel_size=3, stride=1, padding=6, dilation=6)
65 | self.bn_conv_3x3_1 = nn.BatchNorm2d(256)
66 |
67 | self.conv_3x3_2 = nn.Conv2d(4*512, 256, kernel_size=3, stride=1, padding=12, dilation=12)
68 | self.bn_conv_3x3_2 = nn.BatchNorm2d(256)
69 |
70 | self.conv_3x3_3 = nn.Conv2d(4*512, 256, kernel_size=3, stride=1, padding=18, dilation=18)
71 | self.bn_conv_3x3_3 = nn.BatchNorm2d(256)
72 |
73 | self.avg_pool = nn.AdaptiveAvgPool2d(1)
74 |
75 | self.conv_1x1_2 = nn.Conv2d(4*512, 256, kernel_size=1)
76 | self.bn_conv_1x1_2 = nn.BatchNorm2d(256)
77 |
78 | self.conv_1x1_3 = nn.Conv2d(1280, 256, kernel_size=1) # (1280 = 5*256)
79 | self.bn_conv_1x1_3 = nn.BatchNorm2d(256)
80 |
81 | self.conv_1x1_4 = nn.Conv2d(256, num_classes, kernel_size=1)
82 |
83 | def forward(self, feature_map):
84 | # (feature_map has shape (batch_size, 4*512, h/16, w/16))
85 |
86 | feature_map_h = feature_map.size()[2] # (== h/16)
87 | feature_map_w = feature_map.size()[3] # (== w/16)
88 |
89 | out_1x1 = F.relu(self.bn_conv_1x1_1(self.conv_1x1_1(feature_map))) # (shape: (batch_size, 256, h/16, w/16))
90 | out_3x3_1 = F.relu(self.bn_conv_3x3_1(self.conv_3x3_1(feature_map))) # (shape: (batch_size, 256, h/16, w/16))
91 | out_3x3_2 = F.relu(self.bn_conv_3x3_2(self.conv_3x3_2(feature_map))) # (shape: (batch_size, 256, h/16, w/16))
92 | out_3x3_3 = F.relu(self.bn_conv_3x3_3(self.conv_3x3_3(feature_map))) # (shape: (batch_size, 256, h/16, w/16))
93 |
94 | out_img = self.avg_pool(feature_map) # (shape: (batch_size, 512, 1, 1))
95 | out_img = F.relu(self.bn_conv_1x1_2(self.conv_1x1_2(out_img))) # (shape: (batch_size, 256, 1, 1))
96 | out_img = F.interpolate(out_img, size=(feature_map_h, feature_map_w), mode="bilinear") # (shape: (batch_size, 256, h/16, w/16))
97 |
98 | out = torch.cat([out_1x1, out_3x3_1, out_3x3_2, out_3x3_3, out_img], 1) # (shape: (batch_size, 1280, h/16, w/16))
99 | out = F.relu(self.bn_conv_1x1_3(self.conv_1x1_3(out))) # (shape: (batch_size, 256, h/16, w/16))
100 | out = self.conv_1x1_4(out) # (shape: (batch_size, num_classes, h/16, w/16))
101 |
102 | return out
103 |
--------------------------------------------------------------------------------
/Ours/Base_transformer.py:
--------------------------------------------------------------------------------
1 | import sys, os
2 |
3 | root_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), '..')
4 | sys.path.insert(0, os.path.join(root_path))
5 | sys.path.insert(0, os.path.join(root_path, 'lib'))
6 | sys.path.insert(0, os.path.join(root_path, 'lib/Cell_DETR_master'))
7 |
8 | import torch
9 | import torch.nn as nn
10 | import torch.nn.functional as F
11 |
12 | from Ours.base import DeepLabV3 as base
13 |
14 | from src.BAT_Modules import BoundaryCrossAttention, CrossAttention
15 | from src.BAT_Modules import MultiHeadAttention as Attention_head
16 | from src.transformer import BoundaryAwareTransformer, Transformer
17 |
18 |
19 | class BAT(nn.Module):
20 | def __init__(
21 | self,
22 | num_classes,
23 | num_layers,
24 | point_pred,
25 | decoder=False,
26 | transformer_type_index=0,
27 | hidden_features=128, # 256
28 | number_of_query_positions=1,
29 | segmentation_attention_heads=8):
30 |
31 | super(BAT, self).__init__()
32 |
33 | self.num_classes = num_classes
34 | self.point_pred = point_pred
35 | self.transformer_type = "BoundaryAwareTransformer" if transformer_type_index == 0 else "Transformer"
36 | self.use_decoder = decoder
37 |
38 | self.deeplab = base(num_classes, num_layers)
39 |
40 | in_channels = 2048 if num_layers == 50 else 512
41 |
42 | self.convolution_mapping = nn.Conv2d(in_channels=in_channels,
43 | out_channels=hidden_features,
44 | kernel_size=(1, 1),
45 | stride=(1, 1),
46 | padding=(0, 0),
47 | bias=True)
48 |
49 | self.query_positions = nn.Parameter(data=torch.randn(
50 | number_of_query_positions, hidden_features, dtype=torch.float),
51 | requires_grad=True)
52 |
53 | self.row_embedding = nn.Parameter(data=torch.randn(100,
54 | hidden_features //
55 | 2,
56 | dtype=torch.float),
57 | requires_grad=True)
58 | self.column_embedding = nn.Parameter(data=torch.randn(
59 | 100, hidden_features // 2, dtype=torch.float),
60 | requires_grad=True)
61 |
62 | self.transformer = [
63 | Transformer(d_model=hidden_features),
64 | BoundaryAwareTransformer(d_model=hidden_features)
65 | ][point_pred]
66 |
67 | if self.use_decoder:
68 | self.BCA = BoundaryCrossAttention(hidden_features, 8)
69 |
70 | self.trans_out_conv = nn.Conv2d(in_channels=hidden_features,
71 | out_channels=in_channels,
72 | kernel_size=(1, 1),
73 | stride=(1, 1),
74 | padding=(0, 0),
75 | bias=True)
76 |
77 | def forward(self, x):
78 | h = x.size()[2]
79 | w = x.size()[3]
80 | feature_map = self.deeplab.resnet(x)
81 |
82 | features = self.convolution_mapping(feature_map)
83 | height, width = features.shape[2:]
84 | batch_size = features.shape[0]
85 | positional_embeddings = torch.cat([
86 | self.column_embedding[:height].unsqueeze(dim=0).repeat(
87 | height, 1, 1),
88 | self.row_embedding[:width].unsqueeze(dim=1).repeat(1, width, 1)
89 | ],
90 | dim=-1).permute(
91 | 2, 0, 1).unsqueeze(0).repeat(
92 | batch_size, 1, 1, 1)
93 |
94 | if self.transformer_type == 'BoundaryAwareTransformer':
95 | latent_tensor, features_encoded, point_maps = self.transformer(
96 | features, None, self.query_positions, positional_embeddings)
97 | else:
98 | latent_tensor, features_encoded = self.transformer(
99 | features, None, self.query_positions, positional_embeddings)
100 | point_maps = []
101 |
102 | latent_tensor = latent_tensor.permute(2, 0, 1)
103 | # shape:(bs, 1 , 128)
104 |
105 | if self.use_decoder:
106 | features_encoded, point_dec = self.BCA(features_encoded,
107 | latent_tensor)
108 | point_maps.append(point_dec)
109 |
110 | trans_feature_maps = self.trans_out_conv(
111 | features_encoded.contiguous()) #.contiguous()
112 |
113 | trans_feature_maps = trans_feature_maps + feature_map
114 |
115 | output = self.deeplab.aspp(
116 | trans_feature_maps
117 | ) # (shape: (batch_size, num_classes, h/16, w/16))
118 | output = F.interpolate(
119 | output, size=(h, w),
120 | mode="bilinear") # (shape: (batch_size, num_classes, h, w))
121 |
122 | if self.point_pred == 1:
123 | return output, point_maps
124 |
125 | return output
--------------------------------------------------------------------------------
/Ours/__init__.py:
--------------------------------------------------------------------------------
1 | # from .Cell_DETR_master.sgtr import CellDETR
2 |
--------------------------------------------------------------------------------
/Ours/__pycache__/ASPP.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jcwang123/BA-Transformer/d1ccdff68beac82d18c2cd64bb800e51d7afc5e1/Ours/__pycache__/ASPP.cpython-37.pyc
--------------------------------------------------------------------------------
/Ours/__pycache__/Base_transformer.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jcwang123/BA-Transformer/d1ccdff68beac82d18c2cd64bb800e51d7afc5e1/Ours/__pycache__/Base_transformer.cpython-37.pyc
--------------------------------------------------------------------------------
/Ours/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jcwang123/BA-Transformer/d1ccdff68beac82d18c2cd64bb800e51d7afc5e1/Ours/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/Ours/__pycache__/base.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jcwang123/BA-Transformer/d1ccdff68beac82d18c2cd64bb800e51d7afc5e1/Ours/__pycache__/base.cpython-37.pyc
--------------------------------------------------------------------------------
/Ours/__pycache__/cell_DETR.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jcwang123/BA-Transformer/d1ccdff68beac82d18c2cd64bb800e51d7afc5e1/Ours/__pycache__/cell_DETR.cpython-37.pyc
--------------------------------------------------------------------------------
/Ours/__pycache__/resnet.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jcwang123/BA-Transformer/d1ccdff68beac82d18c2cd64bb800e51d7afc5e1/Ours/__pycache__/resnet.cpython-37.pyc
--------------------------------------------------------------------------------
/Ours/base.py:
--------------------------------------------------------------------------------
1 | # camera-ready
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | import os
7 | import sys
8 | # sys.path.insert(0, '../')
9 |
10 | from Ours.resnet import ResNet18_OS16, ResNet34_OS16, ResNet50_OS16, ResNet101_OS16, ResNet152_OS16, ResNet18_OS8, ResNet34_OS8
11 | from Ours.ASPP import ASPP, ASPP_Bottleneck
12 |
13 |
14 | class DeepLabV3(nn.Module):
15 | def __init__(self, num_classes, num_layers):
16 | super(DeepLabV3, self).__init__()
17 |
18 | self.num_classes = num_classes
19 | layers = num_layers
20 | # NOTE! specify the type of ResNet here
21 | # NOTE! if you use ResNet50-152, set self.aspp = ASPP_Bottleneck(num_classes=self.num_classes) instead
22 | if layers == 18:
23 | self.resnet = ResNet18_OS16()
24 | self.aspp = ASPP(num_classes=self.num_classes)
25 | elif layers == 50:
26 | self.resnet = ResNet50_OS16()
27 | self.aspp = ASPP_Bottleneck(num_classes=self.num_classes)
28 |
29 | def forward(self, x):
30 | # (x has shape (batch_size, 3, h, w))
31 | h = x.size()[2]
32 | w = x.size()[3]
33 | feature_map = self.resnet(x)
34 |
35 | # (shape: (batch_size, 512, h/16, w/16)) (assuming self.resnet is ResNet18_OS16 or ResNet34_OS16.
36 | # If self.resnet is ResNet18_OS8 or ResNet34_OS8, it will be (batch_size, 512, h/8, w/8).
37 | # If self.resnet is ResNet50-152, it will be (batch_size, 4*512, h/16, w/16))
38 | output = self.aspp(
39 | feature_map) # (shape: (batch_size, num_classes, h/16, w/16))
40 | output = F.upsample(
41 | output, size=(h, w),
42 | mode="bilinear") # (shape: (batch_size, num_classes, h, w))
43 | return output
44 |
45 |
46 | if __name__ == '__main__':
47 | model = DeepLabV3(12).cuda()
48 | d = torch.rand((2, 3, 384, 384)).cuda()
49 | o = model(d)
50 | print(o.size())
--------------------------------------------------------------------------------
/Ours/cell_DETR.py:
--------------------------------------------------------------------------------
1 | import sys, os
2 |
3 | root_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), '..')
4 | sys.path.insert(0, root_path)
5 | sys.path.insert(0, os.path.join(root_path, 'lib/Cell_DETR_master'))
6 |
7 | import torch
8 | import torch.nn as nn
9 |
10 | from detr_new import CellDETR
11 | from segmentation import ResFeaturePyramidBlock, ResPACFeaturePyramidBlock
12 | #from transformer import BoundaryAwareTransformer, Transformer
13 | from src.transformer import BoundaryAwareTransformer, Transformer
14 |
15 | type_of_transformer = [BoundaryAwareTransformer, Transformer]
16 |
17 |
18 | def cell_detr_128(pretrained=False, type_index=0):
19 |
20 | detr = CellDETR(num_classes=2,
21 | transformer_type=type_of_transformer[type_index],
22 | segmentation_head_block=ResFeaturePyramidBlock,
23 | segmentation_head_final_activation=nn.Softmax,
24 | backbone_convolution=nn.Conv2d,
25 | segmentation_head_convolution=nn.Conv2d,
26 | transformer_activation=nn.LeakyReLU,
27 | backbone_activation=nn.LeakyReLU,
28 | bounding_box_head_activation=nn.LeakyReLU,
29 | classification_head_activation=nn.LeakyReLU,
30 | segmentation_head_activation=nn.LeakyReLU)
31 | if pretrained:
32 | ckpt = torch.load(
33 | "/home/chenfei/my_codes/TransformerCode-master/lib/Cell_DETR_master/trained_models/Cell_DETR_A/detr_99.pt"
34 | )
35 | detr.load_state_dict(ckpt, strict=False)
36 | print("pretrained cell_detr's transformer!")
37 | return detr.transformer
--------------------------------------------------------------------------------
/Ours/non_local.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from torch.nn import functional as F
4 |
5 | class _NonLocalBlockND(nn.Module):
6 | def __init__(self, in_channels, inter_channels=None, dimension=3, sub_sample=True, bn_layer=True):
7 | """
8 | :param in_channels:
9 | :param inter_channels:
10 | :param dimension:
11 | :param sub_sample:
12 | :param bn_layer:
13 | """
14 |
15 | super(_NonLocalBlockND, self).__init__()
16 |
17 | assert dimension in [1, 2, 3]
18 |
19 | self.dimension = dimension
20 | self.sub_sample = sub_sample
21 |
22 | self.in_channels = in_channels
23 | self.inter_channels = inter_channels
24 |
25 | if self.inter_channels is None:
26 | self.inter_channels = in_channels // 2
27 | if self.inter_channels == 0:
28 | self.inter_channels = 1
29 |
30 | if dimension == 3:
31 | conv_nd = nn.Conv3d
32 | max_pool_layer = nn.MaxPool3d(kernel_size=(1, 2, 2))
33 | bn = nn.BatchNorm3d
34 | elif dimension == 2:
35 | conv_nd = nn.Conv2d
36 | max_pool_layer = nn.MaxPool2d(kernel_size=(2, 2))
37 | bn = nn.BatchNorm2d
38 | else:
39 | conv_nd = nn.Conv1d
40 | max_pool_layer = nn.MaxPool1d(kernel_size=(2))
41 | bn = nn.BatchNorm1d
42 |
43 | self.g = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
44 | kernel_size=1, stride=1, padding=0)
45 |
46 | if bn_layer:
47 | self.W = nn.Sequential(
48 | conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels,
49 | kernel_size=1, stride=1, padding=0),
50 | bn(self.in_channels)
51 | )
52 | nn.init.constant_(self.W[1].weight, 0)
53 | nn.init.constant_(self.W[1].bias, 0)
54 | else:
55 | self.W = conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels,
56 | kernel_size=1, stride=1, padding=0)
57 | nn.init.constant_(self.W.weight, 0)
58 | nn.init.constant_(self.W.bias, 0)
59 |
60 | self.theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
61 | kernel_size=1, stride=1, padding=0)
62 | self.phi = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
63 | kernel_size=1, stride=1, padding=0)
64 |
65 | if sub_sample:
66 | self.g = nn.Sequential(self.g, max_pool_layer)
67 | self.phi = nn.Sequential(self.phi, max_pool_layer)
68 |
69 | def forward(self, x, y, return_nl_map=False):
70 | """
71 | :param x: (b, c, h, w)
72 | :param y: (b, c, 1)
73 | :param return_nl_map: if True return z, nl_map, else only return z.
74 | :return:
75 | """
76 |
77 | batch_size = x.size(0)
78 | h, w = x.shape[2:]
79 |
80 | g_x = self.g(x).view(batch_size, self.inter_channels, -1)
81 | #g_x = g_x.permute(0, 2, 1)
82 |
83 | theta_x = self.theta(x).view(batch_size, self.inter_channels, -1)
84 | theta_x = theta_x.permute(0, 2, 1)
85 |
86 | #phi_x = self.phi(x).view(batch_size, self.inter_channels, -1)
87 | phi_x = self.phi(y.unsqueeze(-1)).view(batch_size, self.inter_channels, -1)
88 | f = torch.matmul(theta_x, phi_x)
89 | #f_div_C = F.softmax(f, dim=-1)
90 | f_div_C = torch.sigmoid(f)
91 | f_div_C = f_div_C.permute(0,2,1).contiguous()
92 |
93 | #y = torch.matmul(f_div_C, g_x)
94 | #y = y.permute(0, 2, 1).contiguous()
95 | y = g_x * f_div_C
96 | y = y.view(batch_size, self.inter_channels, *x.size()[2:])
97 | W_y = self.W(y)
98 | z = W_y + x
99 |
100 | if return_nl_map:
101 | return z, f_div_C.view(batch_size, 1, h, w)
102 | return z
103 |
104 | class NONLocalBlock2D(_NonLocalBlockND):
105 | def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True):
106 | super(NONLocalBlock2D, self).__init__(in_channels,
107 | inter_channels=inter_channels,
108 | dimension=2, sub_sample=sub_sample,
109 | bn_layer=bn_layer,)
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Boundary-aware Transformers for Skin Lesion Segmentation
2 |
3 | ## Introduction
4 |
5 | This is an official release of the paper **Boundary-aware Transformers for Skin Lesion Segmentation**.
6 |
7 | > [**Boundary-aware Transformers for Skin Lesion Segmentation**](https://arxiv.org/abs/2110.03864),
8 | > **Jiacheng Wang**, Lan Wei, Liansheng Wang, Qichao Zhou, Lei Zhu, Jing Qin
9 | > In: Medical Image Computing and Computer Assisted Intervention (MICCAI), 2021
10 | > [[arXiv](https://arxiv.org/abs/2110.03864)][[Bibetex](https://github.com/jcwang123/BA-Transformer#citation)]
11 |
12 |
13 |
14 | ## News
15 | - **[5/27 2022] We have released a more powerful [XBound-Former](https://github.com/jcwang123/xboundformer) with clearer concept and codes.**
16 | - **[11/15 2021] We have released the point map data.**
17 | - **[11/08 2021] We have released the training / testing codes.**
18 |
19 | ## Code List
20 |
21 | - [x] Network
22 | - [x] Pre-processing
23 | - [x] Training Codes
24 | - [ ] MS
25 |
26 | For more details or any questions, please feel easy to contact us by email (jiachengw@stu.xmu.edu.cn).
27 |
28 |
29 | ## Usage
30 |
31 | ### Dataset
32 |
33 | Please download the dataset from [ISIC](https://www.isic-archive.com/) challenge and [PH2](https://www.fc.up.pt/addi/ph2%20database.html) website.
34 |
35 | ### Pre-processing
36 |
37 | Please run:
38 |
39 | ```bash
40 | $ python src/process_resize.py
41 | $ python src/process_point.py
42 | ```
43 |
44 | You need to change the **File Path** to your own.
45 |
46 | ### Point Maps
47 |
48 | For your convenience, we release the processed maps and the dataset division.
49 |
50 | Please download them from [Baidu Disk](https://pan.baidu.com/s/1pNbH5zUI8Dw_ZAC8Iq9f7w) (code:**kmqr**) or [Google Drive](https://drive.google.com/file/d/1mSLt-ipLM9CxrfvwgjJr5V9NKrpnQaQ5/view?usp=sharing)
51 |
52 | The file names are equal to the original image names.
53 |
54 | ### Training
55 |
56 | ### Testing
57 |
58 | Download the pretrained weight for PH2 dataset from [Google Drive](https://drive.google.com/file/d/1-eMHYX1fr-QvI3n50S0xqWcxc3FGsMgE/view?usp=sharing).
59 |
60 | ```bash
61 | $ python test.py --dataset isic2016
62 | ```
63 |
64 | ### Result
65 |
66 | |Method | Dice | IoU | HD95 | ASSD|
67 | | ------ | ------ | ------ |------ |------ |
68 | Lee *et al.* | 0.918 | 0.843 | - | - |
69 | BAT (paper)| 0.921 | 0.858 | - | - |
70 |
71 |
72 |
73 |
74 | ## Citation
75 |
76 | If you find BAT useful in your research, please consider citing:
77 |
78 | ```
79 | @inproceedings{wang2021boundary,
80 | title={Boundary-Aware Transformers for Skin Lesion Segmentation},
81 | author={Wang, Jiacheng and Wei, Lan and Wang, Liansheng and Zhou, Qichao and Zhu, Lei and Qin, Jing},
82 | booktitle={International Conference on Medical Image Computing and Computer-Assisted Intervention},
83 | pages={206--216},
84 | year={2021},
85 | organization={Springer}
86 | }
87 | ```
88 |
--------------------------------------------------------------------------------
/dataset/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jcwang123/BA-Transformer/d1ccdff68beac82d18c2cd64bb800e51d7afc5e1/dataset/__init__.py
--------------------------------------------------------------------------------
/dataset/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jcwang123/BA-Transformer/d1ccdff68beac82d18c2cd64bb800e51d7afc5e1/dataset/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/dataset/__pycache__/isbi2016.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jcwang123/BA-Transformer/d1ccdff68beac82d18c2cd64bb800e51d7afc5e1/dataset/__pycache__/isbi2016.cpython-37.pyc
--------------------------------------------------------------------------------
/dataset/__pycache__/isic2016.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jcwang123/BA-Transformer/d1ccdff68beac82d18c2cd64bb800e51d7afc5e1/dataset/__pycache__/isic2016.cpython-37.pyc
--------------------------------------------------------------------------------
/dataset/isic2016.py:
--------------------------------------------------------------------------------
1 | import os
2 | import glob
3 | import random
4 | import torch
5 | import torch.nn as nn
6 | import numpy as np
7 | import torch.utils.data
8 | from torchvision import transforms
9 | import torch.utils.data as data
10 | import torch.nn.functional as F
11 |
12 | import albumentations as A
13 |
14 |
15 | def norm01(x):
16 | return np.clip(x, 0, 255) / 255
17 |
18 |
19 | def filter_image(p):
20 | label_data = np.load(p.replace('image', 'label'))
21 | return np.max(label_data) == 1
22 |
23 |
24 | class myDataset(data.Dataset):
25 | def __init__(self, split, aug=False):
26 | super(myDataset, self).__init__()
27 |
28 | self.image_paths = []
29 | self.label_paths = []
30 | self.point_paths = []
31 | self.dist_paths = []
32 |
33 | root_dir = '/raid/wjc/data/skin_lesion/isic2016/'
34 | if split == 'train':
35 | self.image_paths = glob.glob(root_dir + '/Train/Image/*.npy')
36 | self.label_paths = glob.glob(root_dir + '/Train/Label/*.npy')
37 | self.point_paths = glob.glob(root_dir + '/Train/Point/*.npy')
38 | elif split == 'valid':
39 | self.image_paths = glob.glob(root_dir + '/Validation/Image/*.npy')
40 | self.label_paths = glob.glob(root_dir + '/Validation/Label/*.npy')
41 | self.point_paths = glob.glob(root_dir + '/Validation/Point/*.npy')
42 | elif split == 'test':
43 | self.image_paths = glob.glob(root_dir + '/Test/Image/*.npy')
44 | self.label_paths = glob.glob(root_dir + '/Test/Label/*.npy')
45 | self.point_paths = glob.glob(root_dir + '/Test/Point/*.npy')
46 | self.image_paths.sort()
47 | self.label_paths.sort()
48 | self.point_paths.sort()
49 |
50 | print('Loaded {} frames'.format(len(self.image_paths)))
51 | self.num_samples = len(self.image_paths)
52 | self.aug = aug
53 |
54 | self.transf = A.Compose([
55 | A.HorizontalFlip(p=0.5),
56 | A.VerticalFlip(p=0.5),
57 | A.RandomBrightnessContrast(p=0.2),
58 | A.Rotate()
59 | ])
60 |
61 | def __getitem__(self, index):
62 |
63 | image_data = np.load(self.image_paths[index])
64 | label_data = np.load(self.label_paths[index]) > 0.5
65 | point_data = np.load(self.point_paths[index])
66 |
67 | if self.aug:
68 | mask = np.concatenate([
69 | label_data[..., np.newaxis].astype('uint8'),
70 | point_data[..., np.newaxis]
71 | ],
72 | axis=-1)
73 | # print(mask.shape)
74 | tsf = self.transf(image=image_data.astype('uint8'), mask=mask)
75 | image_data, mask_aug = tsf['image'], tsf['mask']
76 | label_data = mask_aug[:, :, 0]
77 | point_data = mask_aug[:, :, 1]
78 |
79 | image_data = norm01(image_data)
80 | label_data = np.expand_dims(label_data, 0)
81 | point_data = np.expand_dims(point_data, 0)
82 | image_data = torch.from_numpy(image_data).float()
83 | label_data = torch.from_numpy(label_data).float()
84 | point_data = torch.from_numpy(point_data).float()
85 |
86 | image_data = image_data.permute(2, 0, 1)
87 | return {
88 | 'image_path': self.image_paths[index],
89 | 'label_path': self.label_paths[index],
90 | 'point_path': self.point_paths[index],
91 | 'image': image_data,
92 | 'label': label_data,
93 | 'point': point_data
94 | }
95 |
96 | def __len__(self):
97 | return self.num_samples
--------------------------------------------------------------------------------
/dataset/isic2018.py:
--------------------------------------------------------------------------------
1 | import os
2 | import glob
3 | import json
4 | import torch
5 | import random
6 | import torch.nn as nn
7 | import numpy as np
8 | import torch.utils.data
9 | from torchvision import transforms
10 | import torch.utils.data as data
11 | import torch.nn.functional as F
12 |
13 | import albumentations as A
14 | from sklearn.model_selection import KFold
15 |
16 |
17 | def norm01(x):
18 | return np.clip(x, 0, 255) / 255
19 |
20 |
21 | seperable_indexes = json.load(open('dataset/data_split.json', 'r'))
22 |
23 |
24 | # cross validation
25 | class myDataset(data.Dataset):
26 | def __init__(self, fold, split, aug=False):
27 | super(myDataset, self).__init__()
28 | self.split = split
29 | root_data_dir = '/raid/wjc/data/skin_lesion/isic2018/'
30 |
31 | # load images, label, point
32 | self.image_paths = []
33 | self.label_paths = []
34 | self.point_paths = []
35 | self.dist_paths = []
36 |
37 | indexes = [l[:-4] for l in os.listdir(root_data_dir + 'Image/')]
38 | valid_indexes = seperable_indexes[fold]
39 |
40 | train_indexes = list(filter(lambda x: x not in valid_indexes, indexes))
41 | print('Fold {}: train: {} valid: {}'.format(fold, len(train_indexes),
42 | len(valid_indexes)))
43 |
44 | indexes = train_indexes if split == 'train' else valid_indexes
45 |
46 | self.image_paths = [
47 | root_data_dir + '/Image/{}.npy'.format(_id) for _id in indexes
48 | ]
49 | self.label_paths = [
50 | root_data_dir + '/Label/{}.npy'.format(_id) for _id in indexes
51 | ]
52 | self.point_paths = [
53 | root_data_dir + '/Point/{}.npy'.format(_id) for _id in indexes
54 | ]
55 | # self.point_All_paths = [
56 | # '/data2/cf_data/skinlesion_segment/ISIC2018_rawdata/ISBI_2018/Train/Point_All/{}.npy'
57 | # .format(_id) for _id in indexes
58 | # ]
59 |
60 | print('Loaded {} frames'.format(len(self.image_paths)))
61 | self.num_samples = len(self.image_paths)
62 | self.aug = aug
63 |
64 | p = 0.5
65 | self.transf = A.Compose([
66 | A.GaussNoise(p=p),
67 | A.HorizontalFlip(p=p),
68 | A.VerticalFlip(p=p),
69 | A.ShiftScaleRotate(p=p),
70 | # A.RandomBrightnessContrast(p=p),
71 | ])
72 |
73 | def __getitem__(self, index):
74 |
75 | image_data = np.load(self.image_paths[index])
76 | label_data = np.load(self.label_paths[index]) > 0.5
77 | point_data = np.load(self.point_paths[index]) > 0.5
78 |
79 | if self.aug and self.split == 'train':
80 | mask = np.concatenate([
81 | label_data[..., np.newaxis].astype('uint8'),
82 | point_data[..., np.newaxis]
83 | ],
84 | axis=-1)
85 | # print(mask.shape)
86 | tsf = self.transf(image=image_data.astype('uint8'), mask=mask)
87 | image_data, mask_aug = tsf['image'], tsf['mask']
88 | label_data = mask_aug[:, :, 0]
89 | point_data = mask_aug[:, :, 1]
90 |
91 | image_data = norm01(image_data)
92 | label_data = np.expand_dims(label_data, 0)
93 | point_data = np.expand_dims(point_data, 0)
94 |
95 | image_data = torch.from_numpy(image_data).float()
96 | label_data = torch.from_numpy(label_data).float()
97 | point_data = torch.from_numpy(point_data).float()
98 |
99 | image_data = image_data.permute(2, 0, 1)
100 | return {
101 | 'image_path': self.image_paths[index],
102 | 'label_path': self.label_paths[index],
103 | 'point_path': self.point_paths[index],
104 | 'image': image_data,
105 | 'label': label_data,
106 | 'point': point_data
107 | }
108 |
109 | def __len__(self):
110 | return self.num_samples
111 |
112 |
113 | def dataset_kfold(dataset_dir, save_path, k=5):
114 | indexes = [l[:-4] for l in os.listdir(dataset_dir)]
115 |
116 | kf = KFold(k, shuffle=True) #k折交叉验证
117 | val_index = dict()
118 | for i in range(k):
119 | val_index[str(i)] = []
120 |
121 | for i, (tr, val) in enumerate(kf.split(indexes)):
122 | for item in val:
123 | val_index[str(i)].append(indexes[item])
124 | print('fold:{},train_len:{},val_len:{}'.format(i, len(tr), len(val)))
125 |
126 | with open(save_path, 'w') as f:
127 | json.dump(val_index, f)
128 |
129 |
130 | def random_seperate_dataset():
131 | # indexes = [l[:-4] for l in os.listdir('/raid/wl/ISBI_2018/Train/Image/')]
132 | # random.shuffle(indexes)
133 | # names = {'0':indexes[:500], '1':indexes[500:1000], '2':indexes[1000:1500], '3':indexes[1500:2000],'4':indexes[2000:]}
134 | # names = [indexes[:500], indexes[500:1000], indexes[1000:1500], indexes[1500:2000], indexes[2000:]]
135 | # with open('/raid/wl/ISBI_2018/data_split.json','w') as f:
136 | # json.dump(names, f)
137 | return
138 |
139 |
140 | if __name__ == '__main__':
141 | from tqdm import tqdm
142 | dataset = myDataset(fold='0', split='train', aug=True)
143 |
144 | train_loader = torch.utils.data.DataLoader(dataset,
145 | batch_size=8,
146 | shuffle=False,
147 | num_workers=2,
148 | pin_memory=True,
149 | drop_last=True)
150 | for d in train_loader:
151 | pass
152 |
--------------------------------------------------------------------------------
/framework.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jcwang123/BA-Transformer/d1ccdff68beac82d18c2cd64bb800e51d7afc5e1/framework.jpg
--------------------------------------------------------------------------------
/infer.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | def main():
5 | print('hello')
6 |
7 |
8 | if __name__ == '__main__':
9 | main()
--------------------------------------------------------------------------------
/lib/Cell_DETR_master/.gitignore:
--------------------------------------------------------------------------------
1 | # Source: https://github.com/github/gitignore/blob/master/Global/JetBrains.gitignore
2 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider
3 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839
4 |
5 | # User-specific stuff
6 | .idea
7 | .idea/**/workspace.xml
8 | .idea/**/tasks.xml
9 | .idea/**/usage.statistics.xml
10 | .idea/**/dictionaries
11 | .idea/**/shelf
12 |
13 | # Generated files
14 | .idea/**/contentModel.xml
15 |
16 | # Sensitive or high-churn files
17 | .idea/**/dataSources/
18 | .idea/**/dataSources.ids
19 | .idea/**/dataSources.local.xml
20 | .idea/**/sqlDataSources.xml
21 | .idea/**/dynamic.xml
22 | .idea/**/uiDesigner.xml
23 | .idea/**/dbnavigator.xml
24 |
25 | # Gradle
26 | .idea/**/gradle.xml
27 | .idea/**/libraries
28 |
29 | # Gradle and Maven with auto-import
30 | # When using Gradle or Maven with auto-import, you should exclude module files,
31 | # since they will be recreated, and may cause churn. Uncomment if using
32 | # auto-import.
33 | # .idea/artifacts
34 | # .idea/compiler.xml
35 | # .idea/jarRepositories.xml
36 | # .idea/modules.xml
37 | # .idea/*.iml
38 | # .idea/modules
39 | # *.iml
40 | # *.ipr
41 |
42 | # CMake
43 | cmake-build-*/
44 |
45 | # Mongo Explorer plugin
46 | .idea/**/mongoSettings.xml
47 |
48 | # File-based project format
49 | *.iws
50 |
51 | # IntelliJ
52 | out/
53 |
54 | # mpeltonen/sbt-idea plugin
55 | .idea_modules/
56 |
57 | # JIRA plugin
58 | atlassian-ide-plugin.xml
59 |
60 | # Cursive Clojure plugin
61 | .idea/replstate.xml
62 |
63 | # Crashlytics plugin (for Android Studio and IntelliJ)
64 | com_crashlytics_export_strings.xml
65 | crashlytics.properties
66 | crashlytics-build.properties
67 | fabric.properties
68 |
69 | # Editor-based Rest Client
70 | .idea/httpRequests
71 |
72 | # Android studio 3.1+ serialized cache file
73 | .idea/caches/build_file_checksums.ser
74 |
--------------------------------------------------------------------------------
/lib/Cell_DETR_master/DETR_Test.py:
--------------------------------------------------------------------------------
1 | import os,argparse, math
2 | import sys
3 | import torch
4 | import torch.nn
5 | import tqdm
6 | import logging
7 | import numpy as np
8 | from glob import glob
9 | from HNgtv_dataset import norm01,myDataset,crop_array
10 | from medpy.metric.binary import hd, hd95, dc, jc, assd
11 | sys.path.append('Ours/Cell_DETR_master/')
12 |
13 | parser = argparse.ArgumentParser()
14 | parser.add_argument('--crop_size', type=int,default=128)
15 | # parser.add_argument('--with_BPB', type=int, default=0,choices = [0,1])
16 | parse_config = parser.parse_args()
17 | print(parse_config)
18 |
19 | os.environ['CUDA_VISIBLE_DEVICES'] = '1'
20 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
21 | # model = torch.load('./logs/test5_aaai_sdm_loss/model/best.pkl')
22 | dir_path = "./logs/test_transformer_loss_0_ver_1/"
23 | model = torch.load(dir_path + 'model/best.pkl')
24 | # model = torch.load(dir_path + 'model/latest.pkl')
25 | txt_path = os.path.join(dir_path + 'parameter.txt')
26 |
27 | def test():
28 | dice_value = 0
29 | hd_value = 0
30 | hd95_value = 0
31 | jc_value = 0
32 | assd_value = 0
33 | numm = 0
34 |
35 | logging.basicConfig(filename = txt_path, level=logging.INFO, format='[%(asctime)s.%(msecs)03d] %(message)s', datefmt='%H:%M:%S')
36 | logging.getLogger().addHandler(logging.StreamHandler(sys.stdout))
37 |
38 | for num in range(41,51):
39 | prediction = []
40 | labels = []
41 | img_dir = os.path.join( '/home/wl/gtv_data/Patient' + str(num).zfill(3)+ '/')
42 | n = len(glob(img_dir + 'image_*.npy'))
43 | for i in range(n):
44 | img = np.load(img_dir + 'image_{:03d}.npy'.format(i))
45 | label = np.load(img_dir + 'label_{:03d}.npy'.format(i))
46 | point = np.load(img_dir + 'point_{:03d}.npy'.format(i))
47 | point = np.expand_dims(point,0)
48 | label_sum = np.sum(label)
49 | img = norm01(img)
50 | img = crop_array(img, parse_config.crop_size)
51 | # img = np.repeat(img,3,0)
52 | label = crop_array(label, parse_config.crop_size)
53 | labels.append(label)
54 | img = torch.from_numpy(img).unsqueeze(0).float().cuda()
55 |
56 | with torch.no_grad():
57 | # if parse_config.with_BPB == 0:
58 | # output = model(img)
59 | # if parse_config.with_BPB == 1:
60 | # output,maps = model(img)
61 | # output = torch.sigmoid(output)[0]
62 |
63 | output = model(img)
64 | output = torch.max(output, dim=0, keepdim=True).values
65 | output = output.cpu().numpy()>0.5
66 |
67 | # writer.add_image('val_label', label, val_num)
68 | # writer.add_image('val_point', point, val_num)
69 | # writer.add_image('val_output', output, val_num)
70 |
71 | # if parse_config.with_BPB == 1:
72 | # for j, m in enumerate(maps):
73 | # if m is not None:
74 | # writer.add_image('val_m{}_img'.format(j+1), maps[j][0,...], val_num)
75 | # val_num += 1
76 | prediction.append(output)
77 |
78 | prediction = np.array(prediction)
79 | # prediction = prediction.squeeze(1)
80 | labels = np.array(labels)
81 |
82 | assert(prediction.shape==labels.shape)
83 | # calculate metric
84 | dice_ave = dc(prediction, labels)
85 | hd_ave = hd(prediction, labels)
86 | hd95_ave = hd95(prediction, labels)
87 | jc_ave = jc(prediction, labels)
88 | assd_ave = assd(prediction, labels)
89 |
90 | logging.info('patient %d : dice_value : %f' % (num, dice_ave))
91 | logging.info('patient %d : hd_value : %f' % (num, hd_ave))
92 | logging.info('patient %d : hd95_value : %f' % (num, hd95_ave))
93 | logging.info('patient %d : jc_value : %f' % (num, jc_ave))
94 | logging.info('patient %d : assd_value : %f' % (num, assd_ave))
95 |
96 | # print("Dice value for patient{} = ".format(num),dice_ave)
97 | # print("HD value for patient{} = ".format(num),hd_ave)
98 | dice_value += dice_ave
99 | hd_value += hd_ave
100 | hd95_value += hd95_ave
101 | jc_value += jc_ave
102 | assd_value += assd_ave
103 | numm += 1
104 |
105 | dice_average = dice_value / numm
106 | hd_average = hd_value / numm
107 | hd95_average = hd95_value / numm
108 | jc_average = jc_value / numm
109 | assd_average = assd_value / numm
110 |
111 | logging.info('Dice value of test dataset : %f' % (dice_average))
112 | logging.info('HD value of test dataset : %f' % (hd_average))
113 | logging.info('HD95 value of test dataset : %f' % (hd95_average))
114 | logging.info('JC value of test dataset : %f' % (jc_average))
115 | logging.info('ASSD value of test dataset : %f' % (assd_average))
116 | # print("Average dice value of evaluation dataset = ",dice_average)
117 | return dice_average
118 |
119 | if __name__ == '__main__':
120 | test()
--------------------------------------------------------------------------------
/lib/Cell_DETR_master/DeepDetr.py:
--------------------------------------------------------------------------------
1 | import sys
2 | from typing import Tuple, Type, Iterable
3 | sys.path.insert(0, '/home/chenfei/my_codes/TransformerCode-master/lib/Cell_DETR_master/')
4 | import torch
5 | import numpy as np
6 | import torch.nn as nn
7 | from torchvision import models
8 | import torch.nn.functional as F
9 | # from modules.modulated_deform_conv import ModulatedDeformConvPack
10 | # from pade_activation_unit.utils import PAU
11 | from torch.nn.modules import Conv2d,LeakyReLU
12 |
13 | # A
14 | conv = Conv2d
15 | act = LeakyReLU
16 |
17 | from backbone import Backbone, DenseNetBlock, StandardBlock, ResNetBlock
18 | from segmentation import MultiHeadAttention, SegmentationHead, SingleSegmentationHead, ResFeaturePyramidBlock, ResPACFeaturePyramidBlock
19 | from transformer import Transformer,TransformerEncoder,TransformerEncoderLayer
20 |
21 | class CellDETR(nn.Module):
22 | def __init__(self,
23 | num_classes: int = 1,
24 | number_of_query_positions: int = 1,
25 | hidden_features=128,
26 | backbone_channels: Tuple[Tuple[int, int], ...] = (
27 | (3, 64), (64, 128), (128, 256), (256, 256)),
28 | backbone_block: Type = ResNetBlock, backbone_convolution: Type = conv,
29 | backbone_normalization: Type = nn.BatchNorm2d, backbone_activation: Type = act,
30 | backbone_pooling: Type = nn.AvgPool2d,
31 | bounding_box_head_features: Tuple[Tuple[int, int], ...] = ((128, 64), (64, 16), (16, 4)),
32 | bounding_box_head_activation: Type = act,
33 | classification_head_activation: Type = act,
34 | num_encoder_layers: int = 3,
35 | num_decoder_layers: int = 2,
36 | dropout: float = 0.0,
37 | transformer_attention_heads: int = 8,
38 | transformer_activation: Type = act,
39 | segmentation_attention_heads: int = 8,
40 | segmentation_head_channels: Tuple[Tuple[int, int], ...] = (
41 | (128+8, 64), (64, 32),(32, 16)),
42 | segmentation_head_feature_channels: Tuple[int, ...] = (256, 128, 64),
43 | segmentation_head_block: Type = ResPACFeaturePyramidBlock,
44 | segmentation_head_convolution: Type = conv,
45 | segmentation_head_normalization: Type = nn.InstanceNorm2d,
46 | segmentation_head_activation: Type = act,
47 | segmentation_head_final_activation: Type = nn.Sigmoid) -> None:
48 |
49 | # Call super constructor
50 | super(CellDETR, self).__init__()
51 | # Init backbone
52 | self.backbone = Backbone(channels=backbone_channels, block=backbone_block, convolution=backbone_convolution,
53 | normalization=backbone_normalization, activation=backbone_activation,
54 | pooling=backbone_pooling)
55 | # Init convolution mapping to match transformer dims
56 | self.convolution_mapping = nn.Conv2d(in_channels=backbone_channels[-1][-1], out_channels=hidden_features,
57 | kernel_size=(1, 1), stride=(1, 1), padding=(0, 0), bias=True)
58 | # Init query positions
59 | self.query_positions = nn.Parameter(
60 | data=torch.randn(number_of_query_positions, hidden_features, dtype=torch.float),
61 | requires_grad=True)
62 | # Init embeddings
63 | self.row_embedding = nn.Parameter(data=torch.randn(50, hidden_features // 2, dtype=torch.float),
64 | requires_grad=True)
65 | self.column_embedding = nn.Parameter(data=torch.randn(50, hidden_features // 2, dtype=torch.float),
66 | requires_grad=True)
67 | # Init transformer
68 | self.transformer = Transformer(d_model=hidden_features, nhead=transformer_attention_heads,
69 | num_encoder_layers=num_encoder_layers, num_decoder_layers=num_decoder_layers,
70 | dropout=dropout, dim_feedforward=4 * hidden_features,
71 | activation=transformer_activation)
72 |
73 | # Init segmentation attention head
74 | self.segmentation_attention_head = MultiHeadAttention(query_dimension=hidden_features,
75 | hidden_features=hidden_features,
76 | number_of_heads=segmentation_attention_heads,
77 | dropout=dropout)
78 | # Init segmentation head
79 | self.segmentation_head = SegmentationHead(channels=segmentation_head_channels,
80 | feature_channels=segmentation_head_feature_channels,
81 | convolution=segmentation_head_convolution,
82 | normalization=segmentation_head_normalization,
83 | activation=segmentation_head_activation,
84 | block=segmentation_head_block,
85 | number_of_query_positions=number_of_query_positions,
86 | softmax=isinstance(segmentation_head_final_activation(), nn.Softmax))
87 | # Init final segmentation activation
88 | self.segmentation_final_activation = segmentation_head_final_activation(dim=1) if isinstance(
89 | segmentation_head_final_activation(), nn.Softmax) else segmentation_head_final_activation()
90 |
91 | self.point_pre_layer = nn.Conv2d(hidden_features, 1, kernel_size = 1)
92 |
93 | def get_parameters(self, lr_main: float = 1e-04, lr_backbone: float = 1e-05) -> Iterable:
94 | return [{'params': self.backbone.parameters(), 'lr': lr_backbone},
95 | {'params': self.convolution_mapping.parameters(), 'lr': lr_main},
96 | {'params': [self.row_embedding], 'lr': lr_main},
97 | {'params': [self.column_embedding], 'lr': lr_main},
98 | {'params': self.transformer.parameters(), 'lr': lr_main},
99 | {'params': self.bounding_box_head.parameters(), 'lr': lr_main},
100 | {'params': self.class_head.parameters(), 'lr': lr_main},
101 | {'params': self.segmentation_attention_head.parameters(), 'lr': lr_main},
102 | {'params': self.segmentation_head.parameters(), 'lr': lr_main}]
103 |
104 | def get_segmentation_head_parameters(self, lr: float = 1e-05) -> Iterable:
105 | return [{'params': self.segmentation_attention_head.parameters(), 'lr': lr},
106 | {'params': self.segmentation_head.parameters(), 'lr': lr}]
107 |
108 | def forward(self, input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
109 | features, feature_list = self.backbone(input)
110 | features = self.convolution_mapping(features)
111 | height, width = features.shape[2:]
112 | batch_size = features.shape[0]
113 | positional_embeddings = torch.cat([self.column_embedding[:height].unsqueeze(dim=0).repeat(height, 1, 1),
114 | self.row_embedding[:width].unsqueeze(dim=1).repeat(1, width, 1)],
115 | dim=-1).permute(2, 0, 1).unsqueeze(0).repeat(batch_size, 1, 1, 1)
116 | latent_tensor, features_encoded = self.transformer(features, None, self.query_positions, positional_embeddings)
117 | latent_tensor = latent_tensor.permute(2, 0, 1)
118 | # point = self.point_pre_layer(features_encoded)
119 | # point = torch.sigmoid(point)
120 | # Feature = point*features_encoded + features_encoded
121 | bounding_box_attention_masks = self.segmentation_attention_head(
122 | latent_tensor, features_encoded.contiguous())
123 | instance_segmentation_prediction = self.segmentation_head(features.contiguous(),
124 | bounding_box_attention_masks.contiguous(),
125 | feature_list[-2::-1])
126 | return self.segmentation_final_activation(instance_segmentation_prediction).clone()
127 |
128 | if __name__ == '__main__':
129 | # Init model
130 | detr = CellDETR()
131 | # Print number of parameters
132 | print("DETR # parameters", sum([p.numel() for p in detr.parameters()]))
133 | # Model into eval mode
134 | # detr.eval()
135 | image = torch.randn(5,3,512,512)
136 | # point = torch.randn(5,1,128,128)
137 | # Predict
138 | segmentation_prediction = detr(image)
139 |
140 | # Print shapes
141 | # print(segmentation_prediction.shape)
142 | # print(segmentation_prediction.max(), segmentation_prediction.min())
143 | # loss = segmentation_prediction.sum()
144 | # loss.backward()
145 |
--------------------------------------------------------------------------------
/lib/Cell_DETR_master/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 Christoph Reich & Tim Prangemeier (Bioinspired Communication Systems Lab, TU Darmstadt)
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
--------------------------------------------------------------------------------
/lib/Cell_DETR_master/__pycache__/backbone.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jcwang123/BA-Transformer/d1ccdff68beac82d18c2cd64bb800e51d7afc5e1/lib/Cell_DETR_master/__pycache__/backbone.cpython-37.pyc
--------------------------------------------------------------------------------
/lib/Cell_DETR_master/__pycache__/bounding_box_head.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jcwang123/BA-Transformer/d1ccdff68beac82d18c2cd64bb800e51d7afc5e1/lib/Cell_DETR_master/__pycache__/bounding_box_head.cpython-37.pyc
--------------------------------------------------------------------------------
/lib/Cell_DETR_master/__pycache__/detr_new.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jcwang123/BA-Transformer/d1ccdff68beac82d18c2cd64bb800e51d7afc5e1/lib/Cell_DETR_master/__pycache__/detr_new.cpython-37.pyc
--------------------------------------------------------------------------------
/lib/Cell_DETR_master/__pycache__/segmentation.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jcwang123/BA-Transformer/d1ccdff68beac82d18c2cd64bb800e51d7afc5e1/lib/Cell_DETR_master/__pycache__/segmentation.cpython-37.pyc
--------------------------------------------------------------------------------
/lib/Cell_DETR_master/__pycache__/transformer.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jcwang123/BA-Transformer/d1ccdff68beac82d18c2cd64bb800e51d7afc5e1/lib/Cell_DETR_master/__pycache__/transformer.cpython-37.pyc
--------------------------------------------------------------------------------
/lib/Cell_DETR_master/augmentation.py:
--------------------------------------------------------------------------------
1 | from typing import Tuple
2 |
3 | import torch
4 | from scipy.ndimage.interpolation import map_coordinates
5 | from scipy.ndimage.filters import gaussian_filter
6 | import numpy as np
7 |
8 |
9 | class Augmentation(object):
10 | """
11 | Super class for all augmentations.
12 | """
13 |
14 | def __init__(self) -> None:
15 | """
16 | Constructor method
17 | """
18 | pass
19 |
20 | def need_labels(self) -> None:
21 | """
22 | Method should return if the labels are needed for the augmentation
23 | """
24 | raise NotImplementedError()
25 |
26 | def __call__(self, *args, **kwargs) -> None:
27 | """
28 | Call method is used to apply the augmentation
29 | :param args: Will be ignored
30 | :param kwargs: Will be ignored
31 | """
32 | raise NotImplementedError()
33 |
34 |
35 | class VerticalFlip(Augmentation):
36 | """
37 | This class implements vertical flipping for instance segmentation.
38 | """
39 |
40 | def __init__(self) -> None:
41 | """
42 | Constructor method
43 | """
44 | # Call super constructor
45 | super(VerticalFlip, self).__init__()
46 |
47 | def need_labels(self) -> bool:
48 | """
49 | Method returns that the labels are needed for the augmentation
50 | :return: (Bool) True will be returned
51 | """
52 | return True
53 |
54 | def __call__(self, input: torch.tensor, instances: torch.tensor,
55 | bounding_boxes: torch.tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
56 | """
57 | Flipping augmentation (only horizontal)
58 | :param image: (torch.Tensor) Input image of shape [channels, height, width]
59 | :param instances: (torch.Tenor) Instances segmentation maps of shape [instances, height, width]
60 | :param bounding_boxes: (torch.Tensor) Bounding boxes of shape [instances, 4 (x1, y1, x2, y2)]
61 | :return: (Tuple[torch.Tensor, torch.Tensor, torch.Tensor]) Input flipped, instances flipped & BBs flipped
62 | """
63 | # Flip input
64 | input_flipped = input.flip(dims=(2,))
65 | # Flip instances
66 | instances_flipped = instances.flip(dims=(2,))
67 | # Flip bounding boxes
68 | image_center = torch.tensor((input.shape[2] // 2, input.shape[1] // 2))
69 | bounding_boxes[:, [0, 2]] += 2 * (image_center - bounding_boxes[:, [0, 2]])
70 | bounding_boxes_w = torch.abs(bounding_boxes[:, 0] - bounding_boxes[:, 2])
71 | bounding_boxes[:, 0] -= bounding_boxes_w
72 | bounding_boxes[:, 2] += bounding_boxes_w
73 | return input_flipped, instances_flipped, bounding_boxes
74 |
75 |
76 | class ElasticDeformation(Augmentation):
77 | """
78 | This class implement random elastic deformation of a given input image
79 | """
80 |
81 | def __init__(self, alpha: float = 125, sigma: float = 20) -> None:
82 | """
83 | Constructor method
84 | :param alpha: (float) Alpha coefficient which represents the scaling
85 | :param sigma: (float) Sigma coefficient which represents the elastic factor
86 | """
87 | # Call super constructor
88 | super(ElasticDeformation, self).__init__()
89 | # Save parameters
90 | self.alpha = alpha
91 | self.sigma = sigma
92 |
93 | def need_labels(self) -> bool:
94 | """
95 | Method returns that the labels are needed for the augmentation
96 | :return: (Bool) True will be returned
97 | """
98 | return False
99 |
100 | def __call__(self, image: torch.Tensor) -> torch.Tensor:
101 | """
102 | Method applies the random elastic deformation
103 | :param image: (torch.Tensor) Input image
104 | :return: (torch.Tensor) Transformed input image
105 | """
106 | # Convert torch tensor to numpy array for scipy
107 | image = image.numpy()
108 | # Save basic shape
109 | shape = image.shape[1:]
110 | # Sample offsets
111 | dx = gaussian_filter((np.random.rand(*shape) * 2 - 1), self.sigma, mode="constant", cval=0) * self.alpha
112 | dy = gaussian_filter((np.random.rand(*shape) * 2 - 1), self.sigma, mode="constant", cval=0) * self.alpha
113 | x, y = np.meshgrid(np.arange(shape[0]), np.arange(shape[1]), indexing='ij')
114 | indices = np.reshape(x + dx, (-1, 1)), np.reshape(y + dy, (-1, 1))
115 | # Perform deformation
116 | for index in range(image.shape[0]):
117 | image[index] = map_coordinates(image[index], indices, order=1).reshape(shape)
118 | return torch.from_numpy(image)
119 |
120 |
121 | class NoiseInjection(Augmentation):
122 | """
123 | This class implements vertical flipping for instance segmentation.
124 | """
125 |
126 | def __init__(self, mean: float = 0.0, std: float = 0.25) -> None:
127 | """
128 | Constructor method
129 | :param mean: (Optional[float]) Mean of the gaussian noise
130 | :param std: (Optional[float]) Standard deviation of the gaussian noise
131 | """
132 | # Call super constructor
133 | super(NoiseInjection, self).__init__()
134 | # Save parameter
135 | self.mean = mean
136 | self.std = std
137 |
138 | def need_labels(self) -> bool:
139 | """
140 | Method returns that the labels are needed for the augmentation
141 | :return: (Bool) False will be returned
142 | """
143 | return False
144 |
145 | def __call__(self, input: torch.Tensor) -> torch.Tensor:
146 | """
147 | Method injects gaussian noise to the given input image
148 | :param image: (torch.Tensor) Input image
149 | :return: (torch.Tensor) Transformed input image
150 | """
151 | # Get noise
152 | noise = self.mean + torch.randn_like(input) * self.std
153 | # Apply nose to image
154 | input = input + noise
155 | return input
156 |
--------------------------------------------------------------------------------
/lib/Cell_DETR_master/botr.py:
--------------------------------------------------------------------------------
1 | import sys
2 | from typing import Tuple, Type, Iterable
3 | sys.path.insert(0, '/home/wl/File/SegGTV/Ours/Cell_DETR_master/')
4 | import torch
5 | import numpy as np
6 | import torch.nn as nn
7 | from torchvision import models
8 | import torch.nn.functional as F
9 | from modules.modulated_deform_conv import ModulatedDeformConvPack
10 | from pade_activation_unit.utils import PAU
11 | from torch.nn.modules import Conv2d,LeakyReLU
12 |
13 | # A
14 | conv = Conv2d
15 | act = LeakyReLU
16 |
17 | from backbone import Backbone, DenseNetBlock, StandardBlock, ResNetBlock
18 | from modules.modulated_deform_conv import ModulatedDeformConvPack
19 | from pade_activation_unit.utils import PAU
20 | from segmentation import MultiHeadAttention, SegmentationHead, SingleSegmentationHead, ResFeaturePyramidBlock, ResPACFeaturePyramidBlock
21 | from transformer import Transformer,TransformerEncoder,TransformerEncoderLayer
22 |
23 | class CellDETR(nn.Module):
24 | def __init__(self,
25 | num_classes: int = 1,
26 | number_of_query_positions: int = 1,
27 | hidden_features=128,
28 | backbone_channels: Tuple[Tuple[int, int], ...] = (
29 | (1, 64), (64, 128), (128, 256), (256, 256)),
30 | backbone_block: Type = ResNetBlock, backbone_convolution: Type = conv,
31 | backbone_normalization: Type = nn.BatchNorm2d, backbone_activation: Type = act,
32 | backbone_pooling: Type = nn.AvgPool2d,
33 | bounding_box_head_features: Tuple[Tuple[int, int], ...] = ((128, 64), (64, 16), (16, 4)),
34 | bounding_box_head_activation: Type = act,
35 | classification_head_activation: Type = act,
36 | num_encoder_layers: int = 3,
37 | num_decoder_layers: int = 2,
38 | dropout: float = 0.0,
39 | transformer_attention_heads: int = 8,
40 | transformer_activation: Type = act,
41 | segmentation_attention_heads: int = 8,
42 | segmentation_head_channels: Tuple[Tuple[int, int], ...] = (
43 | (128+8, 64), (64, 32),(32, 16)),
44 | segmentation_head_feature_channels: Tuple[int, ...] = (256, 128, 64),
45 | segmentation_head_block: Type = ResPACFeaturePyramidBlock,
46 | segmentation_head_convolution: Type = conv,
47 | segmentation_head_normalization: Type = nn.InstanceNorm2d,
48 | segmentation_head_activation: Type = act,
49 | segmentation_head_final_activation: Type = nn.Sigmoid) -> None:
50 |
51 | # Call super constructor
52 | super(CellDETR, self).__init__()
53 | # Init backbone
54 | self.backbone = Backbone(channels=backbone_channels, block=backbone_block, convolution=backbone_convolution,
55 | normalization=backbone_normalization, activation=backbone_activation,
56 | pooling=backbone_pooling)
57 | # Init convolution mapping to match transformer dims
58 | self.convolution_mapping = nn.Conv2d(in_channels=backbone_channels[-1][-1], out_channels=hidden_features,
59 | kernel_size=(1, 1), stride=(1, 1), padding=(0, 0), bias=True)
60 | # Init query positions
61 | self.query_positions = nn.Parameter(
62 | data=torch.randn(number_of_query_positions, hidden_features, dtype=torch.float),
63 | requires_grad=True)
64 | # Init embeddings
65 | self.row_embedding = nn.Parameter(data=torch.randn(50, hidden_features // 2, dtype=torch.float),
66 | requires_grad=True)
67 | self.column_embedding = nn.Parameter(data=torch.randn(50, hidden_features // 2, dtype=torch.float),
68 | requires_grad=True)
69 | # Init transformer
70 | self.transformer = Transformer(d_model=hidden_features, nhead=transformer_attention_heads,
71 | num_encoder_layers=num_encoder_layers, num_decoder_layers=num_decoder_layers,
72 | dropout=dropout, dim_feedforward=4 * hidden_features,
73 | activation=transformer_activation)
74 |
75 | # Init segmentation attention head
76 | self.segmentation_attention_head = MultiHeadAttention(query_dimension=hidden_features,
77 | hidden_features=hidden_features,
78 | number_of_heads=segmentation_attention_heads,
79 | dropout=dropout)
80 | # Init segmentation head
81 | self.segmentation_head = SegmentationHead(channels=segmentation_head_channels,
82 | feature_channels=segmentation_head_feature_channels,
83 | convolution=segmentation_head_convolution,
84 | normalization=segmentation_head_normalization,
85 | activation=segmentation_head_activation,
86 | block=segmentation_head_block,
87 | number_of_query_positions=number_of_query_positions,
88 | softmax=isinstance(segmentation_head_final_activation(), nn.Softmax))
89 | # Init final segmentation activation
90 | self.segmentation_final_activation = segmentation_head_final_activation(dim=1) if isinstance(
91 | segmentation_head_final_activation(), nn.Softmax) else segmentation_head_final_activation()
92 |
93 | self.point_pre_layer = nn.Conv2d(hidden_features, 1, kernel_size = 1)
94 |
95 | def get_parameters(self, lr_main: float = 1e-04, lr_backbone: float = 1e-05) -> Iterable:
96 | return [{'params': self.backbone.parameters(), 'lr': lr_backbone},
97 | {'params': self.convolution_mapping.parameters(), 'lr': lr_main},
98 | {'params': [self.row_embedding], 'lr': lr_main},
99 | {'params': [self.column_embedding], 'lr': lr_main},
100 | {'params': self.transformer.parameters(), 'lr': lr_main},
101 | {'params': self.bounding_box_head.parameters(), 'lr': lr_main},
102 | {'params': self.class_head.parameters(), 'lr': lr_main},
103 | {'params': self.segmentation_attention_head.parameters(), 'lr': lr_main},
104 | {'params': self.segmentation_head.parameters(), 'lr': lr_main}]
105 |
106 | def get_segmentation_head_parameters(self, lr: float = 1e-05) -> Iterable:
107 | return [{'params': self.segmentation_attention_head.parameters(), 'lr': lr},
108 | {'params': self.segmentation_head.parameters(), 'lr': lr}]
109 |
110 | def forward(self, input: torch.Tensor, point: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
111 | features, feature_list = self.backbone(input)
112 | features = self.convolution_mapping(features)
113 | height, width = features.shape[2:]
114 | batch_size = features.shape[0]
115 | positional_embeddings = torch.cat([self.column_embedding[:height].unsqueeze(dim=0).repeat(height, 1, 1),
116 | self.row_embedding[:width].unsqueeze(dim=1).repeat(1, width, 1)],
117 | dim=-1).permute(2, 0, 1).unsqueeze(0).repeat(batch_size, 1, 1, 1)
118 | latent_tensor, features_encoded = self.transformer(features, None, self.query_positions, positional_embeddings)
119 | latent_tensor = latent_tensor.permute(2, 0, 1)
120 | Feature = point*features_encoded + features_encoded
121 | bounding_box_attention_masks = self.segmentation_attention_head(
122 | latent_tensor, Feature.contiguous())
123 | instance_segmentation_prediction = self.segmentation_head(features.contiguous(),
124 | bounding_box_attention_masks.contiguous(),
125 | feature_list[-2::-1])
126 | return self.segmentation_final_activation(instance_segmentation_prediction).clone()
127 |
128 |
129 | if __name__ == '__main__':
130 | # Init model
131 | detr = CellDETR()
132 | # Print number of parameters
133 | print("DETR # parameters", sum([p.numel() for p in detr.parameters()]))
134 | # Model into eval mode
135 | # detr.eval()
136 | image = torch.randn(5,1,128,128)
137 | # point = torch.randn(5,1,128,128)
138 | # Predict
139 | segmentation_prediction = detr(image)
--------------------------------------------------------------------------------
/lib/Cell_DETR_master/botr2.py:
--------------------------------------------------------------------------------
1 | import sys
2 | from typing import Tuple, Type, Iterable
3 | sys.path.insert(0, '/home/wl/File/SegGTV/Ours/Cell_DETR_master/')
4 | import torch
5 | import numpy as np
6 | import torch.nn as nn
7 | from torchvision import models
8 | import torch.nn.functional as F
9 | from modules.modulated_deform_conv import ModulatedDeformConvPack
10 | from pade_activation_unit.utils import PAU
11 | from torch.nn.modules import Conv2d,LeakyReLU
12 |
13 | # A
14 | conv = Conv2d
15 | act = LeakyReLU
16 |
17 | from backbone import Backbone, DenseNetBlock, StandardBlock, ResNetBlock
18 | from modules.modulated_deform_conv import ModulatedDeformConvPack
19 | from pade_activation_unit.utils import PAU
20 | from segmentation import MultiHeadAttention, SegmentationHead, SingleSegmentationHead, ResFeaturePyramidBlock, ResPACFeaturePyramidBlock
21 | from transformer import Transformer,TransformerEncoder,TransformerEncoderLayer
22 |
23 | class CellDETR(nn.Module):
24 | def __init__(self,
25 | num_classes: int = 1,
26 | number_of_query_positions: int = 1,
27 | hidden_features=128,
28 | backbone_channels: Tuple[Tuple[int, int], ...] = (
29 | (1, 64), (64, 128), (128, 256), (256, 256)),
30 | backbone_block: Type = ResNetBlock, backbone_convolution: Type = conv,
31 | backbone_normalization: Type = nn.BatchNorm2d, backbone_activation: Type = act,
32 | backbone_pooling: Type = nn.AvgPool2d,
33 | bounding_box_head_features: Tuple[Tuple[int, int], ...] = ((128, 64), (64, 16), (16, 4)),
34 | bounding_box_head_activation: Type = act,
35 | classification_head_activation: Type = act,
36 | num_encoder_layers: int = 3,
37 | num_decoder_layers: int = 2,
38 | dropout: float = 0.0,
39 | transformer_attention_heads: int = 8,
40 | transformer_activation: Type = act,
41 | segmentation_attention_heads: int = 8,
42 | segmentation_head_channels: Tuple[Tuple[int, int], ...] = (
43 | (128+8, 64), (64, 32),(32, 16)),
44 | segmentation_head_feature_channels: Tuple[int, ...] = (256, 128, 64),
45 | segmentation_head_block: Type = ResPACFeaturePyramidBlock,
46 | segmentation_head_convolution: Type = conv,
47 | segmentation_head_normalization: Type = nn.InstanceNorm2d,
48 | segmentation_head_activation: Type = act,
49 | segmentation_head_final_activation: Type = nn.Sigmoid) -> None:
50 |
51 | # Call super constructor
52 | super(CellDETR, self).__init__()
53 | # Init backbone
54 | self.backbone = Backbone(channels=backbone_channels, block=backbone_block, convolution=backbone_convolution,
55 | normalization=backbone_normalization, activation=backbone_activation,
56 | pooling=backbone_pooling)
57 | # Init convolution mapping to match transformer dims
58 | self.convolution_mapping = nn.Conv2d(in_channels=backbone_channels[-1][-1], out_channels=hidden_features,
59 | kernel_size=(1, 1), stride=(1, 1), padding=(0, 0), bias=True)
60 | # Init query positions
61 | self.query_positions = nn.Parameter(
62 | data=torch.randn(number_of_query_positions, hidden_features, dtype=torch.float),
63 | requires_grad=True)
64 | # Init embeddings
65 | self.row_embedding = nn.Parameter(data=torch.randn(50, hidden_features // 2, dtype=torch.float),
66 | requires_grad=True)
67 | self.column_embedding = nn.Parameter(data=torch.randn(50, hidden_features // 2, dtype=torch.float),
68 | requires_grad=True)
69 | # Init transformer
70 | self.transformer = Transformer(d_model=hidden_features, nhead=transformer_attention_heads,
71 | num_encoder_layers=num_encoder_layers, num_decoder_layers=num_decoder_layers,
72 | dropout=dropout, dim_feedforward=4 * hidden_features,
73 | activation=transformer_activation)
74 |
75 | # Init segmentation attention head
76 | self.segmentation_attention_head = MultiHeadAttention(query_dimension=hidden_features,
77 | hidden_features=hidden_features,
78 | number_of_heads=segmentation_attention_heads,
79 | dropout=dropout)
80 | # Init segmentation head
81 | self.segmentation_head = SegmentationHead(channels=segmentation_head_channels,
82 | feature_channels=segmentation_head_feature_channels,
83 | convolution=segmentation_head_convolution,
84 | normalization=segmentation_head_normalization,
85 | activation=segmentation_head_activation,
86 | block=segmentation_head_block,
87 | number_of_query_positions=number_of_query_positions,
88 | softmax=isinstance(segmentation_head_final_activation(), nn.Softmax))
89 | # Init final segmentation activation
90 | self.segmentation_final_activation = segmentation_head_final_activation(dim=1) if isinstance(
91 | segmentation_head_final_activation(), nn.Softmax) else segmentation_head_final_activation()
92 |
93 | self.point_pre_layer = nn.Conv2d(hidden_features, 1, kernel_size = 1)
94 |
95 | def get_parameters(self, lr_main: float = 1e-04, lr_backbone: float = 1e-05) -> Iterable:
96 | return [{'params': self.backbone.parameters(), 'lr': lr_backbone},
97 | {'params': self.convolution_mapping.parameters(), 'lr': lr_main},
98 | {'params': [self.row_embedding], 'lr': lr_main},
99 | {'params': [self.column_embedding], 'lr': lr_main},
100 | {'params': self.transformer.parameters(), 'lr': lr_main},
101 | {'params': self.bounding_box_head.parameters(), 'lr': lr_main},
102 | {'params': self.class_head.parameters(), 'lr': lr_main},
103 | {'params': self.segmentation_attention_head.parameters(), 'lr': lr_main},
104 | {'params': self.segmentation_head.parameters(), 'lr': lr_main}]
105 |
106 | def get_segmentation_head_parameters(self, lr: float = 1e-05) -> Iterable:
107 | return [{'params': self.segmentation_attention_head.parameters(), 'lr': lr},
108 | {'params': self.segmentation_head.parameters(), 'lr': lr}]
109 |
110 | def forward(self, input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
111 | features, feature_list = self.backbone(input)
112 | features = self.convolution_mapping(features)
113 | height, width = features.shape[2:]
114 | batch_size = features.shape[0]
115 | positional_embeddings = torch.cat([self.column_embedding[:height].unsqueeze(dim=0).repeat(height, 1, 1),
116 | self.row_embedding[:width].unsqueeze(dim=1).repeat(1, width, 1)],
117 | dim=-1).permute(2, 0, 1).unsqueeze(0).repeat(batch_size, 1, 1, 1)
118 | latent_tensor, features_encoded = self.transformer(features, None, self.query_positions, positional_embeddings)
119 | latent_tensor = latent_tensor.permute(2, 0, 1)
120 | point = self.point_pre_layer(features_encoded)
121 | point = torch.sigmoid(point)
122 | Feature = point*features_encoded + features_encoded
123 | bounding_box_attention_masks = self.segmentation_attention_head(
124 | latent_tensor, Feature.contiguous())
125 | instance_segmentation_prediction = self.segmentation_head(features.contiguous(),
126 | bounding_box_attention_masks.contiguous(),
127 | feature_list[-2::-1])
128 | return self.segmentation_final_activation(instance_segmentation_prediction).clone(),point
129 |
130 | if __name__ == '__main__':
131 | # Init model
132 | detr = CellDETR()
133 | # Print number of parameters
134 | print("DETR # parameters", sum([p.numel() for p in detr.parameters()]))
135 | # Model into eval mode
136 | # detr.eval()
137 | image = torch.randn(5,1,128,128)
138 | # point = torch.randn(5,1,128,128)
139 | # Predict
140 | segmentation_prediction = detr(image)
141 |
142 |
--------------------------------------------------------------------------------
/lib/Cell_DETR_master/bounding_box_head.py:
--------------------------------------------------------------------------------
1 | from typing import Tuple, Type
2 |
3 | import torch
4 | import torch.nn as nn
5 |
6 |
7 | class BoundingBoxHead(nn.Module):
8 | """
9 | This class implements the feed forward bounding box head as proposed in:
10 | https://arxiv.org/abs/2005.12872
11 | """
12 |
13 | def __init__(self, features: Tuple[Tuple[int, int]] = ((256, 64), (64, 16), (16, 4)),
14 | activation: Type = nn.PReLU) -> None:
15 | """
16 | Constructor method
17 | :param features: (Tuple[Tuple[int, int]]) Number of input and output features in each layer
18 | :param activation: (Type) Activation function to be utilized
19 | """
20 | # Call super constructor
21 | super(BoundingBoxHead, self).__init__()
22 | # Init layers
23 | self.layers = []
24 | for index, feature in enumerate(features):
25 | if index < len(features) - 1:
26 | self.layers.extend([nn.Linear(in_features=feature[0], out_features=feature[1]), activation()])
27 | else:
28 | self.layers.append(nn.Linear(in_features=feature[0], out_features=feature[1]))
29 | self.layers = nn.Sequential(*self.layers)
30 |
31 | def forward(self, input: torch.Tensor) -> torch.Tensor:
32 | """
33 | Forward pass
34 | :param input: (torch.Tensor) Input tensor of shape (batch size, instances, features)
35 | :return: (torch.Tensor) Output tensor of shape (batch size, instances, classes + 1 (no object))
36 | """
37 | return self.layers(input)
38 |
--------------------------------------------------------------------------------
/lib/Cell_DETR_master/ccccode.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torch.nn as nn
4 | from torchvision import models
5 | import torch.nn.functional as F
6 | from Ours.Cell_DETR_master.transformer import Transformer,TransformerEncoder,TransformerEncoderLayer
7 | os.environ['CUDA_VISIBLE_DEVICES'] = '1'
8 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
9 |
10 | x = torch.rand((5,1,128,128)).cpu()
11 | # resnet = models.resnet18().cuda()
12 | # cnn_encoder = nn.Sequential(*list(resnet.children())[:-4])
13 | # features = cnn_encoder(x)
14 |
15 | from torch.nn.modules import Conv2d,LeakyReLU
16 | hidden_features=128
17 | query_dimension = 128
18 | dropout = 0.0
19 | num_classes = 3
20 | number_of_heads = 16
21 | num_encoder_layers = 3
22 | num_decoder_layers = 2
23 | number_of_query_positions=12
24 | transformer_attention_heads = 8
25 | segmentation_attention_heads = 8
26 | transformer_activation = nn.LeakyReLU
27 | classification_head_activation = nn.LeakyReLU
28 | normalize_before=False
29 |
30 | from backbone import Backbone, DenseNetBlock, StandardBlock, ResNetBlock
31 | from modules.modulated_deform_conv import ModulatedDeformConvPack
32 | from pade_activation_unit.utils import PAU
33 | backbone_channels = ((1, 64), (64, 128), (128, 256), (256, 256))
34 | backbone_block = ResNetBlock
35 | backbone_convolution = nn.Conv2d
36 | backbone_normalization = nn.BatchNorm2d
37 | backbone_activation = nn.LeakyReLU
38 | backbone_pooling = nn.AvgPool2d
39 | backbone = Backbone(channels=backbone_channels, block=backbone_block, convolution=backbone_convolution,
40 | normalization=backbone_normalization, activation=backbone_activation,
41 | pooling=backbone_pooling)
42 | convolution_mapping = nn.Conv2d(in_channels=backbone_channels[-1][-1], out_channels=hidden_features,
43 | kernel_size=(1, 1), stride=(1, 1), padding=(0, 0), bias=True)
44 |
45 |
46 | features, feature_list = backbone(x)
47 | features = convolution_mapping(features)
48 | height, width = features.shape[2:]
49 | print("height width: ",height, width)
50 | # Get batch size
51 | batch_size = features.shape[0]
52 | # Make positional embeddings
53 | print("features size: ",features.shape,len(feature_list))
54 |
55 | query_positions = nn.Parameter(
56 | data=torch.randn(number_of_query_positions, hidden_features, dtype=torch.float),
57 | requires_grad=True)
58 |
59 | row_embedding = nn.Parameter(data=torch.randn(50, hidden_features // 2, dtype=torch.float),
60 | requires_grad=True)
61 | column_embedding = nn.Parameter(data=torch.randn(50, hidden_features // 2, dtype=torch.float),
62 | requires_grad=True)
63 | positional_embeddings = torch.cat([column_embedding[:height].unsqueeze(dim=0).repeat(height, 1, 1),
64 | row_embedding[:width].unsqueeze(dim=1).repeat(1, width, 1)],
65 | dim=-1).permute(2, 0, 1).unsqueeze(0).repeat(batch_size, 1, 1, 1)
66 | print("query_positions size: ", query_positions.shape)
67 | print("positional_embeddings size: ",positional_embeddings.shape)
68 |
69 | positional_embeddings.max(), positional_embeddings.min()
70 | Input = features+positional_embeddings
71 | print(Input.shape)
72 | plt.figure()
73 | plt.subplot(1,3,1)
74 | plt.imshow(features.detach()[0,63,:])
75 | plt.subplot(1,3,2)
76 | plt.imshow(positional_embeddings.detach()[0,63,:])
77 | plt.subplot(1,3,3)
78 | plt.imshow(Input.detach()[0,63,:])
79 | plt.figure()
80 | plt.subplot(1,2,1)
81 | plt.imshow(positional_embeddings.detach()[0,63,:])
82 | plt.subplot(1,2,2)
83 | plt.imshow(positional_embeddings.detach()[0,64,:])
84 |
85 |
86 |
87 | transformer = Transformer(d_model=hidden_features, nhead=transformer_attention_heads,
88 | num_encoder_layers=num_encoder_layers, num_decoder_layers=num_decoder_layers,
89 | dropout=dropout, dim_feedforward=4 * hidden_features,
90 | activation=transformer_activation)
91 | latent_tensor, features_encoded = transformer(features, None, query_positions, positional_embeddings)
92 | latent_tensor = latent_tensor.permute(2, 0, 1)
93 | print(latent_tensor.shape,features_encoded.shape)
94 |
95 |
96 |
97 | from Ours.Cell_DETR_master.segmentation import MultiHeadAttention, SegmentationHead, ResFeaturePyramidBlock, ResPACFeaturePyramidBlock
98 | segmentation_head_channels = ((128 + 8, 128), (128, 64), (64, 32))
99 | segmentation_head_feature_channels = (256, 128, 64)
100 | segmentation_head_convolution = Conv2d
101 | segmentation_head_normalization = nn.InstanceNorm2d
102 | segmentation_head_activation = nn.LeakyReLU
103 | segmentation_head_block = ResPACFeaturePyramidBlock
104 | segmentation_head_final_activation = nn.Sigmoid
105 | segmentation_attention_head = MultiHeadAttention(query_dimension=hidden_features,
106 | hidden_features=hidden_features,
107 | number_of_heads=segmentation_attention_heads,
108 | dropout=dropout)
109 | segmentation_head = SegmentationHead(channels=segmentation_head_channels,
110 | feature_channels=segmentation_head_feature_channels,
111 | convolution=segmentation_head_convolution,
112 | normalization=segmentation_head_normalization,
113 | activation=segmentation_head_activation,
114 | block=segmentation_head_block,
115 | number_of_query_positions=number_of_query_positions,
116 | softmax=isinstance(segmentation_head_final_activation(), nn.Softmax))
117 |
118 |
119 |
120 | bounding_box_attention_masks = segmentation_attention_head(
121 | latent_tensor, features_encoded.contiguous())
122 | instance_segmentation_prediction = segmentation_head(features.contiguous(),
123 | bounding_box_attention_masks.contiguous(),
124 | feature_list[-2::-1])
--------------------------------------------------------------------------------
/lib/Cell_DETR_master/dataset.py:
--------------------------------------------------------------------------------
1 | from typing import Callable, Tuple, List
2 |
3 | import torch
4 | import torch.nn.functional as F
5 | from torch.utils.data import Dataset
6 | import os
7 | import numpy as np
8 |
9 | import misc
10 | import augmentation
11 |
12 |
13 | class CellInstanceSegmentation(Dataset):
14 | """
15 | This dataset implements the cell instance segmentation dataset for the DETR model.
16 | Dataset source: https://github.com/ChristophReich1996/BCS_Data/tree/master/Cell_Instance_Segmentation_Regular_Traps
17 | """
18 |
19 | def __init__(self, path: str = "../../BCS_Data/Cell_Instance_Segmentation_Regular_Traps/train",
20 | normalize: bool = True,
21 | normalization_function: Callable[[torch.Tensor], torch.Tensor] = misc.normalize,
22 | augmentation: Tuple[augmentation.Augmentation, ...] = (
23 | augmentation.VerticalFlip(), augmentation.NoiseInjection(), augmentation.ElasticDeformation()),
24 | augmentation_p: float = 0.5, return_absolute_bounding_box: bool = False,
25 | downscale: bool = True, downscale_shape: Tuple[int, int] = (128, 128),
26 | two_classes: bool = True) -> None:
27 | """
28 | Constructor method
29 | :param path: (str) Path to dataset
30 | :param normalize: (bool) If true normalization_function is applied
31 | :param normalization_function: (Callable[[torch.Tensor], torch.Tensor]) Normalization function
32 | :param augmentation: (Tuple[Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]]) Tuple of
33 | augmentation functions to be applied
34 | :param augmentation_p: (float) Probability that an augmentation is utilized
35 | :param downscale: (bool) If true images and segmentation maps will be downscaled to a size of 256 X 256
36 | :param downscale_shape: (Tuple[int, int]) Target shape is downscale is utilized
37 | :param return_absolute_bounding_box: (Bool) If true the absolute bb is returned else the relative bb is returned
38 | :param two_classes: (bool) If true only two classes, trap and cell, will be utilized
39 | """
40 | # Save parameters
41 | self.normalize = normalize
42 | self.normalization_function = normalization_function
43 | self.augmentation = augmentation
44 | self.augmentation_p = augmentation_p
45 | self.return_absolute_bounding_box = return_absolute_bounding_box
46 | self.downscale = downscale
47 | self.downscale_shape = downscale_shape
48 | self.two_class = two_classes
49 | # Get paths of input images
50 | self.inputs = []
51 | for file in sorted(os.listdir(os.path.join(path, "inputs"))):
52 | self.inputs.append(os.path.join(path, "inputs", file))
53 | # Get paths of instances
54 | self.instances = []
55 | for file in sorted(os.listdir(os.path.join(path, "instances"))):
56 | self.instances.append(os.path.join(path, "instances", file))
57 | # Get paths of class labels
58 | self.class_labels = []
59 | for file in sorted(os.listdir(os.path.join(path, "classes"))):
60 | self.class_labels.append(os.path.join(path, "classes", file))
61 | # Get paths of bounding boxes
62 | self.bounding_boxes = []
63 | for file in sorted(os.listdir(os.path.join(path, "bounding_boxes"))):
64 | self.bounding_boxes.append(os.path.join(path, "bounding_boxes", file))
65 |
66 | def __len__(self) -> int:
67 | """
68 | Method returns the length of the dataset
69 | :return: (int) Length of the dataset
70 | """
71 | return len(self.inputs)
72 |
73 | def __getitem__(self, item: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
74 | """
75 | Get item method
76 | :param item: (int) Item to be returned of the dataset
77 | :return: (Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]) Tuple including input image,
78 | bounding box, class label and instances.
79 | """
80 | # Load data
81 | input = torch.load(self.inputs[item]).unsqueeze(dim=0)
82 | instances = torch.load(self.instances[item])
83 | bounding_boxes = torch.load(self.bounding_boxes[item])
84 | class_labels = torch.load(self.class_labels[item])
85 | # Encode class labels as one-hot
86 | if self.two_class:
87 | class_labels = misc.to_one_hot(class_labels.clamp(max=2.0), num_classes=2 + 1)
88 | else:
89 | class_labels = misc.to_one_hot(class_labels, num_classes=3 + 1)
90 | # Normalize input if utilized
91 | if self.normalize:
92 | input = self.normalization_function(input)
93 | # Apply augmentation if needed
94 | if np.random.random() < self.augmentation_p and self.augmentation is not None:
95 | # Get augmentation
96 | augmentation_to_be_applied = np.random.choice(self.augmentation)
97 | # Apply augmentation
98 | if augmentation_to_be_applied.need_labels():
99 | input, instances, bounding_boxes = augmentation_to_be_applied(input, instances, bounding_boxes)
100 | else:
101 | input = augmentation_to_be_applied(input)
102 | # Downscale data to 256 x 256 if utilized
103 | if self.downscale:
104 | # Apply height and width
105 | bounding_boxes[..., [0, 2]] = bounding_boxes[..., [0, 2]] * (self.downscale_shape[0] / input.shape[-1])
106 | bounding_boxes[..., [1, 3]] = bounding_boxes[..., [1, 3]] * (self.downscale_shape[1] / input.shape[-2])
107 | input = F.interpolate(input=input.unsqueeze(dim=0),
108 | size=self.downscale_shape, mode="bicubic", align_corners=False)[0]
109 | instances = (F.interpolate(input=instances.unsqueeze(dim=0),
110 | size=self.downscale_shape, mode="bilinear", align_corners=False)[
111 | 0] > 0.75).float()
112 | # Convert absolute bounding box to relative bounding box of utilized
113 | if not self.return_absolute_bounding_box:
114 | bounding_boxes = misc.absolute_bounding_box_to_relative(bounding_boxes=bounding_boxes,
115 | height=input.shape[1], width=input.shape[2])
116 | return input, instances, misc.bounding_box_x0y0x1y1_to_xcycwh(bounding_boxes), class_labels
117 |
118 |
119 | def collate_function_cell_instance_segmentation(
120 | batch: List[Tuple[torch.Tensor]]) -> \
121 | Tuple[torch.Tensor, List[torch.Tensor], List[torch.Tensor], List[torch.Tensor]]:
122 | """
123 | Collate function of instance segmentation dataset.
124 | :param batch: (Tuple[Iterable[torch.Tensor], Iterable[torch.Tensor], Iterable[torch.Tensor], Iterable[torch.Tensor]])
125 | Batch of input data, instances maps, bounding boxes and class labels
126 | :return: (Tuple[torch.Tensor, Iterable[torch.Tensor], Iterable[torch.Tensor], Iterable[torch.Tensor]]) Batched input
127 | data, instances, bounding boxes and class labels are stored in a list due to the different instances.
128 | """
129 | return torch.stack([input_samples[0] for input_samples in batch], dim=0), \
130 | [input_samples[1] for input_samples in batch], \
131 | [input_samples[2] for input_samples in batch], \
132 | [input_samples[3] for input_samples in batch]
133 |
134 |
135 |
--------------------------------------------------------------------------------
/lib/Cell_DETR_master/dcn2/.gitignore:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jcwang123/BA-Transformer/d1ccdff68beac82d18c2cd64bb800e51d7afc5e1/lib/Cell_DETR_master/dcn2/.gitignore
--------------------------------------------------------------------------------
/lib/Cell_DETR_master/dcn2/LICENSE:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jcwang123/BA-Transformer/d1ccdff68beac82d18c2cd64bb800e51d7afc5e1/lib/Cell_DETR_master/dcn2/LICENSE
--------------------------------------------------------------------------------
/lib/Cell_DETR_master/dcn2/README.md:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jcwang123/BA-Transformer/d1ccdff68beac82d18c2cd64bb800e51d7afc5e1/lib/Cell_DETR_master/dcn2/README.md
--------------------------------------------------------------------------------
/lib/Cell_DETR_master/dcn2/build.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jcwang123/BA-Transformer/d1ccdff68beac82d18c2cd64bb800e51d7afc5e1/lib/Cell_DETR_master/dcn2/build.py
--------------------------------------------------------------------------------
/lib/Cell_DETR_master/dcn2/build_modulated.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jcwang123/BA-Transformer/d1ccdff68beac82d18c2cd64bb800e51d7afc5e1/lib/Cell_DETR_master/dcn2/build_modulated.py
--------------------------------------------------------------------------------
/lib/Cell_DETR_master/dcn2/functions/__init__.py:
--------------------------------------------------------------------------------
1 | from .deform_conv import DeformConvFunction, deform_conv_function
2 | from .modulated_dcn_func import DeformRoIPoolingFunction, ModulatedDeformConvFunction
--------------------------------------------------------------------------------
/lib/Cell_DETR_master/dcn2/functions/deform_conv.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.autograd import Function
3 | from torch.nn.modules.utils import _pair
4 |
5 | from _ext import deform_conv
6 |
7 |
8 | def deform_conv_function(input,
9 | offset,
10 | weight,
11 | stride=1,
12 | padding=0,
13 | dilation=1,
14 | deform_groups=1,
15 | im2col_step=64):
16 |
17 | if input is not None and input.dim() != 4:
18 | raise ValueError(
19 | "Expected 4D tensor as input, got {}D tensor instead.".format(
20 | input.dim()))
21 |
22 | f = DeformConvFunction(
23 | _pair(stride), _pair(padding), _pair(dilation), deform_groups, im2col_step)
24 | return f(input, offset, weight)
25 |
26 |
27 | class DeformConvFunction(Function):
28 | def __init__(self, stride, padding, dilation, deformable_groups=1, im2col_step=64):
29 | super(DeformConvFunction, self).__init__()
30 | self.stride = stride
31 | self.padding = padding
32 | self.dilation = dilation
33 | self.deformable_groups = deformable_groups
34 | self.im2col_step = im2col_step
35 |
36 | def forward(self, input, offset, weight):
37 | self.save_for_backward(input, offset, weight)
38 |
39 | output = input.new(*self._output_size(input, weight))
40 |
41 | self.bufs_ = [input.new(), input.new()] # columns, ones
42 |
43 | if not input.is_cuda:
44 | raise NotImplementedError
45 | else:
46 | if isinstance(input, torch.autograd.Variable):
47 | if not isinstance(input.data, torch.cuda.FloatTensor):
48 | raise NotImplementedError
49 | else:
50 | if not isinstance(input, torch.cuda.FloatTensor):
51 | raise NotImplementedError
52 |
53 | cur_im2col_step = min(self.im2col_step, input.shape[0])
54 | assert (input.shape[0] % cur_im2col_step) == 0, 'im2col step must divide batchsize'
55 | deform_conv.deform_conv_forward_cuda(
56 | input, weight, offset, output, self.bufs_[0], self.bufs_[1],
57 | weight.size(3), weight.size(2), self.stride[1], self.stride[0],
58 | self.padding[1], self.padding[0], self.dilation[1],
59 | self.dilation[0], self.deformable_groups, cur_im2col_step)
60 | return output
61 |
62 | def backward(self, grad_output):
63 | input, offset, weight = self.saved_tensors
64 |
65 | grad_input = grad_offset = grad_weight = None
66 |
67 | if not grad_output.is_cuda:
68 | raise NotImplementedError
69 | else:
70 | if isinstance(grad_output, torch.autograd.Variable):
71 | if not isinstance(grad_output.data, torch.cuda.FloatTensor):
72 | raise NotImplementedError
73 | else:
74 | if not isinstance(grad_output, torch.cuda.FloatTensor):
75 | raise NotImplementedError
76 |
77 | cur_im2col_step = min(self.im2col_step, input.shape[0])
78 | assert (input.shape[0] % cur_im2col_step) == 0, 'im2col step must divide batchsize'
79 |
80 | if self.needs_input_grad[0] or self.needs_input_grad[1]:
81 | grad_input = input.new(*input.size()).zero_()
82 | grad_offset = offset.new(*offset.size()).zero_()
83 | deform_conv.deform_conv_backward_input_cuda(
84 | input, offset, grad_output, grad_input,
85 | grad_offset, weight, self.bufs_[0], weight.size(3),
86 | weight.size(2), self.stride[1], self.stride[0],
87 | self.padding[1], self.padding[0], self.dilation[1],
88 | self.dilation[0], self.deformable_groups, cur_im2col_step)
89 |
90 |
91 | if self.needs_input_grad[2]:
92 | grad_weight = weight.new(*weight.size()).zero_()
93 | deform_conv.deform_conv_backward_parameters_cuda(
94 | input, offset, grad_output,
95 | grad_weight, self.bufs_[0], self.bufs_[1], weight.size(3),
96 | weight.size(2), self.stride[1], self.stride[0],
97 | self.padding[1], self.padding[0], self.dilation[1],
98 | self.dilation[0], self.deformable_groups, 1, cur_im2col_step)
99 |
100 | return grad_input, grad_offset, grad_weight
101 |
102 | def _output_size(self, input, weight):
103 | channels = weight.size(0)
104 |
105 | output_size = (input.size(0), channels)
106 | for d in range(input.dim() - 2):
107 | in_size = input.size(d + 2)
108 | pad = self.padding[d]
109 | kernel = self.dilation[d] * (weight.size(d + 2) - 1) + 1
110 | stride = self.stride[d]
111 | output_size += ((in_size + (2 * pad) - kernel) // stride + 1, )
112 | if not all(map(lambda s: s > 0, output_size)):
113 | raise ValueError(
114 | "convolution input is too small (output would be {})".format(
115 | 'x'.join(map(str, output_size))))
116 | return output_size
117 |
--------------------------------------------------------------------------------
/lib/Cell_DETR_master/dcn2/functions/modulated_dcn_func.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | from __future__ import absolute_import
3 | from __future__ import print_function
4 | from __future__ import division
5 |
6 | import torch
7 | from torch.autograd import Function
8 |
9 | from _ext import modulated_dcn as _backend
10 |
11 |
12 | class ModulatedDeformConvFunction(Function):
13 |
14 | def __init__(self, stride, padding, dilation=1, deformable_groups=1):
15 | super(ModulatedDeformConvFunction, self).__init__()
16 | self.stride = stride
17 | self.padding = padding
18 | self.dilation = dilation
19 | self.deformable_groups = deformable_groups
20 |
21 | def forward(self, input, offset, mask, weight, bias):
22 | if not input.is_cuda:
23 | raise NotImplementedError
24 | if weight.requires_grad or mask.requires_grad or offset.requires_grad or input.requires_grad:
25 | self.save_for_backward(input, offset, mask, weight, bias)
26 | output = input.new(*self._infer_shape(input, weight))
27 | self._bufs = [input.new(), input.new()]
28 | _backend.modulated_deform_conv_cuda_forward(input, weight,
29 | bias, self._bufs[0],
30 | offset, mask,
31 | output, self._bufs[1],
32 | weight.shape[2], weight.shape[3],
33 | self.stride, self.stride,
34 | self.padding, self.padding,
35 | self.dilation, self.dilation,
36 | self.deformable_groups)
37 | return output
38 |
39 | def backward(self, grad_output):
40 | if not grad_output.is_cuda:
41 | raise NotImplementedError
42 | input, offset, mask, weight, bias = self.saved_tensors
43 | grad_input = input.new(*input.size()).zero_()
44 | grad_offset = offset.new(*offset.size()).zero_()
45 | grad_mask = mask.new(*mask.size()).zero_()
46 | grad_weight = weight.new(*weight.size()).zero_()
47 | grad_bias = bias.new(*bias.size()).zero_()
48 | _backend.modulated_deform_conv_cuda_backward(input, weight,
49 | bias, self._bufs[0],
50 | offset, mask,
51 | self._bufs[1],
52 | grad_input, grad_weight,
53 | grad_bias, grad_offset,
54 | grad_mask, grad_output,
55 | weight.shape[2], weight.shape[3],
56 | self.stride, self.stride,
57 | self.padding, self.padding,
58 | self.dilation, self.dilation,
59 | self.deformable_groups)
60 |
61 | return grad_input, grad_offset, grad_mask, grad_weight, grad_bias
62 |
63 | def _infer_shape(self, input, weight):
64 | n = input.size(0)
65 | channels_out = weight.size(0)
66 | height, width = input.shape[2:4]
67 | kernel_h, kernel_w = weight.shape[2:4]
68 | height_out = (height + 2 * self.padding -
69 | (self.dilation * (kernel_h - 1) + 1)) // self.stride + 1
70 | width_out = (width + 2 * self.padding - (self.dilation *
71 | (kernel_w - 1) + 1)) // self.stride + 1
72 | return (n, channels_out, height_out, width_out)
73 |
74 |
75 | class DeformRoIPoolingFunction(Function):
76 |
77 | def __init__(self,
78 | spatial_scale,
79 | pooled_size,
80 | output_dim,
81 | no_trans,
82 | group_size=1,
83 | part_size=None,
84 | sample_per_part=4,
85 | trans_std=.0):
86 | super(DeformRoIPoolingFunction, self).__init__()
87 | self.spatial_scale = spatial_scale
88 | self.pooled_size = pooled_size
89 | self.output_dim = output_dim
90 | self.no_trans = no_trans
91 | self.group_size = group_size
92 | self.part_size = pooled_size if part_size is None else part_size
93 | self.sample_per_part = sample_per_part
94 | self.trans_std = trans_std
95 |
96 | assert self.trans_std >= 0.0 and self.trans_std <= 1.0
97 |
98 | def forward(self, data, rois, offset):
99 | if not data.is_cuda:
100 | raise NotImplementedError
101 |
102 | output = data.new(*self._infer_shape(data, rois))
103 | output_count = data.new(*self._infer_shape(data, rois))
104 | _backend.deform_psroi_pooling_cuda_forward(data, rois, offset,
105 | output, output_count,
106 | self.no_trans, self.spatial_scale,
107 | self.output_dim, self.group_size,
108 | self.pooled_size, self.part_size,
109 | self.sample_per_part, self.trans_std)
110 |
111 | # if data.requires_grad or rois.requires_grad or offset.requires_grad:
112 | # self.save_for_backward(data, rois, offset, output_count)
113 | self.data = data
114 | self.rois = rois
115 | self.offset = offset
116 | self.output_count = output_count
117 |
118 | return output
119 |
120 | def backward(self, grad_output):
121 | if not grad_output.is_cuda:
122 | raise NotImplementedError
123 |
124 | # data, rois, offset, output_count = self.saved_tensors
125 | data = self.data
126 | rois = self.rois
127 | offset = self.offset
128 | output_count = self.output_count
129 | grad_input = data.new(*data.size()).zero_()
130 | grad_offset = offset.new(*offset.size()).zero_()
131 |
132 | _backend.deform_psroi_pooling_cuda_backward(grad_output,
133 | data,
134 | rois,
135 | offset,
136 | output_count,
137 | grad_input,
138 | grad_offset,
139 | self.no_trans,
140 | self.spatial_scale,
141 | self.output_dim,
142 | self.group_size,
143 | self.pooled_size,
144 | self.part_size,
145 | self.sample_per_part,
146 | self.trans_std)
147 | return grad_input, torch.zeros(rois.shape).cuda(), grad_offset
148 |
149 | def _infer_shape(self, data, rois):
150 | # _, c, h, w = data.shape[:4]
151 | c = data.shape[1]
152 | n = rois.shape[0]
153 | return (n, self.output_dim, self.pooled_size, self.pooled_size)
154 |
--------------------------------------------------------------------------------
/lib/Cell_DETR_master/dcn2/modules/__init__.py:
--------------------------------------------------------------------------------
1 | from .deform_conv import DeformConv
2 | from .modulated_dcn import DeformRoIPooling, ModulatedDeformConv, ModulatedDeformConvPack, ModulatedDeformRoIPoolingPack
--------------------------------------------------------------------------------
/lib/Cell_DETR_master/dcn2/modules/deform_conv.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import torch
4 | import torch.nn as nn
5 | from torch.nn.modules.module import Module
6 | from torch.nn.modules.utils import _pair
7 | from functions import deform_conv_function
8 |
9 |
10 | class DeformConv(Module):
11 | def __init__(self,
12 | in_channels,
13 | out_channels,
14 | kernel_size,
15 | stride=1,
16 | padding=0,
17 | dilation=1,
18 | num_deformable_groups=1):
19 | super(DeformConv, self).__init__()
20 | self.in_channels = in_channels
21 | self.out_channels = out_channels
22 | self.kernel_size = _pair(kernel_size)
23 | self.stride = _pair(stride)
24 | self.padding = _pair(padding)
25 | self.dilation = _pair(dilation)
26 | self.num_deformable_groups = num_deformable_groups
27 |
28 | self.weight = nn.Parameter(
29 | torch.Tensor(out_channels, in_channels, *self.kernel_size))
30 |
31 | self.reset_parameters()
32 |
33 | def reset_parameters(self):
34 | n = self.in_channels
35 | for k in self.kernel_size:
36 | n *= k
37 | stdv = 1. / math.sqrt(n)
38 | self.weight.data.uniform_(-stdv, stdv)
39 |
40 | def forward(self, input, offset):
41 | return deform_conv_function(input, offset, self.weight, self.stride,
42 | self.padding, self.dilation,
43 | self.num_deformable_groups)
44 |
--------------------------------------------------------------------------------
/lib/Cell_DETR_master/dcn2/modules/modulated_dcn.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | from __future__ import absolute_import
3 | from __future__ import print_function
4 | from __future__ import division
5 |
6 | import torch
7 | import math
8 | from torch import nn
9 | from torch.nn.modules.utils import _pair
10 |
11 | from functions.modulated_dcn_func import ModulatedDeformConvFunction
12 | from functions.modulated_dcn_func import DeformRoIPoolingFunction
13 |
14 | class ModulatedDeformConv(nn.Module):
15 |
16 | def __init__(self, in_channels, out_channels,
17 | kernel_size, stride, padding, dilation=1, deformable_groups=1, no_bias=True):
18 | super(ModulatedDeformConv, self).__init__()
19 | self.in_channels = in_channels
20 | self.out_channels = out_channels
21 | self.kernel_size = _pair(kernel_size)
22 | self.stride = stride
23 | self.padding = padding
24 | self.dilation = dilation
25 | self.deformable_groups = deformable_groups
26 | self.no_bias = no_bias
27 |
28 | self.weight = nn.Parameter(torch.Tensor(out_channels, in_channels, *self.kernel_size))
29 | self.bias = nn.Parameter(torch.zeros(out_channels))
30 | self.reset_parameters()
31 | if self.no_bias:
32 | self.bias.requires_grad = False
33 |
34 | def reset_parameters(self):
35 | n = self.in_channels
36 | for k in self.kernel_size:
37 | n *= k
38 | stdv = 1. / math.sqrt(n)
39 | self.weight.data.uniform_(-stdv, stdv)
40 | self.bias.data.zero_()
41 |
42 | def forward(self, input, offset, mask):
43 | func = ModulatedDeformConvFunction(self.stride, self.padding, self.dilation, self.deformable_groups)
44 | return func(input, offset, mask, self.weight, self.bias)
45 |
46 |
47 | class ModulatedDeformConvPack(ModulatedDeformConv):
48 |
49 | def __init__(self, in_channels, out_channels,
50 | kernel_size, stride, padding,
51 | dilation=1, deformable_groups=1, no_bias=False):
52 | super(ModulatedDeformConvPack, self).__init__(in_channels, out_channels,
53 | kernel_size, stride, padding, dilation, deformable_groups, no_bias)
54 |
55 | self.conv_offset_mask = nn.Conv2d(self.in_channels,
56 | self.deformable_groups * 3 * self.kernel_size[0] * self.kernel_size[1],
57 | kernel_size=self.kernel_size,
58 | stride=(self.stride, self.stride),
59 | padding=(self.padding, self.padding),
60 | bias=True)
61 | self.init_offset()
62 |
63 | def init_offset(self):
64 | self.conv_offset_mask.weight.data.zero_()
65 | self.conv_offset_mask.bias.data.zero_()
66 |
67 | def forward(self, input):
68 | out = self.conv_offset_mask(input)
69 | o1, o2, mask = torch.chunk(out, 3, dim=1)
70 | offset = torch.cat((o1, o2), dim=1)
71 | mask = torch.sigmoid(mask)
72 | func = ModulatedDeformConvFunction(self.stride, self.padding, self.dilation, self.deformable_groups)
73 | return func(input, offset, mask, self.weight, self.bias)
74 |
75 |
76 | class DeformRoIPooling(nn.Module):
77 |
78 | def __init__(self,
79 | spatial_scale,
80 | pooled_size,
81 | output_dim,
82 | no_trans,
83 | group_size=1,
84 | part_size=None,
85 | sample_per_part=4,
86 | trans_std=.0):
87 | super(DeformRoIPooling, self).__init__()
88 | self.spatial_scale = spatial_scale
89 | self.pooled_size = pooled_size
90 | self.output_dim = output_dim
91 | self.no_trans = no_trans
92 | self.group_size = group_size
93 | self.part_size = pooled_size if part_size is None else part_size
94 | self.sample_per_part = sample_per_part
95 | self.trans_std = trans_std
96 | self.func = DeformRoIPoolingFunction(self.spatial_scale,
97 | self.pooled_size,
98 | self.output_dim,
99 | self.no_trans,
100 | self.group_size,
101 | self.part_size,
102 | self.sample_per_part,
103 | self.trans_std)
104 |
105 | def forward(self, data, rois, offset):
106 |
107 | if self.no_trans:
108 | offset = data.new()
109 | return self.func(data, rois, offset)
110 |
111 | class ModulatedDeformRoIPoolingPack(DeformRoIPooling):
112 |
113 | def __init__(self,
114 | spatial_scale,
115 | pooled_size,
116 | output_dim,
117 | no_trans,
118 | group_size=1,
119 | part_size=None,
120 | sample_per_part=4,
121 | trans_std=.0,
122 | deform_fc_dim=1024):
123 | super(ModulatedDeformRoIPoolingPack, self).__init__(spatial_scale,
124 | pooled_size,
125 | output_dim,
126 | no_trans,
127 | group_size,
128 | part_size,
129 | sample_per_part,
130 | trans_std)
131 |
132 | self.deform_fc_dim = deform_fc_dim
133 |
134 | if not no_trans:
135 | self.func_offset = DeformRoIPoolingFunction(self.spatial_scale,
136 | self.pooled_size,
137 | self.output_dim,
138 | True,
139 | self.group_size,
140 | self.part_size,
141 | self.sample_per_part,
142 | self.trans_std)
143 | self.offset_fc = nn.Sequential(
144 | nn.Linear(self.pooled_size * self.pooled_size * self.output_dim, self.deform_fc_dim),
145 | nn.ReLU(inplace=True),
146 | nn.Linear(self.deform_fc_dim, self.deform_fc_dim),
147 | nn.ReLU(inplace=True),
148 | nn.Linear(self.deform_fc_dim, self.pooled_size * self.pooled_size * 2)
149 | )
150 | self.offset_fc[4].weight.data.zero_()
151 | self.offset_fc[4].bias.data.zero_()
152 | self.mask_fc = nn.Sequential(
153 | nn.Linear(self.pooled_size * self.pooled_size * self.output_dim, self.deform_fc_dim),
154 | nn.ReLU(inplace=True),
155 | nn.Linear(self.deform_fc_dim, self.pooled_size * self.pooled_size * 1),
156 | nn.Sigmoid()
157 | )
158 | self.mask_fc[2].weight.data.zero_()
159 | self.mask_fc[2].bias.data.zero_()
160 |
161 | def forward(self, data, rois):
162 | if self.no_trans:
163 | offset = data.new()
164 | else:
165 | n = rois.shape[0]
166 | offset = data.new()
167 | x = self.func_offset(data, rois, offset)
168 | offset = self.offset_fc(x.view(n, -1))
169 | offset = offset.view(n, 2, self.pooled_size, self.pooled_size)
170 | mask = self.mask_fc(x.view(n, -1))
171 | mask = mask.view(n, 1, self.pooled_size, self.pooled_size)
172 | feat = self.func(data, rois, offset) * mask
173 | return feat
174 | return self.func(data, rois, offset)
175 |
--------------------------------------------------------------------------------
/lib/Cell_DETR_master/dcn2/src/cuda/deform_psroi_pooling_cuda.h:
--------------------------------------------------------------------------------
1 | /*!
2 | * Copyright (c) 2017 Microsoft
3 | * Licensed under The MIT License [see LICENSE for details]
4 | * \file deformable_psroi_pooling.cu
5 | * \brief
6 | * \author Yi Li, Guodong Zhang, Jifeng Dai
7 | */
8 | /***************** Adapted by Charles Shang *********************/
9 |
10 | #ifndef DCN_V2_PSROI_POOLING_CUDA
11 | #define DCN_V2_PSROI_POOLING_CUDA
12 |
13 | #ifdef __cplusplus
14 | extern "C"
15 | {
16 | #endif
17 |
18 | void DeformablePSROIPoolForward(cudaStream_t stream,
19 | const float *data,
20 | const float *bbox,
21 | const float *trans,
22 | float *out,
23 | float *top_count,
24 | const int batch,
25 | const int channels,
26 | const int height,
27 | const int width,
28 | const int num_bbox,
29 | const int channels_trans,
30 | const int no_trans,
31 | const float spatial_scale,
32 | const int output_dim,
33 | const int group_size,
34 | const int pooled_size,
35 | const int part_size,
36 | const int sample_per_part,
37 | const float trans_std);
38 |
39 | void DeformablePSROIPoolBackwardAcc(cudaStream_t stream,
40 | const float *out_grad,
41 | const float *data,
42 | const float *bbox,
43 | const float *trans,
44 | const float *top_count,
45 | float *in_grad,
46 | float *trans_grad,
47 | const int batch,
48 | const int channels,
49 | const int height,
50 | const int width,
51 | const int num_bbox,
52 | const int channels_trans,
53 | const int no_trans,
54 | const float spatial_scale,
55 | const int output_dim,
56 | const int group_size,
57 | const int pooled_size,
58 | const int part_size,
59 | const int sample_per_part,
60 | const float trans_std);
61 |
62 | #ifdef __cplusplus
63 | }
64 | #endif
65 |
66 | #endif
--------------------------------------------------------------------------------
/lib/Cell_DETR_master/dcn2/src/cuda/modulated_deform_im2col_cuda.h:
--------------------------------------------------------------------------------
1 | /*!
2 | ******************* BEGIN Caffe Copyright Notice and Disclaimer ****************
3 | *
4 | * COPYRIGHT
5 | *
6 | * All contributions by the University of California:
7 | * Copyright (c) 2014-2017 The Regents of the University of California (Regents)
8 | * All rights reserved.
9 | *
10 | * All other contributions:
11 | * Copyright (c) 2014-2017, the respective contributors
12 | * All rights reserved.
13 | *
14 | * Caffe uses a shared copyright model: each contributor holds copyright over
15 | * their contributions to Caffe. The project versioning records all such
16 | * contribution and copyright details. If a contributor wants to further mark
17 | * their specific copyright on a particular contribution, they should indicate
18 | * their copyright solely in the commit message of the change when it is
19 | * committed.
20 | *
21 | * LICENSE
22 | *
23 | * Redistribution and use in source and binary forms, with or without
24 | * modification, are permitted provided that the following conditions are met:
25 | *
26 | * 1. Redistributions of source code must retain the above copyright notice, this
27 | * list of conditions and the following disclaimer.
28 | * 2. Redistributions in binary form must reproduce the above copyright notice,
29 | * this list of conditions and the following disclaimer in the documentation
30 | * and/or other materials provided with the distribution.
31 | *
32 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
33 | * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
34 | * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
35 | * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
36 | * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
37 | * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
38 | * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
39 | * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
40 | * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
41 | * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
42 | *
43 | * CONTRIBUTION AGREEMENT
44 | *
45 | * By contributing to the BVLC/caffe repository through pull-request, comment,
46 | * or otherwise, the contributor releases their content to the
47 | * license and copyright terms herein.
48 | *
49 | ***************** END Caffe Copyright Notice and Disclaimer ********************
50 | *
51 | * Copyright (c) 2018 Microsoft
52 | * Licensed under The MIT License [see LICENSE for details]
53 | * \file modulated_deformable_im2col.h
54 | * \brief Function definitions of converting an image to
55 | * column matrix based on kernel, padding, dilation, and offset.
56 | * These functions are mainly used in deformable convolution operators.
57 | * \ref: https://arxiv.org/abs/1811.11168
58 | * \author Yuwen Xiong, Haozhi Qi, Jifeng Dai, Xizhou Zhu, Han Hu
59 | */
60 |
61 | /***************** Adapted by Charles Shang *********************/
62 |
63 | #ifndef DCN_V2_IM2COL_CUDA
64 | #define DCN_V2_IM2COL_CUDA
65 |
66 | #ifdef __cplusplus
67 | extern "C"
68 | {
69 | #endif
70 |
71 | void modulated_deformable_im2col_cuda(cudaStream_t stream,
72 | const float *data_im, const float *data_offset, const float *data_mask,
73 | const int batch_size, const int channels, const int height_im, const int width_im,
74 | const int height_col, const int width_col, const int kernel_h, const int kenerl_w,
75 | const int pad_h, const int pad_w, const int stride_h, const int stride_w,
76 | const int dilation_h, const int dilation_w,
77 | const int deformable_group, float *data_col);
78 |
79 | void modulated_deformable_col2im_cuda(cudaStream_t stream,
80 | const float *data_col, const float *data_offset, const float *data_mask,
81 | const int batch_size, const int channels, const int height_im, const int width_im,
82 | const int height_col, const int width_col, const int kernel_h, const int kenerl_w,
83 | const int pad_h, const int pad_w, const int stride_h, const int stride_w,
84 | const int dilation_h, const int dilation_w,
85 | const int deformable_group, float *grad_im);
86 |
87 | void modulated_deformable_col2im_coord_cuda(cudaStream_t stream,
88 | const float *data_col, const float *data_im, const float *data_offset, const float *data_mask,
89 | const int batch_size, const int channels, const int height_im, const int width_im,
90 | const int height_col, const int width_col, const int kernel_h, const int kenerl_w,
91 | const int pad_h, const int pad_w, const int stride_h, const int stride_w,
92 | const int dilation_h, const int dilation_w,
93 | const int deformable_group,
94 | float *grad_offset, float *grad_mask);
95 |
96 | #ifdef __cplusplus
97 | }
98 | #endif
99 |
100 | #endif
--------------------------------------------------------------------------------
/lib/Cell_DETR_master/dcn2/src/deform_conv.c:
--------------------------------------------------------------------------------
1 | #include
2 |
3 | int deform_conv_forward(THFloatTensor *input, THFloatTensor *offset,
4 | THFloatTensor *output)
5 | {
6 | // if (!THFloatTensor_isSameSizeAs(input1, input2))
7 | // return 0;
8 | // THFloatTensor_resizeAs(output, input);
9 | // THFloatTensor_cadd(output, input1, 1.0, input2);
10 | return 1;
11 | }
12 |
13 | int deform_conv_backward(THFloatTensor *grad_output, THFloatTensor *grad_input,
14 | THFloatTensor *grad_offset)
15 | {
16 | // THFloatTensor_resizeAs(grad_input, grad_output);
17 | // THFloatTensor_fill(grad_input, 1);
18 | return 1;
19 | }
20 |
--------------------------------------------------------------------------------
/lib/Cell_DETR_master/dcn2/src/deform_conv.h:
--------------------------------------------------------------------------------
1 | int deform_conv_forward(THFloatTensor *input, THFloatTensor *offset,
2 | THFloatTensor *output);
3 | int deform_conv_backward(THFloatTensor *grad_output, THFloatTensor *grad_input,
4 | THFloatTensor *grad_offset);
5 |
--------------------------------------------------------------------------------
/lib/Cell_DETR_master/dcn2/src/deform_conv_cuda.h:
--------------------------------------------------------------------------------
1 | int deform_conv_forward_cuda(THCudaTensor *input,
2 | THCudaTensor *weight, /*THCudaTensor * bias, */
3 | THCudaTensor *offset, THCudaTensor *output,
4 | THCudaTensor *columns, THCudaTensor *ones, int kW,
5 | int kH, int dW, int dH, int padW, int padH,
6 | int dilationW, int dilationH,
7 | int deformable_group, int im2col_step);
8 |
9 | int deform_conv_backward_input_cuda(
10 | THCudaTensor *input, THCudaTensor *offset, THCudaTensor *gradOutput,
11 | THCudaTensor *gradInput, THCudaTensor *gradOffset, THCudaTensor *weight,
12 | THCudaTensor *columns, int kW, int kH, int dW, int dH, int padW, int padH,
13 | int dilationW, int dilationH, int deformable_group, int im2col_step);
14 |
15 | int deform_conv_backward_parameters_cuda(
16 | THCudaTensor *input, THCudaTensor *offset, THCudaTensor *gradOutput,
17 | THCudaTensor *gradWeight, /*THCudaTensor *gradBias, */
18 | THCudaTensor *columns, THCudaTensor *ones, int kW, int kH, int dW, int dH,
19 | int padW, int padH, int dilationW, int dilationH, int deformable_group,
20 | float scale, int im2col_step);
21 |
--------------------------------------------------------------------------------
/lib/Cell_DETR_master/dcn2/src/deform_conv_cuda_kernel.h:
--------------------------------------------------------------------------------
1 | template
2 | void deformable_im2col(cudaStream_t stream, const DType *data_im,
3 | const DType *data_offset, const int channels,
4 | const int height, const int width, const int ksize_h,
5 | const int ksize_w, const int pad_h, const int pad_w,
6 | const int stride_h, const int stride_w,
7 | const int dilation_h, const int dilation_w,
8 | const int parallel_imgs,
9 | const int deformable_group, DType *data_col);
10 |
11 | template
12 | void deformable_col2im(cudaStream_t stream, const DType *data_col,
13 | const DType *data_offset, const int channels,
14 | const int height, const int width, const int ksize_h,
15 | const int ksize_w, const int pad_h, const int pad_w,
16 | const int stride_h, const int stride_w,
17 | const int dilation_h, const int dilation_w,
18 | const int parallel_imgs,
19 | const int deformable_group, DType *grad_im);
20 |
21 | template
22 | void deformable_col2im_coord(cudaStream_t stream, const DType *data_col,
23 | const DType *data_im, const DType *data_offset,
24 | const int channels, const int height,
25 | const int width, const int ksize_h,
26 | const int ksize_w, const int pad_h,
27 | const int pad_w, const int stride_h,
28 | const int stride_w, const int dilation_h,
29 | const int dilation_w, const int parallel_imgs,
30 | const int deformable_group, DType *grad_offset);
31 |
--------------------------------------------------------------------------------
/lib/Cell_DETR_master/dcn2/src/modulated_dcn.c:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 | #include
4 |
5 | void modulated_deform_conv_forward(THFloatTensor *input, THFloatTensor *weight,
6 | THFloatTensor *bias, THFloatTensor *ones,
7 | THFloatTensor *offset, THFloatTensor *mask,
8 | THFloatTensor *output, THFloatTensor *columns,
9 | const int pad_h, const int pad_w,
10 | const int stride_h, const int stride_w,
11 | const int dilation_h, const int dilation_w,
12 | const int deformable_group)
13 | {
14 | printf("only implemented in GPU");
15 | }
16 | void modulated_deform_conv_backward(THFloatTensor *input, THFloatTensor *weight,
17 | THFloatTensor *bias, THFloatTensor *ones,
18 | THFloatTensor *offset, THFloatTensor *mask,
19 | THFloatTensor *output, THFloatTensor *columns,
20 | THFloatTensor *grad_input, THFloatTensor *grad_weight,
21 | THFloatTensor *grad_bias, THFloatTensor *grad_offset,
22 | THFloatTensor *grad_mask, THFloatTensor *grad_output,
23 | int kernel_h, int kernel_w,
24 | int stride_h, int stride_w,
25 | int pad_h, int pad_w,
26 | int dilation_h, int dilation_w,
27 | int deformable_group)
28 | {
29 | printf("only implemented in GPU");
30 | }
--------------------------------------------------------------------------------
/lib/Cell_DETR_master/dcn2/src/modulated_dcn.h:
--------------------------------------------------------------------------------
1 | void modulated_deform_conv_forward(THFloatTensor *input, THFloatTensor *weight,
2 | THFloatTensor *bias, THFloatTensor *ones,
3 | THFloatTensor *offset, THFloatTensor *mask,
4 | THFloatTensor *output, THFloatTensor *columns,
5 | const int pad_h, const int pad_w,
6 | const int stride_h, const int stride_w,
7 | const int dilation_h, const int dilation_w,
8 | const int deformable_group);
9 | void modulated_deform_conv_backward(THFloatTensor *input, THFloatTensor *weight,
10 | THFloatTensor *bias, THFloatTensor *ones,
11 | THFloatTensor *offset, THFloatTensor *mask,
12 | THFloatTensor *output, THFloatTensor *columns,
13 | THFloatTensor *grad_input, THFloatTensor *grad_weight,
14 | THFloatTensor *grad_bias, THFloatTensor *grad_offset,
15 | THFloatTensor *grad_mask, THFloatTensor *grad_output,
16 | int kernel_h, int kernel_w,
17 | int stride_h, int stride_w,
18 | int pad_h, int pad_w,
19 | int dilation_h, int dilation_w,
20 | int deformable_group);
--------------------------------------------------------------------------------
/lib/Cell_DETR_master/dcn2/src/modulated_dcn_cuda.h:
--------------------------------------------------------------------------------
1 | // #ifndef DCN_V2_CUDA
2 | // #define DCN_V2_CUDA
3 |
4 | // #ifdef __cplusplus
5 | // extern "C"
6 | // {
7 | // #endif
8 |
9 | void modulated_deform_conv_cuda_forward(THCudaTensor *input, THCudaTensor *weight,
10 | THCudaTensor *bias, THCudaTensor *ones,
11 | THCudaTensor *offset, THCudaTensor *mask,
12 | THCudaTensor *output, THCudaTensor *columns,
13 | int kernel_h, int kernel_w,
14 | const int stride_h, const int stride_w,
15 | const int pad_h, const int pad_w,
16 | const int dilation_h, const int dilation_w,
17 | const int deformable_group);
18 | void modulated_deform_conv_cuda_backward(THCudaTensor *input, THCudaTensor *weight,
19 | THCudaTensor *bias, THCudaTensor *ones,
20 | THCudaTensor *offset, THCudaTensor *mask,
21 | THCudaTensor *columns,
22 | THCudaTensor *grad_input, THCudaTensor *grad_weight,
23 | THCudaTensor *grad_bias, THCudaTensor *grad_offset,
24 | THCudaTensor *grad_mask, THCudaTensor *grad_output,
25 | int kernel_h, int kernel_w,
26 | int stride_h, int stride_w,
27 | int pad_h, int pad_w,
28 | int dilation_h, int dilation_w,
29 | int deformable_group);
30 |
31 | void deform_psroi_pooling_cuda_forward(THCudaTensor * input, THCudaTensor * bbox,
32 | THCudaTensor * trans,
33 | THCudaTensor * out, THCudaTensor * top_count,
34 | const int no_trans,
35 | const float spatial_scale,
36 | const int output_dim,
37 | const int group_size,
38 | const int pooled_size,
39 | const int part_size,
40 | const int sample_per_part,
41 | const float trans_std);
42 |
43 | void deform_psroi_pooling_cuda_backward(THCudaTensor * out_grad,
44 | THCudaTensor * input, THCudaTensor * bbox,
45 | THCudaTensor * trans, THCudaTensor * top_count,
46 | THCudaTensor * input_grad, THCudaTensor * trans_grad,
47 | const int no_trans,
48 | const float spatial_scale,
49 | const int output_dim,
50 | const int group_size,
51 | const int pooled_size,
52 | const int part_size,
53 | const int sample_per_part,
54 | const float trans_std);
55 |
--------------------------------------------------------------------------------
/lib/Cell_DETR_master/dcn2/test.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jcwang123/BA-Transformer/d1ccdff68beac82d18c2cd64bb800e51d7afc5e1/lib/Cell_DETR_master/dcn2/test.py
--------------------------------------------------------------------------------
/lib/Cell_DETR_master/dcn2/test_modulated.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jcwang123/BA-Transformer/d1ccdff68beac82d18c2cd64bb800e51d7afc5e1/lib/Cell_DETR_master/dcn2/test_modulated.py
--------------------------------------------------------------------------------
/lib/Cell_DETR_master/images/CELL_DETR.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jcwang123/BA-Transformer/d1ccdff68beac82d18c2cd64bb800e51d7afc5e1/lib/Cell_DETR_master/images/CELL_DETR.PNG
--------------------------------------------------------------------------------
/lib/Cell_DETR_master/main.py:
--------------------------------------------------------------------------------
1 | from argparse import ArgumentParser
2 | import os
3 | import setproctitle
4 |
5 | # Manage command line arguments
6 | parser = ArgumentParser()
7 | parser.add_argument("--train", default=False, action="store_true",
8 | help="Binary flag. If set training will be performed.")
9 | parser.add_argument("--val", default=False, action="store_true",
10 | help="Binary flag. If set validation will be performed.")
11 | parser.add_argument("--test", default=False, action="store_true",
12 | help="Binary flag. If set testing will be performed.")
13 | parser.add_argument("--cuda_devices", default="0", type=str,
14 | help="String of cuda device indexes to be used. Indexes must be separated by a comma.")
15 | parser.add_argument("--data_parallel", default=False, action="store_true",
16 | help="Binary flag. If multi GPU training should be utilized set flag.")
17 | parser.add_argument("--cpu", default=False, action="store_true",
18 | help="Binary flag. If set all operations are performed on the CPU.")
19 | parser.add_argument("--epochs", default=200, type=int,
20 | help="Number of epochs to perform while training.")
21 | parser.add_argument("--lr_schedule", default=False, action="store_true",
22 | help="Binary flag. If set the learning rate will be reduced after epoch 50 and 100.")
23 | parser.add_argument("--ohem", default=False, action="store_true",
24 | help="Binary flag. If set online heard example mining is utilized.")
25 | parser.add_argument("--ohem_fraction", default=0.75, type=float,
26 | help="Ohem fraction to be applied when performing ohem.")
27 | parser.add_argument("--batch_size", default=4, type=int,
28 | help="Batch size to be utilized while training.")
29 | parser.add_argument("--path_to_data", default="../../BCS_Data/Cell_Instance_Segmentation_Regular_Traps", type=str,
30 | help="Path to dataset.")
31 | parser.add_argument("--augmentation_p", default=0.6, type=float,
32 | help="Probability that data augmentation is applied on training data sample.")
33 | parser.add_argument("--lr_main", default=1e-04, type=float,
34 | help="Learning rate of the detr model (excluding backbone).")
35 | parser.add_argument("--lr_backbone", default=1e-05, type=float,
36 | help="Learning rate of the backbone network.")
37 | parser.add_argument("--lr_segmentation_head", default=1e-06, type=float,
38 | help="Learning rate of the segmentation head, only applied when seg head is trained exclusively.")
39 | parser.add_argument("--no_pac", default=False, action="store_true",
40 | help="Binary flag. If set no pixel adaptive convolutions will be utilized in the segmentation head.")
41 | parser.add_argument("--load_model", default="", type=str,
42 | help="Path to model to be loaded.")
43 | parser.add_argument("--dropout", default=0.05, type=float,
44 | help="Dropout factor to be used in model.")
45 | parser.add_argument("--three_classes", default=False, action="store_true",
46 | help="Binary flag, If set three classes (trap, cell of interest and add. cells) will be utilized.")
47 | parser.add_argument("--softmax", default=False, action="store_true",
48 | help="Binary flag, If set a softmax will be applied to the segmentation prediction instead sigmoid.")
49 | parser.add_argument("--only_train_segmentation_head_after_epoch", default=150, type=int,
50 | help="Number of epoch where only the segmentation head is trained.")
51 | parser.add_argument("--no_deform_conv", default=False, action="store_true",
52 | help="Binary flag. If set no deformable convolutions will be utilized.")
53 | parser.add_argument("--no_pau", default=False, action="store_true",
54 | help="Binary flag. If set no pade activation unit is utilized, however, a leaky ReLU is utilized.")
55 |
56 | # Get arguments
57 | args = parser.parse_args()
58 |
59 | # Set device type
60 | device = "cpu" if args.cpu else "cuda"
61 |
62 | # Set cuda devices
63 | os.environ["CUDA_VISIBLE_DEVICES"] = args.cuda_devices
64 |
65 | setproctitle.setproctitle("Cell-DETR")
66 |
67 | import torch
68 | import torch.nn as nn
69 | from torch.utils.data import DataLoader
70 | from modules.modulated_deform_conv import ModulatedDeformConvPack
71 | from pade_activation_unit.utils import PAU
72 |
73 | # Avoid data loader bug
74 | import resource
75 |
76 | rlimit = resource.getrlimit(resource.RLIMIT_NOFILE)
77 | resource.setrlimit(resource.RLIMIT_NOFILE, (2 ** 12, rlimit[1]))
78 |
79 | from detr import CellDETR
80 | from dataset import CellInstanceSegmentation, collate_function_cell_instance_segmentation
81 | from lossfunction import InstanceSegmentationLoss, SegmentationLoss, MultiClassSegmentationLoss
82 | from model_wrapper import ModelWrapper
83 | from segmentation import ResFeaturePyramidBlock, ResPACFeaturePyramidBlock
84 |
85 | if __name__ == '__main__':
86 | # Init detr
87 | detr = CellDETR(num_classes=3 if args.three_classes else 2,
88 | segmentation_head_block=ResPACFeaturePyramidBlock if not args.no_pac else ResFeaturePyramidBlock,
89 | segmentation_head_final_activation=nn.Softmax if args.softmax else nn.Sigmoid,
90 | backbone_convolution=nn.Conv2d if args.no_deform_conv else ModulatedDeformConvPack,
91 | segmentation_head_convolution=nn.Conv2d if args.no_deform_conv else ModulatedDeformConvPack,
92 | transformer_activation=nn.LeakyReLU if args.no_pau else PAU,
93 | backbone_activation=nn.LeakyReLU if args.no_pau else PAU,
94 | bounding_box_head_activation=nn.LeakyReLU if args.no_pau else PAU,
95 | classification_head_activation=nn.LeakyReLU if args.no_pau else PAU,
96 | segmentation_head_activation=nn.LeakyReLU if args.no_pau else PAU)
97 | if args.load_model != "":
98 | detr.load_state_dict(torch.load(args.load_model))
99 | # Print network
100 | print(detr)
101 | # Print number of parameters
102 | print("# DETR parameters", sum([p.numel() for p in detr.parameters()]))
103 | # Init optimizer
104 | detr_optimizer = torch.optim.AdamW(detr.get_parameters(lr_main=args.lr_main, lr_backbone=args.lr_backbone),
105 | weight_decay=1e-06)
106 | detr_segmentation_optimizer = torch.optim.AdamW(detr.get_segmentation_head_parameters(lr=args.lr_segmentation_head),
107 | weight_decay=1e-06)
108 | # Init data parallel if utilized
109 | if args.data_parallel:
110 | detr = torch.nn.DataParallel(detr)
111 | # Init learning rate schedule if utilized
112 | if args.lr_schedule:
113 | learning_rate_schedule = torch.optim.lr_scheduler.MultiStepLR(detr_optimizer, milestones=[50, 100], gamma=0.1)
114 | else:
115 | learning_rate_schedule = None
116 | # Init datasets
117 | training_dataset = DataLoader(
118 | CellInstanceSegmentation(path=os.path.join(args.path_to_data, "train"),
119 | augmentation_p=args.augmentation_p, two_classes=not args.three_classes),
120 | collate_fn=collate_function_cell_instance_segmentation, batch_size=args.batch_size, num_workers=20,
121 | shuffle=True)
122 | validation_dataset = DataLoader(
123 | CellInstanceSegmentation(path=os.path.join(args.path_to_data, "val"),
124 | augmentation_p=0.0, two_classes=not args.three_classes),
125 | collate_fn=collate_function_cell_instance_segmentation, batch_size=1, num_workers=1, shuffle=False)
126 | test_dataset = DataLoader(
127 | CellInstanceSegmentation(path=os.path.join(args.path_to_data, "test"),
128 | augmentation_p=0.0, two_classes=not args.three_classes),
129 | collate_fn=collate_function_cell_instance_segmentation, batch_size=1, num_workers=1, shuffle=False)
130 | # Model wrapper
131 | model_wrapper = ModelWrapper(detr=detr,
132 | detr_optimizer=detr_optimizer,
133 | detr_segmentation_optimizer=detr_segmentation_optimizer,
134 | training_dataset=training_dataset,
135 | validation_dataset=validation_dataset,
136 | test_dataset=test_dataset,
137 | loss_function=InstanceSegmentationLoss(
138 | segmentation_loss=SegmentationLoss(),
139 | ohem=args.ohem,
140 | ohem_faction=args.ohem_fraction),
141 | device=device)
142 | # Perform training
143 | if args.train:
144 | model_wrapper.train(epochs=args.epochs,
145 | optimize_only_segmentation_head_after_epoch=args.only_train_segmentation_head_after_epoch)
146 | # Perform validation
147 | if args.val:
148 | model_wrapper.validate(number_of_plots=30)
149 | # Perform testing
150 | if args.test:
151 | model_wrapper.test()
152 |
--------------------------------------------------------------------------------
/lib/Cell_DETR_master/matcher.py:
--------------------------------------------------------------------------------
1 | from typing import Tuple, List
2 |
3 | import torch
4 | import torch.nn as nn
5 | from scipy.optimize import linear_sum_assignment
6 |
7 | import misc
8 |
9 |
10 | class HungarianMatcher(nn.Module):
11 | """
12 | This class implements a hungarian algorithm based matcher for DETR.
13 | """
14 |
15 | def __init__(self, weight_classification: float = 1.0,
16 | weight_bb_l1: float = 1.0,
17 | weight_bb_giou: float = 1.0) -> None:
18 | # Call super constructor
19 | super(HungarianMatcher, self).__init__()
20 | # Save parameters
21 | self.weight_classification = weight_classification
22 | self.weight_bb_l1 = weight_bb_l1
23 | self.weight_bb_giou = weight_bb_giou
24 |
25 | def __repr__(self):
26 | """
27 | Get representation of the matcher module
28 | :return: (str) String including information
29 | """
30 | return "{}, W classification:{}, W BB L1:{}, W BB gIoU".format(self.__class__.__name__,
31 | self.weight_classification, self.weight_bb_l1,
32 | self.weight_bb_giou)
33 |
34 | @torch.no_grad()
35 | def forward(self, prediction_classification: torch.Tensor,
36 | prediction_bounding_box: torch.Tensor,
37 | label_classification: Tuple[torch.Tensor],
38 | label_bounding_box: Tuple[torch.Tensor]) -> List[Tuple[torch.Tensor, torch.Tensor]]:
39 | """
40 | Forward pass computes the permutation produced by the hungarian algorithm.
41 | :param prediction_classification: (torch.Tensor) Classification prediction (batch size, # queries, classes + 1)
42 | :param prediction_bounding_box: (torch.Tensor) BB predictions (batch size, # queries, 4)
43 | :param label_classification: (Tuple[torch.Tensor]) Classification label batched [(instances, classes + 1)]
44 | :param label_bounding_box: (Tuple[torch.Tensor]) BB label batched [(instances, 4)]
45 | :return: (torch.Tensor) Permutation of shape (batch size, instances)
46 | """
47 | # Save shapes
48 | batch_size, number_of_queries = prediction_classification.shape[:2]
49 | # Get number of instances in each training sample
50 | number_of_instances = [label_bounding_box_instance.shape[0] for label_bounding_box_instance in
51 | label_bounding_box]
52 | # Flatten to shape [batch size * # queries, classes + 1]
53 | prediction_classification = prediction_classification.flatten(start_dim=0, end_dim=1)
54 | # Flatten to shape [batch size * # queries, 4]
55 | prediction_bounding_box = prediction_bounding_box.flatten(start_dim=0, end_dim=1)
56 | # Class label to index
57 | # Concat labels
58 | label_classification = torch.cat([instance.argmax(dim=-1) for instance in label_classification], dim=0)
59 | label_bounding_box = torch.cat([instance for instance in label_bounding_box], dim=0)
60 | # Compute classification cost
61 | cost_classification = -prediction_classification[:, label_classification.long()]
62 | # Compute the L1 cost of bounding boxes
63 | cost_bounding_boxes_l1 = torch.cdist(prediction_bounding_box, label_bounding_box, p=1)
64 | # Compute gIoU cost of bounding boxes
65 | cost_bounding_boxes_giou = -misc.giou_for_matching(
66 | misc.bounding_box_xcycwh_to_x0y0x1y1(prediction_bounding_box),
67 | misc.bounding_box_xcycwh_to_x0y0x1y1(label_bounding_box))
68 | # Construct cost matrix
69 | cost_matrix = self.weight_classification * cost_classification \
70 | + self.weight_bb_l1 * cost_bounding_boxes_l1 \
71 | + self.weight_bb_giou * cost_bounding_boxes_giou
72 | cost_matrix = cost_matrix.view(batch_size, number_of_queries, -1).cpu().clamp(min=-1e20, max=1e20)
73 | # Get optimal indexes
74 | indexes = [linear_sum_assignment(cost_vector[index]) for index, cost_vector in
75 | enumerate(cost_matrix.split(number_of_instances, dim=-1))]
76 | # Convert indexes to list of prediction index and label index
77 | return [(torch.as_tensor(index_prediction, dtype=torch.int), torch.as_tensor(index_label, dtype=torch.int)) for
78 | index_prediction, index_label in indexes]
79 |
--------------------------------------------------------------------------------
/lib/Cell_DETR_master/pade_activation_unit/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jcwang123/BA-Transformer/d1ccdff68beac82d18c2cd64bb800e51d7afc5e1/lib/Cell_DETR_master/pade_activation_unit/__init__.py
--------------------------------------------------------------------------------
/lib/Cell_DETR_master/pade_activation_unit/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jcwang123/BA-Transformer/d1ccdff68beac82d18c2cd64bb800e51d7afc5e1/lib/Cell_DETR_master/pade_activation_unit/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/lib/Cell_DETR_master/pade_activation_unit/__pycache__/utils.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jcwang123/BA-Transformer/d1ccdff68beac82d18c2cd64bb800e51d7afc5e1/lib/Cell_DETR_master/pade_activation_unit/__pycache__/utils.cpython-37.pyc
--------------------------------------------------------------------------------
/lib/Cell_DETR_master/pade_activation_unit/cuda/pau_cuda.cpp:
--------------------------------------------------------------------------------
1 |
2 | #include
3 | #include
4 | #include
5 |
6 | #define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
7 | #define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
8 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
9 |
10 |
11 | at::Tensor pau_cuda_forward_3_3(torch::Tensor x, torch::Tensor n, torch::Tensor d);
12 | std::vector pau_cuda_backward_3_3(torch::Tensor grad_output, torch::Tensor x, torch::Tensor n, torch::Tensor d);
13 |
14 | at::Tensor pau_cuda_forward_4_4(torch::Tensor x, torch::Tensor n, torch::Tensor d);
15 | std::vector pau_cuda_backward_4_4(torch::Tensor grad_output, torch::Tensor x, torch::Tensor n, torch::Tensor d);
16 |
17 | at::Tensor pau_cuda_forward_5_5(torch::Tensor x, torch::Tensor n, torch::Tensor d);
18 | std::vector pau_cuda_backward_5_5(torch::Tensor grad_output, torch::Tensor x, torch::Tensor n, torch::Tensor d);
19 |
20 | at::Tensor pau_cuda_forward_6_6(torch::Tensor x, torch::Tensor n, torch::Tensor d);
21 | std::vector pau_cuda_backward_6_6(torch::Tensor grad_output, torch::Tensor x, torch::Tensor n, torch::Tensor d);
22 |
23 | at::Tensor pau_cuda_forward_7_7(torch::Tensor x, torch::Tensor n, torch::Tensor d);
24 | std::vector pau_cuda_backward_7_7(torch::Tensor grad_output, torch::Tensor x, torch::Tensor n, torch::Tensor d);
25 |
26 | at::Tensor pau_cuda_forward_8_8(torch::Tensor x, torch::Tensor n, torch::Tensor d);
27 | std::vector pau_cuda_backward_8_8(torch::Tensor grad_output, torch::Tensor x, torch::Tensor n, torch::Tensor d);
28 |
29 | at::Tensor pau_cuda_forward_5_4(torch::Tensor x, torch::Tensor n, torch::Tensor d);
30 | std::vector pau_cuda_backward_5_4(torch::Tensor grad_output, torch::Tensor x, torch::Tensor n, torch::Tensor d);
31 |
32 |
33 | at::Tensor pau_forward__3_3(torch::Tensor x, torch::Tensor n, torch::Tensor d) {
34 | CHECK_INPUT(x);
35 | CHECK_INPUT(n);
36 | CHECK_INPUT(d);
37 |
38 | return pau_cuda_forward_3_3(x, n, d);
39 | }
40 | std::vector pau_backward__3_3(torch::Tensor grad_output, torch::Tensor x, torch::Tensor n, torch::Tensor d) {
41 | CHECK_INPUT(grad_output);
42 | CHECK_INPUT(x);
43 | CHECK_INPUT(n);
44 | CHECK_INPUT(d);
45 |
46 | return pau_cuda_backward_3_3(grad_output, x, n, d);
47 | }
48 |
49 | at::Tensor pau_forward__4_4(torch::Tensor x, torch::Tensor n, torch::Tensor d) {
50 | CHECK_INPUT(x);
51 | CHECK_INPUT(n);
52 | CHECK_INPUT(d);
53 |
54 | return pau_cuda_forward_4_4(x, n, d);
55 | }
56 | std::vector pau_backward__4_4(torch::Tensor grad_output, torch::Tensor x, torch::Tensor n, torch::Tensor d) {
57 | CHECK_INPUT(grad_output);
58 | CHECK_INPUT(x);
59 | CHECK_INPUT(n);
60 | CHECK_INPUT(d);
61 |
62 | return pau_cuda_backward_4_4(grad_output, x, n, d);
63 | }
64 |
65 | at::Tensor pau_forward__5_5(torch::Tensor x, torch::Tensor n, torch::Tensor d) {
66 | CHECK_INPUT(x);
67 | CHECK_INPUT(n);
68 | CHECK_INPUT(d);
69 |
70 | return pau_cuda_forward_5_5(x, n, d);
71 | }
72 | std::vector pau_backward__5_5(torch::Tensor grad_output, torch::Tensor x, torch::Tensor n, torch::Tensor d) {
73 | CHECK_INPUT(grad_output);
74 | CHECK_INPUT(x);
75 | CHECK_INPUT(n);
76 | CHECK_INPUT(d);
77 |
78 | return pau_cuda_backward_5_5(grad_output, x, n, d);
79 | }
80 |
81 | at::Tensor pau_forward__6_6(torch::Tensor x, torch::Tensor n, torch::Tensor d) {
82 | CHECK_INPUT(x);
83 | CHECK_INPUT(n);
84 | CHECK_INPUT(d);
85 |
86 | return pau_cuda_forward_6_6(x, n, d);
87 | }
88 | std::vector pau_backward__6_6(torch::Tensor grad_output, torch::Tensor x, torch::Tensor n, torch::Tensor d) {
89 | CHECK_INPUT(grad_output);
90 | CHECK_INPUT(x);
91 | CHECK_INPUT(n);
92 | CHECK_INPUT(d);
93 |
94 | return pau_cuda_backward_6_6(grad_output, x, n, d);
95 | }
96 |
97 | at::Tensor pau_forward__7_7(torch::Tensor x, torch::Tensor n, torch::Tensor d) {
98 | CHECK_INPUT(x);
99 | CHECK_INPUT(n);
100 | CHECK_INPUT(d);
101 |
102 | return pau_cuda_forward_7_7(x, n, d);
103 | }
104 | std::vector pau_backward__7_7(torch::Tensor grad_output, torch::Tensor x, torch::Tensor n, torch::Tensor d) {
105 | CHECK_INPUT(grad_output);
106 | CHECK_INPUT(x);
107 | CHECK_INPUT(n);
108 | CHECK_INPUT(d);
109 |
110 | return pau_cuda_backward_7_7(grad_output, x, n, d);
111 | }
112 |
113 | at::Tensor pau_forward__8_8(torch::Tensor x, torch::Tensor n, torch::Tensor d) {
114 | CHECK_INPUT(x);
115 | CHECK_INPUT(n);
116 | CHECK_INPUT(d);
117 |
118 | return pau_cuda_forward_8_8(x, n, d);
119 | }
120 | std::vector pau_backward__8_8(torch::Tensor grad_output, torch::Tensor x, torch::Tensor n, torch::Tensor d) {
121 | CHECK_INPUT(grad_output);
122 | CHECK_INPUT(x);
123 | CHECK_INPUT(n);
124 | CHECK_INPUT(d);
125 |
126 | return pau_cuda_backward_8_8(grad_output, x, n, d);
127 | }
128 |
129 | at::Tensor pau_forward__5_4(torch::Tensor x, torch::Tensor n, torch::Tensor d) {
130 | CHECK_INPUT(x);
131 | CHECK_INPUT(n);
132 | CHECK_INPUT(d);
133 |
134 | return pau_cuda_forward_5_4(x, n, d);
135 | }
136 | std::vector pau_backward__5_4(torch::Tensor grad_output, torch::Tensor x, torch::Tensor n, torch::Tensor d) {
137 | CHECK_INPUT(grad_output);
138 | CHECK_INPUT(x);
139 | CHECK_INPUT(n);
140 | CHECK_INPUT(d);
141 |
142 | return pau_cuda_backward_5_4(grad_output, x, n, d);
143 | }
144 |
145 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
146 |
147 | m.def("forward_3_3", &pau_forward__3_3, "PAU forward _3_3");
148 | m.def("backward_3_3", &pau_backward__3_3, "PAU backward _3_3");
149 |
150 | m.def("forward_4_4", &pau_forward__4_4, "PAU forward _4_4");
151 | m.def("backward_4_4", &pau_backward__4_4, "PAU backward _4_4");
152 |
153 | m.def("forward_5_5", &pau_forward__5_5, "PAU forward _5_5");
154 | m.def("backward_5_5", &pau_backward__5_5, "PAU backward _5_5");
155 |
156 | m.def("forward_6_6", &pau_forward__6_6, "PAU forward _6_6");
157 | m.def("backward_6_6", &pau_backward__6_6, "PAU backward _6_6");
158 |
159 | m.def("forward_7_7", &pau_forward__7_7, "PAU forward _7_7");
160 | m.def("backward_7_7", &pau_backward__7_7, "PAU backward _7_7");
161 |
162 | m.def("forward_8_8", &pau_forward__8_8, "PAU forward _8_8");
163 | m.def("backward_8_8", &pau_backward__8_8, "PAU backward _8_8");
164 |
165 | m.def("forward_5_4", &pau_forward__5_4, "PAU forward _5_4");
166 | m.def("backward_5_4", &pau_backward__5_4, "PAU backward _5_4");
167 | }
168 |
--------------------------------------------------------------------------------
/lib/Cell_DETR_master/pade_activation_unit/cuda/python_imp/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jcwang123/BA-Transformer/d1ccdff68beac82d18c2cd64bb800e51d7afc5e1/lib/Cell_DETR_master/pade_activation_unit/cuda/python_imp/__init__.py
--------------------------------------------------------------------------------
/lib/Cell_DETR_master/pade_activation_unit/torchsummary.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jcwang123/BA-Transformer/d1ccdff68beac82d18c2cd64bb800e51d7afc5e1/lib/Cell_DETR_master/pade_activation_unit/torchsummary.py
--------------------------------------------------------------------------------
/lib/Cell_DETR_master/pade_activation_unit/utils.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jcwang123/BA-Transformer/d1ccdff68beac82d18c2cd64bb800e51d7afc5e1/lib/Cell_DETR_master/pade_activation_unit/utils.py
--------------------------------------------------------------------------------
/lib/Cell_DETR_master/pixel_adaptive_convolution/README.md:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jcwang123/BA-Transformer/d1ccdff68beac82d18c2cd64bb800e51d7afc5e1/lib/Cell_DETR_master/pixel_adaptive_convolution/README.md
--------------------------------------------------------------------------------
/lib/Cell_DETR_master/pixel_adaptive_convolution/__pycache__/pac.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jcwang123/BA-Transformer/d1ccdff68beac82d18c2cd64bb800e51d7afc5e1/lib/Cell_DETR_master/pixel_adaptive_convolution/__pycache__/pac.cpython-37.pyc
--------------------------------------------------------------------------------
/lib/Cell_DETR_master/pixel_adaptive_convolution/pac.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jcwang123/BA-Transformer/d1ccdff68beac82d18c2cd64bb800e51d7afc5e1/lib/Cell_DETR_master/pixel_adaptive_convolution/pac.py
--------------------------------------------------------------------------------
/lib/Cell_DETR_master/pixel_adaptive_convolution/paccrf.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jcwang123/BA-Transformer/d1ccdff68beac82d18c2cd64bb800e51d7afc5e1/lib/Cell_DETR_master/pixel_adaptive_convolution/paccrf.py
--------------------------------------------------------------------------------
/lib/Cell_DETR_master/pixel_adaptive_convolution/requirements.txt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jcwang123/BA-Transformer/d1ccdff68beac82d18c2cd64bb800e51d7afc5e1/lib/Cell_DETR_master/pixel_adaptive_convolution/requirements.txt
--------------------------------------------------------------------------------
/lib/Cell_DETR_master/pixel_adaptive_convolution/test_pac.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jcwang123/BA-Transformer/d1ccdff68beac82d18c2cd64bb800e51d7afc5e1/lib/Cell_DETR_master/pixel_adaptive_convolution/test_pac.py
--------------------------------------------------------------------------------
/lib/Cell_DETR_master/pixel_adaptive_convolution/tools/flowlib.py:
--------------------------------------------------------------------------------
1 | """ Copyright (C) 2019 NVIDIA Corporation. All rights reserved. Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
2 | This file incorporates work covered by the following copyright and permission notice:
3 |
4 | Copyright (c) 2019 LI RUOTENG
5 |
6 | Permission to use, copy, modify, and/or distribute this software
7 | for any purpose with or without fee is hereby granted, provided
8 | that the above copyright notice and this permission notice appear
9 | in all copies.
10 |
11 | THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL
12 | WARRANTIES WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED
13 | WARRANTIES OF MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE
14 | AUTHOR BE LIABLE FOR ANY SPECIAL, DIRECT, INDIRECT, OR
15 | CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS
16 | OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT,
17 | NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN
18 | CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
19 | """
20 | from __future__ import division
21 |
22 | import numpy as np
23 | import matplotlib.pyplot as plt
24 | UNKNOWN_FLOW_THRESH = 1e7
25 |
26 |
27 | def evaluate_flow(gt, pred):
28 | """
29 | evaluate the estimated optical flow end point error according to ground truth provided
30 | :param gt: ground truth file path
31 | :param pred: estimated optical flow file path
32 | :return: end point error, float32
33 | """
34 | # Read flow files and calculate the errors
35 | gt_flow = read_flow(gt) # ground truth flow
36 | eva_flow = read_flow(pred) # predicted flow
37 | # Calculate errors
38 | average_pe = flow_error(gt_flow[:, :, 0], gt_flow[:, :, 1], eva_flow[:, :, 0], eva_flow[:, :, 1])
39 | return average_pe
40 |
41 |
42 | def show_flow(filename):
43 | """
44 | visualize optical flow map using matplotlib
45 | :param filename: optical flow file
46 | :return: None
47 | """
48 | flow = read_flow(filename)
49 | img = flow_to_image(flow)
50 | plt.imshow(img)
51 | plt.show()
52 |
53 |
54 | # WARNING: this will work on little-endian architectures (eg Intel x86) only!
55 | def read_flow(filename):
56 | """
57 | read optical flow in Middlebury .flo file format
58 | :param filename:
59 | :return:
60 | """
61 | f = open(filename, 'rb')
62 | magic = np.fromfile(f, np.float32, count=1)
63 | data2d = None
64 |
65 | if 202021.25 != magic:
66 | print('Magic number incorrect. Invalid .flo file')
67 | else:
68 | w = np.fromfile(f, np.int32, count=1)[0]
69 | h = np.fromfile(f, np.int32, count=1)[0]
70 | data2d = np.fromfile(f, np.float32, count=2 * w * h)
71 | # reshape data into 3D array (columns, rows, channels)
72 | data2d = np.reshape(data2d, (h, w, 2))
73 | f.close()
74 | return data2d
75 |
76 |
77 | # WARNING: this will work on little-endian architectures only!
78 | def write_flow(flow, filename):
79 | """
80 | write optical flow in Middlebury .flo format
81 | :param flow: optical flow map
82 | :param filename: optical flow file path to be saved
83 | :return: None
84 | """
85 | f = open(filename, 'wb')
86 | magic = np.array([202021.25], dtype=np.float32)
87 | (height, width) = flow.shape
88 | w = np.array([width], dtype=np.int32)
89 | h = np.array([height], dtype=np.int32)
90 | empty_map = np.zeros((height, width), dtype=np.float32)
91 | data = np.dstack((flow, empty_map))
92 | magic.tofile(f)
93 | w.tofile(f)
94 | h.tofile(f)
95 | data.tofile(f)
96 | f.close()
97 |
98 |
99 | def flow_error(tu, tv, u, v):
100 | """
101 | Calculate average end point error
102 | :param tu: ground-truth horizontal flow map
103 | :param tv: ground-truth vertical flow map
104 | :param u: estimated horizontal flow map
105 | :param v: estimated vertical flow map
106 | :return: End point error of the estimated flow
107 | """
108 | smallflow = 0.0
109 | '''
110 | stu = tu[bord+1:end-bord,bord+1:end-bord]
111 | stv = tv[bord+1:end-bord,bord+1:end-bord]
112 | su = u[bord+1:end-bord,bord+1:end-bord]
113 | sv = v[bord+1:end-bord,bord+1:end-bord]
114 | '''
115 | stu = tu[:]
116 | stv = tv[:]
117 | su = u[:]
118 | sv = v[:]
119 |
120 | idxUnknow = (abs(stu) > UNKNOWN_FLOW_THRESH) | (abs(stv) > UNKNOWN_FLOW_THRESH)
121 | stu[idxUnknow] = 0
122 | stv[idxUnknow] = 0
123 | su[idxUnknow] = 0
124 | sv[idxUnknow] = 0
125 |
126 | ind2 = [(np.absolute(stu) > smallflow) | (np.absolute(stv) > smallflow)]
127 | index_su = su[ind2]
128 | index_sv = sv[ind2]
129 | an = 1.0 / np.sqrt(index_su ** 2 + index_sv ** 2 + 1)
130 | un = index_su * an
131 | vn = index_sv * an
132 |
133 | index_stu = stu[ind2]
134 | index_stv = stv[ind2]
135 | tn = 1.0 / np.sqrt(index_stu ** 2 + index_stv ** 2 + 1)
136 | tun = index_stu * tn
137 | tvn = index_stv * tn
138 |
139 | '''
140 | angle = un * tun + vn * tvn + (an * tn)
141 | index = [angle == 1.0]
142 | angle[index] = 0.999
143 | ang = np.arccos(angle)
144 | mang = np.mean(ang)
145 | mang = mang * 180 / np.pi
146 | '''
147 |
148 | epe = np.sqrt((stu - su) ** 2 + (stv - sv) ** 2)
149 | epe = epe[ind2]
150 | mepe = np.mean(epe)
151 | return mepe
152 |
153 |
154 | def flow_to_image(flow):
155 | """
156 | Convert flow into middlebury color code image
157 | :param flow: optical flow map
158 | :return: optical flow image in middlebury color
159 | """
160 | u = flow[:, :, 0]
161 | v = flow[:, :, 1]
162 |
163 | maxu = -999.
164 | maxv = -999.
165 | minu = 999.
166 | minv = 999.
167 |
168 | idxUnknow = (abs(u) > UNKNOWN_FLOW_THRESH) | (abs(v) > UNKNOWN_FLOW_THRESH)
169 | u[idxUnknow] = 0
170 | v[idxUnknow] = 0
171 |
172 | maxu = max(maxu, np.max(u))
173 | minu = min(minu, np.min(u))
174 |
175 | maxv = max(maxv, np.max(v))
176 | minv = min(minv, np.min(v))
177 |
178 | rad = np.sqrt(u ** 2 + v ** 2)
179 | maxrad = max(-1, np.max(rad))
180 |
181 | print("max flow: %.4f\nflow range:\nu = %.3f .. %.3f\nv = %.3f .. %.3f" % (maxrad, minu,maxu, minv, maxv))
182 |
183 | u = u/(maxrad + np.finfo(float).eps)
184 | v = v/(maxrad + np.finfo(float).eps)
185 |
186 | img = compute_color(u, v)
187 |
188 | idx = np.repeat(idxUnknow[:, :, np.newaxis], 3, axis=2)
189 | img[idx] = 0
190 |
191 | return np.uint8(img)
192 |
193 |
194 | def compute_color(u, v):
195 | """
196 | compute optical flow color map
197 | :param u: optical flow horizontal map
198 | :param v: optical flow vertical map
199 | :return: optical flow in color code
200 | """
201 | [h, w] = u.shape
202 | img = np.zeros([h, w, 3])
203 | nanIdx = np.isnan(u) | np.isnan(v)
204 | u[nanIdx] = 0
205 | v[nanIdx] = 0
206 |
207 | colorwheel = make_color_wheel()
208 | ncols = np.size(colorwheel, 0)
209 |
210 | rad = np.sqrt(u**2+v**2)
211 |
212 | a = np.arctan2(-v, -u) / np.pi
213 |
214 | fk = (a+1) / 2 * (ncols - 1) + 1
215 |
216 | k0 = np.floor(fk).astype(int)
217 |
218 | k1 = k0 + 1
219 | k1[k1 == ncols+1] = 1
220 | f = fk - k0
221 |
222 | for i in range(0, np.size(colorwheel,1)):
223 | tmp = colorwheel[:, i]
224 | col0 = tmp[k0-1] / 255
225 | col1 = tmp[k1-1] / 255
226 | col = (1-f) * col0 + f * col1
227 |
228 | idx = rad <= 1
229 | col[idx] = 1-rad[idx]*(1-col[idx])
230 | notidx = np.logical_not(idx)
231 |
232 | col[notidx] *= 0.75
233 | img[:, :, i] = np.uint8(np.floor(255 * col*(1-nanIdx)))
234 |
235 | return img
236 |
237 |
238 | def make_color_wheel():
239 | """
240 | Generate color wheel according Middlebury color code
241 | :return: Color wheel
242 | """
243 | RY = 15
244 | YG = 6
245 | GC = 4
246 | CB = 11
247 | BM = 13
248 | MR = 6
249 |
250 | ncols = RY + YG + GC + CB + BM + MR
251 |
252 | colorwheel = np.zeros([ncols, 3])
253 |
254 | col = 0
255 |
256 | # RY
257 | colorwheel[0:RY, 0] = 255
258 | colorwheel[0:RY, 1] = np.transpose(np.floor(255*np.arange(0, RY) / RY))
259 | col += RY
260 |
261 | # YG
262 | colorwheel[col:col+YG, 0] = 255 - np.transpose(np.floor(255*np.arange(0, YG) / YG))
263 | colorwheel[col:col+YG, 1] = 255
264 | col += YG
265 |
266 | # GC
267 | colorwheel[col:col+GC, 1] = 255
268 | colorwheel[col:col+GC, 2] = np.transpose(np.floor(255*np.arange(0, GC) / GC))
269 | col += GC
270 |
271 | # CB
272 | colorwheel[col:col+CB, 1] = 255 - np.transpose(np.floor(255*np.arange(0, CB) / CB))
273 | colorwheel[col:col+CB, 2] = 255
274 | col += CB
275 |
276 | # BM
277 | colorwheel[col:col+BM, 2] = 255
278 | colorwheel[col:col+BM, 0] = np.transpose(np.floor(255*np.arange(0, BM) / BM))
279 | col += + BM
280 |
281 | # MR
282 | colorwheel[col:col+MR, 2] = 255 - np.transpose(np.floor(255 * np.arange(0, MR) / MR))
283 | colorwheel[col:col+MR, 0] = 255
284 |
285 | return colorwheel
286 |
287 |
288 | def scale_image(image, new_range):
289 | """
290 | Linearly scale the image into desired range
291 | :param image: input image
292 | :param new_range: the new range to be aligned
293 | :return: image normalized in new range
294 | """
295 | min_val = np.min(image).astype(np.float32)
296 | max_val = np.max(image).astype(np.float32)
297 | min_val_new = np.array(min(new_range), dtype=np.float32)
298 | max_val_new = np.array(max(new_range), dtype=np.float32)
299 | scaled_image = (image - min_val) / (max_val - min_val) * (max_val_new - min_val_new) + min_val_new
300 | return scaled_image.astype(np.uint8)
301 |
302 |
--------------------------------------------------------------------------------
/lib/Cell_DETR_master/pixel_adaptive_convolution/tools/plot_log.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved.
3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
4 | """
5 | import argparse
6 |
7 | import numpy as np
8 | import matplotlib.pyplot as plt
9 |
10 |
11 | LINE_COLORS = plt.rcParams['axes.prop_cycle'].by_key()['color']
12 | LINE_STYLES = ('solid', 'dashed', 'dashdot', 'dotted')
13 |
14 |
15 | def smooth_plot(xs, ys, smooth=5.0, axis=None, *args, **kwargs):
16 | min_window = 3
17 | if smooth > 0 and len(xs) > min_window * 100 / smooth:
18 | window = int(len(xs) * smooth // 100)
19 | f = np.repeat(1.0, window) / window
20 | ys = np.convolve(ys, f, 'valid')
21 | xs = xs[(window//2):(window//2)+len(ys)]
22 | # TODO: fix issue with empty strings and strings starting with '_'
23 | if 'label' in kwargs and (kwargs['label'].startswith('_') or kwargs['label'] == ''):
24 | kwargs['label'] = ' ' + kwargs['label']
25 | if axis is None:
26 | plt.plot(xs, ys, *args, **kwargs)
27 | else:
28 | axis.plot(xs, ys, *args, **kwargs)
29 |
30 |
31 | def remove_common_prefix_suffix(array_of_str):
32 | if len(array_of_str) == 1:
33 | return array_of_str
34 | remove_head, remove_tail = 0, 0
35 | while True:
36 | v = array_of_str[0][remove_head]
37 | if all(s[remove_head] == v for s in array_of_str):
38 | remove_head += 1
39 | else:
40 | break
41 | while True:
42 | v = array_of_str[0][len(array_of_str[0]) - remove_tail - 1]
43 | if all(s[len(s) - remove_tail - 1] == v for s in array_of_str):
44 | remove_tail += 1
45 | else:
46 | break
47 | return [s[remove_head:len(s) - remove_tail] for s in array_of_str]
48 |
49 |
50 | def parse_and_plot(paths, output=None, plots=None, labels=None, reorder=None, trials=1,
51 | smooth=5.0, num_col=0, subplot_size=5, start_x=0.0, end_x=-1.0, legend='best'):
52 | with open(paths[0], 'r') as f:
53 | r = f.readline()
54 | xlabel = r.strip().split(',')[0]
55 | plots_all = r.strip().split(',')[1:]
56 | if plots:
57 | plot_dict = {p: i+1 for i, p in enumerate(plots_all)}
58 | plot_cols = tuple(-1 if p == '-' else plot_dict[p] for p in plots)
59 | else:
60 | plots = plots_all
61 | plot_cols = tuple(range(1, len(plots) + 1))
62 | if not labels:
63 | labels = remove_common_prefix_suffix(paths)[::trials]
64 | assert(len(labels) == len(set(labels)))
65 | paths = [paths[i*trials:(i+1)*trials] for i in range(len(labels))]
66 | if reorder:
67 | if len(reorder) == 1 and reorder[0] == 'str':
68 | reorder = sorted(labels)
69 | paths = [paths[labels.index(l)] for l in reorder]
70 | labels = reorder
71 |
72 | runs_data = []
73 | for ps in paths:
74 | tmp = []
75 | for p in ps:
76 | data = np.genfromtxt(p, delimiter=',', usecols=(0,)+tuple(filter(lambda v: v != -1, plot_cols)), skip_header=1)
77 | data = data.reshape(-1, data.shape[-1])
78 | start = np.where(data[:, 0] >= start_x)[0][0]
79 | end = None
80 | if end_x > 0:
81 | end_idxs = np.where(data[:, 0] > end_x)[0]
82 | if len(end_idxs) > 0:
83 | end = end_idxs[0]
84 | data = data[start:end]
85 | tmp.append(data)
86 | _len = min(len(d) for d in tmp)
87 | runs_data.append(np.mean([d[:_len] for d in tmp], axis=0))
88 |
89 | if output:
90 | plt.switch_backend('agg')
91 |
92 | if num_col <= 0:
93 | num_col = int(np.ceil(np.sqrt(len(plots))))
94 | num_row = (len(plots) - 1) // num_col + 1
95 | fig, axes = plt.subplots(num_row, num_col, squeeze=False, figsize=(num_col * subplot_size, num_row * subplot_size))
96 | p_idx = 0
97 | for plabel, ax in zip(plots, axes.flat):
98 | if plabel == '-':
99 | continue
100 | p_idx += 1
101 | # plt.subplot(num_row, num_col, p + 1)
102 | for r, rdata in enumerate(runs_data):
103 | valid_mask = ~np.isnan(rdata[:, p_idx])
104 | if any(valid_mask):
105 | smooth_plot(rdata[valid_mask, 0], rdata[valid_mask, p_idx], smooth, ax,
106 | color=LINE_COLORS[r % len(LINE_COLORS)],
107 | linestyle=LINE_STYLES[r // len(LINE_COLORS)],
108 | linewidth=1.5, label=labels[r])
109 | ax.set_xlabel(xlabel)
110 | ax.set_title(plabel)
111 | ax.grid(True)
112 | if legend != 'off' and len(labels) > 1:
113 | ax.legend(loc=legend)
114 | for ax in axes.flat[len(plots):]:
115 | fig.delaxes(ax)
116 | fig.tight_layout()
117 |
118 | if output:
119 | plt.savefig(output)
120 | else:
121 | plt.show()
122 |
123 |
124 | if __name__ == '__main__':
125 | parser = argparse.ArgumentParser(description='Parse log files (csv format) and make plots',
126 | formatter_class=argparse.ArgumentDefaultsHelpFormatter)
127 | parser.add_argument('paths', nargs='+', help='one or more log files (each subplot will have #paths curves)')
128 | parser.add_argument('--output', type=str, default='', help='place to save figure (show figure if blank)')
129 | parser.add_argument('--plots', nargs='+', default=None, help='use select plots instead of one for each column')
130 | parser.add_argument('--labels', nargs='+', default=None, help='use labels instead of file names')
131 | parser.add_argument('--average', type=int, default=1, help='plot average curves')
132 | parser.add_argument('--reorder', nargs='+', default=None, help='order the label legends')
133 | parser.add_argument('--smooth', type=float, default=5.0, help='smoothing level (0-100)')
134 | parser.add_argument('--num-col', type=int, default=0, help='number of columns (0 - auto)')
135 | parser.add_argument('--subplot-size', type=int, default=5, help='width and height of each subplot')
136 | parser.add_argument('--legend', type=str, default='best', help='place to put legend')
137 | parser.add_argument('--xlim', nargs='+', default=None, help='xmin (xmax)')
138 |
139 | args = parser.parse_args()
140 |
141 | start_x = 0.0 if not args.xlim else float(args.xlim[0])
142 | end_x = -1.0 if (not args.xlim or len(args.xlim) < 2) else float(args.xlim[1])
143 |
144 | parse_and_plot(args.paths, output=args.output, plots=args.plots, labels=args.labels, reorder=args.reorder,
145 | trials=args.average, smooth=args.smooth, num_col=args.num_col, subplot_size=args.subplot_size,
146 | start_x=start_x, end_x=end_x, legend=args.legend.replace('_', ' '))
147 |
--------------------------------------------------------------------------------
/lib/Cell_DETR_master/requirements.txt:
--------------------------------------------------------------------------------
1 | pytorch==1.4.0
2 | torchvision==0.5.0
3 | numpy
4 | skimage
5 | scipy
6 | tqdm
7 | setproctitle
8 | tikzplotlib
--------------------------------------------------------------------------------
/lib/non_local/non_local_concatenation.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from torch.nn import functional as F
4 |
5 |
6 | class _NonLocalBlockND(nn.Module):
7 | def __init__(self, in_channels, inter_channels=None, dimension=3, sub_sample=True, bn_layer=True):
8 | super(_NonLocalBlockND, self).__init__()
9 |
10 | assert dimension in [1, 2, 3]
11 |
12 | self.dimension = dimension
13 | self.sub_sample = sub_sample
14 |
15 | self.in_channels = in_channels
16 | self.inter_channels = inter_channels
17 |
18 | if self.inter_channels is None:
19 | self.inter_channels = in_channels // 2
20 | if self.inter_channels == 0:
21 | self.inter_channels = 1
22 |
23 | if dimension == 3:
24 | conv_nd = nn.Conv3d
25 | max_pool_layer = nn.MaxPool3d(kernel_size=(1, 2, 2))
26 | bn = nn.BatchNorm3d
27 | elif dimension == 2:
28 | conv_nd = nn.Conv2d
29 | max_pool_layer = nn.MaxPool2d(kernel_size=(2, 2))
30 | bn = nn.BatchNorm2d
31 | else:
32 | conv_nd = nn.Conv1d
33 | max_pool_layer = nn.MaxPool1d(kernel_size=(2))
34 | bn = nn.BatchNorm1d
35 |
36 | self.g = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
37 | kernel_size=1, stride=1, padding=0)
38 |
39 | if bn_layer:
40 | self.W = nn.Sequential(
41 | conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels,
42 | kernel_size=1, stride=1, padding=0),
43 | bn(self.in_channels)
44 | )
45 | nn.init.constant_(self.W[1].weight, 0)
46 | nn.init.constant_(self.W[1].bias, 0)
47 | else:
48 | self.W = conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels,
49 | kernel_size=1, stride=1, padding=0)
50 | nn.init.constant_(self.W.weight, 0)
51 | nn.init.constant_(self.W.bias, 0)
52 |
53 | self.theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
54 | kernel_size=1, stride=1, padding=0)
55 |
56 | self.phi = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
57 | kernel_size=1, stride=1, padding=0)
58 |
59 | self.concat_project = nn.Sequential(
60 | nn.Conv2d(self.inter_channels * 2, 1, 1, 1, 0, bias=False),
61 | nn.ReLU()
62 | )
63 |
64 | if sub_sample:
65 | self.g = nn.Sequential(self.g, max_pool_layer)
66 | self.phi = nn.Sequential(self.phi, max_pool_layer)
67 |
68 | def forward(self, x, return_nl_map=False):
69 | '''
70 | :param x: (b, c, t, h, w)
71 | :param return_nl_map: if True return z, nl_map, else only return z.
72 | :return:
73 | '''
74 |
75 | batch_size = x.size(0)
76 |
77 | g_x = self.g(x).view(batch_size, self.inter_channels, -1)
78 | g_x = g_x.permute(0, 2, 1)
79 |
80 | # (b, c, N, 1)
81 | theta_x = self.theta(x).view(batch_size, self.inter_channels, -1, 1)
82 | # (b, c, 1, N)
83 | phi_x = self.phi(x).view(batch_size, self.inter_channels, 1, -1)
84 |
85 | h = theta_x.size(2)
86 | w = phi_x.size(3)
87 | theta_x = theta_x.repeat(1, 1, 1, w)
88 | phi_x = phi_x.repeat(1, 1, h, 1)
89 |
90 | concat_feature = torch.cat([theta_x, phi_x], dim=1)
91 | f = self.concat_project(concat_feature)
92 | b, _, h, w = f.size()
93 | f = f.view(b, h, w)
94 |
95 | N = f.size(-1)
96 | f_div_C = f / N
97 |
98 | y = torch.matmul(f_div_C, g_x)
99 | y = y.permute(0, 2, 1).contiguous()
100 | y = y.view(batch_size, self.inter_channels, *x.size()[2:])
101 | W_y = self.W(y)
102 | z = W_y + x
103 |
104 | if return_nl_map:
105 | return z, f_div_C
106 | return z
107 |
108 |
109 | class NONLocalBlock1D(_NonLocalBlockND):
110 | def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True):
111 | super(NONLocalBlock1D, self).__init__(in_channels,
112 | inter_channels=inter_channels,
113 | dimension=1, sub_sample=sub_sample,
114 | bn_layer=bn_layer)
115 |
116 |
117 | class NONLocalBlock2D(_NonLocalBlockND):
118 | def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True):
119 | super(NONLocalBlock2D, self).__init__(in_channels,
120 | inter_channels=inter_channels,
121 | dimension=2, sub_sample=sub_sample,
122 | bn_layer=bn_layer)
123 |
124 |
125 | class NONLocalBlock3D(_NonLocalBlockND):
126 | def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True,):
127 | super(NONLocalBlock3D, self).__init__(in_channels,
128 | inter_channels=inter_channels,
129 | dimension=3, sub_sample=sub_sample,
130 | bn_layer=bn_layer)
131 |
132 |
133 | if __name__ == '__main__':
134 | import torch
135 |
136 | for (sub_sample_, bn_layer_) in [(True, True), (False, False), (True, False), (False, True)]:
137 | img = torch.zeros(2, 3, 20)
138 | net = NONLocalBlock1D(3, sub_sample=sub_sample_, bn_layer=bn_layer_)
139 | out = net(img)
140 | print(out.size())
141 |
142 | img = torch.zeros(2, 3, 20, 20)
143 | net = NONLocalBlock2D(3, sub_sample=sub_sample_, bn_layer=bn_layer_)
144 | out = net(img)
145 | print(out.size())
146 |
147 | img = torch.randn(2, 3, 8, 20, 20)
148 | net = NONLocalBlock3D(3, sub_sample=sub_sample_, bn_layer=bn_layer_)
149 | out = net(img)
150 | print(out.size())
151 |
--------------------------------------------------------------------------------
/lib/non_local/non_local_dot_product.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from torch.nn import functional as F
4 |
5 |
6 | class _NonLocalBlockND(nn.Module):
7 | def __init__(self, in_channels, inter_channels=None, dimension=3, sub_sample=True, bn_layer=True):
8 | super(_NonLocalBlockND, self).__init__()
9 |
10 | assert dimension in [1, 2, 3]
11 |
12 | self.dimension = dimension
13 | self.sub_sample = sub_sample
14 |
15 | self.in_channels = in_channels
16 | self.inter_channels = inter_channels
17 |
18 | if self.inter_channels is None:
19 | self.inter_channels = in_channels // 2
20 | if self.inter_channels == 0:
21 | self.inter_channels = 1
22 |
23 | if dimension == 3:
24 | conv_nd = nn.Conv3d
25 | max_pool_layer = nn.MaxPool3d(kernel_size=(1, 2, 2))
26 | bn = nn.BatchNorm3d
27 | elif dimension == 2:
28 | conv_nd = nn.Conv2d
29 | max_pool_layer = nn.MaxPool2d(kernel_size=(2, 2))
30 | bn = nn.BatchNorm2d
31 | else:
32 | conv_nd = nn.Conv1d
33 | max_pool_layer = nn.MaxPool1d(kernel_size=(2))
34 | bn = nn.BatchNorm1d
35 |
36 | self.g = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
37 | kernel_size=1, stride=1, padding=0)
38 |
39 | if bn_layer:
40 | self.W = nn.Sequential(
41 | conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels,
42 | kernel_size=1, stride=1, padding=0),
43 | bn(self.in_channels)
44 | )
45 | nn.init.constant_(self.W[1].weight, 0)
46 | nn.init.constant_(self.W[1].bias, 0)
47 | else:
48 | self.W = conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels,
49 | kernel_size=1, stride=1, padding=0)
50 | nn.init.constant_(self.W.weight, 0)
51 | nn.init.constant_(self.W.bias, 0)
52 |
53 | self.theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
54 | kernel_size=1, stride=1, padding=0)
55 |
56 | self.phi = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
57 | kernel_size=1, stride=1, padding=0)
58 |
59 | if sub_sample:
60 | self.g = nn.Sequential(self.g, max_pool_layer)
61 | self.phi = nn.Sequential(self.phi, max_pool_layer)
62 |
63 | def forward(self, x, return_nl_map=False):
64 | """
65 | :param x: (b, c, t, h, w)
66 | :param return_nl_map: if True return z, nl_map, else only return z.
67 | :return:
68 | """
69 |
70 | batch_size = x.size(0)
71 |
72 | g_x = self.g(x).view(batch_size, self.inter_channels, -1)
73 | g_x = g_x.permute(0, 2, 1)
74 |
75 | theta_x = self.theta(x).view(batch_size, self.inter_channels, -1)
76 | theta_x = theta_x.permute(0, 2, 1)
77 | phi_x = self.phi(x).view(batch_size, self.inter_channels, -1)
78 | f = torch.matmul(theta_x, phi_x)
79 | N = f.size(-1)
80 | f_div_C = f / N
81 |
82 | y = torch.matmul(f_div_C, g_x)
83 | y = y.permute(0, 2, 1).contiguous()
84 | y = y.view(batch_size, self.inter_channels, *x.size()[2:])
85 | W_y = self.W(y)
86 | z = W_y + x
87 |
88 | if return_nl_map:
89 | return z, f_div_C
90 | return z
91 |
92 |
93 | class NONLocalBlock1D(_NonLocalBlockND):
94 | def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True):
95 | super(NONLocalBlock1D, self).__init__(in_channels,
96 | inter_channels=inter_channels,
97 | dimension=1, sub_sample=sub_sample,
98 | bn_layer=bn_layer)
99 |
100 |
101 | class NONLocalBlock2D(_NonLocalBlockND):
102 | def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True):
103 | super(NONLocalBlock2D, self).__init__(in_channels,
104 | inter_channels=inter_channels,
105 | dimension=2, sub_sample=sub_sample,
106 | bn_layer=bn_layer)
107 |
108 |
109 | class NONLocalBlock3D(_NonLocalBlockND):
110 | def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True):
111 | super(NONLocalBlock3D, self).__init__(in_channels,
112 | inter_channels=inter_channels,
113 | dimension=3, sub_sample=sub_sample,
114 | bn_layer=bn_layer)
115 |
116 |
117 | if __name__ == '__main__':
118 | import torch
119 |
120 | for (sub_sample_, bn_layer_) in [(True, True), (False, False), (True, False), (False, True)]:
121 | img = torch.zeros(2, 3, 20)
122 | net = NONLocalBlock1D(3, sub_sample=sub_sample_, bn_layer=bn_layer_)
123 | out = net(img)
124 | print(out.size())
125 |
126 | img = torch.zeros(2, 3, 20, 20)
127 | net = NONLocalBlock2D(3, sub_sample=sub_sample_, bn_layer=bn_layer_)
128 | out = net(img)
129 | print(out.size())
130 |
131 | img = torch.randn(2, 3, 8, 20, 20)
132 | net = NONLocalBlock3D(3, sub_sample=sub_sample_, bn_layer=bn_layer_)
133 | out = net(img)
134 | print(out.size())
135 |
136 |
137 |
138 |
--------------------------------------------------------------------------------
/lib/non_local/non_local_embedded_gaussian.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from torch.nn import functional as F
4 |
5 |
6 | class _NonLocalBlockND(nn.Module):
7 | def __init__(self, in_channels, inter_channels=None, dimension=3, sub_sample=True, bn_layer=True):
8 | """
9 | :param in_channels:
10 | :param inter_channels:
11 | :param dimension:
12 | :param sub_sample:
13 | :param bn_layer:
14 | """
15 |
16 | super(_NonLocalBlockND, self).__init__()
17 |
18 | assert dimension in [1, 2, 3]
19 |
20 | self.dimension = dimension
21 | self.sub_sample = sub_sample
22 |
23 | self.in_channels = in_channels
24 | self.inter_channels = inter_channels
25 |
26 | if self.inter_channels is None:
27 | self.inter_channels = in_channels // 2
28 | if self.inter_channels == 0:
29 | self.inter_channels = 1
30 |
31 | if dimension == 3:
32 | conv_nd = nn.Conv3d
33 | max_pool_layer = nn.MaxPool3d(kernel_size=(1, 2, 2))
34 | bn = nn.BatchNorm3d
35 | elif dimension == 2:
36 | conv_nd = nn.Conv2d
37 | max_pool_layer = nn.MaxPool2d(kernel_size=(2, 2))
38 | bn = nn.BatchNorm2d
39 | else:
40 | conv_nd = nn.Conv1d
41 | max_pool_layer = nn.MaxPool1d(kernel_size=(2))
42 | bn = nn.BatchNorm1d
43 |
44 | self.g = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
45 | kernel_size=1, stride=1, padding=0)
46 |
47 | if bn_layer:
48 | self.W = nn.Sequential(
49 | conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels,
50 | kernel_size=1, stride=1, padding=0),
51 | bn(self.in_channels)
52 | )
53 | nn.init.constant_(self.W[1].weight, 0)
54 | nn.init.constant_(self.W[1].bias, 0)
55 | else:
56 | self.W = conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels,
57 | kernel_size=1, stride=1, padding=0)
58 | nn.init.constant_(self.W.weight, 0)
59 | nn.init.constant_(self.W.bias, 0)
60 |
61 | self.theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
62 | kernel_size=1, stride=1, padding=0)
63 | self.phi = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
64 | kernel_size=1, stride=1, padding=0)
65 |
66 | if sub_sample:
67 | self.g = nn.Sequential(self.g, max_pool_layer)
68 | self.phi = nn.Sequential(self.phi, max_pool_layer)
69 |
70 | def forward(self, x, return_nl_map=False):
71 | """
72 | :param x: (b, c, t, h, w)
73 | :param return_nl_map: if True return z, nl_map, else only return z.
74 | :return:
75 | """
76 |
77 | batch_size = x.size(0)
78 |
79 | g_x = self.g(x).view(batch_size, self.inter_channels, -1)
80 | g_x = g_x.permute(0, 2, 1)
81 |
82 | theta_x = self.theta(x).view(batch_size, self.inter_channels, -1)
83 | theta_x = theta_x.permute(0, 2, 1)
84 | phi_x = self.phi(x).view(batch_size, self.inter_channels, -1)
85 | f = torch.matmul(theta_x, phi_x)
86 | f_div_C = F.softmax(f, dim=-1)
87 |
88 | y = torch.matmul(f_div_C, g_x)
89 | y = y.permute(0, 2, 1).contiguous()
90 | y = y.view(batch_size, self.inter_channels, *x.size()[2:])
91 | W_y = self.W(y)
92 | z = W_y + x
93 |
94 | if return_nl_map:
95 | return z, f_div_C
96 | return z
97 |
98 |
99 | class NONLocalBlock1D(_NonLocalBlockND):
100 | def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True):
101 | super(NONLocalBlock1D, self).__init__(in_channels,
102 | inter_channels=inter_channels,
103 | dimension=1, sub_sample=sub_sample,
104 | bn_layer=bn_layer)
105 |
106 |
107 | class NONLocalBlock2D(_NonLocalBlockND):
108 | def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True):
109 | super(NONLocalBlock2D, self).__init__(in_channels,
110 | inter_channels=inter_channels,
111 | dimension=2, sub_sample=sub_sample,
112 | bn_layer=bn_layer,)
113 |
114 |
115 | class NONLocalBlock3D(_NonLocalBlockND):
116 | def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True):
117 | super(NONLocalBlock3D, self).__init__(in_channels,
118 | inter_channels=inter_channels,
119 | dimension=3, sub_sample=sub_sample,
120 | bn_layer=bn_layer,)
121 |
122 |
123 | if __name__ == '__main__':
124 | import torch
125 |
126 | for (sub_sample_, bn_layer_) in [(True, True), (False, False), (True, False), (False, True)]:
127 | img = torch.zeros(2, 3, 20)
128 | net = NONLocalBlock1D(3, sub_sample=sub_sample_, bn_layer=bn_layer_)
129 | out = net(img)
130 | print(out.size())
131 |
132 | img = torch.zeros(2, 3, 20, 20)
133 | net = NONLocalBlock2D(3, sub_sample=sub_sample_, bn_layer=bn_layer_)
134 | out = net(img)
135 | print(out.size())
136 |
137 | img = torch.randn(2, 3, 8, 20, 20)
138 | net = NONLocalBlock3D(3, sub_sample=sub_sample_, bn_layer=bn_layer_)
139 | out = net(img)
140 | print(out.size())
141 |
142 |
143 |
--------------------------------------------------------------------------------
/lib/non_local/non_local_gaussian.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from torch.nn import functional as F
4 |
5 |
6 | class _NonLocalBlockND(nn.Module):
7 | def __init__(self, in_channels, inter_channels=None, dimension=3, sub_sample=True, bn_layer=True):
8 | super(_NonLocalBlockND, self).__init__()
9 |
10 | assert dimension in [1, 2, 3]
11 |
12 | self.dimension = dimension
13 | self.sub_sample = sub_sample
14 |
15 | self.in_channels = in_channels
16 | self.inter_channels = inter_channels
17 |
18 | if self.inter_channels is None:
19 | self.inter_channels = in_channels // 2
20 | if self.inter_channels == 0:
21 | self.inter_channels = 1
22 |
23 | if dimension == 3:
24 | conv_nd = nn.Conv3d
25 | max_pool_layer = nn.MaxPool3d(kernel_size=(1, 2, 2))
26 | bn = nn.BatchNorm3d
27 | elif dimension == 2:
28 | conv_nd = nn.Conv2d
29 | max_pool_layer = nn.MaxPool2d(kernel_size=(2, 2))
30 | bn = nn.BatchNorm2d
31 | else:
32 | conv_nd = nn.Conv1d
33 | max_pool_layer = nn.MaxPool1d(kernel_size=(2))
34 | bn = nn.BatchNorm1d
35 |
36 | self.g = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
37 | kernel_size=1, stride=1, padding=0)
38 |
39 | if bn_layer:
40 | self.W = nn.Sequential(
41 | conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels,
42 | kernel_size=1, stride=1, padding=0),
43 | bn(self.in_channels)
44 | )
45 | nn.init.constant_(self.W[1].weight, 0)
46 | nn.init.constant_(self.W[1].bias, 0)
47 | else:
48 | self.W = conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels,
49 | kernel_size=1, stride=1, padding=0)
50 | nn.init.constant_(self.W.weight, 0)
51 | nn.init.constant_(self.W.bias, 0)
52 |
53 | if sub_sample:
54 | self.g = nn.Sequential(self.g, max_pool_layer)
55 | self.phi = max_pool_layer
56 |
57 | def forward(self, x, return_nl_map=False):
58 | """
59 | :param x: (b, c, t, h, w)
60 | :param return_nl_map: if True return z, nl_map, else only return z.
61 | :return:
62 | """
63 |
64 | batch_size = x.size(0)
65 |
66 | g_x = self.g(x).view(batch_size, self.inter_channels, -1)
67 |
68 | g_x = g_x.permute(0, 2, 1)
69 |
70 | theta_x = x.view(batch_size, self.in_channels, -1)
71 | theta_x = theta_x.permute(0, 2, 1)
72 |
73 | if self.sub_sample:
74 | phi_x = self.phi(x).view(batch_size, self.in_channels, -1)
75 | else:
76 | phi_x = x.view(batch_size, self.in_channels, -1)
77 |
78 | f = torch.matmul(theta_x, phi_x)
79 | f_div_C = F.softmax(f, dim=-1)
80 |
81 | # if self.store_last_batch_nl_map:
82 | # self.nl_map = f_div_C
83 |
84 | y = torch.matmul(f_div_C, g_x)
85 | y = y.permute(0, 2, 1).contiguous()
86 | y = y.view(batch_size, self.inter_channels, *x.size()[2:])
87 | W_y = self.W(y)
88 | z = W_y + x
89 |
90 | if return_nl_map:
91 | return z, f_div_C
92 | return z
93 |
94 |
95 | class NONLocalBlock1D(_NonLocalBlockND):
96 | def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True):
97 | super(NONLocalBlock1D, self).__init__(in_channels,
98 | inter_channels=inter_channels,
99 | dimension=1, sub_sample=sub_sample,
100 | bn_layer=bn_layer)
101 |
102 |
103 | class NONLocalBlock2D(_NonLocalBlockND):
104 | def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True):
105 | super(NONLocalBlock2D, self).__init__(in_channels,
106 | inter_channels=inter_channels,
107 | dimension=2, sub_sample=sub_sample,
108 | bn_layer=bn_layer)
109 |
110 |
111 | class NONLocalBlock3D(_NonLocalBlockND):
112 | def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True):
113 | super(NONLocalBlock3D, self).__init__(in_channels,
114 | inter_channels=inter_channels,
115 | dimension=3, sub_sample=sub_sample,
116 | bn_layer=bn_layer)
117 |
118 |
119 | if __name__ == '__main__':
120 | import torch
121 |
122 | for (sub_sample_, bn_layer_) in [(True, True), (False, False), (True, False), (False, True)]:
123 | img = torch.zeros(2, 3, 20)
124 | net = NONLocalBlock1D(3, sub_sample=sub_sample_, bn_layer=bn_layer_)
125 | out = net(img)
126 | print(out.size())
127 |
128 | img = torch.zeros(2, 3, 20, 20)
129 | net = NONLocalBlock2D(3, sub_sample=sub_sample_, bn_layer=bn_layer_)
130 | out = net(img)
131 | print(out.size())
132 |
133 | img = torch.randn(2, 3, 8, 20, 20)
134 | net = NONLocalBlock3D(3, sub_sample=sub_sample_, bn_layer=bn_layer_)
135 | out = net(img)
136 | print(out.size())
137 |
138 |
139 |
140 |
141 |
142 |
143 |
--------------------------------------------------------------------------------
/src/__pycache__/BAT_Modules.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jcwang123/BA-Transformer/d1ccdff68beac82d18c2cd64bb800e51d7afc5e1/src/__pycache__/BAT_Modules.cpython-37.pyc
--------------------------------------------------------------------------------
/src/__pycache__/losses.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jcwang123/BA-Transformer/d1ccdff68beac82d18c2cd64bb800e51d7afc5e1/src/__pycache__/losses.cpython-37.pyc
--------------------------------------------------------------------------------
/src/__pycache__/transformer.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jcwang123/BA-Transformer/d1ccdff68beac82d18c2cd64bb800e51d7afc5e1/src/__pycache__/transformer.cpython-37.pyc
--------------------------------------------------------------------------------
/src/__pycache__/utils.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jcwang123/BA-Transformer/d1ccdff68beac82d18c2cd64bb800e51d7afc5e1/src/__pycache__/utils.cpython-37.pyc
--------------------------------------------------------------------------------
/src/process_point.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import os
3 | import random
4 | import torch
5 | import numpy as np
6 | import skimage.draw
7 | from tqdm import tqdm
8 | import matplotlib.pyplot as plt
9 | import torch.nn.functional as F
10 |
11 |
12 | def create_circular_mask(h, w, center, radius):
13 | Y, X = np.ogrid[:h, :w]
14 | dist_from_center = np.sqrt((X - center[0])**2 + (Y - center[1])**2)
15 | mask = dist_from_center <= radius
16 | return mask
17 |
18 |
19 | def NMS(heatmap, kernel=13):
20 | hmax = F.max_pool2d(heatmap, kernel, stride=1, padding=(kernel - 1) // 2)
21 | keep = (hmax == heatmap).float()
22 | return heatmap * keep, hmax, keep
23 |
24 |
25 | def draw_msra_gaussian(heatmap, center, sigma):
26 | tmp_size = sigma * 3
27 | mu_x = int(center[0] + 0.5)
28 | mu_y = int(center[1] + 0.5)
29 | w, h = heatmap.shape[0], heatmap.shape[1]
30 | ul = [int(mu_x - tmp_size), int(mu_y - tmp_size)]
31 | br = [int(mu_x + tmp_size + 1), int(mu_y + tmp_size + 1)]
32 | if ul[0] >= h or ul[1] >= w or br[0] < 0 or br[1] < 0:
33 | return heatmap
34 | size = 2 * tmp_size + 1
35 | x = np.arange(0, size, 1, np.float32)
36 | y = x[:, np.newaxis]
37 | x0 = y0 = size // 2
38 | g = np.exp(-((x - x0)**2 + (y - y0)**2) / (2 * sigma**2))
39 | g_x = max(0, -ul[0]), min(br[0], h) - ul[0]
40 | g_y = max(0, -ul[1]), min(br[1], w) - ul[1]
41 | img_x = max(0, ul[0]), min(br[0], h)
42 | img_y = max(0, ul[1]), min(br[1], w)
43 | heatmap[img_y[0]:img_y[1], img_x[0]:img_x[1]] = np.maximum(
44 | heatmap[img_y[0]:img_y[1], img_x[0]:img_x[1]], g[g_y[0]:g_y[1],
45 | g_x[0]:g_x[1]])
46 | return heatmap
47 |
48 |
49 | def kpm_gen(label_path, R, N):
50 | label = np.load(label_path)
51 | # label = label[0]
52 | label_ori = label.copy()
53 | label = label[::4, ::4]
54 | label = np.uint8(label * 255)
55 | contours, hierarchy = cv2.findContours(label, cv2.RETR_LIST,
56 | cv2.CHAIN_APPROX_NONE)
57 | contour_len = len(contours)
58 |
59 | label = np.repeat(label[..., np.newaxis], 3, axis=-1)
60 | draw_label = cv2.drawContours(label.copy(), contours, -1, (0, 0, 255), 1)
61 |
62 | point_file = []
63 | if contour_len == 0:
64 | point_heatmap = np.zeros((512, 512))
65 | else:
66 | point_heatmap = np.zeros((512, 512))
67 | for contour in contours:
68 | stds = []
69 | points = contour[:, 0] # (N,2)
70 | points = points * 4
71 | points_number = contour.shape[0]
72 | if points_number < 30:
73 | continue
74 |
75 | if points_number < 100:
76 | radius = 6
77 | neighbor_points_n_oneside = 3
78 | elif points_number < 200:
79 | radius = 10
80 | neighbor_points_n_oneside = 15
81 | elif points_number < 300:
82 | radius = 10
83 | neighbor_points_n_oneside = 20
84 | elif points_number < 350:
85 | radius = 15
86 | neighbor_points_n_oneside = 30
87 | else:
88 | radius = 10
89 | neighbor_points_n_oneside = 40
90 |
91 | for i in range(points_number):
92 | current_point = points[i]
93 | mask = create_circular_mask(512, 512, points[i], radius)
94 | overlap_area = np.sum(
95 | mask * label_ori) / (np.pi * radius * radius)
96 | stds.append(overlap_area)
97 | print("stds len: ", len(stds))
98 |
99 | # show
100 | selected_points = []
101 | stds = np.array(stds)
102 | neighbor_points = []
103 | for i in range(len(points)):
104 | current_point = points[i]
105 | neighbor_points_index = np.concatenate([
106 | np.arange(-neighbor_points_n_oneside, 0),
107 | np.arange(1, neighbor_points_n_oneside + 1)
108 | ]) + i
109 | neighbor_points_index[np.where(
110 | neighbor_points_index < 0)[0]] += len(points)
111 | neighbor_points_index[np.where(
112 | neighbor_points_index > len(points) - 1)[0]] -= len(points)
113 | if stds[i] < np.min(
114 | stds[neighbor_points_index]) or stds[i] > np.max(
115 | stds[neighbor_points_index]):
116 | # print(points[i])
117 | point_heatmap = draw_msra_gaussian(
118 | point_heatmap, (points[i, 0], points[i, 1]), 5)
119 | selected_points.append(points[i])
120 |
121 | print("selected_points num: ", len(selected_points))
122 | # print(selected_points)
123 | maskk = np.zeros((512, 512))
124 | rr, cc = skimage.draw.polygon(
125 | np.array(selected_points)[:, 1],
126 | np.array(selected_points)[:, 0])
127 | maskk[rr, cc] = 1
128 | intersection = np.logical_and(label_ori, maskk)
129 | union = np.logical_or(label_ori, maskk)
130 | iou_score = np.sum(intersection) / np.sum(union)
131 | print(iou_score)
132 | return label_ori, point_heatmap
133 |
134 |
135 | def point_gen_isic2018():
136 | R = 10
137 | N = 25
138 | data_dir = '/raid/wjc/data/skin_lesion/isic2018/Label'
139 |
140 | save_dir = data_dir.replace('Label', 'Point')
141 | os.makedirs(save_dir, exist_ok=True)
142 |
143 | path_list = os.listdir(data_dir)
144 | path_list.sort()
145 | num = 0
146 | for path in tqdm(path_list):
147 | name = path[:-4]
148 | label_path = os.path.join(data_dir, path)
149 | print(label_path)
150 | label_ori, point_heatmap = kpm_gen(label_path, R, N)
151 |
152 | save_path = os.path.join(save_dir, name + '.npy')
153 | np.save(save_path, point_heatmap)
154 | num += 1
155 |
156 |
157 | def point_gen_isic2016():
158 | R = 10
159 | N = 25
160 | for split in ['Train', 'Test', 'Validation']:
161 | data_dir = '/raid/wjc/data/skin_lesion/isic2016/{}/Label'.format(split)
162 |
163 | save_dir = data_dir.replace('Label', 'Point')
164 | os.makedirs(save_dir, exist_ok=True)
165 |
166 | path_list = os.listdir(data_dir)
167 | path_list.sort()
168 | num = 0
169 | for path in tqdm(path_list):
170 | name = path[:-4]
171 | label_path = os.path.join(data_dir, path)
172 | print(label_path)
173 | label_ori, point_heatmap = kpm_gen(label_path, R, N)
174 | save_path = os.path.join(save_dir, name + '.npy')
175 | np.save(save_path, point_heatmap)
176 | num += 1
177 |
178 |
179 | if __name__ == '__main__':
180 | # point_gen_isic2018()
181 | point_gen_isic2016()
--------------------------------------------------------------------------------
/src/process_resize.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import os
3 | import random
4 | import torch
5 | import numpy as np
6 | from tqdm import tqdm
7 | import matplotlib.pyplot as plt
8 |
9 |
10 | def process_isic2018(
11 | dim=(352, 352), save_dir='/raid/wjc/data/skin_lesion/isic2018/'):
12 | image_dir_path = '/raid/wl/2018_raw_data/ISIC2018_Task1-2_Training_Input/'
13 | mask_dir_path = '/raid/wl/2018_raw_data/ISIC2018_Task1_Training_GroundTruth/'
14 |
15 | image_path_list = os.listdir(image_dir_path)
16 | mask_path_list = os.listdir(mask_dir_path)
17 |
18 | image_path_list = list(filter(lambda x: x[-3:] == 'jpg', image_path_list))
19 | mask_path_list = list(filter(lambda x: x[-3:] == 'png', mask_path_list))
20 |
21 | image_path_list.sort()
22 | mask_path_list.sort()
23 |
24 | print(len(image_path_list), len(mask_path_list))
25 |
26 | # ISBI Dataset
27 | for image_path, mask_path in zip(image_path_list, mask_path_list):
28 | if image_path[-3:] == 'jpg':
29 | print(image_path)
30 | assert os.path.basename(image_path)[:-4].split(
31 | '_')[1] == os.path.basename(mask_path)[:-4].split('_')[1]
32 | _id = os.path.basename(image_path)[:-4].split('_')[1]
33 | image_path = os.path.join(image_dir_path, image_path)
34 | mask_path = os.path.join(mask_dir_path, mask_path)
35 | image = cv2.imread(image_path)
36 | mask = cv2.imread(mask_path)
37 |
38 | image_new = cv2.resize(image, dim, interpolation=cv2.INTER_CUBIC)
39 | image_new = np.array(image_new, dtype=np.uint8)
40 | mask_new = cv2.resize(mask, dim, interpolation=cv2.INTER_NEAREST)
41 | mask_new = cv2.blur(img,(3,3))
42 | mask_new = np.array(mask_new, dtype=np.uint8)
43 |
44 | save_dir_path = save_dir + '/Image'
45 | os.makedirs(save_dir_path, exist_ok=True)
46 | # np.save(os.path.join(save_dir_path, _id + '.npy'), image_new)
47 | print(image_new.shape)
48 | cv2.imwrite(os.path.join(save_dir_path, 'ISIC_' + _id + '.jpg'),
49 | image_new)
50 |
51 | save_dir_path = save_dir + '/Label'
52 | os.makedirs(save_dir_path, exist_ok=True)
53 | # np.save(os.path.join(save_dir_path, _id + '.npy'), mask_new)
54 | cv2.imwrite(os.path.join(save_dir_path, 'ISIC_' + _id + '.jpg'),
55 | mask_new)
56 |
57 |
58 | def process_ph2():
59 | PH2_images_path = '/data2/cf_data/skinlesion_segment/PH2_rawdata/PH2_Dataset_images'
60 |
61 | path_list = os.listdir(PH2_images_path)
62 | path_list.sort()
63 |
64 | for path in path_list:
65 | image_path = os.path.join(PH2_images_path, path,
66 | path + '_Dermoscopic_Image', path + '.bmp')
67 | label_path = os.path.join(PH2_images_path, path, path + '_lesion',
68 | path + '_lesion.bmp')
69 | image = plt.imread(image_path)
70 | label = plt.imread(label_path)
71 | label = label[:, :, 0]
72 |
73 | dim = (512, 512)
74 | image_new = cv2.resize(image, dim, interpolation=cv2.INTER_AREA)
75 | label_new = cv2.resize(label, dim, interpolation=cv2.INTER_AREA)
76 |
77 | image_save_path = os.path.join(
78 | '/data2/cf_data/skinlesion_segment/PH2_rawdata/PH2/Image',
79 | path + '.npy')
80 | label_save_path = os.path.join(
81 | '/data2/cf_data/skinlesion_segment/PH2_rawdata/PH2/Label',
82 | path + '.npy')
83 |
84 | np.save(image_save_path, image_new)
85 | np.save(label_save_path, label_new)
86 |
87 |
88 | if __name__ == '__main__':
89 | process_isic2018(dim=(352,352), save_dir='/raid/wjc/data/skin_lesion/isic2018_jpg_352_smooth/')
90 |
--------------------------------------------------------------------------------
/src/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from collections import OrderedDict
4 |
5 |
6 | def load_model(model, pretrain_dir, log=True):
7 | state_dict_ = torch.load(pretrain_dir, map_location='cuda:0')
8 | print('loaded pretrained weights form %s !' % pretrain_dir)
9 | state_dict = OrderedDict()
10 |
11 | # convert data_parallal to model
12 | for key in state_dict_:
13 | if key.startswith('module') and not key.startswith('module_list'):
14 | state_dict[key[7:]] = state_dict_[key]
15 | else:
16 | state_dict[key] = state_dict_[key]
17 |
18 | # check loaded parameters and created model parameters
19 | model_state_dict = model.state_dict()
20 | for key in state_dict:
21 | if key in model_state_dict:
22 | # print(key,state_dict[key].shape,model_state_dict[key].shape)
23 | if state_dict[key].shape != model_state_dict[key].shape:
24 | if log:
25 | print(
26 | 'Skip loading parameter {}, required shape{}, loaded shape{}.'
27 | .format(key, model_state_dict[key].shape,
28 | state_dict[key].shape))
29 | state_dict[key] = model_state_dict[key]
30 | else:
31 | if log:
32 | print('Drop parameter {}.'.format(key))
33 | for key in model_state_dict:
34 | if key not in state_dict:
35 | if log:
36 | print('No param {}.'.format(key))
37 | state_dict[key] = model_state_dict[key]
38 | model.load_state_dict(state_dict, strict=False)
39 |
40 | return model
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | import os, argparse, sys, tqdm, logging, cv2
2 | import torch
3 | import torch.nn as nn
4 | import numpy as np
5 | from glob import glob
6 | import torch.nn.functional as F
7 | from medpy.metric.binary import hd, hd95, dc, jc, assd
8 |
9 | parser = argparse.ArgumentParser()
10 | parser.add_argument('--log_name',
11 | type=str,
12 | default='bat_1_1_0_e6_loss_0_aug_1')
13 | parser.add_argument('--gpu', type=str, default='1')
14 | parser.add_argument('--fold', type=str, default='0')
15 | parser.add_argument('--dataset', type=str, default='isic2016')
16 |
17 | parser.add_argument('--arch', type=str, default='BAT')
18 | parser.add_argument('--net_layer', type=int, default=50)
19 | # pre-train
20 | parser.add_argument('--pre', type=int, default=0)
21 |
22 | # transformer
23 | parser.add_argument('--trans', type=int, default=1)
24 |
25 | # point constrain
26 | parser.add_argument('--point_pred', type=int, default=1)
27 | parser.add_argument('--ppl', type=int, default=6)
28 |
29 | # cross-scale framework
30 | parser.add_argument('--cross', type=int, default=0)
31 |
32 | parse_config = parser.parse_args()
33 | print(parse_config)
34 | os.environ['CUDA_VISIBLE_DEVICES'] = parse_config.gpu
35 |
36 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
37 |
38 | if parse_config.dataset == 'isic2018':
39 | from dataset.isic2018 import norm01, myDataset
40 | dataset = myDataset(parse_config.fold, 'valid', aug=False)
41 | elif parse_config.dataset == 'isic2016':
42 | from dataset.isic2016 import norm01, myDataset
43 | dataset = myDataset('test', aug=False)
44 |
45 | if parse_config.arch == 'BAT':
46 | if parse_config.trans == 1:
47 | from Ours.Base_transformer import BAT
48 | model = BAT(1, parse_config.net_layer, parse_config.point_pred,
49 | parse_config.ppl).cuda()
50 | else:
51 | from Ours.base import DeepLabV3
52 | model = DeepLabV3(1, parse_config.net_layer).cuda()
53 |
54 | dir_path = os.path.dirname(
55 | os.path.abspath(__file__)) + "/logs/{}/{}/fold_{}/".format(
56 | parse_config.dataset, parse_config.log_name, parse_config.fold)
57 |
58 | from src.utils import load_model
59 |
60 | model = load_model(model, dir_path + 'model/best.pkl')
61 |
62 | # logging
63 | txt_path = os.path.join(dir_path + 'parameter.txt')
64 | logging.basicConfig(filename=txt_path,
65 | level=logging.INFO,
66 | format='[%(asctime)s.%(msecs)03d] %(message)s',
67 | datefmt='%H:%M:%S')
68 | logging.getLogger().addHandler(logging.StreamHandler(sys.stdout))
69 | test_loader = torch.utils.data.DataLoader(dataset,
70 | batch_size=8,
71 | pin_memory=True,
72 | drop_last=False,
73 | shuffle=False)
74 |
75 |
76 | def test():
77 | model.eval()
78 | num = 0
79 |
80 | dice_value = 0
81 | jc_value = 0
82 | hd95_value = 0
83 | assd_value = 0
84 |
85 | from tqdm import tqdm
86 | labels = []
87 | pres = []
88 | for batch_idx, batch_data in tqdm(enumerate(test_loader)):
89 | data = batch_data['image'].to(device).float()
90 | label = batch_data['label'].to(device).float()
91 | with torch.no_grad():
92 | if parse_config.arch == 'transfuse':
93 | _, _, output = model(data)
94 | elif parse_config.point_pred == 0:
95 | output = model(data)
96 | elif parse_config.point_pred == 1:
97 | output, _ = model(data)
98 | output = torch.sigmoid(output)
99 | output = output.cpu().numpy() > 0.5
100 | label = label.cpu().numpy()
101 | assert (output.shape == label.shape)
102 | labels.append(label)
103 | pres.append(output)
104 | labels = np.concatenate(labels, axis=0)
105 | pres = np.concatenate(pres, axis=0)
106 | print(labels.shape, pres.shape)
107 | for _id in range(labels.shape[0]):
108 | dice_ave = dc(labels[_id], pres[_id])
109 | jc_ave = jc(labels[_id], pres[_id])
110 | try:
111 | hd95_ave = hd95(labels[_id], pres[_id])
112 | assd_ave = assd(labels[_id], pres[_id])
113 | except RuntimeError:
114 | num += 1
115 | hd95_ave = 0
116 | assd_ave = 0
117 |
118 | dice_value += dice_ave
119 | jc_value += jc_ave
120 | hd95_value += hd95_ave
121 | assd_value += assd_ave
122 |
123 | dice_average = dice_value / (labels.shape[0] - num)
124 | jc_average = jc_value / (labels.shape[0] - num)
125 | hd95_average = hd95_value / (labels.shape[0] - num)
126 | assd_average = assd_value / (labels.shape[0] - num)
127 |
128 | logging.info('Dice value of test dataset : %f' % (dice_average))
129 | logging.info('Jc value of test dataset : %f' % (jc_average))
130 | logging.info('Hd95 value of test dataset : %f' % (hd95_average))
131 | logging.info('Assd value of test dataset : %f' % (assd_average))
132 |
133 | print("Average dice value of evaluation dataset = ", dice_average)
134 | print("Average jc value of evaluation dataset = ", jc_average)
135 | print("Average hd95 value of evaluation dataset = ", hd95_average)
136 | print("Average assd value of evaluation dataset = ", assd_average)
137 | return dice_average
138 |
139 |
140 | if __name__ == '__main__':
141 | test()
--------------------------------------------------------------------------------
| |