├── Eval_PD.py ├── Overview.png ├── README.md ├── datasets ├── PD_random.py └── foo.txt ├── models ├── FCN_16s.py ├── FCN_32s.py └── LANet.py ├── train_PD.py └── utils ├── __init__.py ├── crf.py ├── data_vis.py ├── eval.py ├── foo.txt ├── load.py ├── loss.py ├── misc.py ├── transform.py └── utils.py /Eval_PD.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import time 4 | import numpy as np 5 | import torch.autograd 6 | from skimage import io 7 | import torch.nn.functional as F 8 | from torch.utils.data import DataLoader 9 | 10 | ################################# 11 | from models.LANet import LANet as Net 12 | NET_NAME = 'LANet' 13 | DATA_NAME = 'PD' 14 | from datasets import PD_random as PD 15 | ################################# 16 | 17 | from utils.loss import CrossEntropyLoss2d 18 | from utils.utils import accuracy, intersectionAndUnion, AverageMeter, CaclTP 19 | 20 | working_path = os.path.abspath('.') 21 | args = { 22 | 'gpu': True, 23 | 's_class': 0, 24 | 'val_batch_size': 1, 25 | 'val_crop_size': 1024, 26 | 'data_dir': 'YOUR_DATA_DIR', 27 | 'load_path': os.path.join(working_path, 'checkpoints', DATA_NAME, 'LANet_0e_OA80.21.pth') 28 | } 29 | 30 | def norm_gray(x, out_range=(0, 255)): 31 | #x=x*(x>0) 32 | domain = np.min(x), np.max(x) 33 | #print(np.min(x)) 34 | #print(np.max(x)) 35 | y = (x - (domain[1] + domain[0]) / 2) / (domain[1] - domain[0] + 1e-10) 36 | y = y * (out_range[1] - out_range[0]) + (out_range[1] + out_range[0]) / 2 37 | return y.astype('uint8') 38 | 39 | def main(): 40 | net = Net(5, num_classes=PD.num_classes+1) 41 | net.load_state_dict(torch.load(args['load_path']) )#, strict = False 42 | net = net.cuda() 43 | net.eval() 44 | print('Model loaded.') 45 | pred_path = os.path.join(args['data_dir'], 'Eval', NET_NAME) 46 | if not os.path.exists(pred_path): os.makedirs(pred_path) 47 | info_txt_path = os.path.join(pred_path, 'info.txt') 48 | f = open(info_txt_path, 'w+') 49 | 50 | val_set = PD.Loader(args['data_dir'], 'val', sliding_crop=True, crop_size=args['val_crop_size'], padding=False) # 51 | val_loader = DataLoader(val_set, batch_size=args['val_batch_size'], num_workers=4, shuffle=False) 52 | predict(net, val_loader, pred_path, args, f) 53 | f.close() 54 | 55 | def predict(net, pred_loader, pred_path, args, f_out=None): 56 | acc_meter = AverageMeter() 57 | TP_meter = AverageMeter() 58 | pred_meter = AverageMeter() 59 | label_meter = AverageMeter() 60 | Union_meter = AverageMeter() 61 | output_info = f_out is not None 62 | 63 | for vi, data in enumerate(pred_loader): 64 | with torch.no_grad(): 65 | img, label = data 66 | if args['gpu']: 67 | img = img.cuda().float() 68 | label = label.cuda().float() 69 | output, _ = net(img) 70 | 71 | output = output.detach().cpu() 72 | pred = torch.argmax(output, dim=1) 73 | pred = pred.squeeze(0).numpy() 74 | 75 | label = label.detach().cpu().numpy() 76 | acc, _ = accuracy(pred, label) 77 | acc_meter.update(acc) 78 | pred_color = PD.Index2Color(pred) 79 | img = img.detach().cpu().numpy().squeeze().transpose((1, 2, 0))[:,:,:3] 80 | img = norm_gray(img) 81 | pred_name = os.path.join(pred_path, '%d.png'%vi) 82 | io.imsave(pred_name, pred_color) 83 | TP, pred_hist, label_hist, union_hist = CaclTP(pred, label, PD.num_classes) 84 | TP_meter.update(TP) 85 | pred_meter.update(pred_hist) 86 | label_meter.update(label_hist) 87 | Union_meter.update(union_hist) 88 | print('Eval num %d/%d, Acc %.2f'%(vi, len(pred_loader), acc*100)) 89 | if output_info: 90 | f_out.write('Eval num %d/%d, Acc %.2f\n'%(vi, len(pred_loader), acc*100)) 91 | 92 | precision = TP_meter.sum / (label_meter.sum + 1e-10) + 1e-10 93 | recall = TP_meter.sum / (pred_meter.sum + 1e-10) + 1e-10 94 | F1 = [stats.hmean([pre, rec]) for pre, rec in zip(precision, recall)] 95 | F1 = np.array(F1) 96 | IoU = TP_meter.sum / Union_meter.sum 97 | IoU = np.array(IoU) 98 | 99 | print(output.shape) 100 | print('Acc %.2f'%(acc_meter.avg*100)) 101 | avg_F = F1[:-1].mean() 102 | mIoU = IoU[:-1].mean() 103 | print('Avg F1 %.2f'%(avg_F*100)) 104 | print(np.array2string(F1 * 100, precision=4, separator=', ', formatter={'float_kind': lambda x: "%.2f" % x})) 105 | print('mIoU %.2f'%(mIoU*100)) 106 | print(np.array2string(IoU * 100, precision=4, separator=', ', formatter={'float_kind': lambda x: "%.2f" % x})) 107 | if output_info: 108 | f_out.write('Acc %.2f\n'%(acc_meter.avg*100)) 109 | f_out.write('Avg F1 %.2f\n'%(avg_F*100)) 110 | f_out.write(np.array2string(F1 * 100, precision=4, separator=', ', formatter={'float_kind': lambda x: "%.2f" % x})) 111 | f_out.write('\nmIoU %.2f\n'%(mIoU*100)) 112 | f_out.write(np.array2string(IoU * 100, precision=4, separator=', ', formatter={'float_kind': lambda x: "%.2f" % x})) 113 | return avg_F 114 | 115 | 116 | if __name__ == '__main__': 117 | main() 118 | -------------------------------------------------------------------------------- /Overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DingLei14/LANet/24d21d39979e9549f5fa35cf2dcbc54973f5b079/Overview.png -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # LANet 2 | Pytorch codes for ['LANet: Local Attention Embedding to Improve the Semantic Segmentation of Remote Sensing Images'](https://ieeexplore.ieee.org/document/9102424) 3 | 4 | ![alt text](https://github.com/ggsDing/LANet/blob/master/Overview.png) 5 | 6 | **How to Use** 7 | 1. Split the data into training, validation and test set and organize them as follows: 8 | 9 | >YOUR_DATA_DIR 10 | > - Train 11 | > - image 12 | > - label 13 | > - Val 14 | > - image 15 | > - label 16 | > - Test 17 | > - image 18 | > - label 19 | 20 | 2. Change the training parameters in *train_PD.py*, especially the data directory. 21 | 22 | 3. To evaluate, change also the parameters in *eval_PD.py*, especially the data directory and the checkpoint path. 23 | 24 | 25 | If you find this work useful, please consider to cite: 26 | 27 | >'Ding L, Tang H, Bruzzone L. LANet: Local Attention Embedding to Improve the Semantic Segmentation of Remote Sensing Images[J]. IEEE Transactions on Geoscience and Remote Sensing, 2020.' 28 | -------------------------------------------------------------------------------- /datasets/PD_random.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import torch 4 | import numpy as np 5 | from skimage import io 6 | from torch.utils import data 7 | import matplotlib.pyplot as plt 8 | from torchvision import transforms 9 | import utils.transform as transform 10 | from skimage.transform import rescale 11 | from torchvision.transforms import functional as F 12 | 13 | num_classes = 6 14 | PD_COLORMAP = [[0, 0, 0], [255, 255, 255], [0, 0, 255], [0, 255, 255], 15 | [0, 255, 0], [255, 255, 0], [255, 0, 0] ] 16 | PD_CLASSES = ['Invalid', 'Impervious surfaces','Building', 'Low vegetation', 17 | 'Tree', 'Car', 'Clutter/background'] 18 | # PD_MEAN = np.array([0.33885107, 0.36215387, 0.33536868, 0.38485747]) 19 | # PD_STD = np.array([0.14027526, 0.13798502, 0.14333207, 0.14513438]) 20 | PD_MEAN = np.array([85.8, 91.7, 84.9, 96.6, 47]) 21 | PD_STD = np.array([35.8, 35.2, 36.5, 37, 55]) 22 | 23 | def BGRI2RGB(img): 24 | r = img[0, :, :] 25 | g = img[1, :, :] 26 | b = img[2, :, :] 27 | i = img[3, :, :] 28 | img = cv2.merge([r, g, b, i]) 29 | return img 30 | 31 | def showIMG(img): 32 | plt.imshow(img) 33 | plt.show() 34 | return 0 35 | 36 | def normalize_image(im): 37 | return (im - PD_MEAN) / PD_STD 38 | 39 | def normalize_images(imgs): 40 | for i, im in enumerate(imgs): 41 | imgs[i] = normalize_image(im) 42 | return imgs 43 | 44 | colormap2label = np.zeros(256 ** 3) 45 | for i, cm in enumerate(PD_COLORMAP): 46 | colormap2label[(cm[0] * 256 + cm[1]) * 256 + cm[2]] = i 47 | 48 | def Index2Color(pred): 49 | colormap = np.asarray(PD_COLORMAP, dtype='uint8') 50 | x = np.asarray(pred, dtype='int32') 51 | return colormap[x, :] 52 | 53 | def Colorls2Index(ColorLabels): 54 | for i, data in enumerate(ColorLabels): 55 | ColorLabels[i] = Color2Index(data) 56 | return ColorLabels 57 | 58 | def Color2Index(ColorLabel): 59 | data = ColorLabel.astype(np.int32) 60 | idx = (data[:, :, 0] * 256 + data[:, :, 1]) * 256 + data[:, :, 2] 61 | IndexMap = colormap2label[idx] 62 | #IndexMap = 2*(IndexMap > 1) + 1 * (IndexMap <= 1) 63 | IndexMap = IndexMap * (IndexMap <= num_classes) 64 | return IndexMap.astype(np.uint8) 65 | 66 | def get_file_name(mode='train'): 67 | assert mode in ['train', 'val'] 68 | if mode == 'train': 69 | img_path = os.path.join(data_dir, 'train') 70 | pred_path = os.path.join(data_dir, 'numpy', 'train') 71 | else: 72 | img_path = os.path.join(data_dir, 'val') 73 | pred_path = os.path.join(data_dir, 'numpy', 'val') 74 | 75 | data_list = os.listdir(img_path) 76 | numpy_path_list = [os.path.join(pred_path, it) for it in data_list] 77 | return numpy_path_list 78 | 79 | def read_RSimages(data_dir, mode): 80 | assert mode in ['train', 'val', 'test'] 81 | if mode == 'test': 82 | img_path = os.path.join(data_dir, 'test') 83 | data_list = os.listdir(img_path) 84 | imgs = [] 85 | for it in data_list: 86 | im = io.imread(os.path.join(img_path, it)) 87 | imgs.append(im) 88 | return imgs 89 | 90 | img_path = os.path.join(data_dir, mode) 91 | dsm_path = os.path.join(data_dir, mode, 'dsm') 92 | mask_path = os.path.join(data_dir, 'groundtruth_noBoundary') #'groundtruth' 93 | data_list = os.listdir(img_path) 94 | data, labels = [], [] 95 | count=0 96 | for it in data_list: 97 | # print(it) 98 | if (it[-4:]=='.tif'): 99 | dsm_name = 'dsm' + it[3:-10] + '.jpg' 100 | mask_name = it[:-10] + '_label_noBoundary.tif' #'_label.tif' 101 | fpath = os.path.join(img_path, it) 102 | dsm_fpath = os.path.join(dsm_path, dsm_name) 103 | mask_fpath = os.path.join(mask_path, mask_name) 104 | print(dsm_fpath) 105 | ext = os.path.splitext(it)[-1] 106 | if(ext == '.tif'): 107 | img = io.imread(fpath) 108 | dsm = io.imread(dsm_fpath) 109 | img = np.concatenate((img, np.expand_dims(dsm, axis=2)), axis=2) 110 | label = io.imread(mask_fpath) 111 | data.append(img) 112 | labels.append(label) 113 | count+=1 114 | #if count>1: break 115 | print(data[0].shape) 116 | print(str(len(data)) + ' ' + mode + ' images' + ' loaded.') 117 | return data, labels 118 | 119 | def rescale_images(imgs, scale, order=0): 120 | for i, im in enumerate(imgs): 121 | imgs[i] = rescale_image(im, scale, order) 122 | return imgs 123 | 124 | def rescale_image(img, scale=1/8, order=0): 125 | flag = cv2.INTER_NEAREST 126 | if order==1: flag = cv2.INTER_LINEAR 127 | elif order==2: flag = cv2.INTER_AREA 128 | elif order>2: flag = cv2.INTER_CUBIC 129 | im_rescaled = cv2.resize(img, (int(img.shape[0]*scale), int(img.shape[1]*scale)), 130 | interpolation=flag) 131 | return im_rescaled 132 | 133 | class Loader(data.Dataset): 134 | def __init__(self, data_dir, mode, random_crop=False, crop_nums=40, random_flip = False, sliding_crop=False, crop_size=640/8, padding=False): 135 | self.crop_size = crop_size 136 | self.crop_nums = crop_nums 137 | self.random_flip = random_flip 138 | self.random_crop = random_crop 139 | data, labels = read_RSimages(data_dir, mode) 140 | if sliding_crop: 141 | data, labels = transform.create_crops(data, labels, [self.crop_size, self.crop_size]) 142 | if padding: 143 | data, labels = transform.data_padding(data, labels, scale=16) 144 | self.data = data 145 | self.labels = Colorls2Index(labels) 146 | 147 | if self.random_crop: 148 | self.len = crop_nums*len(self.data) 149 | else: 150 | self.len = len(self.data) 151 | 152 | def __getitem__(self, idx): 153 | if self.random_crop: 154 | idx = int(idx/self.crop_nums) 155 | data, label = transform.random_crop(self.data[idx], self.labels[idx], size=[self.crop_size, self.crop_size]) 156 | else: 157 | data = self.data[idx] 158 | label = self.labels[idx] 159 | if self.random_flip: 160 | data, label = transform.rand_flip(data, label) 161 | 162 | data = normalize_image(data) 163 | data = torch.from_numpy(data.transpose((2, 0, 1))) 164 | return data, label 165 | 166 | def __len__(self): 167 | return self.len 168 | -------------------------------------------------------------------------------- /datasets/foo.txt: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /models/FCN_16s.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torchvision import models 4 | from torch.nn import functional as F 5 | 6 | def conv3x3(in_planes, out_planes, stride=1): 7 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) 8 | 9 | class FCN_res101(nn.Module): 10 | def __init__(self, in_channels=3, num_classes=7, pretrained=True): 11 | super(FCN_res101, self).__init__() 12 | resnet = models.resnet101(pretrained) 13 | newconv1 = nn.Conv2d(in_channels, 64, kernel_size=7, stride=2, padding=3, bias=False) 14 | if in_channels>3: 15 | newconv1.weight.data[:, 3:in_channels, :, :].copy_(resnet.conv1.weight.data[:, 0:in_channels-3, :, :]) 16 | self.layer0 = nn.Sequential(newconv1, resnet.bn1, resnet.relu) 17 | self.maxpool = resnet.maxpool 18 | self.layer1 = resnet.layer1 19 | self.layer2 = resnet.layer2 20 | self.layer3 = resnet.layer3 21 | self.layer4 = resnet.layer4 22 | for n, m in self.layer4.named_modules(): 23 | if 'conv2' in n or 'downsample.0' in n: 24 | m.stride = (1, 1) 25 | 26 | self.head = nn.Sequential(nn.Conv2d(2048, 128, kernel_size=1, stride=1, padding=0, bias=False), 27 | nn.BatchNorm2d(128, momentum=0.95), 28 | nn.ReLU()) 29 | 30 | self.classifier = nn.Conv2d(128, num_classes, kernel_size=1) 31 | 32 | def forward(self, x): 33 | x_size = x.size() 34 | 35 | x = self.layer0(x) 36 | x = self.maxpool(x) 37 | x = self.layer1(x) 38 | x = self.layer2(x) 39 | x = self.layer3(x) 40 | x = self.layer4(x) 41 | x = self.head(x) 42 | x = self.classifier(x) 43 | 44 | return F.upsample(x, x_size[2:], mode='bilinear') 45 | 46 | class FCN_res50(nn.Module): 47 | def __init__(self, in_channels=3, num_classes=7, pretrained=True): 48 | super(FCN_res50, self).__init__() 49 | resnet = models.resnet50(pretrained) 50 | newconv1 = nn.Conv2d(in_channels, 64, kernel_size=7, stride=2, padding=3, bias=False) 51 | if in_channels>3: 52 | newconv1.weight.data[:, 3:in_channels, :, :].copy_(resnet.conv1.weight.data[:, 0:in_channels-3, :, :]) 53 | self.layer0 = nn.Sequential(newconv1, resnet.bn1, resnet.relu) 54 | self.maxpool = resnet.maxpool 55 | self.layer1 = resnet.layer1 56 | self.layer2 = resnet.layer2 57 | self.layer3 = resnet.layer3 58 | self.layer4 = resnet.layer4 59 | for n, m in self.layer4.named_modules(): 60 | if 'conv2' in n or 'downsample.0' in n: 61 | m.stride = (1, 1) 62 | 63 | self.head = nn.Sequential(nn.Conv2d(2048, 128, kernel_size=1, stride=1, padding=0, bias=False), 64 | nn.BatchNorm2d(128, momentum=0.95), 65 | nn.ReLU()) 66 | 67 | self.classifier = nn.Sequential( 68 | nn.Conv2d(128, 128, kernel_size=1), 69 | nn.BatchNorm2d(128, momentum=0.95), 70 | nn.ReLU(), 71 | #nn.Dropout(0.1), 72 | nn.Conv2d(128, num_classes, kernel_size=1) 73 | ) 74 | 75 | def forward(self, x): 76 | x_size = x.size() 77 | 78 | x = self.layer0(x) 79 | x = self.maxpool(x) 80 | x = self.layer1(x) 81 | x = self.layer2(x) 82 | x = self.layer3(x) 83 | x = self.layer4(x) 84 | x = self.head(x) 85 | x = self.classifier(x) 86 | 87 | return F.upsample(x, x_size[2:], mode='bilinear') 88 | 89 | class FCN_res18(nn.Module): 90 | def __init__(self, in_channels=3, num_classes=7, pretrained=True): 91 | super(FCN_res18, self).__init__() 92 | resnet = models.resnet18(pretrained) 93 | newconv1 = nn.Conv2d(in_channels, 64, kernel_size=7, stride=2, padding=3, bias=False) 94 | newconv1.weight.data[:, 0:in_channels, :, :].copy_(resnet.conv1.weight.data[:, 0:in_channels, :, :]) 95 | if in_channels>3: 96 | newconv1.weight.data[:, 3:in_channels, :, :].copy_(resnet.conv1.weight.data[:, 0:in_channels-3, :, :]) 97 | 98 | self.layer0 = nn.Sequential(newconv1, resnet.bn1, resnet.relu) 99 | self.maxpool = resnet.maxpool 100 | self.layer1 = resnet.layer1 101 | self.layer2 = resnet.layer2 102 | self.layer3 = resnet.layer3 103 | self.layer4 = resnet.layer4 104 | for n, m in self.layer4.named_modules(): 105 | if 'conv2' in n or 'downsample.0' in n: 106 | m.stride = (1, 1) 107 | 108 | self.head = nn.Sequential(nn.Conv2d(512, 128, kernel_size=1, stride=1, padding=0, bias=False), 109 | nn.BatchNorm2d(128, momentum=0.95), 110 | nn.ReLU()) 111 | 112 | self.classifier = nn.Sequential( 113 | nn.Conv2d(128, 128, kernel_size=1), 114 | nn.BatchNorm2d(128, momentum=0.95), 115 | nn.ReLU(), 116 | nn.Conv2d(128, num_classes, kernel_size=1) 117 | ) 118 | 119 | def forward(self, x): 120 | x_size = x.size() 121 | 122 | x0 = self.layer0(x) #size:1/2 123 | x = self.maxpool(x0) #size:1/4 124 | x = self.layer1(x) #size:1/4 125 | x = self.layer2(x) #size:1/8 126 | x = self.layer3(x) #size:1/16 127 | x = self.layer4(x) 128 | x = self.head(x) 129 | x = self.classifier(x) 130 | 131 | return F.upsample(x, x_size[2:], mode='bilinear') 132 | -------------------------------------------------------------------------------- /models/FCN_32s.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torchvision import models 4 | from torch.nn import functional as F 5 | 6 | def conv3x3(in_planes, out_planes, stride=1): 7 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) 8 | 9 | class FCN_res101(nn.Module): 10 | def __init__(self, in_channels=3, num_classes=7, pretrained=True): 11 | super(FCN_res101, self).__init__() 12 | resnet = models.resnet101(pretrained) 13 | newconv1 = nn.Conv2d(in_channels, 64, kernel_size=7, stride=2, padding=3, bias=False) 14 | if in_channels>3: 15 | newconv1.weight.data[:, 3:in_channels, :, :].copy_(resnet.conv1.weight.data[:, 0:in_channels-3, :, :]) 16 | self.layer0 = nn.Sequential(newconv1, resnet.bn1, resnet.relu) 17 | self.maxpool = resnet.maxpool 18 | self.layer1 = resnet.layer1 19 | self.layer2 = resnet.layer2 20 | self.layer3 = resnet.layer3 21 | self.layer4 = resnet.layer4 22 | 23 | self.head = nn.Sequential(nn.Conv2d(2048, 128, kernel_size=1, stride=1, padding=0, bias=False), 24 | nn.BatchNorm2d(128, momentum=0.95), 25 | nn.ReLU()) 26 | 27 | self.classifier = nn.Conv2d(128, num_classes, kernel_size=1) 28 | 29 | def forward(self, x): 30 | x_size = x.size() 31 | 32 | x = self.layer0(x) 33 | x = self.maxpool(x) 34 | x = self.layer1(x) 35 | x = self.layer2(x) 36 | x = self.layer3(x) 37 | x = self.layer4(x) 38 | x = self.head(x) 39 | x = self.classifier(x) 40 | 41 | return F.upsample(x, x_size[2:], mode='bilinear') 42 | 43 | class FCN_res50(nn.Module): 44 | def __init__(self, in_channels=3, num_classes=7, pretrained=True): 45 | super(FCN_res50, self).__init__() 46 | resnet = models.resnet50(pretrained) 47 | newconv1 = nn.Conv2d(in_channels, 64, kernel_size=7, stride=2, padding=3, bias=False) 48 | if in_channels>3: 49 | newconv1.weight.data[:, 3:in_channels, :, :].copy_(resnet.conv1.weight.data[:, 0:in_channels-3, :, :]) 50 | self.layer0 = nn.Sequential(newconv1, resnet.bn1, resnet.relu) 51 | self.maxpool = resnet.maxpool 52 | self.layer1 = resnet.layer1 53 | self.layer2 = resnet.layer2 54 | self.layer3 = resnet.layer3 55 | self.layer4 = resnet.layer4 56 | 57 | self.head = nn.Sequential(nn.Conv2d(2048, 128, kernel_size=1, stride=1, padding=0, bias=False), 58 | nn.BatchNorm2d(128, momentum=0.95), 59 | nn.ReLU()) 60 | 61 | self.classifier = nn.Sequential( 62 | nn.Conv2d(128, 128, kernel_size=1), 63 | nn.BatchNorm2d(128, momentum=0.95), 64 | nn.ReLU(), 65 | #nn.Dropout(0.1), 66 | nn.Conv2d(128, num_classes, kernel_size=1) 67 | ) 68 | 69 | def forward(self, x): 70 | x_size = x.size() 71 | 72 | x = self.layer0(x) 73 | x = self.maxpool(x) 74 | x = self.layer1(x) 75 | x = self.layer2(x) 76 | x = self.layer3(x) 77 | x = self.layer4(x) 78 | x = self.head(x) 79 | x = self.classifier(x) 80 | 81 | return F.upsample(x, x_size[2:], mode='bilinear') 82 | 83 | class FCN_res18(nn.Module): 84 | def __init__(self, in_channels=3, num_classes=7, pretrained=True): 85 | super(FCN_res18, self).__init__() 86 | resnet = models.resnet18(pretrained) 87 | newconv1 = nn.Conv2d(in_channels, 64, kernel_size=7, stride=2, padding=3, bias=False) 88 | newconv1.weight.data[:, 0:in_channels, :, :].copy_(resnet.conv1.weight.data[:, 0:in_channels, :, :]) 89 | if in_channels>3: 90 | newconv1.weight.data[:, 3:in_channels, :, :].copy_(resnet.conv1.weight.data[:, 0:in_channels-3, :, :]) 91 | 92 | self.layer0 = nn.Sequential(newconv1, resnet.bn1, resnet.relu) 93 | self.maxpool = resnet.maxpool 94 | self.layer1 = resnet.layer1 95 | self.layer2 = resnet.layer2 96 | self.layer3 = resnet.layer3 97 | self.layer4 = resnet.layer4 98 | 99 | self.head = nn.Sequential(nn.Conv2d(512, 128, kernel_size=1, stride=1, padding=0, bias=False), 100 | nn.BatchNorm2d(128, momentum=0.95), 101 | nn.ReLU()) 102 | 103 | self.classifier = nn.Sequential( 104 | nn.Conv2d(128, 128, kernel_size=1), 105 | nn.BatchNorm2d(128, momentum=0.95), 106 | nn.ReLU(), 107 | nn.Conv2d(128, num_classes, kernel_size=1) 108 | ) 109 | 110 | def forward(self, x): 111 | x_size = x.size() 112 | 113 | x0 = self.layer0(x) #size:1/2 114 | x = self.maxpool(x0) #size:1/4 115 | x = self.layer1(x) #size:1/4 116 | x = self.layer2(x) #size:1/8 117 | x = self.layer3(x) #size:1/16 118 | x = self.layer4(x) 119 | x = self.head(x) 120 | x = self.classifier(x) 121 | 122 | return F.upsample(x, x_size[2:], mode='bilinear') 123 | -------------------------------------------------------------------------------- /models/LANet.py: -------------------------------------------------------------------------------- 1 | # model codes for 'LANet: Local Attention Embedding to Improve the Semantic Segmentation of Remote Sensing Images[J]. IEEE Transactions on Geoscience and Remote Sensing, 2020.' 2 | import torch.nn as nn 3 | from torch.nn import functional as F 4 | import torch 5 | #from models.FCN_32s import FCN_res50 as FCN 6 | from models.FCN_16s import FCN_res50 as FCN 7 | 8 | def conv3x3(in_planes, out_planes, stride=1): 9 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) 10 | 11 | # ASP module from: 'Multi-scale context aggregation for semantic segmentation of remote sensing images[J]. Remote Sensing.' 12 | class ASP(nn.Module): 13 | def __init__(self, in_channels, in_stride, reduction=4, RF=(320, 160, 80, 40)): 14 | super(ASP, self).__init__() 15 | self.strides = [R // in_stride for R in RF] 16 | out_channels = in_channels // reduction 17 | 18 | self.stages = [] 19 | self.stages = nn.ModuleList([self._make_stage(in_channels, out_channels) for i in range(4)]) 20 | self.bottleneck = nn.Sequential( 21 | nn.Conv2d(in_channels+4*out_channels, in_channels, kernel_size=3, padding=1, dilation=1, bias=False), 22 | nn.BatchNorm2d(in_channels), nn.ReLU() 23 | ) 24 | 25 | def _make_stage(self, in_channels, out_channels): 26 | #prior = nn.AdaptiveAvgPool2d(output_size=(size, size)) 27 | conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False) 28 | bn = nn.BatchNorm2d(out_channels) 29 | relu = nn.ReLU() 30 | return nn.Sequential(conv, bn, relu) 31 | 32 | def forward(self, feats): 33 | h, w = feats.size()[2:] 34 | 35 | priors = [feats] 36 | for idx, stage in enumerate(self.stages): 37 | h_out = h // self.strides[idx] 38 | w_out = w // self.strides[idx] 39 | feats_avg = F.adaptive_avg_pool2d(feats, [h_out, w_out]) 40 | feats_avg = stage(feats_avg) 41 | priors.append(F.upsample(input=feats_avg, size=(h, w), mode='bilinear', align_corners=True)) 42 | 43 | bottle = self.bottleneck(torch.cat(priors, 1)) 44 | return bottle 45 | 46 | # Patch attention module. Parameters: reduction is the rate of channel reduction. pool_window should be set according to the scaling rate. 47 | class Patch_Attention(nn.Module): 48 | def __init__(self, in_channels, reduction=8, pool_window=10, add_input=False): 49 | super(Patch_Attention, self).__init__() 50 | self.pool_window = pool_window 51 | self.add_input = add_input 52 | self.SA = nn.Sequential( 53 | nn.Conv2d(in_channels, in_channels // reduction, 1), 54 | nn.BatchNorm2d(in_channels // reduction, momentum=0.95), 55 | nn.ReLU(inplace=False), 56 | nn.Conv2d(in_channels // reduction, in_channels, 1), 57 | nn.Sigmoid() 58 | ) 59 | 60 | def forward(self, x): 61 | b, c, h, w = x.size() 62 | pool_h = h//self.pool_window 63 | pool_w = w//self.pool_window 64 | 65 | A = F.adaptive_avg_pool2d(x, (pool_h, pool_w)) 66 | A = self.SA(A) 67 | 68 | A = F.upsample(A, (h,w), mode='bilinear') 69 | output = x*A 70 | if self.add_input: 71 | output += x 72 | 73 | return output 74 | 75 | # Calculate pixel-wise local attention. Costs more computations. 76 | class Patch_AttentionV2(nn.Module): 77 | def __init__(self, in_channels, reduction=16, pool_window=10, add_input=False): 78 | super(Patch_AttentionV2, self).__init__() 79 | self.pool_window = pool_window 80 | self.add_input = add_input 81 | self.SA = nn.Sequential( 82 | nn.AvgPool2d(kernel_size=pool_window+1, stride=1, padding = pool_window//2), 83 | nn.Conv2d(in_channels, in_channels // reduction, 1), 84 | nn.BatchNorm2d(in_channels // reduction, momentum=0.95), 85 | nn.ReLU(inplace=False), 86 | nn.Conv2d(in_channels // reduction, in_channels, 1), 87 | nn.Sigmoid() 88 | ) 89 | 90 | def forward(self, x): 91 | b, c, h, w = x.size() 92 | A = self.SA(x) 93 | 94 | A = F.upsample(A, (h,w), mode='bilinear') 95 | output = x*A 96 | if self.add_input: 97 | output += x 98 | 99 | return output 100 | 101 | # Attention embedding module. Parameters: reduction is the rate of channel reduction. pool_window should be set according to the scaling rate. 102 | class Attention_Embedding(nn.Module): 103 | def __init__(self, in_channels, out_channels, reduction=16, pool_window=6, add_input=False): 104 | super(Attention_Embedding, self).__init__() 105 | self.add_input = add_input 106 | self.SE = nn.Sequential( 107 | nn.AvgPool2d(kernel_size=pool_window+1, stride=1, padding = pool_window//2), 108 | nn.Conv2d(in_channels, in_channels//reduction, 1), 109 | nn.BatchNorm2d(in_channels//reduction, momentum=0.95), 110 | nn.ReLU(inplace=False), 111 | nn.Conv2d(in_channels//reduction, out_channels, 1), 112 | nn.Sigmoid()) 113 | 114 | def forward(self, high_feat, low_feat): 115 | b, c, h, w = low_feat.size() 116 | A = self.SE(high_feat) 117 | A = F.upsample(A, (h,w), mode='bilinear') 118 | 119 | output = low_feat*A 120 | if self.add_input: 121 | output += low_feat 122 | 123 | return output 124 | 125 | class LANet(nn.Module): 126 | def __init__(self, in_channels=3, num_classes=7, pretrained=True): 127 | super(LANet, self).__init__() 128 | self.FCN = FCN(in_channels, num_classes) 129 | 130 | self.PA0 = nn.Sequential(nn.Conv2d(256, 64, kernel_size=1), 131 | nn.BatchNorm2d(64, momentum=0.95), nn.ReLU(inplace=False), 132 | Patch_Attention(64, reduction=8, pool_window=20, add_input=True)) 133 | 134 | self.PA2 = Patch_Attention(128, reduction=16, pool_window=4, add_input=True) 135 | self.AE = Attention_Embedding(128, 64) 136 | 137 | self.classifier0 = nn.Conv2d(64, num_classes, kernel_size=1) 138 | self.classifier1 = nn.Conv2d(128, num_classes, kernel_size=1) 139 | 140 | def forward(self, x): 141 | x_size = x.size() 142 | 143 | x = self.FCN.layer0(x) #size:1/2 144 | x = self.FCN.maxpool(x) #size:1/4 145 | x0 = self.FCN.layer1(x) #size:1/4, C256 146 | x = self.FCN.layer2(x0) #size:1/8, C512 147 | x = self.FCN.layer3(x) #size:1/16, C1024 148 | x = self.FCN.layer4(x) #size:1/16 or 1/32, C2048 149 | x2 = self.FCN.head(x) #size:1/16 or 1/32, C128 150 | 151 | x2 = self.PA2(x2) 152 | x0 = self.PA0(x0) 153 | x0 = self.AE(x2.detach(), x0) 154 | 155 | low = self.classifier0(x0) 156 | low = F.upsample(low, x_size[2:], mode='bilinear') 157 | 158 | high = self.classifier1(x2) 159 | high = F.upsample(high, x_size[2:], mode='bilinear') 160 | 161 | # high-level and low-level features are auxiliary outputs. Recommended loss: main_loss + 0.3*aux_loss1 + 0.3*aux_loss2 162 | return high+low, high 163 | -------------------------------------------------------------------------------- /train_PD.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import random 4 | import numpy as np 5 | import torch.autograd 6 | from skimage import io 7 | from torch import optim 8 | import torch.nn.functional as F 9 | from tensorboardX import SummaryWriter 10 | from torch.utils.data import DataLoader 11 | working_path = os.path.dirname(os.path.abspath(__file__)) 12 | 13 | ############################################### 14 | from datasets import PD_random as PD 15 | from models.LANet import LANet as Net 16 | NET_NAME = 'LANet' 17 | DATA_NAME = 'PD' 18 | ############################################### 19 | 20 | from utils.loss import CrossEntropyLoss2d 21 | from utils.utils import accuracy, intersectionAndUnion, AverageMeter 22 | 23 | args = { 24 | 'train_batch_size': 8, 25 | 'val_batch_size': 8, 26 | 'lr': 0.1, 27 | 'epochs': 50, 28 | 'gpu': True, 29 | 'crop_nums': 1000, 30 | 'lr_decay_power': 1.5, 31 | 'train_crop_size': 512, 32 | 'val_crop_size': 512, 33 | 'weight_decay': 5e-4, 34 | 'momentum': 0.9, 35 | 'print_freq': 100, 36 | 'predict_step': 5, 37 | 'pred_dir': os.path.join(working_path, 'results', DATA_NAME), 38 | 'chkpt_dir': os.path.join(working_path, 'checkpoints', DATA_NAME), 39 | 'log_dir': os.path.join(working_path, 'logs', DATA_NAME, NET_NAME), 40 | 'data_dir': 'YOUR_DATA_DIR' 41 | } 42 | 43 | if not os.path.exists(args['chkpt_dir']): os.makedirs(args['chkpt_dir']) 44 | if not os.path.exists(args['pred_dir']): os.makedirs(args['pred_dir']) 45 | if not os.path.exists(args['log_dir']): os.makedirs(args['log_dir']) 46 | writer = SummaryWriter(args['log_dir']) 47 | 48 | def main(): 49 | net = Net(5, num_classes=PD.num_classes+1).cuda() 50 | 51 | train_set = PD.Loader(args['data_dir'], 'train', random_crop=True, crop_nums=args['crop_nums'], random_flip=True, crop_size=args['train_crop_size'], padding=True) 52 | train_loader = DataLoader(train_set, batch_size=args['train_batch_size'], num_workers=4, shuffle=True) 53 | val_set = PD.Loader(args['data_dir'], 'val', sliding_crop=True, crop_size=args['val_crop_size']) 54 | val_loader = DataLoader(val_set, batch_size=args['val_batch_size'], num_workers=4, shuffle=False) 55 | 56 | criterion = CrossEntropyLoss2d(ignore_index=0).cuda() 57 | optimizer = optim.SGD(filter(lambda p: p.requires_grad, net.parameters()), lr=args['lr'], weight_decay=args['weight_decay'], momentum=args['momentum'], nesterov=True) 58 | scheduler = optim.lr_scheduler.StepLR(optimizer, 1, gamma=0.95, last_epoch=-1) 59 | 60 | train(train_loader, net, criterion, optimizer, scheduler, args, val_loader) 61 | writer.close() 62 | print('Training finished.') 63 | 64 | def train(train_loader, net, criterion, optimizer, scheduler, train_args, val_loader): 65 | bestaccT=0 66 | bestaccV=0.5 67 | bestloss=1 68 | begin_time = time.time() 69 | all_iters = float(len(train_loader)*args['epochs']) 70 | curr_epoch=0 71 | while True: 72 | torch.cuda.empty_cache() 73 | net.train() 74 | start = time.time() 75 | acc_meter = AverageMeter() 76 | train_loss = AverageMeter() 77 | 78 | curr_iter = curr_epoch*len(train_loader) 79 | for i, data in enumerate(train_loader): 80 | running_iter = curr_iter+i+1 81 | adjust_learning_rate(optimizer, running_iter, all_iters) 82 | imgs, labels = data 83 | if args['gpu']: 84 | imgs = imgs.cuda().float() 85 | labels = labels.cuda().long() 86 | 87 | optimizer.zero_grad() 88 | outputs, aux = net(imgs) 89 | 90 | alpha = calc_alpha(running_iter, all_iters) 91 | main_loss = criterion(outputs, labels) 92 | aux_loss = criterion(aux, labels) 93 | loss = main_loss + alpha*aux_loss 94 | loss.backward() 95 | optimizer.step() 96 | 97 | labels = labels.cpu().detach().numpy() 98 | outputs = outputs.cpu().detach() 99 | preds = torch.argmax(outputs, dim=1) 100 | preds = preds.numpy() 101 | # batch_valid_sum = 0 102 | acc_curr_meter = AverageMeter() 103 | for (pred, label) in zip(preds, labels): 104 | acc, valid_sum = accuracy(pred, label) 105 | # print(valid_sum) 106 | acc_curr_meter.update(acc) 107 | acc_meter.update(acc_curr_meter.avg) 108 | train_loss.update(loss.cpu().detach().numpy()) 109 | 110 | curr_time = time.time() - start 111 | 112 | if (i + 1) % train_args['print_freq'] == 0: 113 | print('[epoch %d] [iter %d / %d %.1fs] [lr %f] [train loss %.4f acc %.2f]' % ( 114 | curr_epoch, i + 1, len(train_loader), curr_time, optimizer.param_groups[0]['lr'], 115 | train_loss.val, acc_meter.val*100)) 116 | writer.add_scalar('train loss', train_loss.val, running_iter) 117 | loss_rec = train_loss.val 118 | writer.add_scalar('train accuracy', acc_meter.val, running_iter) 119 | writer.add_scalar('lr', optimizer.param_groups[0]['lr'], running_iter) 120 | 121 | acc_v, loss_v = validate(val_loader, net, criterion, curr_epoch, train_args) 122 | if acc_meter.avg>bestaccT: bestaccT=acc_meter.avg 123 | if acc_v>bestaccV: 124 | bestaccV=acc_v 125 | bestloss=loss_v 126 | save_path = os.path.join(args['chkpt_dir'], NET_NAME+'_%de_OA%.2f.pth'%(curr_epoch, acc_v*100)) 127 | torch.save(net.state_dict(), save_path) 128 | print('Total time: %.1fs Best rec: Train %.2f, Val %.2f, Val_loss %.4f' %(time.time()-begin_time, bestaccT*100, bestaccV*100, bestloss)) 129 | curr_epoch += 1 130 | #scheduler.step() 131 | if curr_epoch >= train_args['epochs']: 132 | return 133 | 134 | def validate(val_loader, net, criterion, curr_epoch, train_args): 135 | # the following code is written assuming that batch size is 1 136 | net.eval() 137 | torch.cuda.empty_cache() 138 | start = time.time() 139 | 140 | val_loss = AverageMeter() 141 | acc_meter = AverageMeter() 142 | 143 | for vi, data in enumerate(val_loader): 144 | imgs, labels = data 145 | 146 | if train_args['gpu']: 147 | imgs = imgs.cuda().float() 148 | labels = labels.cuda().long() 149 | 150 | with torch.no_grad(): 151 | outputs, _ = net(imgs) 152 | loss = criterion(outputs, labels) 153 | val_loss.update(loss.cpu().detach().numpy()) 154 | 155 | outputs = outputs.cpu().detach() 156 | labels = labels.cpu().detach().numpy() 157 | preds = torch.argmax(outputs, dim=1) 158 | preds = preds.numpy() 159 | for (pred, label) in zip(preds, labels): 160 | acc, valid_sum = accuracy(pred, label) 161 | acc_meter.update(acc) 162 | 163 | if curr_epoch%args['predict_step']==0 and vi==0: 164 | pred_color = PD.Index2Color(preds[0]) 165 | pred_path = os.path.join(args['pred_dir'], NET_NAME+'.png') 166 | io.imsave(pred_path, pred_color) 167 | print('Prediction saved!') 168 | 169 | curr_time = time.time() - start 170 | print('%.1fs Val loss: %.2f Accuracy: %.2f'%(curr_time, val_loss.average(), acc_meter.average()*100)) 171 | 172 | writer.add_scalar('val_loss', val_loss.average(), curr_epoch) 173 | writer.add_scalar('val_Accuracy', acc_meter.average(), curr_epoch) 174 | 175 | return acc_meter.avg, val_loss.avg 176 | 177 | def calc_alpha(curr_iter, all_iters, weight=1.0): 178 | r = (1.0-float(curr_iter)/all_iters)** 2.0 179 | return weight*r 180 | 181 | def adjust_learning_rate(optimizer, curr_iter, all_iter): 182 | scale_running_lr = ((1. - float(curr_iter) / all_iter) ** args['lr_decay_power']) 183 | running_lr = args['lr'] * scale_running_lr 184 | 185 | for param_group in optimizer.param_groups: 186 | param_group['lr'] = running_lr 187 | 188 | if __name__ == '__main__': 189 | main() 190 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .misc import * 2 | -------------------------------------------------------------------------------- /utils/crf.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pydensecrf.densecrf as dcrf 3 | 4 | def dense_crf(img, output_probs): 5 | h = output_probs.shape[0] 6 | w = output_probs.shape[1] 7 | 8 | output_probs = np.expand_dims(output_probs, 0) 9 | output_probs = np.append(1 - output_probs, output_probs, axis=0) 10 | 11 | d = dcrf.DenseCRF2D(w, h, 2) 12 | U = -np.log(output_probs) 13 | U = U.reshape((2, -1)) 14 | U = np.ascontiguousarray(U) 15 | img = np.ascontiguousarray(img) 16 | 17 | d.setUnaryEnergy(U) 18 | 19 | d.addPairwiseGaussian(sxy=20, compat=3) 20 | d.addPairwiseBilateral(sxy=30, srgb=20, rgbim=img, compat=10) 21 | 22 | Q = d.inference(5) 23 | Q = np.argmax(np.array(Q), axis=0).reshape((h, w)) 24 | 25 | return Q 26 | -------------------------------------------------------------------------------- /utils/data_vis.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | 3 | def plot_img_and_mask(img, mask): 4 | fig = plt.figure() 5 | a = fig.add_subplot(1, 2, 1) 6 | a.set_title('Input image') 7 | plt.imshow(img) 8 | 9 | b = fig.add_subplot(1, 2, 2) 10 | b.set_title('Output mask') 11 | plt.imshow(mask) 12 | plt.show() -------------------------------------------------------------------------------- /utils/eval.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import torch.nn as nn 4 | 5 | from dice_loss import dice_coeff 6 | 7 | 8 | def eval_net(net, dataset, gpu=True): 9 | """Evaluation without the densecrf with the dice coefficient""" 10 | net.eval() 11 | tot = 0 12 | n=len(dataset) 13 | for i, b in enumerate(dataset): 14 | # img = b[0] 15 | # true_mask = b[1] 16 | # img = torch.from_numpy(img).unsqueeze(0) 17 | # true_mask = torch.from_numpy(true_mask).unsqueeze(0) 18 | # 19 | # if gpu: 20 | # img = img.cuda() 21 | # true_mask = true_mask.cuda() 22 | 23 | img = torch.from_numpy(b[0]).unsqueeze(0).float() 24 | label = torch.from_numpy(b[1]).unsqueeze(0).long() 25 | if gpu: 26 | img = img.cuda() 27 | label = label.cuda() 28 | 29 | pred = net(img) 30 | loss = nn.CrossEntropyLoss() 31 | loss = loss(pred, label) 32 | 33 | tot += loss.item() 34 | return tot / n 35 | 36 | def eval_net_BCE(net, dataset, gpu=True): 37 | """Evaluation without the densecrf with the dice coefficient""" 38 | net.eval() 39 | tot = 0 40 | n=len(dataset) 41 | for i, b in enumerate(dataset): 42 | # img = b[0] 43 | # true_mask = b[1] 44 | # img = torch.from_numpy(img).unsqueeze(0) 45 | # true_mask = torch.from_numpy(true_mask).unsqueeze(0) 46 | # 47 | # if gpu: 48 | # img = img.cuda() 49 | # true_mask = true_mask.cuda() 50 | 51 | img = torch.from_numpy(b[0]).unsqueeze(0).float() 52 | label = torch.from_numpy(b[1]).unsqueeze(0).float() 53 | if gpu: 54 | img = img.cuda() 55 | label = label.cuda() 56 | 57 | pred = net(img) 58 | pred_flat = pred.view(-1) 59 | labels_flat = labels.view(-1) 60 | loss = nn.BCEWithLogitsLoss() 61 | loss = loss(pred_flat, labels_flat) 62 | 63 | tot += loss.item() 64 | return tot / n 65 | -------------------------------------------------------------------------------- /utils/foo.txt: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /utils/load.py: -------------------------------------------------------------------------------- 1 | # 2 | # load.py : utils on generators / lists of ids to transform from strings to 3 | # cropped images and masks 4 | 5 | import os 6 | 7 | import numpy as np 8 | import math 9 | from PIL import Image 10 | 11 | from .utils import resize_and_crop, get_square, normalize, hwc_to_chw 12 | import utils.joint_transforms as joint_transforms 13 | 14 | import cv2 15 | 16 | ZUR_COLORMAP = [[0,0,0],[255,255,255],[0,255,0],[150,80,0],[0,125,0],[150,150,255],[100,100,100],[255,255,0],[0,0,150]] 17 | ZUR_CLASSES = ['road','background','grass','ground','tree','water','building','raiway','river'] 18 | 19 | colormap2label = np.zeros(256 ** 3) 20 | for i, cm in enumerate(ZUR_COLORMAP): 21 | colormap2label[(cm[0] * 256 + cm[1]) * 256 + cm[2]] = i 22 | 23 | def get_ids(dir): 24 | """Returns a list of the ids in the directory""" 25 | return (f[:-4] for f in os.listdir(dir)) 26 | 27 | 28 | def split_ids(ids, n=2): 29 | """Split each id in n, creating n tuples (id, k) for each id""" 30 | return ((id, i) for i in range(n) for id in ids) 31 | 32 | 33 | def to_cropped_imgs(ids, dir, suffix, scale): 34 | """From a list of tuples, returns the correct cropped img""" 35 | for id, pos in ids: 36 | im = resize_and_crop(Image.open(dir + id + suffix), scale=scale) 37 | yield get_square(im, pos) 38 | 39 | def get_imgs_and_masks(ids, dir_img, dir_label, crop_size): 40 | """Return all the couples (img, mask)""" 41 | data, labels = [], [] 42 | for id, pos in ids: 43 | fPath = dir_img + id + '.png' 44 | lPath = dir_label + id + '.png' 45 | data.append(np.array(Image.open(fPath))) 46 | labels.append(np.array(Image.open(lPath))) 47 | print('%d images loaded.'%len(data)) 48 | 49 | data, labels = DataAug(data, labels, crop_size) 50 | print('Image augment done. %d Images created.'%len(data)) 51 | # need to transform from HWC to CHW 52 | # imgs_switched = map(hwc_to_chw, data) 53 | imgs_switched = [] 54 | for im in data: 55 | imgs_switched.append(hwc_to_chw(im)) 56 | # for i in range(data.shape[0]): 57 | # print(data[i].shape) 58 | # imgs_normalized = map(normalize, imgs_switched) 59 | 60 | labels_index = [] 61 | for label in labels: 62 | labels_index.append(Color2Index0(label, colormap2label)) 63 | 64 | return list(zip(imgs_switched, labels_index)) 65 | 66 | def get_binary_imgs_and_masks(ids, dir_img, dir_label, crop_size): 67 | """Return all the couples (img, mask)""" 68 | data, labels = [], [] 69 | for id, pos in ids: 70 | fPath = dir_img + id + '.png' 71 | lPath = dir_label + id + '.png' 72 | data.append(np.array(Image.open(fPath).convert("L"))) 73 | labels.append(np.asarray(Image.open(lPath).convert("L"))) 74 | print('%d images loaded.'%len(data)) 75 | 76 | data, labels = DataAug_1C(data, labels, crop_size) 77 | print('Image augment done. %d Images created.'%len(data)) 78 | 79 | imgs_switched = [] 80 | for im in data: 81 | imgs_switched.append(np.expand_dims(im, 0)) 82 | 83 | labels_index = [] 84 | for label in labels: 85 | labels_index.append(label/255) 86 | 87 | return list(zip(imgs_switched, labels_index)) 88 | 89 | def read_images(ids, dir_img, dir_label): 90 | n = len(Img_fileList) 91 | data, label = [None] * n, [None] * n 92 | for id, pos in ids: 93 | fPath = dir_img + id + suffix 94 | lPath = dir_label + id + suffix 95 | data[i] = Image.open(fPath) 96 | label[i] = Image.open(lPath) 97 | return data, label 98 | 99 | def DataAug(data, labels, size): 100 | crop_imgs = create_crops(data[0], size) 101 | crop_labels = create_crops(labels[0], size) 102 | for i in range(1, len(data)): 103 | crop_imgs = np.concatenate((crop_imgs, create_crops(data[i], size)), axis=0) 104 | crop_labels = np.concatenate((crop_labels, create_crops(labels[i], size)), axis=0) 105 | # crop_imgs = [] 106 | # crop_labels = [] 107 | # aug_times = [] 108 | # ten_crop_imgs = [] 109 | # ten_crop_labels = [] 110 | # for i in range(len(data)): 111 | # h_rate = data[i].shape[0]/size[0] 112 | # w_rate = data[i].shape[1]/size[1] 113 | # aug_time = min(h_rate,w_rate)*2 114 | # print(aug_time) 115 | # if (aug_time<1.5): aug_time=8 116 | # elif (aug_time<2): aug_time=10 117 | # else: aug_time=18 118 | # aug_times.append(aug_time) 119 | # ten_crop_imgs.append(ten_crop(data[i], size)) 120 | # ten_crop_labels.append(ten_crop(labels[i], size)) 121 | # for t in range(max(aug_times)): 122 | # for i in range(len(data)): 123 | # if(aug_times[i]>t): 124 | # crop_imgs.append(ten_crop_imgs[i][t]) 125 | # crop_labels.append(ten_crop_labels[i][t]) 126 | return crop_imgs, crop_labels 127 | 128 | def DataAug_1C(data, labels, size): 129 | crop_imgs = create_crops_1C(data[0], size) 130 | crop_labels = create_crops_1C(labels[0], size) 131 | for i in range(1, len(data)): 132 | crop_imgs = np.concatenate((crop_imgs, create_crops_1C(data[i], size)), axis=0) 133 | crop_labels = np.concatenate((crop_labels, create_crops_1C(labels[i], size)), axis=0) 134 | # crop_imgs = [] 135 | # crop_labels = [] 136 | # aug_times = [] 137 | # ten_crop_imgs = [] 138 | # ten_crop_labels = [] 139 | # for i in range(len(data)): 140 | # h_rate = data[i].shape[0]/size[0] 141 | # w_rate = data[i].shape[1]/size[1] 142 | # aug_time = min(h_rate,w_rate)*2 143 | # print(aug_time) 144 | # if (aug_time<1.5): aug_time=8 145 | # elif (aug_time<2): aug_time=10 146 | # else: aug_time=18 147 | # aug_times.append(aug_time) 148 | # ten_crop_imgs.append(ten_crop(data[i], size)) 149 | # ten_crop_labels.append(ten_crop(labels[i], size)) 150 | # for t in range(max(aug_times)): 151 | # for i in range(len(data)): 152 | # if(aug_times[i]>t): 153 | # crop_imgs.append(ten_crop_imgs[i][t]) 154 | # crop_labels.append(ten_crop_labels[i][t]) 155 | return crop_imgs, crop_labels 156 | 157 | def Color2Index(ColorLabels, colormap2label): 158 | IndexLabels = np.zeros(ColorLabels.shape[0],ColorLabels.shape[1], 159 | ColorLabels.shape[2], 1) 160 | for i, data in enumerate(ColorLabels): 161 | data = data.astype('int32') 162 | idx = (data[:,:,0] * 256 + data[:,:,1]) * 256 + data[:,:,2] 163 | IndexLabels[i] = colormap2label[idx] 164 | return IndexLabels 165 | 166 | def Color2Index0(ColorLabel, colormap2label): 167 | data = ColorLabel.astype('int32') 168 | idx = (data[:,:,0] * 256 + data[:,:,1]) * 256 + data[:,:,2] 169 | return colormap2label[idx] 170 | 171 | def Index2Color(pred, colormap2label): 172 | x = pred.astype('int32') 173 | return colormap2label[x, :] 174 | 175 | def ten_crop(src, size): 176 | """Crop 10 regions from an array. 177 | This is performed same as: 178 | http://chainercv.readthedocs.io/en/stable/reference/transforms.html#ten-crop 179 | 180 | This method crops 10 regions. All regions will be in shape 181 | :obj`size`. These regions consist of 1 center crop and 4 corner 182 | crops and horizontal flips of them. 183 | The crops are ordered in this order. 184 | * center crop 185 | * top-left crop 186 | * bottom-right crop 187 | * top-right crop 188 | * bottom-left crop 189 | * center crop (flipped horizontally) 190 | * top-left crop (flipped horizontally) 191 | * bottom-left crop (flipped horizontally) 192 | * top-right crop (flipped horizontally) 193 | * bottom-right crop (flipped horizontally) 194 | 195 | Parameters 196 | ---------- 197 | src : Numpy array 198 | Input image. 199 | size : tuple 200 | Tuple of length 2, as (width, height) of the cropped areas. 201 | 202 | Returns 203 | ------- 204 | mxnet.nd.NDArray 205 | The cropped images with shape (10, size[1], size[0], C) 206 | 207 | """ 208 | h, w, _ = src.shape 209 | ow, oh = size 210 | 211 | if h < oh or w < ow: 212 | raise ValueError( 213 | "Cannot crop area {} from image with size ({}, {})".format(str(size), h, w)) 214 | 215 | # h=int(h) 216 | # w = int(w) 217 | # ow = int(ow) 218 | # oh = int(oh) 219 | 220 | tl = src[0:oh, 0:ow, :] 221 | bl = src[h - oh:h, 0:ow, :] 222 | tr = src[0:oh, w - ow:w, :] 223 | br = src[h - oh:h, w - ow:w, :] 224 | center = src[(h - oh) // 2:(h + oh) // 2, (w - ow) // 2:(w + ow) // 2, :] 225 | 226 | tl_f = cv2.flip(tl, -1) 227 | bl_f = cv2.flip(bl, -1) 228 | tr_f = cv2.flip(tr, -1) 229 | br_f = cv2.flip(br, -1) 230 | center_f = cv2.flip(center, -1) 231 | print(center_f.shape) 232 | print(center_rf.shape) 233 | print(center_tf.shape) 234 | print(center_bf.shape) 235 | crops = np.stack([tl, br, tr, bl, tl_f, br_f, tr_f, bl_f, center, center_f], axis=0) 236 | return crops 237 | 238 | def create_crops(img, size): 239 | # print(img.shape) 240 | h = img.shape[0] 241 | w = img.shape[1] 242 | c_h = size[0] 243 | c_w = size[1] 244 | if h < c_h or w < c_w: 245 | raise ValueError( 246 | "Cannot crop area {} from image with size ({}, {})".format(str(size), h, w)) 247 | 248 | h_rate = h/c_h 249 | w_rate = w/c_w 250 | h_times = math.ceil(h_rate) 251 | w_times = math.ceil(w_rate) 252 | stride_h = math.ceil(c_h*(h_times-h_rate)/(h_times-1)) 253 | stride_w = math.ceil(c_w*(w_times-w_rate)/(w_times-1)) 254 | crop_imgs = [] 255 | for j in range(h_times): 256 | for i in range(w_times): 257 | s_h = int(j*c_h - j*stride_h) 258 | if(j==(h_times-1)): s_h = h - c_h 259 | e_h = s_h + c_h 260 | s_w = int(i*c_w - i*stride_w) 261 | if(i==(w_times-1)): s_w = w - c_w 262 | e_w = s_w + c_w 263 | # print('%d %d %d %d'%(s_h, e_h, s_w, e_w)) 264 | crop_im = img[s_h:e_h, s_w:e_w, :] 265 | crop_imgs.append(crop_im) 266 | 267 | crop_imgs_f = [] 268 | for im in crop_imgs: 269 | crop_imgs_f.append(cv2.flip(im, -1)) 270 | 271 | crops = np.concatenate((np.array(crop_imgs), np.array(crop_imgs_f)), axis=0) 272 | # print(crops.shape) 273 | return crops 274 | 275 | def create_crops_1C(img, size): 276 | # print(img.shape) 277 | h = img.shape[0] 278 | w = img.shape[1] 279 | c_h = size[0] 280 | c_w = size[1] 281 | if h < c_h or w < c_w: 282 | raise ValueError( 283 | "Cannot crop area {} from image with size ({}, {})".format(str(size), h, w)) 284 | 285 | h_rate = h/c_h 286 | w_rate = w/c_w 287 | h_times = math.ceil(h_rate) 288 | w_times = math.ceil(w_rate) 289 | stride_h = math.ceil(c_h*(h_times-h_rate)/(h_times-1)) 290 | stride_w = math.ceil(c_w*(w_times-w_rate)/(w_times-1)) 291 | crop_imgs = [] 292 | for j in range(h_times): 293 | for i in range(w_times): 294 | s_h = int(j*c_h - j*stride_h) 295 | if(j==(h_times-1)): s_h = h - c_h 296 | e_h = s_h + c_h 297 | s_w = int(i*c_w - i*stride_w) 298 | if(i==(w_times-1)): s_w = w - c_w 299 | e_w = s_w + c_w 300 | # print('%d %d %d %d'%(s_h, e_h, s_w, e_w)) 301 | crop_im = img[s_h:e_h, s_w:e_w] 302 | crop_imgs.append(crop_im) 303 | 304 | crop_imgs_f = [] 305 | for im in crop_imgs: 306 | crop_imgs_f.append(cv2.flip(im, -1)) 307 | 308 | crops = np.concatenate((np.array(crop_imgs), np.array(crop_imgs_f)), axis=0) 309 | # print(crops.shape) 310 | return crops 311 | -------------------------------------------------------------------------------- /utils/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch.autograd import Variable 4 | import torch.nn as nn 5 | # Recommend 6 | class CrossEntropyLoss2d(nn.Module): 7 | def __init__(self, weight=None, ignore_index=-1): 8 | super(CrossEntropyLoss2d, self).__init__() 9 | self.nll_loss = nn.NLLLoss(weight=weight, ignore_index=ignore_index, 10 | reduction='elementwise_mean') 11 | 12 | def forward(self, inputs, targets): 13 | return self.nll_loss(F.log_softmax(inputs, dim=1), targets) 14 | 15 | 16 | # this may be unstable sometimes.Notice set the size_average 17 | def CrossEntropy2d(input, target, weight=None, size_average=False): 18 | # input:(n, c, h, w) target:(n, h, w) 19 | n, c, h, w = input.size() 20 | 21 | input = input.transpose(1, 2).transpose(2, 3).contiguous() 22 | input = input[target.view(n, h, w, 1).repeat(1, 1, 1, c) >= 0].view(-1, c) 23 | 24 | target_mask = target >= 0 25 | target = target[target_mask] 26 | #loss = F.nll_loss(F.log_softmax(input), target, weight=weight, size_average=False) 27 | loss = F.cross_entropy(input, target, weight=weight, size_average=False) 28 | if size_average: 29 | loss /= target_mask.sum().data[0] 30 | 31 | return loss 32 | 33 | def weighted_BCE(output, target, weight_pos=None, weight_neg=None): 34 | output = torch.clamp(output,min=1e-8,max=1-1e-8) 35 | 36 | if weight_pos is not None: 37 | loss = weight_pos * (target * torch.log(output)) + \ 38 | weight_neg * ((1 - target) * torch.log(1 - output)) 39 | else: 40 | loss = target * torch.log(output) + (1 - target) * torch.log(1 - output) 41 | 42 | return torch.neg(torch.mean(loss)) 43 | 44 | def weighted_BCE_logits(logit_pixel, truth_pixel, weight_pos=0.25, weight_neg=0.75): 45 | logit = logit_pixel.view(-1) 46 | truth = truth_pixel.view(-1) 47 | assert(logit.shape==truth.shape) 48 | 49 | loss = F.binary_cross_entropy_with_logits(logit, truth, reduction='none') 50 | 51 | pos = (truth>0.5).float() 52 | neg = (truth<0.5).float() 53 | pos_num = pos.sum().item() + 1e-12 54 | neg_num = neg.sum().item() + 1e-12 55 | loss = (weight_pos*pos*loss/pos_num + weight_neg*neg*loss/neg_num).sum() 56 | 57 | return loss 58 | 59 | class FocalLoss(nn.Module): 60 | def __init__(self, alpha=0.5, gamma=2, weight=None, ignore_index=255): 61 | super().__init__() 62 | self.alpha = alpha 63 | self.gamma = gamma 64 | self.weight = weight 65 | self.ignore_index = ignore_index 66 | self.ce_fn = nn.CrossEntropyLoss(weight=self.weight, ignore_index=self.ignore_index) 67 | 68 | def forward(self, preds, labels): 69 | logpt = -self.ce_fn(preds, labels) 70 | pt = torch.exp(logpt) 71 | loss = -((1 - pt) ** self.gamma) * self.alpha * logpt 72 | return loss 73 | 74 | class FocalLoss2d(nn.Module): 75 | def __init__(self, gamma=0, weight=None, size_average=True, ignore_index=-1): 76 | super(FocalLoss2d, self).__init__() 77 | self.gamma = gamma 78 | self.weight = weight 79 | self.size_average = size_average 80 | self.ignore_index = ignore_index 81 | 82 | def forward(self, input, target): 83 | if input.dim()>2: 84 | input = input.contiguous().view(input.size(0), input.size(1), -1) 85 | input = input.transpose(1,2) 86 | input = input.contiguous().view(-1, input.size(2)).squeeze() 87 | if target.dim()==4: 88 | target = target.contiguous().view(target.size(0), target.size(1), -1) 89 | target = target.transpose(1,2) 90 | target = target.contiguous().view(-1, target.size(2)).squeeze() 91 | elif target.dim()==3: 92 | target = target.view(-1) 93 | else: 94 | target = target.view(-1, 1) 95 | 96 | # compute the negative likelyhood 97 | weight = Variable(self.weight) 98 | logpt = -F.cross_entropy(input, target, ignore_index=self.ignore_index) 99 | pt = torch.exp(logpt) 100 | 101 | # compute the loss 102 | loss = -((1-pt)**self.gamma) * logpt 103 | 104 | # averaging (or not) loss 105 | if self.size_average: 106 | return loss.mean() 107 | else: 108 | return loss.sum() 109 | 110 | class ChangeSimilarity(nn.Module): 111 | """input: x1, x2 multi-class predictions, c = class_num 112 | label_change: changed part 113 | """ 114 | def __init__(self, reduction='mean'): 115 | super(ChangeSimilarity, self).__init__() 116 | self.loss_f = nn.CosineEmbeddingLoss(margin=0.1, reduction=reduction) 117 | #self.loss_show = nn.CosineEmbeddingLoss(margin=0., reduction='none') 118 | 119 | def forward(self, x1, x2, label_change): 120 | b,c,h,w = x1.size() 121 | #x1 = F.softmax(x1, dim=1) 122 | #x2 = F.softmax(x2, dim=1) 123 | x1 = x1.permute(0,2,3,1) 124 | x2 = x2.permute(0,2,3,1) 125 | x1 = torch.reshape(x1,[b*h*w,c]) 126 | x2 = torch.reshape(x2,[b*h*w,c]) 127 | 128 | label_unchange = ~label_change.bool() 129 | target = label_unchange.float() 130 | target = target - label_change.float() 131 | target = torch.reshape(target,[b*h*w]) 132 | 133 | loss = self.loss_f(x1, x2, target) 134 | #loss_show = self.loss_show(x1, x2, target) 135 | return loss 136 | 137 | class ChangeSalience(nn.Module): 138 | """input: x1, x2 multi-class predictions, c = class_num 139 | label_change: changed part 140 | """ 141 | def __init__(self, reduction='mean'): 142 | super(ChangeSimilarity, self).__init__() 143 | self.loss_f = nn.MSELoss(reduction=reduction) 144 | 145 | def forward(self, x1, x2, label_change): 146 | b,c,h,w = x1.size() 147 | x1 = F.softmax(x1, dim=1)[:,0,:,:] 148 | x2 = F.softmax(x2, dim=1)[:,0,:,:] 149 | 150 | loss = self.loss_f(x1, x2.detach()) + self.loss_f(x2, x1.detach()) 151 | return loss*0.5 152 | 153 | 154 | def pix_loss(output, target, pix_weight, ignore_index=None): 155 | # Calculate log probabilities 156 | if ignore_index is not None: 157 | active_pos = 1-(target==ignore_index).unsqueeze(1).cuda().float() 158 | pix_weight *= active_pos 159 | 160 | batch_size, _, H, W = output.size() 161 | logp = F.log_softmax(output, dim=1) 162 | # Gather log probabilities with respect to target 163 | logp = logp.gather(1, target.view(batch_size, 1, H, W)) 164 | # Multiply with weights 165 | weighted_logp = (logp * pix_weight).view(batch_size, -1) 166 | # Rescale so that loss is in approx. same interval 167 | weighted_loss = weighted_logp.sum(1) / pix_weight.view(batch_size, -1).sum(1) 168 | # Average over mini-batch 169 | weighted_loss = -1.0 * weighted_loss.mean() 170 | return weighted_loss 171 | 172 | def make_one_hot(input, num_classes): 173 | """Convert class index tensor to one hot encoding tensor. 174 | Args: 175 | input: A tensor of shape [N, 1, *] 176 | num_classes: An int of number of class 177 | Returns: 178 | A tensor of shape [N, num_classes, *] 179 | """ 180 | shape = np.array(input.shape) 181 | shape[1] = num_classes 182 | shape = tuple(shape) 183 | result = torch.zeros(shape) 184 | result = result.scatter_(1, input.cpu(), 1) 185 | 186 | return result 187 | 188 | 189 | class BinaryDiceLoss(nn.Module): 190 | """Dice loss of binary class 191 | Args: 192 | smooth: A float number to smooth loss, and avoid NaN error, default: 1 193 | p: Denominator value: \sum{x^p} + \sum{y^p}, default: 2 194 | predict: A tensor of shape [N, *] 195 | target: A tensor of shape same with predict 196 | reduction: Reduction method to apply, return mean over batch if 'mean', 197 | return sum if 'sum', return a tensor of shape [N,] if 'none' 198 | Returns: 199 | Loss tensor according to arg reduction 200 | Raise: 201 | Exception if unexpected reduction 202 | """ 203 | def __init__(self, smooth=1, p=2, reduction='mean'): 204 | super(BinaryDiceLoss, self).__init__() 205 | self.smooth = smooth 206 | self.p = p 207 | self.reduction = reduction 208 | 209 | def forward(self, predict, target): 210 | assert predict.shape[0] == target.shape[0], "predict & target batch size don't match" 211 | predict = predict.contiguous().view(predict.shape[0], -1) 212 | target = target.contiguous().view(target.shape[0], -1) 213 | 214 | num = torch.sum(torch.mul(predict, target), dim=1) + self.smooth 215 | den = torch.sum(predict.pow(self.p) + target.pow(self.p), dim=1) + self.smooth 216 | 217 | loss = 1 - num / den 218 | 219 | if self.reduction == 'mean': 220 | return loss.mean() 221 | elif self.reduction == 'sum': 222 | return loss.sum() 223 | elif self.reduction == 'none': 224 | return loss 225 | else: 226 | raise Exception('Unexpected reduction {}'.format(self.reduction)) 227 | 228 | 229 | class DiceLoss(nn.Module): 230 | """Dice loss, need one hot encode input 231 | Args: 232 | weight: An array of shape [num_classes,] 233 | ignore_index: class index to ignore 234 | predict: A tensor of shape [N, C, *] 235 | target: A tensor of same shape with predict 236 | other args pass to BinaryDiceLoss 237 | Return: 238 | same as BinaryDiceLoss 239 | """ 240 | def __init__(self, weight=None, ignore_index=None, **kwargs): 241 | super(DiceLoss, self).__init__() 242 | self.kwargs = kwargs 243 | self.weight = weight 244 | self.ignore_index = ignore_index 245 | 246 | def forward(self, predict, target): 247 | assert predict.shape == target.shape, 'predict & target shape do not match' 248 | dice = BinaryDiceLoss(**self.kwargs) 249 | total_loss = 0 250 | predict = F.softmax(predict, dim=1) 251 | 252 | for i in range(target.shape[1]): 253 | if i != self.ignore_index: 254 | dice_loss = dice(predict[:, i], target[:, i]) 255 | if self.weight is not None: 256 | assert self.weight.shape[0] == target.shape[1], \ 257 | 'Expect weight shape [{}], get[{}]'.format(target.shape[1], self.weight.shape[0]) 258 | dice_loss *= self.weights[i] 259 | total_loss += dice_loss 260 | 261 | return total_loss/target.shape[1] -------------------------------------------------------------------------------- /utils/misc.py: -------------------------------------------------------------------------------- 1 | import os 2 | from math import ceil 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn.functional as F 7 | from torch import nn 8 | from torch.autograd import Variable 9 | 10 | 11 | def check_mkdir(dir_name): 12 | if not os.path.exists(dir_name): 13 | os.mkdir(dir_name) 14 | 15 | 16 | def initialize_weights(*models): 17 | for model in models: 18 | for module in model.modules(): 19 | if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear): 20 | nn.init.kaiming_normal_(module.weight) 21 | if module.bias is not None: 22 | module.bias.data.zero_() 23 | elif isinstance(module, nn.BatchNorm2d): 24 | module.weight.data.fill_(1) 25 | module.bias.data.zero_() 26 | 27 | 28 | def get_upsampling_weight(in_channels, out_channels, kernel_size): 29 | factor = (kernel_size + 1) // 2 30 | if kernel_size % 2 == 1: 31 | center = factor - 1 32 | else: 33 | center = factor - 0.5 34 | og = np.ogrid[:kernel_size, :kernel_size] 35 | filt = (1 - abs(og[0] - center) / factor) * (1 - abs(og[1] - center) / factor) 36 | weight = np.zeros((in_channels, out_channels, kernel_size, kernel_size), dtype=np.float64) 37 | weight[list(range(in_channels)), list(range(out_channels)), :, :] = filt 38 | return torch.from_numpy(weight).float() 39 | 40 | 41 | 42 | def _fast_hist(label_pred, label_true, num_classes): 43 | mask = (label_true >= 0) & (label_true < num_classes) 44 | hist = np.bincount( 45 | num_classes * label_true[mask].astype(int) + 46 | label_pred[mask], minlength=num_classes ** 2).reshape(num_classes, num_classes) 47 | return hist 48 | 49 | 50 | def evaluate(predictions, gts, num_classes): 51 | hist = np.zeros((num_classes, num_classes)) 52 | for lp, lt in zip(predictions, gts): 53 | hist += _fast_hist(lp.flatten(), lt.flatten(), num_classes) 54 | # axis 0: gt, axis 1: prediction 55 | acc = np.diag(hist).sum() / hist.sum() 56 | acc_cls = np.diag(hist) / hist.sum(axis=1) 57 | acc_cls = np.nanmean(acc_cls) 58 | iu = np.diag(hist) / (hist.sum(axis=1) + hist.sum(axis=0) - np.diag(hist)) 59 | mean_iu = np.nanmean(iu) 60 | freq = hist.sum(axis=1) / hist.sum() 61 | fwavacc = (freq[freq > 0] * iu[freq > 0]).sum() 62 | return acc, acc_cls, mean_iu, fwavacc 63 | 64 | 65 | class PolyLR(object): 66 | def __init__(self, optimizer, curr_iter, max_iter, lr_decay): 67 | self.max_iter = float(max_iter) 68 | self.init_lr_groups = [] 69 | for p in optimizer.param_groups: 70 | self.init_lr_groups.append(p['lr']) 71 | self.param_groups = optimizer.param_groups 72 | self.curr_iter = curr_iter 73 | self.lr_decay = lr_decay 74 | 75 | def step(self): 76 | for idx, p in enumerate(self.param_groups): 77 | p['lr'] = self.init_lr_groups[idx] * (1 - self.curr_iter / self.max_iter) ** self.lr_decay 78 | 79 | 80 | # just a try, not recommend to use 81 | class Conv2dDeformable(nn.Module): 82 | def __init__(self, regular_filter, cuda=True): 83 | super(Conv2dDeformable, self).__init__() 84 | assert isinstance(regular_filter, nn.Conv2d) 85 | self.regular_filter = regular_filter 86 | self.offset_filter = nn.Conv2d(regular_filter.in_channels, 2 * regular_filter.in_channels, kernel_size=3, 87 | padding=1, bias=False) 88 | self.offset_filter.weight.data.normal_(0, 0.0005) 89 | self.input_shape = None 90 | self.grid_w = None 91 | self.grid_h = None 92 | self.cuda = cuda 93 | 94 | def forward(self, x): 95 | x_shape = x.size() # (b, c, h, w) 96 | offset = self.offset_filter(x) # (b, 2*c, h, w) 97 | offset_w, offset_h = torch.split(offset, self.regular_filter.in_channels, 1) # (b, c, h, w) 98 | offset_w = offset_w.contiguous().view(-1, int(x_shape[2]), int(x_shape[3])) # (b*c, h, w) 99 | offset_h = offset_h.contiguous().view(-1, int(x_shape[2]), int(x_shape[3])) # (b*c, h, w) 100 | if not self.input_shape or self.input_shape != x_shape: 101 | self.input_shape = x_shape 102 | grid_w, grid_h = np.meshgrid(np.linspace(-1, 1, x_shape[3]), np.linspace(-1, 1, x_shape[2])) # (h, w) 103 | grid_w = torch.Tensor(grid_w) 104 | grid_h = torch.Tensor(grid_h) 105 | if self.cuda: 106 | grid_w = grid_w.cuda() 107 | grid_h = grid_h.cuda() 108 | self.grid_w = nn.Parameter(grid_w) 109 | self.grid_h = nn.Parameter(grid_h) 110 | offset_w = offset_w + self.grid_w # (b*c, h, w) 111 | offset_h = offset_h + self.grid_h # (b*c, h, w) 112 | x = x.contiguous().view(-1, int(x_shape[2]), int(x_shape[3])).unsqueeze(1) # (b*c, 1, h, w) 113 | x = F.grid_sample(x, torch.stack((offset_h, offset_w), 3)) # (b*c, h, w) 114 | x = x.contiguous().view(-1, int(x_shape[1]), int(x_shape[2]), int(x_shape[3])) # (b, c, h, w) 115 | x = self.regular_filter(x) 116 | return x 117 | 118 | 119 | def sliced_forward(single_forward): 120 | def _pad(x, crop_size): 121 | h, w = x.size()[2:] 122 | pad_h = max(crop_size - h, 0) 123 | pad_w = max(crop_size - w, 0) 124 | x = F.pad(x, (0, pad_w, 0, pad_h)) 125 | return x, pad_h, pad_w 126 | 127 | def wrapper(self, x): 128 | batch_size, _, ori_h, ori_w = x.size() 129 | if self.training and self.use_aux: 130 | outputs_all_scales = Variable(torch.zeros((batch_size, self.num_classes, ori_h, ori_w))).cuda() 131 | aux_all_scales = Variable(torch.zeros((batch_size, self.num_classes, ori_h, ori_w))).cuda() 132 | for s in self.scales: 133 | new_size = (int(ori_h * s), int(ori_w * s)) 134 | scaled_x = F.upsample(x, size=new_size, mode='bilinear') 135 | scaled_x = Variable(scaled_x).cuda() 136 | scaled_h, scaled_w = scaled_x.size()[2:] 137 | long_size = max(scaled_h, scaled_w) 138 | print(scaled_x.size()) 139 | 140 | if long_size > self.crop_size: 141 | count = torch.zeros((scaled_h, scaled_w)) 142 | outputs = Variable(torch.zeros((batch_size, self.num_classes, scaled_h, scaled_w))).cuda() 143 | aux_outputs = Variable(torch.zeros((batch_size, self.num_classes, scaled_h, scaled_w))).cuda() 144 | stride = int(ceil(self.crop_size * self.stride_rate)) 145 | h_step_num = int(ceil((scaled_h - self.crop_size) / stride)) + 1 146 | w_step_num = int(ceil((scaled_w - self.crop_size) / stride)) + 1 147 | for yy in range(h_step_num): 148 | for xx in range(w_step_num): 149 | sy, sx = yy * stride, xx * stride 150 | ey, ex = sy + self.crop_size, sx + self.crop_size 151 | x_sub = scaled_x[:, :, sy: ey, sx: ex] 152 | x_sub, pad_h, pad_w = _pad(x_sub, self.crop_size) 153 | print(x_sub.size()) 154 | outputs_sub, aux_sub = single_forward(self, x_sub) 155 | 156 | if sy + self.crop_size > scaled_h: 157 | outputs_sub = outputs_sub[:, :, : -pad_h, :] 158 | aux_sub = aux_sub[:, :, : -pad_h, :] 159 | 160 | if sx + self.crop_size > scaled_w: 161 | outputs_sub = outputs_sub[:, :, :, : -pad_w] 162 | aux_sub = aux_sub[:, :, :, : -pad_w] 163 | 164 | outputs[:, :, sy: ey, sx: ex] = outputs_sub 165 | aux_outputs[:, :, sy: ey, sx: ex] = aux_sub 166 | 167 | count[sy: ey, sx: ex] += 1 168 | count = Variable(count).cuda() 169 | outputs = (outputs / count) 170 | aux_outputs = (outputs / count) 171 | else: 172 | scaled_x, pad_h, pad_w = _pad(scaled_x, self.crop_size) 173 | outputs, aux_outputs = single_forward(self, scaled_x) 174 | outputs = outputs[:, :, : -pad_h, : -pad_w] 175 | aux_outputs = aux_outputs[:, :, : -pad_h, : -pad_w] 176 | outputs_all_scales += outputs 177 | aux_all_scales += aux_outputs 178 | return outputs_all_scales / len(self.scales), aux_all_scales 179 | else: 180 | outputs_all_scales = Variable(torch.zeros((batch_size, self.num_classes, ori_h, ori_w))).cuda() 181 | for s in self.scales: 182 | new_size = (int(ori_h * s), int(ori_w * s)) 183 | scaled_x = F.upsample(x, size=new_size, mode='bilinear') 184 | scaled_h, scaled_w = scaled_x.size()[2:] 185 | long_size = max(scaled_h, scaled_w) 186 | 187 | if long_size > self.crop_size: 188 | count = torch.zeros((scaled_h, scaled_w)) 189 | outputs = Variable(torch.zeros((batch_size, self.num_classes, scaled_h, scaled_w))).cuda() 190 | stride = int(ceil(self.crop_size * self.stride_rate)) 191 | h_step_num = int(ceil((scaled_h - self.crop_size) / stride)) + 1 192 | w_step_num = int(ceil((scaled_w - self.crop_size) / stride)) + 1 193 | for yy in range(h_step_num): 194 | for xx in range(w_step_num): 195 | sy, sx = yy * stride, xx * stride 196 | ey, ex = sy + self.crop_size, sx + self.crop_size 197 | x_sub = scaled_x[:, :, sy: ey, sx: ex] 198 | x_sub, pad_h, pad_w = _pad(x_sub, self.crop_size) 199 | 200 | outputs_sub = single_forward(self, x_sub) 201 | 202 | if sy + self.crop_size > scaled_h: 203 | outputs_sub = outputs_sub[:, :, : -pad_h, :] 204 | 205 | if sx + self.crop_size > scaled_w: 206 | outputs_sub = outputs_sub[:, :, :, : -pad_w] 207 | 208 | outputs[:, :, sy: ey, sx: ex] = outputs_sub 209 | 210 | count[sy: ey, sx: ex] += 1 211 | count = Variable(count).cuda() 212 | outputs = (outputs / count) 213 | else: 214 | scaled_x, pad_h, pad_w = _pad(scaled_x, self.crop_size) 215 | outputs = single_forward(self, scaled_x) 216 | outputs = outputs[:, :, : -pad_h, : -pad_w] 217 | outputs_all_scales += outputs 218 | return outputs_all_scales 219 | 220 | return wrapper 221 | -------------------------------------------------------------------------------- /utils/transform.py: -------------------------------------------------------------------------------- 1 | import math 2 | import random 3 | import numpy as np 4 | #random.seed(0) 5 | #np.random.seed(0) 6 | import cv2 7 | import skimage 8 | from skimage import transform as sktransf 9 | import matplotlib.pyplot as plt 10 | 11 | def showIMG(img): 12 | plt.imshow(img) 13 | plt.show() 14 | return 0 15 | 16 | def rand_flip(img, label): 17 | r = random.random() 18 | # showIMG(img.transpose((1, 2, 0))) 19 | if r < 0.25: 20 | return img, label 21 | elif r < 0.5: 22 | return np.flip(img, axis=0).copy(), np.flip(label, axis=0).copy() 23 | elif r < 0.75: 24 | return np.flip(img, axis=1).copy(), np.flip(label, axis=1).copy() 25 | else: 26 | return img[::-1, ::-1, :].copy(), label[::-1, ::-1].copy() 27 | 28 | def rand_flip_2s(img_s, label_s, img, label): 29 | r = random.random() 30 | # showIMG(img.transpose((1, 2, 0))) 31 | if r < 0.25: 32 | return img_s, label_s, img, label 33 | elif r < 0.5: 34 | return np.flip(img_s, axis=0).copy(), np.flip(label_s, axis=0).copy(), np.flip(img, axis=0).copy(), np.flip(label, axis=0).copy() 35 | elif r < 0.75: 36 | return np.flip(img_s, axis=1).copy(), np.flip(label_s, axis=1).copy(), np.flip(img, axis=1).copy(), np.flip(label, axis=1).copy() 37 | else: 38 | return img_s[::-1, ::-1, :].copy(), label_s[::-1, ::-1].copy(), img[::-1, ::-1, :].copy(), label[::-1, ::-1].copy() 39 | 40 | def rand_flip_mix(img, label, x_s): 41 | r = random.random() 42 | # showIMG(img.transpose((1, 2, 0))) 43 | if r < 0.25: 44 | return img, label, x_s 45 | elif r < 0.5: 46 | return np.flip(img, axis=0).copy(), np.flip(label, axis=0).copy(), np.flip(x_s, axis=0).copy() 47 | elif r < 0.75: 48 | return np.flip(img, axis=1).copy(), np.flip(label, axis=1).copy(), np.flip(x_s, axis=1).copy() 49 | else: 50 | return img[::-1, ::-1, :].copy(), label[::-1, ::-1].copy(), x_s[::-1, ::-1, :].copy() 51 | 52 | def rand_rotate(img, label): 53 | r = random.randint(0,179) 54 | # print(r) 55 | # showIMG(img.transpose((1, 2, 0))) 56 | img_rotate = np.asarray(sktransf.rotate(img, r, order=1, mode='symmetric', 57 | preserve_range=True), np.float) 58 | label_rotate = np.asarray(sktransf.rotate(label, r, order=0, mode='constant', 59 | cval=0, preserve_range=True), np.uint8) 60 | # print(img_rotate[0:10, 0:10, :]) 61 | # print(label_rotate[0:10, 0:10]) 62 | # h_s = image 63 | return img_rotate, label_rotate 64 | 65 | def rand_rotate_crop(img, label): 66 | r = random.randint(0,179) 67 | image_height, image_width = img.shape[0:2] 68 | im_rotated = rotate_image(img, r, order=1) 69 | l_rotated = rotate_image(label, r, order=0) 70 | crop_w, crop_h = largest_rotated_rect(image_width, image_height, math.radians(r)) 71 | im_rotated_cropped = crop_around_center(im_rotated, crop_w, crop_h) 72 | l_rotated_cropped = crop_around_center(l_rotated, crop_w, crop_h) 73 | # print(img_rotate[0:10, 0:10, :]) 74 | # print(label_rotate[0:10, 0:10]) 75 | # h_s = image 76 | return im_rotated_cropped, l_rotated_cropped 77 | 78 | def rotate_image(image, angle, order=0): 79 | """ 80 | Rotates an OpenCV 2 / NumPy image about it's centre by the given angle 81 | (in degrees). The returned image will be large enough to hold the entire 82 | new image, with a black background 83 | """ 84 | 85 | # Get the image size 86 | # No that's not an error - NumPy stores image matricies backwards 87 | image_size = (image.shape[1], image.shape[0]) 88 | image_center = tuple(np.array(image_size) / 2) 89 | 90 | # Convert the OpenCV 3x2 rotation matrix to 3x3 91 | rot_mat = np.vstack( 92 | [cv2.getRotationMatrix2D(image_center, angle, 1.0), [0, 0, 1]] 93 | ) 94 | 95 | rot_mat_notranslate = np.matrix(rot_mat[0:2, 0:2]) 96 | 97 | # Shorthand for below calcs 98 | image_w2 = image_size[0] * 0.5 99 | image_h2 = image_size[1] * 0.5 100 | 101 | # Obtain the rotated coordinates of the image corners 102 | rotated_coords = [ 103 | (np.array([-image_w2, image_h2]) * rot_mat_notranslate).A[0], 104 | (np.array([ image_w2, image_h2]) * rot_mat_notranslate).A[0], 105 | (np.array([-image_w2, -image_h2]) * rot_mat_notranslate).A[0], 106 | (np.array([ image_w2, -image_h2]) * rot_mat_notranslate).A[0] 107 | ] 108 | 109 | # Find the size of the new image 110 | x_coords = [pt[0] for pt in rotated_coords] 111 | x_pos = [x for x in x_coords if x > 0] 112 | x_neg = [x for x in x_coords if x < 0] 113 | 114 | y_coords = [pt[1] for pt in rotated_coords] 115 | y_pos = [y for y in y_coords if y > 0] 116 | y_neg = [y for y in y_coords if y < 0] 117 | 118 | right_bound = max(x_pos) 119 | left_bound = min(x_neg) 120 | top_bound = max(y_pos) 121 | bot_bound = min(y_neg) 122 | 123 | new_w = int(abs(right_bound - left_bound)) 124 | new_h = int(abs(top_bound - bot_bound)) 125 | 126 | # We require a translation matrix to keep the image centred 127 | trans_mat = np.matrix([ 128 | [1, 0, int(new_w * 0.5 - image_w2)], 129 | [0, 1, int(new_h * 0.5 - image_h2)], 130 | [0, 0, 1] 131 | ]) 132 | 133 | # Compute the tranform for the combined rotation and translation 134 | affine_mat = (np.matrix(trans_mat) * np.matrix(rot_mat))[0:2, :] 135 | 136 | # Apply the transform 137 | flag = cv2.INTER_NEAREST 138 | if order == 1: flag = cv2.INTER_LINEAR 139 | elif order == 2: flag = cv2.INTER_AREA 140 | elif order > 2: flag = cv2.INTER_CUBIC 141 | 142 | result = cv2.warpAffine( 143 | image, 144 | affine_mat, 145 | (new_w, new_h), 146 | flags=flag 147 | ) 148 | 149 | return result 150 | 151 | def rand_rotate_mix(img, label, x_s): 152 | r = random.randint(0,179) 153 | # print(r) 154 | # showIMG(img.transpose((1, 2, 0))) 155 | img_rotate = np.asarray(sktransf.rotate(img, r, order=1, mode='symmetric', 156 | preserve_range=True), np.float) 157 | label_rotate = np.asarray(sktransf.rotate(label, r, order=0, mode='constant', 158 | cval=0, preserve_range=True), np.uint8) 159 | x_s_rotate = np.asarray(sktransf.rotate(x_s, r, order=0, mode='symmetric', 160 | cval=0, preserve_range=True), np.uint8) 161 | # print(img_rotate[0:10, 0:10, :]) 162 | # print(label_rotate[0:10, 0:10]) 163 | # h_s = image 164 | return img_rotate, label_rotate, x_s_rotate 165 | 166 | def create_crops(ims, labels, size): 167 | crop_imgs = [] 168 | crop_labels = [] 169 | label_dims = len(labels[0].shape) 170 | for img, label, in zip(ims, labels): 171 | h = img.shape[0] 172 | w = img.shape[1] 173 | c_h = size[0] 174 | c_w = size[1] 175 | if h < c_h or w < c_w: 176 | print("Cannot crop area {} from image with size ({}, {})".format(str(size), h, w)) 177 | crop_imgs.append(img) 178 | crop_labels.append(label) 179 | continue 180 | h_rate = h/c_h 181 | w_rate = w/c_w 182 | h_times = math.ceil(h_rate) 183 | w_times = math.ceil(w_rate) 184 | if h_times==1: stride_h=0 185 | else: 186 | stride_h = math.ceil(c_h*(h_times-h_rate)/(h_times-1)) 187 | if w_times==1: stride_w=0 188 | else: 189 | stride_w = math.ceil(c_w*(w_times-w_rate)/(w_times-1)) 190 | for j in range(h_times): 191 | for i in range(w_times): 192 | s_h = int(j*c_h - j*stride_h) 193 | if(j==(h_times-1)): s_h = h - c_h 194 | e_h = s_h + c_h 195 | s_w = int(i*c_w - i*stride_w) 196 | if(i==(w_times-1)): s_w = w - c_w 197 | e_w = s_w + c_w 198 | # print('%d %d %d %d'%(s_h, e_h, s_w, e_w)) 199 | # print('%d %d %d %d'%(s_h_s, e_h_s, s_w_s, e_w_s)) 200 | crop_imgs.append(img[s_h:e_h, s_w:e_w, :]) 201 | if label_dims==2: 202 | crop_labels.append(label[s_h:e_h, s_w:e_w]) 203 | else: 204 | crop_labels.append(label[s_h:e_h, s_w:e_w, :]) 205 | 206 | print('Sliding crop finished. %d images created.' %len(crop_imgs)) 207 | return crop_imgs, crop_labels 208 | 209 | def create_crops_onlyimgs(ims, size): 210 | crop_imgs = [] 211 | for img in ims: 212 | h = img.shape[0] 213 | w = img.shape[1] 214 | c_h = size[0] 215 | c_w = size[1] 216 | if h < c_h or w < c_w: 217 | print("Cannot crop area {} from image with size ({}, {})".format(str(size), h, w)) 218 | continue 219 | h_rate = h/c_h 220 | w_rate = w/c_w 221 | h_times = math.ceil(h_rate) 222 | w_times = math.ceil(w_rate) 223 | stride_h = math.ceil(c_h*(h_times-h_rate)/(h_times-1)) 224 | stride_w = math.ceil(c_w*(w_times-w_rate)/(w_times-1)) 225 | for j in range(h_times): 226 | for i in range(w_times): 227 | s_h = int(j*c_h - j*stride_h) 228 | if(j==(h_times-1)): s_h = h - c_h 229 | e_h = s_h + c_h 230 | s_w = int(i*c_w - i*stride_w) 231 | if(i==(w_times-1)): s_w = w - c_w 232 | e_w = s_w + c_w 233 | # print('%d %d %d %d'%(s_h, e_h, s_w, e_w)) 234 | # print('%d %d %d %d'%(s_h_s, e_h_s, s_w_s, e_w_s)) 235 | crop_imgs.append(img[s_h:e_h, s_w:e_w, :]) 236 | 237 | print('Sliding crop finished. %d images created.' %len(crop_imgs)) 238 | return crop_imgs 239 | 240 | def sliding_crop_single_img(img, size): 241 | crop_imgs = [] 242 | h = img.shape[0] 243 | w = img.shape[1] 244 | c_h = size[0] 245 | c_w = size[1] 246 | assert h >= c_h and w >= c_w, "Cannot crop area from image." 247 | h_rate = h/c_h 248 | w_rate = w/c_w 249 | h_times = math.ceil(h_rate) 250 | w_times = math.ceil(w_rate) 251 | stride_h = math.ceil(c_h*(h_times-h_rate)/(h_times-1)) 252 | stride_w = math.ceil(c_w*(w_times-w_rate)/(w_times-1)) 253 | for j in range(h_times): 254 | for i in range(w_times): 255 | s_h = int(j*c_h - j*stride_h) 256 | if(j==(h_times-1)): s_h = h - c_h 257 | e_h = s_h + c_h 258 | s_w = int(i*c_w - i*stride_w) 259 | if(i==(w_times-1)): s_w = w - c_w 260 | e_w = s_w + c_w 261 | # print('%d %d %d %d'%(s_h, e_h, s_w, e_w)) 262 | # print('%d %d %d %d'%(s_h_s, e_h_s, s_w_s, e_w_s)) 263 | crop_imgs.append(img[s_h:e_h, s_w:e_w, :]) 264 | 265 | #print('Sliding crop finished. %d images created.' %len(crop_imgs)) 266 | return crop_imgs 267 | 268 | def slidding_crop_WC(imgs_s, labels_s, ims, labels, crop_size_global, crop_size_local, scale=8): 269 | crop_imgs_s = [] 270 | crop_labels_s = [] 271 | crop_imgs = [] 272 | crop_labels = [] 273 | c_h = crop_size_local 274 | c_w = crop_size_local 275 | label_dims = len(labels[0].shape) 276 | for img_s, label_s, img, label in zip(imgs_s, labels_s, ims, labels): 277 | h = img.shape[0] 278 | w = img.shape[1] 279 | offset = int((crop_size_global-crop_size_local)/2) 280 | if h < crop_size_local or w < crop_size_local: 281 | print("Cannot crop area {} from image with size ({}, {})".format(str(size), h, w)) 282 | crop_imgs.append(img) 283 | crop_labels.append(label) 284 | continue 285 | h_rate = h/crop_size_local 286 | w_rate = w/crop_size_local 287 | h_times = math.ceil(h_rate) 288 | w_times = math.ceil(w_rate) 289 | if h_times==1: stride_h=0 290 | else: 291 | stride_h = math.ceil(c_h*(h_times-h_rate)/(h_times-1)) 292 | if w_times==1: stride_w=0 293 | else: 294 | stride_w = math.ceil(c_w*(w_times-w_rate)/(w_times-1)) 295 | for j in range(h_times): 296 | for i in range(w_times): 297 | s_h = int(j*c_h - j*stride_h) 298 | if(j==(h_times-1)): s_h = h - c_h 299 | e_h = s_h + c_h 300 | s_w = int(i*c_w - i*stride_w) 301 | if(i==(w_times-1)): s_w = w - c_w 302 | e_w = s_w + c_w 303 | 304 | s_h_s = int(s_h/scale) 305 | s_w_s = int(s_w/scale) 306 | e_h_s = int((e_h+2*offset)/scale) 307 | e_w_s = int((e_w+2*offset)/scale) 308 | # print('%d %d %d %d'%(s_h, e_h, s_w, e_w)) 309 | # print('%d %d %d %d'%(s_h_s, e_h_s, s_w_s, e_w_s)) 310 | crop_imgs.append(img[s_h:e_h, s_w:e_w, :]) 311 | crop_imgs_s.append(img_s[s_h_s:e_h_s, s_w_s:e_w_s, :]) 312 | if label_dims==2: 313 | crop_labels.append(label[s_h:e_h, s_w:e_w]) 314 | crop_labels_s.append(label_s[s_h_s:e_h_s, s_w_s:e_w_s]) 315 | else: 316 | crop_labels.append(label[s_h:e_h, s_w:e_w, :]) 317 | crop_labels_s.append(label_s[s_h_s:e_h_s, s_w_s:e_w_s, :]) 318 | 319 | print('Sliding crop finished. %d images created.' %len(crop_imgs)) 320 | return crop_imgs_s, crop_labels_s, crop_imgs, crop_labels 321 | 322 | def center_crop(ims, labels, size): 323 | crop_imgs = [] 324 | crop_labels = [] 325 | for img, label in zip(ims, labels): 326 | h = img.shape[0] 327 | w = img.shape[1] 328 | c_h = size[0] 329 | c_w = size[1] 330 | if h < c_h or w < c_w: 331 | print("Cannot crop area {} from image with size ({}, {})".format(str(size), h, w)) 332 | continue 333 | s_h = int(h/2 - c_h/2) 334 | e_h = s_h + c_h 335 | s_w = int(w/2 - c_w/2) 336 | e_w = s_w + c_w 337 | crop_imgs.append(img[s_h:e_h, s_w:e_w, :]) 338 | crop_labels.append(label[s_h:e_h, s_w:e_w, :]) 339 | 340 | print('Center crop finished. %d images created.' %len(crop_imgs)) 341 | return crop_imgs, crop_labels 342 | 343 | def five_crop(ims, labels, size): 344 | crop_imgs = [] 345 | crop_labels = [] 346 | for img, label in zip(ims, labels): 347 | h = img.shape[0] 348 | w = img.shape[1] 349 | c_h = size[0] 350 | c_w = size[1] 351 | if h < c_h or w < c_w: 352 | print("Cannot crop area {} from image with size ({}, {})".format(str(size), h, w)) 353 | continue 354 | s_h = int(h/2 - c_h/2) 355 | e_h = s_h + c_h 356 | s_w = int(w/2 - c_w/2) 357 | e_w = s_w + c_w 358 | crop_imgs.append(img[s_h:e_h, s_w:e_w, :]) 359 | crop_labels.append(label[s_h:e_h, s_w:e_w, :]) 360 | 361 | crop_imgs.append(img[0:c_h, 0:c_w, :]) 362 | crop_labels.append(label[0:c_h, 0:c_w, :]) 363 | crop_imgs.append(img[h-c_h:h, w-c_w:w, :]) 364 | crop_labels.append(label[h-c_h:h, w-c_w:w, :]) 365 | crop_imgs.append(img[0:c_h, w-c_w:w, :]) 366 | crop_labels.append(label[0:c_h, w-c_w:w, :]) 367 | crop_imgs.append(img[h-c_h:h, 0:c_w, :]) 368 | crop_labels.append(label[h-c_h:h, 0:c_w, :]) 369 | 370 | print('Five crop finished. %d images created.' %len(crop_imgs)) 371 | return crop_imgs, crop_labels 372 | 373 | def data_padding(imgs, labels, scale=32): 374 | for idx, img in enumerate(imgs): 375 | label = labels[idx] 376 | shape_before = img.shape 377 | h, w = img.shape[:2] 378 | h_padding = h%scale 379 | w_padding = w%scale 380 | need_padding = h_padding>0 and w_padding>0 381 | if need_padding: 382 | h_padding = (scale-h_padding)/2 383 | h_padding1 = math.ceil(h_padding) 384 | h_padding2 = math.floor(h_padding) 385 | 386 | w_padding = (scale-w_padding)/2 387 | w_padding1 = math.ceil(w_padding) 388 | w_padding2 = math.floor(w_padding) 389 | img = np.pad(img, ((h_padding1, h_padding2), (w_padding1, w_padding2), (0,0)), 'symmetric') 390 | label = np.pad(label, ((h_padding1, h_padding2), (w_padding1, w_padding2), (0,0)), 'constant') 391 | shape_after = img.shape 392 | print('img padding: [%d, %d]->[%d, %d]'%(shape_before[0],shape_before[1],shape_after[0],shape_after[1])) 393 | imgs[idx] = img 394 | labels[idx] = label 395 | return imgs, labels 396 | 397 | def data_padding_fixsize(imgs, labels, size): 398 | for idx, img in enumerate(imgs): 399 | label = labels[idx] 400 | h, w = img.shape[:2] 401 | h_padding = size[0] 402 | w_padding = size[1] 403 | 404 | h_padding1 = math.ceil(h_padding) 405 | h_padding2 = math.floor(h_padding) 406 | 407 | w_padding1 = math.ceil(w_padding) 408 | w_padding2 = math.floor(w_padding) 409 | 410 | img = np.pad(img, ((h_padding1, h_padding2), (w_padding1, w_padding2), (0,0)), 'symmetric') 411 | label = np.pad(label, ((h_padding1, h_padding2), (w_padding1, w_padding2)), 'constant') 412 | imgs[idx] = img 413 | labels[idx] = label 414 | return imgs, labels 415 | 416 | def five_crop_mix(ims, labels, x_s, size, scale=8): 417 | crop_imgs = [] 418 | crop_labels = [] 419 | crop_xs = [] 420 | for img, label, x_s in zip(ims, labels, x_s): 421 | h = img.shape[0] 422 | w = img.shape[1] 423 | h_s = int(h/scale) 424 | w_s = int(w/scale) 425 | c_h = size[0] 426 | c_w = size[1] 427 | c_h_s = int(c_h/scale) 428 | c_w_s = int(c_w/scale) 429 | if h < c_h or w < c_w: 430 | print("Cannot crop area {} from image with size ({}, {})".format(str(size), h, w)) 431 | continue 432 | s_h_s = int(h_s/2 - c_h_s/2) 433 | e_h_s = s_h_s + c_h_s 434 | s_w_s = int(w_s/2 - c_w_s/2) 435 | e_w_s = s_w_s + c_w_s 436 | s_h = s_h_s*scale 437 | s_w = s_w_s*scale 438 | e_h = s_h+c_h 439 | e_w = s_w+c_w 440 | 441 | crop_xs.append(x_s[:, s_h_s:e_h_s, s_w_s:e_w_s]) 442 | crop_imgs.append(img[s_h:e_h, s_w:e_w, :]) 443 | crop_labels.append(label[s_h:e_h, s_w:e_w, :]) 444 | 445 | crop_xs.append(x_s[:, :c_h_s, :c_w_s]) 446 | crop_imgs.append(img[:c_h, :c_w, :]) 447 | crop_labels.append(label[:c_h, :c_w, :]) 448 | 449 | crop_xs.append(x_s[:, -c_h_s:, -c_w_s:]) 450 | crop_imgs.append(img[-c_h:, -c_w:, :]) 451 | crop_labels.append(label[-c_h:, -c_w:, :]) 452 | 453 | crop_xs.append(x_s[:, :c_h_s, -c_w_s:]) 454 | crop_imgs.append(img[:c_h, -c_w:, :]) 455 | crop_labels.append(label[:c_h, -c_w:, :]) 456 | 457 | crop_xs.append(x_s[:, -c_h_s:, :c_w_s]) 458 | crop_imgs.append(img[-c_h:, :c_w, :]) 459 | crop_labels.append(label[-c_h:, :c_w, :]) 460 | 461 | print('Five crop finished. %d images created.' %len(crop_imgs)) 462 | return crop_imgs, crop_labels, crop_xs 463 | 464 | def sliding_crop(img, size): 465 | # print(img.shape) 466 | h = img.shape[0] 467 | w = img.shape[1] 468 | c_h = size[0] 469 | c_w = size[1] 470 | if h < c_h or w < c_w: 471 | print("Cannot crop area {} from image with size ({}, {})" 472 | .format(str(size), h, w)) 473 | else: 474 | h_rate = h/c_h 475 | w_rate = w/c_w 476 | h_times = math.ceil(h_rate) 477 | w_times = math.ceil(w_rate) 478 | stride_h = math.ceil(c_h*(h_times-h_rate)/(h_times-1)) 479 | stride_w = math.ceil(c_w*(w_times-w_rate)/(w_times-1)) 480 | crop_imgs = [] 481 | for j in range(h_times): 482 | for i in range(w_times): 483 | s_h = int(j*c_h - j*stride_h) 484 | if(j==(h_times-1)): s_h = h - c_h 485 | e_h = s_h + c_h 486 | s_w = int(i*c_w - i*stride_w) 487 | if(i==(w_times-1)): s_w = w - c_w 488 | e_w = s_w + c_w 489 | # print('%d %d %d %d'%(s_h, e_h, s_w, e_w)) 490 | crop_im = img[s_h:e_h, s_w:e_w, :] 491 | crop_imgs.append(crop_im) 492 | 493 | # crop_imgs_f = [] 494 | # for im in crop_imgs: 495 | # crop_imgs_f.append(cv2.flip(im, -1)) 496 | 497 | # crops = np.concatenate((np.array(crop_imgs)), axis=0) 498 | # print(crops.shape) 499 | return crop_imgs 500 | 501 | def random_crop(img, label, size): 502 | # print(img.shape) 503 | h = img.shape[0] 504 | w = img.shape[1] 505 | c_h = size[0] 506 | c_w = size[1] 507 | if h < c_h or w < c_w: 508 | print("Cannot crop area {} from image with size ({}, {})" 509 | .format(str(size), h, w)) 510 | else: 511 | s_h = random.randint(0, h-c_h) 512 | e_h = s_h + c_h 513 | s_w = random.randint(0, w-c_w) 514 | e_w = s_w + c_w 515 | 516 | crop_im = img[s_h:e_h, s_w:e_w, :] 517 | crop_label = label[s_h:e_h, s_w:e_w] 518 | # print('%d %d %d %d'%(s_h, e_h, s_w, e_w)) 519 | return crop_im, crop_label 520 | 521 | def random_crop_2s(img_s, label_s, img, label, crop_size_global, crop_size_local, scale): 522 | # print(img.shape) 523 | h_s, w_s = img_s.shape[:2] 524 | h, w = img.shape[:2] 525 | padding_size = int((crop_size_global-crop_size_local)/scale) 526 | crop_size_s = int(crop_size_global/scale) 527 | 528 | if h_s < crop_size_s or w_s < crop_size_s or h < crop_size_local or w < crop_size_local: 529 | print('Crop failed. Size error.') 530 | else: 531 | h_seed = random.randint(0, h_s-crop_size_s) 532 | w_seed = random.randint(0, w_s-crop_size_s) 533 | 534 | start_h_s = h_seed 535 | end_h_s = start_h_s+crop_size_s 536 | start_w_s = w_seed 537 | end_w_s = start_w_s+crop_size_s 538 | crop_im_s = img_s[start_h_s:end_h_s, start_w_s:end_w_s, :] 539 | crop_label_s = label_s[start_h_s:end_h_s, start_w_s:end_w_s] 540 | #print('start_h_s%d, end_h_s%d, start_w_s%d, end_w_s%d'%(start_h_s,end_h_s,start_w_s,end_w_s)) 541 | 542 | start_h = h_seed*scale 543 | end_h = start_h+crop_size_local 544 | start_w = w_seed*scale 545 | end_w = start_w+crop_size_local 546 | #print('start_h%d, end_h%d, start_w%d, end_w%d'%(start_h,end_h,start_w,end_w)) 547 | crop_im = img[start_h:end_h, start_w:end_w, :] 548 | crop_label = label[start_h:end_h, start_w:end_w] 549 | 550 | return crop_im_s, crop_label_s, crop_im, crop_label 551 | 552 | def random_crop_mix(img, label, x_s, size, scale=8): 553 | # print(img.shape) 554 | h = img.shape[0] 555 | w = img.shape[1] 556 | c_h = size[0] 557 | c_w = size[1] 558 | c_h_s = int(c_h/scale) 559 | c_w_s = int(c_w/scale) 560 | h_times = int(h/scale - c_h_s) 561 | w_times = int(w/scale - c_w_s) 562 | if h < c_h or w < c_w: 563 | print("Cannot crop area {} from image with size ({}, {})" 564 | .format(str(size), h, w)) 565 | else: 566 | s_h_s = random.randint(0, h_times) 567 | s_h = s_h_s * scale 568 | s_w_s = random.randint(0, w_times) 569 | s_w = s_w_s * scale 570 | e_h_s = s_h_s + c_h_s 571 | e_w_s = s_w_s + c_w_s 572 | e_h = s_h + c_h 573 | e_w = s_w + c_w 574 | 575 | crop_im = img[s_h:e_h, s_w:e_w, :] 576 | crop_label = label[s_h:e_h, s_w:e_w] 577 | crop_xs = x_s[:, s_h_s:e_h_s, s_w_s:e_w_s] 578 | # print('%d %d %d %d' % (s_h, e_h, s_w, e_w)) 579 | # print('%d %d %d %d' % (s_h_s, e_h_s, s_w_s, e_w_s)) 580 | return crop_im, crop_label, crop_xs 581 | 582 | def create_crops_mix(ims, labels, x_s, size, scale=1/8): 583 | crop_imgs = [] 584 | crop_labels = [] 585 | crop_x_s = [] 586 | for img, label, x in zip(ims, labels, x_s): 587 | h = img.shape[0] 588 | w = img.shape[1] 589 | c_h = size[0] 590 | c_w = size[1] 591 | c_h_s = int(c_h*scale) 592 | c_w_s = int(c_w*scale) 593 | if h < c_h or w < c_w: 594 | print("Cannot crop area {} from image with size ({}, {})".format(str(size), h, w)) 595 | continue 596 | h_rate = h/c_h 597 | w_rate = w/c_w 598 | h_times = math.ceil(h_rate) 599 | w_times = math.ceil(w_rate) 600 | stride_h = math.ceil(c_h*(h_times-h_rate)/(h_times-1)) 601 | stride_w = math.ceil(c_w*(w_times-w_rate)/(w_times-1)) 602 | for j in range(h_times): 603 | for i in range(w_times): 604 | s_h = int(j*c_h - j*stride_h) 605 | s_h_s = int(s_h*scale) 606 | if(j==(h_times-1)): s_h = h - c_h 607 | e_h = s_h + c_h 608 | e_h_s = s_h_s + c_h_s 609 | s_w = int(i*c_w - i*stride_w) 610 | s_w_s = int(s_w*scale) 611 | if(i==(w_times-1)): s_w = w - c_w 612 | e_w = s_w + c_w 613 | e_w_s = s_w_s + c_w_s 614 | crop_imgs.append(img[s_h:e_h, s_w:e_w, :]) 615 | crop_labels.append(label[s_h:e_h, s_w:e_w, :]) 616 | crop_x_s.append(x[:, s_h_s:e_h_s, s_w_s:e_w_s]) 617 | 618 | print('Sliding crop finished. %d images created.' %len(crop_imgs)) 619 | return crop_imgs, crop_labels, crop_x_s 620 | 621 | def crop_around_center(image, width, height): 622 | """ 623 | Given a NumPy / OpenCV 2 image, crops it to the given width and height, 624 | around it's centre point 625 | """ 626 | 627 | image_size = (image.shape[1], image.shape[0]) 628 | image_center = (int(image_size[0] * 0.5), int(image_size[1] * 0.5)) 629 | 630 | if(width > image_size[0]): 631 | width = image_size[0] 632 | 633 | if(height > image_size[1]): 634 | height = image_size[1] 635 | 636 | x1 = int(image_center[0] - width * 0.5) 637 | x2 = int(image_center[0] + width * 0.5) 638 | y1 = int(image_center[1] - height * 0.5) 639 | y2 = int(image_center[1] + height * 0.5) 640 | 641 | return image[y1:y2, x1:x2] 642 | 643 | def largest_rotated_rect(w, h, angle): 644 | """ 645 | Given a rectangle of size wxh that has been rotated by 'angle' (in 646 | radians), computes the width and height of the largest possible 647 | axis-aligned rectangle within the rotated rectangle. 648 | 649 | Original JS code by 'Andri' and Magnus Hoff from Stack Overflow 650 | 651 | Converted to Python by Aaron Snoswell 652 | """ 653 | 654 | quadrant = int(math.floor(angle / (math.pi / 2))) & 3 655 | sign_alpha = angle if ((quadrant & 1) == 0) else math.pi - angle 656 | alpha = (sign_alpha % math.pi + math.pi) % math.pi 657 | 658 | bb_w = w * math.cos(alpha) + h * math.sin(alpha) 659 | bb_h = w * math.sin(alpha) + h * math.cos(alpha) 660 | 661 | gamma = math.atan2(bb_w, bb_w) if (w < h) else math.atan2(bb_w, bb_w) 662 | 663 | delta = math.pi - alpha - gamma 664 | 665 | length = h if (w < h) else w 666 | 667 | d = length * math.cos(alpha) 668 | a = d * math.sin(alpha) / math.sin(delta) 669 | 670 | y = a * math.cos(gamma) 671 | x = y * math.tan(gamma) 672 | 673 | return ( 674 | bb_w - 2 * x, 675 | bb_h - 2 * y 676 | ) 677 | 678 | def Rotate_Aug(imgs, labels, step=20, start_angle=20, max_angle=179): 679 | for idx in range(len(imgs)): 680 | im = imgs[idx] 681 | l = labels[idx] 682 | image_height, image_width = im.shape[0:2] 683 | for i in range(start_angle, max_angle, step): 684 | im_rotated = rotate_image(im, i, order=3) 685 | l_rotated = rotate_image(l, i, order=0) 686 | crop_w, crop_h = largest_rotated_rect(image_width, image_height, math.radians(i)) 687 | im_rotated_cropped = crop_around_center(im_rotated, crop_w, crop_h) 688 | l_rotated_cropped = crop_around_center(l_rotated, crop_w, crop_h) 689 | imgs.append(im_rotated_cropped) 690 | labels.append(l_rotated_cropped) 691 | print('Img %d rotated.'%idx) 692 | print('Rotation finished. %d images in total.'%len(imgs)) 693 | return imgs, labels 694 | 695 | def Rotate_Aug_S(im, l, step=20, start_angle=15, max_angle=89): 696 | imgs = [] 697 | labels = [] 698 | image_height, image_width = im.shape[0:2] 699 | for i in range(start_angle, max_angle, step): 700 | im_rotated = rotate_image(im, i, order=1) 701 | l_rotated = rotate_image(l, i, order=0) 702 | crop_w, crop_h = largest_rotated_rect(image_width, image_height, math.radians(i)) 703 | im_rotated_cropped = crop_around_center(im_rotated, crop_w, crop_h) 704 | l_rotated_cropped = crop_around_center(l_rotated, crop_w, crop_h) 705 | imgs.append(im_rotated_cropped) 706 | labels.append(l_rotated_cropped) 707 | print('Rotation finished. %d images added.'%len(imgs)) 708 | return imgs, labels 709 | 710 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import numpy as np 4 | from scipy import stats 5 | 6 | def read_idtxt(path): 7 | id_list = [] 8 | #print('start reading') 9 | f = open(path, 'r') 10 | curr_str = '' 11 | while True: 12 | ch = f.read(1) 13 | if is_number(ch): 14 | curr_str+=ch 15 | else: 16 | id_list.append(curr_str) 17 | #print(curr_str) 18 | curr_str = '' 19 | if not ch: 20 | #print('end reading') 21 | break 22 | f.close() 23 | return id_list 24 | 25 | def get_square(img, pos): 26 | """Extract a left or a right square from ndarray shape : (H, W, C))""" 27 | h = img.shape[0] 28 | if pos == 0: 29 | return img[:, :h] 30 | else: 31 | return img[:, -h:] 32 | 33 | def split_img_into_squares(img): 34 | return get_square(img, 0), get_square(img, 1) 35 | 36 | def hwc_to_chw(img): 37 | return np.transpose(img, axes=[2, 0, 1]) 38 | 39 | def resize_and_crop(pilimg, scale=0.5, final_height=None): 40 | w = pilimg.size[0] 41 | h = pilimg.size[1] 42 | newW = int(w * scale) 43 | newH = int(h * scale) 44 | 45 | if not final_height: 46 | diff = 0 47 | else: 48 | diff = newH - final_height 49 | 50 | img = pilimg.resize((newW, newH)) 51 | img = img.crop((0, diff // 2, newW, newH - diff // 2)) 52 | return np.array(img, dtype=np.float32) 53 | 54 | def batch(iterable, batch_size): 55 | """Yields lists by batch""" 56 | b = [] 57 | for i, t in enumerate(iterable): 58 | b.append(t) 59 | if (i + 1) % batch_size == 0: 60 | yield b 61 | b = [] 62 | 63 | if len(b) > 0: 64 | yield b 65 | 66 | def seprate_batch(dataset, batch_size): 67 | """Yields lists by batch""" 68 | num_batch = len(dataset)//batch_size+1 69 | batch_len = batch_size 70 | # print (len(data)) 71 | # print (num_batch) 72 | batches = [] 73 | for i in range(num_batch): 74 | batches.append([dataset[j] for j in range(batch_len)]) 75 | # print('current data index: %d' %(i*batch_size+batch_len)) 76 | if (i+2==num_batch): batch_len = len(dataset)-(num_batch-1)*batch_size 77 | return(batches) 78 | 79 | def split_train_val(dataset, val_percent=0.05): 80 | dataset = list(dataset) 81 | length = len(dataset) 82 | n = int(length * val_percent) 83 | random.shuffle(dataset) 84 | return {'train': dataset[:-n], 'val': dataset[-n:]} 85 | 86 | 87 | def normalize(x): 88 | return x / 255 89 | 90 | def merge_masks(img1, img2, full_w): 91 | h = img1.shape[0] 92 | 93 | new = np.zeros((h, full_w), np.float32) 94 | new[:, :full_w // 2 + 1] = img1[:, :full_w // 2 + 1] 95 | new[:, full_w // 2 + 1:] = img2[:, -(full_w // 2 - 1):] 96 | 97 | return new 98 | 99 | 100 | # credits to https://stackoverflow.com/users/6076729/manuel-lagunas 101 | def rle_encode(mask_image): 102 | pixels = mask_image.flatten() 103 | # We avoid issues with '1' at the start or end (at the corners of 104 | # the original image) by setting those pixels to '0' explicitly. 105 | # We do not expect these to be non-zero for an accurate mask, 106 | # so this should not harm the score. 107 | pixels[0] = 0 108 | pixels[-1] = 0 109 | runs = np.where(pixels[1:] != pixels[:-1])[0] + 2 110 | runs[1::2] = runs[1::2] - runs[:-1:2] 111 | return runs 112 | 113 | 114 | class AverageMeter(object): 115 | """Computes and stores the average and current value""" 116 | def __init__(self): 117 | self.initialized = False 118 | self.val = None 119 | self.avg = None 120 | self.sum = None 121 | self.count = None 122 | 123 | def initialize(self, val, count, weight): 124 | self.val = val 125 | self.avg = val 126 | self.count = count 127 | self.sum = val * weight 128 | self.initialized = True 129 | 130 | def update(self, val, count=1, weight=1): 131 | if not self.initialized: 132 | self.initialize(val, count, weight) 133 | else: 134 | self.add(val, count, weight) 135 | 136 | def add(self, val, count, weight): 137 | self.val = val 138 | self.count += count 139 | self.sum += val * weight 140 | self.avg = self.sum / self.count 141 | 142 | def value(self): 143 | return self.val 144 | 145 | def average(self): 146 | return self.avg 147 | 148 | def ImageValStretch2D(img): 149 | img = img*255 150 | #maxval = img.max(axis=0).max(axis=0) 151 | #minval = img.min(axis=0).min(axis=0) 152 | #img = (img-minval)*255/(maxval-minval) 153 | return img.astype(int) 154 | 155 | def ConfMap(output, pred): 156 | # print(output.shape) 157 | n, h, w = output.shape 158 | conf = np.zeros(pred.shape, float) 159 | for h_idx in range(h): 160 | for w_idx in range(w): 161 | n_idx = int(pred[h_idx, w_idx]) 162 | sum = 0 163 | for i in range(n): 164 | val=output[i, h_idx, w_idx] 165 | if val>0: sum+=val 166 | conf[h_idx, w_idx] = output[n_idx, h_idx, w_idx]/sum 167 | if conf[h_idx, w_idx]<0: conf[h_idx, w_idx]=0 168 | # print(conf) 169 | return conf 170 | 171 | def accuracy(pred, label): 172 | valid = (label > 0) 173 | acc_sum = (valid * (pred == label)).sum() 174 | valid_sum = valid.sum() 175 | acc = float(acc_sum) / (valid_sum + 1e-10) 176 | return acc, valid_sum 177 | 178 | def binary_accuracy(pred, label): 179 | valid = (label < 2) 180 | acc_sum = (valid * (pred == label)).sum() 181 | valid_sum = valid.sum() 182 | acc = float(acc_sum) / (valid_sum + 1e-10) 183 | return acc, valid_sum 184 | 185 | def intersectionAndUnion(imPred, imLab, numClass): 186 | imPred = np.asarray(imPred).copy() 187 | imLab = np.asarray(imLab).copy() 188 | 189 | # imPred += 1 190 | # imLab += 1 191 | # Remove classes from unlabeled pixels in gt image. 192 | # We should not penalize detections in unlabeled portions of the image. 193 | imPred = imPred * (imLab > 0) 194 | 195 | # Compute area intersection: 196 | intersection = imPred * (imPred == imLab) 197 | (area_intersection, _) = np.histogram( 198 | intersection, bins=numClass, range=(1, numClass+1)) 199 | # print(area_intersection) 200 | 201 | # Compute area union: 202 | (area_pred, _) = np.histogram(imPred, bins=numClass, range=(1, numClass+1)) 203 | (area_lab, _) = np.histogram(imLab, bins=numClass, range=(1, numClass+1)) 204 | area_union = area_pred + area_lab - area_intersection 205 | # print(area_pred) 206 | # print(area_lab) 207 | return (area_intersection, area_union) 208 | 209 | def CaclTP(imPred, imLab, numClass): 210 | imPred = np.asarray(imPred).copy() 211 | imLab = np.asarray(imLab).copy() 212 | 213 | # imPred += 1 214 | # imLab += 1 215 | # # Remove classes from unlabeled pixels in gt image. 216 | # # We should not penalize detections in unlabeled portions of the image. 217 | imPred = imPred * (imLab > 0) 218 | 219 | # Compute area intersection: 220 | TP = imPred * (imPred == imLab) 221 | (TP_hist, _) = np.histogram( 222 | TP, bins=numClass, range=(1, numClass+1)) 223 | # print(TP.shape) 224 | # print(TP_hist) 225 | 226 | # Compute area union: 227 | (pred_hist, _) = np.histogram(imPred, bins=numClass, range=(1, numClass+1)) 228 | (lab_hist, _) = np.histogram(imLab, bins=numClass, range=(1, numClass+1)) 229 | 230 | union_hist = pred_hist + lab_hist - TP_hist 231 | # print(pred_hist) 232 | # print(lab_hist) 233 | # precision = TP_hist / (lab_hist + 1e-10) + 1e-10 234 | # recall = TP_hist / (pred_hist + 1e-10) + 1e-10 235 | # # print(precision) 236 | # # print(recall) 237 | # F1 = [stats.hmean([pre, rec]) for pre, rec in zip(precision, recall)] 238 | # print(F1) 239 | 240 | # print(area_pred) 241 | # print(area_lab) 242 | 243 | return (TP_hist, pred_hist, lab_hist, union_hist) --------------------------------------------------------------------------------