├── .gitignore ├── README.md ├── checkpoints └── unet-augment-final.pth ├── data_utils ├── __init__.py ├── augment.py ├── colors.py ├── common.py ├── create_dataset.py └── preprocessing.py ├── docs ├── README.md └── images │ ├── 0_0-0_0-blur.png │ ├── 0_0-0_0-bright.png │ ├── 0_0-0_0-crop-resize.png │ ├── 0_0-0_0-distort-rt.png │ ├── 0_0-0_0-distort.png │ ├── 0_0-0_0-gauss-rot.png │ ├── 0_0-0_0-gauss.png │ ├── 0_0-0_0-hsv.png │ ├── 0_0-0_0-med-blur.png │ ├── 0_0-0_0-mirror.png │ ├── 0_0-0_0-rt-inv.png │ └── 0_0-0_0.png ├── example.ipynb ├── models ├── __init__.py └── unet.py ├── predict.py ├── requirements.txt ├── run_training.py └── training ├── __init__.py ├── dataset.py ├── loss.py └── training_loop.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Node artifact files 2 | node_modules/ 3 | dist/ 4 | 5 | # Compiled Java class files 6 | *.class 7 | 8 | # Compiled Python bytecode 9 | *.py[cod] 10 | 11 | # Log files 12 | *.log 13 | 14 | # Package files 15 | *.jar 16 | 17 | # Maven 18 | target/ 19 | dist/ 20 | 21 | # JetBrains IDE 22 | .idea/ 23 | 24 | # Unit test reports 25 | TEST*.xml 26 | 27 | # Generated by MacOS 28 | .DS_Store 29 | 30 | # Generated by Windows 31 | Thumbs.db 32 | 33 | # Applications 34 | *.app 35 | *.exe 36 | *.war 37 | 38 | # Large media files 39 | *.mp4 40 | *.tiff 41 | *.avi 42 | *.flv 43 | *.mov 44 | *.wmv 45 | 46 | # generated files 47 | data 48 | *.pth 49 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Satellite Image Segmentation Using PyTorch # 2 | 3 | This repo contains a U-Net implementation for satellite segmentation 4 | 5 | 6 | ## Data preparation ## 7 | 8 | Due to a severe lack of training data, several pre-processing steps are taken to try and 9 | alleviate this. First, the annotations are converted from their json files to image masks. Then, the satellite images are further tiled 10 | down, to tiles of size `(512x512)`, so that the images can be fed into a fully convolutional network (**FCN**) for semantic segmentation. 11 | The augmentations are: 12 | - `blur`: combination of median and bilateral blur 13 | - `bright increase`: increase brightness artificially 14 | - `distort`: elastic deformation of image 15 | - `gaussian blur`: gaussian blurring 16 | - `HSV`: convert channels to HSV 17 | - `medial blur`: median blur 18 | - `mirror`: mirror image 19 | - `rotation invariance`: apply rotation invariance (see `data_utils.augment.rotation_invariance()` for details) 20 | - `crop + resize`: crop image randomly and resize to size expected by network (crop is applied to both image and mask) 21 | 22 | ### sample of applied augmentations ### 23 | 24 | blur | bright increase | distort | gaussian blur | 25 | :-------------------------:|:-------------------------:|:-------------------------:|:-------------------------: 26 | ![](docs/images/0_0-0_0-blur.png) | ![](docs/images/0_0-0_0-bright.png) | ![](docs/images/0_0-0_0-distort.png) | ![](docs/images/0_0-0_0-gauss.png) | 27 | 28 | 29 | | HSV shift | median blur | mirror | gaussian blur + rotation 30 | :-------------------------:|:-------------------------:|:-------------------------:|:-------------------------: 31 | ![](docs/images/0_0-0_0-hsv.png) | ![](docs/images/0_0-0_0-med-blur.png) | ![](docs/images/0_0-0_0-mirror.png) | ![](docs/images/0_0-0_0-gauss-rot.png) 32 | 33 | | distort + rot | rotation invariance | crop + resize 34 | :-------------------------:|:-------------------------:|:-------------------------: 35 | ![](docs/images/0_0-0_0-distort-rt.png) | ![](docs/images/0_0-0_0-rt-inv.png) | ![](docs/images/0_0-0_0-crop-resize.png) 36 | 37 | 38 | ## Network & Training ## 39 | 40 | [U-Net](https://arxiv.org/abs/1505.04597) is used as a base model for segmentation. The original intention was to use U-Net to show base results, 41 | and then train **PSPNet** ([Pyramid Scene Parsing Network](https://arxiv.org/abs/1612.01105)) using a pretrained satellite segmentation model and show comparisons, but time did not allow for this. 42 | 43 | ### training parameters ### 44 | 45 | U-Net was trained for ~~50~~ 43 epochs, with a batch size of 4. 46 | 47 | The base parameters for training can be seen, and adjusted, in `run_training.py`: 48 | 49 | 50 | ```python 51 | def setup_run_arguments(): 52 | args = EasyDict() 53 | args.epochs = 50 54 | args.batch = 4 55 | args.val_percent = 0.2 56 | args.n_classes = 4 57 | args.n_channels = 3 58 | args.num_workers = 8 59 | 60 | args.learning_rate = 0.001 61 | args.weight_decay = 1e-8 62 | args.momentum = 0.9 63 | args.save_cp = True 64 | args.loss = "CrossEntropy" 65 | ``` 66 | 67 | The trained model is provided in `checkpoints/unet-augment-final.pth` 68 | 69 | 70 | ## Generating json annotations from U-Net predictions ## 71 | 72 | The expected output is a json annotated file containing the vector points corresponding to the classes. 73 | A function for generating such file is found in `predict.prediction_to_json(...)`. A python notebook 74 | is provided showing how to generate the json file, as well as how to generate a color mask file for the json file. 75 | 76 | The example: [example.ipynb](https://github.com/obravo7/satellite-segmentation-pytorch/blob/master/example.ipynb) shows how 77 | to load the trained model and use it to create the annotation file. It also shows how to create a colored image masks directly with the 78 | annotation. (The example notebook was created with u-net after 35 epochs) 79 | 80 | 81 | # Discussion # 82 | 83 | There are several improvements that can be made. For starters, some augmentation methods used could be replaced 84 | or left out entirely, such as HSV. Simple color shift could have been used instead. Another major issue that should 85 | have been addressed at the beginning was class imbalance. It would have been better to apply augmentations with respect to 86 | the class frequency, trying to shift the infrequency balance. 87 | 88 | Another obvious issue is that U-Net was a network catered to medical image segmentation, but it is often used as a baseline mode 89 | because it is small and easy to implement. A more suitable network would have been PSPNet, as mentioned above. Similarly, 90 | there exists several pretrained model that could have been used with transfer learning. This, coupled with meaningful augmentations, 91 | would have yielded better model/results. 92 | -------------------------------------------------------------------------------- /checkpoints/unet-augment-final.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/obravo7/satellite-segmentation-pytorch/299d9acb959adb4059f25a2aa965aaf9e3eac562/checkpoints/unet-augment-final.pth -------------------------------------------------------------------------------- /data_utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/obravo7/satellite-segmentation-pytorch/299d9acb959adb4059f25a2aa965aaf9e3eac562/data_utils/__init__.py -------------------------------------------------------------------------------- /data_utils/augment.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import random 3 | import numpy as np 4 | from PIL import Image 5 | from typing import Tuple 6 | 7 | 8 | def gaussBlur(image, filter_size=15): 9 | """ 10 | gaussian blur 11 | """ 12 | blur = cv2.GaussianBlur(image, (filter_size, filter_size), 0) 13 | return blur 14 | 15 | 16 | def medBlur(image, filter_size=5): 17 | blur = cv2.medianBlur(image, filter_size) 18 | return blur 19 | 20 | 21 | def bilateralBlur(image): 22 | blur = cv2.bilateralFilter(image, 9, 75, 75) 23 | return blur 24 | 25 | 26 | def totalBlur(image, filter_size=11): 27 | mb = medBlur(image, filter_size) 28 | total = bilateralBlur(mb) 29 | return total 30 | 31 | 32 | def distort_elastic_cv2(image, alpha=40, sigma=3, random_state=None): 33 | """ 34 | Elastic deformation of images as described in [Simard2003]_ (with modifications). 35 | .. [Simard2003] Simard, Steinkraus and Platt, "Best Practices for 36 | Convolutional Neural Networks applied to Visual Document Analysis", in 37 | Proc. of the International Conference on Document Analysis and 38 | Recognition, 2003. 39 | """ 40 | 41 | if random_state is None: 42 | random_state = np.random.RandomState(None) 43 | 44 | shape_size = image.shape[:2] 45 | 46 | # Downscale the random grid and then upsizing post filter 47 | # improve performance 48 | 49 | grid_scale = 4 50 | alpha //= grid_scale 51 | sigma //= grid_scale 52 | grid_shape = (shape_size[0] // grid_scale, shape_size[1] // grid_scale) 53 | 54 | blur_size = int(4 * sigma) | 1 55 | rand_x = cv2.GaussianBlur( 56 | (random_state.rand(*grid_shape) * 2 - 1).astype(np.float32), 57 | ksize=(blur_size, blur_size), sigmaX=sigma) * alpha 58 | 59 | rand_y = cv2.GaussianBlur( 60 | (random_state.rand(*grid_shape) * 2 - 1).astype(np.float32), 61 | ksize=(blur_size, blur_size), sigmaX=sigma) * alpha 62 | 63 | if grid_scale > 1: 64 | rand_x = cv2.resize(rand_x, shape_size[::-1]) 65 | rand_y = cv2.resize(rand_y, shape_size[::-1]) 66 | 67 | grid_x, grid_y = np.meshgrid(np.arange(shape_size[1]), np.arange(shape_size[0])) 68 | grid_x = (grid_x + rand_x).astype(np.float32) 69 | grid_y = (grid_y + rand_y).astype(np.float32) 70 | 71 | distorted_img = cv2.remap(image, grid_x, grid_y, 72 | borderMode=cv2.BORDER_REFLECT_101, interpolation=cv2.INTER_LINEAR) 73 | 74 | return distorted_img 75 | 76 | 77 | def rotation_invariance(img): 78 | """ 79 | rotate and shift images randomly 80 | """ 81 | pts1 = np.float32([[50, 50], [200, 50], [50, 200]]) 82 | pts2 = np.float32([[10, 100], [200, 50], [100, 250]]) 83 | rows = img.shape[0] 84 | cols = img.shape[1] 85 | M = cv2.getAffineTransform(pts1, pts2) 86 | dst = cv2.warpAffine(img, M, (rows, cols)) 87 | return dst 88 | 89 | 90 | def speckle_noise(img): 91 | """ 92 | add multiplicative speckle noise 93 | used for radar images 94 | """ 95 | row, col, ch = img.shape 96 | gauss = np.random.randn(row, col, ch) 97 | gauss = gauss.reshape(row, col, ch) 98 | noisy = img * (gauss / (len(gauss) - 0.50 * len(gauss))) 99 | 100 | return noisy 101 | 102 | 103 | def salt_pepper_noise(img, prob): 104 | """ 105 | salt and pepper noise 106 | prob: probability of noise 107 | """ 108 | output = np.zeros(img.shape, np.uint8) 109 | thresh = 1 - prob 110 | 111 | for i in range(img.shape[0]): 112 | for j in range(img.shape[1]): 113 | rdn = random.random() 114 | if rdn < prob: 115 | output[i][j] = 0 116 | elif rdn > thresh: 117 | output[i][j] = 255 118 | else: 119 | output[i][j] = img[i][j] 120 | return output 121 | 122 | 123 | def mirror(img): 124 | """ 125 | horizontal mirror of image 126 | """ 127 | mirror = cv2.flip(img, +1) 128 | return mirror 129 | 130 | 131 | def rotate(img, angle): 132 | """ 133 | rotate image 90 degrees 134 | """ 135 | (h, w) = img.shape[:2] 136 | center = (w / 2, h / 2) 137 | scale = 1.0 138 | 139 | M = cv2.getRotationMatrix2D(center, angle, scale) 140 | rotated = cv2.warpAffine(img, M, (h, w)) 141 | return rotated 142 | 143 | 144 | def to_hsv(image): 145 | hsv_image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV) 146 | return hsv_image 147 | 148 | 149 | def increase_brightness(img, value=30): 150 | hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV) 151 | h, s, v = cv2.split(hsv) 152 | 153 | lim = 255 - value 154 | v[v > lim] = 255 155 | v[v <= lim] += value 156 | 157 | final_hsv = cv2.merge((h, s, v)) 158 | img = cv2.cvtColor(final_hsv, cv2.COLOR_HSV2BGR) 159 | return img 160 | 161 | 162 | def random_crop_and_resize(img, mask, size=(512, 512)) -> Tuple[Image.Image, Image.Image]: 163 | height, width = 256, 256 164 | x = random.randint(0, img.shape[1] - width) 165 | y = random.randint(0, img.shape[0] - height) 166 | img = img[y:y+height, x:x+width] 167 | mask = mask[y:y+height, x:x+width] 168 | img = Image.fromarray(img, "RGB").resize(size, Image.ANTIALIAS) 169 | mask = Image.fromarray(mask, 'L').resize(size, Image.ANTIALIAS) 170 | return img, mask 171 | 172 | -------------------------------------------------------------------------------- /data_utils/colors.py: -------------------------------------------------------------------------------- 1 | 2 | def hex_to_rgb(h: str) -> tuple: 3 | 4 | h = h.lstrip('#') if '#' in h else h 5 | 6 | return tuple(int(h[i:i+2], 16) for i in (0, 2, 4)) 7 | -------------------------------------------------------------------------------- /data_utils/common.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | from PIL import Image 4 | import json 5 | from typing import Any 6 | 7 | 8 | def create_mask(mask: np.ndarray, label_data: dict, pixel_labels: dict) -> np.ndarray: 9 | 10 | for label in label_data: 11 | 12 | hex_label = label['color'] 13 | # category = label['name'] 14 | color = pixel_labels[hex_label] 15 | 16 | for vector_points in label['annotations']: 17 | 18 | x_values = [i for i in vector_points['segmentation'][::2]] 19 | y_values = [i for i in vector_points['segmentation'][1::2]] 20 | contours = np.array([(x, y) for x, y in zip(x_values, y_values)]) 21 | mask = cv2.drawContours(mask, [contours.astype(int)], -1, color, -1) 22 | 23 | return mask 24 | 25 | 26 | def tile_image(image: np.ndarray, save_path, size=512) -> None: 27 | if len(image.shape) == 2: 28 | image = image[..., np.newaxis] # (h, w) -> (h, w, 1) 29 | height, width, channels = image.shape 30 | 31 | h_stride, h_diff = divmod(height, size) 32 | w_stride, w_diff = divmod(width, size) 33 | 34 | h_stride = h_stride + 1 if h_diff > 0 else h_stride 35 | w_stride = w_stride + 1 if w_diff > 0 else w_diff 36 | 37 | for h in range(h_stride): 38 | for w in range(w_stride): 39 | tile = image[ 40 | size * h: size + (size * h), # tile height of image 41 | size * w: size + (size * w), # tile width of image 42 | :] 43 | 44 | if tile.shape[0] != size and tile.shape[1] != size: 45 | # adjust both height and width 46 | # shift according to height and width difference 47 | tile = image[ 48 | (size * h) - (size - h_diff): (size + (size * h)) - (size - h_diff), 49 | (size * w) - (size - w_diff):(size + (size * w)) - (size - w_diff), 50 | :] 51 | 52 | elif tile.shape[1] != size: 53 | # adjust width 54 | tile = image[ 55 | size * h: size + (size * h), 56 | # shift back according to w_diff 57 | (size * w) - (size - w_diff):(size + (size * w)) - (size - w_diff), 58 | :] 59 | 60 | elif tile.shape[0] != size: # height 61 | 62 | # adjust height 63 | tile = image[ 64 | # shift back according to height difference 65 | (size * h) - (size - h_diff): (size + (size * h)) - (size - h_diff), 66 | size * w: size + (size * w), # stride width of image 67 | :] 68 | 69 | else: 70 | # normal sequence; do nothing 71 | pass 72 | 73 | if channels == 1: 74 | Image.fromarray(tile.reshape((size, size)), mode='L').save(f'{save_path}-{h}_{w}.png') 75 | else: 76 | Image.fromarray(tile, mode='RGB').save(f'{save_path}-{h}_{w}.png') 77 | 78 | 79 | def load_json(json_path: str): 80 | with open(json_path, 'rb') as f: 81 | data = json.load(f) 82 | return data 83 | 84 | 85 | class EasyDict(dict): 86 | # adopted from https://github.com/NVlabs/stylegan2-ada/blob/main/dnnlib/util.py 87 | 88 | def __getattr__(self, name: str) -> Any: 89 | try: 90 | return self[name] 91 | except KeyError: 92 | raise AttributeError(name) 93 | 94 | def __setattr__(self, name: str, value: Any) -> None: 95 | self[name] = value 96 | 97 | def __delattr__(self, name: str) -> None: 98 | del self[name] 99 | -------------------------------------------------------------------------------- /data_utils/create_dataset.py: -------------------------------------------------------------------------------- 1 | """script used to create base datset""" 2 | 3 | import os 4 | import glob 5 | import numpy as np 6 | from PIL import Image 7 | 8 | from data_utils import common, colors, augment 9 | 10 | # data paths 11 | data_path = 'data/colors' 12 | mask_path = 'data/masks' 13 | raw_path = 'data/raw' 14 | color_mask_path = 'data/colors' 15 | mask_save_path = 'data/train/masks' 16 | image_save_path = 'data/train/images' 17 | color_save_path = 'data/train/color' 18 | 19 | 20 | colors_from_hex = { 21 | "#ff0000": colors.hex_to_rgb('#ff0000'), 22 | "#0037ff": colors.hex_to_rgb('#0037ff'), 23 | '#f900ff': colors.hex_to_rgb('#f900ff') 24 | } 25 | 26 | 27 | train_labels = { 28 | "#ff0000": (1, 1, 1), # houses 29 | "#0037ff": (2, 2, 2), # buildings 30 | '#f900ff': (3, 3, 3) # Sheds/Garages 31 | } 32 | 33 | 34 | def create_base_dataset(): 35 | 36 | os.makedirs(data_path, exist_ok=True) 37 | os.makedirs(mask_path, exist_ok=True) 38 | json_paths = glob.glob('data/annotations/*.json') 39 | for json_path in json_paths: 40 | 41 | annotation_data = common.load_json(json_path) 42 | label_data = annotation_data['labels'] 43 | 44 | # create mask data 45 | height, width = annotation_data['height'], annotation_data['width'] 46 | mask = np.ones((height, width, 3), dtype=np.uint8) * 255 47 | train_mask = np.zeros((height, width), dtype=np.uint8) # background is a class, category 0 48 | 49 | mask = common.create_mask(mask, label_data=label_data, pixel_labels=colors_from_hex) 50 | train_mask = common.create_mask(train_mask, label_data=label_data, pixel_labels=train_labels) 51 | 52 | file_name = os.path.basename(json_path).split('.')[0] 53 | Image.fromarray(mask, mode="RGB").save(os.path.join(data_path, f"{file_name}.png")) 54 | 55 | print(f"{train_mask.shape}") 56 | Image.fromarray(train_mask, mode="L").save(os.path.join(mask_path, f"{file_name}.png")) 57 | 58 | 59 | def create_train_dataset(): 60 | 61 | mask_image_list = glob.glob(os.path.join(mask_path, '*.png')) 62 | 63 | # There are 9 tiles that are left out as evaluation (testing) 64 | # diff = list(set(raw_base_names) - set(mask_base_names)) 65 | os.makedirs(image_save_path, exist_ok=True) 66 | os.makedirs(mask_save_path, exist_ok=True) 67 | os.makedirs(color_save_path, exist_ok=True) 68 | 69 | i = 1 70 | for mask_file_path in mask_image_list: 71 | file_name = os.path.basename(mask_file_path) 72 | img_file_path = os.path.join(raw_path, file_name) 73 | color_msk_file_path = os.path.join(color_mask_path, file_name) 74 | 75 | img = np.array(Image.open(img_file_path)) 76 | mask = np.array(Image.open(mask_file_path)) 77 | color = np.array(Image.open(color_msk_file_path)) 78 | 79 | common.tile_image(image=img, 80 | save_path=os.path.join(image_save_path, file_name.split('.')[0]) 81 | ) 82 | common.tile_image(image=mask, 83 | save_path=os.path.join(mask_save_path, file_name.split('.')[0]) 84 | ) 85 | common.tile_image(image=color, 86 | save_path=os.path.join(color_save_path, file_name.split('.')[0]) 87 | ) 88 | print(f'files completed: \t{i}/{len(mask_image_list)}...', end='\r') 89 | i += 1 90 | 91 | 92 | def augment_train_data(): 93 | i = 1 94 | base_names = [os.path.basename(fn) for fn in glob.glob(os.path.join(mask_save_path, "*.png"))] 95 | for fn in base_names: 96 | msk_fp = os.path.join(mask_save_path, fn) 97 | img_fp = os.path.join(image_save_path, fn) 98 | color_fp = os.path.join(color_save_path, fn) 99 | 100 | msk = np.array(Image.open(msk_fp)) 101 | img = np.array(Image.open(img_fp)) 102 | color = np.array(Image.open(color_fp)) 103 | 104 | # total blur 105 | Image.fromarray(augment.totalBlur(img), mode='RGB').save( 106 | os.path.join(image_save_path, f'{fn.split(".")[0]}-blur.png')) 107 | Image.fromarray(msk, mode='L').save( 108 | os.path.join(mask_save_path, f'{fn.split(".")[0]}-blur.png')) 109 | Image.fromarray(color, mode='RGB').save( 110 | os.path.join(color_save_path, f'{fn.split(".")[0]}-blur.png')) 111 | 112 | # distort elastic 113 | Image.fromarray(augment.distort_elastic_cv2(img), mode='RGB').save( 114 | os.path.join(image_save_path, f'{fn.split(".")[0]}-distort.png')) 115 | Image.fromarray(msk, mode='L').save( 116 | os.path.join(mask_save_path, f'{fn.split(".")[0]}-distort.png')) 117 | Image.fromarray(color, mode='RGB').save( 118 | os.path.join(color_save_path, f'{fn.split(".")[0]}-distort.png')) 119 | 120 | # mirror 121 | Image.fromarray(augment.mirror(img), mode='RGB').save( 122 | os.path.join(image_save_path, f'{fn.split(".")[0]}-mirror.png')) 123 | Image.fromarray(augment.mirror(msk), mode='L').save( 124 | os.path.join(mask_save_path, f'{fn.split(".")[0]}-mirror.png')) 125 | Image.fromarray(augment.mirror(color), mode='RGB').save( 126 | os.path.join(color_save_path, f'{fn.split(".")[0]}-mirror.png')) 127 | 128 | # rotation invariance 129 | Image.fromarray(augment.rotation_invariance(img), mode='RGB').save( 130 | os.path.join(image_save_path, f'{fn.split(".")[0]}-rt-inv.png')) 131 | Image.fromarray(augment.rotation_invariance(msk), mode='L').save( 132 | os.path.join(mask_save_path, f'{fn.split(".")[0]}-rt-inv.png')) 133 | Image.fromarray(augment.rotation_invariance(color), mode='RGB').save( 134 | os.path.join(color_save_path, f'{fn.split(".")[0]}-rt-inv.png')) 135 | 136 | # gaussian blur 137 | Image.fromarray(augment.gaussBlur(img), mode='RGB').save( 138 | os.path.join(image_save_path, f'{fn.split(".")[0]}-gauss.png')) 139 | Image.fromarray(msk, mode='L').save( 140 | os.path.join(mask_save_path, f'{fn.split(".")[0]}-gauss.png')) 141 | Image.fromarray(color, mode='RGB').save( 142 | os.path.join(color_save_path, f'{fn.split(".")[0]}-gauss.png')) 143 | 144 | # distort elastic + rotation 145 | rot = 90 146 | Image.fromarray(augment.distort_elastic_cv2(augment.rotate(img, rot)), mode='RGB').save( 147 | os.path.join(image_save_path, f'{fn.split(".")[0]}-distort-rt.png')) 148 | Image.fromarray(augment.rotate(msk, rot), mode='L').save( 149 | os.path.join(mask_save_path, f'{fn.split(".")[0]}-distort-rt.png')) 150 | Image.fromarray(augment.rotate(color, rot), mode='RGB').save( 151 | os.path.join(color_save_path, f'{fn.split(".")[0]}-distort-rt.png')) 152 | 153 | # gaussian blur + rot 154 | rot = 270 155 | Image.fromarray(augment.gaussBlur(augment.rotate(img, rot)), mode='RGB').save( 156 | os.path.join(image_save_path, f'{fn.split(".")[0]}-gauss-rot.png')) 157 | Image.fromarray(augment.rotate(msk, rot), mode='L').save( 158 | os.path.join(mask_save_path, f'{fn.split(".")[0]}-gauss-rot.png')) 159 | Image.fromarray(augment.rotate(color, rot), mode='RGB').save( 160 | os.path.join(color_save_path, f'{fn.split(".")[0]}-gauss-rot.png')) 161 | 162 | # HSV 163 | Image.fromarray(augment.to_hsv(img), mode='RGB').save( 164 | os.path.join(image_save_path, f'{fn.split(".")[0]}-hsv.png')) 165 | Image.fromarray(msk, mode='L').save( 166 | os.path.join(mask_save_path, f'{fn.split(".")[0]}-hsv.png')) 167 | Image.fromarray(color, mode='RGB').save( 168 | os.path.join(color_save_path, f'{fn.split(".")[0]}-hsv.png')) 169 | 170 | # increase brightness 171 | Image.fromarray(augment.increase_brightness(img), mode='RGB').save( 172 | os.path.join(image_save_path, f'{fn.split(".")[0]}-bright.png')) 173 | Image.fromarray(msk, mode='L').save( 174 | os.path.join(mask_save_path, f'{fn.split(".")[0]}-bright.png')) 175 | Image.fromarray(color, mode='RGB').save( 176 | os.path.join(color_save_path, f'{fn.split(".")[0]}-bright.png')) 177 | 178 | # median blur 179 | Image.fromarray(augment.medBlur(img), mode='RGB').save( 180 | os.path.join(image_save_path, f'{fn.split(".")[0]}-med-blur.png')) 181 | Image.fromarray(msk, mode='L').save( 182 | os.path.join(mask_save_path, f'{fn.split(".")[0]}-med-blur.png')) 183 | Image.fromarray(color, mode='RGB').save( 184 | os.path.join(color_save_path, f'{fn.split(".")[0]}-med-blur.png')) 185 | 186 | # crop and resize 187 | # img_cr, mask_cr = augment.random_crop_and_resize(img, msk) 188 | # img_cr.save(os.path.join(image_save_path, f'{fn.split(".")[0]}-crop-resize.png')) 189 | # mask_cr.save(os.path.join(mask_save_path, f'{fn.split(".")[0]}-crop-resize.png')) 190 | # Image.fromarray(color, mode='RGB').save( 191 | # os.path.join(color_save_path, f'{fn.split(".")[0]}-crop-resize.png')) 192 | 193 | -------------------------------------------------------------------------------- /data_utils/preprocessing.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def minmax_normalize(img, norm_range=(0, 1), orig_range=(0, 255)): 5 | # range(0, 1) 6 | norm_img = (img - orig_range[0]) / (orig_range[1] - orig_range[0]) 7 | # range(min_value, max_value) 8 | norm_img = norm_img * (norm_range[1] - norm_range[0]) + norm_range[0] 9 | return norm_img 10 | 11 | 12 | def meanstd_normalize(img, mean, std): 13 | mean = np.asarray(mean) 14 | std = np.asarray(std) 15 | norm_img = (img - mean) / std 16 | return norm_img 17 | 18 | 19 | def preprocess(img): 20 | 21 | image_array = np.array(img) 22 | if len(image_array.shape) == 2: 23 | image_array = image_array[..., np.newaxis] # (height, width) -> (height, width, 1) 24 | 25 | # HWC -> CHW 26 | image_trans = image_array.transpose((2, 0, 1)) 27 | 28 | if image_trans.max() > 1: 29 | image_trans = image_trans / 255 30 | 31 | return image_trans 32 | -------------------------------------------------------------------------------- /docs/README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## Introduction 4 | 5 | The following task seeks to emulate some real world scenario that you might encounter at Incubit. The goal is to create a model that takes an image as input and outputs vectorized segmentations of instances of objects inside the image. 6 | 7 | ## Problem description 8 | 9 | You are given a small slice of satellite data of a city in Japan by a fantasy land surveying company. They want to evaluate the feasibility of using satellite data to augment some downstream tasks. To that end, they need to segment out individual buildings, by category. They also need to return the segmentation in vector format so that they can import it into their CAD software. 10 | 11 | They had some of their interns annotate some of the data and would like you to have a go at it. They're not data scientists so the data they provided might not be optimal and their annotations not entirely consistent. But, it is what it is, and you have to make due with it. 12 | 13 | 14 | ## The Data: 15 | The data consists of one single sattelite picture of Tokyo split into a 9x9 grid (81 non-overlapping PNG images total). The naming convention reflects it. You can put it back together if you want to. 16 | 17 | In addition, you receive 72 annotation data files containing target labels for individual images. 9 of them, forming the bottom right corner of the image, are kept away for evaluation. So no need to panic if you can't find them. 18 | 19 | The data is picked in a way that will alleviate your model training time, while still being consistent enough to have a good chance of yielding reasonable output on the test data. 20 | 21 | ## The Labels: 22 | The target data consists of three labels that are of interest in the image: 23 | 1. Houses 24 | 2. Buildings 25 | 3. Sheds/Garages 26 | 27 | The labels come from our internal annotation tools. The format doesn't follow any other academic data format, but it's pretty straightforward. 28 | 29 | * The label data is provided in json format, one json annotation file per image, named after the image it represents. 30 | 31 | * there is plenty of verbose metadata in the annotation files, produced by our annotation tools, that is not relevant to the task. Feel free to navigate around it. 32 | 33 | * The labels are provided as polygons under an [x,y,x,y,x,y....] format. Once the sequence is finished, the last point connects to the first point. There is no distinction between clockwise and conterclockwise. 34 | 35 | * One polygon defines the perimeter of a building unit. 36 | 37 | 38 | ## The Task 39 | 40 | Your mission is to create a model that can take an image as input and output the individual vector polygon detections of each of the individual buildings according to their class in [x,y,x,y,x,y,....] format. 41 | 42 | You are free to use any pretrained models, any external data sources, any programming language, DL framework, or other resources that you see fit. 43 | 44 | What we'd like to see: 45 | * The code, preferably in github form 46 | * A report describing your problem analysis, approach, results, conclusions, hurdles, ideas. It can be in pdf, readme markdown, ipynb form, a mix thereof, or whatever you feel is good at conveying information. 47 | 48 | 49 | ## The output 50 | 51 | We would like your code to be able to output the results in the following json format: 52 | 53 | ``` 54 | {'filename':file_name, 55 | {'labels': [{'name': label_name, 'annotations': [{'id':some_unique_integer_id, 'segmentation':[x,y,x,y,x,y....]} 56 | ....] } 57 | ....] 58 | } 59 | ``` 60 | 61 | with .... standing for an indeterminate number of elements of the same type in the json structure 62 | 63 | ## The Evaluation 64 | 65 | What we're looking for is 66 | - Your analysis and understanding of the provided problem, challenge, data and results 67 | - Quality and creativity of the solution 68 | - Thought process and ability to convey it to us in words, tables, plots and/or visualizations 69 | - Readability and usability of the code 70 | 71 | ## The Timeline 72 | 73 | There is no hard deadline for the task but you should not spend more than 2 weeks of occasionally working on it, fitting it around your schedule. Less if you have more free time and put in more concerted effort. 74 | 75 | 76 | -------------------------------------------------------------------------------- /docs/images/0_0-0_0-blur.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/obravo7/satellite-segmentation-pytorch/299d9acb959adb4059f25a2aa965aaf9e3eac562/docs/images/0_0-0_0-blur.png -------------------------------------------------------------------------------- /docs/images/0_0-0_0-bright.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/obravo7/satellite-segmentation-pytorch/299d9acb959adb4059f25a2aa965aaf9e3eac562/docs/images/0_0-0_0-bright.png -------------------------------------------------------------------------------- /docs/images/0_0-0_0-crop-resize.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/obravo7/satellite-segmentation-pytorch/299d9acb959adb4059f25a2aa965aaf9e3eac562/docs/images/0_0-0_0-crop-resize.png -------------------------------------------------------------------------------- /docs/images/0_0-0_0-distort-rt.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/obravo7/satellite-segmentation-pytorch/299d9acb959adb4059f25a2aa965aaf9e3eac562/docs/images/0_0-0_0-distort-rt.png -------------------------------------------------------------------------------- /docs/images/0_0-0_0-distort.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/obravo7/satellite-segmentation-pytorch/299d9acb959adb4059f25a2aa965aaf9e3eac562/docs/images/0_0-0_0-distort.png -------------------------------------------------------------------------------- /docs/images/0_0-0_0-gauss-rot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/obravo7/satellite-segmentation-pytorch/299d9acb959adb4059f25a2aa965aaf9e3eac562/docs/images/0_0-0_0-gauss-rot.png -------------------------------------------------------------------------------- /docs/images/0_0-0_0-gauss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/obravo7/satellite-segmentation-pytorch/299d9acb959adb4059f25a2aa965aaf9e3eac562/docs/images/0_0-0_0-gauss.png -------------------------------------------------------------------------------- /docs/images/0_0-0_0-hsv.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/obravo7/satellite-segmentation-pytorch/299d9acb959adb4059f25a2aa965aaf9e3eac562/docs/images/0_0-0_0-hsv.png -------------------------------------------------------------------------------- /docs/images/0_0-0_0-med-blur.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/obravo7/satellite-segmentation-pytorch/299d9acb959adb4059f25a2aa965aaf9e3eac562/docs/images/0_0-0_0-med-blur.png -------------------------------------------------------------------------------- /docs/images/0_0-0_0-mirror.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/obravo7/satellite-segmentation-pytorch/299d9acb959adb4059f25a2aa965aaf9e3eac562/docs/images/0_0-0_0-mirror.png -------------------------------------------------------------------------------- /docs/images/0_0-0_0-rt-inv.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/obravo7/satellite-segmentation-pytorch/299d9acb959adb4059f25a2aa965aaf9e3eac562/docs/images/0_0-0_0-rt-inv.png -------------------------------------------------------------------------------- /docs/images/0_0-0_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/obravo7/satellite-segmentation-pytorch/299d9acb959adb4059f25a2aa965aaf9e3eac562/docs/images/0_0-0_0.png -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/obravo7/satellite-segmentation-pytorch/299d9acb959adb4059f25a2aa965aaf9e3eac562/models/__init__.py -------------------------------------------------------------------------------- /models/unet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | import torch 4 | 5 | 6 | class DoubleConv(nn.Module): 7 | def __init__(self, in_channels, out_channels, mid_channels=None): 8 | super().__init__() 9 | if not mid_channels: 10 | mid_channels = out_channels 11 | self.double_conv = nn.Sequential( 12 | nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1), 13 | nn.BatchNorm2d(mid_channels), 14 | nn.ReLU(inplace=True), 15 | nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1), 16 | nn.BatchNorm2d(out_channels), 17 | nn.ReLU(inplace=True) 18 | ) 19 | 20 | def forward(self, x): 21 | return self.double_conv(x) 22 | 23 | 24 | class ConvLayer(nn.Module): 25 | def __init__(self, in_channels, out_channels): 26 | super(ConvLayer, self).__init__() 27 | self.conv_layer = nn.Sequential( 28 | nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), 29 | nn.BatchNorm2d(out_channels), 30 | nn.ReLU(inplace=True) 31 | ) 32 | 33 | def forward(self, x): 34 | return self.conv_layer(x) 35 | 36 | 37 | class Down(nn.Module): 38 | def __init__(self, in_channels, out_channels): 39 | super().__init__() 40 | self.maxpool_conv = nn.Sequential( 41 | nn.MaxPool2d(2), 42 | DoubleConv(in_channels, out_channels) 43 | ) 44 | 45 | def forward(self, x): 46 | return self.maxpool_conv(x) 47 | 48 | 49 | class Up(nn.Module): 50 | def __init__(self, in_channels, out_channels, bilinear=True): 51 | super().__init__() 52 | if bilinear: 53 | self.up = nn.Upsample( 54 | scale_factor=2, mode='bilinear', align_corners=True 55 | ) 56 | self.conv = DoubleConv(in_channels, out_channels, in_channels // 2) 57 | else: 58 | self.up = nn.ConvTranspose2d( 59 | in_channels, in_channels // 2, kernel_size=2, stride=2 60 | ) 61 | self.conv = DoubleConv(in_channels, out_channels) 62 | 63 | def forward(self, x1, x2): 64 | x1 = self.up(x1) 65 | diffY = x2.size()[2] - x1.size()[2] 66 | diffX = x2.size()[3] - x1.size()[3] 67 | 68 | x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, 69 | diffY // 2, diffY - diffY // 2]) 70 | x = torch.cat([x2, x1], dim=1) 71 | return self.conv(x) 72 | 73 | 74 | class OutConv(nn.Module): 75 | def __init__(self, in_channels, out_channels): 76 | super(OutConv, self).__init__() 77 | self.conv = nn.Conv2d( 78 | in_channels, out_channels, kernel_size=1 79 | ) 80 | 81 | def forward(self, x): 82 | return self.conv(x) 83 | 84 | 85 | class UNet(nn.Module): 86 | def __init__(self, n_channels, n_classes, bilinear=True): 87 | super(UNet, self).__init__() 88 | self.n_channels = n_channels 89 | self.n_classes = n_classes 90 | self.bilinear = bilinear 91 | 92 | self.first_layer = DoubleConv(n_channels, 64) 93 | self.down1 = Down(64, 128) 94 | self.down2 = Down(128, 256) 95 | self.down3 = Down(256, 512) 96 | 97 | factor = 2 if bilinear else 1 98 | self.down4 = Down(512, 1024 // factor) 99 | 100 | self.up1 = Up(1024, 512 // factor, bilinear) 101 | self.up2 = Up(512, 256 // factor, bilinear) 102 | self.up3 = Up(256, 128 // factor, bilinear) 103 | self.up4 = Up(128, 64, bilinear) 104 | 105 | self.out = OutConv(64, self.n_classes) 106 | 107 | def forward(self, x): 108 | x1 = self.first_layer(x) 109 | x2 = self.down1(x1) 110 | x3 = self.down2(x2) 111 | x4 = self.down3(x3) 112 | x5 = self.down4(x4) 113 | x = self.up1(x5, x4) 114 | x = self.up2(x, x3) 115 | x = self.up3(x, x2) 116 | x = self.up4(x, x1) 117 | out = self.out(x) 118 | return out 119 | -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | import cv2 4 | import os 5 | 6 | import torch 7 | import torch.nn.functional as F 8 | from torchvision import transforms 9 | 10 | from data_utils.preprocessing import preprocess 11 | from data_utils import colors 12 | from models.unet import UNet 13 | 14 | colors_from_hex = { 15 | "0": (255, 255, 255), # background 16 | "1": colors.hex_to_rgb('#ff0000'), 17 | "2": colors.hex_to_rgb('#0037ff'), 18 | "3": colors.hex_to_rgb('#f900ff') 19 | } 20 | 21 | hex_labels = { 22 | "0": 'none', 23 | "1": '#ff0000', 24 | "2": '#0037ff', 25 | "3": '#f900ff' 26 | 27 | } 28 | 29 | category_labels = { 30 | "0": 'none', 31 | "1": 'Houses', 32 | "2": 'Buildings', 33 | "3": 'Sheds/Garages' 34 | 35 | } 36 | 37 | 38 | def predict_on_image(net, src_img, device, thresh=0.6): 39 | net.eval() 40 | 41 | img = torch.from_numpy(preprocess(src_img)) # hack 42 | 43 | img = img.unsqueeze(0) 44 | img = img.to(device=device, dtype=torch.float32) 45 | 46 | with torch.no_grad(): 47 | out = net(img) # tensor: [1, n_classes, height, width] 48 | 49 | if net.n_classes > 1: 50 | probs = F.softmax(out, dim=1) 51 | else: 52 | probs = torch.sigmoid(out) 53 | 54 | probs = probs.squeeze(0) 55 | 56 | tf = transforms.Compose( 57 | [ 58 | transforms.ToPILImage(), 59 | transforms.ToTensor() 60 | ] 61 | ) 62 | 63 | probs = tf(probs.cpu()) 64 | mask = probs.squeeze().cpu().numpy() # (n_classes, height, width) 65 | 66 | return mask > thresh 67 | 68 | 69 | def decode_seg_map(image) -> np.ndarray: 70 | """decode generated segmentation map into 3 channel RGB image.""" 71 | 72 | h, w, n_labels = image.shape 73 | rgb_mask = np.ones((h, w, 3), dtype=np.uint8) * 255 74 | 75 | for label in range(1, n_labels): 76 | idx = np.where(image[:, :, label].astype(int) == 1) 77 | rgb_mask[idx] = colors_from_hex[str(label)] 78 | 79 | return rgb_mask 80 | 81 | 82 | def prediction_to_json(image_path, chkp_path, net=None) -> dict: 83 | """ 84 | Convert mask prediction to json. The format matches the format in the training annotation data: 85 | 86 | {'filename':file_name, 'labels': 87 | [{'name': label_name, 'annotations': [{'id':some_unique_integer_id, 'segmentation':[x,y,x,y,x,y....]} 88 | ....] } 89 | ....] 90 | } 91 | """ 92 | file_name = os.path.basename(image_path) 93 | annotation = {'filename': file_name, 'labels': []} 94 | 95 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 96 | if not net: 97 | net = UNet(n_channels=3, n_classes=4) 98 | 99 | net.to(device=device) 100 | net.load_state_dict( 101 | torch.load(chkp_path, map_location=device) 102 | ) 103 | 104 | img = Image.open(image_path) 105 | 106 | msk = predict_on_image(net=net, device=device, src_img=img) 107 | msk = msk.transpose((1, 2, 0)) 108 | 109 | h, w, n_labels = msk.shape 110 | rgb_mask = np.ones((h, w, 3), dtype=np.uint8) 111 | annotation['height'] = h 112 | annotation['width'] = w 113 | 114 | for label in range(1, n_labels): 115 | color = hex_labels[str(label)] 116 | category = category_labels[str(label)] 117 | c_label = {'color': color, 'name': category, 'annotations': []} 118 | 119 | label_mask = msk[:, :, label].astype(int).astype(np.uint8) 120 | contours, hierarchy = cv2.findContours(label_mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) 121 | 122 | for contour in contours: 123 | vector_points = [] 124 | for x, y in contour.reshape((len(contour), 2)): 125 | vector_points += [float(x), float(y)] 126 | 127 | c_label['annotations'].append({'segmentation': vector_points}) 128 | 129 | idx = np.where(msk[:, :, label].astype(int) == 1) 130 | rgb_mask[idx] = colors_from_hex[str(label)] 131 | 132 | annotation['labels'].append(c_label) 133 | 134 | return annotation 135 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | pillow 3 | opencv-python-headless 4 | torch==1.7.0 5 | torchvision==0.5.0 6 | tensorboard 7 | tqdm -------------------------------------------------------------------------------- /run_training.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from models.unet import UNet 4 | from training import training_loop 5 | from data_utils.common import EasyDict 6 | 7 | 8 | def setup_run_arguments(): 9 | args = EasyDict() 10 | args.epochs = 100 11 | args.batch = 4 12 | args.val_percent = 0.2 13 | args.n_classes = 4 14 | args.n_channels = 3 15 | args.num_workers = 8 16 | 17 | args.learning_rate = 0.001 18 | args.weight_decay = 1e-8 19 | args.momentum = 0.9 20 | args.save_cp = True 21 | args.loss = "CrossEntropy" 22 | 23 | args.checkpoint_path = 'checkpoints/' 24 | args.image_dir = 'data/train/images' 25 | args.mask_dir = 'data/train/masks' 26 | 27 | args.from_pretrained = False 28 | 29 | return args 30 | 31 | 32 | def train(): 33 | args = setup_run_arguments() 34 | 35 | # args = parse_args() 36 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 37 | print(f"[INFO] Initializing UNet-model using: {device}") 38 | 39 | net = UNet(n_channels=args.n_channels, n_classes=args.n_classes, bilinear=True) 40 | 41 | if args.from_pretrained: 42 | net.load_state_dict(torch.load(args.from_pretrained, map_location=device)) 43 | 44 | net.to(device=device) 45 | 46 | training_loop.run(network=net, 47 | epochs=args.epochs, 48 | batch_size=args.batch_size, 49 | lr=args.learning_rate, 50 | device=device, 51 | n_classes=args.n_classes, 52 | val_percent=args.val_percent, 53 | image_dir=args.image_dir, 54 | mask_dir=args.mask_dir, 55 | checkpoint_path=args.checkpoint_path, 56 | loss=args.loss, 57 | num_workers=args.num_workers 58 | ) 59 | 60 | 61 | if __name__ == "__main__": 62 | train() 63 | -------------------------------------------------------------------------------- /training/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/obravo7/satellite-segmentation-pytorch/299d9acb959adb4059f25a2aa965aaf9e3eac562/training/__init__.py -------------------------------------------------------------------------------- /training/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from torch.utils.data import Dataset 4 | from glob import glob 5 | from PIL import Image 6 | import torch 7 | 8 | from data_utils.preprocessing import preprocess 9 | 10 | 11 | class InferenceDataset(Dataset): 12 | """Basic Pytorch datatset""" 13 | def __init__(self, image_dir, masks_dir, n_classes, augmentation=None): 14 | self.image_dir = image_dir 15 | self.masks_dir = masks_dir 16 | self.n_classes = n_classes 17 | self.augmentations = augmentation 18 | 19 | self.ids = [os.path.splitext(file)[0] for file in os.listdir(self.image_dir) 20 | if not file.startswith('.')] 21 | 22 | def __len__(self): 23 | return len(self.ids) 24 | 25 | def __getitem__(self, idx): 26 | image_name = self.ids[idx] 27 | 28 | # get file path regardless of extension 29 | image_path = glob(os.path.join(self.image_dir, f'{image_name}.*')) 30 | mask_path = glob(os.path.join(self.masks_dir, f"{image_name}.*")) 31 | 32 | image = Image.open(image_path[0]) 33 | mask = Image.open(mask_path[0]) # convert to one-hot 34 | 35 | assert image.size == mask.size, f"Image and mask should be the same size, but are {image.size}, and {mask.size}" 36 | 37 | image = preprocess(image) 38 | 39 | mask = np.array(mask)[..., np.newaxis].transpose((2, 0, 1)) 40 | 41 | n_labels = np.unique(mask) 42 | assert len(n_labels) == self.n_classes, f"image has too many labels: {image_path[0]}" 43 | 44 | mask = mask.transpose((2, 0, 1)) 45 | 46 | if self.augmentations: 47 | augment = self.augmentations(image=image, mask=mask) 48 | image = augment['image'] 49 | mask = augment['mask'] 50 | response = { 51 | "image": image, 52 | "mask": mask 53 | } 54 | else: 55 | response = { 56 | "image": torch.from_numpy(image).type(torch.FloatTensor), 57 | "mask": torch.from_numpy(mask).type(torch.FloatTensor) 58 | } 59 | 60 | return response 61 | -------------------------------------------------------------------------------- /training/loss.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import torch.nn.functional as F 4 | 5 | 6 | class MultiClassCriterion(nn.Module): 7 | def __init__(self, loss_type='CrossEntropy', **kwargs): 8 | super().__init__() 9 | if loss_type == 'CrossEntropy': 10 | self.criterion = nn.CrossEntropyLoss(**kwargs) 11 | elif loss_type == 'Focal': 12 | self.criterion = FocalLoss(**kwargs) 13 | elif loss_type == 'Lovasz': 14 | self.criterion = LovaszSoftmax(**kwargs) 15 | elif loss_type == 'SoftIOU': 16 | self.criterion = SoftIoULoss(**kwargs) 17 | else: 18 | raise NotImplementedError 19 | 20 | def forward(self, preds, labels): 21 | loss = self.criterion(preds, labels) 22 | return loss 23 | 24 | 25 | class FocalLoss(nn.Module): 26 | """ 27 | https://arxiv.org/abs/1708.02002 28 | """ 29 | def __init__(self, alpha=0.5, gamma=2, weight=None, ignore_index=255): 30 | super().__init__() 31 | self.alpha = alpha 32 | self.gamma = gamma 33 | self.weight = weight 34 | self.ignore_index = ignore_index 35 | self.ce_fn = nn.CrossEntropyLoss(weight=self.weight, ignore_index=self.ignore_index) 36 | 37 | def forward(self, preds, labels): 38 | logpt = -self.ce_fn(preds, labels) 39 | pt = torch.exp(logpt) 40 | loss = -((1 - pt) ** self.gamma) * self.alpha * logpt 41 | return loss 42 | 43 | 44 | class LovaszSoftmax(nn.Module): 45 | """ 46 | Multi-class Lovasz-Softmax loss 47 | logits: [B, C, H, W] class logits at each prediction (between -\infty and \infty) 48 | labels: [B, H, W] Tensor, ground truth labels (between 0 and C - 1) 49 | ignore_index: void class labels 50 | only_present: average only on classes present in ground truth 51 | """ 52 | def __init__(self, ignore_index=None, only_present=True): 53 | super().__init__() 54 | self.ignore_index = ignore_index 55 | self.only_present = only_present 56 | 57 | def forward(self, logits, labels): 58 | probas = F.softmax(logits, dim=1) 59 | total_loss = 0 60 | batch_size = logits.shape[0] 61 | for prb, lbl in zip(probas, labels): 62 | total_loss += lovasz_softmax_flat(prb, lbl, self.ignore_index, self.only_present) 63 | return total_loss / batch_size 64 | 65 | 66 | def lovasz_grad(gt_sorted): 67 | """ 68 | Computes gradient of the Lovasz extension w.r.t sorted errors 69 | See Alg. 1 in paper 70 | """ 71 | p = len(gt_sorted) 72 | gts = gt_sorted.sum() 73 | intersection = gts - gt_sorted.float().cumsum(0) 74 | union = gts + (1 - gt_sorted).float().cumsum(0) 75 | jaccard = 1 - intersection / union 76 | if p > 1: # cover 1-pixel case 77 | jaccard[1:p] = jaccard[1:p] - jaccard[0:-1] 78 | return jaccard 79 | 80 | 81 | def lovasz_softmax_flat(prb, lbl, ignore_index, only_present): 82 | """ 83 | Multi-class Lovasz-Softmax loss 84 | prb: [P, C] Variable, class probabilities at each prediction (between 0 and 1) 85 | lbl: [P] Tensor, ground truth labels (between 0 and C - 1) 86 | ignore_index: void class labels 87 | only_present: average only on classes present in ground truth 88 | """ 89 | C = prb.shape[0] 90 | prb = prb.permute(1, 2, 0).contiguous().view(-1, C) # H * W, C 91 | lbl = lbl.view(-1) # H * W 92 | if ignore_index is not None: 93 | mask = lbl != ignore_index 94 | if mask.sum() == 0: 95 | return torch.mean(prb * 0) 96 | prb = prb[mask] 97 | lbl = lbl[mask] 98 | 99 | total_loss = 0 100 | cnt = 0 101 | for c in range(C): 102 | fg = (lbl == c).float() # foreground for class c 103 | if only_present and fg.sum() == 0: 104 | continue 105 | errors = (fg - prb[:, c]).abs() 106 | errors_sorted, perm = torch.sort(errors, dim=0, descending=True) 107 | perm = perm.data 108 | fg_sorted = fg[perm] 109 | total_loss += torch.dot(errors_sorted, lovasz_grad(fg_sorted)) 110 | cnt += 1 111 | return total_loss / cnt 112 | 113 | 114 | class SoftIoULoss(nn.Module): 115 | def __init__(self, n_classes): 116 | super(SoftIoULoss, self).__init__() 117 | self.n_classes = n_classes 118 | 119 | @staticmethod 120 | def to_one_hot(tensor, n_classes): 121 | n, h, w = tensor.size() 122 | one_hot = torch.zeros(n, n_classes, h, w).scatter_(1, tensor.view(n, 1, h, w), 1) 123 | return one_hot 124 | 125 | def forward(self, logit, target): 126 | # logit => N x Classes x H x W 127 | # target => N x H x W 128 | 129 | N = len(logit) 130 | 131 | pred = F.softmax(logit, dim=1) 132 | target_onehot = self.to_one_hot(target, self.n_classes) 133 | 134 | # Numerator Product 135 | inter = pred * target_onehot 136 | # Sum over all pixels N x C x H x W => N x C 137 | inter = inter.view(N, self.n_classes, -1).sum(2) 138 | 139 | # Denominator 140 | union = pred + target_onehot - (pred * target_onehot) 141 | # Sum over all pixels N x C x H x W => N x C 142 | union = union.view(N, self.n_classes, -1).sum(2) 143 | 144 | loss = inter / (union + 1e-16) 145 | 146 | # Return average loss over classes and batch 147 | return -loss.mean() 148 | 149 | 150 | class DiceCoeff(torch.autograd.Function): 151 | """Dice coeff for individual examples""" 152 | 153 | def forward(self, input, target): 154 | self.save_for_backward(input, target) 155 | eps = 0.0001 156 | self.inter = torch.dot(input.view(-1), target.view(-1)) 157 | self.union = torch.sum(input) + torch.sum(target) + eps 158 | 159 | t = (2 * self.inter.float() + eps) / self.union.float() 160 | return t 161 | 162 | # This function has only a single output, so it gets only one gradient 163 | def backward(self, grad_output): 164 | 165 | input, target = self.saved_variables 166 | grad_input = grad_target = None 167 | 168 | if self.needs_input_grad[0]: 169 | grad_input = grad_output * 2 * (target * self.union - self.inter) \ 170 | / (self.union * self.union) 171 | if self.needs_input_grad[1]: 172 | grad_target = None 173 | 174 | return grad_input, grad_target 175 | 176 | 177 | def dice_coeff(input, target): 178 | """wrapper for Dice coeff that works with batches""" 179 | if input.is_cuda: 180 | s = torch.FloatTensor(1).cuda().zero_() 181 | else: 182 | s = torch.FloatTensor(1).zero_() 183 | 184 | for i, c in enumerate(zip(input, target)): 185 | s = s + DiceCoeff().forward(c[0], c[1]) 186 | 187 | x = len(input) 188 | return s / (x + 1) 189 | 190 | -------------------------------------------------------------------------------- /training/training_loop.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | import os 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from torch import optim 8 | from torch.utils.data import DataLoader, random_split 9 | from torch.utils.tensorboard import SummaryWriter 10 | 11 | from training.loss import MultiClassCriterion, dice_coeff 12 | from training.dataset import InferenceDataset 13 | 14 | 15 | def run( 16 | image_dir, 17 | mask_dir, 18 | network, 19 | device, 20 | n_classes, 21 | checkpoint_path, 22 | epochs=50, 23 | batch_size=4, 24 | val_percent=0.1, 25 | lr=0.001, 26 | momentum=0.9, 27 | weight_decay=1e-8, 28 | save_cp=True, 29 | loss="CrossEntropy", 30 | num_workers=8, 31 | ): 32 | 33 | dataset = InferenceDataset(image_dir=image_dir, masks_dir=mask_dir, n_classes=n_classes) 34 | n_val = int(len(dataset) * val_percent) # validation set size 35 | n_train = len(dataset) - n_val # training set size 36 | 37 | train_data, val_data = random_split(dataset, [n_train, n_val]) 38 | 39 | train_loader = DataLoader(train_data, 40 | batch_size=batch_size, 41 | shuffle=True, 42 | num_workers=num_workers, 43 | pin_memory=True 44 | ) 45 | 46 | val_loader = DataLoader(val_data, 47 | batch_size=batch_size, 48 | shuffle=True, 49 | num_workers=num_workers, 50 | pin_memory=True, 51 | drop_last=True 52 | ) 53 | 54 | writer = SummaryWriter(comment=f'LR_{lr}_BS_{batch_size}') 55 | global_step = 0 56 | 57 | # optimizer = optim.RMSprop(net.parameters(), lr=lr, weight_decay=1e-8, momentum=0.9) 58 | optimizer = optim.Adam(network.parameters(), lr=lr, weight_decay=weight_decay) 59 | scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min' if network.n_classes > 1 else 'max', patience=2) 60 | 61 | if network.n_classes > 1: 62 | # loss function: Categorical cross entropy 63 | print(f'[INFO] using {loss} loss...') 64 | # criterion = nn.CrossEntropyLoss() 65 | criterion = MultiClassCriterion(loss_type=loss) 66 | else: 67 | # binary cross entropy, where only two classes exist, including the background 68 | print('[INFO] using binary cross entropy...') 69 | criterion = nn.BCEWithLogitsLoss() 70 | 71 | for epoch in range(epochs): 72 | 73 | network.train() 74 | 75 | epoch_loss = 0 76 | with tqdm(total=n_train, desc=f"Epoch {epoch + 1}/{epochs}", unit='image') as progress_bar: 77 | 78 | for batch in train_loader: 79 | image = batch['image'] 80 | target = batch['mask'] 81 | 82 | image = image.to(device=device, dtype=torch.float32) 83 | mask_type = torch.float32 if network.n_classes == 1 else torch.long 84 | target = target.to(device=device, dtype=mask_type).squeeze(1) 85 | 86 | mask_pred = network(image) 87 | 88 | # prediction should be a FloadTensor of shape (batch, n_classes, h, w) 89 | # target should be a LongTensor of shape (batch, h, w) 90 | loss = criterion(mask_pred, target=target) 91 | 92 | epoch_loss += loss.item() 93 | writer.add_scalar('Loss/train', loss.item(), global_step) 94 | 95 | # update progress bar 96 | progress_bar.set_postfix(**{'loss (batch)': loss.item()}) 97 | 98 | optimizer.zero_grad() 99 | loss.backward() 100 | nn.utils.clip_grad_value_(network.parameters(), 0.1) 101 | optimizer.step() 102 | 103 | progress_bar.update(image.shape[0]) 104 | global_step += 1 105 | if global_step % (n_train // (10 * batch_size)) == 0: 106 | for tag, value in network.named_parameters(): 107 | tag = tag.replace('.', '/') 108 | writer.add_histogram('weights/' + tag, value.data.cpu().numpy(), global_step) 109 | writer.add_histogram('grads/' + tag, value.grad.data.cpu().numpy(), global_step) 110 | 111 | validation_score = eval_net(network, val_loader, device) 112 | scheduler.step(validation_score) 113 | writer.add_scalar('learning_rate', optimizer.param_groups[0]['lr'], global_step) 114 | 115 | if network.n_classes > 1: 116 | writer.add_scalar('Loss/test', validation_score, global_step) 117 | else: 118 | writer.add_scalar('Dice/test', validation_score, global_step) 119 | 120 | writer.add_images('images', image, global_step) 121 | if network.n_classes == 1: 122 | writer.add_images('masks/true', mask_pred, global_step) 123 | writer.add_images('mask/pred', torch.sigmoid(mask_pred) > 0.5, global_step) 124 | 125 | if save_cp: 126 | os.makedirs(checkpoint_path, exist_ok=True) 127 | torch.save(network.state_dict(), 128 | os.path.join(checkpoint_path, f'CP_epoch{epoch + 1}.pth') 129 | ) 130 | 131 | writer.close() 132 | 133 | torch.save(network.state_dict(), 134 | os.path.join(checkpoint_path, f'CP_epochs-{epochs}-final.pth') 135 | ) 136 | 137 | 138 | def eval_net(net, loader, device): 139 | """Evaluation without the densecrf with the dice coefficient""" 140 | net.eval() 141 | mask_type = torch.float32 if net.n_classes == 1 else torch.long 142 | n_val = len(loader) # the number of batch 143 | tot = 0 144 | 145 | with tqdm(total=n_val, desc='Validation round', unit='batch', leave=False) as pbar: 146 | for batch in loader: 147 | imgs, true_masks = batch['image'], batch['mask'] 148 | imgs = imgs.to(device=device, dtype=torch.float32) 149 | true_masks = true_masks.to(device=device, dtype=mask_type) 150 | 151 | with torch.no_grad(): 152 | mask_pred = net(imgs) 153 | 154 | if net.n_classes > 1: 155 | tot += F.cross_entropy(mask_pred, true_masks.squeeze(1)).item() 156 | else: 157 | pred = torch.sigmoid(mask_pred) 158 | pred = (pred > 0.5).float() 159 | tot += dice_coeff(pred, true_masks).item() 160 | pbar.update() 161 | 162 | net.train() 163 | return tot / n_val 164 | --------------------------------------------------------------------------------