├── README.md ├── data_segmentation └── challenge-2019-classes-description-segmentable.csv ├── img └── mask_rcnn_prediction_example.jpg ├── inference_example.py └── training ├── csv_generator.py ├── eval_map_generator.py └── train_maskrcnn.py /README.md: -------------------------------------------------------------------------------- 1 | ## Keras Mask R-CNN for Open Images Challenge 2019: Instance Segmentation 2 | 3 | Repository contains Mask R-CNN models which were trained on Open Images Dataset during Kaggle competition: 4 | https://www.kaggle.com/c/open-images-2019-instance-segmentation/leaderboard 5 | 6 | Repository contains the following: 7 | * Pre-trained Mask R-CNN models (ResNet50, ResNet101 and ResNet152 backbones) 8 | * Example code to get predictions with these models for any set of images 9 | * Code to train model based on Keras Mask R-CNN and OID dataset 10 | 11 | ## Requirements 12 | 13 | Python 3.\*, Keras 2.\*, [keras-maskrcnn 0.2.2](https://github.com/fizyr/keras-maskrcnn), cv2, numpy, pandas 14 | 15 | ## Pretrained models 16 | 17 | There are 3 Mask R-CNN models based on ResNet50, ResNet101 and ResNet152 for [300 classes](https://github.com/ZFTurbo/Keras-Mask-RCNN-for-Open-Images-2019-Instance-Segmentation/blob/master/data_segmentation/challenge-2019-classes-description-segmentable.csv). 18 | 19 | | Backbone | Image Size (px) | Model | Small validation mAP | LB (Public) | 20 | | --- | --- | --- | --- | --- | 21 | | ResNet50 | 800 - 1024 | [521 MB](https://github.com/ZFTurbo/Keras-Mask-RCNN-for-Open-Images-2019-Instance-Segmentation/releases/download/v1.0/mask_rcnn_resnet50_oid_v1.0.h5) | 0.5745 | 0.4259 | 22 | | ResNet101 | 800 - 1024 | [739 MB](https://github.com/ZFTurbo/Keras-Mask-RCNN-for-Open-Images-2019-Instance-Segmentation/releases/download/v1.0/mask_rcnn_resnet101_oid_v1.0.h5) | 0.5917 | 0.4345 | 23 | | ResNet152 | 800 - 1024 | [918 MB](https://github.com/ZFTurbo/Keras-Mask-RCNN-for-Open-Images-2019-Instance-Segmentation/releases/download/v1.0/mask_rcnn_resnet152_oid_v1.0.h5) | 0.5899 | 0.4404 | 24 | 25 | * Model - can be used to resume training or can be used as pretrain for your own instance segmentation model 26 | 27 | ## Inference 28 | 29 | Simple example can be found here: [inference_example.py](https://github.com/ZFTurbo/Keras-Mask-RCNN-for-Open-Images-2019-Instance-Segmentation/blob/master/inference_example.py) 30 | 31 | ![Example of predictions](https://github.com/ZFTurbo/Keras-Mask-RCNN-for-Open-Images-2019-Instance-Segmentation/blob/master/img/mask_rcnn_prediction_example.jpg) 32 | 33 | ## Training 34 | 35 | For training you need to download OID dataset (~500 GB images): https://storage.googleapis.com/openimages/web/download.html 36 | You need all images, all masks and all CSV-files related to Instance Segmentation track. 37 | 38 | Then run script (change parameters and file locations at the bottom of script): 39 | * [training/train_maskrcnn.py](https://github.com/ZFTurbo/Keras-Mask-RCNN-for-Open-Images-2019-Instance-Segmentation/blob/master/training/train_maskrcnn.py) 40 | -------------------------------------------------------------------------------- /data_segmentation/challenge-2019-classes-description-segmentable.csv: -------------------------------------------------------------------------------- 1 | /m/01bms0,Screwdriver 2 | /m/03jbxj,Light switch 3 | /m/0jy4k,Doughnut 4 | /m/09gtd,Toilet paper 5 | /m/01j5ks,Wrench 6 | /m/01k6s3,Toaster 7 | /m/05ctyq,Tennis ball 8 | /m/015x5n,Radish 9 | /m/0jwn_,Pomegranate 10 | /m/02zt3,Kite 11 | /m/05_5p_0,Table tennis racket 12 | /m/03qrc,Hamster 13 | /m/01btn,Barge 14 | /m/02f9f_,Shower 15 | /m/01m4t,Printer 16 | /m/01x3jk,Snowmobile 17 | /m/01pns0,Fire hydrant 18 | /m/01lcw4,Limousine 19 | /m/084zz,Whale 20 | /m/0fx9l,Microwave oven 21 | /m/0cjs7,Asparagus 22 | /m/096mb,Lion 23 | /m/02d1br,Spatula 24 | /m/07dd4,Torch 25 | /m/02rgn06,Volleyball 26 | /m/012n7d,Ambulance 27 | /m/01_5g,Chopsticks 28 | /m/0dq75,Raccoon 29 | /m/01f8m5,Blue jay 30 | /m/04g2r,Lynx 31 | /m/029b3,Dice 32 | /m/047j0r,Filing cabinet 33 | /m/0hdln,Ruler 34 | /m/03bbps,Power plugs and sockets 35 | /m/0jg57,Bell pepper 36 | /m/0lt4_,Binoculars 37 | /m/01f91_,Pretzel 38 | /m/01b9xk,Hot dog 39 | /m/04ylt,Missile 40 | /m/043nyj,Common fig 41 | /m/015wgc,Croissant 42 | /m/03m3vtv,Adhesive tape 43 | /m/02tsc9,Slow cooker 44 | /m/0h8n6f9,Dog bed 45 | /m/03q5t,Harpsichord 46 | /m/04p0qw,Billiard table 47 | /m/0pcr,Alpaca 48 | /m/02l8p9,Harbor seal 49 | /m/0388q,Grape 50 | /m/05bm6,Nail 51 | /m/02w3r3,Paper towel 52 | /m/046dlr,Alarm clock 53 | /m/02g30s,Guacamole 54 | /m/01h8tj,Starfish 55 | /m/0898b,Zebra 56 | /m/076bq,Segway 57 | /m/0120dh,Sea turtle 58 | /m/01lsmm,Scissors 59 | /m/03d443,Rhinoceros 60 | /m/04c0y,Kangaroo 61 | /m/0449p,Jaguar 62 | /m/0c29q,Leopard 63 | /m/04h8sr,Dumbbell 64 | /m/0frqm,Envelope 65 | /m/02cvgx,Winter melon 66 | /m/01fh4r,Teapot 67 | /m/01x_v,Camel 68 | /m/0d20w4,Beaker 69 | /m/01dxs,Brown bear 70 | /m/09g1w,Toilet 71 | /m/0kmg4,Teddy bear 72 | /m/0584n8,Briefcase 73 | /m/02pv19,Stop sign 74 | /m/07dm6,Tiger 75 | /m/0fbw6,Cabbage 76 | /m/03bk1,Giraffe 77 | /m/0633h,Polar bear 78 | /m/0by6g,Shark 79 | /m/06mf6,Rabbit 80 | /m/04tn4x,Swim cap 81 | /m/0h8ntjv,Pressure cooker 82 | /m/058qzx,Kitchen knife 83 | /m/06pcq,Submarine sandwich 84 | /m/01kb5b,Flashlight 85 | /m/05z6w,Penguin 86 | /m/078jl,Snake 87 | /m/027pcv,Zucchini 88 | /m/01h44,Bat 89 | /m/03y6mg,Food processor 90 | /m/05n4y,Ostrich 91 | /m/0gd36,Sea lion 92 | /m/03fj2,Goldfish 93 | /m/0bwd_0j,Elephant 94 | /m/09rvcxw,Rocket 95 | /m/04rmv,Mouse 96 | /m/0_cp5,Oyster 97 | /m/06_72j,Digital clock 98 | /m/0cn6p,Otter 99 | /m/02hj4,Dolphin 100 | /m/0420v5,Punching bag 101 | /m/0h8lkj8,Corded phone 102 | /m/0h8my_4,Tennis racket 103 | /m/01dwwc,Pancake 104 | /m/0fldg,Mango 105 | /m/09f_2,Crocodile 106 | /m/01dwsz,Waffle 107 | /m/020lf,Computer mouse 108 | /m/03s_tn,Kettle 109 | /m/02zvsm,Tart 110 | /m/029bxz,Oven 111 | /m/09qck,Banana 112 | /m/0cd4d,Cheetah 113 | /m/06j2d,Raven 114 | /m/04v6l4,Frying pan 115 | /m/061_f,Pear 116 | /m/0306r,Fox 117 | /m/06_fw,Skateboard 118 | /m/0wdt60w,Rugby ball 119 | /m/0kpqd,Watermelon 120 | /m/0l14j_,Flute 121 | /m/0ccs93,Canary 122 | /m/03c7gz,Door handle 123 | /m/06ncr,Saxophone 124 | /m/01j3zr,Burrito 125 | /m/01s55n,Suitcase 126 | /m/02p3w7d,Roller skates 127 | /m/02gzp,Dagger 128 | /m/0dkzw,Seat belt 129 | /m/0174k2,Washing machine 130 | /m/01xs3r,Jet ski 131 | /m/02jfl0,Sombrero 132 | /m/068zj,Pig 133 | /m/03v5tg,Drinking straw 134 | /m/0dj6p,Peach 135 | /m/011k07,Tortoise 136 | /m/0162_1,Towel 137 | /m/0bh9flk,Tablet computer 138 | /m/015x4r,Cucumber 139 | /m/0dbzx,Mule 140 | /m/05vtc,Potato 141 | /m/09ld4,Frog 142 | /m/01dws,Bear 143 | /m/04h7h,Lighthouse 144 | /m/0176mf,Belt 145 | /m/03g8mr,Baseball bat 146 | /m/0dv9c,Racket 147 | /m/06y5r,Sword 148 | /m/01fb_0,Bagel 149 | /m/03fwl,Goat 150 | /m/04m9y,Lizard 151 | /m/0gv1x,Parrot 152 | /m/09d5_,Owl 153 | /m/0jly1,Turkey 154 | /m/01xqw,Cello 155 | /m/04ctx,Knife 156 | /m/0gxl3,Handgun 157 | /m/0fj52s,Carrot 158 | /m/0cdn1,Hamburger 159 | /m/0hqkz,Grapefruit 160 | /m/02jz0l,Tap 161 | /m/07clx,Tea 162 | /m/0cnyhnx,Bull 163 | /m/09dzg,Turtle 164 | /m/04yqq2,Bust 165 | /m/08pbxl,Monkey 166 | /m/084rd,Wok 167 | /m/0hkxq,Broccoli 168 | /m/054fyh,Pitcher 169 | /m/02d9qx,Whiteboard 170 | /m/071qp,Squirrel 171 | /m/08hvt4,Jug 172 | /m/01dy8n,Woodpecker 173 | /m/0663v,Pizza 174 | /m/019w40,Surfboard 175 | /m/03m3pdh,Sofa bed 176 | /m/07bgp,Sheep 177 | /m/0c06p,Candle 178 | /m/01tcjp,Muffin 179 | /m/021mn,Cookie 180 | /m/014j1m,Apple 181 | /m/05kyg_,Chest of drawers 182 | /m/016m2d,Skull 183 | /m/09b5t,Chicken 184 | /m/0703r8,Loveseat 185 | /m/03grzl,Baseball glove 186 | /m/05r5c,Piano 187 | /m/0bjyj5,Waste container 188 | /m/02zn6n,Barrel 189 | /m/0dftk,Swan 190 | /m/0pg52,Taxi 191 | /m/09k_b,Lemon 192 | /m/05zsy,Pumpkin 193 | /m/0h23m,Sparrow 194 | /m/0cyhj_,Orange 195 | /m/07cmd,Tank 196 | /m/0l515,Sandwich 197 | /m/02vqfm,Coffee 198 | /m/01z1kdw,Juice 199 | /m/0242l,Coin 200 | /m/0k1tl,Pen 201 | /m/0gjkl,Watch 202 | /m/09csl,Eagle 203 | /m/0dbvp,Goose 204 | /m/0f6wt,Falcon 205 | /m/025nd,Christmas tree 206 | /m/0ftb8,Sunflower 207 | /m/02s195,Vase 208 | /m/01226z,Football 209 | /m/0ph39,Canoe 210 | /m/06k2mb,High heels 211 | /m/0cmx8,Spoon 212 | /m/02jvh9,Mug 213 | /m/01gkx_,Swimwear 214 | /m/09ddx,Duck 215 | /m/01yrx,Cat 216 | /m/07j87,Tomato 217 | /m/024g6,Cocktail 218 | /m/01x3z,Clock 219 | /m/025rp__,Cowboy hat 220 | /m/01cmb2,Miniskirt 221 | /m/01xq0k1,Cattle 222 | /m/07fbm7,Strawberry 223 | /m/01yx86,Bronze sculpture 224 | /m/034c16,Pillow 225 | /m/0dv77,Squash 226 | /m/015qff,Traffic light 227 | /m/03q5c7,Saucer 228 | /m/06bt6,Reptile 229 | /m/0fszt,Cake 230 | /m/05gqfk,Plastic bag 231 | /m/026qbn5,Studio couch 232 | /m/01599,Beer 233 | /m/02h19r,Scarf 234 | /m/02p5f1q,Coffee cup 235 | /m/081qc,Wine 236 | /m/052sf,Mushroom 237 | /m/01mqdt,Traffic sign 238 | /m/0dv5r,Camera 239 | /m/06m11,Rose 240 | /m/02crq1,Couch 241 | /m/080hkjn,Handbag 242 | /m/02fq_6,Fedora 243 | /m/01nq26,Sock 244 | /m/01m2v,Computer keyboard 245 | /m/050k8,Mobile phone 246 | /m/018xm,Ball 247 | /m/01j51,Balloon 248 | /m/03k3r,Horse 249 | /m/01b638,Boot 250 | /m/0ch_cf,Fish 251 | /m/01940j,Backpack 252 | /m/02wv6h6,Skirt 253 | /m/0h2r6,Van 254 | /m/09728,Bread 255 | /m/0174n1,Glove 256 | /m/0bt9lr,Dog 257 | /m/0cmf2,Airplane 258 | /m/04_sv,Motorcycle 259 | /m/0271t,Drink 260 | /m/0bt_c3,Book 261 | /m/07jdr,Train 262 | /m/0c9ph5,Flower 263 | /m/01lrl,Carnivore 264 | /m/039xj_,Human ear 265 | /m/0138tl,Toy 266 | /m/025dyy,Box 267 | /m/07r04,Truck 268 | /m/083wq,Wheel 269 | /m/0k5j,Aircraft 270 | /m/01bjv,Bus 271 | /m/0283dt1,Human mouth 272 | /m/06msq,Sculpture 273 | /m/01n4qj,Shirt 274 | /m/02dl1y,Hat 275 | /m/01jfm_,Vehicle registration plate 276 | /m/0342h,Guitar 277 | /m/02wbtzl,Sun hat 278 | /m/04dr76w,Bottle 279 | /m/0hf58v5,Luggage and bags 280 | /m/07mhn,Trousers 281 | /m/01bqk0,Bicycle wheel 282 | /m/01xyhv,Suit 283 | /m/04kkgm,Bowl 284 | /m/04yx4,Man 285 | /m/0fm3zh,Flowerpot 286 | /m/01c648,Laptop 287 | /m/01bl7v,Boy 288 | /m/06z37_,Picture frame 289 | /m/015p6,Bird 290 | /m/0k4j,Car 291 | /m/01bfm9,Shorts 292 | /m/03bt1vf,Woman 293 | /m/099ssp,Platter 294 | /m/01rkbr,Tie 295 | /m/05r655,Girl 296 | /m/079cl,Skyscraper 297 | /m/01g317,Person 298 | /m/03120,Flag 299 | /m/0fly7,Jeans 300 | /m/01d40f,Dress 301 | -------------------------------------------------------------------------------- /img/mask_rcnn_prediction_example.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZFTurbo/Keras-Mask-RCNN-for-Open-Images-2019-Instance-Segmentation/82b01e60cf734c1f0a9692163f1fb838215b5ea1/img/mask_rcnn_prediction_example.jpg -------------------------------------------------------------------------------- /inference_example.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | __author__ = 'ZFTurbo: https://kaggle.com/zfturbo' 3 | 4 | 5 | if __name__ == '__main__': 6 | import os 7 | gpu_use = 0 8 | print('GPU use: {}'.format(gpu_use)) 9 | os.environ["KERAS_BACKEND"] = "tensorflow" 10 | os.environ["CUDA_VISIBLE_DEVICES"] = "{}".format(gpu_use) 11 | 12 | 13 | from keras_maskrcnn import models 14 | import cv2 15 | import time 16 | import glob 17 | import pandas as pd 18 | import numpy as np 19 | import base64 20 | from pycocotools import mask as coco_mask 21 | import zlib 22 | 23 | 24 | def show_image(im, name='image'): 25 | cv2.imshow(name, im.astype(np.uint8)) 26 | cv2.waitKey(0) 27 | cv2.destroyAllWindows() 28 | 29 | 30 | def read_single_image(path): 31 | img = cv2.cvtColor(cv2.imread(path), cv2.COLOR_BGR2RGB) 32 | return img 33 | 34 | 35 | def get_class_arr(path, type='name'): 36 | s = pd.read_csv(path, names=['google_name', 'name'], header=None)[type].values 37 | return s 38 | 39 | 40 | def encode_binary_mask(mask): 41 | """Converts a binary mask into OID challenge encoding ascii text.""" 42 | 43 | # convert input mask to expected COCO API input -- 44 | mask_to_encode = np.expand_dims(mask, axis=2) 45 | mask_to_encode = np.asfortranarray(mask_to_encode) 46 | 47 | # RLE encode mask -- 48 | encoded_mask = coco_mask.encode(mask_to_encode)[0]["counts"] 49 | 50 | # compress and base64 encoding -- 51 | binary_str = zlib.compress(encoded_mask, zlib.Z_BEST_COMPRESSION) 52 | # binary_str = zlib.compress(encoded_mask, zlib.Z_BEST_SPEED) 53 | base64_str = base64.b64encode(binary_str) 54 | return base64_str 55 | 56 | 57 | def decode_binary_mask(mask, width, height): 58 | """Converts a binary mask into OID challenge encoding ascii text.""" 59 | 60 | compressed_mask = base64.b64decode(mask) 61 | rle_encoded_mask = zlib.decompress(compressed_mask) 62 | # print(rle_encoded_mask) 63 | decoding_dict = { 64 | 'size': [height, width], # [im_height, im_width], 65 | 'counts': rle_encoded_mask 66 | } 67 | mask_tensor = coco_mask.decode(decoding_dict) 68 | return mask_tensor 69 | 70 | 71 | def show_image_debug(draw, boxes, scores, labels, masks, classes): 72 | from keras_retinanet.utils.visualization import draw_box, draw_caption 73 | from keras_maskrcnn.utils.visualization import draw_mask 74 | from keras_retinanet.utils.colors import label_color 75 | 76 | # visualize detections 77 | limit_conf = 0.2 78 | for box, score, label, mask in zip(boxes, scores, labels, masks): 79 | # scores are sorted so we can break 80 | if score < limit_conf: 81 | break 82 | 83 | color = label_color(label) 84 | color_mask = (255, 0, 0) 85 | 86 | b = box.astype(int) 87 | draw_box(draw, b, color=color) 88 | 89 | mask = mask[:, :] 90 | draw_mask(draw, b, mask, color=color_mask) 91 | 92 | caption = "{} {:.3f}".format(classes[label], score) 93 | print(caption) 94 | draw_caption(draw, b, caption) 95 | draw = cv2.cvtColor(draw, cv2.COLOR_RGB2BGR) 96 | show_image(draw) 97 | # cv2.imwrite('debug.png', draw) 98 | 99 | 100 | def get_maskrcnn_single_predictions(model, input_image, classes, show_debug_images): 101 | from keras_retinanet.utils.image import preprocess_image, resize_image 102 | 103 | image_init = input_image.copy() 104 | 105 | # preprocess image for network 106 | image = preprocess_image(image_init) 107 | 108 | # Resize image 109 | image, image_scale = resize_image(image, min_side=800, max_side=1024) 110 | if show_debug_images: 111 | # copy to draw on 112 | draw, draw_scale = resize_image(image_init, min_side=800, max_side=1024) 113 | 114 | start = time.time() 115 | print('Image shape: {}'.format(image.shape)) 116 | img_rot = image.copy() 117 | img_rot = np.expand_dims(img_rot, axis=0) 118 | outputs = model.predict_on_batch(img_rot) 119 | 120 | # Save only needed mask 121 | boxes = outputs[-4][0].copy() 122 | masks = outputs[-1][0].copy() 123 | scores = outputs[-3][0].copy() 124 | labels = outputs[-2][0].copy() 125 | 126 | # Only save needed mask to save space 127 | masks_reduced = [] 128 | for i in range(masks.shape[0]): 129 | masks_reduced.append(masks[i, :, :, labels[i]]) 130 | masks = np.array(masks_reduced) 131 | 132 | print('Detections shape: {} {} {} {}'.format(boxes.shape, scores.shape, labels.shape, masks.shape)) 133 | print("Processing time: {:.2f} sec".format(time.time() - start)) 134 | 135 | if show_debug_images: 136 | boxes_init = boxes.copy() 137 | 138 | boxes[:, 0] /= image.shape[1] 139 | boxes[:, 2] /= image.shape[1] 140 | boxes[:, 1] /= image.shape[0] 141 | boxes[:, 3] /= image.shape[0] 142 | 143 | if show_debug_images: 144 | show_image_debug(draw.astype(np.uint8), boxes_init, scores, labels, masks, classes) 145 | 146 | return boxes, scores, labels, masks 147 | 148 | 149 | def get_preds_as_string(id, input_image, boxes, scores, labels, masks, classes_google): 150 | thr_keep_in_predictions = 0.01 151 | thr_mask = 0.5 152 | shape0, shape1 = input_image.shape[0], input_image.shape[1] 153 | s1 = '{},{},{},'.format(id, shape1, shape0) 154 | 155 | for i in range(scores.shape[0]): 156 | score = scores[i] 157 | 158 | if score < thr_keep_in_predictions: 159 | continue 160 | 161 | box = boxes[i] 162 | label = classes_google[labels[i]] 163 | mask = masks[i] 164 | 165 | x1 = int(box[0] * shape1) 166 | y1 = int(box[1] * shape0) 167 | x2 = int(box[2] * shape1) 168 | y2 = int(box[3] * shape0) 169 | mask = cv2.resize(mask, (x2 - x1, y2 - y1), interpolation=cv2.INTER_LINEAR) 170 | 171 | mask[mask > thr_mask] = 1 172 | mask[mask <= thr_mask] = 0 173 | mask_complete = np.zeros((shape0, shape1), dtype=np.uint8) 174 | mask_complete[y1:y2, x1:x2] = mask 175 | 176 | enc_mask = encode_binary_mask(mask_complete) 177 | str1 = str(label) + ' ' + str(score) + ' ' 178 | str1 += str(enc_mask)[2:-1] + ' ' 179 | s1 += '{} {:.8f} {} '.format(label, score, str(enc_mask)[2:-1]) 180 | 181 | s1 += '\n' 182 | return s1 183 | 184 | 185 | def get_maskrcnn_predictions(model_path, backbone, image_files, classes_description, output_csv, show_debug_image): 186 | model = models.load_model(model_path, backbone_name=backbone) 187 | classes = get_class_arr(classes_description, type='name') 188 | classes_google = get_class_arr(classes_description, type='google_name') 189 | print('Image files to process: {}'.format(len(image_files))) 190 | 191 | out = open(output_csv, 'w') 192 | out.write('ImageID,ImageWidth,ImageHeight,PredictionString\n') 193 | for i in range(len(image_files)): 194 | inp_file = image_files[i] 195 | id = os.path.basename(inp_file) 196 | img = read_single_image(inp_file) 197 | if img is None: 198 | print('Problem reading image: {}'.format(inp_file)) 199 | continue 200 | boxes, scores, labels, masks = get_maskrcnn_single_predictions(model, img, classes, show_debug_image) 201 | s1 = get_preds_as_string(id, img, boxes, scores, labels, masks, classes_google) 202 | out.write(s1) 203 | 204 | out.close() 205 | 206 | 207 | if __name__ == '__main__': 208 | backbone = 'resnet50' 209 | model_path = 'mask_rcnn_resnet50_oid_v1.0.h5' 210 | classes_description = 'data_segmentation/challenge-2019-classes-description-segmentable.csv' 211 | 212 | show_debug_images = True 213 | image_files = glob.glob('img/*.jpg') 214 | output_csv = 'output.csv' 215 | 216 | get_maskrcnn_predictions(model_path, backbone, image_files, classes_description, output_csv, show_debug_images) 217 | -------------------------------------------------------------------------------- /training/csv_generator.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | __author__ = 'ZFTurbo: https://kaggle.com/zfturbo' 3 | 4 | from keras_maskrcnn.preprocessing.generator import Generator 5 | 6 | import os.path 7 | import numpy as np 8 | import time 9 | import pandas as pd 10 | import cv2 11 | import random 12 | from PIL import Image 13 | from keras import backend as K 14 | 15 | 16 | def get_image_size(image_filename): 17 | im = Image.open(image_filename) 18 | w, h = im.size 19 | return w, h 20 | 21 | 22 | def read_single_image(path): 23 | img = cv2.cvtColor(cv2.imread(path), cv2.COLOR_BGR2RGB) 24 | return img 25 | 26 | 27 | def rle_decode(mask_rle, shape=(1024, 1024)): 28 | ''' 29 | mask_rle: run-length as string formated (start length) 30 | shape: (height,width) of array to return 31 | Returns numpy array, 1 - mask, 0 - background 32 | 33 | ''' 34 | s = mask_rle.split() 35 | starts, lengths = [np.asarray(x, dtype=int) for x in (s[0:][::2], s[1:][::2])] 36 | starts -= 1 37 | ends = starts + lengths 38 | img = np.zeros(shape[0]*shape[1], dtype=np.uint8) 39 | for lo, hi in zip(starts, ends): 40 | img[lo:hi] = 1 41 | return img.reshape(shape).T # Needed to align to RLE direction 42 | 43 | 44 | def read_oid_segmentation_annotations(dataset_path, csv_path, classes): 45 | result = {} 46 | start_time = time.time() 47 | print('Reading annotations {}'.format(csv_path)) 48 | csv = pd.read_csv(csv_path, usecols=['MaskPath', 'ImageID', 'LabelName', 'BoxXMin', 'BoxXMax', 'BoxYMin', 'BoxYMax']) 49 | image_size_cache = dict() 50 | csv = csv[['MaskPath', 'ImageID', 'LabelName', 'BoxXMin', 'BoxXMax', 'BoxYMin', 'BoxYMax']].values 51 | for i in range(csv.shape[0]): 52 | mask_id, img_id, class_name, x1, x2, y1, y2 = csv[i] 53 | 54 | if 'validation' in csv_path: 55 | img_path = dataset_path + 'validation/' + img_id + '.jpg' 56 | mask_path = dataset_path + 'masks-validation-rescaled/' + mask_id[0] + '/' + mask_id 57 | else: 58 | img_path = dataset_path + 'train/' + img_id[:3] + '/' + img_id + '.jpg' 59 | mask_path = dataset_path + 'masks-train-rescaled/' + mask_id[0] + '/' + mask_id 60 | 61 | if img_path not in result: 62 | result[img_path] = [] 63 | 64 | # Check that the bounding box is valid. 65 | if x1 < 0: 66 | # raise ValueError('line {}: negative x1 ({})'.format(i, x1)) 67 | print('line {}: negative x1 ({})'.format(i, x1)) 68 | x1 = 0 69 | if y1 < 0: 70 | # raise ValueError('line {}: negative y1 ({})'.format(i, y1)) 71 | print('line {}: negative y1 ({})'.format(i, y1)) 72 | y1 = 0 73 | if x2 > 1: 74 | # raise ValueError('line {}: invalid x2 ({})'.format(i, x2)) 75 | print('line {}: invalid x2 ({})'.format(i, x2)) 76 | x2 = 1 77 | if y2 > 1: 78 | # raise ValueError('line {}: invalid y2 ({})'.format(i, y2)) 79 | print('line {}: invalid y2 ({})'.format(i, y2)) 80 | y2 = 1 81 | 82 | if x2 <= x1: 83 | raise ValueError('line {}: x2 ({}) must be higher than x1 ({})'.format(i, x2, x1)) 84 | if y2 <= y1: 85 | raise ValueError('line {}: y2 ({}) must be higher than y1 ({})'.format(i, y2, y1)) 86 | 87 | if class_name not in classes: 88 | raise ValueError('line {}: unknown class name: \'{}\' (classes: {})'.format(i, class_name, classes)) 89 | 90 | result[img_path].append({'x1': x1, 'x2': x2, 'y1': y1, 'y2': y2, 'class': class_name, 'mask_path': mask_path}) 91 | print('Total images: {} Reading time: {:.2f} sec'.format(len(result), time.time() - start_time)) 92 | return result 93 | 94 | 95 | def get_class_index_arrays(classes_dict, image_data): 96 | classes = dict() 97 | # classes['empty'] = [] 98 | for name in classes_dict: 99 | classes[classes_dict[name]] = set() 100 | 101 | for key in image_data: 102 | for entry in image_data[key]: 103 | c = classes_dict[entry['class']] 104 | classes[c] |= set([key]) 105 | 106 | for c in classes: 107 | classes[c] = list(classes[c]) 108 | print('Class ID: {} Images: {}'.format(c, len(classes[c]))) 109 | 110 | return classes 111 | 112 | 113 | class CSVGenerator(Generator): 114 | def __init__( 115 | self, 116 | csv_data_file, 117 | csv_class_file, 118 | dataset_path, 119 | base_dir=None, 120 | is_rle=False, 121 | **kwargs 122 | ): 123 | self.image_names = [] 124 | self.image_data = {} 125 | self.dataset_path = dataset_path 126 | self.base_dir = base_dir 127 | self.is_rle = is_rle 128 | 129 | # Take base_dir from annotations file if not explicitly specified. 130 | if self.base_dir is None: 131 | self.base_dir = os.path.dirname(csv_data_file) 132 | 133 | # parse the provided class file 134 | clss_table = pd.read_csv(csv_class_file, header=None, names=['id', 'name']) 135 | self.classes = dict() 136 | for index, row in clss_table.iterrows(): 137 | self.classes[row['id']] = index 138 | 139 | self.labels = {} 140 | for key, value in self.classes.items(): 141 | self.labels[value] = key 142 | 143 | # csv with MaskPath,LabelName,BoxID,BoxXMin,BoxXMax,BoxYMin,BoxYMax 144 | self.image_data = read_oid_segmentation_annotations(self.dataset_path, csv_data_file, self.classes) 145 | self.image_names = list(self.image_data.keys()) 146 | 147 | self.id_to_image_id = dict([(i, k) for i, k in enumerate(self.image_names)]) 148 | self.image_id_to_id = dict([(k, i) for i, k in enumerate(self.image_names)]) 149 | self.class_index_array = get_class_index_arrays(self.classes, self.image_data) 150 | 151 | super(CSVGenerator, self).__init__(**kwargs) 152 | 153 | def size(self): 154 | return len(self.image_names) 155 | 156 | def num_classes(self): 157 | return max(self.classes.values()) + 1 158 | 159 | def name_to_label(self, name): 160 | return self.classes[name] 161 | 162 | def label_to_name(self, label): 163 | return self.labels[label] 164 | 165 | def image_path(self, image_index): 166 | return os.path.join(self.base_dir, self.image_names[image_index]) 167 | 168 | def image_aspect_ratio(self, image_index): 169 | # PIL is fast for metadata 170 | image = Image.open(self.image_path(image_index)) 171 | return float(image.width) / float(image.height) 172 | 173 | def load_image(self, image_index): 174 | return read_single_image(self.image_path(image_index)) 175 | 176 | def load_annotations(self, image_index): 177 | path = self.image_names[image_index] 178 | annots = self.image_data[path] 179 | 180 | annotations = { 181 | 'labels': np.empty((len(annots),)), 182 | 'bboxes': np.empty((len(annots), 4)), 183 | 'masks': [], 184 | } 185 | 186 | for idx, annot in enumerate(annots): 187 | if self.is_rle is False: 188 | mask = cv2.imread(annot['mask_path'], cv2.IMREAD_GRAYSCALE) 189 | if mask is None: 190 | print('Invalid mask: {}'.format(annot['mask_path'])) 191 | w, h = get_image_size(path) 192 | mask = np.zeros((h, w), dtype=np.uint8) 193 | else: 194 | mask = rle_decode(annot['mask_path'], (1200, 1200)) 195 | 196 | if 1: 197 | annotations['bboxes'][idx, 0] = float(annot['x1'] * mask.shape[1]) 198 | annotations['bboxes'][idx, 1] = float(annot['y1'] * mask.shape[0]) 199 | annotations['bboxes'][idx, 2] = float(annot['x2'] * mask.shape[1]) 200 | annotations['bboxes'][idx, 3] = float(annot['y2'] * mask.shape[0]) 201 | else: 202 | annotations['bboxes'][idx, 0] = float(annot['x1']) 203 | annotations['bboxes'][idx, 1] = float(annot['y1']) 204 | annotations['bboxes'][idx, 2] = float(annot['x2']) 205 | annotations['bboxes'][idx, 3] = float(annot['y2']) 206 | annotations['labels'][idx] = self.name_to_label(annot['class']) 207 | 208 | mask = (mask > 0).astype(np.uint8) # convert from 0-255 to binary mask 209 | annotations['masks'].append(np.expand_dims(mask, axis=-1)) 210 | 211 | return annotations 212 | 213 | def preprocess_group_entry(self, image, annotations): 214 | """ Preprocess image and its annotations. 215 | """ 216 | 217 | # randomly transform image and annotations 218 | image, annotations = self.random_transform_group_entry(image, annotations) 219 | 220 | # preprocess the image 221 | image = self.preprocess_image(image) 222 | 223 | # resize image 224 | image, image_scale = self.resize_image(image) 225 | 226 | # resize masks 227 | for i in range(len(annotations['masks'])): 228 | annotations['masks'][i], _ = self.resize_image(annotations['masks'][i]) 229 | 230 | # apply resizing to annotations too 231 | annotations['bboxes'] *= image_scale 232 | 233 | return image, annotations 234 | 235 | def random_transform_group_entry(self, image, annotations, transform=None): 236 | """ Randomly transforms image and annotation. 237 | """ 238 | # randomly transform both image and annotations 239 | # show_image(image) 240 | # print(image.min(), image.max()) 241 | # print(annotations) 242 | 243 | if self.transform_generator and len(annotations['masks']) > 0: 244 | ann = dict() 245 | ann['image'] = image.copy() 246 | ann['masks'] = np.array(annotations['masks'])[:, :, :, 0].copy() 247 | ann['labels'] = annotations['labels'].copy() 248 | ann['bboxes'] = list(annotations['bboxes']) 249 | augm = self.transform_generator(**ann) 250 | image = augm['image'] 251 | for i in range(len(annotations['masks'])): 252 | # show_image(255*annotations['masks'][i][:, :, 0]) 253 | annotations['masks'][i][:, :, 0] = augm['masks'][i] 254 | # show_image(255*annotations['masks'][i][:, :, 0]) 255 | annotations['bboxes'] = np.array(augm['bboxes']) 256 | 257 | # show_image(image) 258 | # print(annotations) 259 | # exit() 260 | 261 | return image, annotations 262 | 263 | def group_images(self): 264 | print('Group images. Method: {}...'.format(self.group_method)) 265 | # determine the order of the images 266 | order = list(range(self.size())) 267 | if self.group_method == 'random': 268 | random.shuffle(order) 269 | elif self.group_method == 'ratio': 270 | order.sort(key=lambda x: self.image_aspect_ratio(x)) 271 | elif self.group_method == 'random_classes': 272 | classes = list(range(self.num_classes())) 273 | self.groups = [] 274 | while 1: 275 | if len(self.groups) > 1000000: 276 | break 277 | self.groups.append([]) 278 | for i in range(self.batch_size): 279 | zz = 1000 280 | while zz > 0: 281 | random_class = random.choice(classes) 282 | # print(random_class, len(self.class_index_array[random_class])) 283 | if len(self.class_index_array[random_class]) > 0: 284 | random_image = random.choice(self.class_index_array[random_class]) 285 | break 286 | zz -= 1 287 | random_image_index = self.image_id_to_id[random_image] 288 | self.groups[-1].append(random_image_index) 289 | print('Grouped by random classes: {}'.format(len(self.groups))) 290 | return 291 | 292 | # divide into groups, one group = one batch 293 | self.groups = [[order[x % len(order)] for x in range(i, i + self.batch_size)] for i in range(0, len(order), self.batch_size)] 294 | 295 | def compute_targets(self, image_group, annotations_group): 296 | """ Compute target outputs for the network using images and their annotations. 297 | """ 298 | # get the max image shape 299 | max_shape = tuple(max(image.shape[x] for image in image_group) for x in range(3)) 300 | anchors = self.generate_anchors(max_shape) 301 | 302 | batches = self.compute_anchor_targets( 303 | anchors, 304 | image_group, 305 | annotations_group, 306 | self.num_classes() 307 | ) 308 | 309 | # copy all annotations / masks to the batch 310 | max_annotations = max(len(a['masks']) for a in annotations_group) 311 | # masks_batch has shape: (batch size, max_annotations, bbox_x1 + bbox_y1 + bbox_x2 + bbox_y2 + label + width + height + max_image_dimension) 312 | masks_batch = np.zeros((self.batch_size, max_annotations, 5 + 2 + max_shape[0] * max_shape[1]), dtype=K.floatx()) 313 | for index, annotations in enumerate(annotations_group): 314 | try: 315 | masks_batch[index, :annotations['bboxes'].shape[0], :4] = annotations['bboxes'] 316 | except: 317 | print('Error in compute targets!') 318 | print(index, annotations_group) 319 | 320 | masks_batch[index, :annotations['labels'].shape[0], 4] = annotations['labels'] 321 | masks_batch[index, :, 5] = max_shape[1] # width 322 | masks_batch[index, :, 6] = max_shape[0] # height 323 | 324 | # add flattened mask 325 | for mask_index, mask in enumerate(annotations['masks']): 326 | masks_batch[index, mask_index, 7:7 + (mask.shape[0] * mask.shape[1])] = mask.flatten() 327 | 328 | return list(batches) + [masks_batch] -------------------------------------------------------------------------------- /training/eval_map_generator.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2017-2018 Fizyr (https://fizyr.com) 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | import keras 18 | from keras_maskrcnn.utils.overlap import compute_overlap 19 | from keras_maskrcnn.utils.visualization import draw_masks 20 | 21 | import numpy as np 22 | import os 23 | import time 24 | 25 | import cv2 26 | import progressbar 27 | assert(callable(progressbar.progressbar)), "Using wrong progressbar module, install 'progressbar2' instead." 28 | 29 | 30 | def _compute_ap(recall, precision): 31 | """ Compute the average precision, given the recall and precision curves. 32 | 33 | Code originally from https://github.com/rbgirshick/py-faster-rcnn. 34 | 35 | # Arguments 36 | recall: The recall curve (list). 37 | precision: The precision curve (list). 38 | # Returns 39 | The average precision as computed in py-faster-rcnn. 40 | """ 41 | # correct AP calculation 42 | # first append sentinel values at the end 43 | mrec = np.concatenate(([0.], recall, [1.])) 44 | mpre = np.concatenate(([0.], precision, [0.])) 45 | 46 | # compute the precision envelope 47 | for i in range(mpre.size - 1, 0, -1): 48 | mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i]) 49 | 50 | # to calculate area under PR curve, look for points 51 | # where X axis (recall) changes value 52 | i = np.where(mrec[1:] != mrec[:-1])[0] 53 | 54 | # and sum (\Delta recall) * prec 55 | ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1]) 56 | return ap 57 | 58 | 59 | def _get_detections(generator, model, score_threshold=0.05, max_detections=100, save_path=None): 60 | """ Get the detections from the model using the generator. 61 | 62 | The result is a list of lists such that the size is: 63 | all_detections[num_images][num_classes] = detections[num_detections, 4 + num_classes] 64 | 65 | # Arguments 66 | generator : The generator used to run images through the model. 67 | model : The model to run on the images. 68 | score_threshold : The score confidence threshold to use. 69 | max_detections : The maximum number of detections to use per image. 70 | save_path : The path to save the images with visualized detections to. 71 | # Returns 72 | A list of lists containing the detections for each image in the generator. 73 | """ 74 | all_detections = [[None for i in range(generator.num_classes())] for j in range(generator.size())] 75 | all_masks = [[None for i in range(generator.num_classes())] for j in range(generator.size())] 76 | 77 | for i in progressbar.progressbar(range(generator.size()), prefix='Running network: '): 78 | raw_image = generator.load_image(i) 79 | image = generator.preprocess_image(raw_image.copy()) 80 | image, scale = generator.resize_image(image) 81 | 82 | # run network 83 | outputs = model.predict_on_batch(np.expand_dims(image, axis=0)) 84 | boxes = outputs[-4] 85 | scores = outputs[-3] 86 | labels = outputs[-2] 87 | masks = outputs[-1] 88 | 89 | # correct boxes for image scale 90 | boxes /= scale 91 | 92 | # select indices which have a score above the threshold 93 | indices = np.where(scores[0, :] > score_threshold)[0] 94 | 95 | # select those scores 96 | scores = scores[0][indices] 97 | 98 | # find the order with which to sort the scores 99 | scores_sort = np.argsort(-scores)[:max_detections] 100 | 101 | # select detections 102 | image_boxes = boxes[0, indices[scores_sort], :] 103 | image_scores = scores[scores_sort] 104 | image_labels = labels[0, indices[scores_sort]] 105 | image_masks = masks[0, indices[scores_sort], :, :, image_labels] 106 | image_detections = np.concatenate([image_boxes, np.expand_dims(image_scores, axis=1), np.expand_dims(image_labels, axis=1)], axis=1) 107 | 108 | if save_path is not None: 109 | # draw_annotations(raw_image, generator.load_annotations(i)[0], label_to_name=generator.label_to_name) 110 | #draw_detections(raw_image, image_boxes, image_scores, image_labels, score_threshold=score_threshold, label_to_name=generator.label_to_name) 111 | draw_masks(raw_image, image_boxes.astype(int), image_masks, labels=image_labels) 112 | 113 | cv2.imwrite(os.path.join(save_path, '{}.png'.format(i)), raw_image) 114 | 115 | # copy detections to all_detections 116 | for label in range(generator.num_classes()): 117 | all_detections[i][label] = image_detections[image_detections[:, -1] == label, :-1] 118 | all_masks[i][label] = image_masks[image_detections[:, -1] == label, ...] 119 | 120 | print('{}/{}'.format(i + 1, generator.size()), end='\r') 121 | 122 | return all_detections, all_masks 123 | 124 | 125 | def _get_annotations(generator): 126 | """ Get the ground truth annotations from the generator. 127 | 128 | The result is a list of lists such that the size is: 129 | all_detections[num_images][num_classes] = annotations[num_detections, 5] 130 | 131 | # Arguments 132 | generator : The generator used to retrieve ground truth annotations. 133 | # Returns 134 | A list of lists containing the annotations for each image in the generator. 135 | """ 136 | all_annotations = [[None for i in range(generator.num_classes())] for j in range(generator.size())] 137 | all_masks = [[None for i in range(generator.num_classes())] for j in range(generator.size())] 138 | 139 | for i in range(generator.size()): 140 | # load the annotations 141 | annotations = generator.load_annotations(i) 142 | annotations['masks'] = np.stack(annotations['masks'], axis=0) 143 | 144 | # copy detections to all_annotations 145 | for label in range(generator.num_classes()): 146 | all_annotations[i][label] = annotations['bboxes'][annotations['labels'] == label, :].copy() 147 | all_masks[i][label] = annotations['masks'][annotations['labels'] == label, ..., 0].copy() 148 | 149 | print('{}/{}'.format(i + 1, generator.size()), end='\r') 150 | 151 | return all_annotations, all_masks 152 | 153 | 154 | def evaluate( 155 | generator, 156 | model, 157 | iou_threshold=0.5, 158 | score_threshold=0.05, 159 | max_detections=100, 160 | binarize_threshold=0.5, 161 | save_path=None 162 | ): 163 | """ Evaluate a given dataset using a given model. 164 | 165 | # Arguments 166 | generator : The generator that represents the dataset to evaluate. 167 | model : The model to evaluate. 168 | iou_threshold : The threshold used to consider when a detection is positive or negative. 169 | score_threshold : The score confidence threshold to use for detections. 170 | max_detections : The maximum number of detections to use per image. 171 | binarize_threshold : Threshold to binarize the masks with. 172 | save_path : The path to save images with visualized detections to. 173 | # Returns 174 | A dict mapping class names to mAP scores. 175 | """ 176 | # gather all detections and annotations 177 | all_detections, all_masks = _get_detections(generator, model, score_threshold=score_threshold, max_detections=max_detections, save_path=save_path) 178 | all_annotations, all_gt_masks = _get_annotations(generator) 179 | average_precisions = {} 180 | 181 | # import pickle 182 | # pickle.dump(all_detections, open('all_detections.pkl', 'wb')) 183 | # pickle.dump(all_masks, open('all_masks.pkl', 'wb')) 184 | # pickle.dump(all_annotations, open('all_annotations.pkl', 'wb')) 185 | # pickle.dump(all_gt_masks, open('all_gt_masks.pkl', 'wb')) 186 | 187 | # process detections and annotations 188 | for label in range(generator.num_classes()): 189 | false_positives = [] 190 | true_positives = [] 191 | scores = [] 192 | num_annotations = 0.0 193 | 194 | for i in range(generator.size()): 195 | detections = all_detections[i][label] 196 | masks = all_masks[i][label] 197 | annotations = all_annotations[i][label] 198 | gt_masks = all_gt_masks[i][label] 199 | num_annotations += annotations.shape[0] 200 | detected_annotations = [] 201 | 202 | for d, mask in zip(detections, masks): 203 | box = d[:4].astype(int) 204 | scores.append(d[4]) 205 | 206 | if annotations.shape[0] == 0: 207 | false_positives.append(1) 208 | true_positives.append(0) 209 | continue 210 | 211 | if box[3] > gt_masks[0].shape[0]: 212 | print('Box 3 error: {} Fix {} -> {}'.format(box, box[3], gt_masks[0].shape[0])) 213 | box[3] = gt_masks[0].shape[0] 214 | if box[2] > gt_masks[0].shape[1]: 215 | print('Box 2 error: {} Fix {} -> {}'.format(box, box[2], gt_masks[0].shape[1])) 216 | box[2] = gt_masks[0].shape[1] 217 | if box[0] < 0: 218 | print('Box 0 error: {} Fix {} -> {}'.format(box, box[0], 0)) 219 | box[0] = 0 220 | if box[1] < 0: 221 | print('Box 1 error: {} Fix {} -> {}'.format(box, box[1], 0)) 222 | box[1] = 0 223 | 224 | # resize to fit the box 225 | mask = cv2.resize(mask, (box[2] - box[0], box[3] - box[1])) 226 | 227 | # binarize the mask 228 | mask = (mask > binarize_threshold).astype(np.uint8) 229 | 230 | # place mask in image frame 231 | mask_image = np.zeros_like(gt_masks[0]) 232 | mask_image[box[1]:box[3], box[0]:box[2]] = mask 233 | mask = mask_image 234 | 235 | overlaps = compute_overlap(np.expand_dims(mask, axis=0), gt_masks) 236 | assigned_annotation = np.argmax(overlaps, axis=1) 237 | max_overlap = overlaps[0, assigned_annotation] 238 | 239 | if max_overlap >= iou_threshold and assigned_annotation not in detected_annotations: 240 | false_positives.append(0) 241 | true_positives.append(1) 242 | detected_annotations.append(assigned_annotation) 243 | else: 244 | false_positives.append(1) 245 | true_positives.append(0) 246 | 247 | # no annotations -> AP for this class is 0 (is this correct?) 248 | if num_annotations == 0: 249 | average_precisions[label] = 0, 0 250 | continue 251 | 252 | false_positives = np.array(false_positives, dtype=np.uint8) 253 | true_positives = np.array(true_positives, dtype=np.uint8) 254 | scores = np.array(scores) 255 | 256 | # sort by score 257 | indices = np.argsort(-scores) 258 | false_positives = false_positives[indices] 259 | true_positives = true_positives[indices] 260 | 261 | # compute false positives and true positives 262 | false_positives = np.cumsum(false_positives) 263 | true_positives = np.cumsum(true_positives) 264 | 265 | # compute recall and precision 266 | recall = true_positives / num_annotations 267 | precision = true_positives / np.maximum(true_positives + false_positives, np.finfo(np.float64).eps) 268 | 269 | # compute average precision 270 | average_precision = _compute_ap(recall, precision) 271 | average_precisions[label] = average_precision, num_annotations 272 | 273 | return average_precisions 274 | 275 | 276 | class Evaluate(keras.callbacks.Callback): 277 | def __init__( 278 | self, 279 | generator, 280 | iou_threshold=0.5, 281 | score_threshold=0.01, 282 | max_detections=300, 283 | save_map_path=None, 284 | binarize_threshold=0.5, 285 | tensorboard=None, 286 | weighted_average=False, 287 | verbose=1 288 | ): 289 | """ Evaluate a given dataset using a given model at the end of every epoch during training. 290 | 291 | # Arguments 292 | generator : The generator that represents the dataset to evaluate. 293 | iou_threshold : The threshold used to consider when a detection is positive or negative. 294 | score_threshold : The score confidence threshold to use for detections. 295 | max_detections : The maximum number of detections to use per image. 296 | binarize_threshold : The threshold used for binarizing the masks. 297 | save_path : The path to save images with visualized detections to. 298 | tensorboard : Instance of keras.callbacks.TensorBoard used to log the mAP value. 299 | weighted_average : Compute the mAP using the weighted average of precisions among classes. 300 | verbose : Set the verbosity level, by default this is set to 1. 301 | """ 302 | self.generator = generator 303 | self.iou_threshold = iou_threshold 304 | self.score_threshold = score_threshold 305 | self.max_detections = max_detections 306 | self.save_map_path = save_map_path 307 | self.tensorboard = tensorboard 308 | self.weighted_average = weighted_average 309 | self.verbose = verbose 310 | 311 | super(Evaluate, self).__init__() 312 | 313 | def on_epoch_end(self, epoch, logs=None): 314 | # run evaluation 315 | start_time = time.time() 316 | average_precisions = evaluate( 317 | self.generator, 318 | self.model, 319 | iou_threshold=self.iou_threshold, 320 | score_threshold=self.score_threshold, 321 | max_detections=self.max_detections, 322 | save_path=None 323 | ) 324 | 325 | total_instances = [] 326 | precisions = [] 327 | for label, (average_precision, num_annotations) in average_precisions.items(): 328 | if self.verbose == 1: 329 | print('{:.0f} instances of class'.format(num_annotations), 330 | self.generator.label_to_name(label), 'with average precision: {:.4f}'.format(average_precision)) 331 | total_instances.append(num_annotations) 332 | precisions.append(average_precision) 333 | if self.weighted_average: 334 | mean_ap = sum([a * b for a, b in zip(total_instances, precisions)]) / sum(total_instances) 335 | else: 336 | mean_ap = sum(precisions) / sum(x > 0 for x in total_instances) 337 | 338 | if self.tensorboard is not None and self.tensorboard.writer is not None: 339 | import tensorflow as tf 340 | summary = tf.Summary() 341 | summary_value = summary.value.add() 342 | summary_value.simple_value = mean_ap 343 | summary_value.tag = "mAP" 344 | self.tensorboard.writer.add_summary(summary, epoch) 345 | 346 | if self.verbose == 1: 347 | print('Time: {:.2f} mAP: {:.4f}'.format(time.time() - start_time, mean_ap)) 348 | 349 | if self.save_map_path is not None: 350 | out = open(self.save_map_path, 'a') 351 | out.write('Ep {}: mAP: {:.4f}\n'.format(epoch + 1, mean_ap)) 352 | out.close() 353 | 354 | logs['mAP'] = mean_ap 355 | -------------------------------------------------------------------------------- /training/train_maskrcnn.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | __author__ = 'ZFTurbo: https://kaggle.com/zfturbo' 3 | 4 | 5 | import argparse 6 | import os 7 | import sys 8 | import cv2 9 | 10 | import keras.preprocessing.image 11 | import tensorflow as tf 12 | 13 | import keras_retinanet.losses 14 | from keras_retinanet.callbacks import RedirectModel 15 | from keras_retinanet.utils.keras_version import check_keras_version 16 | from keras_retinanet.utils.model import freeze as freeze_model 17 | 18 | 19 | # Allow relative imports when being executed as script. 20 | if __name__ == "__main__" and __package__ is None: 21 | sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..')) 22 | import keras_maskrcnn.bin 23 | __package__ = "keras_maskrcnn.bin" 24 | 25 | 26 | # Change these to absolute imports if you copy this script outside the keras_retinanet package. 27 | from keras_maskrcnn import losses 28 | from keras_maskrcnn import models 29 | from training.eval_map_generator import Evaluate 30 | from training.csv_generator import CSVGenerator 31 | from albumentations import * 32 | 33 | 34 | def get_session(): 35 | config = tf.ConfigProto() 36 | config.gpu_options.allow_growth = True 37 | config.gpu_options.allocator_type = 'BFC' 38 | config.allow_soft_placement = True 39 | config.log_device_placement = False 40 | return tf.Session(config=config) 41 | 42 | 43 | def model_with_weights(model, weights, skip_mismatch): 44 | if weights is not None: 45 | model.load_weights(weights, by_name=True, skip_mismatch=skip_mismatch) 46 | return model 47 | 48 | 49 | def create_models(backbone_retinanet, num_classes, weights, args, freeze_backbone=False, class_specific_filter=True, anchor_params=None): 50 | modifier = freeze_model if freeze_backbone else None 51 | 52 | model = model_with_weights( 53 | backbone_retinanet( 54 | num_classes, 55 | nms=True, 56 | class_specific_filter=class_specific_filter, 57 | modifier=modifier, 58 | anchor_params=anchor_params 59 | ), weights=weights, skip_mismatch=True) 60 | training_model = model 61 | prediction_model = model 62 | 63 | # compile model 64 | opt = keras.optimizers.adam(lr=1e-5, clipnorm=0.001) 65 | 66 | # compile model 67 | training_model.compile( 68 | loss={ 69 | 'regression' : keras_retinanet.losses.smooth_l1(), 70 | 'classification': keras_retinanet.losses.focal(), 71 | 'masks' : losses.mask(), 72 | }, 73 | optimizer=opt 74 | ) 75 | 76 | return model, training_model, prediction_model 77 | 78 | 79 | def create_callbacks(model, training_model, prediction_model, validation_generator, args): 80 | callbacks = [] 81 | 82 | # save the last prediction model 83 | if args.snapshots: 84 | # ensure directory created first; otherwise h5py will error after epoch. 85 | os.makedirs(args.snapshot_path, exist_ok=True) 86 | checkpoint = keras.callbacks.ModelCheckpoint( 87 | os.path.join( 88 | args.snapshot_path, 89 | '{backbone}_fold_{fold}_last.h5'.format(backbone=args.backbone, fold=args.fold) 90 | ), 91 | verbose=1, 92 | ) 93 | checkpoint = RedirectModel(checkpoint, model) 94 | callbacks.append(checkpoint) 95 | 96 | tensorboard_callback = None 97 | if args.tensorboard_dir: 98 | tensorboard_callback = keras.callbacks.TensorBoard( 99 | log_dir=args.tensorboard_dir, 100 | histogram_freq=0, 101 | batch_size=args.batch_size, 102 | write_graph=True, 103 | write_grads=False, 104 | write_images=False, 105 | embeddings_freq=0, 106 | embeddings_layer_names=None, 107 | embeddings_metadata=None 108 | ) 109 | callbacks.append(tensorboard_callback) 110 | 111 | # Calculate mAP 112 | if args.evaluation and validation_generator: 113 | evaluation = Evaluate(validation_generator, 114 | tensorboard=tensorboard_callback, 115 | weighted_average=args.weighted_average, 116 | save_map_path=args.snapshot_path + '/mask_rcnn_fold_{}.txt'.format(args.fold)) 117 | evaluation = RedirectModel(evaluation, prediction_model) 118 | callbacks.append(evaluation) 119 | 120 | # save prediction model with mAP 121 | if args.snapshots: 122 | checkpoint = keras.callbacks.ModelCheckpoint( 123 | os.path.join( 124 | args.snapshot_path, 125 | '{backbone}_fold_{fold}_{{mAP:.4f}}_ep_{{epoch:02d}}.h5'.format(backbone=args.backbone, fold=args.fold) 126 | ), 127 | verbose=1, 128 | save_best_only=False, 129 | monitor="mAP", 130 | mode='max' 131 | ) 132 | checkpoint = RedirectModel(checkpoint, prediction_model) 133 | callbacks.append(checkpoint) 134 | 135 | callbacks.append(keras.callbacks.ReduceLROnPlateau( 136 | monitor = 'loss', 137 | factor = 0.9, 138 | patience = 3, 139 | verbose = 1, 140 | mode = 'auto', 141 | epsilon = 0.0001, 142 | cooldown = 0, 143 | min_lr = 0 144 | )) 145 | 146 | return callbacks 147 | 148 | 149 | def create_generators(args): 150 | bbox_params = BboxParams(format='pascal_voc', min_area=1.0, min_visibility=0.1, label_fields=['labels']) 151 | 152 | transform_generator = Compose([ 153 | HorizontalFlip(p=0.5), 154 | OneOf([ 155 | IAAAdditiveGaussianNoise(), 156 | GaussNoise(), 157 | ], p=0.1), 158 | OneOf([ 159 | MotionBlur(p=.1), 160 | MedianBlur(blur_limit=3, p=.1), 161 | Blur(blur_limit=3, p=.1), 162 | ], p=0.1), 163 | ShiftScaleRotate(shift_limit=0.05, scale_limit=0.05, rotate_limit=20, p=0.1, border_mode=cv2.BORDER_CONSTANT), 164 | OneOf([ 165 | CLAHE(clip_limit=2), 166 | IAASharpen(), 167 | IAAEmboss(), 168 | RandomBrightnessContrast(), 169 | ], p=0.1), 170 | OneOf([ 171 | RGBShift(p=1.0, r_shift_limit=(-10, 10), g_shift_limit=(-10, 10), b_shift_limit=(-10, 10)), 172 | HueSaturationValue(p=1.0), 173 | ], p=0.1), 174 | ToGray(p=0.01), 175 | ImageCompression(p=0.05, quality_lower=50, quality_upper=99), 176 | ], bbox_params=bbox_params, p=1.0) 177 | 178 | train_generator = CSVGenerator( 179 | args.annotations, 180 | args.classes, 181 | dataset_path=args.dataset_location, 182 | transform_generator=transform_generator, 183 | batch_size=args.batch_size, 184 | config=args.config, 185 | image_min_side=800, 186 | image_max_side=1024, 187 | group_method=args.group_method, 188 | is_rle=False 189 | ) 190 | 191 | if args.val_annotations: 192 | validation_generator = CSVGenerator( 193 | args.val_annotations, 194 | args.classes, 195 | dataset_path=args.dataset_location, 196 | batch_size=args.batch_size, 197 | config=args.config, 198 | image_min_side=800, 199 | image_max_side=1024, 200 | group_method=args.group_method, 201 | is_rle=False 202 | ) 203 | else: 204 | validation_generator = None 205 | 206 | return train_generator, validation_generator 207 | 208 | 209 | def check_args(parsed_args): 210 | """ 211 | Function to check for inherent contradictions within parsed arguments. 212 | For example, batch_size < num_gpus 213 | Intended to raise errors prior to backend initialisation. 214 | 215 | :param parsed_args: parser.parse_args() 216 | :return: parsed_args 217 | """ 218 | 219 | return parsed_args 220 | 221 | 222 | def parse_args(args): 223 | parser = argparse.ArgumentParser(description='Simple training script for training a RetinaNet mask network.') 224 | subparsers = parser.add_subparsers(help='Arguments for specific dataset types.', dest='dataset_type') 225 | subparsers.required = True 226 | 227 | coco_parser = subparsers.add_parser('coco') 228 | coco_parser.add_argument('coco_path', help='Path to dataset directory (ie. /tmp/COCO).') 229 | 230 | csv_parser = subparsers.add_parser('csv') 231 | csv_parser.add_argument('dataset_location', help='Path to OID training images.') 232 | csv_parser.add_argument('annotations', help='Path to CSV file containing annotations for training.') 233 | csv_parser.add_argument('classes', help='Path to a CSV file containing class label mapping.') 234 | csv_parser.add_argument('--val-annotations', help='Path to CSV file containing annotations for validation (optional).') 235 | 236 | group = parser.add_mutually_exclusive_group() 237 | group.add_argument('--snapshot', help='Resume training from a snapshot.') 238 | group.add_argument('--imagenet-weights', help='Initialize the model with pretrained imagenet weights. This is the default behaviour.', action='store_const', const=True, default=True) 239 | group.add_argument('--weights', help='Initialize the model with weights from a file.') 240 | group.add_argument('--no-weights', help='Don\'t initialize the model with any weights.', dest='imagenet_weights', action='store_const', const=False) 241 | 242 | parser.add_argument('--backbone', help='Backbone model used by retinanet.', default='resnet50', type=str) 243 | parser.add_argument('--batch-size', help='Size of the batches.', default=1, type=int) 244 | parser.add_argument('--gpu', help='Id of the GPU to use (as reported by nvidia-smi).') 245 | parser.add_argument('--epochs', help='Number of epochs to train.', type=int, default=50) 246 | parser.add_argument('--fold', help='Fold number.', type=int, default=1) 247 | parser.add_argument('--steps', help='Number of steps per epoch.', type=int, default=10000) 248 | parser.add_argument('--lr', help='Learning rate.', type=float, default=1e-5) 249 | parser.add_argument('--accum_iters', help='Accum iters. If more than 1 used AdamAccum optimizer', type=int, default=1) 250 | parser.add_argument('--snapshot-path', help='Path to store snapshots of models during training (defaults to \'./snapshots\')', default='./snapshots') 251 | parser.add_argument('--tensorboard-dir', help='Log directory for Tensorboard output', default='./logs') 252 | parser.add_argument('--no-snapshots', help='Disable saving snapshots.', dest='snapshots', action='store_false') 253 | parser.add_argument('--no-evaluation', help='Disable per epoch evaluation.', dest='evaluation', action='store_false') 254 | parser.add_argument('--freeze-backbone', help='Freeze training of backbone layers.', action='store_true') 255 | parser.add_argument('--no-class-specific-filter', help='Disables class specific filtering.', dest='class_specific_filter', action='store_false') 256 | parser.add_argument('--config', help='Path to a configuration parameters .ini file.') 257 | parser.add_argument('--weighted-average', help='Compute the mAP using the weighted average of precisions among classes.', action='store_true') 258 | parser.add_argument('--group_method', help='How to form batches', default='random') 259 | 260 | return check_args(parser.parse_args(args)) 261 | 262 | 263 | def main(args=None): 264 | from keras import backend as K 265 | 266 | # parse arguments 267 | if args is None: 268 | args = sys.argv[1:] 269 | args = parse_args(args) 270 | 271 | # make sure keras is the minimum required version 272 | check_keras_version() 273 | 274 | # create object that stores backbone information 275 | backbone = models.backbone(args.backbone) 276 | 277 | # optionally choose specific GPU 278 | if args.gpu: 279 | print('Use GPU: {}'.format(args.gpu)) 280 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu 281 | keras.backend.tensorflow_backend.set_session(get_session()) 282 | 283 | # create the generators 284 | train_generator, validation_generator = create_generators(args) 285 | 286 | # create the model 287 | if args.snapshot is not None: 288 | print('Loading model {}, this may take a second...'.format(args.snapshot)) 289 | model = models.load_model(args.snapshot, backbone_name=args.backbone) 290 | training_model = model 291 | prediction_model = model 292 | else: 293 | weights = args.weights 294 | # default to imagenet if nothing else is specified 295 | if weights is None and args.imagenet_weights: 296 | weights = backbone.download_imagenet() 297 | 298 | anchor_params = None 299 | 300 | print('Creating model, this may take a second...') 301 | model, training_model, prediction_model = create_models( 302 | backbone_retinanet=backbone.maskrcnn, 303 | num_classes=train_generator.num_classes(), 304 | weights=weights, 305 | args=args, 306 | freeze_backbone=args.freeze_backbone, 307 | class_specific_filter=args.class_specific_filter, 308 | anchor_params=anchor_params 309 | ) 310 | 311 | # print model summary 312 | print(model.summary()) 313 | 314 | print('Learning rate: {}'.format(K.get_value(model.optimizer.lr))) 315 | if args.lr > 0.0: 316 | K.set_value(model.optimizer.lr, args.lr) 317 | print('Updated learning rate: {}'.format(K.get_value(model.optimizer.lr))) 318 | 319 | # create the callbacks 320 | callbacks = create_callbacks( 321 | model, 322 | training_model, 323 | prediction_model, 324 | validation_generator, 325 | args, 326 | ) 327 | 328 | initial_epoch = 0 329 | if args.snapshot is not None: 330 | initial_epoch = int((args.snapshot.split('_')[-1]).split('.')[0]) 331 | 332 | # start training 333 | training_model.fit_generator( 334 | generator=train_generator, 335 | steps_per_epoch=args.steps, 336 | epochs=args.epochs, 337 | verbose=1, 338 | callbacks=callbacks, 339 | max_queue_size=1, 340 | initial_epoch=initial_epoch, 341 | ) 342 | 343 | if __name__ == '__main__': 344 | models_path = './maskrcnn_training_models/' 345 | if not os.path.isdir(models_path): 346 | os.mkdir(models_path) 347 | DATASET_PATH = 'C:/Projects/2019_Google_Open_Images/input/' 348 | 349 | params = [ 350 | # '--snapshot', models_path + 'mask_rcnn_resnet50_oid_v1.0.h5', 351 | # '--imagenet-weights', 352 | # '--freeze-backbone', 353 | '--weights', '../mask_rcnn_resnet50_oid_v1.0.h5', 354 | '--epochs', '1000', 355 | '--gpu', '2', 356 | '--steps', '5', 357 | '--snapshot-path', models_path, 358 | '--lr', '1e-5', 359 | '--backbone', 'resnet50', 360 | '--group_method', 'random', 361 | '--batch-size', '1', 362 | 'csv', 363 | DATASET_PATH, 364 | DATASET_PATH + 'data_segmentation/challenge-2019-train-segmentation-masks.csv', 365 | DATASET_PATH + 'data_segmentation/challenge-2019-classes-description-segmentable.csv', 366 | '--val-annotations', DATASET_PATH + 'data_segmentation/challenge-2019-validation-segmentation-masks.csv', 367 | ] 368 | main(params) 369 | 370 | --------------------------------------------------------------------------------