├── .gitignore ├── README.md ├── data_parsing ├── LICENSE ├── __init__.py ├── requirements.txt ├── setup.py ├── voc_train.py └── voc_utils.py └── network ├── LICENSE ├── YOLO_small_tf.py ├── test └── person.jpg └── weights └── put_weight_file_here.txt /.gitignore: -------------------------------------------------------------------------------- 1 | .project 2 | .pydevproject 3 | .settings 4 | data_parsing/__pycache__/ 5 | data_parsing/build 6 | data_parsing/dist 7 | data_parsing/voc_utils.egg-info 8 | network/weights/YOLO_small.ckpt 9 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Tensorflow implementation of You Only Look Once 2 | 3 | [![Video](https://i.imgur.com/szGut1Z.png)](https://www.youtube.com/watch?v=EJy0EI3hfSg) 4 | 5 | An improvement of the implementation from @gliese581gg, with added training, testing and video parsing. We also used the VOC tools to parse the VOC dataset from @mprat. 6 | 7 | Paper: https://arxiv.org/abs/1506.02640 8 | 9 | *Note:* the code needs cleaning (I didn't write it myself from scratch). I could not find the time and honestly, at this point, there are multiple better Open Source implementations of YOLOv2, so I don't see the point of doing it anymore. 10 | 11 | # Installation 12 | 13 | ## Requirenments 14 | - Tensorflow 15 | - OpenCV2 16 | 17 | ## If you want to train the network yourself 18 | `cd data_parsing` 19 | 20 | `python data_parsing/setup.py install` 21 | 22 | ## If you want to use pre-trained weights 23 | Download YOLO weight file from: [https://drive.google.com/file/d/0B2JbaJSrWLpza08yS2FSUnV2dlE/view?usp=sharing]. 24 | 25 | Put the 'YOLO_small.ckpt' in the 'weight' folder of downloaded code. 26 | 27 | ## Uninstall (training) 28 | `cd data_parsing` 29 | 30 | `python setup.py develop --uninstall` 31 | 32 | # Usage 33 | ## Images 34 | `python network/YOLO_small_tf.py -fromfile "name of input file" -tofile_img "name of output file"` 35 | 36 | ## Videos 37 | `python network/YOLO_small_tf.py -video "name of input file" -tofile_vid "name of output file"` 38 | 39 | # License 40 | Refeer to the LICENSE files of both *data_parsing* and *network*. 41 | -------------------------------------------------------------------------------- /data_parsing/LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2015 Michele Pratusevich 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | 23 | -------------------------------------------------------------------------------- /data_parsing/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dshahrokhian/YOLO_tensorflow/f84fcc32fbead73b10da0a85c14d5ec7bd9de787/data_parsing/__init__.py -------------------------------------------------------------------------------- /data_parsing/requirements.txt: -------------------------------------------------------------------------------- 1 | beautifulsoup4==4.3.2 2 | matplotlib>=1.4.3 3 | more_itertools>=2.2 4 | numpy>=1.10 5 | pandas>=0.16 6 | scikit-image>=0.11 7 | -------------------------------------------------------------------------------- /data_parsing/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import find_packages, setup 2 | 3 | setup(name="voc_utils", 4 | version="0.0", 5 | description="A python utility for loading data in Pascal VOC format", 6 | author="Michele Pratusevich", 7 | author_email='mprat@alum.mit.edu', 8 | platforms=["osx"], # or more specific, e.g. "win32", "cygwin", "osx" 9 | license="BSD", 10 | url="http://github.com/mprat/pascal-voc-python", 11 | packages=find_packages(), 12 | install_requires=[i.strip() for i in open("requirements.txt").readlines()] 13 | ) 14 | -------------------------------------------------------------------------------- /data_parsing/voc_train.py: -------------------------------------------------------------------------------- 1 | import voc_utils as voc 2 | import math 3 | 4 | grid_size = 7 5 | network_reshape = 448 # The convolutional neural network reduces the image size to 448x448 6 | 7 | def get_training_data(img_filename): 8 | annotation = voc.load_annotation(img_filename) 9 | 10 | img_width = int(annotation.find("width").text) 11 | img_height = int(annotation.find("height").text) 12 | 13 | objects = annotation.find_all("object") 14 | 15 | return _getXYWHC(objects, img_width, img_height) 16 | 17 | def _getXYWHC(objects, img_width, img_height): 18 | ''' 19 | Return a (grid_size)x(grid_size) grid with the center, width, height and class of the objects 20 | ''' 21 | grid = [[None for x in range(grid_size)] for x in range(grid_size)] 22 | 23 | for obj in objects: 24 | bbox = obj.find("bndbox") 25 | xmin = int(bbox.find("xmin").text) 26 | xmax = int(bbox.find("xmax").text) 27 | ymin = int(bbox.find("ymin").text) 28 | ymax = int(bbox.find("ymax").text) 29 | obj_class = obj.find("name").text 30 | 31 | width = xmax - xmin 32 | height = ymax - ymin 33 | center_x = float(xmin + xmax) / 2 34 | center_y = float(ymin + ymax) / 2 35 | 36 | #bbox = reshape([center_x,center_y,width,height], 37 | # img_width, img_height, 38 | # network_reshape, network_reshape) 39 | 40 | cell_x, cell_y = getCell([center_x,center_y], img_width, img_height) 41 | 42 | if (grid[cell_x][cell_y] is None): 43 | grid[cell_x][cell_y] = [[center_x, center_y, width, height, obj_class]] 44 | else: 45 | grid[cell_x][cell_y] = [grid[cell_x][cell_y][0], [center_x, center_y, width, height, obj_class]] 46 | 47 | return grid 48 | 49 | def getCell(point, width, height): 50 | ''' 51 | Determines where a point falls within the (grid_size)x(grid_size) grid 52 | ''' 53 | col = int(math.floor(point[0] / width * (grid_size-1))) 54 | row = int(math.floor(point[1] / height * (grid_size-1))) 55 | 56 | return [row,col] 57 | 58 | def reshape(bbox, original_width, original_height, new_width, new_height): 59 | 60 | w_ratio = new_width / original_width 61 | h_ratio = new_height / original_height 62 | 63 | return [bbox[0] * w_ratio, 64 | bbox[1] * h_ratio, 65 | bbox[2] * w_ratio, 66 | bbox[3] * h_ratio] 67 | -------------------------------------------------------------------------------- /data_parsing/voc_utils.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import os 3 | from bs4 import BeautifulSoup 4 | from more_itertools import unique_everseen 5 | import numpy as np 6 | import matplotlib.pyplot as plt 7 | import skimage 8 | from skimage import io 9 | 10 | # Change root_dir depending on where you have stored the dataset 11 | root_dir = '/home/dani/Files/Data/VOC2012/' 12 | img_dir = os.path.join(root_dir, 'JPEGImages/') 13 | ann_dir = os.path.join(root_dir, 'Annotations/') 14 | set_dir = os.path.join(root_dir, 'ImageSets', 'Main') 15 | 16 | def list_image_sets(): 17 | """ 18 | List all the image sets from Pascal VOC. Don't bother computing 19 | this on the fly, just remember it. It's faster. 20 | """ 21 | return [ 22 | 'aeroplane', 'bicycle', 'bird', 'boat', 23 | 'bottle', 'bus', 'car', 'cat', 'chair', 24 | 'cow', 'diningtable', 'dog', 'horse', 25 | 'motorbike', 'person', 'pottedplant', 26 | 'sheep', 'sofa', 'train', 27 | 'tvmonitor'] 28 | 29 | def imgs_from_category(cat_name, dataset): 30 | """ 31 | Summary 32 | 33 | Args: 34 | cat_name (string): Category name as a string (from list_image_sets()) 35 | dataset (string): "train", "val", "train_val", or "test" (if available) 36 | 37 | Returns: 38 | pandas dataframe: pandas DataFrame of all filenames from that category 39 | """ 40 | filename = os.path.join(set_dir, cat_name + "_" + dataset + ".txt") 41 | df = pd.read_csv( 42 | filename, 43 | delim_whitespace=True, 44 | header=None, 45 | names=['filename', 'true']) 46 | return df 47 | 48 | def imgs_from_category_as_list(cat_name, dataset): 49 | """ 50 | Get a list of filenames for images in a particular category 51 | as a list rather than a pandas dataframe. 52 | 53 | Args: 54 | cat_name (string): Category name as a string (from list_image_sets()) 55 | dataset (string): "train", "val", "train_val", or "test" (if available) 56 | 57 | Returns: 58 | list of srings: all filenames from that category 59 | """ 60 | df = imgs_from_category(cat_name, dataset) 61 | df = df[df['true'] == 1] 62 | return df['filename'].values 63 | 64 | def annotation_file_from_img(img_name): 65 | """ 66 | Given an image name, get the annotation file for that image 67 | 68 | Args: 69 | img_name (string): string of the image name, relative to 70 | the image directory. 71 | 72 | Returns: 73 | string: file path to the annotation file 74 | """ 75 | return os.path.join(ann_dir, img_name) + '.xml' 76 | 77 | def load_annotation(img_filename): 78 | """ 79 | Load annotation file for a given image. 80 | 81 | Args: 82 | img_name (string): string of the image name, relative to 83 | the image directory. 84 | 85 | Returns: 86 | BeautifulSoup structure: the annotation labels loaded as a 87 | BeautifulSoup data structure 88 | """ 89 | xml = "" 90 | with open(annotation_file_from_img(img_filename)) as f: 91 | xml = f.readlines() 92 | xml = ''.join([line.strip('\t') for line in xml]) 93 | return BeautifulSoup(xml) 94 | 95 | # TODO: implement this 96 | def get_all_obj_and_box(objname, img_set): 97 | img_list = imgs_from_category_as_list(objname, img_set) 98 | 99 | for img in img_list: 100 | annotation = load_annotation(img) 101 | 102 | 103 | def load_img(img_filename): 104 | """ 105 | Load image from the filename. Default is to load in color if 106 | possible. 107 | 108 | Args: 109 | img_name (string): string of the image name, relative to 110 | the image directory. 111 | 112 | Returns: 113 | np array of float32: an image as a numpy array of float32 114 | """ 115 | img_filename = os.path.join(img_dir, img_filename + '.jpg') 116 | img = skimage.img_as_float(io.imread( 117 | img_filename)).astype(np.float32) 118 | if img.ndim == 2: 119 | img = img[:, :, np.newaxis] 120 | elif img.shape[2] == 4: 121 | img = img[:, :, :3] 122 | return img 123 | 124 | def load_imgs(img_filenames): 125 | """ 126 | Load a bunch of images from disk as np array. 127 | 128 | Args: 129 | img_filenames (list of strings): string of the image name, relative to 130 | the image directory. 131 | 132 | Returns: 133 | np array of float32: a numpy array of images. each image is 134 | a numpy array of float32 135 | """ 136 | return np.array([load_img(fname) for fname in img_filenames]) 137 | 138 | def _load_data(category, data_type=None): 139 | """ 140 | Loads all the data as a pandas DataFrame for a particular category. 141 | 142 | Args: 143 | category (string): Category name as a string (from list_image_sets()) 144 | data_type (string, optional): "train" or "val" 145 | 146 | Raises: 147 | ValueError: when you don't give "train" or "val" as data_type 148 | 149 | Returns: 150 | pandas DataFrame: df of filenames and bounding boxes 151 | """ 152 | if data_type is None: 153 | raise ValueError('Must provide data_type = train or val') 154 | to_find = category 155 | filename = os.path.join(root_dir, 'csvs/') + \ 156 | data_type + '_' + \ 157 | category + '.csv' 158 | if os.path.isfile(filename): 159 | return pd.read_csv(filename) 160 | else: 161 | train_img_list = imgs_from_category_as_list(to_find, data_type) 162 | data = [] 163 | for item in train_img_list: 164 | anno = load_annotation(item) 165 | objs = anno.findAll('object') 166 | for obj in objs: 167 | obj_names = obj.findChildren('name') 168 | for name_tag in obj_names: 169 | if str(name_tag.contents[0]) == category: 170 | fname = anno.findChild('filename').contents[0] 171 | bbox = obj.findChildren('bndbox')[0] 172 | xmin = int(bbox.findChildren('xmin')[0].contents[0]) 173 | ymin = int(bbox.findChildren('ymin')[0].contents[0]) 174 | xmax = int(bbox.findChildren('xmax')[0].contents[0]) 175 | ymax = int(bbox.findChildren('ymax')[0].contents[0]) 176 | data.append([fname, xmin, ymin, xmax, ymax]) 177 | df = pd.DataFrame( 178 | data, columns=['fname', 'xmin', 'ymin', 'xmax', 'ymax']) 179 | df.to_csv(filename) 180 | return df 181 | 182 | def get_image_url_list(category, data_type=None): 183 | """ 184 | For a given data type, returns a list of filenames. 185 | 186 | Args: 187 | category (string): Category name as a string (from list_image_sets()) 188 | data_type (string, optional): "train" or "val" 189 | 190 | Returns: 191 | list of strings: list of all filenames for that particular category 192 | """ 193 | df = _load_data(category, data_type=data_type) 194 | image_url_list = list( 195 | unique_everseen(list(img_dir + df['fname']))) 196 | return image_url_list 197 | 198 | 199 | def get_masks(cat_name, data_type, mask_type=None): 200 | """ 201 | Return a list of masks for a given category and data_type. 202 | 203 | Args: 204 | cat_name (string): Category name as a string (from list_image_sets()) 205 | data_type (string, optional): "train" or "val" 206 | mask_type (string, optional): either "bbox1" or "bbox2" - whether to 207 | sum or add the masks for multiple objects 208 | 209 | Raises: 210 | ValueError: if mask_type is not valid 211 | 212 | Returns: 213 | list of np arrays: list of np arrays that are masks for the images 214 | in the particular category. 215 | """ 216 | # change this to searching through the df 217 | # for the bboxes instead of relying on the order 218 | # so far, should be OK since I'm always loading 219 | # the df from disk anyway 220 | # mask_type should be bbox1 or bbox 221 | if mask_type is None: 222 | raise ValueError('Must provide mask_type') 223 | df = _load_data(cat_name, data_type=data_type) 224 | # load each image, turn into a binary mask 225 | masks = [] 226 | prev_url = "" 227 | blank_img = None 228 | for row_num, entry in df.iterrows(): 229 | img_url = os.path.join(img_dir, entry['fname']) 230 | if img_url != prev_url: 231 | if blank_img is not None: 232 | # TODO: options for how to process the masks 233 | # make sure the mask is from 0 to 1 234 | max_val = blank_img.max() 235 | if max_val > 0: 236 | min_val = blank_img.min() 237 | # print "min val before normalizing: ", min_val 238 | # start at zero 239 | blank_img -= min_val 240 | # print "max val before normalizing: ", max_val 241 | # max val at 1 242 | blank_img /= max_val 243 | masks.append(blank_img) 244 | prev_url = img_url 245 | img = load_img(img_url) 246 | blank_img = np.zeros((img.shape[0], img.shape[1], 1)) 247 | bbox = [entry['xmin'], entry['ymin'], entry['xmax'], entry['ymax']] 248 | if mask_type == 'bbox1': 249 | blank_img[bbox[1]:bbox[3], bbox[0]:bbox[2]] = 1.0 250 | elif mask_type == 'bbox2': 251 | blank_img[bbox[1]:bbox[3], bbox[0]:bbox[2]] += 1.0 252 | else: 253 | raise ValueError('Not a valid mask type') 254 | # TODO: options for how to process the masks 255 | # make sure the mask is from 0 to 1 256 | max_val = blank_img.max() 257 | if max_val > 0: 258 | min_val = blank_img.min() 259 | # print "min val before normalizing: ", min_val 260 | # start at zero 261 | blank_img -= min_val 262 | # print "max val before normalizing: ", max_val 263 | # max val at 1 264 | blank_img /= max_val 265 | masks.append(blank_img) 266 | return np.array(masks) 267 | 268 | 269 | def get_imgs(cat_name, data_type=None): 270 | """ 271 | Load and return all the images for a particular category. 272 | 273 | Args: 274 | cat_name (string): Category name as a string (from list_image_sets()) 275 | data_type (string, optional): "train" or "val" 276 | 277 | Returns: 278 | np array of images: np array of loaded images for the category 279 | and data_type. 280 | """ 281 | image_url_list = get_image_url_list(cat_name, data_type=data_type) 282 | imgs = [] 283 | for url in image_url_list: 284 | imgs.append(load_img(url)) 285 | return np.array(imgs) 286 | 287 | 288 | def display_image_and_mask(img, mask): 289 | """ 290 | Display an image and it's mask side by side. 291 | 292 | Args: 293 | img (np array): the loaded image as a np array 294 | mask (np array): the loaded mask as a np array 295 | """ 296 | plt.figure(1) 297 | plt.clf() 298 | ax1 = plt.subplot(1, 2, 1) 299 | ax2 = plt.subplot(1, 2, 2) 300 | ax1.imshow(img) 301 | ax1.set_title('Original image') 302 | ax2.imshow(mask) 303 | ax2.set_title('Mask') 304 | plt.show(block=False) 305 | 306 | 307 | def cat_name_to_cat_id(cat_name): 308 | """ 309 | Transform a category name to an id number alphabetically. 310 | 311 | Args: 312 | cat_name (string): Category name as a string (from list_image_sets()) 313 | 314 | Returns: 315 | int: the integer that corresponds to the category name 316 | """ 317 | cat_list = list_image_sets() 318 | cat_id_dict = dict(zip(cat_list, range(len(cat_list)))) 319 | return cat_id_dict[cat_name] 320 | 321 | 322 | def display_img_and_masks( 323 | img, true_mask, predicted_mask, block=False): 324 | """ 325 | Display an image and it's two masks side by side. 326 | 327 | Args: 328 | img (np array): image as a np array 329 | true_mask (np array): true mask as a np array 330 | predicted_mask (np array): predicted_mask as a np array 331 | block (bool, optional): whether to display in a blocking manner or not. 332 | Default to False (non-blocking) 333 | """ 334 | m_predicted_color = predicted_mask.reshape( 335 | predicted_mask.shape[0], predicted_mask.shape[1]) 336 | m_true_color = true_mask.reshape( 337 | true_mask.shape[0], true_mask.shape[1]) 338 | # m_predicted_color = predicted_mask 339 | # m_true_color = true_mask 340 | # plt.close(1) 341 | plt.figure(1) 342 | plt.clf() 343 | plt.axis('off') 344 | f, (ax1, ax2, ax3) = plt.subplots(nrows=1, ncols=3, num=1) 345 | # f.clf() 346 | ax1.get_xaxis().set_ticks([]) 347 | ax2.get_xaxis().set_ticks([]) 348 | ax3.get_xaxis().set_ticks([]) 349 | ax1.get_yaxis().set_ticks([]) 350 | ax2.get_yaxis().set_ticks([]) 351 | ax3.get_yaxis().set_ticks([]) 352 | 353 | ax1.imshow(img) 354 | ax2.imshow(m_true_color) 355 | ax3.imshow(m_predicted_color) 356 | plt.draw() 357 | plt.show(block=block) 358 | 359 | 360 | def load_data_multilabel(data_type=None): 361 | """ 362 | Returns a data frame for all images in a given set in multilabel format. 363 | 364 | Args: 365 | data_type (string, optional): "train" or "val" 366 | 367 | Returns: 368 | pandas DataFrame: filenames in multilabel format 369 | """ 370 | if data_type is None: 371 | raise ValueError('Must provide data_type = train or val') 372 | filename = os.path.join(set_dir, data_type + ".txt") 373 | cat_list = list_image_sets() 374 | df = pd.read_csv( 375 | filename, 376 | delim_whitespace=True, 377 | header=None, 378 | names=['filename']) 379 | # add all the blank rows for the multilabel case 380 | for cat_name in cat_list: 381 | df[cat_name] = 0 382 | for info in df.itertuples(): 383 | index = info[0] 384 | fname = info[1] 385 | anno = load_annotation(fname) 386 | objs = anno.findAll('object') 387 | for obj in objs: 388 | obj_names = obj.findChildren('name') 389 | for name_tag in obj_names: 390 | tag_name = str(name_tag.contents[0]) 391 | if tag_name in cat_list: 392 | df.at[index, tag_name] = 1 393 | return df 394 | -------------------------------------------------------------------------------- /network/LICENSE: -------------------------------------------------------------------------------- 1 | YOLO_tensorflow LICENSE 2 | Version 0.1, FEB 15 2016 3 | 4 | ACCORDING TO ORIGINAL CODE'S LICENSE, 5 | 6 | DO NOT USE THIS ON COMMERCIAL! 7 | I OR ORIGINAL AUTHOR DO NOT HOLD LIABILITY FOR ANY DAMAGES! 8 | 9 | 10 | BELOW IS THE ORIGINAL CODE'S LICENSE 11 | { 12 | THIS SOFTWARE LICENSE IS PROVIDED "ALL CAPS" SO THAT YOU KNOW IT IS SUPER 13 | SERIOUS AND YOU DON'T MESS AROUND WITH COPYRIGHT LAW BECAUSE YOU WILL GET IN 14 | TROUBLE HERE ARE SOME OTHER BUZZWORDS COMMONLY IN THESE THINGS WARRANTIES 15 | LIABILITY CONTRACT TORT LIABLE CLAIMS RESTRICTION MERCHANTABILITY SUBJECT TO 16 | THE FOLLOWING CONDITIONS: 17 | 18 | 1. #yolo 19 | 2. #swag 20 | 3. #blazeit 21 | } 22 | -------------------------------------------------------------------------------- /network/YOLO_small_tf.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | import cv2 4 | import time 5 | import sys 6 | import itertools 7 | from data_parsing import voc_utils 8 | from data_parsing import voc_train 9 | 10 | 11 | class YOLO_TF: 12 | ''' 13 | Class representing the operations required for building the YOLO neural network. 14 | ''' 15 | 16 | # Control variables 17 | fromfile = None 18 | tofile_img = 'test/output.jpg' 19 | tofile_txt = 'test/output.txt' 20 | imshow = True 21 | filewrite_img = False 22 | writter = None 23 | video = False 24 | filewrite_txt = False 25 | disp_console = True 26 | weights_file = 'network/weights/YOLO_small.ckpt' 27 | 28 | # algorihtm variable 29 | alpha = 0.1 30 | threshold = 0.2 31 | iou_threshold = 0.5 32 | num_class = 20 33 | num_box = 2 34 | grid_size = 7 35 | classes = ["aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat", "chair", "cow", "diningtable", 36 | "dog", "horse", "motorbike", "person", "pottedplant", "sheep", "sofa", "train", "tvmonitor"] 37 | 38 | w_img = 640 39 | h_img = 480 40 | 41 | # training variaible 42 | training = False 43 | keep_prob = tf.placeholder(tf.float32) 44 | lambdacoord = 5.0 45 | lambdanoobj = 0.5 46 | label = None 47 | label = None 48 | index_in_epoch = 0 49 | epochs_completed = 0 50 | 51 | 52 | 53 | def __init__(self, argvs=[]): 54 | self.argv_parser(argvs) 55 | self.build_networks() 56 | if self.training: 57 | self.build_training() 58 | self.train() 59 | print("detection") 60 | print(self.fromfile) 61 | if self.fromfile is not None: 62 | if self.video: 63 | print("video") 64 | self.detect_from_file_video(self.fromfile) 65 | else: 66 | print("image") 67 | self.detect_from_file(self.fromfile) 68 | 69 | def argv_parser(self, argvs): 70 | for i in range(1, len(argvs), 2): 71 | print(argvs[i]) 72 | if argvs[i] == '-train': self.training = True; 73 | if argvs[i] == '-fromfile': self.fromfile = argvs[i + 1] 74 | if argvs[i] == '-tofile_img': self.tofile_img = argvs[i + 1]; self.filewrite_img = True 75 | if argvs[i] == '-tofile_vid': self.tofile_img = argvs[i + 1]; self.filewrite_img = True 76 | if argvs[i] == '-tofile_txt': self.tofile_txt = argvs[i + 1]; self.filewrite_txt = True 77 | if argvs[i] == '-imshow': 78 | if argvs[i + 1] == '1': 79 | self.imshow = True 80 | else: 81 | self.imshow = False 82 | if argvs[i] == '-disp_console': 83 | if argvs[i + 1] == '1': 84 | self.disp_console = True 85 | else: 86 | self.disp_console = False 87 | if argvs[i] == '-video': 88 | self.video = True 89 | self.fromfile = argvs[i + 1] 90 | self.filewrite_img = True 91 | 92 | def build_networks(self): 93 | if self.disp_console: print "Building YOLO_small graph..." 94 | self.x = tf.placeholder('float32', [None, 448, 448, 3]) 95 | self.conv_1 = self.conv_layer(1, self.x, 64, 7, 2) 96 | self.pool_2 = self.pooling_layer(2, self.conv_1, 2, 2) 97 | self.conv_3 = self.conv_layer(3, self.pool_2, 192, 3, 1) 98 | self.pool_4 = self.pooling_layer(4, self.conv_3, 2, 2) 99 | self.conv_5 = self.conv_layer(5, self.pool_4, 128, 1, 1) 100 | self.conv_6 = self.conv_layer(6, self.conv_5, 256, 3, 1) 101 | self.conv_7 = self.conv_layer(7, self.conv_6, 256, 1, 1) 102 | self.conv_8 = self.conv_layer(8, self.conv_7, 512, 3, 1) 103 | self.pool_9 = self.pooling_layer(9, self.conv_8, 2, 2) 104 | self.conv_10 = self.conv_layer(10, self.pool_9, 256, 1, 1) 105 | self.conv_11 = self.conv_layer(11, self.conv_10, 512, 3, 1) 106 | self.conv_12 = self.conv_layer(12, self.conv_11, 256, 1, 1) 107 | self.conv_13 = self.conv_layer(13, self.conv_12, 512, 3, 1) 108 | self.conv_14 = self.conv_layer(14, self.conv_13, 256, 1, 1) 109 | self.conv_15 = self.conv_layer(15, self.conv_14, 512, 3, 1) 110 | self.conv_16 = self.conv_layer(16, self.conv_15, 256, 1, 1) 111 | self.conv_17 = self.conv_layer(17, self.conv_16, 512, 3, 1) 112 | self.conv_18 = self.conv_layer(18, self.conv_17, 512, 1, 1) 113 | self.conv_19 = self.conv_layer(19, self.conv_18, 1024, 3, 1) 114 | self.pool_20 = self.pooling_layer(20, self.conv_19, 2, 2) 115 | self.conv_21 = self.conv_layer(21, self.pool_20, 512, 1, 1) 116 | self.conv_22 = self.conv_layer(22, self.conv_21, 1024, 3, 1) 117 | self.conv_23 = self.conv_layer(23, self.conv_22, 512, 1, 1) 118 | self.conv_24 = self.conv_layer(24, self.conv_23, 1024, 3, 1) 119 | self.conv_25 = self.conv_layer(25, self.conv_24, 1024, 3, 1,trainable=self.training) 120 | self.conv_26 = self.conv_layer(26, self.conv_25, 1024, 3, 2,trainable=self.training) 121 | self.conv_27 = self.conv_layer(27, self.conv_26, 1024, 3, 1,trainable=self.training) 122 | self.conv_28 = self.conv_layer(28, self.conv_27, 1024, 3, 1,trainable=self.training) 123 | self.fc_29 = self.fc_layer(29, self.conv_28, 512, flat=True, linear=False,trainable=self.training) 124 | self.fc_30 = self.fc_layer(30, self.fc_29, 4096, flat=False, linear=False,trainable=self.training) 125 | self.drop_31 = self.dropout(31, self.fc_30) 126 | self.fc_32 = self.fc_layer(32, self.drop_31, 1470, flat=False, linear=True,trainable=self.training) 127 | self.sess = tf.Session() 128 | self.sess.run(tf.global_variables_initializer()) 129 | self.saver = tf.train.Saver() 130 | self.saver.restore(self.sess, self.weights_file) 131 | if self.disp_console: print "Loading complete!" + '\n' 132 | 133 | def conv_layer(self, idx, inputs, filters, size, stride, trainable=False): 134 | channels = inputs.get_shape()[3] 135 | weight = tf.Variable(tf.truncated_normal([size, size, int(channels), filters], stddev=0.1), trainable=trainable) 136 | biases = tf.Variable(tf.constant(0.1, shape=[filters]), trainable=trainable) 137 | 138 | pad_size = size // 2 139 | pad_mat = np.array([[0, 0], [pad_size, pad_size], [pad_size, pad_size], [0, 0]]) 140 | inputs_pad = tf.pad(inputs, pad_mat) 141 | 142 | conv = tf.nn.conv2d(inputs_pad, weight, strides=[1, stride, stride, 1], padding='VALID', 143 | name=str(idx) + '_conv') 144 | conv_biased = tf.add(conv, biases, name=str(idx) + '_conv_biased') 145 | if self.disp_console: print ' Layer %d : Type = Conv, Size = %d * %d, Stride = %d, Filters = %d, Input channels = %d' % ( 146 | idx, size, size, stride, filters, int(channels)) 147 | return tf.maximum(tf.mul(self.alpha, conv_biased), conv_biased, name=str(idx) + '_leaky_relu') 148 | 149 | def pooling_layer(self, idx, inputs, size, stride): 150 | if self.disp_console: print ' Layer %d : Type = Pool, Size = %d * %d, Stride = %d' % ( 151 | idx, size, size, stride) 152 | return tf.nn.max_pool(inputs, ksize=[1, size, size, 1], strides=[1, stride, stride, 1], padding='SAME', 153 | name=str(idx) + '_pool') 154 | 155 | def fc_layer(self, idx, inputs, hiddens, flat=False, linear=False, trainable=False): 156 | input_shape = inputs.get_shape().as_list() 157 | if flat: 158 | dim = input_shape[1] * input_shape[2] * input_shape[3] 159 | inputs_transposed = tf.transpose(inputs, (0, 3, 1, 2)) 160 | inputs_processed = tf.reshape(inputs_transposed, [-1, dim]) 161 | else: 162 | dim = input_shape[1] 163 | inputs_processed = inputs 164 | #weight = tf.Variable(tf.truncated_normal([dim, hiddens], stddev=0.1), trainable=trainable) 165 | weight = tf.Variable(tf.zeros([dim, hiddens]), trainable=trainable) 166 | biases = tf.Variable(tf.constant(0.1, shape=[hiddens]), trainable=trainable) 167 | if self.disp_console: print ' Layer %d : Type = Full, Hidden = %d, Input dimension = %d, Flat = %d, Activation = %d' % ( 168 | idx, hiddens, int(dim), int(flat), 1 - int(linear)) 169 | if linear: return tf.add(tf.matmul(inputs_processed, weight), biases, name=str(idx) + '_fc') 170 | ip = tf.add(tf.matmul(inputs_processed, weight), biases) 171 | return tf.maximum(tf.mul(self.alpha, ip), ip, name=str(idx) + '_fc') 172 | 173 | def dropout(self, idx, inputs): 174 | if self.disp_console: print ' Layer %d : Type = DropOut' % (idx) 175 | return tf.nn.dropout(inputs, keep_prob=self.keep_prob) 176 | 177 | def detect_from_cvmat(self, img): 178 | s = time.time() 179 | self.h_img, self.w_img, _ = img.shape 180 | img_resized = cv2.resize(img, (448, 448)) 181 | img_RGB = cv2.cvtColor(img_resized, cv2.COLOR_BGR2RGB) 182 | img_resized_np = np.asarray(img_RGB) 183 | inputs = np.zeros((1, 448, 448, 3), dtype='float32') 184 | inputs[0] = (img_resized_np / 255.0) * 2.0 - 1.0 185 | in_dict = {self.x: inputs, self.keep_prob: 1.0} 186 | net_output = self.sess.run(self.fc_32, feed_dict=in_dict) 187 | self.result = self.interpret_output(net_output[0]) 188 | strtime = str(time.time() - s) 189 | if self.disp_console: print 'Elapsed time : ' + strtime + ' secs' + '\n' 190 | 191 | def detect_from_file_video(self, filename): 192 | if self.disp_console: print 'Detect from ' + filename 193 | input = cv2.VideoCapture(filename) 194 | while not input.isOpened(): 195 | input = cv2.VideoCapture(filename) 196 | cv2.waitKey(1000) 197 | print "Wait for the header" 198 | if self.filewrite_img: 199 | print input.get(cv2.CAP_PROP_FOURCC) 200 | self.writter = cv2.VideoWriter(self.tofile_img, 201 | int(input.get(cv2.CAP_PROP_FOURCC)), 202 | int(input.get(cv2.CAP_PROP_FPS)), 203 | (int(input.get(cv2.CAP_PROP_FRAME_WIDTH)), 204 | int(input.get(cv2.CAP_PROP_FRAME_HEIGHT)))) 205 | while not self.writter.isOpened(): 206 | self.writter = cv2.VideoWriter(self.tofile_img, -1, input.get(cv2.CAP_PROP_FPS), 207 | [input.get(cv2.CAP_PROP_FRAME_WIDTH), input.get(cv2.CAP_PROP_FRAME_HEIGHT), 208 | True]) 209 | cv2.waitKey(1000) 210 | print "Wait for the header Writter" 211 | pos_frame = input.get(cv2.CAP_PROP_POS_FRAMES) 212 | while True: 213 | flag, frame = input.read() 214 | if flag: 215 | # The frame is ready and already captured 216 | cv2.imshow('video', frame) 217 | pos_frame = input.get(cv2.CAP_PROP_POS_FRAMES) 218 | print str(pos_frame) + " frames" 219 | self.detect_from_cvmat(frame) 220 | self.show_results(frame, self.result) 221 | else: 222 | # The next frame is not ready, so we try to read it again 223 | input.set(cv2.CAP_PROP_POS_FRAMES, pos_frame - 1) 224 | print "frame is not ready" 225 | # It is better to wait for a while for the next frame to be ready 226 | cv2.waitKey(1000) 227 | if cv2.waitKey(10) == 27: 228 | break 229 | if input.get(cv2.CAP_PROP_POS_FRAMES) == input.get(cv2.CAP_PROP_FRAME_COUNT): 230 | # If the number of captured frames is equal to the total number of frames, 231 | # we stop 232 | break 233 | input.release() 234 | self.writter.release() 235 | cv2.destroyAllWindows() 236 | 237 | 238 | def detect_from_file(self, filename): 239 | if self.disp_console: print 'Detect from ' + filename 240 | img = cv2.imread(filename) 241 | self.detect_from_cvmat(img) 242 | self.show_results(img, self.result) 243 | 244 | def interpret_output(self, output): 245 | probs = np.zeros((7, 7, 2, 20)) 246 | class_probs = np.reshape(output[0:980], (7, 7, 20)) 247 | scales = np.reshape(output[980:1078], (7, 7, 2)) 248 | boxes = np.reshape(output[1078:], (7, 7, 2, 4)) 249 | offset = np.transpose(np.reshape(np.array([np.arange(7)] * 14), (2, 7, 7)), (1, 2, 0)) 250 | 251 | boxes[:, :, :, 0] += offset 252 | boxes[:, :, :, 1] += np.transpose(offset, (1, 0, 2)) 253 | boxes[:, :, :, 0:2] = boxes[:, :, :, 0:2] / 7.0 254 | boxes[:, :, :, 2] = np.multiply(boxes[:, :, :, 2], boxes[:, :, :, 2]) 255 | boxes[:, :, :, 3] = np.multiply(boxes[:, :, :, 3], boxes[:, :, :, 3]) 256 | 257 | boxes[:, :, :, 0] *= self.w_img 258 | boxes[:, :, :, 1] *= self.h_img 259 | boxes[:, :, :, 2] *= self.w_img 260 | boxes[:, :, :, 3] *= self.h_img 261 | 262 | for i in range(2): 263 | for j in range(20): 264 | probs[:, :, i, j] = np.multiply(class_probs[:, :, j], scales[:, :, i]) 265 | 266 | filter_mat_probs = np.array(probs >= self.threshold, dtype='bool') 267 | filter_mat_boxes = np.nonzero(filter_mat_probs) 268 | boxes_filtered = boxes[filter_mat_boxes[0], filter_mat_boxes[1], filter_mat_boxes[2]] 269 | probs_filtered = probs[filter_mat_probs] 270 | classes_num_filtered = np.argmax(filter_mat_probs, axis=3)[ 271 | filter_mat_boxes[0], filter_mat_boxes[1], filter_mat_boxes[2]] 272 | 273 | argsort = np.array(np.argsort(probs_filtered))[::-1] 274 | boxes_filtered = boxes_filtered[argsort] 275 | probs_filtered = probs_filtered[argsort] 276 | classes_num_filtered = classes_num_filtered[argsort] 277 | 278 | for i in range(len(boxes_filtered)): 279 | if probs_filtered[i] == 0: continue 280 | for j in range(i + 1, len(boxes_filtered)): 281 | if self.iou(boxes_filtered[i], boxes_filtered[j]) > self.iou_threshold: 282 | probs_filtered[j] = 0.0 283 | 284 | filter_iou = np.array(probs_filtered > 0.0, dtype='bool') 285 | boxes_filtered = boxes_filtered[filter_iou] 286 | probs_filtered = probs_filtered[filter_iou] 287 | classes_num_filtered = classes_num_filtered[filter_iou] 288 | 289 | result = [] 290 | for i in range(len(boxes_filtered)): 291 | result.append([self.classes[classes_num_filtered[i]], boxes_filtered[i][0], boxes_filtered[i][1], 292 | boxes_filtered[i][2], boxes_filtered[i][3], probs_filtered[i]]) 293 | 294 | return result 295 | 296 | def show_results(self, img, results): 297 | img_cp = img.copy() 298 | if self.filewrite_txt: 299 | ftxt = open(self.tofile_txt, 'w') 300 | for i in range(len(results)): 301 | x = int(results[i][1]) 302 | y = int(results[i][2]) 303 | w = int(results[i][3]) // 2 304 | h = int(results[i][4]) // 2 305 | if self.disp_console: print ' class : ' + results[i][0] + ' , [x,y,w,h]=[' + str(x) + ',' + str( 306 | y) + ',' + str(int(results[i][3])) + ',' + str(int(results[i][4])) + '], Confidence = ' + str( 307 | results[i][5]) 308 | if self.filewrite_img or self.imshow: 309 | cv2.rectangle(img_cp, (x - w, y - h), (x + w, y + h), (0, 255, 0), 2) 310 | cv2.rectangle(img_cp, (x - w, y - h - 20), (x + w, y - h), (125, 125, 125), -1) 311 | cv2.putText(img_cp, results[i][0] + ' : %.2f' % results[i][5], (x - w + 5, y - h - 7), 312 | cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1) 313 | if self.filewrite_txt: 314 | ftxt.write(results[i][0] + ',' + str(x) + ',' + str(y) + ',' + str(w) + ',' + str(h) + ',' + str( 315 | results[i][5]) + '\n') 316 | if self.filewrite_img: 317 | if self.disp_console: print ' image file writed : ' + self.tofile_img 318 | if self.video: 319 | self.writter.write(img_cp) 320 | else: 321 | cv2.imwrite(self.tofile_img, img_cp) 322 | if self.imshow: 323 | cv2.imshow('YOLO_small detection', img_cp) 324 | cv2.waitKey(1) 325 | if self.filewrite_txt: 326 | if self.disp_console: print ' txt file writed : ' + self.tofile_txt 327 | ftxt.close() 328 | 329 | def iou(self, box1, box2): 330 | tb = min(box1[0] + 0.5 * box1[2], box2[0] + 0.5 * box2[2]) - max(box1[0] - 0.5 * box1[2], 331 | box2[0] - 0.5 * box2[2]) 332 | lr = min(box1[1] + 0.5 * box1[3], box2[1] + 0.5 * box2[3]) - max(box1[1] - 0.5 * box1[3], 333 | box2[1] - 0.5 * box2[3]) 334 | if tb < 0 or lr < 0: 335 | intersection = 0 336 | else: 337 | intersection = tb * lr 338 | return intersection / (box1[2] * box1[3] + box2[2] * box2[3] - intersection) 339 | 340 | 341 | def build_training(self): # TODO add training function! 342 | # the label of image 343 | self.x_ = tf.placeholder(tf.float32, [None,7, 7, 2]) # the first dimension (None) will index the images 344 | self.y_ = tf.placeholder(tf.float32, [None,7, 7, 2]) 345 | self.w_ = tf.placeholder(tf.float32, [None,7, 7, 2]) 346 | self.h_ = tf.placeholder(tf.float32, [None,7, 7, 2]) 347 | self.C_ = tf.placeholder(tf.float32, [None,7, 7, 2]) 348 | self.p_ = tf.placeholder(tf.float32, [None,7, 7, 20]) 349 | self.obj = tf.placeholder(tf.float32, [None,7, 7, 2]) 350 | self.objI = tf.placeholder(tf.float32, [None,7, 7]) 351 | self.noobj = tf.placeholder(tf.float32, [None,7, 7, 2]) 352 | 353 | 354 | #output network 355 | output = self.fc_32 356 | nb_image = tf.shape(self.x_)[0] 357 | class_probs = tf.reshape(output[0:nb_image,0:980], (nb_image,7, 7, 20)) 358 | scales = tf.reshape(output[0:nb_image,980:1078], (nb_image,7, 7, 2)) 359 | boxes = tf.reshape(output[0:nb_image,1078:], (nb_image,7, 7, 2, 4)) 360 | 361 | boxes0 = boxes[:,:, :, :, 0] 362 | boxes1 = boxes[:,:, :, :, 1] 363 | boxes2 = boxes[:,:, :, :, 2] 364 | boxes3 = boxes[:,:, :, :, 3] 365 | 366 | # loss funtion 367 | self.subX = tf.sub(boxes0, self.x_) 368 | self.subY = tf.sub(boxes1, self.y_) 369 | self.subW = tf.sub(tf.sqrt(tf.abs(boxes2)), tf.sqrt(self.w_)) 370 | self.subH = tf.sub(tf.sqrt(tf.abs(boxes3)), tf.sqrt(self.h_)) 371 | self.subC = tf.sub(scales, self.C_) 372 | self.subP = tf.sub(class_probs, self.p_) 373 | self.lossX=tf.mul(self.lambdacoord,tf.reduce_sum(tf.mul(self.obj,tf.mul(self.subX, self.subX)),axis=[1,2,3])) 374 | self.lossY=tf.mul(self.lambdacoord, tf.reduce_sum(tf.mul(self.obj, tf.mul(self.subY, self.subY)),axis=[1,2,3])) 375 | self.lossW=tf.mul(self.lambdacoord, tf.reduce_sum(tf.mul(self.obj, tf.mul(self.subW, self.subW)),axis=[1,2,3])) 376 | self.lossH=tf.mul(self.lambdacoord, tf.reduce_sum(tf.mul(self.obj, tf.mul(self.subH, self.subH)),axis=[1,2,3])) 377 | self.lossCObj=tf.reduce_sum(tf.mul(self.obj, tf.mul(self.subC, self.subC)),axis=[1,2,3]) 378 | self.lossCNobj=tf.mul(self.lambdanoobj, tf.reduce_sum(tf.mul(self.noobj, tf.mul(self.subC, self.subC)),axis=[1,2,3])) 379 | self.lossP=tf.reduce_sum(tf.mul(self.objI,tf.reduce_sum(tf.mul(self.subP, self.subP), axis=3)) ,axis=[1,2]) 380 | self.loss = tf.add_n((self.lossX,self.lossY,self.lossW,self.lossH,self.lossCObj,self.lossCNobj,self.lossP)) 381 | self.loss = tf.reduce_mean(self.loss) 382 | 383 | #variable for the training 384 | global_step = tf.Variable(0, trainable=False) 385 | starter_learning_rate = 0.001 386 | decay = 0.0005 387 | end_learning_rate = 0.01 388 | self.epoch=tf.placeholder(tf.int32) 389 | 390 | # Different case of learning rate 391 | def lr1(): 392 | return tf.train.polynomial_decay(starter_learning_rate, global_step, decay, end_learning_rate=end_learning_rate, 393 | power=1.0) 394 | def lr2(): 395 | return tf.constant(0.01) 396 | def lr3(): 397 | return tf.constant(0.001) 398 | def lr4(): 399 | return tf.constant(0.0001) 400 | lr = tf.case({tf.less_equal(self.epoch, 1): lr1, 401 | tf.logical_and(tf.greater(self.epoch, 76), tf.less_equal(self.epoch, 106)): lr2, 402 | tf.logical_and(tf.greater(self.epoch, 106), tf.less_equal(self.epoch, 136)): lr3, 403 | tf.greater(self.epoch, 136): lr4},lr4, exclusive=True) 404 | 405 | self.train_step = tf.train.MomentumOptimizer(learning_rate=lr, momentum=0.9).minimize(self.loss,global_step=global_step) 406 | self.sess = tf.Session() 407 | self.sess.run(tf.global_variables_initializer()) 408 | 409 | def build_label (self,img_filenames,epoch): 410 | X_global=[]; Y_global=[]; W_global=[]; H_global=[]; C_global=[]; P_global=[]; obj_global=[]; objI_global=[]; 411 | noobj_global=[];Image=[] 412 | for img_filename in img_filenames: 413 | prelabel=voc_train.get_training_data(img_filename) 414 | x = np.zeros([7,7,2]); y = np.zeros([7,7,2]); w = np.zeros([7,7,2]); h = np.zeros([7,7,2]) 415 | C = np.zeros([7,7,2]); p = np.zeros([7,7,20]); obj = np.zeros([7,7,2]); objI = np.zeros([7,7]) 416 | noobj = np.ones([7,7,2]); img = voc_utils.load_img(img_filename) 417 | for i,j in itertools.product(range(0,7),range(0,7)): 418 | if prelabel[i][j] is not None: 419 | index=0 420 | while(len(prelabel[i][j])>index and index<2): 421 | x[i][j][index]= (float(prelabel[i][j][index][0])/len(img))*7-i 422 | y[i][j][index] = (float(prelabel[i][ j][ index][ 1])/len(img[0]))*7-j 423 | w[i][j][index] = np.sqrt(prelabel[i][ j][ index][ 2])/len(img)*7 424 | h[i][j][index] = np.sqrt(prelabel[i][ j][ index][ 3])/len(img[0]) 425 | C[i][j][index] = 1.0 426 | p[i][j][self.classes.index(prelabel[i][ j][ index][ 4])] = 1.0/float(len(prelabel[i][j])) 427 | obj[i][j][index] = 1.0 428 | objI[i][j] = 1.0 429 | noobj[i][j][ index]=0.0 430 | index=index+1 431 | X_global.append(x); Y_global.append(y); W_global.append(w); H_global.append(h); C_global.append(C) 432 | P_global.append(p); obj_global.append(obj); objI_global.append(objI); noobj_global.append(noobj) 433 | 434 | #resize the image 435 | img_resized = cv2.resize(img, (448, 448)) 436 | img_RGB = cv2.cvtColor(img_resized, cv2.COLOR_BGR2RGB) 437 | img_resized_np = np.asarray(img_RGB) 438 | inputs = np.zeros((1, 448, 448, 3), dtype='float32') 439 | inputs[0] = (img_resized_np / 255.0) * 2.0 - 1.0 440 | Image.append(inputs[0]) 441 | X_global=np.array(X_global); Y_global=np.array(Y_global); W_global=np.array(W_global) 442 | H_global=np.array(H_global); C_global=np.array(C_global); P_global=np.array(P_global) 443 | obj_global=np.array(obj_global); objI_global=np.array(objI_global); noobj_global=np.array(noobj_global) 444 | Image=np.array(Image) 445 | return {self.x:Image,self.x_:X_global,self.y_:Y_global,self.w_:W_global,self.h_:H_global,self.C_:C_global, 446 | self.p_:P_global,self.obj:obj_global,self.objI:objI_global,self.noobj:noobj_global, 447 | self.keep_prob: 0.5,self.epoch:epoch} 448 | 449 | def next_batch(self,batch_size, num_examples): 450 | start = self.index_in_epoch 451 | self.index_in_epoch += batch_size 452 | if self.index_in_epoch > num_examples: 453 | # Finished epoch 454 | self.epochs_completed += 1 455 | # Shuffle the data 456 | perm = np.arange(num_examples) 457 | np.random.shuffle(perm) 458 | self.label=self.label[perm] 459 | # Start next epoch 460 | start = 0 461 | self.index_in_epoch = batch_size 462 | assert batch_size <= num_examples 463 | end = self.index_in_epoch 464 | return self.label[start:end] 465 | 466 | 467 | 468 | 469 | def training_step(self, i, update_test_data, update_train_data): 470 | 471 | # TODO need to create the loop for the training and test 472 | for nbatch in range(0,len(self.label)/64): 473 | dict=self.build_label(self.next_batch(64,num_examples=len(self.label)),i) 474 | self.sess.run(self.train_step, dict) 475 | 476 | train_l = [] 477 | test_l = [] 478 | 479 | if update_train_data: 480 | l = self.sess.run(self.loss, feed_dict=self.build_label(self.label,i)) 481 | train_l.append(l) 482 | 483 | if update_test_data: 484 | l= self.sess.run(self.loss, feed_dict=self.build_label(self.label_test,i)) 485 | print("\r", i, "loss : ", l) 486 | test_l.append(l) 487 | 488 | return (train_l, test_l) 489 | 490 | 491 | def train(self): 492 | train_l = [] 493 | test_l = [] 494 | self.label=voc_utils.imgs_from_category_as_list("bird", "train") 495 | self.label_test=voc_utils.imgs_from_category_as_list("bird", "val") 496 | training_iter = 137 497 | epoch_size = 5 498 | for i in range(training_iter): 499 | test = False 500 | if i % epoch_size == 0: 501 | test = True 502 | l, tl = self.training_step(i, test, test) 503 | train_l += l 504 | test_l += tl 505 | print("train loss") 506 | print(train_l) 507 | print("test loss") 508 | print(test_l) 509 | 510 | 511 | 512 | if __name__ == '__main__': 513 | yolo = YOLO_TF(sys.argv) 514 | cv2.waitKey(1000) 515 | -------------------------------------------------------------------------------- /network/test/person.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dshahrokhian/YOLO_tensorflow/f84fcc32fbead73b10da0a85c14d5ec7bd9de787/network/test/person.jpg -------------------------------------------------------------------------------- /network/weights/put_weight_file_here.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dshahrokhian/YOLO_tensorflow/f84fcc32fbead73b10da0a85c14d5ec7bd9de787/network/weights/put_weight_file_here.txt --------------------------------------------------------------------------------