├── README.md ├── main.ipynb ├── output_8_1.png └── output_8_3.png /README.md: -------------------------------------------------------------------------------- 1 | # gobang-object-detection-dataset 2 | 3 | # 1. 相关大作业进度的博客: 4 | 5 | [【BIT大作业】人工智能+五子棋实战(一)棋子目标检测](https://blog.csdn.net/weixin_44936889/article/details/109862218) 6 | 7 | [【BIT大作业】人工智能+五子棋实战(二)博弈搜索算法](https://blog.csdn.net/weixin_44936889/article/details/110380769) 8 | 9 | # 2. 数据集介绍 10 | 北理BIT人工智能大作业,写脚本收集了黑/白棋子检测数据集 11 | 12 | 可以关注我的公众号,回复“五子棋”获取数据集: 13 | 14 | ![在这里插入图片描述](https://img-blog.csdnimg.cn/20210127153004430.jpg?x-oss-process=image/watermark,type_ZmFuZ3poZW5naGVpdGk,shadow_10,text_aHR0cHM6Ly9ibG9nLmNzZG4ubmV0L3dlaXhpbl80NDkzNjg4OQ==,size_16,color_FFFFFF,t_70) 15 | 16 | 数据集为pygame游戏界面截图: 17 | ![在这里插入图片描述](https://img-blog.csdnimg.cn/20201120194534365.png?x-oss-process=image/watermark,type_ZmFuZ3poZW5naGVpdGk,shadow_10,text_aHR0cHM6Ly9ibG9nLmNzZG4ubmV0L3dlaXhpbl80NDkzNjg4OQ==,size_16,color_FFFFFF,t_70#pic_center) 18 | 19 | # 3. 目标检测Baseline: 20 | 21 | 这里使用PaddleX提供的YOLOv3目标检测算法。 22 | 23 | 同时由于目标比较好识别,所以使用轻量级的MobileNet作为主干网络。 24 | ```python 25 | !pip install paddlex -i https://mirror.baidu.com/pypi/simple 26 | !pip install imgaug -i https://mirror.baidu.com/pypi/simple 27 | !pip install paddlelite -i https://mirror.baidu.com/pypi/simple 28 | ``` 29 | 30 | 31 | ```python 32 | !unzip VOCData.zip 33 | ``` 34 | 35 | 36 | ```python 37 | from random import shuffle 38 | from tqdm import tqdm 39 | import imgaug.augmenters as iaa 40 | import cv2 41 | import os 42 | 43 | base = 'VOCData' 44 | 45 | imgs = [v for v in os.listdir(base) if v.endswith('.jpg')] 46 | 47 | shuffle(imgs) 48 | 49 | with open('train.txt', 'w') as f: 50 | for im in tqdm(imgs): 51 | xml = im[:-4]+'.xml' 52 | info = im + ' ' + xml 53 | if not cv2.imread(os.path.join(base, im)) is None: 54 | f.write(info+'\n') 55 | 56 | with open('eval.txt', 'w') as f: 57 | for im in tqdm(imgs): 58 | xml = im[:-4]+'.xml' 59 | info = im + ' ' + xml 60 | if not cv2.imread(os.path.join(base, im)) is None: 61 | f.write(info+'\n') 62 | 63 | labels = ['black', 'white'] 64 | with open('labels.txt', 'w') as f: 65 | for lbl in labels: 66 | f.write(lbl+'\n') 67 | ``` 68 | 69 | 100%|██████████| 129/129 [00:00<00:00, 262.17it/s] 70 | 100%|██████████| 129/129 [00:00<00:00, 263.15it/s] 71 | 72 | 73 | 74 | ```python 75 | from paddlex.det import transforms 76 | import imgaug.augmenters as iaa 77 | 78 | train_transforms = transforms.Compose([ 79 | transforms.RandomCrop(), 80 | # iaa.GaussianBlur(sigma=(0.0, 3.0)), 81 | iaa.MultiplyAndAddToBrightness(mul=(0.5, 1.5), add=(-30, 30)), 82 | iaa.Cutout(fill_mode="constant", cval=(0, 255), fill_per_channel=0.5), 83 | transforms.RandomHorizontalFlip(), 84 | iaa.SaltAndPepper(0.1), 85 | # transforms.RandomDistort(), 86 | transforms.RandomExpand(), 87 | transforms.Resize(target_size=608, interp='RANDOM'), 88 | # transforms.RandomCrop(608), 89 | transforms.Normalize(), 90 | ]) 91 | 92 | eval_transforms = transforms.Compose([ 93 | transforms.Resize(target_size=608, interp='CUBIC'), 94 | transforms.Normalize(), 95 | ]) 96 | ``` 97 | 98 | 99 | ```python 100 | import os 101 | import cv2 102 | from tqdm import tqdm 103 | import xml.etree.ElementTree as ET 104 | 105 | base = 'VOCData' 106 | 107 | 108 | xmls = [v for v in os.listdir(base) if v.endswith('.xml')] 109 | 110 | for x in xmls: 111 | updateTree = ET.parse(os.path.join(base, x)) # 读取待修改文件 112 | root = updateTree.getroot() 113 | 114 | W = root.find("size").find("width") 115 | H = root.find("size").find("height") 116 | try: 117 | print(x, end='\t') 118 | print(float(W.text), end='\t') 119 | print(float(H.text), end='\r') 120 | except ValueError as e: 121 | print(x) 122 | print(e) 123 | im = cv2.imread(os.path.join(base, x[:-4]+'.jpg')) 124 | h, w = im.shape[:2] 125 | W.text = str(w) 126 | H.text = str(h) 127 | updateTree.write(os.path.join(base, x)) 128 | ``` 129 | 130 | 131 | ```python 132 | import paddlex as pdx 133 | 134 | base = 'VOCData' 135 | 136 | train_dataset = pdx.datasets.VOCDetection( 137 | data_dir=base, 138 | file_list='train.txt', 139 | label_list='labels.txt', 140 | transforms=train_transforms, 141 | shuffle=True) 142 | 143 | eval_dataset = pdx.datasets.VOCDetection( 144 | data_dir=base, 145 | file_list='eval.txt', 146 | label_list='labels.txt', 147 | transforms=eval_transforms) 148 | ``` 149 | 150 | 2020-11-20 19:29:45 [INFO] Starting to read file list from dataset... 151 | 2020-11-20 19:29:46 [INFO] 129 samples in file train.txt 152 | creating index... 153 | index created! 154 | 2020-11-20 19:29:46 [INFO] Starting to read file list from dataset... 155 | 2020-11-20 19:29:46 [INFO] 129 samples in file eval.txt 156 | creating index... 157 | index created! 158 | 159 | 160 | 161 | ```python 162 | num_classes = len(train_dataset.labels) 163 | print('class num:', num_classes) 164 | model = pdx.det.YOLOv3(num_classes=num_classes, backbone='MobileNetV1') 165 | model.train( 166 | num_epochs=48, 167 | train_dataset=train_dataset, 168 | train_batch_size=4, 169 | eval_dataset=eval_dataset, 170 | learning_rate=0.00025, 171 | lr_decay_epochs=[60, 160], 172 | save_interval_epochs=16, 173 | log_interval_steps=100, 174 | save_dir='./YOLOv3', 175 | pretrain_weights='YOLOv3/best_model', 176 | use_vdl=True) 177 | ``` 178 | 179 | 180 | ```python 181 | model = pdx.load_model('YOLOv3/best_model') 182 | model.evaluate(eval_dataset, batch_size=1, epoch_id=None, metric=None, return_details=False) 183 | ``` 184 | 185 | 2020-11-20 19:33:13 [INFO] Model[YOLOv3] loaded. 186 | 2020-11-20 19:33:13 [INFO] Start to evaluating(total_samples=129, total_steps=129)... 187 | 188 | 189 | 100%|██████████| 129/129 [00:09<00:00, 13.29it/s] 190 | 191 | 192 | 193 | 194 | 195 | OrderedDict([('bbox_map', 100.0)]) 196 | 197 | 198 | 199 | 200 | ```python 201 | import os 202 | import cv2 203 | import time 204 | import numpy as np 205 | import matplotlib.pyplot as plt 206 | import paddlex as pdx 207 | %matplotlib inline 208 | 209 | model = pdx.load_model('YOLOv3/best_model') 210 | 211 | base = 'images' 212 | for im in os.listdir(base): 213 | if not im.endswith('.png'): 214 | continue 215 | image_name = os.path.join(base, im) 216 | start = time.time() 217 | result = model.predict(image_name) 218 | print('infer time:{:.6f}s'.format(time.time()-start)) 219 | print('detected num:', len(result)) 220 | 221 | im = cv2.imread(image_name) 222 | font = cv2.FONT_HERSHEY_SIMPLEX 223 | threshold = 0.2 224 | 225 | for value in result: 226 | xmin, ymin, w, h = np.array(value['bbox']).astype(np.int) 227 | cls = value['category'] 228 | score = value['score'] 229 | if score < threshold: 230 | continue 231 | if cls == 'white': 232 | color = (0, 255, 0) 233 | else: 234 | color = (0, 0, 255) 235 | cv2.rectangle(im, (xmin, ymin), (xmin+w, ymin+h), color, 3) 236 | cv2.putText(im, '{:s} {:.3f}'.format(cls, score), 237 | (xmin, ymin), font, 0.5, (255, 0, 0), thickness=1) 238 | 239 | plt.figure(figsize=(15,12)) 240 | plt.imshow(im[:, :, [2,1,0]]) 241 | plt.show() 242 | ``` 243 | 244 | 2020-11-20 19:34:48 [INFO] Model[YOLOv3] loaded. 245 | infer time:0.069679s 246 | detected num: 13 247 | 248 | 249 | 250 | ![png](output_8_1.png) 251 | 252 | 253 | infer time:0.035611s 254 | detected num: 51 255 | 256 | 257 | 258 | ![png](output_8_3.png) 259 | 260 | 261 | 262 | 263 | 写代码不易,求个star~ 264 | -------------------------------------------------------------------------------- /output_8_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sharpiless/gobang-object-detection-dataset/5b9cac6d3cae824f1798f24897a3c8ac076fc3e8/output_8_1.png -------------------------------------------------------------------------------- /output_8_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sharpiless/gobang-object-detection-dataset/5b9cac6d3cae824f1798f24897a3c8ac076fc3e8/output_8_3.png --------------------------------------------------------------------------------