├── README.md ├── coco.names ├── detect.py ├── detect_image.py ├── detect_video.py ├── detections ├── cars_yolo.jpg ├── driving_yolo.mp4 ├── office_yolo.jpg └── zebra_yolo.jpg ├── input ├── cars.jpg ├── office.jpg └── zebra.jpg ├── load_weights.py ├── openh264-1.8.0-win64.dll ├── requirements.txt ├── utils.py └── yolo_v3.py /README.md: -------------------------------------------------------------------------------- 1 | # TensorFlow YOLO v3 Tutorial 2 | If you hearing about "You Only Look Once" first time, you should know that it is an algorithm that uses convolutional neural networks for object detection. 3 | You only look once, or YOLO, is one of the fastest object detection algorithms out there. 4 | Though it is not the most accurate object detection algorithm, but it is a very good choice when we need real-time detection, without loss of too much accuracy. 5 | 6 | To learn more about YOLO v3 and how it works please read my tutorial to understand how it works before moving to code:

7 | [YOLO v3 theory explained](https://pylessons.com/YOLOv3-introduction/)

8 | 9 | Detailed code explanation you can find also on my website:

10 | [YOLO v3 code explained](https://pylessons.com/YOLOv3-code-explanation/)

11 | 12 | 13 | ## Getting started 14 | 15 | ### Prerequisites 16 | This tutorial was written in Python 3.7 using Tensorflow (for deep learning), NumPy (for numerical computing), OpenCV (computer vision) and seaborn (visualization) packages. It's so wonderful that you can run object detection just using 4 simple libraries! First of all download all files from this tutorial. To install required libraries run: 17 | ``` 18 | pip install -r requirements.txt 19 | ``` 20 | 21 | 22 | ### Downloading official pretrained weights 23 | Next we need to download official weights pretrained on COCO dataset. You can do this two ways. You can download it manually on same link below, create "weights" folder in repository and copy downloaded weights to that folder. Or you can simply do it with this command: 24 | ``` 25 | wget -P weights https://pjreddie.com/media/files/yolov3.weights 26 | ``` 27 | 28 | ### Convert weights into TensorFlow format 29 | Now you need to run `load_weights.py` script, to convert downloaded weights to TensorFlow format. 30 | ``` 31 | python load_weights.py 32 | ``` 33 | 34 | ## Running the model 35 | Now you are ready to run the model using `detect_image.py` or `detect_video.py`script. 36 | You can try to play around with iou_threshold and confidence_threshold parameters. 37 | My example images and video is in `input` folder. So you can put your examples there also or use different location. 38 | 39 | ### Image usage example 40 | If you'll open `detect_image.py` script at the last line you'll see: 41 | ``` 42 | main(0.5, 0.5, "input/office.jpg") 43 | ``` 44 | Here you can play with iou_threshold, confidence_threshold parameters and try you image for detection.

45 | Here is few examples: 46 | ``` 47 | main(0.5, 0.5, "input/office.jpg") 48 | ``` 49 | ![alt text](https://github.com/pythonlessons/TensorFlow-YOLO-v3-Tutorial/blob/master/detections/office_yolo.jpg) 50 | ``` 51 | main(0.5, 0.5, "input/cars.jpg") 52 | ``` 53 | ![alt text](https://github.com/pythonlessons/TensorFlow-YOLO-v3-Tutorial/blob/master/detections/cars_yolo.jpg) 54 | ``` 55 | main(0.5, 0.5, "input/zebra.jpg") 56 | ``` 57 | ![alt text](https://github.com/pythonlessons/TensorFlow-YOLO-v3-Tutorial/blob/master/detections/zebra_yolo.jpg) 58 | 59 | ### Video usage example 60 | If you'll open `detect_video.py` script at the last line you'll see: 61 | ``` 62 | main(0.5, 0.5, "input/driving.mp4") 63 | ``` 64 | The detections will be saved as `driving_yolo.mp4` file.Example video: 65 | [![IMAGE ALT TEXT HERE](https://img.youtube.com/vi/wEmhflE7vmg/0.jpg)](https://youtu.be/wEmhflE7vmg)
66 | 67 | ## Future To-Do List 68 | * Write YOLOv3 in Keras 69 | * Train custom YOLOv3 detection model 70 | * Test YOLOv3 FPS performance on CS:GO 71 | -------------------------------------------------------------------------------- /coco.names: -------------------------------------------------------------------------------- 1 | person 2 | bicycle 3 | car 4 | motorbike 5 | aeroplane 6 | bus 7 | train 8 | truck 9 | boat 10 | traffic light 11 | fire hydrant 12 | stop sign 13 | parking meter 14 | bench 15 | bird 16 | cat 17 | dog 18 | horse 19 | sheep 20 | cow 21 | elephant 22 | bear 23 | zebra 24 | giraffe 25 | backpack 26 | umbrella 27 | handbag 28 | tie 29 | suitcase 30 | frisbee 31 | skis 32 | snowboard 33 | sports ball 34 | kite 35 | baseball bat 36 | baseball glove 37 | skateboard 38 | surfboard 39 | tennis racket 40 | bottle 41 | wine glass 42 | cup 43 | fork 44 | knife 45 | spoon 46 | bowl 47 | banana 48 | apple 49 | sandwich 50 | orange 51 | broccoli 52 | carrot 53 | hot dog 54 | pizza 55 | donut 56 | cake 57 | chair 58 | sofa 59 | pottedplant 60 | bed 61 | diningtable 62 | toilet 63 | tvmonitor 64 | laptop 65 | mouse 66 | remote 67 | keyboard 68 | cell phone 69 | microwave 70 | oven 71 | toaster 72 | sink 73 | refrigerator 74 | book 75 | clock 76 | vase 77 | scissors 78 | teddy bear 79 | hair drier 80 | toothbrush 81 | -------------------------------------------------------------------------------- /detect.py: -------------------------------------------------------------------------------- 1 | """Yolo v3 detection script. 2 | 3 | Saves the detections in the `detection` folder. 4 | 5 | Usage: 6 | python detect.py 7 | 8 | Example: 9 | python detect.py images 0.5 0.5 data/images/dog.jpg data/images/office.jpg 10 | python detect.py video 0.5 0.5 data/video/shinjuku.mp4 11 | 12 | Note that only one video can be processed at one run. 13 | """ 14 | 15 | import os 16 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' 17 | 18 | import tensorflow as tf 19 | import sys 20 | import cv2 21 | 22 | from yolo_v3 import Yolo_v3 23 | from utils import load_images, load_class_names, draw_boxes, draw_frame 24 | 25 | _MODEL_SIZE = (416, 416) 26 | _CLASS_NAMES_FILE = './data/labels/coco.names' 27 | _MAX_OUTPUT_SIZE = 20 28 | 29 | detection_result = {} 30 | 31 | 32 | def main(type, iou_threshold, confidence_threshold, input_names): 33 | global detection_result 34 | class_names = load_class_names(_CLASS_NAMES_FILE) 35 | n_classes = len(class_names) 36 | 37 | model = Yolo_v3(n_classes=n_classes, model_size=_MODEL_SIZE, 38 | max_output_size=_MAX_OUTPUT_SIZE, 39 | iou_threshold=iou_threshold, 40 | confidence_threshold=confidence_threshold) 41 | 42 | if type == 'images': 43 | #batch_size = len(input_names) 44 | batch = load_images(input_names, model_size=_MODEL_SIZE) 45 | inputs = tf.placeholder(tf.float32, [1, *_MODEL_SIZE, 3]) 46 | detections = model(inputs, training=False) 47 | saver = tf.train.Saver(tf.global_variables(scope='yolo_v3_model')) 48 | 49 | with tf.Session() as sess: 50 | saver.restore(sess, './weights/model.ckpt') 51 | detection_result = sess.run(detections, feed_dict={inputs: batch}) 52 | 53 | #detection_result = detection_result[0] 54 | print("detection_result", detection_result) 55 | draw_boxes(input_names, detection_result, class_names, _MODEL_SIZE) 56 | 57 | print('Detections have been saved successfully.') 58 | 59 | elif type == 'video': 60 | inputs = tf.placeholder(tf.float32, [1, *_MODEL_SIZE, 3]) 61 | detections = model(inputs, training=False) 62 | saver = tf.train.Saver(tf.global_variables(scope='yolo_v3_model')) 63 | 64 | with tf.Session() as sess: 65 | saver.restore(sess, './weights/model.ckpt') 66 | 67 | win_name = 'Video detection' 68 | cv2.namedWindow(win_name) 69 | cap = cv2.VideoCapture(input_names[0]) 70 | frame_size = (cap.get(cv2.CAP_PROP_FRAME_WIDTH), 71 | cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) 72 | fourcc = cv2.VideoWriter_fourcc(*'X264') 73 | fps = cap.get(cv2.CAP_PROP_FPS) 74 | out = cv2.VideoWriter('./detections/detections.mp4', fourcc, fps, 75 | (int(frame_size[0]), int(frame_size[1]))) 76 | 77 | try: 78 | while True: 79 | ret, frame = cap.read() 80 | if not ret: 81 | break 82 | resized_frame = cv2.resize(frame, dsize=_MODEL_SIZE[::-1], 83 | interpolation=cv2.INTER_NEAREST) 84 | detection_result = sess.run(detections, 85 | feed_dict={inputs: [resized_frame]}) 86 | 87 | draw_frame(frame, frame_size, detection_result, 88 | class_names, _MODEL_SIZE) 89 | 90 | cv2.imshow(win_name, frame) 91 | 92 | key = cv2.waitKey(1) & 0xFF 93 | 94 | if key == ord('q'): 95 | break 96 | 97 | out.write(frame) 98 | finally: 99 | cv2.destroyAllWindows() 100 | cap.release() 101 | print('Detections have been saved successfully.') 102 | 103 | else: 104 | raise ValueError("Inappropriate data type. Please choose either 'video' or 'images'.") 105 | 106 | 107 | if __name__ == '__main__': 108 | #main(sys.argv[1], float(sys.argv[2]), float(sys.argv[3]), sys.argv[4:]) 109 | main("images", 0.5, 0.5, "road.jpg") 110 | -------------------------------------------------------------------------------- /detect_image.py: -------------------------------------------------------------------------------- 1 | # Yolo v3 image detection 2 | import os 3 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' 4 | import tensorflow as tf 5 | import sys 6 | import cv2 7 | 8 | from yolo_v3 import Yolo_v3 9 | from utils import load_images, load_class_names, draw_boxes 10 | 11 | _MODEL_SIZE = (416, 416) 12 | _CLASS_NAMES_FILE = 'coco.names' 13 | _MAX_OUTPUT_SIZE = 50 14 | 15 | detection_result = {} 16 | 17 | 18 | def main(iou_threshold, confidence_threshold, input_names): 19 | global detection_result 20 | class_names = load_class_names(_CLASS_NAMES_FILE) 21 | n_classes = len(class_names) 22 | 23 | model = Yolo_v3(n_classes=n_classes, model_size=_MODEL_SIZE, 24 | max_output_size=_MAX_OUTPUT_SIZE, 25 | iou_threshold=iou_threshold, 26 | confidence_threshold=confidence_threshold) 27 | 28 | 29 | batch = load_images(input_names, model_size=_MODEL_SIZE) 30 | inputs = tf.placeholder(tf.float32, [1, *_MODEL_SIZE, 3]) 31 | detections = model(inputs, training=False) 32 | saver = tf.train.Saver(tf.global_variables(scope='yolo_v3_model')) 33 | 34 | with tf.Session() as sess: 35 | saver.restore(sess, './weights/model.ckpt') 36 | detection_result = sess.run(detections, feed_dict={inputs: batch}) 37 | 38 | draw_boxes(input_names, detection_result, class_names, _MODEL_SIZE) 39 | 40 | print('Detections have been saved successfully.') 41 | 42 | 43 | if __name__ == '__main__': 44 | main(0.5, 0.5, "input/office.jpg") 45 | -------------------------------------------------------------------------------- /detect_video.py: -------------------------------------------------------------------------------- 1 | # Yolo v3 video detection 2 | 3 | import os 4 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' 5 | import tensorflow as tf 6 | import sys 7 | import cv2 8 | 9 | from yolo_v3 import Yolo_v3 10 | from utils import load_images, load_class_names, draw_boxes, draw_frame 11 | 12 | _MODEL_SIZE = (416, 416) 13 | _CLASS_NAMES_FILE = 'coco.names' 14 | _MAX_OUTPUT_SIZE = 50 15 | 16 | detection_result = {} 17 | 18 | 19 | def main(iou_threshold, confidence_threshold, input_names): 20 | global detection_result 21 | class_names = load_class_names(_CLASS_NAMES_FILE) 22 | n_classes = len(class_names) 23 | 24 | model = Yolo_v3(n_classes=n_classes, model_size=_MODEL_SIZE, 25 | max_output_size=_MAX_OUTPUT_SIZE, 26 | iou_threshold=iou_threshold, 27 | confidence_threshold=confidence_threshold) 28 | 29 | inputs = tf.placeholder(tf.float32, [1, *_MODEL_SIZE, 3]) 30 | detections = model(inputs, training=False) 31 | saver = tf.train.Saver(tf.global_variables(scope='yolo_v3_model')) 32 | 33 | with tf.Session() as sess: 34 | saver.restore(sess, './weights/model.ckpt') 35 | 36 | win_name = 'Video detection' 37 | cv2.namedWindow(win_name) 38 | cap = cv2.VideoCapture(input_names) 39 | frame_size = (cap.get(cv2.CAP_PROP_FRAME_WIDTH), cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) 40 | fourcc = int(cap.get(cv2.CAP_PROP_FOURCC)) 41 | fps = cap.get(cv2.CAP_PROP_FPS) 42 | if not os.path.exists('detections'): 43 | os.mkdir('detections') 44 | head, tail = os.path.split(input_names) 45 | name = './detections/'+tail[:-4]+'_yolo.mp4' 46 | out = cv2.VideoWriter(name, fourcc, fps, (int(frame_size[0]), int(frame_size[1]))) 47 | 48 | try: 49 | print("Show video") 50 | while(cap.isOpened()): 51 | ret, frame = cap.read() 52 | if not ret: 53 | break 54 | resized_frame = cv2.resize(frame, dsize=_MODEL_SIZE[::-1], interpolation=cv2.INTER_NEAREST) 55 | detection_result = sess.run(detections, feed_dict={inputs: [resized_frame]}) 56 | draw_frame(frame, frame_size, detection_result, class_names, _MODEL_SIZE) 57 | if ret == True: 58 | cv2.imshow(win_name, frame) 59 | out.write(frame) 60 | 61 | if cv2.waitKey(1) & 0xFF == ord('q'): 62 | break 63 | 64 | finally: 65 | cv2.destroyAllWindows() 66 | cap.release() 67 | print('Detections have been saved successfully.') 68 | 69 | 70 | if __name__ == '__main__': 71 | main(0.5, 0.5, "input/driving.mp4") 72 | -------------------------------------------------------------------------------- /detections/cars_yolo.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pythonlessons/TensorFlow-YOLO-v3-Tutorial/a4e81d6cc5a790476bef497a024c6070f37d51c2/detections/cars_yolo.jpg -------------------------------------------------------------------------------- /detections/driving_yolo.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pythonlessons/TensorFlow-YOLO-v3-Tutorial/a4e81d6cc5a790476bef497a024c6070f37d51c2/detections/driving_yolo.mp4 -------------------------------------------------------------------------------- /detections/office_yolo.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pythonlessons/TensorFlow-YOLO-v3-Tutorial/a4e81d6cc5a790476bef497a024c6070f37d51c2/detections/office_yolo.jpg -------------------------------------------------------------------------------- /detections/zebra_yolo.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pythonlessons/TensorFlow-YOLO-v3-Tutorial/a4e81d6cc5a790476bef497a024c6070f37d51c2/detections/zebra_yolo.jpg -------------------------------------------------------------------------------- /input/cars.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pythonlessons/TensorFlow-YOLO-v3-Tutorial/a4e81d6cc5a790476bef497a024c6070f37d51c2/input/cars.jpg -------------------------------------------------------------------------------- /input/office.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pythonlessons/TensorFlow-YOLO-v3-Tutorial/a4e81d6cc5a790476bef497a024c6070f37d51c2/input/office.jpg -------------------------------------------------------------------------------- /input/zebra.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pythonlessons/TensorFlow-YOLO-v3-Tutorial/a4e81d6cc5a790476bef497a024c6070f37d51c2/input/zebra.jpg -------------------------------------------------------------------------------- /load_weights.py: -------------------------------------------------------------------------------- 1 | """Loads Yolo v3 pretrained weights and saves them in tensorflow format.""" 2 | 3 | import os 4 | os.environ['CUDA_VISIBLE_DEVICES'] = '-1' 5 | 6 | import tensorflow as tf 7 | import numpy as np 8 | 9 | from yolo_v3 import Yolo_v3 10 | 11 | 12 | # Reshapes and loads official pretrained Yolo weights. 13 | def load_weights(variables, file_name): 14 | with open(file_name, "rb") as f: 15 | # Skip first 5 values containing irrelevant info 16 | np.fromfile(f, dtype=np.int32, count=5) 17 | weights = np.fromfile(f, dtype=np.float32) 18 | 19 | assign_ops = [] 20 | ptr = 0 21 | 22 | # Load weights for Darknet part. 23 | # Each convolution layer has batch normalization. 24 | for i in range(52): 25 | conv_var = variables[5 * i] 26 | gamma, beta, mean, variance = variables[5 * i + 1:5 * i + 5] 27 | batch_norm_vars = [beta, gamma, mean, variance] 28 | 29 | for var in batch_norm_vars: 30 | shape = var.shape.as_list() 31 | num_params = np.prod(shape) 32 | var_weights = weights[ptr:ptr + num_params].reshape(shape) 33 | ptr += num_params 34 | assign_ops.append(tf.assign(var, var_weights)) 35 | 36 | shape = conv_var.shape.as_list() 37 | num_params = np.prod(shape) 38 | var_weights = weights[ptr:ptr + num_params].reshape( 39 | (shape[3], shape[2], shape[0], shape[1])) 40 | var_weights = np.transpose(var_weights, (2, 3, 1, 0)) 41 | ptr += num_params 42 | assign_ops.append(tf.assign(conv_var, var_weights)) 43 | 44 | # Loading weights for Yolo part. 45 | # 7th, 15th and 23rd convolution layer has biases and no batch norm. 46 | ranges = [range(0, 6), range(6, 13), range(13, 20)] 47 | unnormalized = [6, 13, 20] 48 | for j in range(3): 49 | for i in ranges[j]: 50 | current = 52 * 5 + 5 * i + j * 2 51 | conv_var = variables[current] 52 | gamma, beta, mean, variance = \ 53 | variables[current + 1:current + 5] 54 | batch_norm_vars = [beta, gamma, mean, variance] 55 | 56 | for var in batch_norm_vars: 57 | shape = var.shape.as_list() 58 | num_params = np.prod(shape) 59 | var_weights = weights[ptr:ptr + num_params].reshape(shape) 60 | ptr += num_params 61 | assign_ops.append(tf.assign(var, var_weights)) 62 | 63 | shape = conv_var.shape.as_list() 64 | num_params = np.prod(shape) 65 | var_weights = weights[ptr:ptr + num_params].reshape( 66 | (shape[3], shape[2], shape[0], shape[1])) 67 | var_weights = np.transpose(var_weights, (2, 3, 1, 0)) 68 | ptr += num_params 69 | assign_ops.append(tf.assign(conv_var, var_weights)) 70 | 71 | bias = variables[52 * 5 + unnormalized[j] * 5 + j * 2 + 1] 72 | shape = bias.shape.as_list() 73 | num_params = np.prod(shape) 74 | var_weights = weights[ptr:ptr + num_params].reshape(shape) 75 | ptr += num_params 76 | assign_ops.append(tf.assign(bias, var_weights)) 77 | 78 | conv_var = variables[52 * 5 + unnormalized[j] * 5 + j * 2] 79 | shape = conv_var.shape.as_list() 80 | num_params = np.prod(shape) 81 | var_weights = weights[ptr:ptr + num_params].reshape( 82 | (shape[3], shape[2], shape[0], shape[1])) 83 | var_weights = np.transpose(var_weights, (2, 3, 1, 0)) 84 | ptr += num_params 85 | assign_ops.append(tf.assign(conv_var, var_weights)) 86 | 87 | return assign_ops 88 | 89 | 90 | def main(): 91 | model = Yolo_v3(n_classes=80, model_size=(416, 416), 92 | max_output_size=5, 93 | iou_threshold=0.5, 94 | confidence_threshold=0.5) 95 | 96 | inputs = tf.placeholder(tf.float32, [1, 416, 416, 3]) 97 | 98 | model(inputs, training=False) 99 | 100 | model_vars = tf.global_variables(scope='yolo_v3_model') 101 | assign_ops = load_weights(model_vars, './weights/yolov3.weights') 102 | 103 | saver = tf.train.Saver(tf.global_variables(scope='yolo_v3_model')) 104 | 105 | with tf.Session() as sess: 106 | sess.run(assign_ops) 107 | saver.save(sess, './weights/model.ckpt') 108 | print('Model has been saved successfully.') 109 | 110 | 111 | if __name__ == '__main__': 112 | main() 113 | -------------------------------------------------------------------------------- /openh264-1.8.0-win64.dll: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pythonlessons/TensorFlow-YOLO-v3-Tutorial/a4e81d6cc5a790476bef497a024c6070f37d51c2/openh264-1.8.0-win64.dll -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tensorflow 2 | numpy 3 | Pillow 4 | opencv-python 5 | seaborn -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | # this file contains utility functions for Yolo v3 model 2 | 3 | import os 4 | import numpy as np 5 | from seaborn import color_palette 6 | import cv2 7 | 8 | # Loads images in a 4D array 9 | def load_images(img_name, model_size): 10 | imgs = [] 11 | 12 | img = cv2.imread(img_name) 13 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 14 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 15 | img = cv2.resize(img, model_size) 16 | img = np.array(img, dtype=np.float32) 17 | img = np.expand_dims(img[:, :, :3], axis=0) 18 | imgs.append(img) 19 | 20 | imgs = np.concatenate(imgs) 21 | 22 | return imgs 23 | 24 | # Returns a list of class names read from `file_name` 25 | def load_class_names(file_name): 26 | with open(file_name, 'r') as f: 27 | class_names = f.read().splitlines() 28 | return class_names 29 | 30 | # Draws detected boxes 31 | def draw_boxes(img_names, boxes_dicts, class_names, model_size): 32 | colors = ((np.array(color_palette("hls", 80)) * 255)).astype(np.uint8) 33 | for num, img_name, boxes_dict in zip(range(len(img_names)), img_names, boxes_dicts): 34 | 35 | img = cv2.imread(img_names) 36 | img = np.array(img, dtype=np.float32) 37 | resize_factor = (img.shape[1] / model_size[0], img.shape[0] / model_size[1]) 38 | for cls in range(len(class_names)): 39 | boxes = boxes_dict[cls] 40 | color = colors[cls] 41 | color = tuple([int(x) for x in color]) 42 | if np.size(boxes) != 0: 43 | for box in boxes: 44 | xy, confidence = box[:4], box[4] 45 | confidence = ' '+str(confidence*100)[:2] 46 | xy = [int(xy[i] * resize_factor[i % 2]) for i in range(4)] 47 | cv2.rectangle(img, (xy[0], xy[1]), (xy[2], xy[3]), color[::-1], 2) 48 | (test_width, text_height), baseline = cv2.getTextSize(class_names[cls]+confidence, 49 | cv2.FONT_HERSHEY_SIMPLEX, 50 | 0.75, 1) 51 | cv2.rectangle(img, (xy[0], xy[1]), 52 | (xy[0] + test_width, xy[1] - text_height - baseline), 53 | color[::-1], thickness=cv2.FILLED) 54 | cv2.putText(img, class_names[cls]+confidence, (xy[0], xy[1] - baseline), 55 | cv2.FONT_HERSHEY_SIMPLEX, 0.75, (0, 0, 0), 1) 56 | 57 | if not os.path.exists('detections'): 58 | os.mkdir('detections') 59 | 60 | head, tail = os.path.split(img_names) 61 | name = './detections/'+tail[:-4]+'_yolo.jpg' 62 | cv2.imwrite(name, img) 63 | 64 | # Draws detected boxes in a video frame 65 | def draw_frame(frame, frame_size, boxes_dicts, class_names, model_size): 66 | boxes_dict = boxes_dicts[0] 67 | resize_factor = (frame_size[0] / model_size[1], frame_size[1] / model_size[0]) 68 | colors = ((np.array(color_palette("hls", 80)) * 255)).astype(np.uint8) 69 | for cls in range(len(class_names)): 70 | boxes = boxes_dict[cls] 71 | color = colors[cls] 72 | color = tuple([int(x) for x in color]) 73 | if np.size(boxes) != 0: 74 | for box in boxes: 75 | xy, confidence = box[:4], box[4] 76 | confidence = '' 77 | #confidence = ' '+str(confidence*100)[:2] 78 | xy = [int(xy[i] * resize_factor[i % 2]) for i in range(4)] 79 | cv2.rectangle(frame, (xy[0], xy[1]), (xy[2], xy[3]), color[::-1], 2) 80 | (test_width, text_height), baseline = cv2.getTextSize(class_names[cls]+confidence, 81 | cv2.FONT_HERSHEY_SIMPLEX, 82 | 0.75, 1) 83 | cv2.rectangle(frame, (xy[0], xy[1]), 84 | (xy[0] + test_width, xy[1] - text_height - baseline), 85 | color[::-1], thickness=cv2.FILLED) 86 | cv2.putText(frame, class_names[cls]+confidence, (xy[0], xy[1] - baseline), 87 | cv2.FONT_HERSHEY_SIMPLEX, 0.75, (0, 0, 0), 1) 88 | -------------------------------------------------------------------------------- /yolo_v3.py: -------------------------------------------------------------------------------- 1 | # Contains Yolo core definitions 2 | 3 | import tensorflow as tf 4 | 5 | _BATCH_NORM_DECAY = 0.9 6 | _BATCH_NORM_EPSILON = 1e-05 7 | _LEAKY_RELU = 0.1 8 | _ANCHORS = [(10, 13), (16, 30), (33, 23), 9 | (30, 61), (62, 45), (59, 119), 10 | (116, 90), (156, 198), (373, 326)] 11 | 12 | 13 | # Performs a batch normalization using a standard set of parameters 14 | def batch_norm(inputs, training, data_format): 15 | return tf.layers.batch_normalization( 16 | inputs=inputs, axis=1 if data_format == 'channels_first' else 3, 17 | momentum=_BATCH_NORM_DECAY, epsilon=_BATCH_NORM_EPSILON, 18 | scale=True, training=training) 19 | 20 | 21 | # ResNet implementation of fixed padding 22 | def fixed_padding(inputs, kernel_size, data_format): 23 | pad_total = kernel_size - 1 24 | pad_beg = pad_total // 2 25 | pad_end = pad_total - pad_beg 26 | 27 | if data_format == 'channels_first': 28 | padded_inputs = tf.pad(inputs, [[0, 0], [0, 0], 29 | [pad_beg, pad_end], 30 | [pad_beg, pad_end]]) 31 | else: 32 | padded_inputs = tf.pad(inputs, [[0, 0], [pad_beg, pad_end], 33 | [pad_beg, pad_end], [0, 0]]) 34 | return padded_inputs 35 | 36 | # Strided 2-D convolution with explicit padding 37 | def conv2d_fixed_padding(inputs, filters, kernel_size, data_format, strides=1): 38 | if strides > 1: 39 | inputs = fixed_padding(inputs, kernel_size, data_format) 40 | 41 | return tf.layers.conv2d( 42 | inputs=inputs, filters=filters, kernel_size=kernel_size, 43 | strides=strides, padding=('SAME' if strides == 1 else 'VALID'), 44 | use_bias=False, data_format=data_format) 45 | 46 | 47 | # Creates a residual block for Darknet 48 | def darknet53_residual_block(inputs, filters, training, data_format, strides=1): 49 | shortcut = inputs 50 | 51 | inputs = conv2d_fixed_padding(inputs, filters=filters, kernel_size=1, strides=strides, data_format=data_format) 52 | inputs = batch_norm(inputs, training=training, data_format=data_format) 53 | inputs = tf.nn.leaky_relu(inputs, alpha=_LEAKY_RELU) 54 | 55 | inputs = conv2d_fixed_padding( inputs, filters=2 * filters, kernel_size=3, strides=strides, data_format=data_format) 56 | inputs = batch_norm(inputs, training=training, data_format=data_format) 57 | inputs = tf.nn.leaky_relu(inputs, alpha=_LEAKY_RELU) 58 | 59 | inputs += shortcut 60 | 61 | return inputs 62 | 63 | 64 | # Creates Darknet53 model for feature extraction 65 | def darknet53(inputs, training, data_format): 66 | inputs = conv2d_fixed_padding(inputs, filters=32, kernel_size=3, data_format=data_format) 67 | inputs = batch_norm(inputs, training=training, data_format=data_format) 68 | inputs = tf.nn.leaky_relu(inputs, alpha=_LEAKY_RELU) 69 | inputs = conv2d_fixed_padding(inputs, filters=64, kernel_size=3, strides=2, data_format=data_format) 70 | inputs = batch_norm(inputs, training=training, data_format=data_format) 71 | inputs = tf.nn.leaky_relu(inputs, alpha=_LEAKY_RELU) 72 | 73 | inputs = darknet53_residual_block(inputs, filters=32, training=training, data_format=data_format) 74 | 75 | inputs = conv2d_fixed_padding(inputs, filters=128, kernel_size=3, strides=2, data_format=data_format) 76 | inputs = batch_norm(inputs, training=training, data_format=data_format) 77 | inputs = tf.nn.leaky_relu(inputs, alpha=_LEAKY_RELU) 78 | 79 | for _ in range(2): 80 | inputs = darknet53_residual_block(inputs, filters=64, training=training, data_format=data_format) 81 | 82 | inputs = conv2d_fixed_padding(inputs, filters=256, kernel_size=3, strides=2, data_format=data_format) 83 | inputs = batch_norm(inputs, training=training, data_format=data_format) 84 | inputs = tf.nn.leaky_relu(inputs, alpha=_LEAKY_RELU) 85 | 86 | for _ in range(8): 87 | inputs = darknet53_residual_block(inputs, filters=128, training=training, data_format=data_format) 88 | 89 | route1 = inputs 90 | 91 | inputs = conv2d_fixed_padding(inputs, filters=512, kernel_size=3, strides=2, data_format=data_format) 92 | inputs = batch_norm(inputs, training=training, data_format=data_format) 93 | inputs = tf.nn.leaky_relu(inputs, alpha=_LEAKY_RELU) 94 | 95 | for _ in range(8): 96 | inputs = darknet53_residual_block(inputs, filters=256, training=training, data_format=data_format) 97 | 98 | route2 = inputs 99 | 100 | inputs = conv2d_fixed_padding(inputs, filters=1024, kernel_size=3, strides=2, data_format=data_format) 101 | inputs = batch_norm(inputs, training=training, data_format=data_format) 102 | inputs = tf.nn.leaky_relu(inputs, alpha=_LEAKY_RELU) 103 | 104 | for _ in range(4): 105 | inputs = darknet53_residual_block(inputs, filters=512, training=training, data_format=data_format) 106 | 107 | return route1, route2, inputs 108 | 109 | 110 | # Creates convolution operations layer used after Darknet 111 | def yolo_convolution_block(inputs, filters, training, data_format): 112 | inputs = conv2d_fixed_padding(inputs, filters=filters, kernel_size=1, data_format=data_format) 113 | inputs = batch_norm(inputs, training=training, data_format=data_format) 114 | inputs = tf.nn.leaky_relu(inputs, alpha=_LEAKY_RELU) 115 | 116 | inputs = conv2d_fixed_padding(inputs, filters=2 * filters, kernel_size=3, data_format=data_format) 117 | inputs = batch_norm(inputs, training=training, data_format=data_format) 118 | inputs = tf.nn.leaky_relu(inputs, alpha=_LEAKY_RELU) 119 | 120 | inputs = conv2d_fixed_padding(inputs, filters=filters, kernel_size=1, data_format=data_format) 121 | inputs = batch_norm(inputs, training=training, data_format=data_format) 122 | inputs = tf.nn.leaky_relu(inputs, alpha=_LEAKY_RELU) 123 | 124 | inputs = conv2d_fixed_padding(inputs, filters=2 * filters, kernel_size=3, data_format=data_format) 125 | inputs = batch_norm(inputs, training=training, data_format=data_format) 126 | inputs = tf.nn.leaky_relu(inputs, alpha=_LEAKY_RELU) 127 | 128 | inputs = conv2d_fixed_padding(inputs, filters=filters, kernel_size=1, data_format=data_format) 129 | inputs = batch_norm(inputs, training=training, data_format=data_format) 130 | inputs = tf.nn.leaky_relu(inputs, alpha=_LEAKY_RELU) 131 | 132 | route = inputs 133 | 134 | inputs = conv2d_fixed_padding(inputs, filters=2 * filters, kernel_size=3, data_format=data_format) 135 | inputs = batch_norm(inputs, training=training, data_format=data_format) 136 | inputs = tf.nn.leaky_relu(inputs, alpha=_LEAKY_RELU) 137 | 138 | return route, inputs 139 | 140 | 141 | # Creates Yolo final detection layer 142 | def yolo_layer(inputs, n_classes, anchors, img_size, data_format): 143 | n_anchors = len(anchors) 144 | 145 | inputs = tf.layers.conv2d(inputs, filters=n_anchors * (5 + n_classes), 146 | kernel_size=1, strides=1, use_bias=True, 147 | data_format=data_format) 148 | 149 | shape = inputs.get_shape().as_list() 150 | grid_shape = shape[2:4] if data_format == 'channels_first' else shape[1:3] 151 | if data_format == 'channels_first': 152 | inputs = tf.transpose(inputs, [0, 2, 3, 1]) 153 | inputs = tf.reshape(inputs, [-1, n_anchors * grid_shape[0] * grid_shape[1], 5 + n_classes]) 154 | 155 | strides = (img_size[0] // grid_shape[0], img_size[1] // grid_shape[1]) 156 | 157 | box_centers, box_shapes, confidence, classes = \ 158 | tf.split(inputs, [2, 2, 1, n_classes], axis=-1) 159 | 160 | x = tf.range(grid_shape[0], dtype=tf.float32) 161 | y = tf.range(grid_shape[1], dtype=tf.float32) 162 | x_offset, y_offset = tf.meshgrid(x, y) 163 | x_offset = tf.reshape(x_offset, (-1, 1)) 164 | y_offset = tf.reshape(y_offset, (-1, 1)) 165 | x_y_offset = tf.concat([x_offset, y_offset], axis=-1) 166 | x_y_offset = tf.tile(x_y_offset, [1, n_anchors]) 167 | x_y_offset = tf.reshape(x_y_offset, [1, -1, 2]) 168 | box_centers = tf.nn.sigmoid(box_centers) 169 | box_centers = (box_centers + x_y_offset) * strides 170 | 171 | anchors = tf.tile(anchors, [grid_shape[0] * grid_shape[1], 1]) 172 | box_shapes = tf.exp(box_shapes) * tf.to_float(anchors) 173 | 174 | confidence = tf.nn.sigmoid(confidence) 175 | 176 | classes = tf.nn.sigmoid(classes) 177 | 178 | inputs = tf.concat([box_centers, box_shapes, confidence, classes], axis=-1) 179 | 180 | return inputs 181 | 182 | 183 | # Upsamples to `out_shape` using nearest neighbor interpolation 184 | def upsample(inputs, out_shape, data_format): 185 | if data_format == 'channels_first': 186 | inputs = tf.transpose(inputs, [0, 2, 3, 1]) 187 | new_height = out_shape[3] 188 | new_width = out_shape[2] 189 | else: 190 | new_height = out_shape[2] 191 | new_width = out_shape[1] 192 | 193 | inputs = tf.image.resize_nearest_neighbor(inputs, (new_height, new_width)) 194 | 195 | if data_format == 'channels_first': 196 | inputs = tf.transpose(inputs, [0, 3, 1, 2]) 197 | 198 | return inputs 199 | 200 | 201 | # Performs non-max suppression separately for each class 202 | def non_max_suppression(inputs, n_classes, max_output_size, iou_threshold, confidence_threshold): 203 | batch = tf.unstack(inputs) 204 | boxes_dicts = [] 205 | for boxes in batch: 206 | boxes = tf.boolean_mask(boxes, boxes[:, 4] > confidence_threshold) 207 | classes = tf.argmax(boxes[:, 5:], axis=-1) 208 | classes = tf.expand_dims(tf.to_float(classes), axis=-1) 209 | boxes = tf.concat([boxes[:, :5], classes], axis=-1) 210 | 211 | boxes_dict = dict() 212 | for cls in range(n_classes): 213 | mask = tf.equal(boxes[:, 5], cls) 214 | mask_shape = mask.get_shape() 215 | if mask_shape.ndims != 0: 216 | class_boxes = tf.boolean_mask(boxes, mask) 217 | boxes_coords, boxes_conf_scores, _ = tf.split(class_boxes, 218 | [4, 1, -1], 219 | axis=-1) 220 | boxes_conf_scores = tf.reshape(boxes_conf_scores, [-1]) 221 | indices = tf.image.non_max_suppression(boxes_coords, 222 | boxes_conf_scores, 223 | max_output_size, 224 | iou_threshold) 225 | class_boxes = tf.gather(class_boxes, indices) 226 | boxes_dict[cls] = class_boxes[:, :5] 227 | 228 | boxes_dicts.append(boxes_dict) 229 | 230 | return boxes_dicts 231 | 232 | 233 | # Computes top left and bottom right points of the boxes 234 | def build_boxes(inputs): 235 | center_x, center_y, width, height, confidence, classes = \ 236 | tf.split(inputs, [1, 1, 1, 1, 1, -1], axis=-1) 237 | 238 | top_left_x = center_x - width / 2 239 | top_left_y = center_y - height / 2 240 | bottom_right_x = center_x + width / 2 241 | bottom_right_y = center_y + height / 2 242 | 243 | boxes = tf.concat([top_left_x, top_left_y, 244 | bottom_right_x, bottom_right_y, 245 | confidence, classes], axis=-1) 246 | 247 | return boxes 248 | 249 | 250 | # Yolo v3 model class 251 | class Yolo_v3: 252 | def __init__(self, n_classes, model_size, max_output_size, iou_threshold, 253 | confidence_threshold, data_format=None): 254 | if not data_format: 255 | if tf.test.is_built_with_cuda(): 256 | data_format = 'channels_first' 257 | else: 258 | data_format = 'channels_last' 259 | 260 | self.n_classes = n_classes 261 | self.model_size = model_size 262 | self.max_output_size = max_output_size 263 | self.iou_threshold = iou_threshold 264 | self.confidence_threshold = confidence_threshold 265 | self.data_format = data_format 266 | 267 | # Add operations to detect boxes for a batch of input images 268 | def __call__(self, inputs, training): 269 | with tf.variable_scope('yolo_v3_model'): 270 | if self.data_format == 'channels_first': 271 | inputs = tf.transpose(inputs, [0, 3, 1, 2]) 272 | 273 | inputs = inputs / 255 274 | 275 | route1, route2, inputs = darknet53(inputs, training=training, 276 | data_format=self.data_format) 277 | 278 | route, inputs = yolo_convolution_block( 279 | inputs, filters=512, training=training, 280 | data_format=self.data_format) 281 | detect1 = yolo_layer(inputs, n_classes=self.n_classes, 282 | anchors=_ANCHORS[6:9], 283 | img_size=self.model_size, 284 | data_format=self.data_format) 285 | 286 | inputs = conv2d_fixed_padding(route, filters=256, kernel_size=1, 287 | data_format=self.data_format) 288 | inputs = batch_norm(inputs, training=training, 289 | data_format=self.data_format) 290 | inputs = tf.nn.leaky_relu(inputs, alpha=_LEAKY_RELU) 291 | upsample_size = route2.get_shape().as_list() 292 | inputs = upsample(inputs, out_shape=upsample_size, 293 | data_format=self.data_format) 294 | axis = 1 if self.data_format == 'channels_first' else 3 295 | inputs = tf.concat([inputs, route2], axis=axis) 296 | route, inputs = yolo_convolution_block( 297 | inputs, filters=256, training=training, 298 | data_format=self.data_format) 299 | detect2 = yolo_layer(inputs, n_classes=self.n_classes, 300 | anchors=_ANCHORS[3:6], 301 | img_size=self.model_size, 302 | data_format=self.data_format) 303 | 304 | inputs = conv2d_fixed_padding(route, filters=128, kernel_size=1, 305 | data_format=self.data_format) 306 | inputs = batch_norm(inputs, training=training, 307 | data_format=self.data_format) 308 | inputs = tf.nn.leaky_relu(inputs, alpha=_LEAKY_RELU) 309 | upsample_size = route1.get_shape().as_list() 310 | inputs = upsample(inputs, out_shape=upsample_size, 311 | data_format=self.data_format) 312 | inputs = tf.concat([inputs, route1], axis=axis) 313 | route, inputs = yolo_convolution_block( 314 | inputs, filters=128, training=training, 315 | data_format=self.data_format) 316 | detect3 = yolo_layer(inputs, n_classes=self.n_classes, 317 | anchors=_ANCHORS[0:3], 318 | img_size=self.model_size, 319 | data_format=self.data_format) 320 | 321 | inputs = tf.concat([detect1, detect2, detect3], axis=1) 322 | 323 | inputs = build_boxes(inputs) 324 | 325 | boxes_dicts = non_max_suppression( 326 | inputs, n_classes=self.n_classes, 327 | max_output_size=self.max_output_size, 328 | iou_threshold=self.iou_threshold, 329 | confidence_threshold=self.confidence_threshold) 330 | 331 | return boxes_dicts 332 | --------------------------------------------------------------------------------