├── .gitignore ├── README.md ├── imgs ├── simple检测结果.gif ├── 标注示例.png └── 物体检测结果.gif ├── label.zip ├── simple ├── simple.js ├── simple.py └── simple_ios.py ├── tensorflow ├── frozen_inference_graph_frcnn_inception_v2_coco.pb ├── utils │ ├── __pycache__ │ │ ├── label_map_util.cpython-36.pyc │ │ └── visualization_utils.cpython-36.pyc │ ├── dataset_util.py │ ├── label_map_util.py │ └── visualization_utils.py ├── wechat_auto_jump.py └── wechat_jump_label_map.pbtxt └── 物体检测标注说明.pdf /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | simple/screenshot.png 3 | simple/detection.png 4 | tensorflow/screenshot.png 5 | tensorflow/detection.png -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 深度学习 - 微信跳一跳 2 | 3 | ### 2018.01.14更新 4 | 5 | `simple`目录下增加了`simple.js`,思路同`simple.py`,使用`JavaScript`编写,在安卓上安装`Auto.js`之后运行该脚本即可,好处是直接在手机上运行,不需要连电脑 6 | 7 | ### 2018.01.05更新 8 | 9 | 标注数据增加到1200张图片,并且用更准的`faster_rcnn_inception_v2_coco`模型重新训练了一遍 10 | 11 | ### 项目介绍 12 | 13 | 知乎文章:[https://zhuanlan.zhihu.com/p/32553763](https://zhuanlan.zhihu.com/p/32553763) 14 | 15 | 感谢[Chao](https://github.com/loveu520)、[奋逗逗](https://github.com/liuzhenhui)对于标注数据做出的贡献 16 | 17 | ### 所需环境 18 | 19 | - `Python3.6`、`OpenCV2`、`TensorFlow`等 20 | - `adb`,用于调试安卓手机,参考[https://github.com/wangshub/wechat_jump_game](https://github.com/wangshub/wechat_jump_game) 21 | 22 | ### 文件介绍 23 | 24 | `simple`目录中的`simple.py`使用`OpenCV2`检测棋子和目标块的位置,简单粗暴,`simple_ios.py`是对应的IOS版本 25 | 26 | ![simple检测结果](imgs/simple检测结果.gif) 27 | 28 | `tensorflow`目录包括以下文件: 29 | 30 | - `wechat_jump_label_map.pbtxt`:物体类别映射文件; 31 | - `utils`:提供辅助功能的文件; 32 | - `frozen_inference_graph_frcnn_inception_v2_coco.pb`:训练好的物体检测模型,共1200张标注数据,使用`faster_rcnn_inception_v2_coco`训练; 33 | - `wechat_auto_jump.py`:自动跳一跳的代码 34 | 35 | ![物体检测结果](imgs/物体检测结果.gif) 36 | 37 | `label.zip`提供了标注的工具,使用[labelImg](https://github.com/tzutalin/labelImg)进行物体检测标注,使用方法可以参考`物体检测标注说明.pdf` 38 | 39 | ![labelImg标注示例](imgs/标注示例.png) -------------------------------------------------------------------------------- /imgs/simple检测结果.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Honlan/wechat_jump_tensorflow/75ee565c2eae489be841734c09cdeb703dcde251/imgs/simple检测结果.gif -------------------------------------------------------------------------------- /imgs/标注示例.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Honlan/wechat_jump_tensorflow/75ee565c2eae489be841734c09cdeb703dcde251/imgs/标注示例.png -------------------------------------------------------------------------------- /imgs/物体检测结果.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Honlan/wechat_jump_tensorflow/75ee565c2eae489be841734c09cdeb703dcde251/imgs/物体检测结果.gif -------------------------------------------------------------------------------- /label.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Honlan/wechat_jump_tensorflow/75ee565c2eae489be841734c09cdeb703dcde251/label.zip -------------------------------------------------------------------------------- /simple/simple.js: -------------------------------------------------------------------------------- 1 | console.show(); 2 | 3 | // 设备信息 4 | var WIDTH = device.width, 5 | HEIGHT = device.height, 6 | TYPE = device.brand + ' ' + device.model; 7 | log('设备信息:', TYPE, '\n分辨率:', WIDTH, '*', HEIGHT); 8 | 9 | // 获取截图权限 10 | if (!requestScreenCapture()) { 11 | toast('请求截图失败,程序结束'); 12 | exit(); 13 | } 14 | 15 | // 启动微信 16 | launchApp('微信'); 17 | 18 | // 提示用户进入跳一跳页面 19 | new java.lang.Thread(function() { 20 | packageName('com.stardust.scriptdroid').className('android.widget.EditText').setText('准备好后点击 确定'); 21 | }).start(); 22 | console.rawInput('进入微信跳一跳,点击 开始游戏\n点击 确定 开始自动游戏'); 23 | 24 | do { 25 | // 获取截图 26 | var img = captureScreen(); 27 | 28 | // 触按位置 29 | var bx1 = parseInt(WIDTH / 2 + random(-10, 10)), 30 | bx2 = parseInt(WIDTH / 2 + random(-10, 10)), 31 | by1 = parseInt(HEIGHT * 0.785 + random(-4, 4)), 32 | by2 = parseInt(HEIGHT * 0.785 + random(-4, 4)); 33 | 34 | // 棋子底部中心找色 35 | var CHESS_X, CHESS_Y; 36 | var linemax = 0; 37 | for (let r = parseInt(HEIGHT * 0.7); r > parseInt(HEIGHT * 0.5);) { 38 | var line = []; 39 | for (let c = parseInt(WIDTH * 0.15); c < parseInt(WIDTH * 0.85); c++) { 40 | var point = images.pixel(img, c, r); 41 | var red = colors.red(point), 42 | green = colors.green(point), 43 | blue = colors.blue(point); 44 | if (red >= 40 && red <= 70 && green >= 40 && green <= 60 && blue >= 70 && blue <= 105) { 45 | line.push(c); 46 | } 47 | } 48 | if (line.length > linemax) { 49 | linemax = line.length; 50 | CHESS_X = line[Math.floor(line.length / 2)]; 51 | CHESS_Y = r; 52 | } 53 | else if (line.length < linemax) { 54 | break; 55 | } 56 | r -= 5; 57 | } 58 | log('棋子X坐标:', CHESS_X); 59 | 60 | // 目标块顶部中心X坐标 61 | var TARGET_X, TARGET_Y; 62 | for (let r = parseInt(HEIGHT * 0.3); r <= parseInt(HEIGHT * 0.5);) { 63 | var flag = false; 64 | for (let c = parseInt(WIDTH * 0.15); c < parseInt(WIDTH * 0.85); c++) { 65 | if (Math.abs(c - CHESS_X) <= linemax) { 66 | continue 67 | } 68 | var c0 = images.pixel(img, c, r); 69 | var c1 = images.pixel(img, c, r - 5); 70 | if (Math.abs(colors.red(c0) - colors.red(c1)) + Math.abs(colors.green(c0) - colors.green(c1)) + Math.abs(colors.blue(c0) - colors.blue(c1)) >= 30) { 71 | TARGET_X = c; 72 | TARGET_Y = r; 73 | flag = true; 74 | break; 75 | } 76 | } 77 | if (flag) { 78 | break; 79 | } 80 | r += 5; 81 | } 82 | // 寻找白点 83 | var whitepoint = images.findColor(img, '#f5f5f5', { 84 | region: [TARGET_X - 20, TARGET_Y, 40, 250], 85 | threshold: 2 86 | }); 87 | if (whitepoint) { 88 | TARGET_X = whitepoint.x; 89 | } 90 | log('目标块X坐标:', TARGET_X); 91 | 92 | // 跳! 93 | swipe(bx1, by1, bx2, by2, Math.abs(CHESS_X - TARGET_X) / WIDTH * 1900); 94 | sleep(random(1500, 2000)); 95 | } while (true); -------------------------------------------------------------------------------- /simple/simple.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import numpy as np 4 | import cv2 5 | import os 6 | import time 7 | import re 8 | 9 | # 屏幕截图 10 | def pull_screenshot(path): 11 | os.system('adb shell screencap -p /sdcard/%s' % path) 12 | os.system('adb pull /sdcard/%s .' % path) 13 | 14 | # 根据x距离跳跃 15 | def jump(distance, alpha): 16 | press_time = max(int(distance * alpha), 200) 17 | 18 | cmd = 'adb shell input swipe {} {} {} {} {}'.format(bx1, by1, bx2, by2, press_time) 19 | os.system(cmd) 20 | 21 | screenshot = 'screenshot.png' 22 | alpha = 0 23 | bx1, by1, bx2, by2 = 0, 0, 0, 0 24 | chess_x = 0 25 | target_x = 0 26 | 27 | fix = 1.6667 28 | # 检查分辨率是否是960x540 29 | size_str = os.popen('adb shell wm size').read() 30 | if size_str: 31 | m = re.search(r'(\d+)x(\d+)', size_str) 32 | if m: 33 | hxw = "{height}x{width}".format(height=m.group(2), width=m.group(1)) 34 | if hxw == "960x540": 35 | fix = 3.16 36 | 37 | while True: 38 | pull_screenshot(screenshot) 39 | image_np = cv2.imread(screenshot) 40 | image_np = cv2.cvtColor(image_np, cv2.COLOR_BGR2RGB) 41 | gray = cv2.Canny(image_np, 20, 80) 42 | 43 | HEIGHT = image_np.shape[0] 44 | WIDTH = image_np.shape[1] 45 | 46 | bx1 = WIDTH / 2 + int(np.random.rand() * 10 - 5) 47 | bx2 = WIDTH / 2 + int(np.random.rand() * 10 - 5) 48 | by1 = HEIGHT * 0.785 + int(np.random.rand() * 4 - 2) 49 | by2 = HEIGHT * 0.785 + int(np.random.rand() * 4 - 2) 50 | alpha = WIDTH * fix 51 | 52 | # 获取棋子x坐标 53 | linemax = [] 54 | for i in range(int(HEIGHT * 0.4), int(HEIGHT * 0.6)): 55 | line = [] 56 | for j in range(int(WIDTH * 0.15), int(WIDTH * 0.85)): 57 | if image_np[i, j, 0] > 40 and image_np[i, j, 0] < 70 and image_np[i, j, 1] > 40 and image_np[i, j, 1] < 70 and image_np[i, j, 2] > 60 and image_np[i, j, 2] < 110: 58 | gray[i, j] = 255 59 | if len(line) > 0 and j - line[-1] > 1: 60 | break 61 | else: 62 | line.append(j) 63 | 64 | if len(line) > 5 and len(line) > len(linemax): 65 | linemax = line 66 | if len(linemax) > 20 and len(line) == 0: 67 | break 68 | 69 | chess_x = int(np.mean(linemax)) 70 | 71 | # 获取目标x坐标 72 | for i in range(int(HEIGHT * 0.3), int(HEIGHT * 0.5)): 73 | flag = False 74 | for j in range(WIDTH): 75 | # 超过朋友时棋子上方的图案 76 | if np.abs(j - chess_x) < len(linemax): 77 | continue 78 | if not gray[i, j] == 0: 79 | target_x = j 80 | flag = True 81 | break 82 | if flag: 83 | break 84 | 85 | # 修改检测图 86 | gray[:, chess_x] = 255 87 | gray[:, target_x] = 255 88 | # 保存检测图 89 | cv2.imwrite('detection.png', gray) 90 | 91 | print(chess_x, target_x) 92 | jump(float(np.abs(chess_x - target_x)) / WIDTH, alpha) 93 | 94 | # 等棋子落稳 95 | time.sleep(np.random.random() + 1) 96 | 97 | -------------------------------------------------------------------------------- /simple/simple_ios.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # 已测试MacOS + IOS,要正确运行下面代码,请先配置: 4 | # 使用开发者账号Xcode真机调试WDA(WebDriverAgent) 5 | # 参考网址 https://testerhome.com/topics/7220 6 | # 安装openatx/facebook-wda支持使用Python调用WDA 7 | # 参考网址 https://github.com/openatx/facebook-wda 8 | 9 | import numpy as np 10 | import cv2 11 | import time 12 | import wda 13 | 14 | # iPhone 6s 按压时间参数修正,其它型号iPhone请自行修改 15 | fixtime = 2.255 16 | 17 | c = wda.Client() 18 | s = c.session() 19 | 20 | screenshot = 'jump_ios.png' 21 | 22 | # 屏幕截图 23 | def pull_screenshot(): 24 | c.screenshot(screenshot) 25 | 26 | # 根据x距离跳跃 27 | def jump(distance, alpha): 28 | press_time = max(int(distance * alpha), 200) / 1000.0 29 | print('press time: {}'.format(press_time)) 30 | s.tap_hold(200, 200, press_time) 31 | 32 | alpha = 0 33 | bx1, by1, bx2, by2 = 0, 0, 0, 0 34 | chess_x = 0 35 | target_x = 0 36 | 37 | while True: 38 | pull_screenshot() 39 | image_np = cv2.imread(screenshot) 40 | image_np = cv2.cvtColor(image_np, cv2.COLOR_BGR2RGB) 41 | gray = cv2.Canny(image_np, 20, 80) 42 | 43 | HEIGHT = image_np.shape[0] 44 | WIDTH = image_np.shape[1] 45 | 46 | bx1 = WIDTH / 2 + int(np.random.rand() * 10 - 5) 47 | bx2 = WIDTH / 2 + int(np.random.rand() * 10 - 5) 48 | by1 = HEIGHT * 0.785 + int(np.random.rand() * 4 - 2) 49 | by2 = HEIGHT * 0.785 + int(np.random.rand() * 4 - 2) 50 | alpha = WIDTH * fixtime 51 | 52 | # 获取棋子x坐标 53 | linemax = [] 54 | for i in range(int(HEIGHT * 0.4), int(HEIGHT * 0.6)): 55 | line = [] 56 | for j in range(int(WIDTH * 0.15), int(WIDTH * 0.85)): 57 | if image_np[i, j, 0] > 40 and image_np[i, j, 0] < 70 and image_np[i, j, 1] > 40 and image_np[i, j, 1] < 70 and image_np[i, j, 2] > 60 and image_np[i, j, 2] < 110: 58 | gray[i, j] = 255 59 | if len(line) > 0 and j - line[-1] > 1: 60 | break 61 | else: 62 | line.append(j) 63 | 64 | if len(line) > 5 and len(line) > len(linemax): 65 | linemax = line 66 | 67 | if len(linemax) > 50 and len(line) == 0: 68 | break 69 | 70 | chess_x = int(np.mean(linemax)) 71 | 72 | # 获取目标x坐标 73 | for i in range(int(HEIGHT * 0.3), int(HEIGHT * 0.5)): 74 | flag = False 75 | for j in range(WIDTH): 76 | # 超过朋友时棋子上方的图案 77 | if np.abs(j - chess_x) < len(linemax): 78 | continue 79 | if not gray[i, j] == 0: 80 | target_x = j 81 | flag = True 82 | break 83 | if flag: 84 | break 85 | 86 | # 修改检测图 87 | gray[:, chess_x] = 255 88 | gray[:, target_x] = 255 89 | # 保存检测图 90 | cv2.imwrite('detection_ios.png', gray) 91 | 92 | print(chess_x, target_x) 93 | jump(float(np.abs(chess_x - target_x)) / WIDTH, alpha) 94 | 95 | # 等棋子落稳 96 | time.sleep(np.random.random() + 1.4) 97 | -------------------------------------------------------------------------------- /tensorflow/frozen_inference_graph_frcnn_inception_v2_coco.pb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Honlan/wechat_jump_tensorflow/75ee565c2eae489be841734c09cdeb703dcde251/tensorflow/frozen_inference_graph_frcnn_inception_v2_coco.pb -------------------------------------------------------------------------------- /tensorflow/utils/__pycache__/label_map_util.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Honlan/wechat_jump_tensorflow/75ee565c2eae489be841734c09cdeb703dcde251/tensorflow/utils/__pycache__/label_map_util.cpython-36.pyc -------------------------------------------------------------------------------- /tensorflow/utils/__pycache__/visualization_utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Honlan/wechat_jump_tensorflow/75ee565c2eae489be841734c09cdeb703dcde251/tensorflow/utils/__pycache__/visualization_utils.cpython-36.pyc -------------------------------------------------------------------------------- /tensorflow/utils/dataset_util.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Utility functions for creating TFRecord data sets.""" 17 | 18 | import tensorflow as tf 19 | 20 | 21 | def int64_feature(value): 22 | return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) 23 | 24 | 25 | def int64_list_feature(value): 26 | return tf.train.Feature(int64_list=tf.train.Int64List(value=value)) 27 | 28 | 29 | def bytes_feature(value): 30 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) 31 | 32 | 33 | def bytes_list_feature(value): 34 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=value)) 35 | 36 | 37 | def float_list_feature(value): 38 | return tf.train.Feature(float_list=tf.train.FloatList(value=value)) 39 | 40 | 41 | def read_examples_list(path): 42 | """Read list of training or validation examples. 43 | 44 | The file is assumed to contain a single example per line where the first 45 | token in the line is an identifier that allows us to find the image and 46 | annotation xml for that example. 47 | 48 | For example, the line: 49 | xyz 3 50 | would allow us to find files xyz.jpg and xyz.xml (the 3 would be ignored). 51 | 52 | Args: 53 | path: absolute path to examples list file. 54 | 55 | Returns: 56 | list of example identifiers (strings). 57 | """ 58 | with tf.gfile.GFile(path) as fid: 59 | lines = fid.readlines() 60 | return [line.strip().split(' ')[0] for line in lines] 61 | 62 | 63 | def recursive_parse_xml_to_dict(xml): 64 | """Recursively parses XML contents to python dict. 65 | 66 | We assume that `object` tags are the only ones that can appear 67 | multiple times at the same level of a tree. 68 | 69 | Args: 70 | xml: xml tree obtained by parsing XML file contents using lxml.etree 71 | 72 | Returns: 73 | Python dictionary holding XML contents. 74 | """ 75 | if not xml: 76 | return {xml.tag: xml.text} 77 | result = {} 78 | for child in xml: 79 | child_result = recursive_parse_xml_to_dict(child) 80 | if child.tag != 'object': 81 | result[child.tag] = child_result[child.tag] 82 | else: 83 | if child.tag not in result: 84 | result[child.tag] = [] 85 | result[child.tag].append(child_result[child.tag]) 86 | return {xml.tag: result} 87 | -------------------------------------------------------------------------------- /tensorflow/utils/label_map_util.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Label map utility functions.""" 17 | 18 | import logging 19 | 20 | import tensorflow as tf 21 | from google.protobuf import text_format 22 | # from object_detection.protos import string_int_label_map_pb2 23 | 24 | 25 | def _validate_label_map(label_map): 26 | """Checks if a label map is valid. 27 | 28 | Args: 29 | label_map: StringIntLabelMap to validate. 30 | 31 | Raises: 32 | ValueError: if label map is invalid. 33 | """ 34 | for item in label_map.item: 35 | if item.id < 1: 36 | raise ValueError('Label map ids should be >= 1.') 37 | 38 | 39 | def create_category_index(categories): 40 | """Creates dictionary of COCO compatible categories keyed by category id. 41 | 42 | Args: 43 | categories: a list of dicts, each of which has the following keys: 44 | 'id': (required) an integer id uniquely identifying this category. 45 | 'name': (required) string representing category name 46 | e.g., 'cat', 'dog', 'pizza'. 47 | 48 | Returns: 49 | category_index: a dict containing the same entries as categories, but keyed 50 | by the 'id' field of each category. 51 | """ 52 | category_index = {} 53 | for cat in categories: 54 | category_index[cat['id']] = cat 55 | return category_index 56 | 57 | 58 | def convert_label_map_to_categories(label_map, 59 | max_num_classes, 60 | use_display_name=True): 61 | """Loads label map proto and returns categories list compatible with eval. 62 | 63 | This function loads a label map and returns a list of dicts, each of which 64 | has the following keys: 65 | 'id': (required) an integer id uniquely identifying this category. 66 | 'name': (required) string representing category name 67 | e.g., 'cat', 'dog', 'pizza'. 68 | We only allow class into the list if its id-label_id_offset is 69 | between 0 (inclusive) and max_num_classes (exclusive). 70 | If there are several items mapping to the same id in the label map, 71 | we will only keep the first one in the categories list. 72 | 73 | Args: 74 | label_map: a StringIntLabelMapProto or None. If None, a default categories 75 | list is created with max_num_classes categories. 76 | max_num_classes: maximum number of (consecutive) label indices to include. 77 | use_display_name: (boolean) choose whether to load 'display_name' field 78 | as category name. If False or if the display_name field does not exist, 79 | uses 'name' field as category names instead. 80 | Returns: 81 | categories: a list of dictionaries representing all possible categories. 82 | """ 83 | categories = [] 84 | list_of_ids_already_added = [] 85 | if not label_map: 86 | label_id_offset = 1 87 | for class_id in range(max_num_classes): 88 | categories.append({ 89 | 'id': class_id + label_id_offset, 90 | 'name': 'category_{}'.format(class_id + label_id_offset) 91 | }) 92 | return categories 93 | for item in label_map.item: 94 | if not 0 < item.id <= max_num_classes: 95 | logging.info('Ignore item %d since it falls outside of requested ' 96 | 'label range.', item.id) 97 | continue 98 | if use_display_name and item.HasField('display_name'): 99 | name = item.display_name 100 | else: 101 | name = item.name 102 | if item.id not in list_of_ids_already_added: 103 | list_of_ids_already_added.append(item.id) 104 | categories.append({'id': item.id, 'name': name}) 105 | return categories 106 | 107 | 108 | def load_labelmap(path): 109 | """Loads label map proto. 110 | 111 | Args: 112 | path: path to StringIntLabelMap proto text file. 113 | Returns: 114 | a StringIntLabelMapProto 115 | """ 116 | with tf.gfile.GFile(path, 'r') as fid: 117 | label_map_string = fid.read() 118 | label_map = StringIntLabelMap() 119 | try: 120 | text_format.Merge(label_map_string, label_map) 121 | except text_format.ParseError: 122 | label_map.ParseFromString(label_map_string) 123 | _validate_label_map(label_map) 124 | return label_map 125 | 126 | 127 | def get_label_map_dict(label_map_path, use_display_name=False): 128 | """Reads a label map and returns a dictionary of label names to id. 129 | 130 | Args: 131 | label_map_path: path to label_map. 132 | use_display_name: whether to use the label map items' display names as keys. 133 | 134 | Returns: 135 | A dictionary mapping label names to id. 136 | """ 137 | label_map = load_labelmap(label_map_path) 138 | label_map_dict = {} 139 | for item in label_map.item: 140 | if use_display_name: 141 | label_map_dict[item.display_name] = item.id 142 | else: 143 | label_map_dict[item.name] = item.id 144 | return label_map_dict 145 | 146 | 147 | def create_category_index_from_labelmap(label_map_path): 148 | """Reads a label map and returns a category index. 149 | 150 | Args: 151 | label_map_path: Path to `StringIntLabelMap` proto text file. 152 | 153 | Returns: 154 | A category index, which is a dictionary that maps integer ids to dicts 155 | containing categories, e.g. 156 | {1: {'id': 1, 'name': 'dog'}, 2: {'id': 2, 'name': 'cat'}, ...} 157 | """ 158 | label_map = load_labelmap(label_map_path) 159 | max_num_classes = max(item.id for item in label_map.item) 160 | categories = convert_label_map_to_categories(label_map, max_num_classes) 161 | return create_category_index(categories) 162 | 163 | 164 | def create_class_agnostic_category_index(): 165 | """Creates a category index with a single `object` class.""" 166 | return {1: {'id': 1, 'name': 'object'}} 167 | 168 | # Generated by the protocol buffer compiler. DO NOT EDIT! 169 | # source: object_detection/protos/string_int_label_map.proto 170 | 171 | import sys 172 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 173 | from google.protobuf import descriptor as _descriptor 174 | from google.protobuf import message as _message 175 | from google.protobuf import reflection as _reflection 176 | from google.protobuf import symbol_database as _symbol_database 177 | from google.protobuf import descriptor_pb2 178 | # @@protoc_insertion_point(imports) 179 | 180 | _sym_db = _symbol_database.Default() 181 | 182 | 183 | 184 | 185 | DESCRIPTOR = _descriptor.FileDescriptor( 186 | name='object_detection/protos/string_int_label_map.proto', 187 | package='object_detection.protos', 188 | syntax='proto2', 189 | serialized_pb=_b('\n2object_detection/protos/string_int_label_map.proto\x12\x17object_detection.protos\"G\n\x15StringIntLabelMapItem\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\n\n\x02id\x18\x02 \x01(\x05\x12\x14\n\x0c\x64isplay_name\x18\x03 \x01(\t\"Q\n\x11StringIntLabelMap\x12<\n\x04item\x18\x01 \x03(\x0b\x32..object_detection.protos.StringIntLabelMapItem') 190 | ) 191 | 192 | 193 | 194 | 195 | _STRINGINTLABELMAPITEM = _descriptor.Descriptor( 196 | name='StringIntLabelMapItem', 197 | full_name='object_detection.protos.StringIntLabelMapItem', 198 | filename=None, 199 | file=DESCRIPTOR, 200 | containing_type=None, 201 | fields=[ 202 | _descriptor.FieldDescriptor( 203 | name='name', full_name='object_detection.protos.StringIntLabelMapItem.name', index=0, 204 | number=1, type=9, cpp_type=9, label=1, 205 | has_default_value=False, default_value=_b("").decode('utf-8'), 206 | message_type=None, enum_type=None, containing_type=None, 207 | is_extension=False, extension_scope=None, 208 | options=None, file=DESCRIPTOR), 209 | _descriptor.FieldDescriptor( 210 | name='id', full_name='object_detection.protos.StringIntLabelMapItem.id', index=1, 211 | number=2, type=5, cpp_type=1, label=1, 212 | has_default_value=False, default_value=0, 213 | message_type=None, enum_type=None, containing_type=None, 214 | is_extension=False, extension_scope=None, 215 | options=None, file=DESCRIPTOR), 216 | _descriptor.FieldDescriptor( 217 | name='display_name', full_name='object_detection.protos.StringIntLabelMapItem.display_name', index=2, 218 | number=3, type=9, cpp_type=9, label=1, 219 | has_default_value=False, default_value=_b("").decode('utf-8'), 220 | message_type=None, enum_type=None, containing_type=None, 221 | is_extension=False, extension_scope=None, 222 | options=None, file=DESCRIPTOR), 223 | ], 224 | extensions=[ 225 | ], 226 | nested_types=[], 227 | enum_types=[ 228 | ], 229 | options=None, 230 | is_extendable=False, 231 | syntax='proto2', 232 | extension_ranges=[], 233 | oneofs=[ 234 | ], 235 | serialized_start=79, 236 | serialized_end=150, 237 | ) 238 | 239 | 240 | _STRINGINTLABELMAP = _descriptor.Descriptor( 241 | name='StringIntLabelMap', 242 | full_name='object_detection.protos.StringIntLabelMap', 243 | filename=None, 244 | file=DESCRIPTOR, 245 | containing_type=None, 246 | fields=[ 247 | _descriptor.FieldDescriptor( 248 | name='item', full_name='object_detection.protos.StringIntLabelMap.item', index=0, 249 | number=1, type=11, cpp_type=10, label=3, 250 | has_default_value=False, default_value=[], 251 | message_type=None, enum_type=None, containing_type=None, 252 | is_extension=False, extension_scope=None, 253 | options=None, file=DESCRIPTOR), 254 | ], 255 | extensions=[ 256 | ], 257 | nested_types=[], 258 | enum_types=[ 259 | ], 260 | options=None, 261 | is_extendable=False, 262 | syntax='proto2', 263 | extension_ranges=[], 264 | oneofs=[ 265 | ], 266 | serialized_start=152, 267 | serialized_end=233, 268 | ) 269 | 270 | _STRINGINTLABELMAP.fields_by_name['item'].message_type = _STRINGINTLABELMAPITEM 271 | DESCRIPTOR.message_types_by_name['StringIntLabelMapItem'] = _STRINGINTLABELMAPITEM 272 | DESCRIPTOR.message_types_by_name['StringIntLabelMap'] = _STRINGINTLABELMAP 273 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 274 | 275 | StringIntLabelMapItem = _reflection.GeneratedProtocolMessageType('StringIntLabelMapItem', (_message.Message,), dict( 276 | DESCRIPTOR = _STRINGINTLABELMAPITEM, 277 | __module__ = 'object_detection.protos.string_int_label_map_pb2' 278 | # @@protoc_insertion_point(class_scope:object_detection.protos.StringIntLabelMapItem) 279 | )) 280 | _sym_db.RegisterMessage(StringIntLabelMapItem) 281 | 282 | StringIntLabelMap = _reflection.GeneratedProtocolMessageType('StringIntLabelMap', (_message.Message,), dict( 283 | DESCRIPTOR = _STRINGINTLABELMAP, 284 | __module__ = 'object_detection.protos.string_int_label_map_pb2' 285 | # @@protoc_insertion_point(class_scope:object_detection.protos.StringIntLabelMap) 286 | )) 287 | _sym_db.RegisterMessage(StringIntLabelMap) 288 | 289 | 290 | # @@protoc_insertion_point(module_scope) 291 | -------------------------------------------------------------------------------- /tensorflow/utils/visualization_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """A set of functions that are used for visualization. 17 | 18 | These functions often receive an image, perform some visualization on the image. 19 | The functions do not return a value, instead they modify the image itself. 20 | 21 | """ 22 | import collections 23 | import functools 24 | import matplotlib.pyplot as plt 25 | import numpy as np 26 | import PIL.Image as Image 27 | import PIL.ImageColor as ImageColor 28 | import PIL.ImageDraw as ImageDraw 29 | import PIL.ImageFont as ImageFont 30 | import six 31 | import tensorflow as tf 32 | 33 | 34 | _TITLE_LEFT_MARGIN = 10 35 | _TITLE_TOP_MARGIN = 10 36 | STANDARD_COLORS = [ 37 | 'AliceBlue', 'Chartreuse', 'Aqua', 'Aquamarine', 'Azure', 'Beige', 'Bisque', 38 | 'BlanchedAlmond', 'BlueViolet', 'BurlyWood', 'CadetBlue', 'AntiqueWhite', 39 | 'Chocolate', 'Coral', 'CornflowerBlue', 'Cornsilk', 'Crimson', 'Cyan', 40 | 'DarkCyan', 'DarkGoldenRod', 'DarkGrey', 'DarkKhaki', 'DarkOrange', 41 | 'DarkOrchid', 'DarkSalmon', 'DarkSeaGreen', 'DarkTurquoise', 'DarkViolet', 42 | 'DeepPink', 'DeepSkyBlue', 'DodgerBlue', 'FireBrick', 'FloralWhite', 43 | 'ForestGreen', 'Fuchsia', 'Gainsboro', 'GhostWhite', 'Gold', 'GoldenRod', 44 | 'Salmon', 'Tan', 'HoneyDew', 'HotPink', 'IndianRed', 'Ivory', 'Khaki', 45 | 'Lavender', 'LavenderBlush', 'LawnGreen', 'LemonChiffon', 'LightBlue', 46 | 'LightCoral', 'LightCyan', 'LightGoldenRodYellow', 'LightGray', 'LightGrey', 47 | 'LightGreen', 'LightPink', 'LightSalmon', 'LightSeaGreen', 'LightSkyBlue', 48 | 'LightSlateGray', 'LightSlateGrey', 'LightSteelBlue', 'LightYellow', 'Lime', 49 | 'LimeGreen', 'Linen', 'Magenta', 'MediumAquaMarine', 'MediumOrchid', 50 | 'MediumPurple', 'MediumSeaGreen', 'MediumSlateBlue', 'MediumSpringGreen', 51 | 'MediumTurquoise', 'MediumVioletRed', 'MintCream', 'MistyRose', 'Moccasin', 52 | 'NavajoWhite', 'OldLace', 'Olive', 'OliveDrab', 'Orange', 'OrangeRed', 53 | 'Orchid', 'PaleGoldenRod', 'PaleGreen', 'PaleTurquoise', 'PaleVioletRed', 54 | 'PapayaWhip', 'PeachPuff', 'Peru', 'Pink', 'Plum', 'PowderBlue', 'Purple', 55 | 'Red', 'RosyBrown', 'RoyalBlue', 'SaddleBrown', 'Green', 'SandyBrown', 56 | 'SeaGreen', 'SeaShell', 'Sienna', 'Silver', 'SkyBlue', 'SlateBlue', 57 | 'SlateGray', 'SlateGrey', 'Snow', 'SpringGreen', 'SteelBlue', 'GreenYellow', 58 | 'Teal', 'Thistle', 'Tomato', 'Turquoise', 'Violet', 'Wheat', 'White', 59 | 'WhiteSmoke', 'Yellow', 'YellowGreen' 60 | ] 61 | 62 | 63 | def save_image_array_as_png(image, output_path): 64 | """Saves an image (represented as a numpy array) to PNG. 65 | 66 | Args: 67 | image: a numpy array with shape [height, width, 3]. 68 | output_path: path to which image should be written. 69 | """ 70 | image_pil = Image.fromarray(np.uint8(image)).convert('RGB') 71 | with tf.gfile.Open(output_path, 'w') as fid: 72 | image_pil.save(fid, 'PNG') 73 | 74 | 75 | def encode_image_array_as_png_str(image): 76 | """Encodes a numpy array into a PNG string. 77 | 78 | Args: 79 | image: a numpy array with shape [height, width, 3]. 80 | 81 | Returns: 82 | PNG encoded image string. 83 | """ 84 | image_pil = Image.fromarray(np.uint8(image)) 85 | output = six.BytesIO() 86 | image_pil.save(output, format='PNG') 87 | png_string = output.getvalue() 88 | output.close() 89 | return png_string 90 | 91 | 92 | def draw_bounding_box_on_image_array(image, 93 | ymin, 94 | xmin, 95 | ymax, 96 | xmax, 97 | color='red', 98 | thickness=4, 99 | display_str_list=(), 100 | use_normalized_coordinates=True): 101 | """Adds a bounding box to an image (numpy array). 102 | 103 | Args: 104 | image: a numpy array with shape [height, width, 3]. 105 | ymin: ymin of bounding box in normalized coordinates (same below). 106 | xmin: xmin of bounding box. 107 | ymax: ymax of bounding box. 108 | xmax: xmax of bounding box. 109 | color: color to draw bounding box. Default is red. 110 | thickness: line thickness. Default value is 4. 111 | display_str_list: list of strings to display in box 112 | (each to be shown on its own line). 113 | use_normalized_coordinates: If True (default), treat coordinates 114 | ymin, xmin, ymax, xmax as relative to the image. Otherwise treat 115 | coordinates as absolute. 116 | """ 117 | image_pil = Image.fromarray(np.uint8(image)).convert('RGB') 118 | draw_bounding_box_on_image(image_pil, ymin, xmin, ymax, xmax, color, 119 | thickness, display_str_list, 120 | use_normalized_coordinates) 121 | np.copyto(image, np.array(image_pil)) 122 | 123 | 124 | def draw_bounding_box_on_image(image, 125 | ymin, 126 | xmin, 127 | ymax, 128 | xmax, 129 | color='red', 130 | thickness=4, 131 | display_str_list=(), 132 | use_normalized_coordinates=True): 133 | """Adds a bounding box to an image. 134 | 135 | Each string in display_str_list is displayed on a separate line above the 136 | bounding box in black text on a rectangle filled with the input 'color'. 137 | If the top of the bounding box extends to the edge of the image, the strings 138 | are displayed below the bounding box. 139 | 140 | Args: 141 | image: a PIL.Image object. 142 | ymin: ymin of bounding box. 143 | xmin: xmin of bounding box. 144 | ymax: ymax of bounding box. 145 | xmax: xmax of bounding box. 146 | color: color to draw bounding box. Default is red. 147 | thickness: line thickness. Default value is 4. 148 | display_str_list: list of strings to display in box 149 | (each to be shown on its own line). 150 | use_normalized_coordinates: If True (default), treat coordinates 151 | ymin, xmin, ymax, xmax as relative to the image. Otherwise treat 152 | coordinates as absolute. 153 | """ 154 | draw = ImageDraw.Draw(image) 155 | im_width, im_height = image.size 156 | if use_normalized_coordinates: 157 | (left, right, top, bottom) = (xmin * im_width, xmax * im_width, 158 | ymin * im_height, ymax * im_height) 159 | else: 160 | (left, right, top, bottom) = (xmin, xmax, ymin, ymax) 161 | draw.line([(left, top), (left, bottom), (right, bottom), 162 | (right, top), (left, top)], width=thickness, fill=color) 163 | try: 164 | font = ImageFont.truetype('arial.ttf', 24) 165 | except IOError: 166 | font = ImageFont.load_default() 167 | 168 | # If the total height of the display strings added to the top of the bounding 169 | # box exceeds the top of the image, stack the strings below the bounding box 170 | # instead of above. 171 | display_str_heights = [font.getsize(ds)[1] for ds in display_str_list] 172 | # Each display_str has a top and bottom margin of 0.05x. 173 | total_display_str_height = (1 + 2 * 0.05) * sum(display_str_heights) 174 | 175 | if top > total_display_str_height: 176 | text_bottom = top 177 | else: 178 | text_bottom = bottom + total_display_str_height 179 | # Reverse list and print from bottom to top. 180 | for display_str in display_str_list[::-1]: 181 | text_width, text_height = font.getsize(display_str) 182 | margin = np.ceil(0.05 * text_height) 183 | draw.rectangle( 184 | [(left, text_bottom - text_height - 2 * margin), (left + text_width, 185 | text_bottom)], 186 | fill=color) 187 | draw.text( 188 | (left + margin, text_bottom - text_height - margin), 189 | display_str, 190 | fill='black', 191 | font=font) 192 | text_bottom -= text_height - 2 * margin 193 | 194 | 195 | def draw_bounding_boxes_on_image_array(image, 196 | boxes, 197 | color='red', 198 | thickness=4, 199 | display_str_list_list=()): 200 | """Draws bounding boxes on image (numpy array). 201 | 202 | Args: 203 | image: a numpy array object. 204 | boxes: a 2 dimensional numpy array of [N, 4]: (ymin, xmin, ymax, xmax). 205 | The coordinates are in normalized format between [0, 1]. 206 | color: color to draw bounding box. Default is red. 207 | thickness: line thickness. Default value is 4. 208 | display_str_list_list: list of list of strings. 209 | a list of strings for each bounding box. 210 | The reason to pass a list of strings for a 211 | bounding box is that it might contain 212 | multiple labels. 213 | 214 | Raises: 215 | ValueError: if boxes is not a [N, 4] array 216 | """ 217 | image_pil = Image.fromarray(image) 218 | draw_bounding_boxes_on_image(image_pil, boxes, color, thickness, 219 | display_str_list_list) 220 | np.copyto(image, np.array(image_pil)) 221 | 222 | 223 | def draw_bounding_boxes_on_image(image, 224 | boxes, 225 | color='red', 226 | thickness=4, 227 | display_str_list_list=()): 228 | """Draws bounding boxes on image. 229 | 230 | Args: 231 | image: a PIL.Image object. 232 | boxes: a 2 dimensional numpy array of [N, 4]: (ymin, xmin, ymax, xmax). 233 | The coordinates are in normalized format between [0, 1]. 234 | color: color to draw bounding box. Default is red. 235 | thickness: line thickness. Default value is 4. 236 | display_str_list_list: list of list of strings. 237 | a list of strings for each bounding box. 238 | The reason to pass a list of strings for a 239 | bounding box is that it might contain 240 | multiple labels. 241 | 242 | Raises: 243 | ValueError: if boxes is not a [N, 4] array 244 | """ 245 | boxes_shape = boxes.shape 246 | if not boxes_shape: 247 | return 248 | if len(boxes_shape) != 2 or boxes_shape[1] != 4: 249 | raise ValueError('Input must be of size [N, 4]') 250 | for i in range(boxes_shape[0]): 251 | display_str_list = () 252 | if display_str_list_list: 253 | display_str_list = display_str_list_list[i] 254 | draw_bounding_box_on_image(image, boxes[i, 0], boxes[i, 1], boxes[i, 2], 255 | boxes[i, 3], color, thickness, display_str_list) 256 | 257 | 258 | def draw_bounding_boxes_on_image_tensors(images, 259 | boxes, 260 | classes, 261 | scores, 262 | category_index, 263 | max_boxes_to_draw=20, 264 | min_score_thresh=0.2): 265 | """Draws bounding boxes on batch of image tensors. 266 | 267 | Args: 268 | images: A 4D uint8 image tensor of shape [N, H, W, C]. 269 | boxes: [N, max_detections, 4] float32 tensor of detection boxes. 270 | classes: [N, max_detections] int tensor of detection classes. Note that 271 | classes are 1-indexed. 272 | scores: [N, max_detections] float32 tensor of detection scores. 273 | category_index: a dict that maps integer ids to category dicts. e.g. 274 | {1: {1: 'dog'}, 2: {2: 'cat'}, ...} 275 | max_boxes_to_draw: Maximum number of boxes to draw on an image. Default 20. 276 | min_score_thresh: Minimum score threshold for visualization. Default 0.2. 277 | 278 | Returns: 279 | 4D image tensor of type uint8, with boxes drawn on top. 280 | """ 281 | visualize_boxes_fn = functools.partial( 282 | visualize_boxes_and_labels_on_image_array, 283 | category_index=category_index, 284 | instance_masks=None, 285 | keypoints=None, 286 | use_normalized_coordinates=True, 287 | max_boxes_to_draw=max_boxes_to_draw, 288 | min_score_thresh=min_score_thresh, 289 | agnostic_mode=False, 290 | line_thickness=4) 291 | 292 | def draw_boxes(image_boxes_classes_scores): 293 | """Draws boxes on image.""" 294 | (image, boxes, classes, scores) = image_boxes_classes_scores 295 | image_with_boxes = tf.py_func(visualize_boxes_fn, 296 | [image, boxes, classes, scores], tf.uint8) 297 | return image_with_boxes 298 | 299 | images = tf.map_fn( 300 | draw_boxes, (images, boxes, classes, scores), 301 | dtype=tf.uint8, 302 | back_prop=False) 303 | return images 304 | 305 | 306 | def draw_keypoints_on_image_array(image, 307 | keypoints, 308 | color='red', 309 | radius=2, 310 | use_normalized_coordinates=True): 311 | """Draws keypoints on an image (numpy array). 312 | 313 | Args: 314 | image: a numpy array with shape [height, width, 3]. 315 | keypoints: a numpy array with shape [num_keypoints, 2]. 316 | color: color to draw the keypoints with. Default is red. 317 | radius: keypoint radius. Default value is 2. 318 | use_normalized_coordinates: if True (default), treat keypoint values as 319 | relative to the image. Otherwise treat them as absolute. 320 | """ 321 | image_pil = Image.fromarray(np.uint8(image)).convert('RGB') 322 | draw_keypoints_on_image(image_pil, keypoints, color, radius, 323 | use_normalized_coordinates) 324 | np.copyto(image, np.array(image_pil)) 325 | 326 | 327 | def draw_keypoints_on_image(image, 328 | keypoints, 329 | color='red', 330 | radius=2, 331 | use_normalized_coordinates=True): 332 | """Draws keypoints on an image. 333 | 334 | Args: 335 | image: a PIL.Image object. 336 | keypoints: a numpy array with shape [num_keypoints, 2]. 337 | color: color to draw the keypoints with. Default is red. 338 | radius: keypoint radius. Default value is 2. 339 | use_normalized_coordinates: if True (default), treat keypoint values as 340 | relative to the image. Otherwise treat them as absolute. 341 | """ 342 | draw = ImageDraw.Draw(image) 343 | im_width, im_height = image.size 344 | keypoints_x = [k[1] for k in keypoints] 345 | keypoints_y = [k[0] for k in keypoints] 346 | if use_normalized_coordinates: 347 | keypoints_x = tuple([im_width * x for x in keypoints_x]) 348 | keypoints_y = tuple([im_height * y for y in keypoints_y]) 349 | for keypoint_x, keypoint_y in zip(keypoints_x, keypoints_y): 350 | draw.ellipse([(keypoint_x - radius, keypoint_y - radius), 351 | (keypoint_x + radius, keypoint_y + radius)], 352 | outline=color, fill=color) 353 | 354 | 355 | def draw_mask_on_image_array(image, mask, color='red', alpha=0.7): 356 | """Draws mask on an image. 357 | 358 | Args: 359 | image: uint8 numpy array with shape (img_height, img_height, 3) 360 | mask: a uint8 numpy array of shape (img_height, img_height) with 361 | values between either 0 or 1. 362 | color: color to draw the keypoints with. Default is red. 363 | alpha: transparency value between 0 and 1. (default: 0.7) 364 | 365 | Raises: 366 | ValueError: On incorrect data type for image or masks. 367 | """ 368 | if image.dtype != np.uint8: 369 | raise ValueError('`image` not of type np.uint8') 370 | if mask.dtype != np.uint8: 371 | raise ValueError('`mask` not of type np.uint8') 372 | if np.any(np.logical_and(mask != 1, mask != 0)): 373 | raise ValueError('`mask` elements should be in [0, 1]') 374 | rgb = ImageColor.getrgb(color) 375 | pil_image = Image.fromarray(image) 376 | 377 | solid_color = np.expand_dims( 378 | np.ones_like(mask), axis=2) * np.reshape(list(rgb), [1, 1, 3]) 379 | pil_solid_color = Image.fromarray(np.uint8(solid_color)).convert('RGBA') 380 | pil_mask = Image.fromarray(np.uint8(255.0*alpha*mask)).convert('L') 381 | pil_image = Image.composite(pil_solid_color, pil_image, pil_mask) 382 | np.copyto(image, np.array(pil_image.convert('RGB'))) 383 | 384 | 385 | def visualize_boxes_and_labels_on_image_array(image, 386 | boxes, 387 | classes, 388 | scores, 389 | category_index, 390 | instance_masks=None, 391 | keypoints=None, 392 | use_normalized_coordinates=False, 393 | max_boxes_to_draw=20, 394 | min_score_thresh=.5, 395 | agnostic_mode=False, 396 | line_thickness=4): 397 | """Overlay labeled boxes on an image with formatted scores and label names. 398 | 399 | This function groups boxes that correspond to the same location 400 | and creates a display string for each detection and overlays these 401 | on the image. Note that this function modifies the image in place, and returns 402 | that same image. 403 | 404 | Args: 405 | image: uint8 numpy array with shape (img_height, img_width, 3) 406 | boxes: a numpy array of shape [N, 4] 407 | classes: a numpy array of shape [N]. Note that class indices are 1-based, 408 | and match the keys in the label map. 409 | scores: a numpy array of shape [N] or None. If scores=None, then 410 | this function assumes that the boxes to be plotted are groundtruth 411 | boxes and plot all boxes as black with no classes or scores. 412 | category_index: a dict containing category dictionaries (each holding 413 | category index `id` and category name `name`) keyed by category indices. 414 | instance_masks: a numpy array of shape [N, image_height, image_width], can 415 | be None 416 | keypoints: a numpy array of shape [N, num_keypoints, 2], can 417 | be None 418 | use_normalized_coordinates: whether boxes is to be interpreted as 419 | normalized coordinates or not. 420 | max_boxes_to_draw: maximum number of boxes to visualize. If None, draw 421 | all boxes. 422 | min_score_thresh: minimum score threshold for a box to be visualized 423 | agnostic_mode: boolean (default: False) controlling whether to evaluate in 424 | class-agnostic mode or not. This mode will display scores but ignore 425 | classes. 426 | line_thickness: integer (default: 4) controlling line width of the boxes. 427 | 428 | Returns: 429 | uint8 numpy array with shape (img_height, img_width, 3) with overlaid boxes. 430 | """ 431 | # Create a display string (and color) for every box location, group any boxes 432 | # that correspond to the same location. 433 | box_to_display_str_map = collections.defaultdict(list) 434 | box_to_color_map = collections.defaultdict(str) 435 | box_to_instance_masks_map = {} 436 | box_to_keypoints_map = collections.defaultdict(list) 437 | if not max_boxes_to_draw: 438 | max_boxes_to_draw = boxes.shape[0] 439 | for i in range(min(max_boxes_to_draw, boxes.shape[0])): 440 | if scores is None or scores[i] > min_score_thresh: 441 | box = tuple(boxes[i].tolist()) 442 | if instance_masks is not None: 443 | box_to_instance_masks_map[box] = instance_masks[i] 444 | if keypoints is not None: 445 | box_to_keypoints_map[box].extend(keypoints[i]) 446 | if scores is None: 447 | box_to_color_map[box] = 'black' 448 | else: 449 | if not agnostic_mode: 450 | if classes[i] in category_index.keys(): 451 | class_name = category_index[classes[i]]['name'] 452 | else: 453 | class_name = 'N/A' 454 | display_str = '{}: {}%'.format( 455 | class_name, 456 | int(100*scores[i])) 457 | else: 458 | display_str = 'score: {}%'.format(int(100 * scores[i])) 459 | box_to_display_str_map[box].append(display_str) 460 | if agnostic_mode: 461 | box_to_color_map[box] = 'DarkOrange' 462 | else: 463 | box_to_color_map[box] = STANDARD_COLORS[ 464 | classes[i] % len(STANDARD_COLORS)] 465 | 466 | # Draw all boxes onto image. 467 | for box, color in box_to_color_map.items(): 468 | ymin, xmin, ymax, xmax = box 469 | if instance_masks is not None: 470 | draw_mask_on_image_array( 471 | image, 472 | box_to_instance_masks_map[box], 473 | color=color 474 | ) 475 | draw_bounding_box_on_image_array( 476 | image, 477 | ymin, 478 | xmin, 479 | ymax, 480 | xmax, 481 | color=color, 482 | thickness=line_thickness, 483 | display_str_list=box_to_display_str_map[box], 484 | use_normalized_coordinates=use_normalized_coordinates) 485 | if keypoints is not None: 486 | draw_keypoints_on_image_array( 487 | image, 488 | box_to_keypoints_map[box], 489 | color=color, 490 | radius=line_thickness / 2, 491 | use_normalized_coordinates=use_normalized_coordinates) 492 | 493 | return image 494 | 495 | 496 | def add_cdf_image_summary(values, name): 497 | """Adds a tf.summary.image for a CDF plot of the values. 498 | 499 | Normalizes `values` such that they sum to 1, plots the cumulative distribution 500 | function and creates a tf image summary. 501 | 502 | Args: 503 | values: a 1-D float32 tensor containing the values. 504 | name: name for the image summary. 505 | """ 506 | def cdf_plot(values): 507 | """Numpy function to plot CDF.""" 508 | normalized_values = values / np.sum(values) 509 | sorted_values = np.sort(normalized_values) 510 | cumulative_values = np.cumsum(sorted_values) 511 | fraction_of_examples = (np.arange(cumulative_values.size, dtype=np.float32) 512 | / cumulative_values.size) 513 | fig = plt.figure(frameon=False) 514 | ax = fig.add_subplot('111') 515 | ax.plot(fraction_of_examples, cumulative_values) 516 | ax.set_ylabel('cumulative normalized values') 517 | ax.set_xlabel('fraction of examples') 518 | fig.canvas.draw() 519 | width, height = fig.get_size_inches() * fig.get_dpi() 520 | image = np.fromstring(fig.canvas.tostring_rgb(), dtype='uint8').reshape( 521 | 1, height, width, 3) 522 | return image 523 | cdf_plot = tf.py_func(cdf_plot, [values], tf.uint8) 524 | tf.summary.image(name, cdf_plot) 525 | -------------------------------------------------------------------------------- /tensorflow/wechat_auto_jump.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import numpy as np 4 | import tensorflow as tf 5 | import time 6 | import os 7 | from utils import label_map_util 8 | from utils import visualization_utils as vis_util 9 | import cv2 10 | 11 | if tf.__version__ != '1.4.0': 12 | raise ImportError('Please upgrade your tensorflow installation to v1.4.0!') 13 | 14 | # 模型配置 15 | PATH_TO_CKPT = 'frozen_inference_graph_frcnn_inception_v2_coco.pb' 16 | PATH_TO_LABELS = 'wechat_jump_label_map.pbtxt' 17 | NUM_CLASSES = 7 18 | 19 | # 加载模型 20 | detection_graph = tf.Graph() 21 | with detection_graph.as_default(): 22 | od_graph_def = tf.GraphDef() 23 | with tf.gfile.GFile(PATH_TO_CKPT, 'rb') as fid: 24 | od_graph_def.ParseFromString(fid.read()) 25 | tf.import_graph_def(od_graph_def, name='') 26 | 27 | # 加载类别 28 | label_map = label_map_util.load_labelmap(PATH_TO_LABELS) 29 | categories = label_map_util.convert_label_map_to_categories(label_map, max_num_classes=NUM_CLASSES, use_display_name=True) 30 | category_index = label_map_util.create_category_index(categories) 31 | 32 | # 屏幕截图 33 | def pull_screenshot(path): 34 | os.system('adb shell screencap -p /sdcard/%s' % path) 35 | os.system('adb pull /sdcard/%s .' % path) 36 | 37 | # 读取图片 38 | def read_image(path): 39 | image_np = cv2.imread(path) 40 | image_np = cv2.cvtColor(image_np, cv2.COLOR_BGR2RGB) 41 | 42 | WIDTH = image_np.shape[1] 43 | HEIGHT = image_np.shape[0] 44 | image_np_expanded = np.expand_dims(image_np, axis=0) 45 | 46 | return image_np, image_np_expanded, WIDTH, HEIGHT 47 | 48 | # 获取物体识别结果 49 | def get_positions(boxes, classes, scores, category_index): 50 | cp = [1, 1, 1, 1] 51 | tp = [1, 1, 1, 1] 52 | target_type = '' 53 | min_score_thresh = .5 54 | 55 | for i in range(boxes.shape[0]): 56 | if scores[i] > min_score_thresh: 57 | if boxes[i][0] < 0.3 or boxes[i][2] > 0.8: 58 | continue 59 | if category_index[classes[i]]['name'] == 'chess': 60 | cp = boxes[i] 61 | elif boxes[i][0] < tp[0]: 62 | tp = boxes[i] 63 | target_type = category_index[classes[i]]['name'] 64 | 65 | return cp, tp, target_type 66 | 67 | # 一些变量 68 | loop = 1 69 | alpha = 1800 70 | chess_x = 0 71 | target_x = 0 72 | distance = 0 73 | screenshot = 'screenshot.png' 74 | 75 | # 根据x距离跳跃 76 | def jump(distance, target_type, alpha, bx1, by1, bx2, by2): 77 | press_time = max(int(distance * alpha), 200) 78 | 79 | cmd = 'adb shell input swipe {} {} {} {} {}'.format(bx1, by1, bx2, by2, press_time) 80 | os.system(cmd) 81 | 82 | if target_type in ['waste', 'magic', 'shop', 'music']: 83 | print('=' * 10, target_type , '=' * 10) 84 | 85 | with detection_graph.as_default(): 86 | with tf.Session(graph=detection_graph) as sess: 87 | image_tensor = detection_graph.get_tensor_by_name('image_tensor:0') 88 | detection_boxes = detection_graph.get_tensor_by_name('detection_boxes:0') 89 | detection_scores = detection_graph.get_tensor_by_name('detection_scores:0') 90 | detection_classes = detection_graph.get_tensor_by_name('detection_classes:0') 91 | num_detections = detection_graph.get_tensor_by_name('num_detections:0') 92 | while True: 93 | pull_screenshot(screenshot) 94 | image_np, image_np_expanded, WIDTH, HEIGHT = read_image(screenshot) 95 | 96 | bx1 = WIDTH / 2 + int(np.random.rand() * 10 - 5) 97 | bx2 = WIDTH / 2 + int(np.random.rand() * 10 - 5) 98 | by1 = HEIGHT * 0.785 + int(np.random.rand() * 4 - 2) 99 | by2 = HEIGHT * 0.785 + int(np.random.rand() * 4 - 2) 100 | 101 | (boxes, scores, classes, num) = sess.run( 102 | [detection_boxes, detection_scores, detection_classes, num_detections], 103 | feed_dict={image_tensor: image_np_expanded}) 104 | 105 | boxes = np.reshape(boxes, (-1, boxes.shape[-1])) 106 | scores = np.reshape(scores, (-1)) 107 | classes = np.reshape(classes, (-1)).astype(np.int32) 108 | 109 | vis_util.visualize_boxes_and_labels_on_image_array(image_np, boxes, classes, scores, category_index, use_normalized_coordinates=True, line_thickness=8) 110 | cv2.imwrite('detection.png', cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR)) 111 | 112 | # 计算棋子和目标块位置 113 | cp, tp, target_type = get_positions(boxes, classes, scores, category_index) 114 | chess_x = (cp[1] + cp[3]) / 2 115 | target_x = (tp[1] + tp[3]) / 2 116 | distance = np.abs(chess_x - target_x) 117 | 118 | # 跳! 119 | jump(distance, target_type, alpha, bx1, by1, bx2, by2) 120 | print(distance, target_type) 121 | 122 | # 等棋子落稳 123 | loop += 1 124 | time.sleep(np.random.rand() + 1) 125 | 126 | # 跳累了休息一会 127 | rest_jump = np.random.rand() * 50 + 50 128 | rest_time = np.random.rand() * 5 + 5 129 | if loop > rest_jump: 130 | loop = 1 131 | print('已经跳了 %d 下,休息 %d 秒' % (rest_jump, rest_time)) 132 | time.sleep(rest_time) 133 | -------------------------------------------------------------------------------- /tensorflow/wechat_jump_label_map.pbtxt: -------------------------------------------------------------------------------- 1 | item { 2 | id: 1 3 | name: 'chess' 4 | } 5 | item { 6 | id: 2 7 | name: 'rect' 8 | } 9 | item { 10 | id: 3 11 | name: 'circle' 12 | } 13 | item { 14 | id: 4 15 | name: 'music' 16 | } 17 | item { 18 | id: 5 19 | name: 'shop' 20 | } 21 | item { 22 | id: 6 23 | name: 'magic' 24 | } 25 | item { 26 | id: 7 27 | name: 'waste' 28 | } -------------------------------------------------------------------------------- /物体检测标注说明.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Honlan/wechat_jump_tensorflow/75ee565c2eae489be841734c09cdeb703dcde251/物体检测标注说明.pdf --------------------------------------------------------------------------------