├── .gitignore
├── Anchor-Kmeans
├── README.md
├── datasets.py
├── demo.ipynb
├── gen_anchors.py
├── imgs
│ └── avgiou.png
└── kmeans.py
├── README.md
├── cal_mean_std.py
└── wtm
├── find_gt_files.sh
└── img_fill.py
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__/datasets.cpython-37.pyc
2 | __pycache__/kmeans.cpython-37.pyc
3 | .vscode/settings.json
4 | .idea/inspectionProfiles/profiles_settings.xml
5 | .idea/vcs.xml
6 | .idea/modules.xml
7 | .idea/misc.xml
8 | .idea/Anchor-Kmeans.iml
9 | .idea/.gitignore
10 |
--------------------------------------------------------------------------------
/Anchor-Kmeans/README.md:
--------------------------------------------------------------------------------
1 | # Anchor-Kmeans
2 | Implementation of kmeans clustering on bounding boxes to generate anchors, as mentioned in the [YOLOv2](https://arxiv.org/abs/1612.08242).
3 |
4 | ## Usage
5 | Currently supports three types of annotation file:
6 | - [labelme json file](https://github.com/wkentaro/labelme)
7 | - [VOC xml file](https://pjreddie.com/projects/pascal-voc-dataset-mirror/)
8 | - csv file, each line is a coordinate values separated by a comma, form as `xmin, ymin, xmax, ymax`
9 |
10 | To generate anchors of your own dataset is very simple, just execute the `gen_anchors.py` script with 3 arguments:
11 |
12 | ```bash
13 | python gen_anchors.py -d /path to your/annotations-dir -t [annotation file type, defualt 'xml'] -k [num of clusters, default 5]
14 | ```
15 |
16 | ## Test
17 |
18 | I have tested it on the VOC2012 dataset, the trend of average iou with k value is shown in the figure below
19 |
20 | 
21 |
22 | See the detailed test code in [demo.ipynb](./demo.ipynb)
23 |
24 |
--------------------------------------------------------------------------------
/Anchor-Kmeans/datasets.py:
--------------------------------------------------------------------------------
1 | import xml.etree.ElementTree as ET
2 | import numpy as np
3 | import glob
4 | import os
5 | import json
6 | import cv2
7 |
8 |
9 | class AnnotParser(object):
10 | def __init__(self, file_type):
11 | assert file_type in ['csv', 'xml', 'json'], "Unsupported file type."
12 | self.file_type = file_type
13 |
14 | def parse(self, annot_dir):
15 | """
16 | Parse annotation file, the file type must be csv or xml or json.
17 |
18 | :param annot_dir: directory path of annotation files
19 | :return: 2-d array, shape as (n, 2), each row represents a bbox, and each column
20 | represents the corresponding width and height after normalized
21 | """
22 | if self.file_type == 'xml':
23 | return self.parse_xml(annot_dir)
24 | elif self.file_type == 'json':
25 | return self.parse_json(annot_dir)
26 | else:
27 | return self.parse_csv(annot_dir)
28 |
29 | @staticmethod
30 | def parse_xml(annot_dir):
31 | """
32 | Parse xml annotation file in VOC.
33 | """
34 | boxes = []
35 |
36 | for xml_file in glob.glob(os.path.join(annot_dir, '*.xml')):
37 | tree = ET.parse(xml_file)
38 |
39 | h_img = int(tree.findtext('./size/height'))
40 | w_img = int(tree.findtext('./size/width'))
41 |
42 | for obj in tree.iter('object'):
43 | xmin = int(round(float(obj.findtext('bndbox/xmin'))))
44 | ymin = int(round(float(obj.findtext('bndbox/ymin'))))
45 | xmax = int(round(float(obj.findtext('bndbox/xmax'))))
46 | ymax = int(round(float(obj.findtext('bndbox/ymax'))))
47 |
48 | w_norm = (xmax - xmin) / w_img
49 | h_norm = (ymax - ymin) / h_img
50 |
51 | boxes.append([w_norm, h_norm])
52 |
53 | return np.array(boxes)
54 |
55 | @staticmethod
56 | def parse_json(annot_dir):
57 | """
58 | Parse labelme json annotation file.
59 | """
60 | boxes = []
61 |
62 | for js_file in glob.glob(os.path.join(annot_dir, '*.json')):
63 | with open(js_file) as f:
64 | data = json.load(f)
65 |
66 | h_img = data['imageHeight']
67 | w_img = data['imageWidth']
68 |
69 | for shape in data['shapes']:
70 | points = shape['points']
71 | xmin = int(round(points[0][0]))
72 | ymin = int(round(points[0][1]))
73 | xmax = int(round(points[1][0]))
74 | ymax = int(round(points[1][1]))
75 |
76 | w_norm = (xmax - xmin) / w_img
77 | h_norm = (ymax - ymin) / h_img
78 |
79 | boxes.append([w_norm, h_norm])
80 |
81 | return np.array(boxes)
82 |
83 | @staticmethod
84 | def parse_csv(annot_dir):
85 | """
86 | Parse csv annotation file.
87 | """
88 | boxes = []
89 |
90 | for csv_file in glob.glob(os.path.join(annot_dir, '*.csv')):
91 | with open(csv_file) as f:
92 | lines = f.readlines()
93 |
94 | for line in lines:
95 | items = line.strip().split(',')
96 | img = cv2.imread(items[0])
97 | h_img, w_img = img.shape[:2]
98 | xmin, ymin, xmax, ymax = list(map(int, items[1:-1]))
99 |
100 | w_norm = (xmax - xmin) / w_img
101 | h_norm = (ymax - ymin) / h_img
102 |
103 | boxes.append([w_norm, h_norm])
104 |
105 | return np.array(boxes)
106 |
--------------------------------------------------------------------------------
/Anchor-Kmeans/demo.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "import numpy as np\n",
10 | "from kmeans import AnchorKmeans\n",
11 | "from datasets import AnnotParser\n",
12 | "from matplotlib import pyplot as plt\n",
13 | "from matplotlib.patches import Rectangle\n",
14 | "%matplotlib inline\n",
15 | "\n",
16 | "plt.style.use('ggplot')"
17 | ]
18 | },
19 | {
20 | "cell_type": "code",
21 | "execution_count": 2,
22 | "metadata": {},
23 | "outputs": [
24 | {
25 | "name": "stdout",
26 | "output_type": "stream",
27 | "text": "[INFO] Load datas\nboxes shape : (40138, 2)\n"
28 | }
29 | ],
30 | "source": [
31 | "print('[INFO] Load datas')\n",
32 | "annot_dir = \"/PATH TO YOUR/VOCdevkit/VOC2012/Annotations\"\n",
33 | "parser = AnnotParser('xml')\n",
34 | "boxes = parser.parse_xml(annot_dir)\n",
35 | "print('boxes shape : {}'.format(boxes.shape))"
36 | ]
37 | },
38 | {
39 | "cell_type": "code",
40 | "execution_count": 3,
41 | "metadata": {},
42 | "outputs": [
43 | {
44 | "name": "stdout",
45 | "output_type": "stream",
46 | "text": "[INFO] Run anchor k-means with k = 2,3,...,10\nK = 2, Avg IOU = 0.4646\nK = 3, Avg IOU = 0.5391\nK = 4, Avg IOU = 0.5801\nK = 5, Avg IOU = 0.6016\nK = 6, Avg IOU = 0.6252\nK = 7, Avg IOU = 0.6434\nK = 8, Avg IOU = 0.6596\nK = 9, Avg IOU = 0.6732\nK = 10, Avg IOU = 0.6838\n"
47 | }
48 | ],
49 | "source": [
50 | "print('[INFO] Run anchor k-means with k = 2,3,...,10')\n",
51 | "results = {}\n",
52 | "for k in range(2, 11):\n",
53 | " model = AnchorKmeans(k, random_seed=333)\n",
54 | " model.fit(boxes)\n",
55 | " avg_iou = model.avg_iou()\n",
56 | " results[k] = {'anchors': model.anchors_, 'avg_iou': avg_iou}\n",
57 | " print(\"K = {}, Avg IOU = {:.4f}\".format(k, avg_iou))"
58 | ]
59 | },
60 | {
61 | "cell_type": "code",
62 | "execution_count": 4,
63 | "metadata": {},
64 | "outputs": [
65 | {
66 | "name": "stdout",
67 | "output_type": "stream",
68 | "text": "[INFO] Plot average IOU curve\n"
69 | },
70 | {
71 | "data": {
72 | "image/png": "\n",
73 | "image/svg+xml": "\r\n\r\n\r\n\r\n",
74 | "text/plain": ""
75 | },
76 | "metadata": {
77 | "needs_background": "light"
78 | },
79 | "output_type": "display_data"
80 | }
81 | ],
82 | "source": [
83 | "print('[INFO] Plot average IOU curve')\n",
84 | "plt.figure()\n",
85 | "plt.plot(range(2, 11), [results[k][\"avg_iou\"] for k in range(2, 11)], \"o-\")\n",
86 | "plt.ylabel(\"Avg IOU\")\n",
87 | "plt.xlabel(\"K (#anchors)\")\n",
88 | "plt.show()"
89 | ]
90 | },
91 | {
92 | "cell_type": "code",
93 | "execution_count": 5,
94 | "metadata": {},
95 | "outputs": [
96 | {
97 | "name": "stdout",
98 | "output_type": "stream",
99 | "text": "[INFO] The result anchors:\n[[0.7794355 0.8338808 ]\n [0.33883529 0.68815335]\n [0.61044288 0.40655773]\n [0.19493034 0.35335266]\n [0.07805765 0.13006786]]\n"
100 | }
101 | ],
102 | "source": [
103 | "print('[INFO] The result anchors:')\n",
104 | "best_k = 5\n",
105 | "anchors = results[best_k]['anchors']\n",
106 | "print(anchors)"
107 | ]
108 | },
109 | {
110 | "cell_type": "code",
111 | "execution_count": 6,
112 | "metadata": {},
113 | "outputs": [
114 | {
115 | "name": "stdout",
116 | "output_type": "stream",
117 | "text": "[INFO] Visualizing anchors\n"
118 | },
119 | {
120 | "data": {
121 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAfgAAAFpCAYAAABwEjqZAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAXaklEQVR4nO3dfYxd9X3n8ffUN0qrdlVKbh4YYwmkWiucKKVKBUj9o92AEhOl67QbvoUsKaFIViVQhNo/AgWVatdZuYpEhJqH1ShBhSWp8xVtBErYpQ5NhFZamkQs2pZ4t+skNExsB015SCskWJPZP86Z+jK+Y5/xfZr7nfdLsube3zn3nt/5zpn7ub/z5IXV1VUkSVItPzXrDkiSpPEz4CVJKsiAlySpIANekqSCDHhJkgoy4CVJKqg36htExE8DjwNvbN/vwcy8KyIuBg4B5wNPAh/OzFcj4o3A/cC7gH8Efjsznxm1H5Ik6ZRxjOBfAd6dmb8EXArsjYgrgD8BPpmZu4EXgJva+W8CXsjMXwQ+2c4nSZLGaOSAz8zVzPzn9ukb2n+rwLuBB9v2+4APtI/3tc9pp18ZEQuj9kOSJJ0ylmPwEbEjIp4CngMOA98FXszMk+0sy8DO9vFO4FmAdvpLwJvG0Q9JktQY+Rg8QGa+BlwaEecBXwYuGTLb2j1xh43WT7tfbkTsB/a37/+ucfRTkqQ5c857uMcS8Gsy88WI+AZwBXBeRPTaUfqFwLF2tmVgF7AcET3g54Hnh7zXErDUPl09duzY+lk0RL/fZ2VlZdbd2PKsU3fWqhvr1I116m5xcXGk14+8iz4i3tyO3ImInwGuAo4AXwc+2M52A/BQ+/jh9jnt9L/OTP/HG0mSxmgcx+AvAL4eEf8L+BZwODO/AnwM+P2IOEpzjP3z7fyfB97Utv8+cNsY+iBJkgYszMl/F+su+o7c/dWNderOWnVjnbqxTt21u+jP+Ri8d7KTJKkgA16SpIIMeEmSCjLgJUkqyICXJKkgA16SpIIMeEmSCjLgJUkqyICXJKkgA16SpIIMeEmSCjLgJUkqyICXJKkgA16SpIIMeEmSCjLgJUkqyICXJKkgA16SpIIMeEmSCjLgJUkqyICXJKmg3qw7sF1cf+DklJZ0YkrLmXfWqTtr1Y116mZ6dXrgzu0dcY7gJUkqaHt/vZmBSX+j7Pf7rKysTHQZFVin7qxVN9apm2nUaXp7TLc2R/CSJBVkwEuSVJABL0lSQQa8JEkFGfCSJBVkwEuSVJABL0lSQQa8JEkFGfCSJBVkwEuSVJABL0lSQQa8JEkFGfCSJBVkwEuSVJABL0lSQQa8JEkFGfCSJBVkwEuSVJABL0lSQQa8JEkF9WbdAWnarj9wEjgx627MEWvVzak6PXCnH62aPUfwkiQVNPLXzIjYBdwPvA34CbCUmfdExPnAl4CLgGeAyMwXImIBuAd4H/Ay8JHMfHLUfkib5Sirm36/z8rKyqy7seX1+3323ureDm0d4xjBnwT+IDMvAa4Abo6IPcBtwGOZuRt4rH0OcDWwu/23H/jsGPogSZIGjBzwmXl8bQSemf8EHAF2AvuA+9rZ7gM+0D7eB9yfmauZ+QRwXkRcMGo/JEnSKWPdRxkRFwG/DPwN8NbMPA7Nl4CIeEs7207g2YGXLbdtx9e9136aET6ZSb/fH2dXZ6DZdTfp9ej1egVqNWnT+V1U4TbVTa936uPUem1sOtuTf+MwxoCPiJ8D/gK4NTN/HBEbzbowpG11fUNmLgFLa9OrHAOc9Hp4vLQ769SN21Q3g2FivTY2ze1p3n8Pi4uLI71+LGfRR8QbaML9C5n5l23zj9Z2vbc/n2vbl4FdAy+/EDg2jn5IkqTGOM6iXwA+DxzJzLsHJj0M3AAcbH8+NNB+S0QcAi4HXlrblS9JksZjHLvofxX4MPC3EfFU2/aHNMGeEXET8APgmnbaIzSXyB2luUzuxjH0QZIkDRg54DPzvzP8uDrAlUPmXwVuHnW5kiRpY97JTpKkggx4SZIKMuAlSSrIgJckqSADXpKkggx4SZIKMuAlSSrIgJckqSADXpKkggx4SZIKMuAlSSrIgJckqSADXpKkggx4SZIKMuAlSSrIgJckqSADXpKkggx4SZIKMuAlSSrIgJckqSADXpKkggx4SZIKMuAlSSrIgJckqSADXpKkggx4SZIKMuAlSSrIgJckqSADXpKkggx4SZIKMuAlSSrIgJckqSADXpKkggx4SZIKMuAlSSqoN+sOaHauP3By1l2Yqe2+/t2dmHUH5sSpOm23beuBO42SrcgRvCRJBfm1S9vu2/fa6Gq7rfe56vf7rKyszLobW16/32fvrc0ofrtsW9ttT8W8cQQvSVJBBrwkSQUZ8JIkFWTAS5JUkAEvSVJBBrwkSQVtj2s5pC1iPi8rmr8b3WyXy9SkMxnLX0FE3Au8H3guM9/Rtp0PfAm4CHgGiMx8ISIWgHuA9wEvAx/JzCfH0Q9JktQY19fcPwM+Bdw/0HYb8FhmHoyI29rnHwOuBna3/y4HPtv+lLaNeRphztONbuZzD4k0GWM5Bp+ZjwPPr2veB9zXPr4P+MBA+/2ZuZqZTwDnRcQF4+iHJElqTHIY8dbMPA6Qmccj4i1t+07g2YH5ltu244Mvjoj9wP729fT7/Ql2dRqa45iTXo9er7eJZUynT1vPLNd7/mq+uW1q1mZX317v1Mfp/NRrVJuv93S2p/n7O5uEWewnXBjStrq+ITOXgKW16fOyi/BsJr0e57I7tUptN2uW6z1PNZ+nXfRrZtHfwTCZt3qNajPrO83tad5/D4uLiyO9fpKXyf1obdd7+/O5tn0Z2DUw34XAsQn2Q5KkbWeSI/iHgRuAg+3Phwbab4mIQzQn1720titfkiSNx7guk/tz4NeBfkQsA3fRBHtGxE3AD4Br2tkfoblE7ijNZXI3jqMPkiTplLEEfGZet8GkK4fMuwrcPI7lSpKk4bxVrSRJBRnwkiQVZMBLklSQAS9JUkEGvCRJBRnwkiQVZMBLklSQAS9JUkEGvCRJBRnwkiQVZMBLklSQAS9JUkEGvCRJBRnwkiQVZMBLklSQAS9JUkEGvCRJBRnwkiQVZMBLklSQAS9JUkEGvCRJBRnwkiQVZMBLklSQAS9JUkEGvCRJBRnwkiQVZMBLklRQb9YdkDQZ1x84OaZ3OjGm94EH7vQjR5oWR/CSJBXk12mpuFFHzf1+n5WVlZHeY3x7EyR15QhekqSCDHhJkgoy4CVJKsiAlySpIANekqSCDHhJkgoy4CVJKsiAlySpIANekqSCDHhJkgoy4CVJKsiAlySpIANekqSCDHhJkgoy4CVJKmhm/x98ROwF7gF2AJ/LzIOz6oskSdXMZAQfETuATwNXA3uA6yJizyz6IklSRbPaRX8ZcDQzv5eZrwKHgH0z6oskSeXMKuB3As8OPF9u2yRJ0hjM6hj8wpC21cEnEbEf2A+QmfT7/Wn0a4JOAEx8PXq93iaWMZ0+bT2zXO9pLns8y9rcNjXZvmyd5Zyu1zv1cbp9/qY2X+/xbE9ns10/215vVgG/DOwaeH4hcGxwhsxcApbap6srKytT6tpkTXo9+v3+ppdRpbabNcv1nuayR13WuWxTk+rLVlvOoMEw2W5/U5tZ33FuT2cz77+HxcXFkV4/q4D/FrA7Ii4GfghcC3xoRn2RJKmcmRyDz8yTwC3Ao8CRpimfnkVfJEmqaGbXwWfmI8Ajs1q+JEmVeSc7SZIKMuAlSSrIgJckqSADXpKkggx4SZIKMuAlSSrIgJckqSADXpKkggx4SZIKMuAlSSrIgJckqSADXpKkggx4SZIKMuAlSSrIgJckqaCZ/X/wkqbj+gMnR3yHE2Pph6TpcgQvSVJBjuCloh64czx/3v1+n5WVlbG8l6TpcQQvSVJBBrwkSQUZ8JIkFWTAS5JUkAEvSVJBBrwkSQUZ8JIkFWTAS5JUkAEvSVJBBrwkSQUZ8JIkFWTAS5JUkAEvSVJBBrwkSQUZ8JIkFWTAS5JUkAEvSVJBBrwkSQUZ8JIkFWTAS5JUkAEvSVJBBrwkSQUZ8JIkFWTAS5JUkAEvSVJBBrwkSQUZ8JIkFWTAS5JUUG+UF0fENcAfA5cAl2Xmtwem3Q7cBLwGfDQzH23b9wL3ADuAz2XmwVH6IEmSTjfqCP7vgN8CHh9sjIg9wLXA24G9wGciYkdE7AA+DVwN7AGua+eVJEljNNIIPjOPAETE+kn7gEOZ+Qrw/Yg4ClzWTjuamd9rX3eonfc7o/RDkiS93kgBfwY7gScGni+3bQDPrmu/fNgbRMR+YD9AZtLv9yfQzWk6ATDx9ej1eptYxnT6tPXMcr3nr+ab26ZmbXb17fVOfZzOT71Gtfl6T2d7mr+/s0k4a8BHxNeAtw2ZdEdmPrTByxaGtK0y/JDA6rA3yMwlYGltnpWVlbN1dS5Mej36/f6ml1Gltps1y/Wep5qfyzY1a7Po72CYzFu9RrWZ9Z3m9jTvv4fFxcWRXn/WgM/Mq87hfZeBXQPPLwSOtY83ape2jesPnJx1FzbhxKw7IOkcTGoX/cPAFyPibmAR2A18k2ZkvzsiLgZ+SHMi3ocm1AdJkratUS+T+03gT4E3A1+NiKcy872Z+XREJM3JcyeBmzPztfY1twCP0lwmd29mPj3SGkhz5IE7J/WdenLmcRe9pNHPov8y8OUNpn0c+PiQ9keAR0ZZriRJOjPvZCdJUkEGvCRJBRnwkiQVZMBLklSQAS9JUkHzd82Oxm6+broyPtt1vTfPG910c6pOblvaChzBS5JUkCP4bWweb7oyDmujq+26/pvljW666ff77L21GcW7bWkrcAQvSVJBBrwkSQUZ8JIkFWTAS5JUkAEvSVJBBrwkSQUZ8JIkFWTAS5JUkAEvSVJBBrwkSQUZ8JIkFWTAS5JUkAEvSVJBBrwkSQUZ8JIkFWTAS5JUkAEvSVJBBrwkSQUZ8JIkFWTAS5JUkAEvSVJBBrwkSQUZ8JIkFWTAS5JUkAEvSVJBBrwkSQUZ8JIkFWTAS5JUkAEvSVJBBrwkSQUZ8JIkFWTAS5JUkAEvSVJBBrwkSQUZ8JIkFWTAS5JUkAEvSVJBvVFeHBGfAH4DeBX4LnBjZr7YTrsduAl4DfhoZj7atu8F7gF2AJ/LzIOj9EGSJJ1u1BH8YeAdmflO4O+B2wEiYg9wLfB2YC/wmYjYERE7gE8DVwN7gOvaeSVJ0hiNNILPzL8aePoE8MH28T7gUGa+Anw/Io4Cl7XTjmbm9wAi4lA773dG6YckSXq9kQJ+nd8FvtQ+3kkT+GuW2zaAZ9e1Xz7szSJiP7AfIDPp9/tj7OosnACY+Hr0er0CtZq06fwuqnCb6qbXO/Vxar02Np3tyb9x6BDwEfE14G1DJt2RmQ+189wBnAS+0E5bGDL/KsMPCawOW25mLgFLa/OsrKycratzYdLr0e/3J76MKqxTN25T3QyGifXa2DS3p3n/PSwuLo70+rMGfGZedabpEXED8H7gysxcC+tlYNfAbBcCx9rHG7VLkqQxGfUs+r3Ax4Bfy8yXByY9DHwxIu4GFoHdwDdpRva7I+Ji4Ic0J+J9aJQ+SOfq+gMnZ92FOXFi1h2YE9ZJW8uoZ9F/CvhXwOGIeCoi/jNAZj4NJM3Jc/8NuDkzX8vMk8AtwKPAkWbWfHrEPkiSpHUWVleHHgLfalaPHZvvPflro8UH7hzneY2n83hpN9apO2vVjXXqZhp1mtbn7aS1x+CHndPWiXeykySpIANekqSCDHhJkgoy4CVJKsiAlySpIANekqSCDHhJkgoy4CVJKsiAlySpIANekqSCDHhJkgoy4CVJKsiAlySpIANekqSCDHhJkgoy4CVJKsiAlySpIANekqSCDHhJkgoy4CVJKsiAlySpoN6sO7DdXH/g5ISXcGLC71+FderOWnVjnbqxTtPiCF6SpIIWVldXZ92HLlaPHTs26z7MhX6/z8rKyqy7seVZp+6sVTfWqRvr1N3i4iLAwrm+3hG8JEkFGfCSJBVkwEuSVJABL0lSQQa8JEkFGfCSJBVkwEuSVJABL0lSQQa8JEkFGfCSJBVkwEuSVJABL0lSQQa8JEkFGfCSJBVkwEuSVJABL0lSQQa8JEkFGfCSJBVkwEuSVJABL0lSQQa8JEkF9UZ5cUT8R2Af8BPgOeAjmXksIhaAe4D3AS+37U+2r7kBuLN9iwOZed8ofZAkSacbdQT/icx8Z2ZeCnwF+KO2/Wpgd/tvP/BZgIg4H7gLuBy4DLgrIn5hxD5IkqR1Rgr4zPzxwNOfBVbbx/uA+zNzNTOfAM6LiAuA9wKHM/P5zHwBOAzsHaUPkiTpdCPtogeIiI8DvwO8BPybtnkn8OzAbMtt20btkiRpjM4a8BHxNeBtQybdkZkPZeYdwB0RcTtwC80u+IUh86+eoX3YcvfT7N4nM1lcXDxbV9WyVt1Yp+6sVTfWqRvrNB1nDfjMvKrje30R+CpNwC8DuwamXQgca9t/fV37NzZY7hKwBBAR387MX+nYj23NWnVjnbqzVt1Yp26sU3dtrc759SMdg4+I3QNP/y3wv9vHDwO/ExELEXEF8FJmHgceBd4TEb/Qnlz3nrZNkiSN0ajH4A9GxL+muUzuH4Dfa9sfoblE7ijNZXI3AmTm8+2ldd9q5/sPmfn8iH2QJEnrjBTwmfnvNmhfBW7eYNq9wL2bXNTSJuffzqxVN9apO2vVjXXqxjp1N1KtFlZXh57jJkmS5pi3qpUkqaCRr4MfN29/201EfAL4DeBV4LvAjZn5YjvtduAm4DXgo5n5aNu+l6aGO4DPZebBWfR92iLiGuCPgUuAyzLz2wPTrNUGrMHrRcS9wPuB5zLzHW3b+cCXgIuAZ4DIzBfO9HlVXUTsAu6nubz6J8BSZt5jrV4vIn4aeBx4I00WP5iZd0XExcAh4HzgSeDDmflqRLyRpq7vAv4R+O3MfOZMy9iKI3hvf9vNYeAdmflO4O+B2wEiYg9wLfB2mrsEfiYidkTEDuDTNHXcA1zXzrsd/B3wWzR/TP/CWm3MGgz1Z5x+583bgMcyczfwWPscNvi82iZOAn+QmZcAVwA3t9uOtXq9V4B3Z+YvAZcCe9urzv4E+GRbpxdoBiC0P1/IzF8EPtnOd0ZbLuC9/W03mflXmXmyffoEzT0FoKnTocx8JTO/T3Mlw2Xtv6OZ+b3MfJXmG+K+afd7FjLzSGb+nyGTrNXGrME6mfk4sP6qn33A2h7D+4APDLQP+7wqLzOPr43AM/OfgCM0dyy1VgPa9f3n9ukb2n+rwLuBB9v29XVaq9+DwJXt3o8NbbmAh+b2txHxLPDvOTWC9/a3G/td4L+2j61Td9ZqY9agm7e29/ig/fmWtt36ARFxEfDLwN9grU7T7jF8iuZw9GGaw60vDgzeBmvxL3Vqp78EvOlM7z+TY/Czuv3tvDlbndp57qDZJfaFdtpG9Rj2Za5EnaBbrYbYlrXqqOzf1ZRs+/pFxM8BfwHcmpk/joiNZt22tcrM14BLI+I84Ms05wmtt1aLTddpJgE/q9vfzpuz1ak9ufD9wJXtvQdg4zpxhva5t4ltatC2rFVHZ6qNTvlRRFyQmcfb3crPte3bun4R8QaacP9CZv5l22ytNpCZL0bEN2jOWTgvInrtKH2wFmt1Wo6IHvDznH7I6HW24ln0uzPz/7ZP19/+9paIOERzQt1L7YbyKPCfBk6sew/tCWeVtWc4fwz4tcx8eWDSw8AXI+JuYJHmxJVv0nz7292eoflDmpPLPjTdXm851mpj38IadPEwcANwsP350ED7aZ9Xs+nidLXHhT8PHMnMuwcmWasBEfFm4P+14f4zwFU0J859HfggzXkv6+t0A/A/2ul/PTCwG2rLBTze/rarT9FcXnG43fX1RGb+XmY+HREJfIdm1/3N7W4gIuIWmnv/7wDuzcynZ9P16YqI3wT+FHgz8NWIeCoz32utNpaZJ7d7DdaLiD+n2VvYj4hlmj2LB4GMiJuAHwDXtLMP/bzaJn4V+DDwt+3xZYA/xFqtdwFwX3vFyk8BmZlfiYjvAIci4gDwP2m+LNH+/C8RcZRm5H7t2RbgnewkSSpoS55FL0mSRmPAS5JUkAEvSVJBBrwkSQUZ8JIkFWTAS5JUkAEvSVJBBrwkSQX9f73xA4985BUVAAAAAElFTkSuQmCC\n",
122 | "image/svg+xml": "\r\n\r\n\r\n\r\n",
123 | "text/plain": ""
124 | },
125 | "metadata": {
126 | "needs_background": "light"
127 | },
128 | "output_type": "display_data"
129 | }
130 | ],
131 | "source": [
132 | "print('[INFO] Visualizing anchors')\n",
133 | "w_img, h_img = 600, 600\n",
134 | "\n",
135 | "anchors[:, 0] *= w_img\n",
136 | "anchors[:, 1] *= h_img\n",
137 | "anchors = np.round(anchors).astype(np.int)\n",
138 | "\n",
139 | "rects = np.empty((5, 4), dtype=np.int)\n",
140 | "for i in range(len(anchors)):\n",
141 | " w, h = anchors[i]\n",
142 | " x1, y1 = -(w // 2), -(h // 2)\n",
143 | " rects[i] = [x1, y1, w, h]\n",
144 | "\n",
145 | "fig = plt.figure(figsize=(8, 6))\n",
146 | "ax = fig.add_subplot()\n",
147 | "for rect in rects:\n",
148 | " x1, y1, w, h = rect\n",
149 | " rect1 = Rectangle((x1, y1), w, h, color='royalblue', fill=False, linewidth=2)\n",
150 | " ax.add_patch(rect1)\n",
151 | "plt.xlim([-(w_img // 2), w_img // 2])\n",
152 | "plt.ylim([-(h_img // 2), h_img // 2])\n",
153 | "\n",
154 | "plt.show()"
155 | ]
156 | }
157 | ],
158 | "metadata": {
159 | "kernelspec": {
160 | "display_name": "Python 3",
161 | "language": "python",
162 | "name": "python3"
163 | },
164 | "language_info": {
165 | "codemirror_mode": {
166 | "name": "ipython",
167 | "version": 3
168 | },
169 | "file_extension": ".py",
170 | "mimetype": "text/x-python",
171 | "name": "python",
172 | "nbconvert_exporter": "python",
173 | "pygments_lexer": "ipython3",
174 | "version": "3.7.4-final"
175 | },
176 | "toc": {
177 | "base_numbering": 1,
178 | "nav_menu": {},
179 | "number_sections": true,
180 | "sideBar": true,
181 | "skip_h1_title": false,
182 | "title_cell": "Table of Contents",
183 | "title_sidebar": "Contents",
184 | "toc_cell": false,
185 | "toc_position": {},
186 | "toc_section_display": true,
187 | "toc_window_display": false
188 | }
189 | },
190 | "nbformat": 4,
191 | "nbformat_minor": 2
192 | }
--------------------------------------------------------------------------------
/Anchor-Kmeans/gen_anchors.py:
--------------------------------------------------------------------------------
1 | from kmeans import AnchorKmeans
2 | from datasets import AnnotParser
3 | import argparse
4 |
5 |
6 | def main(args):
7 | file_type = args["type"]
8 | k = args["k_clusters"]
9 | annot_dir = args["dir_path"]
10 | parser = AnnotParser(file_type)
11 |
12 | print("[INFO] Load datas from {}".format(annot_dir))
13 | boxes = parser.parse(annot_dir)
14 |
15 | print("[INFO] Initialize model")
16 | model = AnchorKmeans(k)
17 |
18 | print("[INFO] Training...")
19 | model.fit(boxes)
20 |
21 | anchors = model.anchors_
22 | print("[INFO] The results anchors:\n{}".format(anchors))
23 |
24 |
25 | if __name__ == "__main__":
26 | ap = argparse.ArgumentParser()
27 | ap.add_argument("-d",
28 | "--dir_path",
29 | required=True,
30 | help="directory path of annotation files")
31 | ap.add_argument("-t",
32 | "--type",
33 | choices=['xml', 'json', 'csv'],
34 | default='xml',
35 | help="type of annotation file")
36 | ap.add_argument("-k",
37 | "--k_clusters",
38 | type=int,
39 | default=5,
40 | help="the number of clusters")
41 | args = vars(ap.parse_args())
42 | main(args)
43 |
--------------------------------------------------------------------------------
/Anchor-Kmeans/imgs/avgiou.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ybcc2015/DeepLearning-Utils/629bea84be257005dd3c331f14ba390c5ea59065/Anchor-Kmeans/imgs/avgiou.png
--------------------------------------------------------------------------------
/Anchor-Kmeans/kmeans.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | class AnchorKmeans(object):
5 | """
6 | K-means clustering on bounding boxes to generate anchors
7 | """
8 | def __init__(self, k, max_iter=300, random_seed=None):
9 | self.k = k
10 | self.max_iter = max_iter
11 | self.random_seed = random_seed
12 | self.n_iter = 0
13 | self.anchors_ = None
14 | self.labels_ = None
15 | self.ious_ = None
16 |
17 | def fit(self, boxes):
18 | """
19 | Run K-means cluster on input boxes.
20 |
21 | :param boxes: 2-d array, shape(n, 2), form as (w, h)
22 | :return: None
23 | """
24 | assert self.k < len(boxes), "K must be less than the number of data."
25 |
26 | # If the current number of iterations is greater than 0, then reset
27 | if self.n_iter > 0:
28 | self.n_iter = 0
29 |
30 | np.random.seed(self.random_seed)
31 | n = boxes.shape[0]
32 |
33 | # Initialize K cluster centers (i.e., K anchors)
34 | self.anchors_ = boxes[np.random.choice(n, self.k, replace=True)]
35 |
36 | self.labels_ = np.zeros((n,))
37 |
38 | while True:
39 | self.n_iter += 1
40 |
41 | # If the current number of iterations is greater than max number of iterations , then break
42 | if self.n_iter > self.max_iter:
43 | break
44 |
45 | self.ious_ = self.iou(boxes, self.anchors_)
46 | distances = 1 - self.ious_
47 | cur_labels = np.argmin(distances, axis=1)
48 |
49 | # If anchors not change any more, then break
50 | if (cur_labels == self.labels_).all():
51 | break
52 |
53 | # Update K anchors
54 | for i in range(self.k):
55 | self.anchors_[i] = np.mean(boxes[cur_labels == i], axis=0)
56 |
57 | self.labels_ = cur_labels
58 |
59 | @staticmethod
60 | def iou(boxes, anchors):
61 | """
62 | Calculate the IOU between boxes and anchors.
63 |
64 | :param boxes: 2-d array, shape(n, 2)
65 | :param anchors: 2-d array, shape(k, 2)
66 | :return: 2-d array, shape(n, k)
67 | """
68 | # Calculate the intersection,
69 | # the new dimension are added to construct shape (n, 1) and shape (1, k),
70 | # so we can get (n, k) shape result by numpy broadcast
71 | w_min = np.minimum(boxes[:, 0, np.newaxis], anchors[np.newaxis, :, 0])
72 | h_min = np.minimum(boxes[:, 1, np.newaxis], anchors[np.newaxis, :, 1])
73 | inter = w_min * h_min
74 |
75 | # Calculate the union
76 | box_area = boxes[:, 0] * boxes[:, 1]
77 | anchor_area = anchors[:, 0] * anchors[:, 1]
78 | union = box_area[:, np.newaxis] + anchor_area[np.newaxis]
79 |
80 | return inter / (union - inter)
81 |
82 | def avg_iou(self):
83 | """
84 | Calculate the average IOU with closest anchor.
85 |
86 | :return: None
87 | """
88 | return np.mean(self.ious_[np.arange(len(self.labels_)), self.labels_])
89 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # DeepLearning-Utils
2 | This repository contains some commonly ulits in deep learning.
3 |
--------------------------------------------------------------------------------
/cal_mean_std.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import cv2
3 | from imutils import paths
4 |
5 |
6 | def calculate_mean_std(img_root, channels=3):
7 | """
8 | Calculate the mean and standard deviation of the training images.
9 |
10 | Arguments:
11 | img_root {str} -- the root directory of training images
12 | channels {int} -- the numbers of channles
13 |
14 | Returns:
15 | mean {1-d numpy array} -- mean value of each channel
16 | std {1-d numpy array} -- standard deviation of each channel
17 | """
18 | total_pixel = 0
19 | channel_sum = np.zeros(channels)
20 | channel_square_sum = np.zeros(channels)
21 |
22 | for img_path in paths.list_images(img_root):
23 | img = cv2.imread(img_path)
24 | img = img / 255.
25 | channel_sum = np.sum(img, axis=(0, 1))
26 | channel_square_sum = np.sum(img ** 2, axis=(0, 1))
27 | total_pixel += img.shape[0] * img.shape[1]
28 |
29 | mean = channel_sum / total_pixel
30 | std = np.sqrt(channel_square_sum / total_pixel - mean ** 2)
31 |
32 | if channels == 3: # bgr -> rgb
33 | mean = mean[::-1]
34 | std = std[::-1]
35 |
36 | return mean, std
37 |
--------------------------------------------------------------------------------
/wtm/find_gt_files.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | DIR=$1
4 | find $DIR -name '*_gt.json' | wc -l
5 |
6 |
--------------------------------------------------------------------------------
/wtm/img_fill.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import cv2
3 | import json
4 |
5 |
6 | class ImgFill(object):
7 | def __init__(self, json_file):
8 | self.box = self.get_box(json_file)
9 |
10 | @staticmethod
11 | def get_box(json_file):
12 | """
13 | 获取json文件中的box_b的坐标
14 | Args:
15 | json_file (str): json文件路径
16 | Returns:
17 | [list]: box_b的坐标,格式为[left, top, right, bottom]
18 | """
19 | res = []
20 |
21 | with open("./boxes.json") as f:
22 | json_data = json.load(f)
23 |
24 | for box in json_data["boxes"]:
25 | if box["name"] == "box_b":
26 | print(box["rectangle"])
27 | res = box["rectangle"]["left_top"]
28 | res.extend(box["rectangle"]["right_bottom"])
29 |
30 | return res
31 |
32 | def is_box_valid(self, img):
33 | """
34 | 判断box_b指定的区域是否超出img的边界
35 | Args:
36 | img (numpy array): 目标图像
37 | Returns:
38 | [bool]: ture or false
39 | """
40 | left, top = self.box[:2]
41 | h = self.box[3] - self.box[1]
42 | w = self.box[2] - self.box[0]
43 | h_img, w_img = img.shape[:2]
44 |
45 | cond1 = left >= 0 and (left + w) <= w_img
46 | cond2 = top >= 0 and (top + h) <= h_img
47 | return cond1 and cond2
48 |
49 | def fill(self, dst_img, src_img, mode="stretch"):
50 | """
51 | 图像填充函数
52 | Args:
53 | dst_img (numpy array): 目标图像
54 | src_img (numpy array): 源图像 (待填充的图像)
55 | mode (str): 填充模式, "stretch"指拉伸填充, "keep"指保持比例填充
56 | Returns:
57 | [numpy array]: 填充后的图像
58 | """
59 | ok = self.is_box_valid(dst_img)
60 | if not ok:
61 | return
62 |
63 | # 得到填充区域的左上角顶点, 以及宽和高
64 | left, top = self.box[:2]
65 | h = self.box[3] - self.box[1]
66 | w = self.box[2] - self.box[0]
67 |
68 | assert mode in ["stretch", "keep"], "当前仅支持'stretch'和'keep'填充模式!"
69 |
70 | if mode == "stretch":
71 | src_img = cv2.resize(src_img, (w, h))
72 | dst_img[top: top + h, left: left + w] = src_img
73 |
74 | if mode == "keep":
75 | # 基于源图的长边得到缩放比例
76 | h_img, w_img = src_img.shape[:2]
77 | ratio = h / h_img if h_img >= w_img else w / w_img
78 |
79 | # 源图等比例缩放
80 | h_new = int(round(h_img * ratio))
81 | w_new = int(round(w_img * ratio))
82 | src_img = cv2.resize(src_img, (w_new, h_new))
83 |
84 | # 如果缩放后的高小于填充区域的高, 则沿y轴方向进行pad
85 | if h_new < h:
86 | pad = h - h_new
87 | pad_size = (pad // 2, pad - pad // 2)
88 | np.pad(src_img, (pad_size, (0, 0)))
89 | h_new = h
90 |
91 | # 如果缩放后的宽小于填充区域的高宽, 则沿x轴方向进行pad
92 | if w_new < w:
93 | pad = w - w_new
94 | pad_size = (pad // 2, pad - pad // 2)
95 | np.pad(src_img, ((0, 0), pad_size))
96 | w_new = w
97 |
98 | dst_img[top: top + h_new, left: left + w_new] = src_img
99 |
100 | return dst_img
101 |
--------------------------------------------------------------------------------