├── README.md ├── config └── uavid │ └── ssfnet.py ├── geoseg ├── __init__.py ├── datasets │ ├── __init__.py │ ├── loveda_dataset.py │ ├── potsdam_dataset.py │ ├── transform.py │ ├── uavid_dataset.py │ └── vaihingen_dataset.py ├── losses │ ├── __init__.py │ ├── balanced_bce.py │ ├── bitempered_loss.py │ ├── cel1.py │ ├── dice.py │ ├── focal.py │ ├── focal_cosine.py │ ├── functional.py │ ├── jaccard.py │ ├── joint_loss.py │ ├── lovasz.py │ ├── soft_bce.py │ ├── soft_ce.py │ ├── soft_f1.py │ ├── useful_loss.py │ └── wing_loss.py └── models │ ├── SSFNet.py │ └── __init__.py ├── inference_uavid.py ├── modules ├── BFM.py ├── FAC.py └── FRFB.py ├── requirements.txt ├── tools ├── __init__.py ├── cfg.py ├── img.png ├── loveda_mask_convert.py ├── metric.py ├── uavid_patch_split.py └── vaihingen_patch_split.py └── train_supervision.py /README.md: -------------------------------------------------------------------------------- 1 | # SSFNet 2 | SSF 3 | 4 | 5 | SSFNet: A network for real-time processing of unmanned aerial vehicle (UAV) remote sensing images, aiming to balance the trade-off between accuracy and speed in semantic segmentation. 6 | 7 | 8 | **2024.4.12** 9 | 10 | SSFNet was updated. 11 | 12 | **2024.3.27** 13 | 14 | FRFB, BFM, and FAC were updated. 15 | 16 | SSFNet: Lightweight real-time network for semantic segmentation of UAV remote sensing images 17 | 18 | Our results have been submitted to the official dataset under the username wxy07496, and the specific scores are as follows. Due to the progress of other work, the official website at https://codalab.lisn.upsaclay.fr/competitions/7302#results now displays another project of ours. If interested, please check it out at https://github.com/wxy16/MLFMNet. 19 | 20 | 21 | ![图片1](https://github.com/wxy16/SSFNet/assets/128227957/c2fa5f37-51a1-4191-a8bd-f34d9f9b4b3a) 22 | 23 | 24 | 25 | **UAVid** 26 | ``` 27 | python tools/uavid_patch_split.py \ 28 | --input-dir "data/uavid/uavid_train_val" \ 29 | --output-img-dir "data/uavid/train_val/images" \ 30 | --output-mask-dir "data/uavid/train_val/masks" \ 31 | --mode 'train' --split-size-h 1024 --split-size-w 1024 \ 32 | --stride-h 1024 --stride-w 1024 33 | ``` 34 | 35 | ``` 36 | python tools/uavid_patch_split.py \ 37 | --input-dir "data/uavid/uavid_train" \ 38 | --output-img-dir "data/uavid/train/images" \ 39 | --output-mask-dir "data/uavid/train/masks" \ 40 | --mode 'train' --split-size-h 1024 --split-size-w 1024 \ 41 | --stride-h 1024 --stride-w 1024 42 | ``` 43 | 44 | ``` 45 | python tools/uavid_patch_split.py \ 46 | --input-dir "data/uavid/uavid_val" \ 47 | --output-img-dir "data/uavid/val/images" \ 48 | --output-mask-dir "data/uavid/val/masks" \ 49 | --mode 'val' --split-size-h 1024 --split-size-w 1024 \ 50 | --stride-h 1024 --stride-w 1024 51 | ``` 52 | 53 | ## Training 54 | 55 | "-c" means the path of the config, use different **config** to train different models. 56 | 57 | ``` 58 | python train_supervision.py -c config/uavid/ssfnet.py 59 | ``` 60 | 61 | ## Testing 62 | 63 | "-c" denotes the path of the config, Use different **config** to test different models. 64 | 65 | "-o" denotes the output path 66 | 67 | "-t" denotes the test time augmentation (TTA), can be [None, 'lr', 'd4'], default is None, 'lr' is flip TTA, 'd4' is multiscale TTA 68 | 69 | "--rgb" denotes whether to output masks in RGB format 70 | 71 | **UAVid** ([Online Testing](https://codalab.lisn.upsaclay.fr/competitions/7302)) 72 | ``` 73 | python inference_uavid.py \ 74 | -i 'data/uavid/uavid_test' \ 75 | -c config/uavid/SSFNet.py \ 76 | -o results/uavid/ \ 77 | -t 'lr' -ph 1152 -pw 1024 -b 2 -d "uavid" 78 | 79 | 80 | **Training Code Reference ** 81 | 82 | ([GeoSeg](https://github.com/WangLibo1995/GeoSeg)) 83 | ([mmsegmentation](https://github.com/open-mmlab/mmsegmentation)) 84 | 85 | -------------------------------------------------------------------------------- /config/uavid/ssfnet.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import DataLoader 2 | from geoseg.losses import * 3 | from geoseg.datasets.uavid_dataset import * 4 | from geoseg.models.SSFNet import SSFNet 5 | from catalyst.contrib.nn import Lookahead 6 | from catalyst import utils 7 | 8 | max_epoch = 60 9 | ignore_index = 255 10 | train_batch_size = 16 11 | val_batch_size = 4 12 | lr = 5e-4 13 | weight_decay = 0.01 14 | backbone_lr = 1e-4 15 | backbone_weight_decay = 0.01 16 | accumulate_n = 1 17 | num_classes = len(CLASSES) 18 | classes = CLASSES 19 | 20 | weights_name = "ssfnet-r18-1024" 21 | weights_path = "model_weights/uavid/{}".format(weights_name) 22 | test_weights_name = "last" 23 | log_name = 'uavid/{}'.format(weights_name) 24 | monitor = 'val_mIoU' 25 | monitor_mode = 'max' 26 | save_top_k = 3 27 | save_last = True 28 | check_val_every_n_epoch = 4 29 | gpus = [0] 30 | strategy = None 31 | pretrained_ckpt_path = None 32 | resume_ckpt_path = None 33 | # define the network 34 | net = SSFNet(num_classes=num_classes) 35 | # define the loss 36 | loss = JointLoss(SoftCrossEntropyLoss(smooth_factor=0.05, ignore_index=ignore_index), 37 | DiceLoss(smooth=0.05, ignore_index=ignore_index), 1.0, 1.0) 38 | use_aux_loss = False 39 | 40 | train_dataset = UAVIDDataset(data_root='data/uavid/Train1', img_dir='images', mask_dir='masks', 41 | mode='train', mosaic_ratio=0.25, transform=train_aug, img_size=(1024, 1024)) 42 | 43 | val_dataset = UAVIDDataset(data_root='data/uavid/train_val', img_dir='images', mask_dir='masks', mode='val', 44 | mosaic_ratio=0.0, transform=val_aug, img_size=(1024, 1024)) 45 | 46 | 47 | train_loader = DataLoader(dataset=train_dataset, 48 | batch_size=train_batch_size, 49 | num_workers=4, 50 | pin_memory=True, 51 | shuffle=True, 52 | drop_last=True) 53 | 54 | val_loader = DataLoader(dataset=val_dataset, 55 | batch_size=val_batch_size, 56 | num_workers=4, 57 | shuffle=False, 58 | pin_memory=True, 59 | drop_last=False) 60 | 61 | # define the optimizer 62 | layerwise_params = {"backbone.*": dict(lr=backbone_lr, weight_decay=backbone_weight_decay)} 63 | net_params = utils.process_model_params(net, layerwise_params=layerwise_params) 64 | base_optimizer = torch.optim.AdamW(net_params, lr=lr, weight_decay=weight_decay) 65 | optimizer = Lookahead(base_optimizer) 66 | lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=max_epoch) 67 | 68 | -------------------------------------------------------------------------------- /geoseg/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wxy16/SSFNet/b424b22bbab4a142f150a417be19880102462618/geoseg/__init__.py -------------------------------------------------------------------------------- /geoseg/datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wxy16/SSFNet/b424b22bbab4a142f150a417be19880102462618/geoseg/datasets/__init__.py -------------------------------------------------------------------------------- /geoseg/datasets/loveda_dataset.py: -------------------------------------------------------------------------------- 1 | from .transform import * 2 | import os 3 | import os.path as osp 4 | import numpy as np 5 | import torch 6 | from torch.utils.data import Dataset 7 | import cv2 8 | import matplotlib.pyplot as plt 9 | import albumentations as albu 10 | import matplotlib.patches as mpatches 11 | from PIL import Image, ImageOps 12 | import random 13 | 14 | 15 | CLASSES = ('background', 'building', 'road', 'water', 'barren', 'forest', 16 | 'agricultural') 17 | 18 | PALETTE = [[255, 255, 255], [255, 0, 0], [255, 255, 0], [0, 0, 255], 19 | [159, 129, 183], [0, 255, 0], [255, 195, 128]] 20 | 21 | 22 | ORIGIN_IMG_SIZE = (1024, 1024) 23 | INPUT_IMG_SIZE = (1024, 1024) 24 | TEST_IMG_SIZE = (1024, 1024) 25 | 26 | 27 | def get_training_transform(): 28 | train_transform = [ 29 | # albu.Resize(height=1024, width=1024), 30 | albu.HorizontalFlip(p=0.5), 31 | albu.VerticalFlip(p=0.5), 32 | albu.RandomBrightnessContrast(brightness_limit=0.25, contrast_limit=0.25, p=0.25), 33 | # albu.RandomRotate90(p=0.5), 34 | # albu.OneOf([ 35 | # albu.RandomBrightnessContrast(brightness_limit=0.25, contrast_limit=0.25), 36 | # albu.HueSaturationValue(hue_shift_limit=10, sat_shift_limit=35, val_shift_limit=25) 37 | # ], p=0.25), 38 | albu.Normalize() 39 | ] 40 | return albu.Compose(train_transform) 41 | 42 | 43 | def train_aug(img, mask): 44 | # multi-scale training and crop 45 | crop_aug = Compose([RandomScale(scale_list=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75, 2.0], mode='value'), 46 | SmartCropV1(crop_size=512, max_ratio=0.75, ignore_index=255, nopad=False)]) 47 | #crop_aug = SmartCropV1(crop_size=512, max_ratio=0.75, ignore_index=255, nopad=False) 48 | img, mask = crop_aug(img, mask) 49 | 50 | img, mask = np.array(img), np.array(mask) 51 | aug = get_training_transform()(image=img.copy(), mask=mask.copy()) 52 | img, mask = aug['image'], aug['mask'] 53 | return img, mask 54 | 55 | 56 | def get_val_transform(): 57 | val_transform = [ 58 | # albu.Resize(height=1024, width=1024, interpolation=cv2.INTER_CUBIC), 59 | albu.Normalize() 60 | ] 61 | return albu.Compose(val_transform) 62 | 63 | 64 | def val_aug(img, mask): 65 | img, mask = np.array(img), np.array(mask) 66 | aug = get_val_transform()(image=img.copy(), mask=mask.copy()) 67 | img, mask = aug['image'], aug['mask'] 68 | return img, mask 69 | 70 | 71 | class LoveDATrainDataset(Dataset): 72 | def __init__(self, data_root='data/LoveDA/Train', img_dir='images_png', mosaic_ratio=0.25, 73 | mask_dir='masks_png_convert', img_suffix='.png', mask_suffix='.png', 74 | transform=train_aug, img_size=ORIGIN_IMG_SIZE): 75 | self.data_root = data_root 76 | self.img_dir = img_dir 77 | self.mask_dir = mask_dir 78 | self.mosaic_ratio = mosaic_ratio 79 | 80 | self.img_suffix = img_suffix 81 | self.mask_suffix = mask_suffix 82 | self.transform = transform 83 | self.img_size = img_size 84 | self.img_ids = self.get_img_ids(self.data_root, self.img_dir, self.mask_dir) 85 | 86 | def __getitem__(self, index): 87 | p_ratio = random.random() 88 | img, mask = self.load_img_and_mask(index) 89 | if p_ratio < self.mosaic_ratio: 90 | img, mask = self.load_mosaic_img_and_mask(index) 91 | if self.transform: 92 | img, mask = self.transform(img, mask) 93 | img = torch.from_numpy(img).permute(2, 0, 1).float() 94 | mask = torch.from_numpy(mask).long() 95 | img_id, img_type = self.img_ids[index] 96 | results = {'img': img, 'gt_semantic_seg': mask, 'img_id': img_id, 'img_type': img_type} 97 | 98 | return results 99 | 100 | def __len__(self): 101 | length = len(self.img_ids) 102 | return length 103 | 104 | def get_img_ids(self, data_root, img_dir, mask_dir): 105 | urban_img_filename_list = os.listdir(osp.join(data_root, 'Urban', img_dir)) 106 | urban_mask_filename_list = os.listdir(osp.join(data_root, 'Urban', mask_dir)) 107 | assert len(urban_img_filename_list) == len(urban_mask_filename_list) 108 | urban_img_ids = [(str(id.split('.')[0]), 'Urban') for id in urban_img_filename_list] 109 | 110 | rural_img_filename_list = os.listdir(osp.join(data_root, 'Rural', img_dir)) 111 | rural_mask_filename_list = os.listdir(osp.join(data_root, 'Rural', mask_dir)) 112 | assert len(rural_img_filename_list) == len(rural_mask_filename_list) 113 | rural_img_ids = [(str(id.split('.')[0]), 'Rural') for id in rural_img_filename_list] 114 | img_ids = urban_img_ids + rural_img_ids 115 | 116 | return img_ids 117 | 118 | def load_img_and_mask(self, index): 119 | img_id, img_type = self.img_ids[index] 120 | img_name = osp.join(self.data_root, img_type, self.img_dir, img_id + self.img_suffix) 121 | mask_name = osp.join(self.data_root, img_type, self.mask_dir, img_id + self.mask_suffix) 122 | img = Image.open(img_name).convert('RGB') 123 | mask = Image.open(mask_name).convert('L') 124 | 125 | return img, mask 126 | 127 | def load_mosaic_img_and_mask(self, index): 128 | indexes = [index] + [random.randint(0, len(self.img_ids) - 1) for _ in range(3)] 129 | img_a, mask_a = self.load_img_and_mask(indexes[0]) 130 | img_b, mask_b = self.load_img_and_mask(indexes[1]) 131 | img_c, mask_c = self.load_img_and_mask(indexes[2]) 132 | img_d, mask_d = self.load_img_and_mask(indexes[3]) 133 | 134 | img_a, mask_a = np.array(img_a), np.array(mask_a) 135 | img_b, mask_b = np.array(img_b), np.array(mask_b) 136 | img_c, mask_c = np.array(img_c), np.array(mask_c) 137 | img_d, mask_d = np.array(img_d), np.array(mask_d) 138 | 139 | w = self.img_size[1] 140 | h = self.img_size[0] 141 | 142 | start_x = w // 4 143 | strat_y = h // 4 144 | # The coordinates of the splice center 145 | offset_x = random.randint(start_x, (w - start_x)) 146 | offset_y = random.randint(strat_y, (h - strat_y)) 147 | 148 | crop_size_a = (offset_x, offset_y) 149 | crop_size_b = (w - offset_x, offset_y) 150 | crop_size_c = (offset_x, h - offset_y) 151 | crop_size_d = (w - offset_x, h - offset_y) 152 | 153 | random_crop_a = albu.RandomCrop(width=crop_size_a[0], height=crop_size_a[1]) 154 | random_crop_b = albu.RandomCrop(width=crop_size_b[0], height=crop_size_b[1]) 155 | random_crop_c = albu.RandomCrop(width=crop_size_c[0], height=crop_size_c[1]) 156 | random_crop_d = albu.RandomCrop(width=crop_size_d[0], height=crop_size_d[1]) 157 | 158 | croped_a = random_crop_a(image=img_a.copy(), mask=mask_a.copy()) 159 | croped_b = random_crop_b(image=img_b.copy(), mask=mask_b.copy()) 160 | croped_c = random_crop_c(image=img_c.copy(), mask=mask_c.copy()) 161 | croped_d = random_crop_d(image=img_d.copy(), mask=mask_d.copy()) 162 | 163 | img_crop_a, mask_crop_a = croped_a['image'], croped_a['mask'] 164 | img_crop_b, mask_crop_b = croped_b['image'], croped_b['mask'] 165 | img_crop_c, mask_crop_c = croped_c['image'], croped_c['mask'] 166 | img_crop_d, mask_crop_d = croped_d['image'], croped_d['mask'] 167 | 168 | top = np.concatenate((img_crop_a, img_crop_b), axis=1) 169 | bottom = np.concatenate((img_crop_c, img_crop_d), axis=1) 170 | img = np.concatenate((top, bottom), axis=0) 171 | 172 | top_mask = np.concatenate((mask_crop_a, mask_crop_b), axis=1) 173 | bottom_mask = np.concatenate((mask_crop_c, mask_crop_d), axis=1) 174 | mask = np.concatenate((top_mask, bottom_mask), axis=0) 175 | mask = np.ascontiguousarray(mask) 176 | img = np.ascontiguousarray(img) 177 | 178 | img = Image.fromarray(img) 179 | mask = Image.fromarray(mask) 180 | 181 | return img, mask 182 | 183 | 184 | loveda_val_dataset = LoveDATrainDataset(data_root='data/LoveDA/Val', mosaic_ratio=0.0, 185 | transform=val_aug) 186 | 187 | 188 | class LoveDATestDataset(Dataset): 189 | def __init__(self, data_root='data/LoveDA/Test', img_dir='images_png', 190 | img_suffix='.png', mosaic_ratio=0.0, 191 | img_size=ORIGIN_IMG_SIZE): 192 | self.data_root = data_root 193 | self.img_dir = img_dir 194 | 195 | self.img_suffix = img_suffix 196 | self.mosaic_ratio = mosaic_ratio 197 | self.img_size = img_size 198 | self.img_ids = self.get_img_ids(self.data_root, self.img_dir) 199 | 200 | def __getitem__(self, index): 201 | img = self.load_img(index) 202 | 203 | img = np.array(img) 204 | aug = albu.Normalize()(image=img) 205 | img = aug['image'] 206 | 207 | img = torch.from_numpy(img).permute(2, 0, 1).float() 208 | img_id, img_type = self.img_ids[index] 209 | 210 | results = {'img': img, 'img_id': img_id, 'img_type': img_type} 211 | 212 | return results 213 | 214 | def __len__(self): 215 | length = len(self.img_ids) 216 | 217 | return length 218 | 219 | def get_img_ids(self, data_root, img_dir): 220 | urban_img_filename_list = os.listdir(osp.join(data_root, 'Urban', img_dir)) 221 | urban_img_ids = [(str(id.split('.')[0]), 'Urban') for id in urban_img_filename_list] 222 | rural_img_filename_list = os.listdir(osp.join(data_root, 'Rural', img_dir)) 223 | rural_img_ids = [(str(id.split('.')[0]), 'Rural') for id in rural_img_filename_list] 224 | img_ids = urban_img_ids + rural_img_ids 225 | 226 | return img_ids 227 | 228 | def load_img(self, index): 229 | img_id, img_type = self.img_ids[index] 230 | img_name = osp.join(self.data_root, img_type, self.img_dir, img_id + self.img_suffix) 231 | img = Image.open(img_name).convert('RGB') 232 | 233 | return img 234 | 235 | 236 | def show_img_mask_seg(seg_path, img_path, mask_path, start_seg_index): 237 | seg_list = os.listdir(seg_path) 238 | fig, ax = plt.subplots(2, 3, figsize=(18, 12)) 239 | seg_list = seg_list[start_seg_index:start_seg_index+2] 240 | patches = [mpatches.Patch(color=np.array(PALETTE[i])/255., label=CLASSES[i]) for i in range(len(CLASSES))] 241 | for i in range(len(seg_list)): 242 | seg_id = seg_list[i] 243 | img_seg = cv2.imread(f'{seg_path}/{seg_id}', cv2.IMREAD_UNCHANGED) 244 | img_seg = img_seg.astype(np.uint8) 245 | img_seg = Image.fromarray(img_seg).convert('P') 246 | img_seg.putpalette(np.array(PALETTE, dtype=np.uint8)) 247 | img_seg = np.array(img_seg.convert('RGB')) 248 | mask = cv2.imread(f'{mask_path}/{seg_id}', cv2.IMREAD_UNCHANGED) 249 | mask = mask.astype(np.uint8) 250 | mask = Image.fromarray(mask).convert('P') 251 | mask.putpalette(np.array(PALETTE, dtype=np.uint8)) 252 | mask = np.array(mask.convert('RGB')) 253 | img_id = str(seg_id.split('.')[0])+'.tif' 254 | img = cv2.imread(f'{img_path}/{img_id}', cv2.IMREAD_COLOR) 255 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 256 | ax[i, 0].set_axis_off() 257 | ax[i, 0].imshow(img) 258 | ax[i, 0].set_title('RS IMAGE ' + img_id) 259 | ax[i, 1].set_axis_off() 260 | ax[i, 1].imshow(mask) 261 | ax[i, 1].set_title('Mask True ' + seg_id) 262 | ax[i, 2].set_axis_off() 263 | ax[i, 2].imshow(img_seg) 264 | ax[i, 2].set_title('Mask Predict ' + seg_id) 265 | ax[i, 2].legend(handles=patches, bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0., fontsize='large') 266 | 267 | 268 | def show_seg(seg_path, img_path, start_seg_index): 269 | seg_list = os.listdir(seg_path) 270 | fig, ax = plt.subplots(2, 2, figsize=(12, 12)) 271 | seg_list = seg_list[start_seg_index:start_seg_index+2] 272 | patches = [mpatches.Patch(color=np.array(PALETTE[i])/255., label=CLASSES[i]) for i in range(len(CLASSES))] 273 | for i in range(len(seg_list)): 274 | seg_id = seg_list[i] 275 | img_seg = cv2.imread(f'{seg_path}/{seg_id}', cv2.IMREAD_UNCHANGED) 276 | img_seg = img_seg.astype(np.uint8) 277 | img_seg = Image.fromarray(img_seg).convert('P') 278 | img_seg.putpalette(np.array(PALETTE, dtype=np.uint8)) 279 | img_seg = np.array(img_seg.convert('RGB')) 280 | img_id = str(seg_id.split('.')[0])+'.tif' 281 | img = cv2.imread(f'{img_path}/{img_id}', cv2.IMREAD_COLOR) 282 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 283 | ax[i, 0].set_axis_off() 284 | ax[i, 0].imshow(img) 285 | ax[i, 0].set_title('RS IMAGE '+img_id) 286 | ax[i, 1].set_axis_off() 287 | ax[i, 1].imshow(img_seg) 288 | ax[i, 1].set_title('Seg IMAGE '+seg_id) 289 | ax[i, 1].legend(handles=patches, bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0., fontsize='large') 290 | 291 | 292 | def show_mask(img, mask, img_id): 293 | fig, (ax1, ax2) = plt.subplots(nrows=1, ncols=2, figsize=(12, 12)) 294 | patches = [mpatches.Patch(color=np.array(PALETTE[i])/255., label=CLASSES[i]) for i in range(len(CLASSES))] 295 | mask = mask.astype(np.uint8) 296 | mask = Image.fromarray(mask).convert('P') 297 | mask.putpalette(np.array(PALETTE, dtype=np.uint8)) 298 | mask = np.array(mask.convert('RGB')) 299 | ax1.imshow(img) 300 | ax1.set_title('RS IMAGE ' + str(img_id)+'.png') 301 | ax2.imshow(mask) 302 | ax2.set_title('Mask ' + str(img_id)+'.png') 303 | ax2.legend(handles=patches, bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0., fontsize='large') 304 | -------------------------------------------------------------------------------- /geoseg/datasets/potsdam_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import numpy as np 4 | import torch 5 | from torch.utils.data import Dataset 6 | import cv2 7 | import matplotlib.pyplot as plt 8 | import albumentations as albu 9 | from .transform import * 10 | import matplotlib.patches as mpatches 11 | from PIL import Image 12 | import random 13 | 14 | 15 | CLASSES = ('ImSurf', 'Building', 'LowVeg', 'Tree', 'Car', 'Clutter') 16 | PALETTE = [[255, 255, 255], [0, 0, 255], [0, 255, 255], [0, 255, 0], [255, 204, 0], [255, 0, 0]] 17 | 18 | ORIGIN_IMG_SIZE = (1024, 1024) 19 | INPUT_IMG_SIZE = (1024, 1024) 20 | TEST_IMG_SIZE = (1024, 1024) 21 | 22 | def get_training_transform(): 23 | train_transform = [ 24 | # albu.RandomBrightnessContrast(brightness_limit=0.25, contrast_limit=0.25, p=0.15), 25 | # albu.RandomRotate90(p=0.25), 26 | albu.Normalize() 27 | ] 28 | return albu.Compose(train_transform) 29 | 30 | 31 | def train_aug(img, mask): 32 | crop_aug = Compose([RandomScale(scale_list=[0.75, 1.0, 1.25, 1.5], mode='value'), 33 | SmartCropV1(crop_size=768, max_ratio=0.75, ignore_index=len(CLASSES), nopad=False)]) 34 | img, mask = crop_aug(img, mask) 35 | img, mask = np.array(img), np.array(mask) 36 | aug = get_training_transform()(image=img.copy(), mask=mask.copy()) 37 | img, mask = aug['image'], aug['mask'] 38 | return img, mask 39 | 40 | 41 | def get_val_transform(): 42 | val_transform = [ 43 | albu.Normalize() 44 | ] 45 | return albu.Compose(val_transform) 46 | 47 | 48 | def val_aug(img, mask): 49 | img, mask = np.array(img), np.array(mask) 50 | aug = get_val_transform()(image=img.copy(), mask=mask.copy()) 51 | img, mask = aug['image'], aug['mask'] 52 | return img, mask 53 | 54 | 55 | class PotsdamDataset(Dataset): 56 | def __init__(self, data_root='data/potsdam/test', mode='val', img_dir='images_1024', mask_dir='masks_1024', 57 | img_suffix='.tif', mask_suffix='.png', transform=val_aug, mosaic_ratio=0.0, 58 | img_size=ORIGIN_IMG_SIZE): 59 | self.data_root = data_root 60 | self.img_dir = img_dir 61 | self.mask_dir = mask_dir 62 | self.img_suffix = img_suffix 63 | self.mask_suffix = mask_suffix 64 | self.transform = transform 65 | self.mode = mode 66 | self.mosaic_ratio = mosaic_ratio 67 | self.img_size = img_size 68 | self.img_ids = self.get_img_ids(self.data_root, self.img_dir, self.mask_dir) 69 | 70 | def __getitem__(self, index): 71 | p_ratio = random.random() 72 | if p_ratio > self.mosaic_ratio or self.mode == 'val' or self.mode == 'test': 73 | img, mask = self.load_img_and_mask(index) 74 | if self.transform: 75 | img, mask = self.transform(img, mask) 76 | else: 77 | img, mask = self.load_mosaic_img_and_mask(index) 78 | if self.transform: 79 | img, mask = self.transform(img, mask) 80 | 81 | img = torch.from_numpy(img).permute(2, 0, 1).float() 82 | mask = torch.from_numpy(mask).long() 83 | img_id = self.img_ids[index] 84 | results = dict(img_id=img_id, img=img, gt_semantic_seg=mask) 85 | return results 86 | 87 | def __len__(self): 88 | return len(self.img_ids) 89 | 90 | def get_img_ids(self, data_root, img_dir, mask_dir): 91 | img_filename_list = os.listdir(osp.join(data_root, img_dir)) 92 | mask_filename_list = os.listdir(osp.join(data_root, mask_dir)) 93 | assert len(img_filename_list) == len(mask_filename_list) 94 | img_ids = [str(id.split('.')[0]) for id in mask_filename_list] 95 | return img_ids 96 | 97 | def load_img_and_mask(self, index): 98 | img_id = self.img_ids[index] 99 | img_name = osp.join(self.data_root, self.img_dir, img_id + self.img_suffix) 100 | mask_name = osp.join(self.data_root, self.mask_dir, img_id + self.mask_suffix) 101 | img = Image.open(img_name).convert('RGB') 102 | mask = Image.open(mask_name).convert('L') 103 | return img, mask 104 | 105 | def load_mosaic_img_and_mask(self, index): 106 | indexes = [index] + [random.randint(0, len(self.img_ids) - 1) for _ in range(3)] 107 | img_a, mask_a = self.load_img_and_mask(indexes[0]) 108 | img_b, mask_b = self.load_img_and_mask(indexes[1]) 109 | img_c, mask_c = self.load_img_and_mask(indexes[2]) 110 | img_d, mask_d = self.load_img_and_mask(indexes[3]) 111 | 112 | img_a, mask_a = np.array(img_a), np.array(mask_a) 113 | img_b, mask_b = np.array(img_b), np.array(mask_b) 114 | img_c, mask_c = np.array(img_c), np.array(mask_c) 115 | img_d, mask_d = np.array(img_d), np.array(mask_d) 116 | 117 | w = self.img_size[1] 118 | h = self.img_size[0] 119 | 120 | start_x = w // 4 121 | strat_y = h // 4 122 | # The coordinates of the splice center 123 | offset_x = random.randint(start_x, (w - start_x)) 124 | offset_y = random.randint(strat_y, (h - strat_y)) 125 | 126 | crop_size_a = (offset_x, offset_y) 127 | crop_size_b = (w - offset_x, offset_y) 128 | crop_size_c = (offset_x, h - offset_y) 129 | crop_size_d = (w - offset_x, h - offset_y) 130 | 131 | random_crop_a = albu.RandomCrop(width=crop_size_a[0], height=crop_size_a[1]) 132 | random_crop_b = albu.RandomCrop(width=crop_size_b[0], height=crop_size_b[1]) 133 | random_crop_c = albu.RandomCrop(width=crop_size_c[0], height=crop_size_c[1]) 134 | random_crop_d = albu.RandomCrop(width=crop_size_d[0], height=crop_size_d[1]) 135 | 136 | croped_a = random_crop_a(image=img_a.copy(), mask=mask_a.copy()) 137 | croped_b = random_crop_b(image=img_b.copy(), mask=mask_b.copy()) 138 | croped_c = random_crop_c(image=img_c.copy(), mask=mask_c.copy()) 139 | croped_d = random_crop_d(image=img_d.copy(), mask=mask_d.copy()) 140 | 141 | img_crop_a, mask_crop_a = croped_a['image'], croped_a['mask'] 142 | img_crop_b, mask_crop_b = croped_b['image'], croped_b['mask'] 143 | img_crop_c, mask_crop_c = croped_c['image'], croped_c['mask'] 144 | img_crop_d, mask_crop_d = croped_d['image'], croped_d['mask'] 145 | 146 | top = np.concatenate((img_crop_a, img_crop_b), axis=1) 147 | bottom = np.concatenate((img_crop_c, img_crop_d), axis=1) 148 | img = np.concatenate((top, bottom), axis=0) 149 | 150 | top_mask = np.concatenate((mask_crop_a, mask_crop_b), axis=1) 151 | bottom_mask = np.concatenate((mask_crop_c, mask_crop_d), axis=1) 152 | mask = np.concatenate((top_mask, bottom_mask), axis=0) 153 | mask = np.ascontiguousarray(mask) 154 | img = np.ascontiguousarray(img) 155 | 156 | img = Image.fromarray(img) 157 | mask = Image.fromarray(mask) 158 | 159 | return img, mask 160 | -------------------------------------------------------------------------------- /geoseg/datasets/transform.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numbers 3 | from PIL import Image, ImageOps, ImageEnhance 4 | import numpy as np 5 | import random 6 | from scipy.ndimage.morphology import generate_binary_structure, binary_erosion 7 | from scipy.ndimage import maximum_filter 8 | 9 | 10 | class Compose(object): 11 | def __init__(self, transforms): 12 | self.transforms = transforms 13 | 14 | def __call__(self, img, mask): 15 | assert img.size == mask.size 16 | for t in self.transforms: 17 | img, mask = t(img, mask) 18 | return img, mask 19 | 20 | 21 | class RandomCrop(object): 22 | """ 23 | Take a random crop from the image. 24 | First the image or crop size may need to be adjusted if the incoming image 25 | is too small... 26 | If the image is smaller than the crop, then: 27 | the image is padded up to the size of the crop 28 | unless 'nopad', in which case the crop size is shrunk to fit the image 29 | A random crop is taken such that the crop fits within the image. 30 | If a centroid is passed in, the crop must intersect the centroid. 31 | """ 32 | def __init__(self, size=512, ignore_index=12, nopad=True): 33 | 34 | if isinstance(size, numbers.Number): 35 | self.size = (int(size), int(size)) 36 | else: 37 | self.size = size 38 | self.ignore_index = ignore_index 39 | self.nopad = nopad 40 | self.pad_color = (0, 0, 0) 41 | 42 | def __call__(self, img, mask, centroid=None): 43 | assert img.size == mask.size 44 | w, h = img.size 45 | # ASSUME H, W 46 | th, tw = self.size 47 | if w == tw and h == th: 48 | return img, mask 49 | 50 | if self.nopad: 51 | if th > h or tw > w: 52 | # Instead of padding, adjust crop size to the shorter edge of image. 53 | shorter_side = min(w, h) 54 | th, tw = shorter_side, shorter_side 55 | else: 56 | # Check if we need to pad img to fit for crop_size. 57 | if th > h: 58 | pad_h = (th - h) // 2 + 1 59 | else: 60 | pad_h = 0 61 | if tw > w: 62 | pad_w = (tw - w) // 2 + 1 63 | else: 64 | pad_w = 0 65 | border = (pad_w, pad_h, pad_w, pad_h) 66 | if pad_h or pad_w: 67 | img = ImageOps.expand(img, border=border, fill=self.pad_color) 68 | mask = ImageOps.expand(mask, border=border, fill=self.ignore_index) 69 | w, h = img.size 70 | 71 | if centroid is not None: 72 | # Need to insure that centroid is covered by crop and that crop 73 | # sits fully within the image 74 | c_x, c_y = centroid 75 | max_x = w - tw 76 | max_y = h - th 77 | x1 = random.randint(c_x - tw, c_x) 78 | x1 = min(max_x, max(0, x1)) 79 | y1 = random.randint(c_y - th, c_y) 80 | y1 = min(max_y, max(0, y1)) 81 | else: 82 | if w == tw: 83 | x1 = 0 84 | else: 85 | x1 = random.randint(0, w - tw) 86 | if h == th: 87 | y1 = 0 88 | else: 89 | y1 = random.randint(0, h - th) 90 | return img.crop((x1, y1, x1 + tw, y1 + th)), mask.crop((x1, y1, x1 + tw, y1 + th)) 91 | 92 | 93 | class PadImage(object): 94 | def __init__(self, size=(512, 512), ignore_index=0): 95 | self.size = size 96 | self.ignore_index = ignore_index 97 | 98 | def __call__(self, img, mask): 99 | assert img.size == mask.size 100 | th, tw = self.size, self.size 101 | 102 | w, h = img.size 103 | 104 | if w > tw or h > th: 105 | wpercent = (tw / float(w)) 106 | target_h = int((float(img.size[1]) * float(wpercent))) 107 | img, mask = img.resize((tw, target_h), Image.BICUBIC), mask.resize((tw, target_h), Image.NEAREST) 108 | 109 | w, h = img.size 110 | img = ImageOps.expand(img, border=(0, 0, tw - w, th - h), fill=0) 111 | mask = ImageOps.expand(mask, border=(0, 0, tw - w, th - h), fill=self.ignore_index) 112 | 113 | return img, mask 114 | 115 | 116 | class RandomHorizontalFlip(object): 117 | 118 | def __init__(self, prob: float = 0.5): 119 | self.prob = prob 120 | 121 | def __call__(self, img, mask=None): 122 | if mask is not None: 123 | if random.random() < self.prob: 124 | return img.transpose(Image.FLIP_LEFT_RIGHT), mask.transpose( 125 | Image.FLIP_LEFT_RIGHT) 126 | else: 127 | return img, mask 128 | else: 129 | if random.random() < self.prob: 130 | return img.transpose(Image.FLIP_LEFT_RIGHT) 131 | else: 132 | return img 133 | 134 | 135 | class RandomVerticalFlip(object): 136 | def __init__(self, prob: float = 0.5): 137 | self.prob = prob 138 | 139 | def __call__(self, img, mask=None): 140 | if mask is not None: 141 | if random.random() < self.prob: 142 | return img.transpose(Image.FLIP_TOP_BOTTOM), mask.transpose( 143 | Image.FLIP_TOP_BOTTOM) 144 | else: 145 | return img, mask 146 | else: 147 | if random.random() < self.prob: 148 | return img.transpose(Image.FLIP_TOP_BOTTOM) 149 | else: 150 | return img 151 | 152 | 153 | class Resize(object): 154 | def __init__(self, size: tuple = (512, 512)): 155 | self.size = size # size: (h, w) 156 | 157 | def __call__(self, img, mask): 158 | assert img.size == mask.size 159 | return img.resize(self.size, Image.BICUBIC), mask.resize(self.size, Image.NEAREST) 160 | 161 | 162 | class RandomScale(object): 163 | def __init__(self, scale_list=[0.75, 1.0, 1.25], mode='value'): 164 | self.scale_list = scale_list 165 | self.mode = mode 166 | 167 | def __call__(self, img, mask): 168 | oh, ow = img.size 169 | scale_amt = 1.0 170 | if self.mode == 'value': 171 | scale_amt = np.random.choice(self.scale_list, 1) 172 | elif self.mode == 'range': 173 | scale_amt = random.uniform(self.scale_list[0], self.scale_list[-1]) 174 | h = int(scale_amt * oh) 175 | w = int(scale_amt * ow) 176 | return img.resize((w, h), Image.BICUBIC), mask.resize((w, h), Image.NEAREST) 177 | 178 | 179 | class ColorJitter(object): 180 | def __init__(self, brightness=0.5, contrast=0.5, saturation=0.5): 181 | if not brightness is None and brightness>0: 182 | self.brightness = [max(1-brightness, 0), 1+brightness] 183 | if not contrast is None and contrast>0: 184 | self.contrast = [max(1-contrast, 0), 1+contrast] 185 | if not saturation is None and saturation>0: 186 | self.saturation = [max(1-saturation, 0), 1+saturation] 187 | 188 | def __call__(self, img, mask=None): 189 | r_brightness = random.uniform(self.brightness[0], self.brightness[1]) 190 | r_contrast = random.uniform(self.contrast[0], self.contrast[1]) 191 | r_saturation = random.uniform(self.saturation[0], self.saturation[1]) 192 | img = ImageEnhance.Brightness(img).enhance(r_brightness) 193 | img = ImageEnhance.Contrast(img).enhance(r_contrast) 194 | img = ImageEnhance.Color(img).enhance(r_saturation) 195 | if mask is None: 196 | return img 197 | else: 198 | return img, mask 199 | 200 | 201 | class SmartCropV1(object): 202 | def __init__(self, crop_size=512, 203 | max_ratio=0.75, 204 | ignore_index=12, nopad=False): 205 | self.crop_size = crop_size 206 | self.max_ratio = max_ratio 207 | self.ignore_index = ignore_index 208 | self.crop = RandomCrop(crop_size, ignore_index=ignore_index, nopad=nopad) 209 | 210 | def __call__(self, img, mask): 211 | assert img.size == mask.size 212 | count = 0 213 | while True: 214 | img_crop, mask_crop = self.crop(img.copy(), mask.copy()) 215 | count += 1 216 | labels, cnt = np.unique(np.array(mask_crop), return_counts=True) 217 | cnt = cnt[labels != self.ignore_index] 218 | if len(cnt) > 1 and np.max(cnt) / np.sum(cnt) < self.max_ratio: 219 | break 220 | if count > 10: 221 | break 222 | 223 | return img_crop, mask_crop 224 | 225 | 226 | class SmartCropV2(object): 227 | def __init__(self, crop_size=512, num_classes=13, 228 | class_interest=[2, 3], 229 | class_ratio=[0.1, 0.25], 230 | max_ratio=0.75, 231 | ignore_index=12, nopad=True): 232 | self.crop_size = crop_size 233 | self.num_classes = num_classes 234 | self.class_interest = class_interest 235 | self.class_ratio = class_ratio 236 | self.max_ratio = max_ratio 237 | self.ignore_index = ignore_index 238 | self.crop = RandomCrop(crop_size, ignore_index=ignore_index, nopad=nopad) 239 | 240 | def __call__(self, img, mask): 241 | assert img.size == mask.size 242 | count = 0 243 | while True: 244 | img_crop, mask_crop = self.crop(img.copy(), mask.copy()) 245 | count += 1 246 | bins = np.array(range(self.num_classes + 1)) 247 | class_pixel_counts, _ = np.histogram(np.array(mask_crop), bins=bins) 248 | cf = class_pixel_counts / (self.crop_size * self.crop_size) 249 | cf = np.array(cf) 250 | for c, f in zip(self.class_interest, self.class_ratio): 251 | if cf[c] > f: 252 | break 253 | if np.max(cf) < 0.75 and np.argmax(cf) != self.ignore_index: 254 | break 255 | if count > 10: 256 | break 257 | 258 | return img_crop, mask_crop -------------------------------------------------------------------------------- /geoseg/datasets/uavid_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import numpy as np 4 | import torch 5 | from torch.utils.data import Dataset 6 | import cv2 7 | import matplotlib.pyplot as plt 8 | import albumentations as albu 9 | 10 | import matplotlib.patches as mpatches 11 | from PIL import Image 12 | import random 13 | from .transform import * 14 | 15 | CLASSES = ('Building', 'Road', 'Tree', 'LowVeg', 'Moving_Car', 'Static_Car', 'Human', 'Clutter') 16 | PALETTE = [[128, 0, 0], [128, 64, 128], [0, 128, 0], [128, 128, 0], [64, 0, 128], [192, 0, 192], [64, 64, 0], [0, 0, 0]] 17 | 18 | ORIGIN_IMG_SIZE = (1024, 1024) 19 | INPUT_IMG_SIZE = (1024, 1024) 20 | TEST_IMG_SIZE = (1024, 1024) 21 | 22 | 23 | def get_training_transform(): 24 | train_transform = [ 25 | albu.HorizontalFlip(p=0.5), 26 | albu.VerticalFlip(p=0.5), 27 | albu.RandomBrightnessContrast(brightness_limit=0.25, contrast_limit=0.25, p=0.25), 28 | albu.Normalize() 29 | ] 30 | return albu.Compose(train_transform) 31 | 32 | 33 | def train_aug(img, mask): 34 | crop_aug = SmartCropV1(crop_size=768, max_ratio=0.75, ignore_index=255, nopad=False) 35 | img, mask = crop_aug(img, mask) 36 | img, mask = np.array(img), np.array(mask) 37 | aug = get_training_transform()(image=img.copy(), mask=mask.copy()) 38 | img, mask = aug['image'], aug['mask'] 39 | return img, mask 40 | 41 | 42 | def get_val_transform(): 43 | val_transform = [ 44 | albu.Normalize() 45 | ] 46 | return albu.Compose(val_transform) 47 | 48 | 49 | def val_aug(img, mask): 50 | img, mask = np.array(img), np.array(mask) 51 | aug = get_val_transform()(image=img.copy(), mask=mask.copy()) 52 | img, mask = aug['image'], aug['mask'] 53 | return img, mask 54 | 55 | 56 | class UAVIDDataset(Dataset): 57 | def __init__(self, data_root='data/uavid/train', mode='val', img_dir='images', mask_dir='masks', 58 | img_suffix='.png', mask_suffix='.png', transform=val_aug, mosaic_ratio=0.0, 59 | img_size=ORIGIN_IMG_SIZE): 60 | self.data_root = data_root 61 | self.img_dir = img_dir 62 | self.mask_dir = mask_dir 63 | self.img_suffix = img_suffix 64 | self.mask_suffix = mask_suffix 65 | self.transform = transform 66 | self.mode = mode 67 | self.mosaic_ratio = mosaic_ratio 68 | self.img_size = img_size 69 | self.img_ids = self.get_img_ids(self.data_root, self.img_dir, self.mask_dir) 70 | 71 | def __getitem__(self, index): 72 | p_ratio = random.random() 73 | if p_ratio > self.mosaic_ratio or self.mode == 'val' or self.mode == 'test': 74 | img, mask = self.load_img_and_mask(index) 75 | if self.transform: 76 | img, mask = self.transform(img, mask) 77 | else: 78 | img, mask = np.array(img), np.array(mask) 79 | else: 80 | img, mask = self.load_mosaic_img_and_mask(index) 81 | if self.transform: 82 | img, mask = self.transform(img, mask) 83 | else: 84 | img, mask = np.array(img), np.array(mask) 85 | 86 | img = torch.from_numpy(img).permute(2, 0, 1).float() 87 | mask = torch.from_numpy(mask).long() 88 | img_id = self.img_ids[index] 89 | results = {'img': img, 'gt_semantic_seg': mask, 'img_id': img_id} 90 | return results 91 | 92 | def __len__(self): 93 | return len(self.img_ids) 94 | 95 | def get_img_ids(self, data_root, img_dir, mask_dir): 96 | img_filename_list = os.listdir(osp.join(data_root, img_dir)) 97 | mask_filename_list = os.listdir(osp.join(data_root, mask_dir)) 98 | assert len(img_filename_list) == len(mask_filename_list) 99 | img_ids = [str(id.split('.')[0]) for id in mask_filename_list] 100 | return img_ids 101 | 102 | def load_img_and_mask(self, index): 103 | img_id = self.img_ids[index] 104 | img_name = osp.join(self.data_root, self.img_dir, img_id + self.img_suffix) 105 | mask_name = osp.join(self.data_root, self.mask_dir, img_id + self.mask_suffix) 106 | img = Image.open(img_name).convert('RGB') 107 | mask = Image.open(mask_name).convert('L') 108 | return img, mask 109 | 110 | def load_mosaic_img_and_mask(self, index): 111 | indexes = [index] + [random.randint(0, len(self.img_ids) - 1) for _ in range(3)] 112 | img_a, mask_a = self.load_img_and_mask(indexes[0]) 113 | img_b, mask_b = self.load_img_and_mask(indexes[1]) 114 | img_c, mask_c = self.load_img_and_mask(indexes[2]) 115 | img_d, mask_d = self.load_img_and_mask(indexes[3]) 116 | 117 | img_a, mask_a = np.array(img_a), np.array(mask_a) 118 | img_b, mask_b = np.array(img_b), np.array(mask_b) 119 | img_c, mask_c = np.array(img_c), np.array(mask_c) 120 | img_d, mask_d = np.array(img_d), np.array(mask_d) 121 | 122 | h = self.img_size[0] 123 | w = self.img_size[1] 124 | 125 | start_x = w // 4 126 | strat_y = h // 4 127 | # The coordinates of the splice center 128 | offset_x = random.randint(start_x, (w - start_x)) 129 | offset_y = random.randint(strat_y, (h - strat_y)) 130 | 131 | crop_size_a = (offset_x, offset_y) 132 | crop_size_b = (w - offset_x, offset_y) 133 | crop_size_c = (offset_x, h - offset_y) 134 | crop_size_d = (w - offset_x, h - offset_y) 135 | 136 | random_crop_a = albu.RandomCrop(width=crop_size_a[0], height=crop_size_a[1]) 137 | random_crop_b = albu.RandomCrop(width=crop_size_b[0], height=crop_size_b[1]) 138 | random_crop_c = albu.RandomCrop(width=crop_size_c[0], height=crop_size_c[1]) 139 | random_crop_d = albu.RandomCrop(width=crop_size_d[0], height=crop_size_d[1]) 140 | 141 | croped_a = random_crop_a(image=img_a.copy(), mask=mask_a.copy()) 142 | croped_b = random_crop_b(image=img_b.copy(), mask=mask_b.copy()) 143 | croped_c = random_crop_c(image=img_c.copy(), mask=mask_c.copy()) 144 | croped_d = random_crop_d(image=img_d.copy(), mask=mask_d.copy()) 145 | 146 | img_crop_a, mask_crop_a = croped_a['image'], croped_a['mask'] 147 | img_crop_b, mask_crop_b = croped_b['image'], croped_b['mask'] 148 | img_crop_c, mask_crop_c = croped_c['image'], croped_c['mask'] 149 | img_crop_d, mask_crop_d = croped_d['image'], croped_d['mask'] 150 | 151 | top = np.concatenate((img_crop_a, img_crop_b), axis=1) 152 | bottom = np.concatenate((img_crop_c, img_crop_d), axis=1) 153 | img = np.concatenate((top, bottom), axis=0) 154 | 155 | top_mask = np.concatenate((mask_crop_a, mask_crop_b), axis=1) 156 | bottom_mask = np.concatenate((mask_crop_c, mask_crop_d), axis=1) 157 | mask = np.concatenate((top_mask, bottom_mask), axis=0) 158 | mask = np.ascontiguousarray(mask) 159 | img = np.ascontiguousarray(img) 160 | img = Image.fromarray(img) 161 | mask = Image.fromarray(mask) 162 | # print(img.shape) 163 | 164 | return img, mask 165 | 166 | 167 | def show_img_mask_seg(seg_path, img_path, mask_path, start_seg_index): 168 | seg_list = os.listdir(seg_path) 169 | fig, ax = plt.subplots(2, 3, figsize=(18, 12)) 170 | seg_list = seg_list[start_seg_index:start_seg_index+2] 171 | patches = [mpatches.Patch(color=np.array(PALETTE[i])/255., label=CLASSES[i]) for i in range(len(CLASSES))] 172 | for i in range(len(seg_list)): 173 | seg_id = seg_list[i] 174 | img_seg = cv2.imread(f'{seg_path}/{seg_id}', cv2.IMREAD_UNCHANGED) 175 | img_seg = img_seg.astype(np.uint8) 176 | img_seg = Image.fromarray(img_seg).convert('P') 177 | img_seg.putpalette(np.array(PALETTE, dtype=np.uint8)) 178 | img_seg = np.array(img_seg.convert('RGB')) 179 | mask = cv2.imread(f'{mask_path}/{seg_id}', cv2.IMREAD_UNCHANGED) 180 | mask = mask.astype(np.uint8) 181 | mask = Image.fromarray(mask).convert('P') 182 | mask.putpalette(np.array(PALETTE, dtype=np.uint8)) 183 | mask = np.array(mask.convert('RGB')) 184 | img_id = str(seg_id.split('.')[0])+'.tif' 185 | img = cv2.imread(f'{img_path}/{img_id}', cv2.IMREAD_COLOR) 186 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 187 | ax[i, 0].set_axis_off() 188 | ax[i, 0].imshow(img) 189 | ax[i, 0].set_title('RS IMAGE ' + img_id) 190 | ax[i, 1].set_axis_off() 191 | ax[i, 1].imshow(mask) 192 | ax[i, 1].set_title('Mask True ' + seg_id) 193 | ax[i, 2].set_axis_off() 194 | ax[i, 2].imshow(img_seg) 195 | ax[i, 2].set_title('Mask Predict ' + seg_id) 196 | ax[i, 2].legend(handles=patches, bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0., fontsize='large') 197 | 198 | 199 | def show_seg(seg_path, img_path, start_seg_index): 200 | seg_list = os.listdir(seg_path) 201 | fig, ax = plt.subplots(2, 2, figsize=(12, 12)) 202 | seg_list = seg_list[start_seg_index:start_seg_index+2] 203 | patches = [mpatches.Patch(color=np.array(PALETTE[i])/255., label=CLASSES[i]) for i in range(len(CLASSES))] 204 | for i in range(len(seg_list)): 205 | seg_id = seg_list[i] 206 | img_seg = cv2.imread(f'{seg_path}/{seg_id}', cv2.IMREAD_UNCHANGED) 207 | img_seg = img_seg.astype(np.uint8) 208 | img_seg = Image.fromarray(img_seg).convert('P') 209 | img_seg.putpalette(np.array(PALETTE, dtype=np.uint8)) 210 | img_seg = np.array(img_seg.convert('RGB')) 211 | img_id = str(seg_id.split('.')[0])+'.tif' 212 | img = cv2.imread(f'{img_path}/{img_id}', cv2.IMREAD_COLOR) 213 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 214 | ax[i, 0].set_axis_off() 215 | ax[i, 0].imshow(img) 216 | ax[i, 0].set_title('RS IMAGE '+img_id) 217 | ax[i, 1].set_axis_off() 218 | ax[i, 1].imshow(img_seg) 219 | ax[i, 1].set_title('Seg IMAGE '+seg_id) 220 | ax[i, 1].legend(handles=patches, bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0., fontsize='large') 221 | 222 | 223 | def show_mask(img, mask, img_id): 224 | fig, (ax1, ax2) = plt.subplots(nrows=1, ncols=2, figsize=(12, 12)) 225 | patches = [mpatches.Patch(color=np.array(PALETTE[i])/255., label=CLASSES[i]) for i in range(len(CLASSES))] 226 | mask = mask.astype(np.uint8) 227 | mask = Image.fromarray(mask).convert('P') 228 | mask.putpalette(np.array(PALETTE, dtype=np.uint8)) 229 | mask = np.array(mask.convert('RGB')) 230 | ax1.imshow(img) 231 | ax1.set_title('RS IMAGE ' + str(img_id)+'.png') 232 | ax2.imshow(mask) 233 | ax2.set_title('Mask ' + str(img_id)+'.png') 234 | ax2.legend(handles=patches, bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0., fontsize='large') 235 | -------------------------------------------------------------------------------- /geoseg/datasets/vaihingen_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import numpy as np 4 | import torch 5 | from torch.utils.data import Dataset 6 | import cv2 7 | import matplotlib.pyplot as plt 8 | import albumentations as albu 9 | 10 | import matplotlib.patches as mpatches 11 | from PIL import Image 12 | import random 13 | from .transform import * 14 | 15 | CLASSES = ('ImSurf', 'Building', 'LowVeg', 'Tree', 'Car', 'Clutter') 16 | PALETTE = [[255, 255, 255], [0, 0, 255], [0, 255, 255], [0, 255, 0], [255, 204, 0], [255, 0, 0]] 17 | 18 | ORIGIN_IMG_SIZE = (1024, 1024) 19 | INPUT_IMG_SIZE = (1024, 1024) 20 | TEST_IMG_SIZE = (1024, 1024) 21 | 22 | 23 | def get_training_transform(): 24 | train_transform = [ 25 | albu.RandomRotate90(p=0.5), 26 | albu.Normalize() 27 | ] 28 | return albu.Compose(train_transform) 29 | 30 | 31 | def train_aug(img, mask): 32 | crop_aug = Compose([RandomScale(scale_list=[0.5, 0.75, 1.0, 1.25, 1.5], mode='value'), 33 | SmartCropV1(crop_size=512, max_ratio=0.75, 34 | ignore_index=len(CLASSES), nopad=False)]) 35 | img, mask = crop_aug(img, mask) 36 | img, mask = np.array(img), np.array(mask) 37 | aug = get_training_transform()(image=img.copy(), mask=mask.copy()) 38 | img, mask = aug['image'], aug['mask'] 39 | return img, mask 40 | 41 | 42 | def get_val_transform(): 43 | val_transform = [ 44 | albu.Normalize() 45 | ] 46 | return albu.Compose(val_transform) 47 | 48 | 49 | def val_aug(img, mask): 50 | img, mask = np.array(img), np.array(mask) 51 | aug = get_val_transform()(image=img.copy(), mask=mask.copy()) 52 | img, mask = aug['image'], aug['mask'] 53 | return img, mask 54 | 55 | 56 | class VaihingenDataset(Dataset): 57 | def __init__(self, data_root='data/vaihingen/test', mode='val', img_dir='images_1024', mask_dir='masks_1024', 58 | img_suffix='.tif', mask_suffix='.png', transform=val_aug, mosaic_ratio=0.0, 59 | img_size=ORIGIN_IMG_SIZE): 60 | self.data_root = data_root 61 | self.img_dir = img_dir 62 | self.mask_dir = mask_dir 63 | self.img_suffix = img_suffix 64 | self.mask_suffix = mask_suffix 65 | self.transform = transform 66 | self.mode = mode 67 | self.mosaic_ratio = mosaic_ratio 68 | self.img_size = img_size 69 | self.img_ids = self.get_img_ids(self.data_root, self.img_dir, self.mask_dir) 70 | 71 | def __getitem__(self, index): 72 | p_ratio = random.random() 73 | if p_ratio > self.mosaic_ratio or self.mode == 'val' or self.mode == 'test': 74 | img, mask = self.load_img_and_mask(index) 75 | if self.transform: 76 | img, mask = self.transform(img, mask) 77 | else: 78 | img, mask = self.load_mosaic_img_and_mask(index) 79 | if self.transform: 80 | img, mask = self.transform(img, mask) 81 | 82 | img = torch.from_numpy(img).permute(2, 0, 1).float() 83 | mask = torch.from_numpy(mask).long() 84 | img_id = self.img_ids[index] 85 | results = dict(img_id=img_id, img=img, gt_semantic_seg=mask) 86 | return results 87 | 88 | def __len__(self): 89 | return len(self.img_ids) 90 | 91 | def get_img_ids(self, data_root, img_dir, mask_dir): 92 | img_filename_list = os.listdir(osp.join(data_root, img_dir)) 93 | mask_filename_list = os.listdir(osp.join(data_root, mask_dir)) 94 | assert len(img_filename_list) == len(mask_filename_list) 95 | img_ids = [str(id.split('.')[0]) for id in mask_filename_list] 96 | return img_ids 97 | 98 | def load_img_and_mask(self, index): 99 | img_id = self.img_ids[index] 100 | img_name = osp.join(self.data_root, self.img_dir, img_id + self.img_suffix) 101 | mask_name = osp.join(self.data_root, self.mask_dir, img_id + self.mask_suffix) 102 | img = Image.open(img_name).convert('RGB') 103 | mask = Image.open(mask_name).convert('L') 104 | return img, mask 105 | 106 | def load_mosaic_img_and_mask(self, index): 107 | indexes = [index] + [random.randint(0, len(self.img_ids) - 1) for _ in range(3)] 108 | img_a, mask_a = self.load_img_and_mask(indexes[0]) 109 | img_b, mask_b = self.load_img_and_mask(indexes[1]) 110 | img_c, mask_c = self.load_img_and_mask(indexes[2]) 111 | img_d, mask_d = self.load_img_and_mask(indexes[3]) 112 | 113 | img_a, mask_a = np.array(img_a), np.array(mask_a) 114 | img_b, mask_b = np.array(img_b), np.array(mask_b) 115 | img_c, mask_c = np.array(img_c), np.array(mask_c) 116 | img_d, mask_d = np.array(img_d), np.array(mask_d) 117 | 118 | h = self.img_size[0] 119 | w = self.img_size[1] 120 | 121 | start_x = w // 4 122 | strat_y = h // 4 123 | # The coordinates of the splice center 124 | offset_x = random.randint(start_x, (w - start_x)) 125 | offset_y = random.randint(strat_y, (h - strat_y)) 126 | 127 | crop_size_a = (offset_x, offset_y) 128 | crop_size_b = (w - offset_x, offset_y) 129 | crop_size_c = (offset_x, h - offset_y) 130 | crop_size_d = (w - offset_x, h - offset_y) 131 | 132 | random_crop_a = albu.RandomCrop(width=crop_size_a[0], height=crop_size_a[1]) 133 | random_crop_b = albu.RandomCrop(width=crop_size_b[0], height=crop_size_b[1]) 134 | random_crop_c = albu.RandomCrop(width=crop_size_c[0], height=crop_size_c[1]) 135 | random_crop_d = albu.RandomCrop(width=crop_size_d[0], height=crop_size_d[1]) 136 | 137 | croped_a = random_crop_a(image=img_a.copy(), mask=mask_a.copy()) 138 | croped_b = random_crop_b(image=img_b.copy(), mask=mask_b.copy()) 139 | croped_c = random_crop_c(image=img_c.copy(), mask=mask_c.copy()) 140 | croped_d = random_crop_d(image=img_d.copy(), mask=mask_d.copy()) 141 | 142 | img_crop_a, mask_crop_a = croped_a['image'], croped_a['mask'] 143 | img_crop_b, mask_crop_b = croped_b['image'], croped_b['mask'] 144 | img_crop_c, mask_crop_c = croped_c['image'], croped_c['mask'] 145 | img_crop_d, mask_crop_d = croped_d['image'], croped_d['mask'] 146 | 147 | top = np.concatenate((img_crop_a, img_crop_b), axis=1) 148 | bottom = np.concatenate((img_crop_c, img_crop_d), axis=1) 149 | img = np.concatenate((top, bottom), axis=0) 150 | 151 | top_mask = np.concatenate((mask_crop_a, mask_crop_b), axis=1) 152 | bottom_mask = np.concatenate((mask_crop_c, mask_crop_d), axis=1) 153 | mask = np.concatenate((top_mask, bottom_mask), axis=0) 154 | mask = np.ascontiguousarray(mask) 155 | img = np.ascontiguousarray(img) 156 | img = Image.fromarray(img) 157 | mask = Image.fromarray(mask) 158 | # print(img.shape) 159 | 160 | return img, mask 161 | 162 | 163 | def show_img_mask_seg(seg_path, img_path, mask_path, start_seg_index): 164 | seg_list = os.listdir(seg_path) 165 | seg_list = [f for f in seg_list if f.endswith('.png')] 166 | fig, ax = plt.subplots(2, 3, figsize=(18, 12)) 167 | seg_list = seg_list[start_seg_index:start_seg_index+2] 168 | patches = [mpatches.Patch(color=np.array(PALETTE[i])/255., label=CLASSES[i]) for i in range(len(CLASSES))] 169 | for i in range(len(seg_list)): 170 | seg_id = seg_list[i] 171 | img_seg = cv2.imread(f'{seg_path}/{seg_id}', cv2.IMREAD_UNCHANGED) 172 | img_seg = img_seg.astype(np.uint8) 173 | img_seg = Image.fromarray(img_seg).convert('P') 174 | img_seg.putpalette(np.array(PALETTE, dtype=np.uint8)) 175 | img_seg = np.array(img_seg.convert('RGB')) 176 | mask = cv2.imread(f'{mask_path}/{seg_id}', cv2.IMREAD_UNCHANGED) 177 | mask = mask.astype(np.uint8) 178 | mask = Image.fromarray(mask).convert('P') 179 | mask.putpalette(np.array(PALETTE, dtype=np.uint8)) 180 | mask = np.array(mask.convert('RGB')) 181 | img_id = str(seg_id.split('.')[0])+'.tif' 182 | img = cv2.imread(f'{img_path}/{img_id}', cv2.IMREAD_COLOR) 183 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 184 | ax[i, 0].set_axis_off() 185 | ax[i, 0].imshow(img) 186 | ax[i, 0].set_title('RS IMAGE ' + img_id) 187 | ax[i, 1].set_axis_off() 188 | ax[i, 1].imshow(mask) 189 | ax[i, 1].set_title('Mask True ' + seg_id) 190 | ax[i, 2].set_axis_off() 191 | ax[i, 2].imshow(img_seg) 192 | ax[i, 2].set_title('Mask Predict ' + seg_id) 193 | ax[i, 2].legend(handles=patches, bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0., fontsize='large') 194 | 195 | 196 | def show_seg(seg_path, img_path, start_seg_index): 197 | seg_list = os.listdir(seg_path) 198 | seg_list = [f for f in seg_list if f.endswith('.png')] 199 | fig, ax = plt.subplots(2, 2, figsize=(12, 12)) 200 | seg_list = seg_list[start_seg_index:start_seg_index+2] 201 | patches = [mpatches.Patch(color=np.array(PALETTE[i])/255., label=CLASSES[i]) for i in range(len(CLASSES))] 202 | for i in range(len(seg_list)): 203 | seg_id = seg_list[i] 204 | img_seg = cv2.imread(f'{seg_path}/{seg_id}', cv2.IMREAD_UNCHANGED) 205 | img_seg = img_seg.astype(np.uint8) 206 | img_seg = Image.fromarray(img_seg).convert('P') 207 | img_seg.putpalette(np.array(PALETTE, dtype=np.uint8)) 208 | img_seg = np.array(img_seg.convert('RGB')) 209 | img_id = str(seg_id.split('.')[0])+'.tif' 210 | img = cv2.imread(f'{img_path}/{img_id}', cv2.IMREAD_COLOR) 211 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 212 | ax[i, 0].set_axis_off() 213 | ax[i, 0].imshow(img) 214 | ax[i, 0].set_title('RS IMAGE '+img_id) 215 | ax[i, 1].set_axis_off() 216 | ax[i, 1].imshow(img_seg) 217 | ax[i, 1].set_title('Seg IMAGE '+seg_id) 218 | ax[i, 1].legend(handles=patches, bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0., fontsize='large') 219 | 220 | 221 | def show_mask(img, mask, img_id): 222 | fig, (ax1, ax2) = plt.subplots(nrows=1, ncols=2, figsize=(12, 12)) 223 | patches = [mpatches.Patch(color=np.array(PALETTE[i])/255., label=CLASSES[i]) for i in range(len(CLASSES))] 224 | mask = mask.astype(np.uint8) 225 | mask = Image.fromarray(mask).convert('P') 226 | mask.putpalette(np.array(PALETTE, dtype=np.uint8)) 227 | mask = np.array(mask.convert('RGB')) 228 | ax1.imshow(img) 229 | ax1.set_title('RS IMAGE ' + str(img_id)+'.tif') 230 | ax2.imshow(mask) 231 | ax2.set_title('Mask ' + str(img_id)+'.png') 232 | ax2.legend(handles=patches, bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0., fontsize='large') 233 | -------------------------------------------------------------------------------- /geoseg/losses/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from .balanced_bce import * 4 | from .bitempered_loss import * 5 | from .dice import * 6 | from .focal import * 7 | from .focal_cosine import * 8 | from .functional import * 9 | from .jaccard import * 10 | from .joint_loss import * 11 | from .lovasz import * 12 | from .soft_bce import * 13 | from .soft_ce import * 14 | from .soft_f1 import * 15 | from .wing_loss import * 16 | from .useful_loss import * 17 | -------------------------------------------------------------------------------- /geoseg/losses/balanced_bce.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | from torch import nn, Tensor 6 | 7 | __all__ = ["BalancedBCEWithLogitsLoss", "balanced_binary_cross_entropy_with_logits"] 8 | 9 | 10 | def balanced_binary_cross_entropy_with_logits( 11 | logits: Tensor, targets: Tensor, gamma: float = 1.0, ignore_index: Optional[int] = None, reduction: str = "mean" 12 | ) -> Tensor: 13 | """ 14 | Balanced binary cross entropy loss. 15 | 16 | Args: 17 | logits: 18 | targets: This loss function expects target values to be hard targets 0/1. 19 | gamma: Power factor for balancing weights 20 | ignore_index: 21 | reduction: 22 | 23 | Returns: 24 | Zero-sized tensor with reduced loss if `reduction` is `sum` or `mean`; Otherwise returns loss of the 25 | shape of `logits` tensor. 26 | """ 27 | pos_targets: Tensor = targets.eq(1).sum() 28 | neg_targets: Tensor = targets.eq(0).sum() 29 | 30 | num_targets = pos_targets + neg_targets 31 | pos_weight = torch.pow(neg_targets / (num_targets + 1e-7), gamma) 32 | neg_weight = 1.0 - pos_weight 33 | 34 | pos_term = pos_weight.pow(gamma) * targets * torch.nn.functional.logsigmoid(logits) 35 | neg_term = neg_weight.pow(gamma) * (1 - targets) * torch.nn.functional.logsigmoid(-logits) 36 | 37 | loss = -(pos_term + neg_term) 38 | 39 | if ignore_index is not None: 40 | loss = torch.masked_fill(loss, targets.eq(ignore_index), 0) 41 | 42 | if reduction == "mean": 43 | loss = loss.mean() 44 | 45 | if reduction == "sum": 46 | loss = loss.sum() 47 | 48 | return loss 49 | 50 | 51 | class BalancedBCEWithLogitsLoss(nn.Module): 52 | """ 53 | Balanced binary cross-entropy loss. 54 | 55 | https://arxiv.org/pdf/1504.06375.pdf (Formula 2) 56 | """ 57 | 58 | __constants__ = ["gamma", "reduction", "ignore_index"] 59 | 60 | def __init__(self, gamma: float = 1.0, reduction="mean", ignore_index: Optional[int] = None): 61 | """ 62 | 63 | Args: 64 | gamma: 65 | ignore_index: 66 | reduction: 67 | """ 68 | super().__init__() 69 | self.gamma = gamma 70 | self.reduction = reduction 71 | self.ignore_index = ignore_index 72 | 73 | def forward(self, output: Tensor, target: Tensor) -> Tensor: 74 | return balanced_binary_cross_entropy_with_logits( 75 | output, target, gamma=self.gamma, ignore_index=self.ignore_index, reduction=self.reduction 76 | ) 77 | -------------------------------------------------------------------------------- /geoseg/losses/bitempered_loss.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import torch 4 | from torch import nn, Tensor 5 | 6 | __all__ = ["BiTemperedLogisticLoss", "BinaryBiTemperedLogisticLoss"] 7 | 8 | 9 | def log_t(u, t): 10 | """Compute log_t for `u'.""" 11 | if t == 1.0: 12 | return u.log() 13 | else: 14 | return (u.pow(1.0 - t) - 1.0) / (1.0 - t) 15 | 16 | 17 | def exp_t(u, t): 18 | """Compute exp_t for `u'.""" 19 | if t == 1: 20 | return u.exp() 21 | else: 22 | return (1.0 + (1.0 - t) * u).relu().pow(1.0 / (1.0 - t)) 23 | 24 | 25 | def compute_normalization_fixed_point(activations: Tensor, t: float, num_iters: int) -> Tensor: 26 | """Return the normalization value for each example (t > 1.0). 27 | Args: 28 | activations: A multi-dimensional tensor with last dimension `num_classes`. 29 | t: Temperature 2 (> 1.0 for tail heaviness). 30 | num_iters: Number of iterations to run the method. 31 | Return: A tensor of same shape as activation with the last dimension being 1. 32 | """ 33 | mu, _ = torch.max(activations, -1, keepdim=True) 34 | normalized_activations_step_0 = activations - mu 35 | 36 | normalized_activations = normalized_activations_step_0 37 | 38 | for _ in range(num_iters): 39 | logt_partition = torch.sum(exp_t(normalized_activations, t), -1, keepdim=True) 40 | normalized_activations = normalized_activations_step_0 * logt_partition.pow(1.0 - t) 41 | 42 | logt_partition = torch.sum(exp_t(normalized_activations, t), -1, keepdim=True) 43 | normalization_constants = -log_t(1.0 / logt_partition, t) + mu 44 | 45 | return normalization_constants 46 | 47 | 48 | def compute_normalization_binary_search(activations: Tensor, t: float, num_iters: int) -> Tensor: 49 | """Compute normalization value for each example (t < 1.0). 50 | Args: 51 | activations: A multi-dimensional tensor with last dimension `num_classes`. 52 | t: Temperature 2 (< 1.0 for finite support). 53 | num_iters: Number of iterations to run the method. 54 | Return: A tensor of same rank as activation with the last dimension being 1. 55 | """ 56 | mu, _ = torch.max(activations, -1, keepdim=True) 57 | normalized_activations = activations - mu 58 | 59 | effective_dim = torch.sum((normalized_activations > -1.0 / (1.0 - t)).to(torch.int32), dim=-1, keepdim=True).to( 60 | activations.dtype 61 | ) 62 | 63 | shape_partition = activations.shape[:-1] + (1,) 64 | lower = torch.zeros(shape_partition, dtype=activations.dtype, device=activations.device) 65 | upper = -log_t(1.0 / effective_dim, t) * torch.ones_like(lower) 66 | 67 | for _ in range(num_iters): 68 | logt_partition = (upper + lower) / 2.0 69 | sum_probs = torch.sum(exp_t(normalized_activations - logt_partition, t), dim=-1, keepdim=True) 70 | update = (sum_probs < 1.0).to(activations.dtype) 71 | lower = torch.reshape(lower * update + (1.0 - update) * logt_partition, shape_partition) 72 | upper = torch.reshape(upper * (1.0 - update) + update * logt_partition, shape_partition) 73 | 74 | logt_partition = (upper + lower) / 2.0 75 | return logt_partition + mu 76 | 77 | 78 | class ComputeNormalization(torch.autograd.Function): 79 | """ 80 | Class implementing custom backward pass for compute_normalization. See compute_normalization. 81 | """ 82 | 83 | @staticmethod 84 | def forward(ctx, activations, t, num_iters): 85 | if t < 1.0: 86 | normalization_constants = compute_normalization_binary_search(activations, t, num_iters) 87 | else: 88 | normalization_constants = compute_normalization_fixed_point(activations, t, num_iters) 89 | 90 | ctx.save_for_backward(activations, normalization_constants) 91 | ctx.t = t 92 | return normalization_constants 93 | 94 | @staticmethod 95 | def backward(ctx, grad_output): 96 | activations, normalization_constants = ctx.saved_tensors 97 | t = ctx.t 98 | normalized_activations = activations - normalization_constants 99 | probabilities = exp_t(normalized_activations, t) 100 | escorts = probabilities.pow(t) 101 | escorts = escorts / escorts.sum(dim=-1, keepdim=True) 102 | grad_input = escorts * grad_output 103 | 104 | return grad_input, None, None 105 | 106 | 107 | def compute_normalization(activations, t, num_iters=5): 108 | """Compute normalization value for each example. 109 | Backward pass is implemented. 110 | Args: 111 | activations: A multi-dimensional tensor with last dimension `num_classes`. 112 | t: Temperature 2 (> 1.0 for tail heaviness, < 1.0 for finite support). 113 | num_iters: Number of iterations to run the method. 114 | Return: A tensor of same rank as activation with the last dimension being 1. 115 | """ 116 | return ComputeNormalization.apply(activations, t, num_iters) 117 | 118 | 119 | def tempered_softmax(activations, t, num_iters=5): 120 | """Tempered softmax function. 121 | Args: 122 | activations: A multi-dimensional tensor with last dimension `num_classes`. 123 | t: Temperature > 1.0. 124 | num_iters: Number of iterations to run the method. 125 | Returns: 126 | A probabilities tensor. 127 | """ 128 | if t == 1.0: 129 | return activations.softmax(dim=-1) 130 | 131 | normalization_constants = compute_normalization(activations, t, num_iters) 132 | return exp_t(activations - normalization_constants, t) 133 | 134 | 135 | def bi_tempered_logistic_loss(activations, labels, t1, t2, label_smoothing=0.0, num_iters=5, reduction="mean"): 136 | """Bi-Tempered Logistic Loss. 137 | Args: 138 | activations: A multi-dimensional tensor with last dimension `num_classes`. 139 | labels: A tensor with shape and dtype as activations (onehot), 140 | or a long tensor of one dimension less than activations (pytorch standard) 141 | t1: Temperature 1 (< 1.0 for boundedness). 142 | t2: Temperature 2 (> 1.0 for tail heaviness, < 1.0 for finite support). 143 | label_smoothing: Label smoothing parameter between [0, 1). Default 0.0. 144 | num_iters: Number of iterations to run the method. Default 5. 145 | reduction: ``'none'`` | ``'mean'`` | ``'sum'``. Default ``'mean'``. 146 | ``'none'``: No reduction is applied, return shape is shape of 147 | activations without the last dimension. 148 | ``'mean'``: Loss is averaged over minibatch. Return shape (1,) 149 | ``'sum'``: Loss is summed over minibatch. Return shape (1,) 150 | Returns: 151 | A loss tensor. 152 | """ 153 | if len(labels.shape) < len(activations.shape): # not one-hot 154 | labels_onehot = torch.zeros_like(activations) 155 | labels_onehot.scatter_(1, labels[..., None], 1) 156 | else: 157 | labels_onehot = labels 158 | 159 | if label_smoothing > 0: 160 | num_classes = labels_onehot.shape[-1] 161 | labels_onehot = (1 - label_smoothing * num_classes / (num_classes - 1)) * labels_onehot + label_smoothing / ( 162 | num_classes - 1 163 | ) 164 | 165 | probabilities = tempered_softmax(activations, t2, num_iters) 166 | 167 | loss_values = ( 168 | labels_onehot * log_t(labels_onehot + 1e-10, t1) 169 | - labels_onehot * log_t(probabilities, t1) 170 | - labels_onehot.pow(2.0 - t1) / (2.0 - t1) 171 | + probabilities.pow(2.0 - t1) / (2.0 - t1) 172 | ) 173 | loss_values = loss_values.sum(dim=-1) # sum over classes 174 | 175 | if reduction == "none": 176 | return loss_values 177 | if reduction == "sum": 178 | return loss_values.sum() 179 | if reduction == "mean": 180 | return loss_values.mean() 181 | 182 | 183 | class BiTemperedLogisticLoss(nn.Module): 184 | """ 185 | 186 | https://ai.googleblog.com/2019/08/bi-tempered-logistic-loss-for-training.html 187 | https://arxiv.org/abs/1906.03361 188 | """ 189 | 190 | def __init__(self, t1: float, t2: float, smoothing=0.0, ignore_index=None, reduction: str = "mean"): 191 | """ 192 | 193 | Args: 194 | t1: 195 | t2: 196 | smoothing: 197 | ignore_index: 198 | reduction: 199 | """ 200 | super(BiTemperedLogisticLoss, self).__init__() 201 | self.t1 = t1 202 | self.t2 = t2 203 | self.smoothing = smoothing 204 | self.reduction = reduction 205 | self.ignore_index = ignore_index 206 | 207 | def forward(self, predictions: Tensor, targets: Tensor) -> Tensor: 208 | loss = bi_tempered_logistic_loss( 209 | predictions, targets, t1=self.t1, t2=self.t2, label_smoothing=self.smoothing, reduction="none" 210 | ) 211 | 212 | if self.ignore_index is not None: 213 | mask = ~targets.eq(self.ignore_index) 214 | loss *= mask 215 | 216 | if self.reduction == "mean": 217 | loss = loss.mean() 218 | elif self.reduction == "sum": 219 | loss = loss.sum() 220 | return loss 221 | 222 | 223 | class BinaryBiTemperedLogisticLoss(nn.Module): 224 | """ 225 | Modification of BiTemperedLogisticLoss for binary classification case. 226 | It's signature matches nn.BCEWithLogitsLoss: Predictions and target tensors must have shape [B,1,...] 227 | 228 | References: 229 | https://ai.googleblog.com/2019/08/bi-tempered-logistic-loss-for-training.html 230 | https://arxiv.org/abs/1906.03361 231 | """ 232 | 233 | def __init__( 234 | self, t1: float, t2: float, smoothing: float = 0.0, ignore_index: Optional[int] = None, reduction: str = "mean" 235 | ): 236 | """ 237 | 238 | Args: 239 | t1: 240 | t2: 241 | smoothing: 242 | ignore_index: 243 | reduction: 244 | """ 245 | super().__init__() 246 | self.t1 = t1 247 | self.t2 = t2 248 | self.smoothing = smoothing 249 | self.reduction = reduction 250 | self.ignore_index = ignore_index 251 | 252 | def forward(self, predictions: Tensor, targets: Tensor) -> Tensor: 253 | """ 254 | Forward method of the loss function 255 | 256 | Args: 257 | predictions: [B,1,...] 258 | targets: [B,1,...] 259 | 260 | Returns: 261 | Zero-sized tensor with reduced loss if self.reduction is `sum` or `mean`; Otherwise returns loss of the 262 | shape of `predictions` tensor. 263 | """ 264 | if predictions.size(1) != 1 or targets.size(1) != 1: 265 | raise ValueError("Channel dimension for predictions and targets must be equal to 1") 266 | 267 | loss = bi_tempered_logistic_loss( 268 | torch.cat([-predictions, predictions], dim=1).moveaxis(1, -1), 269 | torch.cat([1 - targets, targets], dim=1).moveaxis(1, -1), 270 | t1=self.t1, 271 | t2=self.t2, 272 | label_smoothing=self.smoothing, 273 | reduction="none", 274 | ).unsqueeze(dim=1) 275 | 276 | if self.ignore_index is not None: 277 | mask = targets.eq(self.ignore_index) 278 | loss = torch.masked_fill(loss, mask, 0) 279 | 280 | if self.reduction == "mean": 281 | loss = loss.mean() 282 | elif self.reduction == "sum": 283 | loss = loss.sum() 284 | return loss 285 | -------------------------------------------------------------------------------- /geoseg/losses/cel1.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import logging 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from typing import Optional 6 | 7 | BINARY_MODE: str = "binary" 8 | 9 | MULTICLASS_MODE: str = "multiclass" 10 | 11 | MULTILABEL_MODE: str = "multilabel" 12 | 13 | 14 | EPS = 1e-10 15 | 16 | 17 | logger = logging.getLogger(__name__) 18 | 19 | 20 | def expand_onehot_labels(labels, target_shape, ignore_index): 21 | """Expand onehot labels to match the size of prediction.""" 22 | bin_labels = labels.new_zeros(target_shape) 23 | valid_mask = (labels >= 0) & (labels != ignore_index) 24 | inds = torch.nonzero(valid_mask, as_tuple=True) 25 | 26 | if inds[0].numel() > 0: 27 | if labels.dim() == 3: 28 | bin_labels[inds[0], labels[valid_mask], inds[1], inds[2]] = 1 29 | else: 30 | bin_labels[inds[0], labels[valid_mask]] = 1 31 | 32 | return bin_labels, valid_mask 33 | 34 | 35 | def get_region_proportion(x: torch.Tensor, valid_mask: torch.Tensor = None) -> torch.Tensor: 36 | """Get region proportion 37 | Args: 38 | x : one-hot label map/mask 39 | valid_mask : indicate the considered elements 40 | """ 41 | if valid_mask is not None: 42 | if valid_mask.dim() == 4: 43 | x = torch.einsum("bcwh, bcwh->bcwh", x, valid_mask) 44 | cardinality = torch.einsum("bcwh->bc", valid_mask) 45 | else: 46 | x = torch.einsum("bcwh,bwh->bcwh", x, valid_mask) 47 | cardinality = torch.einsum("bwh->b", valid_mask).unsqueeze(dim=1).repeat(1, x.shape[1]) 48 | else: 49 | cardinality = x.shape[2] * x.shape[3] 50 | 51 | region_proportion = (torch.einsum("bcwh->bc", x) + EPS) / (cardinality + EPS) 52 | 53 | return region_proportion 54 | 55 | 56 | class CompoundLoss(nn.Module): 57 | """ 58 | The base class for implementing a compound loss: 59 | l = l_1 + alpha * l_2 60 | """ 61 | def __init__(self, mode: str = MULTICLASS_MODE, 62 | alpha: float = 0.1, 63 | factor: float = 5., 64 | step_size: int = 0, 65 | max_alpha: float = 100., 66 | temp: float = 1., 67 | ignore_index: int = 255, 68 | background_index: int = -1, 69 | weight: Optional[torch.Tensor] = None) -> None: 70 | assert mode in {BINARY_MODE, MULTILABEL_MODE, MULTICLASS_MODE} 71 | super().__init__() 72 | self.mode = mode 73 | self.alpha = alpha 74 | self.max_alpha = max_alpha 75 | self.factor = factor 76 | self.step_size = step_size 77 | self.temp = temp 78 | self.ignore_index = ignore_index 79 | self.background_index = background_index 80 | self.weight = weight 81 | 82 | def cross_entropy(self, inputs: torch.Tensor, labels: torch.Tensor): 83 | if self.mode == MULTICLASS_MODE: 84 | loss = F.cross_entropy( 85 | inputs, labels, weight=self.weight, ignore_index=self.ignore_index) 86 | else: 87 | if labels.dim() == 3: 88 | labels = labels.unsqueeze(dim=1) 89 | loss = F.binary_cross_entropy_with_logits(inputs, labels.type(torch.float32)) 90 | return loss 91 | 92 | def adjust_alpha(self, epoch: int) -> None: 93 | if self.step_size == 0: 94 | return 95 | if (epoch + 1) % self.step_size == 0: 96 | curr_alpha = self.alpha 97 | self.alpha = min(self.alpha * self.factor, self.max_alpha) 98 | logger.info( 99 | "CompoundLoss : Adjust the tradoff param alpha : {:.3g} -> {:.3g}".format(curr_alpha, self.alpha) 100 | ) 101 | 102 | def get_gt_proportion(self, mode: str, 103 | labels: torch.Tensor, 104 | target_shape, 105 | ignore_index: int = 255): 106 | if mode == MULTICLASS_MODE: 107 | bin_labels, valid_mask = expand_onehot_labels(labels, target_shape, ignore_index) 108 | else: 109 | valid_mask = (labels >= 0) & (labels != ignore_index) 110 | if labels.dim() == 3: 111 | labels = labels.unsqueeze(dim=1) 112 | bin_labels = labels 113 | gt_proportion = get_region_proportion(bin_labels, valid_mask) 114 | return gt_proportion, valid_mask 115 | 116 | def get_pred_proportion(self, mode: str, 117 | logits: torch.Tensor, 118 | temp: float = 1.0, 119 | valid_mask=None): 120 | if mode == MULTICLASS_MODE: 121 | preds = F.log_softmax(temp * logits, dim=1).exp() 122 | else: 123 | preds = F.logsigmoid(temp * logits).exp() 124 | pred_proportion = get_region_proportion(preds, valid_mask) 125 | return pred_proportion 126 | 127 | 128 | class CrossEntropyWithL1(CompoundLoss): 129 | """ 130 | Cross entropy loss with region size priors measured by l1. 131 | The loss can be described as: 132 | l = CE(X, Y) + alpha * |gt_region - prob_region| 133 | """ 134 | def forward(self, inputs: torch.Tensor, labels: torch.Tensor): 135 | # ce term 136 | loss_ce = self.cross_entropy(inputs, labels) 137 | # regularization 138 | gt_proportion, valid_mask = self.get_gt_proportion(self.mode, labels, inputs.shape) 139 | pred_proportion = self.get_pred_proportion(self.mode, inputs, temp=self.temp, valid_mask=valid_mask) 140 | loss_reg = (pred_proportion - gt_proportion).abs().mean() 141 | 142 | loss = loss_ce + self.alpha * loss_reg 143 | 144 | return loss 145 | 146 | 147 | class CrossEntropyWithKL(CompoundLoss): 148 | """ 149 | Cross entropy loss with region size priors measured by l1. 150 | The loss can be described as: 151 | l = CE(X, Y) + alpha * KL(gt_region || prob_region) 152 | """ 153 | def kl_div(self, p : torch.Tensor, q : torch.Tensor) -> torch.Tensor: 154 | x = p * torch.log(p / q) 155 | x = torch.einsum("ij->i", x) 156 | return x 157 | 158 | def forward(self, inputs: torch.Tensor, labels: torch.Tensor): 159 | # ce term 160 | loss_ce = self.cross_entropy(inputs, labels) 161 | # regularization 162 | gt_proportion, valid_mask = self.get_gt_proportion(self.mode, labels, inputs.shape) 163 | pred_proportion = self.get_pred_proportion(self.mode, inputs, temp=self.temp, valid_mask=valid_mask) 164 | 165 | if self.mode == BINARY_MODE: 166 | regularizer = ( 167 | self.kl_div(gt_proportion, pred_proportion) 168 | + self.kl_div(1 - gt_proportion, 1 - pred_proportion) 169 | ).mean() 170 | else: 171 | regularizer = self.kl_div(gt_proportion, pred_proportion).mean() 172 | 173 | loss = loss_ce + self.alpha * regularizer 174 | 175 | return loss -------------------------------------------------------------------------------- /geoseg/losses/dice.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | from torch import Tensor 6 | from torch.nn.modules.loss import _Loss 7 | import numpy as np 8 | 9 | from .functional import soft_dice_score 10 | 11 | __all__ = ["DiceLoss"] 12 | 13 | BINARY_MODE = "binary" 14 | MULTICLASS_MODE = "multiclass" 15 | MULTILABEL_MODE = "multilabel" 16 | 17 | 18 | def to_tensor(x, dtype=None) -> torch.Tensor: 19 | if isinstance(x, torch.Tensor): 20 | if dtype is not None: 21 | x = x.type(dtype) 22 | return x 23 | if isinstance(x, np.ndarray) and x.dtype.kind not in {"O", "M", "U", "S"}: 24 | x = torch.from_numpy(x) 25 | if dtype is not None: 26 | x = x.type(dtype) 27 | return x 28 | if isinstance(x, (list, tuple)): 29 | x = np.ndarray(x) 30 | x = torch.from_numpy(x) 31 | if dtype is not None: 32 | x = x.type(dtype) 33 | return x 34 | 35 | raise ValueError("Unsupported input type" + str(type(x))) 36 | 37 | 38 | class DiceLoss(_Loss): 39 | """ 40 | Implementation of Dice loss for image segmentation task. 41 | It supports binary, multiclass and multilabel cases 42 | """ 43 | 44 | def __init__( 45 | self, 46 | mode: str = 'multiclass', 47 | classes: List[int] = None, 48 | log_loss=False, 49 | from_logits=True, 50 | smooth: float = 0.0, 51 | ignore_index=None, 52 | eps=1e-7, 53 | ): 54 | """ 55 | 56 | :param mode: Metric mode {'binary', 'multiclass', 'multilabel'} 57 | :param classes: Optional list of classes that contribute in loss computation; 58 | By default, all channels are included. 59 | :param log_loss: If True, loss computed as `-log(jaccard)`; otherwise `1 - jaccard` 60 | :param from_logits: If True assumes input is raw logits 61 | :param smooth: 62 | :param ignore_index: Label that indicates ignored pixels (does not contribute to loss) 63 | :param eps: Small epsilon for numerical stability 64 | """ 65 | assert mode in {BINARY_MODE, MULTILABEL_MODE, MULTICLASS_MODE} 66 | super(DiceLoss, self).__init__() 67 | self.mode = mode 68 | if classes is not None: 69 | assert mode != BINARY_MODE, "Masking classes is not supported with mode=binary" 70 | classes = to_tensor(classes, dtype=torch.long) 71 | 72 | self.classes = classes 73 | self.from_logits = from_logits 74 | self.smooth = smooth 75 | self.eps = eps 76 | self.ignore_index = ignore_index 77 | self.log_loss = log_loss 78 | 79 | def forward(self, y_pred: Tensor, y_true: Tensor) -> Tensor: 80 | """ 81 | 82 | :param y_pred: NxCxHxW 83 | :param y_true: NxHxW 84 | :return: scalar 85 | """ 86 | assert y_true.size(0) == y_pred.size(0) 87 | 88 | if self.from_logits: 89 | # Apply activations to get [0..1] class probabilities 90 | # Using Log-Exp as this gives more numerically stable result and does not cause vanishing gradient on 91 | # extreme values 0 and 1 92 | if self.mode == MULTICLASS_MODE: 93 | y_pred = y_pred.log_softmax(dim=1).exp() 94 | else: 95 | y_pred = F.logsigmoid(y_pred).exp() 96 | 97 | bs = y_true.size(0) 98 | num_classes = y_pred.size(1) 99 | dims = (0, 2) 100 | 101 | if self.mode == BINARY_MODE: 102 | y_true = y_true.view(bs, 1, -1) 103 | y_pred = y_pred.view(bs, 1, -1) 104 | 105 | if self.ignore_index is not None: 106 | mask = y_true != self.ignore_index 107 | y_pred = y_pred * mask 108 | y_true = y_true * mask 109 | 110 | if self.mode == MULTICLASS_MODE: 111 | y_true = y_true.view(bs, -1) 112 | y_pred = y_pred.view(bs, num_classes, -1) 113 | 114 | if self.ignore_index is not None: 115 | mask = y_true != self.ignore_index 116 | y_pred = y_pred * mask.unsqueeze(1) 117 | 118 | y_true = F.one_hot((y_true * mask).to(torch.long), num_classes) # N,H*W -> N,H*W, C 119 | y_true = y_true.permute(0, 2, 1) * mask.unsqueeze(1) # H, C, H*W 120 | else: 121 | y_true = F.one_hot(y_true, num_classes) # N,H*W -> N,H*W, C 122 | y_true = y_true.permute(0, 2, 1) # H, C, H*W 123 | 124 | if self.mode == MULTILABEL_MODE: 125 | y_true = y_true.view(bs, num_classes, -1) 126 | y_pred = y_pred.view(bs, num_classes, -1) 127 | 128 | if self.ignore_index is not None: 129 | mask = y_true != self.ignore_index 130 | y_pred = y_pred * mask 131 | y_true = y_true * mask 132 | 133 | scores = soft_dice_score(y_pred, y_true.type_as(y_pred), smooth=self.smooth, eps=self.eps, dims=dims) 134 | 135 | if self.log_loss: 136 | loss = -torch.log(scores.clamp_min(self.eps)) 137 | else: 138 | loss = 1.0 - scores 139 | 140 | # Dice loss is undefined for non-empty classes 141 | # So we zero contribution of channel that does not have true pixels 142 | # NOTE: A better workaround would be to use loss term `mean(y_pred)` 143 | # for this case, however it will be a modified jaccard loss 144 | 145 | mask = y_true.sum(dims) > 0 146 | loss *= mask.to(loss.dtype) 147 | 148 | if self.classes is not None: 149 | loss = loss[self.classes] 150 | 151 | return loss.mean() 152 | -------------------------------------------------------------------------------- /geoseg/losses/focal.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | import torch 4 | from torch.nn.modules.loss import _Loss 5 | 6 | from .functional import focal_loss_with_logits 7 | 8 | __all__ = ["BinaryFocalLoss", "FocalLoss"] 9 | 10 | 11 | class BinaryFocalLoss(_Loss): 12 | def __init__( 13 | self, 14 | alpha=0.5, 15 | gamma: float = 2.0, 16 | ignore_index=None, 17 | reduction="mean", 18 | normalized=False, 19 | reduced_threshold=None, 20 | ): 21 | """ 22 | 23 | :param alpha: Prior probability of having positive value in target. 24 | :param gamma: Power factor for dampening weight (focal strenght). 25 | :param ignore_index: If not None, targets may contain values to be ignored. 26 | Target values equal to ignore_index will be ignored from loss computation. 27 | :param reduced: Switch to reduced focal loss. Note, when using this mode you should use `reduction="sum"`. 28 | :param threshold: 29 | """ 30 | super().__init__() 31 | self.ignore_index = ignore_index 32 | self.focal_loss_fn = partial( 33 | focal_loss_with_logits, 34 | alpha=alpha, 35 | gamma=gamma, 36 | reduced_threshold=reduced_threshold, 37 | reduction=reduction, 38 | normalized=normalized, 39 | ignore_index=ignore_index, 40 | ) 41 | 42 | def forward(self, label_input, label_target): 43 | """Compute focal loss for binary classification problem.""" 44 | loss = self.focal_loss_fn(label_input, label_target) 45 | return loss 46 | 47 | 48 | class FocalLoss(_Loss): 49 | def __init__(self, alpha=0.5, gamma=2, ignore_index=None, reduction="mean", normalized=False, reduced_threshold=None): 50 | """ 51 | Focal loss for multi-class problem. 52 | 53 | :param alpha: 54 | :param gamma: 55 | :param ignore_index: If not None, targets with given index are ignored 56 | :param reduced_threshold: A threshold factor for computing reduced focal loss 57 | """ 58 | super().__init__() 59 | self.ignore_index = ignore_index 60 | self.focal_loss_fn = partial( 61 | focal_loss_with_logits, 62 | alpha=alpha, 63 | gamma=gamma, 64 | reduced_threshold=reduced_threshold, 65 | reduction=reduction, 66 | normalized=normalized, 67 | ) 68 | 69 | def forward(self, label_input, label_target): 70 | num_classes = label_input.size(1) 71 | loss = 0 72 | 73 | # Filter anchors with -1 label from loss computation 74 | if self.ignore_index is not None: 75 | not_ignored = label_target != self.ignore_index 76 | 77 | for cls in range(num_classes): 78 | cls_label_target = (label_target == cls).long() 79 | cls_label_input = label_input[:, cls, ...] 80 | 81 | if self.ignore_index is not None: 82 | cls_label_target = cls_label_target[not_ignored] 83 | cls_label_input = cls_label_input[not_ignored] 84 | 85 | loss += self.focal_loss_fn(cls_label_input, cls_label_target) 86 | return loss 87 | -------------------------------------------------------------------------------- /geoseg/losses/focal_cosine.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | from torch import nn, Tensor 3 | import torch.nn.functional as F 4 | import torch 5 | 6 | __all__ = ["FocalCosineLoss"] 7 | 8 | 9 | class FocalCosineLoss(nn.Module): 10 | """ 11 | Implementation Focal cosine loss from the "Data-Efficient Deep Learning Method for Image Classification 12 | Using Data Augmentation, Focal Cosine Loss, and Ensemble" (https://arxiv.org/abs/2007.07805). 13 | 14 | Credit: https://www.kaggle.com/c/cassava-leaf-disease-classification/discussion/203271 15 | """ 16 | 17 | def __init__(self, alpha: float = 1, gamma: float = 2, xent: float = 0.1, reduction="mean"): 18 | super(FocalCosineLoss, self).__init__() 19 | self.alpha = alpha 20 | self.gamma = gamma 21 | self.xent = xent 22 | self.reduction = reduction 23 | 24 | def forward(self, input: Tensor, target: Tensor) -> Tensor: 25 | cosine_loss = F.cosine_embedding_loss( 26 | input, 27 | torch.nn.functional.one_hot(target, num_classes=input.size(-1)), 28 | torch.tensor([1], device=target.device), 29 | reduction=self.reduction, 30 | ) 31 | 32 | cent_loss = F.cross_entropy(F.normalize(input), target, reduction="none") 33 | pt = torch.exp(-cent_loss) 34 | focal_loss = self.alpha * (1 - pt) ** self.gamma * cent_loss 35 | 36 | if self.reduction == "mean": 37 | focal_loss = torch.mean(focal_loss) 38 | 39 | return cosine_loss + self.xent * focal_loss 40 | -------------------------------------------------------------------------------- /geoseg/losses/functional.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Optional 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | 7 | __all__ = [ 8 | "focal_loss_with_logits", 9 | "softmax_focal_loss_with_logits", 10 | "soft_jaccard_score", 11 | "soft_dice_score", 12 | "wing_loss", 13 | ] 14 | 15 | 16 | def focal_loss_with_logits( 17 | output: torch.Tensor, 18 | target: torch.Tensor, 19 | gamma: float = 2.0, 20 | alpha: Optional[float] = 0.25, 21 | reduction: str = "mean", 22 | normalized: bool = False, 23 | reduced_threshold: Optional[float] = None, 24 | eps: float = 1e-6, 25 | ignore_index=None, 26 | ) -> torch.Tensor: 27 | """Compute binary focal loss between target and output logits. 28 | 29 | See :class:`~pytorch_toolbelt.losses.FocalLoss` for details. 30 | 31 | Args: 32 | output: Tensor of arbitrary shape (predictions of the models) 33 | target: Tensor of the same shape as input 34 | gamma: Focal loss power factor 35 | alpha: Weight factor to balance positive and negative samples. Alpha must be in [0...1] range, 36 | high values will give more weight to positive class. 37 | reduction (string, optional): Specifies the reduction to apply to the output: 38 | 'none' | 'mean' | 'sum' | 'batchwise_mean'. 'none': no reduction will be applied, 39 | 'mean': the sum of the output will be divided by the number of 40 | elements in the output, 'sum': the output will be summed. Note: :attr:`size_average` 41 | and :attr:`reduce` are in the process of being deprecated, and in the meantime, 42 | specifying either of those two args will override :attr:`reduction`. 43 | 'batchwise_mean' computes mean loss per sample in batch. Default: 'mean' 44 | normalized (bool): Compute normalized focal loss (https://arxiv.org/pdf/1909.07829.pdf). 45 | reduced_threshold (float, optional): Compute reduced focal loss (https://arxiv.org/abs/1903.01347). 46 | 47 | References: 48 | https://github.com/open-mmlab/mmdetection/blob/master/mmdet/core/loss/losses.py 49 | """ 50 | target = target.type_as(output) 51 | 52 | p = torch.sigmoid(output) 53 | ce_loss = F.binary_cross_entropy_with_logits(output, target, reduction="none") 54 | pt = p * target + (1 - p) * (1 - target) 55 | 56 | # compute the loss 57 | if reduced_threshold is None: 58 | focal_term = (1.0 - pt).pow(gamma) 59 | else: 60 | focal_term = ((1.0 - pt) / reduced_threshold).pow(gamma) 61 | focal_term = torch.masked_fill(focal_term, pt < reduced_threshold, 1) 62 | 63 | loss = focal_term * ce_loss 64 | 65 | if alpha is not None: 66 | loss *= alpha * target + (1 - alpha) * (1 - target) 67 | 68 | if ignore_index is not None: 69 | ignore_mask = target.eq(ignore_index) 70 | loss = torch.masked_fill(loss, ignore_mask, 0) 71 | if normalized: 72 | focal_term = torch.masked_fill(focal_term, ignore_mask, 0) 73 | 74 | if normalized: 75 | norm_factor = focal_term.sum(dtype=torch.float32).clamp_min(eps) 76 | loss /= norm_factor 77 | 78 | if reduction == "mean": 79 | loss = loss.mean() 80 | if reduction == "sum": 81 | loss = loss.sum(dtype=torch.float32) 82 | if reduction == "batchwise_mean": 83 | loss = loss.sum(dim=0, dtype=torch.float32) 84 | 85 | return loss 86 | 87 | 88 | def softmax_focal_loss_with_logits( 89 | output: torch.Tensor, 90 | target: torch.Tensor, 91 | gamma: float = 2.0, 92 | reduction="mean", 93 | normalized=False, 94 | reduced_threshold: Optional[float] = None, 95 | eps: float = 1e-6, 96 | ) -> torch.Tensor: 97 | """ 98 | Softmax version of focal loss between target and output logits. 99 | See :class:`~pytorch_toolbelt.losses.FocalLoss` for details. 100 | 101 | Args: 102 | output: Tensor of shape [B, C, *] (Similar to nn.CrossEntropyLoss) 103 | target: Tensor of shape [B, *] (Similar to nn.CrossEntropyLoss) 104 | reduction (string, optional): Specifies the reduction to apply to the output: 105 | 'none' | 'mean' | 'sum' | 'batchwise_mean'. 'none': no reduction will be applied, 106 | 'mean': the sum of the output will be divided by the number of 107 | elements in the output, 'sum': the output will be summed. Note: :attr:`size_average` 108 | and :attr:`reduce` are in the process of being deprecated, and in the meantime, 109 | specifying either of those two args will override :attr:`reduction`. 110 | 'batchwise_mean' computes mean loss per sample in batch. Default: 'mean' 111 | normalized (bool): Compute normalized focal loss (https://arxiv.org/pdf/1909.07829.pdf). 112 | reduced_threshold (float, optional): Compute reduced focal loss (https://arxiv.org/abs/1903.01347). 113 | """ 114 | log_softmax = F.log_softmax(output, dim=1) 115 | 116 | loss = F.nll_loss(log_softmax, target, reduction="none") 117 | pt = torch.exp(-loss) 118 | 119 | # compute the loss 120 | if reduced_threshold is None: 121 | focal_term = (1.0 - pt).pow(gamma) 122 | else: 123 | focal_term = ((1.0 - pt) / reduced_threshold).pow(gamma) 124 | focal_term[pt < reduced_threshold] = 1 125 | 126 | loss = focal_term * loss 127 | 128 | if normalized: 129 | norm_factor = focal_term.sum().clamp_min(eps) 130 | loss = loss / norm_factor 131 | 132 | if reduction == "mean": 133 | loss = loss.mean() 134 | if reduction == "sum": 135 | loss = loss.sum() 136 | if reduction == "batchwise_mean": 137 | loss = loss.sum(0) 138 | 139 | return loss 140 | 141 | 142 | def soft_jaccard_score( 143 | output: torch.Tensor, target: torch.Tensor, smooth: float = 0.0, eps: float = 1e-7, dims=None 144 | ) -> torch.Tensor: 145 | """ 146 | 147 | :param output: 148 | :param target: 149 | :param smooth: 150 | :param eps: 151 | :param dims: 152 | :return: 153 | 154 | Shape: 155 | - Input: :math:`(N, NC, *)` where :math:`*` means 156 | any number of additional dimensions 157 | - Target: :math:`(N, NC, *)`, same shape as the input 158 | - Output: scalar. 159 | 160 | """ 161 | assert output.size() == target.size() 162 | 163 | if dims is not None: 164 | intersection = torch.sum(output * target, dim=dims) 165 | cardinality = torch.sum(output + target, dim=dims) 166 | else: 167 | intersection = torch.sum(output * target) 168 | cardinality = torch.sum(output + target) 169 | 170 | union = cardinality - intersection 171 | jaccard_score = (intersection + smooth) / (union + smooth).clamp_min(eps) 172 | return jaccard_score 173 | 174 | 175 | def soft_dice_score( 176 | output: torch.Tensor, target: torch.Tensor, smooth: float = 0.0, eps: float = 1e-7, dims=None 177 | ) -> torch.Tensor: 178 | """ 179 | 180 | :param output: 181 | :param target: 182 | :param smooth: 183 | :param eps: 184 | :return: 185 | 186 | Shape: 187 | - Input: :math:`(N, NC, *)` where :math:`*` means any number 188 | of additional dimensions 189 | - Target: :math:`(N, NC, *)`, same shape as the input 190 | - Output: scalar. 191 | 192 | """ 193 | assert output.size() == target.size() 194 | if dims is not None: 195 | intersection = torch.sum(output * target, dim=dims) 196 | cardinality = torch.sum(output + target, dim=dims) 197 | else: 198 | intersection = torch.sum(output * target) 199 | cardinality = torch.sum(output + target) 200 | dice_score = (2.0 * intersection + smooth) / (cardinality + smooth).clamp_min(eps) 201 | return dice_score 202 | 203 | 204 | def wing_loss(output: torch.Tensor, target: torch.Tensor, width=5, curvature=0.5, reduction="mean"): 205 | """ 206 | https://arxiv.org/pdf/1711.06753.pdf 207 | :param output: 208 | :param target: 209 | :param width: 210 | :param curvature: 211 | :param reduction: 212 | :return: 213 | """ 214 | diff_abs = (target - output).abs() 215 | loss = diff_abs.clone() 216 | 217 | idx_smaller = diff_abs < width 218 | idx_bigger = diff_abs >= width 219 | 220 | loss[idx_smaller] = width * torch.log(1 + diff_abs[idx_smaller] / curvature) 221 | 222 | C = width - width * math.log(1 + width / curvature) 223 | loss[idx_bigger] = loss[idx_bigger] - C 224 | 225 | if reduction == "sum": 226 | loss = loss.sum() 227 | 228 | if reduction == "mean": 229 | loss = loss.mean() 230 | 231 | return loss 232 | 233 | 234 | def label_smoothed_nll_loss( 235 | lprobs: torch.Tensor, target: torch.Tensor, epsilon: float, ignore_index=None, reduction="mean", dim=-1 236 | ) -> torch.Tensor: 237 | """ 238 | 239 | Source: https://github.com/pytorch/fairseq/blob/master/fairseq/criterions/label_smoothed_cross_entropy.py 240 | 241 | :param lprobs: Log-probabilities of predictions (e.g after log_softmax) 242 | :param target: 243 | :param epsilon: 244 | :param ignore_index: 245 | :param reduction: 246 | :return: 247 | """ 248 | if target.dim() == lprobs.dim() - 1: 249 | target = target.unsqueeze(dim) 250 | 251 | if ignore_index is not None: 252 | pad_mask = target.eq(ignore_index) 253 | target = target.masked_fill(pad_mask, 0) 254 | nll_loss = -lprobs.gather(dim=dim, index=target) 255 | smooth_loss = -lprobs.sum(dim=dim, keepdim=True) 256 | 257 | # nll_loss.masked_fill_(pad_mask, 0.0) 258 | # smooth_loss.masked_fill_(pad_mask, 0.0) 259 | nll_loss = nll_loss.masked_fill(pad_mask, 0.0) 260 | smooth_loss = smooth_loss.masked_fill(pad_mask, 0.0) 261 | else: 262 | nll_loss = -lprobs.gather(dim=dim, index=target) 263 | smooth_loss = -lprobs.sum(dim=dim, keepdim=True) 264 | 265 | nll_loss = nll_loss.squeeze(dim) 266 | smooth_loss = smooth_loss.squeeze(dim) 267 | 268 | if reduction == "sum": 269 | nll_loss = nll_loss.sum() 270 | smooth_loss = smooth_loss.sum() 271 | if reduction == "mean": 272 | nll_loss = nll_loss.mean() 273 | smooth_loss = smooth_loss.mean() 274 | 275 | eps_i = epsilon / lprobs.size(dim) 276 | loss = (1.0 - epsilon) * nll_loss + eps_i * smooth_loss 277 | return loss 278 | -------------------------------------------------------------------------------- /geoseg/losses/jaccard.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | from .dice import to_tensor 6 | from torch import Tensor 7 | from torch.nn.modules.loss import _Loss 8 | 9 | from .functional import soft_jaccard_score 10 | 11 | __all__ = ["JaccardLoss", "BINARY_MODE", "MULTICLASS_MODE", "MULTILABEL_MODE"] 12 | 13 | BINARY_MODE = "binary" 14 | MULTICLASS_MODE = "multiclass" 15 | MULTILABEL_MODE = "multilabel" 16 | 17 | 18 | class JaccardLoss(_Loss): 19 | """ 20 | Implementation of Jaccard loss for image segmentation task. 21 | It supports binary, multi-class and multi-label cases. 22 | """ 23 | 24 | def __init__(self, mode: str, classes: List[int] = None, log_loss=False, from_logits=True, smooth=0, eps=1e-7): 25 | """ 26 | 27 | :param mode: Metric mode {'binary', 'multiclass', 'multilabel'} 28 | :param classes: Optional list of classes that contribute in loss computation; 29 | By default, all channels are included. 30 | :param log_loss: If True, loss computed as `-log(jaccard)`; otherwise `1 - jaccard` 31 | :param from_logits: If True assumes input is raw logits 32 | :param smooth: 33 | :param eps: Small epsilon for numerical stability 34 | """ 35 | assert mode in {BINARY_MODE, MULTILABEL_MODE, MULTICLASS_MODE} 36 | super(JaccardLoss, self).__init__() 37 | self.mode = mode 38 | if classes is not None: 39 | assert mode != BINARY_MODE, "Masking classes is not supported with mode=binary" 40 | classes = to_tensor(classes, dtype=torch.long) 41 | 42 | self.classes = classes 43 | self.from_logits = from_logits 44 | self.smooth = smooth 45 | self.eps = eps 46 | self.log_loss = log_loss 47 | 48 | def forward(self, y_pred: Tensor, y_true: Tensor) -> Tensor: 49 | """ 50 | 51 | :param y_pred: NxCxHxW 52 | :param y_true: NxHxW 53 | :return: scalar 54 | """ 55 | assert y_true.size(0) == y_pred.size(0) 56 | 57 | if self.from_logits: 58 | # Apply activations to get [0..1] class probabilities 59 | # Using Log-Exp as this gives more numerically stable result and does not cause vanishing gradient on 60 | # extreme values 0 and 1 61 | if self.mode == MULTICLASS_MODE: 62 | y_pred = y_pred.log_softmax(dim=1).exp() 63 | else: 64 | y_pred = F.logsigmoid(y_pred).exp() 65 | 66 | bs = y_true.size(0) 67 | num_classes = y_pred.size(1) 68 | dims = (0, 2) 69 | 70 | if self.mode == BINARY_MODE: 71 | y_true = y_true.view(bs, 1, -1) 72 | y_pred = y_pred.view(bs, 1, -1) 73 | 74 | if self.mode == MULTICLASS_MODE: 75 | y_true = y_true.view(bs, -1) 76 | y_pred = y_pred.view(bs, num_classes, -1) 77 | 78 | y_true = F.one_hot(y_true, num_classes) # N,H*W -> N,H*W, C 79 | y_true = y_true.permute(0, 2, 1) # H, C, H*W 80 | 81 | if self.mode == MULTILABEL_MODE: 82 | y_true = y_true.view(bs, num_classes, -1) 83 | y_pred = y_pred.view(bs, num_classes, -1) 84 | 85 | scores = soft_jaccard_score(y_pred, y_true.type(y_pred.dtype), smooth=self.smooth, eps=self.eps, dims=dims) 86 | 87 | if self.log_loss: 88 | loss = -torch.log(scores.clamp_min(self.eps)) 89 | else: 90 | loss = 1.0 - scores 91 | 92 | # IoU loss is defined for non-empty classes 93 | # So we zero contribution of channel that does not have true pixels 94 | # NOTE: A better workaround would be to use loss term `mean(y_pred)` 95 | # for this case, however it will be a modified jaccard loss 96 | 97 | mask = y_true.sum(dims) > 0 98 | loss *= mask.float() 99 | 100 | if self.classes is not None: 101 | loss = loss[self.classes] 102 | 103 | return loss.mean() 104 | -------------------------------------------------------------------------------- /geoseg/losses/joint_loss.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from torch.nn.modules.loss import _Loss 3 | 4 | __all__ = ["JointLoss", "WeightedLoss"] 5 | 6 | 7 | class WeightedLoss(_Loss): 8 | """Wrapper class around loss function that applies weighted with fixed factor. 9 | This class helps to balance multiple losses if they have different scales 10 | """ 11 | 12 | def __init__(self, loss, weight=1.0): 13 | super().__init__() 14 | self.loss = loss 15 | self.weight = weight 16 | 17 | def forward(self, *input): 18 | return self.loss(*input) * self.weight 19 | 20 | 21 | class JointLoss(_Loss): 22 | """ 23 | Wrap two loss functions into one. This class computes a weighted sum of two losses. 24 | """ 25 | 26 | def __init__(self, first: nn.Module, second: nn.Module, first_weight=1.0, second_weight=1.0): 27 | super().__init__() 28 | self.first = WeightedLoss(first, first_weight) 29 | self.second = WeightedLoss(second, second_weight) 30 | 31 | def forward(self, *input): 32 | return self.first(*input) + self.second(*input) 33 | -------------------------------------------------------------------------------- /geoseg/losses/lovasz.py: -------------------------------------------------------------------------------- 1 | """ 2 | Lovasz-Softmax and Jaccard hinge loss in PyTorch 3 | Maxim Berman 2018 ESAT-PSI KU Leuven (MIT License) 4 | """ 5 | 6 | from __future__ import print_function, division 7 | 8 | from typing import Optional, Union 9 | 10 | import torch 11 | import torch.nn.functional as F 12 | from torch.autograd import Variable 13 | from torch.nn.modules.loss import _Loss 14 | 15 | try: 16 | from itertools import ifilterfalse 17 | except ImportError: # py3k 18 | from itertools import filterfalse as ifilterfalse 19 | 20 | __all__ = ["BinaryLovaszLoss", "LovaszLoss"] 21 | 22 | 23 | def _lovasz_grad(gt_sorted): 24 | """Compute gradient of the Lovasz extension w.r.t sorted errors 25 | See Alg. 1 in paper 26 | """ 27 | p = len(gt_sorted) 28 | gts = gt_sorted.sum() 29 | intersection = gts - gt_sorted.float().cumsum(0) 30 | union = gts + (1 - gt_sorted).float().cumsum(0) 31 | jaccard = 1.0 - intersection / union 32 | if p > 1: # cover 1-pixel case 33 | jaccard[1:p] = jaccard[1:p] - jaccard[0:-1] 34 | return jaccard 35 | 36 | 37 | def _lovasz_hinge(logits, labels, per_image=True, ignore_index=None): 38 | """ 39 | Binary Lovasz hinge loss 40 | logits: [B, H, W] Variable, logits at each pixel (between -infinity and +infinity) 41 | labels: [B, H, W] Tensor, binary ground truth masks (0 or 1) 42 | per_image: compute the loss per image instead of per batch 43 | ignore: void class id 44 | """ 45 | if per_image: 46 | loss = mean( 47 | _lovasz_hinge_flat(*_flatten_binary_scores(log.unsqueeze(0), lab.unsqueeze(0), ignore_index)) 48 | for log, lab in zip(logits, labels) 49 | ) 50 | else: 51 | loss = _lovasz_hinge_flat(*_flatten_binary_scores(logits, labels, ignore_index)) 52 | return loss 53 | 54 | 55 | def _lovasz_hinge_flat(logits, labels): 56 | """Binary Lovasz hinge loss 57 | Args: 58 | logits: [P] Variable, logits at each prediction (between -iinfinity and +iinfinity) 59 | labels: [P] Tensor, binary ground truth labels (0 or 1) 60 | ignore: label to ignore 61 | """ 62 | if len(labels) == 0: 63 | # only void pixels, the gradients should be 0 64 | return logits.sum() * 0.0 65 | signs = 2.0 * labels.float() - 1.0 66 | errors = 1.0 - logits * Variable(signs) 67 | errors_sorted, perm = torch.sort(errors, dim=0, descending=True) 68 | perm = perm.data 69 | gt_sorted = labels[perm] 70 | grad = _lovasz_grad(gt_sorted) 71 | loss = torch.dot(F.relu(errors_sorted), Variable(grad)) 72 | return loss 73 | 74 | 75 | def _flatten_binary_scores(scores, labels, ignore_index=None): 76 | """Flattens predictions in the batch (binary case) 77 | Remove labels equal to 'ignore' 78 | """ 79 | scores = scores.view(-1) 80 | labels = labels.view(-1) 81 | if ignore_index is None: 82 | return scores, labels 83 | valid = labels != ignore_index 84 | vscores = scores[valid] 85 | vlabels = labels[valid] 86 | return vscores, vlabels 87 | 88 | 89 | # --------------------------- MULTICLASS LOSSES --------------------------- 90 | 91 | 92 | def _lovasz_softmax(probas, labels, classes="present", per_image=False, ignore_index=None): 93 | """Multi-class Lovasz-Softmax loss 94 | Args: 95 | @param probas: [B, C, H, W] Variable, class probabilities at each prediction (between 0 and 1). 96 | Interpreted as binary (sigmoid) output with outputs of size [B, H, W]. 97 | @param labels: [B, H, W] Tensor, ground truth labels (between 0 and C - 1) 98 | @param classes: 'all' for all, 'present' for classes present in labels, or a list of classes to average. 99 | @param per_image: compute the loss per image instead of per batch 100 | @param ignore_index: void class labels 101 | """ 102 | if per_image: 103 | loss = mean( 104 | _lovasz_softmax_flat(*_flatten_probas(prob.unsqueeze(0), lab.unsqueeze(0), ignore_index), classes=classes) 105 | for prob, lab in zip(probas, labels) 106 | ) 107 | else: 108 | loss = _lovasz_softmax_flat(*_flatten_probas(probas, labels, ignore_index), classes=classes) 109 | return loss 110 | 111 | 112 | def _lovasz_softmax_flat(probas, labels, classes="present"): 113 | """Multi-class Lovasz-Softmax loss 114 | Args: 115 | @param probas: [P, C] Variable, class probabilities at each prediction (between 0 and 1) 116 | @param labels: [P] Tensor, ground truth labels (between 0 and C - 1) 117 | @param classes: 'all' for all, 'present' for classes present in labels, or a list of classes to average. 118 | """ 119 | if probas.numel() == 0: 120 | # only void pixels, the gradients should be 0 121 | return probas * 0.0 122 | C = probas.size(1) 123 | losses = [] 124 | class_to_sum = list(range(C)) if classes in ["all", "present"] else classes 125 | for c in class_to_sum: 126 | fg = (labels == c).type_as(probas) # foreground for class c 127 | if classes == "present" and fg.sum() == 0: 128 | continue 129 | if C == 1: 130 | if len(classes) > 1: 131 | raise ValueError("Sigmoid output possible only with 1 class") 132 | class_pred = probas[:, 0] 133 | else: 134 | class_pred = probas[:, c] 135 | errors = (fg - class_pred).abs() 136 | errors_sorted, perm = torch.sort(errors, 0, descending=True) 137 | perm = perm.data 138 | fg_sorted = fg[perm] 139 | losses.append(torch.dot(errors_sorted, _lovasz_grad(fg_sorted))) 140 | return mean(losses) 141 | 142 | 143 | def _flatten_probas(probas, labels, ignore=None): 144 | """Flattens predictions in the batch""" 145 | if probas.dim() == 3: 146 | # assumes output of a sigmoid layer 147 | B, H, W = probas.size() 148 | probas = probas.view(B, 1, H, W) 149 | 150 | C = probas.size(1) 151 | probas = torch.movedim(probas, 1, -1) # [B, C, Di, Dj, ...] -> [B, Di, Dj, ..., C] 152 | probas = probas.contiguous().view(-1, C) # [P, C] 153 | 154 | labels = labels.view(-1) 155 | if ignore is None: 156 | return probas, labels 157 | valid = labels != ignore 158 | vprobas = probas[valid] 159 | vlabels = labels[valid] 160 | return vprobas, vlabels 161 | 162 | 163 | # --------------------------- HELPER FUNCTIONS --------------------------- 164 | def isnan(x): 165 | return x != x 166 | 167 | 168 | def mean(values, ignore_nan=False, empty=0): 169 | """Nanmean compatible with generators.""" 170 | values = iter(values) 171 | if ignore_nan: 172 | values = ifilterfalse(isnan, values) 173 | try: 174 | n = 1 175 | acc = next(values) 176 | except StopIteration: 177 | if empty == "raise": 178 | raise ValueError("Empty mean") 179 | return empty 180 | for n, v in enumerate(values, 2): 181 | acc += v 182 | if n == 1: 183 | return acc 184 | return acc / n 185 | 186 | 187 | class BinaryLovaszLoss(_Loss): 188 | def __init__(self, per_image: bool = False, ignore_index: Optional[Union[int, float]] = None): 189 | super().__init__() 190 | self.ignore_index = ignore_index 191 | self.per_image = per_image 192 | 193 | def forward(self, logits, target): 194 | return _lovasz_hinge(logits, target, per_image=self.per_image, ignore_index=self.ignore_index) 195 | 196 | 197 | class LovaszLoss(_Loss): 198 | def __init__(self, per_image=False, ignore=None): 199 | super().__init__() 200 | self.ignore = ignore 201 | self.per_image = per_image 202 | 203 | def forward(self, logits, target): 204 | return _lovasz_softmax(logits, target, per_image=self.per_image, ignore_index=self.ignore) 205 | -------------------------------------------------------------------------------- /geoseg/losses/soft_bce.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import torch.nn.functional as F 4 | from torch import nn, Tensor 5 | 6 | __all__ = ["SoftBCEWithLogitsLoss"] 7 | 8 | 9 | class SoftBCEWithLogitsLoss(nn.Module): 10 | """ 11 | Drop-in replacement for nn.BCEWithLogitsLoss with few additions: 12 | - Support of ignore_index value 13 | - Support of label smoothing 14 | """ 15 | 16 | __constants__ = ["weight", "pos_weight", "reduction", "ignore_index", "smooth_factor"] 17 | 18 | def __init__( 19 | self, weight=None, ignore_index: Optional[int] = -100, reduction="mean", smooth_factor=None, pos_weight=None 20 | ): 21 | super().__init__() 22 | self.ignore_index = ignore_index 23 | self.reduction = reduction 24 | self.smooth_factor = smooth_factor 25 | self.register_buffer("weight", weight) 26 | self.register_buffer("pos_weight", pos_weight) 27 | 28 | def forward(self, input: Tensor, target: Tensor) -> Tensor: 29 | if self.smooth_factor is not None: 30 | soft_targets = ((1 - target) * self.smooth_factor + target * (1 - self.smooth_factor)).type_as(input) 31 | else: 32 | soft_targets = target.type_as(input) 33 | 34 | loss = F.binary_cross_entropy_with_logits( 35 | input, soft_targets, self.weight, pos_weight=self.pos_weight, reduction="none" 36 | ) 37 | 38 | if self.ignore_index is not None: 39 | not_ignored_mask: Tensor = target != self.ignore_index 40 | loss *= not_ignored_mask.type_as(loss) 41 | 42 | if self.reduction == "mean": 43 | loss = loss.mean() 44 | 45 | if self.reduction == "sum": 46 | loss = loss.sum() 47 | 48 | return loss 49 | -------------------------------------------------------------------------------- /geoseg/losses/soft_ce.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | from torch import nn, Tensor 3 | import torch.nn.functional as F 4 | from .functional import label_smoothed_nll_loss 5 | 6 | __all__ = ["SoftCrossEntropyLoss"] 7 | 8 | 9 | class SoftCrossEntropyLoss(nn.Module): 10 | """ 11 | Drop-in replacement for nn.CrossEntropyLoss with few additions: 12 | - Support of label smoothing 13 | """ 14 | 15 | __constants__ = ["reduction", "ignore_index", "smooth_factor"] 16 | 17 | def __init__(self, reduction: str = "mean", smooth_factor: float = 0.0, ignore_index: Optional[int] = -100, dim=1): 18 | super().__init__() 19 | self.smooth_factor = smooth_factor 20 | self.ignore_index = ignore_index 21 | self.reduction = reduction 22 | self.dim = dim 23 | 24 | def forward(self, input: Tensor, target: Tensor) -> Tensor: 25 | log_prob = F.log_softmax(input, dim=self.dim) 26 | return label_smoothed_nll_loss( 27 | log_prob, 28 | target, 29 | epsilon=self.smooth_factor, 30 | ignore_index=self.ignore_index, 31 | reduction=self.reduction, 32 | dim=self.dim, 33 | ) 34 | -------------------------------------------------------------------------------- /geoseg/losses/soft_f1.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, Tensor 3 | from typing import Optional 4 | 5 | __all__ = ["soft_micro_f1", "BinarySoftF1Loss", "SoftF1Loss"] 6 | 7 | 8 | def soft_micro_f1(preds: Tensor, targets: Tensor, eps=1e-6) -> Tensor: 9 | """Compute the macro soft F1-score as a cost. 10 | Average (1 - soft-F1) across all labels. 11 | Use probability values instead of binary predictions. 12 | 13 | Args: 14 | targets (Tensor): targets array of shape (Num Samples, Num Classes) 15 | preds (Tensor): probability matrix of shape (Num Samples, Num Classes) 16 | 17 | Returns: 18 | cost (scalar Tensor): value of the cost function for the batch 19 | 20 | References: 21 | https://towardsdatascience.com/the-unknown-benefits-of-using-a-soft-f1-loss-in-classification-systems-753902c0105d 22 | """ 23 | tp = torch.sum(preds * targets, dim=0) 24 | fp = torch.sum(preds * (1 - targets), dim=0) 25 | fn = torch.sum((1 - preds) * targets, dim=0) 26 | soft_f1 = 2 * tp / (2 * tp + fn + fp + eps) 27 | loss = 1 - soft_f1 # reduce 1 - soft-f1 in order to increase soft-f1 28 | return loss.mean() 29 | 30 | 31 | # TODO: Test 32 | # def macro_double_soft_f1(y, y_hat): 33 | # """Compute the macro soft F1-score as a cost (average 1 - soft-F1 across all labels). 34 | # Use probability values instead of binary predictions. 35 | # This version uses the computation of soft-F1 for both positive and negative class for each label. 36 | # 37 | # Args: 38 | # y (int32 Tensor): targets array of shape (BATCH_SIZE, N_LABELS) 39 | # y_hat (float32 Tensor): probability matrix from forward propagation of shape (BATCH_SIZE, N_LABELS) 40 | # 41 | # Returns: 42 | # cost (scalar Tensor): value of the cost function for the batch 43 | # """ 44 | # tp = tf.reduce_sum(y_hat * y, axis=0) 45 | # fp = tf.reduce_sum(y_hat * (1 - y), axis=0) 46 | # fn = tf.reduce_sum((1 - y_hat) * y, axis=0) 47 | # tn = tf.reduce_sum((1 - y_hat) * (1 - y), axis=0) 48 | # soft_f1_class1 = 2 * tp / (2 * tp + fn + fp + 1e-16) 49 | # soft_f1_class0 = 2 * tn / (2 * tn + fn + fp + 1e-16) 50 | # cost_class1 = 1 - soft_f1_class1 # reduce 1 - soft-f1_class1 in order to increase soft-f1 on class 1 51 | # cost_class0 = 1 - soft_f1_class0 # reduce 1 - soft-f1_class0 in order to increase soft-f1 on class 0 52 | # cost = 0.5 * (cost_class1 + cost_class0) # take into account both class 1 and class 0 53 | # macro_cost = tf.reduce_mean(cost) # average on all labels 54 | # return macro_cost 55 | 56 | 57 | class BinarySoftF1Loss(nn.Module): 58 | def __init__(self, ignore_index: Optional[int] = None, eps=1e-6): 59 | super().__init__() 60 | self.ignore_index = ignore_index 61 | self.eps = eps 62 | 63 | def forward(self, preds: Tensor, targets: Tensor) -> Tensor: 64 | targets = targets.view(-1) 65 | preds = preds.view(-1) 66 | 67 | if self.ignore_index is not None: 68 | # Filter predictions with ignore label from loss computation 69 | not_ignored = targets != self.ignore_index 70 | preds = preds[not_ignored] 71 | targets = targets[not_ignored] 72 | 73 | if targets.numel() == 0: 74 | return torch.tensor(0, dtype=preds.dtype, device=preds.device) 75 | 76 | preds = preds.sigmoid().clamp(self.eps, 1 - self.eps) 77 | return soft_micro_f1(preds.view(-1, 1), targets.view(-1, 1)) 78 | 79 | 80 | class SoftF1Loss(nn.Module): 81 | def __init__(self, ignore_index: Optional[int] = None, eps=1e-6): 82 | super().__init__() 83 | self.ignore_index = ignore_index 84 | self.eps = eps 85 | 86 | def forward(self, preds: Tensor, targets: Tensor) -> Tensor: 87 | preds = preds.softmax(dim=1).clamp(self.eps, 1 - self.eps) 88 | targets = torch.nn.functional.one_hot(targets, preds.size(1)) 89 | 90 | if self.ignore_index is not None: 91 | # Filter predictions with ignore label from loss computation 92 | not_ignored = targets != self.ignore_index 93 | preds = preds[not_ignored] 94 | targets = targets[not_ignored] 95 | 96 | if targets.numel() == 0: 97 | return torch.tensor(0, dtype=preds.dtype, device=preds.device) 98 | 99 | return soft_micro_f1(preds, targets) 100 | -------------------------------------------------------------------------------- /geoseg/losses/useful_loss.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn.functional as F 4 | import torch.nn as nn 5 | from torch import Tensor 6 | from .soft_ce import SoftCrossEntropyLoss 7 | from .joint_loss import JointLoss 8 | from .dice import DiceLoss 9 | 10 | 11 | class EdgeLoss(nn.Module): 12 | def __init__(self, ignore_index=255, edge_factor=1.0): 13 | super(EdgeLoss, self).__init__() 14 | self.main_loss = JointLoss(SoftCrossEntropyLoss(smooth_factor=0.05, ignore_index=ignore_index), 15 | DiceLoss(smooth=0.05, ignore_index=ignore_index), 1.0, 1.0) 16 | self.edge_factor = edge_factor 17 | 18 | def get_boundary(self, x): 19 | laplacian_kernel_target = torch.tensor( 20 | [-1, -1, -1, -1, 8, -1, -1, -1, -1], 21 | dtype=torch.float32).reshape(1, 1, 3, 3).requires_grad_(False).cuda(device=x.device) 22 | x = x.unsqueeze(1).float() 23 | x = F.conv2d(x, laplacian_kernel_target, padding=1) 24 | x = x.clamp(min=0) 25 | x[x >= 0.1] = 1 26 | x[x < 0.1] = 0 27 | 28 | return x 29 | 30 | def compute_edge_loss(self, logits, targets): 31 | bs = logits.size()[0] 32 | boundary_targets = self.get_boundary(targets) 33 | boundary_targets = boundary_targets.view(bs, 1, -1) 34 | # print(boundary_targets.shape) 35 | logits = F.softmax(logits, dim=1).argmax(dim=1).squeeze(dim=1) 36 | boundary_pre = self.get_boundary(logits) 37 | boundary_pre = boundary_pre / (boundary_pre + 0.01) 38 | # print(boundary_pre) 39 | boundary_pre = boundary_pre.view(bs, 1, -1) 40 | # print(boundary_pre) 41 | # dice_loss = 1 - ((2. * (boundary_pre * boundary_targets).sum(1) + 1.0) / 42 | # (boundary_pre.sum(1) + boundary_targets.sum(1) + 1.0)) 43 | # dice_loss = dice_loss.mean() 44 | edge_loss = F.binary_cross_entropy_with_logits(boundary_pre, boundary_targets) 45 | 46 | return edge_loss 47 | 48 | def forward(self, logits, targets): 49 | loss = self.main_loss(logits, targets) + self.compute_edge_loss(logits, targets) * self.edge_factor 50 | return loss 51 | 52 | 53 | class OHEM_CELoss(nn.Module): 54 | 55 | def __init__(self, thresh=0.7, ignore_index=255): 56 | super(OHEM_CELoss, self).__init__() 57 | self.thresh = -torch.log(torch.tensor(thresh, requires_grad=False, dtype=torch.float)).cuda() 58 | self.ignore_index = ignore_index 59 | self.criteria = nn.CrossEntropyLoss(ignore_index=ignore_index, reduction='none') 60 | 61 | def forward(self, logits, labels): 62 | n_min = labels[labels != self.ignore_index].numel() // 16 63 | loss = self.criteria(logits, labels).view(-1) 64 | loss_hard = loss[loss > self.thresh] 65 | if loss_hard.numel() < n_min: 66 | loss_hard, _ = loss.topk(n_min) 67 | return torch.mean(loss_hard) 68 | 69 | 70 | class UnetFormerLoss(nn.Module): 71 | 72 | def __init__(self, ignore_index=255): 73 | super().__init__() 74 | self.main_loss = JointLoss(SoftCrossEntropyLoss(smooth_factor=0.05, ignore_index=ignore_index), 75 | DiceLoss(smooth=0.05, ignore_index=ignore_index), 1.0, 1.0) 76 | self.aux_loss = SoftCrossEntropyLoss(smooth_factor=0.05, ignore_index=ignore_index) 77 | 78 | def forward(self, logits, labels): 79 | if self.training and len(logits) == 2: 80 | logit_main, logit_aux = logits 81 | loss = self.main_loss(logit_main, labels) + 0.4 * self.aux_loss(logit_aux, labels) 82 | else: 83 | loss = self.main_loss(logits, labels) 84 | 85 | return loss 86 | 87 | 88 | if __name__ == '__main__': 89 | targets = torch.randint(low=0, high=2, size=(2, 16, 16)) 90 | logits = torch.randn((2, 2, 16, 16)) 91 | # print(targets) 92 | model = EdgeLoss() 93 | loss = model.compute_edge_loss(logits, targets) 94 | 95 | print(loss) -------------------------------------------------------------------------------- /geoseg/losses/wing_loss.py: -------------------------------------------------------------------------------- 1 | from torch.nn.modules.loss import _Loss 2 | 3 | from . import functional as F 4 | 5 | __all__ = ["WingLoss"] 6 | 7 | 8 | class WingLoss(_Loss): 9 | def __init__(self, width=5, curvature=0.5, reduction="mean"): 10 | super(WingLoss, self).__init__(reduction=reduction) 11 | self.width = width 12 | self.curvature = curvature 13 | 14 | def forward(self, prediction, target): 15 | return F.wing_loss(prediction, target, self.width, self.curvature, self.reduction) 16 | -------------------------------------------------------------------------------- /geoseg/models/SSFNet.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | import math 4 | import torch 5 | import torch.nn as nn 6 | import torch.utils.model_zoo as modelzoo 7 | from torch.nn import BatchNorm2d 8 | 9 | class SSFNet(nn.Module): 10 | 11 | def __init__(self, num_classes, *args, **kwargs): 12 | super(SSFNet, self).__init__() 13 | self.cp = Semanticbranch() 14 | self.sp = Detailbrabch1() 15 | 16 | self.sp1 = nn.Sequential( 17 | nn.Conv2d( 18 | 256, 128, kernel_size=3, stride=1, 19 | padding=1, bias=False), 20 | nn.BatchNorm2d(128), 21 | ) 22 | self.eca = ECA(128) 23 | self.bfm = BFM() 24 | self.conv_out = SegHead(128, 128, num_classes, up_factor=8) 25 | 26 | 27 | def forward(self, x): 28 | a, feat_res8, feat_cp8, feat_cp16 = self.cp(x) 29 | a = self.sp(a) 30 | feat_sp = torch.concat([a, feat_res8], 1) 31 | feat_sp = self.sp1(feat_sp) 32 | feat_sp = self.eca(feat_sp) 33 | feat_fuse = self.bfm(feat_sp, feat_cp8) 34 | feat_out = self.conv_out(feat_fuse) 35 | return feat_out 36 | 37 | resnet18_url = 'https://download.pytorch.org/models/resnet18-5c106cde.pth' 38 | 39 | 40 | def conv3x3(in_planes, out_planes, stride=1): 41 | """3x3 convolution with padding""" 42 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 43 | padding=1, bias=False) 44 | 45 | 46 | class BasicBlock(nn.Module): 47 | def __init__(self, in_chan, out_chan, stride=1): 48 | super(BasicBlock, self).__init__() 49 | self.conv1 = conv3x3(in_chan, out_chan, stride) 50 | self.bn1 = BatchNorm2d(out_chan) 51 | self.conv2 = conv3x3(out_chan, out_chan) 52 | self.bn2 = BatchNorm2d(out_chan) 53 | self.relu = nn.ReLU(inplace=True) 54 | self.downsample = None 55 | if in_chan != out_chan or stride != 1: 56 | self.downsample = nn.Sequential( 57 | nn.Conv2d(in_chan, out_chan, 58 | kernel_size=1, stride=stride, bias=False), 59 | BatchNorm2d(out_chan), 60 | ) 61 | 62 | def forward(self, x): 63 | residual = self.conv1(x) 64 | residual = self.bn1(residual) 65 | residual = self.relu(residual) 66 | residual = self.conv2(residual) 67 | residual = self.bn2(residual) 68 | 69 | shortcut = x 70 | if self.downsample is not None: 71 | shortcut = self.downsample(x) 72 | 73 | out = shortcut + residual 74 | out = self.relu(out) 75 | return out 76 | 77 | 78 | def create_layer_basic(in_chan, out_chan, bnum, stride=1): 79 | layers = [BasicBlock(in_chan, out_chan, stride=stride)] 80 | for i in range(bnum-1): 81 | layers.append(BasicBlock(out_chan, out_chan, stride=1)) 82 | return nn.Sequential(*layers) 83 | 84 | 85 | class Resnet18(nn.Module): 86 | def __init__(self): 87 | super(Resnet18, self).__init__() 88 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 89 | bias=False) 90 | self.bn1 = BatchNorm2d(64) 91 | self.relu = nn.ReLU(inplace=True) 92 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 93 | self.layer1 = create_layer_basic(64, 64, bnum=2, stride=1) 94 | self.layer2 = create_layer_basic(64, 128, bnum=2, stride=2) 95 | self.layer3 = create_layer_basic(128, 256, bnum=2, stride=2) 96 | self.layer4 = create_layer_basic(256, 512, bnum=2, stride=2) 97 | self.init_weight() 98 | 99 | def forward(self, x): 100 | x = self.conv1(x) 101 | x = self.bn1(x) 102 | x = self.relu(x) 103 | x = self.maxpool(x) 104 | x = self.layer1(x) 105 | feat8 = self.layer2(x) # 1/8 106 | feat16 = self.layer3(feat8) # 1/16 107 | feat32 = self.layer4(feat16) # 1/32 108 | return x, feat8, feat16, feat32 109 | 110 | def init_weight(self): 111 | state_dict = modelzoo.load_url(resnet18_url) 112 | self_state_dict = self.state_dict() 113 | for k, v in state_dict.items(): 114 | if 'fc' in k: continue 115 | self_state_dict.update({k: v}) 116 | self.load_state_dict(self_state_dict) 117 | 118 | def get_params(self): 119 | wd_params, nowd_params = [], [] 120 | for name, module in self.named_modules(): 121 | if isinstance(module, (nn.Linear, nn.Conv2d)): 122 | wd_params.append(module.weight) 123 | if not module.bias is None: 124 | nowd_params.append(module.bias) 125 | elif isinstance(module, nn.modules.batchnorm._BatchNorm): 126 | nowd_params += list(module.parameters()) 127 | return wd_params, nowd_params 128 | 129 | class ConvBNReLU(nn.Module): 130 | 131 | def __init__(self, in_chan, out_chan, ks=3, stride=1, padding=1, *args, **kwargs): 132 | super(ConvBNReLU, self).__init__() 133 | self.conv = nn.Conv2d(in_chan, 134 | out_chan, 135 | kernel_size = ks, 136 | stride = stride, 137 | padding = padding, 138 | bias = False) 139 | self.bn = BatchNorm2d(out_chan) 140 | self.relu = nn.ReLU(inplace=True) 141 | 142 | def forward(self, x): 143 | x = self.conv(x) 144 | x = self.bn(x) 145 | x = self.relu(x) 146 | return x 147 | class ConvBNGELU(nn.Module): 148 | 149 | def __init__(self, in_chan, out_chan, ks=3, stride=1, padding=1, *args, **kwargs): 150 | super(ConvBNGELU, self).__init__() 151 | self.conv = nn.Conv2d(in_chan, 152 | out_chan, 153 | kernel_size = ks, 154 | stride = stride, 155 | padding = padding, 156 | bias = False) 157 | self.bn = BatchNorm2d(out_chan) 158 | self.relu = nn.GELU() 159 | 160 | def forward(self, x): 161 | x = self.conv(x) 162 | x = self.bn(x) 163 | x = self.relu(x) 164 | return x 165 | 166 | 167 | class UpSample(nn.Module): 168 | 169 | def __init__(self, n_chan, factor=2): 170 | super(UpSample, self).__init__() 171 | out_chan = n_chan * factor * factor 172 | self.proj = nn.Conv2d(n_chan, out_chan, 1, 1, 0) 173 | self.up = nn.PixelShuffle(factor) 174 | self.init_weight() 175 | 176 | def forward(self, x): 177 | feat = self.proj(x) 178 | feat = self.up(feat) 179 | return feat 180 | 181 | def init_weight(self): 182 | nn.init.xavier_normal_(self.proj.weight, gain=1.) 183 | 184 | 185 | class SegHead(nn.Module): 186 | 187 | def __init__(self, in_chan, mid_chan, n_classes, up_factor=32, *args, **kwargs): 188 | super(SegHead, self).__init__() 189 | self.up_factor = up_factor 190 | out_chan = n_classes 191 | self.conv = ConvBNReLU(in_chan, mid_chan, ks=3, stride=1, padding=1) 192 | self.drop = nn.Dropout(0.1) 193 | self.conv_out = nn.Conv2d(mid_chan, out_chan, kernel_size=1, bias=True) 194 | self.up = nn.Upsample(scale_factor=up_factor, 195 | mode='bilinear', align_corners=False) 196 | 197 | def forward(self, x): 198 | x = self.conv(x) 199 | x = self.drop(x) 200 | x = self.conv_out(x) 201 | x = self.up(x) 202 | return x 203 | 204 | class FAC(nn.Module): 205 | def __init__(self, in_chan, out_chan, *args, **kwargs): 206 | super(FAC, self).__init__() 207 | self.conv = ConvBNGELU(in_chan, out_chan, ks=3, stride=1, padding=1) 208 | self.conv_atten = nn.Conv2d(out_chan, out_chan, kernel_size=1, bias=False) 209 | self.bn_atten = BatchNorm2d(out_chan) 210 | 211 | def forward(self, x): 212 | x = self.conv(x) 213 | feat = torch.mean(x, dim=(2, 3), keepdim=True) 214 | feat = self.conv_atten(feat) 215 | feat = self.bn_atten(feat) 216 | feat = feat.sigmoid() 217 | out = torch.mul(x, feat) 218 | return out 219 | 220 | class BasicConv(nn.Module): 221 | 222 | def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1, relu=True, bn=True, bias=False): 223 | super(BasicConv, self).__init__() 224 | self.out_channels = out_planes 225 | self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias) 226 | self.bn = nn.BatchNorm2d(out_planes,eps=1e-5, momentum=0.01, affine=True) if bn else None 227 | self.relu = nn.GELU() if relu else None 228 | 229 | def forward(self, x): 230 | x = self.conv(x) 231 | if self.bn is not None: 232 | x = self.bn(x) 233 | if self.relu is not None: 234 | x = self.relu(x) 235 | return x 236 | 237 | 238 | class ECA(nn.Module): 239 | def __init__(self, in_channel, gamma=2, b=1): 240 | super(ECA, self).__init__() 241 | k = int(abs((math.log(in_channel, 2)+b)/gamma)) 242 | kernel_size = k if k % 2 else k+1 243 | padding = kernel_size//2 244 | self.pool = nn.AdaptiveAvgPool2d(output_size=1) 245 | self.conv = nn.Sequential( 246 | nn.Conv1d(in_channels=1, out_channels=1, kernel_size=kernel_size, padding=padding, bias=False), 247 | nn.Sigmoid() 248 | ) 249 | 250 | def forward(self, x): 251 | out = self.pool(x) 252 | out = out.view(x.size(0), 1, x.size(1)) 253 | out = self.conv(out) 254 | out = out.view(x.size(0), x.size(1), 1, 1) 255 | return out*x 256 | 257 | class FRFB(nn.Module): 258 | 259 | def __init__(self, in_planes, out_planes, stride=1, scale=0.1): 260 | super(FRFB, self).__init__() 261 | self.scale = scale 262 | self.out_channels = out_planes 263 | inter_planes = in_planes // 8 264 | 265 | self.branch0 = nn.Sequential( 266 | BasicConv(in_planes, inter_planes // 2, kernel_size=1, stride=1), 267 | BasicConv(inter_planes // 2, inter_planes, kernel_size=3, stride=1, padding=1, relu=False) 268 | ) 269 | self.branch1 = nn.Sequential( 270 | BasicConv(in_planes, inter_planes, kernel_size=1, stride=1), 271 | BasicConv(inter_planes, inter_planes, kernel_size=3, stride=1, padding=1), 272 | BasicConv(inter_planes, inter_planes, kernel_size=3, stride=1, padding=3, dilation=3, relu=False) 273 | ) 274 | 275 | self.branch2 = nn.Sequential( 276 | BasicConv(in_planes, inter_planes, kernel_size=1, stride=1), 277 | BasicConv(inter_planes, inter_planes, kernel_size=5, stride=1, padding=2), 278 | BasicConv(inter_planes, inter_planes, kernel_size=3, stride=1, padding=5, dilation=5, relu=False) 279 | ) 280 | 281 | self.branch3 = nn.Sequential( 282 | BasicConv(in_planes, inter_planes, kernel_size=1, stride=1), 283 | BasicConv(inter_planes, inter_planes, kernel_size=(7, 7), stride=stride, padding=3), 284 | BasicConv(inter_planes, inter_planes, kernel_size=3, stride=1, padding=7, dilation=7, relu=False) 285 | ) 286 | self.ConvLinear = BasicConv(4 * inter_planes, out_planes, kernel_size=1, stride=1, relu=False) 287 | self.shortcut = BasicConv(in_planes, out_planes, kernel_size=1, stride=stride, relu=False) 288 | self.relu = nn.GELU() 289 | 290 | def forward(self, x): 291 | x0 = self.branch0(x) 292 | x1 = self.branch1(x) 293 | x2 = self.branch2(x) 294 | x3 = self.branch3(x) 295 | out = torch.cat((x0, x1, x2, x3), 1) 296 | out = self.ConvLinear(out) 297 | short = self.shortcut(x) 298 | out = torch.add(out * self.scale, short) 299 | out = self.relu(out) 300 | return out 301 | 302 | class Semanticbranch(nn.Module): 303 | def __init__(self, *args, **kwargs): 304 | super(Semanticbranch, self).__init__() 305 | self.resnet = Resnet18() 306 | self.Norm = FRFB(512, 512, stride=1, scale=1.0) 307 | self.arm16 = FAC(256, 128) 308 | self.arm32 = FAC(512, 128) 309 | self.conv_head32 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1) 310 | self.conv_head16 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1) 311 | self.conv_avg = ConvBNReLU(512, 128, ks=1, stride=1, padding=0) 312 | self.up32 = nn.Upsample(scale_factor=2.) 313 | self.up16 = nn.Upsample(scale_factor=2.) 314 | def forward(self, x): 315 | x, feat8, feat16, feat32 = self.resnet(x)#[16, 512, 20, 15] 316 | feat32 = self.Norm(feat32)#16, 512, 20, 15 317 | avg = torch.mean(feat32, dim=(2, 3), keepdim=True)#16, 512, 1, 1 318 | avg = self.conv_avg(avg)#16, 128, 1, 1 319 | feat32_arm = self.arm32(feat32)#16, 128, 20, 15 320 | feat32_sum = torch.add(feat32_arm, avg)#16, 128, 20, 15 321 | feat32_up = self.up32(feat32_sum)#16, 128, 40, 30 322 | feat32_up = self.conv_head32(feat32_up)#16, 128, 40, 30 323 | feat16_arm = self.arm16(feat16)#16, 128, 40, 30 324 | feat16_sum = torch.add(feat16_arm, feat32_up)#t[16, 128, 40, 30]) 325 | feat16_up = self.up16(feat16_sum) 326 | feat16_up = self.conv_head16(feat16_up)#16, 128, 80, 60] 327 | return x, feat8, feat16_up, feat32_up # x8, x16 328 | 329 | class Detailbrabch1(nn.Module): 330 | def __init__(self, *args, **kwargs): 331 | super(Detailbrabch1, self).__init__() 332 | self.conv1 = ConvBNReLU(64, 128, ks=3, stride=2, padding=1) 333 | def forward(self, x): 334 | feat = self.conv1(x) 335 | return feat 336 | 337 | class BFM(nn.Module): 338 | 339 | def __init__(self): 340 | super(BFM, self).__init__() 341 | self.left1 = nn.Sequential( 342 | nn.Conv2d( 343 | 128, 128, kernel_size=3, stride=1, 344 | padding=1, groups=128, bias=False), 345 | nn.BatchNorm2d(128), 346 | nn.Conv2d( 347 | 128, 128, kernel_size=1, stride=1, 348 | padding=0, bias=False), 349 | ) 350 | self.left2 = nn.Sequential( 351 | nn.Conv2d( 352 | 128, 128, kernel_size=3, stride=1, 353 | padding=1, bias=False), 354 | nn.BatchNorm2d(128), 355 | nn.AvgPool2d(kernel_size=3, stride=1, padding=1, ceil_mode=False) 356 | ) 357 | self.right1 = nn.Sequential( 358 | nn.Conv2d( 359 | 128, 128, kernel_size=3, stride=1, 360 | padding=1, bias=False), 361 | nn.BatchNorm2d(128), 362 | ) 363 | self.right2 = nn.Sequential( 364 | nn.Conv2d( 365 | 128, 128, kernel_size=3, stride=1, 366 | padding=1, groups=128, bias=False), 367 | nn.BatchNorm2d(128), 368 | nn.Conv2d( 369 | 128, 128, kernel_size=1, stride=1, 370 | padding=0, bias=False), 371 | ) 372 | 373 | self.conv = nn.Sequential( 374 | nn.Conv2d( 375 | 128, 128, kernel_size=3, stride=1, 376 | padding=1, bias=False), 377 | nn.BatchNorm2d(128), 378 | nn.ReLU(inplace=True), 379 | ) 380 | 381 | def forward(self, x_d, x_s): 382 | left1 = self.left1(x_d) 383 | left2 = self.left2(x_d) 384 | right1 = self.right1(x_s) 385 | right2 = self.right2(x_s) 386 | left = left1 * torch.sigmoid(right1) 387 | right = left2 * torch.sigmoid(right2) 388 | out = self.conv(left + right) 389 | return out 390 | 391 | 392 | -------------------------------------------------------------------------------- /geoseg/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wxy16/SSFNet/b424b22bbab4a142f150a417be19880102462618/geoseg/models/__init__.py -------------------------------------------------------------------------------- /inference_uavid.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from pathlib import Path 3 | import glob 4 | from PIL import Image 5 | import ttach as tta 6 | import cv2 7 | import numpy as np 8 | import torch 9 | import albumentations as albu 10 | from catalyst.dl import SupervisedRunner 11 | from skimage.morphology import remove_small_holes, remove_small_objects 12 | from tools.cfg import py2cfg 13 | from torch import nn 14 | from torch.utils.data import Dataset, DataLoader 15 | from tqdm import tqdm 16 | from train_supervision import * 17 | import random 18 | import os 19 | 20 | 21 | def seed_everything(seed): 22 | random.seed(seed) 23 | os.environ['PYTHONHASHSEED'] = str(seed) 24 | np.random.seed(seed) 25 | torch.manual_seed(seed) 26 | torch.cuda.manual_seed(seed) 27 | torch.backends.cudnn.deterministic = True 28 | torch.backends.cudnn.benchmark = True 29 | 30 | 31 | def pv2rgb(mask): 32 | h, w = mask.shape[0], mask.shape[1] 33 | mask_rgb = np.zeros(shape=(h, w, 3), dtype=np.uint8) 34 | mask_convert = mask[np.newaxis, :, :] 35 | mask_rgb[np.all(mask_convert == 3, axis=0)] = [0, 255, 0] 36 | mask_rgb[np.all(mask_convert == 0, axis=0)] = [255, 255, 255] 37 | mask_rgb[np.all(mask_convert == 1, axis=0)] = [255, 0, 0] 38 | mask_rgb[np.all(mask_convert == 2, axis=0)] = [255, 255, 0] 39 | mask_rgb[np.all(mask_convert == 4, axis=0)] = [0, 204, 255] 40 | mask_rgb[np.all(mask_convert == 5, axis=0)] = [0, 0, 255] 41 | mask_rgb = cv2.cvtColor(mask_rgb, cv2.COLOR_RGB2BGR) 42 | return mask_rgb 43 | 44 | 45 | def landcoverai_to_rgb(mask): 46 | w, h = mask.shape[0], mask.shape[1] 47 | mask_rgb = np.zeros(shape=(w, h, 3), dtype=np.uint8) 48 | mask_convert = mask[np.newaxis, :, :] 49 | mask_rgb[np.all(mask_convert == 3, axis=0)] = [255, 255, 255] 50 | mask_rgb[np.all(mask_convert == 0, axis=0)] = [233, 193, 133] 51 | mask_rgb[np.all(mask_convert == 1, axis=0)] = [255, 0, 0] 52 | mask_rgb[np.all(mask_convert == 2, axis=0)] = [0, 255, 0] 53 | mask_rgb = cv2.cvtColor(mask_rgb, cv2.COLOR_RGB2BGR) 54 | return mask_rgb 55 | 56 | 57 | def uavid2rgb(mask): 58 | h, w = mask.shape[0], mask.shape[1] 59 | mask_rgb = np.zeros(shape=(h, w, 3), dtype=np.uint8) 60 | mask_convert = mask[np.newaxis, :, :] 61 | mask_rgb[np.all(mask_convert == 0, axis=0)] = [128, 0, 0] 62 | mask_rgb[np.all(mask_convert == 1, axis=0)] = [128, 64, 128] 63 | mask_rgb[np.all(mask_convert == 2, axis=0)] = [0, 128, 0] 64 | mask_rgb[np.all(mask_convert == 3, axis=0)] = [128, 128, 0] 65 | mask_rgb[np.all(mask_convert == 4, axis=0)] = [64, 0, 128] 66 | mask_rgb[np.all(mask_convert == 5, axis=0)] = [192, 0, 192] 67 | mask_rgb[np.all(mask_convert == 6, axis=0)] = [64, 64, 0] 68 | mask_rgb[np.all(mask_convert == 7, axis=0)] = [0, 0, 0] 69 | mask_rgb = cv2.cvtColor(mask_rgb, cv2.COLOR_RGB2BGR) 70 | return mask_rgb 71 | 72 | 73 | def get_args(): 74 | parser = argparse.ArgumentParser() 75 | arg = parser.add_argument 76 | arg("-i", "--image_path", type=str, default='data/uavid/uavid_test', help="Path to huge image") 77 | arg("-c", "--config_path", type=Path, required=True, help="Path to config") 78 | arg("-o", "--output_path", type=Path, help="Path to save resulting masks.", required=True) 79 | arg("-t", "--tta", help="Test time augmentation.", default="lr", choices=[None, "d4", "lr"]) 80 | arg("-ph", "--patch-height", help="height of patch size", type=int, default=1152) 81 | arg("-pw", "--patch-width", help="width of patch size", type=int, default=1024) 82 | arg("-b", "--batch-size", help="batch size", type=int, default=2) 83 | arg("-d", "--dataset", help="dataset", default="uavid", choices=["pv", "landcoverai", "uavid"]) 84 | return parser.parse_args() 85 | 86 | 87 | def load_checkpoint(checkpoint_path, model): 88 | pretrained_dict = torch.load(checkpoint_path)['model_state_dict'] 89 | model_dict = model.state_dict() 90 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} 91 | model_dict.update(pretrained_dict) 92 | model.load_state_dict(model_dict) 93 | 94 | return model 95 | 96 | 97 | def get_img_padded(image, patch_size): 98 | oh, ow = image.shape[0], image.shape[1] 99 | rh, rw = oh % patch_size[0], ow % patch_size[1] 100 | 101 | width_pad = 0 if rw == 0 else patch_size[1] - rw 102 | height_pad = 0 if rh == 0 else patch_size[0] - rh 103 | # print(oh, ow, rh, rw, height_pad, width_pad) 104 | h, w = oh + height_pad, ow + width_pad 105 | 106 | pad = albu.PadIfNeeded(min_height=h, min_width=w, border_mode=0, 107 | position='bottom_right', value=[0, 0, 0])(image=image) 108 | img_pad = pad['image'] 109 | return img_pad, height_pad, width_pad 110 | 111 | 112 | class InferenceDataset(Dataset): 113 | def __init__(self, tile_list=None, transform=albu.Normalize()): 114 | self.tile_list = tile_list 115 | self.transform = transform 116 | 117 | def __getitem__(self, index): 118 | img = self.tile_list[index] 119 | img_id = index 120 | aug = self.transform(image=img) 121 | img = aug['image'] 122 | img = torch.from_numpy(img).permute(2, 0, 1).float() 123 | results = dict(img_id=img_id, img=img) 124 | return results 125 | 126 | def __len__(self): 127 | return len(self.tile_list) 128 | 129 | 130 | def make_dataset_for_one_huge_image(img_path, patch_size): 131 | img = cv2.imread(img_path, cv2.IMREAD_COLOR) 132 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 133 | tile_list = [] 134 | image_pad, height_pad, width_pad = get_img_padded(img.copy(), patch_size) 135 | 136 | output_height, output_width = image_pad.shape[0], image_pad.shape[1] 137 | 138 | for x in range(0, output_height, patch_size[0]): 139 | for y in range(0, output_width, patch_size[1]): 140 | image_tile = image_pad[x:x+patch_size[0], y:y+patch_size[1]] 141 | tile_list.append(image_tile) 142 | 143 | dataset = InferenceDataset(tile_list=tile_list) 144 | return dataset, width_pad, height_pad, output_width, output_height, image_pad, img.shape 145 | 146 | 147 | def main(): 148 | args = get_args() 149 | seed_everything(42) 150 | seqs = os.listdir(args.image_path) 151 | 152 | # print(img_paths) 153 | patch_size = (args.patch_height, args.patch_width) 154 | config = py2cfg(args.config_path) 155 | model = Supervision_Train.load_from_checkpoint(os.path.join(config.weights_path, config.test_weights_name+'.ckpt'), config=config) 156 | 157 | model.cuda(config.gpus[0]) 158 | model.eval() 159 | 160 | if args.tta == "lr": 161 | transforms = tta.Compose( 162 | [ 163 | tta.HorizontalFlip(), 164 | tta.VerticalFlip() 165 | ] 166 | ) 167 | model = tta.SegmentationTTAWrapper(model, transforms) 168 | elif args.tta == "d4": 169 | transforms = tta.Compose( 170 | [ 171 | tta.HorizontalFlip(), 172 | # tta.VerticalFlip(), 173 | # tta.Rotate90(angles=[0, 90, 180, 270]), 174 | tta.Scale(scales=[0.75, 1, 1.25, 1.5, 1.75]), 175 | # tta.Multiply(factors=[0.8, 1, 1.2]) 176 | ] 177 | ) 178 | model = tta.SegmentationTTAWrapper(model, transforms) 179 | 180 | for seq in seqs: 181 | img_paths = [] 182 | output_path = os.path.join(args.output_path, str(seq), 'Labels') 183 | if not os.path.exists(output_path): 184 | os.makedirs(output_path) 185 | for ext in ('*.tif', '*.png', '*.jpg'): 186 | img_paths.extend(glob.glob(os.path.join(args.image_path, str(seq), 'Images', ext))) 187 | img_paths.sort() 188 | # print(img_paths) 189 | for img_path in img_paths: 190 | img_name = img_path.split('/')[-1] 191 | # print('origin mask', original_mask.shape) 192 | dataset, width_pad, height_pad, output_width, output_height, img_pad, img_shape = \ 193 | make_dataset_for_one_huge_image(img_path, patch_size) 194 | # print('img_padded', img_pad.shape) 195 | output_mask = np.zeros(shape=(output_height, output_width), dtype=np.uint8) 196 | output_tiles = [] 197 | k = 0 198 | with torch.no_grad(): 199 | dataloader = DataLoader(dataset=dataset, batch_size=args.batch_size, 200 | drop_last=False, shuffle=False) 201 | for input in tqdm(dataloader): 202 | # raw_prediction NxCxHxW 203 | raw_predictions = model(input['img'].cuda(config.gpus[0])) 204 | # print('raw_pred shape:', raw_predictions.shape) 205 | raw_predictions = nn.Softmax(dim=1)(raw_predictions) 206 | # input_images['features'] NxCxHxW C=3 207 | predictions = raw_predictions.argmax(dim=1) 208 | image_ids = input['img_id'] 209 | # print('prediction', predictions.shape) 210 | # print(np.unique(predictions)) 211 | 212 | for i in range(predictions.shape[0]): 213 | raw_mask = predictions[i].cpu().numpy() 214 | mask = raw_mask 215 | output_tiles.append((mask, image_ids[i].cpu().numpy())) 216 | 217 | for m in range(0, output_height, patch_size[0]): 218 | for n in range(0, output_width, patch_size[1]): 219 | output_mask[m:m + patch_size[0], n:n + patch_size[1]] = output_tiles[k][0] 220 | k = k + 1 221 | 222 | output_mask = output_mask[-img_shape[0]:, -img_shape[1]:] 223 | 224 | # print('mask', output_mask.shape) 225 | if args.dataset == 'landcoverai': 226 | output_mask = landcoverai_to_rgb(output_mask) 227 | elif args.dataset == 'pv': 228 | output_mask = pv2rgb(output_mask) 229 | elif args.dataset == 'uavid': 230 | output_mask = uavid2rgb(output_mask) 231 | else: 232 | output_mask = output_mask 233 | assert img_shape == output_mask.shape 234 | cv2.imwrite(os.path.join(output_path, img_name), output_mask) 235 | 236 | 237 | if __name__ == "__main__": 238 | main() 239 | -------------------------------------------------------------------------------- /modules/BFM.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class BFM(nn.Module): 5 | def __init__(self): 6 | super(BFM, self).__init__() 7 | self.left1 = nn.Sequential( 8 | nn.Conv2d( 9 | 128, 128, kernel_size=3, stride=1, # groups=128,DWConv 10 | padding=1, groups=128, bias=False), 11 | nn.BatchNorm2d(128), 12 | nn.Conv2d( 13 | 128, 128, kernel_size=1, stride=1, 14 | padding=0, bias=False), 15 | ) 16 | self.left2 = nn.Sequential( 17 | nn.Conv2d( 18 | 128, 128, kernel_size=3, stride=1, 19 | padding=1, bias=False), 20 | nn.BatchNorm2d(128), 21 | nn.AvgPool2d(kernel_size=3, stride=1, padding=1, ceil_mode=False) 22 | ) 23 | self.right1 = nn.Sequential( 24 | nn.Conv2d( 25 | 128, 128, kernel_size=3, stride=1, 26 | padding=1, bias=False), 27 | nn.BatchNorm2d(128), 28 | ) 29 | self.right2 = nn.Sequential( 30 | nn.Conv2d( 31 | 128, 128, kernel_size=3, stride=1, 32 | padding=1, groups=128, bias=False), 33 | nn.BatchNorm2d(128), 34 | nn.Conv2d( 35 | 128, 128, kernel_size=1, stride=1, 36 | padding=0, bias=False), 37 | ) 38 | 39 | self.conv = nn.Sequential( 40 | nn.Conv2d( 41 | 128, 128, kernel_size=3, stride=1, 42 | padding=1, bias=False), 43 | nn.BatchNorm2d(128), 44 | nn.ReLU(inplace=True), 45 | 46 | ) 47 | 48 | def forward(self, x_d, x_s): 49 | left1 = self.left1(x_d) 50 | left2 = self.left2(x_d) 51 | right1 = self.right1(x_s) 52 | right2 = self.right2(x_s) 53 | left = left1 * torch.sigmoid(right1) 54 | right = left2 * torch.sigmoid(right2) 55 | out = self.conv(left + right) 56 | return out -------------------------------------------------------------------------------- /modules/FAC.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import BatchNorm2d 4 | 5 | class FAC(nn.Module): 6 | def __init__(self, in_chan, out_chan, *args, **kwargs): 7 | super(FAC, self).__init__() 8 | self.conv = ConvBNReLU(in_chan, out_chan, ks=3, stride=1, padding=1) 9 | self.conv_atten = nn.Conv2d(out_chan, out_chan, kernel_size=1, bias=False) 10 | self.bn_atten = BatchNorm2d(out_chan) 11 | 12 | 13 | def forward(self, x): 14 | x = self.conv(x) 15 | feat = torch.mean(x, dim=(2, 3), keepdim=True) 16 | feat = self.conv_atten(feat) 17 | feat = self.bn_atten(feat) 18 | feat = feat.sigmoid() 19 | out = torch.mul(x, feat) 20 | return out 21 | 22 | class ConvBNReLU(nn.Module): 23 | 24 | def __init__(self, in_chan, out_chan, ks=3, stride=1, padding=1, *args, **kwargs): 25 | super(ConvBNReLU, self).__init__() 26 | self.conv = nn.Conv2d(in_chan, 27 | out_chan, 28 | kernel_size = ks, 29 | stride = stride, 30 | padding = padding, 31 | bias = False) 32 | self.bn = BatchNorm2d(out_chan) 33 | self.relu = nn.ReLU(inplace=True) 34 | self.init_weight() 35 | 36 | def forward(self, x): 37 | x = self.conv(x) 38 | x = self.bn(x) 39 | x = self.relu(x) 40 | return x -------------------------------------------------------------------------------- /modules/FRFB.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | 5 | class FRFB(nn.Module): 6 | 7 | def __init__(self, in_planes, out_planes, stride=1, scale=0.1): 8 | super(FRFB, self).__init__() 9 | self.scale = scale 10 | self.out_channels = out_planes 11 | inter_planes = in_planes // 8 12 | 13 | self.branch0 = nn.Sequential( 14 | BasicConv(in_planes, inter_planes // 2, kernel_size=1, stride=1), 15 | BasicConv(inter_planes // 2, inter_planes, kernel_size=3, stride=1, padding=1, relu=False) 16 | ) 17 | self.branch1 = nn.Sequential( 18 | BasicConv(in_planes, inter_planes, kernel_size=1, stride=1), 19 | BasicConv(inter_planes, inter_planes, kernel_size=3, stride=1, padding=1), 20 | BasicConv(inter_planes, inter_planes, kernel_size=3, stride=1, padding=3, dilation=3, relu=False) 21 | ) 22 | 23 | self.branch2 = nn.Sequential( 24 | BasicConv(in_planes, inter_planes, kernel_size=1, stride=1), 25 | BasicConv(inter_planes, inter_planes, kernel_size=5, stride=1, padding=2), 26 | BasicConv(inter_planes, inter_planes, kernel_size=3, stride=1, padding=5, dilation=5, relu=False) 27 | ) 28 | 29 | self.branch3 = nn.Sequential( 30 | BasicConv(in_planes, inter_planes, kernel_size=1, stride=1), 31 | BasicConv(inter_planes, inter_planes, kernel_size=7, stride=1, padding=3), 32 | BasicConv(inter_planes, inter_planes, kernel_size=3, stride=1, padding=7, dilation=7, relu=False) 33 | ) 34 | self.ConvLinear = BasicConv(4 * inter_planes, out_planes, kernel_size=1, stride=1, relu=False) 35 | self.shortcut = BasicConv(in_planes, out_planes, kernel_size=1, stride=stride, relu=False) 36 | self.relu = nn.ReLU(inplace=False) 37 | 38 | def forward(self, x): 39 | x0 = self.branch0(x) 40 | x1 = self.branch1(x) 41 | x2 = self.branch3(x) 42 | x3 = self.branch5(x) 43 | out = torch.cat((x0, x1, x2, x3), 1) 44 | out = self.ConvLinear(out) 45 | short = self.shortcut(x) 46 | out = torch.add(out * self.scale, short) 47 | out = self.relu(out) 48 | return out 49 | 50 | class BasicConv(nn.Module): 51 | 52 | def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1, relu=True, bn=True, bias=False): 53 | super(BasicConv, self).__init__() 54 | self.out_channels = out_planes 55 | self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias) 56 | self.bn = nn.BatchNorm2d(out_planes,eps=1e-5, momentum=0.01, affine=True) if bn else None 57 | self.relu = nn.ReLU(inplace=True) if relu else None 58 | 59 | def forward(self, x): 60 | x = self.conv(x) 61 | if self.bn is not None: 62 | x = self.bn(x) 63 | if self.relu is not None: 64 | x = self.relu(x) 65 | return x -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | timm 2 | catalyst==20.09 3 | pytorch-lightning==1.5.9 4 | albumentations==1.1.0 5 | ttach 6 | numpy 7 | tqdm 8 | opencv-python 9 | scipy 10 | matplotlib 11 | einops 12 | addict -------------------------------------------------------------------------------- /tools/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wxy16/SSFNet/b424b22bbab4a142f150a417be19880102462618/tools/__init__.py -------------------------------------------------------------------------------- /tools/cfg.py: -------------------------------------------------------------------------------- 1 | import pydoc 2 | import sys 3 | from importlib import import_module 4 | from pathlib import Path 5 | from typing import Union 6 | 7 | from addict import Dict 8 | 9 | 10 | class ConfigDict(Dict): 11 | def __missing__(self, name): 12 | raise KeyError(name) 13 | 14 | def __getattr__(self, name): 15 | try: 16 | value = super().__getattr__(name) 17 | except KeyError: 18 | ex = AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'") 19 | else: 20 | return value 21 | raise ex 22 | 23 | 24 | def py2dict(file_path: Union[str, Path]) -> dict: 25 | """Convert python file to dictionary. 26 | The main use - config parser. 27 | file: 28 | ``` 29 | a = 1 30 | b = 3 31 | c = range(10) 32 | ``` 33 | will be converted to 34 | {'a':1, 35 | 'b':3, 36 | 'c': range(10) 37 | } 38 | Args: 39 | file_path: path to the original python file. 40 | Returns: {key: value}, where key - all variables defined in the file and value is their value. 41 | """ 42 | file_path = Path(file_path).absolute() 43 | 44 | if file_path.suffix != ".py": 45 | raise TypeError(f"Only Py file can be parsed, but got {file_path.name} instead.") 46 | 47 | if not file_path.exists(): 48 | raise FileExistsError(f"There is no file at the path {file_path}") 49 | 50 | module_name = file_path.stem 51 | 52 | if "." in module_name: 53 | raise ValueError("Dots are not allowed in config file path.") 54 | 55 | config_dir = str(file_path.parent) 56 | 57 | sys.path.insert(0, config_dir) 58 | 59 | mod = import_module(module_name) 60 | sys.path.pop(0) 61 | cfg_dict = {name: value for name, value in mod.__dict__.items() if not name.startswith("__")} 62 | 63 | return cfg_dict 64 | 65 | 66 | def py2cfg(file_path: Union[str, Path]) -> ConfigDict: 67 | cfg_dict = py2dict(file_path) 68 | 69 | return ConfigDict(cfg_dict) 70 | 71 | 72 | def object_from_dict(d, parent=None, **default_kwargs): 73 | kwargs = d.copy() 74 | object_type = kwargs.pop("type") 75 | for name, value in default_kwargs.items(): 76 | kwargs.setdefault(name, value) 77 | 78 | if parent is not None: 79 | return getattr(parent, object_type)(**kwargs) # skipcq PTC-W0034 80 | 81 | return pydoc.locate(object_type)(**kwargs) -------------------------------------------------------------------------------- /tools/img.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wxy16/SSFNet/b424b22bbab4a142f150a417be19880102462618/tools/img.png -------------------------------------------------------------------------------- /tools/loveda_mask_convert.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | import numpy as np 4 | import cv2 5 | import multiprocessing.pool as mpp 6 | import multiprocessing as mp 7 | import time 8 | import argparse 9 | import torch 10 | import random 11 | 12 | SEED = 42 13 | 14 | CLASSES = ('background', 'building', 'road', 'water', 'barren', 'forest', 15 | 'agricultural') 16 | 17 | PALETTE = [[255, 255, 255], [255, 0, 0], [255, 255, 0], [0, 0, 255], 18 | [159, 129, 183], [0, 255, 0], [255, 195, 128]] 19 | 20 | 21 | def seed_everything(seed): 22 | random.seed(seed) 23 | os.environ['PYTHONHASHSEED'] = str(seed) 24 | np.random.seed(seed) 25 | torch.manual_seed(seed) 26 | torch.cuda.manual_seed(seed) 27 | torch.backends.cudnn.deterministic = True 28 | torch.backends.cudnn.benchmark = True 29 | 30 | 31 | def parse_args(): 32 | parser = argparse.ArgumentParser() 33 | parser.add_argument("--mask-dir", default="data/LoveDA/Train/Rural/masks_png") 34 | parser.add_argument("--output-mask-dir", default="data/LoveDA/Train/Rural/masks_png_convert") 35 | return parser.parse_args() 36 | 37 | 38 | def convert_label(mask): 39 | mask[mask == 0] = 8 40 | mask -= 1 41 | 42 | return mask 43 | 44 | 45 | def label2rgb(mask): 46 | h, w = mask.shape[0], mask.shape[1] 47 | mask_rgb = np.zeros(shape=(h, w, 3), dtype=np.uint8) 48 | mask_convert = mask[np.newaxis, :, :] 49 | mask_rgb[np.all(mask_convert == 0, axis=0)] = [255, 0, 0] 50 | mask_rgb[np.all(mask_convert == 1, axis=0)] = [255, 255, 0] 51 | mask_rgb[np.all(mask_convert == 2, axis=0)] = [0, 0, 255] 52 | mask_rgb[np.all(mask_convert == 3, axis=0)] = [255, 255, 255] 53 | mask_rgb[np.all(mask_convert == 4, axis=0)] = [0, 255, 0] 54 | mask_rgb[np.all(mask_convert == 5, axis=0)] = [255, 195, 128] 55 | mask_rgb[np.all(mask_convert == 6, axis=0)] = [0, 255, 0] 56 | mask_rgb[np.all(mask_convert == 7, axis=0)] = [255, 255, 255] 57 | return mask_rgb 58 | 59 | 60 | def patch_format(inp): 61 | (mask_path, masks_output_dir) = inp 62 | # print(mask_path, masks_output_dir) 63 | mask_filename = os.path.splitext(os.path.basename(mask_path))[0] 64 | mask = cv2.imread(mask_path, cv2.IMREAD_UNCHANGED) 65 | label = convert_label(mask) 66 | rgb_label = label2rgb(label.copy()) 67 | rgb_label = cv2.cvtColor(rgb_label, cv2.COLOR_RGB2BGR) 68 | out_mask_path_rgb = os.path.join(masks_output_dir + '_rgb', "{}.png".format(mask_filename)) 69 | cv2.imwrite(out_mask_path_rgb, rgb_label) 70 | 71 | out_mask_path = os.path.join(masks_output_dir, "{}.png".format(mask_filename)) 72 | cv2.imwrite(out_mask_path, label) 73 | 74 | 75 | if __name__ == "__main__": 76 | seed_everything(SEED) 77 | args = parse_args() 78 | masks_dir = args.mask_dir 79 | masks_output_dir = args.output_mask_dir 80 | mask_paths = glob.glob(os.path.join(masks_dir, "*.png")) 81 | 82 | if not os.path.exists(masks_output_dir): 83 | os.makedirs(masks_output_dir) 84 | os.makedirs(masks_output_dir + '_rgb') 85 | 86 | inp = [(mask_path, masks_output_dir) for mask_path in mask_paths] 87 | 88 | t0 = time.time() 89 | mpp.Pool(processes=mp.cpu_count()).map(patch_format, inp) 90 | t1 = time.time() 91 | split_time = t1 - t0 92 | print('images spliting spends: {} s'.format(split_time)) 93 | 94 | 95 | -------------------------------------------------------------------------------- /tools/metric.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class Evaluator(object): 5 | def __init__(self, num_class): 6 | self.num_class = num_class 7 | self.confusion_matrix = np.zeros((self.num_class,) * 2) 8 | self.eps = 1e-8 9 | 10 | def get_tp_fp_tn_fn(self): 11 | tp = np.diag(self.confusion_matrix) 12 | fp = self.confusion_matrix.sum(axis=0) - np.diag(self.confusion_matrix) 13 | fn = self.confusion_matrix.sum(axis=1) - np.diag(self.confusion_matrix) 14 | tn = np.diag(self.confusion_matrix).sum() - np.diag(self.confusion_matrix) 15 | return tp, fp, tn, fn 16 | 17 | def Precision(self): 18 | tp, fp, tn, fn = self.get_tp_fp_tn_fn() 19 | precision = tp / (tp + fp) 20 | return precision 21 | 22 | def Recall(self): 23 | tp, fp, tn, fn = self.get_tp_fp_tn_fn() 24 | recall = tp / (tp + fn) 25 | return recall 26 | 27 | def F1(self): 28 | tp, fp, tn, fn = self.get_tp_fp_tn_fn() 29 | Precision = tp / (tp + fp) 30 | Recall = tp / (tp + fn) 31 | F1 = (2.0 * Precision * Recall) / (Precision + Recall) 32 | return F1 33 | 34 | def OA(self): 35 | OA = np.diag(self.confusion_matrix).sum() / (self.confusion_matrix.sum() + self.eps) 36 | return OA 37 | 38 | def Intersection_over_Union(self): 39 | tp, fp, tn, fn = self.get_tp_fp_tn_fn() 40 | IoU = tp / (tp + fn + fp) 41 | return IoU 42 | 43 | def Dice(self): 44 | tp, fp, tn, fn = self.get_tp_fp_tn_fn() 45 | Dice = 2 * tp / ((tp + fp) + (tp + fn)) 46 | return Dice 47 | 48 | def Pixel_Accuracy_Class(self): 49 | # TP TP+FP 50 | Acc = np.diag(self.confusion_matrix) / (self.confusion_matrix.sum(axis=0) + self.eps) 51 | return Acc 52 | 53 | def Frequency_Weighted_Intersection_over_Union(self): 54 | freq = np.sum(self.confusion_matrix, axis=1) / (np.sum(self.confusion_matrix) + self.eps) 55 | iou = self.Intersection_over_Union() 56 | FWIoU = (freq[freq > 0] * iou[freq > 0]).sum() 57 | return FWIoU 58 | 59 | def _generate_matrix(self, gt_image, pre_image): 60 | mask = (gt_image >= 0) & (gt_image < self.num_class) 61 | label = self.num_class * gt_image[mask].astype('int') + pre_image[mask] 62 | count = np.bincount(label, minlength=self.num_class ** 2) 63 | confusion_matrix = count.reshape(self.num_class, self.num_class) 64 | return confusion_matrix 65 | 66 | def add_batch(self, gt_image, pre_image): 67 | assert gt_image.shape == pre_image.shape, 'pre_image shape {}, gt_image shape {}'.format(pre_image.shape, 68 | gt_image.shape) 69 | self.confusion_matrix += self._generate_matrix(gt_image, pre_image) 70 | 71 | def reset(self): 72 | self.confusion_matrix = np.zeros((self.num_class,) * 2) 73 | 74 | 75 | if __name__ == '__main__': 76 | 77 | gt = np.array([[0, 2, 1], 78 | [1, 2, 1], 79 | [1, 0, 1]]) 80 | 81 | pre = np.array([[0, 1, 1], 82 | [2, 0, 1], 83 | [1, 1, 1]]) 84 | 85 | eval = Evaluator(num_class=3) 86 | eval.add_batch(gt, pre) 87 | print(eval.confusion_matrix) 88 | print(eval.get_tp_fp_tn_fn()) 89 | print(eval.Precision()) 90 | print(eval.Recall()) 91 | print(eval.Intersection_over_Union()) 92 | print(eval.OA()) 93 | print(eval.F1()) 94 | print(eval.Frequency_Weighted_Intersection_over_Union()) 95 | -------------------------------------------------------------------------------- /tools/uavid_patch_split.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | import numpy as np 4 | import cv2 5 | import multiprocessing.pool as mpp 6 | import multiprocessing as mp 7 | import time 8 | import argparse 9 | import torch 10 | import albumentations as albu 11 | 12 | import random 13 | 14 | def seed_everything(seed): 15 | random.seed(seed) 16 | os.environ['PYTHONHASHSEED'] = str(seed) 17 | np.random.seed(seed) 18 | torch.manual_seed(seed) 19 | torch.cuda.manual_seed(seed) 20 | torch.backends.cudnn.deterministic = True 21 | torch.backends.cudnn.benchmark = True 22 | 23 | 24 | Building = np.array([128, 0, 0]) # label 0 25 | Road = np.array([128, 64, 128]) # label 1 26 | Tree = np.array([0, 128, 0]) # label 2 27 | LowVeg = np.array([128, 128, 0]) # label 3 28 | Moving_Car = np.array([64, 0, 128]) # label 4 29 | Static_Car = np.array([192, 0, 192]) # label 5 30 | Human = np.array([64, 64, 0]) # label 6 31 | Clutter = np.array([0, 0, 0]) # label 7 32 | Boundary = np.array([255, 255, 255]) # label 255 33 | 34 | num_classes = 8 35 | 36 | 37 | # split huge RS image to small patches 38 | def parse_args(): 39 | parser = argparse.ArgumentParser() 40 | parser.add_argument("--input-dir", default="data/uavid/uavid_train_val") 41 | parser.add_argument("--output-img-dir", default="data/uavid/train_val/images") 42 | parser.add_argument("--output-mask-dir", default="data/uavid/train_val/masks") 43 | parser.add_argument("--mode", type=str, default='train') 44 | parser.add_argument("--split-size-h", type=int, default=1024) 45 | parser.add_argument("--split-size-w", type=int, default=1024) 46 | parser.add_argument("--stride-h", type=int, default=1024) 47 | parser.add_argument("--stride-w", type=int, default=1024) 48 | return parser.parse_args() 49 | 50 | 51 | def label2rgb(mask): 52 | h, w = mask.shape[0], mask.shape[1] 53 | mask_rgb = np.zeros(shape=(h, w, 3), dtype=np.uint8) 54 | mask_convert = mask[np.newaxis, :, :] 55 | mask_rgb[np.all(mask_convert == 0, axis=0)] = Building 56 | mask_rgb[np.all(mask_convert == 1, axis=0)] = Road 57 | mask_rgb[np.all(mask_convert == 2, axis=0)] = Tree 58 | mask_rgb[np.all(mask_convert == 3, axis=0)] = LowVeg 59 | mask_rgb[np.all(mask_convert == 4, axis=0)] = Moving_Car 60 | mask_rgb[np.all(mask_convert == 5, axis=0)] = Static_Car 61 | mask_rgb[np.all(mask_convert == 6, axis=0)] = Human 62 | mask_rgb[np.all(mask_convert == 7, axis=0)] = Clutter 63 | mask_rgb[np.all(mask_convert == 255, axis=0)] = Boundary 64 | return mask_rgb 65 | 66 | 67 | def rgb2label(label): 68 | label_seg = np.zeros(label.shape[:2], dtype=np.uint8) 69 | label_seg[np.all(label == Building, axis=-1)] = 0 70 | label_seg[np.all(label == Road, axis=-1)] = 1 71 | label_seg[np.all(label == Tree, axis=-1)] = 2 72 | label_seg[np.all(label == LowVeg, axis=-1)] = 3 73 | label_seg[np.all(label == Moving_Car, axis=-1)] = 4 74 | label_seg[np.all(label == Static_Car, axis=-1)] = 5 75 | label_seg[np.all(label == Human, axis=-1)] = 6 76 | label_seg[np.all(label == Clutter, axis=-1)] = 7 77 | label_seg[np.all(label == Boundary, axis=-1)] = 255 78 | return label_seg 79 | 80 | 81 | def image_augment(image, mask, mode='train'): 82 | image_list = [] 83 | mask_list = [] 84 | image_width, image_height = image.shape[1], image.shape[0] 85 | mask_width, mask_height = mask.shape[1], mask.shape[0] 86 | assert image_height == mask_height and image_width == mask_width 87 | if mode == 'train': 88 | image_list_train = [image] 89 | mask_list_train = [mask] 90 | for i in range(len(image_list_train)): 91 | mask_tmp = rgb2label(mask_list_train[i]) 92 | image_list.append(image_list_train[i]) 93 | mask_list.append(mask_tmp) 94 | else: 95 | mask = rgb2label(mask.copy()) 96 | image_list.append(image) 97 | mask_list.append(mask) 98 | return image_list, mask_list 99 | 100 | 101 | def padifneeded(image, mask): 102 | pad = albu.PadIfNeeded(min_height=2160, min_width=4096, position='bottom_right', 103 | border_mode=0, value=[0, 0, 0], mask_value=[255, 255, 255])(image=image, mask=mask) 104 | # pad = albu.PadIfNeeded(min_height=h, min_width=w)(image=image, mask=mask) 105 | img_pad, mask_pad = pad['image'], pad['mask'] 106 | assert img_pad.shape[0] == 2048 or img_pad.shape[1] == 4096, print(img_pad.shape) 107 | # print(img_pad.shape) 108 | return img_pad, mask_pad 109 | 110 | 111 | def patch_format(inp): 112 | (input_dir, seq, imgs_output_dir, masks_output_dir, mode, split_size, stride) = inp 113 | img_paths = glob.glob(os.path.join(input_dir, str(seq), 'Images', "*.png")) 114 | mask_paths = glob.glob(os.path.join(input_dir, str(seq), 'Labels', "*.png")) 115 | for img_path, mask_path in zip(img_paths, mask_paths): 116 | img = cv2.imread(img_path, cv2.IMREAD_COLOR) 117 | mask = cv2.imread(mask_path, cv2.IMREAD_COLOR) 118 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 119 | mask = cv2.cvtColor(mask, cv2.COLOR_BGR2RGB) 120 | id = os.path.splitext(os.path.basename(img_path))[0] 121 | assert img.shape == mask.shape and img.shape[0] == 2160, print(img.shape) 122 | assert img.shape[1] == 3840 or img.shape[1] == 4096, print(img.shape) 123 | img, mask = padifneeded(img.copy(), mask.copy()) 124 | 125 | # print(img_path) 126 | # print(img.size, mask.size) 127 | # img and mask shape: WxHxC 128 | image_list, mask_list = image_augment(image=img.copy(), mask=mask.copy(), mode=mode) 129 | assert len(image_list) == len(mask_list) 130 | for m in range(len(image_list)): 131 | k = 0 132 | img = image_list[m] 133 | mask = mask_list[m] 134 | img, mask = img[-2048:, -4096:, :], mask[-2048:, -4096:] 135 | assert img.shape[0] == mask.shape[0] and img.shape[1] == mask.shape[1] 136 | for y in range(0, img.shape[0], stride[0]): 137 | for x in range(0, img.shape[1], stride[1]): 138 | img_tile_cut = img[y:y + split_size[0], x:x + split_size[1]] 139 | mask_tile_cut = mask[y:y + split_size[0], x:x + split_size[1]] 140 | img_tile, mask_tile = img_tile_cut, mask_tile_cut 141 | 142 | if img_tile.shape[0] == split_size[0] and img_tile.shape[1] == split_size[1] \ 143 | and mask_tile.shape[0] == split_size[0] and mask_tile.shape[1] == split_size[1]: 144 | if mode == 'train': 145 | out_img_path = os.path.join(imgs_output_dir, "{}_{}_{}_{}.png".format(seq, id, m, k)) 146 | img_tile = cv2.cvtColor(img_tile, cv2.COLOR_RGB2BGR) 147 | cv2.imwrite(out_img_path, img_tile) 148 | # print(img_tile.shape) 149 | 150 | out_mask_path = os.path.join(masks_output_dir, 151 | "{}_{}_{}_{}.png".format(seq, id, m, k)) 152 | cv2.imwrite(out_mask_path, mask_tile) 153 | else: 154 | img_tile = cv2.cvtColor(img_tile, cv2.COLOR_RGB2BGR) 155 | out_img_path = os.path.join(imgs_output_dir, "{}_{}_{}_{}.png".format(seq, id, m, k)) 156 | cv2.imwrite(out_img_path, img_tile) 157 | 158 | out_mask_path = os.path.join(masks_output_dir, "{}_{}_{}_{}.png".format(seq, id, m, k)) 159 | cv2.imwrite(out_mask_path, mask_tile) 160 | 161 | k += 1 162 | 163 | 164 | if __name__ == "__main__": 165 | seed_everything(42) 166 | args = parse_args() 167 | input_dir = args.input_dir 168 | imgs_output_dir = args.output_img_dir 169 | masks_output_dir = args.output_mask_dir 170 | mode = args.mode 171 | split_size_h = args.split_size_h 172 | split_size_w = args.split_size_w 173 | split_size = (split_size_h, split_size_w) 174 | stride_h = args.stride_h 175 | stride_w = args.stride_w 176 | stride = (stride_h, stride_w) 177 | seqs = os.listdir(input_dir) 178 | # print(seqs) 179 | 180 | if not os.path.exists(imgs_output_dir): 181 | os.makedirs(imgs_output_dir) 182 | if not os.path.exists(masks_output_dir): 183 | os.makedirs(masks_output_dir) 184 | 185 | inp = [(input_dir, seq, imgs_output_dir, masks_output_dir, mode, split_size, stride) 186 | for seq in seqs] 187 | 188 | t0 = time.time() 189 | mpp.Pool(processes=mp.cpu_count()).map(patch_format, inp) 190 | t1 = time.time() 191 | split_time = t1 - t0 192 | print('images spliting spends: {} s'.format(split_time)) 193 | 194 | 195 | -------------------------------------------------------------------------------- /tools/vaihingen_patch_split.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | import numpy as np 4 | import cv2 5 | from PIL import Image 6 | import multiprocessing.pool as mpp 7 | import multiprocessing as mp 8 | import time 9 | import argparse 10 | import torch 11 | import albumentations as albu 12 | from torchvision.transforms import (Pad, ColorJitter, Resize, FiveCrop, RandomCrop, 13 | RandomHorizontalFlip, RandomRotation, RandomVerticalFlip) 14 | import random 15 | 16 | SEED = 42 17 | 18 | 19 | def seed_everything(seed): 20 | random.seed(seed) 21 | os.environ['PYTHONHASHSEED'] = str(seed) 22 | np.random.seed(seed) 23 | torch.manual_seed(seed) 24 | torch.cuda.manual_seed(seed) 25 | torch.backends.cudnn.deterministic = True 26 | torch.backends.cudnn.benchmark = True 27 | 28 | 29 | ImSurf = np.array([255, 255, 255]) # label 0 30 | Building = np.array([255, 0, 0]) # label 1 31 | LowVeg = np.array([255, 255, 0]) # label 2 32 | Tree = np.array([0, 255, 0]) # label 3 33 | Car = np.array([0, 255, 255]) # label 4 34 | Clutter = np.array([0, 0, 255]) # label 5 35 | Boundary = np.array([0, 0, 0]) # label 6 36 | num_classes = 6 37 | 38 | 39 | # split huge RS image to small patches 40 | def parse_args(): 41 | parser = argparse.ArgumentParser() 42 | parser.add_argument("--img-dir", default="data/vaihingen/train_images") 43 | parser.add_argument("--mask-dir", default="data/vaihingen/train_masks") 44 | parser.add_argument("--output-img-dir", default="data/vaihingen/train/images_1024") 45 | parser.add_argument("--output-mask-dir", default="data/vaihingen/train/masks_1024") 46 | parser.add_argument("--eroded", action='store_true') 47 | parser.add_argument("--gt", action='store_true') 48 | parser.add_argument("--mode", type=str, default='train') 49 | parser.add_argument("--val-scale", type=float, default=1.0) 50 | parser.add_argument("--split-size", type=int, default=1024) 51 | parser.add_argument("--stride", type=int, default=512) 52 | return parser.parse_args() 53 | 54 | 55 | def get_img_mask_padded(image, mask, patch_size, mode): 56 | img, mask = np.array(image), np.array(mask) 57 | oh, ow = img.shape[0], img.shape[1] 58 | rh, rw = oh % patch_size, ow % patch_size 59 | width_pad = 0 if rw == 0 else patch_size - rw 60 | height_pad = 0 if rh == 0 else patch_size - rh 61 | 62 | h, w = oh + height_pad, ow + width_pad 63 | pad_img = albu.PadIfNeeded(min_height=h, min_width=w, position='bottom_right')(image=img) 64 | if mode == 'train': 65 | pad_img = albu.PadIfNeeded(min_height=h, min_width=w, position='bottom_right')(image=img) 66 | 67 | pad_mask = albu.PadIfNeeded(min_height=h, min_width=w, position='bottom_right')(image=mask) 68 | img_pad, mask_pad = pad_img['image'], pad_mask['image'] 69 | img_pad = cv2.cvtColor(np.array(img_pad), cv2.COLOR_RGB2BGR) 70 | mask_pad = cv2.cvtColor(np.array(mask_pad), cv2.COLOR_RGB2BGR) 71 | return img_pad, mask_pad 72 | 73 | 74 | def pv2rgb(mask): 75 | h, w = mask.shape[0], mask.shape[1] 76 | mask_rgb = np.zeros(shape=(h, w, 3), dtype=np.uint8) 77 | mask_convert = mask[np.newaxis, :, :] 78 | mask_rgb[np.all(mask_convert == 3, axis=0)] = [0, 255, 0] 79 | mask_rgb[np.all(mask_convert == 0, axis=0)] = [255, 255, 255] 80 | mask_rgb[np.all(mask_convert == 1, axis=0)] = [255, 0, 0] 81 | mask_rgb[np.all(mask_convert == 2, axis=0)] = [255, 255, 0] 82 | mask_rgb[np.all(mask_convert == 4, axis=0)] = [0, 204, 255] 83 | mask_rgb[np.all(mask_convert == 5, axis=0)] = [0, 0, 255] 84 | return mask_rgb 85 | 86 | 87 | def car_color_replace(mask): 88 | mask = cv2.cvtColor(np.array(mask.copy()), cv2.COLOR_RGB2BGR) 89 | mask[np.all(mask == [0, 255, 255], axis=-1)] = [0, 204, 255] 90 | 91 | return mask 92 | 93 | 94 | def rgb_to_2D_label(_label): 95 | _label = _label.transpose(2, 0, 1) 96 | label_seg = np.zeros(_label.shape[1:], dtype=np.uint8) 97 | label_seg[np.all(_label.transpose([1, 2, 0]) == ImSurf, axis=-1)] = 0 98 | label_seg[np.all(_label.transpose([1, 2, 0]) == Building, axis=-1)] = 1 99 | label_seg[np.all(_label.transpose([1, 2, 0]) == LowVeg, axis=-1)] = 2 100 | label_seg[np.all(_label.transpose([1, 2, 0]) == Tree, axis=-1)] = 3 101 | label_seg[np.all(_label.transpose([1, 2, 0]) == Car, axis=-1)] = 4 102 | label_seg[np.all(_label.transpose([1, 2, 0]) == Clutter, axis=-1)] = 5 103 | label_seg[np.all(_label.transpose([1, 2, 0]) == Boundary, axis=-1)] = 6 104 | return label_seg 105 | 106 | 107 | def image_augment(image, mask, patch_size, mode='train', val_scale=1.0): 108 | image_list = [] 109 | mask_list = [] 110 | image_width, image_height = image.size[1], image.size[0] 111 | mask_width, mask_height = mask.size[1], mask.size[0] 112 | 113 | assert image_height == mask_height and image_width == mask_width 114 | if mode == 'train': 115 | # resize_0 = Resize(size=(int(image_width * 0.25), int(image_height * 0.25))) 116 | # resize_1 = Resize(size=(int(image_width * 0.5), int(image_height * 0.5))) 117 | # resize_2 = Resize(size=(int(image_width * 0.75), int(image_height * 0.75))) 118 | # resize_3 = Resize(size=(int(image_width * 1.25), int(image_height * 1.25))) 119 | # resize_4 = Resize(size=(int(image_width * 1.5), int(image_height * 1.5))) 120 | # resize_5 = Resize(size=(int(image_width * 1.75), int(image_height * 1.75))) 121 | # resize_6 = Resize(size=(int(image_width * 2.0), int(image_height * 2.0))) 122 | # image_resize_0, mask_resize_0 = resize_0(image.copy()), resize_0(mask.copy()) 123 | # image_resize_1, mask_resize_1 = resize_1(image.copy()), resize_1(mask.copy()) 124 | # image_resize_2, mask_resize_2 = resize_2(image.copy()), resize_2(mask.copy()) 125 | # image_resize_3, mask_resize_3 = resize_3(image.copy()), resize_3(mask.copy()) 126 | # image_resize_4, mask_resize_4 = resize_4(image.copy()), resize_4(mask.copy()) 127 | # image_resize_5, mask_resize_5 = resize_5(image.copy()), resize_5(mask.copy()) 128 | # image_resize_6, mask_resize_6 = resize_6(image.copy()), resize_6(mask.copy()) 129 | h_vlip = RandomHorizontalFlip(p=1.0) 130 | v_vlip = RandomVerticalFlip(p=1.0) 131 | # crop_1 = RandomCrop(size=(int(image_width*0.75), int(image_height*0.75))) 132 | # crop_2 = RandomCrop(size=(int(image_width * 0.5), int(image_height * 0.5))) 133 | # color = torchvision.transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2) 134 | image_h_vlip, mask_h_vlip = h_vlip(image.copy()), h_vlip(mask.copy()) 135 | image_v_vlip, mask_v_vlip = v_vlip(image.copy()), v_vlip(mask.copy()) 136 | # image_crop_1, mask_crop_1 = crop_1(image.copy()), crop_1(mask.copy()) 137 | # image_crop_2, mask_crop_2 = crop_2(image.copy()), crop_2(mask.copy()) 138 | # image_color = color(image.copy()) 139 | 140 | image_list_train = [image, image_h_vlip, image_v_vlip] 141 | mask_list_train = [mask, mask_h_vlip, mask_v_vlip] 142 | # image_list_train = [image] 143 | # mask_list_train = [mask] 144 | for i in range(len(image_list_train)): 145 | image_tmp, mask_tmp = get_img_mask_padded(image_list_train[i], mask_list_train[i], patch_size, mode) 146 | mask_tmp = rgb_to_2D_label(mask_tmp.copy()) 147 | image_list.append(image_tmp) 148 | mask_list.append(mask_tmp) 149 | else: 150 | rescale = Resize(size=(int(image_width * val_scale), int(image_height * val_scale))) 151 | image, mask = rescale(image.copy()), rescale(mask.copy()) 152 | image, mask = get_img_mask_padded(image.copy(), mask.copy(), patch_size, mode) 153 | mask = rgb_to_2D_label(mask.copy()) 154 | image_list.append(image) 155 | mask_list.append(mask) 156 | return image_list, mask_list 157 | 158 | 159 | def randomsizedcrop(image, mask): 160 | # assert image.shape[:2] == mask.shape 161 | h, w = image.shape[0], image.shape[1] 162 | crop = albu.RandomSizedCrop(min_max_height=(int(3*h//8), int(h//2)), width=h, height=w)(image=image.copy(), mask=mask.copy()) 163 | img_crop, mask_crop = crop['image'], crop['mask'] 164 | return img_crop, mask_crop 165 | 166 | 167 | def car_aug(image, mask): 168 | assert image.shape[:2] == mask.shape 169 | v_flip = albu.VerticalFlip(p=1.0)(image=image.copy(), mask=mask.copy()) 170 | h_flip = albu.HorizontalFlip(p=1.0)(image=image.copy(), mask=mask.copy()) 171 | rotate_90 = albu.RandomRotate90(p=1.0)(image=image.copy(), mask=mask.copy()) 172 | # blur = albu.GaussianBlur(p=1.0)(image=image.copy()) 173 | image_vflip, mask_vflip = v_flip['image'], v_flip['mask'] 174 | image_hflip, mask_hflip = h_flip['image'], h_flip['mask'] 175 | image_rotate, mask_rotate = rotate_90['image'], rotate_90['mask'] 176 | # blur_image = blur['image'] 177 | image_list = [image, image_vflip, image_hflip, image_rotate] 178 | mask_list = [mask, mask_vflip, mask_hflip, mask_rotate] 179 | 180 | return image_list, mask_list 181 | 182 | 183 | def vaihingen_format(inp): 184 | (img_path, mask_path, imgs_output_dir, masks_output_dir, eroded, gt, mode, val_scale, split_size, stride) = inp 185 | img_filename = os.path.splitext(os.path.basename(img_path))[0] 186 | mask_filename = os.path.splitext(os.path.basename(mask_path))[0] 187 | if eroded: 188 | mask_path = mask_path[:-4] + '_noBoundary.tif' 189 | img = Image.open(img_path).convert('RGB') 190 | mask = Image.open(mask_path).convert('RGB') 191 | if gt: 192 | mask_ = car_color_replace(mask) 193 | out_origin_mask_path = os.path.join(masks_output_dir + '/origin/', "{}.tif".format(mask_filename)) 194 | cv2.imwrite(out_origin_mask_path, mask_) 195 | # print(img_path) 196 | # print(img.size, mask.size) 197 | # img and mask shape: WxHxC 198 | image_list, mask_list = image_augment(image=img.copy(), mask=mask.copy(), patch_size=split_size, 199 | mode=mode, val_scale=val_scale) 200 | assert img_filename == mask_filename and len(image_list) == len(mask_list) 201 | for m in range(len(image_list)): 202 | k = 0 203 | img = image_list[m] 204 | mask = mask_list[m] 205 | assert img.shape[0] == mask.shape[0] and img.shape[1] == mask.shape[1] 206 | if gt: 207 | mask = pv2rgb(mask) 208 | 209 | for y in range(0, img.shape[0], stride): 210 | for x in range(0, img.shape[1], stride): 211 | img_tile = img[y:y + split_size, x:x + split_size] 212 | mask_tile = mask[y:y + split_size, x:x + split_size] 213 | 214 | if img_tile.shape[0] == split_size and img_tile.shape[1] == split_size \ 215 | and mask_tile.shape[0] == split_size and mask_tile.shape[1] == split_size: 216 | image_crop, mask_crop = randomsizedcrop(img_tile, mask_tile) 217 | bins = np.array(range(num_classes + 1)) 218 | class_pixel_counts, _ = np.histogram(mask_crop, bins=bins) 219 | cf = class_pixel_counts / (mask_crop.shape[0] * mask_crop.shape[1]) 220 | if cf[4] > 0.1 and mode == 'train': 221 | car_imgs, car_masks = car_aug(image_crop, mask_crop) 222 | for i in range(len(car_imgs)): 223 | out_img_path = os.path.join(imgs_output_dir, 224 | "{}_{}_{}_{}.tif".format(img_filename, m, k, i)) 225 | cv2.imwrite(out_img_path, car_imgs[i]) 226 | 227 | out_mask_path = os.path.join(masks_output_dir, 228 | "{}_{}_{}_{}.png".format(mask_filename, m, k, i)) 229 | cv2.imwrite(out_mask_path, car_masks[i]) 230 | else: 231 | out_img_path = os.path.join(imgs_output_dir, "{}_{}_{}.tif".format(img_filename, m, k)) 232 | cv2.imwrite(out_img_path, img_tile) 233 | 234 | out_mask_path = os.path.join(masks_output_dir, "{}_{}_{}.png".format(mask_filename, m, k)) 235 | cv2.imwrite(out_mask_path, mask_tile) 236 | 237 | k += 1 238 | 239 | 240 | if __name__ == "__main__": 241 | seed_everything(SEED) 242 | args = parse_args() 243 | imgs_dir = args.img_dir 244 | masks_dir = args.mask_dir 245 | imgs_output_dir = args.output_img_dir 246 | masks_output_dir = args.output_mask_dir 247 | gt = args.gt 248 | eroded = args.eroded 249 | mode = args.mode 250 | val_scale = args.val_scale 251 | split_size = args.split_size 252 | stride = args.stride 253 | img_paths = glob.glob(os.path.join(imgs_dir, "*.tif")) 254 | mask_paths_raw = glob.glob(os.path.join(masks_dir, "*.tif")) 255 | if eroded: 256 | mask_paths = [(p[:-15] + '.tif') for p in mask_paths_raw] 257 | else: 258 | mask_paths = mask_paths_raw 259 | img_paths.sort() 260 | mask_paths.sort() 261 | 262 | if not os.path.exists(imgs_output_dir): 263 | os.makedirs(imgs_output_dir) 264 | if not os.path.exists(masks_output_dir): 265 | os.makedirs(masks_output_dir) 266 | if gt: 267 | os.makedirs(masks_output_dir+'/origin') 268 | 269 | inp = [(img_path, mask_path, imgs_output_dir, masks_output_dir, eroded, gt, mode, val_scale, split_size, stride) 270 | for img_path, mask_path in zip(img_paths, mask_paths)] 271 | 272 | t0 = time.time() 273 | mpp.Pool(processes=mp.cpu_count()).map(vaihingen_format, inp) 274 | t1 = time.time() 275 | split_time = t1 - t0 276 | print('images spliting spends: {} s'.format(split_time)) 277 | 278 | 279 | -------------------------------------------------------------------------------- /train_supervision.py: -------------------------------------------------------------------------------- 1 | import pytorch_lightning as pl 2 | from pytorch_lightning.callbacks import ModelCheckpoint 3 | from tools.cfg import py2cfg 4 | import os 5 | import torch 6 | from torch import nn 7 | import cv2 8 | import numpy as np 9 | import argparse 10 | from pathlib import Path 11 | from tools.metric import Evaluator 12 | from pytorch_lightning.loggers import CSVLogger 13 | import random 14 | 15 | 16 | def seed_everything(seed): 17 | random.seed(seed) 18 | os.environ['PYTHONHASHSEED'] = str(seed) 19 | np.random.seed(seed) 20 | torch.manual_seed(seed) 21 | torch.cuda.manual_seed(seed) 22 | torch.backends.cudnn.deterministic = True 23 | torch.backends.cudnn.benchmark = True 24 | 25 | 26 | def get_args(): 27 | parser = argparse.ArgumentParser() 28 | arg = parser.add_argument 29 | arg("-c", "--config_path", type=Path, help="Path to the config.", required=True) 30 | return parser.parse_args() 31 | 32 | 33 | class Supervision_Train(pl.LightningModule): 34 | def __init__(self, config): 35 | super().__init__() 36 | self.config = config 37 | self.net = config.net 38 | self.automatic_optimization = False 39 | 40 | self.loss = config.loss 41 | 42 | self.metrics_train = Evaluator(num_class=config.num_classes) 43 | self.metrics_val = Evaluator(num_class=config.num_classes) 44 | 45 | def forward(self, x): 46 | # only net is used in the prediction/inference 47 | seg_pre = self.net(x) 48 | return seg_pre 49 | 50 | def training_step(self, batch, batch_idx): 51 | img, mask = batch['img'], batch['gt_semantic_seg'] 52 | 53 | prediction = self.net(img) 54 | loss = self.loss(prediction, mask) 55 | 56 | if self.config.use_aux_loss: 57 | pre_mask = nn.Softmax(dim=1)(prediction[0]) 58 | else: 59 | pre_mask = nn.Softmax(dim=1)(prediction) 60 | 61 | pre_mask = pre_mask.argmax(dim=1) 62 | for i in range(mask.shape[0]): 63 | self.metrics_train.add_batch(mask[i].cpu().numpy(), pre_mask[i].cpu().numpy()) 64 | 65 | # supervision stage 66 | opt = self.optimizers(use_pl_optimizer=False) 67 | self.manual_backward(loss) 68 | if (batch_idx + 1) % self.config.accumulate_n == 0: 69 | opt.step() 70 | opt.zero_grad() 71 | 72 | sch = self.lr_schedulers() 73 | if self.trainer.is_last_batch and (self.trainer.current_epoch + 1) % 1 == 0: 74 | sch.step() 75 | 76 | return {"loss": loss} 77 | 78 | def training_epoch_end(self, outputs): 79 | if 'vaihingen' in self.config.log_name: 80 | mIoU = np.nanmean(self.metrics_train.Intersection_over_Union()[:-1]) 81 | F1 = np.nanmean(self.metrics_train.F1()[:-1]) 82 | elif 'potsdam' in self.config.log_name: 83 | mIoU = np.nanmean(self.metrics_train.Intersection_over_Union()[:-1]) 84 | F1 = np.nanmean(self.metrics_train.F1()[:-1]) 85 | elif 'whubuilding' in self.config.log_name: 86 | mIoU = np.nanmean(self.metrics_train.Intersection_over_Union()[:-1]) 87 | F1 = np.nanmean(self.metrics_train.F1()[:-1]) 88 | elif 'massbuilding' in self.config.log_name: 89 | mIoU = np.nanmean(self.metrics_train.Intersection_over_Union()[:-1]) 90 | F1 = np.nanmean(self.metrics_train.F1()[:-1]) 91 | elif 'inriabuilding' in self.config.log_name: 92 | mIoU = np.nanmean(self.metrics_train.Intersection_over_Union()[:-1]) 93 | F1 = np.nanmean(self.metrics_train.F1()[:-1]) 94 | else: 95 | mIoU = np.nanmean(self.metrics_train.Intersection_over_Union()) 96 | F1 = np.nanmean(self.metrics_train.F1()) 97 | 98 | OA = np.nanmean(self.metrics_train.OA()) 99 | iou_per_class = self.metrics_train.Intersection_over_Union() 100 | eval_value = {'mIoU': mIoU, 101 | 'F1': F1, 102 | 'OA': OA} 103 | print('train:', eval_value) 104 | 105 | iou_value = {} 106 | for class_name, iou in zip(self.config.classes, iou_per_class): 107 | iou_value[class_name] = iou 108 | print(iou_value) 109 | self.metrics_train.reset() 110 | loss = torch.stack([x["loss"] for x in outputs]).mean() 111 | log_dict = {"train_loss": loss, 'train_mIoU': mIoU, 'train_F1': F1, 'train_OA': OA} 112 | self.log_dict(log_dict, prog_bar=True) 113 | 114 | def validation_step(self, batch, batch_idx): 115 | img, mask = batch['img'], batch['gt_semantic_seg'] 116 | prediction = self.forward(img) 117 | pre_mask = nn.Softmax(dim=1)(prediction) 118 | pre_mask = pre_mask.argmax(dim=1) 119 | for i in range(mask.shape[0]): 120 | self.metrics_val.add_batch(mask[i].cpu().numpy(), pre_mask[i].cpu().numpy()) 121 | 122 | loss_val = self.loss(prediction, mask) 123 | return {"loss_val": loss_val} 124 | 125 | def validation_epoch_end(self, outputs): 126 | if 'vaihingen' in self.config.log_name: 127 | mIoU = np.nanmean(self.metrics_val.Intersection_over_Union()[:-1]) 128 | F1 = np.nanmean(self.metrics_val.F1()[:-1]) 129 | elif 'potsdam' in self.config.log_name: 130 | mIoU = np.nanmean(self.metrics_val.Intersection_over_Union()[:-1]) 131 | F1 = np.nanmean(self.metrics_val.F1()[:-1]) 132 | elif 'whubuilding' in self.config.log_name: 133 | mIoU = np.nanmean(self.metrics_val.Intersection_over_Union()[:-1]) 134 | F1 = np.nanmean(self.metrics_val.F1()[:-1]) 135 | elif 'massbuilding' in self.config.log_name: 136 | mIoU = np.nanmean(self.metrics_val.Intersection_over_Union()[:-1]) 137 | F1 = np.nanmean(self.metrics_val.F1()[:-1]) 138 | elif 'inriabuilding' in self.config.log_name: 139 | mIoU = np.nanmean(self.metrics_val.Intersection_over_Union()[:-1]) 140 | F1 = np.nanmean(self.metrics_val.F1()[:-1]) 141 | else: 142 | mIoU = np.nanmean(self.metrics_val.Intersection_over_Union()) 143 | F1 = np.nanmean(self.metrics_val.F1()) 144 | 145 | OA = np.nanmean(self.metrics_val.OA()) 146 | iou_per_class = self.metrics_val.Intersection_over_Union() 147 | 148 | eval_value = {'mIoU': mIoU, 149 | 'F1': F1, 150 | 'OA': OA} 151 | print('val:', eval_value) 152 | iou_value = {} 153 | for class_name, iou in zip(self.config.classes, iou_per_class): 154 | iou_value[class_name] = iou 155 | print(iou_value) 156 | 157 | self.metrics_val.reset() 158 | loss = torch.stack([x["loss_val"] for x in outputs]).mean() 159 | log_dict = {"val_loss": loss, 'val_mIoU': mIoU, 'val_F1': F1, 'val_OA': OA} 160 | self.log_dict(log_dict, prog_bar=True) 161 | 162 | def configure_optimizers(self): 163 | optimizer = self.config.optimizer 164 | lr_scheduler = self.config.lr_scheduler 165 | 166 | return [optimizer], [lr_scheduler] 167 | 168 | def train_dataloader(self): 169 | 170 | return self.config.train_loader 171 | 172 | def val_dataloader(self): 173 | 174 | return self.config.val_loader 175 | 176 | 177 | # training 178 | def main(): 179 | args = get_args() 180 | config = py2cfg(args.config_path) 181 | seed_everything(42) 182 | 183 | checkpoint_callback = ModelCheckpoint(save_top_k=config.save_top_k, monitor=config.monitor, 184 | save_last=config.save_last, mode=config.monitor_mode, 185 | dirpath=config.weights_path, 186 | filename=config.weights_name) 187 | logger = CSVLogger('lightning_logs', name=config.log_name) 188 | 189 | model = Supervision_Train(config) 190 | if config.pretrained_ckpt_path: 191 | model = Supervision_Train.load_from_checkpoint(config.pretrained_ckpt_path, config=config) 192 | 193 | trainer = pl.Trainer(devices=config.gpus, max_epochs=config.max_epoch, accelerator='gpu', 194 | check_val_every_n_epoch=config.check_val_every_n_epoch, 195 | callbacks=[checkpoint_callback], strategy=config.strategy, 196 | resume_from_checkpoint=config.resume_ckpt_path, logger=logger) 197 | trainer.fit(model=model) 198 | 199 | 200 | if __name__ == "__main__": 201 | main() 202 | --------------------------------------------------------------------------------