├── Detection-Example-A3.ipynb ├── Detection-Example.ipynb ├── README.md ├── extras ├── graphics │ ├── 3D_Convolution_Animation.gif │ ├── PPE_logo.jpg │ ├── dataset.png │ ├── demo.gif │ ├── methods.jpg │ ├── results.jpg │ ├── yolo.jpg │ ├── yolo_complete.jpg │ ├── yolo_conv_block.jpg │ ├── yolo_output_block.jpg │ └── yolo_residual_block.jpg └── sample-images │ ├── 0.JPG │ ├── 1.JPG │ ├── 10.jpg │ ├── 2.jpg │ ├── 3.jpg │ ├── 4.jpg │ ├── 5.jpg │ ├── 6.JPG │ ├── 7.JPG │ ├── 8.jpg │ └── 9.jpg ├── model-data └── weights │ └── readme.md ├── src ├── __init__.py ├── utils │ ├── __init__.py │ ├── datagen.py │ ├── fixes.py │ └── image.py └── yolo3 │ ├── __init__.py │ ├── detect.py │ └── model.py └── tutorials ├── Build-YOLO-Model.ipynb └── Object-Detection.ipynb /README.md: -------------------------------------------------------------------------------- 1 | Logo 2 | 3 | # **Real-time Detection of Personal-Protective-Equipment (PPE)** 4 | 5 | Demo 6 | 7 | ## **Table of Contents** 8 | 1. [Introduction](#introduction) 9 | 2. [Implementation](#implementation) 10 | 3. [Dataset](#dataset) 11 | 4. [Pre-trained models](#pre-trained-models) 12 | 5. [Tutorials](#tutorials) 13 | 14 | ## **Introduction** 15 | 16 | The repository presents Tensorflow 2.0 (Keras) implementation of real-time detection of PPE (e.g., hard hat, safety vest) compliances of workers. The detailed procedures can be found in the following paper. 17 | 18 | ### Article 19 | 20 | [**Deep Learning for Site Safety: Real-Time Detection of Personal Protective Equipment**](https://www.sciencedirect.com/science/article/abs/pii/S0926580519308325)\ 21 | Nipun D. Nath, Amir H. Behzadan, Stephanie G. Paal \ 22 | Automation in Construction 112, pp. 103085 23 | 24 | Please cite the article if you use the dataset, model or method(s), or find the article useful in your research. Thank you! 25 | 26 | ### LaTeX citation: 27 | 28 | @article{Nath2020, 29 | author = {Nipun D. Nath and Amir H. Behzadan and Stephanie G. Paal}, 30 | title = {Deep Learning for Site Safety: Real-Time Detection of Personal Protective Equipment}, 31 | journal = {Automation in Construction}, 32 | volume = {112}, 33 | year = {2020}, 34 | pages = {103085} 35 | 36 | 37 | ## **Implementation** 38 | 39 | ### Dependencies 40 | - `tensorflow 2.0` 41 | - `numpy` 42 | - `opencv` 43 | - `matplotlib` 44 | - `pandas` 45 | 46 | Please follow the [`Detection-Example.ipynb`](https://github.com/nipundebnath/pictor-ppe/blob/master/Detection-Example.ipynb) notebook to test the models. 47 | 48 | ## **Dataset** 49 | 50 | ### Dataset statisctics 51 | 52 | The dataset (named **Pictor-v3**) contains 774 crowd-sourced and 698 web-mined images. Crowd-sourced and web-mined images contain 2,496 and 2,230 instances of workers, respectively. The dataset is available [here](https://drive.google.com/drive/folders/1akhyTNVrkqMMcIFUQCEbW5ehfmG0CdYH?usp=sharing). A brief statistics of the dataset is shown in the following figure. 53 | 54 | Dataset 55 | 56 | ### Annotation example 57 | 58 | **@TODO:** Please stay tuned! 59 | 60 | ### Download the crowd-sourced dataset 61 | 62 | The crowd-sourced images and annotated labels can be accessed from the corresponding subfolders in the dataset. 63 | 64 | ## **Methods/Approaches** 65 | 66 | The paper presents three different approaches for verifying PPE compliance: 67 | 68 | Methods/Approaches 69 | 70 | **Approach-1**: YOLO-v3-A1 model detects worker, hat, and vest (three object classes) individually. Next, ML classifiers (Neural Network, Decision Tree) classify each worker as W (wearing no hat or vest), WH (wearing only hat), WV (wearing only vest), or WHV (wearing both hat and vest). 71 | 72 | **Approach-2**: YOLO-v3-A2 model localizes workers in the input image and directly classifies each detected worker as W, WH, WV, or WHV. 73 | 74 | **Approach-3**: YOLO-v3-A3 model first detects all workers in the input image and then, a CNN-based classifier model (VGG-16, ResNet-50, Xception) was applied to the cropped worker images to classify the detected worker as W, WH, WV, or WHV. 75 | 76 | ## **Results** 77 | 78 | - The mean-average-precision (mAP) of YOLO models in detecting classes: 79 | 80 | Results 81 | 82 | - The overall mAP of verifying workers' PPE compliance: 83 | 84 | Results 85 | 86 | - The best is Approach-2 (72.3% mAP). 87 | 88 | ## **Pre-trained Models** 89 | 90 | You can download the models trained on Pictor-v3 dataset from the "Models" folder in the dataset. These models include: 91 | 92 | - YOLO-v3-A1 93 | - YOLO-v3-A2 94 | - YOLO-v3-A3 95 | - ML Classifiers (Approach-1) 96 | - CNN Classifiers (Approach-3) 97 | 98 | ## **Tutorials** 99 | 100 | Please follow the notebooks in [tutorials](https://github.com/nipundebnath/pictor-ppe/blob/master/tutorials/) folder to learn more about: 101 | - building YOLO model from scratch using tensorflow 2.0 102 | - interpret YOLO output and convert to bounding boxes with class label and confidence score. 103 | 104 | **@TODO:** Please stay tuned! More tutorials are coming soon. 105 | -------------------------------------------------------------------------------- /extras/graphics/3D_Convolution_Animation.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ciber-lab/pictor-ppe/8f41c9525bbd1f45245f593c7502a1934eef06fa/extras/graphics/3D_Convolution_Animation.gif -------------------------------------------------------------------------------- /extras/graphics/PPE_logo.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ciber-lab/pictor-ppe/8f41c9525bbd1f45245f593c7502a1934eef06fa/extras/graphics/PPE_logo.jpg -------------------------------------------------------------------------------- /extras/graphics/dataset.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ciber-lab/pictor-ppe/8f41c9525bbd1f45245f593c7502a1934eef06fa/extras/graphics/dataset.png -------------------------------------------------------------------------------- /extras/graphics/demo.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ciber-lab/pictor-ppe/8f41c9525bbd1f45245f593c7502a1934eef06fa/extras/graphics/demo.gif -------------------------------------------------------------------------------- /extras/graphics/methods.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ciber-lab/pictor-ppe/8f41c9525bbd1f45245f593c7502a1934eef06fa/extras/graphics/methods.jpg -------------------------------------------------------------------------------- /extras/graphics/results.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ciber-lab/pictor-ppe/8f41c9525bbd1f45245f593c7502a1934eef06fa/extras/graphics/results.jpg -------------------------------------------------------------------------------- /extras/graphics/yolo.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ciber-lab/pictor-ppe/8f41c9525bbd1f45245f593c7502a1934eef06fa/extras/graphics/yolo.jpg -------------------------------------------------------------------------------- /extras/graphics/yolo_complete.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ciber-lab/pictor-ppe/8f41c9525bbd1f45245f593c7502a1934eef06fa/extras/graphics/yolo_complete.jpg -------------------------------------------------------------------------------- /extras/graphics/yolo_conv_block.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ciber-lab/pictor-ppe/8f41c9525bbd1f45245f593c7502a1934eef06fa/extras/graphics/yolo_conv_block.jpg -------------------------------------------------------------------------------- /extras/graphics/yolo_output_block.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ciber-lab/pictor-ppe/8f41c9525bbd1f45245f593c7502a1934eef06fa/extras/graphics/yolo_output_block.jpg -------------------------------------------------------------------------------- /extras/graphics/yolo_residual_block.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ciber-lab/pictor-ppe/8f41c9525bbd1f45245f593c7502a1934eef06fa/extras/graphics/yolo_residual_block.jpg -------------------------------------------------------------------------------- /extras/sample-images/0.JPG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ciber-lab/pictor-ppe/8f41c9525bbd1f45245f593c7502a1934eef06fa/extras/sample-images/0.JPG -------------------------------------------------------------------------------- /extras/sample-images/1.JPG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ciber-lab/pictor-ppe/8f41c9525bbd1f45245f593c7502a1934eef06fa/extras/sample-images/1.JPG -------------------------------------------------------------------------------- /extras/sample-images/10.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ciber-lab/pictor-ppe/8f41c9525bbd1f45245f593c7502a1934eef06fa/extras/sample-images/10.jpg -------------------------------------------------------------------------------- /extras/sample-images/2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ciber-lab/pictor-ppe/8f41c9525bbd1f45245f593c7502a1934eef06fa/extras/sample-images/2.jpg -------------------------------------------------------------------------------- /extras/sample-images/3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ciber-lab/pictor-ppe/8f41c9525bbd1f45245f593c7502a1934eef06fa/extras/sample-images/3.jpg -------------------------------------------------------------------------------- /extras/sample-images/4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ciber-lab/pictor-ppe/8f41c9525bbd1f45245f593c7502a1934eef06fa/extras/sample-images/4.jpg -------------------------------------------------------------------------------- /extras/sample-images/5.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ciber-lab/pictor-ppe/8f41c9525bbd1f45245f593c7502a1934eef06fa/extras/sample-images/5.jpg -------------------------------------------------------------------------------- /extras/sample-images/6.JPG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ciber-lab/pictor-ppe/8f41c9525bbd1f45245f593c7502a1934eef06fa/extras/sample-images/6.JPG -------------------------------------------------------------------------------- /extras/sample-images/7.JPG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ciber-lab/pictor-ppe/8f41c9525bbd1f45245f593c7502a1934eef06fa/extras/sample-images/7.JPG -------------------------------------------------------------------------------- /extras/sample-images/8.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ciber-lab/pictor-ppe/8f41c9525bbd1f45245f593c7502a1934eef06fa/extras/sample-images/8.jpg -------------------------------------------------------------------------------- /extras/sample-images/9.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ciber-lab/pictor-ppe/8f41c9525bbd1f45245f593c7502a1934eef06fa/extras/sample-images/9.jpg -------------------------------------------------------------------------------- /model-data/weights/readme.md: -------------------------------------------------------------------------------- 1 | Download the trained weights of YOLO models ([Google Drive folder](https://drive.google.com/drive/folders/13tCdROHnS0c5VibW1VO8pOEj0rXEvvGj?usp=sharing)) and put in this folder. 2 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ciber-lab/pictor-ppe/8f41c9525bbd1f45245f593c7502a1934eef06fa/src/__init__.py -------------------------------------------------------------------------------- /src/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ciber-lab/pictor-ppe/8f41c9525bbd1f45245f593c7502a1934eef06fa/src/utils/__init__.py -------------------------------------------------------------------------------- /src/utils/datagen.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | 4 | from matplotlib.colors import rgb_to_hsv, hsv_to_rgb 5 | 6 | 7 | def rand(a=0, b=1): 8 | return np.random.rand()*(b-a) + a 9 | 10 | 11 | def get_random_data ( 12 | annotation_line, 13 | input_shape, 14 | max_boxes=25, 15 | scale=.3, 16 | hue=.1, 17 | sat=1.5, 18 | val=1.5, 19 | random=True 20 | ): 21 | 22 | ''' 23 | random preprocessing for real-time data augmentation 24 | ''' 25 | 26 | line = annotation_line.split('\t') 27 | h, w = input_shape 28 | box = np.array([np.array(list(map(int,box.split(',')))) for box in line[1:]]) 29 | 30 | image = cv2.imread(line[0]) 31 | ih, iw, ic = image.shape 32 | 33 | if not random: 34 | resize_scale = min(h/ih, w/iw) 35 | 36 | nw = int(iw * resize_scale) 37 | nh = int(ih * resize_scale) 38 | 39 | max_offx = w - nw 40 | max_offy = h - nh 41 | 42 | dx = max_offx//2 43 | dy = max_offy//2 44 | 45 | to_x0, to_y0 = max(0, dx), max(0, dy) 46 | from_x0, from_y0 = max(0, -dx), max(0, -dy) 47 | wx, hy = min(w, dx+nw) - to_x0, min(h, dy+nh) - to_y0 48 | 49 | # place image 50 | image_data = np.zeros((*input_shape,ic), dtype='uint8') + 128 51 | image_data[to_y0:to_y0+hy, to_x0:to_x0+wx, :] = cv2.resize(image, (nw, nh))[from_y0:from_y0+hy, from_x0:from_x0+wx, :] 52 | 53 | flip = False 54 | image_data = image_data/255. 55 | else: 56 | if np.random.uniform() >= 0.5: 57 | # scale Up 58 | resize_scale = 1. + scale * np.random.uniform() 59 | resize_scale = max( h*resize_scale/ih, w*resize_scale/iw) 60 | 61 | nw = int(iw * resize_scale) 62 | nh = int(ih * resize_scale) 63 | 64 | max_offx = nw - w 65 | max_offy = nh - h 66 | 67 | dx = int(np.random.uniform() * max_offx) 68 | dy = int(np.random.uniform() * max_offy) 69 | 70 | # resize and crop 71 | image = cv2.resize(image, (nw, nh)) 72 | image_data = image[dy : (dy + h), dx : (dx + w), :] 73 | 74 | dx, dy = (-dx, -dy) 75 | else: 76 | # scale down 77 | mul = 1 if np.random.uniform() >= 0.5 else -1 78 | 79 | resize_scale = 1. + mul * scale * np.random.uniform() 80 | resize_scale = min( h*resize_scale/ih, w*resize_scale/iw) 81 | 82 | nw = int(iw * resize_scale) 83 | nh = int(ih * resize_scale) 84 | 85 | max_offx = w - nw 86 | max_offy = h - nh 87 | 88 | dx = int(np.random.uniform() * max_offx) 89 | dy = int(np.random.uniform() * max_offy) 90 | 91 | to_x0, to_y0 = max(0, dx), max(0, dy) 92 | from_x0, from_y0 = max(0, -dx), max(0, -dy) 93 | wx, hy = min(w, dx+nw) - to_x0, min(h, dy+nh) - to_y0 94 | 95 | # place image 96 | image_data = np.zeros((*input_shape,ic), dtype='uint8') + 128 97 | image_data[to_y0:to_y0+hy, to_x0:to_x0+wx, :] = cv2.resize(image, (nw, nh))[from_y0:from_y0+hy, from_x0:from_x0+wx, :] 98 | 99 | flip = np.random.uniform() >= 0.5 100 | if flip: image_data = image_data[:,::-1,:] 101 | 102 | # distort color of the image 103 | hue = rand(-hue, hue) 104 | sat = rand(1, sat) if rand()<.5 else 1/rand(1, sat) 105 | val = rand(1, val) if rand()<.5 else 1/rand(1, val) 106 | x = rgb_to_hsv(np.array(image_data)/255.) 107 | x[..., 0] += hue 108 | x[..., 0][x[..., 0]>1] -= 1 109 | x[..., 0][x[..., 0]<0] += 1 110 | x[..., 1] *= sat 111 | x[..., 2] *= val 112 | x[x>1] = 1 113 | x[x<0] = 0 114 | image_data = hsv_to_rgb(x) # numpy array, 0 to 1 115 | 116 | # correct boxes 117 | box_data = np.zeros((max_boxes,5)) 118 | if len(box)>0: 119 | np.random.shuffle(box) 120 | box[:, [0,2]] = box[:, [0,2]]*nw/iw + dx 121 | box[:, [1,3]] = box[:, [1,3]]*nh/ih + dy 122 | if flip: box[:, [0,2]] = w - box[:, [2,0]] 123 | box[:, 0:2][box[:, 0:2]<0] = 0 124 | box[:, 2][box[:, 2]>w] = w 125 | box[:, 3][box[:, 3]>h] = h 126 | box_w = box[:, 2] - box[:, 0] 127 | box_h = box[:, 3] - box[:, 1] 128 | box = box[np.logical_and(box_w>1, box_h>1)] # discard invalid box 129 | if len(box)>max_boxes: box = box[:max_boxes] 130 | box_data[:len(box)] = box 131 | 132 | return image_data, box_data 133 | 134 | 135 | def data_generator(annotation_lines, batch_size, input_shape, random): 136 | ''' 137 | data generator for fit_generator 138 | ''' 139 | n = len(annotation_lines) 140 | i = 0 141 | while True: 142 | image_data = [] 143 | box_data = [] 144 | for _ in range(batch_size): 145 | image, box = get_random_data(annotation_lines[i], input_shape, max_boxes=50, random=random) 146 | image_data.append(image) 147 | box = box[np.sum(box, axis=1) != 0, :] 148 | box_data.append(box) 149 | i = (i+1) % n 150 | image_data = np.array(image_data) 151 | box_data = np.array(box_data) 152 | 153 | yield image_data, box_data 154 | 155 | 156 | def data_generator_wrapper(annotation_lines, batch_size, input_shape, random): 157 | n = len(annotation_lines) 158 | if n==0 or batch_size<=0: return None 159 | return data_generator(annotation_lines, batch_size, input_shape, random) -------------------------------------------------------------------------------- /src/utils/fixes.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | def fix_tf_gpu(): 5 | ''' 6 | Fix for the following error message: 7 | UnknownError: Failed to get convolution algorithm. 8 | This is probably because cuDNN failed to initialize... 9 | 10 | More: 11 | https://www.tensorflow.org/api_docs/python/tf/config/experimental/set_memory_growth 12 | ''' 13 | 14 | physical_devices = tf.config.experimental.list_physical_devices('GPU') 15 | 16 | try: 17 | tf.config.experimental.set_memory_growth(physical_devices[0], True) 18 | except: 19 | pass -------------------------------------------------------------------------------- /src/utils/image.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import matplotlib as mpl 4 | 5 | 6 | def letterbox_image(image, size): 7 | ''' 8 | Resize image with unchanged aspect ratio using padding 9 | ''' 10 | 11 | # original image size 12 | ih, iw, ic = image.shape 13 | 14 | # given size 15 | h, w = size 16 | 17 | # scale and new size of the image 18 | scale = min(w/iw, h/ih) 19 | nw = int(iw*scale) 20 | nh = int(ih*scale) 21 | 22 | # placeholder letter box 23 | new_image = np.zeros((h, w, ic), dtype='uint8') + 128 24 | 25 | # top-left corner 26 | top, left = (h - nh)//2, (w - nw)//2 27 | 28 | # paste the scaled image in the placeholder anchoring at the top-left corner 29 | new_image[top:top+nh, left:left+nw, :] = cv2.resize(image, (nw, nh)) 30 | 31 | return new_image 32 | 33 | 34 | def draw_detection( 35 | img, 36 | boxes, 37 | class_names, 38 | # drawing configs 39 | font=cv2.FONT_HERSHEY_DUPLEX, 40 | font_scale=0.5, 41 | box_thickness=2, 42 | border=5, 43 | text_color=(255, 255, 255), 44 | text_weight=1 45 | ): 46 | ''' 47 | Draw the bounding boxes on the image 48 | ''' 49 | # generate some colors for different classes 50 | num_classes = len(class_names) # number of classes 51 | colors = [mpl.colors.hsv_to_rgb((i/num_classes, 1, 1)) * 255 for i in range(num_classes)] 52 | 53 | # draw the detections 54 | for box in boxes: 55 | x1, y1, x2, y2 = box[:4].astype(int) 56 | score = box[-2] 57 | label = int(box[-1]) 58 | 59 | clr = colors[label] 60 | 61 | # draw the bounding box 62 | img = cv2.rectangle(img, (x1, y1), (x2, y2), clr, box_thickness) 63 | 64 | # text: (%) 65 | text = f'{class_names[label]} ({score*100:.0f}%)' 66 | 67 | # get width (tw) and height (th) of the text 68 | (tw, th), _ = cv2.getTextSize(text, font, font_scale, 1) 69 | 70 | # background rectangle for the text 71 | tb_x1 = x1 - box_thickness//2 72 | tb_y1 = y1 - box_thickness//2 - th - 2*border 73 | tb_x2 = x1 + tw + 2*border 74 | tb_y2 = y1 75 | 76 | # draw the background rectangle 77 | img = cv2.rectangle(img, (tb_x1, tb_y1), (tb_x2, tb_y2), clr, -1) 78 | 79 | # put the text 80 | img = cv2.putText(img, text, (x1 + border, y1 - border), font, font_scale, text_color, text_weight, cv2.LINE_AA) 81 | 82 | return img -------------------------------------------------------------------------------- /src/yolo3/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ciber-lab/pictor-ppe/8f41c9525bbd1f45245f593c7502a1934eef06fa/src/yolo3/__init__.py -------------------------------------------------------------------------------- /src/yolo3/detect.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | 4 | 5 | def detection( 6 | prediction, 7 | anchor_boxes, 8 | num_classes, 9 | image_shape, 10 | input_shape, 11 | max_boxes = 20, 12 | score_threshold=0.3, 13 | iou_threshold=0.45, 14 | classes_can_overlap=True, 15 | ): 16 | ''' 17 | INPUT: 18 | OUTPUT: 19 | ''' 20 | 21 | all_boxes = [] 22 | 23 | '''@ Each output layer''' 24 | for output, anchors in zip( prediction, anchor_boxes ): 25 | 26 | '''Preprocessing''' 27 | '''-------------''' 28 | # shapes 29 | batch_size = output.shape[0] 30 | grid_h, grid_w = output.shape[1:3] 31 | 32 | # reshape to [batch_size, grid_height, grid_width, num_anchors, box_params] 33 | output = tf.reshape( output, [ -1, grid_h, grid_w, len(anchors), num_classes+5 ] ) 34 | 35 | # create a tensor for the anchor boxes 36 | anchors_tensor = tf.constant(anchors, dtype=output.dtype) 37 | 38 | '''Scaling factors''' 39 | '''---------------''' 40 | image_shape_tensor = tf.cast( image_shape, output.dtype ) # actual image's shape 41 | grids_shape_tensor = tf.cast( output.shape[1:3], output.dtype ) # grid_height, grid_width @ output layer 42 | input_shape_tensor = tf.cast( input_shape, output.dtype ) # yolo input image's shape 43 | 44 | # reshape 45 | image_shape_tensor = tf.reshape( image_shape_tensor, [-1, 1, 1, 1, 2] ) 46 | grids_shape_tensor = tf.reshape( grids_shape_tensor, [-1, 1, 1, 1, 2] ) 47 | input_shape_tensor = tf.reshape( input_shape_tensor, [-1, 1, 1, 1, 2] ) 48 | 49 | ### Scaling factors 50 | sized_shape_tensor = tf.round( image_shape_tensor * tf.reshape( tf.reduce_min( input_shape_tensor / image_shape_tensor, axis=-1 ), [-1,1,1,1,1] ) ) 51 | # to scale the boxes from grid's unit to actual image's pixel unit 52 | box_scaling = input_shape_tensor * image_shape_tensor / sized_shape_tensor / grids_shape_tensor 53 | # to offset the boxes 54 | box_offsets = (tf.expand_dims(tf.reduce_max(image_shape_tensor, axis=-1), axis=-1) - image_shape_tensor) / 2. 55 | 56 | '''Box geometric properties''' 57 | '''------------------------''' 58 | grid_h, grid_w = output.shape[1:3] # grid_height, grid_width @ output layer 59 | 60 | grid_i = tf.reshape( np.arange(grid_h), [-1, 1, 1, 1] ) 61 | grid_i = tf.tile( grid_i, [1, grid_w, 1, 1] ) 62 | 63 | grid_j = tf.reshape( np.arange(grid_w), [1, -1, 1, 1] ) 64 | grid_j = tf.tile( grid_j, [grid_h, 1, 1, 1] ) 65 | 66 | grid_ji = tf.concat( [grid_j, grid_i], axis=-1 ) 67 | grid_ji = tf.cast( grid_ji, output.dtype ) 68 | 69 | # Box centers 70 | box_xy = output[..., 0:2] 71 | box_xy = tf.sigmoid( box_xy ) + grid_ji 72 | 73 | # Box sizes 74 | box_wh = output[..., 2:4] 75 | box_wh = tf.exp( box_wh ) * anchors_tensor 76 | 77 | # scale to actual pixel unit 78 | box_xy = box_xy * box_scaling - box_offsets[...,::-1] 79 | box_wh = box_wh * box_scaling 80 | 81 | # calculate top-left corner (x1, y1) and bottom-right corner (x2, y2) of the boxex 82 | box_x1_y1 = box_xy - box_wh / 2 83 | box_x2_y2 = box_xy + box_wh / 2 84 | 85 | # top-left corner cannot be negative 86 | box_x1_y1 = tf.maximum(0, box_x1_y1) 87 | # bottom-right corner cannot be more than actual image size 88 | box_x2_y2 = tf.minimum(box_x2_y2, image_shape_tensor[..., ::-1]) 89 | 90 | '''Box labels and confidences''' 91 | '''--------------------------''' 92 | # class probabilities = objectness score * conditional class probabilities 93 | if classes_can_overlap: 94 | # use sigmoid for the conditional class probabilities 95 | classs_probs = tf.sigmoid( output[..., 4:5] ) * tf.sigmoid( output[..., 5:] ) 96 | else: 97 | # use softmax for the conditional class probabilities 98 | classs_probs = tf.sigmoid( output[..., 4:5] ) * tf.nn.softmax( output[..., 5:] ) 99 | 100 | box_cl = tf.argmax( classs_probs, axis=-1 ) # final classes 101 | box_sc = tf.reduce_max( classs_probs, axis=-1 ) # confidence scores 102 | 103 | '''Organize''' 104 | '''--------''' 105 | # take care of dtype and dimensions 106 | box_cl = tf.cast( box_cl, output.dtype ) 107 | box_cl = tf.expand_dims(box_cl, axis=-1) 108 | box_sc = tf.expand_dims(box_sc, axis=-1) 109 | 110 | # store all information as: [ left(x1), top(y1), right(x2), bottom(y2), confidence, label ] 111 | boxes = tf.reshape( tf.concat( [ box_x1_y1, box_x2_y2, box_sc, box_cl ], axis=-1 ), 112 | [batch_size, -1, 6] ) 113 | 114 | all_boxes. append( boxes ) 115 | 116 | # Merge across all output layers 117 | all_boxes = tf.concat( all_boxes, axis=1 ) 118 | 119 | # To store all the final results of all images in the batch 120 | all_final_boxes = [] 121 | 122 | '''For each image in the batch''' 123 | for _boxes_ in all_boxes: 124 | 125 | if classes_can_overlap: 126 | '''Perform NMS for each class individually''' 127 | 128 | # to stote the final results of this image 129 | final_boxes = [] 130 | 131 | for class_id in range(num_classes): 132 | 133 | # Get the boxes and scores for this class 134 | class_boxes = _boxes_[ _boxes_[...,-1] == class_id ] 135 | 136 | '''Non-max-suppression''' 137 | selected_idc = tf.image.non_max_suppression( 138 | class_boxes[...,:4], # boxes' (y1,x1,y2,x2) 139 | class_boxes[...,-2], # boxes' scores 140 | max_output_size = max_boxes, 141 | iou_threshold = iou_threshold, 142 | score_threshold = score_threshold 143 | ) 144 | 145 | # boxes selected by nms 146 | class_boxes = tf.gather( class_boxes, selected_idc ) 147 | final_boxes.append( class_boxes ) 148 | 149 | # concatenate boxes for each class in the image 150 | final_boxes = tf.concat( final_boxes, axis=0 ) 151 | 152 | else: 153 | '''Perform NMS for all classes''' 154 | 155 | # nms indices 156 | selected_idc = tf.image.non_max_suppression( 157 | _boxes_[...,:4], # boxes' (y1,x1,y2,x2) 158 | _boxes_[...,-2], # boxes' scores 159 | max_output_size = max_boxes, 160 | iou_threshold = iou_threshold, 161 | score_threshold = score_threshold 162 | ) 163 | 164 | # boxes selected by nms 165 | final_boxes = tf.gather( _boxes_, selected_idc ) 166 | 167 | # append final boxes for each image in the batch 168 | all_final_boxes.append( final_boxes ) 169 | 170 | return all_final_boxes -------------------------------------------------------------------------------- /src/yolo3/model.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import tensorflow.keras.backend as K 3 | 4 | from tensorflow.keras.layers import Input, Conv2D, Add, ZeroPadding2D, UpSampling2D, Concatenate, MaxPooling2D 5 | from tensorflow.keras.layers import LeakyReLU, BatchNormalization 6 | from tensorflow.keras.models import Sequential, Model 7 | from tensorflow.keras.regularizers import l2 8 | 9 | '''====================================================================================================''' 10 | '''BLOCKS''' 11 | 12 | '''Convolutional Block''' 13 | def yolo_ConvBlock (input_tensor, num_filters, filter_size, strides = (1,1) ): 14 | padding = 'valid' if strides == (2,2) else 'same' 15 | 16 | ### Layers 17 | x = Conv2D( num_filters, filter_size, strides, padding, use_bias=False, kernel_regularizer=l2(5e-4) ) (input_tensor) 18 | x = BatchNormalization() (x) 19 | x = LeakyReLU(alpha=0.1) (x) 20 | 21 | return x 22 | 23 | '''Residual Block''' 24 | def yolo_ResidualBlocks (input_tensor, num_filters, num_blocks ): 25 | 26 | ### Layers 27 | x = ZeroPadding2D( ((1,0),(1,0)) ) (input_tensor) # left & top padding 28 | x = yolo_ConvBlock ( x, num_filters, filter_size=(3,3), strides = (2,2) ) 29 | 30 | for _ in range( num_blocks ): 31 | y = yolo_ConvBlock ( x, num_filters//2, filter_size=(1,1), strides = (1,1) ) 32 | y = yolo_ConvBlock ( y, num_filters , filter_size=(3,3), strides = (1,1) ) 33 | x = Add() ([x, y]) 34 | 35 | return x 36 | 37 | '''Output Block''' 38 | def yolo_OutputBlock (x, num_filters, out_filters ): 39 | 40 | ### Layers 41 | x = yolo_ConvBlock ( x, 1*num_filters, filter_size=(1,1), strides = (1,1) ) 42 | x = yolo_ConvBlock ( x, 2*num_filters, filter_size=(3,3), strides = (1,1) ) 43 | x = yolo_ConvBlock ( x, 1*num_filters, filter_size=(1,1), strides = (1,1) ) 44 | x = yolo_ConvBlock ( x, 2*num_filters, filter_size=(3,3), strides = (1,1) ) 45 | x = yolo_ConvBlock ( x, 1*num_filters, filter_size=(1,1), strides = (1,1) ) 46 | 47 | y = yolo_ConvBlock ( x, 2*num_filters, filter_size=(3,3), strides = (1,1) ) 48 | y = Conv2D ( filters=out_filters, kernel_size=(1,1), strides=(1,1), 49 | padding='same', use_bias=True, kernel_regularizer=l2(5e-4) ) (y) 50 | 51 | return x, y 52 | 53 | '''====================================================================================================''' 54 | '''COMPLETE MODEL''' 55 | 56 | def yolo_body (input_tensor, num_out_filters): 57 | ''' 58 | Input: 59 | input_tensor = Input( shape=( *input_shape, 3 ) ) 60 | num_out_filter = ( num_anchors // 3 ) * ( 5 + num_classes ) 61 | Output: 62 | complete YOLO-v3 model 63 | ''' 64 | 65 | # 1st Conv block 66 | x = yolo_ConvBlock( input_tensor, num_filters=32, filter_size=(3,3), strides=(1,1) ) 67 | 68 | # 5 Resblocks 69 | x = yolo_ResidualBlocks ( x, num_filters= 64, num_blocks=1 ) 70 | x = yolo_ResidualBlocks ( x, num_filters= 128, num_blocks=2 ) 71 | x = yolo_ResidualBlocks ( x, num_filters= 256, num_blocks=8 ) 72 | x = yolo_ResidualBlocks ( x, num_filters= 512, num_blocks=8 ) 73 | x = yolo_ResidualBlocks ( x, num_filters=1024, num_blocks=4 ) 74 | 75 | darknet = Model( input_tensor, x ) # will use it just in a moment 76 | 77 | # 1st output block 78 | x, y1 = yolo_OutputBlock( x, num_filters= 512, out_filters=num_out_filters ) 79 | 80 | # 2nd output block 81 | x = yolo_ConvBlock( x, num_filters=256, filter_size=(1,1), strides=(1,1) ) 82 | x = UpSampling2D(2) (x) 83 | x = Concatenate() ( [x, darknet.layers[152].output] ) 84 | x, y2 = yolo_OutputBlock( x, num_filters= 256, out_filters=num_out_filters ) 85 | 86 | # 3rd output block 87 | x = yolo_ConvBlock( x, num_filters=128, filter_size=(1,1), strides=(1,1) ) 88 | x = UpSampling2D(2) (x) 89 | x = Concatenate() ( [x, darknet.layers[92].output] ) 90 | x, y3 = yolo_OutputBlock( x, num_filters= 128, out_filters=num_out_filters ) 91 | 92 | # Final model 93 | model = Model( input_tensor, [y1, y2, y3] ) 94 | 95 | return model 96 | -------------------------------------------------------------------------------- /tutorials/Build-YOLO-Model.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# **BUILD YOLO MODEL**\n", 8 | "FROM SCRATCH, USING TENSORFLOW=2.0" 9 | ] 10 | }, 11 | { 12 | "cell_type": "markdown", 13 | "metadata": {}, 14 | "source": [] 15 | }, 16 | { 17 | "cell_type": "code", 18 | "execution_count": 1, 19 | "metadata": {}, 20 | "outputs": [], 21 | "source": [ 22 | "import numpy as np\n", 23 | "import cv2\n", 24 | "\n", 25 | "import tensorflow as tf\n", 26 | "import tensorflow.keras.backend as K\n", 27 | "\n", 28 | "from tensorflow.keras.layers import Input, Conv2D, Add, ZeroPadding2D, UpSampling2D, Concatenate, MaxPooling2D\n", 29 | "from tensorflow.keras.layers import LeakyReLU, BatchNormalization\n", 30 | "from tensorflow.keras.models import Sequential, Model\n", 31 | "from tensorflow.keras.regularizers import l2\n", 32 | "from tensorflow.keras.utils import plot_model, model_to_dot" 33 | ] 34 | }, 35 | { 36 | "cell_type": "markdown", 37 | "metadata": {}, 38 | "source": [ 39 | "Acknowledgement:\n", 40 | "- Thanks to the authors of the YOLO.\n", 41 | "- Many ideas for implementation are derived from [qqwweee](https://github.com/qqwweee/keras-yolo3/blob/master/yolo3/model.py)." 42 | ] 43 | }, 44 | { 45 | "cell_type": "markdown", 46 | "metadata": {}, 47 | "source": [ 48 | "## **YOLO Architecture**" 49 | ] 50 | }, 51 | { 52 | "cell_type": "markdown", 53 | "metadata": {}, 54 | "source": [ 55 | "### **Conv Block**\n", 56 | "\"Conv" 57 | ] 58 | }, 59 | { 60 | "cell_type": "code", 61 | "execution_count": 2, 62 | "metadata": {}, 63 | "outputs": [], 64 | "source": [ 65 | "def yolo_ConvBlock (input_tensor, num_filters, filter_size, strides = (1,1) ):\n", 66 | " padding = 'valid' if strides == (2,2) else 'same'\n", 67 | " \n", 68 | " '''Layers'''\n", 69 | " x = Conv2D( num_filters, filter_size, strides, padding, use_bias=False, kernel_regularizer=l2(5e-4) ) (input_tensor)\n", 70 | " x = BatchNormalization() (x)\n", 71 | " x = LeakyReLU(alpha=0.1) (x)\n", 72 | " \n", 73 | " return x" 74 | ] 75 | }, 76 | { 77 | "cell_type": "markdown", 78 | "metadata": {}, 79 | "source": [ 80 | "### **Residual Block**\n", 81 | "\"Residual" 82 | ] 83 | }, 84 | { 85 | "cell_type": "code", 86 | "execution_count": 3, 87 | "metadata": {}, 88 | "outputs": [], 89 | "source": [ 90 | "def yolo_ResidualBlocks (input_tensor, num_filters, num_blocks ):\n", 91 | " \n", 92 | " '''Layers'''\n", 93 | " x = ZeroPadding2D( ((1,0),(1,0)) ) (input_tensor) # left & top padding\n", 94 | " x = yolo_ConvBlock ( x, num_filters, filter_size=(3,3), strides = (2,2) )\n", 95 | " for _ in range( num_blocks ):\n", 96 | " y = yolo_ConvBlock ( x, num_filters//2, filter_size=(1,1), strides = (1,1) )\n", 97 | " y = yolo_ConvBlock ( y, num_filters , filter_size=(3,3), strides = (1,1) )\n", 98 | " x = Add() ([x, y])\n", 99 | " \n", 100 | " return x" 101 | ] 102 | }, 103 | { 104 | "cell_type": "markdown", 105 | "metadata": {}, 106 | "source": [ 107 | "### **Output Block**\n", 108 | "\"Conv" 109 | ] 110 | }, 111 | { 112 | "cell_type": "code", 113 | "execution_count": 4, 114 | "metadata": {}, 115 | "outputs": [], 116 | "source": [ 117 | "def yolo_OutputBlock (x, num_filters, out_filters ):\n", 118 | " \n", 119 | " '''Layers'''\n", 120 | " x = yolo_ConvBlock ( x, 1*num_filters, filter_size=(1,1), strides = (1,1) )\n", 121 | " x = yolo_ConvBlock ( x, 2*num_filters, filter_size=(3,3), strides = (1,1) )\n", 122 | " x = yolo_ConvBlock ( x, 1*num_filters, filter_size=(1,1), strides = (1,1) )\n", 123 | " x = yolo_ConvBlock ( x, 2*num_filters, filter_size=(3,3), strides = (1,1) )\n", 124 | " x = yolo_ConvBlock ( x, 1*num_filters, filter_size=(1,1), strides = (1,1) )\n", 125 | " \n", 126 | " y = yolo_ConvBlock ( x, 2*num_filters, filter_size=(3,3), strides = (1,1) )\n", 127 | " y = Conv2D ( filters=out_filters, kernel_size=(1,1), strides=(1,1), \n", 128 | " padding='same', use_bias=True, kernel_regularizer=l2(5e-4) )(y)\n", 129 | " \n", 130 | " return x, y" 131 | ] 132 | }, 133 | { 134 | "cell_type": "markdown", 135 | "metadata": {}, 136 | "source": [ 137 | "### **The Complete Model**\n", 138 | "\"Conv" 139 | ] 140 | }, 141 | { 142 | "cell_type": "code", 143 | "execution_count": 5, 144 | "metadata": {}, 145 | "outputs": [], 146 | "source": [ 147 | "def yolo_body (input_tensor, num_out_filters):\n", 148 | " '''\n", 149 | " Input: \n", 150 | " input_tensor = Input( shape=( *input_shape, 3 ) )\n", 151 | " num_out_filter = ( num_anchors // 3 ) * ( 5 + num_classes )\n", 152 | " Output:\n", 153 | " complete YOLO-v3 model\n", 154 | " '''\n", 155 | "\n", 156 | " # 1st Conv block\n", 157 | " x = yolo_ConvBlock( input_tensor, num_filters=32, filter_size=(3,3), strides=(1,1) )\n", 158 | "\n", 159 | " # 5 Resblocks\n", 160 | " x = yolo_ResidualBlocks ( x, num_filters= 64, num_blocks=1 )\n", 161 | " x = yolo_ResidualBlocks ( x, num_filters= 128, num_blocks=2 )\n", 162 | " x = yolo_ResidualBlocks ( x, num_filters= 256, num_blocks=8 )\n", 163 | " x = yolo_ResidualBlocks ( x, num_filters= 512, num_blocks=8 )\n", 164 | " x = yolo_ResidualBlocks ( x, num_filters=1024, num_blocks=4 )\n", 165 | "\n", 166 | " darknet = Model( input_tensor, x ) # will use it just in a moment\n", 167 | "\n", 168 | " # 1st output block\n", 169 | " x, y1 = yolo_OutputBlock( x, num_filters= 512, out_filters=num_out_filters )\n", 170 | "\n", 171 | " # 2nd output block\n", 172 | " x = yolo_ConvBlock( x, num_filters=256, filter_size=(1,1), strides=(1,1) )\n", 173 | " x = UpSampling2D(2) (x)\n", 174 | " x = Concatenate() ( [x, darknet.layers[152].output] )\n", 175 | " x, y2 = yolo_OutputBlock( x, num_filters= 256, out_filters=num_out_filters )\n", 176 | "\n", 177 | " # 3rd output block\n", 178 | " x = yolo_ConvBlock( x, num_filters=128, filter_size=(1,1), strides=(1,1) )\n", 179 | " x = UpSampling2D(2) (x)\n", 180 | " x = Concatenate() ( [x, darknet.layers[92].output] )\n", 181 | " x, y3 = yolo_OutputBlock( x, num_filters= 128, out_filters=num_out_filters )\n", 182 | "\n", 183 | " # Final model\n", 184 | " model = Model( input_tensor, [y1, y2, y3] )\n", 185 | " \n", 186 | " return model" 187 | ] 188 | }, 189 | { 190 | "cell_type": "markdown", 191 | "metadata": {}, 192 | "source": [ 193 | "## **EXAMPLE**" 194 | ] 195 | }, 196 | { 197 | "source": [ 198 | "#### Build a model" 199 | ], 200 | "cell_type": "markdown", 201 | "metadata": {} 202 | }, 203 | { 204 | "cell_type": "code", 205 | "execution_count": 6, 206 | "metadata": {}, 207 | "outputs": [ 208 | { 209 | "output_type": "stream", 210 | "name": "stdout", 211 | "text": [ 212 | "model is built!\n" 213 | ] 214 | } 215 | ], 216 | "source": [ 217 | "# configurations\n", 218 | "input_shape = (416, 416)\n", 219 | "num_classes = 3\n", 220 | "num_anchors = 9\n", 221 | "\n", 222 | "# input and output\n", 223 | "input_tensor = Input( shape=(input_shape[0], input_shape[1], 3) ) # Input\n", 224 | "num_out_filters = ( num_anchors//3 ) * ( 5 + num_classes ) # Output\n", 225 | "\n", 226 | "# build the model\n", 227 | "model = yolo_body (input_tensor, num_out_filters)\n", 228 | "\n", 229 | "print('model is built!')" 230 | ] 231 | }, 232 | { 233 | "source": [ 234 | "#### Check if a pre-trained weight can be loaded" 235 | ], 236 | "cell_type": "markdown", 237 | "metadata": {} 238 | }, 239 | { 240 | "cell_type": "code", 241 | "execution_count": 7, 242 | "metadata": {}, 243 | "outputs": [ 244 | { 245 | "output_type": "stream", 246 | "name": "stdout", 247 | "text": [ 248 | "weights are loaded!\n" 249 | ] 250 | } 251 | ], 252 | "source": [ 253 | "# check if trained weights can be loaded onto the model\n", 254 | "\n", 255 | "weight_path = '../model-data/weights/pictor-ppe-v302-a1-yolo-v3-weights.h5'\n", 256 | "model.load_weights( weight_path )\n", 257 | "\n", 258 | "print('weights are loaded!')" 259 | ] 260 | }, 261 | { 262 | "cell_type": "code", 263 | "execution_count": null, 264 | "metadata": {}, 265 | "outputs": [], 266 | "source": [] 267 | } 268 | ], 269 | "metadata": { 270 | "kernelspec": { 271 | "display_name": "Python 3", 272 | "language": "python", 273 | "name": "python3" 274 | }, 275 | "language_info": { 276 | "codemirror_mode": { 277 | "name": "ipython", 278 | "version": 3 279 | }, 280 | "file_extension": ".py", 281 | "mimetype": "text/x-python", 282 | "name": "python", 283 | "nbconvert_exporter": "python", 284 | "pygments_lexer": "ipython3", 285 | "version": "3.7.9-final" 286 | } 287 | }, 288 | "nbformat": 4, 289 | "nbformat_minor": 4 290 | } --------------------------------------------------------------------------------