├── CircConv.py ├── README.md ├── dataset.py ├── generate_data.py ├── generate_edge_img.py ├── loss_function.py ├── models.py ├── test.py ├── train.py └── utils.py /CircConv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | class point_att(nn.Module): 5 | 6 | def __init__(self): 7 | super(point_att, self).__init__() 8 | 9 | self.gamma = nn.Parameter(torch.zeros(1)) 10 | self.softmax = nn.Softmax(dim=-1) 11 | 12 | def forward(self, snake_feature): 13 | 14 | proj_query = snake_feature.permute(0, 2, 1) 15 | energy = torch.bmm(proj_query, snake_feature) 16 | n_energy = torch.max(energy, -1, keepdim=True)[0].expand_as(energy) 17 | energy = self.softmax(energy - n_energy) 18 | 19 | proj_value = snake_feature.permute(0, 2, 1) 20 | 21 | out = torch.bmm(energy, proj_value) 22 | out = out.permute(0, 2, 1) 23 | 24 | out = self.gamma*out + snake_feature 25 | return out 26 | 27 | class DilatedCirConv(nn.Module): 28 | 29 | def __init__(self, state_dim, out_state_dim, n_adj=2, dilation=1): 30 | super(DilatedCirConv, self).__init__() 31 | 32 | self.n_adj = n_adj 33 | self.dilation = dilation 34 | self.circconv = nn.Conv1d(state_dim, out_state_dim, kernel_size=self.n_adj * 2 + 1, dilation=self.dilation) 35 | 36 | def forward(self, x): 37 | x = torch.cat([x[..., -self.n_adj * self.dilation:], x, x[..., :self.n_adj * self.dilation]], dim=2) 38 | return self.circconv(x) 39 | 40 | 41 | class BasicBlock(nn.Module): 42 | 43 | def __init__(self, state_dim, out_state_dim, n_adj=2, dilation=1): 44 | super(BasicBlock, self).__init__() 45 | 46 | self.circblock = DilatedCirConv(state_dim, out_state_dim, n_adj, dilation) 47 | self.circrelu = nn.ReLU(inplace=True) 48 | self.circnorm = nn.BatchNorm1d(out_state_dim) 49 | 50 | def forward(self, x): 51 | x = self.circblock(x) 52 | x = self.circrelu(x) 53 | x = self.circnorm(x) 54 | 55 | return x 56 | 57 | class AttSnake(nn.Module): 58 | 59 | def __init__(self, n_adj): 60 | super(AttSnake, self).__init__() 61 | 62 | self.head = BasicBlock(130, 64, n_adj=n_adj) 63 | self.res_layer_num = 7 64 | dilation = [1, 1, 1, 2, 2, 4, 4] 65 | for i in range(self.res_layer_num): 66 | circconv = BasicBlock(64, 64, n_adj=n_adj, dilation=dilation[i]) 67 | self.__setattr__('circconv'+str(i), circconv) 68 | 69 | self.fusion = nn.Conv1d(512, 128, 1) 70 | 71 | self.att_point = point_att() 72 | 73 | def forward(self, app_features): 74 | 75 | states = [] 76 | 77 | x = self.head(app_features) 78 | states.append(x) 79 | for i in range(self.res_layer_num): 80 | x = self.__getattr__('circconv'+str(i))(x) + x 81 | states.append(x) 82 | 83 | state = torch.cat(states, dim=1) 84 | snake_feature = self.att_point(self.fusion(state)) 85 | #snake_feature = self.fusion(state) 86 | 87 | return snake_feature -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Automated Segmentation of Prohibited Items in X-ray Baggage Images Using Dense De-overlap Attention Snake 2 | 3 | Dataset and Code are updated! 4 | 5 | [[Paper]](https://ieeexplore.ieee.org/document/9772992) [[Dataset]](https://pan.baidu.com/s/11jMmECsjvW49N1NwLb8iIg?pwd=vnyw) 6 | 7 | :smile:If you find this dataset useful for your research, please cite 8 | 9 | ```bibtex 10 | @ARTICLE{9772992, 11 | author={Ma, Bowen and Jia, Tong and Su, Min and Jia, Xiaodong and Chen, Dongyue and Zhang, Yichun}, 12 | journal={IEEE Transactions on Multimedia}, 13 | title={Automated Segmentation of Prohibited Items in X-ray Baggage Images Using Dense De-overlap Attention Snake}, 14 | year={2022}, 15 | volume={}, 16 | number={}, 17 | pages={1-1}, 18 | doi={10.1109/TMM.2022.3174339}} 19 | ``` 20 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import json 3 | from PIL import Image 4 | import torch 5 | from torch.utils.data import Dataset 6 | from utils import get_ic 7 | import torchvision.transforms as transforms 8 | 9 | 10 | prohibited_item_classes = {'Gun': 0, 'Knife': 1, 'Wrench': 2, 'Pliers': 3, 'Scissors': 4, 'Lighter': 5, 'Battery': 6, 11 | 'Bat': 7, 'Razor_blade': 8, 'Saw_blade': 9, 'Fireworks': 10, 'Hammer': 11, 12 | 'Screwdriver': 12, 'Dart': 13, 'Pressure_vessel': 14} 13 | #cityscapes_classes = {'person': 0, 'car': 1, 'truck': 2, 'bicycle': 3, 'motorcycle': 4, 'rider': 5, 14 | # 'bus': 6, 'train': 7} 15 | 16 | trans = transforms.ToTensor() 17 | normalize = transforms.Normalize(mean=[0.838, 0.855, 0.784], 18 | std=[0.268, 0.225, 0.291]) 19 | 20 | class data_loader(Dataset): 21 | def __init__(self, split): 22 | 23 | self.split = split 24 | assert self.split in {'trainset', 'valset', 'testset'} 25 | 26 | self.dataset_size = len(glob.glob('{}/*.json'.format(split))) 27 | 28 | def __getitem__(self, ind): 29 | img_path = '{}/{}.png'.format(self.split, ind) 30 | edge_path = '{}/{}b.png'.format(self.split, ind) 31 | gt_path = '{}/{}.json'.format(self.split, ind) 32 | 33 | # Open json file where ground-truth are stored 34 | with open(gt_path, 'r', encoding='utf8', errors='ignore') as j: 35 | gt = json.load(j) 36 | label_ind = torch.LongTensor([prohibited_item_classes[gt['label']]]) 37 | gt_ellipse = torch.FloatTensor(gt['polygon']) 38 | 39 | img = Image.open(img_path) 40 | img_edge = Image.open(edge_path) 41 | 42 | # Get the initial contour 43 | input_ellipse = get_ic(scale=4.) 44 | 45 | # PyTorch transformation pipeline for the image (totensor, normalizing, etc.) 46 | img = normalize(trans(img)) 47 | img_edge = trans(img_edge) 48 | 49 | return img, img_edge, input_ellipse, label_ind, gt_ellipse 50 | 51 | def __len__(self): 52 | 53 | return self.dataset_size 54 | 55 | class data_loader_test(data_loader): 56 | def __init__(self, split): 57 | 58 | self.split = split 59 | self.dataset_size = len(glob.glob('{}/*.json'.format(split))) 60 | 61 | def __getitem__(self, ind): 62 | img_path = '{}/{}.png'.format(self.split, ind) 63 | gt_path = '{}/{}.json'.format(self.split, ind) 64 | 65 | # Open json file where ground-truth are stored 66 | with open(gt_path, 'r', encoding='utf8', errors='ignore') as j: 67 | gt = json.load(j) 68 | label_ind = torch.LongTensor([prohibited_item_classes[gt['label']]]) 69 | gt_ellipse = torch.FloatTensor(gt['polygon']) 70 | 71 | img = Image.open(img_path) 72 | 73 | # Get the initial contour 74 | input_ellipse = get_ic(scale=4.) 75 | 76 | # PyTorch transformation pipeline for the image (totensor, normalizing, etc.) 77 | img = normalize(trans(img)) 78 | 79 | return img, input_ellipse, label_ind, gt_ellipse 80 | -------------------------------------------------------------------------------- /generate_data.py: -------------------------------------------------------------------------------- 1 | import json 2 | import glob 3 | import numpy as np 4 | from PIL import Image 5 | 6 | 7 | def main(): 8 | cityscapes_classes = ['person', 'car', 'truck', 'bicycle', 'motorcycle', 'rider', 9 | 'bus', 'train'] 10 | prohibited_item_classes = ['Gun', 'Knife', 'Wrench', 'Pliers', 'Scissors', 'Lighter', 'Battery', 'Bat', 'Razor_blade', 11 | 'Saw_blade', 'Fireworks', 'Hammer', 'Screwdriver', 'Dart', 'Pressure_vessel'] 12 | files = glob.glob('./train_set/*.png') 13 | count = 0 14 | for file in files: 15 | #label_path = file[:-15] + 'gtFine_polygons.json' 16 | label_path = file[0:-3] + 'json' 17 | with open(label_path, 'r', encoding='utf8', errors='ignore') as j: 18 | label = json.load(j) 19 | Img = Image.open(file) 20 | w = Img.size[0] 21 | h = Img.size[1] 22 | items = label['objects'] 23 | for single_prohibited_item in items: 24 | if single_prohibited_item['label'] in prohibited_item_classes: 25 | # image process 26 | min_b = np.min(np.array(single_prohibited_item['polygon']), axis=0) 27 | max_b = np.max(np.array(single_prohibited_item['polygon']), axis=0) 28 | single_prohibited_item_h = max_b[1] - min_b[1] 29 | single_prohibited_item_w = max_b[0] - min_b[0] 30 | h_extend = int(round(0.2 * single_prohibited_item_h)) 31 | w_extend = int(round(0.2 * single_prohibited_item_w)) 32 | min_w = np.maximum(0, min_b[0] - w_extend) 33 | min_h = np.maximum(0, min_b[1] - h_extend) 34 | max_w = np.minimum(w, max_b[0] + w_extend) 35 | max_h = np.minimum(h, max_b[1] + h_extend) 36 | single_prohibited_item_new_w = max_w - min_w 37 | single_prohibited_item_new_h = max_h - min_h 38 | 39 | scale_w = 224.0 / single_prohibited_item_new_w 40 | scale_h = 224.0 / single_prohibited_item_new_h 41 | new_single_prohibited_item = Img.crop(box=(min_w, min_h, max_w, max_h)) 42 | new_single_prohibited_item = new_single_prohibited_item.resize((224, 224), Image.BILINEAR) 43 | new_single_prohibited_item.save('./trainset/' + str(count) + '.png') 44 | 45 | # label process 46 | dict = {} 47 | dict['label'] = single_prohibited_item['label'] 48 | dict['polygon'] = [] 49 | polygon_list = single_prohibited_item['polygon'][:] 50 | 51 | expend = True 52 | while(expend): 53 | if len(polygon_list) < 60: 54 | if len(polygon_list) < 31: 55 | n = len(polygon_list)-1 56 | for i in range(n): 57 | new_w = (polygon_list[i][0] + polygon_list[i + 1][0]) / 2 58 | new_h = (polygon_list[i][1] + polygon_list[i + 1][1]) / 2 59 | single_prohibited_item['polygon'].insert((2 * i) + 1, [new_w, new_h]) 60 | polygon_list = single_prohibited_item['polygon'][:] 61 | else: 62 | n = 60 - len(polygon_list) 63 | for i in range(n): 64 | new_w = (polygon_list[i][0] + polygon_list[i+1][0]) / 2 65 | new_h = (polygon_list[i][1] + polygon_list[i+1][1]) / 2 66 | single_prohibited_item['polygon'].insert((2*i)+1, [new_w, new_h]) 67 | polygon_list = single_prohibited_item['polygon'][:] 68 | else: 69 | if len(polygon_list) == 60: 70 | for point in single_prohibited_item['polygon']: 71 | index_w = (point[0] - min_w) * scale_w 72 | index_h = (point[1] - min_h) * scale_h 73 | index_w = np.maximum(0, np.minimum(223, index_w)) 74 | index_h = np.maximum(0, np.minimum(223, index_h)) 75 | dict['polygon'].append([index_w, index_h]) 76 | expend = False 77 | else: 78 | scale = len(polygon_list) * 1.0 / 60 79 | index_list = (np.arange(0, 60) * scale).astype(int) 80 | for point in np.array(single_prohibited_item['polygon'])[index_list]: 81 | index_w = (point[0] - min_w) * scale_w 82 | index_h = (point[1] - min_h) * scale_h 83 | index_w = np.maximum(0, np.minimum(223, index_w)) 84 | index_h = np.maximum(0, np.minimum(223, index_h)) 85 | dict['polygon'].append([index_w, index_h]) 86 | expend = False 87 | 88 | with open('./trainset/' + str(count) + '.json', 'w') as j: 89 | json.dump(dict, j) 90 | count += 1 91 | print(count) 92 | 93 | 94 | if __name__ == '__main__': 95 | main() 96 | 97 | 98 | 99 | 100 | 101 | -------------------------------------------------------------------------------- /generate_edge_img.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | import glob 4 | import json 5 | 6 | gt_files = glob.glob('./trainset/*.json') 7 | for file in gt_files: 8 | name = file[11:-5] 9 | #print(name) 10 | with open(file, 'r', encoding='utf8', errors='ignore') as j: 11 | gt = json.load(j) 12 | polygon = np.array(gt['polygon']) 13 | polygon /= 4.0 14 | bg = np.zeros((56, 56), dtype=np.uint8) 15 | 16 | mask = cv2.fillPoly(bg, np.int32([polygon]), 255) 17 | 18 | mask = mask > 128 19 | mask = np.asarray(mask, dtype=np.double) 20 | gx, gy = np.gradient(mask) 21 | 22 | boundary = gy * gy + gx * gx 23 | 24 | boundary[boundary != 0.0] = 255.0 25 | 26 | boundary = np.asarray(boundary, dtype=np.uint8) 27 | 28 | cv2.imwrite('./trainset/' + str(name) + 'b.png', boundary) -------------------------------------------------------------------------------- /loss_function.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | 5 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 6 | 7 | 8 | class loss(nn.Module): 9 | def __init__(self): 10 | super(loss, self).__init__() 11 | 12 | def smooth_l1(self, x): 13 | 14 | if torch.abs(x) < 1: 15 | y = 0.5 * (x ** 2) 16 | else: 17 | y = torch.abs(x) - 0.5 18 | 19 | return y 20 | 21 | def forward(self, cls_feature, label_ind, new_ellipse, gt_ellipse, edge_feature, img_edge): 22 | 23 | bs = new_ellipse.size(0) 24 | # loss_1 for classification 25 | metric_1 = nn.CrossEntropyLoss() 26 | l1 = metric_1(cls_feature, label_ind) 27 | 28 | # loss_2 for edge supervision 29 | l2 = F.binary_cross_entropy_with_logits(edge_feature, img_edge, reduction='sum') / bs 30 | 31 | # loss_3 for distance regression 32 | s = torch.pow(gt_ellipse - new_ellipse, 2) 33 | dist = torch.sqrt(s[:, :, 0] + s[:, :, 1] + 1e-10) 34 | l3 = 0. 35 | for i in range(bs): 36 | p = dist[i] 37 | for j in range(60): 38 | y = self.smooth_l1(p[j]) 39 | l3 += y 40 | l3 = l3 / bs 41 | 42 | return l1, l2, l3 43 | 44 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torchvision 4 | from CircConv import AttSnake 5 | from thop import profile 6 | 7 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 8 | 9 | class DDoM(nn.Module): 10 | 11 | def __init__(self): 12 | 13 | super(DDoM, self).__init__() 14 | 15 | self.fc0 = nn.Linear(15, 1024) 16 | 17 | self.sig = nn.Sigmoid() 18 | 19 | self.fc1 = nn.Linear(1024, 2048) 20 | 21 | self.avg_pool2 = nn.AvgPool2d(7, stride=1) 22 | self.fc2 = nn.Linear(3072, 1024) 23 | 24 | self.avg_pool3 = nn.AvgPool2d(14, stride=1) 25 | self.fc3 = nn.Linear(2048, 512) 26 | 27 | self.avg_pool4 = nn.AvgPool2d(28, stride=1) 28 | self.fc4 = nn.Linear(1536, 256) 29 | 30 | def init_weights(self): 31 | self.fc0.weight.data.normal_(0, 0.01) 32 | self.fc0.bias.data.zero_() 33 | self.fc1.weight.data.normal_(0, 0.01) 34 | self.fc1.bias.data.zero_() 35 | self.fc2.weight.data.normal_(0, 0.01) 36 | self.fc2.bias.data.zero_() 37 | self.fc3.weight.data.normal_(0, 0.01) 38 | self.fc3.bias.data.zero_() 39 | self.fc4.weight.data.normal_(0, 0.01) 40 | self.fc4.bias.data.zero_() 41 | 42 | def forward(self, cls_feature, o4, o3, o2, o1): 43 | 44 | bs = o4.size(0) 45 | 46 | gs = self.fc0(cls_feature) 47 | 48 | gs1 = torch.bernoulli(self.sig(self.fc1(gs))) 49 | o4 = o4.view(bs, 2048, -1) 50 | gs1 = gs1.unsqueeze(2).expand_as(o4) 51 | doo4 = o4.mul(gs1).view(bs, 2048, 7, 7) 52 | 53 | gs2 = self.avg_pool2(doo4).view(bs, -1) 54 | gs2 = torch.bernoulli(self.sig(self.fc2(torch.cat([gs, gs2], dim=1)))) 55 | o3 = o3.view(bs, 1024, -1) 56 | gs2 = gs2.unsqueeze(2).expand_as(o3) 57 | doo3 = o3.mul(gs2).view(bs, 1024, 14, 14) 58 | 59 | gs3 = self.avg_pool3(doo3).view(bs, -1) 60 | gs3 = torch.bernoulli(self.sig(self.fc3(torch.cat([gs, gs3], dim=1)))) 61 | o2 = o2.view(bs, 512, -1) 62 | gs3 = gs3.unsqueeze(2).expand_as(o2) 63 | doo2 = o2.mul(gs3).view(bs, 512, 28, 28) 64 | 65 | gs4 = self.avg_pool4(doo2).view(bs, -1) 66 | gs4 = torch.bernoulli(self.sig(self.fc4(torch.cat([gs, gs4], dim=1)))) 67 | o1 = o1.view(bs, 256, -1) 68 | gs4 = gs4.unsqueeze(2).expand_as(o1) 69 | doo1 = o1.mul(gs4).view(bs, 256, 56, 56) 70 | 71 | return doo4, doo3, doo2, doo1 72 | 73 | 74 | class DDoAS(nn.Module): 75 | 76 | def __init__(self, num_classes): 77 | 78 | ''' 79 | Dense De-overlap Attention Snake for real-time prohibited item segmentation 80 | ''' 81 | 82 | super(DDoAS, self).__init__() 83 | 84 | # load pretrained model (we take ResNet-50 as an example) 85 | resnet = torchvision.models.resnet50(pretrained=False) 86 | resnet.load_state_dict(torch.load('./checkpoint/resnet50-19c8e357.pth')) 87 | 88 | for item in resnet.children(): 89 | if isinstance(item, nn.BatchNorm2d): 90 | item.affine = False 91 | 92 | self.base_features = nn.Sequential(resnet.conv1, 93 | resnet.bn1, 94 | resnet.relu, 95 | resnet.maxpool) 96 | self.res1 = resnet.layer1 97 | self.res2 = resnet.layer2 98 | self.res3 = resnet.layer3 99 | self.res4 = resnet.layer4 100 | 101 | # set Dense De-overlap Module 102 | 103 | self.ddom = DDoM() 104 | 105 | # conv layer for enhance features 106 | self.enconv4 = nn.Sequential(nn.Conv2d(2048, 256, 7, 1, 3), nn.ReLU(inplace=True), 107 | nn.Conv2d(256, 256, 7, 1, 3), nn.ReLU(inplace=True), 108 | nn.Conv2d(256, 256, 7, 1, 3), nn.ReLU(inplace=True), 109 | nn.Upsample(scale_factor=8, mode='bilinear', align_corners=True)) 110 | 111 | self.enconv3 = nn.Sequential(nn.Conv2d(1024, 256, 5, 1, 2), nn.ReLU(inplace=True), 112 | nn.Conv2d(256, 256, 5, 1, 2), nn.ReLU(inplace=True), 113 | nn.Conv2d(256, 256, 5, 1, 2), nn.ReLU(inplace=True), 114 | nn.Upsample(scale_factor=4, mode='bilinear', align_corners=True)) 115 | 116 | self.enconv2 = nn.Sequential(nn.Conv2d(512, 256, 5, 1, 2), nn.ReLU(inplace=True), 117 | nn.Conv2d(256, 256, 5, 1, 2), nn.ReLU(inplace=True), 118 | nn.Conv2d(256, 256, 5, 1, 2), nn.ReLU(inplace=True), 119 | nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)) 120 | 121 | self.enconv1 = nn.Sequential(nn.Conv2d(256, 256, 3, 1, 1), nn.ReLU(inplace=True), 122 | nn.Conv2d(256, 256, 3, 1, 1), nn.ReLU(inplace=True), 123 | nn.Conv2d(256, 256, 3, 1, 1), nn.ReLU(inplace=True)) 124 | 125 | self.enff = nn.Sequential(nn.Conv2d(256, 128, 3, 1, 1), nn.ReLU(inplace=True), 126 | nn.Conv2d(128, 128, 3, 1, 1), nn.ReLU(inplace=True), 127 | nn.Conv2d(128, 128, 3, 1, 1), nn.ReLU(inplace=True)) 128 | 129 | # linear layer for classification 130 | self.avg_pool = nn.AvgPool2d(7, stride=1) 131 | self.cls_spv = nn.Linear(2048, num_classes) 132 | 133 | # edge supervision 134 | self.edge_spv = nn.Conv2d(128, 1, 3, 1, 1) 135 | 136 | # generate weight map 137 | self.att_F = nn.Sequential(nn.Conv2d(128, 128, 3, 1, 1), nn.ReLU(inplace=True), 138 | nn.Conv2d(128, 64, 3, 1, 1), nn.ReLU(inplace=True), 139 | nn.Conv2d(64, 1, 3, 1, 1), nn.ReLU(inplace=True), 140 | nn.Sigmoid()) 141 | 142 | # set Attention Deforming Module 143 | self.attsnake = AttSnake(n_adj=2) 144 | 145 | self.prediction = nn.Sequential(nn.Conv1d(128, 64, 1), 146 | nn.ReLU(inplace=True), 147 | nn.Conv1d(64, 64, 1), 148 | nn.ReLU(inplace=True), 149 | nn.Conv1d(64, 2, 1)) 150 | 151 | # Initializes some parameters for easier convergence 152 | self.init_weights() 153 | 154 | def get_bilinear_interpolation(self, ff, Fatt, img_poly): 155 | 156 | ''' 157 | Extract each vertex features by using bilinear interpolation 158 | ''' 159 | 160 | img_poly = img_poly.clone() 161 | img_poly[..., 0] = img_poly[..., 0] / 56. - 1 162 | img_poly[..., 1] = img_poly[..., 1] / 56. - 1 163 | 164 | bs = ff.size(0) 165 | gcn_feature = torch.zeros([bs, ff.size(1), img_poly.size(1)]).to(device) 166 | pw = torch.zeros([bs, 1, img_poly.size(1)]).to(device) 167 | 168 | for i in range(bs): 169 | grid = img_poly[i:i + 1].unsqueeze(1) 170 | bilinear_feature = torch.nn.functional.grid_sample(ff[i:i + 1], grid=grid, align_corners=True)[0].permute(1, 0, 2) 171 | gcn_feature[i] = bilinear_feature 172 | 173 | fatt_feature = torch.nn.functional.grid_sample(Fatt[i:i + 1], grid=grid, align_corners=True)[0].permute(1, 0, 2) 174 | pw[i] = fatt_feature 175 | 176 | 177 | #point_center = (torch.min(img_poly, dim=1)[0] + torch.max(img_poly, dim=1)[0]) * 0.5 178 | #point_center = point_center[:, None] 179 | #ct_feature = torch.zeros([batch_size, concat_feature.size(1), point_center.size(1)]).to(device) 180 | 181 | #for j in range(batch_size): 182 | # grid = point_center[j:j + 1].unsqueeze(1) 183 | # ct_bilinear_feature = torch.nn.functional.grid_sample(concat_feature[j:j + 1], grid=grid)[0].permute(1, 0, 2) 184 | # ct_feature[j] = ct_bilinear_feature 185 | 186 | #fuse_feature = torch.cat([gcn_feature, ct_feature.expand_as(gcn_feature)], dim=1) 187 | #fuse_feature = self.fuse(fuse_feature) 188 | 189 | return gcn_feature, pw 190 | 191 | def normalize_poly(self, img_poly): 192 | 193 | mi = torch.min(img_poly, dim=1, keepdim=True)[0].expand_as(img_poly) 194 | ma = torch.max(img_poly, dim=1, keepdim=True)[0].expand_as(img_poly) 195 | 196 | new_poly = (img_poly - mi) / (ma - mi) 197 | 198 | return new_poly, mi, ma 199 | 200 | 201 | def init_weights(self): 202 | self.cls_spv.weight.data.normal_(0, 0.01) 203 | self.cls_spv.bias.data.zero_() 204 | for m in self.modules(): 205 | if isinstance(m, nn.Conv1d): 206 | m.weight.data.normal_(0.0, 0.01) 207 | if m.bias is not None: 208 | nn.init.constant_(m.bias, 0) 209 | for name, param in self.named_parameters(): 210 | if 'fc' in name and 'weight' in name: 211 | nn.init.normal_(param, 0.0, 0.01) 212 | if 'fc' in name and 'bias' in name: 213 | nn.init.constant_(param, 0.0) 214 | if ('enconv' or 'enff' or 'edge_spv') in name and 'weight' in name: 215 | nn.init.normal_(param, 0.0, 0.01) 216 | if ('enconv' or 'enff' or 'edge_spv') in name and 'bias' in name: 217 | nn.init.constant_(param, 0.0) 218 | 219 | 220 | def forward(self, image, img_poly): 221 | 222 | image = self.base_features(image) 223 | o1 = self.res1(image) 224 | o2 = self.res2(o1) 225 | o3 = self.res3(o2) 226 | o4 = self.res4(o3) 227 | 228 | # classification branches 229 | cls_feature = self.avg_pool(o4).view(o4.size(0), -1) 230 | cls_feature = self.cls_spv(cls_feature) 231 | 232 | # DDoM 233 | doo4, doo3, doo2, doo1 = self.ddom(cls_feature, o4, o3, o2, o1) 234 | 235 | # O2OFM 236 | doo1 = self.enconv1(doo1) 237 | fo4 = self.enconv4(doo4) + doo1 238 | fo3 = self.enconv3(doo3) + doo1 239 | fo2 = self.enconv2(doo2) + doo1 240 | ff = self.enff(fo2 + fo3 + fo4) 241 | 242 | # edge supervision branches 243 | edge_feature = self.edge_spv(ff) 244 | 245 | Fatt = self.att_F(ff) 246 | 247 | # Deformation branches 248 | bilinear_feature, pw = self.get_bilinear_interpolation(ff, Fatt, img_poly) # (bs, 128, point_num) (bs, 1, point_num) 249 | Fatt_bf = bilinear_feature.mul(pw) 250 | img_poly, mi, ma = self.normalize_poly(img_poly) 251 | app_features = torch.cat([Fatt_bf, img_poly.permute(0, 2, 1)], dim=1) 252 | #app_features = torch.cat([bilinear_feature, img_poly.permute(0, 2, 1)], dim=1) 253 | 254 | 255 | snake_feature = self.snake(app_features) 256 | 257 | predict_offset = self.prediction(snake_feature).permute(0, 2, 1) 258 | 259 | predict_offset = predict_offset * (ma - mi) 260 | 261 | return cls_feature, edge_feature, predict_offset 262 | 263 | if __name__ == '__main__': 264 | 265 | from utils import * 266 | net = DDoAS(15) 267 | input_ellipse = get_ic(scale=4.) 268 | flops, params = profile(net, (torch.randn(1, 3, 224, 224), input_ellipse.unsqueeze(0),)) 269 | print(flops) 270 | print(params) 271 | 272 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import torch.backends.cudnn as cudnn 2 | import torch.utils 3 | import torchvision.transforms as transforms 4 | from dataset import * 5 | from utils import * 6 | from loss_function import * 7 | 8 | prohibited_item_classes = {'Gun': 0, 'Knife': 1, 'Wrench': 2, 'Pliers': 3, 'Scissors': 4, 'Lighter': 5, 'Battery': 6, 9 | 'Bat': 7, 'Razor_blade': 8, 'Saw_blade': 9, 'Fireworks': 10, 'Hammer': 11, 10 | 'Screwdriver': 12, 'Dart': 13, 'Pressure_vessel': 14} 11 | cityscapes_classes = {'person': 0, 'car': 1, 'truck': 2, 'bicycle': 3, 'motorcycle': 4, 'rider': 5, 12 | 'bus': 6, 'train': 7} 13 | 14 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 15 | cudnn.benchmark = True 16 | checkpoint = './checkpoint/BEST_checkpoint.pth.tar' 17 | batch_size = 1 18 | 19 | # load model 20 | checkpoint = torch.load(checkpoint) 21 | model = checkpoint['model'] 22 | model = model.to(device) 23 | model.eval() 24 | 25 | # normalization transform 26 | trans = transforms.ToTensor() 27 | normalize = transforms.Normalize(mean=[0.838, 0.855, 0.784], 28 | std=[0.268, 0.225, 0.291]) 29 | 30 | def test(): 31 | 32 | test_loader = torch.utils.data.DataLoader(data_loader_test('testset'), batch_size=batch_size, shuffle=True) 33 | gun_iou = list() 34 | knife_iou = list() 35 | wrench_iou = list() 36 | pliers_iou = list() 37 | scissors_iou = list() 38 | lighter_iou = list() 39 | battery_iou = list() 40 | bat_iou = list() 41 | razor_blade_iou = list() 42 | saw_blade_iou = list() 43 | fireworks_iou = list() 44 | hammer_iou = list() 45 | screwdriver_iou = list() 46 | dart_iou = list() 47 | pressure_vessel_iou = list() 48 | #person_iou = list() 49 | #car_iou = list() 50 | #truck_iou = list() 51 | #bicycle_iou = list() 52 | #motorcycle_iou = list() 53 | #rider_iou = list() 54 | #bus_iou = list() 55 | #train_iou = list() 56 | 57 | 58 | for i, (img, input_ellipse, label_ind, gt_ellipse) in enumerate(test_loader): 59 | 60 | # move to GPU, if available 61 | img = img.to(device) 62 | input_ellipse = input_ellipse.to(device) 63 | label_ind = label_ind.to(device) 64 | gt_ellipse = gt_ellipse.to(device) 65 | 66 | cls_feature, edge_feature, predict_offset = model(img, input_ellipse) 67 | new_ellipse = input_ellipse + predict_offset 68 | new_ellipse[..., 0] = new_ellipse[..., 0] * 4. 69 | new_ellipse[..., 1] = new_ellipse[..., 1] * 4. 70 | 71 | # calculate IoU 72 | IoU = accuracy(new_ellipse, gt_ellipse) 73 | if label_ind == 0: 74 | gun_iou.append(IoU) 75 | elif label_ind == 1: 76 | knife_iou.append(IoU) 77 | elif label_ind == 2: 78 | wrench_iou.append(IoU) 79 | elif label_ind == 3: 80 | pliers_iou.append(IoU) 81 | elif label_ind == 4: 82 | scissors_iou.append(IoU) 83 | elif label_ind == 5: 84 | lighter_iou.append(IoU) 85 | elif label_ind == 6: 86 | battery_iou.append(IoU) 87 | elif label_ind == 7: 88 | bat_iou.append(IoU) 89 | elif label_ind == 8: 90 | razor_blade_iou.append(IoU) 91 | elif label_ind == 9: 92 | saw_blade_iou.append(IoU) 93 | elif label_ind == 10: 94 | fireworks_iou.append(IoU) 95 | elif label_ind == 11: 96 | hammer_iou.append(IoU) 97 | elif label_ind == 12: 98 | screwdriver_iou.append(IoU) 99 | elif label_ind == 13: 100 | dart_iou.append(IoU) 101 | elif label_ind == 14: 102 | pressure_vessel_iou.append(IoU) 103 | #if label_ind == 0: 104 | # person_iou.append(IoU) 105 | #elif label_ind == 1: 106 | # car_iou.append(IoU) 107 | #elif label_ind == 2: 108 | # truck_iou.append(IoU) 109 | #elif label_ind == 3: 110 | # bicycle_iou.append(IoU) 111 | #elif label_ind == 4: 112 | # motorcycle_iou.append(IoU) 113 | #elif label_ind == 5: 114 | # rider_iou.append(IoU) 115 | #elif label_ind == 6: 116 | # bus_iou.append(IoU) 117 | #elif label_ind == 7: 118 | # train_iou.append(IoU) 119 | 120 | 121 | print('Gun_IoU: {:.3f}\t' 122 | 'Knife_IoU: {:.3f}\t' 123 | 'Wrench_IoU: {:.3f}\t' 124 | 'Pliers_IoU: {:.3f}\t' 125 | 'Scissors_IoU: {:.3f}\t' 126 | 'Lighter_IoU: {:.3f}\t' 127 | 'Battery_IoU: {:.3f}\t' 128 | 'Bat_IoU: {:.3f}\t' 129 | 'Razor_blade_IoU: {:.3f}\t' 130 | 'Saw_blade_IoU: {:.3f}\t' 131 | 'Fireworks_IoU: {:.3f}\t' 132 | 'Hammer_IoU: {:.3f}\t' 133 | 'Screwdriver_IoU: {:.3f}\t' 134 | 'Dart_IoU: {:.3f}\t' 135 | 'Pressure_vessel_IoU: {:.3f}\t'.format(sum(gun_iou) / len(gun_iou), 136 | sum(knife_iou) / len(knife_iou), 137 | sum(wrench_iou) / len(wrench_iou), 138 | sum(pliers_iou) / len(pliers_iou), 139 | sum(scissors_iou) / len(scissors_iou), 140 | sum(lighter_iou) / len(lighter_iou), 141 | sum(battery_iou) / len(battery_iou), 142 | sum(bat_iou) / len(bat_iou), 143 | sum(razor_blade_iou) / len(razor_blade_iou), 144 | sum(saw_blade_iou) / len(saw_blade_iou), 145 | sum(fireworks_iou) / len(fireworks_iou), 146 | sum(hammer_iou) / len(hammer_iou), 147 | sum(screwdriver_iou) / len(screwdriver_iou), 148 | sum(dart_iou) / len(dart_iou), 149 | sum(pressure_vessel_iou) / len(pressure_vessel_iou))) 150 | #print('Person_IoU: {:.3f}\t' 151 | # 'Car_IoU: {:.3f}\t' 152 | # 'Truck_IoU: {:.3f}\t' 153 | # 'Bicycle_IoU: {:.3f}\t' 154 | # 'Motorcycle_IoU: {:.3f}\t' 155 | # 'Rider_IoU: {:.3f}\t' 156 | # 'Bus_IoU: {:.3f}\t' 157 | # 'Train_IoU: {:.3f}\t'.format(sum(person_iou) / len(person_iou), 158 | # sum(car_iou) / len(car_iou), 159 | # sum(truck_iou) / len(truck_iou), 160 | # sum(bicycle_iou) / len(bicycle_iou), 161 | # sum(motorcycle_iou) / len(motorcycle_iou), 162 | # sum(rider_iou) / len(rider_iou), 163 | # sum(bus_iou) / len(bus_iou), 164 | # sum(train_iou) / len(train_iou))) 165 | if __name__ == '__main__': 166 | test() 167 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import torch.utils 2 | import torchvision.transforms as transforms 3 | from models import DDoAS 4 | from dataset import * 5 | from utils import * 6 | from loss_function import * 7 | import torch 8 | 9 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 10 | 11 | # Training parameters 12 | start_epoch = 0 13 | epochs = 1000 14 | grad_clip = 5. 15 | learning_rate = 0.0001 16 | epochs_since_improvement = 0 17 | print_freq = 50 # 300 18 | best_accuracy = 0. 19 | checkpoint = None # './checkpoint/checkpoint.pth.tar' # path to checkpoint 20 | batch_size = 32 21 | val_batch_size = 16 22 | 23 | 24 | def main(): 25 | 26 | ''' 27 | Training and Validation 28 | ''' 29 | 30 | global best_accuracy, epochs_since_improvement, checkpoint, start_epoch 31 | 32 | # initialize / load checkpoint 33 | if checkpoint is None: 34 | model = DDoAS(num_classes=15) 35 | optimizer = torch.optim.Adam(params=filter(lambda p: p.requires_grad, model.parameters()), 36 | lr=learning_rate) 37 | else: 38 | checkpoint = torch.load(checkpoint) 39 | start_epoch = checkpoint['epoch'] + 1 40 | epochs_since_improvement = checkpoint['epochs_since_improvement'] 41 | model = checkpoint['model'] 42 | optimizer = checkpoint['optimizer'] 43 | best_accuracy = checkpoint['val_accuracy'] 44 | 45 | 46 | # move to GPU, if available 47 | model = model.to(device) 48 | 49 | # Loss function 50 | criterion = loss().to(device) 51 | 52 | # custom dataloaders 53 | train_loader = torch.utils.data.DataLoader(data_loader('trainset'), batch_size=batch_size, shuffle=True) 54 | 55 | val_loader = torch.utils.data.DataLoader(data_loader('valset'), batch_size=val_batch_size, shuffle=True) 56 | 57 | # Epochs 58 | for epoch in range(start_epoch, epochs): 59 | 60 | if epochs_since_improvement == 30: 61 | break 62 | if epochs_since_improvement > 0 and epochs_since_improvement % 10 == 0: 63 | adjust_learning_rate(optimizer, 0.8) 64 | 65 | # One epoch's training 66 | train(train_loader, model, criterion, optimizer, epoch) 67 | 68 | # One epoch's validation 69 | val_accuracy = validate(val_loader, model, criterion) 70 | 71 | # Check if there was an improvement 72 | is_best = val_accuracy > best_accuracy 73 | best_accuracy = max(val_accuracy, best_accuracy) 74 | if not is_best: 75 | epochs_since_improvement += 1 76 | print("\nEpochs since last improvement: %d\n" % (epochs_since_improvement)) 77 | else: 78 | epochs_since_improvement = 0 79 | 80 | # save checkpoint 81 | save_checkpoint(epoch, epochs_since_improvement, model, optimizer, val_accuracy, is_best) 82 | 83 | def train(train_loader, model, criterion, optimizer, epoch): 84 | 85 | ''' 86 | performs one epoch's training 87 | ''' 88 | 89 | model.train() 90 | 91 | train_loss = list() 92 | closs = list() 93 | eloss = list() 94 | ploss = list() 95 | train_acc = list() 96 | 97 | for i, (img, img_edge, input_ellipse, label_ind, gt_ellipse) in enumerate(train_loader): 98 | 99 | # move to GPU, if availble 100 | img = img.to(device) 101 | img_edge = img_edge.to(device) 102 | input_ellipse = input_ellipse.to(device) 103 | label_ind = label_ind.to(device) 104 | gt_ellipse = gt_ellipse.to(device) 105 | 106 | # forward prop. 107 | cls_feature, edge_feature, predict_offset = model(img, input_ellipse) # (bs, class_num), (bs, point_num, 2) 108 | 109 | label_ind = label_ind.view(-1) # (1, bs) 110 | 111 | # update contour 112 | new_ellipse = input_ellipse + predict_offset 113 | new_ellipse[..., 0] = new_ellipse[..., 0] * 4. 114 | new_ellipse[..., 1] = new_ellipse[..., 1] * 4. 115 | 116 | # calculate loss 117 | cls_loss, edge_loss, point_reg_loss = criterion(cls_feature, label_ind, new_ellipse, gt_ellipse, edge_feature, img_edge) 118 | 119 | loss_value = cls_loss + 0.01 * edge_loss + point_reg_loss * 0.01 120 | 121 | # Back prop. 122 | optimizer.zero_grad() 123 | loss_value.backward() 124 | 125 | train_loss.append(loss_value.item()) 126 | closs.append(cls_loss.item()) 127 | eloss.append(edge_loss.item() * 0.01) 128 | ploss.append(point_reg_loss.item() * 0.01) 129 | 130 | # clip gradients 131 | if grad_clip is not None: 132 | clip_gradient(optimizer, grad_clip) 133 | 134 | # update weights 135 | optimizer.step() 136 | 137 | # calculate accuracy 138 | acc = accuracy(new_ellipse, gt_ellipse) 139 | train_acc.append(acc) 140 | 141 | # print status 142 | if i % print_freq == 0: 143 | print('Epoch: [{}]/[{}/{}]\t' 144 | 'Loss: {:.3f}\t' 145 | 'cls_loss: {:.3f} - edge_loss: {:.3f} - point_reg_loss: {:.3f}\t' 146 | 'Accuracy: {:.3f}'.format(epoch, i, len(train_loader), 147 | sum(train_loss)/len(train_loss), 148 | sum(closs)/len(closs), 149 | sum(eloss)/len(eloss), 150 | sum(ploss)/len(ploss), 151 | sum(train_acc)/len(train_acc))) 152 | 153 | def validate(val_loader, model, criterion): 154 | 155 | model.eval() 156 | 157 | val_closs = list() 158 | val_eloss = list() 159 | val_ploss = list() 160 | val_loss = list() 161 | val_acc = list() 162 | with torch.no_grad(): 163 | for i, (img, img_edge, input_ellipse, label_ind, gt_ellipse) in enumerate(val_loader): 164 | 165 | # move to GPU, if available 166 | img = img.to(device) 167 | img_edge = img_edge.to(device) 168 | input_ellipse = input_ellipse.to(device) 169 | label_ind = label_ind.to(device) 170 | gt_ellipse = gt_ellipse.to(device) 171 | 172 | # forward prop. 173 | cls_feature, edge_feature, predict_offset = model(img, input_ellipse) 174 | label_ind = label_ind.view(-1) # (1, bs) 175 | 176 | # update contour 177 | new_ellipse = input_ellipse + predict_offset 178 | new_ellipse[..., 0] = new_ellipse[..., 0] * 4. 179 | new_ellipse[..., 1] = new_ellipse[..., 1] * 4. 180 | 181 | # calculate loss 182 | cls_loss, edge_loss, point_reg_loss = criterion(cls_feature, label_ind, new_ellipse, gt_ellipse, edge_feature, img_edge) 183 | 184 | loss_value = cls_loss + 0.01 * edge_loss + point_reg_loss * 0.01 185 | 186 | # calculate accuracy 187 | acc = accuracy(new_ellipse, gt_ellipse) 188 | val_acc.append(acc) 189 | 190 | val_loss.append(loss_value.item()) 191 | val_closs.append(cls_loss.item()) 192 | val_eloss.append(edge_loss.item() * 0.01) 193 | val_ploss.append(point_reg_loss.item() * 0.01) 194 | 195 | val_accuracy = sum(val_acc) / len(val_acc) 196 | 197 | # print status 198 | print('Loss: {:.3f}\t' 199 | 'cls_loss: {:.3f} - edge_loss: {:.3f} - point_reg_loss: {:.3f}\t' 200 | 'Accuracy: {:.3f}'.format(sum(val_loss) / len(val_loss), 201 | sum(val_closs) / len(val_closs), 202 | sum(val_eloss) / len(val_eloss), 203 | sum(val_ploss) / len(val_ploss), 204 | val_accuracy)) 205 | print('--------------------------------------------------------') 206 | return val_accuracy 207 | 208 | if __name__ == '__main__': 209 | main() -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch.nn as nn 3 | import torch 4 | import cv2 5 | 6 | def get_ellipse(box, point_num): 7 | point_array = np.zeros(shape=(point_num, 2), dtype=np.float32) 8 | for i in range(point_num): 9 | theta = 1.0 * i / point_num * 2 * np.pi 10 | x = np.cos(theta) 11 | y = -np.sin(theta) 12 | point_array[i, 0] = x 13 | point_array[i, 1] = y 14 | point_array /= 2 15 | point_array += 0.5 16 | w, h = box[2] - box[0], box[3] - box[1] 17 | point_array *= np.array([w, h]) 18 | point_array = point_array + np.array([box[0], box[1]]) 19 | return point_array 20 | 21 | def get_ic(scale=4.): 22 | 23 | point_array = get_ellipse((20, 40, 204, 184), 60) 24 | point_array = point_array.tolist() 25 | points = [point_array[0]] + list(reversed(point_array[1:])) 26 | points = torch.FloatTensor(points) 27 | points[..., 0] = points[..., 0] / scale 28 | points[..., 1] = points[..., 1] / scale 29 | 30 | return points 31 | 32 | def adjust_learning_rate(optimizer, shrink_factor): 33 | 34 | print("\nDecaying learning rate.") 35 | for param_group in optimizer.param_groups: 36 | param_group['lr'] = param_group['lr'] * shrink_factor 37 | print("The new learning rate is %f\n" % (optimizer.param_groups[0]['lr'],)) 38 | 39 | def clip_gradient(optimizer, grad_clip): 40 | 41 | for group in optimizer.param_groups: 42 | for param in group['params']: 43 | if param.grad is not None: 44 | param.grad.data.clamp_(-grad_clip, grad_clip) 45 | 46 | def save_checkpoint(epoch, epochs_since_improvement, model, optimizer, val_accuracy, is_best): 47 | 48 | state = {'epoch': epoch, 49 | 'epochs_since_improvement': epochs_since_improvement, 50 | 'model': model, 51 | 'optimizer': optimizer, 52 | 'val_accuracy': val_accuracy} 53 | 54 | filename = 'HR_checkpoint.pth.tar' 55 | torch.save(state, './checkpoint/' + filename) 56 | if is_best: 57 | torch.save(state, './checkpoint/BEST_' + filename) 58 | 59 | 60 | def accuracy(new_ellipse, gt_ellipse): 61 | 62 | IoU = 0. 63 | batch_size = new_ellipse.size(0) 64 | for i in range(batch_size): 65 | new_ei = new_ellipse[i].tolist() 66 | gt_ei = gt_ellipse[i].tolist() 67 | iou, _ = iou_from_poly(new_ei, gt_ei, 224, 224) 68 | IoU += iou 69 | acc = IoU / batch_size 70 | 71 | return acc 72 | 73 | def iou_from_mask(pred, gt): 74 | 75 | pred = pred.astype(np.bool) 76 | gt = gt.astype(np.bool) 77 | 78 | false_negatives = np.count_nonzero(np.logical_and(gt, np.logical_not(pred))) 79 | false_positives = np.count_nonzero(np.logical_and(np.logical_not(gt), pred)) 80 | true_positives = np.count_nonzero(np.logical_and(gt, pred)) 81 | 82 | union = float(true_positives + false_positives + false_negatives) 83 | intersection = float(true_positives) 84 | 85 | iou = intersection / union 86 | 87 | return iou, union 88 | 89 | def iou_from_poly(pred, gt, height, width): 90 | 91 | masks = np.zeros((2, height, width), dtype=np.uint8) 92 | 93 | if not isinstance(pred, list): 94 | pred = [pred] 95 | if not isinstance(gt, list): 96 | gt = [gt] 97 | 98 | masks[0] = draw_poly(masks[0], pred) 99 | 100 | masks[1] = draw_poly(masks[1], gt) 101 | 102 | return iou_from_mask(masks[0], masks[1]) 103 | 104 | def draw_poly(mask, poly): 105 | if not isinstance(poly, np.ndarray): 106 | poly = np.array(poly) 107 | 108 | cv2.fillPoly(mask, np.int32([poly]), 255) 109 | 110 | return mask 111 | 112 | --------------------------------------------------------------------------------