├── Detection ├── MtcnnDetector.py ├── __init__.py ├── detector.py ├── fcn_detector.py └── nms.py ├── LICENSE ├── README.md ├── client.py ├── core ├── preprocess.py └── visualization.py ├── crc16.py ├── crc8.py ├── data └── MTCNN_model │ ├── ONet_landmark │ ├── ONet-16.data-00000-of-00001 │ ├── ONet-16.index │ ├── ONet-16.meta │ └── checkpoint │ ├── PNet_landmark │ ├── PNet-18.data-00000-of-00001 │ ├── PNet-18.index │ ├── PNet-18.meta │ └── checkpoint │ └── RNet_landmark │ ├── RNet-14.data-00000-of-00001 │ ├── RNet-14.index │ ├── RNet-14.meta │ └── checkpoint ├── face.png ├── flight-face.png ├── frame.py ├── instruction.py ├── server.py ├── standard_fields.py ├── tello.py ├── temp.h264 ├── test_flight.py ├── test_flight_rotate.py ├── test_video.py ├── train_models ├── MTCNN_config.py ├── __init__.py └── mtcnn_model.py └── visualization_utils.py /Detection/MtcnnDetector.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import time 3 | import numpy as np 4 | import sys 5 | sys.path.append("../") 6 | from .nms import py_nms 7 | 8 | class MtcnnDetector(object): 9 | 10 | 11 | def __init__(self, 12 | detectors, 13 | min_face_size=25, 14 | stride=2, 15 | threshold=[0.6, 0.7, 0.7], 16 | scale_factor=0.79, 17 | #scale_factor=0.709,#change 18 | slide_window=False): 19 | 20 | self.pnet_detector = detectors[0] 21 | self.rnet_detector = detectors[1] 22 | self.onet_detector = detectors[2] 23 | self.min_face_size = min_face_size 24 | self.stride = stride 25 | self.thresh = threshold 26 | self.scale_factor = scale_factor 27 | self.slide_window = slide_window 28 | 29 | def convert_to_square(self, bbox): 30 | """ 31 | convert bbox to square 32 | Parameters: 33 | ---------- 34 | bbox: numpy array , shape n x 5 35 | input bbox 36 | Returns: 37 | ------- 38 | square bbox 39 | """ 40 | square_bbox = bbox.copy() 41 | 42 | h = bbox[:, 3] - bbox[:, 1] + 1 43 | w = bbox[:, 2] - bbox[:, 0] + 1 44 | max_side = np.maximum(h, w) 45 | square_bbox[:, 0] = bbox[:, 0] + w * 0.5 - max_side * 0.5 46 | square_bbox[:, 1] = bbox[:, 1] + h * 0.5 - max_side * 0.5 47 | square_bbox[:, 2] = square_bbox[:, 0] + max_side - 1 48 | square_bbox[:, 3] = square_bbox[:, 1] + max_side - 1 49 | return square_bbox 50 | 51 | def calibrate_box(self, bbox, reg): 52 | """ 53 | calibrate bboxes 54 | Parameters: 55 | ---------- 56 | bbox: numpy array, shape n x 5 57 | input bboxes 58 | reg: numpy array, shape n x 4 59 | bboxes adjustment 60 | Returns: 61 | ------- 62 | bboxes after refinement 63 | """ 64 | 65 | bbox_c = bbox.copy() 66 | w = bbox[:, 2] - bbox[:, 0] + 1 67 | w = np.expand_dims(w, 1) 68 | h = bbox[:, 3] - bbox[:, 1] + 1 69 | h = np.expand_dims(h, 1) 70 | reg_m = np.hstack([w, h, w, h]) 71 | aug = reg_m * reg 72 | bbox_c[:, 0:4] = bbox_c[:, 0:4] + aug 73 | return bbox_c 74 | 75 | def generate_bbox(self, cls_map, reg, scale, threshold): 76 | """ 77 | generate bbox from feature cls_map 78 | Parameters: 79 | ---------- 80 | cls_map: numpy array , n x m 81 | detect score for each position 82 | reg: numpy array , n x m x 4 83 | bbox 84 | scale: float number 85 | scale of this detection 86 | threshold: float number 87 | detect threshold 88 | Returns: 89 | ------- 90 | bbox array 91 | """ 92 | stride = 2 93 | #stride = 4 94 | cellsize = 12 95 | #cellsize = 25 96 | 97 | t_index = np.where(cls_map > threshold) 98 | 99 | # find nothing 100 | if t_index[0].size == 0: 101 | return np.array([]) 102 | #offset 103 | dx1, dy1, dx2, dy2 = [reg[t_index[0], t_index[1], i] for i in range(4)] 104 | 105 | reg = np.array([dx1, dy1, dx2, dy2]) 106 | score = cls_map[t_index[0], t_index[1]] 107 | boundingbox = np.vstack([np.round((stride * t_index[1]) / scale), 108 | np.round((stride * t_index[0]) / scale), 109 | np.round((stride * t_index[1] + cellsize) / scale), 110 | np.round((stride * t_index[0] + cellsize) / scale), 111 | score, 112 | reg]) 113 | 114 | return boundingbox.T 115 | #pre-process images 116 | def processed_image(self, img, scale): 117 | height, width, channels = img.shape 118 | new_height = int(height * scale) # resized new height 119 | new_width = int(width * scale) # resized new width 120 | new_dim = (new_width, new_height) 121 | img_resized = cv2.resize(img, new_dim, interpolation=cv2.INTER_LINEAR) # resized image 122 | img_resized = (img_resized - 127.5) / 128 123 | return img_resized 124 | 125 | def pad(self, bboxes, w, h): 126 | """ 127 | pad the the bboxes, alse restrict the size of it 128 | Parameters: 129 | ---------- 130 | bboxes: numpy array, n x 5 131 | input bboxes 132 | w: float number 133 | width of the input image 134 | h: float number 135 | height of the input image 136 | Returns : 137 | ------ 138 | dy, dx : numpy array, n x 1 139 | start point of the bbox in target image 140 | edy, edx : numpy array, n x 1 141 | end point of the bbox in target image 142 | y, x : numpy array, n x 1 143 | start point of the bbox in original image 144 | ex, ex : numpy array, n x 1 145 | end point of the bbox in original image 146 | tmph, tmpw: numpy array, n x 1 147 | height and width of the bbox 148 | """ 149 | tmpw, tmph = bboxes[:, 2] - bboxes[:, 0] + 1, bboxes[:, 3] - bboxes[:, 1] + 1 150 | num_box = bboxes.shape[0] 151 | 152 | dx, dy = np.zeros((num_box,)), np.zeros((num_box,)) 153 | edx, edy = tmpw.copy() - 1, tmph.copy() - 1 154 | 155 | x, y, ex, ey = bboxes[:, 0], bboxes[:, 1], bboxes[:, 2], bboxes[:, 3] 156 | 157 | tmp_index = np.where(ex > w - 1) 158 | edx[tmp_index] = tmpw[tmp_index] + w - 2 - ex[tmp_index] 159 | ex[tmp_index] = w - 1 160 | 161 | tmp_index = np.where(ey > h - 1) 162 | edy[tmp_index] = tmph[tmp_index] + h - 2 - ey[tmp_index] 163 | ey[tmp_index] = h - 1 164 | 165 | tmp_index = np.where(x < 0) 166 | dx[tmp_index] = 0 - x[tmp_index] 167 | x[tmp_index] = 0 168 | 169 | tmp_index = np.where(y < 0) 170 | dy[tmp_index] = 0 - y[tmp_index] 171 | y[tmp_index] = 0 172 | 173 | return_list = [dy, edy, dx, edx, y, ey, x, ex, tmpw, tmph] 174 | return_list = [item.astype(np.int32) for item in return_list] 175 | 176 | return return_list 177 | 178 | def detect_pnet(self, im): 179 | """Get face candidates through pnet 180 | 181 | Parameters: 182 | ---------- 183 | im: numpy array 184 | input image array 185 | 186 | Returns: 187 | ------- 188 | boxes: numpy array 189 | detected boxes before calibration 190 | boxes_c: numpy array 191 | boxes after calibration 192 | """ 193 | h, w, c = im.shape 194 | net_size = 12 195 | 196 | current_scale = float(net_size) / self.min_face_size # find initial scale 197 | # print("current_scale", net_size, self.min_face_size, current_scale) 198 | im_resized = self.processed_image(im, current_scale) 199 | current_height, current_width, _ = im_resized.shape 200 | # fcn 201 | all_boxes = list() 202 | while min(current_height, current_width) > net_size: 203 | #return the result predicted by pnet 204 | #cls_cls_map : H*w*2 205 | #reg: H*w*4 206 | cls_cls_map, reg = self.pnet_detector.predict(im_resized) 207 | #boxes: num*9(x1,y1,x2,y2,score,x1_offset,y1_offset,x2_offset,y2_offset) 208 | boxes = self.generate_bbox(cls_cls_map[:, :,1], reg, current_scale, self.thresh[0]) 209 | 210 | current_scale *= self.scale_factor 211 | im_resized = self.processed_image(im, current_scale) 212 | current_height, current_width, _ = im_resized.shape 213 | 214 | if boxes.size == 0: 215 | continue 216 | keep = py_nms(boxes[:, :5], 0.5, 'Union') 217 | boxes = boxes[keep] 218 | all_boxes.append(boxes) 219 | 220 | if len(all_boxes) == 0: 221 | return None, None, None 222 | 223 | all_boxes = np.vstack(all_boxes) 224 | 225 | # merge the detection from first stage 226 | keep = py_nms(all_boxes[:, 0:5], 0.7, 'Union') 227 | all_boxes = all_boxes[keep] 228 | boxes = all_boxes[:, :5] 229 | 230 | bbw = all_boxes[:, 2] - all_boxes[:, 0] + 1 231 | bbh = all_boxes[:, 3] - all_boxes[:, 1] + 1 232 | 233 | # refine the boxes 234 | boxes_c = np.vstack([all_boxes[:, 0] + all_boxes[:, 5] * bbw, 235 | all_boxes[:, 1] + all_boxes[:, 6] * bbh, 236 | all_boxes[:, 2] + all_boxes[:, 7] * bbw, 237 | all_boxes[:, 3] + all_boxes[:, 8] * bbh, 238 | all_boxes[:, 4]]) 239 | boxes_c = boxes_c.T 240 | 241 | return boxes, boxes_c, None 242 | def detect_rnet(self, im, dets): 243 | """Get face candidates using rnet 244 | 245 | Parameters: 246 | ---------- 247 | im: numpy array 248 | input image array 249 | dets: numpy array 250 | detection results of pnet 251 | 252 | Returns: 253 | ------- 254 | boxes: numpy array 255 | detected boxes before calibration 256 | boxes_c: numpy array 257 | boxes after calibration 258 | """ 259 | h, w, c = im.shape 260 | dets = self.convert_to_square(dets) 261 | dets[:, 0:4] = np.round(dets[:, 0:4]) 262 | 263 | [dy, edy, dx, edx, y, ey, x, ex, tmpw, tmph] = self.pad(dets, w, h) 264 | num_boxes = dets.shape[0] 265 | cropped_ims = np.zeros((num_boxes, 24, 24, 3), dtype=np.float32) 266 | for i in range(num_boxes): 267 | tmp = np.zeros((tmph[i], tmpw[i], 3), dtype=np.uint8) 268 | tmp[dy[i]:edy[i] + 1, dx[i]:edx[i] + 1, :] = im[y[i]:ey[i] + 1, x[i]:ex[i] + 1, :] 269 | cropped_ims[i, :, :, :] = (cv2.resize(tmp, (24, 24))-127.5) / 128 270 | #cls_scores : num_data*2 271 | #reg: num_data*4 272 | #landmark: num_data*10 273 | cls_scores, reg, _ = self.rnet_detector.predict(cropped_ims) 274 | cls_scores = cls_scores[:,1] 275 | keep_inds = np.where(cls_scores > self.thresh[1])[0] 276 | if len(keep_inds) > 0: 277 | boxes = dets[keep_inds] 278 | boxes[:, 4] = cls_scores[keep_inds] 279 | reg = reg[keep_inds] 280 | #landmark = landmark[keep_inds] 281 | else: 282 | return None, None, None 283 | 284 | 285 | keep = py_nms(boxes, 0.6) 286 | boxes = boxes[keep] 287 | boxes_c = self.calibrate_box(boxes, reg[keep]) 288 | return boxes, boxes_c,None 289 | def detect_onet(self, im, dets): 290 | """Get face candidates using onet 291 | 292 | Parameters: 293 | ---------- 294 | im: numpy array 295 | input image array 296 | dets: numpy array 297 | detection results of rnet 298 | 299 | Returns: 300 | ------- 301 | boxes: numpy array 302 | detected boxes before calibration 303 | boxes_c: numpy array 304 | boxes after calibration 305 | """ 306 | h, w, c = im.shape 307 | dets = self.convert_to_square(dets) 308 | dets[:, 0:4] = np.round(dets[:, 0:4]) 309 | [dy, edy, dx, edx, y, ey, x, ex, tmpw, tmph] = self.pad(dets, w, h) 310 | num_boxes = dets.shape[0] 311 | cropped_ims = np.zeros((num_boxes, 48, 48, 3), dtype=np.float32) 312 | for i in range(num_boxes): 313 | tmp = np.zeros((tmph[i], tmpw[i], 3), dtype=np.uint8) 314 | tmp[dy[i]:edy[i] + 1, dx[i]:edx[i] + 1, :] = im[y[i]:ey[i] + 1, x[i]:ex[i] + 1, :] 315 | cropped_ims[i, :, :, :] = (cv2.resize(tmp, (48, 48))-127.5) / 128 316 | 317 | cls_scores, reg,landmark = self.onet_detector.predict(cropped_ims) 318 | #prob belongs to face 319 | cls_scores = cls_scores[:,1] 320 | keep_inds = np.where(cls_scores > self.thresh[2])[0] 321 | if len(keep_inds) > 0: 322 | #pickout filtered box 323 | boxes = dets[keep_inds] 324 | boxes[:, 4] = cls_scores[keep_inds] 325 | reg = reg[keep_inds] 326 | landmark = landmark[keep_inds] 327 | else: 328 | return None, None, None 329 | 330 | #width 331 | w = boxes[:,2] - boxes[:,0] + 1 332 | #height 333 | h = boxes[:,3] - boxes[:,1] + 1 334 | landmark[:,0::2] = (np.tile(w,(5,1)) * landmark[:,0::2].T + np.tile(boxes[:,0],(5,1)) - 1).T 335 | landmark[:,1::2] = (np.tile(h,(5,1)) * landmark[:,1::2].T + np.tile(boxes[:,1],(5,1)) - 1).T 336 | boxes_c = self.calibrate_box(boxes, reg) 337 | 338 | 339 | boxes = boxes[py_nms(boxes, 0.6, "Minimum")] 340 | keep = py_nms(boxes_c, 0.6, "Minimum") 341 | boxes_c = boxes_c[keep] 342 | landmark = landmark[keep] 343 | return boxes, boxes_c,landmark 344 | #use for video 345 | def detect(self, img): 346 | """Detect face over image 347 | """ 348 | boxes = None 349 | t = time.time() 350 | 351 | # pnet 352 | t1 = 0 353 | if self.pnet_detector: 354 | boxes, boxes_c,_ = self.detect_pnet(img) 355 | if boxes_c is None: 356 | return np.array([]),np.array([]) 357 | 358 | t1 = time.time() - t 359 | t = time.time() 360 | 361 | # rnet 362 | t2 = 0 363 | if self.rnet_detector: 364 | boxes, boxes_c,_ = self.detect_rnet(img, boxes_c) 365 | if boxes_c is None: 366 | return np.array([]),np.array([]) 367 | 368 | t2 = time.time() - t 369 | t = time.time() 370 | 371 | # onet 372 | t3 = 0 373 | if self.onet_detector: 374 | boxes, boxes_c,landmark = self.detect_onet(img, boxes_c) 375 | if boxes_c is None: 376 | return np.array([]),np.array([]) 377 | 378 | t3 = time.time() - t 379 | t = time.time() 380 | 381 | return boxes_c,landmark 382 | def detect_face(self, test_data): 383 | all_boxes = []#save each image's bboxes 384 | landmarks = [] 385 | batch_idx = 0 386 | sum_time = 0 387 | #test_data is iter_ 388 | for databatch in test_data: 389 | #databatch(image returned) 390 | if batch_idx % 100 == 0: 391 | print("%d images done" % batch_idx) 392 | im = databatch 393 | # pnet 394 | t1 = 0 395 | if self.pnet_detector: 396 | t = time.time() 397 | #ignore landmark 398 | boxes, boxes_c, landmark = self.detect_pnet(im) 399 | t1 = time.time() - t 400 | sum_time += t1 401 | if boxes_c is None: 402 | print("boxes_c is None...") 403 | all_boxes.append(np.array([])) 404 | #pay attention 405 | landmarks.append(np.array([])) 406 | batch_idx += 1 407 | continue 408 | # rnet 409 | t2 = 0 410 | if self.rnet_detector: 411 | t = time.time() 412 | #ignore landmark 413 | boxes, boxes_c, landmark = self.detect_rnet(im, boxes_c) 414 | t2 = time.time() - t 415 | sum_time += t2 416 | if boxes_c is None: 417 | all_boxes.append(np.array([])) 418 | landmarks.append(np.array([])) 419 | batch_idx += 1 420 | continue 421 | # onet 422 | t3 = 0 423 | if self.onet_detector: 424 | t = time.time() 425 | boxes, boxes_c, landmark = self.detect_onet(im, boxes_c) 426 | t3 = time.time() - t 427 | sum_time += t3 428 | if boxes_c is None: 429 | all_boxes.append(np.array([])) 430 | landmarks.append(np.array([])) 431 | batch_idx += 1 432 | continue 433 | 434 | 435 | all_boxes.append(boxes_c) 436 | landmarks.append(landmark) 437 | batch_idx += 1 438 | #num_of_data*9,num_of_data*10 439 | return all_boxes,landmarks 440 | -------------------------------------------------------------------------------- /Detection/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DevconX/Tello-Python/0e7ef8375e6904a536ff274ec7c868388424327e/Detection/__init__.py -------------------------------------------------------------------------------- /Detection/detector.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | 4 | class Detector(object): 5 | #net_factory:rnet or onet 6 | #datasize:24 or 48 7 | def __init__(self, net_factory, data_size, batch_size, model_path): 8 | graph = tf.Graph() 9 | with graph.as_default(): 10 | self.image_op = tf.placeholder(tf.float32, shape=[batch_size, data_size, data_size, 3], name='input_image') 11 | #figure out landmark 12 | self.cls_prob, self.bbox_pred, self.landmark_pred = net_factory(self.image_op, training=False) 13 | self.sess = tf.Session( 14 | config=tf.ConfigProto(allow_soft_placement=True, gpu_options=tf.GPUOptions(allow_growth=True))) 15 | saver = tf.train.Saver() 16 | #check whether the dictionary is valid 17 | model_dict = '/'.join(model_path.split('/')[:-1]) 18 | ckpt = tf.train.get_checkpoint_state(model_dict) 19 | print(model_path) 20 | readstate = ckpt and ckpt.model_checkpoint_path 21 | assert readstate, "the params dictionary is not valid" 22 | print("restore models' param") 23 | saver.restore(self.sess, model_path) 24 | 25 | self.data_size = data_size 26 | self.batch_size = batch_size 27 | #rnet and onet minibatch(test) 28 | def predict(self, databatch): 29 | # access data 30 | # databatch: N x 3 x data_size x data_size 31 | scores = [] 32 | batch_size = self.batch_size 33 | 34 | minibatch = [] 35 | cur = 0 36 | #num of all_data 37 | n = databatch.shape[0] 38 | while cur < n: 39 | #split mini-batch 40 | minibatch.append(databatch[cur:min(cur + batch_size, n), :, :, :]) 41 | cur += batch_size 42 | #every batch prediction result 43 | cls_prob_list = [] 44 | bbox_pred_list = [] 45 | landmark_pred_list = [] 46 | for idx, data in enumerate(minibatch): 47 | m = data.shape[0] 48 | real_size = self.batch_size 49 | #the last batch 50 | if m < batch_size: 51 | keep_inds = np.arange(m) 52 | #gap (difference) 53 | gap = self.batch_size - m 54 | while gap >= len(keep_inds): 55 | gap -= len(keep_inds) 56 | keep_inds = np.concatenate((keep_inds, keep_inds)) 57 | if gap != 0: 58 | keep_inds = np.concatenate((keep_inds, keep_inds[:gap])) 59 | data = data[keep_inds] 60 | real_size = m 61 | #cls_prob batch*2 62 | #bbox_pred batch*4 63 | cls_prob, bbox_pred,landmark_pred = self.sess.run([self.cls_prob, self.bbox_pred,self.landmark_pred], feed_dict={self.image_op: data}) 64 | #num_batch * batch_size *2 65 | cls_prob_list.append(cls_prob[:real_size]) 66 | #num_batch * batch_size *4 67 | bbox_pred_list.append(bbox_pred[:real_size]) 68 | #num_batch * batch_size*10 69 | landmark_pred_list.append(landmark_pred[:real_size]) 70 | #num_of_data*2,num_of_data*4,num_of_data*10 71 | return np.concatenate(cls_prob_list, axis=0), np.concatenate(bbox_pred_list, axis=0), np.concatenate(landmark_pred_list, axis=0) 72 | -------------------------------------------------------------------------------- /Detection/fcn_detector.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | import sys 4 | sys.path.append("../") 5 | 6 | class FcnDetector(object): 7 | #net_factory: which net 8 | #model_path: where the params'file is 9 | def __init__(self, net_factory, model_path): 10 | #create a graph 11 | graph = tf.Graph() 12 | with graph.as_default(): 13 | #define tensor and op in graph(-1,1) 14 | self.image_op = tf.placeholder(tf.float32, name='input_image') 15 | self.width_op = tf.placeholder(tf.int32, name='image_width') 16 | self.height_op = tf.placeholder(tf.int32, name='image_height') 17 | image_reshape = tf.reshape(self.image_op, [1, self.height_op, self.width_op, 3]) 18 | #self.cls_prob batch*2 19 | #self.bbox_pred batch*4 20 | #construct model here 21 | #self.cls_prob, self.bbox_pred = net_factory(image_reshape, training=False) 22 | #contains landmark 23 | self.cls_prob, self.bbox_pred, _ = net_factory(image_reshape, training=False) 24 | 25 | #allow 26 | self.sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True, gpu_options=tf.GPUOptions(allow_growth=True))) 27 | saver = tf.train.Saver() 28 | #check whether the dictionary is valid 29 | model_dict = '/'.join(model_path.split('/')[:-1]) 30 | ckpt = tf.train.get_checkpoint_state(model_dict) 31 | print (model_path) 32 | readstate = ckpt and ckpt.model_checkpoint_path 33 | assert readstate, "the params dictionary is not valid" 34 | print ("restore models' param") 35 | saver.restore(self.sess, model_path) 36 | def predict(self, databatch): 37 | height, width, _ = databatch.shape 38 | # print(height, width) 39 | cls_prob, bbox_pred = self.sess.run([self.cls_prob, self.bbox_pred], 40 | feed_dict={self.image_op: databatch, self.width_op: width, 41 | self.height_op: height}) 42 | return cls_prob, bbox_pred 43 | -------------------------------------------------------------------------------- /Detection/nms.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | def py_nms(dets, thresh, mode="Union"): 3 | """ 4 | greedily select boxes with high confidence 5 | keep boxes overlap <= thresh 6 | rule out overlap > thresh 7 | :param dets: [[x1, y1, x2, y2 score]] 8 | :param thresh: retain overlap <= thresh 9 | :return: indexes to keep 10 | """ 11 | x1 = dets[:, 0] 12 | y1 = dets[:, 1] 13 | x2 = dets[:, 2] 14 | y2 = dets[:, 3] 15 | scores = dets[:, 4] 16 | 17 | areas = (x2 - x1 + 1) * (y2 - y1 + 1) 18 | order = scores.argsort()[::-1] 19 | 20 | keep = [] 21 | while order.size > 0: 22 | i = order[0] 23 | keep.append(i) 24 | xx1 = np.maximum(x1[i], x1[order[1:]]) 25 | yy1 = np.maximum(y1[i], y1[order[1:]]) 26 | xx2 = np.minimum(x2[i], x2[order[1:]]) 27 | yy2 = np.minimum(y2[i], y2[order[1:]]) 28 | 29 | w = np.maximum(0.0, xx2 - xx1 + 1) 30 | h = np.maximum(0.0, yy2 - yy1 + 1) 31 | inter = w * h 32 | if mode == "Union": 33 | ovr = inter / (areas[i] + areas[order[1:]] - inter) 34 | elif mode == "Minimum": 35 | ovr = inter / np.minimum(areas[i], areas[order[1:]]) 36 | #keep 37 | inds = np.where(ovr <= thresh)[0] 38 | order = order[inds + 1] 39 | 40 | return keep 41 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Devcon 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Tello-Python 2 | tello DGI face recognition using MTCNN with streaming prototype 3 | 4 | ## Requirement 5 | ```bash 6 | pip3 install opencv-contrib-python numpy scipy tensorflow==1.4 flask-socketio socketIO_client 7 | ``` 8 | 9 | ## How to run streaming video only 10 | ```bash 11 | python3 test_video.py 12 | ``` 13 | 14 | ## How to run streaming video + 5 seconds flight + land 15 | ```bash 16 | python3 test_flight.py 17 | ``` 18 | 19 | ## How to run streaming video + flight + 360 degrees clockwise rotation + land 20 | ```bash 21 | python3 test_flight_rotate.py 22 | ``` 23 | 24 | ## to run a client listen to video streaming using OpenCV 25 | ```bash 26 | python3 client.py 27 | ``` 28 | 29 | ## Output 30 | 31 | ![alt text](face.png) 32 | 33 | fly and detect face, click for video 34 | 35 | [![IMAGE ALT TEXT](flight-face.png)](https://youtu.be/Q147BLIRcMs "Test flight with face recognition using Tensorflow") 36 | -------------------------------------------------------------------------------- /client.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import base64 3 | from socketIO_client import SocketIO, BaseNamespace 4 | import numpy as np 5 | import time 6 | from PIL import Image 7 | from threading import Thread, ThreadError 8 | import io 9 | 10 | img_np = None 11 | # change port and IP 12 | socketIO = SocketIO('http://192.168.0.102', 8020) 13 | live_namespace = socketIO.define(BaseNamespace, '/live') 14 | 15 | def receive_events_thread(): 16 | socketIO.wait() 17 | 18 | def on_camera_response(*args): 19 | global img_np 20 | img_bytes = base64.b64decode(args[0]['data']) 21 | img_np = np.array(Image.open(io.BytesIO(img_bytes))) 22 | 23 | live_namespace.on('camera_update', on_camera_response) 24 | receive_events_thread = Thread(target=receive_events_thread) 25 | receive_events_thread.daemon = True 26 | receive_events_thread.start() 27 | 28 | while True: 29 | try: 30 | cv2.imshow('cam',img_np) 31 | if cv2.waitKey(30) & 0xFF == ord('q'): 32 | break 33 | except: 34 | continue 35 | -------------------------------------------------------------------------------- /core/preprocess.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | EVAL_SIZE = (512, 512) 4 | R, G, B = 123.0, 117.0, 104.0 5 | 6 | def preprocess_for_eval(image, out_shape = EVAL_SIZE, data_format = 'NHWC'): 7 | image = tf.to_float(image) 8 | mean = tf.constant([R, G, B], dtype = image.dtype) 9 | image = image - mean 10 | 11 | # Add image rectangle to bboxes. 12 | bbox_img = tf.constant([[0., 0., 1., 1.]]) 13 | bboxes = bbox_img 14 | 15 | height, width, channels = image.shape 16 | image = tf.expand_dims(image, 0) 17 | image = tf.image.resize_images(image, out_shape, tf.image.ResizeMethod.BILINEAR, False) 18 | image = tf.reshape(image, tf.stack([out_shape[0], out_shape[1], channels])) 19 | 20 | # Split back bounding boxes. 21 | bbox_img = bboxes[0] 22 | bboxes = bboxes[1:] 23 | 24 | # Image data format. 25 | if data_format == 'NCHW': 26 | image = tf.transpose(image, perm = (2, 0, 1)) 27 | return image, bboxes, bbox_img 28 | 29 | 30 | -------------------------------------------------------------------------------- /core/visualization.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import random 3 | import matplotlib.pyplot as plt 4 | import matplotlib.image as mpimg 5 | import matplotlib.cm as mpcm 6 | import tensorflow as tf 7 | import numpy as np 8 | from scipy import misc 9 | 10 | # =========================================================================== # 11 | # Some colormaps. 12 | # =========================================================================== # 13 | def colors_subselect(colors, num_classes=21): 14 | dt = len(colors) // num_classes 15 | sub_colors = [] 16 | for i in range(num_classes): 17 | color = colors[i*dt] 18 | if isinstance(color[0], float): 19 | sub_colors.append([int(c * 255) for c in color]) 20 | else: 21 | sub_colors.append([c for c in color]) 22 | return sub_colors 23 | 24 | colors_plasma = colors_subselect(mpcm.plasma.colors, num_classes=21) 25 | colors_tableau = [(255, 255, 255), (31, 119, 180), (174, 199, 232), (255, 127, 14), (255, 187, 120), 26 | (44, 160, 44), (152, 223, 138), (214, 39, 40), (255, 152, 150), 27 | (148, 103, 189), (197, 176, 213), (140, 86, 75), (196, 156, 148), 28 | (227, 119, 194), (224, 26, 53), (127, 127, 127), (199, 199, 199), 29 | (188, 189, 34), (219, 219, 141), (23, 190, 207), (158, 218, 229)] 30 | 31 | labels = ['aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike' , 'person', 32 | 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor'] 33 | 34 | accept_labels = ['person'] 35 | 36 | # =========================================================================== # 37 | # OpenCV drawing. 38 | # =========================================================================== # 39 | def draw_lines(img, lines, color=[255, 0, 0], thickness=2): 40 | """Draw a collection of lines on an image. 41 | """ 42 | for line in lines: 43 | for x1, y1, x2, y2 in line: 44 | cv2.line(img, (x1, y1), (x2, y2), color, thickness) 45 | 46 | 47 | def draw_rectangle(img, p1, p2, color=[255, 0, 0], thickness=2): 48 | cv2.rectangle(img, p1[::-1], p2[::-1], color, thickness) 49 | 50 | 51 | def draw_bbox(img, bbox, shape, label, color=[255, 0, 0], thickness=2): 52 | p1 = (int(bbox[0] * shape[0]), int(bbox[1] * shape[1])) 53 | p2 = (int(bbox[2] * shape[0]), int(bbox[3] * shape[1])) 54 | cv2.rectangle(img, p1[::-1], p2[::-1], color, thickness) 55 | p1 = (p1[0]+15, p1[1]) 56 | cv2.putText(img, str(label), p1[::-1], cv2.FONT_HERSHEY_DUPLEX, 0.5, color, 1) 57 | 58 | 59 | def bboxes_draw_on_img(img, classes, scores, bboxes, thickness=2): 60 | shape = img.shape 61 | for i in range(bboxes.shape[0]): 62 | label = labels[classes[i] - 1] 63 | if label not in accept_labels: 64 | continue 65 | bbox = bboxes[i] 66 | #color = colors_tableau[classes[i] - 1] 67 | p1 = (int(bbox[0] * shape[0]), int(bbox[1] * shape[1])) 68 | p2 = (int(bbox[2] * shape[0]), int(bbox[3] * shape[1])) 69 | cv2.rectangle(img, p1[::-1], p2[::-1], (0,0,255), 1) 70 | s = '%s' % (label) 71 | p1 = (p1[0]-5, p1[1]) 72 | cv2.putText(img, s, p1[::-1], cv2.FONT_HERSHEY_DUPLEX, 0.7, (0,0,255), 1) 73 | 74 | 75 | # =========================================================================== # 76 | # Matplotlib show... 77 | # =========================================================================== # 78 | def plt_bboxes(img, classes, scores, bboxes, sess, model, size_image, distancelabel = 'm', figsize=(20, 20), linewidth=1.5): 79 | fig = plt.figure(figsize=figsize) 80 | plt.imshow(img) 81 | height = img.shape[0] 82 | width = img.shape[1] 83 | colors = dict() 84 | for i in range(classes.shape[0]): 85 | cls_id = int(classes[i]) 86 | if cls_id >= 0: 87 | if labels[cls_id - 1] not in accept_labels: 88 | continue 89 | score = scores[i] 90 | if cls_id not in colors: 91 | colors[cls_id] = (random.random(), random.random(), random.random()) 92 | ymin = int(bboxes[i, 0] * height) 93 | xmin = int(bboxes[i, 1] * width) 94 | ymax = int(bboxes[i, 2] * height) 95 | xmax = int(bboxes[i, 3] * width) 96 | rect = plt.Rectangle((xmin, ymin), xmax - xmin, 97 | ymax - ymin, fill=False, 98 | edgecolor=colors[cls_id], 99 | linewidth=linewidth) 100 | plt.gca().add_patch(rect) 101 | scalefactor = (((xmax / (width * 1.0)) + (ymax / (height * 1.0))) / ((height / (width * 1.0)) * 10.666666667)) 102 | distance = ((height * width) / ((xmax * ymax * 1.0)) ) * scalefactor 103 | class_name = str(labels[cls_id - 1]) 104 | if class_name == 'person': 105 | probs = sess.run(model.outputs, 106 | feed_dict = {model.X: np.expand_dims(misc.imresize(img[ymin: ymax, xmin: xmax, :], (size_image, size_image)), axis = 0)})[0] 107 | plt.gca().text(xmin, ymin - 2, 108 | '{:s} | {:.3f} {:s} | prob_vand: {:.3f}'.format(class_name, distance, distancelabel, probs[1]), 109 | bbox=dict(facecolor=colors[cls_id], alpha=0.5), 110 | fontsize=12, color='white') 111 | else: 112 | plt.gca().text(xmin, ymin - 2, 113 | '{:s} | {:.3f} {:s}'.format(class_name, distance, distancelabel), 114 | bbox=dict(facecolor=colors[cls_id], alpha=0.5), 115 | fontsize=12, color='white') 116 | plt.savefig('output.png') 117 | -------------------------------------------------------------------------------- /crc16.py: -------------------------------------------------------------------------------- 1 | crc16table = [0x0000, 0x1189, 0x2312, 0x329b, 0x4624, 0x57ad, 0x6536, 0x74bf, 0x8c48, 0x9dc1, 0xaf5a, 0xbed3, 0xca6c, 0xdbe5, 0xe97e, 0xf8f7, 2 | 0x1081, 0x0108, 0x3393, 0x221a, 0x56a5, 0x472c, 0x75b7, 0x643e, 0x9cc9, 0x8d40, 0xbfdb, 0xae52, 0xdaed, 0xcb64, 0xf9ff, 0xe876, 3 | 0x2102, 0x308b, 0x0210, 0x1399, 0x6726, 0x76af, 0x4434, 0x55bd, 0xad4a, 0xbcc3, 0x8e58, 0x9fd1, 0xeb6e, 0xfae7, 0xc87c, 0xd9f5, 4 | 0x3183, 0x200a, 0x1291, 0x0318, 0x77a7, 0x662e, 0x54b5, 0x453c, 0xbdcb, 0xac42, 0x9ed9, 0x8f50, 0xfbef, 0xea66, 0xd8fd, 0xc974, 5 | 0x4204, 0x538d, 0x6116, 0x709f, 0x0420, 0x15a9, 0x2732, 0x36bb, 0xce4c, 0xdfc5, 0xed5e, 0xfcd7, 0x8868, 0x99e1, 0xab7a, 0xbaf3, 6 | 0x5285, 0x430c, 0x7197, 0x601e, 0x14a1, 0x0528, 0x37b3, 0x263a, 0xdecd, 0xcf44, 0xfddf, 0xec56, 0x98e9, 0x8960, 0xbbfb, 0xaa72, 7 | 0x6306, 0x728f, 0x4014, 0x519d, 0x2522, 0x34ab, 0x0630, 0x17b9, 0xef4e, 0xfec7, 0xcc5c, 0xddd5, 0xa96a, 0xb8e3, 0x8a78, 0x9bf1, 8 | 0x7387, 0x620e, 0x5095, 0x411c, 0x35a3, 0x242a, 0x16b1, 0x0738, 0xffcf, 0xee46, 0xdcdd, 0xcd54, 0xb9eb, 0xa862, 0x9af9, 0x8b70, 9 | 0x8408, 0x9581, 0xa71a, 0xb693, 0xc22c, 0xd3a5, 0xe13e, 0xf0b7, 0x0840, 0x19c9, 0x2b52, 0x3adb, 0x4e64, 0x5fed, 0x6d76, 0x7cff, 10 | 0x9489, 0x8500, 0xb79b, 0xa612, 0xd2ad, 0xc324, 0xf1bf, 0xe036, 0x18c1, 0x0948, 0x3bd3, 0x2a5a, 0x5ee5, 0x4f6c, 0x7df7, 0x6c7e, 11 | 0xa50a, 0xb483, 0x8618, 0x9791, 0xe32e, 0xf2a7, 0xc03c, 0xd1b5, 0x2942, 0x38cb, 0x0a50, 0x1bd9, 0x6f66, 0x7eef, 0x4c74, 0x5dfd, 12 | 0xb58b, 0xa402, 0x9699, 0x8710, 0xf3af, 0xe226, 0xd0bd, 0xc134, 0x39c3, 0x284a, 0x1ad1, 0x0b58, 0x7fe7, 0x6e6e, 0x5cf5, 0x4d7c, 13 | 0xc60c, 0xd785, 0xe51e, 0xf497, 0x8028, 0x91a1, 0xa33a, 0xb2b3, 0x4a44, 0x5bcd, 0x6956, 0x78df, 0x0c60, 0x1de9, 0x2f72, 0x3efb, 14 | 0xd68d, 0xc704, 0xf59f, 0xe416, 0x90a9, 0x8120, 0xb3bb, 0xa232, 0x5ac5, 0x4b4c, 0x79d7, 0x685e, 0x1ce1, 0x0d68, 0x3ff3, 0x2e7a, 15 | 0xe70e, 0xf687, 0xc41c, 0xd595, 0xa12a, 0xb0a3, 0x8238, 0x93b1, 0x6b46, 0x7acf, 0x4854, 0x59dd, 0x2d62, 0x3ceb, 0x0e70, 0x1ff9, 16 | 0xf78f, 0xe606, 0xd49d, 0xc514, 0xb1ab, 0xa022, 0x92b9, 0x8330, 0x7bc7, 0x6a4e, 0x58d5, 0x495c, 0x3de3, 0x2c6a, 0x1ef1, 0x0f78] 17 | 18 | def calculate_crc16(arrays): 19 | crc = 0x3692 20 | for i in arrays: 21 | i = int.from_bytes(i,byteorder='big') 22 | crc = crc16table[(crc^i)&0xff]^(crc>>8) 23 | return crc 24 | -------------------------------------------------------------------------------- /crc8.py: -------------------------------------------------------------------------------- 1 | crc8table = [0x00, 0x5e, 0xbc, 0xe2, 0x61, 0x3f, 0xdd, 0x83, 0xc2, 0x9c, 0x7e, 0x20, 0xa3, 0xfd, 0x1f, 0x41, 2 | 0x9d, 0xc3, 0x21, 0x7f, 0xfc, 0xa2, 0x40, 0x1e, 0x5f, 0x01, 0xe3, 0xbd, 0x3e, 0x60, 0x82, 0xdc, 3 | 0x23, 0x7d, 0x9f, 0xc1, 0x42, 0x1c, 0xfe, 0xa0, 0xe1, 0xbf, 0x5d, 0x03, 0x80, 0xde, 0x3c, 0x62, 4 | 0xbe, 0xe0, 0x02, 0x5c, 0xdf, 0x81, 0x63, 0x3d, 0x7c, 0x22, 0xc0, 0x9e, 0x1d, 0x43, 0xa1, 0xff, 5 | 0x46, 0x18, 0xfa, 0xa4, 0x27, 0x79, 0x9b, 0xc5, 0x84, 0xda, 0x38, 0x66, 0xe5, 0xbb, 0x59, 0x07, 6 | 0xdb, 0x85, 0x67, 0x39, 0xba, 0xe4, 0x06, 0x58, 0x19, 0x47, 0xa5, 0xfb, 0x78, 0x26, 0xc4, 0x9a, 7 | 0x65, 0x3b, 0xd9, 0x87, 0x04, 0x5a, 0xb8, 0xe6, 0xa7, 0xf9, 0x1b, 0x45, 0xc6, 0x98, 0x7a, 0x24, 8 | 0xf8, 0xa6, 0x44, 0x1a, 0x99, 0xc7, 0x25, 0x7b, 0x3a, 0x64, 0x86, 0xd8, 0x5b, 0x05, 0xe7, 0xb9, 9 | 0x8c, 0xd2, 0x30, 0x6e, 0xed, 0xb3, 0x51, 0x0f, 0x4e, 0x10, 0xf2, 0xac, 0x2f, 0x71, 0x93, 0xcd, 10 | 0x11, 0x4f, 0xad, 0xf3, 0x70, 0x2e, 0xcc, 0x92, 0xd3, 0x8d, 0x6f, 0x31, 0xb2, 0xec, 0x0e, 0x50, 11 | 0xaf, 0xf1, 0x13, 0x4d, 0xce, 0x90, 0x72, 0x2c, 0x6d, 0x33, 0xd1, 0x8f, 0x0c, 0x52, 0xb0, 0xee, 12 | 0x32, 0x6c, 0x8e, 0xd0, 0x53, 0x0d, 0xef, 0xb1, 0xf0, 0xae, 0x4c, 0x12, 0x91, 0xcf, 0x2d, 0x73, 13 | 0xca, 0x94, 0x76, 0x28, 0xab, 0xf5, 0x17, 0x49, 0x08, 0x56, 0xb4, 0xea, 0x69, 0x37, 0xd5, 0x8b, 14 | 0x57, 0x09, 0xeb, 0xb5, 0x36, 0x68, 0x8a, 0xd4, 0x95, 0xcb, 0x29, 0x77, 0xf4, 0xaa, 0x48, 0x16, 15 | 0xe9, 0xb7, 0x55, 0x0b, 0x88, 0xd6, 0x34, 0x6a, 0x2b, 0x75, 0x97, 0xc9, 0x4a, 0x14, 0xf6, 0xa8, 16 | 0x74, 0x2a, 0xc8, 0x96, 0x15, 0x4b, 0xa9, 0xf7, 0xb6, 0xe8, 0x0a, 0x54, 0xd7, 0x89, 0x6b, 0x35] 17 | 18 | def calculate_crc8(arrays): 19 | crc = 0x77 20 | for i in arrays: 21 | i = int.from_bytes(i,byteorder='big') 22 | crc = crc8table[(crc^i)&0xff] 23 | return crc 24 | -------------------------------------------------------------------------------- /data/MTCNN_model/ONet_landmark/ONet-16.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DevconX/Tello-Python/0e7ef8375e6904a536ff274ec7c868388424327e/data/MTCNN_model/ONet_landmark/ONet-16.data-00000-of-00001 -------------------------------------------------------------------------------- /data/MTCNN_model/ONet_landmark/ONet-16.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DevconX/Tello-Python/0e7ef8375e6904a536ff274ec7c868388424327e/data/MTCNN_model/ONet_landmark/ONet-16.index -------------------------------------------------------------------------------- /data/MTCNN_model/ONet_landmark/ONet-16.meta: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DevconX/Tello-Python/0e7ef8375e6904a536ff274ec7c868388424327e/data/MTCNN_model/ONet_landmark/ONet-16.meta -------------------------------------------------------------------------------- /data/MTCNN_model/ONet_landmark/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "ONet-16" 2 | all_model_checkpoint_paths: "ONet-16" 3 | -------------------------------------------------------------------------------- /data/MTCNN_model/PNet_landmark/PNet-18.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DevconX/Tello-Python/0e7ef8375e6904a536ff274ec7c868388424327e/data/MTCNN_model/PNet_landmark/PNet-18.data-00000-of-00001 -------------------------------------------------------------------------------- /data/MTCNN_model/PNet_landmark/PNet-18.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DevconX/Tello-Python/0e7ef8375e6904a536ff274ec7c868388424327e/data/MTCNN_model/PNet_landmark/PNet-18.index -------------------------------------------------------------------------------- /data/MTCNN_model/PNet_landmark/PNet-18.meta: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DevconX/Tello-Python/0e7ef8375e6904a536ff274ec7c868388424327e/data/MTCNN_model/PNet_landmark/PNet-18.meta -------------------------------------------------------------------------------- /data/MTCNN_model/PNet_landmark/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "PNet-18" 2 | all_model_checkpoint_paths: "PNet-18" 3 | -------------------------------------------------------------------------------- /data/MTCNN_model/RNet_landmark/RNet-14.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DevconX/Tello-Python/0e7ef8375e6904a536ff274ec7c868388424327e/data/MTCNN_model/RNet_landmark/RNet-14.data-00000-of-00001 -------------------------------------------------------------------------------- /data/MTCNN_model/RNet_landmark/RNet-14.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DevconX/Tello-Python/0e7ef8375e6904a536ff274ec7c868388424327e/data/MTCNN_model/RNet_landmark/RNet-14.index -------------------------------------------------------------------------------- /data/MTCNN_model/RNet_landmark/RNet-14.meta: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DevconX/Tello-Python/0e7ef8375e6904a536ff274ec7c868388424327e/data/MTCNN_model/RNet_landmark/RNet-14.meta -------------------------------------------------------------------------------- /data/MTCNN_model/RNet_landmark/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "RNet-14" 2 | all_model_checkpoint_paths: "RNet-14" 3 | -------------------------------------------------------------------------------- /face.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DevconX/Tello-Python/0e7ef8375e6904a536ff274ec7c868388424327e/face.png -------------------------------------------------------------------------------- /flight-face.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DevconX/Tello-Python/0e7ef8375e6904a536ff274ec7c868388424327e/flight-face.png -------------------------------------------------------------------------------- /frame.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import time 4 | 5 | cap = cv2.VideoCapture('temp.h264') 6 | 7 | while(cap.isOpened()): 8 | 9 | ret, frame = cap.read() 10 | print(frame.shape) 11 | cv2.imshow('frame',frame) 12 | if cv2.waitKey(1) & 0xFF == ord('q'): 13 | break 14 | time.sleep(0.03) 15 | 16 | cap.release() 17 | cv2.destroyAllWindows() 18 | -------------------------------------------------------------------------------- /instruction.py: -------------------------------------------------------------------------------- 1 | import crc8 2 | import crc16 3 | import time 4 | import datetime 5 | 6 | messageStart = 0x00cc 7 | wifiMessage = 0x001a 8 | videoRateQuery = 0x0028 9 | lightMessage = 0x0035 10 | flightMessage = 0x0056 11 | logMessage = 0x1050 12 | 13 | videoEncoderRateCommand = 0x0020 14 | videoStartCommand = 0x0025 15 | exposureCommand = 0x0034 16 | timeCommand = 0x0046 17 | stickCommand = 0x0050 18 | takeoffCommand = 0x0054 19 | landCommand = 0x0055 20 | flipCommand = 0x005c 21 | throwtakeoffCommand = 0x005d 22 | palmLandCommand = 0x005e 23 | bounceCommand = 0x1053 24 | 25 | FlipFront = 0 26 | FlipLeft = 1 27 | FlipBack = 2 28 | FlipRight = 3 29 | FlipForwardLeft = 4 30 | FlipBackLeft = 5 31 | FlipBackRight = 6 32 | FlipForwardRight = 7 33 | 34 | seq = 0 35 | rx, ry, lx, ly = 0, 0, 0, 0 36 | throttle = 0 37 | 38 | def create_packet(cmd, pkt, length): 39 | l = length + 11 40 | array_bytes = [] 41 | array_bytes.append(messageStart.to_bytes(1,byteorder='little')) 42 | array_bytes.append((l << 3).to_bytes(1,byteorder='little')) 43 | array_bytes.append((0).to_bytes(1,byteorder='little')) 44 | digested= crc8.calculate_crc8(array_bytes).to_bytes(1,byteorder='little') 45 | array_bytes.append(digested) 46 | array_bytes.append((pkt).to_bytes(1,byteorder='little')) 47 | instructions=(cmd).to_bytes(2,byteorder='little') 48 | array_bytes+=[instructions[0].to_bytes(1,byteorder='little'),instructions[1].to_bytes(1,byteorder='little')] 49 | return array_bytes 50 | 51 | def take_off(): 52 | global seq 53 | array_bytes = create_packet(takeoffCommand,0x68, 0) 54 | seq += 1 55 | instructions=seq.to_bytes(2,byteorder='little') 56 | array_bytes+=[instructions[0].to_bytes(1,byteorder='little'),instructions[1].to_bytes(1,byteorder='little')] 57 | instructions=(crc16.calculate_crc16(array_bytes)).to_bytes(2,byteorder='little') 58 | array_bytes+=[instructions[0].to_bytes(1,byteorder='little'),instructions[1].to_bytes(1,byteorder='little')] 59 | return b''.join(array_bytes) 60 | 61 | def land(): 62 | global seq 63 | array_bytes = create_packet(landCommand,0x68, 1) 64 | seq += 1 65 | instructions=seq.to_bytes(2,byteorder='little') 66 | array_bytes+=[instructions[0].to_bytes(1,byteorder='little'),instructions[1].to_bytes(1,byteorder='little')] 67 | array_bytes.append((0).to_bytes(1,byteorder='little')) 68 | instructions=(crc16.calculate_crc16(array_bytes)).to_bytes(2,byteorder='little') 69 | array_bytes+=[instructions[0].to_bytes(1,byteorder='little'),instructions[1].to_bytes(1,byteorder='little')] 70 | return b''.join(array_bytes) 71 | 72 | def start_video(): 73 | array_bytes = create_packet(videoStartCommand,0x60, 0) 74 | instructions=(0).to_bytes(2,byteorder='little') 75 | array_bytes+=[instructions[0].to_bytes(1,byteorder='little'),instructions[1].to_bytes(1,byteorder='little')] 76 | instructions=(crc16.calculate_crc16(array_bytes)).to_bytes(2,byteorder='little') 77 | array_bytes+=[instructions[0].to_bytes(1,byteorder='little'),instructions[1].to_bytes(1,byteorder='little')] 78 | return b''.join(array_bytes) 79 | 80 | def set_videoencoder_rate(rate): 81 | global seq 82 | array_bytes = create_packet(videoEncoderRateCommand,0x68, 1) 83 | seq += 1 84 | instructions=seq.to_bytes(2,byteorder='little') 85 | array_bytes+=[instructions[0].to_bytes(1,byteorder='little'),instructions[1].to_bytes(1,byteorder='little')] 86 | array_bytes.append(rate.to_bytes(1,byteorder='little')) 87 | instructions=(crc16.calculate_crc16(array_bytes)).to_bytes(2,byteorder='little') 88 | array_bytes+=[instructions[0].to_bytes(1,byteorder='little'),instructions[1].to_bytes(1,byteorder='little')] 89 | return b''.join(array_bytes) 90 | 91 | def connection_string(port): 92 | instructions=port.to_bytes(2,byteorder='little') 93 | array_bytes=[instructions[0].to_bytes(1,byteorder='little'),instructions[1].to_bytes(1,byteorder='little')] 94 | array_bytes=b'conn_req:'+b''.join(array_bytes) 95 | return array_bytes 96 | 97 | def up(val): 98 | global ly 99 | ly = float(val) / 100.0 100 | 101 | def down(val): 102 | global ly 103 | ly = float(val) / 100.0 * -1 104 | 105 | def forward(val): 106 | global ry 107 | ry = float(val) / 100.0 108 | 109 | def backward(val): 110 | global ry 111 | ry = float(val) / 100.0 * -1 112 | 113 | def right(val): 114 | global rx 115 | rx = float(val) / 100.0 116 | 117 | def left(val): 118 | global rx 119 | rx = float(val) / 100.0 * -1 120 | 121 | def clockwise(val): 122 | global lx 123 | lx = float(val) / 100.0 124 | 125 | def counter_clockwise(val): 126 | global lx 127 | lx = float(val) / 100.0 * -1 128 | 129 | def flip(direction): 130 | global seq 131 | array_bytes = create_packet(flipCommand,0x70, 1) 132 | seq += 1 133 | instructions=seq.to_bytes(2,byteorder='little') 134 | array_bytes+=[instructions[0].to_bytes(1,byteorder='little'),instructions[1].to_bytes(1,byteorder='little')] 135 | array_bytes.append(direction.to_bytes(1,byteorder='little')) 136 | instructions=(crc16.calculate_crc16(array_bytes)).to_bytes(2,byteorder='little') 137 | array_bytes+=[instructions[0].to_bytes(1,byteorder='little'),instructions[1].to_bytes(1,byteorder='little')] 138 | return b''.join(array_bytes) 139 | 140 | def front_flip(): 141 | return flip(FlipFront) 142 | 143 | def back_flip(): 144 | return flip(FlipBack) 145 | 146 | def right_flip(): 147 | return flip(FlipRight) 148 | 149 | def left_flip(): 150 | return flip(FlipLeft) 151 | 152 | def send_stickcommand(): 153 | array_bytes = create_packet(stickCommand, 0x60, 11) 154 | instructions=(0).to_bytes(2,byteorder='little') 155 | array_bytes+=[instructions[0].to_bytes(1,byteorder='little'),instructions[1].to_bytes(1,byteorder='little')] 156 | axis1 = int(660.0*rx + 1024.0) 157 | axis2 = int(660.0*ry + 1024.0) 158 | axis3 = int(660.0*ly + 1024.0) 159 | axis4 = int(660.0*lx + 1024.0) 160 | axis5 = int(throttle) 161 | packed = (axis1)&0x7FF | (axis2&0x7FF)<<11 | (0x7FF&axis3)<<22 | (0x7FF&axis4)<<33 | (axis5)<<44 162 | array_bytes.append((0xFF&packed).to_bytes(1,byteorder='little')) 163 | array_bytes.append((packed>>8&0xFF).to_bytes(1,byteorder='little')) 164 | array_bytes.append((packed>>16&0xFF).to_bytes(1,byteorder='little')) 165 | array_bytes.append((packed>>24&0xFF).to_bytes(1,byteorder='little')) 166 | array_bytes.append((packed>>32&0xFF).to_bytes(1,byteorder='little')) 167 | array_bytes.append((packed>>40&0xFF).to_bytes(1,byteorder='little')) 168 | now = datetime.datetime.now() 169 | array_bytes.append((now.hour).to_bytes(1,byteorder='little')) 170 | array_bytes.append((now.minute).to_bytes(1,byteorder='little')) 171 | array_bytes.append((now.second).to_bytes(1,byteorder='little')) 172 | array_bytes.append((int(time.time() * 10)&0xff).to_bytes(1,byteorder='little')) 173 | array_bytes.append((int(time.time() * 10)>>8).to_bytes(4,byteorder='little')[0].to_bytes(1,byteorder='little')) 174 | instructions=(crc16.calculate_crc16(array_bytes)).to_bytes(2,byteorder='little') 175 | array_bytes+=[instructions[0].to_bytes(1,byteorder='little'),instructions[1].to_bytes(1,byteorder='little')] 176 | return b''.join(array_bytes) 177 | 178 | print(take_off()) 179 | -------------------------------------------------------------------------------- /server.py: -------------------------------------------------------------------------------- 1 | from flask import Flask, Response 2 | from flask_socketio import SocketIO, send, emit 3 | from queue import Queue 4 | import base64 5 | import cv2 6 | import numpy as np 7 | from PIL import Image 8 | import io 9 | d = dirname(dirname(abspath(__file__))) 10 | 11 | app = Flask(__name__) 12 | app.queue = Queue() 13 | socketio = SocketIO(app) 14 | 15 | @socketio.on('connect', namespace='/live') 16 | def test_connect(): 17 | print('Client wants to connect.') 18 | emit('response', {'data': 'OK'},broadcast=True) 19 | 20 | @socketio.on('disconnect', namespace='/live') 21 | def test_disconnect(): 22 | print('Client disconnected') 23 | 24 | @socketio.on('livevideo', namespace='/live') 25 | def test_live(message): 26 | app.queue.put(message['data']) 27 | emit('camera_update', {'data': app.queue.get()},broadcast=True) 28 | 29 | # change port and IP 30 | if __name__ == '__main__': 31 | socketio.run(app, host = '0.0.0.0', port = 8020,debug=True) 32 | -------------------------------------------------------------------------------- /standard_fields.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Contains classes specifying naming conventions used for object detection. 17 | 18 | 19 | Specifies: 20 | InputDataFields: standard fields used by reader/preprocessor/batcher. 21 | DetectionResultFields: standard fields returned by object detector. 22 | BoxListFields: standard field used by BoxList 23 | TfExampleFields: standard fields for tf-example data format (go/tf-example). 24 | """ 25 | 26 | 27 | class InputDataFields(object): 28 | """Names for the input tensors. 29 | 30 | Holds the standard data field names to use for identifying input tensors. This 31 | should be used by the decoder to identify keys for the returned tensor_dict 32 | containing input tensors. And it should be used by the model to identify the 33 | tensors it needs. 34 | 35 | Attributes: 36 | image: image. 37 | original_image: image in the original input size. 38 | key: unique key corresponding to image. 39 | source_id: source of the original image. 40 | filename: original filename of the dataset (without common path). 41 | groundtruth_image_classes: image-level class labels. 42 | groundtruth_boxes: coordinates of the ground truth boxes in the image. 43 | groundtruth_classes: box-level class labels. 44 | groundtruth_label_types: box-level label types (e.g. explicit negative). 45 | groundtruth_is_crowd: [DEPRECATED, use groundtruth_group_of instead] 46 | is the groundtruth a single object or a crowd. 47 | groundtruth_area: area of a groundtruth segment. 48 | groundtruth_difficult: is a `difficult` object 49 | groundtruth_group_of: is a `group_of` objects, e.g. multiple objects of the 50 | same class, forming a connected group, where instances are heavily 51 | occluding each other. 52 | proposal_boxes: coordinates of object proposal boxes. 53 | proposal_objectness: objectness score of each proposal. 54 | groundtruth_instance_masks: ground truth instance masks. 55 | groundtruth_instance_boundaries: ground truth instance boundaries. 56 | groundtruth_instance_classes: instance mask-level class labels. 57 | groundtruth_keypoints: ground truth keypoints. 58 | groundtruth_keypoint_visibilities: ground truth keypoint visibilities. 59 | groundtruth_label_scores: groundtruth label scores. 60 | groundtruth_weights: groundtruth weight factor for bounding boxes. 61 | num_groundtruth_boxes: number of groundtruth boxes. 62 | true_image_shapes: true shapes of images in the resized images, as resized 63 | images can be padded with zeros. 64 | verified_labels: list of human-verified image-level labels (note, that a 65 | label can be verified both as positive and negative). 66 | multiclass_scores: the label score per class for each box. 67 | """ 68 | image = 'image' 69 | original_image = 'original_image' 70 | key = 'key' 71 | source_id = 'source_id' 72 | filename = 'filename' 73 | groundtruth_image_classes = 'groundtruth_image_classes' 74 | groundtruth_boxes = 'groundtruth_boxes' 75 | groundtruth_classes = 'groundtruth_classes' 76 | groundtruth_label_types = 'groundtruth_label_types' 77 | groundtruth_is_crowd = 'groundtruth_is_crowd' 78 | groundtruth_area = 'groundtruth_area' 79 | groundtruth_difficult = 'groundtruth_difficult' 80 | groundtruth_group_of = 'groundtruth_group_of' 81 | proposal_boxes = 'proposal_boxes' 82 | proposal_objectness = 'proposal_objectness' 83 | groundtruth_instance_masks = 'groundtruth_instance_masks' 84 | groundtruth_instance_boundaries = 'groundtruth_instance_boundaries' 85 | groundtruth_instance_classes = 'groundtruth_instance_classes' 86 | groundtruth_keypoints = 'groundtruth_keypoints' 87 | groundtruth_keypoint_visibilities = 'groundtruth_keypoint_visibilities' 88 | groundtruth_label_scores = 'groundtruth_label_scores' 89 | groundtruth_weights = 'groundtruth_weights' 90 | num_groundtruth_boxes = 'num_groundtruth_boxes' 91 | true_image_shape = 'true_image_shape' 92 | verified_labels = 'verified_labels' 93 | multiclass_scores = 'multiclass_scores' 94 | 95 | 96 | class DetectionResultFields(object): 97 | """Naming conventions for storing the output of the detector. 98 | 99 | Attributes: 100 | source_id: source of the original image. 101 | key: unique key corresponding to image. 102 | detection_boxes: coordinates of the detection boxes in the image. 103 | detection_scores: detection scores for the detection boxes in the image. 104 | detection_classes: detection-level class labels. 105 | detection_masks: contains a segmentation mask for each detection box. 106 | detection_boundaries: contains an object boundary for each detection box. 107 | detection_keypoints: contains detection keypoints for each detection box. 108 | num_detections: number of detections in the batch. 109 | """ 110 | 111 | source_id = 'source_id' 112 | key = 'key' 113 | detection_boxes = 'detection_boxes' 114 | detection_scores = 'detection_scores' 115 | detection_classes = 'detection_classes' 116 | detection_masks = 'detection_masks' 117 | detection_boundaries = 'detection_boundaries' 118 | detection_keypoints = 'detection_keypoints' 119 | num_detections = 'num_detections' 120 | 121 | 122 | class BoxListFields(object): 123 | """Naming conventions for BoxLists. 124 | 125 | Attributes: 126 | boxes: bounding box coordinates. 127 | classes: classes per bounding box. 128 | scores: scores per bounding box. 129 | weights: sample weights per bounding box. 130 | objectness: objectness score per bounding box. 131 | masks: masks per bounding box. 132 | boundaries: boundaries per bounding box. 133 | keypoints: keypoints per bounding box. 134 | keypoint_heatmaps: keypoint heatmaps per bounding box. 135 | is_crowd: is_crowd annotation per bounding box. 136 | """ 137 | boxes = 'boxes' 138 | classes = 'classes' 139 | scores = 'scores' 140 | weights = 'weights' 141 | objectness = 'objectness' 142 | masks = 'masks' 143 | boundaries = 'boundaries' 144 | keypoints = 'keypoints' 145 | keypoint_heatmaps = 'keypoint_heatmaps' 146 | is_crowd = 'is_crowd' 147 | 148 | 149 | class TfExampleFields(object): 150 | """TF-example proto feature names for object detection. 151 | 152 | Holds the standard feature names to load from an Example proto for object 153 | detection. 154 | 155 | Attributes: 156 | image_encoded: JPEG encoded string 157 | image_format: image format, e.g. "JPEG" 158 | filename: filename 159 | channels: number of channels of image 160 | colorspace: colorspace, e.g. "RGB" 161 | height: height of image in pixels, e.g. 462 162 | width: width of image in pixels, e.g. 581 163 | source_id: original source of the image 164 | object_class_text: labels in text format, e.g. ["person", "cat"] 165 | object_class_label: labels in numbers, e.g. [16, 8] 166 | object_bbox_xmin: xmin coordinates of groundtruth box, e.g. 10, 30 167 | object_bbox_xmax: xmax coordinates of groundtruth box, e.g. 50, 40 168 | object_bbox_ymin: ymin coordinates of groundtruth box, e.g. 40, 50 169 | object_bbox_ymax: ymax coordinates of groundtruth box, e.g. 80, 70 170 | object_view: viewpoint of object, e.g. ["frontal", "left"] 171 | object_truncated: is object truncated, e.g. [true, false] 172 | object_occluded: is object occluded, e.g. [true, false] 173 | object_difficult: is object difficult, e.g. [true, false] 174 | object_group_of: is object a single object or a group of objects 175 | object_depiction: is object a depiction 176 | object_is_crowd: [DEPRECATED, use object_group_of instead] 177 | is the object a single object or a crowd 178 | object_segment_area: the area of the segment. 179 | object_weight: a weight factor for the object's bounding box. 180 | instance_masks: instance segmentation masks. 181 | instance_boundaries: instance boundaries. 182 | instance_classes: Classes for each instance segmentation mask. 183 | detection_class_label: class label in numbers. 184 | detection_bbox_ymin: ymin coordinates of a detection box. 185 | detection_bbox_xmin: xmin coordinates of a detection box. 186 | detection_bbox_ymax: ymax coordinates of a detection box. 187 | detection_bbox_xmax: xmax coordinates of a detection box. 188 | detection_score: detection score for the class label and box. 189 | """ 190 | image_encoded = 'image/encoded' 191 | image_format = 'image/format' # format is reserved keyword 192 | filename = 'image/filename' 193 | channels = 'image/channels' 194 | colorspace = 'image/colorspace' 195 | height = 'image/height' 196 | width = 'image/width' 197 | source_id = 'image/source_id' 198 | object_class_text = 'image/object/class/text' 199 | object_class_label = 'image/object/class/label' 200 | object_bbox_ymin = 'image/object/bbox/ymin' 201 | object_bbox_xmin = 'image/object/bbox/xmin' 202 | object_bbox_ymax = 'image/object/bbox/ymax' 203 | object_bbox_xmax = 'image/object/bbox/xmax' 204 | object_view = 'image/object/view' 205 | object_truncated = 'image/object/truncated' 206 | object_occluded = 'image/object/occluded' 207 | object_difficult = 'image/object/difficult' 208 | object_group_of = 'image/object/group_of' 209 | object_depiction = 'image/object/depiction' 210 | object_is_crowd = 'image/object/is_crowd' 211 | object_segment_area = 'image/object/segment/area' 212 | object_weight = 'image/object/weight' 213 | instance_masks = 'image/segmentation/object' 214 | instance_boundaries = 'image/boundaries/object' 215 | instance_classes = 'image/segmentation/object/class' 216 | detection_class_label = 'image/detection/label' 217 | detection_bbox_ymin = 'image/detection/bbox/ymin' 218 | detection_bbox_xmin = 'image/detection/bbox/xmin' 219 | detection_bbox_ymax = 'image/detection/bbox/ymax' 220 | detection_bbox_xmax = 'image/detection/bbox/xmax' 221 | detection_score = 'image/detection/score' 222 | -------------------------------------------------------------------------------- /tello.py: -------------------------------------------------------------------------------- 1 | """License. 2 | 3 | Copyright 2018 Todd Mueller 4 | 5 | This program is free software: you can redistribute it and/or modify 6 | it under the terms of the GNU General Public License as published by 7 | the Free Software Foundation, either version 3 of the License, or 8 | (at your option) any later version. 9 | 10 | This program is distributed in the hope that it will be useful, 11 | but WITHOUT ANY WARRANTY; without even the implied warranty of 12 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 | GNU General Public License for more details. 14 | 15 | You should have received a copy of the GNU General Public License 16 | along with this program. If not, see . 17 | 18 | """ 19 | 20 | import socket 21 | import threading 22 | import time 23 | import traceback 24 | 25 | class Tello: 26 | """Wrapper to simply interactions with the Ryze Tello drone.""" 27 | 28 | def __init__(self, local_ip, local_port, imperial=True, command_timeout=.3, tello_ip='192.168.10.1', tello_port=8889): 29 | """Binds to the local IP/port and puts the Tello into command mode. 30 | 31 | Args: 32 | local_ip (str): Local IP address to bind. 33 | local_port (int): Local port to bind. 34 | imperial (bool): If True, speed is MPH and distance is feet. 35 | If False, speed is KPH and distance is meters. 36 | command_timeout (int|float): Number of seconds to wait for a response to a command. 37 | tello_ip (str): Tello IP. 38 | tello_port (int): Tello port. 39 | 40 | Raises: 41 | RuntimeError: If the Tello rejects the attempt to enter command mode. 42 | 43 | """ 44 | 45 | self.abort_flag = False 46 | self.command_timeout = command_timeout 47 | self.imperial = imperial 48 | self.response = None 49 | self.socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) 50 | self.tello_address = (tello_ip, tello_port) 51 | 52 | self.socket.bind((local_ip, local_port)) 53 | 54 | self.receive_thread = threading.Thread(target=self._receive_thread) 55 | self.receive_thread.daemon=True 56 | 57 | self.receive_thread.start() 58 | 59 | if self.send_command('command') != 'OK': 60 | raise RuntimeError('Tello rejected attempt to enter command mode') 61 | 62 | def __del__(self): 63 | """Closes the local socket.""" 64 | 65 | self.socket.close() 66 | 67 | def _receive_thread(self): 68 | """Listens for responses from the Tello. 69 | 70 | Runs as a thread, sets self.response to whatever the Tello last returned. 71 | 72 | """ 73 | while True: 74 | try: 75 | self.response, ip = self.socket.recvfrom(256) 76 | except Exception: 77 | break 78 | 79 | def flip(self, direction): 80 | """Flips. 81 | 82 | Args: 83 | direction (str): Direction to flip, 'l', 'r', 'f', 'b', 'lb', 'lf', 'rb' or 'rf'. 84 | 85 | Returns: 86 | str: Response from Tello, 'OK' or 'FALSE'. 87 | 88 | """ 89 | 90 | return self.send_command('flip %s' % direction) 91 | 92 | def get_battery(self): 93 | """Returns percent battery life remaining. 94 | 95 | Returns: 96 | int: Percent battery life remaining. 97 | 98 | """ 99 | 100 | battery = self.send_command('battery?') 101 | 102 | try: 103 | battery = int(battery) 104 | except: 105 | pass 106 | 107 | return battery 108 | 109 | 110 | def get_flight_time(self): 111 | """Returns the number of seconds elapsed during flight. 112 | 113 | Returns: 114 | int: Seconds elapsed during flight. 115 | 116 | """ 117 | 118 | flight_time = self.send_command('time?') 119 | 120 | try: 121 | flight_time = int(flight_time) 122 | except: 123 | pass 124 | 125 | return flight_time 126 | 127 | def get_speed(self): 128 | """Returns the current speed. 129 | 130 | Returns: 131 | int: Current speed in KPH or MPH. 132 | 133 | """ 134 | 135 | speed = self.send_command('speed?') 136 | 137 | try: 138 | speed = float(speed) 139 | 140 | if self.imperial is True: 141 | speed = round((speed / 44.704), 1) 142 | else: 143 | speed = round((speed / 27.7778), 1) 144 | except: 145 | pass 146 | 147 | return speed 148 | 149 | def land(self): 150 | """Initiates landing. 151 | 152 | Returns: 153 | str: Response from Tello, 'OK' or 'FALSE'. 154 | 155 | """ 156 | 157 | return self.send_command('land') 158 | 159 | def move(self, direction, distance): 160 | """Moves in a direction for a distance. 161 | 162 | This method expects meters or feet. The Tello API expects distances 163 | from 20 to 500 centimeters. 164 | 165 | Metric: .1 to 5 meters 166 | Imperial: .7 to 16.4 feet 167 | 168 | Args: 169 | direction (str): Direction to move, 'forward', 'back', 'right' or 'left'. 170 | distance (int|float): Distance to move. 171 | 172 | Returns: 173 | str: Response from Tello, 'OK' or 'FALSE'. 174 | 175 | """ 176 | 177 | distance = float(distance) 178 | 179 | if self.imperial is True: 180 | distance = int(round(distance * 30.48)) 181 | else: 182 | distance = int(round(distance * 100)) 183 | 184 | return self.send_command('%s %s' % (direction, distance)) 185 | 186 | def move_backward(self, distance): 187 | """Moves backward for a distance. 188 | 189 | See comments for Tello.move(). 190 | 191 | Args: 192 | distance (int): Distance to move. 193 | 194 | Returns: 195 | str: Response from Tello, 'OK' or 'FALSE'. 196 | 197 | """ 198 | 199 | return self.move('back', distance) 200 | 201 | def move_down(self, distance): 202 | """Moves down for a distance. 203 | 204 | See comments for Tello.move(). 205 | 206 | Args: 207 | distance (int): Distance to move. 208 | 209 | Returns: 210 | str: Response from Tello, 'OK' or 'FALSE'. 211 | 212 | """ 213 | 214 | return self.move('down', distance) 215 | 216 | def move_forward(self, distance): 217 | """Moves forward for a distance. 218 | 219 | See comments for Tello.move(). 220 | 221 | Args: 222 | distance (int): Distance to move. 223 | 224 | Returns: 225 | str: Response from Tello, 'OK' or 'FALSE'. 226 | 227 | """ 228 | return self.move('forward', distance) 229 | 230 | def move_left(self, distance): 231 | """Moves left for a distance. 232 | 233 | See comments for Tello.move(). 234 | 235 | Args: 236 | distance (int): Distance to move. 237 | 238 | Returns: 239 | str: Response from Tello, 'OK' or 'FALSE'. 240 | 241 | """ 242 | return self.move('left', distance) 243 | 244 | def move_right(self, distance): 245 | """Moves right for a distance. 246 | 247 | See comments for Tello.move(). 248 | 249 | Args: 250 | distance (int): Distance to move. 251 | 252 | """ 253 | return self.move('right', distance) 254 | 255 | def move_up(self, distance): 256 | """Moves up for a distance. 257 | 258 | See comments for Tello.move(). 259 | 260 | Args: 261 | distance (int): Distance to move. 262 | 263 | Returns: 264 | str: Response from Tello, 'OK' or 'FALSE'. 265 | 266 | """ 267 | 268 | return self.move('up', distance) 269 | 270 | def send_command(self, command): 271 | """Sends a command to the Tello and waits for a response. 272 | 273 | If self.command_timeout is exceeded before a response is received, 274 | a RuntimeError exception is raised. 275 | 276 | Args: 277 | command (str): Command to send. 278 | 279 | Returns: 280 | str: Response from Tello. 281 | 282 | Raises: 283 | RuntimeError: If no response is received within self.timeout seconds. 284 | 285 | """ 286 | 287 | self.abort_flag = False 288 | timer = threading.Timer(self.command_timeout, self.set_abort_flag) 289 | 290 | self.socket.sendto(command.encode('utf-8'), self.tello_address) 291 | 292 | timer.start() 293 | 294 | while self.response is None: 295 | if self.abort_flag is True: 296 | raise RuntimeError('No response to command') 297 | 298 | timer.cancel() 299 | 300 | response = self.response.decode('utf-8') 301 | self.response = None 302 | 303 | return response 304 | 305 | def set_abort_flag(self): 306 | """Sets self.abort_flag to True. 307 | 308 | Used by the timer in Tello.send_command() to indicate to that a response 309 | timeout has occurred. 310 | 311 | """ 312 | 313 | self.abort_flag = True 314 | 315 | def set_speed(self, speed): 316 | """Sets speed. 317 | 318 | This method expects KPH or MPH. The Tello API expects speeds from 319 | 1 to 100 centimeters/second. 320 | 321 | Metric: .1 to 3.6 KPH 322 | Imperial: .1 to 2.2 MPH 323 | 324 | Args: 325 | speed (int|float): Speed. 326 | 327 | Returns: 328 | str: Response from Tello, 'OK' or 'FALSE'. 329 | 330 | """ 331 | 332 | speed = float(speed) 333 | 334 | if self.imperial is True: 335 | speed = int(round(speed * 44.704)) 336 | else: 337 | speed = int(round(speed * 27.7778)) 338 | 339 | return self.send_command('speed %s' % speed) 340 | 341 | def takeoff(self): 342 | """Initiates take-off. 343 | 344 | Returns: 345 | str: Response from Tello, 'OK' or 'FALSE'. 346 | 347 | """ 348 | 349 | return self.send_command('takeoff') 350 | 351 | def rotate_cw(self, degrees): 352 | """Rotates clockwise. 353 | 354 | Args: 355 | degrees (int): Degrees to rotate, 1 to 360. 356 | 357 | Returns: 358 | str: Response from Tello, 'OK' or 'FALSE'. 359 | 360 | """ 361 | 362 | return self.send_command('cw %s' % degrees) 363 | 364 | def rotate_ccw(self, degrees): 365 | """Rotates counter-clockwise. 366 | 367 | Args: 368 | degrees (int): Degrees to rotate, 1 to 360. 369 | 370 | Returns: 371 | str: Response from Tello, 'OK' or 'FALSE'. 372 | 373 | """ 374 | return self.send_command('ccw %s' % degrees) 375 | -------------------------------------------------------------------------------- /temp.h264: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DevconX/Tello-Python/0e7ef8375e6904a536ff274ec7c868388424327e/temp.h264 -------------------------------------------------------------------------------- /test_flight.py: -------------------------------------------------------------------------------- 1 | import socket 2 | import threading 3 | import time 4 | import pickle 5 | import cv2 6 | from Detection.MtcnnDetector import MtcnnDetector 7 | from Detection.detector import Detector 8 | from Detection.fcn_detector import FcnDetector 9 | from train_models.mtcnn_model import P_Net, R_Net, O_Net 10 | import visualization_utils 11 | from instruction import * 12 | 13 | thresh = [0.7, 0.1, 0.1] 14 | min_face_size = 24 15 | stride = 2 16 | slide_window = True 17 | shuffle = False 18 | detectors = [None, None, None] 19 | prefix = ['data/MTCNN_model/PNet_landmark/PNet', 'data/MTCNN_model/RNet_landmark/RNet', 'data/MTCNN_model/ONet_landmark/ONet'] 20 | epoch = [18, 14, 16] 21 | model_path = ['%s-%s' % (x, y) for x, y in zip(prefix, epoch)] 22 | PNet = FcnDetector(P_Net, model_path[0]) 23 | detectors[0] = PNet 24 | RNet = Detector(R_Net, 24, 1, model_path[1]) 25 | detectors[1] = RNet 26 | ONet = Detector(O_Net, 48, 1, model_path[2]) 27 | detectors[2] = ONet 28 | mtcnn_detector = MtcnnDetector(detectors=detectors, min_face_size=min_face_size,stride=stride, threshold=thresh, slide_window=slide_window) 29 | # in cm, etc, assumed same width during flight 30 | initial_flight_height = 20 31 | # focal length = (P * D) / W 32 | # assume D = W 33 | focal_length = initial_flight_height * 10 34 | 35 | def midpoint(x, y): 36 | return ((x[0] + y[0]) * 0.5, (x[1] + y[1]) * 0.5) 37 | 38 | def distance_to_camera(initial_width, focal_length, virtual_width): 39 | return (initial_width * focalLength) / virtual_width 40 | 41 | header = b'\x00\x00\x00\x01gM@(\x95\xa0<\x05\xb9\x00\x00\x00\x01h\xee8\x80' 42 | h264 = [] 43 | port_video = 6038 44 | sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, socket.IPPROTO_UDP) 45 | sock_video = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, socket.IPPROTO_UDP) 46 | sock_video.bind(('192.168.10.2', port_video)) 47 | 48 | def receive_video(): 49 | global h264 50 | while True: 51 | data, addr = sock_video.recvfrom(2048) 52 | if len(data[2:]) not in [8,13]: 53 | h264.append(data[2:]) 54 | 55 | receive_thread = threading.Thread(target=receive_video) 56 | receive_thread.daemon=True 57 | receive_thread.start() 58 | 59 | sock.sendto(connection_string(port_video), ('192.168.10.1',8889)) 60 | sock.sendto(start_video(), ('192.168.10.1',8889)) 61 | sock.sendto(set_videoencoder_rate(4), ('192.168.10.1',8889)) 62 | 63 | def send_video(): 64 | try: 65 | while True: 66 | time.sleep(0.1) 67 | sock.sendto(start_video(), ('192.168.10.1',8889)) 68 | except KeyboardInterrupt: 69 | pass 70 | 71 | send_thread = threading.Thread(target=send_video) 72 | send_thread.daemon=True 73 | send_thread.start() 74 | 75 | def send_control(): 76 | try: 77 | while True: 78 | time.sleep(0.01) 79 | stick = send_stickcommand() 80 | sock.sendto(stick, ('192.168.10.1',8889)) 81 | except KeyboardInterrupt: 82 | pass 83 | 84 | control_thread = threading.Thread(target=send_control) 85 | control_thread.daemon=True 86 | control_thread.start() 87 | 88 | def send_flight(): 89 | #sock.sendto(take_off(), ('192.168.10.1',8889)) 90 | time.sleep(2) 91 | print(land()) 92 | sock.sendto(land(), ('192.168.10.1',8889)) 93 | 94 | flight_thread = threading.Thread(target=send_flight) 95 | flight_thread.daemon=True 96 | flight_thread.start() 97 | 98 | while True: 99 | k = sum([int(len(i) < 1000) for i in h264]) 100 | temp = [] 101 | for i in reversed(range(len(h264))): 102 | if len(h264[i]) < 1000: 103 | count, temp = 0, [h264[i]] 104 | for n in reversed(range(len(h264[:i]))): 105 | if len(h264[n]) < 1000: 106 | count += 1 107 | if count == 3: 108 | break 109 | temp.append(h264[n]) 110 | break 111 | if k > 2: 112 | with open('temp.h264','wb') as fopen: 113 | fopen.write(header+b''.join(temp[::-1])) 114 | h264.clear() 115 | cap = cv2.VideoCapture('temp.h264') 116 | while True: 117 | try: 118 | last_time = time.time() 119 | ret, img = cap.read() 120 | boxes_c,_ = mtcnn_detector.detect(img) 121 | for u in range(boxes_c.shape[0]): 122 | bbox = boxes_c[u, :4] 123 | # tl,tr,bl,br = [int(bbox[0]),int(bbox[1])],[int(bbox[2]),int(bbox[1])],[int(bbox[0]),int(bbox[3])],[int(bbox[2]),int(bbox[3])] 124 | # (tltrX, tltrY) = midpoint(tl, tr) 125 | # (blbrX, blbrY) = midpoint(bl, br) 126 | # (tlblX, tlblY) = midpoint(tl, bl) 127 | # (trbrX, trbrY) = midpoint(tr, br) 128 | # print(land()) 129 | # # virtual width 130 | # dA = dist.euclidean((tltrX, tltrY), (blbrX, blbrY)) 131 | # # virtual height 132 | # dB = dist.euclidean((tlblX, tlblY), (trbrX, trbrY)) 133 | # distance = distance_to_camera(initial_flight_height, focal_length, dA) 134 | # print(distance) 135 | visualization_utils.draw_bounding_box_on_image_array(img,int(bbox[1]),int(bbox[0]), 136 | int(bbox[3]), 137 | int(bbox[2]), 138 | 'YellowGreen',display_str_list=['face'], 139 | use_normalized_coordinates=False) 140 | cv2.putText(img,'%.1f FPS'%(1/(time.time() - last_time)), (0,20), cv2.FONT_HERSHEY_SIMPLEX, 0.5, 255) 141 | cv2.imshow('cam',img) 142 | if cv2.waitKey(30) & 0xFF == ord('q'): 143 | break 144 | except: 145 | break 146 | -------------------------------------------------------------------------------- /test_flight_rotate.py: -------------------------------------------------------------------------------- 1 | import socket 2 | import threading 3 | import time 4 | import pickle 5 | import cv2 6 | from Detection.MtcnnDetector import MtcnnDetector 7 | from Detection.detector import Detector 8 | from Detection.fcn_detector import FcnDetector 9 | from train_models.mtcnn_model import P_Net, R_Net, O_Net 10 | import visualization_utils 11 | from instruction import * 12 | 13 | thresh = [0.7, 0.1, 0.1] 14 | min_face_size = 24 15 | stride = 2 16 | slide_window = True 17 | shuffle = False 18 | detectors = [None, None, None] 19 | prefix = ['data/MTCNN_model/PNet_landmark/PNet', 'data/MTCNN_model/RNet_landmark/RNet', 'data/MTCNN_model/ONet_landmark/ONet'] 20 | epoch = [18, 14, 16] 21 | model_path = ['%s-%s' % (x, y) for x, y in zip(prefix, epoch)] 22 | PNet = FcnDetector(P_Net, model_path[0]) 23 | detectors[0] = PNet 24 | RNet = Detector(R_Net, 24, 1, model_path[1]) 25 | detectors[1] = RNet 26 | ONet = Detector(O_Net, 48, 1, model_path[2]) 27 | detectors[2] = ONet 28 | mtcnn_detector = MtcnnDetector(detectors=detectors, min_face_size=min_face_size,stride=stride, threshold=thresh, slide_window=slide_window) 29 | # in cm, etc, assumed same width during flight 30 | initial_flight_height = 20 31 | # focal length = (P * D) / W 32 | # assume D = W 33 | focal_length = initial_flight_height * 10 34 | 35 | def midpoint(x, y): 36 | return ((x[0] + y[0]) * 0.5, (x[1] + y[1]) * 0.5) 37 | 38 | def distance_to_camera(initial_width, focal_length, virtual_width): 39 | return (initial_width * focalLength) / virtual_width 40 | 41 | header = b'\x00\x00\x00\x01gM@(\x95\xa0<\x05\xb9\x00\x00\x00\x01h\xee8\x80' 42 | h264 = [] 43 | port_video = 6038 44 | sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, socket.IPPROTO_UDP) 45 | sock_video = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, socket.IPPROTO_UDP) 46 | sock_video.bind(('192.168.10.2', port_video)) 47 | 48 | def receive_video(): 49 | global h264 50 | while True: 51 | data, addr = sock_video.recvfrom(2048) 52 | if len(data[2:]) not in [8,13]: 53 | h264.append(data[2:]) 54 | 55 | def run_cam(): 56 | global h264 57 | while True: 58 | k = sum([int(len(i) < 1000) for i in h264]) 59 | temp = [] 60 | for i in reversed(range(len(h264))): 61 | if len(h264[i]) < 1000: 62 | count, temp = 0, [h264[i]] 63 | for n in reversed(range(len(h264[:i]))): 64 | if len(h264[n]) < 1000: 65 | count += 1 66 | if count == 3: 67 | break 68 | temp.append(h264[n]) 69 | break 70 | if k > 2: 71 | with open('temp.h264','wb') as fopen: 72 | fopen.write(header+b''.join(temp[::-1])) 73 | h264.clear() 74 | cap = cv2.VideoCapture('temp.h264') 75 | while True: 76 | try: 77 | last_time = time.time() 78 | ret, img = cap.read() 79 | boxes_c,_ = mtcnn_detector.detect(img) 80 | for u in range(boxes_c.shape[0]): 81 | bbox = boxes_c[u, :4] 82 | # tl,tr,bl,br = [int(bbox[0]),int(bbox[1])],[int(bbox[2]),int(bbox[1])],[int(bbox[0]),int(bbox[3])],[int(bbox[2]),int(bbox[3])] 83 | # (tltrX, tltrY) = midpoint(tl, tr) 84 | # (blbrX, blbrY) = midpoint(bl, br) 85 | # (tlblX, tlblY) = midpoint(tl, bl) 86 | # (trbrX, trbrY) = midpoint(tr, br) 87 | # # virtual width 88 | # dA = dist.euclidean((tltrX, tltrY), (blbrX, blbrY)) 89 | # # virtual height 90 | # dB = dist.euclidean((tlblX, tlblY), (trbrX, trbrY)) 91 | # distance = distance_to_camera(initial_flight_height, focal_length, dA) 92 | distance = 1 93 | visualization_utils.draw_bounding_box_on_image_array(img,int(bbox[1]),int(bbox[0]), 94 | int(bbox[3]), 95 | int(bbox[2]), 96 | 'YellowGreen',display_str_list=['face','','distance: %.2fcm'%(distance)], 97 | use_normalized_coordinates=False) 98 | cv2.putText(img,'%.1f FPS'%(1/(time.time() - last_time)), (0,20), cv2.FONT_HERSHEY_SIMPLEX, 0.5, 255) 99 | cv2.imshow('cam',img) 100 | if cv2.waitKey(30) & 0xFF == ord('q'): 101 | break 102 | except: 103 | break 104 | 105 | receive_thread = threading.Thread(target=receive_video) 106 | receive_thread.daemon=True 107 | receive_thread.start() 108 | receive_cam_thread = threading.Thread(target=run_cam) 109 | receive_cam_thread.daemon = True 110 | receive_cam_thread.start() 111 | 112 | sock.sendto(connection_string(port_video), ('192.168.10.1',8889)) 113 | sock.sendto(start_video(), ('192.168.10.1',8889)) 114 | sock.sendto(set_videoencoder_rate(4), ('192.168.10.1',8889)) 115 | 116 | def send_video(): 117 | try: 118 | while True: 119 | time.sleep(0.1) 120 | sock.sendto(start_video(), ('192.168.10.1',8889)) 121 | except KeyboardInterrupt: 122 | pass 123 | 124 | send_thread = threading.Thread(target=send_video) 125 | send_thread.daemon=True 126 | send_thread.start() 127 | 128 | def send_control(): 129 | try: 130 | while True: 131 | time.sleep(0.01) 132 | stick = send_stickcommand() 133 | print(stick.hex()) 134 | sock.sendto(stick, ('192.168.10.1',8889)) 135 | except KeyboardInterrupt: 136 | pass 137 | 138 | control_thread = threading.Thread(target=send_control) 139 | control_thread.daemon=True 140 | control_thread.start() 141 | 142 | sock.sendto(take_off(), ('192.168.10.1',8889)) 143 | rotation_seconds = 20 144 | time.sleep(1) 145 | for i in range(100): 146 | clockwise(i) 147 | time.sleep(rotation_seconds/100) 148 | sock.sendto(land(), ('192.168.10.1',8889)) 149 | -------------------------------------------------------------------------------- /test_video.py: -------------------------------------------------------------------------------- 1 | import socket 2 | import threading 3 | import time 4 | import pickle 5 | import cv2 6 | from Detection.MtcnnDetector import MtcnnDetector 7 | from Detection.detector import Detector 8 | from Detection.fcn_detector import FcnDetector 9 | from train_models.mtcnn_model import P_Net, R_Net, O_Net 10 | import visualization_utils 11 | 12 | thresh = [0.7, 0.1, 0.1] 13 | min_face_size = 24 14 | stride = 2 15 | slide_window = True 16 | shuffle = False 17 | detectors = [None, None, None] 18 | prefix = ['data/MTCNN_model/PNet_landmark/PNet', 'data/MTCNN_model/RNet_landmark/RNet', 'data/MTCNN_model/ONet_landmark/ONet'] 19 | epoch = [18, 14, 16] 20 | model_path = ['%s-%s' % (x, y) for x, y in zip(prefix, epoch)] 21 | PNet = FcnDetector(P_Net, model_path[0]) 22 | detectors[0] = PNet 23 | RNet = Detector(R_Net, 24, 1, model_path[1]) 24 | detectors[1] = RNet 25 | ONet = Detector(O_Net, 48, 1, model_path[2]) 26 | detectors[2] = ONet 27 | mtcnn_detector = MtcnnDetector(detectors=detectors, min_face_size=min_face_size,stride=stride, threshold=thresh, slide_window=slide_window) 28 | # in cm, etc, assumed same width during flight 29 | initial_flight_height = 20 30 | # focal length = (P * D) / W 31 | # assume D = W 32 | focal_length = initial_flight_height * 10 33 | 34 | def midpoint(x, y): 35 | return ((x[0] + y[0]) * 0.5, (x[1] + y[1]) * 0.5) 36 | 37 | def distance_to_camera(initial_width, focal_length, virtual_width): 38 | return (initial_width * focalLength) / virtual_width 39 | 40 | header = b'\x00\x00\x00\x01gM@(\x95\xa0<\x05\xb9\x00\x00\x00\x01h\xee8\x80' 41 | h264 = [] 42 | 43 | sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, socket.IPPROTO_UDP) 44 | sock_video = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, socket.IPPROTO_UDP) 45 | sock_video.bind(('192.168.10.3', 6038)) 46 | 47 | def receive_video(): 48 | global h264 49 | while True: 50 | data, addr = sock_video.recvfrom(2048) 51 | if len(data[2:]) not in [8,13]: 52 | h264.append(data[2:]) 53 | 54 | def run_cam(): 55 | global h264 56 | while True: 57 | k = sum([int(len(i) < 1000) for i in h264]) 58 | temp = [] 59 | for i in reversed(range(len(h264))): 60 | if len(h264[i]) < 1000: 61 | count, temp = 0, [h264[i]] 62 | for n in reversed(range(len(h264[:i]))): 63 | if len(h264[n]) < 1000: 64 | count += 1 65 | if count == 3: 66 | break 67 | temp.append(h264[n]) 68 | break 69 | if k > 2: 70 | with open('temp.h264','wb') as fopen: 71 | fopen.write(header+b''.join(temp[::-1])) 72 | h264.clear() 73 | cap = cv2.VideoCapture('temp.h264') 74 | while True: 75 | try: 76 | last_time = time.time() 77 | ret, img = cap.read() 78 | boxes_c,_ = mtcnn_detector.detect(img) 79 | for u in range(boxes_c.shape[0]): 80 | bbox = boxes_c[u, :4] 81 | tl,tr,bl,br = [int(bbox[0]),int(bbox[1])],[int(bbox[2]),int(bbox[1])],[int(bbox[0]),int(bbox[3])],[int(bbox[2]),int(bbox[3])] 82 | (tltrX, tltrY) = midpoint(tl, tr) 83 | (blbrX, blbrY) = midpoint(bl, br) 84 | (tlblX, tlblY) = midpoint(tl, bl) 85 | (trbrX, trbrY) = midpoint(tr, br) 86 | # virtual width 87 | dA = dist.euclidean((tltrX, tltrY), (blbrX, blbrY)) 88 | # virtual height 89 | dB = dist.euclidean((tlblX, tlblY), (trbrX, trbrY)) 90 | distance = distance_to_camera(initial_flight_height, focal_length, dA) 91 | visualization_utils.draw_bounding_box_on_image_array(img,int(bbox[1]),int(bbox[0]), 92 | int(bbox[3]), 93 | int(bbox[2]), 94 | 'YellowGreen',display_str_list=['face','','distance: %.2fcm'%(distance)], 95 | use_normalized_coordinates=False) 96 | cv2.putText(img,'%.1f FPS'%(1/(time.time() - last_time)), (0,20), cv2.FONT_HERSHEY_SIMPLEX, 0.5, 255) 97 | cv2.imshow('cam',img) 98 | if cv2.waitKey(30) & 0xFF == ord('q'): 99 | break 100 | except: 101 | break 102 | 103 | receive_thread = threading.Thread(target=receive_video) 104 | receive_thread.daemon=True 105 | receive_thread.start() 106 | receive_cam_thread = threading.Thread(target=run_cam) 107 | receive_cam_thread.daemon = True 108 | receive_cam_thread.start() 109 | 110 | command = bytearray.fromhex('9617') 111 | sock.sendto(b'conn_req:'+bytes(command), ('192.168.10.1',8889)) 112 | sock.sendto(bytes(bytearray.fromhex('cc58007c60250000006c95')), ('192.168.10.1',8889)) 113 | sock.sendto(bytes(bytearray.fromhex('cc600027682000000004fd9b')), ('192.168.10.1',8889)) 114 | try: 115 | while True: 116 | time.sleep(0.1) 117 | sock.sendto(bytes(bytearray.fromhex('cc58007c60250000006c95')), ('192.168.10.1',8889)) 118 | except KeyboardInterrupt: 119 | pass 120 | -------------------------------------------------------------------------------- /train_models/MTCNN_config.py: -------------------------------------------------------------------------------- 1 | #coding:utf-8 2 | 3 | from easydict import EasyDict as edict 4 | 5 | config = edict() 6 | 7 | config.BATCH_SIZE = 384 8 | config.CLS_OHEM = True 9 | config.CLS_OHEM_RATIO = 0.7 10 | config.BBOX_OHEM = False 11 | config.BBOX_OHEM_RATIO = 0.7 12 | 13 | config.EPS = 1e-14 14 | config.LR_EPOCH = [6,14,20] 15 | -------------------------------------------------------------------------------- /train_models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DevconX/Tello-Python/0e7ef8375e6904a536ff274ec7c868388424327e/train_models/__init__.py -------------------------------------------------------------------------------- /train_models/mtcnn_model.py: -------------------------------------------------------------------------------- 1 | #coding:utf-8 2 | import tensorflow as tf 3 | from tensorflow.contrib import slim 4 | import numpy as np 5 | num_keep_radio = 0.7 6 | #define prelu 7 | def prelu(inputs): 8 | alphas = tf.get_variable("alphas", shape=inputs.get_shape()[-1], dtype=tf.float32, initializer=tf.constant_initializer(0.25)) 9 | pos = tf.nn.relu(inputs) 10 | neg = alphas * (inputs-abs(inputs))*0.5 11 | return pos + neg 12 | def dense_to_one_hot(labels_dense,num_classes): 13 | num_labels = labels_dense.shape[0] 14 | index_offset = np.arange(num_labels)*num_classes 15 | #num_sample*num_classes 16 | labels_one_hot = np.zeros((num_labels,num_classes)) 17 | labels_one_hot.flat[index_offset + labels_dense.ravel()] = 1 18 | return labels_one_hot 19 | #cls_prob:batch*2 20 | #label:batch 21 | 22 | def cls_ohem(cls_prob, label): 23 | zeros = tf.zeros_like(label) 24 | #label=-1 --> label=0net_factory 25 | label_filter_invalid = tf.where(tf.less(label,0), zeros, label) 26 | num_cls_prob = tf.size(cls_prob) 27 | cls_prob_reshape = tf.reshape(cls_prob,[num_cls_prob,-1]) 28 | label_int = tf.cast(label_filter_invalid,tf.int32) 29 | num_row = tf.to_int32(cls_prob.get_shape()[0]) 30 | row = tf.range(num_row)*2 31 | indices_ = row + label_int 32 | label_prob = tf.squeeze(tf.gather(cls_prob_reshape, indices_)) 33 | loss = -tf.log(label_prob+1e-10) 34 | zeros = tf.zeros_like(label_prob, dtype=tf.float32) 35 | ones = tf.ones_like(label_prob,dtype=tf.float32) 36 | valid_inds = tf.where(label < zeros,zeros,ones) 37 | num_valid = tf.reduce_sum(valid_inds) 38 | keep_num = tf.cast(num_valid*num_keep_radio,dtype=tf.int32) 39 | #set 0 to invalid sample 40 | loss = loss * valid_inds 41 | loss,_ = tf.nn.top_k(loss, k=keep_num) 42 | return tf.reduce_mean(loss) 43 | def bbox_ohem_smooth_L1_loss(bbox_pred,bbox_target,label): 44 | sigma = tf.constant(1.0) 45 | threshold = 1.0/(sigma**2) 46 | zeros_index = tf.zeros_like(label, dtype=tf.float32) 47 | valid_inds = tf.where(label!=zeros_index,tf.ones_like(label,dtype=tf.float32),zeros_index) 48 | abs_error = tf.abs(bbox_pred-bbox_target) 49 | loss_smaller = 0.5*((abs_error*sigma)**2) 50 | loss_larger = abs_error-0.5/(sigma**2) 51 | smooth_loss = tf.reduce_sum(tf.where(abs_error total_display_str_height: 186 | text_bottom = top 187 | else: 188 | text_bottom = bottom + total_display_str_height 189 | # Reverse list and print from bottom to top. 190 | for display_str in display_str_list[::-1]: 191 | text_width, text_height = font.getsize(display_str) 192 | margin = np.ceil(0.05 * text_height) 193 | draw.rectangle( 194 | [(left, text_bottom - text_height - 2 * margin), (left + text_width, 195 | text_bottom)], 196 | fill=color) 197 | draw.text( 198 | (left + margin, text_bottom - text_height - margin), 199 | display_str, 200 | fill='black', 201 | font=font) 202 | text_bottom -= text_height - 2 * margin 203 | 204 | 205 | def draw_bounding_boxes_on_image_array(image, 206 | boxes, 207 | color='red', 208 | thickness=4, 209 | display_str_list_list=()): 210 | """Draws bounding boxes on image (numpy array). 211 | 212 | Args: 213 | image: a numpy array object. 214 | boxes: a 2 dimensional numpy array of [N, 4]: (ymin, xmin, ymax, xmax). 215 | The coordinates are in normalized format between [0, 1]. 216 | color: color to draw bounding box. Default is red. 217 | thickness: line thickness. Default value is 4. 218 | display_str_list_list: list of list of strings. 219 | a list of strings for each bounding box. 220 | The reason to pass a list of strings for a 221 | bounding box is that it might contain 222 | multiple labels. 223 | 224 | Raises: 225 | ValueError: if boxes is not a [N, 4] array 226 | """ 227 | image_pil = Image.fromarray(image) 228 | draw_bounding_boxes_on_image(image_pil, boxes, color, thickness, 229 | display_str_list_list) 230 | np.copyto(image, np.array(image_pil)) 231 | 232 | 233 | def draw_bounding_boxes_on_image(image, 234 | boxes, 235 | color='red', 236 | thickness=4, 237 | display_str_list_list=()): 238 | """Draws bounding boxes on image. 239 | 240 | Args: 241 | image: a PIL.Image object. 242 | boxes: a 2 dimensional numpy array of [N, 4]: (ymin, xmin, ymax, xmax). 243 | The coordinates are in normalized format between [0, 1]. 244 | color: color to draw bounding box. Default is red. 245 | thickness: line thickness. Default value is 4. 246 | display_str_list_list: list of list of strings. 247 | a list of strings for each bounding box. 248 | The reason to pass a list of strings for a 249 | bounding box is that it might contain 250 | multiple labels. 251 | 252 | Raises: 253 | ValueError: if boxes is not a [N, 4] array 254 | """ 255 | boxes_shape = boxes.shape 256 | if not boxes_shape: 257 | return 258 | if len(boxes_shape) != 2 or boxes_shape[1] != 4: 259 | raise ValueError('Input must be of size [N, 4]') 260 | for i in range(boxes_shape[0]): 261 | display_str_list = () 262 | if display_str_list_list: 263 | display_str_list = display_str_list_list[i] 264 | draw_bounding_box_on_image(image, boxes[i, 0], boxes[i, 1], boxes[i, 2], 265 | boxes[i, 3], color, thickness, display_str_list) 266 | 267 | 268 | def _visualize_boxes(image, boxes, classes, scores, category_index, **kwargs): 269 | return visualize_boxes_and_labels_on_image_array( 270 | image, boxes, classes, scores, category_index=category_index, **kwargs) 271 | 272 | 273 | def _visualize_boxes_and_masks(image, boxes, classes, scores, masks, 274 | category_index, **kwargs): 275 | return visualize_boxes_and_labels_on_image_array( 276 | image, 277 | boxes, 278 | classes, 279 | scores, 280 | category_index=category_index, 281 | instance_masks=masks, 282 | **kwargs) 283 | 284 | 285 | def _visualize_boxes_and_keypoints(image, boxes, classes, scores, keypoints, 286 | category_index, **kwargs): 287 | return visualize_boxes_and_labels_on_image_array( 288 | image, 289 | boxes, 290 | classes, 291 | scores, 292 | category_index=category_index, 293 | keypoints=keypoints, 294 | **kwargs) 295 | 296 | 297 | def _visualize_boxes_and_masks_and_keypoints( 298 | image, boxes, classes, scores, masks, keypoints, category_index, **kwargs): 299 | return visualize_boxes_and_labels_on_image_array( 300 | image, 301 | boxes, 302 | classes, 303 | scores, 304 | category_index=category_index, 305 | instance_masks=masks, 306 | keypoints=keypoints, 307 | **kwargs) 308 | 309 | 310 | def draw_bounding_boxes_on_image_tensors(images, 311 | boxes, 312 | classes, 313 | scores, 314 | category_index, 315 | instance_masks=None, 316 | keypoints=None, 317 | max_boxes_to_draw=20, 318 | min_score_thresh=0.2): 319 | """Draws bounding boxes, masks, and keypoints on batch of image tensors. 320 | 321 | Args: 322 | images: A 4D uint8 image tensor of shape [N, H, W, C]. 323 | boxes: [N, max_detections, 4] float32 tensor of detection boxes. 324 | classes: [N, max_detections] int tensor of detection classes. Note that 325 | classes are 1-indexed. 326 | scores: [N, max_detections] float32 tensor of detection scores. 327 | category_index: a dict that maps integer ids to category dicts. e.g. 328 | {1: {1: 'dog'}, 2: {2: 'cat'}, ...} 329 | instance_masks: A 4D uint8 tensor of shape [N, max_detection, H, W] with 330 | instance masks. 331 | keypoints: A 4D float32 tensor of shape [N, max_detection, num_keypoints, 2] 332 | with keypoints. 333 | max_boxes_to_draw: Maximum number of boxes to draw on an image. Default 20. 334 | min_score_thresh: Minimum score threshold for visualization. Default 0.2. 335 | 336 | Returns: 337 | 4D image tensor of type uint8, with boxes drawn on top. 338 | """ 339 | visualization_keyword_args = { 340 | 'use_normalized_coordinates': True, 341 | 'max_boxes_to_draw': max_boxes_to_draw, 342 | 'min_score_thresh': min_score_thresh, 343 | 'agnostic_mode': False, 344 | 'line_thickness': 4 345 | } 346 | 347 | if instance_masks is not None and keypoints is None: 348 | visualize_boxes_fn = functools.partial( 349 | _visualize_boxes_and_masks, 350 | category_index=category_index, 351 | **visualization_keyword_args) 352 | elems = [images, boxes, classes, scores, instance_masks] 353 | elif instance_masks is None and keypoints is not None: 354 | visualize_boxes_fn = functools.partial( 355 | _visualize_boxes_and_keypoints, 356 | category_index=category_index, 357 | **visualization_keyword_args) 358 | elems = [images, boxes, classes, scores, keypoints] 359 | elif instance_masks is not None and keypoints is not None: 360 | visualize_boxes_fn = functools.partial( 361 | _visualize_boxes_and_masks_and_keypoints, 362 | category_index=category_index, 363 | **visualization_keyword_args) 364 | elems = [images, boxes, classes, scores, instance_masks, keypoints] 365 | else: 366 | visualize_boxes_fn = functools.partial( 367 | _visualize_boxes, 368 | category_index=category_index, 369 | **visualization_keyword_args) 370 | elems = [images, boxes, classes, scores] 371 | 372 | def draw_boxes(image_and_detections): 373 | """Draws boxes on image.""" 374 | image_with_boxes = tf.py_func(visualize_boxes_fn, image_and_detections, 375 | tf.uint8) 376 | return image_with_boxes 377 | 378 | images = tf.map_fn(draw_boxes, elems, dtype=tf.uint8, back_prop=False) 379 | return images 380 | 381 | 382 | def draw_side_by_side_evaluation_image(eval_dict, 383 | category_index, 384 | max_boxes_to_draw=20, 385 | min_score_thresh=0.2): 386 | """Creates a side-by-side image with detections and groundtruth. 387 | 388 | Bounding boxes (and instance masks, if available) are visualized on both 389 | subimages. 390 | 391 | Args: 392 | eval_dict: The evaluation dictionary returned by 393 | eval_util.result_dict_for_single_example(). 394 | category_index: A category index (dictionary) produced from a labelmap. 395 | max_boxes_to_draw: The maximum number of boxes to draw for detections. 396 | min_score_thresh: The minimum score threshold for showing detections. 397 | 398 | Returns: 399 | A [1, H, 2 * W, C] uint8 tensor. The subimage on the left corresponds to 400 | detections, while the subimage on the right corresponds to groundtruth. 401 | """ 402 | detection_fields = fields.DetectionResultFields() 403 | input_data_fields = fields.InputDataFields() 404 | instance_masks = None 405 | if detection_fields.detection_masks in eval_dict: 406 | instance_masks = tf.cast( 407 | tf.expand_dims(eval_dict[detection_fields.detection_masks], axis=0), 408 | tf.uint8) 409 | keypoints = None 410 | if detection_fields.detection_keypoints in eval_dict: 411 | keypoints = tf.expand_dims( 412 | eval_dict[detection_fields.detection_keypoints], axis=0) 413 | groundtruth_instance_masks = None 414 | if input_data_fields.groundtruth_instance_masks in eval_dict: 415 | groundtruth_instance_masks = tf.cast( 416 | tf.expand_dims( 417 | eval_dict[input_data_fields.groundtruth_instance_masks], axis=0), 418 | tf.uint8) 419 | images_with_detections = draw_bounding_boxes_on_image_tensors( 420 | eval_dict[input_data_fields.original_image], 421 | tf.expand_dims(eval_dict[detection_fields.detection_boxes], axis=0), 422 | tf.expand_dims(eval_dict[detection_fields.detection_classes], axis=0), 423 | tf.expand_dims(eval_dict[detection_fields.detection_scores], axis=0), 424 | category_index, 425 | instance_masks=instance_masks, 426 | keypoints=keypoints, 427 | max_boxes_to_draw=max_boxes_to_draw, 428 | min_score_thresh=min_score_thresh) 429 | images_with_groundtruth = draw_bounding_boxes_on_image_tensors( 430 | eval_dict[input_data_fields.original_image], 431 | tf.expand_dims(eval_dict[input_data_fields.groundtruth_boxes], axis=0), 432 | tf.expand_dims(eval_dict[input_data_fields.groundtruth_classes], axis=0), 433 | tf.expand_dims( 434 | tf.ones_like( 435 | eval_dict[input_data_fields.groundtruth_classes], 436 | dtype=tf.float32), 437 | axis=0), 438 | category_index, 439 | instance_masks=groundtruth_instance_masks, 440 | keypoints=None, 441 | max_boxes_to_draw=None, 442 | min_score_thresh=0.0) 443 | return tf.concat([images_with_detections, images_with_groundtruth], axis=2) 444 | 445 | 446 | def draw_keypoints_on_image_array(image, 447 | keypoints, 448 | color='red', 449 | radius=2, 450 | use_normalized_coordinates=True): 451 | """Draws keypoints on an image (numpy array). 452 | 453 | Args: 454 | image: a numpy array with shape [height, width, 3]. 455 | keypoints: a numpy array with shape [num_keypoints, 2]. 456 | color: color to draw the keypoints with. Default is red. 457 | radius: keypoint radius. Default value is 2. 458 | use_normalized_coordinates: if True (default), treat keypoint values as 459 | relative to the image. Otherwise treat them as absolute. 460 | """ 461 | image_pil = Image.fromarray(np.uint8(image)).convert('RGB') 462 | draw_keypoints_on_image(image_pil, keypoints, color, radius, 463 | use_normalized_coordinates) 464 | np.copyto(image, np.array(image_pil)) 465 | 466 | 467 | def draw_keypoints_on_image(image, 468 | keypoints, 469 | color='red', 470 | radius=2, 471 | use_normalized_coordinates=True): 472 | """Draws keypoints on an image. 473 | 474 | Args: 475 | image: a PIL.Image object. 476 | keypoints: a numpy array with shape [num_keypoints, 2]. 477 | color: color to draw the keypoints with. Default is red. 478 | radius: keypoint radius. Default value is 2. 479 | use_normalized_coordinates: if True (default), treat keypoint values as 480 | relative to the image. Otherwise treat them as absolute. 481 | """ 482 | draw = ImageDraw.Draw(image) 483 | im_width, im_height = image.size 484 | keypoints_x = [k[1] for k in keypoints] 485 | keypoints_y = [k[0] for k in keypoints] 486 | if use_normalized_coordinates: 487 | keypoints_x = tuple([im_width * x for x in keypoints_x]) 488 | keypoints_y = tuple([im_height * y for y in keypoints_y]) 489 | for keypoint_x, keypoint_y in zip(keypoints_x, keypoints_y): 490 | draw.ellipse([(keypoint_x - radius, keypoint_y - radius), 491 | (keypoint_x + radius, keypoint_y + radius)], 492 | outline=color, fill=color) 493 | 494 | 495 | def draw_mask_on_image_array(image, mask, color='red', alpha=0.4): 496 | """Draws mask on an image. 497 | 498 | Args: 499 | image: uint8 numpy array with shape (img_height, img_height, 3) 500 | mask: a uint8 numpy array of shape (img_height, img_height) with 501 | values between either 0 or 1. 502 | color: color to draw the keypoints with. Default is red. 503 | alpha: transparency value between 0 and 1. (default: 0.4) 504 | 505 | Raises: 506 | ValueError: On incorrect data type for image or masks. 507 | """ 508 | if image.dtype != np.uint8: 509 | raise ValueError('`image` not of type np.uint8') 510 | if mask.dtype != np.uint8: 511 | raise ValueError('`mask` not of type np.uint8') 512 | if np.any(np.logical_and(mask != 1, mask != 0)): 513 | raise ValueError('`mask` elements should be in [0, 1]') 514 | if image.shape[:2] != mask.shape: 515 | raise ValueError('The image has spatial dimensions %s but the mask has ' 516 | 'dimensions %s' % (image.shape[:2], mask.shape)) 517 | rgb = ImageColor.getrgb(color) 518 | pil_image = Image.fromarray(image) 519 | 520 | solid_color = np.expand_dims( 521 | np.ones_like(mask), axis=2) * np.reshape(list(rgb), [1, 1, 3]) 522 | pil_solid_color = Image.fromarray(np.uint8(solid_color)).convert('RGBA') 523 | pil_mask = Image.fromarray(np.uint8(255.0*alpha*mask)).convert('L') 524 | pil_image = Image.composite(pil_solid_color, pil_image, pil_mask) 525 | np.copyto(image, np.array(pil_image.convert('RGB'))) 526 | 527 | 528 | def visualize_boxes_and_labels_on_image_array( 529 | image, 530 | boxes, 531 | classes, 532 | scores, 533 | category_index, 534 | instance_masks=None, 535 | instance_boundaries=None, 536 | keypoints=None, 537 | use_normalized_coordinates=False, 538 | max_boxes_to_draw=20, 539 | min_score_thresh=.5, 540 | agnostic_mode=False, 541 | line_thickness=4, 542 | groundtruth_box_visualization_color='black', 543 | skip_scores=False, 544 | skip_labels=False): 545 | """Overlay labeled boxes on an image with formatted scores and label names. 546 | 547 | This function groups boxes that correspond to the same location 548 | and creates a display string for each detection and overlays these 549 | on the image. Note that this function modifies the image in place, and returns 550 | that same image. 551 | 552 | Args: 553 | image: uint8 numpy array with shape (img_height, img_width, 3) 554 | boxes: a numpy array of shape [N, 4] 555 | classes: a numpy array of shape [N]. Note that class indices are 1-based, 556 | and match the keys in the label map. 557 | scores: a numpy array of shape [N] or None. If scores=None, then 558 | this function assumes that the boxes to be plotted are groundtruth 559 | boxes and plot all boxes as black with no classes or scores. 560 | category_index: a dict containing category dictionaries (each holding 561 | category index `id` and category name `name`) keyed by category indices. 562 | instance_masks: a numpy array of shape [N, image_height, image_width] with 563 | values ranging between 0 and 1, can be None. 564 | instance_boundaries: a numpy array of shape [N, image_height, image_width] 565 | with values ranging between 0 and 1, can be None. 566 | keypoints: a numpy array of shape [N, num_keypoints, 2], can 567 | be None 568 | use_normalized_coordinates: whether boxes is to be interpreted as 569 | normalized coordinates or not. 570 | max_boxes_to_draw: maximum number of boxes to visualize. If None, draw 571 | all boxes. 572 | min_score_thresh: minimum score threshold for a box to be visualized 573 | agnostic_mode: boolean (default: False) controlling whether to evaluate in 574 | class-agnostic mode or not. This mode will display scores but ignore 575 | classes. 576 | line_thickness: integer (default: 4) controlling line width of the boxes. 577 | groundtruth_box_visualization_color: box color for visualizing groundtruth 578 | boxes 579 | skip_scores: whether to skip score when drawing a single detection 580 | skip_labels: whether to skip label when drawing a single detection 581 | 582 | Returns: 583 | uint8 numpy array with shape (img_height, img_width, 3) with overlaid boxes. 584 | """ 585 | # Create a display string (and color) for every box location, group any boxes 586 | # that correspond to the same location. 587 | box_to_display_str_map = collections.defaultdict(list) 588 | box_to_color_map = collections.defaultdict(str) 589 | box_to_instance_masks_map = {} 590 | box_to_instance_boundaries_map = {} 591 | box_to_keypoints_map = collections.defaultdict(list) 592 | if not max_boxes_to_draw: 593 | max_boxes_to_draw = boxes.shape[0] 594 | for i in range(min(max_boxes_to_draw, boxes.shape[0])): 595 | if scores is None or scores[i] > min_score_thresh: 596 | box = tuple(boxes[i].tolist()) 597 | if instance_masks is not None: 598 | box_to_instance_masks_map[box] = instance_masks[i] 599 | if instance_boundaries is not None: 600 | box_to_instance_boundaries_map[box] = instance_boundaries[i] 601 | if keypoints is not None: 602 | box_to_keypoints_map[box].extend(keypoints[i]) 603 | if scores is None: 604 | box_to_color_map[box] = groundtruth_box_visualization_color 605 | else: 606 | display_str = '' 607 | if not skip_labels: 608 | if not agnostic_mode: 609 | if classes[i] in category_index.keys(): 610 | class_name = category_index[classes[i]]['name'] 611 | else: 612 | class_name = 'N/A' 613 | display_str = str(class_name) 614 | if not skip_scores: 615 | if not display_str: 616 | display_str = '{}%'.format(int(100*scores[i])) 617 | else: 618 | display_str = '{}: {}%'.format(display_str, int(100*scores[i])) 619 | box_to_display_str_map[box].append(display_str) 620 | if agnostic_mode: 621 | box_to_color_map[box] = 'DarkOrange' 622 | else: 623 | box_to_color_map[box] = STANDARD_COLORS[ 624 | classes[i] % len(STANDARD_COLORS)] 625 | 626 | # Draw all boxes onto image. 627 | for box, color in box_to_color_map.items(): 628 | ymin, xmin, ymax, xmax = box 629 | if instance_masks is not None: 630 | draw_mask_on_image_array( 631 | image, 632 | box_to_instance_masks_map[box], 633 | color=color 634 | ) 635 | if instance_boundaries is not None: 636 | draw_mask_on_image_array( 637 | image, 638 | box_to_instance_boundaries_map[box], 639 | color='red', 640 | alpha=1.0 641 | ) 642 | draw_bounding_box_on_image_array( 643 | image, 644 | ymin, 645 | xmin, 646 | ymax, 647 | xmax, 648 | color=color, 649 | thickness=line_thickness, 650 | display_str_list=box_to_display_str_map[box], 651 | use_normalized_coordinates=use_normalized_coordinates) 652 | if keypoints is not None: 653 | draw_keypoints_on_image_array( 654 | image, 655 | box_to_keypoints_map[box], 656 | color=color, 657 | radius=line_thickness / 2, 658 | use_normalized_coordinates=use_normalized_coordinates) 659 | 660 | return image 661 | 662 | 663 | def add_cdf_image_summary(values, name): 664 | """Adds a tf.summary.image for a CDF plot of the values. 665 | 666 | Normalizes `values` such that they sum to 1, plots the cumulative distribution 667 | function and creates a tf image summary. 668 | 669 | Args: 670 | values: a 1-D float32 tensor containing the values. 671 | name: name for the image summary. 672 | """ 673 | def cdf_plot(values): 674 | """Numpy function to plot CDF.""" 675 | normalized_values = values / np.sum(values) 676 | sorted_values = np.sort(normalized_values) 677 | cumulative_values = np.cumsum(sorted_values) 678 | fraction_of_examples = (np.arange(cumulative_values.size, dtype=np.float32) 679 | / cumulative_values.size) 680 | fig = plt.figure(frameon=False) 681 | ax = fig.add_subplot('111') 682 | ax.plot(fraction_of_examples, cumulative_values) 683 | ax.set_ylabel('cumulative normalized values') 684 | ax.set_xlabel('fraction of examples') 685 | fig.canvas.draw() 686 | width, height = fig.get_size_inches() * fig.get_dpi() 687 | image = np.fromstring(fig.canvas.tostring_rgb(), dtype='uint8').reshape( 688 | 1, int(height), int(width), 3) 689 | return image 690 | cdf_plot = tf.py_func(cdf_plot, [values], tf.uint8) 691 | tf.summary.image(name, cdf_plot) 692 | 693 | 694 | def add_hist_image_summary(values, bins, name): 695 | """Adds a tf.summary.image for a histogram plot of the values. 696 | 697 | Plots the histogram of values and creates a tf image summary. 698 | 699 | Args: 700 | values: a 1-D float32 tensor containing the values. 701 | bins: bin edges which will be directly passed to np.histogram. 702 | name: name for the image summary. 703 | """ 704 | 705 | def hist_plot(values, bins): 706 | """Numpy function to plot hist.""" 707 | fig = plt.figure(frameon=False) 708 | ax = fig.add_subplot('111') 709 | y, x = np.histogram(values, bins=bins) 710 | ax.plot(x[:-1], y) 711 | ax.set_ylabel('count') 712 | ax.set_xlabel('value') 713 | fig.canvas.draw() 714 | width, height = fig.get_size_inches() * fig.get_dpi() 715 | image = np.fromstring( 716 | fig.canvas.tostring_rgb(), dtype='uint8').reshape( 717 | 1, int(height), int(width), 3) 718 | return image 719 | hist_plot = tf.py_func(hist_plot, [values, bins], tf.uint8) 720 | tf.summary.image(name, hist_plot) 721 | --------------------------------------------------------------------------------