├── README.md ├── assert ├── overview.png └── visual_result.png ├── main.py ├── model ├── MSHNet.py └── loss.py └── utils ├── data.py └── metric.py /README.md: -------------------------------------------------------------------------------- 1 | # Infrared Small Target Detection with Scale and Location Sensitivity 2 | 3 | ## Notice! 📰 4 | First of all, thank you to all relevant workers for your attention. Recently, many people have discovered some obvious errors in the code, so we re-checked, modified and debugged the code. Surprisingly, we unexpectedly obtained a pretty good result on the IRSTD-1k data set. The results are published below for your reference. 5 | | Dataset | mIoU (x10(-2)) | Pd (x10(-2))| Fa (x10(-6)) | Weights| 6 | | ------------- |:-------------:|:-----:|:-----:|:-----:| 7 | | IRSTD-1k | 67.87 | 92.86 | 8.88 | [new_weights](https://drive.google.com/file/d/1CSDwQG8xg7hv0_oGKa4NCEWUiMRU7eIs/view?usp=sharing) | 8 | 9 | ## Overview 10 | ![](assert/overview.png) 11 | 12 | ## Introduction 13 | This repository is the official implementation of our CVPR 2024 paper [Infrared Small Target Detection with Scale and Location Sensitivity](https://arxiv.org/abs/2403.19366). 14 | 15 | In this paper, we first propose a novel Scale and Location Sensitive (SLS) loss to handle the limitations of existing losses: 1) for scale sensitivity, we compute a weight for the IoU loss based on target scales to help the detector distinguish targets with different scales: 2) for location sensitivity, we introduce a penalty term based on the center points of targets to help the detector localize targets more precisely. Then, we design a simple Multi-Scale Head to the plain U-Net (MSHNet). By applying SLS loss to each scale of the predictions, our MSHNet outperforms existing state-of-the-art methods by a large margin. In addition, the detection performance of existing detectors can be further improved when trained with our SLS loss, demonstrating the effectiveness and generalization of our SLS loss. The contribution of this paper are as follows: 16 | 17 | 1. We propose a novel scale and location sensitive loss for infrared small target detection, which helps detectors distinguish objects with different scales and locations. 18 | 19 | 2. We propose a simple but effective detector which achieves SOTA performance without bells and whistles. 20 | 21 | 3. We apply our loss to existing detectors and show that the detection performance can be further boosted. 22 | 23 | ## Training 24 | The training command is very simple like this: 25 | ``` 26 | python main --dataset-dir --batch-size --epochs --lr --mode 'train' 27 | ``` 28 | 29 | For example: 30 | ``` 31 | python main.py --dataset-dir '/dataset/IRSTD-1k' --batch-size 4 --epochs 400 --lr 0.05 --mode 'train' 32 | ``` 33 | 34 | ## Testing 35 | You can test the model with the following command: 36 | ``` 37 | python main.py --dataset-dir '/dataset/IRSTD-1k' --batch-size 4 --mode 'test' --weight-path '/weight/MSHNet_weight.tar' 38 | ``` 39 | 40 | ## Visual Results 41 | ![](assert/visual_result.png) 42 | 43 | ## Quantative Results 44 | | Dataset | mIoU (x10(-2)) | Pd (x10(-2))| Fa (x10(-6)) | Weights| 45 | | ------------- |:-------------:|:-----:|:-----:|:-----:| 46 | | IRSTD-1k | 67.16 | 93.88 | 15.03 | [IRSTD-1k_weights](https://drive.google.com/file/d/1q3zfzJRczodGQb0dZ3y3KmLn0zz4F8ra/view?usp=drive_link) | 47 | | NUDT-SIRST | 80.55 | 97.99 | 11.77 | [NUDT-SIRST_weights](https://drive.google.com/file/d/1uczanUIHePZqJA79RZu25fv9FNSHSDQZ/view?usp=drive_link) | 48 | 49 | 50 | ## Citation 51 | **Please kindly cite the papers if this code is useful and helpful for your research.** 52 | 53 | @inproceedings{liu2024infrared, 54 | title={Infrared Small Target Detection with Scale and Location Sensitivity}, 55 | author={Liu, Qiankun and Liu, Rui and Zheng, Bolun and Wang, Hongkui and Fu, Ying}, 56 | booktitle={Proceedings of the IEEE/CVF Computer Vision and Pattern Recognition}, 57 | year={2024} 58 | } 59 | -------------------------------------------------------------------------------- /assert/overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ying-fu/MSHNet/46cdfd46802629da51f70124662af7335be74b56/assert/overview.png -------------------------------------------------------------------------------- /assert/visual_result.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ying-fu/MSHNet/46cdfd46802629da51f70124662af7335be74b56/assert/visual_result.png -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from utils.data import * 2 | from utils.metric import * 3 | from argparse import ArgumentParser 4 | import torch 5 | import torch.utils.data as Data 6 | from model.MSHNet import * 7 | from model.loss import * 8 | from torch.optim import Adagrad 9 | from tqdm import tqdm 10 | import os.path as osp 11 | import os 12 | import time 13 | 14 | os.environ['CUDA_VISIBLE_DEVICES']="0" 15 | 16 | def parse_args(): 17 | 18 | # 19 | # Setting parameters 20 | # 21 | parser = ArgumentParser(description='Implement of model') 22 | 23 | parser.add_argument('--dataset-dir', type=str, default='/dataset/IRSTD-1k') 24 | parser.add_argument('--batch-size', type=int, default=4) 25 | parser.add_argument('--epochs', type=int, default=400) 26 | parser.add_argument('--lr', type=float, default=0.05) 27 | parser.add_argument('--warm-epoch', type=int, default=5) 28 | 29 | parser.add_argument('--base-size', type=int, default=256) 30 | parser.add_argument('--crop-size', type=int, default=256) 31 | parser.add_argument('--multi-gpus', type=bool, default=False) 32 | parser.add_argument('--if-checkpoint', type=bool, default=False) 33 | 34 | parser.add_argument('--mode', type=str, default='train') 35 | parser.add_argument('--weight-path', type=str, default='/MSHNet/weight/IRSTD-1k_weight.tar') 36 | 37 | args = parser.parse_args() 38 | return args 39 | 40 | class Trainer(object): 41 | def __init__(self, args): 42 | assert args.mode == 'train' or args.mode == 'test' 43 | 44 | self.args = args 45 | self.start_epoch = 0 46 | self.mode = args.mode 47 | 48 | trainset = IRSTD_Dataset(args, mode='train') 49 | valset = IRSTD_Dataset(args, mode='val') 50 | 51 | self.train_loader = Data.DataLoader(trainset, args.batch_size, shuffle=True, drop_last=True) 52 | self.val_loader = Data.DataLoader(valset, 1, drop_last=False) 53 | 54 | device = torch.device('cuda') 55 | self.device = device 56 | 57 | model = MSHNet(3) 58 | 59 | if args.multi_gpus: 60 | if torch.cuda.device_count() > 1: 61 | print('use '+str(torch.cuda.device_count())+' gpus') 62 | model = nn.DataParallel(model, device_ids=[0, 1]) 63 | model.to(device) 64 | self.model = model 65 | 66 | self.optimizer = Adagrad(filter(lambda p: p.requires_grad, self.model.parameters()), lr=args.lr) 67 | 68 | self.down = nn.MaxPool2d(2, 2) 69 | self.loss_fun = SLSIoULoss() 70 | self.PD_FA = PD_FA(1, 10, args.base_size) 71 | self.mIoU = mIoU(1) 72 | self.ROC = ROCMetric(1, 10) 73 | self.best_iou = 0 74 | self.warm_epoch = args.warm_epoch 75 | 76 | if args.mode=='train': 77 | if args.if_checkpoint: 78 | check_folder = '' 79 | checkpoint = torch.load(check_folder+'/checkpoint.pkl') 80 | self.model.load_state_dict(checkpoint['net']) 81 | self.optimizer.load_state_dict(checkpoint['optimizer']) 82 | self.start_epoch = checkpoint['epoch']+1 83 | self.best_iou = checkpoint['iou'] 84 | self.save_folder = check_folder 85 | else: 86 | self.save_folder = '/MSHNet/weight/MSHNet-%s'%(time.strftime('%Y-%m-%d-%H-%M-%S',time.localtime(time.time()))) 87 | if not osp.exists(self.save_folder): 88 | os.mkdir(self.save_folder) 89 | if args.mode=='test': 90 | 91 | weight = torch.load(args.weight_path) 92 | self.model.load_state_dict(weight['state_dict']) 93 | ''' 94 | # iou_67.87_weight 95 | weight = torch.load(args.weight_path) 96 | self.model.load_state_dict(weight) 97 | ''' 98 | self.warm_epoch = -1 99 | 100 | 101 | def train(self, epoch): 102 | self.model.train() 103 | tbar = tqdm(self.train_loader) 104 | losses = AverageMeter() 105 | tag = False 106 | for i, (data, mask) in enumerate(tbar): 107 | 108 | data = data.to(self.device) 109 | labels = mask.to(self.device) 110 | 111 | if epoch>self.warm_epoch: 112 | tag = True 113 | 114 | masks, pred = self.model(data, tag) 115 | loss = 0 116 | 117 | loss = loss + self.loss_fun(pred, labels, self.warm_epoch, epoch) 118 | for j in range(len(masks)): 119 | if j>0: 120 | labels = self.down(labels) 121 | loss = loss + self.loss_fun(masks[j], labels, self.warm_epoch, epoch) 122 | 123 | loss = loss / (len(masks)+1) 124 | 125 | self.optimizer.zero_grad() 126 | loss.backward() 127 | self.optimizer.step() 128 | 129 | losses.update(loss.item(), pred.size(0)) 130 | tbar.set_description('Epoch %d, loss %.4f' % (epoch, losses.avg)) 131 | 132 | def test(self, epoch): 133 | self.model.eval() 134 | self.mIoU.reset() 135 | self.PD_FA.reset() 136 | tbar = tqdm(self.val_loader) 137 | tag = False 138 | with torch.no_grad(): 139 | for i, (data, mask) in enumerate(tbar): 140 | 141 | data = data.to(self.device) 142 | mask = mask.to(self.device) 143 | 144 | if epoch>self.warm_epoch: 145 | tag = True 146 | 147 | loss = 0 148 | _, pred = self.model(data, tag) 149 | # loss += self.loss_fun(pred, mask,self.warm_epoch, epoch) 150 | 151 | self.mIoU.update(pred, mask) 152 | self.PD_FA.update(pred, mask) 153 | self.ROC.update(pred, mask) 154 | _, mean_IoU = self.mIoU.get() 155 | 156 | tbar.set_description('Epoch %d, IoU %.4f' % (epoch, mean_IoU)) 157 | FA, PD = self.PD_FA.get(len(self.val_loader)) 158 | _, mean_IoU = self.mIoU.get() 159 | ture_positive_rate, false_positive_rate, _, _ = self.ROC.get() 160 | 161 | 162 | if self.mode == 'train': 163 | if mean_IoU > self.best_iou: 164 | self.best_iou = mean_IoU 165 | 166 | torch.save(self.model.state_dict(), self.save_folder+'/weight.pkl') 167 | with open(osp.join(self.save_folder, 'metric.log'), 'a') as f: 168 | f.write('{} - {:04d}\t - IoU {:.4f}\t - PD {:.4f}\t - FA {:.4f}\n' . 169 | format(time.strftime('%Y-%m-%d-%H-%M-%S',time.localtime(time.time())), 170 | epoch, self.best_iou, PD[0], FA[0] * 1000000)) 171 | 172 | all_states = {"net":self.model.state_dict(), "optimizer":self.optimizer.state_dict(), "epoch": epoch, "iou":self.best_iou} 173 | torch.save(all_states, self.save_folder+'/checkpoint.pkl') 174 | elif self.mode == 'test': 175 | print('mIoU: '+str(mean_IoU)+'\n') 176 | print('Pd: '+str(PD[0])+'\n') 177 | print('Fa: '+str(FA[0]*1000000)+'\n') 178 | 179 | 180 | 181 | if __name__ == '__main__': 182 | args = parse_args() 183 | 184 | trainer = Trainer(args) 185 | 186 | if trainer.mode=='train': 187 | for epoch in range(trainer.start_epoch, args.epochs): 188 | trainer.train(epoch) 189 | trainer.test(epoch) 190 | else: 191 | trainer.test(1) 192 | -------------------------------------------------------------------------------- /model/MSHNet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class ChannelAttention(nn.Module): 6 | def __init__(self, in_planes, ratio=16): 7 | super(ChannelAttention, self).__init__() 8 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 9 | self.max_pool = nn.AdaptiveMaxPool2d(1) 10 | self.fc1 = nn.Conv2d(in_planes, in_planes // 16, 1, bias=False) 11 | self.relu1 = nn.ReLU() 12 | self.fc2 = nn.Conv2d(in_planes // 16, in_planes, 1, bias=False) 13 | self.sigmoid = nn.Sigmoid() 14 | def forward(self, x): 15 | avg_out = self.fc2(self.relu1(self.fc1(self.avg_pool(x)))) 16 | max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x)))) 17 | out = avg_out + max_out 18 | return self.sigmoid(out) 19 | 20 | class SpatialAttention(nn.Module): 21 | def __init__(self, kernel_size=7): 22 | super(SpatialAttention, self).__init__() 23 | assert kernel_size in (3, 7), 'kernel size must be 3 or 7' 24 | padding = 3 if kernel_size == 7 else 1 25 | self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False) 26 | self.sigmoid = nn.Sigmoid() 27 | def forward(self, x): 28 | avg_out = torch.mean(x, dim=1, keepdim=True) 29 | max_out, _ = torch.max(x, dim=1, keepdim=True) 30 | x = torch.cat([avg_out, max_out], dim=1) 31 | x = self.conv1(x) 32 | return self.sigmoid(x) 33 | 34 | class ResNet(nn.Module): 35 | def __init__(self, in_channels, out_channels, stride = 1): 36 | super(ResNet, self).__init__() 37 | self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size = 3, stride = stride, padding = 1) 38 | self.bn1 = nn.BatchNorm2d(out_channels) 39 | self.relu = nn.ReLU(inplace = True) 40 | self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size = 3, padding = 1) 41 | self.bn2 = nn.BatchNorm2d(out_channels) 42 | if stride != 1 or out_channels != in_channels: 43 | self.shortcut = nn.Sequential( 44 | nn.Conv2d(in_channels, out_channels, kernel_size = 1, stride = stride), 45 | nn.BatchNorm2d(out_channels)) 46 | else: 47 | self.shortcut = None 48 | 49 | self.ca = ChannelAttention(out_channels) 50 | self.sa = SpatialAttention() 51 | 52 | def forward(self, x): 53 | residual = x 54 | if self.shortcut is not None: 55 | residual = self.shortcut(x) 56 | out = self.conv1(x) 57 | out = self.bn1(out) 58 | out = self.relu(out) 59 | out = self.conv2(out) 60 | out = self.bn2(out) 61 | out = self.ca(out) * out 62 | out = self.sa(out) * out 63 | out += residual 64 | out = self.relu(out) 65 | return out 66 | 67 | class MSHNet(nn.Module): 68 | def __init__(self, input_channels, block=ResNet): 69 | super().__init__() 70 | param_channels = [16, 32, 64, 128, 256] 71 | param_blocks = [2, 2, 2, 2] 72 | self.pool = nn.MaxPool2d(2, 2) 73 | self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 74 | self.up_4 = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=True) 75 | self.up_8 = nn.Upsample(scale_factor=8, mode='bilinear', align_corners=True) 76 | self.up_16 = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=True) 77 | 78 | self.conv_init = nn.Conv2d(input_channels, param_channels[0], 1, 1) 79 | 80 | self.encoder_0 = self._make_layer(param_channels[0], param_channels[0], block) 81 | self.encoder_1 = self._make_layer(param_channels[0], param_channels[1], block, param_blocks[0]) 82 | self.encoder_2 = self._make_layer(param_channels[1], param_channels[2], block, param_blocks[1]) 83 | self.encoder_3 = self._make_layer(param_channels[2], param_channels[3], block, param_blocks[2]) 84 | 85 | self.middle_layer = self._make_layer(param_channels[3], param_channels[4], block, param_blocks[3]) 86 | 87 | self.decoder_3 = self._make_layer(param_channels[3]+param_channels[4], param_channels[3], block, param_blocks[2]) 88 | self.decoder_2 = self._make_layer(param_channels[2]+param_channels[3], param_channels[2], block, param_blocks[1]) 89 | self.decoder_1 = self._make_layer(param_channels[1]+param_channels[2], param_channels[1], block, param_blocks[0]) 90 | self.decoder_0 = self._make_layer(param_channels[0]+param_channels[1], param_channels[0], block) 91 | 92 | self.output_0 = nn.Conv2d(param_channels[0], 1, 1) 93 | self.output_1 = nn.Conv2d(param_channels[1], 1, 1) 94 | self.output_2 = nn.Conv2d(param_channels[2], 1, 1) 95 | self.output_3 = nn.Conv2d(param_channels[3], 1, 1) 96 | 97 | self.final = nn.Conv2d(4, 1, 3, 1, 1) 98 | 99 | 100 | def _make_layer(self, in_channels, out_channels, block, block_num=1): 101 | layer = [] 102 | layer.append(block(in_channels, out_channels)) 103 | for _ in range(block_num-1): 104 | layer.append(block(out_channels, out_channels)) 105 | return nn.Sequential(*layer) 106 | 107 | def forward(self, x, warm_flag): 108 | x_e0 = self.encoder_0(self.conv_init(x)) 109 | x_e1 = self.encoder_1(self.pool(x_e0)) 110 | x_e2 = self.encoder_2(self.pool(x_e1)) 111 | x_e3 = self.encoder_3(self.pool(x_e2)) 112 | 113 | x_m = self.middle_layer(self.pool(x_e3)) 114 | 115 | x_d3 = self.decoder_3(torch.cat([x_e3, self.up(x_m)], 1)) 116 | x_d2 = self.decoder_2(torch.cat([x_e2, self.up(x_d3)], 1)) 117 | x_d1 = self.decoder_1(torch.cat([x_e1, self.up(x_d2)], 1)) 118 | x_d0 = self.decoder_0(torch.cat([x_e0, self.up(x_d1)], 1)) 119 | 120 | 121 | if warm_flag: 122 | mask0 = self.output_0(x_d0) 123 | mask1 = self.output_1(x_d1) 124 | mask2 = self.output_2(x_d2) 125 | mask3 = self.output_3(x_d3) 126 | output = self.final(torch.cat([mask0, self.up(mask1), self.up_4(mask2), self.up_8(mask3)], dim=1)) 127 | return [mask0, mask1, mask2, mask3], output 128 | 129 | else: 130 | output = self.output_0(x_d0) 131 | return [], output 132 | 133 | 134 | -------------------------------------------------------------------------------- /model/loss.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import numpy as np 3 | import torch 4 | import torch.nn.functional as F 5 | from skimage import measure 6 | 7 | 8 | def SoftIoULoss( pred, target): 9 | pred = torch.sigmoid(pred) 10 | 11 | smooth = 1 12 | 13 | intersection = pred * target 14 | intersection_sum = torch.sum(intersection, dim=(1,2,3)) 15 | pred_sum = torch.sum(pred, dim=(1,2,3)) 16 | target_sum = torch.sum(target, dim=(1,2,3)) 17 | 18 | loss = (intersection_sum + smooth) / \ 19 | (pred_sum + target_sum - intersection_sum + smooth) 20 | 21 | loss = 1 - loss.mean() 22 | 23 | return loss 24 | 25 | def Dice( pred, target,warm_epoch=1, epoch=1, layer=0): 26 | pred = torch.sigmoid(pred) 27 | 28 | smooth = 1 29 | 30 | intersection = pred * target 31 | intersection_sum = torch.sum(intersection, dim=(1,2,3)) 32 | pred_sum = torch.sum(pred, dim=(1,2,3)) 33 | target_sum = torch.sum(target, dim=(1,2,3)) 34 | 35 | loss = (2*intersection_sum + smooth) / \ 36 | (pred_sum + target_sum + intersection_sum + smooth) 37 | 38 | loss = 1 - loss.mean() 39 | 40 | return loss 41 | 42 | class SLSIoULoss(nn.Module): 43 | def __init__(self): 44 | super(SLSIoULoss, self).__init__() 45 | 46 | 47 | def forward(self, pred_log, target,warm_epoch, epoch, with_shape=True): 48 | pred = torch.sigmoid(pred_log) 49 | smooth = 0.0 50 | 51 | intersection = pred * target 52 | 53 | intersection_sum = torch.sum(intersection, dim=(1,2,3)) 54 | pred_sum = torch.sum(pred, dim=(1,2,3)) 55 | target_sum = torch.sum(target, dim=(1,2,3)) 56 | 57 | dis = torch.pow((pred_sum-target_sum)/2, 2) 58 | 59 | 60 | alpha = (torch.min(pred_sum, target_sum) + dis + smooth) / (torch.max(pred_sum, target_sum) + dis + smooth) 61 | 62 | loss = (intersection_sum + smooth) / \ 63 | (pred_sum + target_sum - intersection_sum + smooth) 64 | lloss = LLoss(pred, target) 65 | 66 | if epoch>warm_epoch: 67 | siou_loss = alpha * loss 68 | if with_shape: 69 | loss = 1 - siou_loss.mean() + lloss 70 | else: 71 | loss = 1 -siou_loss.mean() 72 | else: 73 | loss = 1 - loss.mean() 74 | return loss 75 | 76 | 77 | 78 | def LLoss(pred, target): 79 | loss = torch.tensor(0.0, requires_grad=True).to(pred) 80 | 81 | patch_size = pred.shape[0] 82 | h = pred.shape[2] 83 | w = pred.shape[3] 84 | x_index = torch.arange(0,w,1).view(1, 1, w).repeat((1,h,1)).to(pred) / w 85 | y_index = torch.arange(0,h,1).view(1, h, 1).repeat((1,1,w)).to(pred) / h 86 | smooth = 1e-8 87 | for i in range(patch_size): 88 | 89 | pred_centerx = (x_index*pred[i]).mean() 90 | pred_centery = (y_index*pred[i]).mean() 91 | 92 | target_centerx = (x_index*target[i]).mean() 93 | target_centery = (y_index*target[i]).mean() 94 | 95 | angle_loss = (4 / (torch.pi**2) ) * (torch.square(torch.arctan((pred_centery) / (pred_centerx + smooth)) 96 | - torch.arctan((target_centery) / (target_centerx + smooth)))) 97 | 98 | pred_length = torch.sqrt(pred_centerx*pred_centerx + pred_centery*pred_centery + smooth) 99 | target_length = torch.sqrt(target_centerx*target_centerx + target_centery*target_centery + smooth) 100 | 101 | length_loss = (torch.min(pred_length, target_length)) / (torch.max(pred_length, target_length) + smooth) 102 | 103 | loss = loss + (1 - length_loss + angle_loss) / patch_size 104 | 105 | return loss 106 | 107 | 108 | class AverageMeter(object): 109 | """Computes and stores the average and current value""" 110 | 111 | def __init__(self): 112 | self.reset() 113 | 114 | def reset(self): 115 | self.val = 0 116 | self.avg = 0 117 | self.sum = 0 118 | self.count = 0 119 | 120 | def update(self, val, n=1): 121 | self.val = val 122 | self.sum += val * n 123 | self.count += n 124 | self.avg = self.sum / self.count -------------------------------------------------------------------------------- /utils/data.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.utils.data as Data 4 | import torchvision.transforms as transforms 5 | 6 | import os 7 | from PIL import Image, ImageOps, ImageFilter 8 | import os.path as osp 9 | import sys 10 | import random 11 | import shutil 12 | 13 | 14 | class IRSTD_Dataset(Data.Dataset): 15 | def __init__(self, args, mode='train'): 16 | 17 | dataset_dir = args.dataset_dir 18 | 19 | if mode == 'train': 20 | txtfile = 'trainval.txt' 21 | elif mode == 'val': 22 | txtfile = 'test.txt' 23 | 24 | self.list_dir = osp.join(dataset_dir, txtfile) 25 | self.imgs_dir = osp.join(dataset_dir, 'images') 26 | self.label_dir = osp.join(dataset_dir, 'masks') 27 | 28 | self.names = [] 29 | with open(self.list_dir, 'r') as f: 30 | self.names += [line.strip() for line in f.readlines()] 31 | 32 | 33 | self.mode = mode 34 | self.crop_size = args.crop_size 35 | self.base_size = args.base_size 36 | self.transform = transforms.Compose([ 37 | transforms.ToTensor(), 38 | transforms.Normalize([.485, .456, .406], [.229, .224, .225]), 39 | ]) 40 | 41 | def __getitem__(self, i): 42 | name = self.names[i] 43 | img_path = osp.join(self.imgs_dir, name+'.png') 44 | label_path = osp.join(self.label_dir, name+'.png') 45 | 46 | img = Image.open(img_path).convert('RGB') 47 | mask = Image.open(label_path) 48 | 49 | if self.mode == 'train': 50 | img, mask = self._sync_transform(img, mask) 51 | elif self.mode == 'val': 52 | img, mask = self._testval_sync_transform(img, mask) 53 | else: 54 | raise ValueError("Unkown self.mode") 55 | 56 | 57 | img, mask = self.transform(img), transforms.ToTensor()(mask) 58 | return img, mask 59 | 60 | def __len__(self): 61 | return len(self.names) 62 | 63 | def _sync_transform(self, img, mask): 64 | # random mirror 65 | if random.random() < 0.5: 66 | img = img.transpose(Image.FLIP_LEFT_RIGHT) 67 | mask = mask.transpose(Image.FLIP_LEFT_RIGHT) 68 | crop_size = self.crop_size 69 | # random scale (short edge) 70 | long_size = random.randint(int(self.base_size * 0.5), int(self.base_size * 2.0)) 71 | w, h = img.size 72 | if h > w: 73 | oh = long_size 74 | ow = int(1.0 * w * long_size / h + 0.5) 75 | short_size = ow 76 | else: 77 | ow = long_size 78 | oh = int(1.0 * h * long_size / w + 0.5) 79 | short_size = oh 80 | img = img.resize((ow, oh), Image.BILINEAR) 81 | mask = mask.resize((ow, oh), Image.NEAREST) 82 | # pad crop 83 | if short_size < crop_size: 84 | padh = crop_size - oh if oh < crop_size else 0 85 | padw = crop_size - ow if ow < crop_size else 0 86 | img = ImageOps.expand(img, border=(0, 0, padw, padh), fill=0) 87 | mask = ImageOps.expand(mask, border=(0, 0, padw, padh), fill=0) 88 | # random crop crop_size 89 | w, h = img.size 90 | x1 = random.randint(0, w - crop_size) 91 | y1 = random.randint(0, h - crop_size) 92 | img = img.crop((x1, y1, x1 + crop_size, y1 + crop_size)) 93 | mask = mask.crop((x1, y1, x1 + crop_size, y1 + crop_size)) 94 | # gaussian blur as in PSP 95 | if random.random() < 0.5: 96 | img = img.filter(ImageFilter.GaussianBlur( 97 | radius=random.random())) 98 | return img, mask 99 | 100 | 101 | def _testval_sync_transform(self, img, mask): 102 | base_size = self.base_size 103 | img = img.resize((base_size, base_size), Image.BILINEAR) 104 | mask = mask.resize((base_size, base_size), Image.NEAREST) 105 | 106 | return img, mask 107 | -------------------------------------------------------------------------------- /utils/metric.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch.nn as nn 3 | import torch 4 | from skimage import measure 5 | import numpy 6 | 7 | class ROCMetric(): 8 | """Computes pixAcc and mIoU metric scores 9 | """ 10 | def __init__(self, nclass, bins): #bin的意义实际上是确定ROC曲线上的threshold取多少个离散值 11 | super(ROCMetric, self).__init__() 12 | self.nclass = nclass 13 | self.bins = bins 14 | self.tp_arr = np.zeros(self.bins+1) 15 | self.pos_arr = np.zeros(self.bins+1) 16 | self.fp_arr = np.zeros(self.bins+1) 17 | self.neg_arr = np.zeros(self.bins+1) 18 | self.class_pos=np.zeros(self.bins+1) 19 | # self.reset() 20 | 21 | def update(self, preds, labels): 22 | for iBin in range(self.bins+1): 23 | score_thresh = (iBin + 0.0) / self.bins 24 | # print(iBin, "-th, score_thresh: ", score_thresh) 25 | i_tp, i_pos, i_fp, i_neg,i_class_pos = cal_tp_pos_fp_neg(preds, labels, self.nclass,score_thresh) 26 | self.tp_arr[iBin] += i_tp 27 | self.pos_arr[iBin] += i_pos 28 | self.fp_arr[iBin] += i_fp 29 | self.neg_arr[iBin] += i_neg 30 | self.class_pos[iBin]+=i_class_pos 31 | 32 | def get(self): 33 | 34 | tp_rates = self.tp_arr / (self.pos_arr + 0.001) 35 | fp_rates = self.fp_arr / (self.neg_arr + 0.001) 36 | 37 | recall = self.tp_arr / (self.pos_arr + 0.001) 38 | precision = self.tp_arr / (self.class_pos + 0.001) 39 | 40 | 41 | return tp_rates, fp_rates, recall, precision 42 | 43 | def reset(self): 44 | 45 | self.tp_arr = np.zeros([11]) 46 | self.pos_arr = np.zeros([11]) 47 | self.fp_arr = np.zeros([11]) 48 | self.neg_arr = np.zeros([11]) 49 | self.class_pos= np.zeros([11]) 50 | 51 | 52 | 53 | class PD_FA(): 54 | def __init__(self, nclass, bins, size): 55 | super(PD_FA, self).__init__() 56 | self.nclass = nclass 57 | self.bins = bins 58 | self.image_area_total = [] 59 | self.image_area_match = [] 60 | self.FA = np.zeros(self.bins+1) 61 | self.PD = np.zeros(self.bins + 1) 62 | self.target= np.zeros(self.bins + 1) 63 | self.size = size 64 | def update(self, preds, labels): 65 | 66 | for iBin in range(self.bins+1): 67 | score_thresh = iBin * (255/self.bins) 68 | predits = np.array((preds > score_thresh).cpu()).astype('int64') 69 | 70 | predits = np.reshape(predits, (self.size, self.size)) 71 | labelss = np.array((labels).cpu()).astype('int64') 72 | labelss = np.reshape(labelss, (self.size, self.size)) 73 | 74 | image = measure.label(predits, connectivity=2) 75 | coord_image = measure.regionprops(image) 76 | label = measure.label(labelss , connectivity=2) 77 | coord_label = measure.regionprops(label) 78 | 79 | self.target[iBin] += len(coord_label) 80 | self.image_area_total = [] 81 | self.image_area_match = [] 82 | self.distance_match = [] 83 | self.dismatch = [] 84 | 85 | for K in range(len(coord_image)): 86 | area_image = np.array(coord_image[K].area) 87 | self.image_area_total.append(area_image) 88 | 89 | for i in range(len(coord_label)): 90 | centroid_label = np.array(list(coord_label[i].centroid)) 91 | for m in range(len(coord_image)): 92 | centroid_image = np.array(list(coord_image[m].centroid)) 93 | distance = np.linalg.norm(centroid_image - centroid_label) 94 | area_image = np.array(coord_image[m].area) 95 | if distance < 3: 96 | self.distance_match.append(distance) 97 | self.image_area_match.append(area_image) 98 | 99 | del coord_image[m] 100 | break 101 | 102 | self.dismatch = [x for x in self.image_area_total if x not in self.image_area_match] 103 | self.FA[iBin]+=np.sum(self.dismatch) 104 | self.PD[iBin]+=len(self.distance_match) 105 | 106 | def get(self,img_num): 107 | 108 | Final_FA = self.FA / ((self.size*self.size) * img_num) 109 | Final_PD = self.PD /self.target 110 | 111 | return Final_FA,Final_PD 112 | 113 | 114 | def reset(self): 115 | self.FA = np.zeros([self.bins+1]) 116 | self.PD = np.zeros([self.bins+1]) 117 | 118 | class mIoU(): 119 | 120 | def __init__(self, nclass): 121 | super(mIoU, self).__init__() 122 | self.nclass = nclass 123 | self.reset() 124 | 125 | def update(self, preds, labels): 126 | # print('come_ininin') 127 | 128 | correct, labeled = batch_pix_accuracy(preds, labels) 129 | inter, union = batch_intersection_union(preds, labels, self.nclass) 130 | self.total_correct += correct 131 | self.total_label += labeled 132 | self.total_inter += inter 133 | self.total_union += union 134 | 135 | 136 | def get(self): 137 | 138 | pixAcc = 1.0 * self.total_correct / (np.spacing(1) + self.total_label) 139 | IoU = 1.0 * self.total_inter / (np.spacing(1) + self.total_union) 140 | mIoU = IoU.mean() 141 | return pixAcc, mIoU 142 | 143 | def reset(self): 144 | 145 | self.total_inter = 0 146 | self.total_union = 0 147 | self.total_correct = 0 148 | self.total_label = 0 149 | 150 | 151 | 152 | 153 | def cal_tp_pos_fp_neg(output, target, nclass, score_thresh): 154 | 155 | predict = (torch.sigmoid(output) > score_thresh).float() 156 | if len(target.shape) == 3: 157 | target = np.expand_dims(target.float(), axis=1) 158 | elif len(target.shape) == 4: 159 | target = target.float() 160 | else: 161 | raise ValueError("Unknown target dimension") 162 | 163 | intersection = predict * ((predict == target).float()) 164 | 165 | tp = intersection.sum() 166 | fp = (predict * ((predict != target).float())).sum() 167 | tn = ((1 - predict) * ((predict == target).float())).sum() 168 | fn = (((predict != target).float()) * (1 - predict)).sum() 169 | pos = tp + fn 170 | neg = fp + tn 171 | class_pos= tp+fp 172 | 173 | return tp, pos, fp, neg, class_pos 174 | 175 | def batch_pix_accuracy(output, target): 176 | 177 | if len(target.shape) == 3: 178 | target = np.expand_dims(target.float(), axis=1) 179 | elif len(target.shape) == 4: 180 | target = target.float() 181 | else: 182 | raise ValueError("Unknown target dimension") 183 | 184 | assert output.shape == target.shape, "Predict and Label Shape Don't Match" 185 | predict = (output > 0).float() 186 | pixel_labeled = (target > 0).float().sum() 187 | pixel_correct = (((predict == target).float())*((target > 0)).float()).sum() 188 | 189 | 190 | 191 | assert pixel_correct <= pixel_labeled, "Correct area should be smaller than Labeled" 192 | return pixel_correct, pixel_labeled 193 | 194 | 195 | def batch_intersection_union(output, target, nclass): 196 | 197 | mini = 1 198 | maxi = 1 199 | nbins = 1 200 | predict = (output > 0).float() 201 | if len(target.shape) == 3: 202 | target = np.expand_dims(target.float(), axis=1) 203 | elif len(target.shape) == 4: 204 | target = target.float() 205 | else: 206 | raise ValueError("Unknown target dimension") 207 | intersection = predict * ((predict == target).float()) 208 | 209 | area_inter, _ = np.histogram(intersection.cpu(), bins=nbins, range=(mini, maxi)) 210 | area_pred, _ = np.histogram(predict.cpu(), bins=nbins, range=(mini, maxi)) 211 | area_lab, _ = np.histogram(target.cpu(), bins=nbins, range=(mini, maxi)) 212 | area_union = area_pred + area_lab - area_inter 213 | 214 | assert (area_inter <= area_union).all(), \ 215 | "Error: Intersection area should be smaller than Union area" 216 | return area_inter, area_union 217 | 218 | --------------------------------------------------------------------------------