├── README.md ├── generate-dataset.py ├── main.py └── requirements.txt /README.md: -------------------------------------------------------------------------------- 1 | humble-yolo is a minimal implementation of YOLO v1 I wrote to learn about the amazing YOLO algorithm. 2 | 3 | To test it run : 4 | 5 | 1. generate-dataset.py to generate data 6 | 2. main.py --train --epoch 100 for training the network 7 | 8 | You should see a list of images with bounding boxes. The first 10 images are test data not used for training. You can evaluate the performance of the network on those. The remaining images have been used for the training. 9 | 10 | main.py saves weights when it complete training. If you want to run the network without training and just see the result, running main.py alone will load last weights and redisplay results. 11 | -------------------------------------------------------------------------------- /generate-dataset.py: -------------------------------------------------------------------------------- 1 | from PIL import Image, ImageDraw 2 | import random 3 | import string 4 | import numpy as np 5 | import os 6 | 7 | def one_hot(x, length): 8 | return [1 if x==i else 0 for i in range(length)] 9 | 10 | def get_word(c): 11 | words = ["chat", "rat", "none"] 12 | return (words[c], one_hot(c,len(words))) 13 | 14 | cell_w = 32 15 | cell_h = 32 16 | grid_w = 2 17 | grid_h = 2 18 | 19 | if not os.path.exists('Labels'): 20 | os.mkdir('Labels') 21 | if not os.path.exists('Images'): 22 | os.mkdir('Images') 23 | 24 | for j in range(0,5000): 25 | img = Image.new('RGB', (grid_w*cell_w,grid_h*cell_h)) 26 | d = ImageDraw.Draw(img) 27 | 28 | with open('Labels/%d.txt' % j,'w+') as f: 29 | 30 | for row in range(grid_w): 31 | for col in range(grid_h): 32 | 33 | (digits, cat) = get_word(random.randint(0,2)) 34 | 35 | width = len(digits)*6 36 | 37 | if(digits=='none'): 38 | f.write('%d %d %d\n' % (cat[0],cat[1],cat[2]) ) 39 | f.write('%d %d %d %d\n' % ( col*cell_w+cell_w/2, row*cell_h+cell_h/2, cell_w, cell_h )) 40 | f.write('0\n') # confidence of object 41 | print("None", (col,row), (col*cell_w+cell_w/2, row*cell_h+cell_h/2, cell_w, cell_h), 0) 42 | else: 43 | x = random.randrange(col*cell_w, (col+1)*cell_w) 44 | y = random.randrange(row*cell_w, min(67, (row+1)*cell_h)) 45 | 46 | d.text((x-width/2, y-10/2), digits, fill=(255,255,255)) 47 | f.write('%d %d %d\n' % (cat[0],cat[1],cat[2])) 48 | f.write('%d %d %d %d\n' % (x, y, width, 10) ) 49 | f.write('1\n') # confidence of object 50 | print("Objt", (col,row), (x, y, width, 10), 1) 51 | 52 | f.write('---\n') 53 | 54 | img.save('Images/%d.PNG' % j) 55 | 56 | 57 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import keras 3 | import tensorflow as tf 4 | from keras.datasets import mnist 5 | from keras.models import Sequential, Model 6 | from keras.layers import Dense, Dropout, Flatten, Input, Reshape 7 | from keras.layers import Conv2D, MaxPooling2D 8 | 9 | from keras.layers.advanced_activations import LeakyReLU, PReLU 10 | 11 | from keras import backend as K 12 | from keras.models import load_model 13 | import numpy as np 14 | import sys 15 | import cv2 16 | import argparse 17 | 18 | from keras.preprocessing.image import ImageDataGenerator, array_to_img, img_to_array, load_img 19 | 20 | import matplotlib.pyplot as plt 21 | import matplotlib.patches as patches 22 | 23 | 24 | x_train = [] 25 | y_train = [] 26 | 27 | nb_boxes=1 28 | grid_w=2 29 | grid_h=2 30 | cell_w=32 31 | cell_h=32 32 | img_w=grid_w*cell_w 33 | img_h=grid_h*cell_h 34 | 35 | # 36 | # Read input image and output prediction 37 | # 38 | def load_image(j): 39 | img = cv2.imread('Images/%d.PNG' % j) 40 | # img = cv2.resize(img,(64,64)) 41 | 42 | x_t = img_to_array(img) 43 | 44 | with open("Labels/%d.txt" % j, "r") as f: 45 | y_t = [] 46 | for row in range(grid_w): 47 | for col in range(grid_h): 48 | c_t = [float(i) for i in f.readline().split()] 49 | [x, y, w, h] = [float(i) for i in f.readline().split()] 50 | conf_t = [float(i) for i in f.readline().split()] 51 | elt = [] 52 | elt += c_t 53 | for b in range(nb_boxes): 54 | elt += [x/cell_w, y/cell_h, w/img_w, h/img_h] + conf_t 55 | y_t.append(elt) 56 | assert(f.readline()=="---\n") 57 | 58 | return [x_t, y_t] 59 | 60 | # 61 | # Load all images and append to vector 62 | # 63 | #for j in range(0, 10): 64 | for j in range(10, 5000): 65 | [x,y] = load_image(j) 66 | x_train.append(x) 67 | y_train.append(y) 68 | 69 | x_train = np.array(x_train) 70 | y_train = np.array(y_train) 71 | 72 | # 73 | # Define the deep learning network 74 | # 75 | 76 | # model 2 77 | i = Input(shape=(img_h,img_w,3)) 78 | 79 | x = Conv2D(16, (1, 1))(i) 80 | x = Conv2D(32, (3, 3))(x) 81 | x = keras.layers.LeakyReLU(alpha=0.3)(x) 82 | x = MaxPooling2D(pool_size=(2, 2))(x) 83 | x = Conv2D(16, (3, 3))(x) 84 | x = Conv2D(32, (3, 3))(x) 85 | x = keras.layers.LeakyReLU(alpha=0.3)(x) 86 | x = MaxPooling2D(pool_size=(2, 2))(x) 87 | #x = Dropout(0.25)(x) 88 | 89 | x = Flatten()(x) 90 | x = Dense(256, activation='sigmoid')(x) 91 | x = Dense(grid_w*grid_h*(3+nb_boxes*5), activation='sigmoid')(x) 92 | x = Reshape((grid_w*grid_h,(3+nb_boxes*5)))(x) 93 | 94 | model = Model(i, x) 95 | 96 | # 97 | # The loss function orient the backpropagation algorithm toward the best direction. 98 | #It does so by outputting a number. The larger the number, the further we are from a correct solution. 99 | #Keras also accept that we output a tensor. In that case it will just sum all the numbers to get a single number. 100 | # 101 | # y_true is training data 102 | # y_pred is value predicted by the network 103 | def custom_loss(y_true, y_pred): 104 | # define a grid of offsets 105 | # [[[ 0. 0.]] 106 | # [[ 1. 0.]] 107 | # [[ 0. 1.]] 108 | # [[ 1. 1.]]] 109 | grid = np.array([ [[float(x),float(y)]]*nb_boxes for y in range(grid_h) for x in range(grid_w)]) 110 | 111 | # first three values are classes : cat, rat, and none. 112 | # However yolo doesn't predict none as a class, none is everything else and is just not predicted 113 | # so I don't use it in the loss 114 | y_true_class = y_true[...,0:2] 115 | y_pred_class = y_pred[...,0:2] 116 | 117 | # reshape array as a list of grid / grid cells / boxes / of 5 elements 118 | pred_boxes = K.reshape(y_pred[...,3:], (-1,grid_w*grid_h,nb_boxes,5)) 119 | true_boxes = K.reshape(y_true[...,3:], (-1,grid_w*grid_h,nb_boxes,5)) 120 | 121 | # sum coordinates of center of boxes with cell offsets. 122 | # as pred boxes are limited to 0 to 1 range, pred x,y + offset is limited to predicting elements inside a cell 123 | y_pred_xy = pred_boxes[...,0:2] + K.variable(grid) 124 | # w and h predicted are 0 to 1 with 1 being image size 125 | y_pred_wh = pred_boxes[...,2:4] 126 | # probability that there is something to predict here 127 | y_pred_conf = pred_boxes[...,4] 128 | 129 | # same as predicate except that we don't need to add an offset, coordinate are already between 0 and cell count 130 | y_true_xy = true_boxes[...,0:2] 131 | # with and height 132 | y_true_wh = true_boxes[...,2:4] 133 | # probability that there is something in that cell. 0 or 1 here as it's a certitude. 134 | y_true_conf = true_boxes[...,4] 135 | 136 | clss_loss = K.sum(K.square(y_true_class - y_pred_class), axis=-1) 137 | xy_loss = K.sum(K.sum(K.square(y_true_xy - y_pred_xy),axis=-1)*y_true_conf, axis=-1) 138 | wh_loss = K.sum(K.sum(K.square(K.sqrt(y_true_wh) - K.sqrt(y_pred_wh)), axis=-1)*y_true_conf, axis=-1) 139 | 140 | # when we add the confidence the box prediction lower in quality but we gain the estimation of the quality of the box 141 | # however the training is a bit unstable 142 | 143 | # compute the intersection of all boxes at once (the IOU) 144 | intersect_wh = K.maximum(K.zeros_like(y_pred_wh), (y_pred_wh + y_true_wh)/2 - K.square(y_pred_xy - y_true_xy) ) 145 | intersect_area = intersect_wh[...,0] * intersect_wh[...,1] 146 | true_area = y_true_wh[...,0] * y_true_wh[...,1] 147 | pred_area = y_pred_wh[...,0] * y_pred_wh[...,1] 148 | union_area = pred_area + true_area - intersect_area 149 | iou = intersect_area / union_area 150 | 151 | conf_loss = K.sum(K.square(y_true_conf*iou - y_pred_conf), axis=-1) 152 | 153 | # final loss function 154 | d = xy_loss + wh_loss + conf_loss + clss_loss 155 | 156 | if False: 157 | d = tf.Print(d, [d], "loss") 158 | d = tf.Print(d, [xy_loss], "xy_loss") 159 | d = tf.Print(d, [wh_loss], "wh_loss") 160 | d = tf.Print(d, [clss_loss], "clss_loss") 161 | d = tf.Print(d, [conf_loss], "conf_loss") 162 | 163 | return d 164 | 165 | model = Model(i, x) 166 | 167 | adam = keras.optimizers.Adam(lr=0.001, beta_1=0.9, beta_2=0.999, decay=0.01) 168 | model.compile(loss=custom_loss, optimizer=adam) # better 169 | 170 | print(model.summary()) 171 | 172 | # 173 | # Training the network 174 | # 175 | parser = argparse.ArgumentParser(description='Process some integers.') 176 | parser.add_argument('--train', help='train', action='store_true') 177 | parser.add_argument('--epoch', help='epoch', const='int', nargs='?', default=1) 178 | args = parser.parse_args() 179 | 180 | if args.train: 181 | model.fit(x_train, y_train, batch_size=64, epochs=int(args.epoch)) 182 | model.save_weights('weights_006.h5') 183 | else: 184 | model.load_weights('weights_006.h5') 185 | 186 | axes=[0 for _ in range(100)] 187 | fig, axes = plt.subplots(5,5) 188 | 189 | # 190 | # Predict bounding box and classes for the first 25 images 191 | # 192 | for j in range(0,25): 193 | im = load_image(j) 194 | 195 | # 196 | # Predict bounding box and classes 197 | # 198 | img = cv2.imread('Images/%d.PNG' % j) 199 | #img = cv2.resize(img, (img_w,img_h)) 200 | #data = img_to_array(img) 201 | P = model.predict(np.array([ img_to_array(img) ])) 202 | 203 | # 204 | # Draw each boxes and class score over each images using pyplot 205 | # 206 | col = 0 207 | for row in range(grid_w): 208 | for col in range(grid_h): 209 | p = P[0][col*grid_h+row] 210 | 211 | boxes = p[3:].reshape(nb_boxes,5) 212 | clss = np.argmax(p[0:2]) 213 | 214 | ax = plt.subplot(5,5,j+1) 215 | imgplot = plt.imshow(img) 216 | 217 | i = 0 218 | for b in boxes: 219 | x = b[0]+float(row) 220 | y = b[1]+float(col) 221 | w = b[2] 222 | h = b[3] 223 | conf = b[4] 224 | if conf < 0.5: 225 | continue 226 | 227 | color = ['r','g','b','0'][clss] 228 | rect = patches.Rectangle((x*cell_w-w/2*img_w, y*cell_h-h/2*img_h), w*img_h, h*img_h, linewidth=1,edgecolor=color,facecolor='none') 229 | ax.add_patch(rect) 230 | 231 | ax.text( (x*cell_w-w/2*img_w) / img_w, 1-(y*cell_h-h/2*img_h)/img_h-i*0.15, "%0.2f" % (conf), transform=ax.transAxes, color=color) 232 | i+=1 233 | 234 | plt.show() 235 | 236 | 237 | 238 | 239 | 240 | 241 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | appnope==0.1.0 2 | backports.functools-lru-cache==1.4 3 | backports.shutil-get-terminal-size==1.0.0 4 | backports.weakref==1.0rc1 5 | bleach==1.5.0 6 | boto==2.48.0 7 | boto3==1.5.29 8 | botocore==1.8.43 9 | bz2file==0.98 10 | certifi==2017.11.5 11 | chardet==3.0.4 12 | click==6.7 13 | cycler==0.10.0 14 | cymem==1.31.2 15 | cytoolz==0.8.2 16 | decorator==4.3.2 17 | dill==0.2.7.1 18 | docutils==0.14 19 | enum34==1.1.6 20 | Flask==1.0.2 21 | ftfy==4.4.3 22 | funcsigs==1.0.2 23 | futures==3.2.0 24 | gensim==3.3.0 25 | h5py==2.7.1 26 | html5lib==0.9999999 27 | idna==2.6 28 | imageio==2.4.1 29 | ipython==5.8.0 30 | ipython-genutils==0.2.0 31 | itsdangerous==0.24 32 | Jinja2==2.10 33 | jmespath==0.9.3 34 | kaggle==1.3.6 35 | Keras==2.1.2 36 | keras-text==0.1 37 | Markdown==2.6.9 38 | MarkupSafe==1.0 39 | matplotlib==2.1.0 40 | mock==2.0.0 41 | mpmath==1.1.0 42 | msgpack-numpy==0.4.1 43 | msgpack-python==0.4.8 44 | murmurhash==0.28.0 45 | nltk==3.2.4 46 | numpy==1.13.1 47 | opencv-python==3.4.4.19 48 | pandas==0.21.1 49 | pathlib==1.0.1 50 | pathlib2==2.3.3 51 | pbr==3.1.1 52 | pexpect==4.6.0 53 | pickleshare==0.7.5 54 | Pillow==5.2.0 55 | pipenv==2018.5.18 56 | plac==0.9.6 57 | preshed==1.0.0 58 | prompt-toolkit==1.0.15 59 | protobuf==3.4.0 60 | ptyprocess==0.6.0 61 | Pygments==2.3.1 62 | pyparsing==2.2.0 63 | python-dateutil==2.6.1 64 | pytz==2017.3 65 | PyYAML==3.12 66 | regex==2017.4.5 67 | requests==2.18.4 68 | s3transfer==0.1.13 69 | scandir==1.10.0 70 | scikit-learn==0.19.1 71 | scipy==1.0.0 72 | simplegeneric==0.8.1 73 | six==1.10.0 74 | sklearn==0.0 75 | smart-open==1.5.6 76 | spacy==2.0.4 77 | subprocess32==3.2.7 78 | sympy==1.3 79 | tensorboard==1.8.0 80 | tensorflow==1.3.0 81 | tensorflow-tensorboard==0.1.6 82 | termcolor==1.1.0 83 | thinc==6.10.2 84 | toolz==0.8.2 85 | tqdm==4.19.4 86 | traitlets==4.3.2 87 | ujson==1.35 88 | urllib3==1.22 89 | virtualenv==16.0.0 90 | virtualenv-clone==0.3.0 91 | wcwidth==0.1.7 92 | Werkzeug==0.14.1 93 | wrapt==1.10.11 94 | --------------------------------------------------------------------------------