├── .gitignore ├── Anchor-Kmeans ├── README.md ├── datasets.py ├── demo.ipynb ├── gen_anchors.py ├── imgs │ └── avgiou.png └── kmeans.py ├── README.md ├── cal_mean_std.py └── wtm ├── find_gt_files.sh └── img_fill.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/datasets.cpython-37.pyc 2 | __pycache__/kmeans.cpython-37.pyc 3 | .vscode/settings.json 4 | .idea/inspectionProfiles/profiles_settings.xml 5 | .idea/vcs.xml 6 | .idea/modules.xml 7 | .idea/misc.xml 8 | .idea/Anchor-Kmeans.iml 9 | .idea/.gitignore 10 | -------------------------------------------------------------------------------- /Anchor-Kmeans/README.md: -------------------------------------------------------------------------------- 1 | # Anchor-Kmeans 2 | Implementation of kmeans clustering on bounding boxes to generate anchors, as mentioned in the [YOLOv2](https://arxiv.org/abs/1612.08242). 3 | 4 | ## Usage 5 | Currently supports three types of annotation file: 6 | - [labelme json file](https://github.com/wkentaro/labelme) 7 | - [VOC xml file](https://pjreddie.com/projects/pascal-voc-dataset-mirror/) 8 | - csv file, each line is a coordinate values separated by a comma, form as `xmin, ymin, xmax, ymax` 9 | 10 | To generate anchors of your own dataset is very simple, just execute the `gen_anchors.py` script with 3 arguments: 11 | 12 | ```bash 13 | python gen_anchors.py -d /path to your/annotations-dir -t [annotation file type, defualt 'xml'] -k [num of clusters, default 5] 14 | ``` 15 | 16 | ## Test 17 | 18 | I have tested it on the VOC2012 dataset, the trend of average iou with k value is shown in the figure below 19 | 20 | ![](./imgs/avgiou.png) 21 | 22 | See the detailed test code in [demo.ipynb](./demo.ipynb) 23 | 24 | -------------------------------------------------------------------------------- /Anchor-Kmeans/datasets.py: -------------------------------------------------------------------------------- 1 | import xml.etree.ElementTree as ET 2 | import numpy as np 3 | import glob 4 | import os 5 | import json 6 | import cv2 7 | 8 | 9 | class AnnotParser(object): 10 | def __init__(self, file_type): 11 | assert file_type in ['csv', 'xml', 'json'], "Unsupported file type." 12 | self.file_type = file_type 13 | 14 | def parse(self, annot_dir): 15 | """ 16 | Parse annotation file, the file type must be csv or xml or json. 17 | 18 | :param annot_dir: directory path of annotation files 19 | :return: 2-d array, shape as (n, 2), each row represents a bbox, and each column 20 | represents the corresponding width and height after normalized 21 | """ 22 | if self.file_type == 'xml': 23 | return self.parse_xml(annot_dir) 24 | elif self.file_type == 'json': 25 | return self.parse_json(annot_dir) 26 | else: 27 | return self.parse_csv(annot_dir) 28 | 29 | @staticmethod 30 | def parse_xml(annot_dir): 31 | """ 32 | Parse xml annotation file in VOC. 33 | """ 34 | boxes = [] 35 | 36 | for xml_file in glob.glob(os.path.join(annot_dir, '*.xml')): 37 | tree = ET.parse(xml_file) 38 | 39 | h_img = int(tree.findtext('./size/height')) 40 | w_img = int(tree.findtext('./size/width')) 41 | 42 | for obj in tree.iter('object'): 43 | xmin = int(round(float(obj.findtext('bndbox/xmin')))) 44 | ymin = int(round(float(obj.findtext('bndbox/ymin')))) 45 | xmax = int(round(float(obj.findtext('bndbox/xmax')))) 46 | ymax = int(round(float(obj.findtext('bndbox/ymax')))) 47 | 48 | w_norm = (xmax - xmin) / w_img 49 | h_norm = (ymax - ymin) / h_img 50 | 51 | boxes.append([w_norm, h_norm]) 52 | 53 | return np.array(boxes) 54 | 55 | @staticmethod 56 | def parse_json(annot_dir): 57 | """ 58 | Parse labelme json annotation file. 59 | """ 60 | boxes = [] 61 | 62 | for js_file in glob.glob(os.path.join(annot_dir, '*.json')): 63 | with open(js_file) as f: 64 | data = json.load(f) 65 | 66 | h_img = data['imageHeight'] 67 | w_img = data['imageWidth'] 68 | 69 | for shape in data['shapes']: 70 | points = shape['points'] 71 | xmin = int(round(points[0][0])) 72 | ymin = int(round(points[0][1])) 73 | xmax = int(round(points[1][0])) 74 | ymax = int(round(points[1][1])) 75 | 76 | w_norm = (xmax - xmin) / w_img 77 | h_norm = (ymax - ymin) / h_img 78 | 79 | boxes.append([w_norm, h_norm]) 80 | 81 | return np.array(boxes) 82 | 83 | @staticmethod 84 | def parse_csv(annot_dir): 85 | """ 86 | Parse csv annotation file. 87 | """ 88 | boxes = [] 89 | 90 | for csv_file in glob.glob(os.path.join(annot_dir, '*.csv')): 91 | with open(csv_file) as f: 92 | lines = f.readlines() 93 | 94 | for line in lines: 95 | items = line.strip().split(',') 96 | img = cv2.imread(items[0]) 97 | h_img, w_img = img.shape[:2] 98 | xmin, ymin, xmax, ymax = list(map(int, items[1:-1])) 99 | 100 | w_norm = (xmax - xmin) / w_img 101 | h_norm = (ymax - ymin) / h_img 102 | 103 | boxes.append([w_norm, h_norm]) 104 | 105 | return np.array(boxes) 106 | -------------------------------------------------------------------------------- /Anchor-Kmeans/demo.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import numpy as np\n", 10 | "from kmeans import AnchorKmeans\n", 11 | "from datasets import AnnotParser\n", 12 | "from matplotlib import pyplot as plt\n", 13 | "from matplotlib.patches import Rectangle\n", 14 | "%matplotlib inline\n", 15 | "\n", 16 | "plt.style.use('ggplot')" 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": 2, 22 | "metadata": {}, 23 | "outputs": [ 24 | { 25 | "name": "stdout", 26 | "output_type": "stream", 27 | "text": "[INFO] Load datas\nboxes shape : (40138, 2)\n" 28 | } 29 | ], 30 | "source": [ 31 | "print('[INFO] Load datas')\n", 32 | "annot_dir = \"/PATH TO YOUR/VOCdevkit/VOC2012/Annotations\"\n", 33 | "parser = AnnotParser('xml')\n", 34 | "boxes = parser.parse_xml(annot_dir)\n", 35 | "print('boxes shape : {}'.format(boxes.shape))" 36 | ] 37 | }, 38 | { 39 | "cell_type": "code", 40 | "execution_count": 3, 41 | "metadata": {}, 42 | "outputs": [ 43 | { 44 | "name": "stdout", 45 | "output_type": "stream", 46 | "text": "[INFO] Run anchor k-means with k = 2,3,...,10\nK = 2, Avg IOU = 0.4646\nK = 3, Avg IOU = 0.5391\nK = 4, Avg IOU = 0.5801\nK = 5, Avg IOU = 0.6016\nK = 6, Avg IOU = 0.6252\nK = 7, Avg IOU = 0.6434\nK = 8, Avg IOU = 0.6596\nK = 9, Avg IOU = 0.6732\nK = 10, Avg IOU = 0.6838\n" 47 | } 48 | ], 49 | "source": [ 50 | "print('[INFO] Run anchor k-means with k = 2,3,...,10')\n", 51 | "results = {}\n", 52 | "for k in range(2, 11):\n", 53 | " model = AnchorKmeans(k, random_seed=333)\n", 54 | " model.fit(boxes)\n", 55 | " avg_iou = model.avg_iou()\n", 56 | " results[k] = {'anchors': model.anchors_, 'avg_iou': avg_iou}\n", 57 | " print(\"K = {}, Avg IOU = {:.4f}\".format(k, avg_iou))" 58 | ] 59 | }, 60 | { 61 | "cell_type": "code", 62 | "execution_count": 4, 63 | "metadata": {}, 64 | "outputs": [ 65 | { 66 | "name": "stdout", 67 | "output_type": "stream", 68 | "text": "[INFO] Plot average IOU curve\n" 69 | }, 70 | { 71 | "data": { 72 | "image/png": "\n", 73 | "image/svg+xml": "\r\n\r\n\r\n\r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n\r\n", 74 | "text/plain": "
" 75 | }, 76 | "metadata": { 77 | "needs_background": "light" 78 | }, 79 | "output_type": "display_data" 80 | } 81 | ], 82 | "source": [ 83 | "print('[INFO] Plot average IOU curve')\n", 84 | "plt.figure()\n", 85 | "plt.plot(range(2, 11), [results[k][\"avg_iou\"] for k in range(2, 11)], \"o-\")\n", 86 | "plt.ylabel(\"Avg IOU\")\n", 87 | "plt.xlabel(\"K (#anchors)\")\n", 88 | "plt.show()" 89 | ] 90 | }, 91 | { 92 | "cell_type": "code", 93 | "execution_count": 5, 94 | "metadata": {}, 95 | "outputs": [ 96 | { 97 | "name": "stdout", 98 | "output_type": "stream", 99 | "text": "[INFO] The result anchors:\n[[0.7794355 0.8338808 ]\n [0.33883529 0.68815335]\n [0.61044288 0.40655773]\n [0.19493034 0.35335266]\n [0.07805765 0.13006786]]\n" 100 | } 101 | ], 102 | "source": [ 103 | "print('[INFO] The result anchors:')\n", 104 | "best_k = 5\n", 105 | "anchors = results[best_k]['anchors']\n", 106 | "print(anchors)" 107 | ] 108 | }, 109 | { 110 | "cell_type": "code", 111 | "execution_count": 6, 112 | "metadata": {}, 113 | "outputs": [ 114 | { 115 | "name": "stdout", 116 | "output_type": "stream", 117 | "text": "[INFO] Visualizing anchors\n" 118 | }, 119 | { 120 | "data": { 121 | "image/png": "\n", 122 | "image/svg+xml": "\r\n\r\n\r\n\r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n\r\n", 123 | "text/plain": "
" 124 | }, 125 | "metadata": { 126 | "needs_background": "light" 127 | }, 128 | "output_type": "display_data" 129 | } 130 | ], 131 | "source": [ 132 | "print('[INFO] Visualizing anchors')\n", 133 | "w_img, h_img = 600, 600\n", 134 | "\n", 135 | "anchors[:, 0] *= w_img\n", 136 | "anchors[:, 1] *= h_img\n", 137 | "anchors = np.round(anchors).astype(np.int)\n", 138 | "\n", 139 | "rects = np.empty((5, 4), dtype=np.int)\n", 140 | "for i in range(len(anchors)):\n", 141 | " w, h = anchors[i]\n", 142 | " x1, y1 = -(w // 2), -(h // 2)\n", 143 | " rects[i] = [x1, y1, w, h]\n", 144 | "\n", 145 | "fig = plt.figure(figsize=(8, 6))\n", 146 | "ax = fig.add_subplot()\n", 147 | "for rect in rects:\n", 148 | " x1, y1, w, h = rect\n", 149 | " rect1 = Rectangle((x1, y1), w, h, color='royalblue', fill=False, linewidth=2)\n", 150 | " ax.add_patch(rect1)\n", 151 | "plt.xlim([-(w_img // 2), w_img // 2])\n", 152 | "plt.ylim([-(h_img // 2), h_img // 2])\n", 153 | "\n", 154 | "plt.show()" 155 | ] 156 | } 157 | ], 158 | "metadata": { 159 | "kernelspec": { 160 | "display_name": "Python 3", 161 | "language": "python", 162 | "name": "python3" 163 | }, 164 | "language_info": { 165 | "codemirror_mode": { 166 | "name": "ipython", 167 | "version": 3 168 | }, 169 | "file_extension": ".py", 170 | "mimetype": "text/x-python", 171 | "name": "python", 172 | "nbconvert_exporter": "python", 173 | "pygments_lexer": "ipython3", 174 | "version": "3.7.4-final" 175 | }, 176 | "toc": { 177 | "base_numbering": 1, 178 | "nav_menu": {}, 179 | "number_sections": true, 180 | "sideBar": true, 181 | "skip_h1_title": false, 182 | "title_cell": "Table of Contents", 183 | "title_sidebar": "Contents", 184 | "toc_cell": false, 185 | "toc_position": {}, 186 | "toc_section_display": true, 187 | "toc_window_display": false 188 | } 189 | }, 190 | "nbformat": 4, 191 | "nbformat_minor": 2 192 | } -------------------------------------------------------------------------------- /Anchor-Kmeans/gen_anchors.py: -------------------------------------------------------------------------------- 1 | from kmeans import AnchorKmeans 2 | from datasets import AnnotParser 3 | import argparse 4 | 5 | 6 | def main(args): 7 | file_type = args["type"] 8 | k = args["k_clusters"] 9 | annot_dir = args["dir_path"] 10 | parser = AnnotParser(file_type) 11 | 12 | print("[INFO] Load datas from {}".format(annot_dir)) 13 | boxes = parser.parse(annot_dir) 14 | 15 | print("[INFO] Initialize model") 16 | model = AnchorKmeans(k) 17 | 18 | print("[INFO] Training...") 19 | model.fit(boxes) 20 | 21 | anchors = model.anchors_ 22 | print("[INFO] The results anchors:\n{}".format(anchors)) 23 | 24 | 25 | if __name__ == "__main__": 26 | ap = argparse.ArgumentParser() 27 | ap.add_argument("-d", 28 | "--dir_path", 29 | required=True, 30 | help="directory path of annotation files") 31 | ap.add_argument("-t", 32 | "--type", 33 | choices=['xml', 'json', 'csv'], 34 | default='xml', 35 | help="type of annotation file") 36 | ap.add_argument("-k", 37 | "--k_clusters", 38 | type=int, 39 | default=5, 40 | help="the number of clusters") 41 | args = vars(ap.parse_args()) 42 | main(args) 43 | -------------------------------------------------------------------------------- /Anchor-Kmeans/imgs/avgiou.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ybcc2015/DeepLearning-Utils/629bea84be257005dd3c331f14ba390c5ea59065/Anchor-Kmeans/imgs/avgiou.png -------------------------------------------------------------------------------- /Anchor-Kmeans/kmeans.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class AnchorKmeans(object): 5 | """ 6 | K-means clustering on bounding boxes to generate anchors 7 | """ 8 | def __init__(self, k, max_iter=300, random_seed=None): 9 | self.k = k 10 | self.max_iter = max_iter 11 | self.random_seed = random_seed 12 | self.n_iter = 0 13 | self.anchors_ = None 14 | self.labels_ = None 15 | self.ious_ = None 16 | 17 | def fit(self, boxes): 18 | """ 19 | Run K-means cluster on input boxes. 20 | 21 | :param boxes: 2-d array, shape(n, 2), form as (w, h) 22 | :return: None 23 | """ 24 | assert self.k < len(boxes), "K must be less than the number of data." 25 | 26 | # If the current number of iterations is greater than 0, then reset 27 | if self.n_iter > 0: 28 | self.n_iter = 0 29 | 30 | np.random.seed(self.random_seed) 31 | n = boxes.shape[0] 32 | 33 | # Initialize K cluster centers (i.e., K anchors) 34 | self.anchors_ = boxes[np.random.choice(n, self.k, replace=True)] 35 | 36 | self.labels_ = np.zeros((n,)) 37 | 38 | while True: 39 | self.n_iter += 1 40 | 41 | # If the current number of iterations is greater than max number of iterations , then break 42 | if self.n_iter > self.max_iter: 43 | break 44 | 45 | self.ious_ = self.iou(boxes, self.anchors_) 46 | distances = 1 - self.ious_ 47 | cur_labels = np.argmin(distances, axis=1) 48 | 49 | # If anchors not change any more, then break 50 | if (cur_labels == self.labels_).all(): 51 | break 52 | 53 | # Update K anchors 54 | for i in range(self.k): 55 | self.anchors_[i] = np.mean(boxes[cur_labels == i], axis=0) 56 | 57 | self.labels_ = cur_labels 58 | 59 | @staticmethod 60 | def iou(boxes, anchors): 61 | """ 62 | Calculate the IOU between boxes and anchors. 63 | 64 | :param boxes: 2-d array, shape(n, 2) 65 | :param anchors: 2-d array, shape(k, 2) 66 | :return: 2-d array, shape(n, k) 67 | """ 68 | # Calculate the intersection, 69 | # the new dimension are added to construct shape (n, 1) and shape (1, k), 70 | # so we can get (n, k) shape result by numpy broadcast 71 | w_min = np.minimum(boxes[:, 0, np.newaxis], anchors[np.newaxis, :, 0]) 72 | h_min = np.minimum(boxes[:, 1, np.newaxis], anchors[np.newaxis, :, 1]) 73 | inter = w_min * h_min 74 | 75 | # Calculate the union 76 | box_area = boxes[:, 0] * boxes[:, 1] 77 | anchor_area = anchors[:, 0] * anchors[:, 1] 78 | union = box_area[:, np.newaxis] + anchor_area[np.newaxis] 79 | 80 | return inter / (union - inter) 81 | 82 | def avg_iou(self): 83 | """ 84 | Calculate the average IOU with closest anchor. 85 | 86 | :return: None 87 | """ 88 | return np.mean(self.ious_[np.arange(len(self.labels_)), self.labels_]) 89 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DeepLearning-Utils 2 | This repository contains some commonly ulits in deep learning. 3 | -------------------------------------------------------------------------------- /cal_mean_std.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | from imutils import paths 4 | 5 | 6 | def calculate_mean_std(img_root, channels=3): 7 | """ 8 | Calculate the mean and standard deviation of the training images. 9 | 10 | Arguments: 11 | img_root {str} -- the root directory of training images 12 | channels {int} -- the numbers of channles 13 | 14 | Returns: 15 | mean {1-d numpy array} -- mean value of each channel 16 | std {1-d numpy array} -- standard deviation of each channel 17 | """ 18 | total_pixel = 0 19 | channel_sum = np.zeros(channels) 20 | channel_square_sum = np.zeros(channels) 21 | 22 | for img_path in paths.list_images(img_root): 23 | img = cv2.imread(img_path) 24 | img = img / 255. 25 | channel_sum = np.sum(img, axis=(0, 1)) 26 | channel_square_sum = np.sum(img ** 2, axis=(0, 1)) 27 | total_pixel += img.shape[0] * img.shape[1] 28 | 29 | mean = channel_sum / total_pixel 30 | std = np.sqrt(channel_square_sum / total_pixel - mean ** 2) 31 | 32 | if channels == 3: # bgr -> rgb 33 | mean = mean[::-1] 34 | std = std[::-1] 35 | 36 | return mean, std 37 | -------------------------------------------------------------------------------- /wtm/find_gt_files.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | DIR=$1 4 | find $DIR -name '*_gt.json' | wc -l 5 | 6 | -------------------------------------------------------------------------------- /wtm/img_fill.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | import json 4 | 5 | 6 | class ImgFill(object): 7 | def __init__(self, json_file): 8 | self.box = self.get_box(json_file) 9 | 10 | @staticmethod 11 | def get_box(json_file): 12 | """ 13 | 获取json文件中的box_b的坐标 14 | Args: 15 | json_file (str): json文件路径 16 | Returns: 17 | [list]: box_b的坐标,格式为[left, top, right, bottom] 18 | """ 19 | res = [] 20 | 21 | with open("./boxes.json") as f: 22 | json_data = json.load(f) 23 | 24 | for box in json_data["boxes"]: 25 | if box["name"] == "box_b": 26 | print(box["rectangle"]) 27 | res = box["rectangle"]["left_top"] 28 | res.extend(box["rectangle"]["right_bottom"]) 29 | 30 | return res 31 | 32 | def is_box_valid(self, img): 33 | """ 34 | 判断box_b指定的区域是否超出img的边界 35 | Args: 36 | img (numpy array): 目标图像 37 | Returns: 38 | [bool]: ture or false 39 | """ 40 | left, top = self.box[:2] 41 | h = self.box[3] - self.box[1] 42 | w = self.box[2] - self.box[0] 43 | h_img, w_img = img.shape[:2] 44 | 45 | cond1 = left >= 0 and (left + w) <= w_img 46 | cond2 = top >= 0 and (top + h) <= h_img 47 | return cond1 and cond2 48 | 49 | def fill(self, dst_img, src_img, mode="stretch"): 50 | """ 51 | 图像填充函数 52 | Args: 53 | dst_img (numpy array): 目标图像 54 | src_img (numpy array): 源图像 (待填充的图像) 55 | mode (str): 填充模式, "stretch"指拉伸填充, "keep"指保持比例填充 56 | Returns: 57 | [numpy array]: 填充后的图像 58 | """ 59 | ok = self.is_box_valid(dst_img) 60 | if not ok: 61 | return 62 | 63 | # 得到填充区域的左上角顶点, 以及宽和高 64 | left, top = self.box[:2] 65 | h = self.box[3] - self.box[1] 66 | w = self.box[2] - self.box[0] 67 | 68 | assert mode in ["stretch", "keep"], "当前仅支持'stretch'和'keep'填充模式!" 69 | 70 | if mode == "stretch": 71 | src_img = cv2.resize(src_img, (w, h)) 72 | dst_img[top: top + h, left: left + w] = src_img 73 | 74 | if mode == "keep": 75 | # 基于源图的长边得到缩放比例 76 | h_img, w_img = src_img.shape[:2] 77 | ratio = h / h_img if h_img >= w_img else w / w_img 78 | 79 | # 源图等比例缩放 80 | h_new = int(round(h_img * ratio)) 81 | w_new = int(round(w_img * ratio)) 82 | src_img = cv2.resize(src_img, (w_new, h_new)) 83 | 84 | # 如果缩放后的高小于填充区域的高, 则沿y轴方向进行pad 85 | if h_new < h: 86 | pad = h - h_new 87 | pad_size = (pad // 2, pad - pad // 2) 88 | np.pad(src_img, (pad_size, (0, 0))) 89 | h_new = h 90 | 91 | # 如果缩放后的宽小于填充区域的高宽, 则沿x轴方向进行pad 92 | if w_new < w: 93 | pad = w - w_new 94 | pad_size = (pad // 2, pad - pad // 2) 95 | np.pad(src_img, ((0, 0), pad_size)) 96 | w_new = w 97 | 98 | dst_img[top: top + h_new, left: left + w_new] = src_img 99 | 100 | return dst_img 101 | --------------------------------------------------------------------------------